Files
deepmind-research/option_keyboard/smart_module.py
Yilei Yang 5501e3237b Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441399134
2022-05-26 17:44:30 +01:00

227 lines
7.7 KiB
Python

# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Smart module export/import utilities."""
import inspect
import pickle
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.io import gfile
import tensorflow_hub as hub
import tree as nest
import wrapt
_ALLOWED_TYPES = (bool, float, int, str)
def _getcallargs(signature, *args, **kwargs):
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
inputs = bound_args.arguments
inputs.pop("self", None)
return inputs
def _to_placeholder(arg):
if arg is None or isinstance(arg, bool):
return arg
arg = tf.convert_to_tensor(arg)
return tf.placeholder(dtype=arg.dtype, shape=arg.shape)
class SmartModuleExport(object):
"""Helper class for exporting TF-Hub modules."""
def __init__(self, object_factory):
self._object_factory = object_factory
self._wrapped_object = self._object_factory()
self._variable_scope = tf.get_variable_scope()
self._captured_calls = {}
self._captured_attrs = {}
def _create_captured_method(self, method_name):
"""Creates a wrapped method that captures its inputs."""
with tf.variable_scope(self._variable_scope):
method_ = getattr(self._wrapped_object, method_name)
@wrapt.decorator
def wrapper(method, instance, args, kwargs):
"""Wrapped method to capture inputs."""
del instance
specs = inspect.signature(method)
inputs = _getcallargs(specs, *args, **kwargs)
with tf.variable_scope(self._variable_scope):
output = method(*args, **kwargs)
self._captured_calls[method_name] = [inputs, specs]
return output
return wrapper(method_) # pylint: disable=no-value-for-parameter
def __getattr__(self, name):
"""Helper method for accessing an attributes of the wrapped object."""
# if "_wrapped_object" not in self.__dict__:
# return super(ExportableModule, self).__getattr__(name)
with tf.variable_scope(self._variable_scope):
attr = getattr(self._wrapped_object, name)
if inspect.ismethod(attr) or inspect.isfunction(attr):
return self._create_captured_method(name)
else:
if all([isinstance(v, _ALLOWED_TYPES) for v in nest.flatten(attr)]):
self._captured_attrs[name] = attr
return attr
def __call__(self, *args, **kwargs):
return self._create_captured_method("__call__")(*args, **kwargs)
def export(self, path, session, overwrite=False):
"""Build the TF-Hub spec, module and sync ops."""
method_specs = {}
def module_fn():
"""A module_fn for use with hub.create_module_spec()."""
# We will use a copy of the original object to build the graph.
wrapped_object = self._object_factory()
for method_name, method_info in self._captured_calls.items():
captured_inputs, captured_specs = method_info
tensor_inputs = nest.map_structure(_to_placeholder, captured_inputs)
method_to_call = getattr(wrapped_object, method_name)
tensor_outputs = method_to_call(**tensor_inputs)
flat_tensor_inputs = nest.flatten(tensor_inputs)
flat_tensor_inputs = {
str(k): v for k, v in zip(
range(len(flat_tensor_inputs)), flat_tensor_inputs)
}
flat_tensor_outputs = nest.flatten(tensor_outputs)
flat_tensor_outputs = {
str(k): v for k, v in zip(
range(len(flat_tensor_outputs)), flat_tensor_outputs)
}
method_specs[method_name] = dict(
specs=captured_specs,
inputs=nest.map_structure(lambda _: None, tensor_inputs),
outputs=nest.map_structure(lambda _: None, tensor_outputs))
signature_name = ("default"
if method_name == "__call__" else method_name)
hub.add_signature(signature_name, flat_tensor_inputs,
flat_tensor_outputs)
hub.attach_message(
"methods", tf.train.BytesList(value=[pickle.dumps(method_specs)]))
hub.attach_message(
"properties",
tf.train.BytesList(value=[pickle.dumps(self._captured_attrs)]))
# Create the spec that will be later used in export.
hub_spec = hub.create_module_spec(module_fn, drop_collections=["sonnet"])
# Get variables values
module_weights = [
session.run(v) for v in self._wrapped_object.get_all_variables()
]
# create the sync ops
with tf.Graph().as_default():
hub_module = hub.Module(hub_spec, trainable=True, name="hub")
assign_ops = []
assign_phs = []
for _, v in sorted(hub_module.variable_map.items()):
ph = tf.placeholder(shape=v.shape, dtype=v.dtype)
assign_phs.append(ph)
assign_ops.append(tf.assign(v, ph))
with tf.Session() as module_session:
module_session.run(tf.local_variables_initializer())
module_session.run(tf.global_variables_initializer())
module_session.run(
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
if overwrite and gfile.exists(path):
gfile.rmtree(path)
gfile.makedirs(path)
hub_module.export(path, module_session)
class SmartModuleImport(object):
"""A class for importing graph building objects from TF-Hub modules."""
def __init__(self, module):
self._module = module
self._method_specs = pickle.loads(
self._module.get_attached_message("methods",
tf.train.BytesList).value[0])
self._properties = pickle.loads(
self._module.get_attached_message("properties",
tf.train.BytesList).value[0])
def _create_wrapped_method(self, method):
"""Creates a wrapped method that converts nested inputs and outputs."""
def wrapped_method(*args, **kwargs):
"""A wrapped method around a TF-Hub module signature."""
inputs = _getcallargs(self._method_specs[method]["specs"], *args,
**kwargs)
nest.assert_same_structure(self._method_specs[method]["inputs"], inputs)
flat_inputs = nest.flatten(inputs)
flat_inputs = {
str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs)
}
signature = "default" if method == "__call__" else method
flat_outputs = self._module(
flat_inputs, signature=signature, as_dict=True)
flat_outputs = [v for _, v in sorted(flat_outputs.items())]
output_spec = self._method_specs[method]["outputs"]
if output_spec is None:
if len(flat_outputs) != 1:
raise ValueError(
"Expected output containing a single tensor, found {}".format(
flat_outputs))
outputs = flat_outputs[0]
else:
outputs = nest.unflatten_as(output_spec, flat_outputs)
return outputs
return wrapped_method
def __getattr__(self, name):
if name in self._method_specs:
return self._create_wrapped_method(name)
if name in self._properties:
return self._properties[name]
return getattr(self._module, name)
def __call__(self, *args, **kwargs):
return self._create_wrapped_method("__call__")(*args, **kwargs)