mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 13:44:07 +08:00
Fix pytype failures related to teaching pytype about NumPy scalar types.
PiperOrigin-RevId: 519205179
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
e2d21540e8
commit
82a347438f
@@ -103,7 +103,7 @@ def weight_decay(params: hk.Params,
|
|||||||
any(re.match(regex, name) for regex in regex_ignore)):
|
any(re.match(regex, name) for regex in regex_ignore)):
|
||||||
continue
|
continue
|
||||||
l2_norm += jnp.sum(jnp.square(param))
|
l2_norm += jnp.sum(jnp.square(param))
|
||||||
return .5 * l2_norm
|
return .5 * l2_norm # pytype: disable=bad-return-type # numpy-scalars
|
||||||
|
|
||||||
|
|
||||||
def ema_update(step: chex.Array,
|
def ema_update(step: chex.Array,
|
||||||
@@ -113,7 +113,7 @@ def ema_update(step: chex.Array,
|
|||||||
warmup_steps: int = 0,
|
warmup_steps: int = 0,
|
||||||
dynamic_decay: bool = True) -> chex.ArrayTree:
|
dynamic_decay: bool = True) -> chex.ArrayTree:
|
||||||
"""Applies an exponential moving average."""
|
"""Applies an exponential moving average."""
|
||||||
factor = (step >= warmup_steps).astype(jnp.float32)
|
factor = (step >= warmup_steps).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars
|
||||||
if dynamic_decay:
|
if dynamic_decay:
|
||||||
# Uses TF-style EMA.
|
# Uses TF-style EMA.
|
||||||
delta = step - warmup_steps
|
delta = step - warmup_steps
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ class EvalExperiment:
|
|||||||
classif_params = self.forward_classif.init(rng, embeddings)
|
classif_params = self.forward_classif.init(rng, embeddings)
|
||||||
classif_opt_state = self._optimizer(0.).init(classif_params)
|
classif_opt_state = self._optimizer(0.).init(classif_params)
|
||||||
|
|
||||||
return _EvalExperimentState(
|
return _EvalExperimentState( # pytype: disable=wrong-arg-types # numpy-scalars
|
||||||
backbone_params=backbone_params,
|
backbone_params=backbone_params,
|
||||||
classif_params=classif_params,
|
classif_params=classif_params,
|
||||||
backbone_state=backbone_state,
|
backbone_state=backbone_state,
|
||||||
@@ -389,7 +389,7 @@ class EvalExperiment:
|
|||||||
new_backbone_params = optax.apply_updates(
|
new_backbone_params = optax.apply_updates(
|
||||||
experiment_state.backbone_params, backbone_updates)
|
experiment_state.backbone_params, backbone_updates)
|
||||||
|
|
||||||
experiment_state = _EvalExperimentState(
|
experiment_state = _EvalExperimentState( # pytype: disable=wrong-arg-types # numpy-scalars
|
||||||
new_backbone_params,
|
new_backbone_params,
|
||||||
new_classif_params,
|
new_classif_params,
|
||||||
new_backbone_state,
|
new_backbone_state,
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ def add_weight_decay(
|
|||||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||||
return new_updates, state
|
return new_updates, state
|
||||||
|
|
||||||
return optax.GradientTransformation(init_fn, update_fn)
|
return optax.GradientTransformation(init_fn, update_fn) # pytype: disable=wrong-arg-types # numpy-scalars
|
||||||
|
|
||||||
|
|
||||||
LarsState = List # Type for the lars optimizer
|
LarsState = List # Type for the lars optimizer
|
||||||
|
|||||||
@@ -444,9 +444,9 @@ class Experiment(experiment.AbstractExperiment):
|
|||||||
params = optax.apply_updates(params, updates)
|
params = optax.apply_updates(params, updates)
|
||||||
|
|
||||||
n_params = 0
|
n_params = 0
|
||||||
for k in params.keys():
|
for k in params.keys(): # pytype: disable=attribute-error # numpy-scalars
|
||||||
for l in params[k]:
|
for l in params[k]:
|
||||||
n_params = n_params + np.prod(params[k][l].shape)
|
n_params = n_params + np.prod(params[k][l].shape) # pytype: disable=attribute-error # numpy-scalars
|
||||||
|
|
||||||
# Scalars to log (note: we log the mean across all hosts/devices).
|
# Scalars to log (note: we log the mean across all hosts/devices).
|
||||||
scalars = {'learning_rate': learning_rate,
|
scalars = {'learning_rate': learning_rate,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ def training_statistics(
|
|||||||
# Note that "_over_time" stats are getting normalised by time here
|
# Note that "_over_time" stats are getting normalised by time here
|
||||||
stats = jax.tree_map(lambda x: x / scale, stats)
|
stats = jax.tree_map(lambda x: x / scale, stats)
|
||||||
if p_x_learned_sigma:
|
if p_x_learned_sigma:
|
||||||
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0]
|
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0] # pytype: disable=attribute-error # numpy-scalars
|
||||||
if q_z is not None:
|
if q_z is not None:
|
||||||
stats["small_latents"] = calculate_small_latents(q_z)
|
stats["small_latents"] = calculate_small_latents(q_z)
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
Reference in New Issue
Block a user