mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Initial commit for NFNets and AGC
PiperOrigin-RevId: 357022514
This commit is contained in:
committed by
Louise Deason
parent
abc9ba0635
commit
68a1754d29
@@ -429,6 +429,68 @@ class Fromage(Optimizer):
|
||||
return out, state
|
||||
|
||||
|
||||
def compute_norm(x, axis, keepdims):
|
||||
"""Returns norm over arbitrary axis."""
|
||||
norm = jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
|
||||
return norm
|
||||
|
||||
|
||||
def unitwise_norm(x):
|
||||
"""Computes norms of each output unit separately, assuming (HW)IO weights."""
|
||||
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
|
||||
axis = None
|
||||
keepdims = False
|
||||
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
|
||||
axis = 0
|
||||
keepdims = True
|
||||
elif len(x.shape) == 4: # Conv kernels of shape HWIO
|
||||
axis = [0, 1, 2,]
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
|
||||
return compute_norm(x, axis, keepdims)
|
||||
|
||||
|
||||
class SGD_AGC(Optimizer): # pylint:disable=invalid-name
|
||||
"""SGD with Unit-Adaptive Gradient-Clipping.
|
||||
|
||||
References:
|
||||
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
|
||||
Recognition Without Normalization.
|
||||
"""
|
||||
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
|
||||
'nesterov': None, 'clipping': 0.01, 'eps': 1e-3}
|
||||
|
||||
def __init__(self, params, lr, weight_decay=None,
|
||||
momentum=None, dampening=0, nesterov=None,
|
||||
clipping=0.01, eps=1e-3):
|
||||
super().__init__(
|
||||
params, defaults={'lr': lr, 'weight_decay': weight_decay,
|
||||
'momentum': momentum, 'dampening': dampening,
|
||||
'clipping': clipping, 'nesterov': nesterov,
|
||||
'eps': eps})
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
return SGD.create_buffers(self, name, param)
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""Clips grads if necessary, then applies the optimizer update."""
|
||||
if param is None:
|
||||
return param, state
|
||||
if opt_params['clipping'] is not None:
|
||||
param_norm = jnp.maximum(unitwise_norm(param), opt_params['eps'])
|
||||
grad_norm = unitwise_norm(grad)
|
||||
max_norm = param_norm * opt_params['clipping']
|
||||
# If grad norm > clipping * param_norm, rescale
|
||||
trigger = grad_norm > max_norm
|
||||
# Note the max(||G||, 1e-6) is technically unnecessary here, as
|
||||
# the clipping shouldn't trigger if the grad norm is zero,
|
||||
# but we include it in practice as a "just-in-case".
|
||||
clipped_grad = grad * (max_norm / jnp.maximum(grad_norm, 1e-6))
|
||||
grad = jnp.where(trigger, clipped_grad, grad)
|
||||
return SGD.update_param(self, param, grad, state, opt_params)
|
||||
|
||||
|
||||
class Hybrid(Optimizer):
|
||||
"""Optimizer which permits passing param groups with different base opts.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user