Fix pytype failures related to teaching pytype about NumPy scalar types.

PiperOrigin-RevId: 519205179
This commit is contained in:
Peter Hawkins
2023-03-24 19:06:40 +00:00
committed by Saran Tunyasuvunakool
parent e2d21540e8
commit 82a347438f
5 changed files with 8 additions and 8 deletions
+1 -1
View File
@@ -121,7 +121,7 @@ def training_statistics(
# Note that "_over_time" stats are getting normalised by time here
stats = jax.tree_map(lambda x: x / scale, stats)
if p_x_learned_sigma:
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0]
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0] # pytype: disable=attribute-error # numpy-scalars
if q_z is not None:
stats["small_latents"] = calculate_small_latents(q_z)
return stats