From 656f93a37efd545d5a5d9b27e6f16687221ef5aa Mon Sep 17 00:00:00 2001 From: Dushyant Rao Date: Thu, 7 Nov 2019 17:17:54 +0000 Subject: [PATCH] Apply minor fixes to README PiperOrigin-RevId: 279101132 --- curl/README.md | 35 ++ curl/layers.py | 120 +++++ curl/model.py | 797 ++++++++++++++++++++++++++++ curl/requirements.txt | 10 + curl/train_main.py | 79 +++ curl/train_sup.py | 64 +++ curl/train_unsup.py | 73 +++ curl/training.py | 1169 +++++++++++++++++++++++++++++++++++++++++ curl/unit_test.py | 67 +++ curl/utils.py | 85 +++ 10 files changed, 2499 insertions(+) create mode 100644 curl/README.md create mode 100644 curl/layers.py create mode 100644 curl/model.py create mode 100644 curl/requirements.txt create mode 100644 curl/train_main.py create mode 100644 curl/train_sup.py create mode 100644 curl/train_unsup.py create mode 100644 curl/training.py create mode 100644 curl/unit_test.py create mode 100644 curl/utils.py diff --git a/curl/README.md b/curl/README.md new file mode 100644 index 0000000..dde40d2 --- /dev/null +++ b/curl/README.md @@ -0,0 +1,35 @@ +# Continual Unsupervised Representation Learning (CURL) + +This repository contains code to accompany the NeurIPS 2019 submission on +Continual Unsupervised Representation Learning (CURL). + +The experiments in the paper can be reproduced by running one of the three +different training scripts: + + +`train_sup.py`: to run the supervised continual learning benchmark + +`train_unsup.py`: to run the unsupervised i.i.d learning benchmark + +`train_main.py`: to run all other experiments in the paper (with details in the +file on what to change) + +In each of these cases, the cluster accuracy / purity and k-NN error are logged +to the terminal, and other quantities can be accessed from training.py +(e.g. the confusion matrix can be found in `results['test_confusion']`). + +We recommend running these scripts in a Python +[virtual environment](https://docs.python.org/3/tutorial/venv.html): + +(Assuming python3-dev is installed in your system) + +```console +python3 -m venv .curl_venv +source .curl_venv/bin/activate +pip install wheel +pip install -r requirements.txt + +PYTHONPATH=`pwd`/..:$PYTHONPATH python3 train_main.py --dataset='mnist' + +Run `deactivate` to exit the virtual environment. +``` diff --git a/curl/layers.py b/curl/layers.py new file mode 100644 index 0000000..2b13e57 --- /dev/null +++ b/curl/layers.py @@ -0,0 +1,120 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Custom layers for CURL.""" + +from absl import logging +import sonnet as snt +import tensorflow as tf + +tfc = tf.compat.v1 + + +class ResidualStack(snt.AbstractModule): + """A stack of ResNet V2 blocks.""" + + def __init__(self, + num_hiddens, + num_residual_layers, + num_residual_hiddens, + filter_size=3, + initializers=None, + data_format='NHWC', + activation=tf.nn.relu, + name='residual_stack'): + """Instantiate a ResidualStack.""" + super(ResidualStack, self).__init__(name=name) + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + self._filter_size = filter_size + self._initializers = initializers + self._data_format = data_format + self._activation = activation + + def _build(self, h): + for i in range(self._num_residual_layers): + h_i = self._activation(h) + + h_i = snt.Conv2D( + output_channels=self._num_residual_hiddens, + kernel_shape=(self._filter_size, self._filter_size), + stride=(1, 1), + initializers=self._initializers, + data_format=self._data_format, + name='res_nxn_%d' % i)( + h_i) + h_i = self._activation(h_i) + + h_i = snt.Conv2D( + output_channels=self._num_hiddens, + kernel_shape=(1, 1), + stride=(1, 1), + initializers=self._initializers, + data_format=self._data_format, + name='res_1x1_%d' % i)( + h_i) + h += h_i + return self._activation(h) + + +class SharedConvModule(snt.AbstractModule): + """Convolutional decoder.""" + + def __init__(self, + filters, + kernel_size, + activation, + strides, + name='shared_conv_encoder'): + super(SharedConvModule, self).__init__(name=name) + + self._filters = filters + self._kernel_size = kernel_size + self._activation = activation + self.strides = strides + assert len(strides) == len(filters) - 1 + self.conv_shapes = None + + def _build(self, x, is_training=True): + with tf.control_dependencies([tfc.assert_rank(x, 4)]): + + self.conv_shapes = [x.shape.as_list()] # Needed by deconv module + conv = x + for i, (filter_i, + stride_i) in enumerate(zip(self._filters, self.strides), 1): + conv = tf.layers.Conv2D( + filters=filter_i, + kernel_size=self._kernel_size, + padding='same', + activation=self._activation, + strides=stride_i, + name='enc_conv_%d' % i)( + conv) + self.conv_shapes.append(conv.shape.as_list()) + conv_flat = snt.BatchFlatten()(conv) + + enc_mlp = snt.nets.MLP( + name='enc_mlp', + output_sizes=[self._filters[-1]], + activation=self._activation, + activate_final=True) + h = enc_mlp(conv_flat) + + logging.info('Shared conv module layer shapes:') + logging.info('\n'.join([str(el) for el in self.conv_shapes])) + logging.info(h.shape.as_list()) + + return h diff --git a/curl/model.py b/curl/model.py new file mode 100644 index 0000000..39dc28f --- /dev/null +++ b/curl/model.py @@ -0,0 +1,797 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Implementation of Continual Unsupervised Representation Learning model.""" + +from absl import logging +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +from curl import layers +from curl import utils + +tfc = tf.compat.v1 + +# pylint: disable=g-long-lambda +# pylint: disable=redefined-outer-name + + +class SharedEncoder(snt.AbstractModule): + """The shared encoder module, mapping input x to hiddens.""" + + def __init__(self, encoder_type, n_enc, enc_strides, name='shared_encoder'): + """The shared encoder function, mapping input x to hiddens. + + Args: + encoder_type: str, type of encoder, either 'conv' or 'multi' + n_enc: list, number of hidden units per layer in the encoder + enc_strides: list, stride in each layer (only for 'conv' encoder_type) + name: str, module name used for tf scope. + """ + super(SharedEncoder, self).__init__(name=name) + self._encoder_type = encoder_type + + if encoder_type == 'conv': + self.shared_encoder = layers.SharedConvModule( + filters=n_enc, + strides=enc_strides, + kernel_size=3, + activation=tf.nn.relu) + elif encoder_type == 'multi': + self.shared_encoder = snt.nets.MLP( + name='mlp_shared_encoder', + output_sizes=n_enc, + activation=tf.nn.relu, + activate_final=True) + else: + raise ValueError('Unknown encoder_type {}'.format(encoder_type)) + + def _build(self, x, is_training): + if self._encoder_type == 'multi': + self.conv_shapes = None + x = snt.BatchFlatten()(x) + return self.shared_encoder(x) + else: + output = self.shared_encoder(x) + self.conv_shapes = self.shared_encoder.conv_shapes + return output + + +def cluster_encoder_fn(hiddens, n_y_active, n_y, is_training=True): + """The cluster encoder function, modelling q(y | x). + + Args: + hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`. + n_y_active: Tensor, the number of active components. + n_y: int, number of maximum components allowed (used for tensor size) + is_training: Boolean, whether to build the training graph or an evaluation + graph. + + Returns: + The distribution `q(y | x)`. + """ + del is_training # unused for now + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + lin = snt.Linear(n_y, name='mlp_cluster_encoder_final') + logits = lin(hiddens) + + # Only use the first n_y_active components, and set the remaining to zero. + if n_y > 1: + probs = tf.nn.softmax(logits[:, :n_y_active]) + logging.info('Cluster softmax active probs shape: %s', str(probs.shape)) + paddings1 = tf.stack([tf.constant(0), tf.constant(0)], axis=0) + paddings2 = tf.stack([tf.constant(0), n_y - n_y_active], axis=0) + paddings = tf.stack([paddings1, paddings2], axis=1) + probs = tf.pad(probs, paddings) + 0.0 * logits + 1e-12 + else: + probs = tf.ones_like(logits) + logging.info('Cluster softmax probs shape: %s', str(probs.shape)) + + return tfp.distributions.OneHotCategorical(probs=probs) + + +def latent_encoder_fn(hiddens, y, n_y, n_z, is_training=True): + """The latent encoder function, modelling q(z | x, y). + + Args: + hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`. + y: Categorical cluster variable, `Tensor` of size `[B, n_y]`. + n_y: int, number of dims of y. + n_z: int, number of dims of z. + is_training: Boolean, whether to build the training graph or an evaluation + graph. + + Returns: + The Gaussian distribution `q(z | x, y)`. + """ + del is_training # unused for now + + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + # Logits for both mean and variance + n_logits = 2 * n_z + + all_logits = [] + for k in range(n_y): + lin = snt.Linear(n_logits, name='mlp_latent_encoder_' + str(k)) + all_logits.append(lin(hiddens)) + + # Sum over cluster components. + all_logits = tf.stack(all_logits) # [n_y, B, n_logits] + logits = tf.einsum('ij,jik->ik', y, all_logits) + + # Compute distribution from logits. + return utils.generate_gaussian( + logits=logits, sigma_nonlin='softplus', sigma_param='var') + + +def data_decoder_fn(z, + y, + output_type, + output_shape, + decoder_type, + n_dec, + dec_up_strides, + n_x, + n_y, + shared_encoder_conv_shapes=None, + is_training=True, + test_local_stats=True): + """The data decoder function, modelling p(x | z). + + Args: + z: Latent variables, `Tensor` of size `[B, n_z]`. + y: Categorical cluster variable, `Tensor` of size `[B, n_y]`. + output_type: str, output distribution ('bernoulli' or 'quantized_normal'). + output_shape: list, shape of output (not including batch dimension). + decoder_type: str, 'single', 'multi', or 'deconv'. + n_dec: list, number of hidden units per layer in the decoder + dec_up_strides: list, stride in each layer (only for 'deconv' decoder_type). + n_x: int, number of dims of x. + n_y: int, number of dims of y. + shared_encoder_conv_shapes: the shapes of the activations of the + intermediate layers of the encoder, + is_training: Boolean, whether to build the training graph or an evaluation + graph. + test_local_stats: Boolean, whether to use the test batch statistics at test + time for batch norm (default) or the moving averages. + + Returns: + The Bernoulli distribution `p(x | z)`. + """ + + if output_type == 'bernoulli': + output_dist = lambda x: tfp.distributions.Bernoulli(logits=x) + n_out_factor = 1 + out_shape = list(output_shape) + else: + raise NotImplementedError + if len(z.shape) != 2: + raise NotImplementedError('The data decoder function expects `z` to be ' + '2D, but its shape was %s instead.' % + str(z.shape)) + if len(y.shape) != 2: + raise NotImplementedError('The data decoder function expects `y` to be ' + '2D, but its shape was %s instead.' % + str(y.shape)) + + # Upsample layer (deconvolutional, bilinear, ..). + if decoder_type == 'deconv': + + # First, check that the encoder is convolutional too (needed for batchnorm) + if shared_encoder_conv_shapes is None: + raise ValueError('Shared encoder does not contain conv_shapes.') + + num_output_channels = output_shape[-1] + conv_decoder = UpsampleModule( + filters=n_dec, + kernel_size=3, + activation=tf.nn.relu, + dec_up_strides=dec_up_strides, + enc_conv_shapes=shared_encoder_conv_shapes, + n_c=num_output_channels * n_out_factor, + method=decoder_type) + logits = conv_decoder( + z, is_training=is_training, test_local_stats=test_local_stats) + logits = tf.reshape(logits, [-1] + out_shape) # n_out_factor in last dim + + # Multiple MLP decoders, one for each component. + elif decoder_type == 'multi': + all_logits = [] + for k in range(n_y): + mlp_decoding = snt.nets.MLP( + name='mlp_latent_decoder_' + str(k), + output_sizes=n_dec + [n_x * n_out_factor], + activation=tf.nn.relu, + activate_final=False) + logits = mlp_decoding(z) + all_logits.append(logits) + + all_logits = tf.stack(all_logits) + logits = tf.einsum('ij,jik->ik', y, all_logits) + logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D + + # Single (shared among components) MLP decoder. + elif decoder_type == 'single': + mlp_decoding = snt.nets.MLP( + name='mlp_latent_decoder', + output_sizes=n_dec + [n_x * n_out_factor], + activation=tf.nn.relu, + activate_final=False) + logits = mlp_decoding(z) + logits = tf.reshape(logits, [-1] + out_shape) # Back to 4D + else: + raise ValueError('Unknown decoder_type {}'.format(decoder_type)) + + return output_dist(logits) + + +def latent_decoder_fn(y, n_z, is_training=True): + """The latent decoder function, modelling p(z | y). + + Args: + y: Categorical cluster variable, `Tensor` of size `[B, n_y]`. + n_z: int, number of dims of z. + is_training: Boolean, whether to build the training graph or an evaluation + graph. + + Returns: + The Gaussian distribution `p(z | y)`. + """ + del is_training # Unused for now. + if len(y.shape) != 2: + raise NotImplementedError('The latent decoder function expects `y` to be ' + '2D, but its shape was %s instead.' % + str(y.shape)) + + lin_mu = snt.Linear(n_z, name='latent_prior_mu') + lin_sigma = snt.Linear(n_z, name='latent_prior_sigma') + + mu = lin_mu(y) + sigma = lin_sigma(y) + + logits = tf.concat([mu, sigma], axis=1) + + return utils.generate_gaussian( + logits=logits, sigma_nonlin='softplus', sigma_param='var') + + +class Curl(object): + """CURL model class.""" + + def __init__(self, + prior, + latent_decoder, + data_decoder, + shared_encoder, + cluster_encoder, + latent_encoder, + n_y_active, + kly_over_batch=False, + is_training=True, + name='curl'): + self.scope_name = name + self._shared_encoder = shared_encoder + self._prior = prior + self._latent_decoder = latent_decoder + self._data_decoder = data_decoder + self._cluster_encoder = cluster_encoder + self._latent_encoder = latent_encoder + self._n_y_active = n_y_active + self._kly_over_batch = kly_over_batch + self._is_training = is_training + self._cache = {} + + def sample(self, sample_shape=(), y=None, mean=False): + """Draws a sample from the learnt distribution p(x). + + Args: + sample_shape: `int` or 0D `Tensor` giving the number of samples to return. + If empty tuple (default value), 1 sample will be returned. + y: Optional, the one hot label on which to condition the sample. + mean: Boolean, if True the expected value of the output distribution is + returned, otherwise samples from the output distribution. + + Returns: + Sample tensor of shape `[B * N, ...]` where `B` is the batch size of + the prior, `N` is the number of samples requested, and `...` represents + the shape of the observations. + + Raises: + ValueError: If both `sample_shape` and `n` are provided. + ValueError: If `sample_shape` has rank > 0 or if `sample_shape` + is an int that is < 1. + """ + with tf.name_scope('{}_sample'.format(self.scope_name)): + if y is None: + y = tf.to_float(self.compute_prior().sample(sample_shape)) + + if y.shape.ndims > 2: + y = snt.MergeDims(start=0, size=y.shape.ndims - 1, name='merge_y')(y) + + z = self._latent_decoder(y, is_training=self._is_training) + if mean: + samples = self.predict(z.sample(), y).mean() + else: + samples = self.predict(z.sample(), y).sample() + return samples + + def reconstruct(self, x, use_mode=True, use_mean=False): + """Reconstructs the given observations. + + Args: + x: Observed `Tensor`. + use_mode: Boolean, if true, take the argmax over q(y|x) + use_mean: Boolean, if true, use pixel-mean for reconstructions. + + Returns: + The reconstructed samples x ~ p(x | y~q(y|x), z~q(z|x, y)). + """ + + hiddens = self._shared_encoder(x, is_training=self._is_training) + qy = self.infer_cluster(hiddens) + y_sample = qy.mode() if use_mode else qy.sample() + y_sample = tf.to_float(y_sample) + qz = self.infer_latent(hiddens, y_sample) + p = self.predict(qz.sample(), y_sample) + + if use_mean: + return p.mean() + else: + return p.sample() + + def log_prob(self, x): + """Redirects to log_prob_elbo with a warning.""" + logging.warn('log_prob is actually a lower bound') + return self.log_prob_elbo(x) + + def log_prob_elbo(self, x): + """Returns evidence lower bound.""" + log_p_x, kl_y, kl_z = self.log_prob_elbo_components(x)[:3] + return log_p_x - kl_y - kl_z + + def log_prob_elbo_components(self, x, y=None, reduce_op=tf.reduce_sum): + """Returns the components used in calculating the evidence lower bound. + + Args: + x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of + a flattened input. + y: Optional labels, `Tensor` of size `[B, I]` where `I` is the size of a + flattened input. + reduce_op: The op to use for reducing across non-batch dimensions. + Typically either `tf.reduce_sum` or `tf.reduce_mean`. + + Returns: + `log p(x|y,z)` of shape `[B]` where `B` is the batch size. + `KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size. + `KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size. + """ + cache_key = (x,) + + # Checks if the output graph for this inputs has already been computed. + if cache_key in self._cache: + return self._cache[cache_key] + + with tf.name_scope('{}_log_prob_elbo'.format(self.scope_name)): + + hiddens = self._shared_encoder(x, is_training=self._is_training) + # 1) Compute KL[q(y|x) || p(y)] from x, and keep distribution q_y around + kl_y, q_y = self._kl_and_qy(hiddens) # [B], distribution + + # For the next two terms, we need to marginalise over all y. + + # First, construct every possible y indexing (as a one hot) and repeat it + # for every element in the batch [n_y_active, B, n_y]. + # Note that the onehot have dimension of all y, while only the codes + # corresponding to active components are instantiated + bs, n_y = q_y.probs.shape + all_y = tf.tile( + tf.expand_dims(tf.one_hot(tf.range(self._n_y_active), + n_y), axis=1), + multiples=[1, bs, 1]) + + # 2) Compute KL[q(z|x,y) || p(z|y)] (for all possible y), and keep z's + # around [n_y, B] and [n_y, B, n_z] + kl_z_all, z_all = tf.map_fn( + fn=lambda y: self._kl_and_z(hiddens, y), + elems=all_y, + dtype=(tf.float32, tf.float32), + name='elbo_components_z_map') + kl_z_all = tf.transpose(kl_z_all, name='kl_z_all') + + # Now take the expectation over y (scale by q(y|x)) + y_logits = q_y.logits[:, :self._n_y_active] # [B, n_y] + y_probs = q_y.probs[:, :self._n_y_active] # [B, n_y] + y_probs = y_probs / tf.reduce_sum(y_probs, axis=1, keepdims=True) + kl_z = tf.reduce_sum(y_probs * kl_z_all, axis=1) + + # 3) Evaluate logp and recon, i.e., log and mean of p(x|z,[y]) + # (conditioning on y only in the `multi` decoder_type case, when + # train_supervised is True). Here we take the reconstruction from each + # possible component y and take its log prob. [n_y, B, Ix, Iy, Iz] + log_p_x_all = tf.map_fn( + fn=lambda val: self.predict(val[0], val[1]).log_prob(x), + elems=(z_all, all_y), + dtype=tf.float32, + name='elbo_components_logpx_map') + + # Sum log probs over all dimensions apart from the first two (n_y, B), + # i.e., over I. Use einsum to construct higher order multiplication. + log_p_x_all = snt.BatchFlatten(preserve_dims=2)(log_p_x_all) # [n_y,B,I] + # Note, this is E_{q(y|x)} [ log p(x | z, y)], i.e., we scale log_p_x_all + # by q(y|x). + log_p_x = tf.einsum('ij,jik->ik', y_probs, log_p_x_all) # [B, I] + + # We may also use a supervised loss for some samples [B, n_y] + if y is not None: + self.y_label = tf.one_hot(y, n_y) + else: + self.y_label = tfc.placeholder( + shape=[bs, n_y], dtype=tf.float32, name='y_label') + + # This is computing log p(x | z, y=true_y)], which is basically equivalent + # to indexing into the correct element of `log_p_x_all`. + log_p_x_sup = tf.einsum('ij,jik->ik', + self.y_label[:, :self._n_y_active], + log_p_x_all) # [B, I] + kl_z_sup = tf.einsum('ij,ij->i', + self.y_label[:, :self._n_y_active], + kl_z_all) # [B] + # -log q(y=y_true | x) + kl_y_sup = tf.nn.sparse_softmax_cross_entropy_with_logits( # [B] + labels=tf.argmax(self.y_label[:, :self._n_y_active], axis=1), + logits=y_logits) + + # Reduce over all dimension except batch. + dims_x = [k for k in range(1, log_p_x.shape.ndims)] + log_p_x = reduce_op(log_p_x, dims_x, name='log_p_x') + log_p_x_sup = reduce_op(log_p_x_sup, dims_x, name='log_p_x_sup') + + # Store values needed externally + self.q_y = q_y + self.log_p_x_all = tf.transpose( + reduce_op( + log_p_x_all, + -1, # [B, n_y] + name='log_p_x_all')) + self.kl_z_all = kl_z_all + self.y_probs = y_probs + + self._cache[cache_key] = (log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup, + kl_z_sup) + return log_p_x, kl_y, kl_z, log_p_x_sup, kl_y_sup, kl_z_sup + + def _kl_and_qy(self, hiddens): + """Returns analytical or sampled KL div and the distribution q(y | x). + + Args: + hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`. + + Returns: + Pair `(kl, y)`, where `kl` is the KL divergence (a `Tensor` with shape + `[B]`, where `B` is the batch size), and `y` is a sample from the + categorical encoding distribution. + """ + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + q = self.infer_cluster(hiddens) # q(y|x) + p = self.compute_prior() # p(y) + try: + # Take the average proportions over whole batch then repeat it in each row + # before computing the KL + if self._kly_over_batch: + probs = tf.reduce_mean( + q.probs, axis=0, keepdims=True) * tf.ones_like(q.probs) + qmean = tfp.distributions.OneHotCategorical(probs=probs) + kl = tfp.distributions.kl_divergence(qmean, p) + else: + kl = tfp.distributions.kl_divergence(q, p) + except NotImplementedError: + y = q.sample(name='y_sample') + logging.warn('Using sampling KLD for y') + log_p_y = p.log_prob(y, name='log_p_y') + log_q_y = q.log_prob(y, name='log_q_y') + + # Reduce over all dimension except batch. + sum_axis_p = [k for k in range(1, log_p_y.get_shape().ndims)] + log_p_y = tf.reduce_sum(log_p_y, sum_axis_p) + sum_axis_q = [k for k in range(1, log_q_y.get_shape().ndims)] + log_q_y = tf.reduce_sum(log_q_y, sum_axis_q) + + kl = log_q_y - log_p_y + + # Reduce over all dimension except batch. + sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)] + kl = tf.reduce_sum(kl, sum_axis_kl, name='kl') + return kl, q + + def _kl_and_z(self, hiddens, y): + """Returns KL[q(z|y,x) || p(z|y)] and a sample for z from q(z|y,x). + + Returns the analytical KL divergence KL[q(z|y,x) || p(z|y)] if one is + available (as registered with `kullback_leibler.RegisterKL`), or a sampled + KL divergence otherwise (in this case the returned sample is the one used + for the KL divergence). + + Args: + hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`. + y: Categorical cluster random variable, `Tensor` of size `[B, n_y]`. + + Returns: + Pair `(kl, z)`, where `kl` is the KL divergence (a `Tensor` with shape + `[B]`, where `B` is the batch size), and `z` is a sample from the encoding + distribution. + """ + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + q = self.infer_latent(hiddens, y) # q(z|x,y) + p = self.generate_latent(y) # p(z|y) + z = q.sample(name='z') + try: + kl = tfp.distributions.kl_divergence(q, p) + except NotImplementedError: + logging.warn('Using sampling KLD for z') + log_p_z = p.log_prob(z, name='log_p_z_y') + log_q_z = q.log_prob(z, name='log_q_z_xy') + + # Reduce over all dimension except batch. + sum_axis_p = [k for k in range(1, log_p_z.get_shape().ndims)] + log_p_z = tf.reduce_sum(log_p_z, sum_axis_p) + sum_axis_q = [k for k in range(1, log_q_z.get_shape().ndims)] + log_q_z = tf.reduce_sum(log_q_z, sum_axis_q) + + kl = log_q_z - log_p_z + + # Reduce over all dimension except batch. + sum_axis_kl = [k for k in range(1, kl.get_shape().ndims)] + kl = tf.reduce_sum(kl, sum_axis_kl, name='kl') + return kl, z + + def infer_latent(self, hiddens, y=None, use_mean_y=False): + """Performs inference over the latent variable z. + + Args: + hiddens: The shared encoder activations, 4D `Tensor` of size `[B, ...]`. + y: Categorical cluster variable, `Tensor` of size `[B, ...]`. + use_mean_y: Boolean, whether to take the mean encoding over all y. + + Returns: + The distribution `q(z|x, y)`, which on sample produces tensors of size + `[N, B, ...]` where `B` is the batch size of `x` and `y`, and `N` is the + number of samples and `...` represents the shape of the latent variables. + """ + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + if y is None: + y = tf.to_float(self.infer_cluster(hiddens).mode()) + + if use_mean_y: + # If use_mean_y, then y must be probabilities + all_y = tf.tile( + tf.expand_dims(tf.one_hot(tf.range(y.shape[1]), y.shape[1]), axis=1), + multiples=[1, y.shape[0], 1]) + + # Compute z KL from x (for all possible y), and keep z's around + z_all = tf.map_fn( + fn=lambda y: self._latent_encoder( + hiddens, y, is_training=self._is_training).mean(), + elems=all_y, + dtype=tf.float32) + return tf.einsum('ij,jik->ik', y, z_all) + else: + return self._latent_encoder(hiddens, y, is_training=self._is_training) + + def generate_latent(self, y): + """Use the generative model to compute latent variable z, given a y. + + Args: + y: Categorical cluster variable, `Tensor` of size `[B, ...]`. + + Returns: + The distribution `p(z|y)`, which on sample produces tensors of size + `[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of + samples asked and `...` represents the shape of the latent variables. + """ + return self._latent_decoder(y, is_training=self._is_training) + + def get_shared_rep(self, x, is_training): + """Gets the shared representation from a given input x. + + Args: + x: Observed variables, `Tensor` of size `[B, I]` where `I` is the size of + a flattened input. + is_training: bool, whether this constitutes training data or not. + + Returns: + `log p(x|y,z)` of shape `[B]` where `B` is the batch size. + `KL[q(y|x) || p(y)]` of shape `[B]` where `B` is the batch size. + `KL[q(z|x,y) || p(z|y)]` of shape `[B]` where `B` is the batch size. + """ + return self._shared_encoder(x, is_training) + + def infer_cluster(self, hiddens): + """Performs inference over the categorical variable y. + + Args: + hiddens: The shared encoder activations, 2D `Tensor` of size `[B, ...]`. + + Returns: + The distribution `q(y|x)`, which on sample produces tensors of size + `[N, B, ...]` where `B` is the batch size of `x`, and `N` is the number of + samples asked and `...` represents the shape of the latent variables. + """ + with tf.control_dependencies([tfc.assert_rank(hiddens, 2)]): + return self._cluster_encoder(hiddens, is_training=self._is_training) + + def predict(self, z, y): + """Computes prediction over the observed variables. + + Args: + z: Latent variables, `Tensor` of size `[B, ...]`. + y: Categorical cluster variable, `Tensor` of size `[B, ...]`. + + Returns: + The distribution `p(x|z)`, which on sample produces tensors of size + `[N, B, ...]` where `N` is the number of samples asked. + """ + encoder_conv_shapes = getattr(self._shared_encoder, 'conv_shapes', None) + return self._data_decoder( + z, + y, + shared_encoder_conv_shapes=encoder_conv_shapes, + is_training=self._is_training) + + def compute_prior(self): + """Computes prior over the latent variables. + + Returns: + The distribution `p(y)`, which on sample produces tensors of size + `[N, ...]` where `N` is the number of samples asked and `...` represents + the shape of the latent variables. + """ + return self._prior() + + +class UpsampleModule(snt.AbstractModule): + """Convolutional decoder. + + If `method` is 'deconv' apply transposed convolutions with stride 2, + otherwise apply the `method` upsampling function and then smooth with a + stride 1x1 convolution. + + Params: + ------- + filters: list, where the first element is the number of filters of the initial + MLP layer and the remaining elements are the number of filters of the + upsampling layers. + kernel_size: the size of the convolutional kernels. The same size will be + used in all convolutions. + activation: an activation function, applied to all layers but the last. + dec_up_strides: list, the upsampling factors of each upsampling convolutional + layer. + enc_conv_shapes: list, the shapes of the input and of all the intermediate + feature maps of the convolutional layers in the encoder. + n_c: the number of output channels. + """ + + def __init__(self, + filters, + kernel_size, + activation, + dec_up_strides, + enc_conv_shapes, + n_c, + method='nn', + name='upsample_module'): + super(UpsampleModule, self).__init__(name=name) + + assert len(filters) == len(dec_up_strides) + 1, ( + 'The decoder\'s filters should contain one element more than the ' + 'decoder\'s up stride list, but has %d elements instead of %d.\n' + 'Decoder filters: %s\nDecoder up strides: %s' % + (len(filters), len(dec_up_strides) + 1, str(filters), + str(dec_up_strides))) + + self._filters = filters + self._kernel_size = kernel_size + self._activation = activation + + self._dec_up_strides = dec_up_strides + self._enc_conv_shapes = enc_conv_shapes + self._n_c = n_c + if method == 'deconv': + self._conv_layer = tf.layers.Conv2DTranspose + self._method = method + else: + self._conv_layer = tf.layers.Conv2D + self._method = getattr(tf.image.ResizeMethod, method.upper()) + self._method_str = method.capitalize() + + def _build(self, z, is_training=True, test_local_stats=True, use_bn=False): + batch_norm_args = { + 'is_training': is_training, + 'test_local_stats': test_local_stats + } + + method = self._method + # Cycle over the encoder shapes backwards, to build a symmetrical decoder. + enc_conv_shapes = self._enc_conv_shapes[::-1] + strides = self._dec_up_strides + # We store the heights and widths of the encoder feature maps that are + # unique, i.e., the ones right after a layer with stride != 1. These will be + # used as a target to potentially crop the upsampled feature maps. + unique_hw = np.unique([(el[1], el[2]) for el in enc_conv_shapes], axis=0) + unique_hw = unique_hw.tolist()[::-1] + unique_hw.pop() # Drop the initial shape + + # The first filter is an MLP. + mlp_filter, conv_filters = self._filters[0], self._filters[1:] + # The first shape is used after the MLP to go to 4D. + + layers = [z] + # The shape of the first enc is used after the MLP to go back to 4D. + dec_mlp = snt.nets.MLP( + name='dec_mlp_projection', + output_sizes=[mlp_filter, np.prod(enc_conv_shapes[0][1:])], + use_bias=not use_bn, + activation=self._activation, + activate_final=True) + + upsample_mlp_flat = dec_mlp(z) + if use_bn: + upsample_mlp_flat = snt.BatchNorm(scale=True)(upsample_mlp_flat, + **batch_norm_args) + layers.append(upsample_mlp_flat) + upsample = tf.reshape(upsample_mlp_flat, enc_conv_shapes[0]) + layers.append(upsample) + + for i, (filter_i, stride_i) in enumerate(zip(conv_filters, strides), 1): + if method != 'deconv' and stride_i > 1: + upsample = tf.image.resize_images( + upsample, [stride_i * el for el in upsample.shape.as_list()[1:3]], + method=method, + name='upsample_' + str(i)) + upsample = self._conv_layer( + filters=filter_i, + kernel_size=self._kernel_size, + padding='same', + use_bias=not use_bn, + activation=self._activation, + strides=stride_i if method == 'deconv' else 1, + name='upsample_conv_' + str(i))( + upsample) + if use_bn: + upsample = snt.BatchNorm(scale=True)(upsample, **batch_norm_args) + if stride_i > 1: + hw = unique_hw.pop() + upsample = utils.maybe_center_crop(upsample, hw) + layers.append(upsample) + + # Final layer, no upsampling. + x_logits = tf.layers.Conv2D( + filters=self._n_c, + kernel_size=self._kernel_size, + padding='same', + use_bias=not use_bn, + activation=None, + strides=1, + name='logits')( + upsample) + if use_bn: + x_logits = snt.BatchNorm(scale=True)(x_logits, **batch_norm_args) + layers.append(x_logits) + + logging.info('%s upsampling module layer shapes', self._method_str) + logging.info('\n'.join([str(v.shape.as_list()) for v in layers])) + + return x_logits diff --git a/curl/requirements.txt b/curl/requirements.txt new file mode 100644 index 0000000..652c195 --- /dev/null +++ b/curl/requirements.txt @@ -0,0 +1,10 @@ +absl-py==0.8.0 +dm-sonnet==1.35 +gast<0.3 +numpy==1.16.4 +scikit-learn==0.20.4 +setuptools>=41.0.0 +six==1.12.0 +tensorflow==1.14.0 +tensorflow-datasets==1.2.0 +tensorflow-probability==0.7.0 diff --git a/curl/train_main.py b/curl/train_main.py new file mode 100644 index 0000000..0874132 --- /dev/null +++ b/curl/train_main.py @@ -0,0 +1,79 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Training file to run most of the experiments in the paper. + +The default parameters corresponding to the first set of experiments in Section +4.2. + +For the expansion ablation, run with different ll_thresh values as in the paper. +Note that n_y_active represents the number of *active* components at the +start, and should be set to 1, while n_y represents the maximum number of +components allowed, and should be set sufficiently high (eg. n_y = 100). + +For the MGR ablation, setting use_sup_replay = True switches to using SMGR, +and the gen_replay_type flag can switch between fixed and dynamic replay. The +generative snapshot period is set automatically in the train_curl.py file based +on these settings (ie. the data_period variable), so the 0.1T runs can be +reproduced by dividing this value by 10. +""" + +from absl import app +from absl import flags + +from curl import training + +flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.') + +FLAGS = flags.FLAGS + + +def main(unused_argv): + training.run_training( + dataset=FLAGS.dataset, + output_type='bernoulli', + n_y=30, + n_y_active=1, + training_data_type='sequential', + n_concurrent_classes=1, + lr_init=1e-3, + lr_factor=1., + lr_schedule=[1], + blend_classes=False, + train_supervised=False, + n_steps=100000, + report_interval=10000, + knn_values=[10], + random_seed=1, + encoder_kwargs={ + 'encoder_type': 'multi', + 'n_enc': [1200, 600, 300, 150], + 'enc_strides': [1], + }, + decoder_kwargs={ + 'decoder_type': 'single', + 'n_dec': [500, 500], + 'dec_up_strides': None, + }, + n_z=32, + dynamic_expansion=True, + ll_thresh=-200.0, + classify_with_samples=False, + gen_replay_type='fixed', + use_supervised_replay=False, + ) + +if __name__ == '__main__': + app.run(main) diff --git a/curl/train_sup.py b/curl/train_sup.py new file mode 100644 index 0000000..9c68286 --- /dev/null +++ b/curl/train_sup.py @@ -0,0 +1,64 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Runs the supervised CL benchmark experiments in the paper.""" + +from absl import app +from absl import flags + +from curl import training + +flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.') + +FLAGS = flags.FLAGS + + +def main(unused_argv): + training.run_training( + dataset=FLAGS.dataset, + output_type='bernoulli', + n_y=10, + n_y_active=10, + training_data_type='sequential', + n_concurrent_classes=2, + lr_init=1e-3, + lr_factor=1., + lr_schedule=[1], + train_supervised=True, + blend_classes=False, + n_steps=100000, + report_interval=10000, + knn_values=[], + random_seed=1, + encoder_kwargs={ + 'encoder_type': 'multi', + 'n_enc': [400, 400], + 'enc_strides': [1], + }, + decoder_kwargs={ + 'decoder_type': 'single', + 'n_dec': [400, 400], + 'dec_up_strides': None, + }, + n_z=32, + dynamic_expansion=False, + ll_thresh=-10000.0, + classify_with_samples=False, + gen_replay_type='fixed', + use_supervised_replay=False, + ) + +if __name__ == '__main__': + app.run(main) diff --git a/curl/train_unsup.py b/curl/train_unsup.py new file mode 100644 index 0000000..48e7ad2 --- /dev/null +++ b/curl/train_unsup.py @@ -0,0 +1,73 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Runs the unsupervised i.i.d benchmark experiments in the paper.""" + +from absl import app +from absl import flags + +from curl import training + +flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.') + +FLAGS = flags.FLAGS + + +def main(unused_argv): + if FLAGS.dataset == 'mnist': + n_y = 25 + n_y_active = 25 + n_z = 50 + else: # omniglot + n_y = 100 + n_y_active = 100 + n_z = 100 + + training.run_training( + dataset=FLAGS.dataset, + n_y=n_y, + n_y_active=n_y_active, + n_z=n_z, + output_type='bernoulli', + training_data_type='iid', + n_concurrent_classes=1, + lr_init=5e-4, + lr_factor=1., + lr_schedule=[1], + blend_classes=False, + train_supervised=False, + n_steps=100000, + report_interval=10000, + knn_values=[3, 5, 10], + random_seed=1, + encoder_kwargs={ + 'encoder_type': 'multi', + 'n_enc': [500, 500], + 'enc_strides': [1], + }, + decoder_kwargs={ + 'decoder_type': 'single', + 'n_dec': [500], + 'dec_up_strides': None, + }, + dynamic_expansion=False, + ll_thresh=-0.0, + classify_with_samples=True, + gen_replay_type=None, + use_supervised_replay=False, + ) + +if __name__ == '__main__': + app.run(main) diff --git a/curl/training.py b/curl/training.py new file mode 100644 index 0000000..086d068 --- /dev/null +++ b/curl/training.py @@ -0,0 +1,1169 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Script to train CURL.""" + +import collections +import functools +from absl import logging + +import numpy as np +from sklearn import neighbors +import sonnet as snt +import tensorflow as tf +import tensorflow_datasets as tfds +import tensorflow_probability as tfp + +from curl import model +from curl import utils + +tfc = tf.compat.v1 + +# pylint: disable=g-long-lambda + +MainOps = collections.namedtuple('MainOps', [ + 'elbo', 'll', 'log_p_x', 'kl_y', 'kl_z', 'elbo_supervised', 'll_supervised', + 'log_p_x_supervised', 'kl_y_supervised', 'kl_z_supervised', + 'cat_probs', 'confusion', 'purity', 'latents' +]) + +DatasetTuple = collections.namedtuple('DatasetTuple', [ + 'train_data', 'train_iter_for_clf', 'train_data_for_clf', + 'valid_iter', 'valid_data', 'test_iter', 'test_data', 'ds_info' +]) + + +def compute_purity(confusion): + return np.sum(np.max(confusion, axis=0)).astype(float) / np.sum(confusion) + + +def process_dataset(iterator, + ops_to_run, + sess, + feed_dict=None, + aggregation_ops=np.stack, + processing_ops=None): + """Process a dataset by computing ops and accumulating batch by batch. + + Args: + iterator: iterator through the dataset. + ops_to_run: dict, tf ops to run as part of dataset processing. + sess: tf.Session to use. + feed_dict: dict, required placeholders. + aggregation_ops: fn or dict of fns, aggregation op to apply for each op. + processing_ops: fn or dict of fns, extra processing op to apply for each op. + + Returns: + Results accumulated over dataset. + """ + + if not isinstance(ops_to_run, dict): + raise TypeError('ops_to_run must be specified as a dict') + + if not isinstance(aggregation_ops, dict): + aggregation_ops = {k: aggregation_ops for k in ops_to_run} + if not isinstance(processing_ops, dict): + processing_ops = {k: processing_ops for k in ops_to_run} + + out_results = collections.OrderedDict() + sess.run(iterator.initializer) + while True: + # Iterate over the whole dataset and append the results to a per-key list. + try: + outs = sess.run(ops_to_run, feed_dict=feed_dict) + for key, value in outs.items(): + out_results.setdefault(key, []).append(value) + + except tf.errors.OutOfRangeError: # end of dataset iterator + break + + # Aggregate and process results. + for key, value in out_results.items(): + if aggregation_ops[key]: + out_results[key] = aggregation_ops[key](value) + if processing_ops[key]: + out_results[key] = processing_ops[key](out_results[key], axis=0) + + return out_results + + +def get_data_sources(dataset, dataset_kwargs, batch_size, test_batch_size, + training_data_type, n_concurrent_classes, image_key, + label_key): + """Create and return data sources for training, validation, and testing. + + Args: + dataset: str, name of dataset ('mnist', 'omniglot', etc). + dataset_kwargs: dict, kwargs used in tf dataset constructors. + batch_size: int, batch size used for training. + test_batch_size: int, batch size used for evaluation. + training_data_type: str, how training data is seen ('iid', or 'sequential'). + n_concurrent_classes: int, # classes seen at a time (ignored for 'iid'). + image_key: str, name if image key in dataset. + label_key: str, name of label key in dataset. + + Returns: + A namedtuple containing all of the dataset iterators and batches. + + """ + + # Load training data sources + ds_train, ds_info = tfds.load( + name=dataset, + split=tfds.Split.TRAIN, + with_info=True, + as_dataset_kwargs={'shuffle_files': False}, + **dataset_kwargs) + + # Validate assumption that data is in [0, 255] + assert ds_info.features[image_key].dtype == tf.uint8 + + n_classes = ds_info.features[label_key].num_classes + num_train_examples = ds_info.splits['train'].num_examples + + def preprocess_data(x): + """Convert images from uint8 in [0, 255] to float in [0, 1].""" + x[image_key] = tf.image.convert_image_dtype(x[image_key], tf.float32) + return x + + if training_data_type == 'sequential': + c = None # The index of the class number, None for now and updated later + if n_concurrent_classes == 1: + filter_fn = lambda v: tf.equal(v[label_key], c) + else: + # Define the lowest and highest class number at each data period. + assert n_classes % n_concurrent_classes == 0, ( + 'Number of total classes must be divisible by ' + 'number of concurrent classes') + cmin = [] + cmax = [] + for i in range(int(n_classes / n_concurrent_classes)): + for _ in range(n_concurrent_classes): + cmin.append(i * n_concurrent_classes) + cmax.append((i + 1) * n_concurrent_classes) + + filter_fn = lambda v: tf.logical_and( + tf.greater_equal(v[label_key], cmin[c]), tf.less( + v[label_key], cmax[c])) + + # Set up data sources/queues (one for each class). + train_datasets = [] + train_iterators = [] + train_data = [] + + full_ds = ds_train.repeat().shuffle(num_train_examples, seed=0) + full_ds = full_ds.map(preprocess_data) + for c in range(n_classes): + filtered_ds = full_ds.filter(filter_fn).batch( + batch_size, drop_remainder=True) + train_datasets.append(filtered_ds) + train_iterators.append(train_datasets[-1].make_one_shot_iterator()) + train_data.append(train_iterators[-1].get_next()) + + else: # not sequential + full_ds = ds_train.repeat().shuffle(num_train_examples, seed=0) + full_ds = full_ds.map(preprocess_data) + train_datasets = full_ds.batch(batch_size, drop_remainder=True) + train_data = train_datasets.make_one_shot_iterator().get_next() + + # Set up data source to get full training set for classifier training + full_ds = ds_train.repeat(1).shuffle(num_train_examples, seed=0) + full_ds = full_ds.map(preprocess_data) + train_datasets_for_classifier = full_ds.batch( + test_batch_size, drop_remainder=True) + train_iter_for_classifier = ( + train_datasets_for_classifier.make_initializable_iterator()) + train_data_for_classifier = train_iter_for_classifier.get_next() + + # Load validation dataset. + try: + valid_dataset = tfds.load( + name=dataset, split=tfds.Split.VALIDATION, **dataset_kwargs) + num_valid_examples = ds_info.splits[tfds.Split.VALIDATION].num_examples + assert (num_valid_examples % + test_batch_size == 0), ('test_batch_size must be a multiple of %d' % + num_valid_examples) + valid_dataset = valid_dataset.repeat(1).batch( + test_batch_size, drop_remainder=True) + valid_dataset = valid_dataset.map(preprocess_data) + valid_iter = valid_dataset.make_initializable_iterator() + valid_data = valid_iter.get_next() + except KeyError: + logging.warning('No validation set!!') + valid_iter = None + valid_data = None + + # Load test dataset. + test_dataset = tfds.load( + name=dataset, split=tfds.Split.TEST, **dataset_kwargs) + num_test_examples = ds_info.splits['test'].num_examples + assert (num_test_examples % + test_batch_size == 0), ('test_batch_size must be a multiple of %d' % + num_test_examples) + test_dataset = test_dataset.repeat(1).batch( + test_batch_size, drop_remainder=True) + test_dataset = test_dataset.map(preprocess_data) + test_iter = test_dataset.make_initializable_iterator() + test_data = test_iter.get_next() + logging.info('Loaded %s data', dataset) + + return DatasetTuple(train_data, train_iter_for_classifier, + train_data_for_classifier, valid_iter, valid_data, + test_iter, test_data, ds_info) + + +def setup_training_and_eval_graphs(x, label, y, n_y, curl_model, + classify_with_samples, is_training, name): + """Set up the graph and return ops for training or evaluation. + + Args: + x: tf placeholder for image. + label: tf placeholder for ground truth label. + y: tf placeholder for some self-supervised label/prediction. + n_y: int, dimensionality of discrete latent variable y. + curl_model: snt.AbstractModule representing the CURL model. + classify_with_samples: bool, whether to *sample* latents for classification. + is_training: bool, whether this graph is the training graph. + name: str, graph name. + + Returns: + A namedtuple with the required graph ops to perform training or evaluation. + + """ + # kl_y_supervised is -log q(y=y_true | x) + (log_p_x, kl_y, kl_z, log_p_x_supervised, kl_y_supervised, + kl_z_supervised) = curl_model.log_prob_elbo_components(x, y) + + ll = log_p_x - kl_y - kl_z + elbo = -tf.reduce_mean(ll) + + # Supervised loss, either for SMGR, or adaptation to supervised benchmark. + ll_supervised = log_p_x_supervised - kl_y_supervised - kl_z_supervised + elbo_supervised = -tf.reduce_mean(ll_supervised) + + # Summaries + kl_y = tf.reduce_mean(kl_y) + kl_z = tf.reduce_mean(kl_z) + log_p_x_supervised = tf.reduce_mean(log_p_x_supervised) + kl_y_supervised = tf.reduce_mean(kl_y_supervised) + kl_z_supervised = tf.reduce_mean(kl_z_supervised) + + # Evaluation. + hiddens = curl_model.get_shared_rep(x, is_training=is_training) + cat = curl_model.infer_cluster(hiddens) + cat_probs = cat.probs + + confusion = tf.confusion_matrix(label, tf.argmax(cat_probs, axis=1), + num_classes=n_y, name=name + '_confusion') + purity = (tf.reduce_sum(tf.reduce_max(confusion, axis=0)) + / tf.reduce_sum(confusion)) + + if classify_with_samples: + latents = curl_model.infer_latent( + hiddens=hiddens, y=tf.to_float(cat.sample())).sample() + else: + latents = curl_model.infer_latent( + hiddens=hiddens, y=tf.to_float(cat.mode())).mean() + + return MainOps(elbo, ll, log_p_x, kl_y, kl_z, elbo_supervised, ll_supervised, + log_p_x_supervised, kl_y_supervised, kl_z_supervised, + cat_probs, confusion, purity, latents) + + +def get_generated_data(sess, gen_op, y_input, gen_buffer_size, + component_counts): + """Get generated model data (in place of saving a model snapshot). + + Args: + sess: tf.Session. + gen_op: tf op representing a batch of generated data. + y_input: tf placeholder for which mixture components to generate from. + gen_buffer_size: int, number of data points to generate. + component_counts: np.array, prior probabilities over components. + + Returns: + A tuple of two numpy arrays + The generated data + The corresponding labels + """ + + batch_size, n_y = y_input.shape.as_list() + + # Sample based on the history of all components used. + cluster_sample_probs = component_counts.astype(float) + cluster_sample_probs = np.maximum(1e-12, cluster_sample_probs) + cluster_sample_probs = cluster_sample_probs / np.sum(cluster_sample_probs) + + # Now generate the data based on the specified cluster prior. + gen_buffer_images = [] + gen_buffer_labels = [] + for _ in range(gen_buffer_size): + gen_label = np.random.choice( + np.arange(n_y), + size=(batch_size,), + replace=True, + p=cluster_sample_probs) + y_gen_posterior_vals = np.zeros((batch_size, n_y)) + y_gen_posterior_vals[np.arange(batch_size), gen_label] = 1 + gen_image = sess.run(gen_op, feed_dict={y_input: y_gen_posterior_vals}) + gen_buffer_images.append(gen_image) + gen_buffer_labels.append(gen_label) + + gen_buffer_images = np.vstack(gen_buffer_images) + gen_buffer_labels = np.concatenate(gen_buffer_labels) + + return gen_buffer_images, gen_buffer_labels + + +def setup_dynamic_ops(n_y): + """Set up ops to move / copy mixture component weights for dynamic expansion. + + Args: + n_y: int, dimensionality of discrete latent variable y. + + Returns: + A dict containing all of the ops required for dynamic updating. + + """ + # Set up graph ops to dynamically modify component params. + graph = tf.get_default_graph() + + # 1) Ops to get and set latent encoder params (entire tensors) + latent_enc_tensors = {} + for k in range(n_y): + latent_enc_tensors['latent_w_' + str(k)] = graph.get_tensor_by_name( + 'latent_encoder/mlp_latent_encoder_{}/w:0'.format(k)) + latent_enc_tensors['latent_b_' + str(k)] = graph.get_tensor_by_name( + 'latent_encoder/mlp_latent_encoder_{}/b:0'.format(k)) + + latent_enc_assign_ops = {} + latent_enc_phs = {} + for key, tensor in latent_enc_tensors.items(): + latent_enc_phs[key] = tfc.placeholder(tensor.dtype, tensor.shape) + latent_enc_assign_ops[key] = tf.assign(tensor, latent_enc_phs[key]) + + # 2) Ops to get and set cluster encoder params (columns of a tensor) + # We will be copying column ind_from to column ind_to. + cluster_w = graph.get_tensor_by_name( + 'cluster_encoder/mlp_cluster_encoder_final/w:0') + cluster_b = graph.get_tensor_by_name( + 'cluster_encoder/mlp_cluster_encoder_final/b:0') + + ind_from = tfc.placeholder(dtype=tf.int32) + ind_to = tfc.placeholder(dtype=tf.int32) + + # Determine indices of cluster encoder weights and biases to be updated + w_indices = tf.transpose( + tf.stack([ + tf.range(cluster_w.shape[0], dtype=tf.int32), + ind_to * tf.ones(shape=(cluster_w.shape[0],), dtype=tf.int32) + ])) + b_indices = ind_to + # Determine updates themselves + cluster_w_updates = tf.squeeze( + tf.slice(cluster_w, begin=(0, ind_from), size=(cluster_w.shape[0], 1))) + cluster_b_updates = cluster_b[ind_from] + # Create update ops + cluster_w_update_op = tf.scatter_nd_update(cluster_w, w_indices, + cluster_w_updates) + cluster_b_update_op = tf.scatter_update(cluster_b, b_indices, + cluster_b_updates) + + # 3) Ops to get and set latent prior params (columns of a tensor) + # We will be copying column ind_from to column ind_to. + latent_prior_mu_w = graph.get_tensor_by_name( + 'latent_decoder/latent_prior_mu/w:0') + latent_prior_sigma_w = graph.get_tensor_by_name( + 'latent_decoder/latent_prior_sigma/w:0') + + mu_indices = tf.transpose( + tf.stack([ + ind_to * tf.ones(shape=(latent_prior_mu_w.shape[1],), dtype=tf.int32), + tf.range(latent_prior_mu_w.shape[1], dtype=tf.int32) + ])) + mu_updates = tf.squeeze( + tf.slice( + latent_prior_mu_w, + begin=(ind_from, 0), + size=(1, latent_prior_mu_w.shape[1]))) + mu_update_op = tf.scatter_nd_update(latent_prior_mu_w, mu_indices, mu_updates) + sigma_indices = tf.transpose( + tf.stack([ + ind_to * + tf.ones(shape=(latent_prior_sigma_w.shape[1],), dtype=tf.int32), + tf.range(latent_prior_sigma_w.shape[1], dtype=tf.int32) + ])) + sigma_updates = tf.squeeze( + tf.slice( + latent_prior_sigma_w, + begin=(ind_from, 0), + size=(1, latent_prior_sigma_w.shape[1]))) + sigma_update_op = tf.scatter_nd_update(latent_prior_sigma_w, sigma_indices, + sigma_updates) + + dynamic_ops = { + 'ind_from_ph': ind_from, + 'ind_to_ph': ind_to, + 'latent_enc_tensors': latent_enc_tensors, + 'latent_enc_assign_ops': latent_enc_assign_ops, + 'latent_enc_phs': latent_enc_phs, + 'cluster_w_update_op': cluster_w_update_op, + 'cluster_b_update_op': cluster_b_update_op, + 'mu_update_op': mu_update_op, + 'sigma_update_op': sigma_update_op + } + + return dynamic_ops + + +def copy_component_params(ind_from, ind_to, sess, ind_from_ph, ind_to_ph, + latent_enc_tensors, latent_enc_assign_ops, + latent_enc_phs, + cluster_w_update_op, cluster_b_update_op, + mu_update_op, sigma_update_op): + """Copy parameters from component i to component j. + + Args: + ind_from: int, component index to copy from. + ind_to: int, component index to copy to. + sess: tf.Session. + ind_from_ph: tf placeholder for component to copy from. + ind_to_ph: tf placeholder for component to copy to. + latent_enc_tensors: dict, tensors in the latent posterior encoder. + latent_enc_assign_ops: dict, assignment ops for latent posterior encoder. + latent_enc_phs: dict, placeholders for assignment ops. + cluster_w_update_op: op for updating weights of cluster encoder. + cluster_b_update_op: op for updating biased of cluster encoder. + mu_update_op: op for updating mu weights of latent prior. + sigma_update_op: op for updating sigma weights of latent prior. + + """ + update_ops = [] + feed_dict = {} + # Copy for latent encoder. + new_w_val, new_b_val = sess.run([ + latent_enc_tensors['latent_w_' + str(ind_from)], + latent_enc_tensors['latent_b_' + str(ind_from)] + ]) + update_ops.extend([ + latent_enc_assign_ops['latent_w_' + str(ind_to)], + latent_enc_assign_ops['latent_b_' + str(ind_to)] + ]) + feed_dict.update({ + latent_enc_phs['latent_w_' + str(ind_to)]: new_w_val, + latent_enc_phs['latent_b_' + str(ind_to)]: new_b_val + }) + + # Copy for cluster encoder softmax. + update_ops.extend([cluster_w_update_op, cluster_b_update_op]) + feed_dict.update({ind_from_ph: ind_from, ind_to_ph: ind_to}) + + # Copy for latent prior. + update_ops.extend([mu_update_op, sigma_update_op]) + feed_dict.update({ind_from_ph: ind_from, ind_to_ph: ind_to}) + sess.run(update_ops, feed_dict) + + +def run_training( + dataset, + training_data_type, + n_concurrent_classes, + blend_classes, + train_supervised, + n_steps, + random_seed, + lr_init, + lr_factor, + lr_schedule, + output_type, + n_y, + n_y_active, + n_z, + encoder_kwargs, + decoder_kwargs, + dynamic_expansion, + ll_thresh, + classify_with_samples, + report_interval, + knn_values, + gen_replay_type, + use_supervised_replay): + """Run training script. + + Args: + dataset: str, name of the dataset. + training_data_type: str, type of training run ('iid' or 'sequential'). + n_concurrent_classes: int, # of classes seen at a time (ignored for 'iid'). + blend_classes: bool, whether to blend in samples from the next class. + train_supervised: bool, whether to use supervision during training. + n_steps: int, number of total training steps. + random_seed: int, seed for tf and numpy RNG. + lr_init: float, initial learning rate. + lr_factor: float, learning rate decay factor. + lr_schedule: float, epochs at which the decay should be applied. + output_type: str, output distribution (currently only 'bernoulli'). + n_y: int, maximum possible dimensionality of discrete latent variable y. + n_y_active: int, starting dimensionality of discrete latent variable y. + n_z: int, dimensionality of continuous latent variable z. + encoder_kwargs: dict, parameters to specify encoder. + decoder_kwargs: dict, parameters to specify decoder. + dynamic_expansion: bool, whether to perform dynamic expansion. + ll_thresh: float, log-likelihood threshold below which to keep poor samples. + classify_with_samples: bool, whether to sample latents when classifying. + report_interval: int, number of steps after which to evaluate and report. + knn_values: list of ints, k values for different k-NN classifiers to run + (values of 3, 5, and 10 were used in different parts of the paper). + gen_replay_type: str, 'fixed', 'dynamic', or None. + use_supervised_replay: str, whether to use supervised replay (aka 'SMGR'). + """ + + # Set tf random seed. + tfc.set_random_seed(random_seed) + np.set_printoptions(precision=2, suppress=True) + + # First set up the data source(s) and get dataset info. + if dataset == 'mnist': + batch_size = 100 + test_batch_size = 1000 + dataset_kwargs = {} + image_key = 'image' + label_key = 'label' + elif dataset == 'omniglot': + batch_size = 15 + test_batch_size = 8115 + dataset_kwargs = {'split': 'instance', 'label': 'alphabet'} + image_key = 'image' + label_key = 'alphabet' + else: + raise NotImplementedError + + dataset_ops = get_data_sources(dataset, dataset_kwargs, batch_size, + test_batch_size, training_data_type, + n_concurrent_classes, image_key, label_key) + train_data = dataset_ops.train_data + train_data_for_clf = dataset_ops.train_data_for_clf + valid_data = dataset_ops.valid_data + test_data = dataset_ops.test_data + + output_shape = dataset_ops.ds_info.features[image_key].shape + n_x = np.prod(output_shape) + n_classes = dataset_ops.ds_info.features[label_key].num_classes + num_train_examples = dataset_ops.ds_info.splits['train'].num_examples + + # Check that the number of classes is compatible with the training scenario + assert n_classes % n_concurrent_classes == 0 + assert n_steps % (n_classes / n_concurrent_classes) == 0 + + # Set specific params depending on the type of gen replay + if gen_replay_type == 'fixed': + data_period = data_period = int(n_steps / + (n_classes / n_concurrent_classes)) + gen_every_n = 2 # Blend in a gen replay batch every 2 steps + gen_refresh_period = data_period # How often to refresh the batches of + # generated data (equivalent to snapshotting a generative model) + gen_refresh_on_expansion = False # Don't refresh on dyn expansion + elif gen_replay_type == 'dynamic': + gen_every_n = 2 # Blend in a gen replay batch every 2 steps + gen_refresh_period = 1e8 # Never refresh generated data periodically + gen_refresh_on_expansion = True # Refresh on dyn expansion instead + elif gen_replay_type is None: + gen_every_n = 0 # Don't use any gen replay batches + gen_refresh_period = 1e8 # Never refresh generated data periodically + gen_refresh_on_expansion = False # Don't refresh on dyn expansion + else: + raise NotImplementedError + + max_gen_batches = 5000 # Max num of gen batches (proxy for storing a model) + + # Set dynamic expansion parameters + exp_wait_steps = 100 # Steps to wait after expansion before eligible again + exp_burn_in = 100 # Steps to wait at start of learning before eligible + exp_buffer_size = 100 # Size of the buffer of poorly explained data + num_buffer_train_steps = 20 # Num steps to train component on buffer + + # Define a global tf variable for the number of active components. + n_y_active_np = n_y_active + n_y_active = tfc.get_variable( + initializer=tf.constant(n_y_active_np, dtype=tf.int32), + trainable=False, + name='n_y_active', + dtype=tf.int32) + + logging.info('Starting CURL script on %s data.', dataset) + + # Set up placeholders for training. + + x_train_raw = tfc.placeholder( + dtype=tf.float32, shape=(batch_size,) + output_shape) + label_train = tfc.placeholder(dtype=tf.int32, shape=(batch_size,)) + + def binarize_fn(x): + """Binarize a Bernoulli by rounding the probabilities. + + Args: + x: tf tensor, input image. + + Returns: + A tf tensor with the binarized image + """ + return tf.cast(tf.greater(x, 0.5 * tf.ones_like(x)), tf.float32) + + if dataset == 'mnist': + x_train = binarize_fn(x_train_raw) + x_valid = binarize_fn(valid_data[image_key]) if valid_data else None + x_test = binarize_fn(test_data[image_key]) + x_train_for_clf = binarize_fn(train_data_for_clf[image_key]) + elif 'cifar' in dataset or dataset == 'omniglot': + x_train = x_train_raw + x_valid = valid_data[image_key] if valid_data else None + x_test = test_data[image_key] + x_train_for_clf = train_data_for_clf[image_key] + else: + raise ValueError('Unknown dataset {}'.format(dataset)) + + label_valid = valid_data[label_key] if valid_data else None + label_test = test_data[label_key] + + # Set up CURL modules. + shared_encoder = model.SharedEncoder(name='shared_encoder', **encoder_kwargs) + latent_encoder = functools.partial(model.latent_encoder_fn, n_y=n_y, n_z=n_z) + latent_encoder = snt.Module(latent_encoder, name='latent_encoder') + latent_decoder = functools.partial(model.latent_decoder_fn, n_z=n_z) + latent_decoder = snt.Module(latent_decoder, name='latent_decoder') + cluster_encoder = functools.partial( + model.cluster_encoder_fn, n_y_active=n_y_active, n_y=n_y) + cluster_encoder = snt.Module(cluster_encoder, name='cluster_encoder') + data_decoder = functools.partial( + model.data_decoder_fn, + output_type=output_type, + output_shape=output_shape, + n_x=n_x, + n_y=n_y, + **decoder_kwargs) + data_decoder = snt.Module(data_decoder, name='data_decoder') + + # Uniform prior over y. + prior_train_probs = utils.construct_prior_probs(batch_size, n_y, n_y_active) + prior_train = snt.Module( + lambda: tfp.distributions.OneHotCategorical(probs=prior_train_probs), + name='prior_unconditional_train') + prior_test_probs = utils.construct_prior_probs(test_batch_size, n_y, + n_y_active) + prior_test = snt.Module( + lambda: tfp.distributions.OneHotCategorical(probs=prior_test_probs), + name='prior_unconditional_test') + + model_train = model.Curl( + prior_train, + latent_decoder, + data_decoder, + shared_encoder, + cluster_encoder, + latent_encoder, + n_y_active, + is_training=True, + name='curl_train') + model_eval = model.Curl( + prior_test, + latent_decoder, + data_decoder, + shared_encoder, + cluster_encoder, + latent_encoder, + n_y_active, + is_training=False, + name='curl_test') + + # Set up training graph + y_train = label_train if train_supervised else None + y_valid = label_valid if train_supervised else None + y_test = label_test if train_supervised else None + + train_ops = setup_training_and_eval_graphs( + x_train, + label_train, + y_train, + n_y, + model_train, + classify_with_samples, + is_training=True, + name='train') + + hiddens_for_clf = model_eval.get_shared_rep(x_train_for_clf, + is_training=False) + cat_for_clf = model_eval.infer_cluster(hiddens_for_clf) + + if classify_with_samples: + latents_for_clf = model_eval.infer_latent( + hiddens=hiddens_for_clf, y=tf.to_float(cat_for_clf.sample())).sample() + else: + latents_for_clf = model_eval.infer_latent( + hiddens=hiddens_for_clf, y=tf.to_float(cat_for_clf.mode())).mean() + + # Set up validation graph + if valid_data is not None: + valid_ops = setup_training_and_eval_graphs( + x_valid, + label_valid, + y_valid, + n_y, + model_eval, + classify_with_samples, + is_training=False, + name='valid') + + # Set up test graph + test_ops = setup_training_and_eval_graphs( + x_test, + label_test, + y_test, + n_y, + model_eval, + classify_with_samples, + is_training=False, + name='test') + + # Set up optimizer (with scheduler). + global_step = tf.train.get_or_create_global_step() + lr_schedule = [ + tf.cast(el * num_train_examples / batch_size, tf.int64) + for el in lr_schedule + ] + num_schedule_steps = tf.reduce_sum( + tf.cast(global_step >= lr_schedule, tf.float32)) + lr = float(lr_init) * float(lr_factor)**num_schedule_steps + optimizer = tf.train.AdamOptimizer(learning_rate=lr) + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + train_step = optimizer.minimize(train_ops.elbo) + train_step_supervised = optimizer.minimize(train_ops.elbo_supervised) + + # Set up ops for generative replay + if gen_every_n > 0: + # How many generative batches will we use each period? + gen_buffer_size = min( + int(gen_refresh_period / gen_every_n), max_gen_batches) + + # Class each sample should be drawn from (default to uniform prior) + y_gen = tfp.distributions.OneHotCategorical( + probs=np.ones((batch_size, n_y)) / n_y, + dtype=tf.float32, + name='extra_train_classes').sample() + + gen_samples = model_train.sample(y=y_gen, mean=True) + if dataset == 'mnist' or dataset == 'omniglot': + gen_samples = binarize_fn(gen_samples) + + # Set up ops to dynamically modify parameters (for dynamic expansion) + dynamic_ops = setup_dynamic_ops(n_y) + + logging.info('Created computation graph.') + + n_steps_per_class = n_steps / n_classes # pylint: disable=invalid-name + + cumulative_component_counts = np.array([0] * n_y).astype(float) + recent_component_counts = np.array([0] * n_y).astype(float) + + gen_buffer_ind = 0 + + # Buffer of poorly explained data (if we're doing dynamic expansion). + poor_data_buffer = [] + poor_data_labels = [] + all_full_poor_data_buffers = [] + all_full_poor_data_labels = [] + has_expanded = False + steps_since_expansion = 0 + gen_buffer_ind = 0 + eligible_for_expansion = False # Flag to ensure we wait a bit after expansion + + # Set up basic ops to run and quantities to log. + ops_to_run = { + 'train_ELBO': train_ops.elbo, + 'train_log_p_x': train_ops.log_p_x, + 'train_kl_y': train_ops.kl_y, + 'train_kl_z': train_ops.kl_z, + 'train_ll': train_ops.ll, + 'train_batch_purity': train_ops.purity, + 'train_probs': train_ops.cat_probs, + 'n_y_active': n_y_active + } + if valid_data is not None: + valid_ops_to_run = { + 'valid_ELBO': valid_ops.elbo, + 'valid_kl_y': valid_ops.kl_y, + 'valid_kl_z': valid_ops.kl_z, + 'valid_confusion': valid_ops.confusion + } + else: + valid_ops_to_run = {} + test_ops_to_run = { + 'test_ELBO': test_ops.elbo, + 'test_kl_y': test_ops.kl_y, + 'test_kl_z': test_ops.kl_z, + 'test_confusion': test_ops.confusion + } + to_log = ['train_batch_purity'] + to_log_eval = ['test_purity', 'test_ELBO', 'test_kl_y', 'test_kl_z'] + if valid_data is not None: + to_log_eval += ['valid_ELBO', 'valid_purity'] + + if train_supervised: + # Track supervised losses, train on supervised loss. + ops_to_run.update({ + 'train_ELBO_supervised': train_ops.elbo_supervised, + 'train_log_p_x_supervised': train_ops.log_p_x_supervised, + 'train_kl_y_supervised': train_ops.kl_y_supervised, + 'train_kl_z_supervised': train_ops.kl_z_supervised, + 'train_ll_supervised': train_ops.ll_supervised + }) + default_train_step = train_step_supervised + to_log += [ + 'train_ELBO_supervised', 'train_log_p_x_supervised', + 'train_kl_y_supervised', 'train_kl_z_supervised' + ] + else: + # Track unsupervised losses, train on unsupervised loss. + ops_to_run.update({ + 'train_ELBO': train_ops.elbo, + 'train_kl_y': train_ops.kl_y, + 'train_kl_z': train_ops.kl_z, + 'train_ll': train_ops.ll + }) + default_train_step = train_step + to_log += ['train_ELBO', 'train_kl_y', 'train_kl_z'] + + with tf.train.SingularMonitoredSession() as sess: + + for step in range(n_steps): + feed_dict = {} + + # Use the default training loss, but vary it each step depending on the + # training scenario (eg. for supervised gen replay, we alternate losses) + ops_to_run['train_step'] = default_train_step + + ### 1) PERIODICALLY TAKE SNAPSHOTS FOR GENERATIVE REPLAY ### + if (gen_refresh_period and step % gen_refresh_period == 0 and + gen_every_n > 0): + + # First, increment cumulative count and reset recent probs count. + cumulative_component_counts += recent_component_counts + recent_component_counts = np.zeros(n_y) + + # Generate enough samples for the rest of the next period + # (Functionally equivalent to storing and sampling from the model). + gen_buffer_images, gen_buffer_labels = get_generated_data( + sess=sess, + gen_op=gen_samples, + y_input=y_gen, + gen_buffer_size=gen_buffer_size, + component_counts=cumulative_component_counts) + + ### 2) DECIDE WHICH DATA SOURCE TO USE (GENERATIVE OR REAL DATA) ### + periodic_refresh_started = ( + gen_refresh_period and step >= gen_refresh_period) + refresh_on_expansion_started = (gen_refresh_on_expansion and has_expanded) + if ((periodic_refresh_started or refresh_on_expansion_started) and + gen_every_n > 0 and step % gen_every_n == 1): + # Use generated data for the training batch + used_real_data = False + + s = gen_buffer_ind * batch_size + e = (gen_buffer_ind + 1) * batch_size + + gen_data_array = { + 'image': gen_buffer_images[s:e], + 'label': gen_buffer_labels[s:e] + } + gen_buffer_ind = (gen_buffer_ind + 1) % gen_buffer_size + + # Feed it as x_train because it's already reshaped and binarized. + feed_dict.update({ + x_train: gen_data_array['image'], + label_train: gen_data_array['label'] + }) + + if use_supervised_replay: + # Convert label to one-hot before feeding in. + gen_label_onehot = np.eye(n_y)[gen_data_array['label']] + feed_dict.update({model_train.y_label: gen_label_onehot}) + ops_to_run['train_step'] = train_step_supervised + + else: + # Else use the standard training data sources. + used_real_data = True + + # Select appropriate data source for iid or sequential setup. + if training_data_type == 'sequential': + current_data_period = int( + min(step / n_steps_per_class, len(train_data) - 1)) + + # If training supervised, set n_y_active directly based on how many + # classes have been seen + if train_supervised: + assert not dynamic_expansion + n_y_active_np = n_concurrent_classes * ( + current_data_period // n_concurrent_classes +1) + n_y_active.load(n_y_active_np, sess) + + train_data_array = sess.run(train_data[current_data_period]) + + # If we are blending classes, figure out where we are in the data + # period and add some fraction of other samples. + if blend_classes: + # If in the first quarter, blend in examples from the previous class + if (step % n_steps_per_class < n_steps_per_class / 4 and + current_data_period > 0): + other_train_data_array = sess.run( + train_data[current_data_period - 1]) + + num_other = int( + (n_steps_per_class / 2 - 2 * + (step % n_steps_per_class)) * batch_size / n_steps_per_class) + other_inds = np.random.permutation(batch_size)[:num_other] + + train_data_array[image_key][:num_other] = other_train_data_array[ + image_key][other_inds] + train_data_array[label_key][:num_other] = other_train_data_array[ + label_key][other_inds] + + # If in the last quarter, blend in examples from the next class + elif (step % n_steps_per_class > 3 * n_steps_per_class / 4 and + current_data_period < n_classes - 1): + other_train_data_array = sess.run(train_data[current_data_period + + 1]) + + num_other = int( + (2 * (step % n_steps_per_class) - 3 * n_steps_per_class / 2) * + batch_size / n_steps_per_class) + other_inds = np.random.permutation(batch_size)[:num_other] + + train_data_array[image_key][:num_other] = other_train_data_array[ + image_key][other_inds] + train_data_array['label'][:num_other] = other_train_data_array[ + label_key][other_inds] + + # Otherwise, just use the current class + + else: + train_data_array = sess.run(train_data) + + feed_dict.update({ + x_train_raw: train_data_array[image_key], + label_train: train_data_array[label_key] + }) + + ### 3) PERFORM A GRADIENT STEP ### + results = sess.run(ops_to_run, feed_dict=feed_dict) + del results['train_step'] + + ### 4) COMPUTE ADDITIONAL DIAGNOSTIC OPS ON VALIDATION/TEST SETS. ### + if (step+1) % report_interval == 0: + if valid_data is not None: + logging.info('Evaluating on validation and test set!') + proc_ops = { + k: (np.sum if 'confusion' in k + else np.mean) for k in valid_ops_to_run + } + results.update( + process_dataset( + dataset_ops.valid_iter, + valid_ops_to_run, + sess, + feed_dict=feed_dict, + processing_ops=proc_ops)) + results['valid_purity'] = compute_purity(results['valid_confusion']) + else: + logging.info('Evaluating on test set!') + proc_ops = { + k: (np.sum if 'confusion' in k + else np.mean) for k in test_ops_to_run + } + results.update(process_dataset(dataset_ops.test_iter, + test_ops_to_run, + sess, + feed_dict=feed_dict, + processing_ops=proc_ops)) + results['test_purity'] = compute_purity(results['test_confusion']) + curr_to_log = to_log + to_log_eval + else: + curr_to_log = list(to_log) # copy to prevent in-place modifications + + ### 5) DYNAMIC EXPANSION ### + if dynamic_expansion and used_real_data: + # If we're doing dynamic expansion and below max capacity then add + # poorly defined data points to a buffer. + + # First check whether the model is eligible for expansion (the model + # becomes ineligible for a fixed time after each expansion, and when + # it has hit max capacity). + if (steps_since_expansion >= exp_wait_steps and step >= exp_burn_in and + n_y_active_np < n_y): + eligible_for_expansion = True + + steps_since_expansion += 1 + + if eligible_for_expansion: + # Add poorly explained data samples to a buffer. + poor_inds = results['train_ll'] < ll_thresh + poor_data_buffer.extend(feed_dict[x_train_raw][poor_inds]) + poor_data_labels.extend(feed_dict[label_train][poor_inds]) + + n_poor_data = len(poor_data_buffer) + + # If buffer is big enough, then add a new component and train just the + # new component with several steps of gradient descent. + # (We just feed in a onehot cluster vector to indicate which + # component). + if n_poor_data >= exp_buffer_size: + # Dump the buffers so we can log them. + all_full_poor_data_buffers.append(poor_data_buffer) + all_full_poor_data_labels.append(poor_data_labels) + + # Take a new generative snapshot if specified. + if gen_refresh_on_expansion and gen_every_n > 0: + # Increment cumulative count and reset recent probs count. + cumulative_component_counts += recent_component_counts + recent_component_counts = np.zeros(n_y) + + gen_buffer_images, gen_buffer_labels = get_generated_data( + sess=sess, + gen_op=gen_samples, + y_input=y_gen, + gen_buffer_size=gen_buffer_size, + component_counts=cumulative_component_counts) + + # Cull to a multiple of batch_size (keep the later data samples). + n_poor_batches = int(n_poor_data / batch_size) + poor_data_buffer = poor_data_buffer[-(n_poor_batches * batch_size):] + poor_data_labels = poor_data_labels[-(n_poor_batches * batch_size):] + + # Find most probable component (on poor batch). + poor_cprobs = [] + for bs in range(n_poor_batches): + poor_cprobs.append( + sess.run( + train_ops.cat_probs, + feed_dict={ + x_train_raw: + poor_data_buffer[bs * batch_size:(bs + 1) * + batch_size] + })) + best_cluster = np.argmax(np.sum(np.vstack(poor_cprobs), axis=0)) + + # Initialize parameters of the new component from most prob + # existing. + new_cluster = n_y_active_np + + copy_component_params(best_cluster, new_cluster, sess, + **dynamic_ops) + + # Increment mixture component count n_y_active. + n_y_active_np += 1 + n_y_active.load(n_y_active_np, sess) + + # Perform a number of steps of gradient descent on the data buffer, + # training only the new component (supervised loss). + for _ in range(num_buffer_train_steps): + for bs in range(n_poor_batches): + x_batch = poor_data_buffer[bs * batch_size:(bs + 1) * + batch_size] + label_batch = poor_data_labels[bs * batch_size:(bs + 1) * + batch_size] + label_onehot_batch = np.eye(n_y)[label_batch] + _ = sess.run( + train_step_supervised, + feed_dict={ + x_train_raw: x_batch, + model_train.y_label: label_onehot_batch + }) + + # Empty the buffer. + poor_data_buffer = [] + poor_data_labels = [] + + # Reset the threshold flag so we have a burn in before the next + # component. + eligible_for_expansion = False + has_expanded = True + steps_since_expansion = 0 + + # Accumulate counts. + if used_real_data: + train_cat_probs_vals = results['train_probs'] + recent_component_counts += np.sum( + train_cat_probs_vals, axis=0).astype(float) + + ### 6) LOGGING AND EVALUATION ### + cleanup_for_print = lambda x: ', {}: %.{}f'.format( + x.capitalize().replace('_', ' '), 3) + log_str = 'Iteration %d' + log_str += ''.join([cleanup_for_print(el) for el in curr_to_log]) + log_str += ' n_active: %d' + logging.info( + log_str, + *([step] + [results[el] for el in curr_to_log] + [n_y_active_np])) + + # Periodically perform evaluation + if (step + 1) % report_interval == 0: + + # Report test purity and related measures + logging.info( + 'Iteration %d, Test purity: %.3f, Test ELBO: %.3f, Test ' + 'KLy: %.3f, Test KLz: %.3f', step, results['test_purity'], + results['test_ELBO'], results['test_kl_y'], results['test_kl_z']) + # Flush data only once in a while to allow buffering of data for more + # efficient writes. + results['all_full_poor_data_buffers'] = all_full_poor_data_buffers + results['all_full_poor_data_labels'] = all_full_poor_data_labels + logging.info('Also training a classifier in latent space') + + # Perform knn classification from latents, to evaluate discriminability. + + # Get and encode training and test datasets. + clf_train_vals = process_dataset( + dataset_ops.train_iter_for_clf, { + 'latents': latents_for_clf, + 'labels': train_data_for_clf[label_key] + }, + sess, + feed_dict, + aggregation_ops=np.concatenate) + clf_test_vals = process_dataset( + dataset_ops.test_iter, { + 'latents': test_ops.latents, + 'labels': test_data[label_key] + }, + sess, + aggregation_ops=np.concatenate) + + # Perform knn classification. + knn_models = [] + for nval in knn_values: + # Fit training dataset. + clf = neighbors.KNeighborsClassifier(n_neighbors=nval) + clf.fit(clf_train_vals['latents'], clf_train_vals['labels']) + knn_models.append(clf) + + results['train_' + str(nval) + 'nn_acc'] = clf.score( + clf_train_vals['latents'], clf_train_vals['labels']) + + # Get test performance. + results['test_' + str(nval) + 'nn_acc'] = clf.score( + clf_test_vals['latents'], clf_test_vals['labels']) + + logging.info( + 'Iteration %d %d-NN classifier accuracies, Training: ' + '%.3f, Test: %.3f', step, nval, + results['train_' + str(nval) + 'nn_acc'], + results['test_' + str(nval) + 'nn_acc']) diff --git a/curl/unit_test.py b/curl/unit_test.py new file mode 100644 index 0000000..592a55a --- /dev/null +++ b/curl/unit_test.py @@ -0,0 +1,67 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Tests for curl.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + +from curl import training + + +class TrainingTest(absltest.TestCase): + + def testRunTraining(self): + + training.run_training( + dataset='mnist', + output_type='bernoulli', + n_y=10, + n_y_active=1, + training_data_type='sequential', + n_concurrent_classes=1, + lr_init=1e-3, + lr_factor=1., + lr_schedule=[1], + blend_classes=False, + train_supervised=False, + n_steps=1000, + report_interval=1000, + knn_values=[3], + random_seed=1, + encoder_kwargs={ + 'encoder_type': 'multi', + 'n_enc': [1200, 600, 300, 150], + 'enc_strides': [1], + }, + decoder_kwargs={ + 'decoder_type': 'single', + 'n_dec': [500, 500], + 'dec_up_strides': None, + }, + n_z=32, + dynamic_expansion=True, + ll_thresh=-200.0, + classify_with_samples=False, + gen_replay_type='fixed', + use_supervised_replay=False, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/curl/utils.py b/curl/utils.py new file mode 100644 index 0000000..33ea298 --- /dev/null +++ b/curl/utils.py @@ -0,0 +1,85 @@ +################################################################################ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Some common utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging +import tensorflow as tf +import tensorflow_probability as tfp + + +def generate_gaussian(logits, sigma_nonlin, sigma_param): + """Generate a Gaussian distribution given a selected parameterisation.""" + + mu, sigma = tf.split(value=logits, num_or_size_splits=2, axis=1) + + if sigma_nonlin == 'exp': + sigma = tf.exp(sigma) + elif sigma_nonlin == 'softplus': + sigma = tf.nn.softplus(sigma) + else: + raise ValueError('Unknown sigma_nonlin {}'.format(sigma_nonlin)) + + if sigma_param == 'var': + sigma = tf.sqrt(sigma) + elif sigma_param != 'std': + raise ValueError('Unknown sigma_param {}'.format(sigma_param)) + + return tfp.distributions.Normal(loc=mu, scale=sigma) + + +def construct_prior_probs(batch_size, n_y, n_y_active): + """Construct the uniform prior probabilities. + + Args: + batch_size: int, the size of the batch. + n_y: int, the number of categorical cluster components. + n_y_active: tf.Variable, the number of components that are currently in use. + + Returns: + Tensor representing the prior probability matrix, size of [batch_size, n_y]. + """ + probs = tf.ones((batch_size, n_y_active)) / tf.cast( + n_y_active, dtype=tf.float32) + paddings1 = tf.stack([tf.constant(0), tf.constant(0)], axis=0) + paddings2 = tf.stack([tf.constant(0), n_y - n_y_active], axis=0) + paddings = tf.stack([paddings1, paddings2], axis=1) + probs = tf.pad(probs, paddings, constant_values=1e-12) + probs.set_shape((batch_size, n_y)) + logging.info('Prior shape: %s', str(probs.shape)) + return probs + + +def maybe_center_crop(layer, target_hw): + """Center crop the layer to match a target shape.""" + l_height, l_width = layer.shape.as_list()[1:3] + t_height, t_width = target_hw + assert t_height <= l_height and t_width <= l_width + + if (l_height - t_height) % 2 != 0 or (l_width - t_width) % 2 != 0: + logging.warn( + 'It is impossible to center-crop [%d, %d] into [%d, %d].' + ' Crop will be uneven.', t_height, t_width, l_height, l_width) + + border = int((l_height - t_height) / 2) + x_0, x_1 = border, l_height - border + border = int((l_width - t_width) / 2) + y_0, y_1 = border, l_width - border + layer_cropped = layer[:, x_0:x_1, y_0:y_1, :] + return layer_cropped