diff --git a/README.md b/README.md index 43d7ce7..5183689 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Self-Supervised MultiModal Versatile Networks](mmv), NeurIPS 2020 * [ODE-GAN: Training GANs by Solving Ordinary Differential Equations](ode_gan), NeurIPS 2020 * [Algorithms for Causal Reasoning in Probability Trees](causal_reasoning) * [Gated Linear Networks](gated_linear_networks), NeurIPS 2020 diff --git a/mmv/README.md b/mmv/README.md new file mode 100644 index 0000000..35449ac --- /dev/null +++ b/mmv/README.md @@ -0,0 +1,83 @@ +# Self-supervised Multimodal Versatile Networks + +This is the code for the models in MMV - https://arxiv.org/abs/2006.16228. + + + +We also make available the code for linear evaluation of a pre-trained model +in UCF101 and the JAX checkpoints for our best models. + +We use different parameters for video compression in UCF101 than the ones +used in `tensorflow_datasets`. We provide the code to download and +preprocess the dataset. The eval_ucf101.py script reproduces the results we +report in Table 2 of the paper, using the checkpoints provided below. + +Visual Backbone | Training Dataset | Results on Linear UCF101 +------- | -------- | -------- +S3D-G | AudioSet + HowTo | 89.6 +Resnet TSM-50 | AudioSet + HowTo | 91.5 +Resnet TSM-50 (x2) | AudioSet + HowTo | 91.8 + + +## Setup + +To set up a Python virtual environment with the required dependencies, run: + +```shell +python3 -m venv mmv_env +source mmv_env/bin/activate +pip install --upgrade pip setuptools wheel +pip install -r mmv/requirements.txt --use-feature=2020-resolver +``` + + +### Linear evaluation + +The linear evaluation on UCF101 can be run using: + +```shell +python -m mmv.eval_ucf101 \ + --checkpoint_path= \ + --dataset_folder= +``` + +## Checkpoints + +We provide three checkpoints containing the best pre-trained weights for each +of the visual backbones we use in the paper, i. e., S3D-G, Resnet-50 TSM, +and Resnet-50 TSM x 2. + +- [S3D-G](https://storage.googleapis.com/deepmind-research-mmv/mmv_s3d.pkl) +- [Resnet-50 TSM](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x1.pkl) +- [Resnet-50 TSMx2](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x2.pkl) + +## References + +### Citing our work + +If you use that code for your research, please consider citing our paper: + +```bibtex +@inproceedings{alayrac2020self, + title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks}, + author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew}, + booktitle={NeurIPS}, + year={2020} +} +``` + +### Models in TF + +You may also be interested in using our TF-Hub release models available at: + +- [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1) +- [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1) +- [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50x2/1) + +## License + +While the code is licensed under the Apache 2.0 License, the checkpoints weights +are made available for non-commercial use only under the terms of the +Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +license. You can find details at: +https://creativecommons.org/licenses/by-nc/4.0/legalcode. diff --git a/mmv/config.py b/mmv/config.py new file mode 100644 index 0000000..f877acb --- /dev/null +++ b/mmv/config.py @@ -0,0 +1,85 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration parameters for MMV.""" + + +def get_model_config(ckpt_path): + """Returns the model configuration to be used with each checkpoint.""" + + config = { + 'audio_backbone': 'resnet50', + 'audio_model_kwargs': { + 'bn_config': { + 'create_offset': True, + 'create_scale': True, + 'decay_rate': 0.9, + 'eps': 1.0e-5 + } + }, + 'bn_config_proj': { + 'create_offset': True, + 'create_scale': True, + 'decay_rate': 0.9, + 'eps': 1.0e-5 + }, + 'config_audio_text': { + 'embedding_dim': 512, + 'toaud_bn_after_proj': False, + 'toaud_head_mode': 'linear', + 'totxt_bn_after_proj': False, + 'totxt_head_mode': 'linear' + }, + 'config_video_audio': { + 'embedding_dim': 512, + 'toaud_bn_after_proj': True, + 'toaud_head_mode': 'mlp@512', + 'tovid_bn_after_proj': False, + 'tovid_head_mode': 'linear' + }, + 'config_video_text': { + 'embedding_dim': 256, + 'totxt_bn_after_proj': True, + 'totxt_head_mode': 'linear', + 'tovid_bn_after_proj': False, + 'tovid_head_mode': 'linear' + }, + 'mm_embedding_graph': 'fac_relu', + 'name': 'text_audio_video', + 'sentence_dim': 2048, + 'use_xreplica_bn': True, + 'vision_model_kwargs': { + 'bn_config': { + 'create_offset': True, + 'create_scale': True, + 'decay_rate': 0.9, + 'eps': 1.0e-5 + }, + 'n_frames': 32, + 'width_mult': 1, + }, + } + + if 's3d' in ckpt_path: + config['visual_backbone'] = 's3d' + + if 'tsm_resnet_x1' in ckpt_path: + config['visual_backbone'] = 'resnet50tsm' + + if 'tsm_resnet_x2' in ckpt_path: + config['visual_backbone'] = 'resnet50tsm' + config['vision_model_kwargs']['width_mult'] = 2 + + return config diff --git a/mmv/eval_ucf101.py b/mmv/eval_ucf101.py new file mode 100644 index 0000000..bf28f5b --- /dev/null +++ b/mmv/eval_ucf101.py @@ -0,0 +1,465 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UCF101 linear evaluation.""" + +import functools +from typing import Any, Dict, Optional + +from absl import app +from absl import flags +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import sklearn +from sklearn import preprocessing +import sklearn.linear_model +import sklearn.svm +import tensorflow as tf +import tensorflow_datasets as tfds + +from mmv import config +from mmv.models import mm_embeddings +from mmv.utils import checkpoint +from mmv.utils import ucf101_dataset + + +flags.DEFINE_string('checkpoint_path', '~/tmp/mmv_s3d.pkl', + 'The directory to load pre-trained weights from.') +flags.DEFINE_string('dataset_folder', '/tmp/ucf101', + 'The directory with the ucf101 dataset.') + +flags.DEFINE_integer('eval_batch_size', 1, + 'The batch size for evaluation.') +flags.DEFINE_integer('train_batch_size', 16, + 'The batch size for training.') +flags.DEFINE_integer('num_train_epochs', 10, + 'How many epochs to collect features during training.') +flags.DEFINE_integer('num_test_windows', 10, + 'How many windows to average on during test.') +flags.DEFINE_integer('min_resize', 224, + 'Min value to resize images to during preprocessing.') +flags.DEFINE_integer('crop_size', 224, + 'Value to resize images to during preprocessing.') +flags.DEFINE_integer('num_frames', 32, + 'Number of video frames.') +flags.DEFINE_integer('stride', 2, + 'Stride for video frames.') +flags.DEFINE_integer('ucf101_split', 1, + 'Which split of ucf101 to use.') + + +FLAGS = flags.FLAGS + + +def get_sampling_offset(sequence: tf.Tensor, + num_steps: Optional[int], + is_training: bool, + stride: int = 1, + seed: Optional[int] = None) -> tf.Tensor: + """Calculates the initial offset for a sequence where all steps will fit. + + Args: + sequence: any tensor where the first dimension is timesteps. + num_steps: The number of timesteps we will output. If None, + deterministically start at the first frame. + is_training: A boolean indicates whether the graph is for training or not. + If False, the starting frame always the first frame. + stride: distance to sample between timesteps. + seed: a deterministic seed to use when sampling. + Returns: + The first index to begin sampling from. A best effort is made to provide a + starting index such that all requested steps fit within the sequence (i.e. + offset + 1 + (num_steps - 1) * stride < len(sequence)). If this is not + satisfied, the starting index is chosen randomly from the full sequence. + """ + if num_steps is None or not is_training: + return tf.constant(0) + sequence_length = tf.shape(sequence)[0] + max_offset = tf.cond( + tf.greater(sequence_length, (num_steps - 1) * stride), + lambda: sequence_length - (num_steps - 1) * stride, + lambda: sequence_length) + offset = tf.random.uniform( + (), + maxval=tf.cast(max_offset, tf.int32), + dtype=tf.int32, + seed=seed) + return offset + + +def sample_or_pad_sequence_indices(sequence: tf.Tensor, + num_steps: Optional[int], + is_training: bool, + repeat_sequence: bool = True, + stride: int = 1, + offset: Optional[int] = None) -> tf.Tensor: + """Returns indices to take for sampling or padding a sequence to fixed size. + + Samples num_steps from the sequence. If the sequence is shorter than + num_steps, the sequence loops. If the sequence is longer than num_steps and + is_training is True, then we seek to a random offset before sampling. If + offset is provided, we use that deterministic offset. + + This method is appropriate for sampling from a tensor where you want every + timestep between a start and end time. See sample_stacked_sequence_indices for + more flexibility. + + Args: + sequence: any tensor where the first dimension is timesteps. + num_steps: how many steps (e.g. frames) to take. If None, all steps from + start to end are considered and `is_training` has no effect. + is_training: A boolean indicates whether the graph is for training or not. + If False, the starting frame is deterministic. + repeat_sequence: A boolean indicates whether the sequence will repeat to + have enough steps for sampling. If False, a runtime error is thrown if + num_steps * stride is longer than sequence length. + stride: distance to sample between timesteps. + offset: a deterministic offset to use regardless of the is_training value. + + Returns: + Indices to gather from the sequence Tensor to get a fixed size sequence. + """ + sequence_length = tf.shape(sequence)[0] + sel_idx = tf.range(sequence_length) + + if num_steps: + if offset is None: + offset = get_sampling_offset(sequence, num_steps, is_training, stride) + + if repeat_sequence: + # Repeats sequence until num_steps are available in total. + num_repeats = tf.cast( + tf.math.ceil( + tf.math.divide( + tf.cast(num_steps * stride + offset, tf.float32), + tf.cast(sequence_length, tf.float32) + )), tf.int32) + sel_idx = tf.tile(sel_idx, [num_repeats]) + steps = tf.range(offset, offset + num_steps * stride, stride) + else: + steps = tf.range(0, sequence_length, stride) + return tf.gather(sel_idx, steps) + + +def random_sample_sequence(sequence: tf.Tensor, + num_steps: int, + stride: int = 1) -> tf.Tensor: + """Randomly sample a segment of size num_steps from a given sequence.""" + + indices = sample_or_pad_sequence_indices( + sequence=sequence, + num_steps=num_steps, + is_training=True, # Random sample. + repeat_sequence=True, # Will repeat the sequence if request more. + stride=stride, + offset=None) + indices.set_shape((num_steps,)) + output = tf.gather(sequence, indices) + return output + + +def sample_linspace_sequence(sequence: tf.Tensor, + num_windows: int, + num_steps: int, + stride: int = 1) -> tf.Tensor: + """Samples num_windows segments from sequence with linearly spaced offsets. + + The samples are concatenated in a single Tensor in order to have the same + format structure per timestep (e.g. a single frame). If num_steps * stride is + bigger than the number of timesteps, the sequence is repeated. This function + can be used in evaluation in order to extract enough segments in order to span + the entire sequence. + + Args: + sequence: Any tensor where the first dimension is timesteps. + num_windows: Number of windows retrieved from the sequence. + num_steps: Number of steps (e.g. frames) to take. + stride: Distance to sample between timesteps. + + Returns: + A single Tensor with first dimension num_windows * num_steps. The Tensor + contains the concatenated list of num_windows tensors which offsets have + been linearly spaced from input. + """ + sequence_length = tf.shape(sequence)[0] + max_offset = tf.maximum(0, sequence_length - num_steps * stride) + offsets = tf.linspace(0.0, tf.cast(max_offset, tf.float32), num_windows) + offsets = tf.cast(offsets, tf.int32) + + all_indices = [] + for i in range(num_windows): + all_indices.append( + sample_or_pad_sequence_indices( + sequence=sequence, + num_steps=num_steps, + is_training=False, + repeat_sequence=True, # Will repeat the sequence if request more. + stride=stride, + offset=offsets[i])) + + indices = tf.concat(all_indices, axis=0) + indices.set_shape((num_windows * num_steps,)) + output = tf.gather(sequence, indices) + + return output + + +def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor: + """Resizes frames so that min(height, width) is equal to min_resize. + + This function will not do anything if the min(height, width) is already equal + to min_resize. This allows to save compute time. + + Args: + frames: A Tensor of dimension [timesteps, input_h, input_w, channels]. + min_resize: Minimum size of the final image dimensions. + Returns: + A Tensor of shape [timesteps, output_h, output_w, channels] of type + frames.dtype where min(output_h, output_w) = min_resize. + """ + shape = tf.shape(frames) + input_h = shape[1] + input_w = shape[2] + + output_h = tf.maximum(min_resize, (input_h * min_resize) // input_w) + output_w = tf.maximum(min_resize, (input_w * min_resize) // input_h) + + def resize_fn(): + frames_resized = tf.image.resize(frames, (output_h, output_w)) + return tf.cast(frames_resized, frames.dtype) + + should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w), + tf.not_equal(input_h, output_h)) + frames = tf.cond(should_resize, resize_fn, lambda: frames) + + return frames + + +def process_samples(features_dict, num_frames=32, stride=1, is_training=True, + min_resize=224, crop_size=224, num_windows=1): + """Process video frames.""" + + video = features_dict['video'] + + if is_training: + assert num_windows == 1 + video = random_sample_sequence(video, num_frames, stride) + is_flipped = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) + video = tf.cond(tf.equal(is_flipped, 1), + true_fn=lambda: tf.image.flip_left_right(video), + false_fn=lambda: video) + else: + video = sample_linspace_sequence(video, num_windows, num_frames, stride) + + # Resize smallest side. + video = resize_smallest(video, min_resize) + + if is_training: + # Random crop. + video = tf.image.random_crop(video, [num_frames, crop_size, crop_size, 3]) + else: + # Central crop. + video = tf.image.resize_with_crop_or_pad(video, crop_size, crop_size) + + video = tf.cast(video, tf.float32) + video /= 255.0 # Set between [0, 1]. + + features_dict['video'] = video + return features_dict + + +def space_to_depth_batch(features_dict): + images = features_dict['video'] + _, l, h, w, c = images.shape + images = tf.reshape(images, [-1, l // 2, 2, h // 2, 2, w // 2, 2, c]) + images = tf.transpose(images, [0, 1, 3, 5, 2, 4, 6, 7]) + images = tf.reshape(images, [-1, l // 2, h // 2, w // 2, 8 * c]) + features_dict['video'] = images + return features_dict + + +def reshape_windows(features_dict, num_frames): + x = features_dict['video'] + x = tf.reshape(x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4])) + features_dict['video'] = x + return features_dict + + +def compute_accuracy_metrics(pred, gt, prefix=''): + order_pred = np.argsort(pred, axis=1) + assert len(gt.shape) == len(order_pred.shape) == 2 + top1_pred = order_pred[:, -1:] + top5_pred = order_pred[:, -5:] + top1_acc = np.mean(top1_pred == gt) + top5_acc = np.mean(np.max(top5_pred == gt, 1)) + return {prefix + 'top1': top1_acc, + prefix + 'top5': top5_acc} + + +def forward_fn(images: jnp.ndarray, + audio_spectrogram: jnp.ndarray, + word_ids: jnp.ndarray, + is_training: bool, + model_config: Dict[str, Any]): + """Forward pass of the model.""" + + # This should contain the pre-trained weights. We set it to zero because it + # will be loaded from the checkpoint. + language_model_vocab_size = 65536 + word_embedding_dim = 300 + dummy_embedding_matrix = jnp.zeros(shape=(language_model_vocab_size, + word_embedding_dim)) + + module = mm_embeddings.AudioTextVideoEmbedding( + **model_config, + word_embedding_matrix=dummy_embedding_matrix) + return module(images=images, + audio_spectrogram=audio_spectrogram, + word_ids=word_ids, + is_training=is_training)['vid_repr'] + + +def main(argv): + del argv + + sklearn_reg = 0.001 + model_config = config.get_model_config(FLAGS.checkpoint_path) + + forward = hk.without_apply_rng(hk.transform_with_state(forward_fn)) + forward_apply = jax.jit(functools.partial(forward.apply, + is_training=False, + model_config=model_config)) + + # Get the UCF101 config. + dset_config = tfds.video.ucf101.Ucf101.BUILDER_CONFIGS[FLAGS.ucf101_split] + + builder = ucf101_dataset.ModUcf101( + data_dir=FLAGS.dataset_folder, + config=dset_config) + # Create the tfrecord files (no-op if already exists) + dl_config = tfds.download.DownloadConfig(verify_ssl=False) + builder.download_and_prepare(download_config=dl_config) + + # Generate the training dataset. + train_ds = builder.as_dataset(split='train', shuffle_files=False) + train_ds = train_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda + x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=True, + min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size)) + train_ds = train_ds.batch(batch_size=FLAGS.train_batch_size) + if model_config['visual_backbone'] == 's3d': + train_ds = train_ds.map(space_to_depth_batch) + train_ds = train_ds.repeat(FLAGS.num_train_epochs) + + # Generate the test dataset. + test_ds = builder.as_dataset(split='test', shuffle_files=False) + test_ds = test_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda + x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=False, + min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size, + num_windows=FLAGS.num_test_windows)) + test_ds = test_ds.batch(batch_size=FLAGS.eval_batch_size) + test_ds = test_ds.map(lambda x: reshape_windows( # pylint: disable=g-long-lambda + x, num_frames=FLAGS.num_frames)) + + if model_config['visual_backbone'] == 's3d': + test_ds = test_ds.map(space_to_depth_batch) + test_ds = test_ds.repeat(1) + + pretrained_weights = checkpoint.load_checkpoint(FLAGS.checkpoint_path) + params = pretrained_weights['params'] + state = pretrained_weights['state'] + + # Collect training samples. + audio_frames = 96 + mel_filters = 40 + num_tokens = 16 + dummy_audio = jnp.zeros( + shape=(FLAGS.train_batch_size, audio_frames, mel_filters, 1)) + dummy_word_ids = jnp.zeros( + shape=(FLAGS.train_batch_size, num_tokens), dtype=jnp.int32) + + train_features = [] + train_labels = [] + print('Computing features on train') + training_examples = iter(tfds.as_numpy(train_ds)) + for train_ex in training_examples: + vid_representation, _ = forward_apply(params=params, + state=state, + images=train_ex['video'], + audio_spectrogram=dummy_audio, + word_ids=dummy_word_ids) + train_labels.append(train_ex['label']) + train_features.append(vid_representation) + if len(train_labels) % 50 == 0: + print(f'Processed {len(train_labels)} examples.') + + train_labels = np.concatenate(train_labels, axis=0) + train_features = np.concatenate(train_features, axis=0) + print(f'Finish collecting train features of shape {train_features.shape}') + + # Collect test samples. + dummy_audio = jnp.zeros( + shape=(FLAGS.eval_batch_size, audio_frames, mel_filters, 1)) + dummy_word_ids = jnp.zeros( + shape=(FLAGS.eval_batch_size, num_tokens), dtype=jnp.int32) + + test_features = [] + test_labels = [] + print('Computing features on test') + test_examples = iter(tfds.as_numpy(test_ds)) + for test_ex in test_examples: + vid_representation_test, _ = forward_apply(params=params, + state=state, + images=test_ex['video'], + audio_spectrogram=dummy_audio, + word_ids=dummy_word_ids) + test_labels.append(test_ex['label']) + test_features.append(vid_representation_test) + if len(test_labels) % 50 == 0: + print(f'Processed {len(test_labels)} examples.') + + test_features = np.concatenate(test_features, axis=0) + test_labels = np.concatenate(test_labels, axis=0) + print(f'Finish collecting test features of shape {test_features.shape}') + + # Train classifier + print('Training linear classifier!') + classifier = sklearn.svm.LinearSVC(C=sklearn_reg) + scaler = preprocessing.StandardScaler().fit(train_features) + train_features = scaler.transform(train_features) + classifier.fit(train_features, train_labels.ravel()) + print('Training done !') + + # Evaluation. + test_features = scaler.transform(test_features) + print('Running inference on train') + pred_train = classifier.decision_function(train_features) + print('Running inference on test') + pred_test = classifier.decision_function(test_features) + if FLAGS.num_test_windows > 1: + pred_test = np.reshape( + pred_test, (test_labels.shape[0], -1, pred_test.shape[1])) + pred_test = pred_test.mean(axis=1) + + # Compute accuracies. + metrics = compute_accuracy_metrics(pred_train, train_labels[:, None], + prefix='train_') + metrics.update( + compute_accuracy_metrics(pred_test, test_labels[:, None], prefix='test_')) + print(metrics) + +if __name__ == '__main__': + app.run(main) diff --git a/mmv/imgs/mmv_fig.png b/mmv/imgs/mmv_fig.png new file mode 100644 index 0000000..6a5551f Binary files /dev/null and b/mmv/imgs/mmv_fig.png differ diff --git a/mmv/models/mm_embeddings.py b/mmv/models/mm_embeddings.py new file mode 100644 index 0000000..0de0e15 --- /dev/null +++ b/mmv/models/mm_embeddings.py @@ -0,0 +1,519 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3. +"""Model for text-video-audio embeddings.""" + +from typing import Any, Dict + +import haiku as hk +import jax +import jax.numpy as jnp + +from mmv.models import normalization +from mmv.models import resnet +from mmv.models import s3d +from mmv.models import tsm_resnet + + +_DEFAULT_CFG_AUDTXT = { + "totxt_head_mode": "linear", + "toaud_head_mode": "linear", + "toaud_bn_after_proj": False, + "totxt_bn_after_proj": False, + "embedding_dim": 512} + +_DEFAULT_CFG_VIDAUD = { + "tovid_head_mode": "linear", + "toaud_head_mode": "mlp@512", + "tovid_bn_after_proj": False, + "toaud_bn_after_proj": True, + "embedding_dim": 512} + +_DEFAULT_CFG_VIDTXT = { + "tovid_head_mode": "linear", + "totxt_head_mode": "mlp@512", + "tovid_bn_after_proj": False, + "totxt_bn_after_proj": True, + "embedding_dim": 512} + +_DEFAULT_CFG_BN = {"decay_rate": 0.9, "eps": 1e-5, + "create_scale": True, "create_offset": True} + + +def _setkey_if_not_exists(d, key, value): + if key not in d: + d[key] = value + + +class AudioTextVideoEmbedding(hk.Module): + """Module to fuse audio, text and video for joint embedding learning.""" + + def __init__( + self, + # Language parameters. + word_embedding_matrix, + sentence_dim=2048, + # Audio parameters. + audio_backbone="resnet18", + audio_model_kwargs=None, + # Vision parameters. + visual_backbone="s3d", + vision_model_kwargs=None, + # Common parameters. + mm_embedding_graph="fac_relu", + use_xreplica_bn=True, + bn_config_proj=None, + config_video_text=None, + config_video_audio=None, + config_audio_text=None, + use_audio_text=False, + name="audio_text_video_model"): + """Initialize the AudioTextVideoEmbedding class. + + Args: + word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words. + sentence_dim: The dimension of the sentence representation. + audio_backbone: Backbone for audio. + audio_model_kwargs: Other specific parameters to pass to the audio + module. + visual_backbone: The video backbone. + vision_model_kwargs: Other specific parameters to pass to the vision + module. + mm_embedding_graph: Embedding graph merging strategy. + Can be `shared`, `disjoint` or `fac` (fac can be followed by an + activation function name e.g. `fac_relu`). + use_xreplica_bn: Whether or not to use the cross replica batch norm. + bn_config_proj: BN config of the projection heads. + config_video_text: Config for the video and the text branches. + config_video_audio: Config for the video and the audio branches. + config_audio_text: Config for the audio and the text branches. + use_audio_text: Whether or not the audio text branch is used during + training. + name: graph name. + """ + super(AudioTextVideoEmbedding, self).__init__(name=name) + # Audio parameters. + self._audio_backbone = audio_backbone + self._audio_model_kwargs = audio_model_kwargs + + # Language parameters. + self._sentence_dim = sentence_dim + self._word_embedding_matrix = word_embedding_matrix + + # Vision parameters. + self._visual_backbone = visual_backbone + self._vision_model_kwargs = vision_model_kwargs + + # Joint parameters. + self._use_xreplica_bn = use_xreplica_bn + if self._use_xreplica_bn: + self._normalizer_name = "cross_replica_batch_norm" + else: + self._normalizer_name = "batch_norm" + + # Projection head parameters. + if config_video_text is None: + config_video_text = _DEFAULT_CFG_VIDTXT + for k, v in _DEFAULT_CFG_VIDTXT.items(): + _setkey_if_not_exists(config_video_text, k, v) + self._cfg_vid_txt = config_video_text + + if config_video_audio is None: + config_video_audio = _DEFAULT_CFG_VIDAUD + for k, v in _DEFAULT_CFG_VIDAUD.items(): + _setkey_if_not_exists(config_video_audio, k, v) + self._cfg_vid_aud = config_video_audio + + if config_audio_text is None: + config_audio_text = _DEFAULT_CFG_AUDTXT + for k, v in _DEFAULT_CFG_AUDTXT.items(): + _setkey_if_not_exists(config_audio_text, k, v) + self._cfg_aud_txt = config_audio_text + self._use_audio_text = use_audio_text + + self._mm_embedding_graph = mm_embedding_graph + self._use_separate_heads = ( + mm_embedding_graph == "disjoint" or + mm_embedding_graph.startswith("fac")) + + self._bn_config_proj = bn_config_proj or _DEFAULT_CFG_BN + + def _get_pair_embedding_heads(self, + embedding_dim_1, embedding_dim_2, + mode1, mode2, + use_bn_out1, use_bn_out2, + name1, name2): + embd1_module = EmbeddingModule( + embedding_dim_1, + mode=mode1, + use_bn_out=use_bn_out1, + bn_config=self._bn_config_proj, + use_xreplica_bn=self._use_xreplica_bn, + name=name1) + if self._use_separate_heads: + embd2_module = EmbeddingModule( + embedding_dim_2, + mode=mode2, + use_bn_out=use_bn_out2, + use_xreplica_bn=self._use_xreplica_bn, + bn_config=self._bn_config_proj, + name=name2) + else: + assert embedding_dim_1 == embedding_dim_2, ( + "Using shared heads but inconsistent embedding dims where provided.") + assert mode1 == mode2, ( + "Using shared heads but inconsistent modes where provided.") + assert use_bn_out1 == use_bn_out2, ( + "Using shared heads but inconsistent bn conf where provided.") + embd2_module = embd1_module + return embd1_module, embd2_module + + def _activate_interaction(self, inputs, activation_fn, is_training, + activation_module=None): + """Activation function for the interaction modules.""" + if activation_fn == "relu": + inputs = jax.nn.relu(inputs) + elif activation_fn == "bnrelu": + if activation_module is None: + activation_module = normalization.get_normalize_fn( + normalizer_name=self._normalizer_name, + normalizer_kwargs=self._bn_config_proj) + inputs = activation_module(inputs, is_training=is_training) + inputs = jax.nn.relu(inputs) + else: + raise ValueError(f"{activation_fn} not supported.") + return inputs, activation_module + + def __call__(self, + images, + audio_spectrogram, + word_ids, + is_training, + return_intermediate_audio=False): + """Computes video, text and audio embeddings. + + Args: + images: The videos tensor of shape [B1, T, H, W, 3] where B1 is the batch + size, T is the number of frames per clip, H the height, W the width + and 3 the rgb channels. + audio_spectrogram: The audio tensor of shape [B2, T', F] where B2 is the + batch size, T' is the number of temporal frames, F is the number of + frequency frames. + word_ids: If words_embeddings is set to None, it will use the word indices + input instead so that we can compute the word embeddings within the + model graph. The expected shape is [B3, N, D] where B3 is the batch size + and N the maximum number of words per sentence. + is_training: Whether or not to activate the graph in training mode. + return_intermediate_audio: Return audio intermediate representation. + + Returns: + if return_intermediate_audio = True + audio_representation: the 4-dim audio representation taken before + averaging over spatial dims in the Resnet. + else + visual_embd: a dict containing the video embeddings in audio and text + of shape [B1, d_embd]. + audio_embd: a dict containing the audio embeddings in video and text + of shape [B2, d_embd]. + txt_embd: a dict containing the text embeddings in video and audio + of shape[B3, d_embd]. + visual_representation: the video rep of shape [B1, d_visual]. + audio_representation: the audio rep of shape [B2, d_audio]. + """ + # Computes the visual representation. + video_cnn = VisualModule(backbone=self._visual_backbone, + use_xreplica_bn=self._use_xreplica_bn, + model_kwargs=self._vision_model_kwargs) + visual_representation = video_cnn(images, is_training=is_training) + + # Projection heads: Video -> Text and Video -> Audio. + vid2txt_embd_module, vid2aud_embd_module = self._get_pair_embedding_heads( + embedding_dim_1=self._cfg_vid_txt["embedding_dim"], + embedding_dim_2=self._cfg_vid_aud["embedding_dim"], + mode1=self._cfg_vid_txt["totxt_head_mode"], + mode2=self._cfg_vid_aud["toaud_head_mode"], + use_bn_out1=self._cfg_vid_txt["totxt_bn_after_proj"], + use_bn_out2=self._cfg_vid_aud["toaud_bn_after_proj"], + name1="vis_embd", + name2="vid2audio_embd") + + video_embd = {} + if self._mm_embedding_graph in ["shared", "disjoint"]: + video_embd["toaud"] = vid2aud_embd_module(visual_representation, + is_training=is_training) + video_embd["totxt"] = vid2txt_embd_module(visual_representation, + is_training=is_training) + elif self._mm_embedding_graph.startswith("fac"): + # Activation function if specificed in the name, e.g. fac_relu. + activation_fn = None + if len(self._mm_embedding_graph.split("_")) == 2: + activation_fn = self._mm_embedding_graph.split("_")[1] + + video_embd["toaud"] = vid2aud_embd_module(visual_representation, + is_training=is_training) + fine_rep = video_embd["toaud"] + # Eventually activate the fine grained representation. + if activation_fn: + fine_rep, activation_module = self._activate_interaction( + inputs=fine_rep, activation_fn=activation_fn, + is_training=is_training) + + video_embd["totxt"] = vid2txt_embd_module(fine_rep, + is_training=is_training) + else: + raise ValueError( + f"{self._mm_embedding_graph} is not a valid MM embedding graph.") + + # Computes the audio representation. + audio_cnn = AudioModule(backbone=self._audio_backbone, + use_xreplica_bn=self._use_xreplica_bn, + model_kwargs=self._audio_model_kwargs) + if return_intermediate_audio: + return audio_cnn(audio_spectrogram, + is_training=is_training, + return_intermediate=True) + + audio_representation = audio_cnn(audio_spectrogram, is_training=is_training) + + # Projection heads: Audio -> Video and Audio -> Text. + aud2vid_embd_module, aud2txt_embd_module = self._get_pair_embedding_heads( + embedding_dim_1=self._cfg_vid_aud["embedding_dim"], + embedding_dim_2=self._cfg_aud_txt["embedding_dim"], + mode1=self._cfg_vid_aud["tovid_head_mode"], + mode2=self._cfg_aud_txt["totxt_head_mode"], + use_bn_out1=self._cfg_vid_aud["tovid_bn_after_proj"], + use_bn_out2=self._cfg_aud_txt["totxt_bn_after_proj"], + name1="audio_embd", + name2="audio2txt_embd") + audio_embd = {} + + audio_embd["tovid"] = aud2vid_embd_module(audio_representation, + is_training=is_training) + + # Computes the projection to the text domain depending on the MM graph mode. + if (self._mm_embedding_graph.startswith("fac") and + (self._use_audio_text or (not is_training))): + # In case the audio text branch is not used during training, we do that + # only at eval time (is_training=False) in order to not pollute the BN + # stats in vid2txt_embd_module with audio features during training. + fine_rep_audio = audio_embd["tovid"] + if activation_fn: + fine_rep_audio, _ = self._activate_interaction( + inputs=fine_rep_audio, activation_fn=activation_fn, + is_training=is_training, activation_module=activation_module) + audio_embd["totxt"] = vid2txt_embd_module(fine_rep_audio, + is_training=is_training) + else: + audio_embd["totxt"] = aud2txt_embd_module(audio_representation, + is_training=is_training) + + # Computes the text representation. + txt_representation = TextModule( + sentence_dim=self._sentence_dim, + word_embedding_matrix=self._word_embedding_matrix)( + word_ids, is_training=is_training) + + # Projection heads: Text -> Video and Text -> Audio. + txt2vid_embd_module, txt2aud_embd_module = self._get_pair_embedding_heads( + embedding_dim_1=self._cfg_vid_txt["embedding_dim"], + embedding_dim_2=self._cfg_aud_txt["embedding_dim"], + mode1=self._cfg_vid_txt["tovid_head_mode"], + mode2=self._cfg_aud_txt["toaud_head_mode"], + use_bn_out1=self._cfg_vid_txt["tovid_bn_after_proj"], + use_bn_out2=self._cfg_aud_txt["toaud_bn_after_proj"], + name1="txt_embd", + name2="txt2audio_embd") + txt_embd = {} + txt_embd["tovid"] = txt2vid_embd_module(txt_representation, + is_training=is_training) + txt_embd["toaud"] = txt2aud_embd_module(txt_representation, + is_training=is_training) + + return { + "vid_embd": video_embd, + "aud_embd": audio_embd, + "txt_embd": txt_embd, + "vid_repr": visual_representation, + "aud_repr": audio_representation, + } + + +class EmbeddingModule(hk.Module): + """Final Embedding module.""" + + def __init__(self, + embedding_dim: int, + mode: str = "linear", + use_bn_out: bool = False, + bn_config: Dict[str, Any] = None, + use_xreplica_bn: bool = True, + name="embedding_module"): + self._embedding_dim = embedding_dim + self._use_bn_out = use_bn_out + self._mode = mode + # Set default BN config. + bn_config = bn_config or _DEFAULT_CFG_BN + if use_xreplica_bn: + normalizer_name = "cross_replica_batch_norm" + else: + normalizer_name = "batch_norm" + self._batch_norm = normalization.get_normalize_fn( + normalizer_name=normalizer_name, + normalizer_kwargs=bn_config) + + super(EmbeddingModule, self).__init__(name=name) + + def __call__(self, input_feature, is_training): + if self._mode == "linear": + proj = hk.Linear(self._embedding_dim, name="final_projection") + embedding = proj(input_feature) + elif self._mode.startswith("mlp"): + if "@" not in self._mode: + raise ValueError( + ("Please specify the inner dimensions of the MLP with `@` symbol" + "e.g. mlp@512 or mlp@512@256 for a 2 layer MLP.")) + inner_dims = [int(dim) for dim in self._mode.split("@")[1:]] + embedding = input_feature + for inner_dim in inner_dims: + embedding = hk.Linear(inner_dim, with_bias=True, + name="final_projection_inner")(embedding) + if not self._mode.startswith("mlp_nobn"): + embedding = self._batch_norm(embedding, is_training=is_training) + embedding = jax.nn.relu(embedding) + + # Final projection. + embedding = hk.Linear(self._embedding_dim, name="final_projection", + with_bias=not self._use_bn_out)(embedding) + else: + raise NotImplementedError + + if self._use_bn_out: + embedding = self._batch_norm(embedding, is_training=is_training) + return embedding + + +class VisualModule(hk.Module): + """The visual module selects which CNN backbone to connect to the graph.""" + + def __init__(self, + use_xreplica_bn=True, + backbone="s3d", + model_kwargs=None, + name="visual_module"): + self._backbone = backbone + super(VisualModule, self).__init__(name=name) + if model_kwargs is None: + model_kwargs = {} + bn_config = model_kwargs.get("bn_config", _DEFAULT_CFG_BN) + if use_xreplica_bn: + normalizer_name = "cross_replica_batch_norm" + else: + normalizer_name = "batch_norm" + + normalize_fn = normalization.get_normalize_fn( + normalizer_name=normalizer_name, + normalizer_kwargs=bn_config) + if backbone == "s3d": + self._cnn = s3d.S3D(normalize_fn=normalize_fn) + elif backbone == "resnet50tsm": + width_mult = model_kwargs.get("width_mult", 1) + self._cnn = tsm_resnet.TSMResNetV2( + normalize_fn=normalize_fn, + depth=50, + num_frames=model_kwargs["n_frames"], + width_mult=width_mult) + else: + raise NotImplementedError + + def __call__(self, images, is_training): + """Connects graph to images.""" + features = self._cnn(images, is_training=is_training) + return features + + +class AudioModule(hk.Module): + """The audio module selects which CNN backbone to connect to the graph.""" + + def __init__(self, + backbone="resnet18", + use_xreplica_bn=True, + model_kwargs=None, + name="audio_module"): + super(AudioModule, self).__init__(name=name) + model_kwargs = model_kwargs or {} + bn_config = model_kwargs.get("bn_config", _DEFAULT_CFG_BN) + backbone_to_depth = { + "resnet18": 18, + "resnet34": 34, + "resnet50": 50, + "resnet101": 101 + } + assert backbone in backbone_to_depth, ( + f"backbone should be in {backbone_to_depth.keys()}") + + if use_xreplica_bn: + normalizer_name = "cross_replica_batch_norm" + else: + normalizer_name = "batch_norm" + + self._cnn = resnet.ResNetV2( + depth=backbone_to_depth[backbone], + normalize_fn=normalization.get_normalize_fn( + normalizer_name=normalizer_name, + normalizer_kwargs=bn_config), + num_classes=None) + + def __call__(self, + audio_spectrogram, + is_training, + return_intermediate=False): + """Connects graph to audio spectrogram.""" + final_endpoint = "output" + if return_intermediate: + final_endpoint = "last_conv" + + return self._cnn(audio_spectrogram, + is_training=is_training, + final_endpoint=final_endpoint) + + +class TextModule(hk.Module): + """Text module computes the sentences representation.""" + + def __init__(self, + word_embedding_matrix, + sentence_dim=1024, + name="text_module"): + """Initialize text module. + + Args: + word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words. + sentence_dim: dimension of sentence representation. + name: module name. + """ + super(TextModule, self).__init__(name=name) + self._word_embedding_module = hk.Embed( + embedding_matrix=word_embedding_matrix) + self._conv1d_module = hk.Conv1D(sentence_dim, 1, name="text_conv1") + + def __call__(self, word_ids, is_training): + """Connects graph to sentence representation.""" + word_embeddings = self._word_embedding_module(word_ids) + word_embeddings = jax.lax.stop_gradient(word_embeddings) + output = self._conv1d_module(word_embeddings) + output = jax.nn.relu(output) + output = jnp.amax(output, axis=1) + return output diff --git a/mmv/models/normalization.py b/mmv/models/normalization.py new file mode 100644 index 0000000..639833f --- /dev/null +++ b/mmv/models/normalization.py @@ -0,0 +1,143 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalize functions constructors.""" + +from typing import Any, Dict, Optional, Sequence, Union + +import haiku as hk +from jax import numpy as jnp + +from mmv.models import types + + +class _BatchNorm(hk.BatchNorm): + """A `hk.BatchNorm` with adapted default arguments.""" + + def __init__(self, + create_scale: bool = True, + create_offset: bool = True, + decay_rate: float = 0.9, + eps: float = 1e-5, + test_local_stats: bool = False, + **kwargs): + # Check args. + if kwargs.get('cross_replica_axis', None) is not None: + raise ValueError( + 'Attempting to use \'batch_norm\' normalizer, but specifying ' + '`cross_replica_axis`. If you want this behavior use ' + '`normalizer=\'cross_replica_batch_norm\'` directly.') + + self._test_local_stats = test_local_stats + super().__init__(create_scale=create_scale, + create_offset=create_offset, + decay_rate=decay_rate, + eps=eps, + **kwargs) + + def __call__(self, + x: types.TensorLike, + is_training: bool) -> jnp.ndarray: + return super().__call__(x, is_training, + test_local_stats=self._test_local_stats) + + +class _CrossReplicaBatchNorm(hk.BatchNorm): + """A `hk.BatchNorm` with adapted default arguments for cross replica.""" + + def __init__(self, + create_scale: bool = True, + create_offset: bool = True, + decay_rate: float = 0.9, + eps: float = 1e-5, + test_local_stats: bool = False, + **kwargs): + # Check args. + if 'cross_replica_axis' in kwargs and kwargs['cross_replica_axis'] is None: + raise ValueError( + 'Attempting to use \'cross_replica_batch_norm\' normalizer, but ' + 'specifying `cross_replica_axis` to be None. If you want this ' + 'behavior use `normalizer=\'batch_norm\'` directly.') + + self._test_local_stats = test_local_stats + kwargs['cross_replica_axis'] = kwargs.get('cross_replica_axis', 'i') + super().__init__(create_scale=create_scale, + create_offset=create_offset, + decay_rate=decay_rate, + eps=eps, + **kwargs) + + def __call__(self, + x: types.TensorLike, + is_training: bool) -> jnp.ndarray: + return super().__call__(x, is_training, + test_local_stats=self._test_local_stats) + + +class _LayerNorm(hk.LayerNorm): + """A `hk.LayerNorm` accepting (and discarding) an `is_training` argument.""" + + def __init__(self, + axis: Union[int, Sequence[int]] = (1, 2), + create_scale: bool = True, + create_offset: bool = True, + **kwargs): + super().__init__(axis=axis, + create_scale=create_scale, + create_offset=create_offset, + **kwargs) + + def __call__(self, + x: types.TensorLike, + is_training: bool) -> jnp.ndarray: + del is_training # Unused. + return super().__call__(x) + + +_NORMALIZER_NAME_TO_CLASS = { + 'batch_norm': _BatchNorm, + 'cross_replica_batch_norm': _CrossReplicaBatchNorm, + 'layer_norm': _LayerNorm, +} + + +def get_normalize_fn( + normalizer_name: str = 'batch_norm', + normalizer_kwargs: Optional[Dict[str, Any]] = None, +) -> types.NormalizeFn: + """Handles NormalizeFn creation. + + These functions are expected to be used as part of Haiku model. On each + application of the returned normalization_fn, a new Haiku layer will be added + to the model. + + Args: + normalizer_name: The name of the normalizer to be constructed. + normalizer_kwargs: The kwargs passed to the normalizer constructor. + + Returns: + A `types.NormalizeFn` that when applied will create a new layer. + + Raises: + ValueError: If `normalizer_name` is unknown. + """ + # Check args. + if normalizer_name not in _NORMALIZER_NAME_TO_CLASS: + raise ValueError(f'Unrecognized `normalizer_name` {normalizer_name}.') + + normalizer_class = _NORMALIZER_NAME_TO_CLASS[normalizer_name] + normalizer_kwargs = normalizer_kwargs or dict() + + return lambda *a, **k: normalizer_class(**normalizer_kwargs)(*a, **k) # pylint: disable=unnecessary-lambda diff --git a/mmv/models/resnet.py b/mmv/models/resnet.py new file mode 100644 index 0000000..a395593 --- /dev/null +++ b/mmv/models/resnet.py @@ -0,0 +1,329 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3. +"""ResNet V2 modules. + + Equivalent to hk.Resnet except accepting a final_endpoint to return + intermediate activations. +""" + +from typing import Optional, Sequence, Text, Type, Union + +import haiku as hk +import jax +import jax.numpy as jnp + +from mmv.models import types + + +class BottleneckBlock(hk.Module): + """Implements a bottleneck residual block (ResNet50 and ResNet101).""" + + # pylint:disable=g-bare-generic + def __init__(self, + channels: int, + stride: Union[int, Sequence[int]], + use_projection: bool, + normalize_fn: Optional[types.NormalizeFn] = None, + name: Optional[Text] = None): + super(BottleneckBlock, self).__init__(name=name) + self._channels = channels + self._stride = stride + self._use_projection = use_projection + self._normalize_fn = normalize_fn + + if self._use_projection: + self._proj_conv = hk.Conv2D( + output_channels=channels, + kernel_shape=1, + stride=stride, + with_bias=False, + padding='SAME', + name='shortcut_conv') + + self._conv_0 = hk.Conv2D( + output_channels=channels // 4, + kernel_shape=1, + stride=1, + with_bias=False, + padding='SAME', + name='conv_0') + + self._conv_1 = hk.Conv2D( + output_channels=channels // 4, + kernel_shape=3, + stride=stride, + with_bias=False, + padding='SAME', + name='conv_1') + + self._conv_2 = hk.Conv2D( + output_channels=channels, + kernel_shape=1, + stride=1, + with_bias=False, + padding='SAME', + name='conv_2') + + def __call__(self, + inputs, + is_training): + net = inputs + shortcut = inputs + + for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]): + if self._normalize_fn is not None: + net = self._normalize_fn(net, is_training=is_training) + net = jax.nn.relu(net) + if i == 0 and self._use_projection: + shortcut = self._proj_conv(net) + + # Now do the convs. + net = conv_i(net) + + return net + shortcut + + +class BasicBlock(hk.Module): + """Implements a basic residual block (ResNet18 and ResNet34).""" + + # pylint:disable=g-bare-generic + def __init__(self, + channels: int, + stride: Union[int, Sequence[int]], + use_projection: bool, + normalize_fn: Optional[types.NormalizeFn] = None, + name: Optional[Text] = None): + super(BasicBlock, self).__init__(name=name) + self._channels = channels + self._stride = stride + self._use_projection = use_projection + self._normalize_fn = normalize_fn + + if self._use_projection: + self._proj_conv = hk.Conv2D( + output_channels=channels, + kernel_shape=1, + stride=stride, + with_bias=False, + padding='SAME', + name='shortcut_conv') + + self._conv_0 = hk.Conv2D( + output_channels=channels, + kernel_shape=1, + stride=1, + with_bias=False, + padding='SAME', + name='conv_0') + + self._conv_1 = hk.Conv2D( + output_channels=channels, + kernel_shape=3, + stride=stride, + with_bias=False, + padding='SAME', + name='conv_1') + + def __call__(self, + inputs, + is_training): + net = inputs + shortcut = inputs + + for i, conv_i in enumerate([self._conv_0, self._conv_1]): + if self._normalize_fn is not None: + net = self._normalize_fn(net, is_training=is_training) + net = jax.nn.relu(net) + if i == 0 and self._use_projection: + shortcut = self._proj_conv(net) + + # Now do the convs. + net = conv_i(net) + + return net + shortcut + + +class ResNetUnit(hk.Module): + """Unit (group of blocks) for ResNet.""" + + # pylint:disable=g-bare-generic + def __init__(self, + channels: int, + num_blocks: int, + stride: Union[int, Sequence[int]], + block_module: Type[BottleneckBlock], + normalize_fn: Optional[types.NormalizeFn] = None, + name: Optional[Text] = None, + remat: bool = False): + super(ResNetUnit, self).__init__(name=name) + self._channels = channels + self._num_blocks = num_blocks + self._stride = stride + self._normalize_fn = normalize_fn + self._block_module = block_module + self._remat = remat + + def __call__(self, + inputs, + is_training): + + input_channels = inputs.shape[-1] + + self._blocks = [] + for id_block in range(self._num_blocks): + use_projection = id_block == 0 and self._channels != input_channels + self._blocks.append( + self._block_module( + channels=self._channels, + stride=self._stride if id_block == 0 else 1, + use_projection=use_projection, + normalize_fn=self._normalize_fn, + name='block_%d' % id_block)) + + net = inputs + for block in self._blocks: + if self._remat: + # Note: we can ignore cell-var-from-loop because the lambda is evaluated + # inside every iteration of the loop. This is needed to go around the + # way variables are passed to jax.remat. + net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop + else: + net = block(net, is_training=is_training) + return net + + +class ResNetV2(hk.Module): + """ResNetV2 model.""" + + # Endpoints of the model in order. + VALID_ENDPOINTS = ( + 'resnet_stem', + 'resnet_unit_0', + 'resnet_unit_1', + 'resnet_unit_2', + 'resnet_unit_3', + 'last_conv', + 'output', + ) + + # pylint:disable=g-bare-generic + def __init__(self, + depth=50, + num_classes: Optional[int] = 1000, + width_mult: int = 1, + normalize_fn: Optional[types.NormalizeFn] = None, + name: Optional[Text] = None, + remat: bool = False): + """Creates ResNetV2 Haiku module. + + Args: + depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202). + num_classes: (int) Number of outputs in final layer. If None will not add + a classification head and will return the output embedding. + width_mult: multiplier for channel width. + normalize_fn: normalization function, see helpers/utils.py + name: Name of the module. + remat: Whether to rematerialize intermediate activations (saves memory). + """ + super(ResNetV2, self).__init__(name=name) + self._normalize_fn = normalize_fn + self._num_classes = num_classes + self._width_mult = width_mult + + self._strides = [1, 2, 2, 2] + num_blocks = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + 200: [3, 24, 36, 3], + } + if depth not in num_blocks: + raise ValueError( + f'`depth` should be in {list(num_blocks.keys())} ({depth} given).') + self._num_blocks = num_blocks[depth] + + if depth >= 50: + self._block_module = BottleneckBlock + self._channels = [256, 512, 1024, 2048] + else: + self._block_module = BasicBlock + self._channels = [64, 128, 256, 512] + + self._initial_conv = hk.Conv2D( + output_channels=64 * self._width_mult, + kernel_shape=7, + stride=2, + with_bias=False, + padding='SAME', + name='initial_conv') + + if remat: + self._initial_conv = hk.remat(self._initial_conv) + + self._block_groups = [] + for i in range(4): + self._block_groups.append( + ResNetUnit( + channels=self._channels[i] * self._width_mult, + num_blocks=self._num_blocks[i], + block_module=self._block_module, + stride=self._strides[i], + normalize_fn=self._normalize_fn, + name='block_group_%d' % i, + remat=remat)) + + if num_classes is not None: + self._logits_layer = hk.Linear( + output_size=num_classes, w_init=jnp.zeros, name='logits') + + def __call__(self, inputs, is_training, final_endpoint='output'): + self._final_endpoint = final_endpoint + net = self._initial_conv(inputs) + net = hk.max_pool( + net, window_shape=(1, 3, 3, 1), + strides=(1, 2, 2, 1), + padding='SAME') + end_point = 'resnet_stem' + if self._final_endpoint == end_point: + return net + + for i_group, block_group in enumerate(self._block_groups): + net = block_group(net, is_training=is_training) + end_point = f'resnet_unit_{i_group}' + if self._final_endpoint == end_point: + return net + + end_point = 'last_conv' + if self._final_endpoint == end_point: + return net + + if self._normalize_fn is not None: + net = self._normalize_fn(net, is_training=is_training) + net = jax.nn.relu(net) + + # The actual representation + net = jnp.mean(net, axis=[1, 2]) + + assert self._final_endpoint == 'output' + if self._num_classes is None: + # If num_classes was None, we just return the output + # of the last block, without fully connected layer. + return net + + return self._logits_layer(net) diff --git a/mmv/models/s3d.py b/mmv/models/s3d.py new file mode 100644 index 0000000..db32fbc --- /dev/null +++ b/mmv/models/s3d.py @@ -0,0 +1,503 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Haiku S3D model.""" + +import collections +from typing import Optional, Sequence + +import haiku as hk +import jax +from jax import numpy as jnp + +from mmv.models import types + + +class _MaxPool(hk.MaxPool): + """A `hk.MaxPool` accepting (and discarding) an `is_training` argument.""" + + def __call__(self, + x: types.TensorLike, + is_training: bool = True) -> jnp.ndarray: + del is_training # Unused. + return super().__call__(x) + + +def self_gating(inputs: types.TensorLike) -> jnp.ndarray: + """Feature gating as used in S3D-G. + + Transforms the input features by aggregating features from all spatial and + temporal locations, and applying gating conditioned on the aggregated + features. More details can be found at: https://arxiv.org/abs/1712.04851. + + Args: + inputs: A 5-D float array of shape `[B, T, H, W, C]`. + + Returns: + A tensor with the same shape as input_tensor. + + Raises: + ValueError: If `inputs` has the wrong shape. + """ + if inputs.ndim != 5: + raise ValueError( + f'Expected an input of shape `[B, T, H, W, C]` but got {inputs.shape}.') + + input_shape = inputs.shape + num_channels = input_shape[4] + spatiotemporal_average = jnp.mean(inputs, axis=(1, 2, 3)) + weights = hk.Linear(num_channels, name='self_gating')(spatiotemporal_average) + weights = jax.nn.sigmoid(weights) + return jnp.multiply(weights[:, None, None, None, :], inputs) + + +class SUnit3D(hk.Module): + """Base 3d Unit combining Conv3d + Batch Norm + non-linearity.""" + + def __init__( + self, + output_channels: int, + kernel_shape: Sequence[int] = (1, 1, 1), + stride: Sequence[int] = (1, 1, 1), + with_bias: bool = False, + separable: bool = False, + normalize_fn: Optional[types.NormalizeFn] = None, + activation_fn: Optional[types.ActivationFn] = jax.nn.relu, + self_gating_fn: Optional[types.GatingFn] = None, + name='SUnit3D'): + """Initializes the SUnit3D module. + + Args: + output_channels: Number of output channels. + kernel_shape: The shape of the kernel. A sequence of length 3. + stride: Stride for the kernel. A sequence of length 3. + with_bias: Whether to add a bias to the convolution. + separable: Whether to use separable. + normalize_fn: Function used for normalization. + activation_fn: Function used as non-linearity. + self_gating_fn: Function used for self-gating. + name: The name of the module. + + Raises: + ValueError: If `kernel_shape` or `stride` has the wrong shape. + """ + super().__init__(name=name) + + # Check args. + if len(kernel_shape) != 3: + raise ValueError( + 'Given `kernel_shape` must have length 3 but has length ' + f'{len(kernel_shape)}.') + if len(stride) != 3: + raise ValueError( + f'Given `stride` must have length 3 but has length {len(stride)}.') + + self._normalize_fn = normalize_fn + self._activation_fn = activation_fn + self._self_gating_fn = self_gating_fn + + k0, k1, k2 = kernel_shape + if separable and k1 != 1: + spatial_kernel_shape = [1, k1, k2] + temporal_kernel_shape = [k0, 1, 1] + s0, s1, s2 = stride + spatial_stride = [1, s1, s2] + temporal_stride = [s0, 1, 1] + self._convolutions = [ + hk.Conv3D( + output_channels=output_channels, + kernel_shape=spatial_kernel_shape, + stride=spatial_stride, + padding='SAME', + with_bias=with_bias), + hk.Conv3D( + output_channels=output_channels, + kernel_shape=temporal_kernel_shape, + stride=temporal_stride, + padding='SAME', + with_bias=with_bias) + ] + + else: + self._convolutions = [ + hk.Conv3D( + output_channels=output_channels, + kernel_shape=kernel_shape, + stride=stride, + padding='SAME', + with_bias=with_bias)] + + def __call__( + self, + inputs: types.TensorLike, + is_training: bool) -> jnp.ndarray: + """Connects the module to inputs. + + Args: + inputs: A 5-D float array of shape `[B, T, H, W, C]`. + is_training: Whether to use training mode. + + Returns: + A 5-D float array of shape `[B, new_t, new_h, new_w, output_channels]`. + """ + x = inputs + for conv in self._convolutions: + x = conv(x) + if self._normalize_fn is not None: + x = self._normalize_fn(x, is_training=is_training) + if self._activation_fn is not None: + x = self._activation_fn(x) + if self._self_gating_fn: + x = self._self_gating_fn(x) + return x + + +class InceptionBlockV13D(hk.Module): + """A 3D Inception v1 block. + + This allows use of separable 3D convolutions and self-gating, as described in: + + Rethinking Spatiotemporal Feature Learning For Video Understanding. + Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu and Kevin Murphy. + https://arxiv.org/abs/1712.04851. + """ + + def __init__(self, + output_channels: Sequence[int], + normalize_fn: Optional[types.NormalizeFn], + temporal_kernel_size: int = 3, + self_gating_fn: Optional[types.GatingFn] = None, + name: str = 'InceptionBlockV13D'): + """Initializes the InceptionBlockV13D module. + + Args: + output_channels: The size of the output channels of each block, ordered as + [Conv2d_0a_1x1, Conv2d_0a_1x1, Conv2d_0b_3x3, Conv2d_0a_1x1, + Conv2d_0b_3x3, Conv2d_0b_1x1] + normalize_fn: Function used for normalization. + temporal_kernel_size: The size of the temporal convolutional filters in + the conv3d_spatiotemporal blocks. + self_gating_fn: Function which optionally performs self-gating. If `None`, + no self-gating is applied. + name: The name of the module. + + Raises: + ValueError: If `output_channels` has the wrong shape. + """ + super().__init__(name=name) + + # Check args. + if len(output_channels) != 6: + raise ValueError( + 'Given `output_channels` must have length 6 but has length ' + f'{len(output_channels)}.') + + self._output_channels = output_channels + self._normalize_fn = normalize_fn + self._temporal_kernel_size = temporal_kernel_size + + if self_gating_fn is None: + self._self_gating_fn = lambda x: x + else: + self._self_gating_fn = self_gating_fn + + def __call__( + self, + inputs: types.TensorLike, + is_training: bool) -> jnp.ndarray: + """Connects the module to inputs. + + Args: + inputs: A 5-D float array of shape `[B, T, H, W, C]`. + is_training: Whether to use training mode. + + Returns: + A 5-D float array of shape + `[B, new_t, new_h, new_w, sum(output_channels)]`. + """ + # Branch 0 + branch_0 = SUnit3D( + output_channels=self._output_channels[0], + kernel_shape=(1, 1, 1), + separable=False, + normalize_fn=self._normalize_fn, + self_gating_fn=self._self_gating_fn, + name='Branch_0_Conv2d_0a_1x1')( + inputs, is_training=is_training) + + # Branch 1 + branch_1 = SUnit3D( + output_channels=self._output_channels[1], + kernel_shape=(1, 1, 1), + separable=False, + normalize_fn=self._normalize_fn, + self_gating_fn=None, + name='Branch_1_Conv2d_0a_1x1')( + inputs, is_training=is_training) + branch_1 = SUnit3D( + output_channels=self._output_channels[2], + kernel_shape=(self._temporal_kernel_size, 3, 3), + separable=True, + normalize_fn=self._normalize_fn, + self_gating_fn=self._self_gating_fn, + name='Branch_1_Conv2d_0b_3x3')( + branch_1, is_training=is_training) + + # Branch 2 + branch_2 = SUnit3D( + output_channels=self._output_channels[3], + kernel_shape=(1, 1, 1), + separable=False, + normalize_fn=self._normalize_fn, + self_gating_fn=None, + name='Branch_2_Conv2d_0a_1x1')( + inputs, is_training=is_training) + branch_2 = SUnit3D( + output_channels=self._output_channels[4], + kernel_shape=(self._temporal_kernel_size, 3, 3), + separable=True, + normalize_fn=self._normalize_fn, + self_gating_fn=self._self_gating_fn, + name='Branch_2_Conv2d_0b_3x3')( + branch_2, is_training=is_training) + + # Branch 3 + branch_3 = hk.MaxPool( + window_shape=(1, 3, 3, 3, 1), + strides=(1, 1, 1, 1, 1), + padding='SAME', + name='Branch_3_MaxPool_0a_3x3')( + inputs) + branch_3 = SUnit3D( + output_channels=self._output_channels[5], + kernel_shape=(1, 1, 1), + separable=False, + normalize_fn=self._normalize_fn, + self_gating_fn=self._self_gating_fn, + name='Branch_3_Conv2d_0b_1x1')( + branch_3, is_training=is_training) + + return jnp.concatenate((branch_0, branch_1, branch_2, branch_3), axis=4) + + +_Layer = collections.namedtuple('_Layer', ('name', 'module', 'kwargs')) + + +class S3D(hk.Module): + """S3D architecture. + + Any intermediary representation can be obtained by choosing one of the valid + `final_endpoint`s. The final value returned by this model (when 'Embeddings' + is used as `final_endpoint`) is a single 1-D representation for each video in + the batch. Another layer can be externally added on top of that to obtain + logits. + """ + + # Endpoints of the model in order. + VALID_ENDPOINTS = ( + 'Conv2d_1a_7x7', + 'MaxPool_2a_3x3', + 'Conv2d_2b_1x1', + 'Conv2d_2c_3x3', + 'MaxPool_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Embeddings', + ) + + def __init__(self, + normalize_fn: Optional[types.NormalizeFn] = None, + first_temporal_kernel_size: int = 7, + temporal_conv_startat: Optional[str] = 'Conv2d_2c_3x3', + gating_startat: Optional[str] = 'Conv2d_2c_3x3', + name='S3D'): + """Initializes the S3D module. + + Args: + normalize_fn: Function used for normalization. + first_temporal_kernel_size: Specifies the temporal kernel size for the + first conv3d filter. A larger value slows down the model but provides + little accuracy improvement. Must be set to one of 1, 3, 5 or 7. + temporal_conv_startat: Specifies the first conv block to use separable 3D + convs rather than 2D convs (implemented as [1, k, k] 3D conv). This is + used to construct the inverted pyramid models. 'Conv2d_2c_3x3' is the + first valid block to use separable 3D convs. If provided block name is + not present, all valid blocks will use separable 3D convs. + gating_startat: Specifies the first conv block to use self gating. + 'Conv2d_2c_3x3' is the first valid block to use self gating. If provided + block name is not present, all valid blocks will use separable 3D convs. + name: The name of the module. + + Raises: + ValueError: If `temporal_conv_startat`, `gating_startat` or + `first_temporal_kernel_size` is not recognized. + """ + super().__init__(name=name) + self._first_temporal_kernel_size = first_temporal_kernel_size + self._temporal_conv_startat = temporal_conv_startat + self._gating_startat = gating_startat + self._normalize_fn = normalize_fn + + if (temporal_conv_startat not in self.VALID_ENDPOINTS + and temporal_conv_startat is not None): + raise ValueError( + f'Provided `temporal_conv_startat`: {temporal_conv_startat} not ' + f'valid. It must be one of: {self.VALID_ENDPOINTS}, or `None`.') + + if (gating_startat not in self.VALID_ENDPOINTS + and gating_startat is not None): + raise ValueError( + f'Provided `gating_startat`: {gating_startat} not valid. ' + f'It must be one of: {self.VALID_ENDPOINTS}, or `None`.') + + if first_temporal_kernel_size not in [1, 3, 5, 7]: + raise ValueError('`first_temporal_kernel_size` can only be 1, 3, 5 or 7.') + + def __call__(self, + inputs: types.TensorLike, + is_training: bool, + final_endpoint: str = 'Embeddings') -> jnp.ndarray: + """Connects the model to inputs. + + Args: + inputs: A 5-D float array of shape `[B, T, H, W, C]`. + is_training: Whether to use training mode. + final_endpoint: Up to which endpoint to run / return. + + Returns: + A 5-D float array of shape + `[B, new_t, new_h, new_w, sum(output_channels)]`. + + Returns: + Network output at location `final_endpoint`. A float array which shape + depends on `final_endpoint`. + + Raises: + ValueError: If `final_endpoint` is not recognized. + """ + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError(f'Provided final_endpoint: {final_endpoint} not valid.' + f' It must be one of: {self.VALID_ENDPOINTS}') + + x = inputs + + # We define layers with tuples (name, module, kwargs) + # Not all kwargs are present, as we will need to fill in certain properties + # as we move down the network. + layers = [] + + # The first layer is conditional on the input data shape: the channel size + # is used to identify whether the `space_to_depth` transformation has been + # applied to the input. This is used to speed up computation on TPUs. + if x.shape[-1] == 3: + layers.append( + _Layer('Conv2d_1a_7x7', SUnit3D, + dict(output_channels=64, stride=(2, 2, 2), separable=False, + kernel_shape=(self._first_temporal_kernel_size, 7, 7), + normalize_fn=self._normalize_fn))) + else: + layers.append( + _Layer('Conv2d_1a_7x7', SUnit3D, + dict(output_channels=64, kernel_shape=(2, 4, 4), + stride=(1, 1, 1), separable=False, + normalize_fn=self._normalize_fn))) + + layers.extend([ + _Layer('MaxPool_2a_3x3', _MaxPool, + dict(window_shape=(1, 1, 3, 3, 1), strides=(1, 1, 2, 2, 1), + padding='SAME')), + _Layer('Conv2d_2b_1x1', SUnit3D, + dict(output_channels=64, kernel_shape=(1, 1, 1), + normalize_fn=self._normalize_fn)), + _Layer('Conv2d_2c_3x3', SUnit3D, + dict(output_channels=192, separable=True, + normalize_fn=self._normalize_fn)), + _Layer('MaxPool_3a_3x3', _MaxPool, + dict(window_shape=(1, 1, 3, 3, 1), strides=(1, 1, 2, 2, 1), + padding='SAME')), + _Layer('Mixed_3b', InceptionBlockV13D, + dict(output_channels=(64, 96, 128, 16, 32, 32), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_3c', InceptionBlockV13D, + dict(output_channels=(128, 128, 192, 32, 96, 64), + normalize_fn=self._normalize_fn)), + _Layer('MaxPool_4a_3x3', _MaxPool, + dict(window_shape=(1, 3, 3, 3, 1), strides=(1, 2, 2, 2, 1), + padding='SAME')), + _Layer('Mixed_4b', InceptionBlockV13D, + dict(output_channels=(192, 96, 208, 16, 48, 64), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_4c', InceptionBlockV13D, + dict(output_channels=(160, 112, 224, 24, 64, 64), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_4d', InceptionBlockV13D, + dict(output_channels=(128, 128, 256, 24, 64, 64), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_4e', InceptionBlockV13D, + dict(output_channels=(112, 144, 288, 32, 64, 64), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_4f', InceptionBlockV13D, + dict(output_channels=(256, 160, 320, 32, 128, 128), + normalize_fn=self._normalize_fn)), + _Layer('MaxPool_5a_2x2', _MaxPool, + dict(window_shape=(1, 2, 2, 2, 1), strides=(1, 2, 2, 2, 1), + padding='SAME')), + _Layer('Mixed_5b', InceptionBlockV13D, + dict(output_channels=(256, 160, 320, 32, 128, 128), + normalize_fn=self._normalize_fn)), + _Layer('Mixed_5c', InceptionBlockV13D, + dict(output_channels=(384, 192, 384, 48, 128, 128), + normalize_fn=self._normalize_fn)), + ]) + + # These parameters may change thoughout the computation. + self_gating_fn = None + temporal_kernel_size = 1 + + # Iterate over layers. + for layer in layers: + # Update + if layer.name == self._gating_startat: + self_gating_fn = self_gating + if layer.name == self._temporal_conv_startat: + temporal_kernel_size = 3 + + kwargs = layer.kwargs + + if layer.module is SUnit3D: + kwargs['self_gating_fn'] = self_gating_fn + if 'kernel_shape' not in kwargs: + kwargs['kernel_shape'] = (temporal_kernel_size, 3, 3) + + elif layer.module is InceptionBlockV13D: + kwargs['self_gating_fn'] = self_gating_fn + kwargs['temporal_kernel_size'] = temporal_kernel_size + + module = layer.module(name=layer.name, **kwargs) + x = module(x, is_training=is_training) + if final_endpoint == layer.name: + return x + + assert final_endpoint == 'Embeddings' + return jnp.mean(x, axis=(1, 2, 3)) diff --git a/mmv/models/s3d_test.py b/mmv/models/s3d_test.py new file mode 100644 index 0000000..74a0756 --- /dev/null +++ b/mmv/models/s3d_test.py @@ -0,0 +1,88 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for s3d.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import haiku as hk +import jax +import numpy as np + +from mmv.models import normalization +from mmv.models import s3d + + +class _CallableS3D: + """Wrapper around S3D that take care of parameter book keeping.""" + + def __init__(self, *args, **kwargs): + self._model = hk.transform_with_state( + lambda *a, **k: # pylint: disable=g-long-lambda,unnecessary-lambda + s3d.S3D( + normalize_fn=normalization.get_normalize_fn(), + *args, **kwargs)(*a, **k)) + self._rng = jax.random.PRNGKey(42) + self._params, self._state = None, None + + def init(self, inputs, **kwargs): + self._params, self._state = self._model.init( + self._rng, inputs, is_training=True, **kwargs) + + def __call__(self, inputs, **kwargs): + if self._params is None: + self.init(inputs) + output, _ = self._model.apply( + self._params, self._state, self._rng, inputs, **kwargs) + return output + + +class S3DTest(parameterized.TestCase): + + # Testing all layers is quite slow, added in comments for completeness. + @parameterized.parameters( + # dict(endpoint='Conv2d_1a_7x7', expected_size=(2, 8, 112, 112, 64)), + # dict(endpoint='MaxPool_2a_3x3', expected_size=(2, 8, 56, 56, 64)), + # dict(endpoint='Conv2d_2b_1x1', expected_size=(2, 8, 56, 56, 64)), + # dict(endpoint='Conv2d_2c_3x3', expected_size=(2, 8, 56, 56, 192)), + # dict(endpoint='MaxPool_3a_3x3', expected_size=(2, 8, 28, 28, 192)), + # dict(endpoint='Mixed_3b', expected_size=(2, 8, 28, 28, 256)), + # dict(endpoint='Mixed_3c', expected_size=(2, 8, 28, 28, 480)), + # dict(endpoint='MaxPool_4a_3x3', expected_size=(2, 4, 14, 14, 480)), + # dict(endpoint='Mixed_4b', expected_size=(2, 4, 14, 14, 512)), + # dict(endpoint='Mixed_4c', expected_size=(2, 4, 14, 14, 512)), + # dict(endpoint='Mixed_4d', expected_size=(2, 4, 14, 14, 512)), + # dict(endpoint='Mixed_4e', expected_size=(2, 4, 14, 14, 528)), + # dict(endpoint='Mixed_4f', expected_size=(2, 4, 14, 14, 832)), + # dict(endpoint='MaxPool_5a_2x2', expected_size=(2, 2, 7, 7, 832)), + # dict(endpoint='Mixed_5b', expected_size=(2, 2, 7, 7, 832)), + # dict(endpoint='Mixed_5c', expected_size=(2, 2, 7, 7, 1024)), + dict(endpoint='Embeddings', expected_size=(2, 1024)), + ) + def test_endpoint_expected_output_dimensions(self, endpoint, expected_size): + inputs = np.random.normal(size=(2, 16, 224, 224, 3)) + model = _CallableS3D() + output = model(inputs, is_training=False, final_endpoint=endpoint) + self.assertSameElements(output.shape, expected_size) + + def test_space_to_depth(self): + inputs = np.random.normal(size=(2, 16//2, 224//2, 224//2, 3*2*2*2)) + model = _CallableS3D() + output = model(inputs, is_training=False, final_endpoint='Conv2d_1a_7x7') + self.assertSameElements(output.shape, (2, 8, 112, 112, 64)) + +if __name__ == '__main__': + absltest.main() diff --git a/mmv/models/tsm_resnet.py b/mmv/models/tsm_resnet.py new file mode 100644 index 0000000..572541f --- /dev/null +++ b/mmv/models/tsm_resnet.py @@ -0,0 +1,353 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Temporal Shift Module w/ ResNet-50 and ResNet-101. + +Based on: + TSM: Temporal Shift Module for Efficient Video Understanding + Ji Lin, Chuang Gan, Song Han + https://arxiv.org/pdf/1811.08383.pdf. +""" + +from typing import Optional + +import haiku as hk +import jax +import jax.numpy as jnp + +from mmv.models import tsm_utils as tsmu +from mmv.models import types + + +class TSMResNetBlock(hk.Module): + """A ResNet subblock with Temporal Channel Shifting. + + Combines a typical ResNetV2 block implementation + (see https://arxiv.org/abs/1512.03385) with a pre-convolution Temporal + Shift Module (see https://arxiv.org/pdf/1811.08383.pdf) in the residual. + """ + + def __init__(self, + output_channels: int, + stride: int, + use_projection: bool, + tsm_mode: str, + normalize_fn: Optional[types.NormalizeFn] = None, + channel_shift_fraction: float = 0.125, + num_frames: int = 8, + name: str = 'TSMResNetBlock'): + """Initializes the TSMResNetBlock module. + + Args: + output_channels: Number of output channels. + stride: Stride used in convolutions. + use_projection: Whether to use a projection for the shortcut. + tsm_mode: Mode for TSM ('gpu' or 'tpu'). + normalize_fn: Function used for normalization. + channel_shift_fraction: The fraction of temporally shifted channels. If + `channel_shift_fraction` is 0, the block is the same as a normal ResNet + block. + num_frames: Size of frame dimension in a single batch example + name: The name of the module. + """ + super().__init__(name=name) + self._output_channels = output_channels + self._bottleneck_channels = output_channels // 4 + self._stride = stride + self._use_projection = use_projection + self._normalize_fn = normalize_fn + self._tsm_mode = tsm_mode + self._channel_shift_fraction = channel_shift_fraction + self._num_frames = num_frames + + def __call__(self, + inputs: types.TensorLike, + is_training: bool = True) -> jnp.ndarray: + """Connects the ResNetBlock module into the graph. + + Args: + inputs: A 4-D float array of shape `[B, H, W, C]`. + is_training: Whether to use training mode. + + Returns: + A 4-D float array of shape + `[B * num_frames, new_h, new_w, output_channels]`. + """ + # ResNet V2 uses pre-activation, where the batch norm and relu are before + # convolutions, rather than after as in ResNet V1. + preact = inputs + if self._normalize_fn is not None: + preact = self._normalize_fn(preact, is_training=is_training) + preact = jax.nn.relu(preact) + + if self._use_projection: + shortcut = hk.Conv2D( + output_channels=self._output_channels, + kernel_shape=1, + stride=self._stride, + with_bias=False, + padding='SAME', + name='shortcut_conv')( + preact) + else: + shortcut = inputs + + # Eventually applies Temporal Shift Module. + if self._channel_shift_fraction != 0: + preact = tsmu.apply_temporal_shift( + preact, tsm_mode=self._tsm_mode, num_frames=self._num_frames, + channel_shift_fraction=self._channel_shift_fraction) + + # First convolution. + residual = hk.Conv2D( + self._bottleneck_channels, + kernel_shape=1, + stride=1, + with_bias=False, + padding='SAME', + name='conv_0')( + preact) + + # Second convolution. + if self._normalize_fn is not None: + residual = self._normalize_fn(residual, is_training=is_training) + residual = jax.nn.relu(residual) + residual = hk.Conv2D( + output_channels=self._bottleneck_channels, + kernel_shape=3, + stride=self._stride, + with_bias=False, + padding='SAME', + name='conv_1')( + residual) + + # Third convolution. + if self._normalize_fn is not None: + residual = self._normalize_fn(residual, is_training=is_training) + residual = jax.nn.relu(residual) + residual = hk.Conv2D( + output_channels=self._output_channels, + kernel_shape=1, + stride=1, + with_bias=False, + padding='SAME', + name='conv_2')( + residual) + + # NOTE: we do not use block multiplier. + output = shortcut + residual + return output + + +class TSMResNetUnit(hk.Module): + """Block group for TSM ResNet.""" + + def __init__(self, + output_channels: int, + num_blocks: int, + stride: int, + tsm_mode: str, + num_frames: int, + normalize_fn: Optional[types.NormalizeFn] = None, + channel_shift_fraction: float = 0.125, + name: str = 'tsm_resnet_unit'): + """Creates a TSMResNet Unit. + + Args: + output_channels: Number of output channels. + num_blocks: Number of ResNet blocks in the unit. + stride: Stride of the unit. + tsm_mode: Which temporal shift module to use. + num_frames: Size of frame dimension in a single batch example. + normalize_fn: Function used for normalization. + channel_shift_fraction: The fraction of temporally shifted channels. If + `channel_shift_fraction` is 0, the block is the same as a normal ResNet + block. + name: The name of the module. + """ + super().__init__(name=name) + self._output_channels = output_channels + self._num_blocks = num_blocks + self._normalize_fn = normalize_fn + self._stride = stride + self._tsm_mode = tsm_mode + self._channel_shift_fraction = channel_shift_fraction + self._num_frames = num_frames + + def __call__(self, + inputs: types.TensorLike, + is_training: bool) -> jnp.ndarray: + """Connects the module to inputs. + + Args: + inputs: A 4-D float array of shape `[B * num_frames, H, W, C]`. + is_training: Whether to use training mode. + + Returns: + A 4-D float array of shape + `[B * num_frames, H // stride, W // stride, output_channels]`. + """ + net = inputs + for idx_block in range(self._num_blocks): + net = TSMResNetBlock( + self._output_channels, + stride=self._stride if idx_block == 0 else 1, + use_projection=idx_block == 0, + normalize_fn=self._normalize_fn, + tsm_mode=self._tsm_mode, + channel_shift_fraction=self._channel_shift_fraction, + num_frames=self._num_frames, + name=f'block_{idx_block}')( + net, is_training=is_training) + return net + + +class TSMResNetV2(hk.Module): + """TSM based on ResNet V2 as described in https://arxiv.org/abs/1603.05027.""" + + # Endpoints of the model in order. + VALID_ENDPOINTS = ( + 'tsm_resnet_stem', + 'tsm_resnet_unit_0', + 'tsm_resnet_unit_1', + 'tsm_resnet_unit_2', + 'tsm_resnet_unit_3', + 'last_conv', + 'Embeddings', + ) + + def __init__(self, + normalize_fn: Optional[types.NormalizeFn] = None, + depth: int = 50, + num_frames: int = 16, + channel_shift_fraction: float = 0.125, + width_mult: int = 1, + name: str = 'TSMResNetV2'): + """Constructs a ResNet model. + + Args: + normalize_fn: Function used for normalization. + depth: Depth of the desired ResNet. + num_frames: Number of frames (used in TPU mode). + channel_shift_fraction: Fraction of channels that are temporally shifted, + if `channel_shift_fraction` is 0, a regular ResNet is returned. + width_mult: Whether or not to use a width multiplier. + name: The name of the module. + + Raises: + ValueError: If `channel_shift_fraction` or `depth` has invalid value. + """ + super().__init__(name=name) + + if not 0. <= channel_shift_fraction <= 1.0: + raise ValueError( + f'channel_shift_fraction ({channel_shift_fraction})' + ' has to be in [0, 1].') + + self._num_frames = num_frames + + self._channels = (256, 512, 1024, 2048) + self._strides = (1, 2, 2, 2) + + num_blocks = { + 50: (3, 4, 6, 3), + 101: (3, 4, 23, 3), + 152: (3, 8, 36, 3), + 200: (3, 24, 36, 3), + } + if depth not in num_blocks: + raise ValueError( + f'`depth` should be in {list(num_blocks.keys())} ({depth} given).') + self._num_blocks = num_blocks[depth] + + self._width_mult = width_mult + self._channel_shift_fraction = channel_shift_fraction + self._normalize_fn = normalize_fn + + def __call__( + self, + inputs: types.TensorLike, + is_training: bool = True, + final_endpoint: str = 'Embeddings') -> jnp.ndarray: + """Connects the TSM ResNetV2 module into the graph. + + Args: + inputs: A 4-D float array of shape `[B, H, W, C]`. + is_training: Whether to use training mode. + final_endpoint: Up to which endpoint to run / return. + + Returns: + Network output at location `final_endpoint`. A float array which shape + depends on `final_endpoint`. + + Raises: + ValueError: If `final_endpoint` is not recognized. + """ + + # Prepare inputs for TSM. + inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs) + num_frames = num_frames or self._num_frames + + self._final_endpoint = final_endpoint + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError(f'Unknown final endpoint {self._final_endpoint}') + + # Stem convolution. + end_point = 'tsm_resnet_stem' + net = hk.Conv2D( + output_channels=64 * self._width_mult, + kernel_shape=7, + stride=2, + with_bias=False, + name=end_point, + padding='SAME')( + inputs) + net = hk.MaxPool( + window_shape=(1, 3, 3, 1), + strides=(1, 2, 2, 1), + padding='SAME')( + net) + if self._final_endpoint == end_point: + return net + + # Residual block. + for unit_id, (channels, num_blocks, stride) in enumerate( + zip(self._channels, self._num_blocks, self._strides)): + end_point = f'tsm_resnet_unit_{unit_id}' + net = TSMResNetUnit( + output_channels=channels * self._width_mult, + num_blocks=num_blocks, + stride=stride, + normalize_fn=self._normalize_fn, + channel_shift_fraction=self._channel_shift_fraction, + num_frames=num_frames, + tsm_mode=tsm_mode, + name=end_point)( + net, is_training=is_training) + if self._final_endpoint == end_point: + return net + + if self._normalize_fn is not None: + net = self._normalize_fn(net, is_training=is_training) + net = jax.nn.relu(net) + + end_point = 'last_conv' + if self._final_endpoint == end_point: + return net + net = jnp.mean(net, axis=(1, 2)) + # Prepare embedding outputs for TSM (temporal average of features). + net = tsmu.prepare_outputs(net, tsm_mode, num_frames) + assert self._final_endpoint == 'Embeddings' + return net diff --git a/mmv/models/tsm_resnet_test.py b/mmv/models/tsm_resnet_test.py new file mode 100644 index 0000000..6d47d08 --- /dev/null +++ b/mmv/models/tsm_resnet_test.py @@ -0,0 +1,65 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TSM ResNet model.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import haiku as hk +import jax +import jax.numpy as jnp + +from mmv.models import tsm_resnet + + +class TSMResNetTest(parameterized.TestCase): + + @parameterized.parameters( + ('tsm_resnet_stem', (2 * 32, 56, 56, 64)), + ('tsm_resnet_unit_0', (2 * 32, 56, 56, 256)), + ('tsm_resnet_unit_1', (2 * 32, 28, 28, 512)), + ('tsm_resnet_unit_2', (2 * 32, 14, 14, 1024)), + ('tsm_resnet_unit_3', (2 * 32, 7, 7, 2048)), + ('last_conv', (2 * 32, 7, 7, 2048)), + ('Embeddings', (2, 2048)), + ) + def test_output_dimension(self, final_endpoint, expected_shape): + input_shape = (2, 32, 224, 224, 3) + + def f(): + data = jnp.zeros(input_shape) + net = tsm_resnet.TSMResNetV2() + return net(data, final_endpoint=final_endpoint) + + init_fn, apply_fn = hk.transform(f) + out = apply_fn(init_fn(jax.random.PRNGKey(42)), None) + self.assertEqual(out.shape, expected_shape) + + def test_tpu_mode(self): + input_shape = (32 * 2, 224, 224, 3) + + def f(): + data = jnp.zeros(input_shape) + net = tsm_resnet.TSMResNetV2(num_frames=32) + return net(data, final_endpoint='Embeddings') + + init_fn, apply_fn = hk.transform(f) + out = apply_fn(init_fn(jax.random.PRNGKey(42)), None) + self.assertEqual(out.shape, (2, 2048)) + + +if __name__ == '__main__': + absltest.main() diff --git a/mmv/models/tsm_utils.py b/mmv/models/tsm_utils.py new file mode 100644 index 0000000..13531c3 --- /dev/null +++ b/mmv/models/tsm_utils.py @@ -0,0 +1,177 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils functions for TSM.""" + +from typing import Tuple + +import jax +import jax.numpy as jnp + +from mmv.models import types + + +def prepare_inputs( + inputs: types.TensorLike) -> Tuple[jnp.ndarray, str, int]: + """Deduces input mode for TSM.""" + # Deduce if we run on TPU based on input shape. + if len(inputs.shape) == 5: + # Input is given in the standard [B, T, H, W, 3] format. + tsm_mode = 'gpu' + num_frames = inputs.shape[1] + inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:])) + else: + # Input is given in the [T * B, H, W, 3] format. + tsm_mode = 'tpu' + num_frames = None + return inputs, tsm_mode, num_frames + + +def prepare_outputs(outputs: types.TensorLike, + tsm_mode: str, + num_frames: int) -> jnp.ndarray: + """Processes output of TSM by averaging representations over time axis.""" + n_channels = outputs.shape[-1] + if tsm_mode == 'tpu': + outputs = jnp.reshape(outputs, [num_frames, -1, n_channels]) + outputs = jnp.mean(outputs, axis=0) + elif tsm_mode == 'gpu': + outputs = jnp.reshape(outputs, [-1, num_frames, n_channels]) + outputs = jnp.mean(outputs, axis=1) + else: + raise ValueError( + f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)') + return outputs + + +def apply_temporal_shift( + x: types.TensorLike, + tsm_mode: str, + num_frames: int, + channel_shift_fraction: float = 0.125) -> jnp.ndarray: + """Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode.""" + if tsm_mode == 'tpu': + outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction) + elif tsm_mode == 'gpu': + outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction) + else: + raise ValueError( + f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)') + return outputs + + +def temporal_shift_gpu( + x: types.TensorLike, + num_frames: int, + channel_shift_fraction: float = 0.125) -> jnp.ndarray: + """Performs a temporal shift: https://arxiv.org/abs/1811.08383.""" + # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels + # Input is (B * T, H, W, C) + orig_shp = tuple(x.shape) + reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:]) + n_channels = orig_shp[-1] + n_shift = int(n_channels * channel_shift_fraction) + + new_shp = tuple(reshaped_x.shape) + + # shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:] + shifted_backward = jax.lax.slice( + reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift), + (new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4])) + shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0)) + shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding) + + # shifted_forward = reshaped_x[:, :-1, :, :, :n_shift] + shifted_forward = jax.lax.slice( + reshaped_x, (0, 0, 0, 0, 0), + (new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift)) + shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)) + shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding) + + no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift] + shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], + axis=4) + return jnp.reshape(shifted_x, (-1,) + orig_shp[1:]) + + +def temporal_shift_tpu( + x: types.TensorLike, + num_frames: int, + channel_shift_fraction: float = 0.125) -> jnp.ndarray: + """Performs a temporal shift: https://arxiv.org/abs/1811.08383. + + TPU optimized version of TSM. Reshape is avoided by having the images + reshaped in [T * B, :] so that frames corresponding to same time frame in + videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse + pad->slice into convolution, we reformulate the slice pad into a pad then + slice. Finally, to avoid concatenate that prevent some fusion from happening + we simply sum masked version of the features. + Args: + x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped + from a time major version of the input). + num_frames: number of frames T per video. + channel_shift_fraction: fraction of the channel to shift forward and + backward. + + Returns: + The temporal shifted version of x. + """ + # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels + # Input is (T * B, H, W, C) + original_shape = list(x.shape) + + batch_size = int(original_shape[0] / num_frames) + n_channels = int(original_shape[-1]) + n_shift = int(n_channels * channel_shift_fraction) + + # Cast to bfloat16. + x = x.astype(jnp.bfloat16) + + # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1. + # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0]. + orig_shp = list(x.shape) + + shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0), + (0, n_channels - n_shift, 0)) + x_backward_padding = jax.lax.pad( + x, + padding_value=jnp.bfloat16(0.), + padding_config=shifted_backward_padding) + # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future). + shifted_backward = jax.lax.slice(x_backward_padding, + (batch_size, 0, 0, n_channels - n_shift), + (orig_shp[0] + batch_size, orig_shp[1], + orig_shp[2], 2 * n_channels - n_shift)) + # Shift forward, we first pad by zeros [0, 0, x1, x2, x3]. + shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0), + (n_channels - n_shift, 0, 0)) + x_forward_padding = jax.lax.pad( + x, + padding_value=jnp.bfloat16(0.), + padding_config=shifted_forward_padding) + # The following shift gets to [0, 0, x1^-1] (where -1 means from the past). + shifted_forward = jax.lax.slice( + x_forward_padding, (0, 0, 0, 0), + (orig_shp[0], orig_shp[1], orig_shp[2], n_channels)) + # No shift is in the middle, this gets [0, x2, 0]. + mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) & + (jnp.arange(n_channels) < n_channels - n_shift), + (1, 1, 1, -1))).astype(jnp.bfloat16) + no_shift = mask_noshift * x + # By summing everything together, we end up with [x3^+1, x2, x1^-1]. + # Note: channels have been reordered but that doesn't matter for the model. + shifted_x = shifted_backward + shifted_forward + no_shift + + return shifted_x.astype(jnp.float32) diff --git a/mmv/models/tsm_utils_test.py b/mmv/models/tsm_utils_test.py new file mode 100644 index 0000000..b070610 --- /dev/null +++ b/mmv/models/tsm_utils_test.py @@ -0,0 +1,60 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tsm_utils.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import jax.numpy as jnp +import numpy as np + +from mmv.models import tsm_utils + + +class TsmUtilsTest(parameterized.TestCase): + + @parameterized.parameters( + ((2, 32, 224, 224, 3), 'gpu', (2 * 32, 224, 224, 3), 32), + ((32, 224, 224, 3), 'tpu', (32, 224, 224, 3), None), + ) + def test_prepare_inputs(self, input_shape, expected_mode, expected_shape, + expected_num_frames): + + data = jnp.zeros(input_shape) + out, mode, num_frames = tsm_utils.prepare_inputs(data) + self.assertEqual(out.shape, expected_shape) + self.assertEqual(mode, expected_mode) + self.assertEqual(num_frames, expected_num_frames) + + def test_prepare_outputs(self): + data = jnp.concatenate([jnp.zeros(4), jnp.ones(4)]).reshape(4, 2) + out_gpu = tsm_utils.prepare_outputs(data, 'gpu', 2) + out_tpu = tsm_utils.prepare_outputs(data, 'tpu', 2) + expected_gpu = np.concatenate([np.zeros(2), np.ones(2)]).reshape(2, 2) + expected_tpu = 0.5 * jnp.ones((2, 2)) + np.testing.assert_allclose(out_gpu, expected_gpu) + np.testing.assert_allclose(out_tpu, expected_tpu) + + def test_apply_tsm(self): + shape = (32, 224, 224, 16) + data = jnp.zeros(shape) + out_gpu = tsm_utils.apply_temporal_shift(data, 'gpu', 16) + out_tpu = tsm_utils.apply_temporal_shift(data, 'tpu', 16) + self.assertEqual(out_gpu.shape, shape) + self.assertEqual(out_tpu.shape, shape) + +if __name__ == '__main__': + absltest.main() diff --git a/mmv/models/types.py b/mmv/models/types.py new file mode 100644 index 0000000..bac7e52 --- /dev/null +++ b/mmv/models/types.py @@ -0,0 +1,36 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type Aliases.""" + +from typing import Callable, Tuple, Union + +import jax.numpy as jnp +import numpy as np +import optax + +TensorLike = Union[np.ndarray, jnp.DeviceArray] + +ActivationFn = Callable[[TensorLike], TensorLike] +GatingFn = Callable[[TensorLike], TensorLike] +NetworkFn = Callable[[TensorLike], TensorLike] + +# Callable doesn't allow kwargs to be used, and we often want to +# pass in is_training=..., so ignore the arguments for the sake of pytype. +NormalizeFn = Callable[..., TensorLike] + +OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState] + + diff --git a/mmv/requirements.txt b/mmv/requirements.txt new file mode 100644 index 0000000..5be8814 --- /dev/null +++ b/mmv/requirements.txt @@ -0,0 +1,9 @@ +dm-haiku +dm-tree +jax +jaxlib +numpy>=1.16 +optax +sklearn +tensorflow +tensorflow_datasets diff --git a/mmv/utils/checkpoint.py b/mmv/utils/checkpoint.py new file mode 100644 index 0000000..e8ff30d --- /dev/null +++ b/mmv/utils/checkpoint.py @@ -0,0 +1,29 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint restoring utilities.""" + +from absl import logging +import dill + + +def load_checkpoint(checkpoint_path): + try: + with open(checkpoint_path, 'rb') as checkpoint_file: + checkpoint_data = dill.load(checkpoint_file) + logging.info('Loading checkpoint from %s', checkpoint_path) + return checkpoint_data + except FileNotFoundError: + return None diff --git a/mmv/utils/ucf101_dataset.py b/mmv/utils/ucf101_dataset.py new file mode 100644 index 0000000..41304b6 --- /dev/null +++ b/mmv/utils/ucf101_dataset.py @@ -0,0 +1,70 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ucf101 with custom decoding params.""" + +import tensorflow as tf +import tensorflow_datasets as tfds + +# Utilities functions. + +tf.compat.v1.enable_eager_execution() + +_CITATION = """\ +@article{DBLP:journals/corr/abs-1212-0402, + author = {Khurram Soomro and + Amir Roshan Zamir and + Mubarak Shah}, + title = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in + The Wild}, + journal = {CoRR}, + volume = {abs/1212.0402}, + year = {2012}, + url = {http://arxiv.org/abs/1212.0402}, + archivePrefix = {arXiv}, + eprint = {1212.0402}, + timestamp = {Mon, 13 Aug 2018 16:47:45 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + +_LABELS_FNAME = 'video/ucf101_labels.txt' + + +class ModUcf101(tfds.video.Ucf101): + """Ucf101 action recognition dataset with better quality. + """ + + def _info(self): + + ffmpeg_extra_args = ('-qscale:v', '2', '-r', '25', '-t', '00:00:20') + + video_shape = ( + None, self.builder_config.height, self.builder_config.width, 3) + labels_names_file = tfds.core.tfds_path(_LABELS_FNAME) + features = tfds.features.FeaturesDict({ + 'video': tfds.features.Video(video_shape, + ffmpeg_extra_args=ffmpeg_extra_args, + encoding_format='jpeg'), + 'label': tfds.features.ClassLabel(names_file=labels_names_file), + }) + return tfds.core.DatasetInfo( + builder=self, + description='A 101-label video classification dataset.', + features=features, + homepage='https://www.crcv.ucf.edu/data-sets/ucf101/', + citation=_CITATION, + )