Replaces references to jax.numpy.DeviceArray with jax.Array.

PiperOrigin-RevId: 515678285
This commit is contained in:
Peter Hawkins
2023-03-10 18:42:12 +00:00
committed by Saran Tunyasuvunakool
parent 0824c28deb
commit d988ff1bf2
5 changed files with 21 additions and 19 deletions
+6 -6
View File
@@ -46,18 +46,18 @@ class RegressionLossConfig:
def _sigmoid_cross_entropy(
logits: jnp.DeviceArray,
labels: jnp.DeviceArray,
) -> jnp.DeviceArray:
logits: jax.Array,
labels: jax.Array,
) -> jax.Array:
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
return -labels * log_p - (1. - labels) * log_not_p
def _softmax_cross_entropy(
logits: jnp.DeviceArray,
targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
logits: jax.Array,
targets: jax.Array,
) -> jax.Array:
logits = jax.nn.log_softmax(logits)
return -jnp.sum(targets * logits, axis=-1)