Initial release of "nfnets".

PiperOrigin-RevId: 356975781
This commit is contained in:
Louise Deason
2021-02-11 16:04:31 +00:00
parent 0909ded4c7
commit abc9ba0635
17 changed files with 3570 additions and 0 deletions

View File

@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Characterizing signal propagation to close the performance gap in unnormalized ResNets](nfnets), ICLR 2021
* [Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](adversarial_robustness)
* [Functional Regularisation for Continual Learning](functional_regularisation_for_continual_learning), ICLR 2020
* [Self-Supervised MultiModal Versatile Networks](mmv), NeurIPS 2020

29
nfnets/README.md Normal file
View File

@@ -0,0 +1,29 @@
# Code for Normalizer-Free ResNets (Brock, De, Smith, ICLR2021)
This repository contains code for the ICLR 2021 paper
"Characterizing signal propagation to close the performance gap in unnormalized
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith.
## Running this code
Install using pip install -r requirements.txt and use one of the experiment.py
files in combination with [JAXline](https://github.com/deepmind/jaxline) to
train models. Optionally copy test.py into
a dir one level up and run it to ensure you can take a single experiment step
with fake data.
Note that you will need a local copy of ImageNet compatible with the TFDS format
used in dataset.py in order to train on ImageNet.
## Giving Credit
If you use this code in your work, we ask you to cite this paper:
```
@inproceedings{brock2021characterizing,
author={Andrew Brock and Soham De and Samuel L. SMith},
title={Characterizing signal propagation to close the performance gap in
unnormalized ResNets},
booktitle={9th International Conference on Learning Representations, {ICLR}},
year={2021}
}
```

714
nfnets/autoaugment.py Normal file

File diff suppressed because it is too large Load Diff

176
nfnets/base.py Normal file
View File

@@ -0,0 +1,176 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Architecture definitions for different models."""
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
# Model settings for NF-RegNets
nf_regnet_params = {
'B0': {'width': [48, 104, 208, 440], 'depth': [1, 3, 6, 6],
'train_imsize': 192, 'test_imsize': 224,
'drop_rate': 0.2},
'B1': {'width': [48, 104, 208, 440], 'depth': [2, 4, 7, 7],
'train_imsize': 224, 'test_imsize': 256,
'drop_rate': 0.2},
'B2': {'width': [56, 112, 232, 488], 'depth': [2, 4, 8, 8],
'train_imsize': 240, 'test_imsize': 272,
'drop_rate': 0.3},
'B3': {'width': [56, 128, 248, 528], 'depth': [2, 5, 9, 9],
'train_imsize': 288, 'test_imsize': 320,
'drop_rate': 0.3},
'B4': {'width': [64, 144, 288, 616], 'depth': [2, 6, 11, 11],
'train_imsize': 320, 'test_imsize': 384,
'drop_rate': 0.4},
'B5': {'width': [80, 168, 336, 704], 'depth': [3, 7, 14, 14],
'train_imsize': 384, 'test_imsize': 456,
'drop_rate': 0.4},
'B6': {'width': [88, 184, 376, 792], 'depth': [3, 8, 16, 16],
'train_imsize': 448, 'test_imsize': 528,
'drop_rate': 0.5},
'B7': {'width': [96, 208, 416, 880], 'depth': [4, 10, 19, 19],
'train_imsize': 512, 'test_imsize': 600,
'drop_rate': 0.5},
'B8': {'width': [104, 232, 456, 968], 'depth': [4, 11, 22, 22],
'train_imsize': 600, 'test_imsize': 672,
'drop_rate': 0.5},
}
# Nonlinearities with magic constants (gamma) baked in.
# Note that not all nonlinearities will be stable, especially if they are
# not perfectly monotonic. Good choices include relu, silu, and gelu.
nonlinearities = {
'identity': lambda x: x,
'celu': lambda x: jax.nn.celu(x) * 1.270926833152771,
'elu': lambda x: jax.nn.elu(x) * 1.2716004848480225,
'gelu': lambda x: jax.nn.gelu(x) * 1.7015043497085571,
'glu': lambda x: jax.nn.glu(x) * 1.8484294414520264,
'leaky_relu': lambda x: jax.nn.leaky_relu(x) * 1.70590341091156,
'log_sigmoid': lambda x: jax.nn.log_sigmoid(x) * 1.9193484783172607,
'log_softmax': lambda x: jax.nn.log_softmax(x) * 1.0002083778381348,
'relu': lambda x: jax.nn.relu(x) * 1.7139588594436646,
'relu6': lambda x: jax.nn.relu6(x) * 1.7131484746932983,
'selu': lambda x: jax.nn.selu(x) * 1.0008515119552612,
'sigmoid': lambda x: jax.nn.sigmoid(x) * 4.803835391998291,
'silu': lambda x: jax.nn.silu(x) * 1.7881293296813965,
'soft_sign': lambda x: jax.nn.soft_sign(x) * 2.338853120803833,
'softplus': lambda x: jax.nn.softplus(x) * 1.9203323125839233,
'tanh': lambda x: jnp.tanh(x) * 1.5939117670059204,
}
class WSConv2D(hk.Conv2D):
"""2D Convolution with Scaled Weight Standardization and affine gain+bias."""
@hk.transparent
def standardize_weight(self, weight, eps=1e-4):
"""Apply scaled WS with affine gain."""
mean = jnp.mean(weight, axis=(0, 1, 2), keepdims=True)
var = jnp.var(weight, axis=(0, 1, 2), keepdims=True)
fan_in = np.prod(weight.shape[:-1])
# Get gain
gain = hk.get_parameter('gain', shape=(weight.shape[-1],),
dtype=weight.dtype, init=jnp.ones)
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps)) * gain
shift = mean * scale
return weight * scale - shift
def __call__(self, inputs: jnp.ndarray, eps: float = 1e-4) -> jnp.ndarray:
w_shape = self.kernel_shape + (
inputs.shape[self.channel_index] // self.feature_group_count,
self.output_channels)
# Use fan-in scaled init, but WS is largely insensitive to this choice.
w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')
w = hk.get_parameter('w', w_shape, inputs.dtype, init=w_init)
weight = self.standardize_weight(w, eps)
out = jax.lax.conv_general_dilated(
inputs, weight, window_strides=self.stride, padding=self.padding,
lhs_dilation=self.lhs_dilation, rhs_dilation=self.kernel_dilation,
dimension_numbers=self.dimension_numbers,
feature_group_count=self.feature_group_count)
# Always add bias
bias_shape = (self.output_channels,)
bias = hk.get_parameter('bias', bias_shape, inputs.dtype, init=jnp.zeros)
return out + bias
def signal_metrics(x, i):
"""Things to measure about a NCHW tensor activation."""
metrics = {}
# Average channel-wise mean-squared
metrics[f'avg_sq_mean_{i}'] = jnp.mean(jnp.mean(x, axis=[0, 1, 2])**2)
# Average channel variance
metrics[f'avg_var_{i}'] = jnp.mean(jnp.var(x, axis=[0, 1, 2]))
return metrics
def count_conv_flops(in_ch, conv, h, w):
"""For a conv layer with in_ch inputs, count the FLOPS."""
# How many output floats are we producing?
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
# At each OHW location we do computation equal to (I//G) * kh * kw
flop_per_loc = (in_ch / conv.feature_group_count)
flop_per_loc *= np.prod(conv.kernel_shape)
return output_shape * flop_per_loc
class SqueezeExcite(hk.Module):
"""Simple Squeeze+Excite module."""
def __init__(self, in_ch, out_ch, se_ratio=0.5,
hidden_ch=None, activation=jax.nn.relu,
name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
if se_ratio is None:
if hidden_ch is None:
raise ValueError('Must provide one of se_ratio or hidden_ch')
self.hidden_ch = hidden_ch
else:
self.hidden_ch = max(1, int(self.in_ch * se_ratio))
self.activation = activation
self.fc0 = hk.Linear(self.hidden_ch, with_bias=True)
self.fc1 = hk.Linear(self.out_ch, with_bias=True)
def __call__(self, x):
h = jnp.mean(x, axis=[1, 2]) # Mean pool over HW extent
h = self.fc1(self.activation(self.fc0(h)))
h = jax.nn.sigmoid(h)[:, None, None] # Broadcast along H, W
return h
class StochDepth(hk.Module):
"""Batchwise Dropout used in EfficientNet, optionally sans rescaling."""
def __init__(self, drop_rate, scale_by_keep=False, name=None):
super().__init__(name=name)
self.drop_rate = drop_rate
self.scale_by_keep = scale_by_keep
def __call__(self, x, is_training) -> jnp.ndarray:
if not is_training:
return x
batch_size = x.shape[0]
r = jax.random.uniform(hk.next_rng_key(), [batch_size, 1, 1, 1],
dtype=x.dtype)
keep_prob = 1. - self.drop_rate
binary_tensor = jnp.floor(keep_prob + r)
if self.scale_by_keep:
x = x / keep_prob
return x * binary_tensor

542
nfnets/dataset.py Normal file

File diff suppressed because it is too large Load Diff

363
nfnets/experiment.py Normal file
View File

@@ -0,0 +1,363 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Basic Jaxline ImageNet experiment."""
import importlib
import sys
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
from jaxline import base_config
from jaxline import experiment
from jaxline import platform
from jaxline import utils as jl_utils
from ml_collections import config_dict
import numpy as np
from nfnets import dataset
from nfnets import optim
from nfnets import utils
# pylint: disable=logging-format-interpolation
FLAGS = flags.FLAGS
# We define the experiment launch config in the same file as the experiment to
# keep things self-contained in a single file, but one might consider moving the
# config and/or sweep functions to a separate file, if necessary.
def get_config():
"""Return config object for training."""
config = base_config.get_base_config()
# Experiment config.
train_batch_size = 1024 # Global batch size.
images_per_epoch = 1281167
num_epochs = 90
steps_per_epoch = images_per_epoch / train_batch_size
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
config.random_seed = 0
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
lr=0.1,
num_epochs=num_epochs,
label_smoothing=0.1,
model='ResNet',
image_size=224,
use_ema=False,
ema_decay=0.9999, # Quatros nuevos amigos
ema_start=0,
which_ema='tf1_ema',
augment_name=None, # 'mixup_cutmix',
augment_before_mix=True,
eval_preproc='crop_resize',
train_batch_size=train_batch_size,
eval_batch_size=50,
eval_subset='test',
num_classes=1000,
which_dataset='imagenet',
fake_data=False,
which_loss='softmax_cross_entropy', # For now, must be softmax
transpose=True, # Use the double-transpose trick?
bfloat16=False,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(num_steps=config.training_steps,
start_val=0,
min_val=0,
warmup_steps=5*steps_per_epoch),
),
lr_scale_by_bs=True,
optimizer=dict(
name='SGD',
kwargs={'momentum': 0.9, 'nesterov': True,
'weight_decay': 1e-4,},
),
model_kwargs=dict(
width=4,
which_norm='BatchNorm',
norm_kwargs=dict(create_scale=True,
create_offset=True,
decay_rate=0.9,
), # cross_replica_axis='i'),
variant='ResNet50',
activation='relu',
drop_rate=0.0,
),
),))
# Training loop config: log and checkpoint every minute
config.log_train_data_interval = 60
config.log_tensors_interval = 60
config.save_checkpoint_interval = 60
config.eval_specific_checkpoint_dir = ''
return config
class Experiment(experiment.AbstractExperiment):
"""Imagenet experiment."""
CHECKPOINT_ATTRS = {
'_params': 'params',
'_state': 'state',
'_ema_params': 'ema_params',
'_ema_state': 'ema_state',
'_opt_state': 'opt_state',
}
def __init__(self, mode, config, init_rng):
super().__init__(mode=mode)
self.mode = mode
self.config = config
self.init_rng = init_rng
# Checkpointed experiment state.
self._params = None
self._state = None
self._ema_params = None
self._ema_state = None
self._opt_state = None
# Input pipelines.
self._train_input = None
self._eval_input = None
# Get model, loaded in from the zoo
self.model_module = importlib.import_module(
('nfnets.'+ self.config.model.lower()))
self.net = hk.transform_with_state(self._forward_fn)
# Assign image sizes
if self.config.get('override_imsize', False):
self.train_imsize = self.config.image_size
self.test_imsize = self.config.get('eval_image_size', self.train_imsize)
else:
variant_dict = getattr(self.model_module, self.config.model).variant_dict
variant_dict = variant_dict[self.config.model_kwargs.variant]
self.train_imsize = variant_dict.get('train_imsize',
self.config.image_size)
# Test imsize defaults to model-specific value, then to config imsize
test_imsize = self.config.get('eval_image_size', self.config.image_size)
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
self.train_fn = jax.pmap(self._train_fn, axis_name='i',
donate_argnums=donate_argnums)
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
def _initialize_train(self):
self._train_input = self._build_train_input()
# Initialize net and EMA copy of net if no params available.
if self._params is None:
inputs = next(self._train_input)
init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
axis_name='i')
init_rng = jl_utils.bcast_local_devices(self.init_rng)
self._params, self._state = init_net(init_rng, inputs)
if self.config.use_ema:
self._ema_params, self._ema_state = init_net(init_rng, inputs)
num_params = hk.data_structures.tree_size(self._params)
logging.info(f'Net parameters: {num_params / jax.local_device_count()}')
self._make_opt()
def _make_opt(self):
# Separate conv params and gains/biases
def pred(mod, name, val): # pylint:disable=unused-argument
return (name in ['scale', 'offset', 'b']
or 'gain' in name or 'bias' in name)
gains_biases, weights = hk.data_structures.partition(pred, self._params)
# Lr schedule with batch-based LR scaling
if self.config.lr_scale_by_bs:
max_lr = (self.config.lr * self.config.train_batch_size) / 256
else:
max_lr = self.config.lr
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
# Optimizer; no need to broadcast!
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
opt_kwargs['lr'] = lr_schedule
opt_module = getattr(optim, self.config.optimizer.name)
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None},
{'params': weights}], **opt_kwargs)
if self._opt_state is None:
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
def _forward_fn(self, inputs, is_training):
net_kwargs = {'num_classes': self.config.num_classes,
**self.config.model_kwargs}
net = getattr(self.model_module, self.config.model)(**net_kwargs)
if self.config.get('transpose', False):
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
else:
images = inputs['images']
if self.config.bfloat16 and self.mode == 'train':
images = utils.to_bf16(images)
return net(images, is_training=is_training)['logits']
def _one_hot(self, value):
"""One-hot encoding potentially over a sequence of labels."""
y = jax.nn.one_hot(value, self.config.num_classes)
return y
def _loss_fn(self, params, state, inputs, rng):
logits, state = self.net.apply(params, state, rng, inputs, is_training=True)
y = self._one_hot(inputs['labels'])
if 'mix_labels' in inputs: # Handle cutmix/mixup label mixing
logging.info('Using mixup or cutmix!')
y1 = self._one_hot(inputs['mix_labels'])
y = inputs['ratio'][:, None] * y + (1. - inputs['ratio'][:, None]) * y1
if self.config.label_smoothing > 0: # get smoothy
spositives = 1. - self.config.label_smoothing
snegatives = self.config.label_smoothing / self.config.num_classes
y = spositives * y + snegatives
if self.config.bfloat16: # Cast logits to float32
logits = logits.astype(jnp.float32)
which_loss = getattr(utils, self.config.which_loss)
loss = which_loss(logits, y, reduction='mean')
metrics = utils.topk_correct(logits, inputs['labels'], prefix='train_')
# Average top-1 and top-5 correct labels
metrics = jax.tree_map(jnp.mean, metrics)
metrics['train_loss'] = loss # Metrics will be pmeaned so don't divide here
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
return scaled_loss, (metrics, state)
def _train_fn(self, params, states, opt_states,
inputs, rng, global_step,
ema_params, ema_states):
"""Runs one batch forward + backward and run a single opt step."""
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
if self.config.bfloat16:
in_params, states = jax.tree_map(utils.to_bf16, (params, states))
else:
in_params = params
grads, (metrics, states) = grad_fn(in_params, states, inputs, rng)
if self.config.bfloat16:
states, metrics, grads = jax.tree_map(utils.from_bf16,
(states, metrics, grads))
# Sum gradients and average losses for pmap
grads = jax.lax.psum(grads, 'i')
metrics = jax.lax.pmean(metrics, 'i')
# Compute updates and update parameters
metrics['learning_rate'] = self.opt._hyperparameters['lr'](global_step) # pylint: disable=protected-access
params, opt_states = self.opt.step(params, grads, opt_states, global_step)
if ema_params is not None:
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_states = jax.tree_multimap(ema, ema_states, states)
return {'params': params, 'states': states, 'opt_states': opt_states,
'ema_params': ema_params, 'ema_states': ema_states,
'metrics': metrics}
# _ _
# | |_ _ __ __ _(_)_ __
# | __| '__/ _` | | '_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step, rng, *unused_args, **unused_kwargs):
if self._train_input is None:
self._initialize_train()
inputs = next(self._train_input)
out = self.train_fn(params=self._params, states=self._state,
opt_states=self._opt_state, inputs=inputs,
rng=rng, global_step=global_step,
ema_params=self._ema_params, ema_states=self._ema_state)
self._params, self._state = out['params'], out['states']
self._opt_state = out['opt_states']
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
self.opt.plugin(self._opt_state)
return jl_utils.get_first(out['metrics'])
def _build_train_input(self):
num_devices = jax.device_count()
global_batch_size = self.config.train_batch_size
bs_per_device, ragged = divmod(global_batch_size, num_devices)
if ragged:
raise ValueError(
f'Global batch size {global_batch_size} must be divisible by '
f'num devices {num_devices}')
return dataset.load(
dataset.Split.TRAIN_AND_VALID, is_training=True,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.train_imsize,) * 2,
augment_name=self.config.augment_name,
augment_before_mix=self.config.get('augment_before_mix', True),
name=self.config.which_dataset,
fake_data=self.config.get('fake_data', False))
# _
# _____ ____ _| |
# / _ \ \ / / _` | |
# | __/\ V / (_| | |
# \___| \_/ \__,_|_|
#
def evaluate(self, global_step, **unused_args):
metrics = self._eval_epoch(self._params, self._state)
if self.config.use_ema:
ema_metrics = self._eval_epoch(self._ema_params, self._ema_state)
metrics.update({f'ema_{key}': val for key, val in ema_metrics.items()})
logging.info(f'[Step {global_step}] Eval scalars: {metrics}')
return metrics
def _eval_epoch(self, params, state):
"""Evaluates an epoch."""
num_samples = 0.
summed_metrics = None
for inputs in self._build_eval_input():
num_samples += np.prod(inputs['labels'].shape[:2]) # Account for pmaps
metrics = self.eval_fn(params, state, inputs)
# Accumulate the sum of metrics for each step.
metrics = jax.tree_map(lambda x: jnp.sum(x[0], axis=0), metrics)
if summed_metrics is None:
summed_metrics = metrics
else:
summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics)
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
return jax.device_get(mean_metrics)
def _eval_fn(self, params, state, inputs):
"""Evaluate a single batch and return loss and top-k acc."""
logits, _ = self.net.apply(params, state, None, inputs, is_training=False)
y = self._one_hot(inputs['labels'])
which_loss = getattr(utils, self.config.which_loss)
loss = which_loss(logits, y, reduction=None)
metrics = utils.topk_correct(logits, inputs['labels'], prefix='eval_')
metrics['eval_loss'] = loss
return jax.lax.psum(metrics, 'i')
def _build_eval_input(self):
"""Builds the evaluation input pipeline."""
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
split = dataset.Split.from_string(self.config.eval_subset)
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
return dataset.load(split, is_training=False,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.test_imsize,) * 2,
name=self.config.which_dataset,
eval_preproc=eval_preproc,
fake_data=self.config.get('fake_data', False))
if __name__ == '__main__':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])

View File

@@ -0,0 +1,86 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""ImageNet experiment with NF-RegNets."""
from ml_collections import config_dict
from nfnets import experiment
def get_config():
"""Return config object for training."""
config = experiment.get_config()
# Experiment config.
train_batch_size = 1024 # Global batch size.
images_per_epoch = 1281167
num_epochs = 360
steps_per_epoch = images_per_epoch / train_batch_size
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
config.random_seed = 0
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
lr=0.4,
num_epochs=num_epochs,
label_smoothing=0.1,
model='NF_RegNet',
image_size=224,
use_ema=True,
ema_decay=0.99999, # Cinco nueves amigos
ema_start=0,
augment_name='mixup_cutmix',
train_batch_size=train_batch_size,
eval_batch_size=50,
eval_subset='test',
num_classes=1000,
which_dataset='imagenet',
which_loss='softmax_cross_entropy', # One of softmax or sigmoid
bfloat16=False,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(num_steps=config.training_steps,
start_val=0,
min_val=0.001,
warmup_steps=5*steps_per_epoch),
),
lr_scale_by_bs=False,
optimizer=dict(
name='SGD',
kwargs={'momentum': 0.9, 'nesterov': True,
'weight_decay': 5e-5,},
),
model_kwargs=dict(
variant='B0',
width=0.75,
expansion=2.25,
se_ratio=0.5,
alpha=0.2,
stochdepth_rate=0.1,
drop_rate=None,
activation='silu',
),
)))
# Set weight decay based on variant (scaled as 5e-5 + 1e-5 * level)
variant = config.experiment_kwargs.config.model_kwargs.variant
weight_decay = {'B0': 5e-5, 'B1': 6e-5, 'B2': 7e-5,
'B3': 8e-5, 'B4': 9e-5, 'B5': 1e-4}[variant]
config.experiment_kwargs.config.optimizer.kwargs.weight_decay = weight_decay
return config
Experiment = experiment.Experiment

216
nfnets/fixup_resnet.py Normal file
View File

@@ -0,0 +1,216 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet (post-activation) with FixUp."""
# pylint: disable=invalid-name
import functools
import haiku as hk
import jax
import jax.numpy as jnp
from nfnets import base
nonlinearities = {
'swish': jax.nn.silu,
'relu': jax.nn.relu,
'identity': lambda x: x}
class FixUp_ResNet(hk.Module):
"""Fixup based ResNet."""
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
'ResNet101': {'depth': [3, 4, 23, 3]},
'ResNet152': {'depth': [3, 8, 36, 3]},
'ResNet200': {'depth': [3, 24, 36, 3]},
'ResNet288': {'depth': [24, 24, 24, 24]},
'ResNet600': {'depth': [50, 50, 50, 50]},
}
def __init__(self, num_classes, variant='ResNet50', width=4,
stochdepth_rate=0.1, drop_rate=None,
activation='relu', fc_init=jnp.zeros,
name='FixUp_ResNet'):
super().__init__(name=name)
self.num_classes = num_classes
self.variant = variant
self.width = width
# Get variant info
block_params = self.variant_dict[self.variant]
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
self.depth_pattern = block_params['depth']
self.activation = nonlinearities[activation]
if drop_rate is None:
self.drop_rate = block_params['drop_rate']
else:
self.drop_rate = drop_rate
self.which_conv = functools.partial(hk.Conv2D,
with_bias=False)
# Stem
ch = int(16 * self.width)
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
padding='SAME',
name='initial_conv')
# Body
self.blocks = []
num_blocks = sum(self.depth_pattern)
index = 0 # Overall block index
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
for block_width, stage_depth, stride in zip(*block_args):
for block_index in range(stage_depth):
# Block stochastic depth drop-rate
block_stochdepth_rate = stochdepth_rate * index / num_blocks
self.blocks += [ResBlock(ch, block_width, num_blocks,
stride=stride if block_index == 0 else 1,
activation=self.activation,
which_conv=self.which_conv,
stochdepth_rate=block_stochdepth_rate,
)]
ch = block_width
index += 1
# Head
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training=True, return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
# Stem
outputs = {}
out = self.initial_conv(x)
bias1 = hk.get_parameter('bias1', (), x.dtype, init=jnp.zeros)
out = self.activation(out + bias1)
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1), padding='SAME')
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_avg_var = block(out, is_training=is_training)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_avg_var
# Final-conv->activation, pool, dropout, classify
pool = jnp.mean(out, [1, 2])
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
bias2 = hk.get_parameter('bias2', (), pool.dtype, init=jnp.zeros)
outputs['logits'] = self.fc(pool + bias2)
return outputs
def count_flops(self, h, w):
flops = []
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
h, w = h / 2, w / 2
# Body FLOPs
for block in self.blocks:
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
return flops, sum(flops)
class ResBlock(hk.Module):
"""Post-activation Fixup Block."""
def __init__(self, in_ch, out_ch, num_blocks, bottleneck_ratio=0.25,
kernel_size=3, stride=1,
which_conv=hk.Conv2D, activation=jax.nn.relu,
stochdepth_rate=None, name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
self.kernel_size = kernel_size
self.activation = activation
# Bottleneck width
self.width = int(self.out_ch * bottleneck_ratio)
self.stride = stride
# Conv 0 (typically expansion conv)
conv0_init = hk.initializers.RandomNormal(
stddev=((2 / self.width)**0.5) * (num_blocks**(-0.25)))
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
name='conv0', w_init=conv0_init)
# Grouped NxN conv
conv1_init = hk.initializers.RandomNormal(
stddev=((2 / (self.width * (kernel_size**2)))**0.5)
* (num_blocks**(-0.25)))
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
padding='SAME', name='conv1', w_init=conv1_init)
# Conv 2, typically projection conv
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
name='conv2', w_init=hk.initializers.Constant(0))
# Use shortcut conv on channel change or downsample.
self.use_projection = stride > 1 or self.in_ch != self.out_ch
if self.use_projection:
shortcut_init = hk.initializers.RandomNormal(
stddev=(2 / self.out_ch) ** 0.5)
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
stride=stride, padding='SAME',
name='conv_shortcut',
w_init=shortcut_init)
# Are we using stochastic depth?
self._has_stochdepth = (stochdepth_rate is not None and
stochdepth_rate > 0. and stochdepth_rate < 1.0)
if self._has_stochdepth:
self.stoch_depth = base.StochDepth(stochdepth_rate)
def __call__(self, x, is_training):
bias1a = hk.get_parameter('bias1a', (), x.dtype, init=jnp.zeros)
bias1b = hk.get_parameter('bias1b', (), x.dtype, init=jnp.zeros)
bias2a = hk.get_parameter('bias2a', (), x.dtype, init=jnp.zeros)
bias2b = hk.get_parameter('bias2b', (), x.dtype, init=jnp.zeros)
bias3a = hk.get_parameter('bias3a', (), x.dtype, init=jnp.zeros)
bias3b = hk.get_parameter('bias3b', (), x.dtype, init=jnp.zeros)
scale = hk.get_parameter('scale', (), x.dtype, init=jnp.ones)
out = x + bias1a
shortcut = out
if self.use_projection: # Downsample with conv1x1
shortcut = self.conv_shortcut(shortcut)
out = self.conv0(out)
out = self.activation(out + bias1b)
out = self.conv1(out + bias2a)
out = self.activation(out + bias2b)
out = self.conv2(out + bias3a)
out = out * scale + bias3b
# Get average residual variance for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
# Apply stochdepth if applicable.
if self._has_stochdepth:
out = self.stoch_depth(out, is_training)
# SkipInit Gain
out = out + shortcut
return self.activation(out), res_avg_var
def count_flops(self, h, w):
# Count conv FLOPs based on input HW
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
# If block is strided we decrease resolution here.
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
if self.stride > 1:
h, w = h / self.stride, w / self.stride
if self.use_projection:
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
else:
sc_flops = 0
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, contract_flops, sc_flops])

217
nfnets/nf_regnet.py Normal file
View File

@@ -0,0 +1,217 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Normalizer-Free RegNets."""
# pylint: disable=invalid-name
import haiku as hk
import jax
import jax.numpy as jnp
from nfnets import base
class NF_RegNet(hk.Module):
"""Normalizer-Free RegNets."""
variant_dict = base.nf_regnet_params
def __init__(self, num_classes, variant='B0',
width=0.75, expansion=2.25, group_size=8, se_ratio=0.5,
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
activation='swish', fc_init=jnp.zeros,
name='NF_RegNet'):
super().__init__(name=name)
self.num_classes = num_classes
self.variant = variant
self.width = width
self.expansion = expansion
self.group_size = group_size
self.se_ratio = se_ratio
# Get variant info
block_params = self.variant_dict[self.variant]
self.train_imsize = block_params['train_imsize']
self.test_imsize = block_params['test_imsize']
self.width_pattern = block_params['width']
self.depth_pattern = block_params['depth']
self.activation = base.nonlinearities[activation]
if drop_rate is None:
self.drop_rate = block_params['drop_rate']
else:
self.drop_rate = drop_rate
self.which_conv = base.WSConv2D
# Stem
ch = int(self.width_pattern[0] * self.width)
self.initial_conv = self.which_conv(ch, kernel_shape=3, stride=2,
padding='SAME', name='initial_conv')
# Body
self.blocks = []
expected_std = 1.0
num_blocks = sum(self.depth_pattern)
index = 0 # Overall block index
for block_width, stage_depth in zip(self.width_pattern, self.depth_pattern):
for block_index in range(stage_depth):
# Scalar pre-multiplier so each block sees an N(0,1) input at init
beta = 1./ expected_std
# Block stochastic depth drop-rate
block_stochdepth_rate = stochdepth_rate * index / num_blocks
# Use a bottleneck expansion ratio of 1 for first block following EffNet
expand_ratio = 1 if index == 0 else expansion
out_ch = (int(block_width * self.width))
self.blocks += [NFBlock(ch, out_ch,
expansion=expand_ratio, se_ratio=se_ratio,
group_size=self.group_size,
stride=2 if block_index == 0 else 1,
beta=beta, alpha=alpha,
activation=self.activation,
which_conv=self.which_conv,
stochdepth_rate=block_stochdepth_rate,
)]
ch = out_ch
index += 1
# Reset expected std but still give it 1 block of growth
if block_index == 0:
expected_std = 1.0
expected_std = (expected_std **2 + alpha**2)**0.5
# Head with final conv mimicking EffNets
self.final_conv = self.which_conv(int(1280 * ch // 440), kernel_shape=1,
padding='SAME', name='final_conv')
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training=True, return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
# Stem
outputs = {}
out = self.initial_conv(x)
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_avg_var = block(out, is_training=is_training)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_avg_var
# Final-conv->activation, pool, dropout, classify
out = self.activation(self.final_conv(out))
pool = jnp.mean(out, [1, 2])
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
outputs['logits'] = self.fc(pool)
return outputs
def count_flops(self, h, w):
flops = []
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
h, w = h / 2, w / 2
# Body FLOPs
for block in self.blocks:
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
return flops, sum(flops)
class NFBlock(hk.Module):
"""Normalizer-Free RegNet Block."""
def __init__(self, in_ch, out_ch, expansion=2.25, se_ratio=0.5,
kernel_size=3, group_size=8, stride=1,
beta=1.0, alpha=0.2,
which_conv=base.WSConv2D, activation=jax.nn.relu,
stochdepth_rate=None, name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
self.expansion = expansion
self.se_ratio = se_ratio
self.kernel_size = kernel_size
self.activation = activation
self.beta, self.alpha = beta, alpha
# Round expanded with based on group count
width = int(self.in_ch * expansion)
self.groups = width // group_size
self.width = group_size * self.groups
self.stride = stride
# Conv 0 (typically expansion conv)
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
name='conv0')
# Grouped NxN conv
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
padding='SAME', feature_group_count=self.groups,
name='conv1')
# Conv 2, typically projection conv
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
name='conv2')
# Use shortcut conv on channel change or downsample.
self.use_projection = stride > 1 or self.in_ch != self.out_ch
if self.use_projection:
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
padding='SAME', name='conv_shortcut')
# Squeeze + Excite Module
self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio)
# Are we using stochastic depth?
self._has_stochdepth = (stochdepth_rate is not None and
stochdepth_rate > 0. and stochdepth_rate < 1.0)
if self._has_stochdepth:
self.stoch_depth = base.StochDepth(stochdepth_rate)
def __call__(self, x, is_training):
out = self.activation(x) * self.beta
if self.stride > 1: # Average-pool downsample.
shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1),
strides=(1, 2, 2, 1), padding='SAME')
if self.use_projection:
shortcut = self.conv_shortcut(shortcut)
elif self.use_projection:
shortcut = self.conv_shortcut(out)
else:
shortcut = x
out = self.conv0(out)
out = self.conv1(self.activation(out))
out = 2 * self.se(out) * out # Multiply by 2 for rescaling
out = self.conv2(self.activation(out))
# Get average residual standard deviation for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
# Apply stochdepth if applicable.
if self._has_stochdepth:
out = self.stoch_depth(out, is_training)
# SkipInit Gain
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
return out * self.alpha + shortcut, res_avg_var
def count_flops(self, h, w):
# Count conv FLOPs based on input HW
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
# If block is strided we decrease resolution here.
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
if self.stride > 1:
h, w = h / self.stride, w / self.stride
if self.use_projection:
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
else:
sc_flops = 0
# SE flops happen on avg-pooled activations
se_flops = self.se.fc0.output_size * self.width
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])

216
nfnets/nf_resnet.py Normal file
View File

@@ -0,0 +1,216 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Norm-Free Residual Networks."""
# pylint: disable=invalid-name
import haiku as hk
import jax
import jax.numpy as jnp
from nfnets import base
class NF_ResNet(hk.Module):
"""Norm-Free preactivation ResNet."""
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
'ResNet101': {'depth': [3, 4, 23, 3]},
'ResNet152': {'depth': [3, 8, 36, 3]},
'ResNet200': {'depth': [3, 24, 36, 3]},
'ResNet288': {'depth': [24, 24, 24, 24]},
'ResNet600': {'depth': [50, 50, 50, 50]},
}
def __init__(self, num_classes, variant='ResNet50', width=4,
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
activation='relu', fc_init=None, skipinit_gain=jnp.zeros,
use_se=False, se_ratio=0.25,
name='NF_ResNet'):
super().__init__(name=name)
self.num_classes = num_classes
self.variant = variant
self.width = width
# Get variant info
block_params = self.variant_dict[self.variant]
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
self.depth_pattern = block_params['depth']
self.activation = base.nonlinearities[activation]
if drop_rate is None:
self.drop_rate = block_params['drop_rate']
else:
self.drop_rate = drop_rate
self.which_conv = base.WSConv2D
# Stem
ch = int(16 * self.width)
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
padding='SAME', with_bias=False,
name='initial_conv')
# Body
self.blocks = []
expected_std = 1.0
num_blocks = sum(self.depth_pattern)
index = 0 # Overall block index
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
for block_width, stage_depth, stride in zip(*block_args):
for block_index in range(stage_depth):
# Scalar pre-multiplier so each block sees an N(0,1) input at init
beta = 1./ expected_std
# Block stochastic depth drop-rate
block_stochdepth_rate = stochdepth_rate * index / num_blocks
self.blocks += [NFResBlock(ch, block_width,
stride=stride if block_index == 0 else 1,
beta=beta, alpha=alpha,
activation=self.activation,
which_conv=self.which_conv,
stochdepth_rate=block_stochdepth_rate,
skipinit_gain=skipinit_gain,
use_se=use_se,
se_ratio=se_ratio,
)]
ch = block_width
index += 1
# Reset expected std but still give it 1 block of growth
if block_index == 0:
expected_std = 1.0
expected_std = (expected_std **2 + alpha**2)**0.5
# Head. By default, initialize with N(0, 0.01)
if fc_init is None:
fc_init = hk.initializers.RandomNormal(0.01, 0)
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training=True, return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
# Stem
outputs = {}
out = self.initial_conv(x)
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1), padding='SAME')
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_avg_var = block(out, is_training=is_training)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_avg_var
# Final-conv->activation, pool, dropout, classify
pool = jnp.mean(self.activation(out), [1, 2])
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
outputs['logits'] = self.fc(pool)
return outputs
def count_flops(self, h, w):
flops = []
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
h, w = h / 2, w / 2
# Body FLOPs
for block in self.blocks:
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
return flops, sum(flops)
class NFResBlock(hk.Module):
"""Normalizer-Free pre-activation ResNet Block."""
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25,
kernel_size=3, stride=1,
beta=1.0, alpha=0.2,
which_conv=base.WSConv2D, activation=jax.nn.relu,
skipinit_gain=jnp.zeros,
stochdepth_rate=None,
use_se=False, se_ratio=0.25,
name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
self.kernel_size = kernel_size
self.activation = activation
self.beta, self.alpha = beta, alpha
self.skipinit_gain = skipinit_gain
self.use_se, self.se_ratio = use_se, se_ratio
# Bottleneck width
self.width = int(self.out_ch * bottleneck_ratio)
self.stride = stride
# Conv 0 (typically expansion conv)
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
name='conv0')
# Grouped NxN conv
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
padding='SAME', name='conv1')
# Conv 2, typically projection conv
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
name='conv2')
# Use shortcut conv on channel change or downsample.
self.use_projection = stride > 1 or self.in_ch != self.out_ch
if self.use_projection:
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
stride=stride, padding='SAME',
name='conv_shortcut')
# Are we using stochastic depth?
self._has_stochdepth = (stochdepth_rate is not None and
stochdepth_rate > 0. and stochdepth_rate < 1.0)
if self._has_stochdepth:
self.stoch_depth = base.StochDepth(stochdepth_rate)
if self.use_se:
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
def __call__(self, x, is_training):
out = self.activation(x) * self.beta
shortcut = x
if self.use_projection: # Downsample with conv1x1
shortcut = self.conv_shortcut(out)
out = self.conv0(out)
out = self.conv1(self.activation(out))
out = self.conv2(self.activation(out))
if self.use_se:
out = 2 * self.se(out) * out
# Get average residual standard deviation for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
# Apply stochdepth if applicable.
if self._has_stochdepth:
out = self.stoch_depth(out, is_training)
# SkipInit Gain
out = out * hk.get_parameter('skip_gain', (), out.dtype,
init=self.skipinit_gain)
return out * self.alpha + shortcut, res_avg_var
def count_flops(self, h, w):
# Count conv FLOPs based on input HW
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
# If block is strided we decrease resolution here.
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
if self.stride > 1:
h, w = h / self.stride, w / self.stride
if self.use_projection:
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
else:
sc_flops = 0
# SE flops happen on avg-pooled activations
se_flops = self.se.fc0.output_size * self.width
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])

455
nfnets/optim.py Normal file
View File

@@ -0,0 +1,455 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers and Schedulers, inspired by the PyTorch API."""
from collections import ChainMap # pylint:disable=g-importing-member
from typing import Callable
import haiku as hk
import jax
import jax.numpy as jnp
import tree
from nfnets import utils
class Optimizer(object):
"""Optimizer base class."""
def __init__(self, params, defaults):
# Flag indicating if parameters have been broadcasted
self._broadcasted = False
# Optimizer hyperparameters; this is a dict to support using param_groups
self._hyperparameters = {}
# Mapping from model parameters to optimizer hyperparameters
self._params2hyperparams = {}
# Assign defaults
self._hyperparameters = dict(**defaults)
# Prepare parameter groups and mappings
self.create_param_groups(params, defaults)
# Join params at top-level if params is a list of groups
if isinstance(params, list):
flatmap = type(hk.data_structures.to_immutable_dict({}))
if any([isinstance(group['params'], flatmap) for group in params]):
params = hk.data_structures.merge(*[group['params']
for group in params])
else:
params = dict(ChainMap(*[group['params'] for group in params]))
# Prepare states
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
self._states = tree.map_structure_with_path(create_buffers, params)
def add_hyperparam_group(self, group, suffix, defaults):
"""Adds new hyperparameters to the hyperparams dict."""
# Use default hyperparams unless overridden by group hyperparams
group_dict = {key: key for key in defaults if key not in group}
for key in group:
if key != 'params': # Reserved keyword 'params'
group_dict[key] = '%s_%s' % (key, suffix)
self._hyperparameters[group_dict[key]] = group[key]
# Set up params2hyperparams
def set_p2h(k, _):
self._params2hyperparams['/'.join(k)] = group_dict
tree.map_structure_with_path(set_p2h, group['params'])
def create_param_groups(self, params, defaults):
"""Creates param-hyperparam mappings."""
if isinstance(params, list):
for group_index, group in enumerate(params):
# Add group to hyperparams and get this group's full hyperparameters
self.add_hyperparam_group(group, group_index, defaults)
else:
mapping = {key: key for key in self._hyperparameters}
def set_p2h(k, _):
self._params2hyperparams['/'.join(k)] = mapping
tree.map_structure_with_path(set_p2h, params)
def create_buffers(self, name, params):
"""Method to be overridden by child classes."""
pass
def get_opt_params(self, param_name, itr):
"""Returns hyperparams corresponding to param_name."""
mapping = self._params2hyperparams[param_name]
output = {}
for key in mapping:
hyper = self._hyperparameters[mapping[key]]
# Handle the case where a hyper is a class, for hybrids
if isinstance(hyper, Callable) and not isinstance(hyper, type):
output[key] = hyper(itr)
else:
output[key] = hyper
return output
def get_hyper(self, param_name, hyper_name):
"""Get an individual hyperparam for a given param."""
mapping = self._params2hyperparams[param_name]
return self._hyperparameters[mapping[hyper_name]]
def plugin(self, states):
self._states = states
def states(self):
return self._states
def broadcast(self):
"""Brodcasts all buffers and parameters."""
self._broadcasted = True
for name, state in self._states.items():
self._states[name] = {key: utils.broadcast(state[key]) for key in state}
def gather(self):
"""Gathers state (if broadcasted) for saving."""
states = {}
for name in self._states:
state = self._states[name]
states[name] = {key: state[key] if state[key] is None else state[key][0]
for key in state}
return states
def __setattr__(self, name, value):
"""Overrides the object's set-attribute function to register states, etc."""
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
self._hyperparameters[name] = value
elif '_states' in self.__dict__ and name in self._states:
self._states[name] = value
else:
object.__setattr__(self, name, value)
def __getattr__(self, name):
"""Override the object's get-attribute function to return states, etc."""
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
return self._hyperparameters[name]
elif '_states' in self.__dict__ and name in self._states:
return self._states[name]
else:
object.__getattr__(self, name)
def step(self, params, grads, states, itr=None):
"""Takes a single optimizer step.
Args:
params: a dict containing the parameters to be updated.
grads: a dict containing the gradients for each parameter in params.
states: a dict containing any optimizer buffers (momentum, etc) for
each parameter in params.
itr: an optional integer indicating the current step, for scheduling.
Returns:
The updated params and optimizer buffers.
"""
get_hyper = lambda k, v: self.get_opt_params('/'.join(k), itr)
hypers = tree.map_structure_with_path(get_hyper, params)
outs = tree.map_structure_up_to(params, self.update_param,
params, grads, states, hypers)
return utils.split_tree(outs, params, 2)
class Schedule(object):
"""Hyperparameter scheduling objects."""
class CosineDecay(Schedule):
"""Cosine decay."""
def __init__(self, min_val, max_val, num_steps):
self.min_val = min_val
self.max_val = max_val
self.num_steps = num_steps
def __call__(self, itr):
cos = (1 + jnp.cos(jnp.pi * itr / self.num_steps))
return 0.5 * (self.max_val - self.min_val) * cos + self.min_val
class WarmupCosineDecay(Schedule):
"""Cosine decay with linear warmup."""
def __init__(self, start_val, min_val, max_val, num_steps, warmup_steps):
self.start_val = start_val
self.min_val = min_val
self.max_val = max_val
self.num_steps = num_steps
self.warmup_steps = warmup_steps
def __call__(self, itr):
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
+ self.start_val)
cos_itr = (itr - self.warmup_steps) / (self.num_steps - self.warmup_steps)
cos = 1 + jnp.cos(jnp.pi * cos_itr)
cos_val = 0.5 * (self.max_val - self.min_val) * cos + self.min_val
# Select warmup_val if itr < warmup, else cosine val
values = jnp.array([warmup_val, cos_val])
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
return jnp.take(values, index)
class WarmupExpDecay(Schedule):
"""Exponential step decay with linear warmup."""
def __init__(self, start_val, max_val, warmup_steps,
decay_factor, decay_interval):
self.start_val = start_val
self.max_val = max_val
self.warmup_steps = warmup_steps
self.decay_factor = decay_factor
self.decay_interval = decay_interval
def __call__(self, itr):
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
+ self.start_val)
# How many decay steps have we taken?
num_decays = jnp.floor((itr - self.warmup_steps) / self.decay_interval)
exp_val = self.max_val * (self.decay_factor ** num_decays)
# Select warmup_val if itr < warmup, else exp_val
values = jnp.array([warmup_val, exp_val])
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
return jnp.take(values, index)
class SGD(Optimizer):
"""Standard SGD with (nesterov) momentum and weight decay.
Attributes:
params: Either a dict mapping param names to JAX tensors, or a list where
each member of the list is a dict containing parameters
and hyperparameters, allowing one to specify param-specific hyperparams.
lr: Learning rate.
weight_decay: Weight decay parameter. Note that this is decay, not L2 reg.
momentum: Momentum parameter
dampening: Dampening parameter
nesterov: Bool indicating this optimizer will use the NAG formulation.
"""
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
'nesterov': None}
def __init__(self, params, lr, weight_decay=None,
momentum=None, dampening=0, nesterov=None):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay,
'momentum': momentum, 'dampening': dampening,
'nesterov': nesterov})
def create_buffers(self, name, param):
"""Prepares all momentum buffers for each parameter."""
state = {'step': jnp.zeros(jax.local_device_count())}
if self.get_hyper(name, 'momentum') is not None:
state['momentum'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
# Apply weight decay
if opt_params.get('weight_decay') is not None:
grad = grad + param * opt_params['weight_decay']
# Update momentum buffers if needed
if 'momentum' in state:
state['momentum'] = (opt_params['momentum'] * state['momentum']
+ (1 - opt_params['dampening']) * grad)
if opt_params['nesterov'] is not None:
grad = grad + opt_params['momentum'] * state['momentum']
else:
grad = state['momentum']
state['step'] += 1
return param - opt_params['lr'] * grad, state
class Adam(Optimizer):
"""Adam optimizer, Kingma & Ba, arxiv.org/abs/1412.6980.
Args:
params (iterable): nested list of params to optimize
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (default: 0)
use_adamw (bool, optional): If not None, use decoupled weight decay
as in arxiv.org/abs/1711.05101. The paper version adds an additional
"schedule" hyperparameter eta, which we instead just replace with the
learning rate following the PyTorch implementation.
Note that this implementation will not instantiate a buffer if the
beta term for that buffer is passed in as None, thus conserving memory.
"""
defaults = {'beta1': 0.9, 'beta2': 0.999, 'weight_decay': None, 'eps': 1e-8,
'use_adamw': None}
def __init__(self, params, lr, beta1=0.9, beta2=0.999,
eps=1e-8, weight_decay=None, use_adamw=None):
super().__init__(params=params,
defaults={'lr': lr, 'beta1': beta1,
'beta2': beta2, 'eps': eps,
'weight_decay': weight_decay,
'use_adamw': use_adamw})
def create_buffers(self, name, param):
"""Prepare exp_avg and exp_avg_sq buffers."""
state = {'step': jnp.zeros(jax.local_device_count())}
if self.get_hyper(name, 'beta1') is not None:
state['exp_avg'] = jnp.zeros_like(param)
if self.get_hyper(name, 'beta2') is not None:
state['exp_avg_sq'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
state['step'] = state['step'] + 1
# Apply weight decay
if opt_params.get('weight_decay') is not None:
if opt_params.get('use_adamw') is not None:
param = param * (1 - opt_params['lr'] * opt_params['weight_decay'])
else:
grad = grad + param * opt_params['weight_decay']
# First moment
if 'exp_avg' in state:
bias_correction1 = 1 - opt_params['beta1'] ** state['step']
state['exp_avg'] = (state['exp_avg'] * opt_params['beta1']
+ (1 - opt_params['beta1']) * grad)
step_size = opt_params['lr'] * state['exp_avg'] / bias_correction1
else:
step_size = opt_params['lr'] * grad
# Second moment
if 'exp_avg_sq' in state:
bias_correction2 = 1 - opt_params['beta2'] ** state['step']
state['exp_avg_sq'] = (state['exp_avg_sq'] * opt_params['beta2']
+ (1 - opt_params['beta2']) * grad * grad)
denom = jnp.sqrt(state['exp_avg_sq']) * jax.lax.rsqrt(bias_correction2)
denom = denom + opt_params['eps']
else:
denom = jnp.abs(grad) + opt_params['eps'] # Add eps to avoid divide-by-0
return param - step_size / denom, state
class RMSProp(Optimizer):
"""RMSProp optimizer, Tieleman and Hinton, ref: powerpoint slides.
Implements RMSProp as
rms = decay * rms{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(rms + epsilon)
param -= mom
Note that the rms buffer is initialized with ones as in TF, as opposed to
zeros as in all other implementations.
Args:
params (iterable): nested list of params to optimize
lr (float): learning rate (default: 1e-3)
decay (float): EMA decay rate for running estimate of squared gradient.
momentum (float or None): Use heavy ball momentum instead of instant grad.
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (NOT ADAMW (default: 0))
"""
defaults = {'weight_decay': None, 'eps': 1e-8}
def __init__(self, params, lr, decay, momentum, weight_decay=None, eps=1e-8):
super().__init__(params=params,
defaults={'lr': lr, 'decay': decay,
'momentum': momentum, 'eps': eps,
'weight_decay': weight_decay})
def create_buffers(self, name, param):
"""Prepare exp_avg and exp_avg_sq buffers."""
state = {'step': jnp.zeros(jax.local_device_count())}
state['rms'] = jnp.ones_like(param)
if self.get_hyper(name, 'momentum') is not None:
state['momentum'] = jnp.zeros_like(param)
return state
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
state['step'] = state['step'] + 1
# Apply weight decay
if opt_params.get('weight_decay') is not None:
grad = grad + param * opt_params['weight_decay']
# EMA of the squared gradient
state['rms'] = (state['rms'] * opt_params['decay']
+ (1 - opt_params['decay']) * (grad ** 2))
scaled_grad = (opt_params['lr'] * grad
/ (state['rms'] + opt_params['eps']) ** 0.5)
if state['momentum'] is not None:
state['momentum'] = (state['momentum'] * opt_params['momentum']
+ scaled_grad)
step_size = state['momentum']
else:
step_size = scaled_grad
return param - step_size, state
class Fromage(Optimizer):
"""Fromage optimizer, Bernstein et al. arXiv.org/abs/2002.03432.
This version optionally includes weight decay.
Attributes:
params (iterable): nested list of params to optimize
lr (float): learning rate.
eps (float, optional): Minimum allowable norm. This term is required for
in case parameters are zero-initialized (default: 1e-5).
weight_decay (float, optional): weight decay (default: 0).
"""
defaults = {'weight_decay': None, 'eps': 1e-5}
def __init__(self, params, lr, weight_decay=None, eps=1e-5):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay, 'eps': eps})
def create_buffers(self, name, param): # pylint: disable=unused-argument
"""Prepares all momentum buffers for each parameter."""
return {'step': jnp.zeros(1)}
def update_param(self, param, grad, state, opt_params):
"""The actual update step for this optimizer."""
if param is None:
return param, state
if opt_params['weight_decay'] is not None:
grad = grad + param * opt_params['weight_decay']
grad_norm = jnp.maximum(jnp.linalg.norm(grad), opt_params['eps'])
param_norm = jnp.maximum(jnp.linalg.norm(param), opt_params['eps'])
mult = jax.lax.rsqrt(1 + opt_params['lr'] ** 2)
out = (param - opt_params['lr'] * grad * (param_norm / grad_norm)) * mult
return out, state
class Hybrid(Optimizer):
"""Optimizer which permits passing param groups with different base opts.
The API for this class follows the case for any other optimizer where one
specifies a list of dicts with separate hyperparams, but in this case it
requires the user to also specify an 'opt' key for each group, such as e.g.
[{'params': params0, 'opt': optim.Adam, 'lr': 0.1}].
The user must also provide values for any arg in the selected optimizers which
does not have a default value associated
"""
def __init__(self, param_groups):
if any(['opt' not in group for group in param_groups]):
raise ValueError('All parameter groups must have an opt key!')
self.defaults = ChainMap(*[group['opt'].defaults for group in param_groups])
super().__init__(param_groups, defaults=dict(self.defaults))
def create_buffers(self, name, param):
return self.get_hyper(name, 'opt').create_buffers(self, name, param)
def update_param(self, param, grad, state, opt_params):
return opt_params['opt'].update_param(self, param, grad, state, opt_params)

14
nfnets/requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
absl-py==0.10.0
chex>=0.0.2
dm-haiku>=0.0.3
jax>=0.2.8
jaxlib>=0.1.58
jaxline>=0.0.1
ml_collections>=0.1
numpy>=1.18.0
tensorflow>=2.3.1
tensorflow-addons>=0.12.0
tensorflow-datasets>=4.1.0
tensorflow-probability>=0.12.1
typing_extensions>=3.7
wrapt>=1.11.2

183
nfnets/resnet.py Normal file
View File

@@ -0,0 +1,183 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet model family."""
import functools
import haiku as hk
import jax
import jax.numpy as jnp
from nfnets import base
class ResNet(hk.Module):
"""ResNetv2 Models."""
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
'ResNet101': {'depth': [3, 4, 23, 3]},
'ResNet152': {'depth': [3, 8, 36, 3]},
'ResNet200': {'depth': [3, 24, 36, 3]},
'ResNet288': {'depth': [24, 24, 24, 24]},
'ResNet600': {'depth': [50, 50, 50, 50]},
}
def __init__(self, width, num_classes,
variant='ResNet50',
which_norm='BatchNorm', norm_kwargs=None,
activation='relu', drop_rate=0.0,
fc_init=jnp.zeros, conv_kwargs=None,
preactivation=True, use_se=False, se_ratio=0.25,
name='ResNet'):
super().__init__(name=name)
self.width = width
self.num_classes = num_classes
self.variant = variant
self.depth_pattern = self.variant_dict[variant]['depth']
self.activation = getattr(jax.nn, activation)
self.drop_rate = drop_rate
self.which_norm = getattr(hk, which_norm)
if norm_kwargs is not None:
self.which_norm = functools.partial(self.which_norm, **norm_kwargs)
if conv_kwargs is not None:
self.which_conv = functools.partial(hk.Conv2D, **conv_kwargs)
else:
self.which_conv = hk.Conv2D
self.preactivation = preactivation
# Stem
self.initial_conv = self.which_conv(16 * self.width, kernel_shape=7,
stride=2, padding='SAME',
with_bias=False, name='initial_conv')
if not self.preactivation:
self.initial_bn = self.which_norm(name='initial_bn')
which_block = ResBlockV2 if self.preactivation else ResBlockV1
# Body
self.blocks = []
for multiplier, blocks_per_stage, stride in zip([64, 128, 256, 512],
self.depth_pattern,
[1, 2, 2, 2]):
for block_index in range(blocks_per_stage):
self.blocks += [which_block(multiplier * self.width,
use_projection=block_index == 0,
stride=stride if block_index == 0 else 1,
activation=self.activation,
which_norm=self.which_norm,
which_conv=self.which_conv,
use_se=use_se,
se_ratio=se_ratio)]
# Head
self.final_bn = self.which_norm(name='final_bn')
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training, test_local_stats=False,
return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
outputs = {}
# Stem
out = self.initial_conv(x)
if not self.preactivation:
out = self.activation(self.initial_bn(out, is_training, test_local_stats))
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1), padding='SAME')
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_var = block(out, is_training, test_local_stats)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_var
if self.preactivation:
out = self.activation(self.final_bn(out, is_training, test_local_stats))
# Pool, dropout, classify
pool = jnp.mean(out, axis=[1, 2])
# Return pool before dropout in case we want to regularize it separately.
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
outputs['logits'] = self.fc(pool)
return outputs
class ResBlockV2(hk.Module):
"""ResNet preac block, 1x1->3x3->1x1 with strides and shortcut downsample."""
def __init__(self, out_ch, stride=1, use_projection=False,
activation=jax.nn.relu, which_norm=hk.BatchNorm,
which_conv=hk.Conv2D, use_se=False, se_ratio=0.25,
name=None):
super().__init__(name=name)
self.out_ch = out_ch
self.stride = stride
self.use_projection = use_projection
self.activation = activation
self.which_norm = which_norm
self.which_conv = which_conv
self.use_se = use_se
self.se_ratio = se_ratio
self.width = self.out_ch // 4
self.bn0 = which_norm(name='bn0')
self.conv0 = which_conv(self.width, kernel_shape=1, with_bias=False,
padding='SAME', name='conv0')
self.bn1 = which_norm(name='bn1')
self.conv1 = which_conv(self.width, stride=self.stride,
kernel_shape=3, with_bias=False,
padding='SAME', name='conv1')
self.bn2 = which_norm(name='bn2')
self.conv2 = which_conv(self.out_ch, kernel_shape=1, with_bias=False,
padding='SAME', name='conv2')
if self.use_projection:
self.conv_shortcut = which_conv(self.out_ch, stride=stride,
kernel_shape=1, with_bias=False,
padding='SAME', name='conv_shortcut')
if self.use_se:
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
def __call__(self, x, is_training, test_local_stats):
bn_args = (is_training, test_local_stats)
out = self.activation(self.bn0(x, *bn_args))
if self.use_projection:
shortcut = self.conv_shortcut(out)
else:
shortcut = x
out = self.conv0(out)
out = self.conv1(self.activation(self.bn1(out, *bn_args)))
out = self.conv2(self.activation(self.bn2(out, *bn_args)))
if self.use_se:
out = self.se(out) * out
# Get average residual standard deviation for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
return out + shortcut, res_avg_var
class ResBlockV1(ResBlockV2):
"""Post-Ac Residual Block."""
def __call__(self, x, is_training, test_local_stats):
bn_args = (is_training, test_local_stats)
if self.use_projection:
shortcut = self.conv_shortcut(x)
shortcut = self.which_norm(name='shortcut_bn')(shortcut, *bn_args)
else:
shortcut = x
out = self.activation(self.bn0(self.conv0(x), *bn_args))
out = self.activation(self.bn1(self.conv1(out), *bn_args))
out = self.bn2(self.conv2(out), *bn_args)
if self.use_se:
out = self.se(out) * out
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
return self.activation(out + shortcut), res_avg_var

22
nfnets/run.sh Normal file
View File

@@ -0,0 +1,22 @@
#!/bin/sh
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
python3 -m venv nfnets_venv
source nfnets_venv/bin/activate
pip3 install --upgrade setuptools wheel
pip3 install -r nfnets/requirements.txt
python3 -m nfnets.test

194
nfnets/skipinit_resnet.py Normal file
View File

@@ -0,0 +1,194 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNetV2 (Pre-activation) with SkipInit."""
# pylint: disable=invalid-name
import haiku as hk
import jax
import jax.numpy as jnp
from nfnets import base
# Nonlinearities
nonlinearities = {
'swish': jax.nn.silu,
'relu': jax.nn.relu,
'identity': lambda x: x}
class SkipInit_ResNet(hk.Module):
"""Skip-Init based ResNet."""
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
'ResNet101': {'depth': [3, 4, 23, 3]},
'ResNet152': {'depth': [3, 8, 36, 3]},
'ResNet200': {'depth': [3, 24, 36, 3]},
'ResNet288': {'depth': [24, 24, 24, 24]},
'ResNet600': {'depth': [50, 50, 50, 50]},
}
def __init__(self, num_classes, variant='ResNet50', width=4,
stochdepth_rate=0.1, drop_rate=None,
activation='relu', fc_init=jnp.zeros,
name='SkipInit_ResNet'):
super().__init__(name=name)
self.num_classes = num_classes
self.variant = variant
self.width = width
# Get variant info
block_params = self.variant_dict[self.variant]
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
self.depth_pattern = block_params['depth']
self.activation = nonlinearities[activation]
if drop_rate is None:
self.drop_rate = block_params['drop_rate']
else:
self.drop_rate = drop_rate
self.which_conv = hk.Conv2D
# Stem
ch = int(16 * self.width)
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
padding='SAME', with_bias=False,
name='initial_conv')
# Body
self.blocks = []
num_blocks = sum(self.depth_pattern)
index = 0 # Overall block index
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
for block_width, stage_depth, stride in zip(*block_args):
for block_index in range(stage_depth):
# Block stochastic depth drop-rate
block_stochdepth_rate = stochdepth_rate * index / num_blocks
self.blocks += [NFResBlock(ch, block_width,
stride=stride if block_index == 0 else 1,
activation=self.activation,
which_conv=self.which_conv,
stochdepth_rate=block_stochdepth_rate,
)]
ch = block_width
index += 1
# Head
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training=True, return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
# Stem
outputs = {}
out = self.initial_conv(x)
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1), padding='SAME')
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_avg_var = block(out, is_training=is_training)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_avg_var
# Final-conv->activation, pool, dropout, classify
pool = jnp.mean(self.activation(out), [1, 2])
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
outputs['logits'] = self.fc(pool)
return outputs
def count_flops(self, h, w):
flops = []
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
h, w = h / 2, w / 2
# Body FLOPs
for block in self.blocks:
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
return flops, sum(flops)
class NFResBlock(hk.Module):
"""Normalizer-Free pre-activation ResNet Block."""
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25,
kernel_size=3, stride=1,
which_conv=hk.Conv2D, activation=jax.nn.relu,
stochdepth_rate=None, name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
self.kernel_size = kernel_size
self.activation = activation
# Bottleneck width
self.width = int(self.out_ch * bottleneck_ratio)
self.stride = stride
# Conv 0 (typically expansion conv)
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
name='conv0')
# Grouped NxN conv
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
padding='SAME', name='conv1')
# Conv 2, typically projection conv
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
name='conv2')
# Use shortcut conv on channel change or downsample.
self.use_projection = stride > 1 or self.in_ch != self.out_ch
if self.use_projection:
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
stride=stride, padding='SAME',
name='conv_shortcut')
# Are we using stochastic depth?
self._has_stochdepth = (stochdepth_rate is not None and
stochdepth_rate > 0. and stochdepth_rate < 1.0)
if self._has_stochdepth:
self.stoch_depth = base.StochDepth(stochdepth_rate)
def __call__(self, x, is_training):
out = self.activation(x)
shortcut = x
if self.use_projection: # Downsample with conv1x1
shortcut = self.conv_shortcut(out)
out = self.conv0(out)
out = self.conv1(self.activation(out))
out = self.conv2(self.activation(out))
# Get average residual standard deviation for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
# Apply stochdepth if applicable.
if self._has_stochdepth:
out = self.stoch_depth(out, is_training)
# SkipInit Gain
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
return out + shortcut, res_avg_var
def count_flops(self, h, w):
# Count conv FLOPs based on input HW
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
# If block is strided we decrease resolution here.
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
if self.stride > 1:
h, w = h / self.stride, w / self.stride
if self.use_projection:
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
else:
sc_flops = 0
# SE flops happen on avg-pooled activations
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, contract_flops, sc_flops])

35
nfnets/test.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Quick script to test that experiment can import and run."""
import jax
import jax.numpy as jnp
from nfnets import experiment
config = experiment.get_config()
exp_config = config.experiment_kwargs.config
exp_config.train_batch_size = 2
exp_config.eval_batch_size = 2
exp_config.lr = 0.1
exp_config.fake_data = True
exp_config.model_kwargs.width = 2
print(exp_config.model_kwargs)
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'Step successfully taken, resulting metrics are {result}')

107
nfnets/utils.py Normal file
View File

@@ -0,0 +1,107 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils."""
import jax
import jax.numpy as jnp
import tree
def reduce_fn(x, mode):
"""Reduce fn for various losses."""
if mode == 'none' or mode is None:
return jnp.asarray(x)
elif mode == 'sum':
return jnp.sum(x)
elif mode == 'mean':
return jnp.mean(x)
else:
raise ValueError('Unsupported reduction option.')
def softmax_cross_entropy(logits, labels, reduction='sum'):
"""Computes softmax cross entropy given logits and one-hot class labels.
Args:
logits: Logit output values.
labels: Ground truth one-hot-encoded labels.
reduction: Type of reduction to apply to loss.
Returns:
Loss value. If `reduction` is `none`, this has the same shape as `labels`;
otherwise, it is scalar.
Raises:
ValueError: If the type of `reduction` is unsupported.
"""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return reduce_fn(loss, reduction)
def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)):
"""Calculate top-k error for multiple k values."""
metrics = {}
argsorted_logits = jnp.argsort(logits)
for k in topk:
pred_labels = argsorted_logits[..., -k:]
# Get the number of examples where the label is in the top-k predictions
correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32)
if mask is not None:
correct *= mask
metrics[f'{prefix}top_{k}_acc'] = correct
return metrics
@jax.vmap
def any_in(prediction, target):
"""For each row in a and b, checks if any element of a is in b."""
return jnp.isin(prediction, target)
def tf1_ema(ema_value, current_value, decay, step):
"""Implements EMA with TF1-style decay warmup."""
decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step))
return ema_value * decay + current_value * (1 - decay)
def ema(ema_value, current_value, decay, step):
"""Implements EMA without any warmup."""
del step
return ema_value * decay + current_value * (1 - decay)
to_bf16 = lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x
from_bf16 = lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x
def _replicate(x, devices=None):
"""Replicate an object on each device."""
x = jax.numpy.array(x)
if devices is None:
devices = jax.local_devices()
return jax.api.device_put_sharded(len(devices) * [x], devices)
def broadcast(obj):
"""Broadcasts an object to all devices."""
if obj is not None and not isinstance(obj, bool):
return _replicate(obj)
else:
return obj
def split_tree(tuple_tree, base_tree, n):
"""Splits tuple_tree with n-tuple leaves into n trees."""
return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop
for i in range(n)]