Files
deepmind-research/density_functional_approximation_dm21/BUILD.bazel
T
2021-12-09 17:41:31 +00:00

126 lines
3.4 KiB
Python

"""DM21 functionals."""
load("@org_tensorflow//tensorflow/python/tools:tools.bzl", "saved_model_compile_aot")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
load("@external_py_deps//:requirements.bzl", "requirement")
licenses(["notice"])
py_library(
name = "compute_hfx_density",
srcs = ["density_functional_approximation_dm21/compute_hfx_density.py"],
srcs_version = "PY3",
deps = [
requirement("numpy"),
requirement("pyscf"),
],
)
py_test(
name = "compute_hfx_density_test",
srcs = ["density_functional_approximation_dm21/compute_hfx_density_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":compute_hfx_density",
requirement("attrs"),
requirement("numpy"),
requirement("scipy"),
requirement("pyscf"),
"@io_abseil_py//absl/testing:absltest",
"@io_abseil_py//absl/testing:parameterized",
],
)
filegroup(
name = "dm21_checkpoints",
srcs = glob(["density_functional_approximation_dm21/checkpoints/**"]),
)
py_library(
name = "neural_numint",
srcs = ["density_functional_approximation_dm21/neural_numint.py"],
data = [":dm21_checkpoints"],
srcs_version = "PY3",
deps = [
":compute_hfx_density",
requirement("attrs"),
requirement("keras"),
requirement("numpy"),
requirement("pyscf"),
"@org_tensorflow//tensorflow:tensorflow_py",
requirement("tensorflow-hub"),
],
)
py_test(
name = "neural_numint_test",
srcs = ["density_functional_approximation_dm21/neural_numint_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":neural_numint",
requirement("attrs"),
requirement("pyscf"),
"@io_abseil_py//absl/testing:parameterized",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_binary(
name = "export_saved_model",
srcs = ["density_functional_approximation_dm21/export_saved_model.py"],
data = [":dm21_checkpoints"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":neural_numint",
],
)
EXPORTED_SAVED_MODEL_OBJECTS = [
"DM21/saved_model.pb",
"DM21/variables/variables.index",
"DM21/variables/variables.data-00000-of-00001",
]
genrule(
name = "create_model_for_aot_compile",
outs = EXPORTED_SAVED_MODEL_OBJECTS,
# The functional used in dm21_aot_compiled_example can be changed by
# setting the --functional flag to the desired functional.
cmd = "$(location :export_saved_model) --functional DM21 --batch_size 1000 --out_dir $(@D)/DM21",
tools = [":export_saved_model"],
)
filegroup(
name = "dm21_exported_model",
srcs = EXPORTED_SAVED_MODEL_OBJECTS,
)
saved_model_compile_aot(
name = "aot_compiled_dm21",
cpp_class = "dm21::functional",
directory = ":DM21",
filegroups = [":dm21_exported_model"],
force_without_xla_support_flag = False,
multithreading = False,
signature_def = "default",
tag_set = "''",
target_triple = "x86_64-pc-linux",
)
cc_library(
name = "dm21_aot_compiled_example",
srcs = ["cc/dm21_aot_compiled_example.cc"],
hdrs = ["cc/dm21_aot_compiled_example.h"],
deps = [":aot_compiled_dm21"],
)
cc_binary(
name = "run_dm21_aot_compiled_example",
srcs = ["cc/run_dm21_aot_compiled_example.cc"],
copts = ["-DXLA_AVAILABLE"],
deps = [":dm21_aot_compiled_example"],
)