Files
deepmind-research/learning_to_simulate/reading_utils.py
Yilei Yang 152ac280f2 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440896914
2022-05-26 17:43:58 +01:00

140 lines
5.0 KiB
Python

# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utilities for reading open sourced Learning Complex Physics data."""
import functools
import numpy as np
import tensorflow.compat.v1 as tf
# Create a description of the features.
_FEATURE_DESCRIPTION = {
'position': tf.io.VarLenFeature(tf.string),
}
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy()
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT['step_context'] = tf.io.VarLenFeature(
tf.string)
_FEATURE_DTYPES = {
'position': {
'in': np.float32,
'out': tf.float32
},
'step_context': {
'in': np.float32,
'out': tf.float32
}
}
_CONTEXT_FEATURES = {
'key': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'particle_type': tf.io.VarLenFeature(tf.string)
}
def convert_to_tensor(x, encoded_dtype):
if len(x) == 1:
out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype)
else:
out = []
for el in x:
out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype))
out = tf.convert_to_tensor(np.array(out))
return out
def parse_serialized_simulation_example(example_proto, metadata):
"""Parses a serialized simulation tf.SequenceExample.
Args:
example_proto: A string encoding of the tf.SequenceExample proto.
metadata: A dict of metadata for the dataset.
Returns:
context: A dict, with features that do not vary over the trajectory.
parsed_features: A dict of tf.Tensors representing the parsed examples
across time, where axis zero is the time axis.
"""
if 'context_mean' in metadata:
feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT
else:
feature_description = _FEATURE_DESCRIPTION
context, parsed_features = tf.io.parse_single_sequence_example(
example_proto,
context_features=_CONTEXT_FEATURES,
sequence_features=feature_description)
for feature_key, item in parsed_features.items():
convert_fn = functools.partial(
convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]['in'])
parsed_features[feature_key] = tf.py_function(
convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]['out'])
# There is an extra frame at the beginning so we can calculate pos change
# for all frames used in the paper.
position_shape = [metadata['sequence_length'] + 1, -1, metadata['dim']]
# Reshape positions to correct dim:
parsed_features['position'] = tf.reshape(parsed_features['position'],
position_shape)
# Set correct shapes of the remaining tensors.
sequence_length = metadata['sequence_length'] + 1
if 'context_mean' in metadata:
context_feat_len = len(metadata['context_mean'])
parsed_features['step_context'] = tf.reshape(
parsed_features['step_context'],
[sequence_length, context_feat_len])
# Decode particle type explicitly
context['particle_type'] = tf.py_function(
functools.partial(convert_fn, encoded_dtype=np.int64),
inp=[context['particle_type'].values],
Tout=[tf.int64])
context['particle_type'] = tf.reshape(context['particle_type'], [-1])
return context, parsed_features
def split_trajectory(context, features, window_length=7):
"""Splits trajectory into sliding windows."""
# Our strategy is to make sure all the leading dimensions are the same size,
# then we can use from_tensor_slices.
trajectory_length = features['position'].get_shape().as_list()[0]
# We then stack window_length position changes so the final
# trajectory length will be - window_length +1 (the 1 to make sure we get
# the last split).
input_trajectory_length = trajectory_length - window_length + 1
model_input_features = {}
# Prepare the context features per step.
model_input_features['particle_type'] = tf.tile(
tf.expand_dims(context['particle_type'], axis=0),
[input_trajectory_length, 1])
if 'step_context' in features:
global_stack = []
for idx in range(input_trajectory_length):
global_stack.append(features['step_context'][idx:idx + window_length])
model_input_features['step_context'] = tf.stack(global_stack)
pos_stack = []
for idx in range(input_trajectory_length):
pos_stack.append(features['position'][idx:idx + window_length])
# Get the corresponding positions
model_input_features['position'] = tf.stack(pos_stack)
return tf.data.Dataset.from_tensor_slices(model_input_features)