mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 06:06:39 +08:00
Internal change
PiperOrigin-RevId: 417034217
This commit is contained in:
committed by
Diego de Las Casas
parent
fba48d1e44
commit
3f0d5ed1a0
@@ -0,0 +1,59 @@
|
||||
# The Autoencoding Variational Autoencoder
|
||||
|
||||
This is the code for the models in NeurIPS Submission [AVAE](https://papers.nips.cc/paper/2020/file/ac10ff1941c540cd87c107330996f4f6-Paper.pdf)
|
||||
|
||||
Folder contains code to train AVAE model in JAX, and we will be uploading
|
||||
evaluation setup soon.
|
||||
|
||||
Code files in the folder
|
||||
- checkpointer.py: Checkpointing abstraction
|
||||
- data_iterators.py: Datasets to be used
|
||||
- decoders.py: VAE decoder network architectures
|
||||
- encoders.py: VAE encoder network architectures
|
||||
- kl.py: KL computation between 2 gaussians
|
||||
- train.py: Function to train given ELBO, network and data
|
||||
- train_main.py: Main file to train AVAE
|
||||
- vae.py: VAE model defining various ELBOs
|
||||
|
||||
## Setup
|
||||
|
||||
To set up a Python3 virtual environment with the required dependencies, run:
|
||||
|
||||
```shell
|
||||
python -m venv avae_env
|
||||
source avae_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r avae/requirements.txt
|
||||
```
|
||||
|
||||
## Running AVAE training
|
||||
|
||||
Following command will run AVAE training for ColorMnist dataset using MLP
|
||||
network architectures.
|
||||
|
||||
```shell
|
||||
python -m avae.train_main \
|
||||
--dataset='color_mnist' \
|
||||
--latent_dim=64 \
|
||||
--checkpoint_dir='/tmp/avae_checkpoints' \
|
||||
--checkpoint_filename='color_mnist_mlp_avae' \
|
||||
--rho=0.975 \
|
||||
--encoder='color_mnist_mlp_encoder' \
|
||||
--decoder='color_mnist_mlp_decoder'
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
### Citing our work
|
||||
|
||||
If you use that code for your research, please consider citing our paper:
|
||||
|
||||
```bibtex
|
||||
@article{cemgil2020autoencoding,
|
||||
title={The Autoencoding Variational Autoencoder},
|
||||
author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
|
||||
journal={Advances in Neural Information Processing Systems},
|
||||
volume={33},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 164 KiB |
@@ -0,0 +1,88 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Checkpointing functionality."""
|
||||
|
||||
import os
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from absl import logging
|
||||
import dill
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Checkpointer:
|
||||
"""A checkpoint saving and loading class."""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, filename: str):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Checkpoint directory.
|
||||
filename: Filename of checkpoint in checkpoint directory.
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
if not os.path.isdir(self._checkpoint_dir):
|
||||
os.mkdir(self._checkpoint_dir)
|
||||
|
||||
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
experiment_state: Mapping[str, jnp.ndarray],
|
||||
opt_state: Mapping[str, jnp.ndarray],
|
||||
step: int,
|
||||
extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None:
|
||||
"""Save checkpoint with experiment state and step information.
|
||||
|
||||
Args:
|
||||
experiment_state: Experiment params to be stored.
|
||||
opt_state: Optimizer state to be stored.
|
||||
step: Training iteration step.
|
||||
extra_checkpoint_info: Extra information to be stored.
|
||||
"""
|
||||
if jax.host_id() != 0:
|
||||
return
|
||||
|
||||
checkpoint_data = dict(
|
||||
experiment_state=jax.tree_map(jax.device_get, experiment_state),
|
||||
opt_state=jax.tree_map(jax.device_get, opt_state),
|
||||
step=step)
|
||||
if extra_checkpoint_info is not None:
|
||||
for key in extra_checkpoint_info:
|
||||
checkpoint_data[key] = extra_checkpoint_info[key]
|
||||
|
||||
with open(self._checkpoint_path, 'wb') as checkpoint_file:
|
||||
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
|
||||
|
||||
def load_checkpoint(
|
||||
self) -> Optional[Mapping[str, Mapping[str, jnp.ndarray]]]:
|
||||
"""Load and return checkpoint data.
|
||||
|
||||
Returns:
|
||||
Loaded checkpoint if it exists else returns None.
|
||||
"""
|
||||
|
||||
if os.path.exists(self._checkpoint_path):
|
||||
with open(self._checkpoint_path, 'rb') as checkpoint_file:
|
||||
checkpoint_data = dill.load(checkpoint_file)
|
||||
logging.info('Loading checkpoint from %s, saved at step %d',
|
||||
self._checkpoint_path, checkpoint_data['step'])
|
||||
return checkpoint_data
|
||||
else:
|
||||
logging.warning('No pre-saved checkpoint found at %s',
|
||||
self._checkpoint_path)
|
||||
return None
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset iterators Mnist, ColorMnist."""
|
||||
|
||||
import enum
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from avae import types
|
||||
|
||||
|
||||
class Dataset(enum.Enum):
|
||||
color_mnist = enum.auto()
|
||||
|
||||
|
||||
class MnistDataIterator(object):
|
||||
"""Mnist data iterator class.
|
||||
|
||||
Data is obtained as dataclass, types.LabelledData.
|
||||
"""
|
||||
|
||||
def __init__(self, subset: str, batch_size: int):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
subset: Subset of dataset.
|
||||
batch_size: Batch size of the returned dataset iterator.
|
||||
"""
|
||||
datasets = tfds.load('mnist')
|
||||
train_dataset = datasets[subset]
|
||||
def _map_fun(x):
|
||||
return {'data': tf.cast(x['image'], tf.float32) / 255.0,
|
||||
'label': x['label']}
|
||||
train_dataset = train_dataset.map(_map_fun)
|
||||
train_dataset = train_dataset.batch(batch_size,
|
||||
drop_remainder=True).repeat()
|
||||
self._iterator = iter(tfds.as_numpy(train_dataset))
|
||||
self._batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> types.LabelledData:
|
||||
return types.LabelledData(**next(self._iterator))
|
||||
|
||||
|
||||
class ColorMnistDataIterator(MnistDataIterator):
|
||||
"""Color Mnist data iterator.
|
||||
|
||||
Each ColorMnist image is of shape (28, 28, 3). ColorMnist digit can have 7
|
||||
different colors by permution of RGB channels (turning on and off RGB
|
||||
channels, except for all channels off permutation).
|
||||
|
||||
Data is obtained as dataclass, util_dataclasses.LabelledData.
|
||||
"""
|
||||
|
||||
def __next__(self) -> types.LabelledData:
|
||||
mnist_batch = next(self._iterator)
|
||||
mnist_image = mnist_batch['data']
|
||||
# Colors are supported by turning off and on RGB channels.
|
||||
# Thus possible colors are
|
||||
# [black (excluded), red, green, yellow, blue, magenta, cyan, white]
|
||||
# color_id takes value from [1,8)
|
||||
color_id = np.random.randint(7, size=self._batch_size) + 1
|
||||
red_channel_bool = np.mod(color_id, 2)
|
||||
red_channel_bool = jnp.reshape(red_channel_bool, [-1, 1, 1, 1])
|
||||
blue_channel_bool = np.floor_divide(color_id, 4)
|
||||
blue_channel_bool = jnp.reshape(blue_channel_bool, [-1, 1, 1, 1])
|
||||
green_channel_bool = np.mod(np.floor_divide(color_id, 2), 2)
|
||||
green_channel_bool = jnp.reshape(green_channel_bool, [-1, 1, 1, 1])
|
||||
|
||||
color_mnist_image = jnp.stack([
|
||||
jnp.multiply(red_channel_bool, mnist_image),
|
||||
jnp.multiply(blue_channel_bool, mnist_image),
|
||||
jnp.multiply(green_channel_bool, mnist_image)], axis=3)
|
||||
color_mnist_image = jnp.reshape(color_mnist_image, [-1, 28, 28, 3])
|
||||
# Color id takes value [1, 8)
|
||||
# Although to make classification code easier `color` label attached to data
|
||||
# takes value in [0, 7) (by subtracting 1 from color id)
|
||||
return types.LabelledData(
|
||||
data=color_mnist_image, label=mnist_batch['label'])
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Decoder architectures to be used with VAE."""
|
||||
|
||||
import abc
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DecoderBase(hk.Module):
|
||||
"""Base class for decoder network classes."""
|
||||
|
||||
def __init__(self, obs_var: float):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
obs_var: oversation variance of the dataset.
|
||||
"""
|
||||
super().__init__()
|
||||
self._obs_var = obs_var
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruct from a given latent sample.
|
||||
|
||||
Args:
|
||||
z: latent samples of shape (batch_size, latent_dim)
|
||||
Returns:
|
||||
Reconstruction with shape (batch_size, ...).
|
||||
"""
|
||||
|
||||
def data_fidelity(
|
||||
self,
|
||||
input_data: jnp.ndarray,
|
||||
recons: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""Compute Data fidelity (recons loss) for given input and recons.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
recons: Reconstruction of the input data. An array with the same shape as
|
||||
`input_data.data`.
|
||||
Returns:
|
||||
Computed data fidelity term across batch of data. An array of shape
|
||||
`(batch_size,)`.
|
||||
"""
|
||||
error = (input_data - recons).reshape(input_data.shape[0], -1)
|
||||
return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
|
||||
|
||||
|
||||
class ColorMnistMLPDecoder(DecoderBase):
|
||||
"""MLP decoder for Color Mnist."""
|
||||
|
||||
_hidden_units = (200, 200, 200, 200)
|
||||
_image_dims = (28, 28, 3) # Dimensions of a single MNIST image.
|
||||
|
||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruct with given latent sample.
|
||||
|
||||
Args:
|
||||
z: latent samples of shape (batch_size, latent_dim)
|
||||
Returns:
|
||||
Reconstructions data of shape (batch_size, 28, 28, 3).
|
||||
"""
|
||||
out = z
|
||||
for units in self._hidden_units:
|
||||
out = hk.Linear(units)(out)
|
||||
out = jax.nn.relu(out)
|
||||
out = hk.Linear(np.product(self._image_dims))(out)
|
||||
out = jax.nn.sigmoid(out)
|
||||
return jnp.reshape(out, (-1,) + self._image_dims)
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Encoder architectures to be used with VAE."""
|
||||
|
||||
import abc
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from avae import types
|
||||
|
||||
_Params = TypeVar('_Params')
|
||||
|
||||
|
||||
class EncoderBase(hk.Module, Generic[_Params]):
|
||||
"""Abstract class for encoder architectures."""
|
||||
|
||||
def __init__(self, latent_dim: int):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
latent_dim: Latent dimensions of the model.
|
||||
"""
|
||||
super().__init__()
|
||||
self._latent_dim = latent_dim
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, input_data: jnp.ndarray) -> _Params:
|
||||
"""Return posterior distribution over latents.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
|
||||
Returns:
|
||||
Parameters of the posterior distribution over the latents.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, posterior: _Params, key: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Sample from the given posterior distribution.
|
||||
|
||||
Args:
|
||||
posterior: Parameters of posterior distribution over the latents.
|
||||
key: Random number generator key.
|
||||
|
||||
Returns:
|
||||
Sample from the posterior distribution over latents,
|
||||
shape[batch_size, latent_dim]
|
||||
"""
|
||||
|
||||
|
||||
class ColorMnistMLPEncoder(EncoderBase[types.NormalParams]):
|
||||
"""MLP encoder for ColorMnist."""
|
||||
|
||||
_hidden_units = (200, 200, 200, 200)
|
||||
|
||||
def __call__(
|
||||
self, input_data: jnp.ndarray) -> types.NormalParams:
|
||||
"""Return posterior distribution over latents.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
|
||||
Returns:
|
||||
Posterior distribution over the latents.
|
||||
"""
|
||||
out = hk.Flatten()(input_data)
|
||||
for units in self._hidden_units:
|
||||
out = hk.Linear(units)(out)
|
||||
out = jax.nn.relu(out)
|
||||
out = hk.Linear(2 * self._latent_dim)(out)
|
||||
return _normal_params_from_logits(out)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
posterior: types.NormalParams,
|
||||
key: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""Sample from the given normal posterior (mean, var) distribution.
|
||||
|
||||
Args:
|
||||
posterior: Posterior over the latents.
|
||||
key: Random number generator key.
|
||||
Returns:
|
||||
Sample from the posterior distribution over latents,
|
||||
shape[batch_size, latent_dim]
|
||||
"""
|
||||
eps = jax.random.normal(
|
||||
key, shape=(posterior.mean.shape[0], self._latent_dim))
|
||||
return posterior.mean + eps * posterior.variance
|
||||
|
||||
|
||||
def _normal_params_from_logits(
|
||||
logits: jnp.ndarray) -> types.NormalParams:
|
||||
"""Construct mean and variance of normal distribution from given logits."""
|
||||
mean, log_variance = jnp.split(logits, 2, axis=1)
|
||||
variance = jnp.exp(log_variance)
|
||||
return types.NormalParams(mean=mean, variance=variance)
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Various KL implementations in JAX."""
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def kl_p_with_uniform_normal(mean: jnp.ndarray,
|
||||
variance: jnp.ndarray) -> jnp.ndarray:
|
||||
r"""KL between p_dist with uniform normal prior.
|
||||
|
||||
Args:
|
||||
mean: Mean of the gaussian distribution, shape (latent_dims,)
|
||||
variance: Variance of the gaussian distribution, shape (latent_dims,)
|
||||
Returns:
|
||||
KL divergence KL(P||N(0, 1)) shape ()
|
||||
"""
|
||||
|
||||
if len(variance.shape) == 2:
|
||||
# If `variance` is a full covariance matrix
|
||||
variance_trace = jnp.trace(variance)
|
||||
_, ldet1 = jnp.linalg.slogdet(variance)
|
||||
else:
|
||||
variance_trace = jnp.sum(variance)
|
||||
ldet1 = jnp.sum(jnp.log(variance))
|
||||
|
||||
mean_contribution = jnp.sum(jnp.square(mean))
|
||||
res = -ldet1
|
||||
res += variance_trace + mean_contribution - mean.shape[0]
|
||||
return res * 0.5
|
||||
@@ -0,0 +1,8 @@
|
||||
absl-py>0.9
|
||||
dill==0.3.2
|
||||
dm-haiku
|
||||
jax
|
||||
optax
|
||||
tensorflow
|
||||
tensorflow_datasets
|
||||
|
||||
Executable
+31
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python -m venv avae_env
|
||||
source avae_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r avae/requirements.txt
|
||||
|
||||
# Start training AVAE model.
|
||||
python -m avae.train_main \
|
||||
--dataset='color_mnist' \
|
||||
--latent_dim=64 \
|
||||
--checkpoint_dir='/tmp/avae_checkpoints' \
|
||||
--checkpoint_filename='color_mnist_mlp_avae' \
|
||||
--rho=0.975 \
|
||||
--encoder=color_mnist_mlp_encoder \
|
||||
--decoder=color_mnist_mlp_decoder \
|
||||
--iterations=1
|
||||
+117
@@ -0,0 +1,117 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""VAE style training."""
|
||||
|
||||
from typing import Any, Callable, Iterator, Sequence, Mapping, Tuple, Optional
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
from avae import checkpointer
|
||||
from avae import types
|
||||
|
||||
|
||||
def train(
|
||||
train_data_iterator: Iterator[types.LabelledData],
|
||||
test_data_iterator: Iterator[types.LabelledData],
|
||||
elbo_fun: hk.Transformed,
|
||||
learning_rate: float,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_filename: str,
|
||||
checkpoint_every: int,
|
||||
test_every: int,
|
||||
iterations: int,
|
||||
rng_seed: int,
|
||||
test_functions: Optional[Sequence[Callable[[Mapping[str, jnp.ndarray]],
|
||||
Tuple[str, float]]]] = None,
|
||||
extra_checkpoint_info: Optional[Mapping[str, Any]] = None):
|
||||
"""Train VAE with given data iterator and elbo definition.
|
||||
|
||||
Args:
|
||||
train_data_iterator: Iterator of batched training data.
|
||||
test_data_iterator: Iterator of batched testing data.
|
||||
elbo_fun: Haiku transfomed function returning elbo.
|
||||
learning_rate: Learning rate to be used with optimizer.
|
||||
checkpoint_dir: Path of the checkpoint directory.
|
||||
checkpoint_filename: Filename of the checkpoint.
|
||||
checkpoint_every: Checkpoint every N iterations.
|
||||
test_every: Test and log results every N iterations.
|
||||
iterations: Number of training iterations to perform.
|
||||
rng_seed: Seed for random number generator.
|
||||
test_functions: Test function iterable, each function takes test data and
|
||||
outputs extra info to print at test and log time.
|
||||
extra_checkpoint_info: Extra info to put inside saved checkpoint.
|
||||
"""
|
||||
rng_seq = hk.PRNGSequence(jax.random.PRNGKey(rng_seed))
|
||||
|
||||
opt_init, opt_update = optax.chain(
|
||||
# Set the parameters of Adam. Note the learning_rate is not here.
|
||||
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
|
||||
# Put a minus sign to *minimise* the loss.
|
||||
optax.scale(-learning_rate))
|
||||
|
||||
@jax.jit
|
||||
def loss(params, key, data):
|
||||
elbo_outputs = elbo_fun.apply(params, key, data)
|
||||
return -jnp.mean(elbo_outputs.elbo)
|
||||
|
||||
@jax.jit
|
||||
def loss_test(params, key, data):
|
||||
elbo_output = elbo_fun.apply(params, key, data)
|
||||
return (-jnp.mean(elbo_output.elbo), jnp.mean(elbo_output.data_fidelity),
|
||||
jnp.mean(elbo_output.kl))
|
||||
|
||||
@jax.jit
|
||||
def update_step(params, key, data, opt_state):
|
||||
grads = jax.grad(loss, has_aux=False)(params, key, data)
|
||||
updates, opt_state = opt_update(grads, opt_state, params)
|
||||
params = optax.apply_updates(params, updates)
|
||||
return params, opt_state
|
||||
|
||||
exp_checkpointer = checkpointer.Checkpointer(
|
||||
checkpoint_dir, checkpoint_filename)
|
||||
experiment_data = exp_checkpointer.load_checkpoint()
|
||||
|
||||
if experiment_data is not None:
|
||||
start = experiment_data['step']
|
||||
params = experiment_data['experiment_state']
|
||||
opt_state = experiment_data['opt_state']
|
||||
else:
|
||||
start = 0
|
||||
params = elbo_fun.init(
|
||||
next(rng_seq), next(train_data_iterator).data)
|
||||
opt_state = opt_init(params)
|
||||
|
||||
for step in range(start, iterations, 1):
|
||||
if step % test_every == 0:
|
||||
test_loss, ll, kl = loss_test(params, next(rng_seq),
|
||||
next(test_data_iterator).data)
|
||||
output_message = (f'Step {step} elbo {-test_loss:0.2f} LL {ll:0.2f} '
|
||||
f'KL {kl:0.2f}')
|
||||
if test_functions:
|
||||
for test_function in test_functions:
|
||||
name, test_output = test_function(params)
|
||||
output_message += f' {name}: {test_output:0.2f}'
|
||||
print(output_message)
|
||||
|
||||
params, opt_state = update_step(params, next(rng_seq),
|
||||
next(train_data_iterator).data, opt_state)
|
||||
|
||||
if step % checkpoint_every == 0:
|
||||
exp_checkpointer.save_checkpoint(
|
||||
params, opt_state, step, extra_checkpoint_info)
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Main file for training VAE/AVAE."""
|
||||
|
||||
import enum
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
|
||||
from avae import data_iterators
|
||||
from avae import decoders
|
||||
from avae import encoders
|
||||
from avae import train
|
||||
from avae import vae
|
||||
|
||||
|
||||
class Model(enum.Enum):
|
||||
vae = enum.auto()
|
||||
avae = enum.auto()
|
||||
|
||||
|
||||
class EncoderArch(enum.Enum):
|
||||
color_mnist_mlp_encoder = 'ColorMnistMLPEncoder'
|
||||
|
||||
|
||||
class DecoderArch(enum.Enum):
|
||||
color_mnist_mlp_decoder = 'ColorMnistMLPDecoder'
|
||||
|
||||
|
||||
_DATASET = flags.DEFINE_enum_class(
|
||||
'dataset', data_iterators.Dataset.color_mnist, data_iterators.Dataset,
|
||||
'Dataset to train on')
|
||||
_LATENT_DIM = flags.DEFINE_integer('latent_dim', 32,
|
||||
'Number of latent dimensions.')
|
||||
_TRAIN_BATCH_SIZE = flags.DEFINE_integer('train_batch_size', 64,
|
||||
'Train batch size.')
|
||||
_TEST_BATCH_SIZE = flags.DEFINE_integer('test_batch_size', 64,
|
||||
'Testing batch size.')
|
||||
_TEST_EVERY = flags.DEFINE_integer('test_every', 1000,
|
||||
'Test every N iterations.')
|
||||
_ITERATIONS = flags.DEFINE_integer('iterations', 102000,
|
||||
'Number of training iterations.')
|
||||
_OBS_VAR = flags.DEFINE_float('obs_var', 0.5,
|
||||
'Observation variance of the data. (Default 0.5)')
|
||||
|
||||
|
||||
_MODEL = flags.DEFINE_enum_class('model', Model.avae, Model,
|
||||
'Model used for training.')
|
||||
_RHO = flags.DEFINE_float('rho', 0.8, 'Rho parameter used with AVAE or SE.')
|
||||
_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
|
||||
_RNG_SEED = flags.DEFINE_integer('rng_seed', 0,
|
||||
'Seed for random number generator.')
|
||||
_CHECKPOINT_DIR = flags.DEFINE_string('checkpoint_dir', '/tmp/',
|
||||
'Directory for checkpointing.')
|
||||
_CHECKPOINT_FILENAME = flags.DEFINE_string(
|
||||
'checkpoint_filename', 'color_mnist_avae_mlp', 'Checkpoint filename.')
|
||||
_CHECKPOINT_EVERY = flags.DEFINE_integer(
|
||||
'checkpoint_every', 1000, 'Checkpoint every N steps.')
|
||||
_ENCODER = flags.DEFINE_enum_class(
|
||||
'encoder', EncoderArch.color_mnist_mlp_encoder, EncoderArch,
|
||||
'Encoder class name.')
|
||||
_DECODER = flags.DEFINE_enum_class(
|
||||
'decoder', DecoderArch.color_mnist_mlp_decoder, DecoderArch,
|
||||
'Decoder class name.')
|
||||
|
||||
|
||||
def main(_):
|
||||
if _DATASET.value is data_iterators.Dataset.color_mnist:
|
||||
train_data_iterator = iter(
|
||||
data_iterators.ColorMnistDataIterator('train', _TRAIN_BATCH_SIZE.value))
|
||||
test_data_iterator = iter(
|
||||
data_iterators.ColorMnistDataIterator('test', _TEST_BATCH_SIZE.value))
|
||||
|
||||
def _elbo_fun(input_data):
|
||||
if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder:
|
||||
encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value)
|
||||
|
||||
if _DECODER.value is DecoderArch.color_mnist_mlp_decoder:
|
||||
decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value)
|
||||
|
||||
vae_obj = vae.VAE(encoder, decoder, _RHO.value)
|
||||
|
||||
if _MODEL.value is Model.vae:
|
||||
return vae_obj.vae_elbo(input_data, hk.next_rng_key())
|
||||
else:
|
||||
return vae_obj.avae_elbo(input_data, hk.next_rng_key())
|
||||
|
||||
elbo_fun = hk.transform(_elbo_fun)
|
||||
|
||||
extra_checkpoint_info = {
|
||||
'dataset': _DATASET.value.name,
|
||||
'encoder': _ENCODER.value.name,
|
||||
'decoder': _DECODER.value.name,
|
||||
'obs_var': _OBS_VAR.value,
|
||||
'rho': _RHO.value,
|
||||
'latent_dim': _LATENT_DIM.value,
|
||||
}
|
||||
|
||||
train.train(
|
||||
train_data_iterator=train_data_iterator,
|
||||
test_data_iterator=test_data_iterator,
|
||||
elbo_fun=elbo_fun,
|
||||
learning_rate=_LEARNING_RATE.value,
|
||||
checkpoint_dir=_CHECKPOINT_DIR.value,
|
||||
checkpoint_filename=_CHECKPOINT_FILENAME.value,
|
||||
checkpoint_every=_CHECKPOINT_EVERY.value,
|
||||
test_every=_TEST_EVERY.value,
|
||||
iterations=_ITERATIONS.value,
|
||||
rng_seed=_RNG_SEED.value,
|
||||
extra_checkpoint_info=extra_checkpoint_info)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Useful dataclasses types used across the code."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import dataclasses
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ELBOOutputs:
|
||||
elbo: jnp.ndarray
|
||||
data_fidelity: jnp.ndarray
|
||||
kl: jnp.ndarray
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LabelledData:
|
||||
"""A batch of labelled examples.
|
||||
|
||||
Attributes:
|
||||
data: Array of shape (batch_size, ...).
|
||||
label: Array of shape (batch_size, ...).
|
||||
"""
|
||||
data: np.ndarray
|
||||
label: Optional[np.ndarray]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NormalParams:
|
||||
"""Parameters of a normal distribution.
|
||||
|
||||
Attributes:
|
||||
mean: Array of shape (batch_size, latent_dim).
|
||||
variance: Array of shape (batch_size, latent_dim).
|
||||
"""
|
||||
mean: jnp.ndarray
|
||||
variance: jnp.ndarray
|
||||
+133
@@ -0,0 +1,133 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Standard VAE class."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from avae import decoders
|
||||
from avae import encoders
|
||||
from avae import kl
|
||||
from avae import types
|
||||
|
||||
|
||||
class VAE:
|
||||
"""VAE class.
|
||||
|
||||
This class defines the ELBO used in training VAE models. It also adds function
|
||||
for forward passing data through VAE.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder: encoders.EncoderBase,
|
||||
decoder: decoders.DecoderBase, rho: Optional[float] = None):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
encoder: Encoder network architecture.
|
||||
decoder: Decoder network architecture.
|
||||
rho: Rho parameter used in AVAE training.
|
||||
"""
|
||||
self._encoder = encoder
|
||||
self._decoder = decoder
|
||||
self._rho = rho
|
||||
|
||||
def vae_elbo(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> types.ELBOOutputs:
|
||||
"""ELBO for training VAE.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
|
||||
Computed VAE Elbo as type util_dataclasses.ELBOOutputs
|
||||
"""
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0])(
|
||||
posterior.mean, posterior.variance)
|
||||
recons = self._decoder(samples)
|
||||
data_fidelity = self._decoder.data_fidelity(input_data, recons)
|
||||
elbo = data_fidelity - kls
|
||||
return types.ELBOOutputs(elbo, data_fidelity, kls)
|
||||
|
||||
def avae_elbo(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> types.ELBOOutputs:
|
||||
"""ELBO for training AVAE model.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
Computed AVAE Elbo in nested tuple (Elbo, (data_fidelity, KL)). All arrays
|
||||
have batch dimension intact.
|
||||
"""
|
||||
aux_images = jax.lax.stop_gradient(self(input_data, key))
|
||||
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0, 0])(
|
||||
posterior.mean, posterior.variance)
|
||||
recons = self._decoder(samples)
|
||||
data_fidelity = self._decoder.data_fidelity(input_data, recons)
|
||||
elbo = data_fidelity - kls
|
||||
|
||||
aux_posterior = self._encoder(aux_images)
|
||||
latent_mean = posterior.mean
|
||||
latent_var = posterior.variance
|
||||
aux_latent_mean = aux_posterior.mean
|
||||
aux_latent_var = aux_posterior.variance
|
||||
latent_dim = latent_mean.shape[1]
|
||||
|
||||
def _reduce(x):
|
||||
return jnp.mean(jnp.sum(x, axis=1))
|
||||
|
||||
# Computation of <log p(Z_aux | Z)>.
|
||||
expected_log_conditional = (
|
||||
aux_latent_var + jnp.square(self._rho) * latent_var +
|
||||
jnp.square(aux_latent_mean - self._rho * latent_mean))
|
||||
expected_log_conditional = _reduce(expected_log_conditional)
|
||||
expected_log_conditional /= 2.0 * (1.0 - jnp.square(self._rho))
|
||||
expected_log_conditional = (latent_dim *
|
||||
jnp.log(1.0 / (2 * jnp.pi)) -
|
||||
expected_log_conditional)
|
||||
elbo += expected_log_conditional
|
||||
# Entropy of Z_aux
|
||||
elbo += _reduce(0.5 * jnp.log(2 * jnp.pi * jnp.e * aux_latent_var))
|
||||
|
||||
return types.ELBOOutputs(elbo, data_fidelity, kls)
|
||||
|
||||
def __call__(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruction of the input data.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
Reconstruction of the input data as jnp.ndarray of shape
|
||||
[batch_dim, observation_dims].
|
||||
"""
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
recons = self._decoder(samples)
|
||||
return recons
|
||||
|
||||
Reference in New Issue
Block a user