mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Initial release of "mmv".
PiperOrigin-RevId: 346305536
This commit is contained in:
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Self-Supervised MultiModal Versatile Networks](mmv), NeurIPS 2020
|
||||
* [ODE-GAN: Training GANs by Solving Ordinary Differential Equations](ode_gan), NeurIPS 2020
|
||||
* [Algorithms for Causal Reasoning in Probability Trees](causal_reasoning)
|
||||
* [Gated Linear Networks](gated_linear_networks), NeurIPS 2020
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Self-supervised Multimodal Versatile Networks
|
||||
|
||||
This is the code for the models in MMV - https://arxiv.org/abs/2006.16228.
|
||||
|
||||
<img src="./imgs/mmv_fig.png" width="50%">
|
||||
|
||||
We also make available the code for linear evaluation of a pre-trained model
|
||||
in UCF101 and the JAX checkpoints for our best models.
|
||||
|
||||
We use different parameters for video compression in UCF101 than the ones
|
||||
used in `tensorflow_datasets`. We provide the code to download and
|
||||
preprocess the dataset. The eval_ucf101.py script reproduces the results we
|
||||
report in Table 2 of the paper, using the checkpoints provided below.
|
||||
|
||||
Visual Backbone | Training Dataset | Results on Linear UCF101
|
||||
------- | -------- | --------
|
||||
S3D-G | AudioSet + HowTo | 89.6
|
||||
Resnet TSM-50 | AudioSet + HowTo | 91.5
|
||||
Resnet TSM-50 (x2) | AudioSet + HowTo | 91.8
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
To set up a Python virtual environment with the required dependencies, run:
|
||||
|
||||
```shell
|
||||
python3 -m venv mmv_env
|
||||
source mmv_env/bin/activate
|
||||
pip install --upgrade pip setuptools wheel
|
||||
pip install -r mmv/requirements.txt --use-feature=2020-resolver
|
||||
```
|
||||
|
||||
|
||||
### Linear evaluation
|
||||
|
||||
The linear evaluation on UCF101 can be run using:
|
||||
|
||||
```shell
|
||||
python -m mmv.eval_ucf101 \
|
||||
--checkpoint_path=</path/to/the/checkpointing/folder> \
|
||||
--dataset_folder=</path/to/dataset/folder>
|
||||
```
|
||||
|
||||
## Checkpoints
|
||||
|
||||
We provide three checkpoints containing the best pre-trained weights for each
|
||||
of the visual backbones we use in the paper, i. e., S3D-G, Resnet-50 TSM,
|
||||
and Resnet-50 TSM x 2.
|
||||
|
||||
- [S3D-G](https://storage.googleapis.com/deepmind-research-mmv/mmv_s3d.pkl)
|
||||
- [Resnet-50 TSM](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x1.pkl)
|
||||
- [Resnet-50 TSMx2](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x2.pkl)
|
||||
|
||||
## References
|
||||
|
||||
### Citing our work
|
||||
|
||||
If you use that code for your research, please consider citing our paper:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{alayrac2020self,
|
||||
title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks},
|
||||
author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew},
|
||||
booktitle={NeurIPS},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
### Models in TF
|
||||
|
||||
You may also be interested in using our TF-Hub release models available at:
|
||||
|
||||
- [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1)
|
||||
- [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1)
|
||||
- [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50x2/1)
|
||||
|
||||
## License
|
||||
|
||||
While the code is licensed under the Apache 2.0 License, the checkpoints weights
|
||||
are made available for non-commercial use only under the terms of the
|
||||
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
||||
license. You can find details at:
|
||||
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration parameters for MMV."""
|
||||
|
||||
|
||||
def get_model_config(ckpt_path):
|
||||
"""Returns the model configuration to be used with each checkpoint."""
|
||||
|
||||
config = {
|
||||
'audio_backbone': 'resnet50',
|
||||
'audio_model_kwargs': {
|
||||
'bn_config': {
|
||||
'create_offset': True,
|
||||
'create_scale': True,
|
||||
'decay_rate': 0.9,
|
||||
'eps': 1.0e-5
|
||||
}
|
||||
},
|
||||
'bn_config_proj': {
|
||||
'create_offset': True,
|
||||
'create_scale': True,
|
||||
'decay_rate': 0.9,
|
||||
'eps': 1.0e-5
|
||||
},
|
||||
'config_audio_text': {
|
||||
'embedding_dim': 512,
|
||||
'toaud_bn_after_proj': False,
|
||||
'toaud_head_mode': 'linear',
|
||||
'totxt_bn_after_proj': False,
|
||||
'totxt_head_mode': 'linear'
|
||||
},
|
||||
'config_video_audio': {
|
||||
'embedding_dim': 512,
|
||||
'toaud_bn_after_proj': True,
|
||||
'toaud_head_mode': 'mlp@512',
|
||||
'tovid_bn_after_proj': False,
|
||||
'tovid_head_mode': 'linear'
|
||||
},
|
||||
'config_video_text': {
|
||||
'embedding_dim': 256,
|
||||
'totxt_bn_after_proj': True,
|
||||
'totxt_head_mode': 'linear',
|
||||
'tovid_bn_after_proj': False,
|
||||
'tovid_head_mode': 'linear'
|
||||
},
|
||||
'mm_embedding_graph': 'fac_relu',
|
||||
'name': 'text_audio_video',
|
||||
'sentence_dim': 2048,
|
||||
'use_xreplica_bn': True,
|
||||
'vision_model_kwargs': {
|
||||
'bn_config': {
|
||||
'create_offset': True,
|
||||
'create_scale': True,
|
||||
'decay_rate': 0.9,
|
||||
'eps': 1.0e-5
|
||||
},
|
||||
'n_frames': 32,
|
||||
'width_mult': 1,
|
||||
},
|
||||
}
|
||||
|
||||
if 's3d' in ckpt_path:
|
||||
config['visual_backbone'] = 's3d'
|
||||
|
||||
if 'tsm_resnet_x1' in ckpt_path:
|
||||
config['visual_backbone'] = 'resnet50tsm'
|
||||
|
||||
if 'tsm_resnet_x2' in ckpt_path:
|
||||
config['visual_backbone'] = 'resnet50tsm'
|
||||
config['vision_model_kwargs']['width_mult'] = 2
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,465 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""UCF101 linear evaluation."""
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import sklearn
|
||||
from sklearn import preprocessing
|
||||
import sklearn.linear_model
|
||||
import sklearn.svm
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from mmv import config
|
||||
from mmv.models import mm_embeddings
|
||||
from mmv.utils import checkpoint
|
||||
from mmv.utils import ucf101_dataset
|
||||
|
||||
|
||||
flags.DEFINE_string('checkpoint_path', '~/tmp/mmv_s3d.pkl',
|
||||
'The directory to load pre-trained weights from.')
|
||||
flags.DEFINE_string('dataset_folder', '/tmp/ucf101',
|
||||
'The directory with the ucf101 dataset.')
|
||||
|
||||
flags.DEFINE_integer('eval_batch_size', 1,
|
||||
'The batch size for evaluation.')
|
||||
flags.DEFINE_integer('train_batch_size', 16,
|
||||
'The batch size for training.')
|
||||
flags.DEFINE_integer('num_train_epochs', 10,
|
||||
'How many epochs to collect features during training.')
|
||||
flags.DEFINE_integer('num_test_windows', 10,
|
||||
'How many windows to average on during test.')
|
||||
flags.DEFINE_integer('min_resize', 224,
|
||||
'Min value to resize images to during preprocessing.')
|
||||
flags.DEFINE_integer('crop_size', 224,
|
||||
'Value to resize images to during preprocessing.')
|
||||
flags.DEFINE_integer('num_frames', 32,
|
||||
'Number of video frames.')
|
||||
flags.DEFINE_integer('stride', 2,
|
||||
'Stride for video frames.')
|
||||
flags.DEFINE_integer('ucf101_split', 1,
|
||||
'Which split of ucf101 to use.')
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def get_sampling_offset(sequence: tf.Tensor,
|
||||
num_steps: Optional[int],
|
||||
is_training: bool,
|
||||
stride: int = 1,
|
||||
seed: Optional[int] = None) -> tf.Tensor:
|
||||
"""Calculates the initial offset for a sequence where all steps will fit.
|
||||
|
||||
Args:
|
||||
sequence: any tensor where the first dimension is timesteps.
|
||||
num_steps: The number of timesteps we will output. If None,
|
||||
deterministically start at the first frame.
|
||||
is_training: A boolean indicates whether the graph is for training or not.
|
||||
If False, the starting frame always the first frame.
|
||||
stride: distance to sample between timesteps.
|
||||
seed: a deterministic seed to use when sampling.
|
||||
Returns:
|
||||
The first index to begin sampling from. A best effort is made to provide a
|
||||
starting index such that all requested steps fit within the sequence (i.e.
|
||||
offset + 1 + (num_steps - 1) * stride < len(sequence)). If this is not
|
||||
satisfied, the starting index is chosen randomly from the full sequence.
|
||||
"""
|
||||
if num_steps is None or not is_training:
|
||||
return tf.constant(0)
|
||||
sequence_length = tf.shape(sequence)[0]
|
||||
max_offset = tf.cond(
|
||||
tf.greater(sequence_length, (num_steps - 1) * stride),
|
||||
lambda: sequence_length - (num_steps - 1) * stride,
|
||||
lambda: sequence_length)
|
||||
offset = tf.random.uniform(
|
||||
(),
|
||||
maxval=tf.cast(max_offset, tf.int32),
|
||||
dtype=tf.int32,
|
||||
seed=seed)
|
||||
return offset
|
||||
|
||||
|
||||
def sample_or_pad_sequence_indices(sequence: tf.Tensor,
|
||||
num_steps: Optional[int],
|
||||
is_training: bool,
|
||||
repeat_sequence: bool = True,
|
||||
stride: int = 1,
|
||||
offset: Optional[int] = None) -> tf.Tensor:
|
||||
"""Returns indices to take for sampling or padding a sequence to fixed size.
|
||||
|
||||
Samples num_steps from the sequence. If the sequence is shorter than
|
||||
num_steps, the sequence loops. If the sequence is longer than num_steps and
|
||||
is_training is True, then we seek to a random offset before sampling. If
|
||||
offset is provided, we use that deterministic offset.
|
||||
|
||||
This method is appropriate for sampling from a tensor where you want every
|
||||
timestep between a start and end time. See sample_stacked_sequence_indices for
|
||||
more flexibility.
|
||||
|
||||
Args:
|
||||
sequence: any tensor where the first dimension is timesteps.
|
||||
num_steps: how many steps (e.g. frames) to take. If None, all steps from
|
||||
start to end are considered and `is_training` has no effect.
|
||||
is_training: A boolean indicates whether the graph is for training or not.
|
||||
If False, the starting frame is deterministic.
|
||||
repeat_sequence: A boolean indicates whether the sequence will repeat to
|
||||
have enough steps for sampling. If False, a runtime error is thrown if
|
||||
num_steps * stride is longer than sequence length.
|
||||
stride: distance to sample between timesteps.
|
||||
offset: a deterministic offset to use regardless of the is_training value.
|
||||
|
||||
Returns:
|
||||
Indices to gather from the sequence Tensor to get a fixed size sequence.
|
||||
"""
|
||||
sequence_length = tf.shape(sequence)[0]
|
||||
sel_idx = tf.range(sequence_length)
|
||||
|
||||
if num_steps:
|
||||
if offset is None:
|
||||
offset = get_sampling_offset(sequence, num_steps, is_training, stride)
|
||||
|
||||
if repeat_sequence:
|
||||
# Repeats sequence until num_steps are available in total.
|
||||
num_repeats = tf.cast(
|
||||
tf.math.ceil(
|
||||
tf.math.divide(
|
||||
tf.cast(num_steps * stride + offset, tf.float32),
|
||||
tf.cast(sequence_length, tf.float32)
|
||||
)), tf.int32)
|
||||
sel_idx = tf.tile(sel_idx, [num_repeats])
|
||||
steps = tf.range(offset, offset + num_steps * stride, stride)
|
||||
else:
|
||||
steps = tf.range(0, sequence_length, stride)
|
||||
return tf.gather(sel_idx, steps)
|
||||
|
||||
|
||||
def random_sample_sequence(sequence: tf.Tensor,
|
||||
num_steps: int,
|
||||
stride: int = 1) -> tf.Tensor:
|
||||
"""Randomly sample a segment of size num_steps from a given sequence."""
|
||||
|
||||
indices = sample_or_pad_sequence_indices(
|
||||
sequence=sequence,
|
||||
num_steps=num_steps,
|
||||
is_training=True, # Random sample.
|
||||
repeat_sequence=True, # Will repeat the sequence if request more.
|
||||
stride=stride,
|
||||
offset=None)
|
||||
indices.set_shape((num_steps,))
|
||||
output = tf.gather(sequence, indices)
|
||||
return output
|
||||
|
||||
|
||||
def sample_linspace_sequence(sequence: tf.Tensor,
|
||||
num_windows: int,
|
||||
num_steps: int,
|
||||
stride: int = 1) -> tf.Tensor:
|
||||
"""Samples num_windows segments from sequence with linearly spaced offsets.
|
||||
|
||||
The samples are concatenated in a single Tensor in order to have the same
|
||||
format structure per timestep (e.g. a single frame). If num_steps * stride is
|
||||
bigger than the number of timesteps, the sequence is repeated. This function
|
||||
can be used in evaluation in order to extract enough segments in order to span
|
||||
the entire sequence.
|
||||
|
||||
Args:
|
||||
sequence: Any tensor where the first dimension is timesteps.
|
||||
num_windows: Number of windows retrieved from the sequence.
|
||||
num_steps: Number of steps (e.g. frames) to take.
|
||||
stride: Distance to sample between timesteps.
|
||||
|
||||
Returns:
|
||||
A single Tensor with first dimension num_windows * num_steps. The Tensor
|
||||
contains the concatenated list of num_windows tensors which offsets have
|
||||
been linearly spaced from input.
|
||||
"""
|
||||
sequence_length = tf.shape(sequence)[0]
|
||||
max_offset = tf.maximum(0, sequence_length - num_steps * stride)
|
||||
offsets = tf.linspace(0.0, tf.cast(max_offset, tf.float32), num_windows)
|
||||
offsets = tf.cast(offsets, tf.int32)
|
||||
|
||||
all_indices = []
|
||||
for i in range(num_windows):
|
||||
all_indices.append(
|
||||
sample_or_pad_sequence_indices(
|
||||
sequence=sequence,
|
||||
num_steps=num_steps,
|
||||
is_training=False,
|
||||
repeat_sequence=True, # Will repeat the sequence if request more.
|
||||
stride=stride,
|
||||
offset=offsets[i]))
|
||||
|
||||
indices = tf.concat(all_indices, axis=0)
|
||||
indices.set_shape((num_windows * num_steps,))
|
||||
output = tf.gather(sequence, indices)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor:
|
||||
"""Resizes frames so that min(height, width) is equal to min_resize.
|
||||
|
||||
This function will not do anything if the min(height, width) is already equal
|
||||
to min_resize. This allows to save compute time.
|
||||
|
||||
Args:
|
||||
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
|
||||
min_resize: Minimum size of the final image dimensions.
|
||||
Returns:
|
||||
A Tensor of shape [timesteps, output_h, output_w, channels] of type
|
||||
frames.dtype where min(output_h, output_w) = min_resize.
|
||||
"""
|
||||
shape = tf.shape(frames)
|
||||
input_h = shape[1]
|
||||
input_w = shape[2]
|
||||
|
||||
output_h = tf.maximum(min_resize, (input_h * min_resize) // input_w)
|
||||
output_w = tf.maximum(min_resize, (input_w * min_resize) // input_h)
|
||||
|
||||
def resize_fn():
|
||||
frames_resized = tf.image.resize(frames, (output_h, output_w))
|
||||
return tf.cast(frames_resized, frames.dtype)
|
||||
|
||||
should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w),
|
||||
tf.not_equal(input_h, output_h))
|
||||
frames = tf.cond(should_resize, resize_fn, lambda: frames)
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def process_samples(features_dict, num_frames=32, stride=1, is_training=True,
|
||||
min_resize=224, crop_size=224, num_windows=1):
|
||||
"""Process video frames."""
|
||||
|
||||
video = features_dict['video']
|
||||
|
||||
if is_training:
|
||||
assert num_windows == 1
|
||||
video = random_sample_sequence(video, num_frames, stride)
|
||||
is_flipped = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
|
||||
video = tf.cond(tf.equal(is_flipped, 1),
|
||||
true_fn=lambda: tf.image.flip_left_right(video),
|
||||
false_fn=lambda: video)
|
||||
else:
|
||||
video = sample_linspace_sequence(video, num_windows, num_frames, stride)
|
||||
|
||||
# Resize smallest side.
|
||||
video = resize_smallest(video, min_resize)
|
||||
|
||||
if is_training:
|
||||
# Random crop.
|
||||
video = tf.image.random_crop(video, [num_frames, crop_size, crop_size, 3])
|
||||
else:
|
||||
# Central crop.
|
||||
video = tf.image.resize_with_crop_or_pad(video, crop_size, crop_size)
|
||||
|
||||
video = tf.cast(video, tf.float32)
|
||||
video /= 255.0 # Set between [0, 1].
|
||||
|
||||
features_dict['video'] = video
|
||||
return features_dict
|
||||
|
||||
|
||||
def space_to_depth_batch(features_dict):
|
||||
images = features_dict['video']
|
||||
_, l, h, w, c = images.shape
|
||||
images = tf.reshape(images, [-1, l // 2, 2, h // 2, 2, w // 2, 2, c])
|
||||
images = tf.transpose(images, [0, 1, 3, 5, 2, 4, 6, 7])
|
||||
images = tf.reshape(images, [-1, l // 2, h // 2, w // 2, 8 * c])
|
||||
features_dict['video'] = images
|
||||
return features_dict
|
||||
|
||||
|
||||
def reshape_windows(features_dict, num_frames):
|
||||
x = features_dict['video']
|
||||
x = tf.reshape(x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4]))
|
||||
features_dict['video'] = x
|
||||
return features_dict
|
||||
|
||||
|
||||
def compute_accuracy_metrics(pred, gt, prefix=''):
|
||||
order_pred = np.argsort(pred, axis=1)
|
||||
assert len(gt.shape) == len(order_pred.shape) == 2
|
||||
top1_pred = order_pred[:, -1:]
|
||||
top5_pred = order_pred[:, -5:]
|
||||
top1_acc = np.mean(top1_pred == gt)
|
||||
top5_acc = np.mean(np.max(top5_pred == gt, 1))
|
||||
return {prefix + 'top1': top1_acc,
|
||||
prefix + 'top5': top5_acc}
|
||||
|
||||
|
||||
def forward_fn(images: jnp.ndarray,
|
||||
audio_spectrogram: jnp.ndarray,
|
||||
word_ids: jnp.ndarray,
|
||||
is_training: bool,
|
||||
model_config: Dict[str, Any]):
|
||||
"""Forward pass of the model."""
|
||||
|
||||
# This should contain the pre-trained weights. We set it to zero because it
|
||||
# will be loaded from the checkpoint.
|
||||
language_model_vocab_size = 65536
|
||||
word_embedding_dim = 300
|
||||
dummy_embedding_matrix = jnp.zeros(shape=(language_model_vocab_size,
|
||||
word_embedding_dim))
|
||||
|
||||
module = mm_embeddings.AudioTextVideoEmbedding(
|
||||
**model_config,
|
||||
word_embedding_matrix=dummy_embedding_matrix)
|
||||
return module(images=images,
|
||||
audio_spectrogram=audio_spectrogram,
|
||||
word_ids=word_ids,
|
||||
is_training=is_training)['vid_repr']
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
sklearn_reg = 0.001
|
||||
model_config = config.get_model_config(FLAGS.checkpoint_path)
|
||||
|
||||
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
|
||||
forward_apply = jax.jit(functools.partial(forward.apply,
|
||||
is_training=False,
|
||||
model_config=model_config))
|
||||
|
||||
# Get the UCF101 config.
|
||||
dset_config = tfds.video.ucf101.Ucf101.BUILDER_CONFIGS[FLAGS.ucf101_split]
|
||||
|
||||
builder = ucf101_dataset.ModUcf101(
|
||||
data_dir=FLAGS.dataset_folder,
|
||||
config=dset_config)
|
||||
# Create the tfrecord files (no-op if already exists)
|
||||
dl_config = tfds.download.DownloadConfig(verify_ssl=False)
|
||||
builder.download_and_prepare(download_config=dl_config)
|
||||
|
||||
# Generate the training dataset.
|
||||
train_ds = builder.as_dataset(split='train', shuffle_files=False)
|
||||
train_ds = train_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
|
||||
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=True,
|
||||
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size))
|
||||
train_ds = train_ds.batch(batch_size=FLAGS.train_batch_size)
|
||||
if model_config['visual_backbone'] == 's3d':
|
||||
train_ds = train_ds.map(space_to_depth_batch)
|
||||
train_ds = train_ds.repeat(FLAGS.num_train_epochs)
|
||||
|
||||
# Generate the test dataset.
|
||||
test_ds = builder.as_dataset(split='test', shuffle_files=False)
|
||||
test_ds = test_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
|
||||
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=False,
|
||||
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size,
|
||||
num_windows=FLAGS.num_test_windows))
|
||||
test_ds = test_ds.batch(batch_size=FLAGS.eval_batch_size)
|
||||
test_ds = test_ds.map(lambda x: reshape_windows( # pylint: disable=g-long-lambda
|
||||
x, num_frames=FLAGS.num_frames))
|
||||
|
||||
if model_config['visual_backbone'] == 's3d':
|
||||
test_ds = test_ds.map(space_to_depth_batch)
|
||||
test_ds = test_ds.repeat(1)
|
||||
|
||||
pretrained_weights = checkpoint.load_checkpoint(FLAGS.checkpoint_path)
|
||||
params = pretrained_weights['params']
|
||||
state = pretrained_weights['state']
|
||||
|
||||
# Collect training samples.
|
||||
audio_frames = 96
|
||||
mel_filters = 40
|
||||
num_tokens = 16
|
||||
dummy_audio = jnp.zeros(
|
||||
shape=(FLAGS.train_batch_size, audio_frames, mel_filters, 1))
|
||||
dummy_word_ids = jnp.zeros(
|
||||
shape=(FLAGS.train_batch_size, num_tokens), dtype=jnp.int32)
|
||||
|
||||
train_features = []
|
||||
train_labels = []
|
||||
print('Computing features on train')
|
||||
training_examples = iter(tfds.as_numpy(train_ds))
|
||||
for train_ex in training_examples:
|
||||
vid_representation, _ = forward_apply(params=params,
|
||||
state=state,
|
||||
images=train_ex['video'],
|
||||
audio_spectrogram=dummy_audio,
|
||||
word_ids=dummy_word_ids)
|
||||
train_labels.append(train_ex['label'])
|
||||
train_features.append(vid_representation)
|
||||
if len(train_labels) % 50 == 0:
|
||||
print(f'Processed {len(train_labels)} examples.')
|
||||
|
||||
train_labels = np.concatenate(train_labels, axis=0)
|
||||
train_features = np.concatenate(train_features, axis=0)
|
||||
print(f'Finish collecting train features of shape {train_features.shape}')
|
||||
|
||||
# Collect test samples.
|
||||
dummy_audio = jnp.zeros(
|
||||
shape=(FLAGS.eval_batch_size, audio_frames, mel_filters, 1))
|
||||
dummy_word_ids = jnp.zeros(
|
||||
shape=(FLAGS.eval_batch_size, num_tokens), dtype=jnp.int32)
|
||||
|
||||
test_features = []
|
||||
test_labels = []
|
||||
print('Computing features on test')
|
||||
test_examples = iter(tfds.as_numpy(test_ds))
|
||||
for test_ex in test_examples:
|
||||
vid_representation_test, _ = forward_apply(params=params,
|
||||
state=state,
|
||||
images=test_ex['video'],
|
||||
audio_spectrogram=dummy_audio,
|
||||
word_ids=dummy_word_ids)
|
||||
test_labels.append(test_ex['label'])
|
||||
test_features.append(vid_representation_test)
|
||||
if len(test_labels) % 50 == 0:
|
||||
print(f'Processed {len(test_labels)} examples.')
|
||||
|
||||
test_features = np.concatenate(test_features, axis=0)
|
||||
test_labels = np.concatenate(test_labels, axis=0)
|
||||
print(f'Finish collecting test features of shape {test_features.shape}')
|
||||
|
||||
# Train classifier
|
||||
print('Training linear classifier!')
|
||||
classifier = sklearn.svm.LinearSVC(C=sklearn_reg)
|
||||
scaler = preprocessing.StandardScaler().fit(train_features)
|
||||
train_features = scaler.transform(train_features)
|
||||
classifier.fit(train_features, train_labels.ravel())
|
||||
print('Training done !')
|
||||
|
||||
# Evaluation.
|
||||
test_features = scaler.transform(test_features)
|
||||
print('Running inference on train')
|
||||
pred_train = classifier.decision_function(train_features)
|
||||
print('Running inference on test')
|
||||
pred_test = classifier.decision_function(test_features)
|
||||
if FLAGS.num_test_windows > 1:
|
||||
pred_test = np.reshape(
|
||||
pred_test, (test_labels.shape[0], -1, pred_test.shape[1]))
|
||||
pred_test = pred_test.mean(axis=1)
|
||||
|
||||
# Compute accuracies.
|
||||
metrics = compute_accuracy_metrics(pred_train, train_labels[:, None],
|
||||
prefix='train_')
|
||||
metrics.update(
|
||||
compute_accuracy_metrics(pred_test, test_labels[:, None], prefix='test_'))
|
||||
print(metrics)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 53 KiB |
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,143 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Normalize functions constructors."""
|
||||
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
import haiku as hk
|
||||
from jax import numpy as jnp
|
||||
|
||||
from mmv.models import types
|
||||
|
||||
|
||||
class _BatchNorm(hk.BatchNorm):
|
||||
"""A `hk.BatchNorm` with adapted default arguments."""
|
||||
|
||||
def __init__(self,
|
||||
create_scale: bool = True,
|
||||
create_offset: bool = True,
|
||||
decay_rate: float = 0.9,
|
||||
eps: float = 1e-5,
|
||||
test_local_stats: bool = False,
|
||||
**kwargs):
|
||||
# Check args.
|
||||
if kwargs.get('cross_replica_axis', None) is not None:
|
||||
raise ValueError(
|
||||
'Attempting to use \'batch_norm\' normalizer, but specifying '
|
||||
'`cross_replica_axis`. If you want this behavior use '
|
||||
'`normalizer=\'cross_replica_batch_norm\'` directly.')
|
||||
|
||||
self._test_local_stats = test_local_stats
|
||||
super().__init__(create_scale=create_scale,
|
||||
create_offset=create_offset,
|
||||
decay_rate=decay_rate,
|
||||
eps=eps,
|
||||
**kwargs)
|
||||
|
||||
def __call__(self,
|
||||
x: types.TensorLike,
|
||||
is_training: bool) -> jnp.ndarray:
|
||||
return super().__call__(x, is_training,
|
||||
test_local_stats=self._test_local_stats)
|
||||
|
||||
|
||||
class _CrossReplicaBatchNorm(hk.BatchNorm):
|
||||
"""A `hk.BatchNorm` with adapted default arguments for cross replica."""
|
||||
|
||||
def __init__(self,
|
||||
create_scale: bool = True,
|
||||
create_offset: bool = True,
|
||||
decay_rate: float = 0.9,
|
||||
eps: float = 1e-5,
|
||||
test_local_stats: bool = False,
|
||||
**kwargs):
|
||||
# Check args.
|
||||
if 'cross_replica_axis' in kwargs and kwargs['cross_replica_axis'] is None:
|
||||
raise ValueError(
|
||||
'Attempting to use \'cross_replica_batch_norm\' normalizer, but '
|
||||
'specifying `cross_replica_axis` to be None. If you want this '
|
||||
'behavior use `normalizer=\'batch_norm\'` directly.')
|
||||
|
||||
self._test_local_stats = test_local_stats
|
||||
kwargs['cross_replica_axis'] = kwargs.get('cross_replica_axis', 'i')
|
||||
super().__init__(create_scale=create_scale,
|
||||
create_offset=create_offset,
|
||||
decay_rate=decay_rate,
|
||||
eps=eps,
|
||||
**kwargs)
|
||||
|
||||
def __call__(self,
|
||||
x: types.TensorLike,
|
||||
is_training: bool) -> jnp.ndarray:
|
||||
return super().__call__(x, is_training,
|
||||
test_local_stats=self._test_local_stats)
|
||||
|
||||
|
||||
class _LayerNorm(hk.LayerNorm):
|
||||
"""A `hk.LayerNorm` accepting (and discarding) an `is_training` argument."""
|
||||
|
||||
def __init__(self,
|
||||
axis: Union[int, Sequence[int]] = (1, 2),
|
||||
create_scale: bool = True,
|
||||
create_offset: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(axis=axis,
|
||||
create_scale=create_scale,
|
||||
create_offset=create_offset,
|
||||
**kwargs)
|
||||
|
||||
def __call__(self,
|
||||
x: types.TensorLike,
|
||||
is_training: bool) -> jnp.ndarray:
|
||||
del is_training # Unused.
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
_NORMALIZER_NAME_TO_CLASS = {
|
||||
'batch_norm': _BatchNorm,
|
||||
'cross_replica_batch_norm': _CrossReplicaBatchNorm,
|
||||
'layer_norm': _LayerNorm,
|
||||
}
|
||||
|
||||
|
||||
def get_normalize_fn(
|
||||
normalizer_name: str = 'batch_norm',
|
||||
normalizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> types.NormalizeFn:
|
||||
"""Handles NormalizeFn creation.
|
||||
|
||||
These functions are expected to be used as part of Haiku model. On each
|
||||
application of the returned normalization_fn, a new Haiku layer will be added
|
||||
to the model.
|
||||
|
||||
Args:
|
||||
normalizer_name: The name of the normalizer to be constructed.
|
||||
normalizer_kwargs: The kwargs passed to the normalizer constructor.
|
||||
|
||||
Returns:
|
||||
A `types.NormalizeFn` that when applied will create a new layer.
|
||||
|
||||
Raises:
|
||||
ValueError: If `normalizer_name` is unknown.
|
||||
"""
|
||||
# Check args.
|
||||
if normalizer_name not in _NORMALIZER_NAME_TO_CLASS:
|
||||
raise ValueError(f'Unrecognized `normalizer_name` {normalizer_name}.')
|
||||
|
||||
normalizer_class = _NORMALIZER_NAME_TO_CLASS[normalizer_name]
|
||||
normalizer_kwargs = normalizer_kwargs or dict()
|
||||
|
||||
return lambda *a, **k: normalizer_class(**normalizer_kwargs)(*a, **k) # pylint: disable=unnecessary-lambda
|
||||
@@ -0,0 +1,329 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3.
|
||||
"""ResNet V2 modules.
|
||||
|
||||
Equivalent to hk.Resnet except accepting a final_endpoint to return
|
||||
intermediate activations.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence, Text, Type, Union
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from mmv.models import types
|
||||
|
||||
|
||||
class BottleneckBlock(hk.Module):
|
||||
"""Implements a bottleneck residual block (ResNet50 and ResNet101)."""
|
||||
|
||||
# pylint:disable=g-bare-generic
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
stride: Union[int, Sequence[int]],
|
||||
use_projection: bool,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
name: Optional[Text] = None):
|
||||
super(BottleneckBlock, self).__init__(name=name)
|
||||
self._channels = channels
|
||||
self._stride = stride
|
||||
self._use_projection = use_projection
|
||||
self._normalize_fn = normalize_fn
|
||||
|
||||
if self._use_projection:
|
||||
self._proj_conv = hk.Conv2D(
|
||||
output_channels=channels,
|
||||
kernel_shape=1,
|
||||
stride=stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='shortcut_conv')
|
||||
|
||||
self._conv_0 = hk.Conv2D(
|
||||
output_channels=channels // 4,
|
||||
kernel_shape=1,
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_0')
|
||||
|
||||
self._conv_1 = hk.Conv2D(
|
||||
output_channels=channels // 4,
|
||||
kernel_shape=3,
|
||||
stride=stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_1')
|
||||
|
||||
self._conv_2 = hk.Conv2D(
|
||||
output_channels=channels,
|
||||
kernel_shape=1,
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_2')
|
||||
|
||||
def __call__(self,
|
||||
inputs,
|
||||
is_training):
|
||||
net = inputs
|
||||
shortcut = inputs
|
||||
|
||||
for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]):
|
||||
if self._normalize_fn is not None:
|
||||
net = self._normalize_fn(net, is_training=is_training)
|
||||
net = jax.nn.relu(net)
|
||||
if i == 0 and self._use_projection:
|
||||
shortcut = self._proj_conv(net)
|
||||
|
||||
# Now do the convs.
|
||||
net = conv_i(net)
|
||||
|
||||
return net + shortcut
|
||||
|
||||
|
||||
class BasicBlock(hk.Module):
|
||||
"""Implements a basic residual block (ResNet18 and ResNet34)."""
|
||||
|
||||
# pylint:disable=g-bare-generic
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
stride: Union[int, Sequence[int]],
|
||||
use_projection: bool,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
name: Optional[Text] = None):
|
||||
super(BasicBlock, self).__init__(name=name)
|
||||
self._channels = channels
|
||||
self._stride = stride
|
||||
self._use_projection = use_projection
|
||||
self._normalize_fn = normalize_fn
|
||||
|
||||
if self._use_projection:
|
||||
self._proj_conv = hk.Conv2D(
|
||||
output_channels=channels,
|
||||
kernel_shape=1,
|
||||
stride=stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='shortcut_conv')
|
||||
|
||||
self._conv_0 = hk.Conv2D(
|
||||
output_channels=channels,
|
||||
kernel_shape=1,
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_0')
|
||||
|
||||
self._conv_1 = hk.Conv2D(
|
||||
output_channels=channels,
|
||||
kernel_shape=3,
|
||||
stride=stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_1')
|
||||
|
||||
def __call__(self,
|
||||
inputs,
|
||||
is_training):
|
||||
net = inputs
|
||||
shortcut = inputs
|
||||
|
||||
for i, conv_i in enumerate([self._conv_0, self._conv_1]):
|
||||
if self._normalize_fn is not None:
|
||||
net = self._normalize_fn(net, is_training=is_training)
|
||||
net = jax.nn.relu(net)
|
||||
if i == 0 and self._use_projection:
|
||||
shortcut = self._proj_conv(net)
|
||||
|
||||
# Now do the convs.
|
||||
net = conv_i(net)
|
||||
|
||||
return net + shortcut
|
||||
|
||||
|
||||
class ResNetUnit(hk.Module):
|
||||
"""Unit (group of blocks) for ResNet."""
|
||||
|
||||
# pylint:disable=g-bare-generic
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
num_blocks: int,
|
||||
stride: Union[int, Sequence[int]],
|
||||
block_module: Type[BottleneckBlock],
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
name: Optional[Text] = None,
|
||||
remat: bool = False):
|
||||
super(ResNetUnit, self).__init__(name=name)
|
||||
self._channels = channels
|
||||
self._num_blocks = num_blocks
|
||||
self._stride = stride
|
||||
self._normalize_fn = normalize_fn
|
||||
self._block_module = block_module
|
||||
self._remat = remat
|
||||
|
||||
def __call__(self,
|
||||
inputs,
|
||||
is_training):
|
||||
|
||||
input_channels = inputs.shape[-1]
|
||||
|
||||
self._blocks = []
|
||||
for id_block in range(self._num_blocks):
|
||||
use_projection = id_block == 0 and self._channels != input_channels
|
||||
self._blocks.append(
|
||||
self._block_module(
|
||||
channels=self._channels,
|
||||
stride=self._stride if id_block == 0 else 1,
|
||||
use_projection=use_projection,
|
||||
normalize_fn=self._normalize_fn,
|
||||
name='block_%d' % id_block))
|
||||
|
||||
net = inputs
|
||||
for block in self._blocks:
|
||||
if self._remat:
|
||||
# Note: we can ignore cell-var-from-loop because the lambda is evaluated
|
||||
# inside every iteration of the loop. This is needed to go around the
|
||||
# way variables are passed to jax.remat.
|
||||
net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop
|
||||
else:
|
||||
net = block(net, is_training=is_training)
|
||||
return net
|
||||
|
||||
|
||||
class ResNetV2(hk.Module):
|
||||
"""ResNetV2 model."""
|
||||
|
||||
# Endpoints of the model in order.
|
||||
VALID_ENDPOINTS = (
|
||||
'resnet_stem',
|
||||
'resnet_unit_0',
|
||||
'resnet_unit_1',
|
||||
'resnet_unit_2',
|
||||
'resnet_unit_3',
|
||||
'last_conv',
|
||||
'output',
|
||||
)
|
||||
|
||||
# pylint:disable=g-bare-generic
|
||||
def __init__(self,
|
||||
depth=50,
|
||||
num_classes: Optional[int] = 1000,
|
||||
width_mult: int = 1,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
name: Optional[Text] = None,
|
||||
remat: bool = False):
|
||||
"""Creates ResNetV2 Haiku module.
|
||||
|
||||
Args:
|
||||
depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
|
||||
num_classes: (int) Number of outputs in final layer. If None will not add
|
||||
a classification head and will return the output embedding.
|
||||
width_mult: multiplier for channel width.
|
||||
normalize_fn: normalization function, see helpers/utils.py
|
||||
name: Name of the module.
|
||||
remat: Whether to rematerialize intermediate activations (saves memory).
|
||||
"""
|
||||
super(ResNetV2, self).__init__(name=name)
|
||||
self._normalize_fn = normalize_fn
|
||||
self._num_classes = num_classes
|
||||
self._width_mult = width_mult
|
||||
|
||||
self._strides = [1, 2, 2, 2]
|
||||
num_blocks = {
|
||||
18: [2, 2, 2, 2],
|
||||
34: [3, 4, 6, 3],
|
||||
50: [3, 4, 6, 3],
|
||||
101: [3, 4, 23, 3],
|
||||
152: [3, 8, 36, 3],
|
||||
200: [3, 24, 36, 3],
|
||||
}
|
||||
if depth not in num_blocks:
|
||||
raise ValueError(
|
||||
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
|
||||
self._num_blocks = num_blocks[depth]
|
||||
|
||||
if depth >= 50:
|
||||
self._block_module = BottleneckBlock
|
||||
self._channels = [256, 512, 1024, 2048]
|
||||
else:
|
||||
self._block_module = BasicBlock
|
||||
self._channels = [64, 128, 256, 512]
|
||||
|
||||
self._initial_conv = hk.Conv2D(
|
||||
output_channels=64 * self._width_mult,
|
||||
kernel_shape=7,
|
||||
stride=2,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='initial_conv')
|
||||
|
||||
if remat:
|
||||
self._initial_conv = hk.remat(self._initial_conv)
|
||||
|
||||
self._block_groups = []
|
||||
for i in range(4):
|
||||
self._block_groups.append(
|
||||
ResNetUnit(
|
||||
channels=self._channels[i] * self._width_mult,
|
||||
num_blocks=self._num_blocks[i],
|
||||
block_module=self._block_module,
|
||||
stride=self._strides[i],
|
||||
normalize_fn=self._normalize_fn,
|
||||
name='block_group_%d' % i,
|
||||
remat=remat))
|
||||
|
||||
if num_classes is not None:
|
||||
self._logits_layer = hk.Linear(
|
||||
output_size=num_classes, w_init=jnp.zeros, name='logits')
|
||||
|
||||
def __call__(self, inputs, is_training, final_endpoint='output'):
|
||||
self._final_endpoint = final_endpoint
|
||||
net = self._initial_conv(inputs)
|
||||
net = hk.max_pool(
|
||||
net, window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1),
|
||||
padding='SAME')
|
||||
end_point = 'resnet_stem'
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
|
||||
for i_group, block_group in enumerate(self._block_groups):
|
||||
net = block_group(net, is_training=is_training)
|
||||
end_point = f'resnet_unit_{i_group}'
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
|
||||
end_point = 'last_conv'
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
|
||||
if self._normalize_fn is not None:
|
||||
net = self._normalize_fn(net, is_training=is_training)
|
||||
net = jax.nn.relu(net)
|
||||
|
||||
# The actual representation
|
||||
net = jnp.mean(net, axis=[1, 2])
|
||||
|
||||
assert self._final_endpoint == 'output'
|
||||
if self._num_classes is None:
|
||||
# If num_classes was None, we just return the output
|
||||
# of the last block, without fully connected layer.
|
||||
return net
|
||||
|
||||
return self._logits_layer(net)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,88 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for s3d."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from mmv.models import normalization
|
||||
from mmv.models import s3d
|
||||
|
||||
|
||||
class _CallableS3D:
|
||||
"""Wrapper around S3D that take care of parameter book keeping."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._model = hk.transform_with_state(
|
||||
lambda *a, **k: # pylint: disable=g-long-lambda,unnecessary-lambda
|
||||
s3d.S3D(
|
||||
normalize_fn=normalization.get_normalize_fn(),
|
||||
*args, **kwargs)(*a, **k))
|
||||
self._rng = jax.random.PRNGKey(42)
|
||||
self._params, self._state = None, None
|
||||
|
||||
def init(self, inputs, **kwargs):
|
||||
self._params, self._state = self._model.init(
|
||||
self._rng, inputs, is_training=True, **kwargs)
|
||||
|
||||
def __call__(self, inputs, **kwargs):
|
||||
if self._params is None:
|
||||
self.init(inputs)
|
||||
output, _ = self._model.apply(
|
||||
self._params, self._state, self._rng, inputs, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class S3DTest(parameterized.TestCase):
|
||||
|
||||
# Testing all layers is quite slow, added in comments for completeness.
|
||||
@parameterized.parameters(
|
||||
# dict(endpoint='Conv2d_1a_7x7', expected_size=(2, 8, 112, 112, 64)),
|
||||
# dict(endpoint='MaxPool_2a_3x3', expected_size=(2, 8, 56, 56, 64)),
|
||||
# dict(endpoint='Conv2d_2b_1x1', expected_size=(2, 8, 56, 56, 64)),
|
||||
# dict(endpoint='Conv2d_2c_3x3', expected_size=(2, 8, 56, 56, 192)),
|
||||
# dict(endpoint='MaxPool_3a_3x3', expected_size=(2, 8, 28, 28, 192)),
|
||||
# dict(endpoint='Mixed_3b', expected_size=(2, 8, 28, 28, 256)),
|
||||
# dict(endpoint='Mixed_3c', expected_size=(2, 8, 28, 28, 480)),
|
||||
# dict(endpoint='MaxPool_4a_3x3', expected_size=(2, 4, 14, 14, 480)),
|
||||
# dict(endpoint='Mixed_4b', expected_size=(2, 4, 14, 14, 512)),
|
||||
# dict(endpoint='Mixed_4c', expected_size=(2, 4, 14, 14, 512)),
|
||||
# dict(endpoint='Mixed_4d', expected_size=(2, 4, 14, 14, 512)),
|
||||
# dict(endpoint='Mixed_4e', expected_size=(2, 4, 14, 14, 528)),
|
||||
# dict(endpoint='Mixed_4f', expected_size=(2, 4, 14, 14, 832)),
|
||||
# dict(endpoint='MaxPool_5a_2x2', expected_size=(2, 2, 7, 7, 832)),
|
||||
# dict(endpoint='Mixed_5b', expected_size=(2, 2, 7, 7, 832)),
|
||||
# dict(endpoint='Mixed_5c', expected_size=(2, 2, 7, 7, 1024)),
|
||||
dict(endpoint='Embeddings', expected_size=(2, 1024)),
|
||||
)
|
||||
def test_endpoint_expected_output_dimensions(self, endpoint, expected_size):
|
||||
inputs = np.random.normal(size=(2, 16, 224, 224, 3))
|
||||
model = _CallableS3D()
|
||||
output = model(inputs, is_training=False, final_endpoint=endpoint)
|
||||
self.assertSameElements(output.shape, expected_size)
|
||||
|
||||
def test_space_to_depth(self):
|
||||
inputs = np.random.normal(size=(2, 16//2, 224//2, 224//2, 3*2*2*2))
|
||||
model = _CallableS3D()
|
||||
output = model(inputs, is_training=False, final_endpoint='Conv2d_1a_7x7')
|
||||
self.assertSameElements(output.shape, (2, 8, 112, 112, 64))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,353 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Temporal Shift Module w/ ResNet-50 and ResNet-101.
|
||||
|
||||
Based on:
|
||||
TSM: Temporal Shift Module for Efficient Video Understanding
|
||||
Ji Lin, Chuang Gan, Song Han
|
||||
https://arxiv.org/pdf/1811.08383.pdf.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from mmv.models import tsm_utils as tsmu
|
||||
from mmv.models import types
|
||||
|
||||
|
||||
class TSMResNetBlock(hk.Module):
|
||||
"""A ResNet subblock with Temporal Channel Shifting.
|
||||
|
||||
Combines a typical ResNetV2 block implementation
|
||||
(see https://arxiv.org/abs/1512.03385) with a pre-convolution Temporal
|
||||
Shift Module (see https://arxiv.org/pdf/1811.08383.pdf) in the residual.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_channels: int,
|
||||
stride: int,
|
||||
use_projection: bool,
|
||||
tsm_mode: str,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
channel_shift_fraction: float = 0.125,
|
||||
num_frames: int = 8,
|
||||
name: str = 'TSMResNetBlock'):
|
||||
"""Initializes the TSMResNetBlock module.
|
||||
|
||||
Args:
|
||||
output_channels: Number of output channels.
|
||||
stride: Stride used in convolutions.
|
||||
use_projection: Whether to use a projection for the shortcut.
|
||||
tsm_mode: Mode for TSM ('gpu' or 'tpu').
|
||||
normalize_fn: Function used for normalization.
|
||||
channel_shift_fraction: The fraction of temporally shifted channels. If
|
||||
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
|
||||
block.
|
||||
num_frames: Size of frame dimension in a single batch example
|
||||
name: The name of the module.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._output_channels = output_channels
|
||||
self._bottleneck_channels = output_channels // 4
|
||||
self._stride = stride
|
||||
self._use_projection = use_projection
|
||||
self._normalize_fn = normalize_fn
|
||||
self._tsm_mode = tsm_mode
|
||||
self._channel_shift_fraction = channel_shift_fraction
|
||||
self._num_frames = num_frames
|
||||
|
||||
def __call__(self,
|
||||
inputs: types.TensorLike,
|
||||
is_training: bool = True) -> jnp.ndarray:
|
||||
"""Connects the ResNetBlock module into the graph.
|
||||
|
||||
Args:
|
||||
inputs: A 4-D float array of shape `[B, H, W, C]`.
|
||||
is_training: Whether to use training mode.
|
||||
|
||||
Returns:
|
||||
A 4-D float array of shape
|
||||
`[B * num_frames, new_h, new_w, output_channels]`.
|
||||
"""
|
||||
# ResNet V2 uses pre-activation, where the batch norm and relu are before
|
||||
# convolutions, rather than after as in ResNet V1.
|
||||
preact = inputs
|
||||
if self._normalize_fn is not None:
|
||||
preact = self._normalize_fn(preact, is_training=is_training)
|
||||
preact = jax.nn.relu(preact)
|
||||
|
||||
if self._use_projection:
|
||||
shortcut = hk.Conv2D(
|
||||
output_channels=self._output_channels,
|
||||
kernel_shape=1,
|
||||
stride=self._stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='shortcut_conv')(
|
||||
preact)
|
||||
else:
|
||||
shortcut = inputs
|
||||
|
||||
# Eventually applies Temporal Shift Module.
|
||||
if self._channel_shift_fraction != 0:
|
||||
preact = tsmu.apply_temporal_shift(
|
||||
preact, tsm_mode=self._tsm_mode, num_frames=self._num_frames,
|
||||
channel_shift_fraction=self._channel_shift_fraction)
|
||||
|
||||
# First convolution.
|
||||
residual = hk.Conv2D(
|
||||
self._bottleneck_channels,
|
||||
kernel_shape=1,
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_0')(
|
||||
preact)
|
||||
|
||||
# Second convolution.
|
||||
if self._normalize_fn is not None:
|
||||
residual = self._normalize_fn(residual, is_training=is_training)
|
||||
residual = jax.nn.relu(residual)
|
||||
residual = hk.Conv2D(
|
||||
output_channels=self._bottleneck_channels,
|
||||
kernel_shape=3,
|
||||
stride=self._stride,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_1')(
|
||||
residual)
|
||||
|
||||
# Third convolution.
|
||||
if self._normalize_fn is not None:
|
||||
residual = self._normalize_fn(residual, is_training=is_training)
|
||||
residual = jax.nn.relu(residual)
|
||||
residual = hk.Conv2D(
|
||||
output_channels=self._output_channels,
|
||||
kernel_shape=1,
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='conv_2')(
|
||||
residual)
|
||||
|
||||
# NOTE: we do not use block multiplier.
|
||||
output = shortcut + residual
|
||||
return output
|
||||
|
||||
|
||||
class TSMResNetUnit(hk.Module):
|
||||
"""Block group for TSM ResNet."""
|
||||
|
||||
def __init__(self,
|
||||
output_channels: int,
|
||||
num_blocks: int,
|
||||
stride: int,
|
||||
tsm_mode: str,
|
||||
num_frames: int,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
channel_shift_fraction: float = 0.125,
|
||||
name: str = 'tsm_resnet_unit'):
|
||||
"""Creates a TSMResNet Unit.
|
||||
|
||||
Args:
|
||||
output_channels: Number of output channels.
|
||||
num_blocks: Number of ResNet blocks in the unit.
|
||||
stride: Stride of the unit.
|
||||
tsm_mode: Which temporal shift module to use.
|
||||
num_frames: Size of frame dimension in a single batch example.
|
||||
normalize_fn: Function used for normalization.
|
||||
channel_shift_fraction: The fraction of temporally shifted channels. If
|
||||
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
|
||||
block.
|
||||
name: The name of the module.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._output_channels = output_channels
|
||||
self._num_blocks = num_blocks
|
||||
self._normalize_fn = normalize_fn
|
||||
self._stride = stride
|
||||
self._tsm_mode = tsm_mode
|
||||
self._channel_shift_fraction = channel_shift_fraction
|
||||
self._num_frames = num_frames
|
||||
|
||||
def __call__(self,
|
||||
inputs: types.TensorLike,
|
||||
is_training: bool) -> jnp.ndarray:
|
||||
"""Connects the module to inputs.
|
||||
|
||||
Args:
|
||||
inputs: A 4-D float array of shape `[B * num_frames, H, W, C]`.
|
||||
is_training: Whether to use training mode.
|
||||
|
||||
Returns:
|
||||
A 4-D float array of shape
|
||||
`[B * num_frames, H // stride, W // stride, output_channels]`.
|
||||
"""
|
||||
net = inputs
|
||||
for idx_block in range(self._num_blocks):
|
||||
net = TSMResNetBlock(
|
||||
self._output_channels,
|
||||
stride=self._stride if idx_block == 0 else 1,
|
||||
use_projection=idx_block == 0,
|
||||
normalize_fn=self._normalize_fn,
|
||||
tsm_mode=self._tsm_mode,
|
||||
channel_shift_fraction=self._channel_shift_fraction,
|
||||
num_frames=self._num_frames,
|
||||
name=f'block_{idx_block}')(
|
||||
net, is_training=is_training)
|
||||
return net
|
||||
|
||||
|
||||
class TSMResNetV2(hk.Module):
|
||||
"""TSM based on ResNet V2 as described in https://arxiv.org/abs/1603.05027."""
|
||||
|
||||
# Endpoints of the model in order.
|
||||
VALID_ENDPOINTS = (
|
||||
'tsm_resnet_stem',
|
||||
'tsm_resnet_unit_0',
|
||||
'tsm_resnet_unit_1',
|
||||
'tsm_resnet_unit_2',
|
||||
'tsm_resnet_unit_3',
|
||||
'last_conv',
|
||||
'Embeddings',
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
normalize_fn: Optional[types.NormalizeFn] = None,
|
||||
depth: int = 50,
|
||||
num_frames: int = 16,
|
||||
channel_shift_fraction: float = 0.125,
|
||||
width_mult: int = 1,
|
||||
name: str = 'TSMResNetV2'):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
normalize_fn: Function used for normalization.
|
||||
depth: Depth of the desired ResNet.
|
||||
num_frames: Number of frames (used in TPU mode).
|
||||
channel_shift_fraction: Fraction of channels that are temporally shifted,
|
||||
if `channel_shift_fraction` is 0, a regular ResNet is returned.
|
||||
width_mult: Whether or not to use a width multiplier.
|
||||
name: The name of the module.
|
||||
|
||||
Raises:
|
||||
ValueError: If `channel_shift_fraction` or `depth` has invalid value.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
if not 0. <= channel_shift_fraction <= 1.0:
|
||||
raise ValueError(
|
||||
f'channel_shift_fraction ({channel_shift_fraction})'
|
||||
' has to be in [0, 1].')
|
||||
|
||||
self._num_frames = num_frames
|
||||
|
||||
self._channels = (256, 512, 1024, 2048)
|
||||
self._strides = (1, 2, 2, 2)
|
||||
|
||||
num_blocks = {
|
||||
50: (3, 4, 6, 3),
|
||||
101: (3, 4, 23, 3),
|
||||
152: (3, 8, 36, 3),
|
||||
200: (3, 24, 36, 3),
|
||||
}
|
||||
if depth not in num_blocks:
|
||||
raise ValueError(
|
||||
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
|
||||
self._num_blocks = num_blocks[depth]
|
||||
|
||||
self._width_mult = width_mult
|
||||
self._channel_shift_fraction = channel_shift_fraction
|
||||
self._normalize_fn = normalize_fn
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: types.TensorLike,
|
||||
is_training: bool = True,
|
||||
final_endpoint: str = 'Embeddings') -> jnp.ndarray:
|
||||
"""Connects the TSM ResNetV2 module into the graph.
|
||||
|
||||
Args:
|
||||
inputs: A 4-D float array of shape `[B, H, W, C]`.
|
||||
is_training: Whether to use training mode.
|
||||
final_endpoint: Up to which endpoint to run / return.
|
||||
|
||||
Returns:
|
||||
Network output at location `final_endpoint`. A float array which shape
|
||||
depends on `final_endpoint`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `final_endpoint` is not recognized.
|
||||
"""
|
||||
|
||||
# Prepare inputs for TSM.
|
||||
inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs)
|
||||
num_frames = num_frames or self._num_frames
|
||||
|
||||
self._final_endpoint = final_endpoint
|
||||
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
||||
raise ValueError(f'Unknown final endpoint {self._final_endpoint}')
|
||||
|
||||
# Stem convolution.
|
||||
end_point = 'tsm_resnet_stem'
|
||||
net = hk.Conv2D(
|
||||
output_channels=64 * self._width_mult,
|
||||
kernel_shape=7,
|
||||
stride=2,
|
||||
with_bias=False,
|
||||
name=end_point,
|
||||
padding='SAME')(
|
||||
inputs)
|
||||
net = hk.MaxPool(
|
||||
window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1),
|
||||
padding='SAME')(
|
||||
net)
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
|
||||
# Residual block.
|
||||
for unit_id, (channels, num_blocks, stride) in enumerate(
|
||||
zip(self._channels, self._num_blocks, self._strides)):
|
||||
end_point = f'tsm_resnet_unit_{unit_id}'
|
||||
net = TSMResNetUnit(
|
||||
output_channels=channels * self._width_mult,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
normalize_fn=self._normalize_fn,
|
||||
channel_shift_fraction=self._channel_shift_fraction,
|
||||
num_frames=num_frames,
|
||||
tsm_mode=tsm_mode,
|
||||
name=end_point)(
|
||||
net, is_training=is_training)
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
|
||||
if self._normalize_fn is not None:
|
||||
net = self._normalize_fn(net, is_training=is_training)
|
||||
net = jax.nn.relu(net)
|
||||
|
||||
end_point = 'last_conv'
|
||||
if self._final_endpoint == end_point:
|
||||
return net
|
||||
net = jnp.mean(net, axis=(1, 2))
|
||||
# Prepare embedding outputs for TSM (temporal average of features).
|
||||
net = tsmu.prepare_outputs(net, tsm_mode, num_frames)
|
||||
assert self._final_endpoint == 'Embeddings'
|
||||
return net
|
||||
@@ -0,0 +1,65 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TSM ResNet model."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from mmv.models import tsm_resnet
|
||||
|
||||
|
||||
class TSMResNetTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
('tsm_resnet_stem', (2 * 32, 56, 56, 64)),
|
||||
('tsm_resnet_unit_0', (2 * 32, 56, 56, 256)),
|
||||
('tsm_resnet_unit_1', (2 * 32, 28, 28, 512)),
|
||||
('tsm_resnet_unit_2', (2 * 32, 14, 14, 1024)),
|
||||
('tsm_resnet_unit_3', (2 * 32, 7, 7, 2048)),
|
||||
('last_conv', (2 * 32, 7, 7, 2048)),
|
||||
('Embeddings', (2, 2048)),
|
||||
)
|
||||
def test_output_dimension(self, final_endpoint, expected_shape):
|
||||
input_shape = (2, 32, 224, 224, 3)
|
||||
|
||||
def f():
|
||||
data = jnp.zeros(input_shape)
|
||||
net = tsm_resnet.TSMResNetV2()
|
||||
return net(data, final_endpoint=final_endpoint)
|
||||
|
||||
init_fn, apply_fn = hk.transform(f)
|
||||
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
|
||||
self.assertEqual(out.shape, expected_shape)
|
||||
|
||||
def test_tpu_mode(self):
|
||||
input_shape = (32 * 2, 224, 224, 3)
|
||||
|
||||
def f():
|
||||
data = jnp.zeros(input_shape)
|
||||
net = tsm_resnet.TSMResNetV2(num_frames=32)
|
||||
return net(data, final_endpoint='Embeddings')
|
||||
|
||||
init_fn, apply_fn = hk.transform(f)
|
||||
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
|
||||
self.assertEqual(out.shape, (2, 2048))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,177 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utils functions for TSM."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from mmv.models import types
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
inputs: types.TensorLike) -> Tuple[jnp.ndarray, str, int]:
|
||||
"""Deduces input mode for TSM."""
|
||||
# Deduce if we run on TPU based on input shape.
|
||||
if len(inputs.shape) == 5:
|
||||
# Input is given in the standard [B, T, H, W, 3] format.
|
||||
tsm_mode = 'gpu'
|
||||
num_frames = inputs.shape[1]
|
||||
inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:]))
|
||||
else:
|
||||
# Input is given in the [T * B, H, W, 3] format.
|
||||
tsm_mode = 'tpu'
|
||||
num_frames = None
|
||||
return inputs, tsm_mode, num_frames
|
||||
|
||||
|
||||
def prepare_outputs(outputs: types.TensorLike,
|
||||
tsm_mode: str,
|
||||
num_frames: int) -> jnp.ndarray:
|
||||
"""Processes output of TSM by averaging representations over time axis."""
|
||||
n_channels = outputs.shape[-1]
|
||||
if tsm_mode == 'tpu':
|
||||
outputs = jnp.reshape(outputs, [num_frames, -1, n_channels])
|
||||
outputs = jnp.mean(outputs, axis=0)
|
||||
elif tsm_mode == 'gpu':
|
||||
outputs = jnp.reshape(outputs, [-1, num_frames, n_channels])
|
||||
outputs = jnp.mean(outputs, axis=1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
|
||||
return outputs
|
||||
|
||||
|
||||
def apply_temporal_shift(
|
||||
x: types.TensorLike,
|
||||
tsm_mode: str,
|
||||
num_frames: int,
|
||||
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
|
||||
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode."""
|
||||
if tsm_mode == 'tpu':
|
||||
outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction)
|
||||
elif tsm_mode == 'gpu':
|
||||
outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
|
||||
return outputs
|
||||
|
||||
|
||||
def temporal_shift_gpu(
|
||||
x: types.TensorLike,
|
||||
num_frames: int,
|
||||
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
|
||||
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383."""
|
||||
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
|
||||
# Input is (B * T, H, W, C)
|
||||
orig_shp = tuple(x.shape)
|
||||
reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:])
|
||||
n_channels = orig_shp[-1]
|
||||
n_shift = int(n_channels * channel_shift_fraction)
|
||||
|
||||
new_shp = tuple(reshaped_x.shape)
|
||||
|
||||
# shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]
|
||||
shifted_backward = jax.lax.slice(
|
||||
reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift),
|
||||
(new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4]))
|
||||
shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0))
|
||||
shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding)
|
||||
|
||||
# shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]
|
||||
shifted_forward = jax.lax.slice(
|
||||
reshaped_x, (0, 0, 0, 0, 0),
|
||||
(new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift))
|
||||
shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0))
|
||||
shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding)
|
||||
|
||||
no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift]
|
||||
shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward],
|
||||
axis=4)
|
||||
return jnp.reshape(shifted_x, (-1,) + orig_shp[1:])
|
||||
|
||||
|
||||
def temporal_shift_tpu(
|
||||
x: types.TensorLike,
|
||||
num_frames: int,
|
||||
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
|
||||
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383.
|
||||
|
||||
TPU optimized version of TSM. Reshape is avoided by having the images
|
||||
reshaped in [T * B, :] so that frames corresponding to same time frame in
|
||||
videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse
|
||||
pad->slice into convolution, we reformulate the slice pad into a pad then
|
||||
slice. Finally, to avoid concatenate that prevent some fusion from happening
|
||||
we simply sum masked version of the features.
|
||||
Args:
|
||||
x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped
|
||||
from a time major version of the input).
|
||||
num_frames: number of frames T per video.
|
||||
channel_shift_fraction: fraction of the channel to shift forward and
|
||||
backward.
|
||||
|
||||
Returns:
|
||||
The temporal shifted version of x.
|
||||
"""
|
||||
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
|
||||
# Input is (T * B, H, W, C)
|
||||
original_shape = list(x.shape)
|
||||
|
||||
batch_size = int(original_shape[0] / num_frames)
|
||||
n_channels = int(original_shape[-1])
|
||||
n_shift = int(n_channels * channel_shift_fraction)
|
||||
|
||||
# Cast to bfloat16.
|
||||
x = x.astype(jnp.bfloat16)
|
||||
|
||||
# For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1.
|
||||
# Shift backward, we first pad by zeros [x1, x2, x3, 0, 0].
|
||||
orig_shp = list(x.shape)
|
||||
|
||||
shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0),
|
||||
(0, n_channels - n_shift, 0))
|
||||
x_backward_padding = jax.lax.pad(
|
||||
x,
|
||||
padding_value=jnp.bfloat16(0.),
|
||||
padding_config=shifted_backward_padding)
|
||||
# The following shift gets to [x3^+1, 0, 0] (where +1 means from the future).
|
||||
shifted_backward = jax.lax.slice(x_backward_padding,
|
||||
(batch_size, 0, 0, n_channels - n_shift),
|
||||
(orig_shp[0] + batch_size, orig_shp[1],
|
||||
orig_shp[2], 2 * n_channels - n_shift))
|
||||
# Shift forward, we first pad by zeros [0, 0, x1, x2, x3].
|
||||
shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0),
|
||||
(n_channels - n_shift, 0, 0))
|
||||
x_forward_padding = jax.lax.pad(
|
||||
x,
|
||||
padding_value=jnp.bfloat16(0.),
|
||||
padding_config=shifted_forward_padding)
|
||||
# The following shift gets to [0, 0, x1^-1] (where -1 means from the past).
|
||||
shifted_forward = jax.lax.slice(
|
||||
x_forward_padding, (0, 0, 0, 0),
|
||||
(orig_shp[0], orig_shp[1], orig_shp[2], n_channels))
|
||||
# No shift is in the middle, this gets [0, x2, 0].
|
||||
mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) &
|
||||
(jnp.arange(n_channels) < n_channels - n_shift),
|
||||
(1, 1, 1, -1))).astype(jnp.bfloat16)
|
||||
no_shift = mask_noshift * x
|
||||
# By summing everything together, we end up with [x3^+1, x2, x1^-1].
|
||||
# Note: channels have been reordered but that doesn't matter for the model.
|
||||
shifted_x = shifted_backward + shifted_forward + no_shift
|
||||
|
||||
return shifted_x.astype(jnp.float32)
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for tsm_utils."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from mmv.models import tsm_utils
|
||||
|
||||
|
||||
class TsmUtilsTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
((2, 32, 224, 224, 3), 'gpu', (2 * 32, 224, 224, 3), 32),
|
||||
((32, 224, 224, 3), 'tpu', (32, 224, 224, 3), None),
|
||||
)
|
||||
def test_prepare_inputs(self, input_shape, expected_mode, expected_shape,
|
||||
expected_num_frames):
|
||||
|
||||
data = jnp.zeros(input_shape)
|
||||
out, mode, num_frames = tsm_utils.prepare_inputs(data)
|
||||
self.assertEqual(out.shape, expected_shape)
|
||||
self.assertEqual(mode, expected_mode)
|
||||
self.assertEqual(num_frames, expected_num_frames)
|
||||
|
||||
def test_prepare_outputs(self):
|
||||
data = jnp.concatenate([jnp.zeros(4), jnp.ones(4)]).reshape(4, 2)
|
||||
out_gpu = tsm_utils.prepare_outputs(data, 'gpu', 2)
|
||||
out_tpu = tsm_utils.prepare_outputs(data, 'tpu', 2)
|
||||
expected_gpu = np.concatenate([np.zeros(2), np.ones(2)]).reshape(2, 2)
|
||||
expected_tpu = 0.5 * jnp.ones((2, 2))
|
||||
np.testing.assert_allclose(out_gpu, expected_gpu)
|
||||
np.testing.assert_allclose(out_tpu, expected_tpu)
|
||||
|
||||
def test_apply_tsm(self):
|
||||
shape = (32, 224, 224, 16)
|
||||
data = jnp.zeros(shape)
|
||||
out_gpu = tsm_utils.apply_temporal_shift(data, 'gpu', 16)
|
||||
out_tpu = tsm_utils.apply_temporal_shift(data, 'tpu', 16)
|
||||
self.assertEqual(out_gpu.shape, shape)
|
||||
self.assertEqual(out_tpu.shape, shape)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Type Aliases."""
|
||||
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
TensorLike = Union[np.ndarray, jnp.DeviceArray]
|
||||
|
||||
ActivationFn = Callable[[TensorLike], TensorLike]
|
||||
GatingFn = Callable[[TensorLike], TensorLike]
|
||||
NetworkFn = Callable[[TensorLike], TensorLike]
|
||||
|
||||
# Callable doesn't allow kwargs to be used, and we often want to
|
||||
# pass in is_training=..., so ignore the arguments for the sake of pytype.
|
||||
NormalizeFn = Callable[..., TensorLike]
|
||||
|
||||
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
dm-haiku
|
||||
dm-tree
|
||||
jax
|
||||
jaxlib
|
||||
numpy>=1.16
|
||||
optax
|
||||
sklearn
|
||||
tensorflow
|
||||
tensorflow_datasets
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Checkpoint restoring utilities."""
|
||||
|
||||
from absl import logging
|
||||
import dill
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path):
|
||||
try:
|
||||
with open(checkpoint_path, 'rb') as checkpoint_file:
|
||||
checkpoint_data = dill.load(checkpoint_file)
|
||||
logging.info('Loading checkpoint from %s', checkpoint_path)
|
||||
return checkpoint_data
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Ucf101 with custom decoding params."""
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
# Utilities functions.
|
||||
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
_CITATION = """\
|
||||
@article{DBLP:journals/corr/abs-1212-0402,
|
||||
author = {Khurram Soomro and
|
||||
Amir Roshan Zamir and
|
||||
Mubarak Shah},
|
||||
title = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in
|
||||
The Wild},
|
||||
journal = {CoRR},
|
||||
volume = {abs/1212.0402},
|
||||
year = {2012},
|
||||
url = {http://arxiv.org/abs/1212.0402},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {1212.0402},
|
||||
timestamp = {Mon, 13 Aug 2018 16:47:45 +0200},
|
||||
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
"""
|
||||
|
||||
_LABELS_FNAME = 'video/ucf101_labels.txt'
|
||||
|
||||
|
||||
class ModUcf101(tfds.video.Ucf101):
|
||||
"""Ucf101 action recognition dataset with better quality.
|
||||
"""
|
||||
|
||||
def _info(self):
|
||||
|
||||
ffmpeg_extra_args = ('-qscale:v', '2', '-r', '25', '-t', '00:00:20')
|
||||
|
||||
video_shape = (
|
||||
None, self.builder_config.height, self.builder_config.width, 3)
|
||||
labels_names_file = tfds.core.tfds_path(_LABELS_FNAME)
|
||||
features = tfds.features.FeaturesDict({
|
||||
'video': tfds.features.Video(video_shape,
|
||||
ffmpeg_extra_args=ffmpeg_extra_args,
|
||||
encoding_format='jpeg'),
|
||||
'label': tfds.features.ClassLabel(names_file=labels_names_file),
|
||||
})
|
||||
return tfds.core.DatasetInfo(
|
||||
builder=self,
|
||||
description='A 101-label video classification dataset.',
|
||||
features=features,
|
||||
homepage='https://www.crcv.ucf.edu/data-sets/ucf101/',
|
||||
citation=_CITATION,
|
||||
)
|
||||
Reference in New Issue
Block a user