Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.

PiperOrigin-RevId: 511294746
This commit is contained in:
Peter Hawkins
2023-02-21 21:45:32 +00:00
committed by Saran Tunyasuvunakool
parent 797ea3c71d
commit c051e6a51d
7 changed files with 12 additions and 12 deletions
+1 -1
View File
@@ -291,7 +291,7 @@ class Experiment(experiment.AbstractExperiment):
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step: int, rng: jnp.ndarray,
def step(self, global_step: int, rng: jnp.ndarray, # pytype: disable=signature-mismatch # jax-ndarray
*unused_args, **unused_kwargs):
"""See base class."""