Files
deepmind-research/curl/layers.py
T
2020-01-15 18:20:53 +00:00

121 lines
3.8 KiB
Python

################################################################################
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Custom layers for CURL."""
from absl import logging
import sonnet as snt
import tensorflow.compat.v1 as tf
tfc = tf.compat.v1
class ResidualStack(snt.AbstractModule):
"""A stack of ResNet V2 blocks."""
def __init__(self,
num_hiddens,
num_residual_layers,
num_residual_hiddens,
filter_size=3,
initializers=None,
data_format='NHWC',
activation=tf.nn.relu,
name='residual_stack'):
"""Instantiate a ResidualStack."""
super(ResidualStack, self).__init__(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self._filter_size = filter_size
self._initializers = initializers
self._data_format = data_format
self._activation = activation
def _build(self, h):
for i in range(self._num_residual_layers):
h_i = self._activation(h)
h_i = snt.Conv2D(
output_channels=self._num_residual_hiddens,
kernel_shape=(self._filter_size, self._filter_size),
stride=(1, 1),
initializers=self._initializers,
data_format=self._data_format,
name='res_nxn_%d' % i)(
h_i)
h_i = self._activation(h_i)
h_i = snt.Conv2D(
output_channels=self._num_hiddens,
kernel_shape=(1, 1),
stride=(1, 1),
initializers=self._initializers,
data_format=self._data_format,
name='res_1x1_%d' % i)(
h_i)
h += h_i
return self._activation(h)
class SharedConvModule(snt.AbstractModule):
"""Convolutional decoder."""
def __init__(self,
filters,
kernel_size,
activation,
strides,
name='shared_conv_encoder'):
super(SharedConvModule, self).__init__(name=name)
self._filters = filters
self._kernel_size = kernel_size
self._activation = activation
self.strides = strides
assert len(strides) == len(filters) - 1
self.conv_shapes = None
def _build(self, x, is_training=True):
with tf.control_dependencies([tfc.assert_rank(x, 4)]):
self.conv_shapes = [x.shape.as_list()] # Needed by deconv module
conv = x
for i, (filter_i,
stride_i) in enumerate(zip(self._filters, self.strides), 1):
conv = tf.layers.Conv2D(
filters=filter_i,
kernel_size=self._kernel_size,
padding='same',
activation=self._activation,
strides=stride_i,
name='enc_conv_%d' % i)(
conv)
self.conv_shapes.append(conv.shape.as_list())
conv_flat = snt.BatchFlatten()(conv)
enc_mlp = snt.nets.MLP(
name='enc_mlp',
output_sizes=[self._filters[-1]],
activation=self._activation,
activate_final=True)
h = enc_mlp(conv_flat)
logging.info('Shared conv module layer shapes:')
logging.info('\n'.join([str(el) for el in self.conv_shapes]))
logging.info(h.shape.as_list())
return h