Replaces references to jax.numpy.DeviceArray with jax.Array.

PiperOrigin-RevId: 515678285
This commit is contained in:
Peter Hawkins
2023-03-10 18:42:12 +00:00
committed by Saran Tunyasuvunakool
parent 0824c28deb
commit d988ff1bf2
5 changed files with 21 additions and 19 deletions
+9 -8
View File
@@ -16,6 +16,7 @@
from typing import Optional, Union
import jax
from jax import random
import jax.numpy as jnp
import pandas as pd
@@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame,
def multinomial_mode(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Calculates the (one-hot) mode of a multinomial distribution.
Args:
@@ -71,8 +72,8 @@ def multinomial_mode(
def multinomial_class(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Computes the mode class of a multinomial distribution.
Args:
@@ -90,7 +91,7 @@ def multinomial_class(
return jnp.argmax(distribution_or_probs, axis=1)
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
@@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
data: jnp.DeviceArray) -> jnp.DeviceArray:
data: jax.Array) -> jax.Array:
"""Compute the accuracy, averaged over a batch of data.
Args:
@@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
def softmax_ndarray(logits: jax.Array) -> jax.Array:
"""Softmax function, implemented for numpy ndarrays."""
assert len(logits.shape) == 2
# Normalise for better stability.
@@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None):
def mmd_loss(distribution: tfd.Distribution,
is_a: jnp.DeviceArray,
is_a: jax.Array,
num_samples: int,
rng: jnp.ndarray,
num_random_features: int = 50,
+2 -2
View File
@@ -17,6 +17,7 @@
from typing import Callable, Iterable, Optional
import haiku as hk
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
@@ -31,8 +32,7 @@ class Variational(hk.Module):
def __init__(self,
common_layer_sizes: Iterable[int],
activation: Callable[[jnp.DeviceArray],
jnp.DeviceArray] = jnp.tanh,
activation: Callable[[jax.Array], jax.Array] = jnp.tanh,
output_dim: int = 1,
name: Optional[str] = None):
"""Initialises a `Variational` instance.
+2 -1
View File
@@ -18,6 +18,7 @@ from typing import Tuple
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from scipy.ndimage import interpolation
@@ -72,7 +73,7 @@ def load_deskewed_mnist(*a, **k):
class MeanStdEstimator(hk.Module):
"""Online mean and standard deviation estimator using Welford's algorithm."""
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
def __call__(self, sample: jax.Array) -> Tuple[Array, Array]:
if len(sample.shape) > 1:
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
+2 -2
View File
@@ -17,11 +17,11 @@
from typing import Callable, Tuple, Union
import jax.numpy as jnp
import jax
import numpy as np
import optax
TensorLike = Union[np.ndarray, jnp.DeviceArray]
TensorLike = Union[np.ndarray, jax.Array]
ActivationFn = Callable[[TensorLike], TensorLike]
GatingFn = Callable[[TensorLike], TensorLike]
+6 -6
View File
@@ -46,18 +46,18 @@ class RegressionLossConfig:
def _sigmoid_cross_entropy(
logits: jnp.DeviceArray,
labels: jnp.DeviceArray,
) -> jnp.DeviceArray:
logits: jax.Array,
labels: jax.Array,
) -> jax.Array:
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
return -labels * log_p - (1. - labels) * log_not_p
def _softmax_cross_entropy(
logits: jnp.DeviceArray,
targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
logits: jax.Array,
targets: jax.Array,
) -> jax.Array:
logits = jax.nn.log_softmax(logits)
return -jnp.sum(targets * logits, axis=-1)