mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Apply minor fixes to README
PiperOrigin-RevId: 279101132
This commit is contained in:
committed by
Diego de Las Casas
parent
f309feabee
commit
656f93a37e
@@ -0,0 +1,35 @@
|
||||
# Continual Unsupervised Representation Learning (CURL)
|
||||
|
||||
This repository contains code to accompany the NeurIPS 2019 submission on
|
||||
Continual Unsupervised Representation Learning (CURL).
|
||||
|
||||
The experiments in the paper can be reproduced by running one of the three
|
||||
different training scripts:
|
||||
|
||||
|
||||
`train_sup.py`: to run the supervised continual learning benchmark
|
||||
|
||||
`train_unsup.py`: to run the unsupervised i.i.d learning benchmark
|
||||
|
||||
`train_main.py`: to run all other experiments in the paper (with details in the
|
||||
file on what to change)
|
||||
|
||||
In each of these cases, the cluster accuracy / purity and k-NN error are logged
|
||||
to the terminal, and other quantities can be accessed from training.py
|
||||
(e.g. the confusion matrix can be found in `results['test_confusion']`).
|
||||
|
||||
We recommend running these scripts in a Python
|
||||
[virtual environment](https://docs.python.org/3/tutorial/venv.html):
|
||||
|
||||
(Assuming python3-dev is installed in your system)
|
||||
|
||||
```console
|
||||
python3 -m venv .curl_venv
|
||||
source .curl_venv/bin/activate
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
|
||||
PYTHONPATH=`pwd`/..:$PYTHONPATH python3 train_main.py --dataset='mnist'
|
||||
|
||||
Run `deactivate` to exit the virtual environment.
|
||||
```
|
||||
+120
@@ -0,0 +1,120 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Custom layers for CURL."""
|
||||
|
||||
from absl import logging
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
tfc = tf.compat.v1
|
||||
|
||||
|
||||
class ResidualStack(snt.AbstractModule):
|
||||
"""A stack of ResNet V2 blocks."""
|
||||
|
||||
def __init__(self,
|
||||
num_hiddens,
|
||||
num_residual_layers,
|
||||
num_residual_hiddens,
|
||||
filter_size=3,
|
||||
initializers=None,
|
||||
data_format='NHWC',
|
||||
activation=tf.nn.relu,
|
||||
name='residual_stack'):
|
||||
"""Instantiate a ResidualStack."""
|
||||
super(ResidualStack, self).__init__(name=name)
|
||||
self._num_hiddens = num_hiddens
|
||||
self._num_residual_layers = num_residual_layers
|
||||
self._num_residual_hiddens = num_residual_hiddens
|
||||
self._filter_size = filter_size
|
||||
self._initializers = initializers
|
||||
self._data_format = data_format
|
||||
self._activation = activation
|
||||
|
||||
def _build(self, h):
|
||||
for i in range(self._num_residual_layers):
|
||||
h_i = self._activation(h)
|
||||
|
||||
h_i = snt.Conv2D(
|
||||
output_channels=self._num_residual_hiddens,
|
||||
kernel_shape=(self._filter_size, self._filter_size),
|
||||
stride=(1, 1),
|
||||
initializers=self._initializers,
|
||||
data_format=self._data_format,
|
||||
name='res_nxn_%d' % i)(
|
||||
h_i)
|
||||
h_i = self._activation(h_i)
|
||||
|
||||
h_i = snt.Conv2D(
|
||||
output_channels=self._num_hiddens,
|
||||
kernel_shape=(1, 1),
|
||||
stride=(1, 1),
|
||||
initializers=self._initializers,
|
||||
data_format=self._data_format,
|
||||
name='res_1x1_%d' % i)(
|
||||
h_i)
|
||||
h += h_i
|
||||
return self._activation(h)
|
||||
|
||||
|
||||
class SharedConvModule(snt.AbstractModule):
|
||||
"""Convolutional decoder."""
|
||||
|
||||
def __init__(self,
|
||||
filters,
|
||||
kernel_size,
|
||||
activation,
|
||||
strides,
|
||||
name='shared_conv_encoder'):
|
||||
super(SharedConvModule, self).__init__(name=name)
|
||||
|
||||
self._filters = filters
|
||||
self._kernel_size = kernel_size
|
||||
self._activation = activation
|
||||
self.strides = strides
|
||||
assert len(strides) == len(filters) - 1
|
||||
self.conv_shapes = None
|
||||
|
||||
def _build(self, x, is_training=True):
|
||||
with tf.control_dependencies([tfc.assert_rank(x, 4)]):
|
||||
|
||||
self.conv_shapes = [x.shape.as_list()] # Needed by deconv module
|
||||
conv = x
|
||||
for i, (filter_i,
|
||||
stride_i) in enumerate(zip(self._filters, self.strides), 1):
|
||||
conv = tf.layers.Conv2D(
|
||||
filters=filter_i,
|
||||
kernel_size=self._kernel_size,
|
||||
padding='same',
|
||||
activation=self._activation,
|
||||
strides=stride_i,
|
||||
name='enc_conv_%d' % i)(
|
||||
conv)
|
||||
self.conv_shapes.append(conv.shape.as_list())
|
||||
conv_flat = snt.BatchFlatten()(conv)
|
||||
|
||||
enc_mlp = snt.nets.MLP(
|
||||
name='enc_mlp',
|
||||
output_sizes=[self._filters[-1]],
|
||||
activation=self._activation,
|
||||
activate_final=True)
|
||||
h = enc_mlp(conv_flat)
|
||||
|
||||
logging.info('Shared conv module layer shapes:')
|
||||
logging.info('\n'.join([str(el) for el in self.conv_shapes]))
|
||||
logging.info(h.shape.as_list())
|
||||
|
||||
return h
|
||||
+797
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,10 @@
|
||||
absl-py==0.8.0
|
||||
dm-sonnet==1.35
|
||||
gast<0.3
|
||||
numpy==1.16.4
|
||||
scikit-learn==0.20.4
|
||||
setuptools>=41.0.0
|
||||
six==1.12.0
|
||||
tensorflow==1.14.0
|
||||
tensorflow-datasets==1.2.0
|
||||
tensorflow-probability==0.7.0
|
||||
@@ -0,0 +1,79 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Training file to run most of the experiments in the paper.
|
||||
|
||||
The default parameters corresponding to the first set of experiments in Section
|
||||
4.2.
|
||||
|
||||
For the expansion ablation, run with different ll_thresh values as in the paper.
|
||||
Note that n_y_active represents the number of *active* components at the
|
||||
start, and should be set to 1, while n_y represents the maximum number of
|
||||
components allowed, and should be set sufficiently high (eg. n_y = 100).
|
||||
|
||||
For the MGR ablation, setting use_sup_replay = True switches to using SMGR,
|
||||
and the gen_replay_type flag can switch between fixed and dynamic replay. The
|
||||
generative snapshot period is set automatically in the train_curl.py file based
|
||||
on these settings (ie. the data_period variable), so the 0.1T runs can be
|
||||
reproduced by dividing this value by 10.
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from curl import training
|
||||
|
||||
flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
training.run_training(
|
||||
dataset=FLAGS.dataset,
|
||||
output_type='bernoulli',
|
||||
n_y=30,
|
||||
n_y_active=1,
|
||||
training_data_type='sequential',
|
||||
n_concurrent_classes=1,
|
||||
lr_init=1e-3,
|
||||
lr_factor=1.,
|
||||
lr_schedule=[1],
|
||||
blend_classes=False,
|
||||
train_supervised=False,
|
||||
n_steps=100000,
|
||||
report_interval=10000,
|
||||
knn_values=[10],
|
||||
random_seed=1,
|
||||
encoder_kwargs={
|
||||
'encoder_type': 'multi',
|
||||
'n_enc': [1200, 600, 300, 150],
|
||||
'enc_strides': [1],
|
||||
},
|
||||
decoder_kwargs={
|
||||
'decoder_type': 'single',
|
||||
'n_dec': [500, 500],
|
||||
'dec_up_strides': None,
|
||||
},
|
||||
n_z=32,
|
||||
dynamic_expansion=True,
|
||||
ll_thresh=-200.0,
|
||||
classify_with_samples=False,
|
||||
gen_replay_type='fixed',
|
||||
use_supervised_replay=False,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,64 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Runs the supervised CL benchmark experiments in the paper."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from curl import training
|
||||
|
||||
flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
training.run_training(
|
||||
dataset=FLAGS.dataset,
|
||||
output_type='bernoulli',
|
||||
n_y=10,
|
||||
n_y_active=10,
|
||||
training_data_type='sequential',
|
||||
n_concurrent_classes=2,
|
||||
lr_init=1e-3,
|
||||
lr_factor=1.,
|
||||
lr_schedule=[1],
|
||||
train_supervised=True,
|
||||
blend_classes=False,
|
||||
n_steps=100000,
|
||||
report_interval=10000,
|
||||
knn_values=[],
|
||||
random_seed=1,
|
||||
encoder_kwargs={
|
||||
'encoder_type': 'multi',
|
||||
'n_enc': [400, 400],
|
||||
'enc_strides': [1],
|
||||
},
|
||||
decoder_kwargs={
|
||||
'decoder_type': 'single',
|
||||
'n_dec': [400, 400],
|
||||
'dec_up_strides': None,
|
||||
},
|
||||
n_z=32,
|
||||
dynamic_expansion=False,
|
||||
ll_thresh=-10000.0,
|
||||
classify_with_samples=False,
|
||||
gen_replay_type='fixed',
|
||||
use_supervised_replay=False,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,73 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Runs the unsupervised i.i.d benchmark experiments in the paper."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from curl import training
|
||||
|
||||
flags.DEFINE_enum('dataset', 'mnist', ['mnist', 'omniglot'], 'Dataset.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
if FLAGS.dataset == 'mnist':
|
||||
n_y = 25
|
||||
n_y_active = 25
|
||||
n_z = 50
|
||||
else: # omniglot
|
||||
n_y = 100
|
||||
n_y_active = 100
|
||||
n_z = 100
|
||||
|
||||
training.run_training(
|
||||
dataset=FLAGS.dataset,
|
||||
n_y=n_y,
|
||||
n_y_active=n_y_active,
|
||||
n_z=n_z,
|
||||
output_type='bernoulli',
|
||||
training_data_type='iid',
|
||||
n_concurrent_classes=1,
|
||||
lr_init=5e-4,
|
||||
lr_factor=1.,
|
||||
lr_schedule=[1],
|
||||
blend_classes=False,
|
||||
train_supervised=False,
|
||||
n_steps=100000,
|
||||
report_interval=10000,
|
||||
knn_values=[3, 5, 10],
|
||||
random_seed=1,
|
||||
encoder_kwargs={
|
||||
'encoder_type': 'multi',
|
||||
'n_enc': [500, 500],
|
||||
'enc_strides': [1],
|
||||
},
|
||||
decoder_kwargs={
|
||||
'decoder_type': 'single',
|
||||
'n_dec': [500],
|
||||
'dec_up_strides': None,
|
||||
},
|
||||
dynamic_expansion=False,
|
||||
ll_thresh=-0.0,
|
||||
classify_with_samples=True,
|
||||
gen_replay_type=None,
|
||||
use_supervised_replay=False,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
+1169
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,67 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Tests for curl."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from curl import training
|
||||
|
||||
|
||||
class TrainingTest(absltest.TestCase):
|
||||
|
||||
def testRunTraining(self):
|
||||
|
||||
training.run_training(
|
||||
dataset='mnist',
|
||||
output_type='bernoulli',
|
||||
n_y=10,
|
||||
n_y_active=1,
|
||||
training_data_type='sequential',
|
||||
n_concurrent_classes=1,
|
||||
lr_init=1e-3,
|
||||
lr_factor=1.,
|
||||
lr_schedule=[1],
|
||||
blend_classes=False,
|
||||
train_supervised=False,
|
||||
n_steps=1000,
|
||||
report_interval=1000,
|
||||
knn_values=[3],
|
||||
random_seed=1,
|
||||
encoder_kwargs={
|
||||
'encoder_type': 'multi',
|
||||
'n_enc': [1200, 600, 300, 150],
|
||||
'enc_strides': [1],
|
||||
},
|
||||
decoder_kwargs={
|
||||
'decoder_type': 'single',
|
||||
'n_dec': [500, 500],
|
||||
'dec_up_strides': None,
|
||||
},
|
||||
n_z=32,
|
||||
dynamic_expansion=True,
|
||||
ll_thresh=-200.0,
|
||||
classify_with_samples=False,
|
||||
gen_replay_type='fixed',
|
||||
use_supervised_replay=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,85 @@
|
||||
################################################################################
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
################################################################################
|
||||
"""Some common utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
def generate_gaussian(logits, sigma_nonlin, sigma_param):
|
||||
"""Generate a Gaussian distribution given a selected parameterisation."""
|
||||
|
||||
mu, sigma = tf.split(value=logits, num_or_size_splits=2, axis=1)
|
||||
|
||||
if sigma_nonlin == 'exp':
|
||||
sigma = tf.exp(sigma)
|
||||
elif sigma_nonlin == 'softplus':
|
||||
sigma = tf.nn.softplus(sigma)
|
||||
else:
|
||||
raise ValueError('Unknown sigma_nonlin {}'.format(sigma_nonlin))
|
||||
|
||||
if sigma_param == 'var':
|
||||
sigma = tf.sqrt(sigma)
|
||||
elif sigma_param != 'std':
|
||||
raise ValueError('Unknown sigma_param {}'.format(sigma_param))
|
||||
|
||||
return tfp.distributions.Normal(loc=mu, scale=sigma)
|
||||
|
||||
|
||||
def construct_prior_probs(batch_size, n_y, n_y_active):
|
||||
"""Construct the uniform prior probabilities.
|
||||
|
||||
Args:
|
||||
batch_size: int, the size of the batch.
|
||||
n_y: int, the number of categorical cluster components.
|
||||
n_y_active: tf.Variable, the number of components that are currently in use.
|
||||
|
||||
Returns:
|
||||
Tensor representing the prior probability matrix, size of [batch_size, n_y].
|
||||
"""
|
||||
probs = tf.ones((batch_size, n_y_active)) / tf.cast(
|
||||
n_y_active, dtype=tf.float32)
|
||||
paddings1 = tf.stack([tf.constant(0), tf.constant(0)], axis=0)
|
||||
paddings2 = tf.stack([tf.constant(0), n_y - n_y_active], axis=0)
|
||||
paddings = tf.stack([paddings1, paddings2], axis=1)
|
||||
probs = tf.pad(probs, paddings, constant_values=1e-12)
|
||||
probs.set_shape((batch_size, n_y))
|
||||
logging.info('Prior shape: %s', str(probs.shape))
|
||||
return probs
|
||||
|
||||
|
||||
def maybe_center_crop(layer, target_hw):
|
||||
"""Center crop the layer to match a target shape."""
|
||||
l_height, l_width = layer.shape.as_list()[1:3]
|
||||
t_height, t_width = target_hw
|
||||
assert t_height <= l_height and t_width <= l_width
|
||||
|
||||
if (l_height - t_height) % 2 != 0 or (l_width - t_width) % 2 != 0:
|
||||
logging.warn(
|
||||
'It is impossible to center-crop [%d, %d] into [%d, %d].'
|
||||
' Crop will be uneven.', t_height, t_width, l_height, l_width)
|
||||
|
||||
border = int((l_height - t_height) / 2)
|
||||
x_0, x_1 = border, l_height - border
|
||||
border = int((l_width - t_width) / 2)
|
||||
y_0, y_1 = border, l_width - border
|
||||
layer_cropped = layer[:, x_0:x_1, y_0:y_1, :]
|
||||
return layer_cropped
|
||||
Reference in New Issue
Block a user