mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-27 18:25:49 +08:00
Update NFNet optim to reflect Haiku default of dict instead of FlatMapping.
PiperOrigin-RevId: 418005668
This commit is contained in:
committed by
Diego de Las Casas
parent
4b1583ae7a
commit
68473a2243
+14
-6
@@ -14,7 +14,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Optimizers and Schedulers, inspired by the PyTorch API."""
|
"""Optimizers and Schedulers, inspired by the PyTorch API."""
|
||||||
from collections import ChainMap # pylint:disable=g-importing-member
|
from collections import ChainMap # pylint:disable=g-importing-member
|
||||||
from typing import Callable
|
from typing import Callable, Mapping
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -38,12 +38,10 @@ class Optimizer(object):
|
|||||||
self.create_param_groups(params, defaults)
|
self.create_param_groups(params, defaults)
|
||||||
# Join params at top-level if params is a list of groups
|
# Join params at top-level if params is a list of groups
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
flatmap = type(hk.data_structures.to_immutable_dict({}))
|
if any(_is_non_empty_two_level_mapping(g['params']) for g in params):
|
||||||
if any([isinstance(group['params'], flatmap) for group in params]):
|
params = hk.data_structures.merge(*[g['params'] for g in params])
|
||||||
params = hk.data_structures.merge(*[group['params']
|
|
||||||
for group in params])
|
|
||||||
else:
|
else:
|
||||||
params = dict(ChainMap(*[group['params'] for group in params]))
|
params = dict(ChainMap(*[g['params'] for g in params]))
|
||||||
# Prepare states
|
# Prepare states
|
||||||
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
|
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
|
||||||
self._states = tree.map_structure_with_path(create_buffers, params)
|
self._states = tree.map_structure_with_path(create_buffers, params)
|
||||||
@@ -153,6 +151,16 @@ class Optimizer(object):
|
|||||||
return utils.split_tree(outs, params, 2)
|
return utils.split_tree(outs, params, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_empty_two_level_mapping(obj):
|
||||||
|
instof = lambda t: lambda v: isinstance(v, t)
|
||||||
|
# Basically: isinstance(obj, Mapping[str, Mapping[str, Any]]) ...
|
||||||
|
return (isinstance(obj, Mapping) and all(map(instof(str), obj.keys())) and
|
||||||
|
all(map(instof(Mapping), obj.values())) and
|
||||||
|
all(map(lambda v: all(map(instof(str), v.keys())), obj.values())) and
|
||||||
|
# ... and has at least one leaf.
|
||||||
|
bool(obj) and any(map(bool, obj.values())))
|
||||||
|
|
||||||
|
|
||||||
class Schedule(object):
|
class Schedule(object):
|
||||||
"""Hyperparameter scheduling objects."""
|
"""Hyperparameter scheduling objects."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user