diff --git a/nfnets/README.md b/nfnets/README.md index f32bfd2..d316760 100644 --- a/nfnets/README.md +++ b/nfnets/README.md @@ -1,7 +1,9 @@ -# Code for Normalizer-Free ResNets (Brock, De, Smith, ICLR2021) +# Code for Normalizer-Free Networks This repository contains code for the ICLR 2021 paper "Characterizing signal propagation to close the performance gap in unnormalized -ResNets," by Andrew Brock, Soham De, and Samuel L. Smith. +ResNets," by Andrew Brock, Soham De, and Samuel L. Smith, and the arXiv preprint +"High-Performance Large-Scale Image Recognition Without Normalization" by +Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan. ## Running this code @@ -16,14 +18,29 @@ used in dataset.py in order to train on ImageNet. ## Giving Credit -If you use this code in your work, we ask you to cite this paper: +If you use this code in your work, we ask you to please cite one or both of the +following papers. + +The reference for the Normalizer-Free structure and NF-ResNets or NF-Regnets: ``` @inproceedings{brock2021characterizing, - author={Andrew Brock and Soham De and Samuel L. SMith}, + author={Andrew Brock and Soham De and Samuel L. Smith}, title={Characterizing signal propagation to close the performance gap in unnormalized ResNets}, booktitle={9th International Conference on Learning Representations, {ICLR}}, year={2021} } ``` + +The reference for Adaptive Gradient Clipping (AGC) and the NFNets models: + +``` +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} +``` + diff --git a/nfnets/agc_optax.py b/nfnets/agc_optax.py new file mode 100644 index 0000000..532f4d1 --- /dev/null +++ b/nfnets/agc_optax.py @@ -0,0 +1,78 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adaptive gradient clipping transform for Optax.""" +import jax +import jax.numpy as jnp +import optax + + +def compute_norm(x, axis, keepdims): + """Axis-wise euclidean norm.""" + return jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5 + + +def unitwise_norm(x): + """Compute norms of each output unit separately, also for linear layers.""" + if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors + axis = None + keepdims = False + elif len(x.shape) in [2, 3]: # Linear layers of shape IO or multihead linear + axis = 0 + keepdims = True + elif len(x.shape) == 4: # Conv kernels of shape HWIO + axis = [0, 1, 2,] + keepdims = True + else: + raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}') + return compute_norm(x, axis, keepdims) + + +def my_clip(g_norm, max_norm, grad): + """Applies my gradient clipping unit-wise.""" + trigger = g_norm < max_norm + # This little max(., 1e-6) is distinct from the normal eps and just prevents + # division by zero. It technically should be impossible to engage. + clipped_grad = grad * (max_norm / jnp.maximum(g_norm, 1e-6)) + return jnp.where(trigger, grad, clipped_grad) + + +def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation: + """Clip updates to be at most clipping * parameter_norm. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. + + Args: + clip: Maximum allowed ratio of update norm to parameter norm. + eps: epsilon term to prevent clipping of zero-initialized params. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(_): + return optax.ClipByGlobalNormState() + + def update_fn(updates, state, params): + g_norm = jax.tree_map(unitwise_norm, updates) + p_norm = jax.tree_map(unitwise_norm, params) + # Maximum allowable norm + max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm) + # If grad norm > clipping * param_norm, rescale + updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates) + return updates, state + + return optax.GradientTransformation(init_fn, update_fn) diff --git a/nfnets/base.py b/nfnets/base.py index a8365d1..1baaa8d 100644 --- a/nfnets/base.py +++ b/nfnets/base.py @@ -50,6 +50,50 @@ nf_regnet_params = { 'drop_rate': 0.5}, } +nfnet_params = {} + + +# F-series models +nfnet_params.update(**{ + 'F0': { + 'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3], + 'train_imsize': 192, 'test_imsize': 256, + 'RA_level': '405', 'drop_rate': 0.2}, + 'F1': { + 'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6], + 'train_imsize': 224, 'test_imsize': 320, + 'RA_level': '410', 'drop_rate': 0.3}, + 'F2': { + 'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9], + 'train_imsize': 256, 'test_imsize': 352, + 'RA_level': '410', 'drop_rate': 0.4}, + 'F3': { + 'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12], + 'train_imsize': 320, 'test_imsize': 416, + 'RA_level': '415', 'drop_rate': 0.4}, + 'F4': { + 'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15], + 'train_imsize': 384, 'test_imsize': 512, + 'RA_level': '415', 'drop_rate': 0.5}, + 'F5': { + 'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18], + 'train_imsize': 416, 'test_imsize': 544, + 'RA_level': '415', 'drop_rate': 0.5}, + 'F6': { + 'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21], + 'train_imsize': 448, 'test_imsize': 576, + 'RA_level': '415', 'drop_rate': 0.5}, + 'F7': { + 'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24], + 'train_imsize': 480, 'test_imsize': 608, + 'RA_level': '415', 'drop_rate': 0.5}, +}) + +# Minor variants FN+, slightly wider +nfnet_params.update(**{ + **{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],} + for key in nfnet_params} +}) # Nonlinearities with magic constants (gamma) baked in. # Note that not all nonlinearities will be stable, especially if they are @@ -122,7 +166,7 @@ def signal_metrics(x, i): def count_conv_flops(in_ch, conv, h, w): """For a conv layer with in_ch inputs, count the FLOPS.""" - # How many output floats are we producing? + # How many outputs are we producing? Note this is wrong for VALID padding. output_shape = conv.output_channels * (h * w) / np.prod(conv.stride) # At each OHW location we do computation equal to (I//G) * kh * kw flop_per_loc = (in_ch / conv.feature_group_count) diff --git a/nfnets/experiment_nfnets.py b/nfnets/experiment_nfnets.py new file mode 100644 index 0000000..1dad04c --- /dev/null +++ b/nfnets/experiment_nfnets.py @@ -0,0 +1,126 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""ImageNet experiment with NFNets.""" + +import haiku as hk +from ml_collections import config_dict +from nfnets import experiment +from nfnets import optim + + +def get_config(): + """Return config object for training.""" + config = experiment.get_config() + + # Experiment config. + train_batch_size = 4096 # Global batch size. + images_per_epoch = 1281167 + num_epochs = 360 + steps_per_epoch = images_per_epoch / train_batch_size + config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) + config.random_seed = 0 + + config.experiment_kwargs = config_dict.ConfigDict( + dict( + config=dict( + lr=0.1, + num_epochs=num_epochs, + label_smoothing=0.1, + model='NFNet', + image_size=224, + use_ema=True, + ema_decay=0.99999, + ema_start=0, + augment_name=None, + augment_before_mix=False, + eval_preproc='resize_crop_32', + train_batch_size=train_batch_size, + eval_batch_size=50, + eval_subset='test', + num_classes=1000, + which_dataset='imagenet', + which_loss='softmax_cross_entropy', # One of softmax or sigmoid + bfloat16=True, + lr_schedule=dict( + name='WarmupCosineDecay', + kwargs=dict(num_steps=config.training_steps, + start_val=0, + min_val=0.0, + warmup_steps=5*steps_per_epoch), + ), + lr_scale_by_bs=True, + optimizer=dict( + name='SGD_AGC', + kwargs={'momentum': 0.9, 'nesterov': True, + 'weight_decay': 2e-5, + 'clipping': 0.01, 'eps': 1e-3}, + ), + model_kwargs=dict( + variant='F0', + width=1.0, + se_ratio=0.5, + alpha=0.2, + stochdepth_rate=0.25, + drop_rate=None, # Use native drop-rate + activation='gelu', + final_conv_mult=2, + final_conv_ch=None, + use_two_convs=True, + ), + ))) + + # Unlike NF-RegNets, use the same weight decay for all, but vary RA levels + variant = config.experiment_kwargs.config.model_kwargs.variant + # RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5) + augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415', + 'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant] + aug_base_name = 'cutmix_mixup_randaugment' + config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}' + + return config + + +class Experiment(experiment.Experiment): + """Experiment with correct parameter filtering for applying AGC.""" + + def _make_opt(self): + # Separate conv params and gains/biases + def pred_gb(mod, name, val): + del mod, val + return (name in ['scale', 'offset', 'b'] + or 'gain' in name or 'bias' in name) + gains_biases, weights = hk.data_structures.partition(pred_gb, self._params) + def pred_fc(mod, name, val): + del name, val + return 'linear' in mod and 'squeeze_excite' not in mod + fc_weights, weights = hk.data_structures.partition(pred_fc, weights) + # Lr schedule with batch-based LR scaling + if self.config.lr_scale_by_bs: + max_lr = (self.config.lr * self.config.train_batch_size) / 256 + else: + max_lr = self.config.lr + lr_sched_fn = getattr(optim, self.config.lr_schedule.name) + lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs) + # Optimizer; no need to broadcast! + opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()} + opt_kwargs['lr'] = lr_schedule + opt_module = getattr(optim, self.config.optimizer.name) + self.opt = opt_module([{'params': gains_biases, 'weight_decay': None,}, + {'params': fc_weights, 'clipping': None}, + {'params': weights}], **opt_kwargs) + if self._opt_state is None: + self._opt_state = self.opt.states() + else: + self.opt.plugin(self._opt_state) diff --git a/nfnets/nfnet.py b/nfnets/nfnet.py new file mode 100644 index 0000000..30ade03 --- /dev/null +++ b/nfnets/nfnet.py @@ -0,0 +1,272 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Norm-Free Nets.""" +# pylint: disable=unused-import +# pylint: disable=invalid-name + +import functools +import haiku as hk +import jax +import jax.numpy as jnp +import jax.random as jrandom +import numpy as np + + +from nfnets import base + + +class NFNet(hk.Module): + """Normalizer-Free Networks with an improved architecture. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. + """ + + variant_dict = base.nfnet_params + + def __init__(self, num_classes, variant='F0', + width=1.0, se_ratio=0.5, + alpha=0.2, stochdepth_rate=0.1, drop_rate=None, + activation='gelu', fc_init=None, + final_conv_mult=2, final_conv_ch=None, + use_two_convs=True, + name='NFNet'): + super().__init__(name=name) + self.num_classes = num_classes + self.variant = variant + self.width = width + self.se_ratio = se_ratio + # Get variant info + block_params = self.variant_dict[self.variant] + self.train_imsize = block_params['train_imsize'] + self.test_imsize = block_params['test_imsize'] + self.width_pattern = block_params['width'] + self.depth_pattern = block_params['depth'] + self.bneck_pattern = block_params.get('expansion', [0.5] * 4) + self.group_pattern = block_params.get('group_width', [128] * 4) + self.big_pattern = block_params.get('big_width', [True] * 4) + self.activation = base.nonlinearities[activation] + if drop_rate is None: + self.drop_rate = block_params['drop_rate'] + else: + self.drop_rate = drop_rate + self.which_conv = base.WSConv2D + # Stem + ch = self.width_pattern[0] // 2 + self.stem = hk.Sequential([ + self.which_conv(16, kernel_shape=3, stride=2, + padding='SAME', name='stem_conv0'), + self.activation, + self.which_conv(32, kernel_shape=3, stride=1, + padding='SAME', name='stem_conv1'), + self.activation, + self.which_conv(64, kernel_shape=3, stride=1, + padding='SAME', name='stem_conv2'), + self.activation, + self.which_conv(ch, kernel_shape=3, stride=2, + padding='SAME', name='stem_conv3'), + ]) + + # Body + self.blocks = [] + expected_std = 1.0 + num_blocks = sum(self.depth_pattern) + index = 0 # Overall block index + stride_pattern = [1, 2, 2, 2] + block_args = zip(self.width_pattern, self.depth_pattern, self.bneck_pattern, + self.group_pattern, self.big_pattern, stride_pattern) + for (block_width, stage_depth, expand_ratio, + group_size, big_width, stride) in block_args: + for block_index in range(stage_depth): + # Scalar pre-multiplier so each block sees an N(0,1) input at init + beta = 1./ expected_std + # Block stochastic depth drop-rate + block_stochdepth_rate = stochdepth_rate * index / num_blocks + out_ch = (int(block_width * self.width)) + self.blocks += [NFBlock(ch, out_ch, + expansion=expand_ratio, se_ratio=se_ratio, + group_size=group_size, + stride=stride if block_index == 0 else 1, + beta=beta, alpha=alpha, + activation=self.activation, + which_conv=self.which_conv, + stochdepth_rate=block_stochdepth_rate, + big_width=big_width, + use_two_convs=use_two_convs, + )] + ch = out_ch + index += 1 + # Reset expected std but still give it 1 block of growth + if block_index == 0: + expected_std = 1.0 + expected_std = (expected_std **2 + alpha**2)**0.5 + + # Head + if final_conv_mult is None: + if final_conv_ch is None: + raise ValueError('Must provide one of final_conv_mult or final_conv_ch') + ch = final_conv_ch + else: + ch = int(final_conv_mult * ch) + self.final_conv = self.which_conv(ch, kernel_shape=1, + padding='SAME', name='final_conv') + # By default, initialize with N(0, 0.01) + if fc_init is None: + fc_init = hk.initializers.RandomNormal(mean=0, stddev=0.01) + self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True) + + def __call__(self, x, is_training=True, return_metrics=False): + """Return the output of the final layer without any [log-]softmax.""" + # Stem + outputs = {} + out = self.stem(x) + if return_metrics: + outputs.update(base.signal_metrics(out, 0)) + # Blocks + for i, block in enumerate(self.blocks): + out, res_avg_var = block(out, is_training=is_training) + if return_metrics: + outputs.update(base.signal_metrics(out, i + 1)) + outputs[f'res_avg_var_{i}'] = res_avg_var + # Final-conv->activation, pool, dropout, classify + out = self.activation(self.final_conv(out)) + pool = jnp.mean(out, [1, 2]) + outputs['pool'] = pool + # Optionally apply dropout + if self.drop_rate > 0.0 and is_training: + pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) + outputs['logits'] = self.fc(pool) + return outputs + + def count_flops(self, h, w): + flops = [] + ch = 3 + for module in self.stem.layers: + if isinstance(module, hk.Conv2D): + flops += [base.count_conv_flops(ch, module, h, w)] + if any([item > 1 for item in module.stride]): + h, w = h / module.stride[0], w / module.stride[1] + ch = module.output_channels + # Body FLOPs + for block in self.blocks: + flops += [block.count_flops(h, w)] + if block.stride > 1: + h, w = h / block.stride, w / block.stride + # Head module FLOPs + out_ch = self.blocks[-1].out_ch + flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] + # Count flops for classifier + flops += [self.final_conv.output_channels * self.fc.output_size] + return flops, sum(flops) + + +class NFBlock(hk.Module): + """Normalizer-Free Net Block.""" + + def __init__(self, in_ch, out_ch, expansion=0.5, se_ratio=0.5, + kernel_shape=3, group_size=128, stride=1, + beta=1.0, alpha=0.2, + which_conv=base.WSConv2D, activation=jax.nn.gelu, + big_width=True, use_two_convs=True, + stochdepth_rate=None, name=None): + super().__init__(name=name) + self.in_ch, self.out_ch = in_ch, out_ch + self.expansion = expansion + self.se_ratio = se_ratio + self.kernel_shape = kernel_shape + self.activation = activation + self.beta, self.alpha = beta, alpha + # Mimic resnet style bigwidth scaling? + width = int((self.out_ch if big_width else self.in_ch) * expansion) + # Round expanded with based on group count + self.groups = width // group_size + self.width = group_size * self.groups + self.stride = stride + self.use_two_convs = use_two_convs + # Conv 0 (typically expansion conv) + self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', + name='conv0') + # Grouped NxN conv + self.conv1 = which_conv(self.width, kernel_shape=kernel_shape, + stride=stride, padding='SAME', + feature_group_count=self.groups, name='conv1') + if self.use_two_convs: + self.conv1b = which_conv(self.width, kernel_shape=kernel_shape, + stride=1, padding='SAME', + feature_group_count=self.groups, name='conv1b') + # Conv 2, typically projection conv + self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', + name='conv2') + # Use shortcut conv on channel change or downsample. + self.use_projection = stride > 1 or self.in_ch != self.out_ch + if self.use_projection: + self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, + padding='SAME', name='conv_shortcut') + # Squeeze + Excite Module + self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio) + + # Are we using stochastic depth? + self._has_stochdepth = (stochdepth_rate is not None and + stochdepth_rate > 0. and stochdepth_rate < 1.0) + if self._has_stochdepth: + self.stoch_depth = base.StochDepth(stochdepth_rate) + + def __call__(self, x, is_training): + out = self.activation(x) * self.beta + if self.stride > 1: # Average-pool downsample. + shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1), + strides=(1, 2, 2, 1), padding='SAME') + if self.use_projection: + shortcut = self.conv_shortcut(shortcut) + elif self.use_projection: + shortcut = self.conv_shortcut(out) + else: + shortcut = x + out = self.conv0(out) + out = self.conv1(self.activation(out)) + if self.use_two_convs: + out = self.conv1b(self.activation(out)) + out = self.conv2(self.activation(out)) + out = (self.se(out) * 2) * out # Multiply by 2 for rescaling + # Get average residual standard deviation for reporting metrics. + res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2])) + # Apply stochdepth if applicable. + if self._has_stochdepth: + out = self.stoch_depth(out, is_training) + # SkipInit Gain + out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros) + return out * self.alpha + shortcut, res_avg_var + + def count_flops(self, h, w): + # Count conv FLOPs based on input HW + expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w) + # If block is strided we decrease resolution here. + dw_flops = base.count_conv_flops(self.width, self.conv1, h, w) + if self.stride > 1: + h, w = h / self.stride, w / self.stride + if self.use_two_convs: + dw_flops += base.count_conv_flops(self.width, self.conv1b, h, w) + + if self.use_projection: + sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w) + else: + sc_flops = 0 + # SE flops happen on avg-pooled activations + se_flops = self.se.fc0.output_size * self.out_ch + se_flops += self.se.fc0.output_size * self.se.fc1.output_size + contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) + return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops]) + diff --git a/nfnets/optim.py b/nfnets/optim.py index 4836d96..0f1e452 100644 --- a/nfnets/optim.py +++ b/nfnets/optim.py @@ -429,6 +429,68 @@ class Fromage(Optimizer): return out, state +def compute_norm(x, axis, keepdims): + """Returns norm over arbitrary axis.""" + norm = jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5 + return norm + + +def unitwise_norm(x): + """Computes norms of each output unit separately, assuming (HW)IO weights.""" + if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors + axis = None + keepdims = False + elif len(x.shape) in [2, 3]: # Linear layers of shape IO + axis = 0 + keepdims = True + elif len(x.shape) == 4: # Conv kernels of shape HWIO + axis = [0, 1, 2,] + keepdims = True + else: + raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}') + return compute_norm(x, axis, keepdims) + + +class SGD_AGC(Optimizer): # pylint:disable=invalid-name + """SGD with Unit-Adaptive Gradient-Clipping. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. + """ + defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0, + 'nesterov': None, 'clipping': 0.01, 'eps': 1e-3} + + def __init__(self, params, lr, weight_decay=None, + momentum=None, dampening=0, nesterov=None, + clipping=0.01, eps=1e-3): + super().__init__( + params, defaults={'lr': lr, 'weight_decay': weight_decay, + 'momentum': momentum, 'dampening': dampening, + 'clipping': clipping, 'nesterov': nesterov, + 'eps': eps}) + + def create_buffers(self, name, param): + return SGD.create_buffers(self, name, param) + + def update_param(self, param, grad, state, opt_params): + """Clips grads if necessary, then applies the optimizer update.""" + if param is None: + return param, state + if opt_params['clipping'] is not None: + param_norm = jnp.maximum(unitwise_norm(param), opt_params['eps']) + grad_norm = unitwise_norm(grad) + max_norm = param_norm * opt_params['clipping'] + # If grad norm > clipping * param_norm, rescale + trigger = grad_norm > max_norm + # Note the max(||G||, 1e-6) is technically unnecessary here, as + # the clipping shouldn't trigger if the grad norm is zero, + # but we include it in practice as a "just-in-case". + clipped_grad = grad * (max_norm / jnp.maximum(grad_norm, 1e-6)) + grad = jnp.where(trigger, clipped_grad, grad) + return SGD.update_param(self, param, grad, state, opt_params) + + class Hybrid(Optimizer): """Optimizer which permits passing param groups with different base opts. diff --git a/nfnets/test.py b/nfnets/test.py index 262aaf6..4860a79 100644 --- a/nfnets/test.py +++ b/nfnets/test.py @@ -16,20 +16,47 @@ import jax import jax.numpy as jnp from nfnets import experiment +from nfnets import experiment_nfnets -config = experiment.get_config() -exp_config = config.experiment_kwargs.config -exp_config.train_batch_size = 2 -exp_config.eval_batch_size = 2 -exp_config.lr = 0.1 -exp_config.fake_data = True -exp_config.model_kwargs.width = 2 -print(exp_config.model_kwargs) -xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0)) -bcast = jax.pmap(lambda x: x) -global_step = bcast(jnp.zeros(jax.local_device_count())) -rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) -print('Taking a single experiment step for test purposes!') -result = xp.step(global_step, rng) -print(f'Step successfully taken, resulting metrics are {result}') +def test_experiment(): + """Tests the main experiment.""" + config = experiment.get_config() + exp_config = config.experiment_kwargs.config + exp_config.train_batch_size = 2 + exp_config.eval_batch_size = 2 + exp_config.lr = 0.1 + exp_config.fake_data = True + exp_config.model_kwargs.width = 2 + print(exp_config.model_kwargs) + + xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0)) + bcast = jax.pmap(lambda x: x) + global_step = bcast(jnp.zeros(jax.local_device_count())) + rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) + print('Taking a single experiment step for test purposes!') + result = xp.step(global_step, rng) + print(f'Step successfully taken, resulting metrics are {result}') + + +def test_nfnet_experiment(): + """Tests the NFNet experiment.""" + config = experiment_nfnets.get_config() + exp_config = config.experiment_kwargs.config + exp_config.train_batch_size = 2 + exp_config.eval_batch_size = 2 + exp_config.lr = 0.1 + exp_config.fake_data = True + exp_config.model_kwargs.width = 2 + print(exp_config.model_kwargs) + + xp = experiment_nfnets.Experiment('train', exp_config, jax.random.PRNGKey(0)) + bcast = jax.pmap(lambda x: x) + global_step = bcast(jnp.zeros(jax.local_device_count())) + rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) + print('Taking a single NFNet experiment step for test purposes!') + result = xp.step(global_step, rng) + print(f'NFNet Step successfully taken, resulting metrics are {result}') + +test_experiment() +test_nfnet_experiment()