Fix pytype failures related to teaching pytype about NumPy scalar types.

PiperOrigin-RevId: 519205179
This commit is contained in:
Peter Hawkins
2023-03-24 19:06:40 +00:00
committed by Saran Tunyasuvunakool
parent e2d21540e8
commit 82a347438f
5 changed files with 8 additions and 8 deletions
+1 -1
View File
@@ -135,7 +135,7 @@ def add_weight_decay(
new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state
return optax.GradientTransformation(init_fn, update_fn)
return optax.GradientTransformation(init_fn, update_fn) # pytype: disable=wrong-arg-types # numpy-scalars
LarsState = List # Type for the lars optimizer