Initial commit for NFNets and AGC

PiperOrigin-RevId: 357022514
This commit is contained in:
Andy Brock
2021-02-11 19:38:33 +00:00
committed by Louise Deason
parent abc9ba0635
commit 68a1754d29
7 changed files with 646 additions and 20 deletions
+62
View File
@@ -429,6 +429,68 @@ class Fromage(Optimizer):
return out, state
def compute_norm(x, axis, keepdims):
"""Returns norm over arbitrary axis."""
norm = jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
return norm
def unitwise_norm(x):
"""Computes norms of each output unit separately, assuming (HW)IO weights."""
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2,]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
return compute_norm(x, axis, keepdims)
class SGD_AGC(Optimizer): # pylint:disable=invalid-name
"""SGD with Unit-Adaptive Gradient-Clipping.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization.
"""
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
'nesterov': None, 'clipping': 0.01, 'eps': 1e-3}
def __init__(self, params, lr, weight_decay=None,
momentum=None, dampening=0, nesterov=None,
clipping=0.01, eps=1e-3):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay,
'momentum': momentum, 'dampening': dampening,
'clipping': clipping, 'nesterov': nesterov,
'eps': eps})
def create_buffers(self, name, param):
return SGD.create_buffers(self, name, param)
def update_param(self, param, grad, state, opt_params):
"""Clips grads if necessary, then applies the optimizer update."""
if param is None:
return param, state
if opt_params['clipping'] is not None:
param_norm = jnp.maximum(unitwise_norm(param), opt_params['eps'])
grad_norm = unitwise_norm(grad)
max_norm = param_norm * opt_params['clipping']
# If grad norm > clipping * param_norm, rescale
trigger = grad_norm > max_norm
# Note the max(||G||, 1e-6) is technically unnecessary here, as
# the clipping shouldn't trigger if the grad norm is zero,
# but we include it in practice as a "just-in-case".
clipped_grad = grad * (max_norm / jnp.maximum(grad_norm, 1e-6))
grad = jnp.where(trigger, clipped_grad, grad)
return SGD.update_param(self, param, grad, state, opt_params)
class Hybrid(Optimizer):
"""Optimizer which permits passing param groups with different base opts.