Files
Peter Hawkins d988ff1bf2 Replaces references to jax.numpy.DeviceArray with jax.Array.
PiperOrigin-RevId: 515678285
2023-06-02 18:03:13 +01:00

313 lines
11 KiB
Python

# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common utilities."""
from typing import Optional, Union
import jax
from jax import random
import jax.numpy as jnp
import pandas as pd
import tensorflow as tf
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
def get_dataset(dataset: pd.DataFrame,
batch_size: int,
shuffle_size: int = 10000,
num_epochs: Optional[int] = None) -> tf.data.Dataset:
"""Makes a tf.Dataset with correct preprocessing."""
dataset_copy = dataset.copy()
for column in dataset.columns:
if dataset[column].dtype.name == 'category':
dataset_copy.loc[:, column] = dataset[column].cat.codes
ds = tf.data.Dataset.from_tensor_slices(dataset_copy.values)
if shuffle_size > 0:
ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True)
ds = ds.repeat(num_epochs)
return ds.batch(batch_size, drop_remainder=True)
def multinomial_mode(
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Calculates the (one-hot) mode of a multinomial distribution.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | List[tensors].
If the former, it is assumed that it has a `probs` property, and
represents a distribution over categories. If the latter, these are
taken to be the probabilities of categories directly.
In either case, it is assumed that `probs` will be shape
(batch_size, dim).
Returns:
`DeviceArray`, float32, (batch_size, dim).
The mode of the distribution - this will be in one-hot form, but contain
multiple non-zero entries in the event that more than one probability is
joint-highest.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
probs = distribution_or_probs.probs_parameter() # pytype: disable=attribute-error # jax-devicearray
else:
probs = distribution_or_probs
max_prob = jnp.max(probs, axis=1, keepdims=True)
mode = jnp.int32(jnp.equal(probs, max_prob))
return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
def multinomial_class(
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Computes the mode class of a multinomial distribution.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | DeviceArray.
As for `multinomial_mode`.
Returns:
`DeviceArray`, float32, (batch_size,).
For each element in the batch, the index of the class with highest
probability.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1) # pytype: disable=attribute-error # jax-devicearray
return jnp.argmax(distribution_or_probs, axis=1)
def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
rather than Tensors.
Args:
probs: `DeviceArray`, (batch_size, dim). Probabilities for each class, for
each element in a batch.
Returns:
`DeviceArray`, (batch_size, dim).
"""
max_prob = jnp.amax(probs, axis=1, keepdims=True)
mode = jnp.equal(probs, max_prob).astype(jnp.int32)
return (mode / jnp.sum(mode, axis=1, keepdims=True)).astype(jnp.float32)
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
data: jax.Array) -> jax.Array:
"""Compute the accuracy, averaged over a batch of data.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | List[tensors].
As for functions above.
data: `DeviceArray`. Reference data, of shape (batch_size, dim).
Returns:
`DeviceArray`, float32, ().
Overall scalar accuracy.
"""
return jnp.mean(
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
def softmax_ndarray(logits: jax.Array) -> jax.Array:
"""Softmax function, implemented for numpy ndarrays."""
assert len(logits.shape) == 2
# Normalise for better stability.
s = jnp.max(logits, axis=1, keepdims=True)
e_x = jnp.exp(logits - s)
return e_x / jnp.sum(e_x, axis=1, keepdims=True)
def get_samples(distribution, num_samples, seed=None):
"""Given a batched distribution, compute samples and reshape along batch.
That is, we have a distribution of shape (batch_size, ...), where each element
of the tensor is independent. We then draw num_samples from each component, to
give a tensor of shape:
(num_samples, batch_size, ...)
Args:
distribution: `tfp.distributions.Distribution`. The distribution from which
to sample.
num_samples: `Integral` | `DeviceArray`, int32, (). The number of samples.
seed: `Integral` | `None`. The seed that will be forwarded to the call to
distribution.sample. Defaults to `None`.
Returns:
`DeviceArray`, float32, (batch_size * num_samples, ...).
Samples for each element of the batch.
"""
# Obtain the sample from the distribution, which will be of shape
# [num_samples] + batch_shape + event_shape.
sample = distribution.sample(num_samples, seed=seed)
sample = sample.reshape((-1, sample.shape[-1]))
# Combine the first two dimensions through a reshape, so the result will
# be of shape (num_samples * batch_size,) + shape_tail.
return sample
def mmd_loss(distribution: tfd.Distribution,
is_a: jax.Array,
num_samples: int,
rng: jnp.ndarray,
num_random_features: int = 50,
gamma: float = 1.):
"""Given two distributions, compute the Maximum Mean Discrepancy (MMD).
More exactly, this uses the 'FastMMD' approximation, a.k.a. 'Random Fourier
Features'. See the description, for example, in sections 2.3.1 and 2.4 of
https://arxiv.org/pdf/1511.00830.pdf.
Args:
distribution: Distribution whose `sample()` method will return a
DeviceArray of shape (batch_size, dim).
is_a: A boolean array indicating which elements of the batch correspond
to class A (the remaining indices correspond to class B).
num_samples: The number of samples to draw from `distribution`.
rng: Random seed provided by the user.
num_random_features: The number of random fourier features
used in the expansion.
gamma: The value of gamma in the Gaussian MMD kernel.
Returns:
`DeviceArray`, shape ().
The scalar MMD value for samples taken from the given distributions.
"""
if distribution.event_shape == (): # pylint: disable=g-explicit-bool-comparison
dim_x = distribution.batch_shape[1]
else:
dim_x, = distribution.event_shape
# Obtain samples from the distribution, which will be of shape
# [num_samples] + batch_shape + event_shape.
samples = distribution.sample(num_samples, seed=rng)
w = random.normal(rng, shape=((dim_x, num_random_features)))
b = random.uniform(rng, shape=(num_random_features,),
minval=0, maxval=2*jnp.pi)
def features(x):
"""Compute the kitchen sink feature."""
# We need to contract last axis of x with first of W - do this with
# tensordot. The result has shape:
# (?, ?, num_random_features)
return jnp.sqrt(2 / num_random_features) * jnp.cos(
jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
# Compute the expected values of the given features.
# The first axis represents the samples from the distribution,
# second axis represents the batch_size.
# Each of these now has shape (num_random_features,)
exp_features = features(samples)
# Swap axes so that batch_size is the last dimension to be compatible
# with is_a and is_b shape at the next step
exp_features_reshaped = jnp.swapaxes(exp_features, 1, 2)
# Current dimensions [num_samples, num_random_features, batch_size]
exp_features_reshaped_a = jnp.where(is_a, exp_features_reshaped, 0)
exp_features_reshaped_b = jnp.where(is_a, 0, exp_features_reshaped)
exp_features_a = jnp.mean(exp_features_reshaped_a, axis=(0, 2))
exp_features_b = jnp.mean(exp_features_reshaped_b, axis=(0, 2))
assert exp_features_a.shape == (num_random_features,)
difference = exp_features_a - exp_features_b
# Compute the squared norm. Shape ().
return jnp.tensordot(difference, difference, axes=1)
def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.):
"""Exact estimate of MMD."""
assert distribution_a.event_shape == distribution_b.event_shape
assert distribution_a.batch_shape[1:] == distribution_b.batch_shape[1:]
# shape (num_samples * batch_size_a, dim_x)
samples_a = get_samples(distribution_a, num_samples)
# shape (num_samples * batch_size_b, dim_x)
samples_b = get_samples(distribution_b, num_samples)
# Make matrices of shape
# (size_b, size_a, dim_x)
# where:
# size_a = num_samples * batch_size_a
# size_b = num_samples * batch_size_b
size_a = samples_a.shape[0]
size_b = samples_b.shape[0]
x_a = jnp.expand_dims(samples_a, axis=0)
x_a = jnp.tile(x_a, (size_b, 1, 1))
x_b = jnp.expand_dims(samples_b, axis=1)
x_b = jnp.tile(x_b, (1, size_a, 1))
def kernel_mean(x, y):
"""Gaussian kernel mean."""
diff = x - y
# Contract over dim_x.
exponent = - jnp.einsum('ijk,ijk->ij', diff, diff) / gamma
# This has shape (size_b, size_a).
kernel_matrix = jnp.exp(exponent)
# Shape ().
return jnp.mean(kernel_matrix)
# Equation 7 from arxiv 1511.00830
return (
kernel_mean(x_a, x_a)
+ kernel_mean(x_b, x_b)
- 2 * kernel_mean(x_a, x_b))
def scalar_log_prob(distribution, val):
"""Compute the log_prob per batch entry.
It is conceptually similar to:
jnp.sum(distribution.log_prob(val), axis=1)
However, classes like `tfp.distributions.Multinomial` have a log_prob which
returns a tensor of shape (batch_size,), which will cause the above
incantation to fail. In these cases we fall back to returning just:
distribution.log_prob(val)
Args:
distribution: `tfp.distributions.Distribution` which implements log_prob.
val: `DeviceArray`, (batch_size, dim).
Returns:
`DeviceArray`, (batch_size,).
If the result of log_prob has a trailing dimension, we perform a reduce_sum
over it.
Raises:
ValueError: If distribution.log_prob(val) has an unsupported shape.
"""
log_prob_val = distribution.log_prob(val)
if len(log_prob_val.shape) == 1:
return log_prob_val
elif len(log_prob_val.shape) > 2:
raise ValueError('log_prob_val has unexpected shape {}.'.format(
log_prob_val.shape))
return jnp.sum(log_prob_val, axis=1)