Replace references to deprecated jax.curry function.

This is deprecated as of https://github.com/google/jax/pull/15263

PiperOrigin-RevId: 520269385
This commit is contained in:
Jake VanderPlas
2023-03-29 10:03:36 +01:00
committed by Saran Tunyasuvunakool
parent 82a347438f
commit 9d01171d43
+5 -1
View File
@@ -14,6 +14,7 @@
"""Dataset utilities."""
import functools
from typing import List, Optional
import jax
@@ -33,6 +34,9 @@ import conformer_utils
import datasets
curry = lambda f: functools.partial(functools.partial, f)
def build_dataset_iterator(
data_root: str,
split: str,
@@ -196,7 +200,7 @@ def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor:
return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0]
@jax.curry(jax.tree_map)
@curry(jax.tree_map)
def _downcast_ints(x):
if x.dtype == tf.int64:
return tf.cast(x, tf.int32)