Files
deepmind-research/curl/training.py
T
Dushyant Rao 656f93a37e Apply minor fixes to README
PiperOrigin-RevId: 279101132
2019-11-19 11:34:24 +00:00

1170 lines
44 KiB
Python

################################################################################
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Script to train CURL."""
import collections
import functools
from absl import logging
import numpy as np
from sklearn import neighbors
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
from curl import model
from curl import utils
tfc = tf.compat.v1
# pylint: disable=g-long-lambda
MainOps = collections.namedtuple('MainOps', [
'elbo', 'll', 'log_p_x', 'kl_y', 'kl_z', 'elbo_supervised', 'll_supervised',
'log_p_x_supervised', 'kl_y_supervised', 'kl_z_supervised',
'cat_probs', 'confusion', 'purity', 'latents'
])
DatasetTuple = collections.namedtuple('DatasetTuple', [
'train_data', 'train_iter_for_clf', 'train_data_for_clf',
'valid_iter', 'valid_data', 'test_iter', 'test_data', 'ds_info'
])
def compute_purity(confusion):
return np.sum(np.max(confusion, axis=0)).astype(float) / np.sum(confusion)
def process_dataset(iterator,
ops_to_run,
sess,
feed_dict=None,
aggregation_ops=np.stack,
processing_ops=None):
"""Process a dataset by computing ops and accumulating batch by batch.
Args:
iterator: iterator through the dataset.
ops_to_run: dict, tf ops to run as part of dataset processing.
sess: tf.Session to use.
feed_dict: dict, required placeholders.
aggregation_ops: fn or dict of fns, aggregation op to apply for each op.
processing_ops: fn or dict of fns, extra processing op to apply for each op.
Returns:
Results accumulated over dataset.
"""
if not isinstance(ops_to_run, dict):
raise TypeError('ops_to_run must be specified as a dict')
if not isinstance(aggregation_ops, dict):
aggregation_ops = {k: aggregation_ops for k in ops_to_run}
if not isinstance(processing_ops, dict):
processing_ops = {k: processing_ops for k in ops_to_run}
out_results = collections.OrderedDict()
sess.run(iterator.initializer)
while True:
# Iterate over the whole dataset and append the results to a per-key list.
try:
outs = sess.run(ops_to_run, feed_dict=feed_dict)
for key, value in outs.items():
out_results.setdefault(key, []).append(value)
except tf.errors.OutOfRangeError: # end of dataset iterator
break
# Aggregate and process results.
for key, value in out_results.items():
if aggregation_ops[key]:
out_results[key] = aggregation_ops[key](value)
if processing_ops[key]:
out_results[key] = processing_ops[key](out_results[key], axis=0)
return out_results
def get_data_sources(dataset, dataset_kwargs, batch_size, test_batch_size,
training_data_type, n_concurrent_classes, image_key,
label_key):
"""Create and return data sources for training, validation, and testing.
Args:
dataset: str, name of dataset ('mnist', 'omniglot', etc).
dataset_kwargs: dict, kwargs used in tf dataset constructors.
batch_size: int, batch size used for training.
test_batch_size: int, batch size used for evaluation.
training_data_type: str, how training data is seen ('iid', or 'sequential').
n_concurrent_classes: int, # classes seen at a time (ignored for 'iid').
image_key: str, name if image key in dataset.
label_key: str, name of label key in dataset.
Returns:
A namedtuple containing all of the dataset iterators and batches.
"""
# Load training data sources
ds_train, ds_info = tfds.load(
name=dataset,
split=tfds.Split.TRAIN,
with_info=True,
as_dataset_kwargs={'shuffle_files': False},
**dataset_kwargs)
# Validate assumption that data is in [0, 255]
assert ds_info.features[image_key].dtype == tf.uint8
n_classes = ds_info.features[label_key].num_classes
num_train_examples = ds_info.splits['train'].num_examples
def preprocess_data(x):
"""Convert images from uint8 in [0, 255] to float in [0, 1]."""
x[image_key] = tf.image.convert_image_dtype(x[image_key], tf.float32)
return x
if training_data_type == 'sequential':
c = None # The index of the class number, None for now and updated later
if n_concurrent_classes == 1:
filter_fn = lambda v: tf.equal(v[label_key], c)
else:
# Define the lowest and highest class number at each data period.
assert n_classes % n_concurrent_classes == 0, (
'Number of total classes must be divisible by '
'number of concurrent classes')
cmin = []
cmax = []
for i in range(int(n_classes / n_concurrent_classes)):
for _ in range(n_concurrent_classes):
cmin.append(i * n_concurrent_classes)
cmax.append((i + 1) * n_concurrent_classes)
filter_fn = lambda v: tf.logical_and(
tf.greater_equal(v[label_key], cmin[c]), tf.less(
v[label_key], cmax[c]))
# Set up data sources/queues (one for each class).
train_datasets = []
train_iterators = []
train_data = []
full_ds = ds_train.repeat().shuffle(num_train_examples, seed=0)
full_ds = full_ds.map(preprocess_data)
for c in range(n_classes):
filtered_ds = full_ds.filter(filter_fn).batch(
batch_size, drop_remainder=True)
train_datasets.append(filtered_ds)
train_iterators.append(train_datasets[-1].make_one_shot_iterator())
train_data.append(train_iterators[-1].get_next())
else: # not sequential
full_ds = ds_train.repeat().shuffle(num_train_examples, seed=0)
full_ds = full_ds.map(preprocess_data)
train_datasets = full_ds.batch(batch_size, drop_remainder=True)
train_data = train_datasets.make_one_shot_iterator().get_next()
# Set up data source to get full training set for classifier training
full_ds = ds_train.repeat(1).shuffle(num_train_examples, seed=0)
full_ds = full_ds.map(preprocess_data)
train_datasets_for_classifier = full_ds.batch(
test_batch_size, drop_remainder=True)
train_iter_for_classifier = (
train_datasets_for_classifier.make_initializable_iterator())
train_data_for_classifier = train_iter_for_classifier.get_next()
# Load validation dataset.
try:
valid_dataset = tfds.load(
name=dataset, split=tfds.Split.VALIDATION, **dataset_kwargs)
num_valid_examples = ds_info.splits[tfds.Split.VALIDATION].num_examples
assert (num_valid_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
num_valid_examples)
valid_dataset = valid_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
valid_dataset = valid_dataset.map(preprocess_data)
valid_iter = valid_dataset.make_initializable_iterator()
valid_data = valid_iter.get_next()
except KeyError:
logging.warning('No validation set!!')
valid_iter = None
valid_data = None
# Load test dataset.
test_dataset = tfds.load(
name=dataset, split=tfds.Split.TEST, **dataset_kwargs)
num_test_examples = ds_info.splits['test'].num_examples
assert (num_test_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
num_test_examples)
test_dataset = test_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
test_dataset = test_dataset.map(preprocess_data)
test_iter = test_dataset.make_initializable_iterator()
test_data = test_iter.get_next()
logging.info('Loaded %s data', dataset)
return DatasetTuple(train_data, train_iter_for_classifier,
train_data_for_classifier, valid_iter, valid_data,
test_iter, test_data, ds_info)
def setup_training_and_eval_graphs(x, label, y, n_y, curl_model,
classify_with_samples, is_training, name):
"""Set up the graph and return ops for training or evaluation.
Args:
x: tf placeholder for image.
label: tf placeholder for ground truth label.
y: tf placeholder for some self-supervised label/prediction.
n_y: int, dimensionality of discrete latent variable y.
curl_model: snt.AbstractModule representing the CURL model.
classify_with_samples: bool, whether to *sample* latents for classification.
is_training: bool, whether this graph is the training graph.
name: str, graph name.
Returns:
A namedtuple with the required graph ops to perform training or evaluation.
"""
# kl_y_supervised is -log q(y=y_true | x)
(log_p_x, kl_y, kl_z, log_p_x_supervised, kl_y_supervised,
kl_z_supervised) = curl_model.log_prob_elbo_components(x, y)
ll = log_p_x - kl_y - kl_z
elbo = -tf.reduce_mean(ll)
# Supervised loss, either for SMGR, or adaptation to supervised benchmark.
ll_supervised = log_p_x_supervised - kl_y_supervised - kl_z_supervised
elbo_supervised = -tf.reduce_mean(ll_supervised)
# Summaries
kl_y = tf.reduce_mean(kl_y)
kl_z = tf.reduce_mean(kl_z)
log_p_x_supervised = tf.reduce_mean(log_p_x_supervised)
kl_y_supervised = tf.reduce_mean(kl_y_supervised)
kl_z_supervised = tf.reduce_mean(kl_z_supervised)
# Evaluation.
hiddens = curl_model.get_shared_rep(x, is_training=is_training)
cat = curl_model.infer_cluster(hiddens)
cat_probs = cat.probs
confusion = tf.confusion_matrix(label, tf.argmax(cat_probs, axis=1),
num_classes=n_y, name=name + '_confusion')
purity = (tf.reduce_sum(tf.reduce_max(confusion, axis=0))
/ tf.reduce_sum(confusion))
if classify_with_samples:
latents = curl_model.infer_latent(
hiddens=hiddens, y=tf.to_float(cat.sample())).sample()
else:
latents = curl_model.infer_latent(
hiddens=hiddens, y=tf.to_float(cat.mode())).mean()
return MainOps(elbo, ll, log_p_x, kl_y, kl_z, elbo_supervised, ll_supervised,
log_p_x_supervised, kl_y_supervised, kl_z_supervised,
cat_probs, confusion, purity, latents)
def get_generated_data(sess, gen_op, y_input, gen_buffer_size,
component_counts):
"""Get generated model data (in place of saving a model snapshot).
Args:
sess: tf.Session.
gen_op: tf op representing a batch of generated data.
y_input: tf placeholder for which mixture components to generate from.
gen_buffer_size: int, number of data points to generate.
component_counts: np.array, prior probabilities over components.
Returns:
A tuple of two numpy arrays
The generated data
The corresponding labels
"""
batch_size, n_y = y_input.shape.as_list()
# Sample based on the history of all components used.
cluster_sample_probs = component_counts.astype(float)
cluster_sample_probs = np.maximum(1e-12, cluster_sample_probs)
cluster_sample_probs = cluster_sample_probs / np.sum(cluster_sample_probs)
# Now generate the data based on the specified cluster prior.
gen_buffer_images = []
gen_buffer_labels = []
for _ in range(gen_buffer_size):
gen_label = np.random.choice(
np.arange(n_y),
size=(batch_size,),
replace=True,
p=cluster_sample_probs)
y_gen_posterior_vals = np.zeros((batch_size, n_y))
y_gen_posterior_vals[np.arange(batch_size), gen_label] = 1
gen_image = sess.run(gen_op, feed_dict={y_input: y_gen_posterior_vals})
gen_buffer_images.append(gen_image)
gen_buffer_labels.append(gen_label)
gen_buffer_images = np.vstack(gen_buffer_images)
gen_buffer_labels = np.concatenate(gen_buffer_labels)
return gen_buffer_images, gen_buffer_labels
def setup_dynamic_ops(n_y):
"""Set up ops to move / copy mixture component weights for dynamic expansion.
Args:
n_y: int, dimensionality of discrete latent variable y.
Returns:
A dict containing all of the ops required for dynamic updating.
"""
# Set up graph ops to dynamically modify component params.
graph = tf.get_default_graph()
# 1) Ops to get and set latent encoder params (entire tensors)
latent_enc_tensors = {}
for k in range(n_y):
latent_enc_tensors['latent_w_' + str(k)] = graph.get_tensor_by_name(
'latent_encoder/mlp_latent_encoder_{}/w:0'.format(k))
latent_enc_tensors['latent_b_' + str(k)] = graph.get_tensor_by_name(
'latent_encoder/mlp_latent_encoder_{}/b:0'.format(k))
latent_enc_assign_ops = {}
latent_enc_phs = {}
for key, tensor in latent_enc_tensors.items():
latent_enc_phs[key] = tfc.placeholder(tensor.dtype, tensor.shape)
latent_enc_assign_ops[key] = tf.assign(tensor, latent_enc_phs[key])
# 2) Ops to get and set cluster encoder params (columns of a tensor)
# We will be copying column ind_from to column ind_to.
cluster_w = graph.get_tensor_by_name(
'cluster_encoder/mlp_cluster_encoder_final/w:0')
cluster_b = graph.get_tensor_by_name(
'cluster_encoder/mlp_cluster_encoder_final/b:0')
ind_from = tfc.placeholder(dtype=tf.int32)
ind_to = tfc.placeholder(dtype=tf.int32)
# Determine indices of cluster encoder weights and biases to be updated
w_indices = tf.transpose(
tf.stack([
tf.range(cluster_w.shape[0], dtype=tf.int32),
ind_to * tf.ones(shape=(cluster_w.shape[0],), dtype=tf.int32)
]))
b_indices = ind_to
# Determine updates themselves
cluster_w_updates = tf.squeeze(
tf.slice(cluster_w, begin=(0, ind_from), size=(cluster_w.shape[0], 1)))
cluster_b_updates = cluster_b[ind_from]
# Create update ops
cluster_w_update_op = tf.scatter_nd_update(cluster_w, w_indices,
cluster_w_updates)
cluster_b_update_op = tf.scatter_update(cluster_b, b_indices,
cluster_b_updates)
# 3) Ops to get and set latent prior params (columns of a tensor)
# We will be copying column ind_from to column ind_to.
latent_prior_mu_w = graph.get_tensor_by_name(
'latent_decoder/latent_prior_mu/w:0')
latent_prior_sigma_w = graph.get_tensor_by_name(
'latent_decoder/latent_prior_sigma/w:0')
mu_indices = tf.transpose(
tf.stack([
ind_to * tf.ones(shape=(latent_prior_mu_w.shape[1],), dtype=tf.int32),
tf.range(latent_prior_mu_w.shape[1], dtype=tf.int32)
]))
mu_updates = tf.squeeze(
tf.slice(
latent_prior_mu_w,
begin=(ind_from, 0),
size=(1, latent_prior_mu_w.shape[1])))
mu_update_op = tf.scatter_nd_update(latent_prior_mu_w, mu_indices, mu_updates)
sigma_indices = tf.transpose(
tf.stack([
ind_to *
tf.ones(shape=(latent_prior_sigma_w.shape[1],), dtype=tf.int32),
tf.range(latent_prior_sigma_w.shape[1], dtype=tf.int32)
]))
sigma_updates = tf.squeeze(
tf.slice(
latent_prior_sigma_w,
begin=(ind_from, 0),
size=(1, latent_prior_sigma_w.shape[1])))
sigma_update_op = tf.scatter_nd_update(latent_prior_sigma_w, sigma_indices,
sigma_updates)
dynamic_ops = {
'ind_from_ph': ind_from,
'ind_to_ph': ind_to,
'latent_enc_tensors': latent_enc_tensors,
'latent_enc_assign_ops': latent_enc_assign_ops,
'latent_enc_phs': latent_enc_phs,
'cluster_w_update_op': cluster_w_update_op,
'cluster_b_update_op': cluster_b_update_op,
'mu_update_op': mu_update_op,
'sigma_update_op': sigma_update_op
}
return dynamic_ops
def copy_component_params(ind_from, ind_to, sess, ind_from_ph, ind_to_ph,
latent_enc_tensors, latent_enc_assign_ops,
latent_enc_phs,
cluster_w_update_op, cluster_b_update_op,
mu_update_op, sigma_update_op):
"""Copy parameters from component i to component j.
Args:
ind_from: int, component index to copy from.
ind_to: int, component index to copy to.
sess: tf.Session.
ind_from_ph: tf placeholder for component to copy from.
ind_to_ph: tf placeholder for component to copy to.
latent_enc_tensors: dict, tensors in the latent posterior encoder.
latent_enc_assign_ops: dict, assignment ops for latent posterior encoder.
latent_enc_phs: dict, placeholders for assignment ops.
cluster_w_update_op: op for updating weights of cluster encoder.
cluster_b_update_op: op for updating biased of cluster encoder.
mu_update_op: op for updating mu weights of latent prior.
sigma_update_op: op for updating sigma weights of latent prior.
"""
update_ops = []
feed_dict = {}
# Copy for latent encoder.
new_w_val, new_b_val = sess.run([
latent_enc_tensors['latent_w_' + str(ind_from)],
latent_enc_tensors['latent_b_' + str(ind_from)]
])
update_ops.extend([
latent_enc_assign_ops['latent_w_' + str(ind_to)],
latent_enc_assign_ops['latent_b_' + str(ind_to)]
])
feed_dict.update({
latent_enc_phs['latent_w_' + str(ind_to)]: new_w_val,
latent_enc_phs['latent_b_' + str(ind_to)]: new_b_val
})
# Copy for cluster encoder softmax.
update_ops.extend([cluster_w_update_op, cluster_b_update_op])
feed_dict.update({ind_from_ph: ind_from, ind_to_ph: ind_to})
# Copy for latent prior.
update_ops.extend([mu_update_op, sigma_update_op])
feed_dict.update({ind_from_ph: ind_from, ind_to_ph: ind_to})
sess.run(update_ops, feed_dict)
def run_training(
dataset,
training_data_type,
n_concurrent_classes,
blend_classes,
train_supervised,
n_steps,
random_seed,
lr_init,
lr_factor,
lr_schedule,
output_type,
n_y,
n_y_active,
n_z,
encoder_kwargs,
decoder_kwargs,
dynamic_expansion,
ll_thresh,
classify_with_samples,
report_interval,
knn_values,
gen_replay_type,
use_supervised_replay):
"""Run training script.
Args:
dataset: str, name of the dataset.
training_data_type: str, type of training run ('iid' or 'sequential').
n_concurrent_classes: int, # of classes seen at a time (ignored for 'iid').
blend_classes: bool, whether to blend in samples from the next class.
train_supervised: bool, whether to use supervision during training.
n_steps: int, number of total training steps.
random_seed: int, seed for tf and numpy RNG.
lr_init: float, initial learning rate.
lr_factor: float, learning rate decay factor.
lr_schedule: float, epochs at which the decay should be applied.
output_type: str, output distribution (currently only 'bernoulli').
n_y: int, maximum possible dimensionality of discrete latent variable y.
n_y_active: int, starting dimensionality of discrete latent variable y.
n_z: int, dimensionality of continuous latent variable z.
encoder_kwargs: dict, parameters to specify encoder.
decoder_kwargs: dict, parameters to specify decoder.
dynamic_expansion: bool, whether to perform dynamic expansion.
ll_thresh: float, log-likelihood threshold below which to keep poor samples.
classify_with_samples: bool, whether to sample latents when classifying.
report_interval: int, number of steps after which to evaluate and report.
knn_values: list of ints, k values for different k-NN classifiers to run
(values of 3, 5, and 10 were used in different parts of the paper).
gen_replay_type: str, 'fixed', 'dynamic', or None.
use_supervised_replay: str, whether to use supervised replay (aka 'SMGR').
"""
# Set tf random seed.
tfc.set_random_seed(random_seed)
np.set_printoptions(precision=2, suppress=True)
# First set up the data source(s) and get dataset info.
if dataset == 'mnist':
batch_size = 100
test_batch_size = 1000
dataset_kwargs = {}
image_key = 'image'
label_key = 'label'
elif dataset == 'omniglot':
batch_size = 15
test_batch_size = 8115
dataset_kwargs = {'split': 'instance', 'label': 'alphabet'}
image_key = 'image'
label_key = 'alphabet'
else:
raise NotImplementedError
dataset_ops = get_data_sources(dataset, dataset_kwargs, batch_size,
test_batch_size, training_data_type,
n_concurrent_classes, image_key, label_key)
train_data = dataset_ops.train_data
train_data_for_clf = dataset_ops.train_data_for_clf
valid_data = dataset_ops.valid_data
test_data = dataset_ops.test_data
output_shape = dataset_ops.ds_info.features[image_key].shape
n_x = np.prod(output_shape)
n_classes = dataset_ops.ds_info.features[label_key].num_classes
num_train_examples = dataset_ops.ds_info.splits['train'].num_examples
# Check that the number of classes is compatible with the training scenario
assert n_classes % n_concurrent_classes == 0
assert n_steps % (n_classes / n_concurrent_classes) == 0
# Set specific params depending on the type of gen replay
if gen_replay_type == 'fixed':
data_period = data_period = int(n_steps /
(n_classes / n_concurrent_classes))
gen_every_n = 2 # Blend in a gen replay batch every 2 steps
gen_refresh_period = data_period # How often to refresh the batches of
# generated data (equivalent to snapshotting a generative model)
gen_refresh_on_expansion = False # Don't refresh on dyn expansion
elif gen_replay_type == 'dynamic':
gen_every_n = 2 # Blend in a gen replay batch every 2 steps
gen_refresh_period = 1e8 # Never refresh generated data periodically
gen_refresh_on_expansion = True # Refresh on dyn expansion instead
elif gen_replay_type is None:
gen_every_n = 0 # Don't use any gen replay batches
gen_refresh_period = 1e8 # Never refresh generated data periodically
gen_refresh_on_expansion = False # Don't refresh on dyn expansion
else:
raise NotImplementedError
max_gen_batches = 5000 # Max num of gen batches (proxy for storing a model)
# Set dynamic expansion parameters
exp_wait_steps = 100 # Steps to wait after expansion before eligible again
exp_burn_in = 100 # Steps to wait at start of learning before eligible
exp_buffer_size = 100 # Size of the buffer of poorly explained data
num_buffer_train_steps = 20 # Num steps to train component on buffer
# Define a global tf variable for the number of active components.
n_y_active_np = n_y_active
n_y_active = tfc.get_variable(
initializer=tf.constant(n_y_active_np, dtype=tf.int32),
trainable=False,
name='n_y_active',
dtype=tf.int32)
logging.info('Starting CURL script on %s data.', dataset)
# Set up placeholders for training.
x_train_raw = tfc.placeholder(
dtype=tf.float32, shape=(batch_size,) + output_shape)
label_train = tfc.placeholder(dtype=tf.int32, shape=(batch_size,))
def binarize_fn(x):
"""Binarize a Bernoulli by rounding the probabilities.
Args:
x: tf tensor, input image.
Returns:
A tf tensor with the binarized image
"""
return tf.cast(tf.greater(x, 0.5 * tf.ones_like(x)), tf.float32)
if dataset == 'mnist':
x_train = binarize_fn(x_train_raw)
x_valid = binarize_fn(valid_data[image_key]) if valid_data else None
x_test = binarize_fn(test_data[image_key])
x_train_for_clf = binarize_fn(train_data_for_clf[image_key])
elif 'cifar' in dataset or dataset == 'omniglot':
x_train = x_train_raw
x_valid = valid_data[image_key] if valid_data else None
x_test = test_data[image_key]
x_train_for_clf = train_data_for_clf[image_key]
else:
raise ValueError('Unknown dataset {}'.format(dataset))
label_valid = valid_data[label_key] if valid_data else None
label_test = test_data[label_key]
# Set up CURL modules.
shared_encoder = model.SharedEncoder(name='shared_encoder', **encoder_kwargs)
latent_encoder = functools.partial(model.latent_encoder_fn, n_y=n_y, n_z=n_z)
latent_encoder = snt.Module(latent_encoder, name='latent_encoder')
latent_decoder = functools.partial(model.latent_decoder_fn, n_z=n_z)
latent_decoder = snt.Module(latent_decoder, name='latent_decoder')
cluster_encoder = functools.partial(
model.cluster_encoder_fn, n_y_active=n_y_active, n_y=n_y)
cluster_encoder = snt.Module(cluster_encoder, name='cluster_encoder')
data_decoder = functools.partial(
model.data_decoder_fn,
output_type=output_type,
output_shape=output_shape,
n_x=n_x,
n_y=n_y,
**decoder_kwargs)
data_decoder = snt.Module(data_decoder, name='data_decoder')
# Uniform prior over y.
prior_train_probs = utils.construct_prior_probs(batch_size, n_y, n_y_active)
prior_train = snt.Module(
lambda: tfp.distributions.OneHotCategorical(probs=prior_train_probs),
name='prior_unconditional_train')
prior_test_probs = utils.construct_prior_probs(test_batch_size, n_y,
n_y_active)
prior_test = snt.Module(
lambda: tfp.distributions.OneHotCategorical(probs=prior_test_probs),
name='prior_unconditional_test')
model_train = model.Curl(
prior_train,
latent_decoder,
data_decoder,
shared_encoder,
cluster_encoder,
latent_encoder,
n_y_active,
is_training=True,
name='curl_train')
model_eval = model.Curl(
prior_test,
latent_decoder,
data_decoder,
shared_encoder,
cluster_encoder,
latent_encoder,
n_y_active,
is_training=False,
name='curl_test')
# Set up training graph
y_train = label_train if train_supervised else None
y_valid = label_valid if train_supervised else None
y_test = label_test if train_supervised else None
train_ops = setup_training_and_eval_graphs(
x_train,
label_train,
y_train,
n_y,
model_train,
classify_with_samples,
is_training=True,
name='train')
hiddens_for_clf = model_eval.get_shared_rep(x_train_for_clf,
is_training=False)
cat_for_clf = model_eval.infer_cluster(hiddens_for_clf)
if classify_with_samples:
latents_for_clf = model_eval.infer_latent(
hiddens=hiddens_for_clf, y=tf.to_float(cat_for_clf.sample())).sample()
else:
latents_for_clf = model_eval.infer_latent(
hiddens=hiddens_for_clf, y=tf.to_float(cat_for_clf.mode())).mean()
# Set up validation graph
if valid_data is not None:
valid_ops = setup_training_and_eval_graphs(
x_valid,
label_valid,
y_valid,
n_y,
model_eval,
classify_with_samples,
is_training=False,
name='valid')
# Set up test graph
test_ops = setup_training_and_eval_graphs(
x_test,
label_test,
y_test,
n_y,
model_eval,
classify_with_samples,
is_training=False,
name='test')
# Set up optimizer (with scheduler).
global_step = tf.train.get_or_create_global_step()
lr_schedule = [
tf.cast(el * num_train_examples / batch_size, tf.int64)
for el in lr_schedule
]
num_schedule_steps = tf.reduce_sum(
tf.cast(global_step >= lr_schedule, tf.float32))
lr = float(lr_init) * float(lr_factor)**num_schedule_steps
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_step = optimizer.minimize(train_ops.elbo)
train_step_supervised = optimizer.minimize(train_ops.elbo_supervised)
# Set up ops for generative replay
if gen_every_n > 0:
# How many generative batches will we use each period?
gen_buffer_size = min(
int(gen_refresh_period / gen_every_n), max_gen_batches)
# Class each sample should be drawn from (default to uniform prior)
y_gen = tfp.distributions.OneHotCategorical(
probs=np.ones((batch_size, n_y)) / n_y,
dtype=tf.float32,
name='extra_train_classes').sample()
gen_samples = model_train.sample(y=y_gen, mean=True)
if dataset == 'mnist' or dataset == 'omniglot':
gen_samples = binarize_fn(gen_samples)
# Set up ops to dynamically modify parameters (for dynamic expansion)
dynamic_ops = setup_dynamic_ops(n_y)
logging.info('Created computation graph.')
n_steps_per_class = n_steps / n_classes # pylint: disable=invalid-name
cumulative_component_counts = np.array([0] * n_y).astype(float)
recent_component_counts = np.array([0] * n_y).astype(float)
gen_buffer_ind = 0
# Buffer of poorly explained data (if we're doing dynamic expansion).
poor_data_buffer = []
poor_data_labels = []
all_full_poor_data_buffers = []
all_full_poor_data_labels = []
has_expanded = False
steps_since_expansion = 0
gen_buffer_ind = 0
eligible_for_expansion = False # Flag to ensure we wait a bit after expansion
# Set up basic ops to run and quantities to log.
ops_to_run = {
'train_ELBO': train_ops.elbo,
'train_log_p_x': train_ops.log_p_x,
'train_kl_y': train_ops.kl_y,
'train_kl_z': train_ops.kl_z,
'train_ll': train_ops.ll,
'train_batch_purity': train_ops.purity,
'train_probs': train_ops.cat_probs,
'n_y_active': n_y_active
}
if valid_data is not None:
valid_ops_to_run = {
'valid_ELBO': valid_ops.elbo,
'valid_kl_y': valid_ops.kl_y,
'valid_kl_z': valid_ops.kl_z,
'valid_confusion': valid_ops.confusion
}
else:
valid_ops_to_run = {}
test_ops_to_run = {
'test_ELBO': test_ops.elbo,
'test_kl_y': test_ops.kl_y,
'test_kl_z': test_ops.kl_z,
'test_confusion': test_ops.confusion
}
to_log = ['train_batch_purity']
to_log_eval = ['test_purity', 'test_ELBO', 'test_kl_y', 'test_kl_z']
if valid_data is not None:
to_log_eval += ['valid_ELBO', 'valid_purity']
if train_supervised:
# Track supervised losses, train on supervised loss.
ops_to_run.update({
'train_ELBO_supervised': train_ops.elbo_supervised,
'train_log_p_x_supervised': train_ops.log_p_x_supervised,
'train_kl_y_supervised': train_ops.kl_y_supervised,
'train_kl_z_supervised': train_ops.kl_z_supervised,
'train_ll_supervised': train_ops.ll_supervised
})
default_train_step = train_step_supervised
to_log += [
'train_ELBO_supervised', 'train_log_p_x_supervised',
'train_kl_y_supervised', 'train_kl_z_supervised'
]
else:
# Track unsupervised losses, train on unsupervised loss.
ops_to_run.update({
'train_ELBO': train_ops.elbo,
'train_kl_y': train_ops.kl_y,
'train_kl_z': train_ops.kl_z,
'train_ll': train_ops.ll
})
default_train_step = train_step
to_log += ['train_ELBO', 'train_kl_y', 'train_kl_z']
with tf.train.SingularMonitoredSession() as sess:
for step in range(n_steps):
feed_dict = {}
# Use the default training loss, but vary it each step depending on the
# training scenario (eg. for supervised gen replay, we alternate losses)
ops_to_run['train_step'] = default_train_step
### 1) PERIODICALLY TAKE SNAPSHOTS FOR GENERATIVE REPLAY ###
if (gen_refresh_period and step % gen_refresh_period == 0 and
gen_every_n > 0):
# First, increment cumulative count and reset recent probs count.
cumulative_component_counts += recent_component_counts
recent_component_counts = np.zeros(n_y)
# Generate enough samples for the rest of the next period
# (Functionally equivalent to storing and sampling from the model).
gen_buffer_images, gen_buffer_labels = get_generated_data(
sess=sess,
gen_op=gen_samples,
y_input=y_gen,
gen_buffer_size=gen_buffer_size,
component_counts=cumulative_component_counts)
### 2) DECIDE WHICH DATA SOURCE TO USE (GENERATIVE OR REAL DATA) ###
periodic_refresh_started = (
gen_refresh_period and step >= gen_refresh_period)
refresh_on_expansion_started = (gen_refresh_on_expansion and has_expanded)
if ((periodic_refresh_started or refresh_on_expansion_started) and
gen_every_n > 0 and step % gen_every_n == 1):
# Use generated data for the training batch
used_real_data = False
s = gen_buffer_ind * batch_size
e = (gen_buffer_ind + 1) * batch_size
gen_data_array = {
'image': gen_buffer_images[s:e],
'label': gen_buffer_labels[s:e]
}
gen_buffer_ind = (gen_buffer_ind + 1) % gen_buffer_size
# Feed it as x_train because it's already reshaped and binarized.
feed_dict.update({
x_train: gen_data_array['image'],
label_train: gen_data_array['label']
})
if use_supervised_replay:
# Convert label to one-hot before feeding in.
gen_label_onehot = np.eye(n_y)[gen_data_array['label']]
feed_dict.update({model_train.y_label: gen_label_onehot})
ops_to_run['train_step'] = train_step_supervised
else:
# Else use the standard training data sources.
used_real_data = True
# Select appropriate data source for iid or sequential setup.
if training_data_type == 'sequential':
current_data_period = int(
min(step / n_steps_per_class, len(train_data) - 1))
# If training supervised, set n_y_active directly based on how many
# classes have been seen
if train_supervised:
assert not dynamic_expansion
n_y_active_np = n_concurrent_classes * (
current_data_period // n_concurrent_classes +1)
n_y_active.load(n_y_active_np, sess)
train_data_array = sess.run(train_data[current_data_period])
# If we are blending classes, figure out where we are in the data
# period and add some fraction of other samples.
if blend_classes:
# If in the first quarter, blend in examples from the previous class
if (step % n_steps_per_class < n_steps_per_class / 4 and
current_data_period > 0):
other_train_data_array = sess.run(
train_data[current_data_period - 1])
num_other = int(
(n_steps_per_class / 2 - 2 *
(step % n_steps_per_class)) * batch_size / n_steps_per_class)
other_inds = np.random.permutation(batch_size)[:num_other]
train_data_array[image_key][:num_other] = other_train_data_array[
image_key][other_inds]
train_data_array[label_key][:num_other] = other_train_data_array[
label_key][other_inds]
# If in the last quarter, blend in examples from the next class
elif (step % n_steps_per_class > 3 * n_steps_per_class / 4 and
current_data_period < n_classes - 1):
other_train_data_array = sess.run(train_data[current_data_period +
1])
num_other = int(
(2 * (step % n_steps_per_class) - 3 * n_steps_per_class / 2) *
batch_size / n_steps_per_class)
other_inds = np.random.permutation(batch_size)[:num_other]
train_data_array[image_key][:num_other] = other_train_data_array[
image_key][other_inds]
train_data_array['label'][:num_other] = other_train_data_array[
label_key][other_inds]
# Otherwise, just use the current class
else:
train_data_array = sess.run(train_data)
feed_dict.update({
x_train_raw: train_data_array[image_key],
label_train: train_data_array[label_key]
})
### 3) PERFORM A GRADIENT STEP ###
results = sess.run(ops_to_run, feed_dict=feed_dict)
del results['train_step']
### 4) COMPUTE ADDITIONAL DIAGNOSTIC OPS ON VALIDATION/TEST SETS. ###
if (step+1) % report_interval == 0:
if valid_data is not None:
logging.info('Evaluating on validation and test set!')
proc_ops = {
k: (np.sum if 'confusion' in k
else np.mean) for k in valid_ops_to_run
}
results.update(
process_dataset(
dataset_ops.valid_iter,
valid_ops_to_run,
sess,
feed_dict=feed_dict,
processing_ops=proc_ops))
results['valid_purity'] = compute_purity(results['valid_confusion'])
else:
logging.info('Evaluating on test set!')
proc_ops = {
k: (np.sum if 'confusion' in k
else np.mean) for k in test_ops_to_run
}
results.update(process_dataset(dataset_ops.test_iter,
test_ops_to_run,
sess,
feed_dict=feed_dict,
processing_ops=proc_ops))
results['test_purity'] = compute_purity(results['test_confusion'])
curr_to_log = to_log + to_log_eval
else:
curr_to_log = list(to_log) # copy to prevent in-place modifications
### 5) DYNAMIC EXPANSION ###
if dynamic_expansion and used_real_data:
# If we're doing dynamic expansion and below max capacity then add
# poorly defined data points to a buffer.
# First check whether the model is eligible for expansion (the model
# becomes ineligible for a fixed time after each expansion, and when
# it has hit max capacity).
if (steps_since_expansion >= exp_wait_steps and step >= exp_burn_in and
n_y_active_np < n_y):
eligible_for_expansion = True
steps_since_expansion += 1
if eligible_for_expansion:
# Add poorly explained data samples to a buffer.
poor_inds = results['train_ll'] < ll_thresh
poor_data_buffer.extend(feed_dict[x_train_raw][poor_inds])
poor_data_labels.extend(feed_dict[label_train][poor_inds])
n_poor_data = len(poor_data_buffer)
# If buffer is big enough, then add a new component and train just the
# new component with several steps of gradient descent.
# (We just feed in a onehot cluster vector to indicate which
# component).
if n_poor_data >= exp_buffer_size:
# Dump the buffers so we can log them.
all_full_poor_data_buffers.append(poor_data_buffer)
all_full_poor_data_labels.append(poor_data_labels)
# Take a new generative snapshot if specified.
if gen_refresh_on_expansion and gen_every_n > 0:
# Increment cumulative count and reset recent probs count.
cumulative_component_counts += recent_component_counts
recent_component_counts = np.zeros(n_y)
gen_buffer_images, gen_buffer_labels = get_generated_data(
sess=sess,
gen_op=gen_samples,
y_input=y_gen,
gen_buffer_size=gen_buffer_size,
component_counts=cumulative_component_counts)
# Cull to a multiple of batch_size (keep the later data samples).
n_poor_batches = int(n_poor_data / batch_size)
poor_data_buffer = poor_data_buffer[-(n_poor_batches * batch_size):]
poor_data_labels = poor_data_labels[-(n_poor_batches * batch_size):]
# Find most probable component (on poor batch).
poor_cprobs = []
for bs in range(n_poor_batches):
poor_cprobs.append(
sess.run(
train_ops.cat_probs,
feed_dict={
x_train_raw:
poor_data_buffer[bs * batch_size:(bs + 1) *
batch_size]
}))
best_cluster = np.argmax(np.sum(np.vstack(poor_cprobs), axis=0))
# Initialize parameters of the new component from most prob
# existing.
new_cluster = n_y_active_np
copy_component_params(best_cluster, new_cluster, sess,
**dynamic_ops)
# Increment mixture component count n_y_active.
n_y_active_np += 1
n_y_active.load(n_y_active_np, sess)
# Perform a number of steps of gradient descent on the data buffer,
# training only the new component (supervised loss).
for _ in range(num_buffer_train_steps):
for bs in range(n_poor_batches):
x_batch = poor_data_buffer[bs * batch_size:(bs + 1) *
batch_size]
label_batch = poor_data_labels[bs * batch_size:(bs + 1) *
batch_size]
label_onehot_batch = np.eye(n_y)[label_batch]
_ = sess.run(
train_step_supervised,
feed_dict={
x_train_raw: x_batch,
model_train.y_label: label_onehot_batch
})
# Empty the buffer.
poor_data_buffer = []
poor_data_labels = []
# Reset the threshold flag so we have a burn in before the next
# component.
eligible_for_expansion = False
has_expanded = True
steps_since_expansion = 0
# Accumulate counts.
if used_real_data:
train_cat_probs_vals = results['train_probs']
recent_component_counts += np.sum(
train_cat_probs_vals, axis=0).astype(float)
### 6) LOGGING AND EVALUATION ###
cleanup_for_print = lambda x: ', {}: %.{}f'.format(
x.capitalize().replace('_', ' '), 3)
log_str = 'Iteration %d'
log_str += ''.join([cleanup_for_print(el) for el in curr_to_log])
log_str += ' n_active: %d'
logging.info(
log_str,
*([step] + [results[el] for el in curr_to_log] + [n_y_active_np]))
# Periodically perform evaluation
if (step + 1) % report_interval == 0:
# Report test purity and related measures
logging.info(
'Iteration %d, Test purity: %.3f, Test ELBO: %.3f, Test '
'KLy: %.3f, Test KLz: %.3f', step, results['test_purity'],
results['test_ELBO'], results['test_kl_y'], results['test_kl_z'])
# Flush data only once in a while to allow buffering of data for more
# efficient writes.
results['all_full_poor_data_buffers'] = all_full_poor_data_buffers
results['all_full_poor_data_labels'] = all_full_poor_data_labels
logging.info('Also training a classifier in latent space')
# Perform knn classification from latents, to evaluate discriminability.
# Get and encode training and test datasets.
clf_train_vals = process_dataset(
dataset_ops.train_iter_for_clf, {
'latents': latents_for_clf,
'labels': train_data_for_clf[label_key]
},
sess,
feed_dict,
aggregation_ops=np.concatenate)
clf_test_vals = process_dataset(
dataset_ops.test_iter, {
'latents': test_ops.latents,
'labels': test_data[label_key]
},
sess,
aggregation_ops=np.concatenate)
# Perform knn classification.
knn_models = []
for nval in knn_values:
# Fit training dataset.
clf = neighbors.KNeighborsClassifier(n_neighbors=nval)
clf.fit(clf_train_vals['latents'], clf_train_vals['labels'])
knn_models.append(clf)
results['train_' + str(nval) + 'nn_acc'] = clf.score(
clf_train_vals['latents'], clf_train_vals['labels'])
# Get test performance.
results['test_' + str(nval) + 'nn_acc'] = clf.score(
clf_test_vals['latents'], clf_test_vals['labels'])
logging.info(
'Iteration %d %d-NN classifier accuracies, Training: '
'%.3f, Test: %.3f', step, nval,
results['train_' + str(nval) + 'nn_acc'],
results['test_' + str(nval) + 'nn_acc'])