diff --git a/physics_inspired_models/models/autoregressive.py b/physics_inspired_models/models/autoregressive.py index 0336b86..f34d3e8 100644 --- a/physics_inspired_models/models/autoregressive.py +++ b/physics_inspired_models/models/autoregressive.py @@ -220,7 +220,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel): is_training=is_training) return p_x, z0, decoder_z - def training_objectives( + def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray self, params: hk.Params, state: hk.State, @@ -300,7 +300,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel): include_z0=False, )[0] - def gt_state_and_latents( + def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray self, params: hk.Params, rng: jnp.ndarray, @@ -336,7 +336,7 @@ class TeacherForcingAutoregressiveModel(base.SequenceModel): ) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: return dict(), dict() - def _init_latent_system( + def _init_latent_system( # pytype: disable=signature-mismatch # jax-ndarray self, rng: jnp.ndarray, z: jnp.ndarray, diff --git a/physics_inspired_models/models/base.py b/physics_inspired_models/models/base.py index f987052..31b0a0f 100644 --- a/physics_inspired_models/models/base.py +++ b/physics_inspired_models/models/base.py @@ -341,7 +341,7 @@ class SequenceModel(abc.ABC, Generic[T]): params = hk.data_structures.to_immutable_dict(params) state = hk.data_structures.to_immutable_dict(state) - return params, state + return params, state # pytype: disable=bad-return-type # jax-ndarray def init( self, diff --git a/physics_inspired_models/models/deterministic_vae.py b/physics_inspired_models/models/deterministic_vae.py index a477b97..0aaecea 100644 --- a/physics_inspired_models/models/deterministic_vae.py +++ b/physics_inspired_models/models/deterministic_vae.py @@ -155,8 +155,8 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]): def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray: if self.latent_dynamics_type == "Physics": - return z.q if self.render_from_q_only else z.single_state - return z + return z.q if self.render_from_q_only else z.single_state # pytype: disable=attribute-error # jax-ndarray + return z # pytype: disable=bad-return-type # jax-ndarray @property def inferred_index(self) -> int: @@ -327,7 +327,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]): if num_steps_backward > 0 and not self.can_run_backwards: raise ValueError("This model can not be unrolled backward in time.") - def unroll_latent_dynamics( + def unroll_latent_dynamics( # pytype: disable=signature-mismatch # jax-ndarray self, z: phase_space.PhaseSpace, params: hk.Params, @@ -393,7 +393,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]): z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z return p_x, q_z, self.prior(), z0, z, dyn_stats - def training_objectives( + def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray self, params: utils.Params, state: hk.State, @@ -532,7 +532,7 @@ class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]): include_z0=True, )[0] - def gt_state_and_latents( + def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray self, params: hk.Params, rng: jnp.ndarray, diff --git a/physics_inspired_models/models/dynamics.py b/physics_inspired_models/models/dynamics.py index d5e08ca..d4f6b51 100644 --- a/physics_inspired_models/models/dynamics.py +++ b/physics_inspired_models/models/dynamics.py @@ -610,12 +610,12 @@ class PhysicsSimulationNetwork(hk.Module): y.q, y.p, **nets_kwargs) # Special Haiku magic to avoid tracer issues if hk.running_init(): - return self.lagrangian(y0, **nets_kwargs) + return self.lagrangian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray else: hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs) dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) if hk.running_init(): - return self.hamiltonian(y0, **nets_kwargs) + return self.hamiltonian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray # Optionally switch coordinate frame if self.input_space == "velocity" and self.simulation_space == "momentum":