Update NFNet optim to reflect Haiku default of dict instead of FlatMapping.

PiperOrigin-RevId: 418005668
This commit is contained in:
Loren Maggiore
2021-12-23 15:34:25 +00:00
committed by Diego de Las Casas
parent 4b1583ae7a
commit 68473a2243

View File

@@ -14,7 +14,7 @@
# ==============================================================================
"""Optimizers and Schedulers, inspired by the PyTorch API."""
from collections import ChainMap # pylint:disable=g-importing-member
from typing import Callable
from typing import Callable, Mapping
import haiku as hk
import jax
import jax.numpy as jnp
@@ -38,12 +38,10 @@ class Optimizer(object):
self.create_param_groups(params, defaults)
# Join params at top-level if params is a list of groups
if isinstance(params, list):
flatmap = type(hk.data_structures.to_immutable_dict({}))
if any([isinstance(group['params'], flatmap) for group in params]):
params = hk.data_structures.merge(*[group['params']
for group in params])
if any(_is_non_empty_two_level_mapping(g['params']) for g in params):
params = hk.data_structures.merge(*[g['params'] for g in params])
else:
params = dict(ChainMap(*[group['params'] for group in params]))
params = dict(ChainMap(*[g['params'] for g in params]))
# Prepare states
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
self._states = tree.map_structure_with_path(create_buffers, params)
@@ -153,6 +151,16 @@ class Optimizer(object):
return utils.split_tree(outs, params, 2)
def _is_non_empty_two_level_mapping(obj):
instof = lambda t: lambda v: isinstance(v, t)
# Basically: isinstance(obj, Mapping[str, Mapping[str, Any]]) ...
return (isinstance(obj, Mapping) and all(map(instof(str), obj.keys())) and
all(map(instof(Mapping), obj.values())) and
all(map(lambda v: all(map(instof(str), v.keys())), obj.values())) and
# ... and has at least one leaf.
bool(obj) and any(map(bool, obj.values())))
class Schedule(object):
"""Hyperparameter scheduling objects."""