mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 22:58:56 +08:00
Replace references to deprecated jax.curry function.
This is deprecated as of https://github.com/google/jax/pull/15263 PiperOrigin-RevId: 520269385
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
82a347438f
commit
9d01171d43
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Dataset utilities."""
|
||||
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
|
||||
import jax
|
||||
@@ -33,6 +34,9 @@ import conformer_utils
|
||||
import datasets
|
||||
|
||||
|
||||
curry = lambda f: functools.partial(functools.partial, f)
|
||||
|
||||
|
||||
def build_dataset_iterator(
|
||||
data_root: str,
|
||||
split: str,
|
||||
@@ -196,7 +200,7 @@ def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor:
|
||||
return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0]
|
||||
|
||||
|
||||
@jax.curry(jax.tree_map)
|
||||
@curry(jax.tree_map)
|
||||
def _downcast_ints(x):
|
||||
if x.dtype == tf.int64:
|
||||
return tf.cast(x, tf.int32)
|
||||
|
||||
Reference in New Issue
Block a user