Suppress some pytype errors related to jnp.DeviceArray == jax.Array.

PiperOrigin-RevId: 514986169
This commit is contained in:
Peter Hawkins
2023-03-08 10:57:57 +00:00
committed by Saran Tunyasuvunakool
parent 6f0ddef7da
commit 784f67565e
3 changed files with 4 additions and 4 deletions
+2 -2
View File
@@ -62,7 +62,7 @@ def multinomial_mode(
joint-highest.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
probs = distribution_or_probs.probs_parameter()
probs = distribution_or_probs.probs_parameter() # pytype: disable=attribute-error # jax-devicearray
else:
probs = distribution_or_probs
max_prob = jnp.max(probs, axis=1, keepdims=True)
@@ -86,7 +86,7 @@ def multinomial_class(
probability.
"""
if isinstance(distribution_or_probs, tfd.Distribution):
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1)
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1) # pytype: disable=attribute-error # jax-devicearray
return jnp.argmax(distribution_or_probs, axis=1)
+1 -1
View File
@@ -161,7 +161,7 @@ class SUnit3D(hk.Module):
x = self._activation_fn(x)
if self._self_gating_fn:
x = self._self_gating_fn(x)
return x
return x # pytype: disable=bad-return-type # jax-devicearray
class InceptionBlockV13D(hk.Module):
+1 -1
View File
@@ -211,7 +211,7 @@ class TSMResNetUnit(hk.Module):
num_frames=self._num_frames,
name=f'block_{idx_block}')(
net, is_training=is_training)
return net
return net # pytype: disable=bad-return-type # jax-devicearray
class TSMResNetV2(hk.Module):