Fix pytype failures related to teaching pytype about NumPy scalar types.

PiperOrigin-RevId: 519205179
This commit is contained in:
Peter Hawkins
2023-03-24 19:06:40 +00:00
committed by Saran Tunyasuvunakool
parent e2d21540e8
commit 82a347438f
5 changed files with 8 additions and 8 deletions
+2 -2
View File
@@ -103,7 +103,7 @@ def weight_decay(params: hk.Params,
any(re.match(regex, name) for regex in regex_ignore)): any(re.match(regex, name) for regex in regex_ignore)):
continue continue
l2_norm += jnp.sum(jnp.square(param)) l2_norm += jnp.sum(jnp.square(param))
return .5 * l2_norm return .5 * l2_norm # pytype: disable=bad-return-type # numpy-scalars
def ema_update(step: chex.Array, def ema_update(step: chex.Array,
@@ -113,7 +113,7 @@ def ema_update(step: chex.Array,
warmup_steps: int = 0, warmup_steps: int = 0,
dynamic_decay: bool = True) -> chex.ArrayTree: dynamic_decay: bool = True) -> chex.ArrayTree:
"""Applies an exponential moving average.""" """Applies an exponential moving average."""
factor = (step >= warmup_steps).astype(jnp.float32) factor = (step >= warmup_steps).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars
if dynamic_decay: if dynamic_decay:
# Uses TF-style EMA. # Uses TF-style EMA.
delta = step - warmup_steps delta = step - warmup_steps
+2 -2
View File
@@ -271,7 +271,7 @@ class EvalExperiment:
classif_params = self.forward_classif.init(rng, embeddings) classif_params = self.forward_classif.init(rng, embeddings)
classif_opt_state = self._optimizer(0.).init(classif_params) classif_opt_state = self._optimizer(0.).init(classif_params)
return _EvalExperimentState( return _EvalExperimentState( # pytype: disable=wrong-arg-types # numpy-scalars
backbone_params=backbone_params, backbone_params=backbone_params,
classif_params=classif_params, classif_params=classif_params,
backbone_state=backbone_state, backbone_state=backbone_state,
@@ -389,7 +389,7 @@ class EvalExperiment:
new_backbone_params = optax.apply_updates( new_backbone_params = optax.apply_updates(
experiment_state.backbone_params, backbone_updates) experiment_state.backbone_params, backbone_updates)
experiment_state = _EvalExperimentState( experiment_state = _EvalExperimentState( # pytype: disable=wrong-arg-types # numpy-scalars
new_backbone_params, new_backbone_params,
new_classif_params, new_classif_params,
new_backbone_state, new_backbone_state,
+1 -1
View File
@@ -135,7 +135,7 @@ def add_weight_decay(
new_updates = _partial_update(updates, new_updates, params, filter_fn) new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state return new_updates, state
return optax.GradientTransformation(init_fn, update_fn) return optax.GradientTransformation(init_fn, update_fn) # pytype: disable=wrong-arg-types # numpy-scalars
LarsState = List # Type for the lars optimizer LarsState = List # Type for the lars optimizer
+2 -2
View File
@@ -444,9 +444,9 @@ class Experiment(experiment.AbstractExperiment):
params = optax.apply_updates(params, updates) params = optax.apply_updates(params, updates)
n_params = 0 n_params = 0
for k in params.keys(): for k in params.keys(): # pytype: disable=attribute-error # numpy-scalars
for l in params[k]: for l in params[k]:
n_params = n_params + np.prod(params[k][l].shape) n_params = n_params + np.prod(params[k][l].shape) # pytype: disable=attribute-error # numpy-scalars
# Scalars to log (note: we log the mean across all hosts/devices). # Scalars to log (note: we log the mean across all hosts/devices).
scalars = {'learning_rate': learning_rate, scalars = {'learning_rate': learning_rate,
+1 -1
View File
@@ -121,7 +121,7 @@ def training_statistics(
# Note that "_over_time" stats are getting normalised by time here # Note that "_over_time" stats are getting normalised by time here
stats = jax.tree_map(lambda x: x / scale, stats) stats = jax.tree_map(lambda x: x / scale, stats)
if p_x_learned_sigma: if p_x_learned_sigma:
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0] stats["p_x_sigma"] = p_x.variance().reshape([-1])[0] # pytype: disable=attribute-error # numpy-scalars
if q_z is not None: if q_z is not None:
stats["small_latents"] = calculate_small_latents(q_z) stats["small_latents"] = calculate_small_latents(q_z)
return stats return stats