Replaces references to jax.numpy.DeviceArray with jax.Array.

PiperOrigin-RevId: 515678285
This commit is contained in:
Peter Hawkins
2023-03-10 18:42:12 +00:00
committed by Saran Tunyasuvunakool
parent 0824c28deb
commit d988ff1bf2
5 changed files with 21 additions and 19 deletions
+9 -8
View File
@@ -16,6 +16,7 @@
from typing import Optional, Union from typing import Optional, Union
import jax
from jax import random from jax import random
import jax.numpy as jnp import jax.numpy as jnp
import pandas as pd import pandas as pd
@@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame,
def multinomial_mode( def multinomial_mode(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jnp.DeviceArray: ) -> jax.Array:
"""Calculates the (one-hot) mode of a multinomial distribution. """Calculates the (one-hot) mode of a multinomial distribution.
Args: Args:
@@ -71,8 +72,8 @@ def multinomial_mode(
def multinomial_class( def multinomial_class(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jnp.DeviceArray: ) -> jax.Array:
"""Computes the mode class of a multinomial distribution. """Computes the mode class of a multinomial distribution.
Args: Args:
@@ -90,7 +91,7 @@ def multinomial_class(
return jnp.argmax(distribution_or_probs, axis=1) return jnp.argmax(distribution_or_probs, axis=1)
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray: def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
"""Calculates the (one-hot) mode from an ndarray of class probabilities. """Calculates the (one-hot) mode from an ndarray of class probabilities.
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
@@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
def multinomial_accuracy(distribution_or_probs: tfd.Distribution, def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
data: jnp.DeviceArray) -> jnp.DeviceArray: data: jax.Array) -> jax.Array:
"""Compute the accuracy, averaged over a batch of data. """Compute the accuracy, averaged over a batch of data.
Args: Args:
@@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1)) jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray: def softmax_ndarray(logits: jax.Array) -> jax.Array:
"""Softmax function, implemented for numpy ndarrays.""" """Softmax function, implemented for numpy ndarrays."""
assert len(logits.shape) == 2 assert len(logits.shape) == 2
# Normalise for better stability. # Normalise for better stability.
@@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None):
def mmd_loss(distribution: tfd.Distribution, def mmd_loss(distribution: tfd.Distribution,
is_a: jnp.DeviceArray, is_a: jax.Array,
num_samples: int, num_samples: int,
rng: jnp.ndarray, rng: jnp.ndarray,
num_random_features: int = 50, num_random_features: int = 50,
+2 -2
View File
@@ -17,6 +17,7 @@
from typing import Callable, Iterable, Optional from typing import Callable, Iterable, Optional
import haiku as hk import haiku as hk
import jax
import jax.numpy as jnp import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp from tensorflow_probability.substrates import jax as tfp
@@ -31,8 +32,7 @@ class Variational(hk.Module):
def __init__(self, def __init__(self,
common_layer_sizes: Iterable[int], common_layer_sizes: Iterable[int],
activation: Callable[[jnp.DeviceArray], activation: Callable[[jax.Array], jax.Array] = jnp.tanh,
jnp.DeviceArray] = jnp.tanh,
output_dim: int = 1, output_dim: int = 1,
name: Optional[str] = None): name: Optional[str] = None):
"""Initialises a `Variational` instance. """Initialises a `Variational` instance.
+2 -1
View File
@@ -18,6 +18,7 @@ from typing import Tuple
import chex import chex
import haiku as hk import haiku as hk
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from scipy.ndimage import interpolation from scipy.ndimage import interpolation
@@ -72,7 +73,7 @@ def load_deskewed_mnist(*a, **k):
class MeanStdEstimator(hk.Module): class MeanStdEstimator(hk.Module):
"""Online mean and standard deviation estimator using Welford's algorithm.""" """Online mean and standard deviation estimator using Welford's algorithm."""
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]: def __call__(self, sample: jax.Array) -> Tuple[Array, Array]:
if len(sample.shape) > 1: if len(sample.shape) > 1:
raise ValueError("sample must be a rank 0 or 1 DeviceArray.") raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
+2 -2
View File
@@ -17,11 +17,11 @@
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union
import jax.numpy as jnp import jax
import numpy as np import numpy as np
import optax import optax
TensorLike = Union[np.ndarray, jnp.DeviceArray] TensorLike = Union[np.ndarray, jax.Array]
ActivationFn = Callable[[TensorLike], TensorLike] ActivationFn = Callable[[TensorLike], TensorLike]
GatingFn = Callable[[TensorLike], TensorLike] GatingFn = Callable[[TensorLike], TensorLike]
+6 -6
View File
@@ -46,18 +46,18 @@ class RegressionLossConfig:
def _sigmoid_cross_entropy( def _sigmoid_cross_entropy(
logits: jnp.DeviceArray, logits: jax.Array,
labels: jnp.DeviceArray, labels: jax.Array,
) -> jnp.DeviceArray: ) -> jax.Array:
log_p = jax.nn.log_sigmoid(logits) log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits) log_not_p = jax.nn.log_sigmoid(-logits)
return -labels * log_p - (1. - labels) * log_not_p return -labels * log_p - (1. - labels) * log_not_p
def _softmax_cross_entropy( def _softmax_cross_entropy(
logits: jnp.DeviceArray, logits: jax.Array,
targets: jnp.DeviceArray, targets: jax.Array,
) -> jnp.DeviceArray: ) -> jax.Array:
logits = jax.nn.log_softmax(logits) logits = jax.nn.log_softmax(logits)
return -jnp.sum(targets * logits, axis=-1) return -jnp.sum(targets * logits, axis=-1)