mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 21:15:21 +08:00
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 511294746
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
797ea3c71d
commit
c051e6a51d
@@ -31,8 +31,8 @@ def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Filter to exclude biaises and normalizations weights."""
|
||||
del val
|
||||
if path[-1] == "b" or "norm" in path[-2]:
|
||||
return False
|
||||
return True
|
||||
return False # pytype: disable=bad-return-type # jax-ndarray
|
||||
return True # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
|
||||
def _partial_update(updates: optax.Updates,
|
||||
|
||||
Reference in New Issue
Block a user