mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-23 07:45:18 +08:00
Adding counterfactual_fairness to deepmind-research repo
PiperOrigin-RevId: 396665335
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
f49eb487a7
commit
cff83be778
@@ -60,6 +60,7 @@ https://deepmind.com/research/publications/
|
||||
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer), NeurIPS 2020
|
||||
* [What can I do here? A theory of affordances in reinforcement learning](affordances_theory), ICML 2020
|
||||
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
|
||||
* [Path-Specific Counterfactual Fairness](counterfactual_fairness), AAAI 2019
|
||||
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
|
||||
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
|
||||
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# Path-Specific Counterfactual Fairness (AAAI 2019)
|
||||
|
||||
Paper: [Path-Specific Counterfactual Fairness](https://ojs.aaai.org/index.php/AAAI/article/view/4777/4655)
|
||||
|
||||
If you use the code here please cite this paper:
|
||||
|
||||
@inproceedings{chiappa2019path,
|
||||
title={Path-specific counterfactual fairness},
|
||||
author={Chiappa, Silvia},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={33},
|
||||
number={01},
|
||||
pages={7801--7808},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
### Overview
|
||||
|
||||
This release contains the path-specific counterfactual fairness method used
|
||||
in the paper, as well as utility functions for loading the *Adult* dataset.
|
||||
|
||||
The following gives a brief overview of the contents, more detailed
|
||||
documentation is available within each file:
|
||||
|
||||
* __causal_network.py__: Defines the `Node` class, instances of which can be
|
||||
combined into a directed graph. Associated with each node is a
|
||||
`distribution_module`, a haiku module which builds a tensorflow
|
||||
`Distribution` instance as a function of the node's parents.
|
||||
* __util.py__: Miscellaneous utility functions.
|
||||
* __variational.py__: Class for performing variational
|
||||
inference. The `Variational` haiku module is a general-purpose approximate
|
||||
posterior, using an MLP to map from arbitrary inputs to the parameters
|
||||
of a Gaussian distribution.
|
||||
* __adult.py__: Utility functions for the *Adult* dataset.
|
||||
* __adult_pscf.py__: Training and 'fair' prediction process (using
|
||||
path-specific counterfactual fairness) on the *Adult* dataset.
|
||||
* __adult_pscf_config.py__: Configuration file with default training
|
||||
parameters.
|
||||
|
||||
### Dataset
|
||||
|
||||
The *Adult* dataset can be downloaded from
|
||||
[https://archive.ics.uci.edu/ml/datasets/adult](https://archive.ics.uci.edu/ml/datasets/adult). You may use the following
|
||||
command to download both necessary files to run the training script:
|
||||
|
||||
`sh download_dataset.sh ${OUTPUT_DIR}`
|
||||
|
||||
### Experiments
|
||||
|
||||
To download the dataset and run the main experiment reported in the paper,
|
||||
you may run:
|
||||
|
||||
`sh run_adult_pscf.sh`
|
||||
|
||||
### Acknowledgements
|
||||
|
||||
Credits to Thomas P. S. Gillam for the original TF1 implementation.
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Adult dataset.
|
||||
|
||||
See https://archive.ics.uci.edu/ml/datasets/adult.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
|
||||
_COLUMNS = ('age', 'workclass', 'final-weight', 'education', 'education-num',
|
||||
'marital-status', 'occupation', 'relationship', 'race', 'sex',
|
||||
'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
|
||||
'income')
|
||||
_CATEGORICAL_COLUMNS = ('workclass', 'education', 'marital-status',
|
||||
'occupation', 'race', 'relationship', 'sex',
|
||||
'native-country', 'income')
|
||||
|
||||
|
||||
def _read_data(
|
||||
name,
|
||||
data_path=''):
|
||||
with os.path.join(data_path, name) as data_file:
|
||||
data = pd.read_csv(data_file, header=None, index_col=False,
|
||||
names=_COLUMNS, skipinitialspace=True, na_values='?')
|
||||
for categorical in _CATEGORICAL_COLUMNS:
|
||||
data[categorical] = data[categorical].astype('category')
|
||||
return data
|
||||
|
||||
|
||||
def _combine_category_coding(df_1, df_2):
|
||||
"""Combines the categories between dataframes df_1 and df_2.
|
||||
|
||||
This is used to ensure that training and test data use the same category
|
||||
coding, so that the one-hot vectors representing the values are compatible
|
||||
between training and test data.
|
||||
|
||||
Args:
|
||||
df_1: Pandas DataFrame.
|
||||
df_2: Pandas DataFrame. Must have the same columns as df_1.
|
||||
"""
|
||||
for column in df_1.columns:
|
||||
if df_1[column].dtype.name == 'category':
|
||||
categories_1 = set(df_1[column].cat.categories)
|
||||
categories_2 = set(df_2[column].cat.categories)
|
||||
categories = sorted(categories_1 | categories_2)
|
||||
df_1[column].cat.set_categories(categories, inplace=True)
|
||||
df_2[column].cat.set_categories(categories, inplace=True)
|
||||
|
||||
|
||||
def read_all_data(root_dir, remove_missing=True):
|
||||
"""Return (train, test) dataframes, optionally removing incomplete rows."""
|
||||
train_data = _read_data('adult.data', root_dir)
|
||||
test_data = _read_data('adult.test', root_dir)
|
||||
_combine_category_coding(train_data, test_data)
|
||||
if remove_missing:
|
||||
train_data = train_data.dropna()
|
||||
test_data = test_data.dropna()
|
||||
|
||||
logging.info('Training data dtypes: %s', train_data.dtypes)
|
||||
return train_data, test_data
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,65 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Config for adult PSCF experiment."""
|
||||
|
||||
import ml_collections
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Return the default configuration."""
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
config.num_steps = 10000 # Number of training steps to perform.
|
||||
config.batch_size = 128 # Batch size.
|
||||
config.learning_rate = 0.01 # Learning rate
|
||||
|
||||
# Number of samples to draw for prediction.
|
||||
config.num_prediction_samples = 500
|
||||
|
||||
# Batch size to use for prediction. Ideally as big as possible, but may need
|
||||
# to be reduced for memory reasons depending on the value of
|
||||
# `num_prediction_samples`.
|
||||
config.prediction_batch_size = 500
|
||||
|
||||
# Multiplier for the likelihood term in the loss
|
||||
config.likelihood_multiplier = 5.
|
||||
|
||||
# Multiplier for the MMD constraint term in the loss
|
||||
config.constraint_multiplier = 0.
|
||||
|
||||
# Scaling factor to use in KL term.
|
||||
config.beta = 1.0
|
||||
|
||||
# The number of samples we draw from each latent variable distribution.
|
||||
config.mmd_sample_size = 100
|
||||
|
||||
# Directory into which results should be placed. By default it is the empty
|
||||
# string, in which case no saving will occur. The directory specified will be
|
||||
# created if it does not exist.
|
||||
config.output_dir = ''
|
||||
|
||||
# The index of the step at which to turn on the constraint multiplier. For
|
||||
# steps prior to this the multiplier will be zero.
|
||||
config.constraint_turn_on_step = 0
|
||||
|
||||
# The random seed for tensorflow that is applied to the graph iff the value is
|
||||
# non-negative. By default the seed is not constrained.
|
||||
config.seed = -1
|
||||
|
||||
# When doing fair inference, don't sample when given a sample for the baseline
|
||||
# gender.
|
||||
config.baseline_passthrough = False
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,405 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Construct DAGs representing causal graphs, and perform inference on them."""
|
||||
|
||||
import collections
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
from tensorflow_probability.substrates import jax as tfp
|
||||
import tree
|
||||
|
||||
|
||||
class Node:
|
||||
"""A node in a graphical model.
|
||||
|
||||
Conceptually, this represents a random variable in a causal probabilistic
|
||||
model. It knows about its 'parents', i.e. other Nodes upon which this Node
|
||||
causally depends. The user is responsible for ensuring that any graph built
|
||||
with this class is acyclic.
|
||||
|
||||
A node knows how to represent its probability density, conditional on the
|
||||
values of its parents.
|
||||
|
||||
The node needs to have a 'name', corresponding to the series within the
|
||||
dataframe that should be used to populate it.
|
||||
"""
|
||||
|
||||
def __init__(self, distribution_module, parents=(), hidden=False):
|
||||
"""Initialise a `Node` instance.
|
||||
|
||||
Args:
|
||||
distribution_module: An instance of `DistributionModule`, a Haiku module
|
||||
that is suitable for modelling the conditional distribution of this
|
||||
node given any parents.
|
||||
parents: `Iterable`, optional. A (potentially nested) collection of nodes
|
||||
which are direct ancestors of `self`.
|
||||
hidden: `bool`, optional. Whether this node is hidden. Hidden nodes are
|
||||
permitted to not have corresponding observations.
|
||||
"""
|
||||
parents = tree.flatten(parents)
|
||||
|
||||
self._distribution_module = distribution_module
|
||||
self._column = distribution_module.column
|
||||
self._index = distribution_module.index
|
||||
self._hidden = hidden
|
||||
self._observed_value = None
|
||||
|
||||
# When implementing the path-specific counterfactual fairness algorithm,
|
||||
# we need the concept of a distribution conditional on the 'corrected'
|
||||
# values of the parents. This is achieved via the 'node_to_replacement'
|
||||
# argument of make_distribution.
|
||||
|
||||
# However, in order to work with the `fix_categorical` and `fix_continuous`
|
||||
# functions, we need to assign counterfactual values for parents at
|
||||
# evaluation time.
|
||||
self._parent_to_value = collections.OrderedDict(
|
||||
(parent, None) for parent in parents)
|
||||
|
||||
# This is the conditional distribution using no replacements, i.e. it is
|
||||
# conditioned on the observed values of parents.
|
||||
self._distribution = None
|
||||
|
||||
def __repr__(self):
|
||||
return 'Node<{}>'.format(self.name)
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
"""The dimensionality of this node."""
|
||||
return self._distribution_module.dim
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._column
|
||||
|
||||
@property
|
||||
def hidden(self):
|
||||
return self._hidden
|
||||
|
||||
@property
|
||||
def observed_value(self):
|
||||
return self._observed_value
|
||||
|
||||
def find_ancestor(self, name):
|
||||
"""Returns an ancestor node with the given name."""
|
||||
if self.name == name:
|
||||
return self
|
||||
for parent in self.parents:
|
||||
found = parent.find_ancestor(name)
|
||||
if found is not None:
|
||||
return found
|
||||
|
||||
@property
|
||||
def parents(self):
|
||||
return tuple(self._parent_to_value)
|
||||
|
||||
@property
|
||||
def distribution_module(self):
|
||||
return self._distribution_module
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
self._distribution = self.make_distribution()
|
||||
return self._distribution
|
||||
|
||||
def make_distribution(self, node_to_replacement=None):
|
||||
"""Make a conditional distribution for this node | parents.
|
||||
|
||||
By default we use values (representing 'real data') from the parent
|
||||
nodes as inputs to the distribution, however we can alternatively swap out
|
||||
any of these for arbitrary arrays by specifying `node_to_replacement`.
|
||||
|
||||
Args:
|
||||
node_to_replacement: `None`, `dict: Node -> DeviceArray`.
|
||||
If specified, use the indicated array.
|
||||
|
||||
Returns:
|
||||
`tfp.distributions.Distribution`
|
||||
"""
|
||||
cur_parent_to_value = self._parent_to_value
|
||||
self._parent_to_value = collections.OrderedDict(
|
||||
(parent, parent.observed_value) for parent in cur_parent_to_value.keys()
|
||||
)
|
||||
|
||||
if node_to_replacement is None:
|
||||
parent_values = self._parent_to_value.values()
|
||||
return self._distribution_module(*parent_values)
|
||||
args = []
|
||||
for node, value in self._parent_to_value.items():
|
||||
if node in node_to_replacement:
|
||||
replacement = node_to_replacement[node]
|
||||
args.append(replacement)
|
||||
else:
|
||||
args.append(value)
|
||||
return self._distribution_module(*args)
|
||||
|
||||
def populate(self, data, node_to_replacement=None):
|
||||
"""Given a dataframe, populate node data.
|
||||
|
||||
If the Node does not have data present, this is taken to be a sign of
|
||||
a) An error if the node is not hidden.
|
||||
b) Fine if the node is hidden.
|
||||
In case a) an exception will be raised, and in case b) observed)v will
|
||||
not be mutated.
|
||||
|
||||
Args:
|
||||
data: tf.data.Dataset
|
||||
node_to_replacement: None | dict(Node -> array). If not None, use the
|
||||
given ndarray data rather than extracting data from the frame. This is
|
||||
only considered when looking at the inputs to a distribution.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `data` doesn't contain the necessary feature, and the
|
||||
node is not hidden.
|
||||
"""
|
||||
column = self._column
|
||||
hidden = self._hidden
|
||||
replacement = None
|
||||
if node_to_replacement is not None and self in node_to_replacement:
|
||||
replacement = node_to_replacement[self]
|
||||
|
||||
if replacement is not None:
|
||||
# If a replacement is present, this takes priority over any other
|
||||
# consideration.
|
||||
self._observed_value = replacement
|
||||
return
|
||||
|
||||
if self._index < 0:
|
||||
if not hidden:
|
||||
raise RuntimeError(
|
||||
'Node {} is not hidden, and column {} is not in the frame.'.format(
|
||||
self, column))
|
||||
# Nothing to do - there is no data, and the node is hidden.
|
||||
return
|
||||
|
||||
# Produce the observed value for this node.
|
||||
self._observed_value = self._distribution_module.prepare_data(data)
|
||||
|
||||
|
||||
class DistributionModule(hk.Module):
|
||||
"""Common base class for a Haiku module representing a distribution.
|
||||
|
||||
This provides some additional functionality common to all modules that would
|
||||
be used as arguments to the `Node` class above.
|
||||
"""
|
||||
|
||||
def __init__(self, column, index, dim):
|
||||
"""Initialise a `DistributionModule` instance.
|
||||
|
||||
Args:
|
||||
column: `string`. The name of the random variable to which this
|
||||
distribution corresponds, and should match the name of the series in
|
||||
the pandas dataframe.
|
||||
index: `int`. The index of the corresponding feature in the dataset.
|
||||
dim: `int`. The output dimensionality of the distribution.
|
||||
"""
|
||||
super().__init__(name=column.replace('-', '_'))
|
||||
self._column = column
|
||||
self._index = index
|
||||
self._dim = dim
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
"""The output dimensionality of this distribution."""
|
||||
return self._dim
|
||||
|
||||
@property
|
||||
def column(self):
|
||||
return self._column
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self._index
|
||||
|
||||
def prepare_data(self, data):
|
||||
"""Given a general tensor, return an ndarray if required.
|
||||
|
||||
This method implements the functionality delegated from
|
||||
`Node._prepare_data`, and it is expected that subclasses will override the
|
||||
implementation appropriately.
|
||||
|
||||
Args:
|
||||
data: A tf.data.Dataset.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of appropriately converted values for this series.
|
||||
"""
|
||||
return data[:, [self._index]]
|
||||
|
||||
def _package_args(self, args):
|
||||
"""Concatenate args into a single tensor.
|
||||
|
||||
Args:
|
||||
args: `List[DeviceArray]`, length > 0.
|
||||
Each array is of shape (batch_size, ?) or (batch_size,). The former
|
||||
will occur if looking at e.g. a one-hot encoded categorical variable,
|
||||
and the latter in the case of a continuous variable.
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, (batch_size, num_values).
|
||||
"""
|
||||
return jnp.concatenate(args, axis=1)
|
||||
|
||||
|
||||
class Gaussian(DistributionModule):
|
||||
"""A Haiku module that maps some inputs into a normal distribution."""
|
||||
|
||||
def __init__(self, column, index, dim=1, hidden_shape=(),
|
||||
hidden_activation=jnp.tanh, scale=None):
|
||||
"""Initialise a `Gaussian` instance with some dimensionality."""
|
||||
super(Gaussian, self).__init__(column, index, dim)
|
||||
self._hidden_shape = tuple(hidden_shape)
|
||||
self._hidden_activation = hidden_activation
|
||||
self._scale = scale
|
||||
|
||||
self._loc_net = hk.nets.MLP(self._hidden_shape + (self._dim,),
|
||||
activation=self._hidden_activation)
|
||||
|
||||
def __call__(self, *args):
|
||||
if args:
|
||||
# There are arguments - these represent the variables on which we are
|
||||
# conditioning. We set the mean of the output distribution to be a
|
||||
# function of these values, parameterised with an MLP.
|
||||
concatenated_inputs = self._package_args(args)
|
||||
loc = self._loc_net(concatenated_inputs)
|
||||
else:
|
||||
# There are no arguments, so instead have a learnable location parameter.
|
||||
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
|
||||
|
||||
if self._scale is None:
|
||||
# The scale has not been explicitly specified, in which case it is left
|
||||
# to be single value, i.e. not a function of the conditioning set.
|
||||
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
|
||||
scale = jnp.sqrt(jnp.exp(log_var))
|
||||
else:
|
||||
scale = jnp.float32(self._scale)
|
||||
|
||||
return tfp.distributions.Normal(loc=loc, scale=scale)
|
||||
|
||||
def prepare_data(self, data):
|
||||
# For continuous data, we ensure the data is of dtype float32, and
|
||||
# additionally that the resulant shape is (num_examples, 1)
|
||||
# Note that this implementation only works for dim=1, however this is
|
||||
# currently also enforced by the fact that pandas series cannot be
|
||||
# multidimensional.
|
||||
result = data[:, [self.index]].astype(jnp.float32)
|
||||
if len(result.shape) == 1:
|
||||
result = jnp.expand_dims(result, axis=1)
|
||||
return result
|
||||
|
||||
|
||||
class GaussianMixture(DistributionModule):
|
||||
"""A Haiku module that maps some inputs into a mixture of normals."""
|
||||
|
||||
def __init__(self, column, num_components, dim=1):
|
||||
"""Initialise a `GaussianMixture` instance with some dimensionality.
|
||||
|
||||
Args:
|
||||
column: `string`. The name of the column.
|
||||
num_components: `int`. The number of gaussians in this mixture.
|
||||
dim: `int`. The dimensionality of the variable.
|
||||
"""
|
||||
super().__init__(column, -1, dim)
|
||||
self._num_components = num_components
|
||||
self._loc_net = hk.nets.MLP([self._dim])
|
||||
self._categorical_logits_module = hk.nets.MLP([self._num_components])
|
||||
|
||||
def __call__(self, *args):
|
||||
# Define component Gaussians to be independent functions of args.
|
||||
locs = []
|
||||
scales = []
|
||||
for _ in range(self._num_components):
|
||||
loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros)
|
||||
log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones)
|
||||
scale = jnp.sqrt(jnp.exp(log_var))
|
||||
locs.extend(loc)
|
||||
scales.extend(scale)
|
||||
|
||||
# Define the Categorical distribution which switches between these
|
||||
categorical_logits = hk.get_parameter('categorical_logits',
|
||||
shape=[self._num_components],
|
||||
init=jnp.zeros)
|
||||
|
||||
# Enforce positivity in the logits
|
||||
categorical_logits = jax.nn.sigmoid(categorical_logits)
|
||||
|
||||
# If we have a multidimensional node, then the normal distributions above
|
||||
# have a batch shape of (dim,). We want to select between these using the
|
||||
# categorical distribution, so tile the logits to match this shape
|
||||
expanded_logits = jnp.repeat(categorical_logits, self._dim)
|
||||
|
||||
categorical = tfp.distributions.Categorical(logits=expanded_logits)
|
||||
|
||||
return tfp.distributions.MixtureSameFamily(
|
||||
mixture_distribution=categorical,
|
||||
components_distribution=tfp.distributions.Normal(
|
||||
loc=locs, scale=scales))
|
||||
|
||||
|
||||
class MLPMultinomial(DistributionModule):
|
||||
"""A Haiku module that consists of an MLP + multinomial distribution."""
|
||||
|
||||
def __init__(self, column, index, dim, hidden_shape=(),
|
||||
hidden_activation=jnp.tanh):
|
||||
"""Initialise an MLPMultinomial instance.
|
||||
|
||||
Args:
|
||||
column: `string`. Name of the corresponding dataframe column.
|
||||
index: `int`. The index of the input data for this feature.
|
||||
dim: `int`. Number of categories.
|
||||
hidden_shape: `Iterable`, optional. Shape of hidden layers.
|
||||
hidden_activation: `Callable`, optional. Non-linearity for hidden
|
||||
layers.
|
||||
"""
|
||||
super(MLPMultinomial, self).__init__(column, index, dim)
|
||||
self._hidden_shape = tuple(hidden_shape)
|
||||
self._hidden_activation = hidden_activation
|
||||
self._logit_net = hk.nets.MLP(self._hidden_shape + (self.dim,),
|
||||
activation=self._hidden_activation)
|
||||
|
||||
@classmethod
|
||||
def from_frame(cls, data, column, hidden_shape=()):
|
||||
"""Create an MLPMultinomial instance from a pandas dataframe and column."""
|
||||
# Helper method that uses the dataframe to work out how many categories
|
||||
# are in the given column. The dataframe is not used for any other purpose.
|
||||
if not isinstance(data[column].dtype, pd.api.types.CategoricalDtype):
|
||||
raise ValueError('{} is not categorical.'.format(column))
|
||||
index = list(data.columns).index(column)
|
||||
num_categories = len(data[column].cat.categories)
|
||||
return cls(column, index, num_categories, hidden_shape)
|
||||
|
||||
def __call__(self, *args):
|
||||
if args:
|
||||
concatenated_inputs = self._package_args(args)
|
||||
logits = self._logit_net(concatenated_inputs)
|
||||
else:
|
||||
logits = hk.get_parameter('b', shape=[self.dim], init=jnp.zeros)
|
||||
return tfp.distributions.Multinomial(logits=logits, total_count=1.0)
|
||||
|
||||
def prepare_data(self, data):
|
||||
# For categorical data, we convert to a one-hot representation using the
|
||||
# pandas category 'codes'. These are integers, and will have a definite
|
||||
# ordering that is identical between runs.
|
||||
codes = data[:, self.index]
|
||||
codes = codes.astype(jnp.int32)
|
||||
return jnp.eye(self.dim)[codes]
|
||||
|
||||
|
||||
def populate(nodes, dataframe, node_to_replacement=None):
|
||||
"""Populate observed values for nodes."""
|
||||
for node in nodes:
|
||||
node.populate(dataframe, node_to_replacement=node_to_replacement)
|
||||
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
# sh download_dataset.sh ${OUTPUT_DIR}
|
||||
# Example:
|
||||
# sh download_dataset.sh uci_adult
|
||||
|
||||
set -e
|
||||
|
||||
OUTPUT_DIR="${1}"
|
||||
|
||||
BASE_URL="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
|
||||
|
||||
mkdir -p ${OUTPUT_DIR}
|
||||
for file in adult.data adult.test
|
||||
do
|
||||
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
|
||||
done
|
||||
@@ -0,0 +1,12 @@
|
||||
wheel
|
||||
absl-py>=0.12.0
|
||||
dm-haiku>=0.0.4
|
||||
optax>=0.0.8
|
||||
ml_collections
|
||||
numpy>=1.16.4
|
||||
tensorflow>=2.5.0
|
||||
tensorflow-datasets>=4.3.0
|
||||
tensorflow_probability
|
||||
pandas
|
||||
sklearn
|
||||
jaxlib>=0.1.67+cuda110
|
||||
Executable
+40
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# Fail on any error.
|
||||
set -e
|
||||
|
||||
# Display commands being run.
|
||||
set -x
|
||||
|
||||
TMP_DIR=`mktemp -d`
|
||||
|
||||
virtualenv --python=python3.6 "${TMP_DIR}/env"
|
||||
source "${TMP_DIR}/env/bin/activate"
|
||||
|
||||
# Install dependencies.
|
||||
pip install --upgrade -r counterfactual_fairness/requirements.txt
|
||||
|
||||
# Download minimal dataset
|
||||
DATA_DIR="${TMP_DIR}/adult"
|
||||
bash counterfactual_fairness/download_dataset.sh ${TMP_DIR}
|
||||
|
||||
# Train for a few steps.
|
||||
python -m counterfactual_fairness.adult_pscf --config="counterfactual_fairness/adult_pscf_config.py" --dataset_dir=${DATA_DIR} --num_steps=10
|
||||
|
||||
# Clean up.
|
||||
rm -r ${TMP_DIR}
|
||||
echo "Test run complete."
|
||||
@@ -0,0 +1,311 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Common utilities."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from tensorflow_probability.substrates import jax as tfp
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def get_dataset(dataset: pd.DataFrame,
|
||||
batch_size: int,
|
||||
shuffle_size: int = 10000,
|
||||
num_epochs: Optional[int] = None) -> tf.data.Dataset:
|
||||
"""Makes a tf.Dataset with correct preprocessing."""
|
||||
dataset_copy = dataset.copy()
|
||||
for column in dataset.columns:
|
||||
if dataset[column].dtype.name == 'category':
|
||||
dataset_copy.loc[:, column] = dataset[column].cat.codes
|
||||
ds = tf.data.Dataset.from_tensor_slices(dataset_copy.values)
|
||||
if shuffle_size > 0:
|
||||
ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True)
|
||||
ds = ds.repeat(num_epochs)
|
||||
return ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
|
||||
def multinomial_mode(
|
||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||
) -> jnp.DeviceArray:
|
||||
"""Calculates the (one-hot) mode of a multinomial distribution.
|
||||
|
||||
Args:
|
||||
distribution_or_probs:
|
||||
`tfp.distributions.Distribution` | List[tensors].
|
||||
If the former, it is assumed that it has a `probs` property, and
|
||||
represents a distribution over categories. If the latter, these are
|
||||
taken to be the probabilities of categories directly.
|
||||
In either case, it is assumed that `probs` will be shape
|
||||
(batch_size, dim).
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, float32, (batch_size, dim).
|
||||
The mode of the distribution - this will be in one-hot form, but contain
|
||||
multiple non-zero entries in the event that more than one probability is
|
||||
joint-highest.
|
||||
"""
|
||||
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||
probs = distribution_or_probs.probs_parameter()
|
||||
else:
|
||||
probs = distribution_or_probs
|
||||
max_prob = jnp.max(probs, axis=1, keepdims=True)
|
||||
mode = jnp.int32(jnp.equal(probs, max_prob))
|
||||
return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
|
||||
|
||||
|
||||
def multinomial_class(
|
||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||
) -> jnp.DeviceArray:
|
||||
"""Computes the mode class of a multinomial distribution.
|
||||
|
||||
Args:
|
||||
distribution_or_probs:
|
||||
`tfp.distributions.Distribution` | DeviceArray.
|
||||
As for `multinomial_mode`.
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, float32, (batch_size,).
|
||||
For each element in the batch, the index of the class with highest
|
||||
probability.
|
||||
"""
|
||||
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1)
|
||||
return jnp.argmax(distribution_or_probs, axis=1)
|
||||
|
||||
|
||||
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
|
||||
|
||||
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
|
||||
rather than Tensors.
|
||||
|
||||
Args:
|
||||
probs: `DeviceArray`, (batch_size, dim). Probabilities for each class, for
|
||||
each element in a batch.
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, (batch_size, dim).
|
||||
"""
|
||||
max_prob = jnp.amax(probs, axis=1, keepdims=True)
|
||||
mode = jnp.equal(probs, max_prob).astype(jnp.int32)
|
||||
return (mode / jnp.sum(mode, axis=1, keepdims=True)).astype(jnp.float32)
|
||||
|
||||
|
||||
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
||||
data: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
"""Compute the accuracy, averaged over a batch of data.
|
||||
|
||||
Args:
|
||||
distribution_or_probs:
|
||||
`tfp.distributions.Distribution` | List[tensors].
|
||||
As for functions above.
|
||||
data: `DeviceArray`. Reference data, of shape (batch_size, dim).
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, float32, ().
|
||||
Overall scalar accuracy.
|
||||
"""
|
||||
return jnp.mean(
|
||||
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
|
||||
|
||||
|
||||
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
"""Softmax function, implemented for numpy ndarrays."""
|
||||
assert len(logits.shape) == 2
|
||||
# Normalise for better stability.
|
||||
s = jnp.max(logits, axis=1, keepdims=True)
|
||||
e_x = jnp.exp(logits - s)
|
||||
return e_x / jnp.sum(e_x, axis=1, keepdims=True)
|
||||
|
||||
|
||||
def get_samples(distribution, num_samples, seed=None):
|
||||
"""Given a batched distribution, compute samples and reshape along batch.
|
||||
|
||||
That is, we have a distribution of shape (batch_size, ...), where each element
|
||||
of the tensor is independent. We then draw num_samples from each component, to
|
||||
give a tensor of shape:
|
||||
|
||||
(num_samples, batch_size, ...)
|
||||
|
||||
Args:
|
||||
distribution: `tfp.distributions.Distribution`. The distribution from which
|
||||
to sample.
|
||||
num_samples: `Integral` | `DeviceArray`, int32, (). The number of samples.
|
||||
seed: `Integral` | `None`. The seed that will be forwarded to the call to
|
||||
distribution.sample. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, float32, (batch_size * num_samples, ...).
|
||||
Samples for each element of the batch.
|
||||
"""
|
||||
# Obtain the sample from the distribution, which will be of shape
|
||||
# [num_samples] + batch_shape + event_shape.
|
||||
sample = distribution.sample(num_samples, seed=seed)
|
||||
sample = sample.reshape((-1, sample.shape[-1]))
|
||||
|
||||
# Combine the first two dimensions through a reshape, so the result will
|
||||
# be of shape (num_samples * batch_size,) + shape_tail.
|
||||
return sample
|
||||
|
||||
|
||||
def mmd_loss(distribution: tfd.Distribution,
|
||||
is_a: jnp.DeviceArray,
|
||||
num_samples: int,
|
||||
rng: jnp.ndarray,
|
||||
num_random_features: int = 50,
|
||||
gamma: float = 1.):
|
||||
"""Given two distributions, compute the Maximum Mean Discrepancy (MMD).
|
||||
|
||||
More exactly, this uses the 'FastMMD' approximation, a.k.a. 'Random Fourier
|
||||
Features'. See the description, for example, in sections 2.3.1 and 2.4 of
|
||||
https://arxiv.org/pdf/1511.00830.pdf.
|
||||
|
||||
Args:
|
||||
distribution: Distribution whose `sample()` method will return a
|
||||
DeviceArray of shape (batch_size, dim).
|
||||
is_a: A boolean array indicating which elements of the batch correspond
|
||||
to class A (the remaining indices correspond to class B).
|
||||
num_samples: The number of samples to draw from `distribution`.
|
||||
rng: Random seed provided by the user.
|
||||
num_random_features: The number of random fourier features
|
||||
used in the expansion.
|
||||
gamma: The value of gamma in the Gaussian MMD kernel.
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, shape ().
|
||||
The scalar MMD value for samples taken from the given distributions.
|
||||
"""
|
||||
if distribution.event_shape == (): # pylint: disable=g-explicit-bool-comparison
|
||||
dim_x = distribution.batch_shape[1]
|
||||
else:
|
||||
dim_x, = distribution.event_shape
|
||||
|
||||
# Obtain samples from the distribution, which will be of shape
|
||||
# [num_samples] + batch_shape + event_shape.
|
||||
samples = distribution.sample(num_samples, seed=rng)
|
||||
|
||||
w = random.normal(rng, shape=((dim_x, num_random_features)))
|
||||
b = random.uniform(rng, shape=(num_random_features,),
|
||||
minval=0, maxval=2*jnp.pi)
|
||||
|
||||
def features(x):
|
||||
"""Compute the kitchen sink feature."""
|
||||
# We need to contract last axis of x with first of W - do this with
|
||||
# tensordot. The result has shape:
|
||||
# (?, ?, num_random_features)
|
||||
return jnp.sqrt(2 / num_random_features) * jnp.cos(
|
||||
jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
|
||||
|
||||
# Compute the expected values of the given features.
|
||||
# The first axis represents the samples from the distribution,
|
||||
# second axis represents the batch_size.
|
||||
# Each of these now has shape (num_random_features,)
|
||||
exp_features = features(samples)
|
||||
# Swap axes so that batch_size is the last dimension to be compatible
|
||||
# with is_a and is_b shape at the next step
|
||||
exp_features_reshaped = jnp.swapaxes(exp_features, 1, 2)
|
||||
# Current dimensions [num_samples, num_random_features, batch_size]
|
||||
exp_features_reshaped_a = jnp.where(is_a, exp_features_reshaped, 0)
|
||||
exp_features_reshaped_b = jnp.where(is_a, 0, exp_features_reshaped)
|
||||
exp_features_a = jnp.mean(exp_features_reshaped_a, axis=(0, 2))
|
||||
exp_features_b = jnp.mean(exp_features_reshaped_b, axis=(0, 2))
|
||||
|
||||
assert exp_features_a.shape == (num_random_features,)
|
||||
difference = exp_features_a - exp_features_b
|
||||
|
||||
# Compute the squared norm. Shape ().
|
||||
return jnp.tensordot(difference, difference, axes=1)
|
||||
|
||||
|
||||
def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.):
|
||||
"""Exact estimate of MMD."""
|
||||
assert distribution_a.event_shape == distribution_b.event_shape
|
||||
assert distribution_a.batch_shape[1:] == distribution_b.batch_shape[1:]
|
||||
|
||||
# shape (num_samples * batch_size_a, dim_x)
|
||||
samples_a = get_samples(distribution_a, num_samples)
|
||||
# shape (num_samples * batch_size_b, dim_x)
|
||||
samples_b = get_samples(distribution_b, num_samples)
|
||||
|
||||
# Make matrices of shape
|
||||
# (size_b, size_a, dim_x)
|
||||
# where:
|
||||
# size_a = num_samples * batch_size_a
|
||||
# size_b = num_samples * batch_size_b
|
||||
size_a = samples_a.shape[0]
|
||||
size_b = samples_b.shape[0]
|
||||
x_a = jnp.expand_dims(samples_a, axis=0)
|
||||
x_a = jnp.tile(x_a, (size_b, 1, 1))
|
||||
x_b = jnp.expand_dims(samples_b, axis=1)
|
||||
x_b = jnp.tile(x_b, (1, size_a, 1))
|
||||
|
||||
def kernel_mean(x, y):
|
||||
"""Gaussian kernel mean."""
|
||||
|
||||
diff = x - y
|
||||
|
||||
# Contract over dim_x.
|
||||
exponent = - jnp.einsum('ijk,ijk->ij', diff, diff) / gamma
|
||||
|
||||
# This has shape (size_b, size_a).
|
||||
kernel_matrix = jnp.exp(exponent)
|
||||
|
||||
# Shape ().
|
||||
return jnp.mean(kernel_matrix)
|
||||
|
||||
# Equation 7 from arxiv 1511.00830
|
||||
return (
|
||||
kernel_mean(x_a, x_a)
|
||||
+ kernel_mean(x_b, x_b)
|
||||
- 2 * kernel_mean(x_a, x_b))
|
||||
|
||||
|
||||
def scalar_log_prob(distribution, val):
|
||||
"""Compute the log_prob per batch entry.
|
||||
|
||||
It is conceptually similar to:
|
||||
|
||||
jnp.sum(distribution.log_prob(val), axis=1)
|
||||
|
||||
However, classes like `tfp.distributions.Multinomial` have a log_prob which
|
||||
returns a tensor of shape (batch_size,), which will cause the above
|
||||
incantation to fail. In these cases we fall back to returning just:
|
||||
|
||||
distribution.log_prob(val)
|
||||
|
||||
Args:
|
||||
distribution: `tfp.distributions.Distribution` which implements log_prob.
|
||||
val: `DeviceArray`, (batch_size, dim).
|
||||
|
||||
Returns:
|
||||
`DeviceArray`, (batch_size,).
|
||||
If the result of log_prob has a trailing dimension, we perform a reduce_sum
|
||||
over it.
|
||||
|
||||
Raises:
|
||||
ValueError: If distribution.log_prob(val) has an unsupported shape.
|
||||
"""
|
||||
log_prob_val = distribution.log_prob(val)
|
||||
if len(log_prob_val.shape) == 1:
|
||||
return log_prob_val
|
||||
elif len(log_prob_val.shape) > 2:
|
||||
raise ValueError('log_prob_val has unexpected shape {}.'.format(
|
||||
log_prob_val.shape))
|
||||
return jnp.sum(log_prob_val, axis=1)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Functions and classes for performing variational inference."""
|
||||
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import haiku as hk
|
||||
import jax.numpy as jnp
|
||||
from tensorflow_probability.substrates import jax as tfp
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
class Variational(hk.Module):
|
||||
"""A module representing the variational distribution q(H | *O).
|
||||
|
||||
H is assumed to be a continuous variable.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
common_layer_sizes: Iterable[int],
|
||||
activation: Callable[[jnp.DeviceArray],
|
||||
jnp.DeviceArray] = jnp.tanh,
|
||||
output_dim: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Initialises a `Variational` instance.
|
||||
|
||||
Args:
|
||||
common_layer_sizes: The number of hidden units in the shared dense
|
||||
network layers.
|
||||
activation: Nonlinearity function to apply to each of the
|
||||
common layers.
|
||||
output_dim: The dimensionality of `H`.
|
||||
name: A name to assign to the module instance.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._common_layer_sizes = common_layer_sizes
|
||||
self._activation = activation
|
||||
self._output_dim = output_dim
|
||||
|
||||
self._linear_layers = [
|
||||
hk.Linear(layer_size)
|
||||
for layer_size in self._common_layer_sizes
|
||||
]
|
||||
|
||||
self._mean_output = hk.Linear(self._output_dim)
|
||||
self._log_var_output = hk.Linear(self._output_dim)
|
||||
|
||||
def __call__(self, *args) -> tfd.Distribution:
|
||||
"""Create a distribution for q(H | *O).
|
||||
|
||||
Args:
|
||||
*args: `List[DeviceArray]`. Corresponds to the values of whatever
|
||||
variables are in the conditional set *O.
|
||||
|
||||
Returns:
|
||||
`tfp.distributions.NormalDistribution` instance.
|
||||
"""
|
||||
# Stack all inputs, ensuring that shapes are consistent and that they are
|
||||
# all of dtype float32.
|
||||
input_ = [hk.Flatten()(arg) for arg in args]
|
||||
input_ = jnp.concatenate(input_, axis=1)
|
||||
|
||||
# Create a common set of layers, then final layer separates mean & log_var
|
||||
for layer in self._linear_layers:
|
||||
input_ = layer(input_)
|
||||
input_ = self._activation(input_)
|
||||
|
||||
# input_ now represents a tensor of shape (batch_size, final_layer_size).
|
||||
# This is now put through two final layers, one for the computation of each
|
||||
# of the mean and standard deviation of the resultant distribution.
|
||||
mean = self._mean_output(input_)
|
||||
log_var = self._log_var_output(input_)
|
||||
std = jnp.sqrt(jnp.exp(log_var))
|
||||
return tfd.Normal(mean, std)
|
||||
Reference in New Issue
Block a user