diff --git a/README.md b/README.md index e4bae3f..04f93ae 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ https://deepmind.com/research/publications/ * [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer), NeurIPS 2020 * [What can I do here? A theory of affordances in reinforcement learning](affordances_theory), ICML 2020 * [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020 +* [Path-Specific Counterfactual Fairness](counterfactual_fairness), AAAI 2019 * [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019 * [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020 * [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020 diff --git a/counterfactual_fairness/README.md b/counterfactual_fairness/README.md new file mode 100644 index 0000000..e5e5032 --- /dev/null +++ b/counterfactual_fairness/README.md @@ -0,0 +1,57 @@ +# Path-Specific Counterfactual Fairness (AAAI 2019) + +Paper: [Path-Specific Counterfactual Fairness](https://ojs.aaai.org/index.php/AAAI/article/view/4777/4655) + +If you use the code here please cite this paper: + + @inproceedings{chiappa2019path, + title={Path-specific counterfactual fairness}, + author={Chiappa, Silvia}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={33}, + number={01}, + pages={7801--7808}, + year={2019} + } + +### Overview + +This release contains the path-specific counterfactual fairness method used +in the paper, as well as utility functions for loading the *Adult* dataset. + +The following gives a brief overview of the contents, more detailed +documentation is available within each file: + +* __causal_network.py__: Defines the `Node` class, instances of which can be + combined into a directed graph. Associated with each node is a + `distribution_module`, a haiku module which builds a tensorflow + `Distribution` instance as a function of the node's parents. +* __util.py__: Miscellaneous utility functions. +* __variational.py__: Class for performing variational + inference. The `Variational` haiku module is a general-purpose approximate + posterior, using an MLP to map from arbitrary inputs to the parameters + of a Gaussian distribution. +* __adult.py__: Utility functions for the *Adult* dataset. +* __adult_pscf.py__: Training and 'fair' prediction process (using + path-specific counterfactual fairness) on the *Adult* dataset. +* __adult_pscf_config.py__: Configuration file with default training + parameters. + +### Dataset + +The *Adult* dataset can be downloaded from +[https://archive.ics.uci.edu/ml/datasets/adult](https://archive.ics.uci.edu/ml/datasets/adult). You may use the following +command to download both necessary files to run the training script: + +`sh download_dataset.sh ${OUTPUT_DIR}` + +### Experiments + +To download the dataset and run the main experiment reported in the paper, +you may run: + +`sh run_adult_pscf.sh` + +### Acknowledgements + +Credits to Thomas P. S. Gillam for the original TF1 implementation. diff --git a/counterfactual_fairness/adult.py b/counterfactual_fairness/adult.py new file mode 100644 index 0000000..fd2787e --- /dev/null +++ b/counterfactual_fairness/adult.py @@ -0,0 +1,78 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Adult dataset. + +See https://archive.ics.uci.edu/ml/datasets/adult. +""" + +import os + +from absl import logging +import pandas as pd + + + + +_COLUMNS = ('age', 'workclass', 'final-weight', 'education', 'education-num', + 'marital-status', 'occupation', 'relationship', 'race', 'sex', + 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', + 'income') +_CATEGORICAL_COLUMNS = ('workclass', 'education', 'marital-status', + 'occupation', 'race', 'relationship', 'sex', + 'native-country', 'income') + + +def _read_data( + name, + data_path=''): + with os.path.join(data_path, name) as data_file: + data = pd.read_csv(data_file, header=None, index_col=False, + names=_COLUMNS, skipinitialspace=True, na_values='?') + for categorical in _CATEGORICAL_COLUMNS: + data[categorical] = data[categorical].astype('category') + return data + + +def _combine_category_coding(df_1, df_2): + """Combines the categories between dataframes df_1 and df_2. + + This is used to ensure that training and test data use the same category + coding, so that the one-hot vectors representing the values are compatible + between training and test data. + + Args: + df_1: Pandas DataFrame. + df_2: Pandas DataFrame. Must have the same columns as df_1. + """ + for column in df_1.columns: + if df_1[column].dtype.name == 'category': + categories_1 = set(df_1[column].cat.categories) + categories_2 = set(df_2[column].cat.categories) + categories = sorted(categories_1 | categories_2) + df_1[column].cat.set_categories(categories, inplace=True) + df_2[column].cat.set_categories(categories, inplace=True) + + +def read_all_data(root_dir, remove_missing=True): + """Return (train, test) dataframes, optionally removing incomplete rows.""" + train_data = _read_data('adult.data', root_dir) + test_data = _read_data('adult.test', root_dir) + _combine_category_coding(train_data, test_data) + if remove_missing: + train_data = train_data.dropna() + test_data = test_data.dropna() + + logging.info('Training data dtypes: %s', train_data.dtypes) + return train_data, test_data diff --git a/counterfactual_fairness/adult_pscf.py b/counterfactual_fairness/adult_pscf.py new file mode 100644 index 0000000..6587153 --- /dev/null +++ b/counterfactual_fairness/adult_pscf.py @@ -0,0 +1,593 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Training script for causal model for Adult dataset, using PSCF.""" + +import functools +import time +from typing import Any, List, Mapping, NamedTuple, Sequence + +from absl import app +from absl import flags +from absl import logging +import haiku as hk +import jax +import jax.numpy as jnp +from ml_collections.config_flags import config_flags +import numpy as np +import optax +import pandas as pd +from sklearn import metrics +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_probability.substrates import jax as tfp +from counterfactual_fairness import adult +from counterfactual_fairness import causal_network +from counterfactual_fairness import utils +from counterfactual_fairness import variational + +FLAGS = flags.FLAGS +config_flags.DEFINE_config_file( + 'config', 'adult_pscf_config.py', 'Training configuration.') + +LOG_EVERY = 100 + +# These are all aliases to callables which will return instances of +# particular distribution modules, or a Node itself. This is used to make +# subsequent code more legible. +Node = causal_network.Node +Gaussian = causal_network.Gaussian +MLPMultinomial = causal_network.MLPMultinomial + + +def build_input(train_data: pd.DataFrame, batch_size: int, + training_steps: int, shuffle_size: int = 10000): + """See base class.""" + num_epochs = (training_steps // batch_size) + 1 + ds = utils.get_dataset(train_data, batch_size, shuffle_size, + num_epochs=num_epochs) + ds = ds.prefetch(tf.data.AUTOTUNE) + return iter(tfds.as_numpy(ds)) + + +class CausalNetOutput(NamedTuple): + q_hidden_obs: Sequence[tfp.distributions.Distribution] + p_hidden: Sequence[tfp.distributions.Distribution] + hidden_samples: Sequence[jnp.ndarray] + log_p_obs_hidden: jnp.ndarray + is_male: jnp.ndarray # indicates which elements of the batch correspond to + # male individuals + + +def build_causal_graph(train_data: pd.DataFrame, column_names: List[str], + inputs: jnp.ndarray): + """Build the causal graph of the model.""" + make_multinomial = functools.partial( + causal_network.MLPMultinomial.from_frame, hidden_shape=(100,)) + make_gaussian = functools.partial( + causal_network.Gaussian, hidden_shape=(100,)) + + # Construct the graphical model. Each random variable is represented by an + # instance of the `Node` class, as discussed in that class's docstring. + # The following nodes have no parents, and thus the distribution modules + # will not be conditional on anything -- they simply represent priors. + node_a = Node(MLPMultinomial.from_frame(train_data, 'sex')) + node_c1 = Node(MLPMultinomial.from_frame(train_data, 'native-country')) + node_c2 = Node(Gaussian('age', column_names.index('age'))) + + # These are all hidden nodes, that do not correspond to any actual data in + # pandas dataframe loaded previously. We therefore are permitted to control + # the dimensionality of these nodes as we wish (with the `dim` argument). + # The distribution module here should be interpreted as saying that we are + # imposing a multi-modal prior (a mixture of Gaussians) on each latent + # variable. + node_hm = Node(causal_network.GaussianMixture('hm', 10, dim=2), hidden=True) + node_hl = Node(causal_network.GaussianMixture('hl', 10, dim=2), hidden=True) + node_hr1 = Node( + causal_network.GaussianMixture('hr1', 10, dim=2), hidden=True) + node_hr2 = Node( + causal_network.GaussianMixture('hr2', 10, dim=2), hidden=True) + node_hr3 = Node( + causal_network.GaussianMixture('hr3', 10, dim=2), hidden=True) + + # The rest of the graph is now constructed; the order of construction is + # important, so we can inform each node of its parents. + # Note that in the paper we simply have one node called "R", but here it is + # separated into three separate `Node` instances. This is necessary since + # each node can only represent a single quantity in the dataframe. + node_m = Node( + make_multinomial(train_data, 'marital-status'), + [node_a, node_hm, node_c1, node_c2]) + node_l = Node( + make_gaussian('education-num', column_names.index('education-num')), + [node_a, node_hl, node_c1, node_c2, node_m]) + node_r1 = Node( + make_multinomial(train_data, 'occupation'), + [node_a, node_c1, node_c2, node_m, node_l]) + node_r2 = Node( + make_gaussian('hours-per-week', column_names.index('hours-per-week')), + [node_a, node_c1, node_c2, node_m, node_l]) + node_r3 = Node( + make_multinomial(train_data, 'workclass'), + [node_a, node_c1, node_c2, node_m, node_l]) + node_y = Node( + MLPMultinomial.from_frame(train_data, 'income'), + [node_a, node_c1, node_c2, node_m, node_l, node_r1, node_r2, node_r3]) + + # We now construct several (self-explanatory) collections of nodes. These + # will be used at various points later in the code, and serve to provide + # greater semantic interpretability. + observable_nodes = (node_a, node_c1, node_c2, node_l, node_m, node_r1, + node_r2, node_r3, node_y) + + # The nodes on which each latent variable is conditionally dependent. + # Note that Y is not in this list, since all of its dependencies are + # included below, and further it does not depend directly on Hm. + nodes_on_which_hm_depends = (node_a, node_c1, node_c2, node_m) + nodes_on_which_hl_depends = (node_a, node_c1, node_c2, node_m, node_l) + nodes_on_which_hr1_depends = (node_a, node_c1, node_c2, node_m, node_l, + node_r1) + nodes_on_which_hr2_depends = (node_a, node_c1, node_c2, node_m, node_l, + node_r2) + nodes_on_which_hr3_depends = (node_a, node_c1, node_c2, node_m, node_l, + node_r3) + hidden_nodes = (node_hm, node_hl, node_hr1, node_hr2, node_hr3) + + # Function to create the distribution needed for variational inference. This + # is the same for each latent variable. + def make_q_x_obs_module(node): + """Make a Variational module for the given hidden variable.""" + assert node.hidden + return variational.Variational( + common_layer_sizes=(20, 20), output_dim=node.dim) + + # For each latent variable, we first construct a Haiku module (using the + # function above), and then connect it to the graph using the node's + # value. As described in more detail in the documentation for `Node`, + # these values represent actual observed data. Therefore we will later + # be connecting these same modules to the graph in different ways in order + # to perform fair inference. + q_hm_obs_module = make_q_x_obs_module(node_hm) + q_hl_obs_module = make_q_x_obs_module(node_hl) + q_hr1_obs_module = make_q_x_obs_module(node_hr1) + q_hr2_obs_module = make_q_x_obs_module(node_hr2) + q_hr3_obs_module = make_q_x_obs_module(node_hr3) + + causal_network.populate(observable_nodes, inputs) + + q_hm_obs = q_hm_obs_module( + *(node.observed_value for node in nodes_on_which_hm_depends)) + q_hl_obs = q_hl_obs_module( + *(node.observed_value for node in nodes_on_which_hl_depends)) + q_hr1_obs = q_hr1_obs_module( + *(node.observed_value for node in nodes_on_which_hr1_depends)) + q_hr2_obs = q_hr2_obs_module( + *(node.observed_value for node in nodes_on_which_hr2_depends)) + q_hr3_obs = q_hr3_obs_module( + *(node.observed_value for node in nodes_on_which_hr3_depends)) + q_hidden_obs = (q_hm_obs, q_hl_obs, q_hr1_obs, q_hr2_obs, q_hr3_obs) + + return observable_nodes, hidden_nodes, q_hidden_obs + + +def build_forward_fn(train_data: pd.DataFrame, column_names: List[str], + likelihood_multiplier: float): + """Create the model's forward pass.""" + + def forward_fn(inputs: jnp.ndarray) -> CausalNetOutput: + """Forward pass.""" + observable_nodes, hidden_nodes, q_hidden = build_causal_graph( + train_data, column_names, inputs) + (node_hm, node_hl, node_hr1, node_hr2, node_hr3) = hidden_nodes + (node_a, _, _, _, _, _, _, _, node_y) = observable_nodes + + # Log-likelihood function. + def log_p_obs_h(hm_value, hl_value, hr1_value, hr2_value, hr3_value): + """Compute log P(A, C, M, L, R, Y | H).""" + # In order to create distributions like P(M | H_m, A, C), we need + # the value of H_m that we've been provided as an argument, rather than + # the value stored on H_m (which, in fact, will never be populated + # since H_m is unobserved). + # For compactness, we first construct the complete list of replacements. + node_to_replacement = { + node_hm: hm_value, + node_hl: hl_value, + node_hr1: hr1_value, + node_hr2: hr2_value, + node_hr3: hr3_value, + } + + def log_prob_for_node(node): + """Given a node, compute it's log probability for the given latents.""" + log_prob = jnp.squeeze( + node.make_distribution(node_to_replacement).log_prob( + node.observed_value)) + return log_prob + + # We apply the likelihood multiplier to all likelihood terms except that + # for Y, the target. This is then added on separately in the line below. + sum_no_y = likelihood_multiplier * sum( + log_prob_for_node(node) + for node in observable_nodes + if node is not node_y) + + return sum_no_y + log_prob_for_node(node_y) + + q_hidden_obs = tuple(q_hidden) + p_hidden = tuple(node.distribution for node in hidden_nodes) + rnd_key = hk.next_rng_key() + hidden_samples = tuple( + q_hidden.sample(seed=rnd_key) for q_hidden in q_hidden_obs) + log_p_obs_hidden = log_p_obs_h(*hidden_samples) + + # We need to split our batch of data into male and female parts. + is_male = jnp.equal(node_a.observed_value[:, 1], 1) + + return CausalNetOutput( + q_hidden_obs=q_hidden_obs, + p_hidden=p_hidden, + hidden_samples=hidden_samples, + log_p_obs_hidden=log_p_obs_hidden, + is_male=is_male) + + def fair_inference_fn(inputs: jnp.ndarray, batch_size: int, + num_prediction_samples: int): + """Get the fair and unfair predictions for the given input.""" + observable_nodes, hidden_nodes, q_hidden_obs = build_causal_graph( + train_data, column_names, inputs) + (node_hm, node_hl, node_hr1, node_hr2, node_hr3) = hidden_nodes + (node_a, node_c1, node_c2, node_l, node_m, node_r1, node_r2, node_r3, + node_y) = observable_nodes + (q_hm_obs, q_hl_obs, q_hr1_obs, q_hr2_obs, q_hr3_obs) = q_hidden_obs + rnd_key = hk.next_rng_key() + + # *** FAIR INFERENCE *** + + # To predict Y in a fair sense: + # * Infer Hm given observations. + # * Infer M using inferred Hm, baseline A, real C + # * Infer L using inferred Hl, M, real A, C + # * Infer Y using inferred M, baseline A, real C + # This is done by numerical integration, i.e. draw samples from + # p_fair(Y | A, C, M, L). + a_all_male = jnp.concatenate( + (jnp.zeros((batch_size, 1)), jnp.ones((batch_size, 1))), + axis=1) + + # Here we take a num_samples per observation. This results to + # an array of shape: + # (num_samples, batch_size, hm_dim). + # However, forward pass is easier by reshaping to: + # (num_samples * batch_size, hm_dim). + hm_dim = 2 + def expanded_sample(distribution): + return distribution.sample( + num_prediction_samples, seed=rnd_key).reshape( + (batch_size * num_prediction_samples, hm_dim)) + + hm_pred_sample = expanded_sample(q_hm_obs) + hl_pred_sample = expanded_sample(q_hl_obs) + hr1_pred_sample = expanded_sample(q_hr1_obs) + hr2_pred_sample = expanded_sample(q_hr2_obs) + hr3_pred_sample = expanded_sample(q_hr3_obs) + + # The values of the observed nodes need to be tiled to match the dims + # of the above hidden samples. The `expand` function achieves this. + def expand(observed_value): + return jnp.tile(observed_value, (num_prediction_samples, 1)) + + expanded_a = expand(node_a.observed_value) + expanded_a_baseline = expand(a_all_male) + expanded_c1 = expand(node_c1.observed_value) + expanded_c2 = expand(node_c2.observed_value) + + # For M, and all subsequent variables, we only generate one sample. This + # is because we already have *many* samples from the latent variables, and + # all we require is an independent sample from the distribution. + m_pred_sample = node_m.make_distribution({ + node_a: expanded_a_baseline, + node_hm: hm_pred_sample, + node_c1: expanded_c1, + node_c2: expanded_c2}).sample(seed=rnd_key) + + l_pred_sample = node_l.make_distribution({ + node_a: expanded_a, + node_hl: hl_pred_sample, + node_c1: expanded_c1, + node_c2: expanded_c2, + node_m: m_pred_sample}).sample(seed=rnd_key) + + r1_pred_sample = node_r1.make_distribution({ + node_a: expanded_a, + node_hr1: hr1_pred_sample, + node_c1: expanded_c1, + node_c2: expanded_c2, + node_m: m_pred_sample, + node_l: l_pred_sample}).sample(seed=rnd_key) + r2_pred_sample = node_r2.make_distribution({ + node_a: expanded_a, + node_hr2: hr2_pred_sample, + node_c1: expanded_c1, + node_c2: expanded_c2, + node_m: m_pred_sample, + node_l: l_pred_sample}).sample(seed=rnd_key) + r3_pred_sample = node_r3.make_distribution({ + node_a: expanded_a, + node_hr3: hr3_pred_sample, + node_c1: expanded_c1, + node_c2: expanded_c2, + node_m: m_pred_sample, + node_l: l_pred_sample}).sample(seed=rnd_key) + + # Finally, we sample from the distribution for Y. Like above, we only + # draw one sample per element in the array. + y_pred_sample = node_y.make_distribution({ + node_a: expanded_a_baseline, + # node_a: expanded_a, + node_c1: expanded_c1, + node_c2: expanded_c2, + node_m: m_pred_sample, + node_l: l_pred_sample, + node_r1: r1_pred_sample, + node_r2: r2_pred_sample, + node_r3: r3_pred_sample}).sample(seed=rnd_key) + + # Reshape back to (num_samples, batch_size, y_dim), undoing the expanding + # operation used for sampling. + y_pred_sample = y_pred_sample.reshape( + (num_prediction_samples, batch_size, -1)) + + # Now form an array of shape (batch_size, y_dim) by taking an expectation + # over the sample dimension. This represents the probability that the + # result is in each class. + y_pred_expectation = jnp.mean(y_pred_sample, axis=0) + + # Find out the predicted y, for later use in a confusion matrix. + predicted_class_y_fair = utils.multinomial_class(y_pred_expectation) + + # *** NAIVE INFERENCE *** + predicted_class_y_unfair = utils.multinomial_class(node_y.distribution) + + return predicted_class_y_fair, predicted_class_y_unfair + + return forward_fn, fair_inference_fn + + +def _loss_fn( + forward_fn, + beta: float, + mmd_sample_size: int, + constraint_multiplier: float, + constraint_ratio: float, + params: hk.Params, + rng: jnp.ndarray, + inputs: jnp.ndarray, +) -> jnp.ndarray: + """Loss function definition.""" + outputs = forward_fn(params, rng, inputs) + loss = _loss_klqp(outputs, beta) + + # if (constraint_ratio * constraint_multiplier) > 0: + constraint_loss = 0. + + # Create constraint penalty and add to overall loss term. + for distribution in outputs.q_hidden_obs: + constraint_loss += (constraint_ratio * constraint_multiplier * + utils.mmd_loss(distribution, + outputs.is_male, + mmd_sample_size, + rng)) + + # Optimisation - don't do the computation if the multiplier is set to zero. + loss += constraint_loss + + return loss + + +def _evaluate( + fair_inference_fn, + params: hk.Params, + rng: jnp.ndarray, + inputs: jnp.ndarray, + batch_size: int, + num_prediction_samples: int, +): + """Perform evaluation of fair inference.""" + output = fair_inference_fn(params, rng, inputs, + batch_size, num_prediction_samples) + + return output + + +def _loss_klqp(outputs: CausalNetOutput, beta: float) -> jnp.ndarray: + """Compute the loss on data wrt params.""" + + expected_log_q_hidden_obs = sum( + jnp.sum(q_hidden_obs.log_prob(hidden_sample), axis=1) for q_hidden_obs, + hidden_sample in zip(outputs.q_hidden_obs, outputs.hidden_samples)) + + assert expected_log_q_hidden_obs.ndim == 1 + + # For log probabilities computed from distributions, we need to sum along + # the last axis, which takes the product of distributions for + # multi-dimensional hidden variables. + log_p_hidden = sum( + jnp.sum(p_hidden.log_prob(hidden_sample), axis=1) for p_hidden, + hidden_sample in zip(outputs.p_hidden, outputs.hidden_samples)) + + assert outputs.log_p_obs_hidden.ndim == 1 + + kl_divergence = ( + beta * (expected_log_q_hidden_obs - log_p_hidden) - + outputs.log_p_obs_hidden) + return jnp.mean(kl_divergence) + + +class Updater: + """A stateless abstraction around an init_fn/update_fn pair. + + This extracts some common boilerplate from the training loop. + """ + + def __init__(self, net_init, loss_fn, eval_fn, + optimizer: optax.GradientTransformation, + constraint_turn_on_step): + self._net_init = net_init + self._loss_fn = loss_fn + self._eval_fn = eval_fn + self._opt = optimizer + self._constraint_turn_on_step = constraint_turn_on_step + + @functools.partial(jax.jit, static_argnums=0) + def init(self, init_rng, data): + """Initializes state of the updater.""" + params = self._net_init(init_rng, data) + opt_state = self._opt.init(params) + out = dict( + step=np.array(0), + rng=init_rng, + opt_state=opt_state, + params=params, + ) + return out + + @functools.partial(jax.jit, static_argnums=0) + def update(self, state: Mapping[str, Any], data: jnp.ndarray): + """Updates the state using some data and returns metrics.""" + rng = state['rng'] + params = state['params'] + + constraint_ratio = (state['step'] > self._constraint_turn_on_step).astype( + float) + loss, g = jax.value_and_grad(self._loss_fn, argnums=1)( + constraint_ratio, params, rng, data) + updates, opt_state = self._opt.update(g, state['opt_state']) + params = optax.apply_updates(params, updates) + + new_state = { + 'step': state['step'] + 1, + 'rng': rng, + 'opt_state': opt_state, + 'params': params, + } + + new_metrics = { + 'step': state['step'], + 'loss': loss, + } + return new_state, new_metrics + + @functools.partial(jax.jit, static_argnums=(0, 3, 4)) + def evaluate(self, state: Mapping[str, Any], inputs: jnp.ndarray, + batch_size: int, num_prediction_samples: int): + """Evaluate fair inference.""" + rng = state['rng'] + params = state['params'] + fair_pred, unfair_pred = self._eval_fn(params, rng, inputs, batch_size, + num_prediction_samples) + return fair_pred, unfair_pred + + +def main(_): + flags_config = FLAGS.config + # Create the dataset. + train_data, test_data = adult.read_all_data(FLAGS.dataset_dir) + column_names = list(train_data.columns) + train_input = build_input(train_data, flags_config.batch_size, + flags_config.num_steps) + + # Set up the model, loss, and updater. + forward_fn, fair_inference_fn = build_forward_fn( + train_data, column_names, flags_config.likelihood_multiplier) + forward_fn = hk.transform(forward_fn) + fair_inference_fn = hk.transform(fair_inference_fn) + + loss_fn = functools.partial(_loss_fn, forward_fn.apply, + flags_config.beta, + flags_config.mmd_sample_size, + flags_config.constraint_multiplier) + eval_fn = functools.partial(_evaluate, fair_inference_fn.apply) + + optimizer = optax.adam(flags_config.learning_rate) + + updater = Updater(forward_fn.init, loss_fn, eval_fn, + optimizer, flags_config.constraint_turn_on_step) + + # Initialize parameters. + logging.info('Initializing parameters...') + rng = jax.random.PRNGKey(42) + train_data = next(train_input) + state = updater.init(rng, train_data) + + # Training loop. + logging.info('Starting train loop...') + prev_time = time.time() + for step in range(flags_config.num_steps): + train_data = next(train_input) + state, stats = updater.update(state, train_data) + if step % LOG_EVERY == 0: + steps_per_sec = LOG_EVERY / (time.time() - prev_time) + prev_time = time.time() + stats.update({'steps_per_sec': steps_per_sec}) + logging.info({k: float(v) for k, v in stats.items()}) + + # Evaluate. + logging.info('Starting evaluation...') + test_input = build_input(test_data, flags_config.batch_size, + training_steps=0, + shuffle_size=0) + + predicted_test_y = [] + corrected_test_y = [] + while True: + try: + eval_data = next(test_input) + # Now run the fair prediction; this projects the input to the latent space + # and then performs sampling. + predicted_class_y_fair, predicted_class_y_unfair = updater.evaluate( + state, eval_data, flags_config.batch_size, + flags_config.num_prediction_samples) + predicted_test_y.append(predicted_class_y_unfair) + corrected_test_y.append(predicted_class_y_fair) + # logging.info('Completed evaluation step %d', step) + except StopIteration: + logging.info('Finished evaluation') + break + + # Join together the predictions from each batch. + test_y = np.concatenate(predicted_test_y, axis=0) + tweaked_test_y = np.concatenate(corrected_test_y, axis=0) + + # Note the true values for computing accuracy and confusion matrices. + y_true = test_data['income'].cat.codes + # Make sure y_true is the same size + y_true = y_true[:len(test_y)] + + test_accuracy = metrics.accuracy_score(y_true, test_y) + tweaked_test_accuracy = metrics.accuracy_score( + y_true, tweaked_test_y) + + # Print out accuracy and confusion matrices. + logging.info('Accuracy (full model): %f', test_accuracy) + logging.info('Confusion matrix:') + logging.info(metrics.confusion_matrix(y_true, test_y)) + logging.info('') + + logging.info('Accuracy (tweaked with baseline: Male): %f', + tweaked_test_accuracy) + logging.info('Confusion matrix:') + logging.info(metrics.confusion_matrix(y_true, tweaked_test_y)) + +if __name__ == '__main__': + app.run(main) diff --git a/counterfactual_fairness/adult_pscf_config.py b/counterfactual_fairness/adult_pscf_config.py new file mode 100644 index 0000000..be7b888 --- /dev/null +++ b/counterfactual_fairness/adult_pscf_config.py @@ -0,0 +1,65 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Config for adult PSCF experiment.""" + +import ml_collections + + +def get_config(): + """Return the default configuration.""" + config = ml_collections.ConfigDict() + + config.num_steps = 10000 # Number of training steps to perform. + config.batch_size = 128 # Batch size. + config.learning_rate = 0.01 # Learning rate + + # Number of samples to draw for prediction. + config.num_prediction_samples = 500 + + # Batch size to use for prediction. Ideally as big as possible, but may need + # to be reduced for memory reasons depending on the value of + # `num_prediction_samples`. + config.prediction_batch_size = 500 + + # Multiplier for the likelihood term in the loss + config.likelihood_multiplier = 5. + + # Multiplier for the MMD constraint term in the loss + config.constraint_multiplier = 0. + + # Scaling factor to use in KL term. + config.beta = 1.0 + + # The number of samples we draw from each latent variable distribution. + config.mmd_sample_size = 100 + + # Directory into which results should be placed. By default it is the empty + # string, in which case no saving will occur. The directory specified will be + # created if it does not exist. + config.output_dir = '' + + # The index of the step at which to turn on the constraint multiplier. For + # steps prior to this the multiplier will be zero. + config.constraint_turn_on_step = 0 + + # The random seed for tensorflow that is applied to the graph iff the value is + # non-negative. By default the seed is not constrained. + config.seed = -1 + + # When doing fair inference, don't sample when given a sample for the baseline + # gender. + config.baseline_passthrough = False + + return config diff --git a/counterfactual_fairness/causal_network.py b/counterfactual_fairness/causal_network.py new file mode 100644 index 0000000..1e93192 --- /dev/null +++ b/counterfactual_fairness/causal_network.py @@ -0,0 +1,405 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Construct DAGs representing causal graphs, and perform inference on them.""" + +import collections + +import haiku as hk +import jax +import jax.numpy as jnp +import pandas as pd +from tensorflow_probability.substrates import jax as tfp +import tree + + +class Node: + """A node in a graphical model. + + Conceptually, this represents a random variable in a causal probabilistic + model. It knows about its 'parents', i.e. other Nodes upon which this Node + causally depends. The user is responsible for ensuring that any graph built + with this class is acyclic. + + A node knows how to represent its probability density, conditional on the + values of its parents. + + The node needs to have a 'name', corresponding to the series within the + dataframe that should be used to populate it. + """ + + def __init__(self, distribution_module, parents=(), hidden=False): + """Initialise a `Node` instance. + + Args: + distribution_module: An instance of `DistributionModule`, a Haiku module + that is suitable for modelling the conditional distribution of this + node given any parents. + parents: `Iterable`, optional. A (potentially nested) collection of nodes + which are direct ancestors of `self`. + hidden: `bool`, optional. Whether this node is hidden. Hidden nodes are + permitted to not have corresponding observations. + """ + parents = tree.flatten(parents) + + self._distribution_module = distribution_module + self._column = distribution_module.column + self._index = distribution_module.index + self._hidden = hidden + self._observed_value = None + + # When implementing the path-specific counterfactual fairness algorithm, + # we need the concept of a distribution conditional on the 'corrected' + # values of the parents. This is achieved via the 'node_to_replacement' + # argument of make_distribution. + + # However, in order to work with the `fix_categorical` and `fix_continuous` + # functions, we need to assign counterfactual values for parents at + # evaluation time. + self._parent_to_value = collections.OrderedDict( + (parent, None) for parent in parents) + + # This is the conditional distribution using no replacements, i.e. it is + # conditioned on the observed values of parents. + self._distribution = None + + def __repr__(self): + return 'Node<{}>'.format(self.name) + + @property + def dim(self): + """The dimensionality of this node.""" + return self._distribution_module.dim + + @property + def name(self): + return self._column + + @property + def hidden(self): + return self._hidden + + @property + def observed_value(self): + return self._observed_value + + def find_ancestor(self, name): + """Returns an ancestor node with the given name.""" + if self.name == name: + return self + for parent in self.parents: + found = parent.find_ancestor(name) + if found is not None: + return found + + @property + def parents(self): + return tuple(self._parent_to_value) + + @property + def distribution_module(self): + return self._distribution_module + + @property + def distribution(self): + self._distribution = self.make_distribution() + return self._distribution + + def make_distribution(self, node_to_replacement=None): + """Make a conditional distribution for this node | parents. + + By default we use values (representing 'real data') from the parent + nodes as inputs to the distribution, however we can alternatively swap out + any of these for arbitrary arrays by specifying `node_to_replacement`. + + Args: + node_to_replacement: `None`, `dict: Node -> DeviceArray`. + If specified, use the indicated array. + + Returns: + `tfp.distributions.Distribution` + """ + cur_parent_to_value = self._parent_to_value + self._parent_to_value = collections.OrderedDict( + (parent, parent.observed_value) for parent in cur_parent_to_value.keys() + ) + + if node_to_replacement is None: + parent_values = self._parent_to_value.values() + return self._distribution_module(*parent_values) + args = [] + for node, value in self._parent_to_value.items(): + if node in node_to_replacement: + replacement = node_to_replacement[node] + args.append(replacement) + else: + args.append(value) + return self._distribution_module(*args) + + def populate(self, data, node_to_replacement=None): + """Given a dataframe, populate node data. + + If the Node does not have data present, this is taken to be a sign of + a) An error if the node is not hidden. + b) Fine if the node is hidden. + In case a) an exception will be raised, and in case b) observed)v will + not be mutated. + + Args: + data: tf.data.Dataset + node_to_replacement: None | dict(Node -> array). If not None, use the + given ndarray data rather than extracting data from the frame. This is + only considered when looking at the inputs to a distribution. + + Raises: + RuntimeError: If `data` doesn't contain the necessary feature, and the + node is not hidden. + """ + column = self._column + hidden = self._hidden + replacement = None + if node_to_replacement is not None and self in node_to_replacement: + replacement = node_to_replacement[self] + + if replacement is not None: + # If a replacement is present, this takes priority over any other + # consideration. + self._observed_value = replacement + return + + if self._index < 0: + if not hidden: + raise RuntimeError( + 'Node {} is not hidden, and column {} is not in the frame.'.format( + self, column)) + # Nothing to do - there is no data, and the node is hidden. + return + + # Produce the observed value for this node. + self._observed_value = self._distribution_module.prepare_data(data) + + +class DistributionModule(hk.Module): + """Common base class for a Haiku module representing a distribution. + + This provides some additional functionality common to all modules that would + be used as arguments to the `Node` class above. + """ + + def __init__(self, column, index, dim): + """Initialise a `DistributionModule` instance. + + Args: + column: `string`. The name of the random variable to which this + distribution corresponds, and should match the name of the series in + the pandas dataframe. + index: `int`. The index of the corresponding feature in the dataset. + dim: `int`. The output dimensionality of the distribution. + """ + super().__init__(name=column.replace('-', '_')) + self._column = column + self._index = index + self._dim = dim + + @property + def dim(self): + """The output dimensionality of this distribution.""" + return self._dim + + @property + def column(self): + return self._column + + @property + def index(self): + return self._index + + def prepare_data(self, data): + """Given a general tensor, return an ndarray if required. + + This method implements the functionality delegated from + `Node._prepare_data`, and it is expected that subclasses will override the + implementation appropriately. + + Args: + data: A tf.data.Dataset. + + Returns: + `np.ndarray` of appropriately converted values for this series. + """ + return data[:, [self._index]] + + def _package_args(self, args): + """Concatenate args into a single tensor. + + Args: + args: `List[DeviceArray]`, length > 0. + Each array is of shape (batch_size, ?) or (batch_size,). The former + will occur if looking at e.g. a one-hot encoded categorical variable, + and the latter in the case of a continuous variable. + + Returns: + `DeviceArray`, (batch_size, num_values). + """ + return jnp.concatenate(args, axis=1) + + +class Gaussian(DistributionModule): + """A Haiku module that maps some inputs into a normal distribution.""" + + def __init__(self, column, index, dim=1, hidden_shape=(), + hidden_activation=jnp.tanh, scale=None): + """Initialise a `Gaussian` instance with some dimensionality.""" + super(Gaussian, self).__init__(column, index, dim) + self._hidden_shape = tuple(hidden_shape) + self._hidden_activation = hidden_activation + self._scale = scale + + self._loc_net = hk.nets.MLP(self._hidden_shape + (self._dim,), + activation=self._hidden_activation) + + def __call__(self, *args): + if args: + # There are arguments - these represent the variables on which we are + # conditioning. We set the mean of the output distribution to be a + # function of these values, parameterised with an MLP. + concatenated_inputs = self._package_args(args) + loc = self._loc_net(concatenated_inputs) + else: + # There are no arguments, so instead have a learnable location parameter. + loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros) + + if self._scale is None: + # The scale has not been explicitly specified, in which case it is left + # to be single value, i.e. not a function of the conditioning set. + log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones) + scale = jnp.sqrt(jnp.exp(log_var)) + else: + scale = jnp.float32(self._scale) + + return tfp.distributions.Normal(loc=loc, scale=scale) + + def prepare_data(self, data): + # For continuous data, we ensure the data is of dtype float32, and + # additionally that the resulant shape is (num_examples, 1) + # Note that this implementation only works for dim=1, however this is + # currently also enforced by the fact that pandas series cannot be + # multidimensional. + result = data[:, [self.index]].astype(jnp.float32) + if len(result.shape) == 1: + result = jnp.expand_dims(result, axis=1) + return result + + +class GaussianMixture(DistributionModule): + """A Haiku module that maps some inputs into a mixture of normals.""" + + def __init__(self, column, num_components, dim=1): + """Initialise a `GaussianMixture` instance with some dimensionality. + + Args: + column: `string`. The name of the column. + num_components: `int`. The number of gaussians in this mixture. + dim: `int`. The dimensionality of the variable. + """ + super().__init__(column, -1, dim) + self._num_components = num_components + self._loc_net = hk.nets.MLP([self._dim]) + self._categorical_logits_module = hk.nets.MLP([self._num_components]) + + def __call__(self, *args): + # Define component Gaussians to be independent functions of args. + locs = [] + scales = [] + for _ in range(self._num_components): + loc = hk.get_parameter('loc', shape=[self._dim], init=jnp.zeros) + log_var = hk.get_parameter('log_var', shape=[self._dim], init=jnp.ones) + scale = jnp.sqrt(jnp.exp(log_var)) + locs.extend(loc) + scales.extend(scale) + + # Define the Categorical distribution which switches between these + categorical_logits = hk.get_parameter('categorical_logits', + shape=[self._num_components], + init=jnp.zeros) + + # Enforce positivity in the logits + categorical_logits = jax.nn.sigmoid(categorical_logits) + + # If we have a multidimensional node, then the normal distributions above + # have a batch shape of (dim,). We want to select between these using the + # categorical distribution, so tile the logits to match this shape + expanded_logits = jnp.repeat(categorical_logits, self._dim) + + categorical = tfp.distributions.Categorical(logits=expanded_logits) + + return tfp.distributions.MixtureSameFamily( + mixture_distribution=categorical, + components_distribution=tfp.distributions.Normal( + loc=locs, scale=scales)) + + +class MLPMultinomial(DistributionModule): + """A Haiku module that consists of an MLP + multinomial distribution.""" + + def __init__(self, column, index, dim, hidden_shape=(), + hidden_activation=jnp.tanh): + """Initialise an MLPMultinomial instance. + + Args: + column: `string`. Name of the corresponding dataframe column. + index: `int`. The index of the input data for this feature. + dim: `int`. Number of categories. + hidden_shape: `Iterable`, optional. Shape of hidden layers. + hidden_activation: `Callable`, optional. Non-linearity for hidden + layers. + """ + super(MLPMultinomial, self).__init__(column, index, dim) + self._hidden_shape = tuple(hidden_shape) + self._hidden_activation = hidden_activation + self._logit_net = hk.nets.MLP(self._hidden_shape + (self.dim,), + activation=self._hidden_activation) + + @classmethod + def from_frame(cls, data, column, hidden_shape=()): + """Create an MLPMultinomial instance from a pandas dataframe and column.""" + # Helper method that uses the dataframe to work out how many categories + # are in the given column. The dataframe is not used for any other purpose. + if not isinstance(data[column].dtype, pd.api.types.CategoricalDtype): + raise ValueError('{} is not categorical.'.format(column)) + index = list(data.columns).index(column) + num_categories = len(data[column].cat.categories) + return cls(column, index, num_categories, hidden_shape) + + def __call__(self, *args): + if args: + concatenated_inputs = self._package_args(args) + logits = self._logit_net(concatenated_inputs) + else: + logits = hk.get_parameter('b', shape=[self.dim], init=jnp.zeros) + return tfp.distributions.Multinomial(logits=logits, total_count=1.0) + + def prepare_data(self, data): + # For categorical data, we convert to a one-hot representation using the + # pandas category 'codes'. These are integers, and will have a definite + # ordering that is identical between runs. + codes = data[:, self.index] + codes = codes.astype(jnp.int32) + return jnp.eye(self.dim)[codes] + + +def populate(nodes, dataframe, node_to_replacement=None): + """Populate observed values for nodes.""" + for node in nodes: + node.populate(dataframe, node_to_replacement=node_to_replacement) diff --git a/counterfactual_fairness/download_dataset.sh b/counterfactual_fairness/download_dataset.sh new file mode 100644 index 0000000..aaa305c --- /dev/null +++ b/counterfactual_fairness/download_dataset.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Usage: +# sh download_dataset.sh ${OUTPUT_DIR} +# Example: +# sh download_dataset.sh uci_adult + +set -e + +OUTPUT_DIR="${1}" + +BASE_URL="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/" + +mkdir -p ${OUTPUT_DIR} +for file in adult.data adult.test +do +wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}" +done diff --git a/counterfactual_fairness/requirements.txt b/counterfactual_fairness/requirements.txt new file mode 100644 index 0000000..7a30ce7 --- /dev/null +++ b/counterfactual_fairness/requirements.txt @@ -0,0 +1,12 @@ +wheel +absl-py>=0.12.0 +dm-haiku>=0.0.4 +optax>=0.0.8 +ml_collections +numpy>=1.16.4 +tensorflow>=2.5.0 +tensorflow-datasets>=4.3.0 +tensorflow_probability +pandas +sklearn +jaxlib>=0.1.67+cuda110 diff --git a/counterfactual_fairness/run_adult_pscf.sh b/counterfactual_fairness/run_adult_pscf.sh new file mode 100755 index 0000000..32a4df2 --- /dev/null +++ b/counterfactual_fairness/run_adult_pscf.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# Fail on any error. +set -e + +# Display commands being run. +set -x + +TMP_DIR=`mktemp -d` + +virtualenv --python=python3.6 "${TMP_DIR}/env" +source "${TMP_DIR}/env/bin/activate" + +# Install dependencies. +pip install --upgrade -r counterfactual_fairness/requirements.txt + +# Download minimal dataset +DATA_DIR="${TMP_DIR}/adult" +bash counterfactual_fairness/download_dataset.sh ${TMP_DIR} + +# Train for a few steps. +python -m counterfactual_fairness.adult_pscf --config="counterfactual_fairness/adult_pscf_config.py" --dataset_dir=${DATA_DIR} --num_steps=10 + +# Clean up. +rm -r ${TMP_DIR} +echo "Test run complete." diff --git a/counterfactual_fairness/utils.py b/counterfactual_fairness/utils.py new file mode 100644 index 0000000..bf506df --- /dev/null +++ b/counterfactual_fairness/utils.py @@ -0,0 +1,311 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Common utilities.""" + +from typing import Optional, Union + +from jax import random +import jax.numpy as jnp +import pandas as pd +import tensorflow as tf +from tensorflow_probability.substrates import jax as tfp + +tfd = tfp.distributions + + +def get_dataset(dataset: pd.DataFrame, + batch_size: int, + shuffle_size: int = 10000, + num_epochs: Optional[int] = None) -> tf.data.Dataset: + """Makes a tf.Dataset with correct preprocessing.""" + dataset_copy = dataset.copy() + for column in dataset.columns: + if dataset[column].dtype.name == 'category': + dataset_copy.loc[:, column] = dataset[column].cat.codes + ds = tf.data.Dataset.from_tensor_slices(dataset_copy.values) + if shuffle_size > 0: + ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True) + ds = ds.repeat(num_epochs) + return ds.batch(batch_size, drop_remainder=True) + + +def multinomial_mode( + distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] + ) -> jnp.DeviceArray: + """Calculates the (one-hot) mode of a multinomial distribution. + + Args: + distribution_or_probs: + `tfp.distributions.Distribution` | List[tensors]. + If the former, it is assumed that it has a `probs` property, and + represents a distribution over categories. If the latter, these are + taken to be the probabilities of categories directly. + In either case, it is assumed that `probs` will be shape + (batch_size, dim). + + Returns: + `DeviceArray`, float32, (batch_size, dim). + The mode of the distribution - this will be in one-hot form, but contain + multiple non-zero entries in the event that more than one probability is + joint-highest. + """ + if isinstance(distribution_or_probs, tfd.Distribution): + probs = distribution_or_probs.probs_parameter() + else: + probs = distribution_or_probs + max_prob = jnp.max(probs, axis=1, keepdims=True) + mode = jnp.int32(jnp.equal(probs, max_prob)) + return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True)) + + +def multinomial_class( + distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] +) -> jnp.DeviceArray: + """Computes the mode class of a multinomial distribution. + + Args: + distribution_or_probs: + `tfp.distributions.Distribution` | DeviceArray. + As for `multinomial_mode`. + + Returns: + `DeviceArray`, float32, (batch_size,). + For each element in the batch, the index of the class with highest + probability. + """ + if isinstance(distribution_or_probs, tfd.Distribution): + return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1) + return jnp.argmax(distribution_or_probs, axis=1) + + +def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray: + """Calculates the (one-hot) mode from an ndarray of class probabilities. + + Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays + rather than Tensors. + + Args: + probs: `DeviceArray`, (batch_size, dim). Probabilities for each class, for + each element in a batch. + + Returns: + `DeviceArray`, (batch_size, dim). + """ + max_prob = jnp.amax(probs, axis=1, keepdims=True) + mode = jnp.equal(probs, max_prob).astype(jnp.int32) + return (mode / jnp.sum(mode, axis=1, keepdims=True)).astype(jnp.float32) + + +def multinomial_accuracy(distribution_or_probs: tfd.Distribution, + data: jnp.DeviceArray) -> jnp.DeviceArray: + """Compute the accuracy, averaged over a batch of data. + + Args: + distribution_or_probs: + `tfp.distributions.Distribution` | List[tensors]. + As for functions above. + data: `DeviceArray`. Reference data, of shape (batch_size, dim). + + Returns: + `DeviceArray`, float32, (). + Overall scalar accuracy. + """ + return jnp.mean( + jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1)) + + +def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray: + """Softmax function, implemented for numpy ndarrays.""" + assert len(logits.shape) == 2 + # Normalise for better stability. + s = jnp.max(logits, axis=1, keepdims=True) + e_x = jnp.exp(logits - s) + return e_x / jnp.sum(e_x, axis=1, keepdims=True) + + +def get_samples(distribution, num_samples, seed=None): + """Given a batched distribution, compute samples and reshape along batch. + + That is, we have a distribution of shape (batch_size, ...), where each element + of the tensor is independent. We then draw num_samples from each component, to + give a tensor of shape: + + (num_samples, batch_size, ...) + + Args: + distribution: `tfp.distributions.Distribution`. The distribution from which + to sample. + num_samples: `Integral` | `DeviceArray`, int32, (). The number of samples. + seed: `Integral` | `None`. The seed that will be forwarded to the call to + distribution.sample. Defaults to `None`. + + Returns: + `DeviceArray`, float32, (batch_size * num_samples, ...). + Samples for each element of the batch. + """ + # Obtain the sample from the distribution, which will be of shape + # [num_samples] + batch_shape + event_shape. + sample = distribution.sample(num_samples, seed=seed) + sample = sample.reshape((-1, sample.shape[-1])) + + # Combine the first two dimensions through a reshape, so the result will + # be of shape (num_samples * batch_size,) + shape_tail. + return sample + + +def mmd_loss(distribution: tfd.Distribution, + is_a: jnp.DeviceArray, + num_samples: int, + rng: jnp.ndarray, + num_random_features: int = 50, + gamma: float = 1.): + """Given two distributions, compute the Maximum Mean Discrepancy (MMD). + + More exactly, this uses the 'FastMMD' approximation, a.k.a. 'Random Fourier + Features'. See the description, for example, in sections 2.3.1 and 2.4 of + https://arxiv.org/pdf/1511.00830.pdf. + + Args: + distribution: Distribution whose `sample()` method will return a + DeviceArray of shape (batch_size, dim). + is_a: A boolean array indicating which elements of the batch correspond + to class A (the remaining indices correspond to class B). + num_samples: The number of samples to draw from `distribution`. + rng: Random seed provided by the user. + num_random_features: The number of random fourier features + used in the expansion. + gamma: The value of gamma in the Gaussian MMD kernel. + + Returns: + `DeviceArray`, shape (). + The scalar MMD value for samples taken from the given distributions. + """ + if distribution.event_shape == (): # pylint: disable=g-explicit-bool-comparison + dim_x = distribution.batch_shape[1] + else: + dim_x, = distribution.event_shape + + # Obtain samples from the distribution, which will be of shape + # [num_samples] + batch_shape + event_shape. + samples = distribution.sample(num_samples, seed=rng) + + w = random.normal(rng, shape=((dim_x, num_random_features))) + b = random.uniform(rng, shape=(num_random_features,), + minval=0, maxval=2*jnp.pi) + + def features(x): + """Compute the kitchen sink feature.""" + # We need to contract last axis of x with first of W - do this with + # tensordot. The result has shape: + # (?, ?, num_random_features) + return jnp.sqrt(2 / num_random_features) * jnp.cos( + jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b) + + # Compute the expected values of the given features. + # The first axis represents the samples from the distribution, + # second axis represents the batch_size. + # Each of these now has shape (num_random_features,) + exp_features = features(samples) + # Swap axes so that batch_size is the last dimension to be compatible + # with is_a and is_b shape at the next step + exp_features_reshaped = jnp.swapaxes(exp_features, 1, 2) + # Current dimensions [num_samples, num_random_features, batch_size] + exp_features_reshaped_a = jnp.where(is_a, exp_features_reshaped, 0) + exp_features_reshaped_b = jnp.where(is_a, 0, exp_features_reshaped) + exp_features_a = jnp.mean(exp_features_reshaped_a, axis=(0, 2)) + exp_features_b = jnp.mean(exp_features_reshaped_b, axis=(0, 2)) + + assert exp_features_a.shape == (num_random_features,) + difference = exp_features_a - exp_features_b + + # Compute the squared norm. Shape (). + return jnp.tensordot(difference, difference, axes=1) + + +def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.): + """Exact estimate of MMD.""" + assert distribution_a.event_shape == distribution_b.event_shape + assert distribution_a.batch_shape[1:] == distribution_b.batch_shape[1:] + + # shape (num_samples * batch_size_a, dim_x) + samples_a = get_samples(distribution_a, num_samples) + # shape (num_samples * batch_size_b, dim_x) + samples_b = get_samples(distribution_b, num_samples) + + # Make matrices of shape + # (size_b, size_a, dim_x) + # where: + # size_a = num_samples * batch_size_a + # size_b = num_samples * batch_size_b + size_a = samples_a.shape[0] + size_b = samples_b.shape[0] + x_a = jnp.expand_dims(samples_a, axis=0) + x_a = jnp.tile(x_a, (size_b, 1, 1)) + x_b = jnp.expand_dims(samples_b, axis=1) + x_b = jnp.tile(x_b, (1, size_a, 1)) + + def kernel_mean(x, y): + """Gaussian kernel mean.""" + + diff = x - y + + # Contract over dim_x. + exponent = - jnp.einsum('ijk,ijk->ij', diff, diff) / gamma + + # This has shape (size_b, size_a). + kernel_matrix = jnp.exp(exponent) + + # Shape (). + return jnp.mean(kernel_matrix) + + # Equation 7 from arxiv 1511.00830 + return ( + kernel_mean(x_a, x_a) + + kernel_mean(x_b, x_b) + - 2 * kernel_mean(x_a, x_b)) + + +def scalar_log_prob(distribution, val): + """Compute the log_prob per batch entry. + + It is conceptually similar to: + + jnp.sum(distribution.log_prob(val), axis=1) + + However, classes like `tfp.distributions.Multinomial` have a log_prob which + returns a tensor of shape (batch_size,), which will cause the above + incantation to fail. In these cases we fall back to returning just: + + distribution.log_prob(val) + + Args: + distribution: `tfp.distributions.Distribution` which implements log_prob. + val: `DeviceArray`, (batch_size, dim). + + Returns: + `DeviceArray`, (batch_size,). + If the result of log_prob has a trailing dimension, we perform a reduce_sum + over it. + + Raises: + ValueError: If distribution.log_prob(val) has an unsupported shape. + """ + log_prob_val = distribution.log_prob(val) + if len(log_prob_val.shape) == 1: + return log_prob_val + elif len(log_prob_val.shape) > 2: + raise ValueError('log_prob_val has unexpected shape {}.'.format( + log_prob_val.shape)) + return jnp.sum(log_prob_val, axis=1) diff --git a/counterfactual_fairness/variational.py b/counterfactual_fairness/variational.py new file mode 100644 index 0000000..c40f03f --- /dev/null +++ b/counterfactual_fairness/variational.py @@ -0,0 +1,87 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions and classes for performing variational inference.""" + +from typing import Callable, Iterable, Optional + +import haiku as hk +import jax.numpy as jnp +from tensorflow_probability.substrates import jax as tfp + +tfd = tfp.distributions + + +class Variational(hk.Module): + """A module representing the variational distribution q(H | *O). + + H is assumed to be a continuous variable. + """ + + def __init__(self, + common_layer_sizes: Iterable[int], + activation: Callable[[jnp.DeviceArray], + jnp.DeviceArray] = jnp.tanh, + output_dim: int = 1, + name: Optional[str] = None): + """Initialises a `Variational` instance. + + Args: + common_layer_sizes: The number of hidden units in the shared dense + network layers. + activation: Nonlinearity function to apply to each of the + common layers. + output_dim: The dimensionality of `H`. + name: A name to assign to the module instance. + """ + super().__init__(name=name) + self._common_layer_sizes = common_layer_sizes + self._activation = activation + self._output_dim = output_dim + + self._linear_layers = [ + hk.Linear(layer_size) + for layer_size in self._common_layer_sizes + ] + + self._mean_output = hk.Linear(self._output_dim) + self._log_var_output = hk.Linear(self._output_dim) + + def __call__(self, *args) -> tfd.Distribution: + """Create a distribution for q(H | *O). + + Args: + *args: `List[DeviceArray]`. Corresponds to the values of whatever + variables are in the conditional set *O. + + Returns: + `tfp.distributions.NormalDistribution` instance. + """ + # Stack all inputs, ensuring that shapes are consistent and that they are + # all of dtype float32. + input_ = [hk.Flatten()(arg) for arg in args] + input_ = jnp.concatenate(input_, axis=1) + + # Create a common set of layers, then final layer separates mean & log_var + for layer in self._linear_layers: + input_ = layer(input_) + input_ = self._activation(input_) + + # input_ now represents a tensor of shape (batch_size, final_layer_size). + # This is now put through two final layers, one for the computation of each + # of the mean and standard deviation of the resultant distribution. + mean = self._mean_output(input_) + log_var = self._log_var_output(input_) + std = jnp.sqrt(jnp.exp(log_var)) + return tfd.Normal(mean, std)