diff --git a/counterfactual_fairness/utils.py b/counterfactual_fairness/utils.py index 8162e9b..877e74c 100644 --- a/counterfactual_fairness/utils.py +++ b/counterfactual_fairness/utils.py @@ -16,6 +16,7 @@ from typing import Optional, Union +import jax from jax import random import jax.numpy as jnp import pandas as pd @@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame, def multinomial_mode( - distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] - ) -> jnp.DeviceArray: + distribution_or_probs: Union[tfd.Distribution, jax.Array] + ) -> jax.Array: """Calculates the (one-hot) mode of a multinomial distribution. Args: @@ -71,8 +72,8 @@ def multinomial_mode( def multinomial_class( - distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] -) -> jnp.DeviceArray: + distribution_or_probs: Union[tfd.Distribution, jax.Array] +) -> jax.Array: """Computes the mode class of a multinomial distribution. Args: @@ -90,7 +91,7 @@ def multinomial_class( return jnp.argmax(distribution_or_probs, axis=1) -def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray: +def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array: """Calculates the (one-hot) mode from an ndarray of class probabilities. Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays @@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray: def multinomial_accuracy(distribution_or_probs: tfd.Distribution, - data: jnp.DeviceArray) -> jnp.DeviceArray: + data: jax.Array) -> jax.Array: """Compute the accuracy, averaged over a batch of data. Args: @@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution, jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1)) -def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray: +def softmax_ndarray(logits: jax.Array) -> jax.Array: """Softmax function, implemented for numpy ndarrays.""" assert len(logits.shape) == 2 # Normalise for better stability. @@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None): def mmd_loss(distribution: tfd.Distribution, - is_a: jnp.DeviceArray, + is_a: jax.Array, num_samples: int, rng: jnp.ndarray, num_random_features: int = 50, diff --git a/counterfactual_fairness/variational.py b/counterfactual_fairness/variational.py index c40f03f..617fc32 100644 --- a/counterfactual_fairness/variational.py +++ b/counterfactual_fairness/variational.py @@ -17,6 +17,7 @@ from typing import Callable, Iterable, Optional import haiku as hk +import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp @@ -31,8 +32,7 @@ class Variational(hk.Module): def __init__(self, common_layer_sizes: Iterable[int], - activation: Callable[[jnp.DeviceArray], - jnp.DeviceArray] = jnp.tanh, + activation: Callable[[jax.Array], jax.Array] = jnp.tanh, output_dim: int = 1, name: Optional[str] = None): """Initialises a `Variational` instance. diff --git a/gated_linear_networks/examples/utils.py b/gated_linear_networks/examples/utils.py index 2557e37..0997489 100644 --- a/gated_linear_networks/examples/utils.py +++ b/gated_linear_networks/examples/utils.py @@ -18,6 +18,7 @@ from typing import Tuple import chex import haiku as hk +import jax import jax.numpy as jnp import numpy as np from scipy.ndimage import interpolation @@ -72,7 +73,7 @@ def load_deskewed_mnist(*a, **k): class MeanStdEstimator(hk.Module): """Online mean and standard deviation estimator using Welford's algorithm.""" - def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]: + def __call__(self, sample: jax.Array) -> Tuple[Array, Array]: if len(sample.shape) > 1: raise ValueError("sample must be a rank 0 or 1 DeviceArray.") diff --git a/mmv/models/types.py b/mmv/models/types.py index bac7e52..82c3675 100644 --- a/mmv/models/types.py +++ b/mmv/models/types.py @@ -17,11 +17,11 @@ from typing import Callable, Tuple, Union -import jax.numpy as jnp +import jax import numpy as np import optax -TensorLike = Union[np.ndarray, jnp.DeviceArray] +TensorLike = Union[np.ndarray, jax.Array] ActivationFn = Callable[[TensorLike], TensorLike] GatingFn = Callable[[TensorLike], TensorLike] diff --git a/ogb_lsc/pcq/model.py b/ogb_lsc/pcq/model.py index 6bb2b7b..e361de7 100644 --- a/ogb_lsc/pcq/model.py +++ b/ogb_lsc/pcq/model.py @@ -46,18 +46,18 @@ class RegressionLossConfig: def _sigmoid_cross_entropy( - logits: jnp.DeviceArray, - labels: jnp.DeviceArray, -) -> jnp.DeviceArray: + logits: jax.Array, + labels: jax.Array, +) -> jax.Array: log_p = jax.nn.log_sigmoid(logits) log_not_p = jax.nn.log_sigmoid(-logits) return -labels * log_p - (1. - labels) * log_not_p def _softmax_cross_entropy( - logits: jnp.DeviceArray, - targets: jnp.DeviceArray, -) -> jnp.DeviceArray: + logits: jax.Array, + targets: jax.Array, +) -> jax.Array: logits = jax.nn.log_softmax(logits) return -jnp.sum(targets * logits, axis=-1)