Adding counterfactual_fairness to deepmind-research repo

PiperOrigin-RevId: 396665335
This commit is contained in:
Ira Ktena
2021-09-14 21:11:14 +01:00
committed by Saran Tunyasuvunakool
parent f49eb487a7
commit cff83be778
11 changed files with 1680 additions and 0 deletions
+1
View File
@@ -60,6 +60,7 @@ https://deepmind.com/research/publications/
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer), NeurIPS 2020
* [What can I do here? A theory of affordances in reinforcement learning](affordances_theory), ICML 2020
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
* [Path-Specific Counterfactual Fairness](counterfactual_fairness), AAAI 2019
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
+57
View File
@@ -0,0 +1,57 @@
# Path-Specific Counterfactual Fairness (AAAI 2019)
Paper: [Path-Specific Counterfactual Fairness](https://ojs.aaai.org/index.php/AAAI/article/view/4777/4655)
If you use the code here please cite this paper:
@inproceedings{chiappa2019path,
title={Path-specific counterfactual fairness},
author={Chiappa, Silvia},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={33},
number={01},
pages={7801--7808},
year={2019}
}
### Overview
This release contains the path-specific counterfactual fairness method used
in the paper, as well as utility functions for loading the *Adult* dataset.
The following gives a brief overview of the contents, more detailed
documentation is available within each file:
* __causal_network.py__: Defines the `Node` class, instances of which can be
combined into a directed graph. Associated with each node is a
`distribution_module`, a haiku module which builds a tensorflow
`Distribution` instance as a function of the node's parents.
* __util.py__: Miscellaneous utility functions.
* __variational.py__: Class for performing variational
inference. The `Variational` haiku module is a general-purpose approximate
posterior, using an MLP to map from arbitrary inputs to the parameters
of a Gaussian distribution.
* __adult.py__: Utility functions for the *Adult* dataset.
* __adult_pscf.py__: Training and 'fair' prediction process (using
path-specific counterfactual fairness) on the *Adult* dataset.
* __adult_pscf_config.py__: Configuration file with default training
parameters.
### Dataset
The *Adult* dataset can be downloaded from
[https://archive.ics.uci.edu/ml/datasets/adult](https://archive.ics.uci.edu/ml/datasets/adult). You may use the following
command to download both necessary files to run the training script:
`sh download_dataset.sh ${OUTPUT_DIR}`
### Experiments
To download the dataset and run the main experiment reported in the paper,
you may run:
`sh run_adult_pscf.sh`
### Acknowledgements
Credits to Thomas P. S. Gillam for the original TF1 implementation.
+78
View File
@@ -0,0 +1,78 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adult dataset.
See https://archive.ics.uci.edu/ml/datasets/adult.
"""
import os
from absl import logging
import pandas as pd
_COLUMNS = ('age', 'workclass', 'final-weight', 'education', 'education-num',
'marital-status', 'occupation', 'relationship', 'race', 'sex',
'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
'income')
_CATEGORICAL_COLUMNS = ('workclass', 'education', 'marital-status',
'occupation', 'race', 'relationship', 'sex',
'native-country', 'income')
def _read_data(
name,
data_path=''):
with os.path.join(data_path, name) as data_file:
data = pd.read_csv(data_file, header=None, index_col=False,
names=_COLUMNS, skipinitialspace=True, na_values='?')
for categorical in _CATEGORICAL_COLUMNS:
data[categorical] = data[categorical].astype('category')
return data
def _combine_category_coding(df_1, df_2):
"""Combines the categories between dataframes df_1 and df_2.
This is used to ensure that training and test data use the same category
coding, so that the one-hot vectors representing the values are compatible
between training and test data.
Args:
df_1: Pandas DataFrame.
df_2: Pandas DataFrame. Must have the same columns as df_1.
"""
for column in df_1.columns:
if df_1[column].dtype.name == 'category':
categories_1 = set(df_1[column].cat.categories)
categories_2 = set(df_2[column].cat.categories)
categories = sorted(categories_1 | categories_2)
df_1[column].cat.set_categories(categories, inplace=True)
df_2[column].cat.set_categories(categories, inplace=True)
def read_all_data(root_dir, remove_missing=True):
"""Return (train, test) dataframes, optionally removing incomplete rows."""
train_data = _read_data('adult.data', root_dir)
test_data = _read_data('adult.test', root_dir)
_combine_category_coding(train_data, test_data)
if remove_missing:
train_data = train_data.dropna()
test_data = test_data.dropna()
logging.info('Training data dtypes: %s', train_data.dtypes)
return train_data, test_data
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,65 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config for adult PSCF experiment."""
import ml_collections
def get_config():
"""Return the default configuration."""
config = ml_collections.ConfigDict()
config.num_steps = 10000 # Number of training steps to perform.
config.batch_size = 128 # Batch size.
config.learning_rate = 0.01 # Learning rate
# Number of samples to draw for prediction.
config.num_prediction_samples = 500
# Batch size to use for prediction. Ideally as big as possible, but may need
# to be reduced for memory reasons depending on the value of
# `num_prediction_samples`.
config.prediction_batch_size = 500
# Multiplier for the likelihood term in the loss
config.likelihood_multiplier = 5.
# Multiplier for the MMD constraint term in the loss
config.constraint_multiplier = 0.
# Scaling factor to use in KL term.
config.beta = 1.0
# The number of samples we draw from each latent variable distribution.
config.mmd_sample_size = 100
# Directory into which results should be placed. By default it is the empty
# string, in which case no saving will occur. The directory specified will be
# created if it does not exist.
config.output_dir = ''
# The index of the step at which to turn on the constraint multiplier. For
# steps prior to this the multiplier will be zero.
config.constraint_turn_on_step = 0
# The random seed for tensorflow that is applied to the graph iff the value is
# non-negative. By default the seed is not constrained.
config.seed = -1
# When doing fair inference, don't sample when given a sample for the baseline
# gender.
config.baseline_passthrough = False
return config
+405
View File
@@ -0,0 +1,405 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Construct DAGs representing causal graphs, and perform inference on them."""
import collections
import haiku as hk
import jax
import jax.numpy as jnp
import pandas as pd
from tensorflow_probability.substrates import jax as tfp
import tree
class Node:
"""A node in a graphical model.
Conceptually, this represents a random variable in a causal probabilistic
model. It knows about its 'parents', i.e. other Nodes upon which this Node
causally depends. The user is responsible for ensuring that any graph built
with this class is acyclic.
A node knows how to represent its probability density, conditional on the
values of its parents.
The node needs to have a 'name', corresponding to the series within the
dataframe that should be used to populate it.
"""
def __init__(self, distribution_module, parents=(), hidden=False):
"""Initialise a `Node` instance.
Args:
distribution_module: An instance of `DistributionModule`, a Haiku module
that is suitable for modelling the conditional distribution of this
node given any parents.
parents: `Iterable`, optional. A (potentially nested) collection of nodes
which are direct ancestors of `self`.
hidden: `bool`, optional. Whether this node is hidden. Hidden nodes are
permitted to not have corresponding observations.
"""
parents = tree.flatten(parents)
self._distribution_module = distribution_module
self._column = distribution_module.column
self._index = distribution_module.index
self._hidden = hidden
self._observed_value = None
# When implementing the path-specific counterfactual fairness algorithm,
# we need the concept of a distribution conditional on the 'corrected'
# values of the parents. This is achieved via the 'node_to_replacement'
# argument of make_distribution.
# However, in order to work with the `fix_categorical` and `fix_continuous`
# functions, we need to assign counterfactual values for parents at
# evaluation time.
self._parent_to_value = collections.OrderedDict(
(parent, None) for parent in parents)
# This is the conditional distribution using no replacements, i.e. it is
# conditioned on the observed values of parents.
self._distribution = None
def __repr__(self):
return 'Node<{}>'.format(self.name)
@property
def dim(self):
"""The dimensionality of this node."""
return self._distribution_module.dim
@property
def name(self):
return self._column
@property
def hidden(self):
return self._hidden
@property
def observed_value(self):
return self._observed_value
def find_ancestor(self, name):
"""Returns an ancestor node with the given name."""
if self.name == name:
return self
for parent in self.parents:
found = parent.find_ancestor(name)
if found is not None:
return found
@property
def parents(self):
return tuple(self._parent_to_value)
@property
def distribution_module(self):
return self._distribution_module
@property
def distribution(self):
self._distribution = self.make_distribution()
return self._distribution
def make_distribution(self, node_to_replacement=None):
"""Make a conditional distribution for this node | parents.
By default we use values (representing 'real data') from the parent
nodes as inputs to the distribution, however we can alternatively swap out
any of these for arbitrary arrays by specifying `node_to_replacement`.
Args:
node_to_replacement: `None`, `dict: Node -> DeviceArray`.
If specified, use the indicated array.
Returns:
`tfp.distributions.Distribution`
"""
cur_parent_to_value = self._parent_to_value
self._parent_to_value = collections.OrderedDict(
(parent, parent.observed_value) for parent in cur_parent_to_value.keys()
)
if node_to_replacement is None:
parent_values = self._parent_to_value.values()
return self._distribution_module(*parent_values)
args = []
for node, value in self._parent_to_value.items():
if node in node_to_replacement:
replacement = node_to_replacement[node]
args.append(replacement)
else:
args.append(value)
return self._distribution_module(*args)
def populate(self, data, node_to_replacement=None):
"""Given a dataframe, populate node data.
If the Node does not have data present, this is taken to be a sign of
a) An error if the node is not hidden.
b) Fine if the node is hidden.
In case a) an exception will be raised, and in case b) observed)v will
not be mutated.
Args:
data: tf.data.Dataset
node_to_replacement: None | dict(Node -> array). If not None, use the
given ndarray data rather than extracting data from the frame. This is
only considered when looking at the inputs to a distribution.
Raises:
RuntimeError: If `data` doesn't contain the necessary feature, and the
node is not hidden.
"""
column = self._column
hidden = self._hidden
replacement = None
if node_to_replacement is not None and self in node_to_replacement:
replacement = node_to_replacement[self]
if replacement is not None:
# If a replacement is present, this takes priority over any other
# consideration.
self._observed_value = replacement
return
if self._index < 0:
if not hidden:
raise RuntimeError(
'Node {} is not hidden, and column {} is not in the frame.'.format(
self, column))
# Nothing to do - there is no data, and the node is hidden.
return
# Produce the observed value for this node.
self._observed_value = self._distribution_module.prepare_data(data)
class DistributionModule(hk.Module):
"""Common base class for a Haiku module representing a distribution.
This provides some additional functionality common to all modules that would
be used as arguments to the `Node` class above.
"""
def __init__(self, column, index, dim):
"""Initialise a `DistributionModule` instance.
Args:
column: `string`. The name of the random variable to which this
distribution corresponds, and should match the name of the series in
the pandas dataframe.
index: `int`. The index of the corresponding feature in the dataset.
dim: `int`. The output dimensionality of the distribution.
"""
super().__init__(name=column.replace('-', '_'))
self._column = column
self._index = index
self._dim = dim
@property
def dim(self):
"""The output dimensionality of this distribution."""
return self._dim
@property
def column(self):
return self._column
@property
def index(self):
return self._index
def prepare_data(self, data):
"""Given a general tensor, return an ndarray if required.
This method implements the functionality delegated from
`Node._prepare_data`, and it is expected that subclasses will override the
implementation appropriately.
Args:
data: A tf.data.Dataset.
Returns:
`np.ndarray` of appropriately converted values for this series.
"""
return data[:, [self._index]]
def _package_args(self, args):
"""Concatenate args into a single tensor.
Args:
args: `List[DeviceArray]`, length > 0.
Each array is of shape (batch_size, ?) or (batch_size,). The former
will occur if looking at e.g. a one-hot encoded categorical variable,
and the latter in the case of a continuous variable.
Returns:
`DeviceArray`, (batch_size, num_values).
"""
return jnp.concatenate(args, axis=1)
class Gaussian(DistributionModule):
"""A Haiku module that maps some inputs into a normal distribution."""
def __init__(self, column, index, dim=1, hidden_shape=(),
hidden_activation=jnp.tanh, scale=None):
"""Initialise a `Gaussian` instance with some dimensionality."""
super(Gaussian, self).__init__(column, index, dim)
self._hidden_shape = tuple(hidden_shape)
self._hidden_activation = hidden_activation
self._scale = scale
self._loc_net = hk.nets.MLP(self._hidden_shape + (self._dim,),
activation=self._hidden_activation)
def __call__(self, *args):
if args:
# There are arguments - these represent the variables on which we are
# conditioning. We set the mean of the output distribution to be a
# function of these values, parameterised with an MLP.
concatenated_inputs = self._package_args(args)
loc = self._loc_net(concatenated_inputs)
else:
# There are no arguments, so instead have a learnable location parameter.
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
if self._scale is None:
# The scale has not been explicitly specified, in which case it is left
# to be single value, i.e. not a function of the conditioning set.
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
scale = jnp.sqrt(jnp.exp(log_var))
else:
scale = jnp.float32(self._scale)
return tfp.distributions.Normal(loc=loc, scale=scale)
def prepare_data(self, data):
# For continuous data, we ensure the data is of dtype float32, and
# additionally that the resulant shape is (num_examples, 1)
# Note that this implementation only works for dim=1, however this is
# currently also enforced by the fact that pandas series cannot be
# multidimensional.
result = data[:, [self.index]].astype(jnp.float32)
if len(result.shape) == 1:
result = jnp.expand_dims(result, axis=1)
return result
class GaussianMixture(DistributionModule):
"""A Haiku module that maps some inputs into a mixture of normals."""
def __init__(self, column, num_components, dim=1):
"""Initialise a `GaussianMixture` instance with some dimensionality.
Args:
column: `string`. The name of the column.
num_components: `int`. The number of gaussians in this mixture.
dim: `int`. The dimensionality of the variable.
"""
super().__init__(column, -1, dim)
self._num_components = num_components
self._loc_net = hk.nets.MLP([self._dim])
self._categorical_logits_module = hk.nets.MLP([self._num_components])
def __call__(self, *args):
# Define component Gaussians to be independent functions of args.
locs = []
scales = []
for _ in range(self._num_components):
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
scale = jnp.sqrt(jnp.exp(log_var))
locs.extend(loc)
scales.extend(scale)
# Define the Categorical distribution which switches between these
categorical_logits = hk.get_parameter('categorical_logits',
shape=[self._num_components],
init=jnp.zeros)
# Enforce positivity in the logits
categorical_logits = jax.nn.sigmoid(categorical_logits)
# If we have a multidimensional node, then the normal distributions above
# have a batch shape of (dim,). We want to select between these using the
# categorical distribution, so tile the logits to match this shape
expanded_logits = jnp.repeat(categorical_logits, self._dim)
categorical = tfp.distributions.Categorical(logits=expanded_logits)
return tfp.distributions.MixtureSameFamily(
mixture_distribution=categorical,
components_distribution=tfp.distributions.Normal(
loc=locs, scale=scales))
class MLPMultinomial(DistributionModule):
"""A Haiku module that consists of an MLP + multinomial distribution."""
def __init__(self, column, index, dim, hidden_shape=(),
hidden_activation=jnp.tanh):
"""Initialise an MLPMultinomial instance.
Args:
column: `string`. Name of the corresponding dataframe column.
index: `int`. The index of the input data for this feature.
dim: `int`. Number of categories.
hidden_shape: `Iterable`, optional. Shape of hidden layers.
hidden_activation: `Callable`, optional. Non-linearity for hidden
layers.
"""
super(MLPMultinomial, self).__init__(column, index, dim)
self._hidden_shape = tuple(hidden_shape)
self._hidden_activation = hidden_activation
self._logit_net = hk.nets.MLP(self._hidden_shape + (self.dim,),
activation=self._hidden_activation)
@classmethod
def from_frame(cls, data, column, hidden_shape=()):
"""Create an MLPMultinomial instance from a pandas dataframe and column."""
# Helper method that uses the dataframe to work out how many categories
# are in the given column. The dataframe is not used for any other purpose.
if not isinstance(data[column].dtype, pd.api.types.CategoricalDtype):
raise ValueError('{} is not categorical.'.format(column))
index = list(data.columns).index(column)
num_categories = len(data[column].cat.categories)
return cls(column, index, num_categories, hidden_shape)
def __call__(self, *args):
if args:
concatenated_inputs = self._package_args(args)
logits = self._logit_net(concatenated_inputs)
else:
logits = hk.get_parameter('b', shape=[self.dim], init=jnp.zeros)
return tfp.distributions.Multinomial(logits=logits, total_count=1.0)
def prepare_data(self, data):
# For categorical data, we convert to a one-hot representation using the
# pandas category 'codes'. These are integers, and will have a definite
# ordering that is identical between runs.
codes = data[:, self.index]
codes = codes.astype(jnp.int32)
return jnp.eye(self.dim)[codes]
def populate(nodes, dataframe, node_to_replacement=None):
"""Populate observed values for nodes."""
for node in nodes:
node.populate(dataframe, node_to_replacement=node_to_replacement)
@@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# sh download_dataset.sh ${OUTPUT_DIR}
# Example:
# sh download_dataset.sh uci_adult
set -e
OUTPUT_DIR="${1}"
BASE_URL="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
mkdir -p ${OUTPUT_DIR}
for file in adult.data adult.test
do
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
done
+12
View File
@@ -0,0 +1,12 @@
wheel
absl-py>=0.12.0
dm-haiku>=0.0.4
optax>=0.0.8
ml_collections
numpy>=1.16.4
tensorflow>=2.5.0
tensorflow-datasets>=4.3.0
tensorflow_probability
pandas
sklearn
jaxlib>=0.1.67+cuda110
+40
View File
@@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Fail on any error.
set -e
# Display commands being run.
set -x
TMP_DIR=`mktemp -d`
virtualenv --python=python3.6 "${TMP_DIR}/env"
source "${TMP_DIR}/env/bin/activate"
# Install dependencies.
pip install --upgrade -r counterfactual_fairness/requirements.txt
# Download minimal dataset
DATA_DIR="${TMP_DIR}/adult"
bash counterfactual_fairness/download_dataset.sh ${TMP_DIR}
# Train for a few steps.
python -m counterfactual_fairness.adult_pscf --config="counterfactual_fairness/adult_pscf_config.py" --dataset_dir=${DATA_DIR} --num_steps=10
# Clean up.
rm -r ${TMP_DIR}
echo "Test run complete."
+311
View File
@@ -0,0 +1,311 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common utilities."""
from typing import Optional, Union
from jax import random
import jax.numpy as jnp
import pandas as pd
import tensorflow as tf
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
def get_dataset(dataset: pd.DataFrame,
batch_size: int,
shuffle_size: int = 10000,
num_epochs: Optional[int] = None) -> tf.data.Dataset:
"""Makes a tf.Dataset with correct preprocessing."""
dataset_copy = dataset.copy()
for column in dataset.columns:
if dataset[column].dtype.name == 'category':
dataset_copy.loc[:, column] = dataset[column].cat.codes
ds = tf.data.Dataset.from_tensor_slices(dataset_copy.values)
if shuffle_size > 0:
ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True)
ds = ds.repeat(num_epochs)
return ds.batch(batch_size, drop_remainder=True)
def multinomial_mode(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
"""Calculates the (one-hot) mode of a multinomial distribution.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | List[tensors].
If the former, it is assumed that it has a `probs` property, and
represents a distribution over categories. If the latter, these are
taken to be the probabilities of categories directly.
In either case, it is assumed that `probs` will be shape
(batch_size, dim).
Returns:
`DeviceArray`, float32, (batch_size, dim).
The mode of the distribution - this will be in one-hot form, but contain
multiple non-zero entries in the event that more than one probability is
joint-highest.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
probs = distribution_or_probs.probs_parameter()
else:
probs = distribution_or_probs
max_prob = jnp.max(probs, axis=1, keepdims=True)
mode = jnp.int32(jnp.equal(probs, max_prob))
return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
def multinomial_class(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
"""Computes the mode class of a multinomial distribution.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | DeviceArray.
As for `multinomial_mode`.
Returns:
`DeviceArray`, float32, (batch_size,).
For each element in the batch, the index of the class with highest
probability.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1)
return jnp.argmax(distribution_or_probs, axis=1)
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
rather than Tensors.
Args:
probs: `DeviceArray`, (batch_size, dim). Probabilities for each class, for
each element in a batch.
Returns:
`DeviceArray`, (batch_size, dim).
"""
max_prob = jnp.amax(probs, axis=1, keepdims=True)
mode = jnp.equal(probs, max_prob).astype(jnp.int32)
return (mode / jnp.sum(mode, axis=1, keepdims=True)).astype(jnp.float32)
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
data: jnp.DeviceArray) -> jnp.DeviceArray:
"""Compute the accuracy, averaged over a batch of data.
Args:
distribution_or_probs:
`tfp.distributions.Distribution` | List[tensors].
As for functions above.
data: `DeviceArray`. Reference data, of shape (batch_size, dim).
Returns:
`DeviceArray`, float32, ().
Overall scalar accuracy.
"""
return jnp.mean(
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
"""Softmax function, implemented for numpy ndarrays."""
assert len(logits.shape) == 2
# Normalise for better stability.
s = jnp.max(logits, axis=1, keepdims=True)
e_x = jnp.exp(logits - s)
return e_x / jnp.sum(e_x, axis=1, keepdims=True)
def get_samples(distribution, num_samples, seed=None):
"""Given a batched distribution, compute samples and reshape along batch.
That is, we have a distribution of shape (batch_size, ...), where each element
of the tensor is independent. We then draw num_samples from each component, to
give a tensor of shape:
(num_samples, batch_size, ...)
Args:
distribution: `tfp.distributions.Distribution`. The distribution from which
to sample.
num_samples: `Integral` | `DeviceArray`, int32, (). The number of samples.
seed: `Integral` | `None`. The seed that will be forwarded to the call to
distribution.sample. Defaults to `None`.
Returns:
`DeviceArray`, float32, (batch_size * num_samples, ...).
Samples for each element of the batch.
"""
# Obtain the sample from the distribution, which will be of shape
# [num_samples] + batch_shape + event_shape.
sample = distribution.sample(num_samples, seed=seed)
sample = sample.reshape((-1, sample.shape[-1]))
# Combine the first two dimensions through a reshape, so the result will
# be of shape (num_samples * batch_size,) + shape_tail.
return sample
def mmd_loss(distribution: tfd.Distribution,
is_a: jnp.DeviceArray,
num_samples: int,
rng: jnp.ndarray,
num_random_features: int = 50,
gamma: float = 1.):
"""Given two distributions, compute the Maximum Mean Discrepancy (MMD).
More exactly, this uses the 'FastMMD' approximation, a.k.a. 'Random Fourier
Features'. See the description, for example, in sections 2.3.1 and 2.4 of
https://arxiv.org/pdf/1511.00830.pdf.
Args:
distribution: Distribution whose `sample()` method will return a
DeviceArray of shape (batch_size, dim).
is_a: A boolean array indicating which elements of the batch correspond
to class A (the remaining indices correspond to class B).
num_samples: The number of samples to draw from `distribution`.
rng: Random seed provided by the user.
num_random_features: The number of random fourier features
used in the expansion.
gamma: The value of gamma in the Gaussian MMD kernel.
Returns:
`DeviceArray`, shape ().
The scalar MMD value for samples taken from the given distributions.
"""
if distribution.event_shape == (): # pylint: disable=g-explicit-bool-comparison
dim_x = distribution.batch_shape[1]
else:
dim_x, = distribution.event_shape
# Obtain samples from the distribution, which will be of shape
# [num_samples] + batch_shape + event_shape.
samples = distribution.sample(num_samples, seed=rng)
w = random.normal(rng, shape=((dim_x, num_random_features)))
b = random.uniform(rng, shape=(num_random_features,),
minval=0, maxval=2*jnp.pi)
def features(x):
"""Compute the kitchen sink feature."""
# We need to contract last axis of x with first of W - do this with
# tensordot. The result has shape:
# (?, ?, num_random_features)
return jnp.sqrt(2 / num_random_features) * jnp.cos(
jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
# Compute the expected values of the given features.
# The first axis represents the samples from the distribution,
# second axis represents the batch_size.
# Each of these now has shape (num_random_features,)
exp_features = features(samples)
# Swap axes so that batch_size is the last dimension to be compatible
# with is_a and is_b shape at the next step
exp_features_reshaped = jnp.swapaxes(exp_features, 1, 2)
# Current dimensions [num_samples, num_random_features, batch_size]
exp_features_reshaped_a = jnp.where(is_a, exp_features_reshaped, 0)
exp_features_reshaped_b = jnp.where(is_a, 0, exp_features_reshaped)
exp_features_a = jnp.mean(exp_features_reshaped_a, axis=(0, 2))
exp_features_b = jnp.mean(exp_features_reshaped_b, axis=(0, 2))
assert exp_features_a.shape == (num_random_features,)
difference = exp_features_a - exp_features_b
# Compute the squared norm. Shape ().
return jnp.tensordot(difference, difference, axes=1)
def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.):
"""Exact estimate of MMD."""
assert distribution_a.event_shape == distribution_b.event_shape
assert distribution_a.batch_shape[1:] == distribution_b.batch_shape[1:]
# shape (num_samples * batch_size_a, dim_x)
samples_a = get_samples(distribution_a, num_samples)
# shape (num_samples * batch_size_b, dim_x)
samples_b = get_samples(distribution_b, num_samples)
# Make matrices of shape
# (size_b, size_a, dim_x)
# where:
# size_a = num_samples * batch_size_a
# size_b = num_samples * batch_size_b
size_a = samples_a.shape[0]
size_b = samples_b.shape[0]
x_a = jnp.expand_dims(samples_a, axis=0)
x_a = jnp.tile(x_a, (size_b, 1, 1))
x_b = jnp.expand_dims(samples_b, axis=1)
x_b = jnp.tile(x_b, (1, size_a, 1))
def kernel_mean(x, y):
"""Gaussian kernel mean."""
diff = x - y
# Contract over dim_x.
exponent = - jnp.einsum('ijk,ijk->ij', diff, diff) / gamma
# This has shape (size_b, size_a).
kernel_matrix = jnp.exp(exponent)
# Shape ().
return jnp.mean(kernel_matrix)
# Equation 7 from arxiv 1511.00830
return (
kernel_mean(x_a, x_a)
+ kernel_mean(x_b, x_b)
- 2 * kernel_mean(x_a, x_b))
def scalar_log_prob(distribution, val):
"""Compute the log_prob per batch entry.
It is conceptually similar to:
jnp.sum(distribution.log_prob(val), axis=1)
However, classes like `tfp.distributions.Multinomial` have a log_prob which
returns a tensor of shape (batch_size,), which will cause the above
incantation to fail. In these cases we fall back to returning just:
distribution.log_prob(val)
Args:
distribution: `tfp.distributions.Distribution` which implements log_prob.
val: `DeviceArray`, (batch_size, dim).
Returns:
`DeviceArray`, (batch_size,).
If the result of log_prob has a trailing dimension, we perform a reduce_sum
over it.
Raises:
ValueError: If distribution.log_prob(val) has an unsupported shape.
"""
log_prob_val = distribution.log_prob(val)
if len(log_prob_val.shape) == 1:
return log_prob_val
elif len(log_prob_val.shape) > 2:
raise ValueError('log_prob_val has unexpected shape {}.'.format(
log_prob_val.shape))
return jnp.sum(log_prob_val, axis=1)
+87
View File
@@ -0,0 +1,87 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions and classes for performing variational inference."""
from typing import Callable, Iterable, Optional
import haiku as hk
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
class Variational(hk.Module):
"""A module representing the variational distribution q(H | *O).
H is assumed to be a continuous variable.
"""
def __init__(self,
common_layer_sizes: Iterable[int],
activation: Callable[[jnp.DeviceArray],
jnp.DeviceArray] = jnp.tanh,
output_dim: int = 1,
name: Optional[str] = None):
"""Initialises a `Variational` instance.
Args:
common_layer_sizes: The number of hidden units in the shared dense
network layers.
activation: Nonlinearity function to apply to each of the
common layers.
output_dim: The dimensionality of `H`.
name: A name to assign to the module instance.
"""
super().__init__(name=name)
self._common_layer_sizes = common_layer_sizes
self._activation = activation
self._output_dim = output_dim
self._linear_layers = [
hk.Linear(layer_size)
for layer_size in self._common_layer_sizes
]
self._mean_output = hk.Linear(self._output_dim)
self._log_var_output = hk.Linear(self._output_dim)
def __call__(self, *args) -> tfd.Distribution:
"""Create a distribution for q(H | *O).
Args:
*args: `List[DeviceArray]`. Corresponds to the values of whatever
variables are in the conditional set *O.
Returns:
`tfp.distributions.NormalDistribution` instance.
"""
# Stack all inputs, ensuring that shapes are consistent and that they are
# all of dtype float32.
input_ = [hk.Flatten()(arg) for arg in args]
input_ = jnp.concatenate(input_, axis=1)
# Create a common set of layers, then final layer separates mean & log_var
for layer in self._linear_layers:
input_ = layer(input_)
input_ = self._activation(input_)
# input_ now represents a tensor of shape (batch_size, final_layer_size).
# This is now put through two final layers, one for the computation of each
# of the mean and standard deviation of the resultant distribution.
mean = self._mean_output(input_)
log_var = self._log_var_output(input_)
std = jnp.sqrt(jnp.exp(log_var))
return tfd.Normal(mean, std)