mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-22 15:21:27 +08:00
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 511294746
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
797ea3c71d
commit
c051e6a51d
@@ -180,11 +180,11 @@ class ByolExperiment:
|
||||
return outputs
|
||||
|
||||
if is_training:
|
||||
outputs_view1 = apply_once_fn(inputs['view1'], '_view1')
|
||||
outputs_view2 = apply_once_fn(inputs['view2'], '_view2')
|
||||
outputs_view1 = apply_once_fn(inputs['view1'], '_view1') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
outputs_view2 = apply_once_fn(inputs['view2'], '_view2') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
return {**outputs_view1, **outputs_view2}
|
||||
else:
|
||||
return apply_once_fn(inputs['images'], '')
|
||||
return apply_once_fn(inputs['images'], '') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
|
||||
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
|
||||
"""Build optimizer from config."""
|
||||
|
||||
@@ -432,8 +432,8 @@ class EvalExperiment:
|
||||
logits = self.forward_classif.apply(classif_params, embeddings)
|
||||
labels = hk.one_hot(inputs['labels'], self._num_classes)
|
||||
loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
|
||||
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
|
||||
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
|
||||
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1) # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5) # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
# NOTE: Returned values will be summed and finally divided by num_samples.
|
||||
return {
|
||||
'eval_loss': loss,
|
||||
|
||||
@@ -31,8 +31,8 @@ def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Filter to exclude biaises and normalizations weights."""
|
||||
del val
|
||||
if path[-1] == "b" or "norm" in path[-2]:
|
||||
return False
|
||||
return True
|
||||
return False # pytype: disable=bad-return-type # jax-ndarray
|
||||
return True # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
|
||||
def _partial_update(updates: optax.Updates,
|
||||
|
||||
@@ -481,10 +481,10 @@ class Optimizer(utils.Stateful):
|
||||
self.step_counter = self.step_counter + 1
|
||||
|
||||
if self.value_func_has_state:
|
||||
return params, self.pop_state(), new_func_state, stats
|
||||
return params, self.pop_state(), new_func_state, stats # pytype: disable=bad-return-type # jax-ndarray
|
||||
else:
|
||||
assert new_func_state is None
|
||||
return params, self.pop_state(), stats
|
||||
return params, self.pop_state(), stats # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
def init(
|
||||
self,
|
||||
|
||||
@@ -728,7 +728,7 @@ class MultimodalPreprocessor(hk.Module):
|
||||
|
||||
# Apply a predictable ordering to the modalities
|
||||
padded_ls = [padded[k] for k in sorted(padded.keys())]
|
||||
return (jnp.concatenate(padded_ls, axis=1),
|
||||
return (jnp.concatenate(padded_ls, axis=1), # pytype: disable=bad-return-type # jax-ndarray
|
||||
modality_sizes,
|
||||
inputs_without_pos)
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
# \__|_| \__,_|_|_| |_|
|
||||
#
|
||||
|
||||
def step(self, global_step: int, rng: jnp.ndarray,
|
||||
def step(self, global_step: int, rng: jnp.ndarray, # pytype: disable=signature-mismatch # jax-ndarray
|
||||
*unused_args, **unused_kwargs):
|
||||
"""See base class."""
|
||||
|
||||
|
||||
@@ -677,7 +677,7 @@ def solve_hamiltonian_ivp_t_eval(
|
||||
if method == "adaptive":
|
||||
dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt)
|
||||
|
||||
return solve_ivp_t_eval(
|
||||
return solve_ivp_t_eval( # pytype: disable=bad-return-type # jax-ndarray
|
||||
fun=dy_dt,
|
||||
t_span=t_span,
|
||||
y0=y0,
|
||||
|
||||
Reference in New Issue
Block a user