diff --git a/README.md b/README.md index d63ea0e..e4bae3f 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Encoders and ensembles for continual learning](continual_learning) * [Towards mental time travel: a hierarchical memory for reinforcement learning agents](hierarchical_transformer_memory) * [Perceiver IO: A General Architecture for Structured Inputs & Outputs](perceiver) * [Solving Mixed Integer Programs Using Neural Networks](neural_mip_solving) diff --git a/continual_learning/README.md b/continual_learning/README.md new file mode 100644 index 0000000..11af121 --- /dev/null +++ b/continual_learning/README.md @@ -0,0 +1,35 @@ +# Continual learning with pre-trained encoders and ensembles of classifiers +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/continual_learning/encoders_and_ensembles.ipynb) + +This repository contains a notebook implementation of a classifier ensemble memory model that mitigates catastrophic forgetting. + +The code was written by Murray Shanahan. + +The model comprises +* a pre-trained encoder, trained on a different dataset from the target dataset, and +* a memory with fixed randomised keys and k-nearest neighbour lookup, where +* each memory location stores the parameters of a trainable local classifier, and +* the ensemble's output is the mean output of the k selected classifiers weighted according to the distance of their keys from the encoded input + +The model is demonstrated on MNIST, where the encoder is pre-trained on Omniglot. The continual learning setting is +* Task-free. The models doesn't know about task boundaries +* Online. The dataset is ony seen once, and there are no epochs +* Incremental class learning. Evaluation is always on 10-way classification + +The code accompanies the paper: + +Shanahan, M., Kaplanis, C. & Mitrovic, J. (2021). Encoders and Ensembles for Task-Free Continual Learning. ArXiv preprint: https://arxiv.org/abs/2105.13327 + +## Running the experiments + +The easiest way to run the code is using the publicly available [Colab](https://colab.research.google.com) kernel. Colaboratory is a free Jupyter notebook environment provided by Google that requires no setup and runs entirely in the cloud. (A GPU runtime is needed to train in a reasonable time.) The notebook is self-contained, and will load all necessary libraries automatically if run in Colaboratory. + +Click "Run all" in the "Runtime" menu to train on 5-way split MNIST ("high data" setting), as described in the paper. Adjusting the "schedule_type" in the config will allow you to try out different benchmarks, such as a 10-way split. + +## Contact + +If you have any feedback, or would like to get in touch regarding the code or the architecture, you can reach out to mshanahan@deepmind.com. + +## Disclaimer + +This is not an officially supported Google product. diff --git a/continual_learning/encoders_and_ensembles.ipynb b/continual_learning/encoders_and_ensembles.ipynb new file mode 100644 index 0000000..d28515f --- /dev/null +++ b/continual_learning/encoders_and_ensembles.ipynb @@ -0,0 +1,1092 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "hoKLQnrnS73m" + }, + "source": [ + "Copyright 2021 DeepMind Technologies Limited.\n", + "\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + "https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "# Continual learning with pre-trained encoders and ensembles of classifiers\n", + "\n", + "Murray Shanahan\n", + "\n", + "July 2021\n", + "\n", + "A classifier ensemble memory model that mitigates catastrophic forgetting. The model comprises\n", + "* a pre-trained encoder, trained on a different dataset from the target dataset, and\n", + "* a memory with fixed randomised keys and k-nearest neighbour lookup, where\n", + "* each memory location stores the parameters of a trainable local classifier, and\n", + "* the ensemble's output is the mean output of the k selected classifiers weighted according to the distance of their keys from the encoded input\n", + "\n", + "The model is demonstrated on MNIST, where the encoder is pre-trained on Omniglot. The continual learning setting is\n", + "* Task-free. The models doesn't know about task boundaries\n", + "* Online. The dataset is ony seen once, and there are no epochs\n", + "* Incremental class learning. Evaluation is always on 10-way classification\n", + "\n", + "This Colab accompanies the paper:\n", + "\n", + "Shanahan, M., Kaplanis, C. \u0026 Mitrovic, J. (2021). Encoders and Ensembles for Task-Free Continual Learning. ArXiv preprint: https://arxiv.org/abs/2105.13327" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kfWWLMG6wN5J" + }, + "source": [ + "# Preliminaries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "thK1IH0gqT3V" + }, + "outputs": [], + "source": [ + "# Dependencies that may require pip installation\n", + "\n", + "!pip install dm-haiku\n", + "!pip install optax\n", + "!pip install dm-tree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LcMFWHXtSYWn" + }, + "outputs": [], + "source": [ + "# Imports\n", + "\n", + "import tensorflow.compat.v2 as tf\n", + "tf.enable_v2_behavior()\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import jit, grad, random\n", + "import optax\n", + "import haiku as hk\n", + "import tree\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JYq-m_tSuc-M" + }, + "outputs": [], + "source": [ + "# Experiment parameters\n", + "\n", + "\n", + "# MNIST config (high data) - for comparison with Lee, et al. (2020)\n", + "\n", + "config = {\n", + " 'enc_size': 512, # size of latent encoding\n", + " 'mem_size': 1024, # number of memory locations (classifiers)\n", + " 'k': 32, # k nearest neighbour lookup parameter\n", + " 'vub': 250, # upper bound for activation function - was 100\n", + " 'res': 28, # resolution - 28 for MNIST \u0026 Omniglot\n", + " 'col_dims': 1, # 3 for RGB, 1 for greyscale\n", + " 'num_classes': 10, # number of classes\n", + " 'pretrain_n_batches': 10000, # number of batches in pre-training\n", + " 'pretrain_dataset': 'omniglot',\n", + " 'main_dataset': 'mnist',\n", + " 'batch_size': 60, # batch size for training main model\n", + " 'learning_rate': 1e-4, # learning rate for training main model\n", + " 'weight_decay': 1e-4, # optimiser weight decay\n", + " 'init_scale': 0.1, # baseline classifier initialiser variance scaling\n", + " 'log_every': 10, # interval for logging accuracies\n", + " 'report_every': 500, # interval for reporting accuracies\n", + " 'schedule_type': '5way_split', # training schedule (defining splits, etc)\n", + " 'n_runs': 20, # number of runs on the schedule\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a6JEJ1hYv-Um" + }, + "outputs": [], + "source": [ + "# Optimisers\n", + "\n", + "\n", + "class NaiveOptimiser():\n", + " \"\"\"Optimiser that discards magnitude of gradients and uses only their sign.\"\"\"\n", + "\n", + " def __init__(self, learning_rate, weight_decay):\n", + " self.learning_rate = learning_rate\n", + " self.weight_decay = weight_decay\n", + " self.state = None\n", + "\n", + " def init(self, params):\n", + " return None\n", + "\n", + " def update(self, grads, _, params=None):\n", + " step_size = self.learning_rate\n", + " weight_decay = self.weight_decay\n", + " updates = jax.tree_map(lambda g: -jnp.sign(g), grads)\n", + " updates = jax.tree_multimap(\n", + " lambda g, p: g + weight_decay * p, updates, params)\n", + " updates = jax.tree_map(lambda g: step_size * g, updates)\n", + " return (updates, self.state)\n", + "\n", + "\n", + "def make_encoder_optimiser():\n", + " opt = optax.adam(learning_rate=0.001) # learning rate was 0.001\n", + " return opt\n", + "\n", + "\n", + "def make_ensemble_optimiser():\n", + " opt = NaiveOptimiser(learning_rate=config['learning_rate'],\n", + " weight_decay=config['weight_decay'])\n", + " return opt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2yojukFKuJYr" + }, + "source": [ + "# Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gKLDMf_-NDPC" + }, + "outputs": [], + "source": [ + "# Datasets: MNIST and Omniglot\n", + "\n", + "\n", + "def get_dataset(dataset_name, train_or_test, batch_size, filter_labels=None):\n", + " filter_fn = lambda batch: tf.reduce_any(tf.equal(batch['label'],\n", + " filter_labels))\n", + " dataset = tfds.load(dataset_name, split=train_or_test,\n", + " as_supervised=False)\n", + " if filter_labels is not None:\n", + " dataset = dataset.filter(filter_fn)\n", + " dataset = dataset.shuffle(buffer_size=10000)\n", + " dataset = dataset.batch(batch_size)\n", + " dataset = dataset.repeat()\n", + " dataset = iter(dataset)\n", + " return dataset\n", + "\n", + "\n", + "def get_batch(dataset, dataset_name):\n", + " batch = next(dataset)\n", + " batch_size = batch['image'].shape[0]\n", + " images = batch['image']\n", + " if dataset_name == 'omniglot':\n", + " images = tf.image.resize(images, [config['res'], config['res']])\n", + " images = images[:, :, :, 0]\n", + " images = tf.reshape(images, [batch_size, 28, 28, 1]) / 255\n", + " if dataset_name == 'omniglot':\n", + " images = 1 - images # raw Omniglot characters are white (1) on black (0)\n", + " labels = batch['label']\n", + " one_hots = tf.one_hot(batch['label'], config['num_classes'])\n", + " return (images.numpy(), one_hots.numpy(), labels.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uEtOt26W9O36" + }, + "source": [ + "# Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HeUf0sJmT00v" + }, + "outputs": [], + "source": [ + "# Plotting accuracies\n", + "\n", + "\n", + "def smooth(data, degree=2):\n", + " \"\"\"Smooth out data for plotting.\"\"\"\n", + " triangle = jnp.array(list(range(degree)) + [degree] +\n", + " list(range(degree)[::-1])) + 1\n", + " # Copy last data point 'degree' times\n", + " data = jnp.append(data, jnp.array([data[-1] for _ in range(len(triangle))]))\n", + " smoothed = [data[0]]\n", + " for i in range(1, len(data) - len(triangle)):\n", + " point = data[i:i + len(triangle)] * triangle\n", + " smoothed.append(sum(point)/sum(triangle))\n", + " return jnp.array(smoothed)\n", + "\n", + "\n", + "def plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3, final=False):\n", + " \"\"\"Plot experiment mean accuracies with error bounds.\"\"\"\n", + "\n", + " # Find means and stds\n", + " x_accuracies1 = jnp.array(x_accuracies1, dtype=jnp.float64)\n", + " x_accuracies2 = jnp.array(x_accuracies2, dtype=jnp.float64)\n", + " x_accuracies3 = jnp.array(x_accuracies3, dtype=jnp.float64)\n", + " c_vanilla_means = jnp.mean(x_accuracies1, axis=(0, 1))\n", + " c_vanilla_stds = jnp.std(jnp.mean(x_accuracies1, axis=1), axis=0)\n", + " c_tanh_means = jnp.mean(x_accuracies2, axis=(0, 1))\n", + " c_tanh_stds = jnp.std(jnp.mean(x_accuracies2, axis=1), axis=0)\n", + " e_means = jnp.mean(x_accuracies3, axis=(0, 1))\n", + " e_stds = jnp.std(jnp.mean(x_accuracies3, axis=1), axis=0)\n", + " final_c_vanilla_mean = c_vanilla_means[-1]\n", + " final_c_vanilla_std = c_vanilla_stds[-1]\n", + " final_c_tanh_mean = c_tanh_means[-1]\n", + " final_c_tanh_std = c_tanh_stds[-1]\n", + " final_e_mean = e_means[-1]\n", + " final_e_std = e_stds[-1]\n", + "\n", + " # Smooth the data\n", + " c_vanilla_means = smooth(c_vanilla_means)\n", + " c_vanilla_stds = smooth(c_vanilla_stds)\n", + " c_tanh_means = smooth(c_tanh_means)\n", + " c_tanh_stds = smooth(c_tanh_stds)\n", + " e_means = smooth(e_means)\n", + " e_stds = smooth(e_stds)\n", + "\n", + " plt.figure(figsize=(8, 4))\n", + "\n", + " # Vanilla classifier accuracies\n", + " ax = plt.plot(range(len(c_vanilla_means)), c_vanilla_means)\n", + " colour = ax[-1].get_color()\n", + " plt.fill_between(range(len(c_vanilla_means)),\n", + " c_vanilla_means-c_vanilla_stds,\n", + " c_vanilla_means+c_vanilla_stds,\n", + " facecolor=colour, alpha=0.2)\n", + " # Tanh classifier accuracies\n", + " ax = plt.plot(range(len(c_tanh_means)), c_tanh_means)\n", + " colour = ax[-1].get_color()\n", + " plt.fill_between(range(len(c_tanh_means)),\n", + " c_tanh_means-c_tanh_stds,\n", + " c_tanh_means+c_tanh_stds,\n", + " facecolor=colour, alpha=0.2)\n", + " # Ensemble accuracies\n", + " ax = plt.plot(range(len(e_means)), e_means)\n", + " colour = ax[-1].get_color()\n", + " plt.fill_between(range(len(e_means)),\n", + " e_means-e_stds,\n", + " e_means+e_stds,\n", + " facecolor=colour, alpha=0.2)\n", + "\n", + " # Produce plots\n", + " plt.ylim([0.0, 1.0])\n", + " plt.xlabel('Batch x{}'.format(config['log_every']))\n", + " plt.ylabel('Accuracy')\n", + " plt.legend(['Vanilla classifier', 'Tanh classifier', 'Ensemble'])\n", + " plt.show()\n", + "\n", + " # Report accuracies\n", + " print('Vanilla classifier accuracy: {:.4f} \\u00b1 {:.4f}'.format(\n", + " final_c_vanilla_mean, final_c_vanilla_std))\n", + " print('Tanh classifier accuracy: {:.4f} \\u00b1 {:.4f}'.format(\n", + " final_c_tanh_mean, final_c_tanh_std))\n", + " print('Ensemble accuracy: {:.4f} \\u00b1 {:.4f}'.format(\n", + " final_e_mean, final_e_std))\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qb3GLfh0waig" + }, + "source": [ + "# Models and losses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2bdSRg6jLxIz" + }, + "outputs": [], + "source": [ + "# Autoencoder (for pretraining)\n", + "\n", + "\n", + "class Autoencoder(hk.Module):\n", + " \"\"\"Autoencoder module.\"\"\"\n", + "\n", + "\n", + " def encode(self, image):\n", + " cnn = hk.Sequential([\n", + " hk.Conv2D(output_channels=16, kernel_shape=4, name='enc1'), jax.nn.relu,\n", + " hk.Conv2D(output_channels=16, kernel_shape=4, name='enc2'), jax.nn.relu,\n", + " hk.Flatten(),\n", + " ])\n", + " mlp1 = hk.Sequential([\n", + " hk.Linear(128, name='enc3'), jax.nn.relu,\n", + " hk.Linear(config['enc_size'], name='enc4'),\n", + " ])\n", + " mlp2 = hk.Sequential([\n", + " hk.Linear(128, name='enc5'), jax.nn.relu,\n", + " hk.Linear(config['enc_size'], name='enc6'),\n", + " ])\n", + " feats = cnn(image.reshape([-1, config['res'], config['res'],\n", + " config['col_dims']]))\n", + " enc_mean = jnp.tanh(mlp1(feats))\n", + " enc_sd = jax.nn.relu(mlp2(feats))\n", + " return (enc_mean, enc_sd)\n", + "\n", + "\n", + " def decode(self, latent):\n", + " dcnn = hk.Sequential(\n", + " [hk.Linear(128, name='dec1'), jax.nn.relu,\n", + " hk.Linear(config['res']*config['res']*16, name='dec2'), jax.nn.relu,\n", + " hk.Reshape((config['res'], config['res'], 16)),\n", + " hk.Conv2DTranspose(output_channels=16, kernel_shape=4, name='dec3'),\n", + " jax.nn.relu,\n", + " hk.Conv2DTranspose(output_channels=config['col_dims'],\n", + " kernel_shape=4, name='dec4'),\n", + " jax.nn.sigmoid])\n", + " image = dcnn(latent).reshape([-1, config['res'], config['res'],\n", + " config['col_dims']])\n", + " return image\n", + "\n", + "\n", + " def forward(self, rng, image):\n", + " (enc_mean, enc_sd) = self.encode(image)\n", + " # Sample\n", + " (rng2, rng) = random.split(rng)\n", + " eps = random.normal(rng2, jnp.shape(enc_mean))\n", + " enc = enc_mean + enc_sd * eps\n", + " image_dec = self.decode(enc)\n", + " out = {\n", + " 'enc_mean': enc_mean,\n", + " 'enc_sd': enc_sd,\n", + " 'image_dec': image_dec,\n", + " }\n", + " return out\n", + "\n", + "\n", + "def encoder(rng, image):\n", + " autoencoder = Autoencoder()\n", + " out = autoencoder.forward(rng, image)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ao_JIdjYMLP_" + }, + "outputs": [], + "source": [ + "# Autoencoder loss (for pretraining)\n", + "\n", + "\n", + "def kl_divergence(mean, sd):\n", + " kl = -0.5 * (1.0 + jnp.log(sd**2) - mean**2 - sd**2)\n", + " return kl\n", + "\n", + "\n", + "def autoencoder_losses(enc_params, rng, images):\n", + " encoder_net = hk.transform(encoder)\n", + " (rng2, rng) = random.split(rng)\n", + " autoencoder_out = encoder_net.apply(enc_params, rng, rng2, images)\n", + " enc_means = autoencoder_out['enc_mean']\n", + " enc_sds = autoencoder_out['enc_sd']\n", + " image_decs = autoencoder_out['image_dec']\n", + " # Decoder reconstruction loss\n", + " decoder_loss = jnp.mean((images-image_decs)**2)\n", + " # Decoder KL loss\n", + " kld = kl_divergence(enc_means, enc_sds + 1e-10) # add epsilon to avoid sd=0\n", + " kl_loss = jnp.mean(kld)\n", + " # Total loss\n", + " beta = 0.001 # weighting of KL term\n", + " tot_loss = decoder_loss + beta * kl_loss\n", + " losses = {\n", + " 'tot_loss': tot_loss,\n", + " 'decoder_loss': decoder_loss,\n", + " 'kl_loss': kl_loss,\n", + " }\n", + " return losses\n", + "\n", + "\n", + "def autoencoder_loss(enc_params, rng, images):\n", + " losses = autoencoder_losses(enc_params, rng, images)\n", + " return losses['tot_loss']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bawIEzugOiZG" + }, + "outputs": [], + "source": [ + "# Update autoencoder parameters\n", + "\n", + "@jit\n", + "def update_autoencoder(enc_params, rng, opt_state, images):\n", + " opt = make_encoder_optimiser()\n", + " grads = grad(autoencoder_loss)(enc_params, rng, images)\n", + " updates, opt_state = opt.update(grads, opt_state)\n", + " new_params = optax.apply_updates(enc_params, updates)\n", + " return new_params, opt_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lKGZurl6fQBd" + }, + "outputs": [], + "source": [ + "# Ensemble memory (main model)\n", + "\n", + "\n", + "def activation(values):\n", + " \"\"\"Activation function for ensemble (scaled tanh).\"\"\"\n", + " out = jnp.tanh(values / config['vub']) * config['vub']\n", + " return out\n", + "\n", + "\n", + "def l2_normalize(x, axis=None, epsilon=1e-12):\n", + " \"\"\"l2 normalize a tensor on an axis with numerical stability.\"\"\"\n", + " square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)\n", + " x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))\n", + " return x * x_inv_norm\n", + "\n", + "\n", + "class Memory(hk.Module):\n", + " \"\"\"Memory module.\"\"\"\n", + "\n", + " def __init__(self, name=None):\n", + " super().__init__(name)\n", + " self.keys = hk.get_parameter('mem_keys', [config['mem_size'],\n", + " config['enc_size']],\n", + " init=hk.initializers.Constant(0))\n", + " self.weights = hk.get_parameter('mem_weights', [config['mem_size'],\n", + " config['enc_size'],\n", + " config['num_classes']],\n", + " init=hk.initializers.VarianceScaling())\n", + " self.biases = hk.get_parameter('mem_biases', [config['mem_size'], 1,\n", + " config['num_classes']],\n", + " init=hk.initializers.Constant(0))\n", + "\n", + "\n", + " def lookup(self, enc):\n", + " \"\"\"k-nearest neighbour lookup in ensemble memory.\"\"\"\n", + " enc = l2_normalize(enc, axis=1)\n", + " keys = l2_normalize(self.keys, axis=1)\n", + " sims = jnp.matmul(enc, jnp.transpose(keys)) # cosine similarities\n", + " (k_sims, idx) = jax.lax.top_k(sims, config['k']) # k nearest neighbours\n", + " # Keys\n", + " k_keys = jnp.take(self.keys, idx, axis=0)\n", + " mean_key = jnp.mean(k_keys, axis=1)\n", + " # Values\n", + " k_encs = jnp.expand_dims(enc, axis=(1, 2))\n", + " k_encs = jnp.tile(k_encs, (1, config['k'], 1, 1))\n", + " k_weights = jnp.take(self.weights, idx, axis=0)\n", + " k_biases = jnp.take(self.biases, idx, axis=0)\n", + " k_values = jnp.matmul(k_encs, k_weights) + k_biases\n", + " k_values = jnp.squeeze(k_values)\n", + " k_values = activation(k_values)\n", + " # Mean of values weighted by key similarity\n", + " k_sims2 = jax.lax.stop_gradient(jnp.expand_dims(k_sims, axis=2))\n", + " mean_value = jnp.sum(k_values * k_sims2, axis=1) / jnp.sum(k_sims2, axis=1)\n", + " return (mean_key, mean_value, k_sims, k_keys, k_values, idx)\n", + "\n", + "\n", + "class EnsembleModel(hk.Module):\n", + " \"\"\"Ensemble memory model.\"\"\"\n", + "\n", + "\n", + " def __init__(self, name=None):\n", + " super().__init__(name)\n", + " self.memory = Memory()\n", + "\n", + "\n", + " def enc_to_class_vanilla(self, enc_image):\n", + " mlp = hk.Sequential([\n", + " hk.Linear(config['num_classes'],\n", + " w_init=hk.initializers.VarianceScaling(\n", + " scale=config['init_scale']),\n", + " name='classifier1'),\n", + " jax.nn.log_softmax,\n", + " ])\n", + " pred = mlp(enc_image) # predicted class\n", + " return pred\n", + "\n", + "\n", + " def enc_to_class_tanh(self, enc_image):\n", + " mlp = hk.Sequential([\n", + " hk.Linear(config['num_classes'],\n", + " w_init=hk.initializers.VarianceScaling(\n", + " scale=config['init_scale']),\n", + " name='classifier2'),\n", + " ])\n", + " pred = activation(mlp(enc_image)) # predicted class\n", + " return pred\n", + "\n", + "\n", + " def forward(self, enc_image):\n", + " \"\"\"Memory lookup and classifier.\"\"\"\n", + " (mean_key, mean_value,\n", + " k_sims, k_keys, k_values, idx) = self.memory.lookup(enc_image)\n", + " classifier_out_vanilla = self.enc_to_class_vanilla(enc_image)\n", + " classifier_out_tanh = self.enc_to_class_tanh(enc_image)\n", + " out = {\n", + " 'classifier_out_vanilla': classifier_out_vanilla,\n", + " 'classifier_out_tanh': classifier_out_tanh,\n", + " 'k_sims': k_sims,\n", + " 'k_keys': k_keys,\n", + " 'k_values': k_values,\n", + " 'mean_key': mean_key,\n", + " 'mean_value': mean_value,\n", + " 'idx': idx,\n", + " }\n", + " return out\n", + "\n", + "\n", + "def model(enc_image):\n", + " ensemble = EnsembleModel()\n", + " out = ensemble.forward(enc_image)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sjFE4ml3gVwK" + }, + "outputs": [], + "source": [ + "# Loss and accuracy for ensemble memory (main model)\n", + "\n", + "\n", + "def model_loss(params, rng, enc_images, one_hots, labels):\n", + " model_net = hk.transform(model)\n", + " model_out = model_net.apply(params, rng, enc_images)\n", + " preds_vanilla = model_out['classifier_out_vanilla']\n", + " preds_tanh = model_out['classifier_out_tanh']\n", + " mean_values = model_out['mean_value']\n", + " # Classifier losses\n", + " classifier_loss_vanilla = -jnp.mean(jnp.sum(preds_vanilla * one_hots, axis=1))\n", + " classifier_loss_tanh = -jnp.mean(jnp.sum(preds_tanh * one_hots, axis=1))\n", + " # Memory loss\n", + " memory_loss = -jnp.mean(jnp.sum(mean_values * one_hots, axis=1))\n", + " # Total loss\n", + " loss = classifier_loss_vanilla + classifier_loss_tanh + memory_loss\n", + " return loss\n", + "\n", + "\n", + "@jit\n", + "def accuracy(params, rng, enc_images, labels):\n", + " \"\"\"Accuracies for each type of model.\"\"\"\n", + " model_net = hk.transform(model)\n", + " model_out = model_net.apply(params, rng, enc_images)\n", + " classifier_classes_van = jnp.argmax(model_out['classifier_out_vanilla'], axis=1)\n", + " classifier_acc_van = jnp.mean(classifier_classes_van == labels)\n", + " classifier_classes_tanh = jnp.argmax(model_out['classifier_out_tanh'], axis=1)\n", + " classifier_acc_tanh = jnp.mean(classifier_classes_tanh == labels)\n", + " ensemble_classes = jnp.argmax(model_out['mean_value'], axis=1)\n", + " ensemble_acc = jnp.mean(ensemble_classes == labels)\n", + " return (classifier_acc_van, classifier_acc_tanh, ensemble_acc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7HPlrc-OsZHf" + }, + "outputs": [], + "source": [ + "# Update ensemble memory parameters\n", + "\n", + "@jit\n", + "def update_model(params, rng, opt_state,\n", + " enc_images, one_hots, labels):\n", + " opt = make_ensemble_optimiser()\n", + " grads = grad(model_loss)(params, rng, enc_images, one_hots, labels)\n", + " updates, opt_state = opt.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l6XB2jDcESyS" + }, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bK8cobd9ctcg" + }, + "outputs": [], + "source": [ + "def get_encoder():\n", + " encoder_net = hk.transform(encoder)\n", + " return encoder_net\n", + "\n", + "\n", + "def initialise_encoder(rng):\n", + " \"\"\"Initialise internal encoder for pre-training.\"\"\"\n", + " encoder_net = get_encoder()\n", + " # Get dummy batch\n", + " batch_size = 24\n", + " train_set = get_dataset(config['pretrain_dataset'], 'train', batch_size)\n", + " batch = get_batch(train_set, config['pretrain_dataset'])\n", + " (images, _, _) = batch\n", + " (rng2, rng) = random.split(rng)\n", + " (rng3, rng) = random.split(rng)\n", + " enc_params = encoder_net.init(rng2, rng3, images)\n", + " return enc_params\n", + "\n", + "\n", + "@jit\n", + "def apply_encoder(enc_params, rng, images):\n", + " encoder_net = get_encoder()\n", + " (rng2, rng) = random.split(rng)\n", + " encoder_out = encoder_net.apply(enc_params, rng, rng2, images)\n", + " enc = encoder_out['enc_mean']\n", + " return enc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NshNbe3gPA31" + }, + "outputs": [], + "source": [ + "# Autoencoder training\n", + "\n", + "\n", + "def pretrain_encoder(rng, n_batches, filter_labels=None):\n", + "\n", + " print('Encoder pre-training on {}'.format(config['pretrain_dataset']))\n", + " print()\n", + " # Get train and test data\n", + " train_batch_size = 48\n", + " test_batch_size = 256\n", + " train_set = get_dataset(config['pretrain_dataset'], 'train',\n", + " train_batch_size, filter_labels=filter_labels)\n", + " test_set = get_dataset(config['pretrain_dataset'], 'test',\n", + " test_batch_size, filter_labels=filter_labels)\n", + " test_batch = get_batch(test_set, config['pretrain_dataset'])\n", + " # Train encoders until a good enough one is found\n", + " success = False\n", + " loss_threshold = 0.025\n", + " while not success:\n", + " # Initialise parameters\n", + " (rng2, rng) = random.split(rng)\n", + " enc_params = initialise_encoder(rng2)\n", + " # Initialise optimiser\n", + " opt = make_encoder_optimiser()\n", + " opt_state = opt.init(enc_params)\n", + " # Training\n", + " for i in range(n_batches):\n", + " batch = get_batch(train_set, config['pretrain_dataset'])\n", + " (images, _, _) = batch\n", + " (rng2, rng) = random.split(rng)\n", + " (enc_params, opt_state) = update_autoencoder(enc_params, rng2,\n", + " opt_state, images)\n", + " (images, _, _) = test_batch\n", + " (rng2, rng) = random.split(rng)\n", + " losses = autoencoder_losses(enc_params, rng2, images)\n", + " print('Batch {}'.format(i+1))\n", + " print('Reconstruction loss {:.8f}'.format(losses['decoder_loss']))\n", + " print('KL loss {:.4f}'.format(losses['kl_loss']))\n", + " print()\n", + " success = losses['decoder_loss'] \u003c loss_threshold\n", + " if not success:\n", + " print('Reconstruction loss too high - retraining')\n", + " print()\n", + " (rng2, rng) = random.split(rng)\n", + "\n", + " return enc_params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "axY5rDzycJZP" + }, + "outputs": [], + "source": [ + "# Model testing with recording of accuracies\n", + "\n", + "\n", + "def test_model(model_params, enc_params, test_labels, test_batches,\n", + " accuracies, batch_number):\n", + " rng = random.PRNGKey(42)\n", + " for i in range(len(test_labels)):\n", + " (images, _, labels) = test_batches[i]\n", + " # Encode images\n", + " (rng2, rng) = random.split(rng)\n", + " enc_images = apply_encoder(enc_params, rng2, images)\n", + " # Get accuracies\n", + " (rng2, rng) = random.split(rng)\n", + " (classifier_acc_vanilla,\n", + " classifier_acc_tanh, ensemble_acc) = accuracy(model_params, rng2,\n", + " enc_images, labels)\n", + " accuracies['accuracies_vanilla'][i].append(classifier_acc_vanilla)\n", + " accuracies['accuracies_tanh'][i].append(classifier_acc_tanh)\n", + " accuracies['accuracies_ensemble'][i].append(ensemble_acc)\n", + " return accuracies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LespFNYzv50c" + }, + "outputs": [], + "source": [ + "# Model training with pre-trained encoder\n", + "\n", + "\n", + "def initialise_model(rng, enc_params):\n", + " model_net = hk.transform(model)\n", + " # Get dummy batch\n", + " train_batch_size = 24\n", + " train_set = get_dataset(config['main_dataset'], 'train', train_batch_size)\n", + " batch = get_batch(train_set, config['main_dataset'])\n", + " (images, _, _) = batch\n", + " # Encode images\n", + " (rng2, rng) = random.split(rng)\n", + " enc_images = apply_encoder(enc_params, rng2, images)\n", + " # Initialise the model\n", + " (rng2, rng) = random.split(rng)\n", + " model_params = model_net.init(rng2, enc_images)\n", + " # Initialise memory keys according to encoding stats\n", + " (rng2, rng) = random.split(rng)\n", + " keys = jax.random.normal(rng2, [config['mem_size'], config['enc_size']])\n", + " new_key_params = {'ensemble_model/~/memory': {'mem_keys': keys}}\n", + " (old_key_params, rest_params) = hk.data_structures.partition(\n", + " lambda m, n, p: (m == 'ensemble_model/~/memory' and\n", + " n == 'mem_keys'), model_params)\n", + " model_params = hk.data_structures.merge(new_key_params, rest_params)\n", + " return model_params\n", + "\n", + "\n", + "def train_model(model_params, enc_params, rng,\n", + " label_set, test_labels, test_batch, accuracies, run,\n", + " n_batches=0, tot_batches=0):\n", + "\n", + " # Get train and test data\n", + " train_batch_size = config['batch_size']\n", + " train_set = get_dataset(config['main_dataset'], 'train', train_batch_size,\n", + " filter_labels=label_set)\n", + " # Training\n", + " for i in range(n_batches):\n", + " # Re-initialise optimiser\n", + " opt = make_ensemble_optimiser()\n", + " opt_state = opt.init(model_params)\n", + " tot_batches += 1\n", + " batch = get_batch(train_set, config['main_dataset'])\n", + " (images, one_hots, labels) = batch\n", + " # Encode images\n", + " (rng2, rng) = random.split(rng)\n", + " enc_images = apply_encoder(enc_params, rng2, images)\n", + " (rng2, rng) = random.split(rng)\n", + " (model_params, opt_state) = update_model(model_params, rng2,\n", + " opt_state, enc_images,\n", + " one_hots, labels)\n", + " # Log accuracies\n", + " if tot_batches % config['log_every'] == 0:\n", + " accuracies = test_model(model_params, enc_params, test_labels, test_batch,\n", + " accuracies, tot_batches)\n", + " return (model_params, accuracies, tot_batches)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "byXZrXQ0ctwK" + }, + "outputs": [], + "source": [ + "def train_with_schedule(model_params, enc_params, rng,\n", + " schedule, test_labels, test_batch, run):\n", + " \"\"\"Train the model with a given schedule of label sets (tasks).\"\"\"\n", + "\n", + " # Initial accuracy\n", + " tot_batches = 0\n", + " accuracies = {\n", + " 'accuracies_vanilla': [[] for _ in range(len(test_labels))],\n", + " 'accuracies_tanh': [[] for _ in range(len(test_labels))],\n", + " 'accuracies_ensemble': [[] for _ in range(len(test_labels))]}\n", + " accuracies = test_model(model_params, enc_params, test_labels, test_batch,\n", + " accuracies, tot_batches)\n", + " # Go through the schedule\n", + " for ep_no in range(len(schedule)):\n", + " episode = schedule[ep_no]\n", + " label_set = episode['label_set']\n", + " n_batches = episode['n_batches']\n", + " (rng2, rng) = random.split(rng)\n", + " (model_params, accuracies, tot_batches) = train_model(\n", + " model_params, enc_params, rng2,\n", + " label_set, test_labels, test_batch, accuracies, run,\n", + " n_batches=n_batches, tot_batches=tot_batches)\n", + "\n", + " return (model_params, accuracies)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d7RFUpZjE3ua" + }, + "source": [ + "# Schedules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t8gecZXnvRuD" + }, + "outputs": [], + "source": [ + "def gaussian(peak, width, x):\n", + " out = jnp.exp(- ((x - peak)**2 / (2 * width**2)))\n", + " return out\n", + "\n", + "\n", + "def gaussian_schedule(rng):\n", + " \"\"\"Returns a schedule where one task blends smoothly into the next.\"\"\"\n", + "\n", + " schedule_length = 1000 # schedule length in batches\n", + " episode_length = 5 # episode length in batches\n", + "\n", + " # Each class label appears according to a Gaussian probability distribution\n", + " # with peaks spread evenly over the schedule\n", + " peak_every = schedule_length // config['num_classes']\n", + " width = 50 # width of Gaussian\n", + " peaks = range(peak_every // 2, schedule_length, peak_every)\n", + "\n", + " schedule = []\n", + " labels = jnp.array(list(range(config['num_classes'])))\n", + " labels = random.permutation(rng, labels) # labels in random order\n", + "\n", + " for ep_no in range(0, schedule_length // episode_length):\n", + "\n", + " lbls = []\n", + " while lbls == []: # make sure lbls isn't empty\n", + " for j in range(len(peaks)):\n", + " peak = peaks[j]\n", + " # Sample from a Gaussian with peak in the right place\n", + " p = gaussian(peak, width, ep_no * episode_length)\n", + " (rng2, rng) = jax.random.split(rng)\n", + " add = jax.random.bernoulli(rng2, p=p)\n", + " if add:\n", + " lbls.append(int(labels[j]))\n", + "\n", + " episode = {'label_set': lbls, 'n_batches': episode_length}\n", + " schedule.append(episode)\n", + "\n", + " return schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B2hXtb2YhxUZ" + }, + "outputs": [], + "source": [ + "def get_schedule(name, rng=None):\n", + "\n", + " # Full set of labels\n", + " test_labels = [[x] for x in range(10)]\n", + "\n", + " if name == '1way_split':\n", + " # 1-way split schedule (multi-task setting) (MNIST or CIFAR-10)\n", + " schedule = [{'label_set': list(range(10)),\n", + " 'n_batches': 1000}]\n", + "\n", + " elif name == '2way_split':\n", + " # Random 2-way split schedule (MNIST or CIFAR-10)\n", + " lbls = jnp.array(list(range(10)))\n", + " lbls = random.permutation(rng, lbls)\n", + " lbls = jnp.reshape(lbls, (2, 5))\n", + " schedule = [{'label_set': lbl.tolist(), 'n_batches': 500} for lbl in lbls]\n", + "\n", + " elif name == '5way_split':\n", + " # Random 5-way split schedule (MNIST or CIFAR-10)\n", + " lbls = jnp.array(list(range(10)))\n", + " lbls = random.permutation(rng, lbls)\n", + " lbls = jnp.reshape(lbls, (5, 2))\n", + " schedule = [{'label_set': lbl.tolist(), 'n_batches': 200} for lbl in lbls]\n", + "\n", + " elif name == '10way_split':\n", + " # Random 10-way split schedule (MNIST or CIFAR-10)\n", + " lbls = jnp.array(list(range(10)))\n", + " lbls = random.permutation(rng, lbls)\n", + " lbls = jnp.reshape(lbls, (10, 1))\n", + " schedule = [{'label_set': lbl.tolist(), 'n_batches': 100} for lbl in lbls]\n", + "\n", + " elif name == 'gaussian_schedule':\n", + " # Gaussian schedule\n", + " schedule = gaussian_schedule(rng)\n", + "\n", + " else:\n", + " print('Error: no such schedule')\n", + " print()\n", + "\n", + " return (schedule, test_labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PVembGzr_-Tc" + }, + "source": [ + "# Scripts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Iocm19OgdiDJ" + }, + "outputs": [], + "source": [ + "# Training scripts\n", + "\n", + "\n", + "def get_test_batches(test_labels):\n", + " \"\"\"Split test batch up into label-wise batches according to 'test_labels'.\"\"\"\n", + " batch_size = 10000 // len(test_labels)\n", + " test_batches = []\n", + " for labels in test_labels:\n", + " dataset = get_dataset(config['main_dataset'], 'test', batch_size,\n", + " filter_labels=labels)\n", + " test_batch = get_batch(dataset, config['main_dataset'])\n", + " test_batches.append(test_batch)\n", + " return test_batches\n", + "\n", + "\n", + "def main():\n", + " \"\"\"Main script - multiple runs.\"\"\"\n", + "\n", + " print('STARTING EXPERIMENT')\n", + " print()\n", + " schedule_name = config['schedule_type']\n", + " rng = random.PRNGKey(78)\n", + " n_runs = config['n_runs']\n", + " # Ensemble training\n", + " x_accuracies1 = [] # vanilla classifier accuracies\n", + " x_accuracies2 = [] # tanh classifier accuracies\n", + " x_accuracies3 = [] # ensemble accuracies\n", + " for run in range(n_runs):\n", + " print('RUN {} of {}'.format(run+1, n_runs))\n", + " print()\n", + " # Encoder pretraining - new encoder every run\n", + " (rng2, rng) = random.split(rng)\n", + " enc_params = pretrain_encoder(rng2, n_batches=config['pretrain_n_batches'])\n", + " # Get a schedule\n", + " (rng2, rng) = random.split(rng)\n", + " (schedule, test_labels) = get_schedule(schedule_name, rng2)\n", + " # Get batches for testing\n", + " test_batches = get_test_batches(test_labels)\n", + " # Train the model\n", + " (rng2, rng) = random.split(rng)\n", + " model_params = initialise_model(rng2, enc_params)\n", + " # Carry out schedule\n", + " print('Ensemble training on {}'.format(config['main_dataset']))\n", + " print()\n", + " (rng2, rng) = random.split(rng)\n", + " (model_params, accuracies) = train_with_schedule(model_params, enc_params,\n", + " rng2, schedule,\n", + " test_labels, test_batches,\n", + " run)\n", + " # Record the results\n", + " x_accuracies1.append(accuracies['accuracies_vanilla'])\n", + " x_accuracies2.append(accuracies['accuracies_tanh'])\n", + " x_accuracies3.append(accuracies['accuracies_ensemble'])\n", + " plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3)\n", + " print('FINAL PLOT')\n", + " print()\n", + " plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3, final=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "arDSjo-nFFUO" + }, + "outputs": [], + "source": [ + "# Main script\n", + "\n", + "main()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Encoders and ensembles open source.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}