mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 02:35:47 +08:00
Replaces references to jax.numpy.DeviceArray with jax.Array.
PiperOrigin-RevId: 515678285
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
0824c28deb
commit
d988ff1bf2
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import jax
|
||||||
from jax import random
|
from jax import random
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -42,8 +43,8 @@ def get_dataset(dataset: pd.DataFrame,
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_mode(
|
def multinomial_mode(
|
||||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
distribution_or_probs: Union[tfd.Distribution, jax.Array]
|
||||||
) -> jnp.DeviceArray:
|
) -> jax.Array:
|
||||||
"""Calculates the (one-hot) mode of a multinomial distribution.
|
"""Calculates the (one-hot) mode of a multinomial distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -71,8 +72,8 @@ def multinomial_mode(
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_class(
|
def multinomial_class(
|
||||||
distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
|
distribution_or_probs: Union[tfd.Distribution, jax.Array]
|
||||||
) -> jnp.DeviceArray:
|
) -> jax.Array:
|
||||||
"""Computes the mode class of a multinomial distribution.
|
"""Computes the mode class of a multinomial distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -90,7 +91,7 @@ def multinomial_class(
|
|||||||
return jnp.argmax(distribution_or_probs, axis=1)
|
return jnp.argmax(distribution_or_probs, axis=1)
|
||||||
|
|
||||||
|
|
||||||
def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array:
|
||||||
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
|
"""Calculates the (one-hot) mode from an ndarray of class probabilities.
|
||||||
|
|
||||||
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
|
Equivalent to `multinomial_mode` above, but implemented for numpy ndarrays
|
||||||
@@ -109,7 +110,7 @@ def multinomial_mode_ndarray(probs: jnp.DeviceArray) -> jnp.DeviceArray:
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
||||||
data: jnp.DeviceArray) -> jnp.DeviceArray:
|
data: jax.Array) -> jax.Array:
|
||||||
"""Compute the accuracy, averaged over a batch of data.
|
"""Compute the accuracy, averaged over a batch of data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -126,7 +127,7 @@ def multinomial_accuracy(distribution_or_probs: tfd.Distribution,
|
|||||||
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
|
jnp.sum(multinomial_mode(distribution_or_probs) * data, axis=1))
|
||||||
|
|
||||||
|
|
||||||
def softmax_ndarray(logits: jnp.DeviceArray) -> jnp.DeviceArray:
|
def softmax_ndarray(logits: jax.Array) -> jax.Array:
|
||||||
"""Softmax function, implemented for numpy ndarrays."""
|
"""Softmax function, implemented for numpy ndarrays."""
|
||||||
assert len(logits.shape) == 2
|
assert len(logits.shape) == 2
|
||||||
# Normalise for better stability.
|
# Normalise for better stability.
|
||||||
@@ -166,7 +167,7 @@ def get_samples(distribution, num_samples, seed=None):
|
|||||||
|
|
||||||
|
|
||||||
def mmd_loss(distribution: tfd.Distribution,
|
def mmd_loss(distribution: tfd.Distribution,
|
||||||
is_a: jnp.DeviceArray,
|
is_a: jax.Array,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
rng: jnp.ndarray,
|
rng: jnp.ndarray,
|
||||||
num_random_features: int = 50,
|
num_random_features: int = 50,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from typing import Callable, Iterable, Optional
|
from typing import Callable, Iterable, Optional
|
||||||
|
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from tensorflow_probability.substrates import jax as tfp
|
from tensorflow_probability.substrates import jax as tfp
|
||||||
|
|
||||||
@@ -31,8 +32,7 @@ class Variational(hk.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
common_layer_sizes: Iterable[int],
|
common_layer_sizes: Iterable[int],
|
||||||
activation: Callable[[jnp.DeviceArray],
|
activation: Callable[[jax.Array], jax.Array] = jnp.tanh,
|
||||||
jnp.DeviceArray] = jnp.tanh,
|
|
||||||
output_dim: int = 1,
|
output_dim: int = 1,
|
||||||
name: Optional[str] = None):
|
name: Optional[str] = None):
|
||||||
"""Initialises a `Variational` instance.
|
"""Initialises a `Variational` instance.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import chex
|
import chex
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.ndimage import interpolation
|
from scipy.ndimage import interpolation
|
||||||
@@ -72,7 +73,7 @@ def load_deskewed_mnist(*a, **k):
|
|||||||
class MeanStdEstimator(hk.Module):
|
class MeanStdEstimator(hk.Module):
|
||||||
"""Online mean and standard deviation estimator using Welford's algorithm."""
|
"""Online mean and standard deviation estimator using Welford's algorithm."""
|
||||||
|
|
||||||
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
|
def __call__(self, sample: jax.Array) -> Tuple[Array, Array]:
|
||||||
if len(sample.shape) > 1:
|
if len(sample.shape) > 1:
|
||||||
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
|
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -17,11 +17,11 @@
|
|||||||
|
|
||||||
from typing import Callable, Tuple, Union
|
from typing import Callable, Tuple, Union
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
|
||||||
TensorLike = Union[np.ndarray, jnp.DeviceArray]
|
TensorLike = Union[np.ndarray, jax.Array]
|
||||||
|
|
||||||
ActivationFn = Callable[[TensorLike], TensorLike]
|
ActivationFn = Callable[[TensorLike], TensorLike]
|
||||||
GatingFn = Callable[[TensorLike], TensorLike]
|
GatingFn = Callable[[TensorLike], TensorLike]
|
||||||
|
|||||||
@@ -46,18 +46,18 @@ class RegressionLossConfig:
|
|||||||
|
|
||||||
|
|
||||||
def _sigmoid_cross_entropy(
|
def _sigmoid_cross_entropy(
|
||||||
logits: jnp.DeviceArray,
|
logits: jax.Array,
|
||||||
labels: jnp.DeviceArray,
|
labels: jax.Array,
|
||||||
) -> jnp.DeviceArray:
|
) -> jax.Array:
|
||||||
log_p = jax.nn.log_sigmoid(logits)
|
log_p = jax.nn.log_sigmoid(logits)
|
||||||
log_not_p = jax.nn.log_sigmoid(-logits)
|
log_not_p = jax.nn.log_sigmoid(-logits)
|
||||||
return -labels * log_p - (1. - labels) * log_not_p
|
return -labels * log_p - (1. - labels) * log_not_p
|
||||||
|
|
||||||
|
|
||||||
def _softmax_cross_entropy(
|
def _softmax_cross_entropy(
|
||||||
logits: jnp.DeviceArray,
|
logits: jax.Array,
|
||||||
targets: jnp.DeviceArray,
|
targets: jax.Array,
|
||||||
) -> jnp.DeviceArray:
|
) -> jax.Array:
|
||||||
logits = jax.nn.log_softmax(logits)
|
logits = jax.nn.log_softmax(logits)
|
||||||
return -jnp.sum(targets * logits, axis=-1)
|
return -jnp.sum(targets * logits, axis=-1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user