mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
656f93a37e
PiperOrigin-RevId: 279101132
798 lines
29 KiB
Python
798 lines
29 KiB
Python
################################################################################
|
|
# Copyright 2019 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
################################################################################
|
|
"""Implementation of Continual Unsupervised Representation Learning model."""
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
import sonnet as snt
|
|
import tensorflow as tf
|
|
import tensorflow_probability as tfp
|
|
|
|
from curl import layers
|
|
from curl import utils
|
|
|
|
tfc = tf.compat.v1
|
|
|
|
# pylint: disable=g-long-lambda
|
|
# pylint: disable=redefined-outer-name
|
|
|
|
|
|
class SharedEncoder(snt.AbstractModule):
|
|
"""The shared encoder module, mapping input x to hiddens."""
|
|
|
|
def __init__(self, encoder_type, n_enc, enc_strides, name='shared_encoder'):
|
|
"""The shared encoder function, mapping input x to hiddens.
|
|
|
|
Args:
|
|
encoder_type: str, type of encoder, either 'conv' or 'multi'
|
|
n_enc: list, number of hidden units per layer in the encoder
|
|
enc_strides: list, stride in each layer (only for 'conv' encoder_type)
|
|
name: str, module name used for tf scope.
|
|
"""
|
|
super(SharedEncoder, self).__init__(name=name)
|
|
self._encoder_type = encoder_type
|
|
|
|
if encoder_type == 'conv':
|
|
self.shared_encoder = layers.SharedConvModule(
|
|
filters=n_enc,
|
|
strides=enc_strides,
|
|
kernel_size=3,
|
|
activation=tf.nn.relu)
|
|
elif encoder_type == 'multi':
|
|
self.shared_encoder = snt.nets.MLP(
|
|
name='mlp_shared_encoder',
|
|
output_sizes=n_enc,
|
|
activation=tf.nn.relu,
|
|
activate_final=True)
|
|
else:
|
|
raise ValueError('Unknown encoder_type {}'.format(encoder_type))
|
|
|
|
def _build(self, x, is_training):
|
|
if self._encoder_type == 'multi':
|
|
self.conv_shapes = None
|
|
x = snt.BatchFlatten()(x)
|
|
return self.shared_encoder(x)
|
|
else:
|
|
output = self.shared_encoder(x)
|
|
self.conv_shapes = self.shared_encoder.conv_shapes
|
|
return output
|
|
|
|
|
|
def cluster_encoder_fn(hiddens, n_y_active, n_y, is_training=True):
|
|
"""The cluster encoder function, modelling q(y | x).
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
|
|
n_y_active: Tensor, the number of active components.
|
|
n_y: int, number of maximum components allowed (used for tensor size)
|
|
is_training: Boolean, whether to build the training graph or an evaluation
|
|
graph.
|
|
|
|
Returns:
|
|
The distribution `q(y | x)`.
|
|
"""
|
|
del is_training # unused for now
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
lin = snt.Linear(n_y, name='mlp_cluster_encoder_final')
|
|
logits = lin(hiddens)
|
|
|
|
# Only use the first n_y_active components, and set the remaining to zero.
|
|
if n_y > 1:
|
|
probs = tf.nn.softmax(logits[:, :n_y_active])
|
|
logging.info('Cluster softmax active probs shape: %s', str(probs.shape))
|
|
paddings1 = tf.stack([tf.constant(0), tf.constant(0)], axis=0)
|
|
paddings2 = tf.stack([tf.constant(0), n_y - n_y_active], axis=0)
|
|
paddings = tf.stack([paddings1, paddings2], axis=1)
|
|
probs = tf.pad(probs, paddings) + 0.0 * logits + 1e-12
|
|
else:
|
|
probs = tf.ones_like(logits)
|
|
logging.info('Cluster softmax probs shape: %s', str(probs.shape))
|
|
|
|
return tfp.distributions.OneHotCategorical(probs=probs)
|
|
|
|
|
|
def latent_encoder_fn(hiddens, y, n_y, n_z, is_training=True):
|
|
"""The latent encoder function, modelling q(z | x, y).
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
|
|
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
|
|
n_y: int, number of dims of y.
|
|
n_z: int, number of dims of z.
|
|
is_training: Boolean, whether to build the training graph or an evaluation
|
|
graph.
|
|
|
|
Returns:
|
|
The Gaussian distribution `q(z | x, y)`.
|
|
"""
|
|
del is_training # unused for now
|
|
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
# Logits for both mean and variance
|
|
n_logits = 2 * n_z
|
|
|
|
all_logits = []
|
|
for k in range(n_y):
|
|
lin = snt.Linear(n_logits, name='mlp_latent_encoder_' + str(k))
|
|
all_logits.append(lin(hiddens))
|
|
|
|
# Sum over cluster components.
|
|
all_logits = tf.stack(all_logits) # [n_y, B, n_logits]
|
|
logits = tf.einsum('ij,jik->ik', y, all_logits)
|
|
|
|
# Compute distribution from logits.
|
|
return utils.generate_gaussian(
|
|
logits=logits, sigma_nonlin='softplus', sigma_param='var')
|
|
|
|
|
|
def data_decoder_fn(z,
|
|
y,
|
|
output_type,
|
|
output_shape,
|
|
decoder_type,
|
|
n_dec,
|
|
dec_up_strides,
|
|
n_x,
|
|
n_y,
|
|
shared_encoder_conv_shapes=None,
|
|
is_training=True,
|
|
test_local_stats=True):
|
|
"""The data decoder function, modelling p(x | z).
|
|
|
|
Args:
|
|
z: Latent variables, `Tensor` of size `[B, n_z]`.
|
|
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
|
|
output_type: str, output distribution ('bernoulli' or 'quantized_normal').
|
|
output_shape: list, shape of output (not including batch dimension).
|
|
decoder_type: str, 'single', 'multi', or 'deconv'.
|
|
n_dec: list, number of hidden units per layer in the decoder
|
|
dec_up_strides: list, stride in each layer (only for 'deconv' decoder_type).
|
|
n_x: int, number of dims of x.
|
|
n_y: int, number of dims of y.
|
|
shared_encoder_conv_shapes: the shapes of the activations of the
|
|
intermediate layers of the encoder,
|
|
is_training: Boolean, whether to build the training graph or an evaluation
|
|
graph.
|
|
test_local_stats: Boolean, whether to use the test batch statistics at test
|
|
time for batch norm (default) or the moving averages.
|
|
|
|
Returns:
|
|
The Bernoulli distribution `p(x | z)`.
|
|
"""
|
|
|
|
if output_type == 'bernoulli':
|
|
output_dist = lambda x: tfp.distributions.Bernoulli(logits=x)
|
|
n_out_factor = 1
|
|
out_shape = list(output_shape)
|
|
else:
|
|
raise NotImplementedError
|
|
if len(z.shape) != 2:
|
|
raise NotImplementedError('The data decoder function expects `z` to be '
|
|
'2D, but its shape was %s instead.' %
|
|
str(z.shape))
|
|
if len(y.shape) != 2:
|
|
raise NotImplementedError('The data decoder function expects `y` to be '
|
|
'2D, but its shape was %s instead.' %
|
|
str(y.shape))
|
|
|
|
# Upsample layer (deconvolutional, bilinear, ..).
|
|
if decoder_type == 'deconv':
|
|
|
|
# First, check that the encoder is convolutional too (needed for batchnorm)
|
|
if shared_encoder_conv_shapes is None:
|
|
raise ValueError('Shared encoder does not contain conv_shapes.')
|
|
|
|
num_output_channels = output_shape[-1]
|
|
conv_decoder = UpsampleModule(
|
|
filters=n_dec,
|
|
kernel_size=3,
|
|
activation=tf.nn.relu,
|
|
dec_up_strides=dec_up_strides,
|
|
enc_conv_shapes=shared_encoder_conv_shapes,
|
|
n_c=num_output_channels * n_out_factor,
|
|
method=decoder_type)
|
|
logits = conv_decoder(
|
|
z, is_training=is_training, test_local_stats=test_local_stats)
|
|
logits = tf.reshape(logits, [-1] + out_shape) # n_out_factor in last dim
|
|
|
|
# Multiple MLP decoders, one for each component.
|
|
elif decoder_type == 'multi':
|
|
all_logits = []
|
|
for k in range(n_y):
|
|
mlp_decoding = snt.nets.MLP(
|
|
name='mlp_latent_decoder_' + str(k),
|
|
output_sizes=n_dec + [n_x * n_out_factor],
|
|
activation=tf.nn.relu,
|
|
activate_final=False)
|
|
logits = mlp_decoding(z)
|
|
all_logits.append(logits)
|
|
|
|
all_logits = tf.stack(all_logits)
|
|
logits = tf.einsum('ij,jik->ik', y, all_logits)
|
|
logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D
|
|
|
|
# Single (shared among components) MLP decoder.
|
|
elif decoder_type == 'single':
|
|
mlp_decoding = snt.nets.MLP(
|
|
name='mlp_latent_decoder',
|
|
output_sizes=n_dec + [n_x * n_out_factor],
|
|
activation=tf.nn.relu,
|
|
activate_final=False)
|
|
logits = mlp_decoding(z)
|
|
logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D
|
|
else:
|
|
raise ValueError('Unknown decoder_type {}'.format(decoder_type))
|
|
|
|
return output_dist(logits)
|
|
|
|
|
|
def latent_decoder_fn(y, n_z, is_training=True):
|
|
"""The latent decoder function, modelling p(z | y).
|
|
|
|
Args:
|
|
y: Categorical cluster variable, `Tensor` of size `[B, n_y]`.
|
|
n_z: int, number of dims of z.
|
|
is_training: Boolean, whether to build the training graph or an evaluation
|
|
graph.
|
|
|
|
Returns:
|
|
The Gaussian distribution `p(z | y)`.
|
|
"""
|
|
del is_training # Unused for now.
|
|
if len(y.shape) != 2:
|
|
raise NotImplementedError('The latent decoder function expects `y` to be '
|
|
'2D, but its shape was %s instead.' %
|
|
str(y.shape))
|
|
|
|
lin_mu = snt.Linear(n_z, name='latent_prior_mu')
|
|
lin_sigma = snt.Linear(n_z, name='latent_prior_sigma')
|
|
|
|
mu = lin_mu(y)
|
|
sigma = lin_sigma(y)
|
|
|
|
logits = tf.concat([mu, sigma], axis=1)
|
|
|
|
return utils.generate_gaussian(
|
|
logits=logits, sigma_nonlin='softplus', sigma_param='var')
|
|
|
|
|
|
class Curl(object):
|
|
"""CURL model class."""
|
|
|
|
def __init__(self,
|
|
prior,
|
|
latent_decoder,
|
|
data_decoder,
|
|
shared_encoder,
|
|
cluster_encoder,
|
|
latent_encoder,
|
|
n_y_active,
|
|
kly_over_batch=False,
|
|
is_training=True,
|
|
name='curl'):
|
|
self.scope_name = name
|
|
self._shared_encoder = shared_encoder
|
|
self._prior = prior
|
|
self._latent_decoder = latent_decoder
|
|
self._data_decoder = data_decoder
|
|
self._cluster_encoder = cluster_encoder
|
|
self._latent_encoder = latent_encoder
|
|
self._n_y_active = n_y_active
|
|
self._kly_over_batch = kly_over_batch
|
|
self._is_training = is_training
|
|
self._cache = {}
|
|
|
|
def sample(self, sample_shape=(), y=None, mean=False):
|
|
"""Draws a sample from the learnt distribution p(x).
|
|
|
|
Args:
|
|
sample_shape: `int` or 0D `Tensor` giving the number of samples to return.
|
|
If empty tuple (default value), 1 sample will be returned.
|
|
y: Optional, the one hot label on which to condition the sample.
|
|
mean: Boolean, if True the expected value of the output distribution is
|
|
returned, otherwise samples from the output distribution.
|
|
|
|
Returns:
|
|
Sample tensor of shape `[B * N, ...]` where `B` is the batch size of
|
|
the prior, `N` is the number of samples requested, and `...` represents
|
|
the shape of the observations.
|
|
|
|
Raises:
|
|
ValueError: If both `sample_shape` and `n` are provided.
|
|
ValueError: If `sample_shape` has rank > 0 or if `sample_shape`
|
|
is an int that is < 1.
|
|
"""
|
|
with tf.name_scope('{}_sample'.format(self.scope_name)):
|
|
if y is None:
|
|
y = tf.to_float(self.compute_prior().sample(sample_shape))
|
|
|
|
if y.shape.ndims > 2:
|
|
y = snt.MergeDims(start=0, size=y.shape.ndims - 1, name='merge_y')(y)
|
|
|
|
z = self._latent_decoder(y, is_training=self._is_training)
|
|
if mean:
|
|
samples = self.predict(z.sample(), y).mean()
|
|
else:
|
|
samples = self.predict(z.sample(), y).sample()
|
|
return samples
|
|
|
|
def reconstruct(self, x, use_mode=True, use_mean=False):
|
|
"""Reconstructs the given observations.
|
|
|
|
Args:
|
|
x: Observed `Tensor`.
|
|
use_mode: Boolean, if true, take the argmax over q(y|x)
|
|
use_mean: Boolean, if true, use pixel-mean for reconstructions.
|
|
|
|
Returns:
|
|
The reconstructed samples x ~ p(x | y~q(y|x), z~q(z|x, y)).
|
|
"""
|
|
|
|
hiddens = self._shared_encoder(x, is_training=self._is_training)
|
|
qy = self.infer_cluster(hiddens)
|
|
y_sample = qy.mode() if use_mode else qy.sample()
|
|
y_sample = tf.to_float(y_sample)
|
|
qz = self.infer_latent(hiddens, y_sample)
|
|
p = self.predict(qz.sample(), y_sample)
|
|
|
|
if use_mean:
|
|
return p.mean()
|
|
else:
|
|
return p.sample()
|
|
|
|
def log_prob(self, x):
|
|
"""Redirects to log_prob_elbo with a warning."""
|
|
logging.warn('log_prob is actually a lower bound')
|
|
return self.log_prob_elbo(x)
|
|
|
|
def log_prob_elbo(self, x):
|
|
"""Returns evidence lower bound."""
|
|
log_p_x, kl_y, kl_z = self.log_prob_elbo_components(x)[:3]
|
|
return log_p_x - kl_y - kl_z
|
|
|
|
def log_prob_elbo_components(self, x, y=None, reduce_op=tf.reduce_sum):
|
|
"""Returns the components used in calculating the evidence lower bound.
|
|
|
|
Args:
|
|
x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of
|
|
a flattened input.
|
|
y: Optional labels, `Tensor` of size `[B, I]` where `I` is the size of a
|
|
flattened input.
|
|
reduce_op: The op to use for reducing across non-batch dimensions.
|
|
Typically either `tf.reduce_sum` or `tf.reduce_mean`.
|
|
|
|
Returns:
|
|
`log p(x|y,z)` of shape `[B]` where `B` is the batch size.
|
|
`KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size.
|
|
`KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size.
|
|
"""
|
|
cache_key = (x,)
|
|
|
|
# Checks if the output graph for this inputs has already been computed.
|
|
if cache_key in self._cache:
|
|
return self._cache[cache_key]
|
|
|
|
with tf.name_scope('{}_log_prob_elbo'.format(self.scope_name)):
|
|
|
|
hiddens = self._shared_encoder(x, is_training=self._is_training)
|
|
# 1) Compute KL[q(y|x) || p(y)] from x, and keep distribution q_y around
|
|
kl_y, q_y = self._kl_and_qy(hiddens) # [B], distribution
|
|
|
|
# For the next two terms, we need to marginalise over all y.
|
|
|
|
# First, construct every possible y indexing (as a one hot) and repeat it
|
|
# for every element in the batch [n_y_active, B, n_y].
|
|
# Note that the onehot have dimension of all y, while only the codes
|
|
# corresponding to active components are instantiated
|
|
bs, n_y = q_y.probs.shape
|
|
all_y = tf.tile(
|
|
tf.expand_dims(tf.one_hot(tf.range(self._n_y_active),
|
|
n_y), axis=1),
|
|
multiples=[1, bs, 1])
|
|
|
|
# 2) Compute KL[q(z|x,y) || p(z|y)] (for all possible y), and keep z's
|
|
# around [n_y, B] and [n_y, B, n_z]
|
|
kl_z_all, z_all = tf.map_fn(
|
|
fn=lambda y: self._kl_and_z(hiddens, y),
|
|
elems=all_y,
|
|
dtype=(tf.float32, tf.float32),
|
|
name='elbo_components_z_map')
|
|
kl_z_all = tf.transpose(kl_z_all, name='kl_z_all')
|
|
|
|
# Now take the expectation over y (scale by q(y|x))
|
|
y_logits = q_y.logits[:, :self._n_y_active] # [B, n_y]
|
|
y_probs = q_y.probs[:, :self._n_y_active] # [B, n_y]
|
|
y_probs = y_probs / tf.reduce_sum(y_probs, axis=1, keepdims=True)
|
|
kl_z = tf.reduce_sum(y_probs * kl_z_all, axis=1)
|
|
|
|
# 3) Evaluate logp and recon, i.e., log and mean of p(x|z,[y])
|
|
# (conditioning on y only in the `multi` decoder_type case, when
|
|
# train_supervised is True). Here we take the reconstruction from each
|
|
# possible component y and take its log prob. [n_y, B, Ix, Iy, Iz]
|
|
log_p_x_all = tf.map_fn(
|
|
fn=lambda val: self.predict(val[0], val[1]).log_prob(x),
|
|
elems=(z_all, all_y),
|
|
dtype=tf.float32,
|
|
name='elbo_components_logpx_map')
|
|
|
|
# Sum log probs over all dimensions apart from the first two (n_y, B),
|
|
# i.e., over I. Use einsum to construct higher order multiplication.
|
|
log_p_x_all = snt.BatchFlatten(preserve_dims=2)(log_p_x_all) # [n_y,B,I]
|
|
# Note, this is E_{q(y|x)} [ log p(x | z, y)], i.e., we scale log_p_x_all
|
|
# by q(y|x).
|
|
log_p_x = tf.einsum('ij,jik->ik', y_probs, log_p_x_all) # [B, I]
|
|
|
|
# We may also use a supervised loss for some samples [B, n_y]
|
|
if y is not None:
|
|
self.y_label = tf.one_hot(y, n_y)
|
|
else:
|
|
self.y_label = tfc.placeholder(
|
|
shape=[bs, n_y], dtype=tf.float32, name='y_label')
|
|
|
|
# This is computing log p(x | z, y=true_y)], which is basically equivalent
|
|
# to indexing into the correct element of `log_p_x_all`.
|
|
log_p_x_sup = tf.einsum('ij,jik->ik',
|
|
self.y_label[:, :self._n_y_active],
|
|
log_p_x_all) # [B, I]
|
|
kl_z_sup = tf.einsum('ij,ij->i',
|
|
self.y_label[:, :self._n_y_active],
|
|
kl_z_all) # [B]
|
|
# -log q(y=y_true | x)
|
|
kl_y_sup = tf.nn.sparse_softmax_cross_entropy_with_logits( # [B]
|
|
labels=tf.argmax(self.y_label[:, :self._n_y_active], axis=1),
|
|
logits=y_logits)
|
|
|
|
# Reduce over all dimension except batch.
|
|
dims_x = [k for k in range(1, log_p_x.shape.ndims)]
|
|
log_p_x = reduce_op(log_p_x, dims_x, name='log_p_x')
|
|
log_p_x_sup = reduce_op(log_p_x_sup, dims_x, name='log_p_x_sup')
|
|
|
|
# Store values needed externally
|
|
self.q_y = q_y
|
|
self.log_p_x_all = tf.transpose(
|
|
reduce_op(
|
|
log_p_x_all,
|
|
-1, # [B, n_y]
|
|
name='log_p_x_all'))
|
|
self.kl_z_all = kl_z_all
|
|
self.y_probs = y_probs
|
|
|
|
self._cache[cache_key] = (log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup,
|
|
kl_z_sup)
|
|
return log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup, kl_z_sup
|
|
|
|
def _kl_and_qy(self, hiddens):
|
|
"""Returns analytical or sampled KL div and the distribution q(y | x).
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
|
|
|
|
Returns:
|
|
Pair `(kl, y)`, where `kl` is the KL divergence (a `Tensor` with shape
|
|
`[B]`, where `B` is the batch size), and `y` is a sample from the
|
|
categorical encoding distribution.
|
|
"""
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
q = self.infer_cluster(hiddens) # q(y|x)
|
|
p = self.compute_prior() # p(y)
|
|
try:
|
|
# Take the average proportions over whole batch then repeat it in each row
|
|
# before computing the KL
|
|
if self._kly_over_batch:
|
|
probs = tf.reduce_mean(
|
|
q.probs, axis=0, keepdims=True) * tf.ones_like(q.probs)
|
|
qmean = tfp.distributions.OneHotCategorical(probs=probs)
|
|
kl = tfp.distributions.kl_divergence(qmean, p)
|
|
else:
|
|
kl = tfp.distributions.kl_divergence(q, p)
|
|
except NotImplementedError:
|
|
y = q.sample(name='y_sample')
|
|
logging.warn('Using sampling KLD for y')
|
|
log_p_y = p.log_prob(y, name='log_p_y')
|
|
log_q_y = q.log_prob(y, name='log_q_y')
|
|
|
|
# Reduce over all dimension except batch.
|
|
sum_axis_p = [k for k in range(1, log_p_y.get_shape().ndims)]
|
|
log_p_y = tf.reduce_sum(log_p_y, sum_axis_p)
|
|
sum_axis_q = [k for k in range(1, log_q_y.get_shape().ndims)]
|
|
log_q_y = tf.reduce_sum(log_q_y, sum_axis_q)
|
|
|
|
kl = log_q_y - log_p_y
|
|
|
|
# Reduce over all dimension except batch.
|
|
sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)]
|
|
kl = tf.reduce_sum(kl, sum_axis_kl, name='kl')
|
|
return kl, q
|
|
|
|
def _kl_and_z(self, hiddens, y):
|
|
"""Returns KL[q(z|y,x) || p(z|y)] and a sample for z from q(z|y,x).
|
|
|
|
Returns the analytical KL divergence KL[q(z|y,x) || p(z|y)] if one is
|
|
available (as registered with `kullback_leibler.RegisterKL`), or a sampled
|
|
KL divergence otherwise (in this case the returned sample is the one used
|
|
for the KL divergence).
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
|
|
y: Categorical cluster random variable, `Tensor` of size `[B, n_y]`.
|
|
|
|
Returns:
|
|
Pair `(kl, z)`, where `kl` is the KL divergence (a `Tensor` with shape
|
|
`[B]`, where `B` is the batch size), and `z` is a sample from the encoding
|
|
distribution.
|
|
"""
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
q = self.infer_latent(hiddens, y) # q(z|x,y)
|
|
p = self.generate_latent(y) # p(z|y)
|
|
z = q.sample(name='z')
|
|
try:
|
|
kl = tfp.distributions.kl_divergence(q, p)
|
|
except NotImplementedError:
|
|
logging.warn('Using sampling KLD for z')
|
|
log_p_z = p.log_prob(z, name='log_p_z_y')
|
|
log_q_z = q.log_prob(z, name='log_q_z_xy')
|
|
|
|
# Reduce over all dimension except batch.
|
|
sum_axis_p = [k for k in range(1, log_p_z.get_shape().ndims)]
|
|
log_p_z = tf.reduce_sum(log_p_z, sum_axis_p)
|
|
sum_axis_q = [k for k in range(1, log_q_z.get_shape().ndims)]
|
|
log_q_z = tf.reduce_sum(log_q_z, sum_axis_q)
|
|
|
|
kl = log_q_z - log_p_z
|
|
|
|
# Reduce over all dimension except batch.
|
|
sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)]
|
|
kl = tf.reduce_sum(kl, sum_axis_kl, name='kl')
|
|
return kl, z
|
|
|
|
def infer_latent(self, hiddens, y=None, use_mean_y=False):
|
|
"""Performs inference over the latent variable z.
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 4D `Tensor` of size `[B, ...]`.
|
|
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
|
|
use_mean_y: Boolean, whether to take the mean encoding over all y.
|
|
|
|
Returns:
|
|
The distribution `q(z|x, y)`, which on sample produces tensors of size
|
|
`[N, B, ...]` where `B` is the batch size of `x` and `y`, and `N` is the
|
|
number of samples and `...` represents the shape of the latent variables.
|
|
"""
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
if y is None:
|
|
y = tf.to_float(self.infer_cluster(hiddens).mode())
|
|
|
|
if use_mean_y:
|
|
# If use_mean_y, then y must be probabilities
|
|
all_y = tf.tile(
|
|
tf.expand_dims(tf.one_hot(tf.range(y.shape[1]), y.shape[1]), axis=1),
|
|
multiples=[1, y.shape[0], 1])
|
|
|
|
# Compute z KL from x (for all possible y), and keep z's around
|
|
z_all = tf.map_fn(
|
|
fn=lambda y: self._latent_encoder(
|
|
hiddens, y, is_training=self._is_training).mean(),
|
|
elems=all_y,
|
|
dtype=tf.float32)
|
|
return tf.einsum('ij,jik->ik', y, z_all)
|
|
else:
|
|
return self._latent_encoder(hiddens, y, is_training=self._is_training)
|
|
|
|
def generate_latent(self, y):
|
|
"""Use the generative model to compute latent variable z, given a y.
|
|
|
|
Args:
|
|
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
|
|
|
|
Returns:
|
|
The distribution `p(z|y)`, which on sample produces tensors of size
|
|
`[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of
|
|
samples asked and `...` represents the shape of the latent variables.
|
|
"""
|
|
return self._latent_decoder(y, is_training=self._is_training)
|
|
|
|
def get_shared_rep(self, x, is_training):
|
|
"""Gets the shared representation from a given input x.
|
|
|
|
Args:
|
|
x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of
|
|
a flattened input.
|
|
is_training: bool, whether this constitutes training data or not.
|
|
|
|
Returns:
|
|
`log p(x|y,z)` of shape `[B]` where `B` is the batch size.
|
|
`KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size.
|
|
`KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size.
|
|
"""
|
|
return self._shared_encoder(x, is_training)
|
|
|
|
def infer_cluster(self, hiddens):
|
|
"""Performs inference over the categorical variable y.
|
|
|
|
Args:
|
|
hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`.
|
|
|
|
Returns:
|
|
The distribution `q(y|x)`, which on sample produces tensors of size
|
|
`[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of
|
|
samples asked and `...` represents the shape of the latent variables.
|
|
"""
|
|
with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]):
|
|
return self._cluster_encoder(hiddens, is_training=self._is_training)
|
|
|
|
def predict(self, z, y):
|
|
"""Computes prediction over the observed variables.
|
|
|
|
Args:
|
|
z: Latent variables, `Tensor` of size `[B, ...]`.
|
|
y: Categorical cluster variable, `Tensor` of size `[B, ...]`.
|
|
|
|
Returns:
|
|
The distribution `p(x|z)`, which on sample produces tensors of size
|
|
`[N, B, ...]` where `N` is the number of samples asked.
|
|
"""
|
|
encoder_conv_shapes = getattr(self._shared_encoder, 'conv_shapes', None)
|
|
return self._data_decoder(
|
|
z,
|
|
y,
|
|
shared_encoder_conv_shapes=encoder_conv_shapes,
|
|
is_training=self._is_training)
|
|
|
|
def compute_prior(self):
|
|
"""Computes prior over the latent variables.
|
|
|
|
Returns:
|
|
The distribution `p(y)`, which on sample produces tensors of size
|
|
`[N, ...]` where `N` is the number of samples asked and `...` represents
|
|
the shape of the latent variables.
|
|
"""
|
|
return self._prior()
|
|
|
|
|
|
class UpsampleModule(snt.AbstractModule):
|
|
"""Convolutional decoder.
|
|
|
|
If `method` is 'deconv' apply transposed convolutions with stride 2,
|
|
otherwise apply the `method` upsampling function and then smooth with a
|
|
stride 1x1 convolution.
|
|
|
|
Params:
|
|
-------
|
|
filters: list, where the first element is the number of filters of the initial
|
|
MLP layer and the remaining elements are the number of filters of the
|
|
upsampling layers.
|
|
kernel_size: the size of the convolutional kernels. The same size will be
|
|
used in all convolutions.
|
|
activation: an activation function, applied to all layers but the last.
|
|
dec_up_strides: list, the upsampling factors of each upsampling convolutional
|
|
layer.
|
|
enc_conv_shapes: list, the shapes of the input and of all the intermediate
|
|
feature maps of the convolutional layers in the encoder.
|
|
n_c: the number of output channels.
|
|
"""
|
|
|
|
def __init__(self,
|
|
filters,
|
|
kernel_size,
|
|
activation,
|
|
dec_up_strides,
|
|
enc_conv_shapes,
|
|
n_c,
|
|
method='nn',
|
|
name='upsample_module'):
|
|
super(UpsampleModule, self).__init__(name=name)
|
|
|
|
assert len(filters) == len(dec_up_strides) + 1, (
|
|
'The decoder\'s filters should contain one element more than the '
|
|
'decoder\'s up stride list, but has %d elements instead of %d.\n'
|
|
'Decoder filters: %s\nDecoder up strides: %s' %
|
|
(len(filters), len(dec_up_strides) + 1, str(filters),
|
|
str(dec_up_strides)))
|
|
|
|
self._filters = filters
|
|
self._kernel_size = kernel_size
|
|
self._activation = activation
|
|
|
|
self._dec_up_strides = dec_up_strides
|
|
self._enc_conv_shapes = enc_conv_shapes
|
|
self._n_c = n_c
|
|
if method == 'deconv':
|
|
self._conv_layer = tf.layers.Conv2DTranspose
|
|
self._method = method
|
|
else:
|
|
self._conv_layer = tf.layers.Conv2D
|
|
self._method = getattr(tf.image.ResizeMethod, method.upper())
|
|
self._method_str = method.capitalize()
|
|
|
|
def _build(self, z, is_training=True, test_local_stats=True, use_bn=False):
|
|
batch_norm_args = {
|
|
'is_training': is_training,
|
|
'test_local_stats': test_local_stats
|
|
}
|
|
|
|
method = self._method
|
|
# Cycle over the encoder shapes backwards, to build a symmetrical decoder.
|
|
enc_conv_shapes = self._enc_conv_shapes[::-1]
|
|
strides = self._dec_up_strides
|
|
# We store the heights and widths of the encoder feature maps that are
|
|
# unique, i.e., the ones right after a layer with stride != 1. These will be
|
|
# used as a target to potentially crop the upsampled feature maps.
|
|
unique_hw = np.unique([(el[1], el[2]) for el in enc_conv_shapes], axis=0)
|
|
unique_hw = unique_hw.tolist()[::-1]
|
|
unique_hw.pop() # Drop the initial shape
|
|
|
|
# The first filter is an MLP.
|
|
mlp_filter, conv_filters = self._filters[0], self._filters[1:]
|
|
# The first shape is used after the MLP to go to 4D.
|
|
|
|
layers = [z]
|
|
# The shape of the first enc is used after the MLP to go back to 4D.
|
|
dec_mlp = snt.nets.MLP(
|
|
name='dec_mlp_projection',
|
|
output_sizes=[mlp_filter, np.prod(enc_conv_shapes[0][1:])],
|
|
use_bias=not use_bn,
|
|
activation=self._activation,
|
|
activate_final=True)
|
|
|
|
upsample_mlp_flat = dec_mlp(z)
|
|
if use_bn:
|
|
upsample_mlp_flat = snt.BatchNorm(scale=True)(upsample_mlp_flat,
|
|
**batch_norm_args)
|
|
layers.append(upsample_mlp_flat)
|
|
upsample = tf.reshape(upsample_mlp_flat, enc_conv_shapes[0])
|
|
layers.append(upsample)
|
|
|
|
for i, (filter_i, stride_i) in enumerate(zip(conv_filters, strides), 1):
|
|
if method != 'deconv' and stride_i > 1:
|
|
upsample = tf.image.resize_images(
|
|
upsample, [stride_i * el for el in upsample.shape.as_list()[1:3]],
|
|
method=method,
|
|
name='upsample_' + str(i))
|
|
upsample = self._conv_layer(
|
|
filters=filter_i,
|
|
kernel_size=self._kernel_size,
|
|
padding='same',
|
|
use_bias=not use_bn,
|
|
activation=self._activation,
|
|
strides=stride_i if method == 'deconv' else 1,
|
|
name='upsample_conv_' + str(i))(
|
|
upsample)
|
|
if use_bn:
|
|
upsample = snt.BatchNorm(scale=True)(upsample, **batch_norm_args)
|
|
if stride_i > 1:
|
|
hw = unique_hw.pop()
|
|
upsample = utils.maybe_center_crop(upsample, hw)
|
|
layers.append(upsample)
|
|
|
|
# Final layer, no upsampling.
|
|
x_logits = tf.layers.Conv2D(
|
|
filters=self._n_c,
|
|
kernel_size=self._kernel_size,
|
|
padding='same',
|
|
use_bias=not use_bn,
|
|
activation=None,
|
|
strides=1,
|
|
name='logits')(
|
|
upsample)
|
|
if use_bn:
|
|
x_logits = snt.BatchNorm(scale=True)(x_logits, **batch_norm_args)
|
|
layers.append(x_logits)
|
|
|
|
logging.info('%s upsampling module layer shapes', self._method_str)
|
|
logging.info('\n'.join([str(v.shape.as_list()) for v in layers]))
|
|
|
|
return x_logits
|