Fix pytype failures related to teaching pytype about NumPy scalar types.

PiperOrigin-RevId: 519205179
This commit is contained in:
Peter Hawkins
2023-03-24 19:06:40 +00:00
committed by Saran Tunyasuvunakool
parent e2d21540e8
commit 82a347438f
5 changed files with 8 additions and 8 deletions
+2 -2
View File
@@ -444,9 +444,9 @@ class Experiment(experiment.AbstractExperiment):
params = optax.apply_updates(params, updates)
n_params = 0
for k in params.keys():
for k in params.keys(): # pytype: disable=attribute-error # numpy-scalars
for l in params[k]:
n_params = n_params + np.prod(params[k][l].shape)
n_params = n_params + np.prod(params[k][l].shape) # pytype: disable=attribute-error # numpy-scalars
# Scalars to log (note: we log the mean across all hosts/devices).
scalars = {'learning_rate': learning_rate,