mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 02:35:47 +08:00
Adding counterfactual_fairness to deepmind-research repo
PiperOrigin-RevId: 396665335
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
f49eb487a7
commit
cff83be778
@@ -60,6 +60,7 @@ https://deepmind.com/research/publications/
|
|||||||
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer), NeurIPS 2020
|
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer), NeurIPS 2020
|
||||||
* [What can I do here? A theory of affordances in reinforcement learning](affordances_theory), ICML 2020
|
* [What can I do here? A theory of affordances in reinforcement learning](affordances_theory), ICML 2020
|
||||||
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
|
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
|
||||||
|
* [Path-Specific Counterfactual Fairness](counterfactual_fairness), AAAI 2019
|
||||||
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
|
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
|
||||||
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
|
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
|
||||||
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
|
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
# Path-Specific Counterfactual Fairness (AAAI 2019)
|
||||||
|
|
||||||
|
Paper: [Path-Specific Counterfactual Fairness](https://ojs.aaai.org/index.php/AAAI/article/view/4777/4655)
|
||||||
|
|
||||||
|
If you use the code here please cite this paper:
|
||||||
|
|
||||||
|
@inproceedings{chiappa2019path,
|
||||||
|
title={Path-specific counterfactual fairness},
|
||||||
|
author={Chiappa, Silvia},
|
||||||
|
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||||
|
volume={33},
|
||||||
|
number={01},
|
||||||
|
pages={7801--7808},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
This release contains the path-specific counterfactual fairness method used
|
||||||
|
in the paper, as well as utility functions for loading the *Adult* dataset.
|
||||||
|
|
||||||
|
The following gives a brief overview of the contents, more detailed
|
||||||
|
documentation is available within each file:
|
||||||
|
|
||||||
|
* __causal_network.py__: Defines the `Node` class, instances of which can be
|
||||||
|
combined into a directed graph. Associated with each node is a
|
||||||
|
`distribution_module`, a haiku module which builds a tensorflow
|
||||||
|
`Distribution` instance as a function of the node's parents.
|
||||||
|
* __util.py__: Miscellaneous utility functions.
|
||||||
|
* __variational.py__: Class for performing variational
|
||||||
|
inference. The `Variational` haiku module is a general-purpose approximate
|
||||||
|
posterior, using an MLP to map from arbitrary inputs to the parameters
|
||||||
|
of a Gaussian distribution.
|
||||||
|
* __adult.py__: Utility functions for the *Adult* dataset.
|
||||||
|
* __adult_pscf.py__: Training and 'fair' prediction process (using
|
||||||
|
path-specific counterfactual fairness) on the *Adult* dataset.
|
||||||
|
* __adult_pscf_config.py__: Configuration file with default training
|
||||||
|
parameters.
|
||||||
|
|
||||||
|
### Dataset
|
||||||
|
|
||||||
|
The *Adult* dataset can be downloaded from
|
||||||
|
[https://archive.ics.uci.edu/ml/datasets/adult](https://archive.ics.uci.edu/ml/datasets/adult). You may use the following
|
||||||
|
command to download both necessary files to run the training script:
|
||||||
|
|
||||||
|
`sh download_dataset.sh ${OUTPUT_DIR}`
|
||||||
|
|
||||||
|
### Experiments
|
||||||
|
|
||||||
|
To download the dataset and run the main experiment reported in the paper,
|
||||||
|
you may run:
|
||||||
|
|
||||||
|
`sh run_adult_pscf.sh`
|
||||||
|
|
||||||
|
### Acknowledgements
|
||||||
|
|
||||||
|
Credits to Thomas P. S. Gillam for the original TF1 implementation.
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Adult dataset.
|
||||||
|
|
||||||
|
See https://archive.ics.uci.edu/ml/datasets/adult.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_COLUMNS = ('age', 'workclass', 'final-weight', 'education', 'education-num',
|
||||||
|
'marital-status', 'occupation', 'relationship', 'race', 'sex',
|
||||||
|
'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
|
||||||
|
'income')
|
||||||
|
_CATEGORICAL_COLUMNS = ('workclass', 'education', 'marital-status',
|
||||||
|
'occupation', 'race', 'relationship', 'sex',
|
||||||
|
'native-country', 'income')
|
||||||
|
|
||||||
|
|
||||||
|
def _read_data(
|
||||||
|
name,
|
||||||
|
data_path=''):
|
||||||
|
with os.path.join(data_path, name) as data_file:
|
||||||
|
data = pd.read_csv(data_file, header=None, index_col=False,
|
||||||
|
names=_COLUMNS, skipinitialspace=True, na_values='?')
|
||||||
|
for categorical in _CATEGORICAL_COLUMNS:
|
||||||
|
data[categorical] = data[categorical].astype('category')
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_category_coding(df_1, df_2):
|
||||||
|
"""Combines the categories between dataframes df_1 and df_2.
|
||||||
|
|
||||||
|
This is used to ensure that training and test data use the same category
|
||||||
|
coding, so that the one-hot vectors representing the values are compatible
|
||||||
|
between training and test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_1: Pandas DataFrame.
|
||||||
|
df_2: Pandas DataFrame. Must have the same columns as df_1.
|
||||||
|
"""
|
||||||
|
for column in df_1.columns:
|
||||||
|
if df_1[column].dtype.name == 'category':
|
||||||
|
categories_1 = set(df_1[column].cat.categories)
|
||||||
|
categories_2 = set(df_2[column].cat.categories)
|
||||||
|
categories = sorted(categories_1 | categories_2)
|
||||||
|
df_1[column].cat.set_categories(categories, inplace=True)
|
||||||
|
df_2[column].cat.set_categories(categories, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
def read_all_data(root_dir, remove_missing=True):
|
||||||
|
"""Return (train, test) dataframes, optionally removing incomplete rows."""
|
||||||
|
train_data = _read_data('adult.data', root_dir)
|
||||||
|
test_data = _read_data('adult.test', root_dir)
|
||||||
|
_combine_category_coding(train_data, test_data)
|
||||||
|
if remove_missing:
|
||||||
|
train_data = train_data.dropna()
|
||||||
|
test_data = test_data.dropna()
|
||||||
|
|
||||||
|
logging.info('Training data dtypes: %s', train_data.dtypes)
|
||||||
|
return train_data, test_data
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Config for adult PSCF experiment."""
|
||||||
|
|
||||||
|
import ml_collections
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""Return the default configuration."""
|
||||||
|
config = ml_collections.ConfigDict()
|
||||||
|
|
||||||
|
config.num_steps = 10000 # Number of training steps to perform.
|
||||||
|
config.batch_size = 128 # Batch size.
|
||||||
|
config.learning_rate = 0.01 # Learning rate
|
||||||
|
|
||||||
|
# Number of samples to draw for prediction.
|
||||||
|
config.num_prediction_samples = 500
|
||||||
|
|
||||||
|
# Batch size to use for prediction. Ideally as big as possible, but may need
|
||||||
|
# to be reduced for memory reasons depending on the value of
|
||||||
|
# `num_prediction_samples`.
|
||||||
|
config.prediction_batch_size = 500
|
||||||
|
|
||||||
|
# Multiplier for the likelihood term in the loss
|
||||||
|
config.likelihood_multiplier = 5.
|
||||||
|
|
||||||
|
# Multiplier for the MMD constraint term in the loss
|
||||||
|
config.constraint_multiplier = 0.
|
||||||
|
|
||||||
|
# Scaling factor to use in KL term.
|
||||||
|
config.beta = 1.0
|
||||||
|
|
||||||
|
# The number of samples we draw from each latent variable distribution.
|
||||||
|
config.mmd_sample_size = 100
|
||||||
|
|
||||||
|
# Directory into which results should be placed. By default it is the empty
|
||||||
|
# string, in which case no saving will occur. The directory specified will be
|
||||||
|
# created if it does not exist.
|
||||||
|
config.output_dir = ''
|
||||||
|
|
||||||
|
# The index of the step at which to turn on the constraint multiplier. For
|
||||||
|
# steps prior to this the multiplier will be zero.
|
||||||
|
config.constraint_turn_on_step = 0
|
||||||
|
|
||||||
|
# The random seed for tensorflow that is applied to the graph iff the value is
|
||||||
|
# non-negative. By default the seed is not constrained.
|
||||||
|
config.seed = -1
|
||||||
|
|
||||||
|
# When doing fair inference, don't sample when given a sample for the baseline
|
||||||
|
# gender.
|
||||||
|
config.baseline_passthrough = False
|
||||||
|
|
||||||
|
return config
|
||||||
@@ -0,0 +1,405 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Construct DAGs representing causal graphs, and perform inference on them."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import pandas as pd
|
||||||
|
from tensorflow_probability.substrates import jax as tfp
|
||||||
|
import tree
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
"""A node in a graphical model.
|
||||||
|
|
||||||
|
Conceptually, this represents a random variable in a causal probabilistic
|
||||||
|
model. It knows about its 'parents', i.e. other Nodes upon which this Node
|
||||||
|
causally depends. The user is responsible for ensuring that any graph built
|
||||||
|
with this class is acyclic.
|
||||||
|
|
||||||
|
A node knows how to represent its probability density, conditional on the
|
||||||
|
values of its parents.
|
||||||
|
|
||||||
|
The node needs to have a 'name', corresponding to the series within the
|
||||||
|
dataframe that should be used to populate it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, distribution_module, parents=(), hidden=False):
|
||||||
|
"""Initialise a `Node` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution_module: An instance of `DistributionModule`, a Haiku module
|
||||||
|
that is suitable for modelling the conditional distribution of this
|
||||||
|
node given any parents.
|
||||||
|
parents: `Iterable`, optional. A (potentially nested) collection of nodes
|
||||||
|
which are direct ancestors of `self`.
|
||||||
|
hidden: `bool`, optional. Whether this node is hidden. Hidden nodes are
|
||||||
|
permitted to not have corresponding observations.
|
||||||
|
"""
|
||||||
|
parents = tree.flatten(parents)
|
||||||
|
|
||||||
|
self._distribution_module = distribution_module
|
||||||
|
self._column = distribution_module.column
|
||||||
|
self._index = distribution_module.index
|
||||||
|
self._hidden = hidden
|
||||||
|
self._observed_value = None
|
||||||
|
|
||||||
|
# When implementing the path-specific counterfactual fairness algorithm,
|
||||||
|
# we need the concept of a distribution conditional on the 'corrected'
|
||||||
|
# values of the parents. This is achieved via the 'node_to_replacement'
|
||||||
|
# argument of make_distribution.
|
||||||
|
|
||||||
|
# However, in order to work with the `fix_categorical` and `fix_continuous`
|
||||||
|
# functions, we need to assign counterfactual values for parents at
|
||||||
|
# evaluation time.
|
||||||
|
self._parent_to_value = collections.OrderedDict(
|
||||||
|
(parent, None) for parent in parents)
|
||||||
|
|
||||||
|
# This is the conditional distribution using no replacements, i.e. it is
|
||||||
|
# conditioned on the observed values of parents.
|
||||||
|
self._distribution = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Node<{}>'.format(self.name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim(self):
|
||||||
|
"""The dimensionality of this node."""
|
||||||
|
return self._distribution_module.dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._column
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden(self):
|
||||||
|
return self._hidden
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observed_value(self):
|
||||||
|
return self._observed_value
|
||||||
|
|
||||||
|
def find_ancestor(self, name):
|
||||||
|
"""Returns an ancestor node with the given name."""
|
||||||
|
if self.name == name:
|
||||||
|
return self
|
||||||
|
for parent in self.parents:
|
||||||
|
found = parent.find_ancestor(name)
|
||||||
|
if found is not None:
|
||||||
|
return found
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parents(self):
|
||||||
|
return tuple(self._parent_to_value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution_module(self):
|
||||||
|
return self._distribution_module
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
self._distribution = self.make_distribution()
|
||||||
|
return self._distribution
|
||||||
|
|
||||||
|
def make_distribution(self, node_to_replacement=None):
|
||||||
|
"""Make a conditional distribution for this node | parents.
|
||||||
|
|
||||||
|
By default we use values (representing 'real data') from the parent
|
||||||
|
nodes as inputs to the distribution, however we can alternatively swap out
|
||||||
|
any of these for arbitrary arrays by specifying `node_to_replacement`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_to_replacement: `None`, `dict: Node -> DeviceArray`.
|
||||||
|
If specified, use the indicated array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tfp.distributions.Distribution`
|
||||||
|
"""
|
||||||
|
cur_parent_to_value = self._parent_to_value
|
||||||
|
self._parent_to_value = collections.OrderedDict(
|
||||||
|
(parent, parent.observed_value) for parent in cur_parent_to_value.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
if node_to_replacement is None:
|
||||||
|
parent_values = self._parent_to_value.values()
|
||||||
|
return self._distribution_module(*parent_values)
|
||||||
|
args = []
|
||||||
|
for node, value in self._parent_to_value.items():
|
||||||
|
if node in node_to_replacement:
|
||||||
|
replacement = node_to_replacement[node]
|
||||||
|
args.append(replacement)
|
||||||
|
else:
|
||||||
|
args.append(value)
|
||||||
|
return self._distribution_module(*args)
|
||||||
|
|
||||||
|
def populate(self, data, node_to_replacement=None):
|
||||||
|
"""Given a dataframe, populate node data.
|
||||||
|
|
||||||
|
If the Node does not have data present, this is taken to be a sign of
|
||||||
|
a) An error if the node is not hidden.
|
||||||
|
b) Fine if the node is hidden.
|
||||||
|
In case a) an exception will be raised, and in case b) observed)v will
|
||||||
|
not be mutated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: tf.data.Dataset
|
||||||
|
node_to_replacement: None | dict(Node -> array). If not None, use the
|
||||||
|
given ndarray data rather than extracting data from the frame. This is
|
||||||
|
only considered when looking at the inputs to a distribution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If `data` doesn't contain the necessary feature, and the
|
||||||
|
node is not hidden.
|
||||||
|
"""
|
||||||
|
column = self._column
|
||||||
|
hidden = self._hidden
|
||||||
|
replacement = None
|
||||||
|
if node_to_replacement is not None and self in node_to_replacement:
|
||||||
|
replacement = node_to_replacement[self]
|
||||||
|
|
||||||
|
if replacement is not None:
|
||||||
|
# If a replacement is present, this takes priority over any other
|
||||||
|
# consideration.
|
||||||
|
self._observed_value = replacement
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._index < 0:
|
||||||
|
if not hidden:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Node {} is not hidden, and column {} is not in the frame.'.format(
|
||||||
|
self, column))
|
||||||
|
# Nothing to do - there is no data, and the node is hidden.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Produce the observed value for this node.
|
||||||
|
self._observed_value = self._distribution_module.prepare_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionModule(hk.Module):
|
||||||
|
"""Common base class for a Haiku module representing a distribution.
|
||||||
|
|
||||||
|
This provides some additional functionality common to all modules that would
|
||||||
|
be used as arguments to the `Node` class above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, column, index, dim):
|
||||||
|
"""Initialise a `DistributionModule` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: `string`. The name of the random variable to which this
|
||||||
|
distribution corresponds, and should match the name of the series in
|
||||||
|
the pandas dataframe.
|
||||||
|
index: `int`. The index of the corresponding feature in the dataset.
|
||||||
|
dim: `int`. The output dimensionality of the distribution.
|
||||||
|
"""
|
||||||
|
super().__init__(name=column.replace('-', '_'))
|
||||||
|
self._column = column
|
||||||
|
self._index = index
|
||||||
|
self._dim = dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim(self):
|
||||||
|
"""The output dimensionality of this distribution."""
|
||||||
|
return self._dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def column(self):
|
||||||
|
return self._column
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index(self):
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
def prepare_data(self, data):
|
||||||
|
"""Given a general tensor, return an ndarray if required.
|
||||||
|
|
||||||
|
This method implements the functionality delegated from
|
||||||
|
`Node._prepare_data`, and it is expected that subclasses will override the
|
||||||
|
implementation appropriately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: A tf.data.Dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray` of appropriately converted values for this series.
|
||||||
|
"""
|
||||||
|
return data[:, [self._index]]
|
||||||
|
|
||||||
|
def _package_args(self, args):
|
||||||
|
"""Concatenate args into a single tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: `List[DeviceArray]`, length > 0.
|
||||||
|
Each array is of shape (batch_size, ?) or (batch_size,). The former
|
||||||
|
will occur if looking at e.g. a one-hot encoded categorical variable,
|
||||||
|
and the latter in the case of a continuous variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, (batch_size, num_values).
|
||||||
|
"""
|
||||||
|
return jnp.concatenate(args, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Gaussian(DistributionModule):
|
||||||
|
"""A Haiku module that maps some inputs into a normal distribution."""
|
||||||
|
|
||||||
|
def __init__(self, column, index, dim=1, hidden_shape=(),
|
||||||
|
hidden_activation=jnp.tanh, scale=None):
|
||||||
|
"""Initialise a `Gaussian` instance with some dimensionality."""
|
||||||
|
super(Gaussian, self).__init__(column, index, dim)
|
||||||
|
self._hidden_shape = tuple(hidden_shape)
|
||||||
|
self._hidden_activation = hidden_activation
|
||||||
|
self._scale = scale
|
||||||
|
|
||||||
|
self._loc_net = hk.nets.MLP(self._hidden_shape + (self._dim,),
|
||||||
|
activation=self._hidden_activation)
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
if args:
|
||||||
|
# There are arguments - these represent the variables on which we are
|
||||||
|
# conditioning. We set the mean of the output distribution to be a
|
||||||
|
# function of these values, parameterised with an MLP.
|
||||||
|
concatenated_inputs = self._package_args(args)
|
||||||
|
loc = self._loc_net(concatenated_inputs)
|
||||||
|
else:
|
||||||
|
# There are no arguments, so instead have a learnable location parameter.
|
||||||
|
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
|
||||||
|
|
||||||
|
if self._scale is None:
|
||||||
|
# The scale has not been explicitly specified, in which case it is left
|
||||||
|
# to be single value, i.e. not a function of the conditioning set.
|
||||||
|
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
|
||||||
|
scale = jnp.sqrt(jnp.exp(log_var))
|
||||||
|
else:
|
||||||
|
scale = jnp.float32(self._scale)
|
||||||
|
|
||||||
|
return tfp.distributions.Normal(loc=loc, scale=scale)
|
||||||
|
|
||||||
|
def prepare_data(self, data):
|
||||||
|
# For continuous data, we ensure the data is of dtype float32, and
|
||||||
|
# additionally that the resulant shape is (num_examples, 1)
|
||||||
|
# Note that this implementation only works for dim=1, however this is
|
||||||
|
# currently also enforced by the fact that pandas series cannot be
|
||||||
|
# multidimensional.
|
||||||
|
result = data[:, [self.index]].astype(jnp.float32)
|
||||||
|
if len(result.shape) == 1:
|
||||||
|
result = jnp.expand_dims(result, axis=1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianMixture(DistributionModule):
|
||||||
|
"""A Haiku module that maps some inputs into a mixture of normals."""
|
||||||
|
|
||||||
|
def __init__(self, column, num_components, dim=1):
|
||||||
|
"""Initialise a `GaussianMixture` instance with some dimensionality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: `string`. The name of the column.
|
||||||
|
num_components: `int`. The number of gaussians in this mixture.
|
||||||
|
dim: `int`. The dimensionality of the variable.
|
||||||
|
"""
|
||||||
|
super().__init__(column, -1, dim)
|
||||||
|
self._num_components = num_components
|
||||||
|
self._loc_net = hk.nets.MLP([self._dim])
|
||||||
|
self._categorical_logits_module = hk.nets.MLP([self._num_components])
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
# Define component Gaussians to be independent functions of args.
|
||||||
|
locs = []
|
||||||
|
scales = []
|
||||||
|
for _ in range(self._num_components):
|
||||||
|
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
|
||||||
|
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
|
||||||
|
scale = jnp.sqrt(jnp.exp(log_var))
|
||||||
|
locs.extend(loc)
|
||||||
|
scales.extend(scale)
|
||||||
|
|
||||||
|
# Define the Categorical distribution which switches between these
|
||||||
|
categorical_logits = hk.get_parameter('categorical_logits',
|
||||||
|
shape=[self._num_components],
|
||||||
|
init=jnp.zeros)
|
||||||
|
|
||||||
|
# Enforce positivity in the logits
|
||||||
|
categorical_logits = jax.nn.sigmoid(categorical_logits)
|
||||||
|
|
||||||
|
# If we have a multidimensional node, then the normal distributions above
|
||||||
|
# have a batch shape of (dim,). We want to select between these using the
|
||||||
|
# categorical distribution, so tile the logits to match this shape
|
||||||
|
expanded_logits = jnp.repeat(categorical_logits, self._dim)
|
||||||
|
|
||||||
|
categorical = tfp.distributions.Categorical(logits=expanded_logits)
|
||||||
|
|
||||||
|
return tfp.distributions.MixtureSameFamily(
|
||||||
|
mixture_distribution=categorical,
|
||||||
|
components_distribution=tfp.distributions.Normal(
|
||||||
|
loc=locs, scale=scales))
|
||||||
|
|
||||||
|
|
||||||
|
class MLPMultinomial(DistributionModule):
|
||||||
|
"""A Haiku module that consists of an MLP + multinomial distribution."""
|
||||||
|
|
||||||
|
def __init__(self, column, index, dim, hidden_shape=(),
|
||||||
|
hidden_activation=jnp.tanh):
|
||||||
|
"""Initialise an MLPMultinomial instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: `string`. Name of the corresponding dataframe column.
|
||||||
|
index: `int`. The index of the input data for this feature.
|
||||||
|
dim: `int`. Number of categories.
|
||||||
|
hidden_shape: `Iterable`, optional. Shape of hidden layers.
|
||||||
|
hidden_activation: `Callable`, optional. Non-linearity for hidden
|
||||||
|
layers.
|
||||||
|
"""
|
||||||
|
super(MLPMultinomial, self).__init__(column, index, dim)
|
||||||
|
self._hidden_shape = tuple(hidden_shape)
|
||||||
|
self._hidden_activation = hidden_activation
|
||||||
|
self._logit_net = hk.nets.MLP(self._hidden_shape + (self.dim,),
|
||||||
|
activation=self._hidden_activation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_frame(cls, data, column, hidden_shape=()):
|
||||||
|
"""Create an MLPMultinomial instance from a pandas dataframe and column."""
|
||||||
|
# Helper method that uses the dataframe to work out how many categories
|
||||||
|
# are in the given column. The dataframe is not used for any other purpose.
|
||||||
|
if not isinstance(data[column].dtype, pd.api.types.CategoricalDtype):
|
||||||
|
raise ValueError('{} is not categorical.'.format(column))
|
||||||
|
index = list(data.columns).index(column)
|
||||||
|
num_categories = len(data[column].cat.categories)
|
||||||
|
return cls(column, index, num_categories, hidden_shape)
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
if args:
|
||||||
|
concatenated_inputs = self._package_args(args)
|
||||||
|
logits = self._logit_net(concatenated_inputs)
|
||||||
|
else:
|
||||||
|
logits = hk.get_parameter('b', shape=[self.dim], init=jnp.zeros)
|
||||||
|
return tfp.distributions.Multinomial(logits=logits, total_count=1.0)
|
||||||
|
|
||||||
|
def prepare_data(self, data):
|
||||||
|
# For categorical data, we convert to a one-hot representation using the
|
||||||
|
# pandas category 'codes'. These are integers, and will have a definite
|
||||||
|
# ordering that is identical between runs.
|
||||||
|
codes = data[:, self.index]
|
||||||
|
codes = codes.astype(jnp.int32)
|
||||||
|
return jnp.eye(self.dim)[codes]
|
||||||
|
|
||||||
|
|
||||||
|
def populate(nodes, dataframe, node_to_replacement=None):
|
||||||
|
"""Populate observed values for nodes."""
|
||||||
|
for node in nodes:
|
||||||
|
node.populate(dataframe, node_to_replacement=node_to_replacement)
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# sh download_dataset.sh ${OUTPUT_DIR}
|
||||||
|
# Example:
|
||||||
|
# sh download_dataset.sh uci_adult
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
OUTPUT_DIR="${1}"
|
||||||
|
|
||||||
|
BASE_URL="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
|
||||||
|
|
||||||
|
mkdir -p ${OUTPUT_DIR}
|
||||||
|
for file in adult.data adult.test
|
||||||
|
do
|
||||||
|
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
|
||||||
|
done
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
wheel
|
||||||
|
absl-py>=0.12.0
|
||||||
|
dm-haiku>=0.0.4
|
||||||
|
optax>=0.0.8
|
||||||
|
ml_collections
|
||||||
|
numpy>=1.16.4
|
||||||
|
tensorflow>=2.5.0
|
||||||
|
tensorflow-datasets>=4.3.0
|
||||||
|
tensorflow_probability
|
||||||
|
pandas
|
||||||
|
sklearn
|
||||||
|
jaxlib>=0.1.67+cuda110
|
||||||
Executable
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Fail on any error.
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Display commands being run.
|
||||||
|
set -x
|
||||||
|
|
||||||
|
TMP_DIR=`mktemp -d`
|
||||||
|
|
||||||
|
virtualenv --python=python3.6 "${TMP_DIR}/env"
|
||||||
|
source "${TMP_DIR}/env/bin/activate"
|
||||||
|
|
||||||
|
# Install dependencies.
|
||||||
|
pip install --upgrade -r counterfactual_fairness/requirements.txt
|
||||||
|
|
||||||
|
# Download minimal dataset
|
||||||
|
DATA_DIR="${TMP_DIR}/adult"
|
||||||
|
bash counterfactual_fairness/download_dataset.sh ${TMP_DIR}
|
||||||
|
|
||||||
|
# Train for a few steps.
|
||||||
|
python -m counterfactual_fairness.adult_pscf --config="counterfactual_fairness/adult_pscf_config.py" --dataset_dir=${DATA_DIR} --num_steps=10
|
||||||
|
|
||||||
|
# Clean up.
|
||||||
|
rm -r ${TMP_DIR}
|
||||||
|
echo "Test run complete."
|
||||||
@@ -0,0 +1,311 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Common utilities."""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from jax import random
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import pandas as pd
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_probability.substrates import jax as tfp
|
||||||
|
|
||||||
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(dataset: pd.DataFrame,
|
||||||
|
batch_size: int,
|
||||||
|
shuffle_size: int = 10000,
|
||||||
|
num_epochs: Optional[int] = None) -> tf.data.Dataset:
|
||||||
|
"""Makes a tf.Dataset with correct preprocessing."""
|
||||||
|
dataset_copy = dataset.copy()
|
||||||
|
for column in dataset.columns:
|
||||||
|
if dataset[column].dtype.name == 'category':
|
||||||
|
dataset_copy.loc[:, column] = dataset[column].cat.codes
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices(dataset_copy.values)
|
||||||
|
if shuffle_size > 0:
|
||||||
|
ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True)
|
||||||
|
ds = ds.repeat(num_epochs)
|
||||||
|
return ds.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_mode(
|
||||||
|
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||||
|
) -> jnp.DeviceArray:
|
||||||
|
"""Calculates the (one-hot) mode of a multinomial distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution_or_probs:
|
||||||
|
`tfp.distributions.Distribution` | List[tensors].
|
||||||
|
If the former, it is assumed that it has a `probs` property, and
|
||||||
|
represents a distribution over categories. If the latter, these are
|
||||||
|
taken to be the probabilities of categories directly.
|
||||||
|
In either case, it is assumed that `probs` will be shape
|
||||||
|
(batch_size, dim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, float32, (batch_size, dim).
|
||||||
|
The mode of the distribution - this will be in one-hot form, but contain
|
||||||
|
multiple non-zero entries in the event that more than one probability is
|
||||||
|
joint-highest.
|
||||||
|
"""
|
||||||
|
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||||
|
probs = distribution_or_probs.probs_parameter()
|
||||||
|
else:
|
||||||
|
probs = distribution_or_probs
|
||||||
|
max_prob = jnp.max(probs, axis=1, keepdims=True)
|
||||||
|
mode = jnp.int32(jnp.equal(probs, max_prob))
|
||||||
|
return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_class(
|
||||||
|
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||||
|
) -> jnp.DeviceArray:
|
||||||
|
"""Computes the mode class of a multinomial distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution_or_probs:
|
||||||
|
`tfp.distributions.Distribution` | DeviceArray.
|
||||||
|
As for `multinomial_mode`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, float32, (batch_size,).
|
||||||
|
For each element in the batch, the index of the class with highest
|
||||||
|
probability.
|
||||||
|
"""
|
||||||
|
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||||
|
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1)
|
||||||
|
return jnp.argmax(distribution_or_probs, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||||
|
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
|
||||||
|
|
||||||
|
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
|
||||||
|
rather than Tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs: `DeviceArray`, (batch_size, dim). Probabilities for each class, for
|
||||||
|
each element in a batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, (batch_size, dim).
|
||||||
|
"""
|
||||||
|
max_prob = jnp.amax(probs, axis=1, keepdims=True)
|
||||||
|
mode = jnp.equal(probs, max_prob).astype(jnp.int32)
|
||||||
|
return (mode / jnp.sum(mode, axis=1, keepdims=True)).astype(jnp.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
||||||
|
data: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||||
|
"""Compute the accuracy, averaged over a batch of data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution_or_probs:
|
||||||
|
`tfp.distributions.Distribution` | List[tensors].
|
||||||
|
As for functions above.
|
||||||
|
data: `DeviceArray`. Reference data, of shape (batch_size, dim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, float32, ().
|
||||||
|
Overall scalar accuracy.
|
||||||
|
"""
|
||||||
|
return jnp.mean(
|
||||||
|
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||||
|
"""Softmax function, implemented for numpy ndarrays."""
|
||||||
|
assert len(logits.shape) == 2
|
||||||
|
# Normalise for better stability.
|
||||||
|
s = jnp.max(logits, axis=1, keepdims=True)
|
||||||
|
e_x = jnp.exp(logits - s)
|
||||||
|
return e_x / jnp.sum(e_x, axis=1, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_samples(distribution, num_samples, seed=None):
|
||||||
|
"""Given a batched distribution, compute samples and reshape along batch.
|
||||||
|
|
||||||
|
That is, we have a distribution of shape (batch_size, ...), where each element
|
||||||
|
of the tensor is independent. We then draw num_samples from each component, to
|
||||||
|
give a tensor of shape:
|
||||||
|
|
||||||
|
(num_samples, batch_size, ...)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution: `tfp.distributions.Distribution`. The distribution from which
|
||||||
|
to sample.
|
||||||
|
num_samples: `Integral` | `DeviceArray`, int32, (). The number of samples.
|
||||||
|
seed: `Integral` | `None`. The seed that will be forwarded to the call to
|
||||||
|
distribution.sample. Defaults to `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, float32, (batch_size * num_samples, ...).
|
||||||
|
Samples for each element of the batch.
|
||||||
|
"""
|
||||||
|
# Obtain the sample from the distribution, which will be of shape
|
||||||
|
# [num_samples] + batch_shape + event_shape.
|
||||||
|
sample = distribution.sample(num_samples, seed=seed)
|
||||||
|
sample = sample.reshape((-1, sample.shape[-1]))
|
||||||
|
|
||||||
|
# Combine the first two dimensions through a reshape, so the result will
|
||||||
|
# be of shape (num_samples * batch_size,) + shape_tail.
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def mmd_loss(distribution: tfd.Distribution,
|
||||||
|
is_a: jnp.DeviceArray,
|
||||||
|
num_samples: int,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
num_random_features: int = 50,
|
||||||
|
gamma: float = 1.):
|
||||||
|
"""Given two distributions, compute the Maximum Mean Discrepancy (MMD).
|
||||||
|
|
||||||
|
More exactly, this uses the 'FastMMD' approximation, a.k.a. 'Random Fourier
|
||||||
|
Features'. See the description, for example, in sections 2.3.1 and 2.4 of
|
||||||
|
https://arxiv.org/pdf/1511.00830.pdf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution: Distribution whose `sample()` method will return a
|
||||||
|
DeviceArray of shape (batch_size, dim).
|
||||||
|
is_a: A boolean array indicating which elements of the batch correspond
|
||||||
|
to class A (the remaining indices correspond to class B).
|
||||||
|
num_samples: The number of samples to draw from `distribution`.
|
||||||
|
rng: Random seed provided by the user.
|
||||||
|
num_random_features: The number of random fourier features
|
||||||
|
used in the expansion.
|
||||||
|
gamma: The value of gamma in the Gaussian MMD kernel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, shape ().
|
||||||
|
The scalar MMD value for samples taken from the given distributions.
|
||||||
|
"""
|
||||||
|
if distribution.event_shape == (): # pylint: disable=g-explicit-bool-comparison
|
||||||
|
dim_x = distribution.batch_shape[1]
|
||||||
|
else:
|
||||||
|
dim_x, = distribution.event_shape
|
||||||
|
|
||||||
|
# Obtain samples from the distribution, which will be of shape
|
||||||
|
# [num_samples] + batch_shape + event_shape.
|
||||||
|
samples = distribution.sample(num_samples, seed=rng)
|
||||||
|
|
||||||
|
w = random.normal(rng, shape=((dim_x, num_random_features)))
|
||||||
|
b = random.uniform(rng, shape=(num_random_features,),
|
||||||
|
minval=0, maxval=2*jnp.pi)
|
||||||
|
|
||||||
|
def features(x):
|
||||||
|
"""Compute the kitchen sink feature."""
|
||||||
|
# We need to contract last axis of x with first of W - do this with
|
||||||
|
# tensordot. The result has shape:
|
||||||
|
# (?, ?, num_random_features)
|
||||||
|
return jnp.sqrt(2 / num_random_features) * jnp.cos(
|
||||||
|
jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
|
||||||
|
|
||||||
|
# Compute the expected values of the given features.
|
||||||
|
# The first axis represents the samples from the distribution,
|
||||||
|
# second axis represents the batch_size.
|
||||||
|
# Each of these now has shape (num_random_features,)
|
||||||
|
exp_features = features(samples)
|
||||||
|
# Swap axes so that batch_size is the last dimension to be compatible
|
||||||
|
# with is_a and is_b shape at the next step
|
||||||
|
exp_features_reshaped = jnp.swapaxes(exp_features, 1, 2)
|
||||||
|
# Current dimensions [num_samples, num_random_features, batch_size]
|
||||||
|
exp_features_reshaped_a = jnp.where(is_a, exp_features_reshaped, 0)
|
||||||
|
exp_features_reshaped_b = jnp.where(is_a, 0, exp_features_reshaped)
|
||||||
|
exp_features_a = jnp.mean(exp_features_reshaped_a, axis=(0, 2))
|
||||||
|
exp_features_b = jnp.mean(exp_features_reshaped_b, axis=(0, 2))
|
||||||
|
|
||||||
|
assert exp_features_a.shape == (num_random_features,)
|
||||||
|
difference = exp_features_a - exp_features_b
|
||||||
|
|
||||||
|
# Compute the squared norm. Shape ().
|
||||||
|
return jnp.tensordot(difference, difference, axes=1)
|
||||||
|
|
||||||
|
|
||||||
|
def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.):
|
||||||
|
"""Exact estimate of MMD."""
|
||||||
|
assert distribution_a.event_shape == distribution_b.event_shape
|
||||||
|
assert distribution_a.batch_shape[1:] == distribution_b.batch_shape[1:]
|
||||||
|
|
||||||
|
# shape (num_samples * batch_size_a, dim_x)
|
||||||
|
samples_a = get_samples(distribution_a, num_samples)
|
||||||
|
# shape (num_samples * batch_size_b, dim_x)
|
||||||
|
samples_b = get_samples(distribution_b, num_samples)
|
||||||
|
|
||||||
|
# Make matrices of shape
|
||||||
|
# (size_b, size_a, dim_x)
|
||||||
|
# where:
|
||||||
|
# size_a = num_samples * batch_size_a
|
||||||
|
# size_b = num_samples * batch_size_b
|
||||||
|
size_a = samples_a.shape[0]
|
||||||
|
size_b = samples_b.shape[0]
|
||||||
|
x_a = jnp.expand_dims(samples_a, axis=0)
|
||||||
|
x_a = jnp.tile(x_a, (size_b, 1, 1))
|
||||||
|
x_b = jnp.expand_dims(samples_b, axis=1)
|
||||||
|
x_b = jnp.tile(x_b, (1, size_a, 1))
|
||||||
|
|
||||||
|
def kernel_mean(x, y):
|
||||||
|
"""Gaussian kernel mean."""
|
||||||
|
|
||||||
|
diff = x - y
|
||||||
|
|
||||||
|
# Contract over dim_x.
|
||||||
|
exponent = - jnp.einsum('ijk,ijk->ij', diff, diff) / gamma
|
||||||
|
|
||||||
|
# This has shape (size_b, size_a).
|
||||||
|
kernel_matrix = jnp.exp(exponent)
|
||||||
|
|
||||||
|
# Shape ().
|
||||||
|
return jnp.mean(kernel_matrix)
|
||||||
|
|
||||||
|
# Equation 7 from arxiv 1511.00830
|
||||||
|
return (
|
||||||
|
kernel_mean(x_a, x_a)
|
||||||
|
+ kernel_mean(x_b, x_b)
|
||||||
|
- 2 * kernel_mean(x_a, x_b))
|
||||||
|
|
||||||
|
|
||||||
|
def scalar_log_prob(distribution, val):
|
||||||
|
"""Compute the log_prob per batch entry.
|
||||||
|
|
||||||
|
It is conceptually similar to:
|
||||||
|
|
||||||
|
jnp.sum(distribution.log_prob(val), axis=1)
|
||||||
|
|
||||||
|
However, classes like `tfp.distributions.Multinomial` have a log_prob which
|
||||||
|
returns a tensor of shape (batch_size,), which will cause the above
|
||||||
|
incantation to fail. In these cases we fall back to returning just:
|
||||||
|
|
||||||
|
distribution.log_prob(val)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution: `tfp.distributions.Distribution` which implements log_prob.
|
||||||
|
val: `DeviceArray`, (batch_size, dim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DeviceArray`, (batch_size,).
|
||||||
|
If the result of log_prob has a trailing dimension, we perform a reduce_sum
|
||||||
|
over it.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If distribution.log_prob(val) has an unsupported shape.
|
||||||
|
"""
|
||||||
|
log_prob_val = distribution.log_prob(val)
|
||||||
|
if len(log_prob_val.shape) == 1:
|
||||||
|
return log_prob_val
|
||||||
|
elif len(log_prob_val.shape) > 2:
|
||||||
|
raise ValueError('log_prob_val has unexpected shape {}.'.format(
|
||||||
|
log_prob_val.shape))
|
||||||
|
return jnp.sum(log_prob_val, axis=1)
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Functions and classes for performing variational inference."""
|
||||||
|
|
||||||
|
from typing import Callable, Iterable, Optional
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from tensorflow_probability.substrates import jax as tfp
|
||||||
|
|
||||||
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
|
||||||
|
class Variational(hk.Module):
|
||||||
|
"""A module representing the variational distribution q(H | *O).
|
||||||
|
|
||||||
|
H is assumed to be a continuous variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
common_layer_sizes: Iterable[int],
|
||||||
|
activation: Callable[[jnp.DeviceArray],
|
||||||
|
jnp.DeviceArray] = jnp.tanh,
|
||||||
|
output_dim: int = 1,
|
||||||
|
name: Optional[str] = None):
|
||||||
|
"""Initialises a `Variational` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
common_layer_sizes: The number of hidden units in the shared dense
|
||||||
|
network layers.
|
||||||
|
activation: Nonlinearity function to apply to each of the
|
||||||
|
common layers.
|
||||||
|
output_dim: The dimensionality of `H`.
|
||||||
|
name: A name to assign to the module instance.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._common_layer_sizes = common_layer_sizes
|
||||||
|
self._activation = activation
|
||||||
|
self._output_dim = output_dim
|
||||||
|
|
||||||
|
self._linear_layers = [
|
||||||
|
hk.Linear(layer_size)
|
||||||
|
for layer_size in self._common_layer_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
self._mean_output = hk.Linear(self._output_dim)
|
||||||
|
self._log_var_output = hk.Linear(self._output_dim)
|
||||||
|
|
||||||
|
def __call__(self, *args) -> tfd.Distribution:
|
||||||
|
"""Create a distribution for q(H | *O).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: `List[DeviceArray]`. Corresponds to the values of whatever
|
||||||
|
variables are in the conditional set *O.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tfp.distributions.NormalDistribution` instance.
|
||||||
|
"""
|
||||||
|
# Stack all inputs, ensuring that shapes are consistent and that they are
|
||||||
|
# all of dtype float32.
|
||||||
|
input_ = [hk.Flatten()(arg) for arg in args]
|
||||||
|
input_ = jnp.concatenate(input_, axis=1)
|
||||||
|
|
||||||
|
# Create a common set of layers, then final layer separates mean & log_var
|
||||||
|
for layer in self._linear_layers:
|
||||||
|
input_ = layer(input_)
|
||||||
|
input_ = self._activation(input_)
|
||||||
|
|
||||||
|
# input_ now represents a tensor of shape (batch_size, final_layer_size).
|
||||||
|
# This is now put through two final layers, one for the computation of each
|
||||||
|
# of the mean and standard deviation of the resultant distribution.
|
||||||
|
mean = self._mean_output(input_)
|
||||||
|
log_var = self._log_var_output(input_)
|
||||||
|
std = jnp.sqrt(jnp.exp(log_var))
|
||||||
|
return tfd.Normal(mean, std)
|
||||||
Reference in New Issue
Block a user