mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-07 04:22:09 +08:00
140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
# pylint: disable=g-bad-file-header
|
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Utilities for reading open sourced Learning Complex Physics data."""
|
|
|
|
import functools
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
# Create a description of the features.
|
|
_FEATURE_DESCRIPTION = {
|
|
'position': tf.io.VarLenFeature(tf.string),
|
|
}
|
|
|
|
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy()
|
|
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT['step_context'] = tf.io.VarLenFeature(
|
|
tf.string)
|
|
|
|
_FEATURE_DTYPES = {
|
|
'position': {
|
|
'in': np.float32,
|
|
'out': tf.float32
|
|
},
|
|
'step_context': {
|
|
'in': np.float32,
|
|
'out': tf.float32
|
|
}
|
|
}
|
|
|
|
_CONTEXT_FEATURES = {
|
|
'key': tf.io.FixedLenFeature([], tf.int64, default_value=0),
|
|
'particle_type': tf.io.VarLenFeature(tf.string)
|
|
}
|
|
|
|
|
|
def convert_to_tensor(x, encoded_dtype):
|
|
if len(x) == 1:
|
|
out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype)
|
|
else:
|
|
out = []
|
|
for el in x:
|
|
out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype))
|
|
out = tf.convert_to_tensor(np.array(out))
|
|
return out
|
|
|
|
|
|
def parse_serialized_simulation_example(example_proto, metadata):
|
|
"""Parses a serialized simulation tf.SequenceExample.
|
|
|
|
Args:
|
|
example_proto: A string encoding of the tf.SequenceExample proto.
|
|
metadata: A dict of metadata for the dataset.
|
|
|
|
Returns:
|
|
context: A dict, with features that do not vary over the trajectory.
|
|
parsed_features: A dict of tf.Tensors representing the parsed examples
|
|
across time, where axis zero is the time axis.
|
|
|
|
"""
|
|
if 'context_mean' in metadata:
|
|
feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT
|
|
else:
|
|
feature_description = _FEATURE_DESCRIPTION
|
|
context, parsed_features = tf.io.parse_single_sequence_example(
|
|
example_proto,
|
|
context_features=_CONTEXT_FEATURES,
|
|
sequence_features=feature_description)
|
|
for feature_key, item in parsed_features.items():
|
|
convert_fn = functools.partial(
|
|
convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]['in'])
|
|
parsed_features[feature_key] = tf.py_function(
|
|
convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]['out'])
|
|
|
|
# There is an extra frame at the beginning so we can calculate pos change
|
|
# for all frames used in the paper.
|
|
position_shape = [metadata['sequence_length'] + 1, -1, metadata['dim']]
|
|
|
|
# Reshape positions to correct dim:
|
|
parsed_features['position'] = tf.reshape(parsed_features['position'],
|
|
position_shape)
|
|
# Set correct shapes of the remaining tensors.
|
|
sequence_length = metadata['sequence_length'] + 1
|
|
if 'context_mean' in metadata:
|
|
context_feat_len = len(metadata['context_mean'])
|
|
parsed_features['step_context'] = tf.reshape(
|
|
parsed_features['step_context'],
|
|
[sequence_length, context_feat_len])
|
|
# Decode particle type explicitly
|
|
context['particle_type'] = tf.py_function(
|
|
functools.partial(convert_fn, encoded_dtype=np.int64),
|
|
inp=[context['particle_type'].values],
|
|
Tout=[tf.int64])
|
|
context['particle_type'] = tf.reshape(context['particle_type'], [-1])
|
|
return context, parsed_features
|
|
|
|
|
|
def split_trajectory(context, features, window_length=7):
|
|
"""Splits trajectory into sliding windows."""
|
|
# Our strategy is to make sure all the leading dimensions are the same size,
|
|
# then we can use from_tensor_slices.
|
|
|
|
trajectory_length = features['position'].get_shape().as_list()[0]
|
|
|
|
# We then stack window_length position changes so the final
|
|
# trajectory length will be - window_length +1 (the 1 to make sure we get
|
|
# the last split).
|
|
input_trajectory_length = trajectory_length - window_length + 1
|
|
|
|
model_input_features = {}
|
|
# Prepare the context features per step.
|
|
model_input_features['particle_type'] = tf.tile(
|
|
tf.expand_dims(context['particle_type'], axis=0),
|
|
[input_trajectory_length, 1])
|
|
|
|
if 'step_context' in features:
|
|
global_stack = []
|
|
for idx in range(input_trajectory_length):
|
|
global_stack.append(features['step_context'][idx:idx + window_length])
|
|
model_input_features['step_context'] = tf.stack(global_stack)
|
|
|
|
pos_stack = []
|
|
for idx in range(input_trajectory_length):
|
|
pos_stack.append(features['position'][idx:idx + window_length])
|
|
# Get the corresponding positions
|
|
model_input_features['position'] = tf.stack(pos_stack)
|
|
|
|
return tf.data.Dataset.from_tensor_slices(model_input_features)
|