mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 512349622
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
c051e6a51d
commit
6f0ddef7da
@@ -220,7 +220,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
|||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
return p_x, z0, decoder_z
|
return p_x, z0, decoder_z
|
||||||
|
|
||||||
def training_objectives(
|
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
params: hk.Params,
|
params: hk.Params,
|
||||||
state: hk.State,
|
state: hk.State,
|
||||||
@@ -300,7 +300,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
|||||||
include_z0=False,
|
include_z0=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
def gt_state_and_latents(
|
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
params: hk.Params,
|
params: hk.Params,
|
||||||
rng: jnp.ndarray,
|
rng: jnp.ndarray,
|
||||||
@@ -336,7 +336,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
|||||||
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
|
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
|
||||||
return dict(), dict()
|
return dict(), dict()
|
||||||
|
|
||||||
def _init_latent_system(
|
def _init_latent_system( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
rng: jnp.ndarray,
|
rng: jnp.ndarray,
|
||||||
z: jnp.ndarray,
|
z: jnp.ndarray,
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ class SequenceModel(abc.ABC, Generic[T]):
|
|||||||
params = hk.data_structures.to_immutable_dict(params)
|
params = hk.data_structures.to_immutable_dict(params)
|
||||||
state = hk.data_structures.to_immutable_dict(state)
|
state = hk.data_structures.to_immutable_dict(state)
|
||||||
|
|
||||||
return params, state
|
return params, state # pytype: disable=bad-return-type # jax-ndarray
|
||||||
|
|
||||||
def init(
|
def init(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -155,8 +155,8 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
|||||||
|
|
||||||
def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray:
|
def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray:
|
||||||
if self.latent_dynamics_type == "Physics":
|
if self.latent_dynamics_type == "Physics":
|
||||||
return z.q if self.render_from_q_only else z.single_state
|
return z.q if self.render_from_q_only else z.single_state # pytype: disable=attribute-error # jax-ndarray
|
||||||
return z
|
return z # pytype: disable=bad-return-type # jax-ndarray
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inferred_index(self) -> int:
|
def inferred_index(self) -> int:
|
||||||
@@ -327,7 +327,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
|||||||
if num_steps_backward > 0 and not self.can_run_backwards:
|
if num_steps_backward > 0 and not self.can_run_backwards:
|
||||||
raise ValueError("This model can not be unrolled backward in time.")
|
raise ValueError("This model can not be unrolled backward in time.")
|
||||||
|
|
||||||
def unroll_latent_dynamics(
|
def unroll_latent_dynamics( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
z: phase_space.PhaseSpace,
|
z: phase_space.PhaseSpace,
|
||||||
params: hk.Params,
|
params: hk.Params,
|
||||||
@@ -393,7 +393,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
|||||||
z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z
|
z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z
|
||||||
return p_x, q_z, self.prior(), z0, z, dyn_stats
|
return p_x, q_z, self.prior(), z0, z, dyn_stats
|
||||||
|
|
||||||
def training_objectives(
|
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
params: utils.Params,
|
params: utils.Params,
|
||||||
state: hk.State,
|
state: hk.State,
|
||||||
@@ -532,7 +532,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
|
|||||||
include_z0=True,
|
include_z0=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
def gt_state_and_latents(
|
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
|
||||||
self,
|
self,
|
||||||
params: hk.Params,
|
params: hk.Params,
|
||||||
rng: jnp.ndarray,
|
rng: jnp.ndarray,
|
||||||
|
|||||||
@@ -610,12 +610,12 @@ class PhysicsSimulationNetwork(hk.Module):
|
|||||||
y.q, y.p, **nets_kwargs)
|
y.q, y.p, **nets_kwargs)
|
||||||
# Special Haiku magic to avoid tracer issues
|
# Special Haiku magic to avoid tracer issues
|
||||||
if hk.running_init():
|
if hk.running_init():
|
||||||
return self.lagrangian(y0, **nets_kwargs)
|
return self.lagrangian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
|
||||||
else:
|
else:
|
||||||
hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
|
hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
|
||||||
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
|
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
|
||||||
if hk.running_init():
|
if hk.running_init():
|
||||||
return self.hamiltonian(y0, **nets_kwargs)
|
return self.hamiltonian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
|
||||||
|
|
||||||
# Optionally switch coordinate frame
|
# Optionally switch coordinate frame
|
||||||
if self.input_space == "velocity" and self.simulation_space == "momentum":
|
if self.input_space == "velocity" and self.simulation_space == "momentum":
|
||||||
|
|||||||
Reference in New Issue
Block a user