mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Initial commit for NFNets and AGC
PiperOrigin-RevId: 357022514
This commit is contained in:
committed by
Louise Deason
parent
abc9ba0635
commit
68a1754d29
+21
-4
@@ -1,7 +1,9 @@
|
||||
# Code for Normalizer-Free ResNets (Brock, De, Smith, ICLR2021)
|
||||
# Code for Normalizer-Free Networks
|
||||
This repository contains code for the ICLR 2021 paper
|
||||
"Characterizing signal propagation to close the performance gap in unnormalized
|
||||
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith.
|
||||
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith, and the arXiv preprint
|
||||
"High-Performance Large-Scale Image Recognition Without Normalization" by
|
||||
Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan.
|
||||
|
||||
|
||||
## Running this code
|
||||
@@ -16,14 +18,29 @@ used in dataset.py in order to train on ImageNet.
|
||||
|
||||
## Giving Credit
|
||||
|
||||
If you use this code in your work, we ask you to cite this paper:
|
||||
If you use this code in your work, we ask you to please cite one or both of the
|
||||
following papers.
|
||||
|
||||
The reference for the Normalizer-Free structure and NF-ResNets or NF-Regnets:
|
||||
|
||||
```
|
||||
@inproceedings{brock2021characterizing,
|
||||
author={Andrew Brock and Soham De and Samuel L. SMith},
|
||||
author={Andrew Brock and Soham De and Samuel L. Smith},
|
||||
title={Characterizing signal propagation to close the performance gap in
|
||||
unnormalized ResNets},
|
||||
booktitle={9th International Conference on Learning Representations, {ICLR}},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
The reference for Adaptive Gradient Clipping (AGC) and the NFNets models:
|
||||
|
||||
```
|
||||
@article{brock2021high,
|
||||
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
|
||||
title={High-Performance Large-Scale Image Recognition Without Normalization},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Adaptive gradient clipping transform for Optax."""
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
|
||||
def compute_norm(x, axis, keepdims):
|
||||
"""Axis-wise euclidean norm."""
|
||||
return jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
|
||||
|
||||
|
||||
def unitwise_norm(x):
|
||||
"""Compute norms of each output unit separately, also for linear layers."""
|
||||
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
|
||||
axis = None
|
||||
keepdims = False
|
||||
elif len(x.shape) in [2, 3]: # Linear layers of shape IO or multihead linear
|
||||
axis = 0
|
||||
keepdims = True
|
||||
elif len(x.shape) == 4: # Conv kernels of shape HWIO
|
||||
axis = [0, 1, 2,]
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}')
|
||||
return compute_norm(x, axis, keepdims)
|
||||
|
||||
|
||||
def my_clip(g_norm, max_norm, grad):
|
||||
"""Applies my gradient clipping unit-wise."""
|
||||
trigger = g_norm < max_norm
|
||||
# This little max(., 1e-6) is distinct from the normal eps and just prevents
|
||||
# division by zero. It technically should be impossible to engage.
|
||||
clipped_grad = grad * (max_norm / jnp.maximum(g_norm, 1e-6))
|
||||
return jnp.where(trigger, grad, clipped_grad)
|
||||
|
||||
|
||||
def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation:
|
||||
"""Clip updates to be at most clipping * parameter_norm.
|
||||
|
||||
References:
|
||||
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
|
||||
Recognition Without Normalization.
|
||||
|
||||
Args:
|
||||
clip: Maximum allowed ratio of update norm to parameter norm.
|
||||
eps: epsilon term to prevent clipping of zero-initialized params.
|
||||
|
||||
Returns:
|
||||
An (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
def init_fn(_):
|
||||
return optax.ClipByGlobalNormState()
|
||||
|
||||
def update_fn(updates, state, params):
|
||||
g_norm = jax.tree_map(unitwise_norm, updates)
|
||||
p_norm = jax.tree_map(unitwise_norm, params)
|
||||
# Maximum allowable norm
|
||||
max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
|
||||
# If grad norm > clipping * param_norm, rescale
|
||||
updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
|
||||
return updates, state
|
||||
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
+45
-1
@@ -50,6 +50,50 @@ nf_regnet_params = {
|
||||
'drop_rate': 0.5},
|
||||
}
|
||||
|
||||
nfnet_params = {}
|
||||
|
||||
|
||||
# F-series models
|
||||
nfnet_params.update(**{
|
||||
'F0': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],
|
||||
'train_imsize': 192, 'test_imsize': 256,
|
||||
'RA_level': '405', 'drop_rate': 0.2},
|
||||
'F1': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],
|
||||
'train_imsize': 224, 'test_imsize': 320,
|
||||
'RA_level': '410', 'drop_rate': 0.3},
|
||||
'F2': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],
|
||||
'train_imsize': 256, 'test_imsize': 352,
|
||||
'RA_level': '410', 'drop_rate': 0.4},
|
||||
'F3': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],
|
||||
'train_imsize': 320, 'test_imsize': 416,
|
||||
'RA_level': '415', 'drop_rate': 0.4},
|
||||
'F4': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],
|
||||
'train_imsize': 384, 'test_imsize': 512,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F5': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],
|
||||
'train_imsize': 416, 'test_imsize': 544,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F6': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],
|
||||
'train_imsize': 448, 'test_imsize': 576,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F7': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24],
|
||||
'train_imsize': 480, 'test_imsize': 608,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
})
|
||||
|
||||
# Minor variants FN+, slightly wider
|
||||
nfnet_params.update(**{
|
||||
**{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],}
|
||||
for key in nfnet_params}
|
||||
})
|
||||
|
||||
# Nonlinearities with magic constants (gamma) baked in.
|
||||
# Note that not all nonlinearities will be stable, especially if they are
|
||||
@@ -122,7 +166,7 @@ def signal_metrics(x, i):
|
||||
|
||||
def count_conv_flops(in_ch, conv, h, w):
|
||||
"""For a conv layer with in_ch inputs, count the FLOPS."""
|
||||
# How many output floats are we producing?
|
||||
# How many outputs are we producing? Note this is wrong for VALID padding.
|
||||
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
|
||||
# At each OHW location we do computation equal to (I//G) * kh * kw
|
||||
flop_per_loc = (in_ch / conv.feature_group_count)
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""ImageNet experiment with NFNets."""
|
||||
|
||||
import haiku as hk
|
||||
from ml_collections import config_dict
|
||||
from nfnets import experiment
|
||||
from nfnets import optim
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Return config object for training."""
|
||||
config = experiment.get_config()
|
||||
|
||||
# Experiment config.
|
||||
train_batch_size = 4096 # Global batch size.
|
||||
images_per_epoch = 1281167
|
||||
num_epochs = 360
|
||||
steps_per_epoch = images_per_epoch / train_batch_size
|
||||
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
|
||||
config.random_seed = 0
|
||||
|
||||
config.experiment_kwargs = config_dict.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
lr=0.1,
|
||||
num_epochs=num_epochs,
|
||||
label_smoothing=0.1,
|
||||
model='NFNet',
|
||||
image_size=224,
|
||||
use_ema=True,
|
||||
ema_decay=0.99999,
|
||||
ema_start=0,
|
||||
augment_name=None,
|
||||
augment_before_mix=False,
|
||||
eval_preproc='resize_crop_32',
|
||||
train_batch_size=train_batch_size,
|
||||
eval_batch_size=50,
|
||||
eval_subset='test',
|
||||
num_classes=1000,
|
||||
which_dataset='imagenet',
|
||||
which_loss='softmax_cross_entropy', # One of softmax or sigmoid
|
||||
bfloat16=True,
|
||||
lr_schedule=dict(
|
||||
name='WarmupCosineDecay',
|
||||
kwargs=dict(num_steps=config.training_steps,
|
||||
start_val=0,
|
||||
min_val=0.0,
|
||||
warmup_steps=5*steps_per_epoch),
|
||||
),
|
||||
lr_scale_by_bs=True,
|
||||
optimizer=dict(
|
||||
name='SGD_AGC',
|
||||
kwargs={'momentum': 0.9, 'nesterov': True,
|
||||
'weight_decay': 2e-5,
|
||||
'clipping': 0.01, 'eps': 1e-3},
|
||||
),
|
||||
model_kwargs=dict(
|
||||
variant='F0',
|
||||
width=1.0,
|
||||
se_ratio=0.5,
|
||||
alpha=0.2,
|
||||
stochdepth_rate=0.25,
|
||||
drop_rate=None, # Use native drop-rate
|
||||
activation='gelu',
|
||||
final_conv_mult=2,
|
||||
final_conv_ch=None,
|
||||
use_two_convs=True,
|
||||
),
|
||||
)))
|
||||
|
||||
# Unlike NF-RegNets, use the same weight decay for all, but vary RA levels
|
||||
variant = config.experiment_kwargs.config.model_kwargs.variant
|
||||
# RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5)
|
||||
augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415',
|
||||
'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant]
|
||||
aug_base_name = 'cutmix_mixup_randaugment'
|
||||
config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class Experiment(experiment.Experiment):
|
||||
"""Experiment with correct parameter filtering for applying AGC."""
|
||||
|
||||
def _make_opt(self):
|
||||
# Separate conv params and gains/biases
|
||||
def pred_gb(mod, name, val):
|
||||
del mod, val
|
||||
return (name in ['scale', 'offset', 'b']
|
||||
or 'gain' in name or 'bias' in name)
|
||||
gains_biases, weights = hk.data_structures.partition(pred_gb, self._params)
|
||||
def pred_fc(mod, name, val):
|
||||
del name, val
|
||||
return 'linear' in mod and 'squeeze_excite' not in mod
|
||||
fc_weights, weights = hk.data_structures.partition(pred_fc, weights)
|
||||
# Lr schedule with batch-based LR scaling
|
||||
if self.config.lr_scale_by_bs:
|
||||
max_lr = (self.config.lr * self.config.train_batch_size) / 256
|
||||
else:
|
||||
max_lr = self.config.lr
|
||||
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
|
||||
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
|
||||
# Optimizer; no need to broadcast!
|
||||
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
|
||||
opt_kwargs['lr'] = lr_schedule
|
||||
opt_module = getattr(optim, self.config.optimizer.name)
|
||||
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None,},
|
||||
{'params': fc_weights, 'clipping': None},
|
||||
{'params': weights}], **opt_kwargs)
|
||||
if self._opt_state is None:
|
||||
self._opt_state = self.opt.states()
|
||||
else:
|
||||
self.opt.plugin(self._opt_state)
|
||||
+272
@@ -0,0 +1,272 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Norm-Free Nets."""
|
||||
# pylint: disable=unused-import
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import functools
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
import numpy as np
|
||||
|
||||
|
||||
from nfnets import base
|
||||
|
||||
|
||||
class NFNet(hk.Module):
|
||||
"""Normalizer-Free Networks with an improved architecture.
|
||||
|
||||
References:
|
||||
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
|
||||
Recognition Without Normalization.
|
||||
"""
|
||||
|
||||
variant_dict = base.nfnet_params
|
||||
|
||||
def __init__(self, num_classes, variant='F0',
|
||||
width=1.0, se_ratio=0.5,
|
||||
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
|
||||
activation='gelu', fc_init=None,
|
||||
final_conv_mult=2, final_conv_ch=None,
|
||||
use_two_convs=True,
|
||||
name='NFNet'):
|
||||
super().__init__(name=name)
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.width = width
|
||||
self.se_ratio = se_ratio
|
||||
# Get variant info
|
||||
block_params = self.variant_dict[self.variant]
|
||||
self.train_imsize = block_params['train_imsize']
|
||||
self.test_imsize = block_params['test_imsize']
|
||||
self.width_pattern = block_params['width']
|
||||
self.depth_pattern = block_params['depth']
|
||||
self.bneck_pattern = block_params.get('expansion', [0.5] * 4)
|
||||
self.group_pattern = block_params.get('group_width', [128] * 4)
|
||||
self.big_pattern = block_params.get('big_width', [True] * 4)
|
||||
self.activation = base.nonlinearities[activation]
|
||||
if drop_rate is None:
|
||||
self.drop_rate = block_params['drop_rate']
|
||||
else:
|
||||
self.drop_rate = drop_rate
|
||||
self.which_conv = base.WSConv2D
|
||||
# Stem
|
||||
ch = self.width_pattern[0] // 2
|
||||
self.stem = hk.Sequential([
|
||||
self.which_conv(16, kernel_shape=3, stride=2,
|
||||
padding='SAME', name='stem_conv0'),
|
||||
self.activation,
|
||||
self.which_conv(32, kernel_shape=3, stride=1,
|
||||
padding='SAME', name='stem_conv1'),
|
||||
self.activation,
|
||||
self.which_conv(64, kernel_shape=3, stride=1,
|
||||
padding='SAME', name='stem_conv2'),
|
||||
self.activation,
|
||||
self.which_conv(ch, kernel_shape=3, stride=2,
|
||||
padding='SAME', name='stem_conv3'),
|
||||
])
|
||||
|
||||
# Body
|
||||
self.blocks = []
|
||||
expected_std = 1.0
|
||||
num_blocks = sum(self.depth_pattern)
|
||||
index = 0 # Overall block index
|
||||
stride_pattern = [1, 2, 2, 2]
|
||||
block_args = zip(self.width_pattern, self.depth_pattern, self.bneck_pattern,
|
||||
self.group_pattern, self.big_pattern, stride_pattern)
|
||||
for (block_width, stage_depth, expand_ratio,
|
||||
group_size, big_width, stride) in block_args:
|
||||
for block_index in range(stage_depth):
|
||||
# Scalar pre-multiplier so each block sees an N(0,1) input at init
|
||||
beta = 1./ expected_std
|
||||
# Block stochastic depth drop-rate
|
||||
block_stochdepth_rate = stochdepth_rate * index / num_blocks
|
||||
out_ch = (int(block_width * self.width))
|
||||
self.blocks += [NFBlock(ch, out_ch,
|
||||
expansion=expand_ratio, se_ratio=se_ratio,
|
||||
group_size=group_size,
|
||||
stride=stride if block_index == 0 else 1,
|
||||
beta=beta, alpha=alpha,
|
||||
activation=self.activation,
|
||||
which_conv=self.which_conv,
|
||||
stochdepth_rate=block_stochdepth_rate,
|
||||
big_width=big_width,
|
||||
use_two_convs=use_two_convs,
|
||||
)]
|
||||
ch = out_ch
|
||||
index += 1
|
||||
# Reset expected std but still give it 1 block of growth
|
||||
if block_index == 0:
|
||||
expected_std = 1.0
|
||||
expected_std = (expected_std **2 + alpha**2)**0.5
|
||||
|
||||
# Head
|
||||
if final_conv_mult is None:
|
||||
if final_conv_ch is None:
|
||||
raise ValueError('Must provide one of final_conv_mult or final_conv_ch')
|
||||
ch = final_conv_ch
|
||||
else:
|
||||
ch = int(final_conv_mult * ch)
|
||||
self.final_conv = self.which_conv(ch, kernel_shape=1,
|
||||
padding='SAME', name='final_conv')
|
||||
# By default, initialize with N(0, 0.01)
|
||||
if fc_init is None:
|
||||
fc_init = hk.initializers.RandomNormal(mean=0, stddev=0.01)
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training=True, return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
# Stem
|
||||
outputs = {}
|
||||
out = self.stem(x)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_avg_var = block(out, is_training=is_training)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_avg_var
|
||||
# Final-conv->activation, pool, dropout, classify
|
||||
out = self.activation(self.final_conv(out))
|
||||
pool = jnp.mean(out, [1, 2])
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
outputs['logits'] = self.fc(pool)
|
||||
return outputs
|
||||
|
||||
def count_flops(self, h, w):
|
||||
flops = []
|
||||
ch = 3
|
||||
for module in self.stem.layers:
|
||||
if isinstance(module, hk.Conv2D):
|
||||
flops += [base.count_conv_flops(ch, module, h, w)]
|
||||
if any([item > 1 for item in module.stride]):
|
||||
h, w = h / module.stride[0], w / module.stride[1]
|
||||
ch = module.output_channels
|
||||
# Body FLOPs
|
||||
for block in self.blocks:
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
class NFBlock(hk.Module):
|
||||
"""Normalizer-Free Net Block."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, expansion=0.5, se_ratio=0.5,
|
||||
kernel_shape=3, group_size=128, stride=1,
|
||||
beta=1.0, alpha=0.2,
|
||||
which_conv=base.WSConv2D, activation=jax.nn.gelu,
|
||||
big_width=True, use_two_convs=True,
|
||||
stochdepth_rate=None, name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
self.expansion = expansion
|
||||
self.se_ratio = se_ratio
|
||||
self.kernel_shape = kernel_shape
|
||||
self.activation = activation
|
||||
self.beta, self.alpha = beta, alpha
|
||||
# Mimic resnet style bigwidth scaling?
|
||||
width = int((self.out_ch if big_width else self.in_ch) * expansion)
|
||||
# Round expanded with based on group count
|
||||
self.groups = width // group_size
|
||||
self.width = group_size * self.groups
|
||||
self.stride = stride
|
||||
self.use_two_convs = use_two_convs
|
||||
# Conv 0 (typically expansion conv)
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
|
||||
name='conv0')
|
||||
# Grouped NxN conv
|
||||
self.conv1 = which_conv(self.width, kernel_shape=kernel_shape,
|
||||
stride=stride, padding='SAME',
|
||||
feature_group_count=self.groups, name='conv1')
|
||||
if self.use_two_convs:
|
||||
self.conv1b = which_conv(self.width, kernel_shape=kernel_shape,
|
||||
stride=1, padding='SAME',
|
||||
feature_group_count=self.groups, name='conv1b')
|
||||
# Conv 2, typically projection conv
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
|
||||
name='conv2')
|
||||
# Use shortcut conv on channel change or downsample.
|
||||
self.use_projection = stride > 1 or self.in_ch != self.out_ch
|
||||
if self.use_projection:
|
||||
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
|
||||
padding='SAME', name='conv_shortcut')
|
||||
# Squeeze + Excite Module
|
||||
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
|
||||
|
||||
# Are we using stochastic depth?
|
||||
self._has_stochdepth = (stochdepth_rate is not None and
|
||||
stochdepth_rate > 0. and stochdepth_rate < 1.0)
|
||||
if self._has_stochdepth:
|
||||
self.stoch_depth = base.StochDepth(stochdepth_rate)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
out = self.activation(x) * self.beta
|
||||
if self.stride > 1: # Average-pool downsample.
|
||||
shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if self.use_projection:
|
||||
shortcut = self.conv_shortcut(shortcut)
|
||||
elif self.use_projection:
|
||||
shortcut = self.conv_shortcut(out)
|
||||
else:
|
||||
shortcut = x
|
||||
out = self.conv0(out)
|
||||
out = self.conv1(self.activation(out))
|
||||
if self.use_two_convs:
|
||||
out = self.conv1b(self.activation(out))
|
||||
out = self.conv2(self.activation(out))
|
||||
out = (self.se(out) * 2) * out # Multiply by 2 for rescaling
|
||||
# Get average residual standard deviation for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
# Apply stochdepth if applicable.
|
||||
if self._has_stochdepth:
|
||||
out = self.stoch_depth(out, is_training)
|
||||
# SkipInit Gain
|
||||
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
|
||||
return out * self.alpha + shortcut, res_avg_var
|
||||
|
||||
def count_flops(self, h, w):
|
||||
# Count conv FLOPs based on input HW
|
||||
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
|
||||
# If block is strided we decrease resolution here.
|
||||
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
|
||||
if self.stride > 1:
|
||||
h, w = h / self.stride, w / self.stride
|
||||
if self.use_two_convs:
|
||||
dw_flops += base.count_conv_flops(self.width, self.conv1b, h, w)
|
||||
|
||||
if self.use_projection:
|
||||
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
|
||||
else:
|
||||
sc_flops = 0
|
||||
# SE flops happen on avg-pooled activations
|
||||
se_flops = self.se.fc0.output_size * self.out_ch
|
||||
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
|
||||
|
||||
@@ -429,6 +429,68 @@ class Fromage(Optimizer):
|
||||
return out, state
|
||||
|
||||
|
||||
def compute_norm(x, axis, keepdims):
|
||||
"""Returns norm over arbitrary axis."""
|
||||
norm = jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
|
||||
return norm
|
||||
|
||||
|
||||
def unitwise_norm(x):
|
||||
"""Computes norms of each output unit separately, assuming (HW)IO weights."""
|
||||
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
|
||||
axis = None
|
||||
keepdims = False
|
||||
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
|
||||
axis = 0
|
||||
keepdims = True
|
||||
elif len(x.shape) == 4: # Conv kernels of shape HWIO
|
||||
axis = [0, 1, 2,]
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
|
||||
return compute_norm(x, axis, keepdims)
|
||||
|
||||
|
||||
class SGD_AGC(Optimizer): # pylint:disable=invalid-name
|
||||
"""SGD with Unit-Adaptive Gradient-Clipping.
|
||||
|
||||
References:
|
||||
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
|
||||
Recognition Without Normalization.
|
||||
"""
|
||||
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
|
||||
'nesterov': None, 'clipping': 0.01, 'eps': 1e-3}
|
||||
|
||||
def __init__(self, params, lr, weight_decay=None,
|
||||
momentum=None, dampening=0, nesterov=None,
|
||||
clipping=0.01, eps=1e-3):
|
||||
super().__init__(
|
||||
params, defaults={'lr': lr, 'weight_decay': weight_decay,
|
||||
'momentum': momentum, 'dampening': dampening,
|
||||
'clipping': clipping, 'nesterov': nesterov,
|
||||
'eps': eps})
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
return SGD.create_buffers(self, name, param)
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""Clips grads if necessary, then applies the optimizer update."""
|
||||
if param is None:
|
||||
return param, state
|
||||
if opt_params['clipping'] is not None:
|
||||
param_norm = jnp.maximum(unitwise_norm(param), opt_params['eps'])
|
||||
grad_norm = unitwise_norm(grad)
|
||||
max_norm = param_norm * opt_params['clipping']
|
||||
# If grad norm > clipping * param_norm, rescale
|
||||
trigger = grad_norm > max_norm
|
||||
# Note the max(||G||, 1e-6) is technically unnecessary here, as
|
||||
# the clipping shouldn't trigger if the grad norm is zero,
|
||||
# but we include it in practice as a "just-in-case".
|
||||
clipped_grad = grad * (max_norm / jnp.maximum(grad_norm, 1e-6))
|
||||
grad = jnp.where(trigger, clipped_grad, grad)
|
||||
return SGD.update_param(self, param, grad, state, opt_params)
|
||||
|
||||
|
||||
class Hybrid(Optimizer):
|
||||
"""Optimizer which permits passing param groups with different base opts.
|
||||
|
||||
|
||||
+42
-15
@@ -16,20 +16,47 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import experiment
|
||||
from nfnets import experiment_nfnets
|
||||
|
||||
config = experiment.get_config()
|
||||
exp_config = config.experiment_kwargs.config
|
||||
exp_config.train_batch_size = 2
|
||||
exp_config.eval_batch_size = 2
|
||||
exp_config.lr = 0.1
|
||||
exp_config.fake_data = True
|
||||
exp_config.model_kwargs.width = 2
|
||||
print(exp_config.model_kwargs)
|
||||
|
||||
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
|
||||
bcast = jax.pmap(lambda x: x)
|
||||
global_step = bcast(jnp.zeros(jax.local_device_count()))
|
||||
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
|
||||
print('Taking a single experiment step for test purposes!')
|
||||
result = xp.step(global_step, rng)
|
||||
print(f'Step successfully taken, resulting metrics are {result}')
|
||||
def test_experiment():
|
||||
"""Tests the main experiment."""
|
||||
config = experiment.get_config()
|
||||
exp_config = config.experiment_kwargs.config
|
||||
exp_config.train_batch_size = 2
|
||||
exp_config.eval_batch_size = 2
|
||||
exp_config.lr = 0.1
|
||||
exp_config.fake_data = True
|
||||
exp_config.model_kwargs.width = 2
|
||||
print(exp_config.model_kwargs)
|
||||
|
||||
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
|
||||
bcast = jax.pmap(lambda x: x)
|
||||
global_step = bcast(jnp.zeros(jax.local_device_count()))
|
||||
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
|
||||
print('Taking a single experiment step for test purposes!')
|
||||
result = xp.step(global_step, rng)
|
||||
print(f'Step successfully taken, resulting metrics are {result}')
|
||||
|
||||
|
||||
def test_nfnet_experiment():
|
||||
"""Tests the NFNet experiment."""
|
||||
config = experiment_nfnets.get_config()
|
||||
exp_config = config.experiment_kwargs.config
|
||||
exp_config.train_batch_size = 2
|
||||
exp_config.eval_batch_size = 2
|
||||
exp_config.lr = 0.1
|
||||
exp_config.fake_data = True
|
||||
exp_config.model_kwargs.width = 2
|
||||
print(exp_config.model_kwargs)
|
||||
|
||||
xp = experiment_nfnets.Experiment('train', exp_config, jax.random.PRNGKey(0))
|
||||
bcast = jax.pmap(lambda x: x)
|
||||
global_step = bcast(jnp.zeros(jax.local_device_count()))
|
||||
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
|
||||
print('Taking a single NFNet experiment step for test purposes!')
|
||||
result = xp.step(global_step, rng)
|
||||
print(f'NFNet Step successfully taken, resulting metrics are {result}')
|
||||
|
||||
test_experiment()
|
||||
test_nfnet_experiment()
|
||||
|
||||
Reference in New Issue
Block a user