diff --git a/transporter/README.md b/transporter/README.md new file mode 100644 index 0000000..67d3cf0 --- /dev/null +++ b/transporter/README.md @@ -0,0 +1,25 @@ +# Transporter - Unsupervised Learning of Object Keypoints for Perception and Control + +This directory contains a [Sonnet](https://sonnet.dev) implementation of +the Transporter architecture. + +The Transporter is a neural network architecture for discovering concise +geometric object representations in terms of keypoints or image-space +coordinates. Our method learns from raw video frames in a fully unsupervised +manner, by transporting learnt image features between video frames using a +keypoint bottleneck. The discovered keypoints track objects and object parts +across long time-horizons more accurately than recent similar methods. + +For details, see our +paper [Unsupervised Learning of Object Keypoints for Perception and Control](https://arxiv.org/abs/1906.11883). + +## Contributors +* Tejas Kulkarni +* Ankush Gupta +* Catalin Ionescu +* Sebastian Borgeaud +* Malcolm Reynolds +* Andrew Zisserman +* Volodymyr Mnih + + diff --git a/transporter/requirements.txt b/transporter/requirements.txt new file mode 100644 index 0000000..48877b5 --- /dev/null +++ b/transporter/requirements.txt @@ -0,0 +1,22 @@ +absl-py==0.7.1 +astor==0.7.1 +contextlib2==0.5.5 +dm-sonnet==1.32 +gast==0.2.2 +grpcio==1.20.1 +h5py==2.9.0 +Keras-Applications==1.0.7 +Keras-Preprocessing==1.0.9 +Markdown==3.1 +mock==3.0.5 +numpy==1.16.3 +protobuf==3.7.1 +semantic-version==2.6.0 +six==1.12.0 +tensorboard==1.13.1 +tensorflow==1.13.1 +tensorflow-estimator==1.13.0 +tensorflow-probability==0.6.0 +termcolor==1.1.0 +Werkzeug==0.15.4 +wrapt==1.11.1 diff --git a/transporter/run.sh b/transporter/run.sh new file mode 100755 index 0000000..b015760 --- /dev/null +++ b/transporter/run.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# Copyright 2019 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m venv transporter-venv +source transporter-venv/bin/activate +pip install -r requirements.txt +python3 transporter_test.py +deactivate diff --git a/transporter/transporter.py b/transporter/transporter.py new file mode 100644 index 0000000..22564e3 --- /dev/null +++ b/transporter/transporter.py @@ -0,0 +1,461 @@ +# Copyright 2019 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transporter architecture in Sonnet/TF 1: https://arxiv.org/abs/1906.11883.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import sonnet as snt +import tensorflow as tf + +nest = tf.contrib.framework.nest + +# Paper submission used BatchNorm, but we have since found that Layer & Instance +# norm can be quite a lot more stable. +_NORMALIZATION_CTORS = { + "layer": snt.LayerNorm, + "instance": functools.partial(snt.LayerNorm, axis=[1, 2]), + "batch": snt.BatchNormV2, +} + + +def _connect_module_with_kwarg_if_supported(module, + input_tensor, + kwarg_name, + kwarg_value): + """Connects a module to some input, plus a kwarg= if supported by module.""" + if snt.supports_kwargs(module, kwarg_name) == "supported": + kwargs = {kwarg_name: kwarg_value} + else: + kwargs = {} + return module(input_tensor, **kwargs) + + +class Transporter(snt.AbstractModule): + """Sonnet module implementing the Transporter architecture.""" + + def __init__( + self, + encoder, + keypointer, + decoder, + name="transporter"): + """Initialize the Transporter module. + + Args: + encoder: `snt.AbstractModule` mapping images to features (see `Encoder`) + keypointer: `snt.AbstractModule` mapping images to keypoint masks (see + `KeyPointer`) + decoder: `snt.AbstractModule` decoding features to images (see `Decoder`) + name: `str` module name + """ + super(Transporter, self).__init__(name=name) + + self._encoder = encoder + self._decoder = decoder + self._keypointer = keypointer + + def _build(self, image_a, image_b, is_training): + """Reconstructs image_b using feature transport from image_a. + + This approaches matches the NeurIPS submission. + + Args: + image_a: Tensor of shape [B, H, W, C] containing a batch of images. + image_b: Tensor of shape [B, H, W, C] containing a batch of images. + is_training: `bool` indication whether the model is in training mode. + + Returns: + A dict containing keys: + 'reconstructed_image_b': Reconstruction of image_b, with the same shape. + 'features_a': Tensor of shape [B, F_h, F_w, N] of the extracted features + for `image_a`. + 'features_b': Tensor of shape [B, F_h, F_w, N] of the extracted features + for `image_b`. + 'keypoints_a': The result of the keypointer module on image_a, with stop + gradients applied. + 'keypoints_b': The result of the keypointer module on image_b. + """ + # Process both images. All gradients related to image_a are stopped. + image_a_features = tf.stop_gradient( + self._encoder(image_a, is_training=is_training)) + image_a_keypoints = nest.map_structure( + tf.stop_gradient, self._keypointer(image_a, is_training=is_training)) + + image_b_features = self._encoder(image_b, is_training=is_training) + image_b_keypoints = self._keypointer(image_b, is_training=is_training) + + # Transport features + num_keypoints = image_a_keypoints["heatmaps"].shape[-1] + transported_features = image_a_features + for k in range(num_keypoints): + mask_a = image_a_keypoints["heatmaps"][Ellipsis, k, None] + mask_b = image_b_keypoints["heatmaps"][Ellipsis, k, None] + + # suppress features from image a, around both keypoint locations. + transported_features = ( + (1 - mask_a) * (1 - mask_b) * transported_features) + + # copy features from image b around keypoints for image b. + transported_features += (mask_b * image_b_features) + + reconstructed_image_b = self._decoder( + transported_features, is_training=is_training) + + return { + "reconstructed_image_b": reconstructed_image_b, + "features_a": image_a_features, + "features_b": image_b_features, + "keypoints_a": image_a_keypoints, + "keypoints_b": image_b_keypoints, + } + + +def reconstruction_loss(image, predicted_image, loss_type="l2"): + """Returns the reconstruction loss between the image and the predicted_image. + + Args: + image: target image tensor of shape [B, H, W, C] + predicted_image: reconstructed image as returned by the model + loss_type: `str` reconstruction loss, either `l2` (default) or `l1`. + + Returns: + The reconstruction loss + """ + + if loss_type == "l2": + return tf.reduce_mean(tf.square(image - predicted_image)) + elif loss_type == "l1": + return tf.reduce_mean(tf.abs(image - predicted_image)) + else: + raise ValueError("Unknown loss type: {}".format(loss_type)) + + +class Encoder(snt.AbstractModule): + """Encoder module mapping an image to features. + + The encoder is a standard convolutional network with ReLu activations. + """ + + def __init__( + self, + filters=(16, 16, 32, 32), + kernel_sizes=(7, 3, 3, 3), + strides=(1, 1, 2, 1), + norm_type="batch", + name="encoder"): + """Initialize the Encoder. + + Args: + filters: tuple of `int`. The ith layer of the encoder will + consist of `filters[i]` filters. + kernel_sizes: tuple of `int` kernel sizes for each layer + strides: tuple of `int` strides for each layer + norm_type: string, one of 'instance', 'layer', 'batch'. + name: `str` name of the module. + """ + super(Encoder, self).__init__(name=name) + if len({len(filters), len(kernel_sizes), len(strides)}) != 1: + raise ValueError( + "length of filters/kernel_sizes/strides lists must be the same") + self._filters = filters + self._kernels = kernel_sizes + self._strides = strides + self._norm_ctor = _NORMALIZATION_CTORS[norm_type] + + def _build(self, image, is_training): + """Connect the Encoder. + + Args: + image: A batch of images of shape [B, H, W, C] + is_training: `bool` indicating if the model is in training mode. + + Returns: + A tensor of features of shape [B, F_h, F_w, N] where F_h and F_w are the + height and width of the feature map and N = 4 * `self._filters` + """ + regularizers = {"w": tf.contrib.layers.l2_regularizer(1.0)} + + features = image + for l in range(len(self._filters)): + with tf.variable_scope("conv_{}".format(l + 1)): + conv = snt.Conv2D( + self._filters[l], + self._kernels[l], + self._strides[l], + padding=snt.SAME, + regularizers=regularizers, + name="conv_{}".format(l+1)) + norm_module = self._norm_ctor(name="normalization") + + features = conv(features) + features = _connect_module_with_kwarg_if_supported( + norm_module, features, "is_training", is_training) + features = tf.nn.relu(features) + + return features + + +class KeyPointer(snt.AbstractModule): + """Module for extracting keypoints from an image.""" + + def __init__(self, + num_keypoints, + gauss_std, + keypoint_encoder, + custom_getter=None, + name="key_pointer"): + """Iniitialize the keypointer. + + Args: + num_keypoints: `int` number of keypoints to extract + gauss_std: `float` size of the keypoints, relative to the image dimensions + normalized to the range [-1, 1] + keypoint_encoder: sonnet Module which produces a feature map. Must accept + an is_training kwarg. When used in the Transporter, the output spatial + resolution of this encoder should match the output spatial resolution + of the other encoder, although these two encoders should not share + weights. + custom_getter: optional custom getter for variables in this module. + name: `str` name of the module + """ + super(KeyPointer, self).__init__(name=name, custom_getter=custom_getter) + self._num_keypoints = num_keypoints + self._gauss_std = gauss_std + self._keypoint_encoder = keypoint_encoder + + def _build(self, image, is_training): + """Compute the gaussian keypoints for the image. + + Args: + image: Image tensor of shape [B, H, W, C] + is_training: `bool` whether the model is in training or evaluation mode + + Returns: + a dict with keys: + 'centers': A tensor of shape [B, K, 2] of the center locations for each + of the K keypoints. + 'heatmaps': A tensor of shape [B, F_h, F_w, K] of gaussian maps over the + keypoints, where [F_h, F_w] is the size of the keypoint_encoder + feature maps. + """ + conv = snt.Conv2D( + self._num_keypoints, [1, 1], + stride=1, + regularizers={"w": tf.contrib.layers.l2_regularizer(1.0)}, + name="conv_1/conv_1") + + image_features = self._keypoint_encoder(image, is_training=is_training) + keypoint_features = conv(image_features) + return get_keypoint_data_from_feature_map( + keypoint_features, self._gauss_std) + + +def get_keypoint_data_from_feature_map(feature_map, gauss_std): + """Returns keypoint information from a feature map. + + Args: + feature_map: [B, H, W, K] Tensor, should be activations from a convnet. + gauss_std: float, the standard deviation of the gaussians to be put around + the keypoints. + + Returns: + a dict with keys: + 'centers': A tensor of shape [B, K, 2] of the center locations for each + of the K keypoints. + 'heatmaps': A tensor of shape [B, H, W, K] of gaussian maps over the + keypoints. + """ + gauss_mu = _get_keypoint_mus(feature_map) + map_size = feature_map.shape.as_list()[1:3] + gauss_maps = _get_gaussian_maps(gauss_mu, map_size, 1.0 / gauss_std) + + return { + "centers": gauss_mu, + "heatmaps": gauss_maps, + } + + +def _get_keypoint_mus(keypoint_features): + """Returns the keypoint center points. + + Args: + keypoint_features: A tensor of shape [B, F_h, F_w, K] where K is the number + of keypoints to extract. + + Returns: + A tensor of shape [B, K, 2] of the y, x center points of each keypoint. Each + center point are in the range [-1, 1]^2. Note: the first element is the y + coordinate, the second is the x coordinate. + """ + gauss_y = _get_coord(keypoint_features, 1) + gauss_x = _get_coord(keypoint_features, 2) + gauss_mu = tf.stack([gauss_y, gauss_x], axis=2) + return gauss_mu + + +def _get_coord(features, axis): + """Returns the keypoint coordinate encoding for the given axis. + + Args: + features: A tensor of shape [B, F_h, F_w, K] where K is the number of + keypoints to extract. + axis: `int` which axis to extract the coordinate for. Has to be axis 1 or 2. + + Returns: + A tensor of shape [B, K] containing the keypoint centers along the given + axis. The location is given in the range [-1, 1]. + """ + if axis != 1 and axis != 2: + raise ValueError("Axis needs to be 1 or 2.") + + other_axis = 1 if axis == 2 else 2 + axis_size = features.shape[axis] + + # Compute the normalized weight for each row/column along the axis + g_c_prob = tf.reduce_mean(features, axis=other_axis) + g_c_prob = tf.nn.softmax(g_c_prob, axis=1) + + # Linear combination of the interval [-1, 1] using the normalized weights to + # give a single coordinate in the same interval [-1, 1] + scale = tf.cast(tf.linspace(-1.0, 1.0, axis_size), tf.float32) + scale = tf.reshape(scale, [1, axis_size, 1]) + coordinate = tf.reduce_sum(g_c_prob * scale, axis=1) + return coordinate + + +def _get_gaussian_maps(mu, map_size, inv_std, power=2): + """Transforms the keypoint center points to a gaussian masks.""" + mu_y, mu_x = mu[:, :, 0:1], mu[:, :, 1:2] + + y = tf.cast(tf.linspace(-1.0, 1.0, map_size[0]), tf.float32) + x = tf.cast(tf.linspace(-1.0, 1.0, map_size[1]), tf.float32) + + mu_y, mu_x = tf.expand_dims(mu_y, -1), tf.expand_dims(mu_x, -1) + + y = tf.reshape(y, [1, 1, map_size[0], 1]) + x = tf.reshape(x, [1, 1, 1, map_size[1]]) + + g_y = tf.pow(y - mu_y, power) + g_x = tf.pow(x - mu_x, power) + dist = (g_y + g_x) * tf.pow(inv_std, power) + g_yx = tf.exp(-dist) + + g_yx = tf.transpose(g_yx, perm=[0, 2, 3, 1]) + return g_yx + + +class Decoder(snt.AbstractModule): + """Decoder reconstruction network. + + The decoder is a standard convolutional network with ReLu activations. + """ + + def __init__(self, initial_filters, output_size, + output_channels=3, + norm_type="batch", + name="decoder"): + """Initialize the decoder. + + Args: + initial_filters: `int` number of initial filters used in the decoder + output_size: tuple of `int` height and width of the reconstructed image + output_channels: `int` number of output channels, for RGB use 3 (default) + norm_type: string, one of 'instance', 'layer', 'batch'. + name: `str` name of the module + """ + super(Decoder, self).__init__(name=name) + self._initial_filters = initial_filters + self._output_height = output_size[0] + self._output_width = output_size[1] + self._output_channels = output_channels + self._norm_ctor = _NORMALIZATION_CTORS[norm_type] + + def _build(self, features, is_training): + """Connect the Decoder. + + Args: + features: Tensor of shape [B, F_h, F_w, N] + is_training: `bool` whether the module is in training mode. + + Returns: + A reconstructed image tensor of shape [B, output_height, output_width, + output_channels] + """ + height, width = features.shape.as_list()[1:3] + + filters = self._initial_filters + regularizers = {"w": tf.contrib.layers.l2_regularizer(1.0)} + + layer = 0 + + while height <= self._output_height: + layer += 1 + with tf.variable_scope("conv_{}".format(layer)): + conv1 = snt.Conv2D( + filters, + [3, 3], + stride=1, + regularizers=regularizers, + name="conv_{}".format(layer)) + norm_module = self._norm_ctor(name="normalization") + + features = conv1(features) + features = _connect_module_with_kwarg_if_supported( + norm_module, features, "is_training", is_training) + features = tf.nn.relu(features) + + if height == self._output_height: + layer += 1 + with tf.variable_scope("conv_{}".format(layer)): + conv2 = snt.Conv2D( + self._output_channels, + [3, 3], + stride=1, + regularizers=regularizers, + name="conv_{}".format(layer)) + features = conv2(features) + break + else: + layer += 1 + with tf.variable_scope("conv_{}".format(layer)): + conv2 = snt.Conv2D( + filters, + [3, 3], + stride=1, + regularizers=regularizers, + name="conv_{}".format(layer)) + norm_module = self._norm_ctor(name="normalization") + + features = conv2(features) + features = _connect_module_with_kwarg_if_supported( + norm_module, features, "is_training", is_training) + features = tf.nn.relu(features) + + height *= 2 + width *= 2 + features = tf.image.resize(features, [height, width]) + + if filters >= 8: + filters /= 2 + + assert height == self._output_height + assert width == self._output_width + + return features + diff --git a/transporter/transporter_test.py b/transporter/transporter_test.py new file mode 100644 index 0000000..c18ef6c --- /dev/null +++ b/transporter/transporter_test.py @@ -0,0 +1,192 @@ +# Copyright 2019 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for the Transporter module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import tensorflow as tf +from transporter import transporter + + +IMAGE_H = 16 +IMAGE_W = 16 +IMAGE_C = 3 +BATCH_SIZE = 4 +IMAGE_BATCH_SHAPE = (BATCH_SIZE, IMAGE_H, IMAGE_W, IMAGE_C) + +FILTERS = (16, 16, 32, 32, 64, 64) +STRIDES = (1, 1, 2, 1, 2, 1) +KERNEL_SIZES = (7, 3, 3, 3, 3, 3) + + +class TransporterTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + {'norm_type': 'batch'}, + {'norm_type': 'layer'}, + {'norm_type': 'instance'}) + def test_output_shape(self, norm_type): + encoder_ctor = transporter.Encoder + encoder_kwargs = { + 'filters': FILTERS, + 'strides': STRIDES, + 'kernel_sizes': KERNEL_SIZES, + 'norm_type': norm_type, + } + decoder_filters = 4 + num_keypoints = 5 + gauss_std = 0.1 + + encoder = encoder_ctor(name='encoder', **encoder_kwargs) + keypoint_encoder = encoder_ctor(name='keypoint_encoder', **encoder_kwargs) + keypointer = transporter.KeyPointer(keypoint_encoder=keypoint_encoder, + num_keypoints=num_keypoints, + gauss_std=gauss_std) + + decoder = transporter.Decoder(initial_filters=decoder_filters, + output_size=[IMAGE_H, IMAGE_W], + output_channels=IMAGE_C, + norm_type=norm_type) + model = transporter.Transporter(encoder=encoder, + decoder=decoder, + keypointer=keypointer) + + image_a = tf.random.normal(IMAGE_BATCH_SHAPE) + image_b = tf.random.normal(IMAGE_BATCH_SHAPE) + + transporter_results = model(image_a, image_b, is_training=True) + reconstructed_image_b = transporter_results['reconstructed_image_b'] + + self.assertEqual(reconstructed_image_b.shape, IMAGE_BATCH_SHAPE) + + def testIncorrectEncoderShapes(self): + """Test that a possible misconfiguration throws an error as expected. + + If the two encoders used produce different spatial sizes for their + feature maps, this should cause an error when multiplying tensors together. + """ + decoder_filters = 4 + num_keypoints = 5 + gauss_std = 0.1 + + encoder = transporter.Encoder( + filters=FILTERS, + strides=STRIDES, + kernel_sizes=KERNEL_SIZES) + # Use less conv layers in this, in particular one less stride 2 layer, so + # we will get a different spatial output resolution. + keypoint_encoder = transporter.Encoder( + filters=FILTERS[:-2], + strides=STRIDES[:-2], + kernel_sizes=KERNEL_SIZES[:-2]) + + keypointer = transporter.KeyPointer( + keypoint_encoder=keypoint_encoder, + num_keypoints=num_keypoints, + gauss_std=gauss_std) + + decoder = transporter.Decoder( + initial_filters=decoder_filters, + output_size=[IMAGE_H, IMAGE_W], + output_channels=IMAGE_C) + model = transporter.Transporter( + encoder=encoder, + decoder=decoder, + keypointer=keypointer) + + with self.assertRaisesRegexp(ValueError, 'Dimensions must be equal'): + model(tf.random.normal(IMAGE_BATCH_SHAPE), + tf.random.normal(IMAGE_BATCH_SHAPE), + is_training=True) + + +class EncoderTest(tf.test.TestCase): + + def test_output_shape(self): + image_batch = tf.random.normal(shape=IMAGE_BATCH_SHAPE) + + filters = (4, 4, 8, 8, 16, 16) + encoder = transporter.Encoder(filters=filters, + strides=STRIDES, + kernel_sizes=KERNEL_SIZES) + + features = encoder(image_batch, is_training=True) + + self.assertEqual(features.shape, (BATCH_SIZE, + IMAGE_H // 4, + IMAGE_W // 4, + filters[-1])) + + +class KeyPointerTest(tf.test.TestCase): + + def test_output_shape(self): + image_batch = tf.random.normal(shape=IMAGE_BATCH_SHAPE) + num_keypoints = 6 + gauss_std = 0.1 + + keypoint_encoder = transporter.Encoder(filters=FILTERS, + strides=STRIDES, + kernel_sizes=KERNEL_SIZES) + keypointer = transporter.KeyPointer(keypoint_encoder=keypoint_encoder, + num_keypoints=num_keypoints, + gauss_std=gauss_std) + + keypointer_results = keypointer(image_batch, is_training=True) + + self.assertEqual(keypointer_results['centers'].shape, + (BATCH_SIZE, num_keypoints, 2)) + self.assertEqual(keypointer_results['heatmaps'].shape, + (BATCH_SIZE, IMAGE_H // 4, IMAGE_W // 4, num_keypoints)) + + +class DecoderTest(tf.test.TestCase): + + def test_output_shape(self): + feature_batch = tf.random.normal(shape=(BATCH_SIZE, + IMAGE_H // 4, + IMAGE_W // 4, + 64)) + + decoder = transporter.Decoder(initial_filters=64, + output_size=[IMAGE_H, IMAGE_W], + output_channels=IMAGE_C) + + reconstructed_image_batch = decoder(feature_batch, is_training=True) + + self.assertEqual(reconstructed_image_batch.shape, IMAGE_BATCH_SHAPE) + + def test_encoder_decoder_output_shape(self): + image_batch = tf.random.normal(shape=IMAGE_BATCH_SHAPE) + + encoder = transporter.Encoder(filters=FILTERS, + strides=STRIDES, + kernel_sizes=KERNEL_SIZES) + decoder = transporter.Decoder(initial_filters=4, + output_size=[IMAGE_H, IMAGE_W], + output_channels=IMAGE_C) + + features = encoder(image_batch, is_training=True) + reconstructed_images = decoder(features, is_training=True) + + self.assertEqual(reconstructed_images.shape, IMAGE_BATCH_SHAPE) + + +if __name__ == '__main__': + tf.test.main() +