Make this code compatible with Python 3.10.

PiperOrigin-RevId: 473457191
This commit is contained in:
Yilei Yang
2022-09-10 13:13:36 +00:00
committed by Diego de las Casas
parent 586a1c55de
commit 3af71dd9a7

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities functions for Jax."""
import collections
from collections import abc
import functools
from typing import Any, Callable, Dict, Mapping, Union
@@ -176,7 +176,7 @@ def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
if isinstance(v, abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))