diff --git a/ogb_lsc/pcq/dataset_utils.py b/ogb_lsc/pcq/dataset_utils.py index fe5a470..ce66a70 100644 --- a/ogb_lsc/pcq/dataset_utils.py +++ b/ogb_lsc/pcq/dataset_utils.py @@ -14,6 +14,7 @@ """Dataset utilities.""" +import functools from typing import List, Optional import jax @@ -33,6 +34,9 @@ import conformer_utils import datasets +curry = lambda f: functools.partial(functools.partial, f) + + def build_dataset_iterator( data_root: str, split: str, @@ -196,7 +200,7 @@ def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor: return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0] -@jax.curry(jax.tree_map) +@curry(jax.tree_map) def _downcast_ints(x): if x.dtype == tf.int64: return tf.cast(x, tf.int32)