Make this code compatible with Python 3.10.

PiperOrigin-RevId: 473457191
This commit is contained in:
Yilei Yang
2022-09-10 13:13:36 +00:00
committed by Diego de las Casas
parent 586a1c55de
commit 3af71dd9a7
+2 -2
View File
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities functions for Jax.""" """Utilities functions for Jax."""
import collections from collections import abc
import functools import functools
from typing import Any, Callable, Dict, Mapping, Union from typing import Any, Callable, Dict, Mapping, Union
@@ -176,7 +176,7 @@ def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping): if isinstance(v, abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items()) items.extend(flatten_dict(v, new_key, sep=sep).items())
else: else:
items.append((new_key, v)) items.append((new_key, v))