diff --git a/tandem_dqn/agent.py b/tandem_dqn/agent.py index 3018c33..1cfb6cc 100644 --- a/tandem_dqn/agent.py +++ b/tandem_dqn/agent.py @@ -338,8 +338,7 @@ class TandemDqn(parts.Agent): def statistics(self) -> Mapping[Text, float]: """Returns current agent statistics as a dictionary.""" # Check for DeviceArrays in values as this can be very slow. - assert all( - not isinstance(x, jnp.DeviceArray) for x in self._statistics.values()) + assert all(not isinstance(x, jax.Array) for x in self._statistics.values()) return self._statistics @property