mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 512349622
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
c051e6a51d
commit
6f0ddef7da
@@ -220,7 +220,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
||||
is_training=is_training)
|
||||
return p_x, z0, decoder_z
|
||||
|
||||
def training_objectives(
|
||||
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
params: hk.Params,
|
||||
state: hk.State,
|
||||
@@ -300,7 +300,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
||||
include_z0=False,
|
||||
)[0]
|
||||
|
||||
def gt_state_and_latents(
|
||||
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
params: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
@@ -336,7 +336,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
||||
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
|
||||
return dict(), dict()
|
||||
|
||||
def _init_latent_system(
|
||||
def _init_latent_system( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
z: jnp.ndarray,
|
||||
|
||||
@@ -341,7 +341,7 @@ class SequenceModel(abc.ABC, Generic[T]):
|
||||
params = hk.data_structures.to_immutable_dict(params)
|
||||
state = hk.data_structures.to_immutable_dict(state)
|
||||
|
||||
return params, state
|
||||
return params, state # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
def init(
|
||||
self,
|
||||
|
||||
@@ -155,8 +155,8 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
||||
|
||||
def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray:
|
||||
if self.latent_dynamics_type == "Physics":
|
||||
return z.q if self.render_from_q_only else z.single_state
|
||||
return z
|
||||
return z.q if self.render_from_q_only else z.single_state # pytype: disable=attribute-error # jax-ndarray
|
||||
return z # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
@property
|
||||
def inferred_index(self) -> int:
|
||||
@@ -327,7 +327,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
||||
if num_steps_backward > 0 and not self.can_run_backwards:
|
||||
raise ValueError("This model can not be unrolled backward in time.")
|
||||
|
||||
def unroll_latent_dynamics(
|
||||
def unroll_latent_dynamics( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
z: phase_space.PhaseSpace,
|
||||
params: hk.Params,
|
||||
@@ -393,7 +393,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
||||
z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z
|
||||
return p_x, q_z, self.prior(), z0, z, dyn_stats
|
||||
|
||||
def training_objectives(
|
||||
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
params: utils.Params,
|
||||
state: hk.State,
|
||||
@@ -532,7 +532,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
||||
include_z0=True,
|
||||
)[0]
|
||||
|
||||
def gt_state_and_latents(
|
||||
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
|
||||
self,
|
||||
params: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
|
||||
@@ -610,12 +610,12 @@ class PhysicsSimulationNetwork(hk.Module):
|
||||
y.q, y.p, **nets_kwargs)
|
||||
# Special Haiku magic to avoid tracer issues
|
||||
if hk.running_init():
|
||||
return self.lagrangian(y0, **nets_kwargs)
|
||||
return self.lagrangian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
|
||||
else:
|
||||
hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
|
||||
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
|
||||
if hk.running_init():
|
||||
return self.hamiltonian(y0, **nets_kwargs)
|
||||
return self.hamiltonian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
|
||||
|
||||
# Optionally switch coordinate frame
|
||||
if self.input_space == "velocity" and self.simulation_space == "momentum":
|
||||
|
||||
Reference in New Issue
Block a user