diff --git a/physics_inspired_models/utils.py b/physics_inspired_models/utils.py index 7d2aad1..a308607 100644 --- a/physics_inspired_models/utils.py +++ b/physics_inspired_models/utils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities functions for Jax.""" -import collections +from collections import abc import functools from typing import Any, Callable, Dict, Mapping, Union @@ -176,7 +176,7 @@ def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]: items = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k - if isinstance(v, collections.MutableMapping): + if isinstance(v, abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v))