Replace jnp.DeviceArray with the new public type jax.Array.

PiperOrigin-RevId: 477558995
This commit is contained in:
Yash Katariya
2022-09-28 21:53:07 +00:00
committed by Diego de las Casas
parent 58fb45db7e
commit da0f2de14d
+1 -2
View File
@@ -338,8 +338,7 @@ class TandemDqn(parts.Agent):
def statistics(self) -> Mapping[Text, float]:
"""Returns current agent statistics as a dictionary."""
# Check for DeviceArrays in values as this can be very slow.
assert all(
not isinstance(x, jnp.DeviceArray) for x in self._statistics.values())
assert all(not isinstance(x, jax.Array) for x in self._statistics.values())
return self._statistics
@property