mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-13 07:17:38 +08:00
72c72d530f
PiperOrigin-RevId: 415267993
126 lines
3.4 KiB
Python
126 lines
3.4 KiB
Python
"""DM21 functionals."""
|
|
|
|
load("@org_tensorflow//tensorflow/python/tools:tools.bzl", "saved_model_compile_aot")
|
|
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
|
|
load("@external_py_deps//:requirements.bzl", "requirement")
|
|
|
|
licenses(["notice"])
|
|
|
|
py_library(
|
|
name = "compute_hfx_density",
|
|
srcs = ["density_functional_approximation_dm21/compute_hfx_density.py"],
|
|
srcs_version = "PY3",
|
|
deps = [
|
|
requirement("numpy"),
|
|
requirement("pyscf"),
|
|
],
|
|
)
|
|
|
|
py_test(
|
|
name = "compute_hfx_density_test",
|
|
srcs = ["density_functional_approximation_dm21/compute_hfx_density_test.py"],
|
|
python_version = "PY3",
|
|
srcs_version = "PY3",
|
|
deps = [
|
|
":compute_hfx_density",
|
|
requirement("attrs"),
|
|
requirement("numpy"),
|
|
requirement("scipy"),
|
|
requirement("pyscf"),
|
|
"@io_abseil_py//absl/testing:absltest",
|
|
"@io_abseil_py//absl/testing:parameterized",
|
|
],
|
|
)
|
|
|
|
filegroup(
|
|
name = "dm21_checkpoints",
|
|
srcs = glob(["density_functional_approximation_dm21/checkpoints/**"]),
|
|
)
|
|
|
|
py_library(
|
|
name = "neural_numint",
|
|
srcs = ["density_functional_approximation_dm21/neural_numint.py"],
|
|
data = [":dm21_checkpoints"],
|
|
srcs_version = "PY3",
|
|
deps = [
|
|
":compute_hfx_density",
|
|
requirement("attrs"),
|
|
requirement("keras"),
|
|
requirement("numpy"),
|
|
requirement("pyscf"),
|
|
"@org_tensorflow//tensorflow:tensorflow_py",
|
|
requirement("tensorflow-hub"),
|
|
],
|
|
)
|
|
|
|
py_test(
|
|
name = "neural_numint_test",
|
|
srcs = ["density_functional_approximation_dm21/neural_numint_test.py"],
|
|
python_version = "PY3",
|
|
srcs_version = "PY3",
|
|
deps = [
|
|
":neural_numint",
|
|
requirement("attrs"),
|
|
requirement("pyscf"),
|
|
"@io_abseil_py//absl/testing:parameterized",
|
|
"@org_tensorflow//tensorflow:tensorflow_py",
|
|
],
|
|
)
|
|
|
|
py_binary(
|
|
name = "export_saved_model",
|
|
srcs = ["density_functional_approximation_dm21/export_saved_model.py"],
|
|
data = [":dm21_checkpoints"],
|
|
python_version = "PY3",
|
|
srcs_version = "PY3",
|
|
deps = [
|
|
":neural_numint",
|
|
],
|
|
)
|
|
|
|
EXPORTED_SAVED_MODEL_OBJECTS = [
|
|
"DM21/saved_model.pb",
|
|
"DM21/variables/variables.index",
|
|
"DM21/variables/variables.data-00000-of-00001",
|
|
]
|
|
|
|
genrule(
|
|
name = "create_model_for_aot_compile",
|
|
outs = EXPORTED_SAVED_MODEL_OBJECTS,
|
|
# The functional used in dm21_aot_compiled_example can be changed by
|
|
# setting the --functional flag to the desired functional.
|
|
cmd = "$(location :export_saved_model) --functional DM21 --batch_size 1000 --out_dir $(@D)/DM21",
|
|
tools = [":export_saved_model"],
|
|
)
|
|
|
|
filegroup(
|
|
name = "dm21_exported_model",
|
|
srcs = EXPORTED_SAVED_MODEL_OBJECTS,
|
|
)
|
|
|
|
saved_model_compile_aot(
|
|
name = "aot_compiled_dm21",
|
|
cpp_class = "dm21::functional",
|
|
directory = ":DM21",
|
|
filegroups = [":dm21_exported_model"],
|
|
force_without_xla_support_flag = False,
|
|
multithreading = False,
|
|
signature_def = "default",
|
|
tag_set = "''",
|
|
target_triple = "x86_64-pc-linux",
|
|
)
|
|
|
|
cc_library(
|
|
name = "dm21_aot_compiled_example",
|
|
srcs = ["cc/dm21_aot_compiled_example.cc"],
|
|
hdrs = ["cc/dm21_aot_compiled_example.h"],
|
|
deps = [":aot_compiled_dm21"],
|
|
)
|
|
|
|
cc_binary(
|
|
name = "run_dm21_aot_compiled_example",
|
|
srcs = ["cc/run_dm21_aot_compiled_example.cc"],
|
|
copts = ["-DXLA_AVAILABLE"],
|
|
deps = [":dm21_aot_compiled_example"],
|
|
)
|