mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-22 07:11:25 +08:00
Suppress some pytype errors related to jnp.DeviceArray == jax.Array.
PiperOrigin-RevId: 514986169
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
6f0ddef7da
commit
784f67565e
@@ -62,7 +62,7 @@ def multinomial_mode(
|
||||
joint-highest.
|
||||
"""
|
||||
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||
probs = distribution_or_probs.probs_parameter()
|
||||
probs = distribution_or_probs.probs_parameter() # pytype: disable=attribute-error # jax-devicearray
|
||||
else:
|
||||
probs = distribution_or_probs
|
||||
max_prob = jnp.max(probs, axis=1, keepdims=True)
|
||||
@@ -86,7 +86,7 @@ def multinomial_class(
|
||||
probability.
|
||||
"""
|
||||
if isinstance(distribution_or_probs, tfd.Distribution):
|
||||
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1)
|
||||
return jnp.argmax(distribution_or_probs.logits_parameter(), axis=1) # pytype: disable=attribute-error # jax-devicearray
|
||||
return jnp.argmax(distribution_or_probs, axis=1)
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -161,7 +161,7 @@ class SUnit3D(hk.Module):
|
||||
x = self._activation_fn(x)
|
||||
if self._self_gating_fn:
|
||||
x = self._self_gating_fn(x)
|
||||
return x
|
||||
return x # pytype: disable=bad-return-type # jax-devicearray
|
||||
|
||||
|
||||
class InceptionBlockV13D(hk.Module):
|
||||
|
||||
@@ -211,7 +211,7 @@ class TSMResNetUnit(hk.Module):
|
||||
num_frames=self._num_frames,
|
||||
name=f'block_{idx_block}')(
|
||||
net, is_training=is_training)
|
||||
return net
|
||||
return net # pytype: disable=bad-return-type # jax-devicearray
|
||||
|
||||
|
||||
class TSMResNetV2(hk.Module):
|
||||
|
||||
Reference in New Issue
Block a user