Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.

PiperOrigin-RevId: 511294746
This commit is contained in:
Peter Hawkins
2023-02-21 21:45:32 +00:00
committed by Saran Tunyasuvunakool
parent 797ea3c71d
commit c051e6a51d
7 changed files with 12 additions and 12 deletions
+2 -2
View File
@@ -31,8 +31,8 @@ def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
"""Filter to exclude biaises and normalizations weights."""
del val
if path[-1] == "b" or "norm" in path[-2]:
return False
return True
return False # pytype: disable=bad-return-type # jax-ndarray
return True # pytype: disable=bad-return-type # jax-ndarray
def _partial_update(updates: optax.Updates,