Internal change

PiperOrigin-RevId: 417034217
This commit is contained in:
Nimrod Gileadi
2021-12-17 18:27:35 +00:00
committed by Diego de Las Casas
parent fba48d1e44
commit 3f0d5ed1a0
41 changed files with 6676 additions and 0 deletions
+59
View File
@@ -0,0 +1,59 @@
# The Autoencoding Variational Autoencoder
This is the code for the models in NeurIPS Submission [AVAE](https://papers.nips.cc/paper/2020/file/ac10ff1941c540cd87c107330996f4f6-Paper.pdf)
Folder contains code to train AVAE model in JAX, and we will be uploading
evaluation setup soon.
Code files in the folder
- checkpointer.py: Checkpointing abstraction
- data_iterators.py: Datasets to be used
- decoders.py: VAE decoder network architectures
- encoders.py: VAE encoder network architectures
- kl.py: KL computation between 2 gaussians
- train.py: Function to train given ELBO, network and data
- train_main.py: Main file to train AVAE
- vae.py: VAE model defining various ELBOs
## Setup
To set up a Python3 virtual environment with the required dependencies, run:
```shell
python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt
```
## Running AVAE training
Following command will run AVAE training for ColorMnist dataset using MLP
network architectures.
```shell
python -m avae.train_main \
--dataset='color_mnist' \
--latent_dim=64 \
--checkpoint_dir='/tmp/avae_checkpoints' \
--checkpoint_filename='color_mnist_mlp_avae' \
--rho=0.975 \
--encoder='color_mnist_mlp_encoder' \
--decoder='color_mnist_mlp_decoder'
```
## References
### Citing our work
If you use that code for your research, please consider citing our paper:
```bibtex
@article{cemgil2020autoencoding,
title={The Autoencoding Variational Autoencoder},
author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}
```
Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

+88
View File
@@ -0,0 +1,88 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing functionality."""
import os
from typing import Any, Mapping, Optional
from absl import logging
import dill
import jax
import jax.numpy as jnp
class Checkpointer:
"""A checkpoint saving and loading class."""
def __init__(self, checkpoint_dir: str, filename: str):
"""Class initializer.
Args:
checkpoint_dir: Checkpoint directory.
filename: Filename of checkpoint in checkpoint directory.
"""
self._checkpoint_dir = checkpoint_dir
if not os.path.isdir(self._checkpoint_dir):
os.mkdir(self._checkpoint_dir)
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
def save_checkpoint(
self,
experiment_state: Mapping[str, jnp.ndarray],
opt_state: Mapping[str, jnp.ndarray],
step: int,
extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None:
"""Save checkpoint with experiment state and step information.
Args:
experiment_state: Experiment params to be stored.
opt_state: Optimizer state to be stored.
step: Training iteration step.
extra_checkpoint_info: Extra information to be stored.
"""
if jax.host_id() != 0:
return
checkpoint_data = dict(
experiment_state=jax.tree_map(jax.device_get, experiment_state),
opt_state=jax.tree_map(jax.device_get, opt_state),
step=step)
if extra_checkpoint_info is not None:
for key in extra_checkpoint_info:
checkpoint_data[key] = extra_checkpoint_info[key]
with open(self._checkpoint_path, 'wb') as checkpoint_file:
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
def load_checkpoint(
self) -> Optional[Mapping[str, Mapping[str, jnp.ndarray]]]:
"""Load and return checkpoint data.
Returns:
Loaded checkpoint if it exists else returns None.
"""
if os.path.exists(self._checkpoint_path):
with open(self._checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s, saved at step %d',
self._checkpoint_path, checkpoint_data['step'])
return checkpoint_data
else:
logging.warning('No pre-saved checkpoint found at %s',
self._checkpoint_path)
return None
+97
View File
@@ -0,0 +1,97 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset iterators Mnist, ColorMnist."""
import enum
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from avae import types
class Dataset(enum.Enum):
color_mnist = enum.auto()
class MnistDataIterator(object):
"""Mnist data iterator class.
Data is obtained as dataclass, types.LabelledData.
"""
def __init__(self, subset: str, batch_size: int):
"""Class initializer.
Args:
subset: Subset of dataset.
batch_size: Batch size of the returned dataset iterator.
"""
datasets = tfds.load('mnist')
train_dataset = datasets[subset]
def _map_fun(x):
return {'data': tf.cast(x['image'], tf.float32) / 255.0,
'label': x['label']}
train_dataset = train_dataset.map(_map_fun)
train_dataset = train_dataset.batch(batch_size,
drop_remainder=True).repeat()
self._iterator = iter(tfds.as_numpy(train_dataset))
self._batch_size = batch_size
def __iter__(self):
return self
def __next__(self) -> types.LabelledData:
return types.LabelledData(**next(self._iterator))
class ColorMnistDataIterator(MnistDataIterator):
"""Color Mnist data iterator.
Each ColorMnist image is of shape (28, 28, 3). ColorMnist digit can have 7
different colors by permution of RGB channels (turning on and off RGB
channels, except for all channels off permutation).
Data is obtained as dataclass, util_dataclasses.LabelledData.
"""
def __next__(self) -> types.LabelledData:
mnist_batch = next(self._iterator)
mnist_image = mnist_batch['data']
# Colors are supported by turning off and on RGB channels.
# Thus possible colors are
# [black (excluded), red, green, yellow, blue, magenta, cyan, white]
# color_id takes value from [1,8)
color_id = np.random.randint(7, size=self._batch_size) + 1
red_channel_bool = np.mod(color_id, 2)
red_channel_bool = jnp.reshape(red_channel_bool, [-1, 1, 1, 1])
blue_channel_bool = np.floor_divide(color_id, 4)
blue_channel_bool = jnp.reshape(blue_channel_bool, [-1, 1, 1, 1])
green_channel_bool = np.mod(np.floor_divide(color_id, 2), 2)
green_channel_bool = jnp.reshape(green_channel_bool, [-1, 1, 1, 1])
color_mnist_image = jnp.stack([
jnp.multiply(red_channel_bool, mnist_image),
jnp.multiply(blue_channel_bool, mnist_image),
jnp.multiply(green_channel_bool, mnist_image)], axis=3)
color_mnist_image = jnp.reshape(color_mnist_image, [-1, 28, 28, 3])
# Color id takes value [1, 8)
# Although to make classification code easier `color` label attached to data
# takes value in [0, 7) (by subtracting 1 from color id)
return types.LabelledData(
data=color_mnist_image, label=mnist_batch['label'])
+88
View File
@@ -0,0 +1,88 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder architectures to be used with VAE."""
import abc
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
class DecoderBase(hk.Module):
"""Base class for decoder network classes."""
def __init__(self, obs_var: float):
"""Class initializer.
Args:
obs_var: oversation variance of the dataset.
"""
super().__init__()
self._obs_var = obs_var
@abc.abstractmethod
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct from a given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstruction with shape (batch_size, ...).
"""
def data_fidelity(
self,
input_data: jnp.ndarray,
recons: jnp.ndarray,
) -> jnp.ndarray:
"""Compute Data fidelity (recons loss) for given input and recons.
Args:
input_data: Input batch of shape (batch_size, ...).
recons: Reconstruction of the input data. An array with the same shape as
`input_data.data`.
Returns:
Computed data fidelity term across batch of data. An array of shape
`(batch_size,)`.
"""
error = (input_data - recons).reshape(input_data.shape[0], -1)
return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
class ColorMnistMLPDecoder(DecoderBase):
"""MLP decoder for Color Mnist."""
_hidden_units = (200, 200, 200, 200)
_image_dims = (28, 28, 3) # Dimensions of a single MNIST image.
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct with given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstructions data of shape (batch_size, 28, 28, 3).
"""
out = z
for units in self._hidden_units:
out = hk.Linear(units)(out)
out = jax.nn.relu(out)
out = hk.Linear(np.product(self._image_dims))(out)
out = jax.nn.sigmoid(out)
return jnp.reshape(out, (-1,) + self._image_dims)
+112
View File
@@ -0,0 +1,112 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder architectures to be used with VAE."""
import abc
from typing import Generic, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
from avae import types
_Params = TypeVar('_Params')
class EncoderBase(hk.Module, Generic[_Params]):
"""Abstract class for encoder architectures."""
def __init__(self, latent_dim: int):
"""Class initializer.
Args:
latent_dim: Latent dimensions of the model.
"""
super().__init__()
self._latent_dim = latent_dim
@abc.abstractmethod
def __call__(self, input_data: jnp.ndarray) -> _Params:
"""Return posterior distribution over latents.
Args:
input_data: Input batch of shape (batch_size, ...).
Returns:
Parameters of the posterior distribution over the latents.
"""
@abc.abstractmethod
def sample(self, posterior: _Params, key: jnp.ndarray) -> jnp.ndarray:
"""Sample from the given posterior distribution.
Args:
posterior: Parameters of posterior distribution over the latents.
key: Random number generator key.
Returns:
Sample from the posterior distribution over latents,
shape[batch_size, latent_dim]
"""
class ColorMnistMLPEncoder(EncoderBase[types.NormalParams]):
"""MLP encoder for ColorMnist."""
_hidden_units = (200, 200, 200, 200)
def __call__(
self, input_data: jnp.ndarray) -> types.NormalParams:
"""Return posterior distribution over latents.
Args:
input_data: Input batch of shape (batch_size, ...).
Returns:
Posterior distribution over the latents.
"""
out = hk.Flatten()(input_data)
for units in self._hidden_units:
out = hk.Linear(units)(out)
out = jax.nn.relu(out)
out = hk.Linear(2 * self._latent_dim)(out)
return _normal_params_from_logits(out)
def sample(
self,
posterior: types.NormalParams,
key: jnp.ndarray,
) -> jnp.ndarray:
"""Sample from the given normal posterior (mean, var) distribution.
Args:
posterior: Posterior over the latents.
key: Random number generator key.
Returns:
Sample from the posterior distribution over latents,
shape[batch_size, latent_dim]
"""
eps = jax.random.normal(
key, shape=(posterior.mean.shape[0], self._latent_dim))
return posterior.mean + eps * posterior.variance
def _normal_params_from_logits(
logits: jnp.ndarray) -> types.NormalParams:
"""Construct mean and variance of normal distribution from given logits."""
mean, log_variance = jnp.split(logits, 2, axis=1)
variance = jnp.exp(log_variance)
return types.NormalParams(mean=mean, variance=variance)
+43
View File
@@ -0,0 +1,43 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various KL implementations in JAX."""
import jax.numpy as jnp
def kl_p_with_uniform_normal(mean: jnp.ndarray,
variance: jnp.ndarray) -> jnp.ndarray:
r"""KL between p_dist with uniform normal prior.
Args:
mean: Mean of the gaussian distribution, shape (latent_dims,)
variance: Variance of the gaussian distribution, shape (latent_dims,)
Returns:
KL divergence KL(P||N(0, 1)) shape ()
"""
if len(variance.shape) == 2:
# If `variance` is a full covariance matrix
variance_trace = jnp.trace(variance)
_, ldet1 = jnp.linalg.slogdet(variance)
else:
variance_trace = jnp.sum(variance)
ldet1 = jnp.sum(jnp.log(variance))
mean_contribution = jnp.sum(jnp.square(mean))
res = -ldet1
res += variance_trace + mean_contribution - mean.shape[0]
return res * 0.5
+8
View File
@@ -0,0 +1,8 @@
absl-py>0.9
dill==0.3.2
dm-haiku
jax
optax
tensorflow
tensorflow_datasets
Executable
+31
View File
@@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt
# Start training AVAE model.
python -m avae.train_main \
--dataset='color_mnist' \
--latent_dim=64 \
--checkpoint_dir='/tmp/avae_checkpoints' \
--checkpoint_filename='color_mnist_mlp_avae' \
--rho=0.975 \
--encoder=color_mnist_mlp_encoder \
--decoder=color_mnist_mlp_decoder \
--iterations=1
+117
View File
@@ -0,0 +1,117 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VAE style training."""
from typing import Any, Callable, Iterator, Sequence, Mapping, Tuple, Optional
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from avae import checkpointer
from avae import types
def train(
train_data_iterator: Iterator[types.LabelledData],
test_data_iterator: Iterator[types.LabelledData],
elbo_fun: hk.Transformed,
learning_rate: float,
checkpoint_dir: str,
checkpoint_filename: str,
checkpoint_every: int,
test_every: int,
iterations: int,
rng_seed: int,
test_functions: Optional[Sequence[Callable[[Mapping[str, jnp.ndarray]],
Tuple[str, float]]]] = None,
extra_checkpoint_info: Optional[Mapping[str, Any]] = None):
"""Train VAE with given data iterator and elbo definition.
Args:
train_data_iterator: Iterator of batched training data.
test_data_iterator: Iterator of batched testing data.
elbo_fun: Haiku transfomed function returning elbo.
learning_rate: Learning rate to be used with optimizer.
checkpoint_dir: Path of the checkpoint directory.
checkpoint_filename: Filename of the checkpoint.
checkpoint_every: Checkpoint every N iterations.
test_every: Test and log results every N iterations.
iterations: Number of training iterations to perform.
rng_seed: Seed for random number generator.
test_functions: Test function iterable, each function takes test data and
outputs extra info to print at test and log time.
extra_checkpoint_info: Extra info to put inside saved checkpoint.
"""
rng_seq = hk.PRNGSequence(jax.random.PRNGKey(rng_seed))
opt_init, opt_update = optax.chain(
# Set the parameters of Adam. Note the learning_rate is not here.
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
# Put a minus sign to *minimise* the loss.
optax.scale(-learning_rate))
@jax.jit
def loss(params, key, data):
elbo_outputs = elbo_fun.apply(params, key, data)
return -jnp.mean(elbo_outputs.elbo)
@jax.jit
def loss_test(params, key, data):
elbo_output = elbo_fun.apply(params, key, data)
return (-jnp.mean(elbo_output.elbo), jnp.mean(elbo_output.data_fidelity),
jnp.mean(elbo_output.kl))
@jax.jit
def update_step(params, key, data, opt_state):
grads = jax.grad(loss, has_aux=False)(params, key, data)
updates, opt_state = opt_update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state
exp_checkpointer = checkpointer.Checkpointer(
checkpoint_dir, checkpoint_filename)
experiment_data = exp_checkpointer.load_checkpoint()
if experiment_data is not None:
start = experiment_data['step']
params = experiment_data['experiment_state']
opt_state = experiment_data['opt_state']
else:
start = 0
params = elbo_fun.init(
next(rng_seq), next(train_data_iterator).data)
opt_state = opt_init(params)
for step in range(start, iterations, 1):
if step % test_every == 0:
test_loss, ll, kl = loss_test(params, next(rng_seq),
next(test_data_iterator).data)
output_message = (f'Step {step} elbo {-test_loss:0.2f} LL {ll:0.2f} '
f'KL {kl:0.2f}')
if test_functions:
for test_function in test_functions:
name, test_output = test_function(params)
output_message += f' {name}: {test_output:0.2f}'
print(output_message)
params, opt_state = update_step(params, next(rng_seq),
next(train_data_iterator).data, opt_state)
if step % checkpoint_every == 0:
exp_checkpointer.save_checkpoint(
params, opt_state, step, extra_checkpoint_info)
+127
View File
@@ -0,0 +1,127 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main file for training VAE/AVAE."""
import enum
from absl import app
from absl import flags
import haiku as hk
from avae import data_iterators
from avae import decoders
from avae import encoders
from avae import train
from avae import vae
class Model(enum.Enum):
vae = enum.auto()
avae = enum.auto()
class EncoderArch(enum.Enum):
color_mnist_mlp_encoder = 'ColorMnistMLPEncoder'
class DecoderArch(enum.Enum):
color_mnist_mlp_decoder = 'ColorMnistMLPDecoder'
_DATASET = flags.DEFINE_enum_class(
'dataset', data_iterators.Dataset.color_mnist, data_iterators.Dataset,
'Dataset to train on')
_LATENT_DIM = flags.DEFINE_integer('latent_dim', 32,
'Number of latent dimensions.')
_TRAIN_BATCH_SIZE = flags.DEFINE_integer('train_batch_size', 64,
'Train batch size.')
_TEST_BATCH_SIZE = flags.DEFINE_integer('test_batch_size', 64,
'Testing batch size.')
_TEST_EVERY = flags.DEFINE_integer('test_every', 1000,
'Test every N iterations.')
_ITERATIONS = flags.DEFINE_integer('iterations', 102000,
'Number of training iterations.')
_OBS_VAR = flags.DEFINE_float('obs_var', 0.5,
'Observation variance of the data. (Default 0.5)')
_MODEL = flags.DEFINE_enum_class('model', Model.avae, Model,
'Model used for training.')
_RHO = flags.DEFINE_float('rho', 0.8, 'Rho parameter used with AVAE or SE.')
_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
_RNG_SEED = flags.DEFINE_integer('rng_seed', 0,
'Seed for random number generator.')
_CHECKPOINT_DIR = flags.DEFINE_string('checkpoint_dir', '/tmp/',
'Directory for checkpointing.')
_CHECKPOINT_FILENAME = flags.DEFINE_string(
'checkpoint_filename', 'color_mnist_avae_mlp', 'Checkpoint filename.')
_CHECKPOINT_EVERY = flags.DEFINE_integer(
'checkpoint_every', 1000, 'Checkpoint every N steps.')
_ENCODER = flags.DEFINE_enum_class(
'encoder', EncoderArch.color_mnist_mlp_encoder, EncoderArch,
'Encoder class name.')
_DECODER = flags.DEFINE_enum_class(
'decoder', DecoderArch.color_mnist_mlp_decoder, DecoderArch,
'Decoder class name.')
def main(_):
if _DATASET.value is data_iterators.Dataset.color_mnist:
train_data_iterator = iter(
data_iterators.ColorMnistDataIterator('train', _TRAIN_BATCH_SIZE.value))
test_data_iterator = iter(
data_iterators.ColorMnistDataIterator('test', _TEST_BATCH_SIZE.value))
def _elbo_fun(input_data):
if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder:
encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value)
if _DECODER.value is DecoderArch.color_mnist_mlp_decoder:
decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value)
vae_obj = vae.VAE(encoder, decoder, _RHO.value)
if _MODEL.value is Model.vae:
return vae_obj.vae_elbo(input_data, hk.next_rng_key())
else:
return vae_obj.avae_elbo(input_data, hk.next_rng_key())
elbo_fun = hk.transform(_elbo_fun)
extra_checkpoint_info = {
'dataset': _DATASET.value.name,
'encoder': _ENCODER.value.name,
'decoder': _DECODER.value.name,
'obs_var': _OBS_VAR.value,
'rho': _RHO.value,
'latent_dim': _LATENT_DIM.value,
}
train.train(
train_data_iterator=train_data_iterator,
test_data_iterator=test_data_iterator,
elbo_fun=elbo_fun,
learning_rate=_LEARNING_RATE.value,
checkpoint_dir=_CHECKPOINT_DIR.value,
checkpoint_filename=_CHECKPOINT_FILENAME.value,
checkpoint_every=_CHECKPOINT_EVERY.value,
test_every=_TEST_EVERY.value,
iterations=_ITERATIONS.value,
rng_seed=_RNG_SEED.value,
extra_checkpoint_info=extra_checkpoint_info)
if __name__ == '__main__':
app.run(main)
+53
View File
@@ -0,0 +1,53 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Useful dataclasses types used across the code."""
from typing import Optional
import dataclasses
import jax.numpy as jnp
import numpy as np
@dataclasses.dataclass(frozen=True)
class ELBOOutputs:
elbo: jnp.ndarray
data_fidelity: jnp.ndarray
kl: jnp.ndarray
@dataclasses.dataclass(frozen=True)
class LabelledData:
"""A batch of labelled examples.
Attributes:
data: Array of shape (batch_size, ...).
label: Array of shape (batch_size, ...).
"""
data: np.ndarray
label: Optional[np.ndarray]
@dataclasses.dataclass(frozen=True)
class NormalParams:
"""Parameters of a normal distribution.
Attributes:
mean: Array of shape (batch_size, latent_dim).
variance: Array of shape (batch_size, latent_dim).
"""
mean: jnp.ndarray
variance: jnp.ndarray
+133
View File
@@ -0,0 +1,133 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard VAE class."""
from typing import Optional
import jax
import jax.numpy as jnp
from avae import decoders
from avae import encoders
from avae import kl
from avae import types
class VAE:
"""VAE class.
This class defines the ELBO used in training VAE models. It also adds function
for forward passing data through VAE.
"""
def __init__(self, encoder: encoders.EncoderBase,
decoder: decoders.DecoderBase, rho: Optional[float] = None):
"""Class initializer.
Args:
encoder: Encoder network architecture.
decoder: Decoder network architecture.
rho: Rho parameter used in AVAE training.
"""
self._encoder = encoder
self._decoder = decoder
self._rho = rho
def vae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training VAE.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed VAE Elbo as type util_dataclasses.ELBOOutputs
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
return types.ELBOOutputs(elbo, data_fidelity, kls)
def avae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training AVAE model.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed AVAE Elbo in nested tuple (Elbo, (data_fidelity, KL)). All arrays
have batch dimension intact.
"""
aux_images = jax.lax.stop_gradient(self(input_data, key))
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0, 0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
aux_posterior = self._encoder(aux_images)
latent_mean = posterior.mean
latent_var = posterior.variance
aux_latent_mean = aux_posterior.mean
aux_latent_var = aux_posterior.variance
latent_dim = latent_mean.shape[1]
def _reduce(x):
return jnp.mean(jnp.sum(x, axis=1))
# Computation of <log p(Z_aux | Z)>.
expected_log_conditional = (
aux_latent_var + jnp.square(self._rho) * latent_var +
jnp.square(aux_latent_mean - self._rho * latent_mean))
expected_log_conditional = _reduce(expected_log_conditional)
expected_log_conditional /= 2.0 * (1.0 - jnp.square(self._rho))
expected_log_conditional = (latent_dim *
jnp.log(1.0 / (2 * jnp.pi)) -
expected_log_conditional)
elbo += expected_log_conditional
# Entropy of Z_aux
elbo += _reduce(0.5 * jnp.log(2 * jnp.pi * jnp.e * aux_latent_var))
return types.ELBOOutputs(elbo, data_fidelity, kls)
def __call__(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> jnp.ndarray:
"""Reconstruction of the input data.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Reconstruction of the input data as jnp.ndarray of shape
[batch_dim, observation_dims].
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
recons = self._decoder(samples)
return recons