Initial release of "mmv".

PiperOrigin-RevId: 346305536
This commit is contained in:
Louise Deason
2020-12-08 13:54:10 +00:00
parent 7ed0b0508d
commit c146166d4b
18 changed files with 3015 additions and 0 deletions
+1
View File
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Self-Supervised MultiModal Versatile Networks](mmv), NeurIPS 2020
* [ODE-GAN: Training GANs by Solving Ordinary Differential Equations](ode_gan), NeurIPS 2020
* [Algorithms for Causal Reasoning in Probability Trees](causal_reasoning)
* [Gated Linear Networks](gated_linear_networks), NeurIPS 2020
+83
View File
@@ -0,0 +1,83 @@
# Self-supervised Multimodal Versatile Networks
This is the code for the models in MMV - https://arxiv.org/abs/2006.16228.
<img src="./imgs/mmv_fig.png" width="50%">
We also make available the code for linear evaluation of a pre-trained model
in UCF101 and the JAX checkpoints for our best models.
We use different parameters for video compression in UCF101 than the ones
used in `tensorflow_datasets`. We provide the code to download and
preprocess the dataset. The eval_ucf101.py script reproduces the results we
report in Table 2 of the paper, using the checkpoints provided below.
Visual Backbone | Training Dataset | Results on Linear UCF101
------- | -------- | --------
S3D-G | AudioSet + HowTo | 89.6
Resnet TSM-50 | AudioSet + HowTo | 91.5
Resnet TSM-50 (x2) | AudioSet + HowTo | 91.8
## Setup
To set up a Python virtual environment with the required dependencies, run:
```shell
python3 -m venv mmv_env
source mmv_env/bin/activate
pip install --upgrade pip setuptools wheel
pip install -r mmv/requirements.txt --use-feature=2020-resolver
```
### Linear evaluation
The linear evaluation on UCF101 can be run using:
```shell
python -m mmv.eval_ucf101 \
--checkpoint_path=</path/to/the/checkpointing/folder> \
--dataset_folder=</path/to/dataset/folder>
```
## Checkpoints
We provide three checkpoints containing the best pre-trained weights for each
of the visual backbones we use in the paper, i. e., S3D-G, Resnet-50 TSM,
and Resnet-50 TSM x 2.
- [S3D-G](https://storage.googleapis.com/deepmind-research-mmv/mmv_s3d.pkl)
- [Resnet-50 TSM](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x1.pkl)
- [Resnet-50 TSMx2](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x2.pkl)
## References
### Citing our work
If you use that code for your research, please consider citing our paper:
```bibtex
@inproceedings{alayrac2020self,
title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks},
author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew},
booktitle={NeurIPS},
year={2020}
}
```
### Models in TF
You may also be interested in using our TF-Hub release models available at:
- [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1)
- [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1)
- [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50x2/1)
## License
While the code is licensed under the Apache 2.0 License, the checkpoints weights
are made available for non-commercial use only under the terms of the
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
license. You can find details at:
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
+85
View File
@@ -0,0 +1,85 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration parameters for MMV."""
def get_model_config(ckpt_path):
"""Returns the model configuration to be used with each checkpoint."""
config = {
'audio_backbone': 'resnet50',
'audio_model_kwargs': {
'bn_config': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
}
},
'bn_config_proj': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
},
'config_audio_text': {
'embedding_dim': 512,
'toaud_bn_after_proj': False,
'toaud_head_mode': 'linear',
'totxt_bn_after_proj': False,
'totxt_head_mode': 'linear'
},
'config_video_audio': {
'embedding_dim': 512,
'toaud_bn_after_proj': True,
'toaud_head_mode': 'mlp@512',
'tovid_bn_after_proj': False,
'tovid_head_mode': 'linear'
},
'config_video_text': {
'embedding_dim': 256,
'totxt_bn_after_proj': True,
'totxt_head_mode': 'linear',
'tovid_bn_after_proj': False,
'tovid_head_mode': 'linear'
},
'mm_embedding_graph': 'fac_relu',
'name': 'text_audio_video',
'sentence_dim': 2048,
'use_xreplica_bn': True,
'vision_model_kwargs': {
'bn_config': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
},
'n_frames': 32,
'width_mult': 1,
},
}
if 's3d' in ckpt_path:
config['visual_backbone'] = 's3d'
if 'tsm_resnet_x1' in ckpt_path:
config['visual_backbone'] = 'resnet50tsm'
if 'tsm_resnet_x2' in ckpt_path:
config['visual_backbone'] = 'resnet50tsm'
config['vision_model_kwargs']['width_mult'] = 2
return config
+465
View File
@@ -0,0 +1,465 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UCF101 linear evaluation."""
import functools
from typing import Any, Dict, Optional
from absl import app
from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import sklearn
from sklearn import preprocessing
import sklearn.linear_model
import sklearn.svm
import tensorflow as tf
import tensorflow_datasets as tfds
from mmv import config
from mmv.models import mm_embeddings
from mmv.utils import checkpoint
from mmv.utils import ucf101_dataset
flags.DEFINE_string('checkpoint_path', '~/tmp/mmv_s3d.pkl',
'The directory to load pre-trained weights from.')
flags.DEFINE_string('dataset_folder', '/tmp/ucf101',
'The directory with the ucf101 dataset.')
flags.DEFINE_integer('eval_batch_size', 1,
'The batch size for evaluation.')
flags.DEFINE_integer('train_batch_size', 16,
'The batch size for training.')
flags.DEFINE_integer('num_train_epochs', 10,
'How many epochs to collect features during training.')
flags.DEFINE_integer('num_test_windows', 10,
'How many windows to average on during test.')
flags.DEFINE_integer('min_resize', 224,
'Min value to resize images to during preprocessing.')
flags.DEFINE_integer('crop_size', 224,
'Value to resize images to during preprocessing.')
flags.DEFINE_integer('num_frames', 32,
'Number of video frames.')
flags.DEFINE_integer('stride', 2,
'Stride for video frames.')
flags.DEFINE_integer('ucf101_split', 1,
'Which split of ucf101 to use.')
FLAGS = flags.FLAGS
def get_sampling_offset(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
stride: int = 1,
seed: Optional[int] = None) -> tf.Tensor:
"""Calculates the initial offset for a sequence where all steps will fit.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: The number of timesteps we will output. If None,
deterministically start at the first frame.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame always the first frame.
stride: distance to sample between timesteps.
seed: a deterministic seed to use when sampling.
Returns:
The first index to begin sampling from. A best effort is made to provide a
starting index such that all requested steps fit within the sequence (i.e.
offset + 1 + (num_steps - 1) * stride < len(sequence)). If this is not
satisfied, the starting index is chosen randomly from the full sequence.
"""
if num_steps is None or not is_training:
return tf.constant(0)
sequence_length = tf.shape(sequence)[0]
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
offset = tf.random.uniform(
(),
maxval=tf.cast(max_offset, tf.int32),
dtype=tf.int32,
seed=seed)
return offset
def sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
repeat_sequence: bool = True,
stride: int = 1,
offset: Optional[int] = None) -> tf.Tensor:
"""Returns indices to take for sampling or padding a sequence to fixed size.
Samples num_steps from the sequence. If the sequence is shorter than
num_steps, the sequence loops. If the sequence is longer than num_steps and
is_training is True, then we seek to a random offset before sampling. If
offset is provided, we use that deterministic offset.
This method is appropriate for sampling from a tensor where you want every
timestep between a start and end time. See sample_stacked_sequence_indices for
more flexibility.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: how many steps (e.g. frames) to take. If None, all steps from
start to end are considered and `is_training` has no effect.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame is deterministic.
repeat_sequence: A boolean indicates whether the sequence will repeat to
have enough steps for sampling. If False, a runtime error is thrown if
num_steps * stride is longer than sequence length.
stride: distance to sample between timesteps.
offset: a deterministic offset to use regardless of the is_training value.
Returns:
Indices to gather from the sequence Tensor to get a fixed size sequence.
"""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
if num_steps:
if offset is None:
offset = get_sampling_offset(sequence, num_steps, is_training, stride)
if repeat_sequence:
# Repeats sequence until num_steps are available in total.
num_repeats = tf.cast(
tf.math.ceil(
tf.math.divide(
tf.cast(num_steps * stride + offset, tf.float32),
tf.cast(sequence_length, tf.float32)
)), tf.int32)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
else:
steps = tf.range(0, sequence_length, stride)
return tf.gather(sel_idx, steps)
def random_sample_sequence(sequence: tf.Tensor,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Randomly sample a segment of size num_steps from a given sequence."""
indices = sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=True, # Random sample.
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=None)
indices.set_shape((num_steps,))
output = tf.gather(sequence, indices)
return output
def sample_linspace_sequence(sequence: tf.Tensor,
num_windows: int,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Samples num_windows segments from sequence with linearly spaced offsets.
The samples are concatenated in a single Tensor in order to have the same
format structure per timestep (e.g. a single frame). If num_steps * stride is
bigger than the number of timesteps, the sequence is repeated. This function
can be used in evaluation in order to extract enough segments in order to span
the entire sequence.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_windows: Number of windows retrieved from the sequence.
num_steps: Number of steps (e.g. frames) to take.
stride: Distance to sample between timesteps.
Returns:
A single Tensor with first dimension num_windows * num_steps. The Tensor
contains the concatenated list of num_windows tensors which offsets have
been linearly spaced from input.
"""
sequence_length = tf.shape(sequence)[0]
max_offset = tf.maximum(0, sequence_length - num_steps * stride)
offsets = tf.linspace(0.0, tf.cast(max_offset, tf.float32), num_windows)
offsets = tf.cast(offsets, tf.int32)
all_indices = []
for i in range(num_windows):
all_indices.append(
sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=False,
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=offsets[i]))
indices = tf.concat(all_indices, axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor:
"""Resizes frames so that min(height, width) is equal to min_resize.
This function will not do anything if the min(height, width) is already equal
to min_resize. This allows to save compute time.
Args:
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
min_resize: Minimum size of the final image dimensions.
Returns:
A Tensor of shape [timesteps, output_h, output_w, channels] of type
frames.dtype where min(output_h, output_w) = min_resize.
"""
shape = tf.shape(frames)
input_h = shape[1]
input_w = shape[2]
output_h = tf.maximum(min_resize, (input_h * min_resize) // input_w)
output_w = tf.maximum(min_resize, (input_w * min_resize) // input_h)
def resize_fn():
frames_resized = tf.image.resize(frames, (output_h, output_w))
return tf.cast(frames_resized, frames.dtype)
should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w),
tf.not_equal(input_h, output_h))
frames = tf.cond(should_resize, resize_fn, lambda: frames)
return frames
def process_samples(features_dict, num_frames=32, stride=1, is_training=True,
min_resize=224, crop_size=224, num_windows=1):
"""Process video frames."""
video = features_dict['video']
if is_training:
assert num_windows == 1
video = random_sample_sequence(video, num_frames, stride)
is_flipped = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
video = tf.cond(tf.equal(is_flipped, 1),
true_fn=lambda: tf.image.flip_left_right(video),
false_fn=lambda: video)
else:
video = sample_linspace_sequence(video, num_windows, num_frames, stride)
# Resize smallest side.
video = resize_smallest(video, min_resize)
if is_training:
# Random crop.
video = tf.image.random_crop(video, [num_frames, crop_size, crop_size, 3])
else:
# Central crop.
video = tf.image.resize_with_crop_or_pad(video, crop_size, crop_size)
video = tf.cast(video, tf.float32)
video /= 255.0 # Set between [0, 1].
features_dict['video'] = video
return features_dict
def space_to_depth_batch(features_dict):
images = features_dict['video']
_, l, h, w, c = images.shape
images = tf.reshape(images, [-1, l // 2, 2, h // 2, 2, w // 2, 2, c])
images = tf.transpose(images, [0, 1, 3, 5, 2, 4, 6, 7])
images = tf.reshape(images, [-1, l // 2, h // 2, w // 2, 8 * c])
features_dict['video'] = images
return features_dict
def reshape_windows(features_dict, num_frames):
x = features_dict['video']
x = tf.reshape(x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4]))
features_dict['video'] = x
return features_dict
def compute_accuracy_metrics(pred, gt, prefix=''):
order_pred = np.argsort(pred, axis=1)
assert len(gt.shape) == len(order_pred.shape) == 2
top1_pred = order_pred[:, -1:]
top5_pred = order_pred[:, -5:]
top1_acc = np.mean(top1_pred == gt)
top5_acc = np.mean(np.max(top5_pred == gt, 1))
return {prefix + 'top1': top1_acc,
prefix + 'top5': top5_acc}
def forward_fn(images: jnp.ndarray,
audio_spectrogram: jnp.ndarray,
word_ids: jnp.ndarray,
is_training: bool,
model_config: Dict[str, Any]):
"""Forward pass of the model."""
# This should contain the pre-trained weights. We set it to zero because it
# will be loaded from the checkpoint.
language_model_vocab_size = 65536
word_embedding_dim = 300
dummy_embedding_matrix = jnp.zeros(shape=(language_model_vocab_size,
word_embedding_dim))
module = mm_embeddings.AudioTextVideoEmbedding(
**model_config,
word_embedding_matrix=dummy_embedding_matrix)
return module(images=images,
audio_spectrogram=audio_spectrogram,
word_ids=word_ids,
is_training=is_training)['vid_repr']
def main(argv):
del argv
sklearn_reg = 0.001
model_config = config.get_model_config(FLAGS.checkpoint_path)
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
forward_apply = jax.jit(functools.partial(forward.apply,
is_training=False,
model_config=model_config))
# Get the UCF101 config.
dset_config = tfds.video.ucf101.Ucf101.BUILDER_CONFIGS[FLAGS.ucf101_split]
builder = ucf101_dataset.ModUcf101(
data_dir=FLAGS.dataset_folder,
config=dset_config)
# Create the tfrecord files (no-op if already exists)
dl_config = tfds.download.DownloadConfig(verify_ssl=False)
builder.download_and_prepare(download_config=dl_config)
# Generate the training dataset.
train_ds = builder.as_dataset(split='train', shuffle_files=False)
train_ds = train_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=True,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size))
train_ds = train_ds.batch(batch_size=FLAGS.train_batch_size)
if model_config['visual_backbone'] == 's3d':
train_ds = train_ds.map(space_to_depth_batch)
train_ds = train_ds.repeat(FLAGS.num_train_epochs)
# Generate the test dataset.
test_ds = builder.as_dataset(split='test', shuffle_files=False)
test_ds = test_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=False,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size,
num_windows=FLAGS.num_test_windows))
test_ds = test_ds.batch(batch_size=FLAGS.eval_batch_size)
test_ds = test_ds.map(lambda x: reshape_windows( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames))
if model_config['visual_backbone'] == 's3d':
test_ds = test_ds.map(space_to_depth_batch)
test_ds = test_ds.repeat(1)
pretrained_weights = checkpoint.load_checkpoint(FLAGS.checkpoint_path)
params = pretrained_weights['params']
state = pretrained_weights['state']
# Collect training samples.
audio_frames = 96
mel_filters = 40
num_tokens = 16
dummy_audio = jnp.zeros(
shape=(FLAGS.train_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.train_batch_size, num_tokens), dtype=jnp.int32)
train_features = []
train_labels = []
print('Computing features on train')
training_examples = iter(tfds.as_numpy(train_ds))
for train_ex in training_examples:
vid_representation, _ = forward_apply(params=params,
state=state,
images=train_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
train_labels.append(train_ex['label'])
train_features.append(vid_representation)
if len(train_labels) % 50 == 0:
print(f'Processed {len(train_labels)} examples.')
train_labels = np.concatenate(train_labels, axis=0)
train_features = np.concatenate(train_features, axis=0)
print(f'Finish collecting train features of shape {train_features.shape}')
# Collect test samples.
dummy_audio = jnp.zeros(
shape=(FLAGS.eval_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.eval_batch_size, num_tokens), dtype=jnp.int32)
test_features = []
test_labels = []
print('Computing features on test')
test_examples = iter(tfds.as_numpy(test_ds))
for test_ex in test_examples:
vid_representation_test, _ = forward_apply(params=params,
state=state,
images=test_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
test_labels.append(test_ex['label'])
test_features.append(vid_representation_test)
if len(test_labels) % 50 == 0:
print(f'Processed {len(test_labels)} examples.')
test_features = np.concatenate(test_features, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
print(f'Finish collecting test features of shape {test_features.shape}')
# Train classifier
print('Training linear classifier!')
classifier = sklearn.svm.LinearSVC(C=sklearn_reg)
scaler = preprocessing.StandardScaler().fit(train_features)
train_features = scaler.transform(train_features)
classifier.fit(train_features, train_labels.ravel())
print('Training done !')
# Evaluation.
test_features = scaler.transform(test_features)
print('Running inference on train')
pred_train = classifier.decision_function(train_features)
print('Running inference on test')
pred_test = classifier.decision_function(test_features)
if FLAGS.num_test_windows > 1:
pred_test = np.reshape(
pred_test, (test_labels.shape[0], -1, pred_test.shape[1]))
pred_test = pred_test.mean(axis=1)
# Compute accuracies.
metrics = compute_accuracy_metrics(pred_train, train_labels[:, None],
prefix='train_')
metrics.update(
compute_accuracy_metrics(pred_test, test_labels[:, None], prefix='test_'))
print(metrics)
if __name__ == '__main__':
app.run(main)
Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

File diff suppressed because it is too large Load Diff
+143
View File
@@ -0,0 +1,143 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize functions constructors."""
from typing import Any, Dict, Optional, Sequence, Union
import haiku as hk
from jax import numpy as jnp
from mmv.models import types
class _BatchNorm(hk.BatchNorm):
"""A `hk.BatchNorm` with adapted default arguments."""
def __init__(self,
create_scale: bool = True,
create_offset: bool = True,
decay_rate: float = 0.9,
eps: float = 1e-5,
test_local_stats: bool = False,
**kwargs):
# Check args.
if kwargs.get('cross_replica_axis', None) is not None:
raise ValueError(
'Attempting to use \'batch_norm\' normalizer, but specifying '
'`cross_replica_axis`. If you want this behavior use '
'`normalizer=\'cross_replica_batch_norm\'` directly.')
self._test_local_stats = test_local_stats
super().__init__(create_scale=create_scale,
create_offset=create_offset,
decay_rate=decay_rate,
eps=eps,
**kwargs)
def __call__(self,
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
return super().__call__(x, is_training,
test_local_stats=self._test_local_stats)
class _CrossReplicaBatchNorm(hk.BatchNorm):
"""A `hk.BatchNorm` with adapted default arguments for cross replica."""
def __init__(self,
create_scale: bool = True,
create_offset: bool = True,
decay_rate: float = 0.9,
eps: float = 1e-5,
test_local_stats: bool = False,
**kwargs):
# Check args.
if 'cross_replica_axis' in kwargs and kwargs['cross_replica_axis'] is None:
raise ValueError(
'Attempting to use \'cross_replica_batch_norm\' normalizer, but '
'specifying `cross_replica_axis` to be None. If you want this '
'behavior use `normalizer=\'batch_norm\'` directly.')
self._test_local_stats = test_local_stats
kwargs['cross_replica_axis'] = kwargs.get('cross_replica_axis', 'i')
super().__init__(create_scale=create_scale,
create_offset=create_offset,
decay_rate=decay_rate,
eps=eps,
**kwargs)
def __call__(self,
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
return super().__call__(x, is_training,
test_local_stats=self._test_local_stats)
class _LayerNorm(hk.LayerNorm):
"""A `hk.LayerNorm` accepting (and discarding) an `is_training` argument."""
def __init__(self,
axis: Union[int, Sequence[int]] = (1, 2),
create_scale: bool = True,
create_offset: bool = True,
**kwargs):
super().__init__(axis=axis,
create_scale=create_scale,
create_offset=create_offset,
**kwargs)
def __call__(self,
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
del is_training # Unused.
return super().__call__(x)
_NORMALIZER_NAME_TO_CLASS = {
'batch_norm': _BatchNorm,
'cross_replica_batch_norm': _CrossReplicaBatchNorm,
'layer_norm': _LayerNorm,
}
def get_normalize_fn(
normalizer_name: str = 'batch_norm',
normalizer_kwargs: Optional[Dict[str, Any]] = None,
) -> types.NormalizeFn:
"""Handles NormalizeFn creation.
These functions are expected to be used as part of Haiku model. On each
application of the returned normalization_fn, a new Haiku layer will be added
to the model.
Args:
normalizer_name: The name of the normalizer to be constructed.
normalizer_kwargs: The kwargs passed to the normalizer constructor.
Returns:
A `types.NormalizeFn` that when applied will create a new layer.
Raises:
ValueError: If `normalizer_name` is unknown.
"""
# Check args.
if normalizer_name not in _NORMALIZER_NAME_TO_CLASS:
raise ValueError(f'Unrecognized `normalizer_name` {normalizer_name}.')
normalizer_class = _NORMALIZER_NAME_TO_CLASS[normalizer_name]
normalizer_kwargs = normalizer_kwargs or dict()
return lambda *a, **k: normalizer_class(**normalizer_kwargs)(*a, **k) # pylint: disable=unnecessary-lambda
+329
View File
@@ -0,0 +1,329 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3.
"""ResNet V2 modules.
Equivalent to hk.Resnet except accepting a final_endpoint to return
intermediate activations.
"""
from typing import Optional, Sequence, Text, Type, Union
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import types
class BottleneckBlock(hk.Module):
"""Implements a bottleneck residual block (ResNet50 and ResNet101)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BottleneckBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
self._conv_2 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_2')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class BasicBlock(hk.Module):
"""Implements a basic residual block (ResNet18 and ResNet34)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BasicBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class ResNetUnit(hk.Module):
"""Unit (group of blocks) for ResNet."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
num_blocks: int,
stride: Union[int, Sequence[int]],
block_module: Type[BottleneckBlock],
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
super(ResNetUnit, self).__init__(name=name)
self._channels = channels
self._num_blocks = num_blocks
self._stride = stride
self._normalize_fn = normalize_fn
self._block_module = block_module
self._remat = remat
def __call__(self,
inputs,
is_training):
input_channels = inputs.shape[-1]
self._blocks = []
for id_block in range(self._num_blocks):
use_projection = id_block == 0 and self._channels != input_channels
self._blocks.append(
self._block_module(
channels=self._channels,
stride=self._stride if id_block == 0 else 1,
use_projection=use_projection,
normalize_fn=self._normalize_fn,
name='block_%d' % id_block))
net = inputs
for block in self._blocks:
if self._remat:
# Note: we can ignore cell-var-from-loop because the lambda is evaluated
# inside every iteration of the loop. This is needed to go around the
# way variables are passed to jax.remat.
net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop
else:
net = block(net, is_training=is_training)
return net
class ResNetV2(hk.Module):
"""ResNetV2 model."""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'resnet_stem',
'resnet_unit_0',
'resnet_unit_1',
'resnet_unit_2',
'resnet_unit_3',
'last_conv',
'output',
)
# pylint:disable=g-bare-generic
def __init__(self,
depth=50,
num_classes: Optional[int] = 1000,
width_mult: int = 1,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
"""Creates ResNetV2 Haiku module.
Args:
depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
num_classes: (int) Number of outputs in final layer. If None will not add
a classification head and will return the output embedding.
width_mult: multiplier for channel width.
normalize_fn: normalization function, see helpers/utils.py
name: Name of the module.
remat: Whether to rematerialize intermediate activations (saves memory).
"""
super(ResNetV2, self).__init__(name=name)
self._normalize_fn = normalize_fn
self._num_classes = num_classes
self._width_mult = width_mult
self._strides = [1, 2, 2, 2]
num_blocks = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
200: [3, 24, 36, 3],
}
if depth not in num_blocks:
raise ValueError(
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
self._num_blocks = num_blocks[depth]
if depth >= 50:
self._block_module = BottleneckBlock
self._channels = [256, 512, 1024, 2048]
else:
self._block_module = BasicBlock
self._channels = [64, 128, 256, 512]
self._initial_conv = hk.Conv2D(
output_channels=64 * self._width_mult,
kernel_shape=7,
stride=2,
with_bias=False,
padding='SAME',
name='initial_conv')
if remat:
self._initial_conv = hk.remat(self._initial_conv)
self._block_groups = []
for i in range(4):
self._block_groups.append(
ResNetUnit(
channels=self._channels[i] * self._width_mult,
num_blocks=self._num_blocks[i],
block_module=self._block_module,
stride=self._strides[i],
normalize_fn=self._normalize_fn,
name='block_group_%d' % i,
remat=remat))
if num_classes is not None:
self._logits_layer = hk.Linear(
output_size=num_classes, w_init=jnp.zeros, name='logits')
def __call__(self, inputs, is_training, final_endpoint='output'):
self._final_endpoint = final_endpoint
net = self._initial_conv(inputs)
net = hk.max_pool(
net, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')
end_point = 'resnet_stem'
if self._final_endpoint == end_point:
return net
for i_group, block_group in enumerate(self._block_groups):
net = block_group(net, is_training=is_training)
end_point = f'resnet_unit_{i_group}'
if self._final_endpoint == end_point:
return net
end_point = 'last_conv'
if self._final_endpoint == end_point:
return net
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
# The actual representation
net = jnp.mean(net, axis=[1, 2])
assert self._final_endpoint == 'output'
if self._num_classes is None:
# If num_classes was None, we just return the output
# of the last block, without fully connected layer.
return net
return self._logits_layer(net)
+503
View File
File diff suppressed because it is too large Load Diff
+88
View File
@@ -0,0 +1,88 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for s3d."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import numpy as np
from mmv.models import normalization
from mmv.models import s3d
class _CallableS3D:
"""Wrapper around S3D that take care of parameter book keeping."""
def __init__(self, *args, **kwargs):
self._model = hk.transform_with_state(
lambda *a, **k: # pylint: disable=g-long-lambda,unnecessary-lambda
s3d.S3D(
normalize_fn=normalization.get_normalize_fn(),
*args, **kwargs)(*a, **k))
self._rng = jax.random.PRNGKey(42)
self._params, self._state = None, None
def init(self, inputs, **kwargs):
self._params, self._state = self._model.init(
self._rng, inputs, is_training=True, **kwargs)
def __call__(self, inputs, **kwargs):
if self._params is None:
self.init(inputs)
output, _ = self._model.apply(
self._params, self._state, self._rng, inputs, **kwargs)
return output
class S3DTest(parameterized.TestCase):
# Testing all layers is quite slow, added in comments for completeness.
@parameterized.parameters(
# dict(endpoint='Conv2d_1a_7x7', expected_size=(2, 8, 112, 112, 64)),
# dict(endpoint='MaxPool_2a_3x3', expected_size=(2, 8, 56, 56, 64)),
# dict(endpoint='Conv2d_2b_1x1', expected_size=(2, 8, 56, 56, 64)),
# dict(endpoint='Conv2d_2c_3x3', expected_size=(2, 8, 56, 56, 192)),
# dict(endpoint='MaxPool_3a_3x3', expected_size=(2, 8, 28, 28, 192)),
# dict(endpoint='Mixed_3b', expected_size=(2, 8, 28, 28, 256)),
# dict(endpoint='Mixed_3c', expected_size=(2, 8, 28, 28, 480)),
# dict(endpoint='MaxPool_4a_3x3', expected_size=(2, 4, 14, 14, 480)),
# dict(endpoint='Mixed_4b', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4c', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4d', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4e', expected_size=(2, 4, 14, 14, 528)),
# dict(endpoint='Mixed_4f', expected_size=(2, 4, 14, 14, 832)),
# dict(endpoint='MaxPool_5a_2x2', expected_size=(2, 2, 7, 7, 832)),
# dict(endpoint='Mixed_5b', expected_size=(2, 2, 7, 7, 832)),
# dict(endpoint='Mixed_5c', expected_size=(2, 2, 7, 7, 1024)),
dict(endpoint='Embeddings', expected_size=(2, 1024)),
)
def test_endpoint_expected_output_dimensions(self, endpoint, expected_size):
inputs = np.random.normal(size=(2, 16, 224, 224, 3))
model = _CallableS3D()
output = model(inputs, is_training=False, final_endpoint=endpoint)
self.assertSameElements(output.shape, expected_size)
def test_space_to_depth(self):
inputs = np.random.normal(size=(2, 16//2, 224//2, 224//2, 3*2*2*2))
model = _CallableS3D()
output = model(inputs, is_training=False, final_endpoint='Conv2d_1a_7x7')
self.assertSameElements(output.shape, (2, 8, 112, 112, 64))
if __name__ == '__main__':
absltest.main()
+353
View File
@@ -0,0 +1,353 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Temporal Shift Module w/ ResNet-50 and ResNet-101.
Based on:
TSM: Temporal Shift Module for Efficient Video Understanding
Ji Lin, Chuang Gan, Song Han
https://arxiv.org/pdf/1811.08383.pdf.
"""
from typing import Optional
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import tsm_utils as tsmu
from mmv.models import types
class TSMResNetBlock(hk.Module):
"""A ResNet subblock with Temporal Channel Shifting.
Combines a typical ResNetV2 block implementation
(see https://arxiv.org/abs/1512.03385) with a pre-convolution Temporal
Shift Module (see https://arxiv.org/pdf/1811.08383.pdf) in the residual.
"""
def __init__(self,
output_channels: int,
stride: int,
use_projection: bool,
tsm_mode: str,
normalize_fn: Optional[types.NormalizeFn] = None,
channel_shift_fraction: float = 0.125,
num_frames: int = 8,
name: str = 'TSMResNetBlock'):
"""Initializes the TSMResNetBlock module.
Args:
output_channels: Number of output channels.
stride: Stride used in convolutions.
use_projection: Whether to use a projection for the shortcut.
tsm_mode: Mode for TSM ('gpu' or 'tpu').
normalize_fn: Function used for normalization.
channel_shift_fraction: The fraction of temporally shifted channels. If
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
block.
num_frames: Size of frame dimension in a single batch example
name: The name of the module.
"""
super().__init__(name=name)
self._output_channels = output_channels
self._bottleneck_channels = output_channels // 4
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
self._tsm_mode = tsm_mode
self._channel_shift_fraction = channel_shift_fraction
self._num_frames = num_frames
def __call__(self,
inputs: types.TensorLike,
is_training: bool = True) -> jnp.ndarray:
"""Connects the ResNetBlock module into the graph.
Args:
inputs: A 4-D float array of shape `[B, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 4-D float array of shape
`[B * num_frames, new_h, new_w, output_channels]`.
"""
# ResNet V2 uses pre-activation, where the batch norm and relu are before
# convolutions, rather than after as in ResNet V1.
preact = inputs
if self._normalize_fn is not None:
preact = self._normalize_fn(preact, is_training=is_training)
preact = jax.nn.relu(preact)
if self._use_projection:
shortcut = hk.Conv2D(
output_channels=self._output_channels,
kernel_shape=1,
stride=self._stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')(
preact)
else:
shortcut = inputs
# Eventually applies Temporal Shift Module.
if self._channel_shift_fraction != 0:
preact = tsmu.apply_temporal_shift(
preact, tsm_mode=self._tsm_mode, num_frames=self._num_frames,
channel_shift_fraction=self._channel_shift_fraction)
# First convolution.
residual = hk.Conv2D(
self._bottleneck_channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')(
preact)
# Second convolution.
if self._normalize_fn is not None:
residual = self._normalize_fn(residual, is_training=is_training)
residual = jax.nn.relu(residual)
residual = hk.Conv2D(
output_channels=self._bottleneck_channels,
kernel_shape=3,
stride=self._stride,
with_bias=False,
padding='SAME',
name='conv_1')(
residual)
# Third convolution.
if self._normalize_fn is not None:
residual = self._normalize_fn(residual, is_training=is_training)
residual = jax.nn.relu(residual)
residual = hk.Conv2D(
output_channels=self._output_channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_2')(
residual)
# NOTE: we do not use block multiplier.
output = shortcut + residual
return output
class TSMResNetUnit(hk.Module):
"""Block group for TSM ResNet."""
def __init__(self,
output_channels: int,
num_blocks: int,
stride: int,
tsm_mode: str,
num_frames: int,
normalize_fn: Optional[types.NormalizeFn] = None,
channel_shift_fraction: float = 0.125,
name: str = 'tsm_resnet_unit'):
"""Creates a TSMResNet Unit.
Args:
output_channels: Number of output channels.
num_blocks: Number of ResNet blocks in the unit.
stride: Stride of the unit.
tsm_mode: Which temporal shift module to use.
num_frames: Size of frame dimension in a single batch example.
normalize_fn: Function used for normalization.
channel_shift_fraction: The fraction of temporally shifted channels. If
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
block.
name: The name of the module.
"""
super().__init__(name=name)
self._output_channels = output_channels
self._num_blocks = num_blocks
self._normalize_fn = normalize_fn
self._stride = stride
self._tsm_mode = tsm_mode
self._channel_shift_fraction = channel_shift_fraction
self._num_frames = num_frames
def __call__(self,
inputs: types.TensorLike,
is_training: bool) -> jnp.ndarray:
"""Connects the module to inputs.
Args:
inputs: A 4-D float array of shape `[B * num_frames, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 4-D float array of shape
`[B * num_frames, H // stride, W // stride, output_channels]`.
"""
net = inputs
for idx_block in range(self._num_blocks):
net = TSMResNetBlock(
self._output_channels,
stride=self._stride if idx_block == 0 else 1,
use_projection=idx_block == 0,
normalize_fn=self._normalize_fn,
tsm_mode=self._tsm_mode,
channel_shift_fraction=self._channel_shift_fraction,
num_frames=self._num_frames,
name=f'block_{idx_block}')(
net, is_training=is_training)
return net
class TSMResNetV2(hk.Module):
"""TSM based on ResNet V2 as described in https://arxiv.org/abs/1603.05027."""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'tsm_resnet_stem',
'tsm_resnet_unit_0',
'tsm_resnet_unit_1',
'tsm_resnet_unit_2',
'tsm_resnet_unit_3',
'last_conv',
'Embeddings',
)
def __init__(self,
normalize_fn: Optional[types.NormalizeFn] = None,
depth: int = 50,
num_frames: int = 16,
channel_shift_fraction: float = 0.125,
width_mult: int = 1,
name: str = 'TSMResNetV2'):
"""Constructs a ResNet model.
Args:
normalize_fn: Function used for normalization.
depth: Depth of the desired ResNet.
num_frames: Number of frames (used in TPU mode).
channel_shift_fraction: Fraction of channels that are temporally shifted,
if `channel_shift_fraction` is 0, a regular ResNet is returned.
width_mult: Whether or not to use a width multiplier.
name: The name of the module.
Raises:
ValueError: If `channel_shift_fraction` or `depth` has invalid value.
"""
super().__init__(name=name)
if not 0. <= channel_shift_fraction <= 1.0:
raise ValueError(
f'channel_shift_fraction ({channel_shift_fraction})'
' has to be in [0, 1].')
self._num_frames = num_frames
self._channels = (256, 512, 1024, 2048)
self._strides = (1, 2, 2, 2)
num_blocks = {
50: (3, 4, 6, 3),
101: (3, 4, 23, 3),
152: (3, 8, 36, 3),
200: (3, 24, 36, 3),
}
if depth not in num_blocks:
raise ValueError(
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
self._num_blocks = num_blocks[depth]
self._width_mult = width_mult
self._channel_shift_fraction = channel_shift_fraction
self._normalize_fn = normalize_fn
def __call__(
self,
inputs: types.TensorLike,
is_training: bool = True,
final_endpoint: str = 'Embeddings') -> jnp.ndarray:
"""Connects the TSM ResNetV2 module into the graph.
Args:
inputs: A 4-D float array of shape `[B, H, W, C]`.
is_training: Whether to use training mode.
final_endpoint: Up to which endpoint to run / return.
Returns:
Network output at location `final_endpoint`. A float array which shape
depends on `final_endpoint`.
Raises:
ValueError: If `final_endpoint` is not recognized.
"""
# Prepare inputs for TSM.
inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs)
num_frames = num_frames or self._num_frames
self._final_endpoint = final_endpoint
if self._final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError(f'Unknown final endpoint {self._final_endpoint}')
# Stem convolution.
end_point = 'tsm_resnet_stem'
net = hk.Conv2D(
output_channels=64 * self._width_mult,
kernel_shape=7,
stride=2,
with_bias=False,
name=end_point,
padding='SAME')(
inputs)
net = hk.MaxPool(
window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')(
net)
if self._final_endpoint == end_point:
return net
# Residual block.
for unit_id, (channels, num_blocks, stride) in enumerate(
zip(self._channels, self._num_blocks, self._strides)):
end_point = f'tsm_resnet_unit_{unit_id}'
net = TSMResNetUnit(
output_channels=channels * self._width_mult,
num_blocks=num_blocks,
stride=stride,
normalize_fn=self._normalize_fn,
channel_shift_fraction=self._channel_shift_fraction,
num_frames=num_frames,
tsm_mode=tsm_mode,
name=end_point)(
net, is_training=is_training)
if self._final_endpoint == end_point:
return net
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
end_point = 'last_conv'
if self._final_endpoint == end_point:
return net
net = jnp.mean(net, axis=(1, 2))
# Prepare embedding outputs for TSM (temporal average of features).
net = tsmu.prepare_outputs(net, tsm_mode, num_frames)
assert self._final_endpoint == 'Embeddings'
return net
+65
View File
@@ -0,0 +1,65 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TSM ResNet model."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import tsm_resnet
class TSMResNetTest(parameterized.TestCase):
@parameterized.parameters(
('tsm_resnet_stem', (2 * 32, 56, 56, 64)),
('tsm_resnet_unit_0', (2 * 32, 56, 56, 256)),
('tsm_resnet_unit_1', (2 * 32, 28, 28, 512)),
('tsm_resnet_unit_2', (2 * 32, 14, 14, 1024)),
('tsm_resnet_unit_3', (2 * 32, 7, 7, 2048)),
('last_conv', (2 * 32, 7, 7, 2048)),
('Embeddings', (2, 2048)),
)
def test_output_dimension(self, final_endpoint, expected_shape):
input_shape = (2, 32, 224, 224, 3)
def f():
data = jnp.zeros(input_shape)
net = tsm_resnet.TSMResNetV2()
return net(data, final_endpoint=final_endpoint)
init_fn, apply_fn = hk.transform(f)
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
self.assertEqual(out.shape, expected_shape)
def test_tpu_mode(self):
input_shape = (32 * 2, 224, 224, 3)
def f():
data = jnp.zeros(input_shape)
net = tsm_resnet.TSMResNetV2(num_frames=32)
return net(data, final_endpoint='Embeddings')
init_fn, apply_fn = hk.transform(f)
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
self.assertEqual(out.shape, (2, 2048))
if __name__ == '__main__':
absltest.main()
+177
View File
@@ -0,0 +1,177 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils functions for TSM."""
from typing import Tuple
import jax
import jax.numpy as jnp
from mmv.models import types
def prepare_inputs(
inputs: types.TensorLike) -> Tuple[jnp.ndarray, str, int]:
"""Deduces input mode for TSM."""
# Deduce if we run on TPU based on input shape.
if len(inputs.shape) == 5:
# Input is given in the standard [B, T, H, W, 3] format.
tsm_mode = 'gpu'
num_frames = inputs.shape[1]
inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:]))
else:
# Input is given in the [T * B, H, W, 3] format.
tsm_mode = 'tpu'
num_frames = None
return inputs, tsm_mode, num_frames
def prepare_outputs(outputs: types.TensorLike,
tsm_mode: str,
num_frames: int) -> jnp.ndarray:
"""Processes output of TSM by averaging representations over time axis."""
n_channels = outputs.shape[-1]
if tsm_mode == 'tpu':
outputs = jnp.reshape(outputs, [num_frames, -1, n_channels])
outputs = jnp.mean(outputs, axis=0)
elif tsm_mode == 'gpu':
outputs = jnp.reshape(outputs, [-1, num_frames, n_channels])
outputs = jnp.mean(outputs, axis=1)
else:
raise ValueError(
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
return outputs
def apply_temporal_shift(
x: types.TensorLike,
tsm_mode: str,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode."""
if tsm_mode == 'tpu':
outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction)
elif tsm_mode == 'gpu':
outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction)
else:
raise ValueError(
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
return outputs
def temporal_shift_gpu(
x: types.TensorLike,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383."""
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
# Input is (B * T, H, W, C)
orig_shp = tuple(x.shape)
reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:])
n_channels = orig_shp[-1]
n_shift = int(n_channels * channel_shift_fraction)
new_shp = tuple(reshaped_x.shape)
# shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]
shifted_backward = jax.lax.slice(
reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift),
(new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4]))
shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0))
shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding)
# shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]
shifted_forward = jax.lax.slice(
reshaped_x, (0, 0, 0, 0, 0),
(new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift))
shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0))
shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding)
no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift]
shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward],
axis=4)
return jnp.reshape(shifted_x, (-1,) + orig_shp[1:])
def temporal_shift_tpu(
x: types.TensorLike,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383.
TPU optimized version of TSM. Reshape is avoided by having the images
reshaped in [T * B, :] so that frames corresponding to same time frame in
videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse
pad->slice into convolution, we reformulate the slice pad into a pad then
slice. Finally, to avoid concatenate that prevent some fusion from happening
we simply sum masked version of the features.
Args:
x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped
from a time major version of the input).
num_frames: number of frames T per video.
channel_shift_fraction: fraction of the channel to shift forward and
backward.
Returns:
The temporal shifted version of x.
"""
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
# Input is (T * B, H, W, C)
original_shape = list(x.shape)
batch_size = int(original_shape[0] / num_frames)
n_channels = int(original_shape[-1])
n_shift = int(n_channels * channel_shift_fraction)
# Cast to bfloat16.
x = x.astype(jnp.bfloat16)
# For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1.
# Shift backward, we first pad by zeros [x1, x2, x3, 0, 0].
orig_shp = list(x.shape)
shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0),
(0, n_channels - n_shift, 0))
x_backward_padding = jax.lax.pad(
x,
padding_value=jnp.bfloat16(0.),
padding_config=shifted_backward_padding)
# The following shift gets to [x3^+1, 0, 0] (where +1 means from the future).
shifted_backward = jax.lax.slice(x_backward_padding,
(batch_size, 0, 0, n_channels - n_shift),
(orig_shp[0] + batch_size, orig_shp[1],
orig_shp[2], 2 * n_channels - n_shift))
# Shift forward, we first pad by zeros [0, 0, x1, x2, x3].
shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0),
(n_channels - n_shift, 0, 0))
x_forward_padding = jax.lax.pad(
x,
padding_value=jnp.bfloat16(0.),
padding_config=shifted_forward_padding)
# The following shift gets to [0, 0, x1^-1] (where -1 means from the past).
shifted_forward = jax.lax.slice(
x_forward_padding, (0, 0, 0, 0),
(orig_shp[0], orig_shp[1], orig_shp[2], n_channels))
# No shift is in the middle, this gets [0, x2, 0].
mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) &
(jnp.arange(n_channels) < n_channels - n_shift),
(1, 1, 1, -1))).astype(jnp.bfloat16)
no_shift = mask_noshift * x
# By summing everything together, we end up with [x3^+1, x2, x1^-1].
# Note: channels have been reordered but that doesn't matter for the model.
shifted_x = shifted_backward + shifted_forward + no_shift
return shifted_x.astype(jnp.float32)
+60
View File
@@ -0,0 +1,60 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tsm_utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
from mmv.models import tsm_utils
class TsmUtilsTest(parameterized.TestCase):
@parameterized.parameters(
((2, 32, 224, 224, 3), 'gpu', (2 * 32, 224, 224, 3), 32),
((32, 224, 224, 3), 'tpu', (32, 224, 224, 3), None),
)
def test_prepare_inputs(self, input_shape, expected_mode, expected_shape,
expected_num_frames):
data = jnp.zeros(input_shape)
out, mode, num_frames = tsm_utils.prepare_inputs(data)
self.assertEqual(out.shape, expected_shape)
self.assertEqual(mode, expected_mode)
self.assertEqual(num_frames, expected_num_frames)
def test_prepare_outputs(self):
data = jnp.concatenate([jnp.zeros(4), jnp.ones(4)]).reshape(4, 2)
out_gpu = tsm_utils.prepare_outputs(data, 'gpu', 2)
out_tpu = tsm_utils.prepare_outputs(data, 'tpu', 2)
expected_gpu = np.concatenate([np.zeros(2), np.ones(2)]).reshape(2, 2)
expected_tpu = 0.5 * jnp.ones((2, 2))
np.testing.assert_allclose(out_gpu, expected_gpu)
np.testing.assert_allclose(out_tpu, expected_tpu)
def test_apply_tsm(self):
shape = (32, 224, 224, 16)
data = jnp.zeros(shape)
out_gpu = tsm_utils.apply_temporal_shift(data, 'gpu', 16)
out_tpu = tsm_utils.apply_temporal_shift(data, 'tpu', 16)
self.assertEqual(out_gpu.shape, shape)
self.assertEqual(out_tpu.shape, shape)
if __name__ == '__main__':
absltest.main()
+36
View File
@@ -0,0 +1,36 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type Aliases."""
from typing import Callable, Tuple, Union
import jax.numpy as jnp
import numpy as np
import optax
TensorLike = Union[np.ndarray, jnp.DeviceArray]
ActivationFn = Callable[[TensorLike], TensorLike]
GatingFn = Callable[[TensorLike], TensorLike]
NetworkFn = Callable[[TensorLike], TensorLike]
# Callable doesn't allow kwargs to be used, and we often want to
# pass in is_training=..., so ignore the arguments for the sake of pytype.
NormalizeFn = Callable[..., TensorLike]
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
+9
View File
@@ -0,0 +1,9 @@
dm-haiku
dm-tree
jax
jaxlib
numpy>=1.16
optax
sklearn
tensorflow
tensorflow_datasets
+29
View File
@@ -0,0 +1,29 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint restoring utilities."""
from absl import logging
import dill
def load_checkpoint(checkpoint_path):
try:
with open(checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s', checkpoint_path)
return checkpoint_data
except FileNotFoundError:
return None
+70
View File
@@ -0,0 +1,70 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ucf101 with custom decoding params."""
import tensorflow as tf
import tensorflow_datasets as tfds
# Utilities functions.
tf.compat.v1.enable_eager_execution()
_CITATION = """\
@article{DBLP:journals/corr/abs-1212-0402,
author = {Khurram Soomro and
Amir Roshan Zamir and
Mubarak Shah},
title = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in
The Wild},
journal = {CoRR},
volume = {abs/1212.0402},
year = {2012},
url = {http://arxiv.org/abs/1212.0402},
archivePrefix = {arXiv},
eprint = {1212.0402},
timestamp = {Mon, 13 Aug 2018 16:47:45 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
_LABELS_FNAME = 'video/ucf101_labels.txt'
class ModUcf101(tfds.video.Ucf101):
"""Ucf101 action recognition dataset with better quality.
"""
def _info(self):
ffmpeg_extra_args = ('-qscale:v', '2', '-r', '25', '-t', '00:00:20')
video_shape = (
None, self.builder_config.height, self.builder_config.width, 3)
labels_names_file = tfds.core.tfds_path(_LABELS_FNAME)
features = tfds.features.FeaturesDict({
'video': tfds.features.Video(video_shape,
ffmpeg_extra_args=ffmpeg_extra_args,
encoding_format='jpeg'),
'label': tfds.features.ClassLabel(names_file=labels_names_file),
})
return tfds.core.DatasetInfo(
builder=self,
description='A 101-label video classification dataset.',
features=features,
homepage='https://www.crcv.ucf.edu/data-sets/ucf101/',
citation=_CITATION,
)