mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 14:33:42 +08:00
Replaces references to jax.numpy.DeviceArray with jax.Array.
PiperOrigin-RevId: 515678285
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
0824c28deb
commit
d988ff1bf2
@@ -16,6 +16,7 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
@@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame,
|
||||
|
||||
|
||||
def multinomial_mode(
|
||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||
) -> jnp.DeviceArray:
|
||||
distribution_or_probs: Union[tfd.Distribution, jax.Array]
|
||||
) -> jax.Array:
|
||||
"""Calculates the (one-hot) mode of a multinomial distribution.
|
||||
|
||||
Args:
|
||||
@@ -71,8 +72,8 @@ def multinomial_mode(
|
||||
|
||||
|
||||
def multinomial_class(
|
||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
||||
) -> jnp.DeviceArray:
|
||||
distribution_or_probs: Union[tfd.Distribution, jax.Array]
|
||||
) -> jax.Array:
|
||||
"""Computes the mode class of a multinomial distribution.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +91,7 @@ def multinomial_class(
|
||||
return jnp.argmax(distribution_or_probs, axis=1)
|
||||
|
||||
|
||||
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
|
||||
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
|
||||
|
||||
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
|
||||
@@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
|
||||
|
||||
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
||||
data: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
data: jax.Array) -> jax.Array:
|
||||
"""Compute the accuracy, averaged over a batch of data.
|
||||
|
||||
Args:
|
||||
@@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
||||
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
|
||||
|
||||
|
||||
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
|
||||
def softmax_ndarray(logits: jax.Array) -> jax.Array:
|
||||
"""Softmax function, implemented for numpy ndarrays."""
|
||||
assert len(logits.shape) == 2
|
||||
# Normalise for better stability.
|
||||
@@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None):
|
||||
|
||||
|
||||
def mmd_loss(distribution: tfd.Distribution,
|
||||
is_a: jnp.DeviceArray,
|
||||
is_a: jax.Array,
|
||||
num_samples: int,
|
||||
rng: jnp.ndarray,
|
||||
num_random_features: int = 50,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from tensorflow_probability.substrates import jax as tfp
|
||||
|
||||
@@ -31,8 +32,7 @@ class Variational(hk.Module):
|
||||
|
||||
def __init__(self,
|
||||
common_layer_sizes: Iterable[int],
|
||||
activation: Callable[[jnp.DeviceArray],
|
||||
jnp.DeviceArray] = jnp.tanh,
|
||||
activation: Callable[[jax.Array], jax.Array] = jnp.tanh,
|
||||
output_dim: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Initialises a `Variational` instance.
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Tuple
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from scipy.ndimage import interpolation
|
||||
@@ -72,7 +73,7 @@ def load_deskewed_mnist(*a, **k):
|
||||
class MeanStdEstimator(hk.Module):
|
||||
"""Online mean and standard deviation estimator using Welford's algorithm."""
|
||||
|
||||
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
|
||||
def __call__(self, sample: jax.Array) -> Tuple[Array, Array]:
|
||||
if len(sample.shape) > 1:
|
||||
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
|
||||
|
||||
|
||||
+2
-2
@@ -17,11 +17,11 @@
|
||||
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
TensorLike = Union[np.ndarray, jnp.DeviceArray]
|
||||
TensorLike = Union[np.ndarray, jax.Array]
|
||||
|
||||
ActivationFn = Callable[[TensorLike], TensorLike]
|
||||
GatingFn = Callable[[TensorLike], TensorLike]
|
||||
|
||||
@@ -46,18 +46,18 @@ class RegressionLossConfig:
|
||||
|
||||
|
||||
def _sigmoid_cross_entropy(
|
||||
logits: jnp.DeviceArray,
|
||||
labels: jnp.DeviceArray,
|
||||
) -> jnp.DeviceArray:
|
||||
logits: jax.Array,
|
||||
labels: jax.Array,
|
||||
) -> jax.Array:
|
||||
log_p = jax.nn.log_sigmoid(logits)
|
||||
log_not_p = jax.nn.log_sigmoid(-logits)
|
||||
return -labels * log_p - (1. - labels) * log_not_p
|
||||
|
||||
|
||||
def _softmax_cross_entropy(
|
||||
logits: jnp.DeviceArray,
|
||||
targets: jnp.DeviceArray,
|
||||
) -> jnp.DeviceArray:
|
||||
logits: jax.Array,
|
||||
targets: jax.Array,
|
||||
) -> jax.Array:
|
||||
logits = jax.nn.log_softmax(logits)
|
||||
return -jnp.sum(targets * logits, axis=-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user