Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.

PiperOrigin-RevId: 511294746
This commit is contained in:
Peter Hawkins
2023-02-21 21:45:32 +00:00
committed by Saran Tunyasuvunakool
parent 797ea3c71d
commit c051e6a51d
7 changed files with 12 additions and 12 deletions
+1 -1
View File
@@ -677,7 +677,7 @@ def solve_hamiltonian_ivp_t_eval(
if method == "adaptive":
dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt)
return solve_ivp_t_eval(
return solve_ivp_t_eval( # pytype: disable=bad-return-type # jax-ndarray
fun=dy_dt,
t_span=t_span,
y0=y0,