Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.

PiperOrigin-RevId: 512349622
This commit is contained in:
Peter Hawkins
2023-02-26 01:47:16 +00:00
committed by Saran Tunyasuvunakool
parent c051e6a51d
commit 6f0ddef7da
4 changed files with 11 additions and 11 deletions
@@ -220,7 +220,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
is_training=is_training) is_training=is_training)
return p_x, z0, decoder_z return p_x, z0, decoder_z
def training_objectives( def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
params: hk.Params, params: hk.Params,
state: hk.State, state: hk.State,
@@ -300,7 +300,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
include_z0=False, include_z0=False,
)[0] )[0]
def gt_state_and_latents( def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
params: hk.Params, params: hk.Params,
rng: jnp.ndarray, rng: jnp.ndarray,
@@ -336,7 +336,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel):
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: ) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
return dict(), dict() return dict(), dict()
def _init_latent_system( def _init_latent_system( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
rng: jnp.ndarray, rng: jnp.ndarray,
z: jnp.ndarray, z: jnp.ndarray,
+1 -1
View File
@@ -341,7 +341,7 @@ class SequenceModel(abc.ABC, Generic[T]):
params = hk.data_structures.to_immutable_dict(params) params = hk.data_structures.to_immutable_dict(params)
state = hk.data_structures.to_immutable_dict(state) state = hk.data_structures.to_immutable_dict(state)
return params, state return params, state # pytype: disable=bad-return-type # jax-ndarray
def init( def init(
self, self,
@@ -155,8 +155,8 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray: def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray:
if self.latent_dynamics_type == "Physics": if self.latent_dynamics_type == "Physics":
return z.q if self.render_from_q_only else z.single_state return z.q if self.render_from_q_only else z.single_state # pytype: disable=attribute-error # jax-ndarray
return z return z # pytype: disable=bad-return-type # jax-ndarray
@property @property
def inferred_index(self) -> int: def inferred_index(self) -> int:
@@ -327,7 +327,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
if num_steps_backward > 0 and not self.can_run_backwards: if num_steps_backward > 0 and not self.can_run_backwards:
raise ValueError("This model can not be unrolled backward in time.") raise ValueError("This model can not be unrolled backward in time.")
def unroll_latent_dynamics( def unroll_latent_dynamics( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
z: phase_space.PhaseSpace, z: phase_space.PhaseSpace,
params: hk.Params, params: hk.Params,
@@ -393,7 +393,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z
return p_x, q_z, self.prior(), z0, z, dyn_stats return p_x, q_z, self.prior(), z0, z, dyn_stats
def training_objectives( def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
params: utils.Params, params: utils.Params,
state: hk.State, state: hk.State,
@@ -532,7 +532,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]):
include_z0=True, include_z0=True,
)[0] )[0]
def gt_state_and_latents( def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
self, self,
params: hk.Params, params: hk.Params,
rng: jnp.ndarray, rng: jnp.ndarray,
+2 -2
View File
@@ -610,12 +610,12 @@ class PhysicsSimulationNetwork(hk.Module):
y.q, y.p, **nets_kwargs) y.q, y.p, **nets_kwargs)
# Special Haiku magic to avoid tracer issues # Special Haiku magic to avoid tracer issues
if hk.running_init(): if hk.running_init():
return self.lagrangian(y0, **nets_kwargs) return self.lagrangian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
else: else:
hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs) hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
if hk.running_init(): if hk.running_init():
return self.hamiltonian(y0, **nets_kwargs) return self.hamiltonian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
# Optionally switch coordinate frame # Optionally switch coordinate frame
if self.input_space == "velocity" and self.simulation_space == "momentum": if self.input_space == "velocity" and self.simulation_space == "momentum":