Files
deepmind-research/nfnets/optim.py
Louise Deason abc9ba0635 Initial release of "nfnets".
PiperOrigin-RevId: 356975781
2021-02-11 16:42:59 +00:00

456 lines
18 KiB
Python

# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers and Schedulers, inspired by the PyTorch API."""
from collections import ChainMap # pylint:disable=g-importing-member
from typing import Callable
import haiku as hk
import jax
import jax.numpy as jnp
import tree
from nfnets import utils
class Optimizer(object):
"""Optimizer base class."""
def __init__(self, params, defaults):
# Flag indicating if parameters have been broadcasted
self._broadcasted = False
# Optimizer hyperparameters; this is a dict to support using param_groups
self._hyperparameters = {}
# Mapping from model parameters to optimizer hyperparameters
self._params2hyperparams = {}
# Assign defaults
self._hyperparameters = dict(**defaults)
# Prepare parameter groups and mappings
self.create_param_groups(params, defaults)
# Join params at top-level if params is a list of groups
if isinstance(params, list):
flatmap = type(hk.data_structures.to_immutable_dict({}))
if any([isinstance(group['params'], flatmap) for group in params]):
params = hk.data_structures.merge(*[group['params']
for group in params])
else:
params = dict(ChainMap(*[group['params'] for group in params]))
# Prepare states
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
self._states = tree.map_structure_with_path(create_buffers, params)
def add_hyperparam_group(self, group, suffix, defaults):
"""Adds new hyperparameters to the hyperparams dict."""
# Use default hyperparams unless overridden by group hyperparams
group_dict = {key: key for key in defaults if key not in group}
for key in group:
if key != 'params': # Reserved keyword 'params'
group_dict[key] = '%s_%s' % (key, suffix)
self._hyperparameters[group_dict[key]] = group[key]
# Set up params2hyperparams
def set_p2h(k, _):
self._params2hyperparams['/'.join(k)] = group_dict
tree.map_structure_with_path(set_p2h, group['params'])
def create_param_groups(self, params, defaults):
"""Creates param-hyperparam mappings."""
if isinstance(params, list):
for group_index, group in enumerate(params):
# Add group to hyperparams and get this group's full hyperparameters
self.add_hyperparam_group(group, group_index, defaults)
else:
mapping = {key: key for key in self._hyperparameters}
def set_p2h(k, _):
self._params2hyperparams['/'.join(k)] = mapping
tree.map_structure_with_path(set_p2h, params)
def create_buffers(self, name, params):
"""Method to be overridden by child classes."""
pass
def get_opt_params(self, param_name, itr):
"""Returns hyperparams corresponding to param_name."""
mapping = self._params2hyperparams[param_name]
output = {}
for key in mapping:
hyper = self._hyperparameters[mapping[key]]
# Handle the case where a hyper is a class, for hybrids
if isinstance(hyper, Callable) and not isinstance(hyper, type):
output[key] = hyper(itr)
else:
output[key] = hyper
return output
def get_hyper(self, param_name, hyper_name):
"""Get an individual hyperparam for a given param."""
mapping = self._params2hyperparams[param_name]
return self._hyperparameters[mapping[hyper_name]]
def plugin(self, states):
self._states = states
def states(self):
return self._states
def broadcast(self):
"""Brodcasts all buffers and parameters."""
self._broadcasted = True
for name, state in self._states.items():
self._states[name] = {key: utils.broadcast(state[key]) for key in state}
def gather(self):
"""Gathers state (if broadcasted) for saving."""
states = {}
for name in self._states:
state = self._states[name]
states[name] = {key: state[key] if state[key] is None else state[key][0]
for key in state}
return states
def __setattr__(self, name, value):
"""Overrides the object's set-attribute function to register states, etc."""
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
self._hyperparameters[name] = value
elif '_states' in self.__dict__ and name in self._states:
self._states[name] = value
else:
object.__setattr__(self, name, value)
def __getattr__(self, name):
"""Override the object's get-attribute function to return states, etc."""
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
return self._hyperparameters[name]
elif '_states' in self.__dict__ and name in self._states:
return self._states[name]
else:
object.__getattr__(self, name)
def step(self, params, grads, states, itr=None):
"""Takes a single optimizer step.
Args:
params: a dict containing the parameters to be updated.
grads: a dict containing the gradients for each parameter in params.
states: a dict containing any optimizer buffers (momentum, etc) for
each parameter in params.
itr: an optional integer indicating the current step, for scheduling.
Returns:
The updated params and optimizer buffers.
"""
get_hyper = lambda k, v: self.get_opt_params('/'.join(k), itr)
hypers = tree.map_structure_with_path(get_hyper, params)
outs = tree.map_structure_up_to(params, self.update_param,
params, grads, states, hypers)
return utils.split_tree(outs, params, 2)
class Schedule(object):
"""Hyperparameter scheduling objects."""
class CosineDecay(Schedule):
"""Cosine decay."""
def __init__(self, min_val, max_val, num_steps):
self.min_val = min_val
self.max_val = max_val
self.num_steps = num_steps
def __call__(self, itr):
cos = (1 + jnp.cos(jnp.pi * itr / self.num_steps))
return 0.5 * (self.max_val - self.min_val) * cos + self.min_val
class WarmupCosineDecay(Schedule):
"""Cosine decay with linear warmup."""
def __init__(self, start_val, min_val, max_val, num_steps, warmup_steps):
self.start_val = start_val
self.min_val = min_val
self.max_val = max_val
self.num_steps = num_steps
self.warmup_steps = warmup_steps
def __call__(self, itr):
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
+ self.start_val)
cos_itr = (itr - self.warmup_steps) / (self.num_steps - self.warmup_steps)
cos = 1 + jnp.cos(jnp.pi * cos_itr)
cos_val = 0.5 * (self.max_val - self.min_val) * cos + self.min_val
# Select warmup_val if itr < warmup, else cosine val
values = jnp.array([warmup_val, cos_val])
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
return jnp.take(values, index)
class WarmupExpDecay(Schedule):
"""Exponential step decay with linear warmup."""
def __init__(self, start_val, max_val, warmup_steps,
decay_factor, decay_interval):
self.start_val = start_val
self.max_val = max_val
self.warmup_steps = warmup_steps
self.decay_factor = decay_factor
self.decay_interval = decay_interval
def __call__(self, itr):
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
+ self.start_val)
# How many decay steps have we taken?
num_decays = jnp.floor((itr - self.warmup_steps) / self.decay_interval)
exp_val = self.max_val * (self.decay_factor ** num_decays)
# Select warmup_val if itr < warmup, else exp_val
values = jnp.array([warmup_val, exp_val])
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
return jnp.take(values, index)
class SGD(Optimizer):
"""Standard SGD with (nesterov) momentum and weight decay.
Attributes:
params: Either a dict mapping param names to JAX tensors, or a list where
each member of the list is a dict containing parameters
and hyperparameters, allowing one to specify param-specific hyperparams.
lr: Learning rate.
weight_decay: Weight decay parameter. Note that this is decay, not L2 reg.
momentum: Momentum parameter
dampening: Dampening parameter
nesterov: Bool indicating this optimizer will use the NAG formulation.
"""
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
'nesterov': None}
def __init__(self, params, lr, weight_decay=None,
momentum=None, dampening=0, nesterov=None):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay,
'momentum': momentum, 'dampening': dampening,
'nesterov': nesterov})
def create_buffers(self, name, param):
"""Prepares all momentum buffers for each parameter."""
state = {'step': jnp.zeros(jax.local_device_count())}
if self.get_hyper(name, 'momentum') is not None:
state['momentum'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
# Apply weight decay
if opt_params.get('weight_decay') is not None:
grad = grad + param * opt_params['weight_decay']
# Update momentum buffers if needed
if 'momentum' in state:
state['momentum'] = (opt_params['momentum'] * state['momentum']
+ (1 - opt_params['dampening']) * grad)
if opt_params['nesterov'] is not None:
grad = grad + opt_params['momentum'] * state['momentum']
else:
grad = state['momentum']
state['step'] += 1
return param - opt_params['lr'] * grad, state
class Adam(Optimizer):
"""Adam optimizer, Kingma & Ba, arxiv.org/abs/1412.6980.
Args:
params (iterable): nested list of params to optimize
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (default: 0)
use_adamw (bool, optional): If not None, use decoupled weight decay
as in arxiv.org/abs/1711.05101. The paper version adds an additional
"schedule" hyperparameter eta, which we instead just replace with the
learning rate following the PyTorch implementation.
Note that this implementation will not instantiate a buffer if the
beta term for that buffer is passed in as None, thus conserving memory.
"""
defaults = {'beta1': 0.9, 'beta2': 0.999, 'weight_decay': None, 'eps': 1e-8,
'use_adamw': None}
def __init__(self, params, lr, beta1=0.9, beta2=0.999,
eps=1e-8, weight_decay=None, use_adamw=None):
super().__init__(params=params,
defaults={'lr': lr, 'beta1': beta1,
'beta2': beta2, 'eps': eps,
'weight_decay': weight_decay,
'use_adamw': use_adamw})
def create_buffers(self, name, param):
"""Prepare exp_avg and exp_avg_sq buffers."""
state = {'step': jnp.zeros(jax.local_device_count())}
if self.get_hyper(name, 'beta1') is not None:
state['exp_avg'] = jnp.zeros_like(param)
if self.get_hyper(name, 'beta2') is not None:
state['exp_avg_sq'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
state['step'] = state['step'] + 1
# Apply weight decay
if opt_params.get('weight_decay') is not None:
if opt_params.get('use_adamw') is not None:
param = param * (1 - opt_params['lr'] * opt_params['weight_decay'])
else:
grad = grad + param * opt_params['weight_decay']
# First moment
if 'exp_avg' in state:
bias_correction1 = 1 - opt_params['beta1'] ** state['step']
state['exp_avg'] = (state['exp_avg'] * opt_params['beta1']
+ (1 - opt_params['beta1']) * grad)
step_size = opt_params['lr'] * state['exp_avg'] / bias_correction1
else:
step_size = opt_params['lr'] * grad
# Second moment
if 'exp_avg_sq' in state:
bias_correction2 = 1 - opt_params['beta2'] ** state['step']
state['exp_avg_sq'] = (state['exp_avg_sq'] * opt_params['beta2']
+ (1 - opt_params['beta2']) * grad * grad)
denom = jnp.sqrt(state['exp_avg_sq']) * jax.lax.rsqrt(bias_correction2)
denom = denom + opt_params['eps']
else:
denom = jnp.abs(grad) + opt_params['eps'] # Add eps to avoid divide-by-0
return param - step_size / denom, state
class RMSProp(Optimizer):
"""RMSProp optimizer, Tieleman and Hinton, ref: powerpoint slides.
Implements RMSProp as
rms = decay * rms{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(rms + epsilon)
param -= mom
Note that the rms buffer is initialized with ones as in TF, as opposed to
zeros as in all other implementations.
Args:
params (iterable): nested list of params to optimize
lr (float): learning rate (default: 1e-3)
decay (float): EMA decay rate for running estimate of squared gradient.
momentum (float or None): Use heavy ball momentum instead of instant grad.
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (NOT ADAMW (default: 0))
"""
defaults = {'weight_decay': None, 'eps': 1e-8}
def __init__(self, params, lr, decay, momentum, weight_decay=None, eps=1e-8):
super().__init__(params=params,
defaults={'lr': lr, 'decay': decay,
'momentum': momentum, 'eps': eps,
'weight_decay': weight_decay})
def create_buffers(self, name, param):
"""Prepare exp_avg and exp_avg_sq buffers."""
state = {'step': jnp.zeros(jax.local_device_count())}
state['rms'] = jnp.ones_like(param)
if self.get_hyper(name, 'momentum') is not None:
state['momentum'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
state['step'] = state['step'] + 1
# Apply weight decay
if opt_params.get('weight_decay') is not None:
grad = grad + param * opt_params['weight_decay']
# EMA of the squared gradient
state['rms'] = (state['rms'] * opt_params['decay']
+ (1 - opt_params['decay']) * (grad ** 2))
scaled_grad = (opt_params['lr'] * grad
/ (state['rms'] + opt_params['eps']) ** 0.5)
if state['momentum'] is not None:
state['momentum'] = (state['momentum'] * opt_params['momentum']
+ scaled_grad)
step_size = state['momentum']
else:
step_size = scaled_grad
return param - step_size, state
class Fromage(Optimizer):
"""Fromage optimizer, Bernstein et al. arXiv.org/abs/2002.03432.
This version optionally includes weight decay.
Attributes:
params (iterable): nested list of params to optimize
lr (float): learning rate.
eps (float, optional): Minimum allowable norm. This term is required for
in case parameters are zero-initialized (default: 1e-5).
weight_decay (float, optional): weight decay (default: 0).
"""
defaults = {'weight_decay': None, 'eps': 1e-5}
def __init__(self, params, lr, weight_decay=None, eps=1e-5):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay, 'eps': eps})
def create_buffers(self, name, param): # pylint: disable=unused-argument
"""Prepares all momentum buffers for each parameter."""
return {'step': jnp.zeros(1)}
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
if opt_params['weight_decay'] is not None:
grad = grad + param * opt_params['weight_decay']
grad_norm = jnp.maximum(jnp.linalg.norm(grad), opt_params['eps'])
param_norm = jnp.maximum(jnp.linalg.norm(param), opt_params['eps'])
mult = jax.lax.rsqrt(1 + opt_params['lr'] ** 2)
out = (param - opt_params['lr'] * grad * (param_norm / grad_norm)) * mult
return out, state
class Hybrid(Optimizer):
"""Optimizer which permits passing param groups with different base opts.
The API for this class follows the case for any other optimizer where one
specifies a list of dicts with separate hyperparams, but in this case it
requires the user to also specify an 'opt' key for each group, such as e.g.
[{'params': params0, 'opt': optim.Adam, 'lr': 0.1}].
The user must also provide values for any arg in the selected optimizers which
does not have a default value associated
"""
def __init__(self, param_groups):
if any(['opt' not in group for group in param_groups]):
raise ValueError('All parameter groups must have an opt key!')
self.defaults = ChainMap(*[group['opt'].defaults for group in param_groups])
super().__init__(param_groups, defaults=dict(self.defaults))
def create_buffers(self, name, param):
return self.get_hyper(name, 'opt').create_buffers(self, name, param)
def update_param(self, param, grad, state, opt_params):
return opt_params['opt'].update_param(self, param, grad, state, opt_params)