mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 14:45:25 +08:00
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 511294746
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
797ea3c71d
commit
c051e6a51d
@@ -180,11 +180,11 @@ class ByolExperiment:
|
||||
return outputs
|
||||
|
||||
if is_training:
|
||||
outputs_view1 = apply_once_fn(inputs['view1'], '_view1')
|
||||
outputs_view2 = apply_once_fn(inputs['view2'], '_view2')
|
||||
outputs_view1 = apply_once_fn(inputs['view1'], '_view1') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
outputs_view2 = apply_once_fn(inputs['view2'], '_view2') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
return {**outputs_view1, **outputs_view2}
|
||||
else:
|
||||
return apply_once_fn(inputs['images'], '')
|
||||
return apply_once_fn(inputs['images'], '') # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
|
||||
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
|
||||
"""Build optimizer from config."""
|
||||
|
||||
@@ -432,8 +432,8 @@ class EvalExperiment:
|
||||
logits = self.forward_classif.apply(classif_params, embeddings)
|
||||
labels = hk.one_hot(inputs['labels'], self._num_classes)
|
||||
loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
|
||||
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
|
||||
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
|
||||
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1) # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5) # pytype: disable=wrong-arg-types # jax-ndarray
|
||||
# NOTE: Returned values will be summed and finally divided by num_samples.
|
||||
return {
|
||||
'eval_loss': loss,
|
||||
|
||||
@@ -31,8 +31,8 @@ def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Filter to exclude biaises and normalizations weights."""
|
||||
del val
|
||||
if path[-1] == "b" or "norm" in path[-2]:
|
||||
return False
|
||||
return True
|
||||
return False # pytype: disable=bad-return-type # jax-ndarray
|
||||
return True # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
|
||||
def _partial_update(updates: optax.Updates,
|
||||
|
||||
Reference in New Issue
Block a user