mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 13:05:40 +08:00
Replaces references to jax.numpy.DeviceArray with jax.Array.
PiperOrigin-RevId: 515678285
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
0824c28deb
commit
d988ff1bf2
@@ -46,18 +46,18 @@ class RegressionLossConfig:
|
||||
|
||||
|
||||
def _sigmoid_cross_entropy(
|
||||
logits: jnp.DeviceArray,
|
||||
labels: jnp.DeviceArray,
|
||||
) -> jnp.DeviceArray:
|
||||
logits: jax.Array,
|
||||
labels: jax.Array,
|
||||
) -> jax.Array:
|
||||
log_p = jax.nn.log_sigmoid(logits)
|
||||
log_not_p = jax.nn.log_sigmoid(-logits)
|
||||
return -labels * log_p - (1. - labels) * log_not_p
|
||||
|
||||
|
||||
def _softmax_cross_entropy(
|
||||
logits: jnp.DeviceArray,
|
||||
targets: jnp.DeviceArray,
|
||||
) -> jnp.DeviceArray:
|
||||
logits: jax.Array,
|
||||
targets: jax.Array,
|
||||
) -> jax.Array:
|
||||
logits = jax.nn.log_softmax(logits)
|
||||
return -jnp.sum(targets * logits, axis=-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user