Files
deepmind-research/curl/model.py
T
2020-01-15 18:20:53 +00:00

798 lines
29 KiB
Python

################################################################################
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Implementation of Continual Unsupervised Representation Learning model."""
from absl import logging
import numpy as np
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from curl import layers
from curl import utils
tfc = tf.compat.v1
# pylint: disable=g-long-lambda
# pylint: disable=redefined-outer-name
class SharedEncoder(snt.AbstractModule):
"""The shared encoder module, mapping input x to hiddens."""
def __init__(self, encoder_type, n_enc, enc_strides, name='shared_encoder'):
"""The shared encoder function, mapping input x to hiddens.
Args:
encoder_type: str, type of encoder, either 'conv' or 'multi'
n_enc: list, number of hidden units per layer in the encoder
enc_strides: list, stride in each layer (only for 'conv' encoder_type)
name: str, module name used for tf scope.
"""
super(SharedEncoder, self).__init__(name=name)
self._encoder_type = encoder_type
if encoder_type == 'conv':
self.shared_encoder = layers.SharedConvModule(
filters=n_enc,
strides=enc_strides,
kernel_size=3,
activation=tf.nn.relu)
elif encoder_type == 'multi':
self.shared_encoder = snt.nets.MLP(
name='mlp_shared_encoder',
output_sizes=n_enc,
activation=tf.nn.relu,
activate_final=True)
else:
raise ValueError('Unknown encoder_type {}'.format(encoder_type))
def _build(self, x, is_training):
if self._encoder_type == 'multi':
self.conv_shapes = None
x = snt.BatchFlatten()(x)
return self.shared_encoder(x)
else:
output = self.shared_encoder(x)
self.conv_shapes = self.shared_encoder.conv_shapes
return output
def cluster_encoder_fn(hiddens, n_y_active, n_y, is_training=True):
"""The cluster encoder function, modelling q(y | x).
Args:
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
n_y_active: Tensor, the number of active components.
n_y: int, number of maximum components allowed (used for tensor size)
is_training: Boolean, whether to build the training graph or an evaluation
graph.
Returns:
The distribution `q(y | x)`.
"""
del is_training # unused for now
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
lin = snt.Linear(n_y, name='mlp_cluster_encoder_final')
logits = lin(hiddens)
# Only use the first n_y_active components, and set the remaining to zero.
if n_y > 1:
probs = tf.nn.softmax(logits[:, :n_y_active])
logging.info('Cluster softmax active probs shape: %s', str(probs.shape))
paddings1 = tf.stack([tf.constant(0), tf.constant(0)], axis=0)
paddings2 = tf.stack([tf.constant(0), n_y - n_y_active], axis=0)
paddings = tf.stack([paddings1, paddings2], axis=1)
probs = tf.pad(probs, paddings) + 0.0 * logits + 1e-12
else:
probs = tf.ones_like(logits)
logging.info('Cluster softmax probs shape: %s', str(probs.shape))
return tfp.distributions.OneHotCategorical(probs=probs)
def latent_encoder_fn(hiddens, y, n_y, n_z, is_training=True):
"""The latent encoder function, modelling q(z | x, y).
Args:
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
n_y: int, number of dims of y.
n_z: int, number of dims of z.
is_training: Boolean, whether to build the training graph or an evaluation
graph.
Returns:
The Gaussian distribution `q(z | x, y)`.
"""
del is_training # unused for now
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
# Logits for both mean and variance
n_logits = 2 * n_z
all_logits = []
for k in range(n_y):
lin = snt.Linear(n_logits, name='mlp_latent_encoder_' + str(k))
all_logits.append(lin(hiddens))
# Sum over cluster components.
all_logits = tf.stack(all_logits) # [n_y, B, n_logits]
logits = tf.einsum('ij,jik->ik', y, all_logits)
# Compute distribution from logits.
return utils.generate_gaussian(
logits=logits, sigma_nonlin='softplus', sigma_param='var')
def data_decoder_fn(z,
y,
output_type,
output_shape,
decoder_type,
n_dec,
dec_up_strides,
n_x,
n_y,
shared_encoder_conv_shapes=None,
is_training=True,
test_local_stats=True):
"""The data decoder function, modelling p(x | z).
Args:
z: Latent variables, `Tensor` of size `[B, n_z]`.
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
output_type: str, output distribution ('bernoulli' or 'quantized_normal').
output_shape: list, shape of output (not including batch dimension).
decoder_type: str, 'single', 'multi', or 'deconv'.
n_dec: list, number of hidden units per layer in the decoder
dec_up_strides: list, stride in each layer (only for 'deconv' decoder_type).
n_x: int, number of dims of x.
n_y: int, number of dims of y.
shared_encoder_conv_shapes: the shapes of the activations of the
intermediate layers of the encoder,
is_training: Boolean, whether to build the training graph or an evaluation
graph.
test_local_stats: Boolean, whether to use the test batch statistics at test
time for batch norm (default) or the moving averages.
Returns:
The Bernoulli distribution `p(x | z)`.
"""
if output_type == 'bernoulli':
output_dist = lambda x: tfp.distributions.Bernoulli(logits=x)
n_out_factor = 1
out_shape = list(output_shape)
else:
raise NotImplementedError
if len(z.shape) != 2:
raise NotImplementedError('The data decoder function expects `z` to be '
'2D, but its shape was %s instead.' %
str(z.shape))
if len(y.shape) != 2:
raise NotImplementedError('The data decoder function expects `y` to be '
'2D, but its shape was %s instead.' %
str(y.shape))
# Upsample layer (deconvolutional, bilinear, ..).
if decoder_type == 'deconv':
# First, check that the encoder is convolutional too (needed for batchnorm)
if shared_encoder_conv_shapes is None:
raise ValueError('Shared encoder does not contain conv_shapes.')
num_output_channels = output_shape[-1]
conv_decoder = UpsampleModule(
filters=n_dec,
kernel_size=3,
activation=tf.nn.relu,
dec_up_strides=dec_up_strides,
enc_conv_shapes=shared_encoder_conv_shapes,
n_c=num_output_channels * n_out_factor,
method=decoder_type)
logits = conv_decoder(
z, is_training=is_training, test_local_stats=test_local_stats)
logits = tf.reshape(logits, [-1] + out_shape) # n_out_factor in last dim
# Multiple MLP decoders, one for each component.
elif decoder_type == 'multi':
all_logits = []
for k in range(n_y):
mlp_decoding = snt.nets.MLP(
name='mlp_latent_decoder_' + str(k),
output_sizes=n_dec + [n_x * n_out_factor],
activation=tf.nn.relu,
activate_final=False)
logits = mlp_decoding(z)
all_logits.append(logits)
all_logits = tf.stack(all_logits)
logits = tf.einsum('ij,jik->ik', y, all_logits)
logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D
# Single (shared among components) MLP decoder.
elif decoder_type == 'single':
mlp_decoding = snt.nets.MLP(
name='mlp_latent_decoder',
output_sizes=n_dec + [n_x * n_out_factor],
activation=tf.nn.relu,
activate_final=False)
logits = mlp_decoding(z)
logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D
else:
raise ValueError('Unknown decoder_type {}'.format(decoder_type))
return output_dist(logits)
def latent_decoder_fn(y, n_z, is_training=True):
"""The latent decoder function, modelling p(z | y).
Args:
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
n_z: int, number of dims of z.
is_training: Boolean, whether to build the training graph or an evaluation
graph.
Returns:
The Gaussian distribution `p(z | y)`.
"""
del is_training # Unused for now.
if len(y.shape) != 2:
raise NotImplementedError('The latent decoder function expects `y` to be '
'2D, but its shape was %s instead.' %
str(y.shape))
lin_mu = snt.Linear(n_z, name='latent_prior_mu')
lin_sigma = snt.Linear(n_z, name='latent_prior_sigma')
mu = lin_mu(y)
sigma = lin_sigma(y)
logits = tf.concat([mu, sigma], axis=1)
return utils.generate_gaussian(
logits=logits, sigma_nonlin='softplus', sigma_param='var')
class Curl(object):
"""CURL model class."""
def __init__(self,
prior,
latent_decoder,
data_decoder,
shared_encoder,
cluster_encoder,
latent_encoder,
n_y_active,
kly_over_batch=False,
is_training=True,
name='curl'):
self.scope_name = name
self._shared_encoder = shared_encoder
self._prior = prior
self._latent_decoder = latent_decoder
self._data_decoder = data_decoder
self._cluster_encoder = cluster_encoder
self._latent_encoder = latent_encoder
self._n_y_active = n_y_active
self._kly_over_batch = kly_over_batch
self._is_training = is_training
self._cache = {}
def sample(self, sample_shape=(), y=None, mean=False):
"""Draws a sample from the learnt distribution p(x).
Args:
sample_shape: `int` or 0D `Tensor` giving the number of samples to return.
If empty tuple (default value), 1 sample will be returned.
y: Optional, the one hot label on which to condition the sample.
mean: Boolean, if True the expected value of the output distribution is
returned, otherwise samples from the output distribution.
Returns:
Sample tensor of shape `[B * N, ...]` where `B` is the batch size of
the prior, `N` is the number of samples requested, and `...` represents
the shape of the observations.
Raises:
ValueError: If both `sample_shape` and `n` are provided.
ValueError: If `sample_shape` has rank > 0 or if `sample_shape`
is an int that is < 1.
"""
with tf.name_scope('{}_sample'.format(self.scope_name)):
if y is None:
y = tf.to_float(self.compute_prior().sample(sample_shape))
if y.shape.ndims > 2:
y = snt.MergeDims(start=0, size=y.shape.ndims - 1, name='merge_y')(y)
z = self._latent_decoder(y, is_training=self._is_training)
if mean:
samples = self.predict(z.sample(), y).mean()
else:
samples = self.predict(z.sample(), y).sample()
return samples
def reconstruct(self, x, use_mode=True, use_mean=False):
"""Reconstructs the given observations.
Args:
x: Observed `Tensor`.
use_mode: Boolean, if true, take the argmax over q(y|x)
use_mean: Boolean, if true, use pixel-mean for reconstructions.
Returns:
The reconstructed samples x ~ p(x | y~q(y|x), z~q(z|x, y)).
"""
hiddens = self._shared_encoder(x, is_training=self._is_training)
qy = self.infer_cluster(hiddens)
y_sample = qy.mode() if use_mode else qy.sample()
y_sample = tf.to_float(y_sample)
qz = self.infer_latent(hiddens, y_sample)
p = self.predict(qz.sample(), y_sample)
if use_mean:
return p.mean()
else:
return p.sample()
def log_prob(self, x):
"""Redirects to log_prob_elbo with a warning."""
logging.warn('log_prob is actually a lower bound')
return self.log_prob_elbo(x)
def log_prob_elbo(self, x):
"""Returns evidence lower bound."""
log_p_x, kl_y, kl_z = self.log_prob_elbo_components(x)[:3]
return log_p_x - kl_y - kl_z
def log_prob_elbo_components(self, x, y=None, reduce_op=tf.reduce_sum):
"""Returns the components used in calculating the evidence lower bound.
Args:
x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of
a flattened input.
y: Optional labels, `Tensor` of size `[B, I]` where `I` is the size of a
flattened input.
reduce_op: The op to use for reducing across non-batch dimensions.
Typically either `tf.reduce_sum` or `tf.reduce_mean`.
Returns:
`log p(x|y,z)` of shape `[B]` where `B` is the batch size.
`KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size.
`KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size.
"""
cache_key = (x,)
# Checks if the output graph for this inputs has already been computed.
if cache_key in self._cache:
return self._cache[cache_key]
with tf.name_scope('{}_log_prob_elbo'.format(self.scope_name)):
hiddens = self._shared_encoder(x, is_training=self._is_training)
# 1) Compute KL[q(y|x) || p(y)] from x, and keep distribution q_y around
kl_y, q_y = self._kl_and_qy(hiddens) # [B], distribution
# For the next two terms, we need to marginalise over all y.
# First, construct every possible y indexing (as a one hot) and repeat it
# for every element in the batch [n_y_active, B, n_y].
# Note that the onehot have dimension of all y, while only the codes
# corresponding to active components are instantiated
bs, n_y = q_y.probs.shape
all_y = tf.tile(
tf.expand_dims(tf.one_hot(tf.range(self._n_y_active),
n_y), axis=1),
multiples=[1, bs, 1])
# 2) Compute KL[q(z|x,y) || p(z|y)] (for all possible y), and keep z's
# around [n_y, B] and [n_y, B, n_z]
kl_z_all, z_all = tf.map_fn(
fn=lambda y: self._kl_and_z(hiddens, y),
elems=all_y,
dtype=(tf.float32, tf.float32),
name='elbo_components_z_map')
kl_z_all = tf.transpose(kl_z_all, name='kl_z_all')
# Now take the expectation over y (scale by q(y|x))
y_logits = q_y.logits[:, :self._n_y_active] # [B, n_y]
y_probs = q_y.probs[:, :self._n_y_active] # [B, n_y]
y_probs = y_probs / tf.reduce_sum(y_probs, axis=1, keepdims=True)
kl_z = tf.reduce_sum(y_probs * kl_z_all, axis=1)
# 3) Evaluate logp and recon, i.e., log and mean of p(x|z,[y])
# (conditioning on y only in the `multi` decoder_type case, when
# train_supervised is True). Here we take the reconstruction from each
# possible component y and take its log prob. [n_y, B, Ix, Iy, Iz]
log_p_x_all = tf.map_fn(
fn=lambda val: self.predict(val[0], val[1]).log_prob(x),
elems=(z_all, all_y),
dtype=tf.float32,
name='elbo_components_logpx_map')
# Sum log probs over all dimensions apart from the first two (n_y, B),
# i.e., over I. Use einsum to construct higher order multiplication.
log_p_x_all = snt.BatchFlatten(preserve_dims=2)(log_p_x_all) # [n_y,B,I]
# Note, this is E_{q(y|x)} [ log p(x | z, y)], i.e., we scale log_p_x_all
# by q(y|x).
log_p_x = tf.einsum('ij,jik->ik', y_probs, log_p_x_all) # [B, I]
# We may also use a supervised loss for some samples [B, n_y]
if y is not None:
self.y_label = tf.one_hot(y, n_y)
else:
self.y_label = tfc.placeholder(
shape=[bs, n_y], dtype=tf.float32, name='y_label')
# This is computing log p(x | z, y=true_y)], which is basically equivalent
# to indexing into the correct element of `log_p_x_all`.
log_p_x_sup = tf.einsum('ij,jik->ik',
self.y_label[:, :self._n_y_active],
log_p_x_all) # [B, I]
kl_z_sup = tf.einsum('ij,ij->i',
self.y_label[:, :self._n_y_active],
kl_z_all) # [B]
# -log q(y=y_true | x)
kl_y_sup = tf.nn.sparse_softmax_cross_entropy_with_logits( # [B]
labels=tf.argmax(self.y_label[:, :self._n_y_active], axis=1),
logits=y_logits)
# Reduce over all dimension except batch.
dims_x = [k for k in range(1, log_p_x.shape.ndims)]
log_p_x = reduce_op(log_p_x, dims_x, name='log_p_x')
log_p_x_sup = reduce_op(log_p_x_sup, dims_x, name='log_p_x_sup')
# Store values needed externally
self.q_y = q_y
self.log_p_x_all = tf.transpose(
reduce_op(
log_p_x_all,
-1, # [B, n_y]
name='log_p_x_all'))
self.kl_z_all = kl_z_all
self.y_probs = y_probs
self._cache[cache_key] = (log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup,
kl_z_sup)
return log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup, kl_z_sup
def _kl_and_qy(self, hiddens):
"""Returns analytical or sampled KL div and the distribution q(y | x).
Args:
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
Returns:
Pair `(kl, y)`, where `kl` is the KL divergence (a `Tensor` with shape
`[B]`, where `B` is the batch size), and `y` is a sample from the
categorical encoding distribution.
"""
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
q = self.infer_cluster(hiddens) # q(y|x)
p = self.compute_prior() # p(y)
try:
# Take the average proportions over whole batch then repeat it in each row
# before computing the KL
if self._kly_over_batch:
probs = tf.reduce_mean(
q.probs, axis=0, keepdims=True) * tf.ones_like(q.probs)
qmean = tfp.distributions.OneHotCategorical(probs=probs)
kl = tfp.distributions.kl_divergence(qmean, p)
else:
kl = tfp.distributions.kl_divergence(q, p)
except NotImplementedError:
y = q.sample(name='y_sample')
logging.warn('Using sampling KLD for y')
log_p_y = p.log_prob(y, name='log_p_y')
log_q_y = q.log_prob(y, name='log_q_y')
# Reduce over all dimension except batch.
sum_axis_p = [k for k in range(1, log_p_y.get_shape().ndims)]
log_p_y = tf.reduce_sum(log_p_y, sum_axis_p)
sum_axis_q = [k for k in range(1, log_q_y.get_shape().ndims)]
log_q_y = tf.reduce_sum(log_q_y, sum_axis_q)
kl = log_q_y - log_p_y
# Reduce over all dimension except batch.
sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)]
kl = tf.reduce_sum(kl, sum_axis_kl, name='kl')
return kl, q
def _kl_and_z(self, hiddens, y):
"""Returns KL[q(z|y,x) || p(z|y)] and a sample for z from q(z|y,x).
Returns the analytical KL divergence KL[q(z|y,x) || p(z|y)] if one is
available (as registered with `kullback_leibler.RegisterKL`), or a sampled
KL divergence otherwise (in this case the returned sample is the one used
for the KL divergence).
Args:
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
y: Categorical cluster random variable, `Tensor` of size `[B, n_y]`.
Returns:
Pair `(kl, z)`, where `kl` is the KL divergence (a `Tensor` with shape
`[B]`, where `B` is the batch size), and `z` is a sample from the encoding
distribution.
"""
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
q = self.infer_latent(hiddens, y) # q(z|x,y)
p = self.generate_latent(y) # p(z|y)
z = q.sample(name='z')
try:
kl = tfp.distributions.kl_divergence(q, p)
except NotImplementedError:
logging.warn('Using sampling KLD for z')
log_p_z = p.log_prob(z, name='log_p_z_y')
log_q_z = q.log_prob(z, name='log_q_z_xy')
# Reduce over all dimension except batch.
sum_axis_p = [k for k in range(1, log_p_z.get_shape().ndims)]
log_p_z = tf.reduce_sum(log_p_z, sum_axis_p)
sum_axis_q = [k for k in range(1, log_q_z.get_shape().ndims)]
log_q_z = tf.reduce_sum(log_q_z, sum_axis_q)
kl = log_q_z - log_p_z
# Reduce over all dimension except batch.
sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)]
kl = tf.reduce_sum(kl, sum_axis_kl, name='kl')
return kl, z
def infer_latent(self, hiddens, y=None, use_mean_y=False):
"""Performs inference over the latent variable z.
Args:
hiddens: The shared encoder activations, 4D `Tensor` of size `[B, ...]`.
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
use_mean_y: Boolean, whether to take the mean encoding over all y.
Returns:
The distribution `q(z|x, y)`, which on sample produces tensors of size
`[N, B, ...]` where `B` is the batch size of `x` and `y`, and `N` is the
number of samples and `...` represents the shape of the latent variables.
"""
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
if y is None:
y = tf.to_float(self.infer_cluster(hiddens).mode())
if use_mean_y:
# If use_mean_y, then y must be probabilities
all_y = tf.tile(
tf.expand_dims(tf.one_hot(tf.range(y.shape[1]), y.shape[1]), axis=1),
multiples=[1, y.shape[0], 1])
# Compute z KL from x (for all possible y), and keep z's around
z_all = tf.map_fn(
fn=lambda y: self._latent_encoder(
hiddens, y, is_training=self._is_training).mean(),
elems=all_y,
dtype=tf.float32)
return tf.einsum('ij,jik->ik', y, z_all)
else:
return self._latent_encoder(hiddens, y, is_training=self._is_training)
def generate_latent(self, y):
"""Use the generative model to compute latent variable z, given a y.
Args:
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
Returns:
The distribution `p(z|y)`, which on sample produces tensors of size
`[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of
samples asked and `...` represents the shape of the latent variables.
"""
return self._latent_decoder(y, is_training=self._is_training)
def get_shared_rep(self, x, is_training):
"""Gets the shared representation from a given input x.
Args:
x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of
a flattened input.
is_training: bool, whether this constitutes training data or not.
Returns:
`log p(x|y,z)` of shape `[B]` where `B` is the batch size.
`KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size.
`KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size.
"""
return self._shared_encoder(x, is_training)
def infer_cluster(self, hiddens):
"""Performs inference over the categorical variable y.
Args:
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
Returns:
The distribution `q(y|x)`, which on sample produces tensors of size
`[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of
samples asked and `...` represents the shape of the latent variables.
"""
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
return self._cluster_encoder(hiddens, is_training=self._is_training)
def predict(self, z, y):
"""Computes prediction over the observed variables.
Args:
z: Latent variables, `Tensor` of size `[B, ...]`.
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
Returns:
The distribution `p(x|z)`, which on sample produces tensors of size
`[N, B, ...]` where `N` is the number of samples asked.
"""
encoder_conv_shapes = getattr(self._shared_encoder, 'conv_shapes', None)
return self._data_decoder(
z,
y,
shared_encoder_conv_shapes=encoder_conv_shapes,
is_training=self._is_training)
def compute_prior(self):
"""Computes prior over the latent variables.
Returns:
The distribution `p(y)`, which on sample produces tensors of size
`[N, ...]` where `N` is the number of samples asked and `...` represents
the shape of the latent variables.
"""
return self._prior()
class UpsampleModule(snt.AbstractModule):
"""Convolutional decoder.
If `method` is 'deconv' apply transposed convolutions with stride 2,
otherwise apply the `method` upsampling function and then smooth with a
stride 1x1 convolution.
Params:
-------
filters: list, where the first element is the number of filters of the initial
MLP layer and the remaining elements are the number of filters of the
upsampling layers.
kernel_size: the size of the convolutional kernels. The same size will be
used in all convolutions.
activation: an activation function, applied to all layers but the last.
dec_up_strides: list, the upsampling factors of each upsampling convolutional
layer.
enc_conv_shapes: list, the shapes of the input and of all the intermediate
feature maps of the convolutional layers in the encoder.
n_c: the number of output channels.
"""
def __init__(self,
filters,
kernel_size,
activation,
dec_up_strides,
enc_conv_shapes,
n_c,
method='nn',
name='upsample_module'):
super(UpsampleModule, self).__init__(name=name)
assert len(filters) == len(dec_up_strides) + 1, (
'The decoder\'s filters should contain one element more than the '
'decoder\'s up stride list, but has %d elements instead of %d.\n'
'Decoder filters: %s\nDecoder up strides: %s' %
(len(filters), len(dec_up_strides) + 1, str(filters),
str(dec_up_strides)))
self._filters = filters
self._kernel_size = kernel_size
self._activation = activation
self._dec_up_strides = dec_up_strides
self._enc_conv_shapes = enc_conv_shapes
self._n_c = n_c
if method == 'deconv':
self._conv_layer = tf.layers.Conv2DTranspose
self._method = method
else:
self._conv_layer = tf.layers.Conv2D
self._method = getattr(tf.image.ResizeMethod, method.upper())
self._method_str = method.capitalize()
def _build(self, z, is_training=True, test_local_stats=True, use_bn=False):
batch_norm_args = {
'is_training': is_training,
'test_local_stats': test_local_stats
}
method = self._method
# Cycle over the encoder shapes backwards, to build a symmetrical decoder.
enc_conv_shapes = self._enc_conv_shapes[::-1]
strides = self._dec_up_strides
# We store the heights and widths of the encoder feature maps that are
# unique, i.e., the ones right after a layer with stride != 1. These will be
# used as a target to potentially crop the upsampled feature maps.
unique_hw = np.unique([(el[1], el[2]) for el in enc_conv_shapes], axis=0)
unique_hw = unique_hw.tolist()[::-1]
unique_hw.pop() # Drop the initial shape
# The first filter is an MLP.
mlp_filter, conv_filters = self._filters[0], self._filters[1:]
# The first shape is used after the MLP to go to 4D.
layers = [z]
# The shape of the first enc is used after the MLP to go back to 4D.
dec_mlp = snt.nets.MLP(
name='dec_mlp_projection',
output_sizes=[mlp_filter, np.prod(enc_conv_shapes[0][1:])],
use_bias=not use_bn,
activation=self._activation,
activate_final=True)
upsample_mlp_flat = dec_mlp(z)
if use_bn:
upsample_mlp_flat = snt.BatchNorm(scale=True)(upsample_mlp_flat,
**batch_norm_args)
layers.append(upsample_mlp_flat)
upsample = tf.reshape(upsample_mlp_flat, enc_conv_shapes[0])
layers.append(upsample)
for i, (filter_i, stride_i) in enumerate(zip(conv_filters, strides), 1):
if method != 'deconv' and stride_i > 1:
upsample = tf.image.resize_images(
upsample, [stride_i * el for el in upsample.shape.as_list()[1:3]],
method=method,
name='upsample_' + str(i))
upsample = self._conv_layer(
filters=filter_i,
kernel_size=self._kernel_size,
padding='same',
use_bias=not use_bn,
activation=self._activation,
strides=stride_i if method == 'deconv' else 1,
name='upsample_conv_' + str(i))(
upsample)
if use_bn:
upsample = snt.BatchNorm(scale=True)(upsample, **batch_norm_args)
if stride_i > 1:
hw = unique_hw.pop()
upsample = utils.maybe_center_crop(upsample, hw)
layers.append(upsample)
# Final layer, no upsampling.
x_logits = tf.layers.Conv2D(
filters=self._n_c,
kernel_size=self._kernel_size,
padding='same',
use_bias=not use_bn,
activation=None,
strides=1,
name='logits')(
upsample)
if use_bn:
x_logits = snt.BatchNorm(scale=True)(x_logits, **batch_norm_args)
layers.append(x_logits)
logging.info('%s upsampling module layer shapes', self._method_str)
logging.info('\n'.join([str(v.shape.as_list()) for v in layers]))
return x_logits