mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Update NFNet optim to reflect Haiku default of dict instead of FlatMapping.
PiperOrigin-RevId: 418005668
This commit is contained in:
committed by
Diego de Las Casas
parent
4b1583ae7a
commit
68473a2243
@@ -14,7 +14,7 @@
|
||||
# ==============================================================================
|
||||
"""Optimizers and Schedulers, inspired by the PyTorch API."""
|
||||
from collections import ChainMap # pylint:disable=g-importing-member
|
||||
from typing import Callable
|
||||
from typing import Callable, Mapping
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -38,12 +38,10 @@ class Optimizer(object):
|
||||
self.create_param_groups(params, defaults)
|
||||
# Join params at top-level if params is a list of groups
|
||||
if isinstance(params, list):
|
||||
flatmap = type(hk.data_structures.to_immutable_dict({}))
|
||||
if any([isinstance(group['params'], flatmap) for group in params]):
|
||||
params = hk.data_structures.merge(*[group['params']
|
||||
for group in params])
|
||||
if any(_is_non_empty_two_level_mapping(g['params']) for g in params):
|
||||
params = hk.data_structures.merge(*[g['params'] for g in params])
|
||||
else:
|
||||
params = dict(ChainMap(*[group['params'] for group in params]))
|
||||
params = dict(ChainMap(*[g['params'] for g in params]))
|
||||
# Prepare states
|
||||
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
|
||||
self._states = tree.map_structure_with_path(create_buffers, params)
|
||||
@@ -153,6 +151,16 @@ class Optimizer(object):
|
||||
return utils.split_tree(outs, params, 2)
|
||||
|
||||
|
||||
def _is_non_empty_two_level_mapping(obj):
|
||||
instof = lambda t: lambda v: isinstance(v, t)
|
||||
# Basically: isinstance(obj, Mapping[str, Mapping[str, Any]]) ...
|
||||
return (isinstance(obj, Mapping) and all(map(instof(str), obj.keys())) and
|
||||
all(map(instof(Mapping), obj.values())) and
|
||||
all(map(lambda v: all(map(instof(str), v.keys())), obj.values())) and
|
||||
# ... and has at least one leaf.
|
||||
bool(obj) and any(map(bool, obj.values())))
|
||||
|
||||
|
||||
class Schedule(object):
|
||||
"""Hyperparameter scheduling objects."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user