Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.

PiperOrigin-RevId: 511294746
This commit is contained in:
Peter Hawkins
2023-02-21 21:45:32 +00:00
committed by Saran Tunyasuvunakool
parent 797ea3c71d
commit c051e6a51d
7 changed files with 12 additions and 12 deletions
+3 -3
View File
@@ -180,11 +180,11 @@ class ByolExperiment:
return outputs
if is_training:
outputs_view1 = apply_once_fn(inputs['view1'], '_view1')
outputs_view2 = apply_once_fn(inputs['view2'], '_view2')
outputs_view1 = apply_once_fn(inputs['view1'], '_view1') # pytype: disable=wrong-arg-types # jax-ndarray
outputs_view2 = apply_once_fn(inputs['view2'], '_view2') # pytype: disable=wrong-arg-types # jax-ndarray
return {**outputs_view1, **outputs_view2}
else:
return apply_once_fn(inputs['images'], '')
return apply_once_fn(inputs['images'], '') # pytype: disable=wrong-arg-types # jax-ndarray
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
"""Build optimizer from config."""
+2 -2
View File
@@ -432,8 +432,8 @@ class EvalExperiment:
logits = self.forward_classif.apply(classif_params, embeddings)
labels = hk.one_hot(inputs['labels'], self._num_classes)
loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1) # pytype: disable=wrong-arg-types # jax-ndarray
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5) # pytype: disable=wrong-arg-types # jax-ndarray
# NOTE: Returned values will be summed and finally divided by num_samples.
return {
'eval_loss': loss,
+2 -2
View File
@@ -31,8 +31,8 @@ def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
"""Filter to exclude biaises and normalizations weights."""
del val
if path[-1] == "b" or "norm" in path[-2]:
return False
return True
return False # pytype: disable=bad-return-type # jax-ndarray
return True # pytype: disable=bad-return-type # jax-ndarray
def _partial_update(updates: optax.Updates,
+2 -2
View File
@@ -481,10 +481,10 @@ class Optimizer(utils.Stateful):
self.step_counter = self.step_counter + 1
if self.value_func_has_state:
return params, self.pop_state(), new_func_state, stats
return params, self.pop_state(), new_func_state, stats # pytype: disable=bad-return-type # jax-ndarray
else:
assert new_func_state is None
return params, self.pop_state(), stats
return params, self.pop_state(), stats # pytype: disable=bad-return-type # jax-ndarray
def init(
self,
+1 -1
View File
@@ -728,7 +728,7 @@ class MultimodalPreprocessor(hk.Module):
# Apply a predictable ordering to the modalities
padded_ls = [padded[k] for k in sorted(padded.keys())]
return (jnp.concatenate(padded_ls, axis=1),
return (jnp.concatenate(padded_ls, axis=1), # pytype: disable=bad-return-type # jax-ndarray
modality_sizes,
inputs_without_pos)
+1 -1
View File
@@ -291,7 +291,7 @@ class Experiment(experiment.AbstractExperiment):
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step: int, rng: jnp.ndarray,
def step(self, global_step: int, rng: jnp.ndarray, # pytype: disable=signature-mismatch # jax-ndarray
*unused_args, **unused_kwargs):
"""See base class."""
+1 -1
View File
@@ -677,7 +677,7 @@ def solve_hamiltonian_ivp_t_eval(
if method == "adaptive":
dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt)
return solve_ivp_t_eval(
return solve_ivp_t_eval( # pytype: disable=bad-return-type # jax-ndarray
fun=dy_dt,
t_span=t_span,
y0=y0,