Replaces references to jax.numpy.DeviceArray with jax.Array.

PiperOrigin-RevId: 515678285
This commit is contained in:
Peter Hawkins
2023-03-10 18:42:12 +00:00
committed by Saran Tunyasuvunakool
parent 0824c28deb
commit d988ff1bf2
5 changed files with 21 additions and 19 deletions
+9 -8
View File
@@ -16,6 +16,7 @@
from typing import Optional, Union
import jax
from jax import random
import jax.numpy as jnp
import pandas as pd
@@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame,
def multinomial_mode(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Calculates the (one-hot) mode of a multinomial distribution.
Args:
@@ -71,8 +72,8 @@ def multinomial_mode(
def multinomial_class(
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
distribution_or_probs: Union[tfd.Distribution, jax.Array]
) -> jax.Array:
"""Computes the mode class of a multinomial distribution.
Args:
@@ -90,7 +91,7 @@ def multinomial_class(
return jnp.argmax(distribution_or_probs, axis=1)
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
@@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
data: jnp.DeviceArray) -> jnp.DeviceArray:
data: jax.Array) -> jax.Array:
"""Compute the accuracy, averaged over a batch of data.
Args:
@@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
def softmax_ndarray(logits: jax.Array) -> jax.Array:
"""Softmax function, implemented for numpy ndarrays."""
assert len(logits.shape) == 2
# Normalise for better stability.
@@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None):
def mmd_loss(distribution: tfd.Distribution,
is_a: jnp.DeviceArray,
is_a: jax.Array,
num_samples: int,
rng: jnp.ndarray,
num_random_features: int = 50,
+2 -2
View File
@@ -17,6 +17,7 @@
from typing import Callable, Iterable, Optional
import haiku as hk
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
@@ -31,8 +32,7 @@ class Variational(hk.Module):
def __init__(self,
common_layer_sizes: Iterable[int],
activation: Callable[[jnp.DeviceArray],
jnp.DeviceArray] = jnp.tanh,
activation: Callable[[jax.Array], jax.Array] = jnp.tanh,
output_dim: int = 1,
name: Optional[str] = None):
"""Initialises a `Variational` instance.