mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-05 19:26:22 +08:00
Initial release of "nfnets".
PiperOrigin-RevId: 356975781
This commit is contained in:
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Characterizing signal propagation to close the performance gap in unnormalized ResNets](nfnets), ICLR 2021
|
||||
* [Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](adversarial_robustness)
|
||||
* [Functional Regularisation for Continual Learning](functional_regularisation_for_continual_learning), ICLR 2020
|
||||
* [Self-Supervised MultiModal Versatile Networks](mmv), NeurIPS 2020
|
||||
|
||||
29
nfnets/README.md
Normal file
29
nfnets/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Code for Normalizer-Free ResNets (Brock, De, Smith, ICLR2021)
|
||||
This repository contains code for the ICLR 2021 paper
|
||||
"Characterizing signal propagation to close the performance gap in unnormalized
|
||||
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith.
|
||||
|
||||
|
||||
## Running this code
|
||||
Install using pip install -r requirements.txt and use one of the experiment.py
|
||||
files in combination with [JAXline](https://github.com/deepmind/jaxline) to
|
||||
train models. Optionally copy test.py into
|
||||
a dir one level up and run it to ensure you can take a single experiment step
|
||||
with fake data.
|
||||
|
||||
Note that you will need a local copy of ImageNet compatible with the TFDS format
|
||||
used in dataset.py in order to train on ImageNet.
|
||||
|
||||
## Giving Credit
|
||||
|
||||
If you use this code in your work, we ask you to cite this paper:
|
||||
|
||||
```
|
||||
@inproceedings{brock2021characterizing,
|
||||
author={Andrew Brock and Soham De and Samuel L. SMith},
|
||||
title={Characterizing signal propagation to close the performance gap in
|
||||
unnormalized ResNets},
|
||||
booktitle={9th International Conference on Learning Representations, {ICLR}},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
714
nfnets/autoaugment.py
Normal file
714
nfnets/autoaugment.py
Normal file
File diff suppressed because it is too large
Load Diff
176
nfnets/base.py
Normal file
176
nfnets/base.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Architecture definitions for different models."""
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Model settings for NF-RegNets
|
||||
nf_regnet_params = {
|
||||
'B0': {'width': [48, 104, 208, 440], 'depth': [1, 3, 6, 6],
|
||||
'train_imsize': 192, 'test_imsize': 224,
|
||||
'drop_rate': 0.2},
|
||||
'B1': {'width': [48, 104, 208, 440], 'depth': [2, 4, 7, 7],
|
||||
'train_imsize': 224, 'test_imsize': 256,
|
||||
'drop_rate': 0.2},
|
||||
'B2': {'width': [56, 112, 232, 488], 'depth': [2, 4, 8, 8],
|
||||
'train_imsize': 240, 'test_imsize': 272,
|
||||
'drop_rate': 0.3},
|
||||
'B3': {'width': [56, 128, 248, 528], 'depth': [2, 5, 9, 9],
|
||||
'train_imsize': 288, 'test_imsize': 320,
|
||||
'drop_rate': 0.3},
|
||||
'B4': {'width': [64, 144, 288, 616], 'depth': [2, 6, 11, 11],
|
||||
'train_imsize': 320, 'test_imsize': 384,
|
||||
'drop_rate': 0.4},
|
||||
'B5': {'width': [80, 168, 336, 704], 'depth': [3, 7, 14, 14],
|
||||
'train_imsize': 384, 'test_imsize': 456,
|
||||
'drop_rate': 0.4},
|
||||
'B6': {'width': [88, 184, 376, 792], 'depth': [3, 8, 16, 16],
|
||||
'train_imsize': 448, 'test_imsize': 528,
|
||||
'drop_rate': 0.5},
|
||||
'B7': {'width': [96, 208, 416, 880], 'depth': [4, 10, 19, 19],
|
||||
'train_imsize': 512, 'test_imsize': 600,
|
||||
'drop_rate': 0.5},
|
||||
'B8': {'width': [104, 232, 456, 968], 'depth': [4, 11, 22, 22],
|
||||
'train_imsize': 600, 'test_imsize': 672,
|
||||
'drop_rate': 0.5},
|
||||
}
|
||||
|
||||
|
||||
# Nonlinearities with magic constants (gamma) baked in.
|
||||
# Note that not all nonlinearities will be stable, especially if they are
|
||||
# not perfectly monotonic. Good choices include relu, silu, and gelu.
|
||||
nonlinearities = {
|
||||
'identity': lambda x: x,
|
||||
'celu': lambda x: jax.nn.celu(x) * 1.270926833152771,
|
||||
'elu': lambda x: jax.nn.elu(x) * 1.2716004848480225,
|
||||
'gelu': lambda x: jax.nn.gelu(x) * 1.7015043497085571,
|
||||
'glu': lambda x: jax.nn.glu(x) * 1.8484294414520264,
|
||||
'leaky_relu': lambda x: jax.nn.leaky_relu(x) * 1.70590341091156,
|
||||
'log_sigmoid': lambda x: jax.nn.log_sigmoid(x) * 1.9193484783172607,
|
||||
'log_softmax': lambda x: jax.nn.log_softmax(x) * 1.0002083778381348,
|
||||
'relu': lambda x: jax.nn.relu(x) * 1.7139588594436646,
|
||||
'relu6': lambda x: jax.nn.relu6(x) * 1.7131484746932983,
|
||||
'selu': lambda x: jax.nn.selu(x) * 1.0008515119552612,
|
||||
'sigmoid': lambda x: jax.nn.sigmoid(x) * 4.803835391998291,
|
||||
'silu': lambda x: jax.nn.silu(x) * 1.7881293296813965,
|
||||
'soft_sign': lambda x: jax.nn.soft_sign(x) * 2.338853120803833,
|
||||
'softplus': lambda x: jax.nn.softplus(x) * 1.9203323125839233,
|
||||
'tanh': lambda x: jnp.tanh(x) * 1.5939117670059204,
|
||||
}
|
||||
|
||||
|
||||
class WSConv2D(hk.Conv2D):
|
||||
"""2D Convolution with Scaled Weight Standardization and affine gain+bias."""
|
||||
|
||||
@hk.transparent
|
||||
def standardize_weight(self, weight, eps=1e-4):
|
||||
"""Apply scaled WS with affine gain."""
|
||||
mean = jnp.mean(weight, axis=(0, 1, 2), keepdims=True)
|
||||
var = jnp.var(weight, axis=(0, 1, 2), keepdims=True)
|
||||
fan_in = np.prod(weight.shape[:-1])
|
||||
# Get gain
|
||||
gain = hk.get_parameter('gain', shape=(weight.shape[-1],),
|
||||
dtype=weight.dtype, init=jnp.ones)
|
||||
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
|
||||
scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps)) * gain
|
||||
shift = mean * scale
|
||||
return weight * scale - shift
|
||||
|
||||
def __call__(self, inputs: jnp.ndarray, eps: float = 1e-4) -> jnp.ndarray:
|
||||
w_shape = self.kernel_shape + (
|
||||
inputs.shape[self.channel_index] // self.feature_group_count,
|
||||
self.output_channels)
|
||||
# Use fan-in scaled init, but WS is largely insensitive to this choice.
|
||||
w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')
|
||||
w = hk.get_parameter('w', w_shape, inputs.dtype, init=w_init)
|
||||
weight = self.standardize_weight(w, eps)
|
||||
out = jax.lax.conv_general_dilated(
|
||||
inputs, weight, window_strides=self.stride, padding=self.padding,
|
||||
lhs_dilation=self.lhs_dilation, rhs_dilation=self.kernel_dilation,
|
||||
dimension_numbers=self.dimension_numbers,
|
||||
feature_group_count=self.feature_group_count)
|
||||
# Always add bias
|
||||
bias_shape = (self.output_channels,)
|
||||
bias = hk.get_parameter('bias', bias_shape, inputs.dtype, init=jnp.zeros)
|
||||
return out + bias
|
||||
|
||||
|
||||
def signal_metrics(x, i):
|
||||
"""Things to measure about a NCHW tensor activation."""
|
||||
metrics = {}
|
||||
# Average channel-wise mean-squared
|
||||
metrics[f'avg_sq_mean_{i}'] = jnp.mean(jnp.mean(x, axis=[0, 1, 2])**2)
|
||||
# Average channel variance
|
||||
metrics[f'avg_var_{i}'] = jnp.mean(jnp.var(x, axis=[0, 1, 2]))
|
||||
return metrics
|
||||
|
||||
|
||||
def count_conv_flops(in_ch, conv, h, w):
|
||||
"""For a conv layer with in_ch inputs, count the FLOPS."""
|
||||
# How many output floats are we producing?
|
||||
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
|
||||
# At each OHW location we do computation equal to (I//G) * kh * kw
|
||||
flop_per_loc = (in_ch / conv.feature_group_count)
|
||||
flop_per_loc *= np.prod(conv.kernel_shape)
|
||||
return output_shape * flop_per_loc
|
||||
|
||||
|
||||
class SqueezeExcite(hk.Module):
|
||||
"""Simple Squeeze+Excite module."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, se_ratio=0.5,
|
||||
hidden_ch=None, activation=jax.nn.relu,
|
||||
name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
if se_ratio is None:
|
||||
if hidden_ch is None:
|
||||
raise ValueError('Must provide one of se_ratio or hidden_ch')
|
||||
self.hidden_ch = hidden_ch
|
||||
else:
|
||||
self.hidden_ch = max(1, int(self.in_ch * se_ratio))
|
||||
self.activation = activation
|
||||
self.fc0 = hk.Linear(self.hidden_ch, with_bias=True)
|
||||
self.fc1 = hk.Linear(self.out_ch, with_bias=True)
|
||||
|
||||
def __call__(self, x):
|
||||
h = jnp.mean(x, axis=[1, 2]) # Mean pool over HW extent
|
||||
h = self.fc1(self.activation(self.fc0(h)))
|
||||
h = jax.nn.sigmoid(h)[:, None, None] # Broadcast along H, W
|
||||
return h
|
||||
|
||||
|
||||
class StochDepth(hk.Module):
|
||||
"""Batchwise Dropout used in EfficientNet, optionally sans rescaling."""
|
||||
|
||||
def __init__(self, drop_rate, scale_by_keep=False, name=None):
|
||||
super().__init__(name=name)
|
||||
self.drop_rate = drop_rate
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def __call__(self, x, is_training) -> jnp.ndarray:
|
||||
if not is_training:
|
||||
return x
|
||||
batch_size = x.shape[0]
|
||||
r = jax.random.uniform(hk.next_rng_key(), [batch_size, 1, 1, 1],
|
||||
dtype=x.dtype)
|
||||
keep_prob = 1. - self.drop_rate
|
||||
binary_tensor = jnp.floor(keep_prob + r)
|
||||
if self.scale_by_keep:
|
||||
x = x / keep_prob
|
||||
return x * binary_tensor
|
||||
542
nfnets/dataset.py
Normal file
542
nfnets/dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
363
nfnets/experiment.py
Normal file
363
nfnets/experiment.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Basic Jaxline ImageNet experiment."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jaxline import base_config
|
||||
from jaxline import experiment
|
||||
from jaxline import platform
|
||||
from jaxline import utils as jl_utils
|
||||
from ml_collections import config_dict
|
||||
import numpy as np
|
||||
from nfnets import dataset
|
||||
from nfnets import optim
|
||||
from nfnets import utils
|
||||
# pylint: disable=logging-format-interpolation
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# We define the experiment launch config in the same file as the experiment to
|
||||
# keep things self-contained in a single file, but one might consider moving the
|
||||
# config and/or sweep functions to a separate file, if necessary.
|
||||
def get_config():
|
||||
"""Return config object for training."""
|
||||
config = base_config.get_base_config()
|
||||
|
||||
# Experiment config.
|
||||
train_batch_size = 1024 # Global batch size.
|
||||
images_per_epoch = 1281167
|
||||
num_epochs = 90
|
||||
steps_per_epoch = images_per_epoch / train_batch_size
|
||||
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
|
||||
config.random_seed = 0
|
||||
config.experiment_kwargs = config_dict.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
lr=0.1,
|
||||
num_epochs=num_epochs,
|
||||
label_smoothing=0.1,
|
||||
model='ResNet',
|
||||
image_size=224,
|
||||
use_ema=False,
|
||||
ema_decay=0.9999, # Quatros nuevos amigos
|
||||
ema_start=0,
|
||||
which_ema='tf1_ema',
|
||||
augment_name=None, # 'mixup_cutmix',
|
||||
augment_before_mix=True,
|
||||
eval_preproc='crop_resize',
|
||||
train_batch_size=train_batch_size,
|
||||
eval_batch_size=50,
|
||||
eval_subset='test',
|
||||
num_classes=1000,
|
||||
which_dataset='imagenet',
|
||||
fake_data=False,
|
||||
which_loss='softmax_cross_entropy', # For now, must be softmax
|
||||
transpose=True, # Use the double-transpose trick?
|
||||
bfloat16=False,
|
||||
lr_schedule=dict(
|
||||
name='WarmupCosineDecay',
|
||||
kwargs=dict(num_steps=config.training_steps,
|
||||
start_val=0,
|
||||
min_val=0,
|
||||
warmup_steps=5*steps_per_epoch),
|
||||
),
|
||||
lr_scale_by_bs=True,
|
||||
optimizer=dict(
|
||||
name='SGD',
|
||||
kwargs={'momentum': 0.9, 'nesterov': True,
|
||||
'weight_decay': 1e-4,},
|
||||
),
|
||||
model_kwargs=dict(
|
||||
width=4,
|
||||
which_norm='BatchNorm',
|
||||
norm_kwargs=dict(create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
), # cross_replica_axis='i'),
|
||||
variant='ResNet50',
|
||||
activation='relu',
|
||||
drop_rate=0.0,
|
||||
),
|
||||
),))
|
||||
|
||||
# Training loop config: log and checkpoint every minute
|
||||
config.log_train_data_interval = 60
|
||||
config.log_tensors_interval = 60
|
||||
config.save_checkpoint_interval = 60
|
||||
config.eval_specific_checkpoint_dir = ''
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class Experiment(experiment.AbstractExperiment):
|
||||
"""Imagenet experiment."""
|
||||
CHECKPOINT_ATTRS = {
|
||||
'_params': 'params',
|
||||
'_state': 'state',
|
||||
'_ema_params': 'ema_params',
|
||||
'_ema_state': 'ema_state',
|
||||
'_opt_state': 'opt_state',
|
||||
}
|
||||
|
||||
def __init__(self, mode, config, init_rng):
|
||||
super().__init__(mode=mode)
|
||||
self.mode = mode
|
||||
self.config = config
|
||||
self.init_rng = init_rng
|
||||
|
||||
# Checkpointed experiment state.
|
||||
self._params = None
|
||||
self._state = None
|
||||
self._ema_params = None
|
||||
self._ema_state = None
|
||||
self._opt_state = None
|
||||
|
||||
# Input pipelines.
|
||||
self._train_input = None
|
||||
self._eval_input = None
|
||||
|
||||
# Get model, loaded in from the zoo
|
||||
self.model_module = importlib.import_module(
|
||||
('nfnets.'+ self.config.model.lower()))
|
||||
self.net = hk.transform_with_state(self._forward_fn)
|
||||
|
||||
# Assign image sizes
|
||||
if self.config.get('override_imsize', False):
|
||||
self.train_imsize = self.config.image_size
|
||||
self.test_imsize = self.config.get('eval_image_size', self.train_imsize)
|
||||
else:
|
||||
variant_dict = getattr(self.model_module, self.config.model).variant_dict
|
||||
variant_dict = variant_dict[self.config.model_kwargs.variant]
|
||||
self.train_imsize = variant_dict.get('train_imsize',
|
||||
self.config.image_size)
|
||||
# Test imsize defaults to model-specific value, then to config imsize
|
||||
test_imsize = self.config.get('eval_image_size', self.config.image_size)
|
||||
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
|
||||
|
||||
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
|
||||
self.train_fn = jax.pmap(self._train_fn, axis_name='i',
|
||||
donate_argnums=donate_argnums)
|
||||
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
|
||||
|
||||
def _initialize_train(self):
|
||||
self._train_input = self._build_train_input()
|
||||
# Initialize net and EMA copy of net if no params available.
|
||||
if self._params is None:
|
||||
inputs = next(self._train_input)
|
||||
init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
|
||||
axis_name='i')
|
||||
init_rng = jl_utils.bcast_local_devices(self.init_rng)
|
||||
self._params, self._state = init_net(init_rng, inputs)
|
||||
if self.config.use_ema:
|
||||
self._ema_params, self._ema_state = init_net(init_rng, inputs)
|
||||
num_params = hk.data_structures.tree_size(self._params)
|
||||
logging.info(f'Net parameters: {num_params / jax.local_device_count()}')
|
||||
self._make_opt()
|
||||
|
||||
def _make_opt(self):
|
||||
# Separate conv params and gains/biases
|
||||
def pred(mod, name, val): # pylint:disable=unused-argument
|
||||
return (name in ['scale', 'offset', 'b']
|
||||
or 'gain' in name or 'bias' in name)
|
||||
gains_biases, weights = hk.data_structures.partition(pred, self._params)
|
||||
# Lr schedule with batch-based LR scaling
|
||||
if self.config.lr_scale_by_bs:
|
||||
max_lr = (self.config.lr * self.config.train_batch_size) / 256
|
||||
else:
|
||||
max_lr = self.config.lr
|
||||
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
|
||||
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
|
||||
# Optimizer; no need to broadcast!
|
||||
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
|
||||
opt_kwargs['lr'] = lr_schedule
|
||||
opt_module = getattr(optim, self.config.optimizer.name)
|
||||
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None},
|
||||
{'params': weights}], **opt_kwargs)
|
||||
if self._opt_state is None:
|
||||
self._opt_state = self.opt.states()
|
||||
else:
|
||||
self.opt.plugin(self._opt_state)
|
||||
|
||||
def _forward_fn(self, inputs, is_training):
|
||||
net_kwargs = {'num_classes': self.config.num_classes,
|
||||
**self.config.model_kwargs}
|
||||
net = getattr(self.model_module, self.config.model)(**net_kwargs)
|
||||
if self.config.get('transpose', False):
|
||||
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
|
||||
else:
|
||||
images = inputs['images']
|
||||
if self.config.bfloat16 and self.mode == 'train':
|
||||
images = utils.to_bf16(images)
|
||||
return net(images, is_training=is_training)['logits']
|
||||
|
||||
def _one_hot(self, value):
|
||||
"""One-hot encoding potentially over a sequence of labels."""
|
||||
y = jax.nn.one_hot(value, self.config.num_classes)
|
||||
return y
|
||||
|
||||
def _loss_fn(self, params, state, inputs, rng):
|
||||
logits, state = self.net.apply(params, state, rng, inputs, is_training=True)
|
||||
y = self._one_hot(inputs['labels'])
|
||||
if 'mix_labels' in inputs: # Handle cutmix/mixup label mixing
|
||||
logging.info('Using mixup or cutmix!')
|
||||
y1 = self._one_hot(inputs['mix_labels'])
|
||||
y = inputs['ratio'][:, None] * y + (1. - inputs['ratio'][:, None]) * y1
|
||||
if self.config.label_smoothing > 0: # get smoothy
|
||||
spositives = 1. - self.config.label_smoothing
|
||||
snegatives = self.config.label_smoothing / self.config.num_classes
|
||||
y = spositives * y + snegatives
|
||||
if self.config.bfloat16: # Cast logits to float32
|
||||
logits = logits.astype(jnp.float32)
|
||||
which_loss = getattr(utils, self.config.which_loss)
|
||||
loss = which_loss(logits, y, reduction='mean')
|
||||
metrics = utils.topk_correct(logits, inputs['labels'], prefix='train_')
|
||||
# Average top-1 and top-5 correct labels
|
||||
metrics = jax.tree_map(jnp.mean, metrics)
|
||||
metrics['train_loss'] = loss # Metrics will be pmeaned so don't divide here
|
||||
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
|
||||
return scaled_loss, (metrics, state)
|
||||
|
||||
def _train_fn(self, params, states, opt_states,
|
||||
inputs, rng, global_step,
|
||||
ema_params, ema_states):
|
||||
"""Runs one batch forward + backward and run a single opt step."""
|
||||
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
|
||||
if self.config.bfloat16:
|
||||
in_params, states = jax.tree_map(utils.to_bf16, (params, states))
|
||||
else:
|
||||
in_params = params
|
||||
grads, (metrics, states) = grad_fn(in_params, states, inputs, rng)
|
||||
if self.config.bfloat16:
|
||||
states, metrics, grads = jax.tree_map(utils.from_bf16,
|
||||
(states, metrics, grads))
|
||||
# Sum gradients and average losses for pmap
|
||||
grads = jax.lax.psum(grads, 'i')
|
||||
metrics = jax.lax.pmean(metrics, 'i')
|
||||
# Compute updates and update parameters
|
||||
metrics['learning_rate'] = self.opt._hyperparameters['lr'](global_step) # pylint: disable=protected-access
|
||||
params, opt_states = self.opt.step(params, grads, opt_states, global_step)
|
||||
if ema_params is not None:
|
||||
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
|
||||
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
|
||||
ema_params = jax.tree_multimap(ema, ema_params, params)
|
||||
ema_states = jax.tree_multimap(ema, ema_states, states)
|
||||
return {'params': params, 'states': states, 'opt_states': opt_states,
|
||||
'ema_params': ema_params, 'ema_states': ema_states,
|
||||
'metrics': metrics}
|
||||
|
||||
# _ _
|
||||
# | |_ _ __ __ _(_)_ __
|
||||
# | __| '__/ _` | | '_ \
|
||||
# | |_| | | (_| | | | | |
|
||||
# \__|_| \__,_|_|_| |_|
|
||||
#
|
||||
|
||||
def step(self, global_step, rng, *unused_args, **unused_kwargs):
|
||||
if self._train_input is None:
|
||||
self._initialize_train()
|
||||
inputs = next(self._train_input)
|
||||
out = self.train_fn(params=self._params, states=self._state,
|
||||
opt_states=self._opt_state, inputs=inputs,
|
||||
rng=rng, global_step=global_step,
|
||||
ema_params=self._ema_params, ema_states=self._ema_state)
|
||||
self._params, self._state = out['params'], out['states']
|
||||
self._opt_state = out['opt_states']
|
||||
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
|
||||
self.opt.plugin(self._opt_state)
|
||||
return jl_utils.get_first(out['metrics'])
|
||||
|
||||
def _build_train_input(self):
|
||||
num_devices = jax.device_count()
|
||||
global_batch_size = self.config.train_batch_size
|
||||
bs_per_device, ragged = divmod(global_batch_size, num_devices)
|
||||
if ragged:
|
||||
raise ValueError(
|
||||
f'Global batch size {global_batch_size} must be divisible by '
|
||||
f'num devices {num_devices}')
|
||||
return dataset.load(
|
||||
dataset.Split.TRAIN_AND_VALID, is_training=True,
|
||||
batch_dims=[jax.local_device_count(), bs_per_device],
|
||||
transpose=self.config.get('transpose', False),
|
||||
image_size=(self.train_imsize,) * 2,
|
||||
augment_name=self.config.augment_name,
|
||||
augment_before_mix=self.config.get('augment_before_mix', True),
|
||||
name=self.config.which_dataset,
|
||||
fake_data=self.config.get('fake_data', False))
|
||||
|
||||
# _
|
||||
# _____ ____ _| |
|
||||
# / _ \ \ / / _` | |
|
||||
# | __/\ V / (_| | |
|
||||
# \___| \_/ \__,_|_|
|
||||
#
|
||||
|
||||
def evaluate(self, global_step, **unused_args):
|
||||
metrics = self._eval_epoch(self._params, self._state)
|
||||
if self.config.use_ema:
|
||||
ema_metrics = self._eval_epoch(self._ema_params, self._ema_state)
|
||||
metrics.update({f'ema_{key}': val for key, val in ema_metrics.items()})
|
||||
logging.info(f'[Step {global_step}] Eval scalars: {metrics}')
|
||||
return metrics
|
||||
|
||||
def _eval_epoch(self, params, state):
|
||||
"""Evaluates an epoch."""
|
||||
num_samples = 0.
|
||||
summed_metrics = None
|
||||
|
||||
for inputs in self._build_eval_input():
|
||||
num_samples += np.prod(inputs['labels'].shape[:2]) # Account for pmaps
|
||||
metrics = self.eval_fn(params, state, inputs)
|
||||
# Accumulate the sum of metrics for each step.
|
||||
metrics = jax.tree_map(lambda x: jnp.sum(x[0], axis=0), metrics)
|
||||
if summed_metrics is None:
|
||||
summed_metrics = metrics
|
||||
else:
|
||||
summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics)
|
||||
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
|
||||
return jax.device_get(mean_metrics)
|
||||
|
||||
def _eval_fn(self, params, state, inputs):
|
||||
"""Evaluate a single batch and return loss and top-k acc."""
|
||||
logits, _ = self.net.apply(params, state, None, inputs, is_training=False)
|
||||
y = self._one_hot(inputs['labels'])
|
||||
which_loss = getattr(utils, self.config.which_loss)
|
||||
loss = which_loss(logits, y, reduction=None)
|
||||
metrics = utils.topk_correct(logits, inputs['labels'], prefix='eval_')
|
||||
metrics['eval_loss'] = loss
|
||||
return jax.lax.psum(metrics, 'i')
|
||||
|
||||
def _build_eval_input(self):
|
||||
"""Builds the evaluation input pipeline."""
|
||||
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
|
||||
split = dataset.Split.from_string(self.config.eval_subset)
|
||||
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
|
||||
return dataset.load(split, is_training=False,
|
||||
batch_dims=[jax.local_device_count(), bs_per_device],
|
||||
transpose=self.config.get('transpose', False),
|
||||
image_size=(self.test_imsize,) * 2,
|
||||
name=self.config.which_dataset,
|
||||
eval_preproc=eval_preproc,
|
||||
fake_data=self.config.get('fake_data', False))
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('config')
|
||||
platform.main(Experiment, sys.argv[1:])
|
||||
86
nfnets/experiment_nf_regnets.py
Normal file
86
nfnets/experiment_nf_regnets.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""ImageNet experiment with NF-RegNets."""
|
||||
|
||||
from ml_collections import config_dict
|
||||
from nfnets import experiment
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Return config object for training."""
|
||||
config = experiment.get_config()
|
||||
|
||||
# Experiment config.
|
||||
train_batch_size = 1024 # Global batch size.
|
||||
images_per_epoch = 1281167
|
||||
num_epochs = 360
|
||||
steps_per_epoch = images_per_epoch / train_batch_size
|
||||
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
|
||||
config.random_seed = 0
|
||||
|
||||
config.experiment_kwargs = config_dict.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
lr=0.4,
|
||||
num_epochs=num_epochs,
|
||||
label_smoothing=0.1,
|
||||
model='NF_RegNet',
|
||||
image_size=224,
|
||||
use_ema=True,
|
||||
ema_decay=0.99999, # Cinco nueves amigos
|
||||
ema_start=0,
|
||||
augment_name='mixup_cutmix',
|
||||
train_batch_size=train_batch_size,
|
||||
eval_batch_size=50,
|
||||
eval_subset='test',
|
||||
num_classes=1000,
|
||||
which_dataset='imagenet',
|
||||
which_loss='softmax_cross_entropy', # One of softmax or sigmoid
|
||||
bfloat16=False,
|
||||
lr_schedule=dict(
|
||||
name='WarmupCosineDecay',
|
||||
kwargs=dict(num_steps=config.training_steps,
|
||||
start_val=0,
|
||||
min_val=0.001,
|
||||
warmup_steps=5*steps_per_epoch),
|
||||
),
|
||||
lr_scale_by_bs=False,
|
||||
optimizer=dict(
|
||||
name='SGD',
|
||||
kwargs={'momentum': 0.9, 'nesterov': True,
|
||||
'weight_decay': 5e-5,},
|
||||
),
|
||||
model_kwargs=dict(
|
||||
variant='B0',
|
||||
width=0.75,
|
||||
expansion=2.25,
|
||||
se_ratio=0.5,
|
||||
alpha=0.2,
|
||||
stochdepth_rate=0.1,
|
||||
drop_rate=None,
|
||||
activation='silu',
|
||||
),
|
||||
|
||||
)))
|
||||
|
||||
# Set weight decay based on variant (scaled as 5e-5 + 1e-5 * level)
|
||||
variant = config.experiment_kwargs.config.model_kwargs.variant
|
||||
weight_decay = {'B0': 5e-5, 'B1': 6e-5, 'B2': 7e-5,
|
||||
'B3': 8e-5, 'B4': 9e-5, 'B5': 1e-4}[variant]
|
||||
config.experiment_kwargs.config.optimizer.kwargs.weight_decay = weight_decay
|
||||
|
||||
return config
|
||||
|
||||
Experiment = experiment.Experiment
|
||||
216
nfnets/fixup_resnet.py
Normal file
216
nfnets/fixup_resnet.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ResNet (post-activation) with FixUp."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import functools
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import base
|
||||
|
||||
|
||||
nonlinearities = {
|
||||
'swish': jax.nn.silu,
|
||||
'relu': jax.nn.relu,
|
||||
'identity': lambda x: x}
|
||||
|
||||
|
||||
class FixUp_ResNet(hk.Module):
|
||||
"""Fixup based ResNet."""
|
||||
|
||||
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
|
||||
'ResNet101': {'depth': [3, 4, 23, 3]},
|
||||
'ResNet152': {'depth': [3, 8, 36, 3]},
|
||||
'ResNet200': {'depth': [3, 24, 36, 3]},
|
||||
'ResNet288': {'depth': [24, 24, 24, 24]},
|
||||
'ResNet600': {'depth': [50, 50, 50, 50]},
|
||||
}
|
||||
|
||||
def __init__(self, num_classes, variant='ResNet50', width=4,
|
||||
stochdepth_rate=0.1, drop_rate=None,
|
||||
activation='relu', fc_init=jnp.zeros,
|
||||
name='FixUp_ResNet'):
|
||||
super().__init__(name=name)
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.width = width
|
||||
# Get variant info
|
||||
block_params = self.variant_dict[self.variant]
|
||||
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
|
||||
self.depth_pattern = block_params['depth']
|
||||
self.activation = nonlinearities[activation]
|
||||
if drop_rate is None:
|
||||
self.drop_rate = block_params['drop_rate']
|
||||
else:
|
||||
self.drop_rate = drop_rate
|
||||
self.which_conv = functools.partial(hk.Conv2D,
|
||||
with_bias=False)
|
||||
# Stem
|
||||
ch = int(16 * self.width)
|
||||
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
|
||||
padding='SAME',
|
||||
name='initial_conv')
|
||||
|
||||
# Body
|
||||
self.blocks = []
|
||||
num_blocks = sum(self.depth_pattern)
|
||||
index = 0 # Overall block index
|
||||
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
|
||||
for block_width, stage_depth, stride in zip(*block_args):
|
||||
for block_index in range(stage_depth):
|
||||
# Block stochastic depth drop-rate
|
||||
block_stochdepth_rate = stochdepth_rate * index / num_blocks
|
||||
self.blocks += [ResBlock(ch, block_width, num_blocks,
|
||||
stride=stride if block_index == 0 else 1,
|
||||
activation=self.activation,
|
||||
which_conv=self.which_conv,
|
||||
stochdepth_rate=block_stochdepth_rate,
|
||||
)]
|
||||
ch = block_width
|
||||
index += 1
|
||||
|
||||
# Head
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training=True, return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
# Stem
|
||||
outputs = {}
|
||||
out = self.initial_conv(x)
|
||||
bias1 = hk.get_parameter('bias1', (), x.dtype, init=jnp.zeros)
|
||||
out = self.activation(out + bias1)
|
||||
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_avg_var = block(out, is_training=is_training)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_avg_var
|
||||
# Final-conv->activation, pool, dropout, classify
|
||||
pool = jnp.mean(out, [1, 2])
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
bias2 = hk.get_parameter('bias2', (), pool.dtype, init=jnp.zeros)
|
||||
outputs['logits'] = self.fc(pool + bias2)
|
||||
return outputs
|
||||
|
||||
def count_flops(self, h, w):
|
||||
flops = []
|
||||
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
|
||||
h, w = h / 2, w / 2
|
||||
# Body FLOPs
|
||||
for block in self.blocks:
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
class ResBlock(hk.Module):
|
||||
"""Post-activation Fixup Block."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, num_blocks, bottleneck_ratio=0.25,
|
||||
kernel_size=3, stride=1,
|
||||
which_conv=hk.Conv2D, activation=jax.nn.relu,
|
||||
stochdepth_rate=None, name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.activation = activation
|
||||
# Bottleneck width
|
||||
self.width = int(self.out_ch * bottleneck_ratio)
|
||||
self.stride = stride
|
||||
# Conv 0 (typically expansion conv)
|
||||
conv0_init = hk.initializers.RandomNormal(
|
||||
stddev=((2 / self.width)**0.5) * (num_blocks**(-0.25)))
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
|
||||
name='conv0', w_init=conv0_init)
|
||||
# Grouped NxN conv
|
||||
conv1_init = hk.initializers.RandomNormal(
|
||||
stddev=((2 / (self.width * (kernel_size**2)))**0.5)
|
||||
* (num_blocks**(-0.25)))
|
||||
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
|
||||
padding='SAME', name='conv1', w_init=conv1_init)
|
||||
# Conv 2, typically projection conv
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
|
||||
name='conv2', w_init=hk.initializers.Constant(0))
|
||||
# Use shortcut conv on channel change or downsample.
|
||||
self.use_projection = stride > 1 or self.in_ch != self.out_ch
|
||||
if self.use_projection:
|
||||
shortcut_init = hk.initializers.RandomNormal(
|
||||
stddev=(2 / self.out_ch) ** 0.5)
|
||||
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
|
||||
stride=stride, padding='SAME',
|
||||
name='conv_shortcut',
|
||||
w_init=shortcut_init)
|
||||
# Are we using stochastic depth?
|
||||
self._has_stochdepth = (stochdepth_rate is not None and
|
||||
stochdepth_rate > 0. and stochdepth_rate < 1.0)
|
||||
if self._has_stochdepth:
|
||||
self.stoch_depth = base.StochDepth(stochdepth_rate)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
bias1a = hk.get_parameter('bias1a', (), x.dtype, init=jnp.zeros)
|
||||
bias1b = hk.get_parameter('bias1b', (), x.dtype, init=jnp.zeros)
|
||||
bias2a = hk.get_parameter('bias2a', (), x.dtype, init=jnp.zeros)
|
||||
bias2b = hk.get_parameter('bias2b', (), x.dtype, init=jnp.zeros)
|
||||
bias3a = hk.get_parameter('bias3a', (), x.dtype, init=jnp.zeros)
|
||||
bias3b = hk.get_parameter('bias3b', (), x.dtype, init=jnp.zeros)
|
||||
scale = hk.get_parameter('scale', (), x.dtype, init=jnp.ones)
|
||||
|
||||
out = x + bias1a
|
||||
shortcut = out
|
||||
if self.use_projection: # Downsample with conv1x1
|
||||
shortcut = self.conv_shortcut(shortcut)
|
||||
out = self.conv0(out)
|
||||
out = self.activation(out + bias1b)
|
||||
out = self.conv1(out + bias2a)
|
||||
out = self.activation(out + bias2b)
|
||||
out = self.conv2(out + bias3a)
|
||||
out = out * scale + bias3b
|
||||
# Get average residual variance for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
# Apply stochdepth if applicable.
|
||||
if self._has_stochdepth:
|
||||
out = self.stoch_depth(out, is_training)
|
||||
# SkipInit Gain
|
||||
out = out + shortcut
|
||||
return self.activation(out), res_avg_var
|
||||
|
||||
def count_flops(self, h, w):
|
||||
# Count conv FLOPs based on input HW
|
||||
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
|
||||
# If block is strided we decrease resolution here.
|
||||
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
|
||||
if self.stride > 1:
|
||||
h, w = h / self.stride, w / self.stride
|
||||
if self.use_projection:
|
||||
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
|
||||
else:
|
||||
sc_flops = 0
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
|
||||
|
||||
217
nfnets/nf_regnet.py
Normal file
217
nfnets/nf_regnet.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Normalizer-Free RegNets."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import base
|
||||
|
||||
|
||||
class NF_RegNet(hk.Module):
|
||||
"""Normalizer-Free RegNets."""
|
||||
|
||||
variant_dict = base.nf_regnet_params
|
||||
|
||||
def __init__(self, num_classes, variant='B0',
|
||||
width=0.75, expansion=2.25, group_size=8, se_ratio=0.5,
|
||||
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
|
||||
activation='swish', fc_init=jnp.zeros,
|
||||
name='NF_RegNet'):
|
||||
super().__init__(name=name)
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.width = width
|
||||
self.expansion = expansion
|
||||
self.group_size = group_size
|
||||
self.se_ratio = se_ratio
|
||||
# Get variant info
|
||||
block_params = self.variant_dict[self.variant]
|
||||
self.train_imsize = block_params['train_imsize']
|
||||
self.test_imsize = block_params['test_imsize']
|
||||
self.width_pattern = block_params['width']
|
||||
self.depth_pattern = block_params['depth']
|
||||
self.activation = base.nonlinearities[activation]
|
||||
if drop_rate is None:
|
||||
self.drop_rate = block_params['drop_rate']
|
||||
else:
|
||||
self.drop_rate = drop_rate
|
||||
self.which_conv = base.WSConv2D
|
||||
# Stem
|
||||
ch = int(self.width_pattern[0] * self.width)
|
||||
self.initial_conv = self.which_conv(ch, kernel_shape=3, stride=2,
|
||||
padding='SAME', name='initial_conv')
|
||||
|
||||
# Body
|
||||
self.blocks = []
|
||||
expected_std = 1.0
|
||||
num_blocks = sum(self.depth_pattern)
|
||||
index = 0 # Overall block index
|
||||
for block_width, stage_depth in zip(self.width_pattern, self.depth_pattern):
|
||||
for block_index in range(stage_depth):
|
||||
# Scalar pre-multiplier so each block sees an N(0,1) input at init
|
||||
beta = 1./ expected_std
|
||||
# Block stochastic depth drop-rate
|
||||
block_stochdepth_rate = stochdepth_rate * index / num_blocks
|
||||
# Use a bottleneck expansion ratio of 1 for first block following EffNet
|
||||
expand_ratio = 1 if index == 0 else expansion
|
||||
out_ch = (int(block_width * self.width))
|
||||
self.blocks += [NFBlock(ch, out_ch,
|
||||
expansion=expand_ratio, se_ratio=se_ratio,
|
||||
group_size=self.group_size,
|
||||
stride=2 if block_index == 0 else 1,
|
||||
beta=beta, alpha=alpha,
|
||||
activation=self.activation,
|
||||
which_conv=self.which_conv,
|
||||
stochdepth_rate=block_stochdepth_rate,
|
||||
)]
|
||||
ch = out_ch
|
||||
index += 1
|
||||
# Reset expected std but still give it 1 block of growth
|
||||
if block_index == 0:
|
||||
expected_std = 1.0
|
||||
expected_std = (expected_std **2 + alpha**2)**0.5
|
||||
|
||||
# Head with final conv mimicking EffNets
|
||||
self.final_conv = self.which_conv(int(1280 * ch // 440), kernel_shape=1,
|
||||
padding='SAME', name='final_conv')
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training=True, return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
# Stem
|
||||
outputs = {}
|
||||
out = self.initial_conv(x)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_avg_var = block(out, is_training=is_training)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_avg_var
|
||||
# Final-conv->activation, pool, dropout, classify
|
||||
out = self.activation(self.final_conv(out))
|
||||
pool = jnp.mean(out, [1, 2])
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
outputs['logits'] = self.fc(pool)
|
||||
return outputs
|
||||
|
||||
def count_flops(self, h, w):
|
||||
flops = []
|
||||
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
|
||||
h, w = h / 2, w / 2
|
||||
# Body FLOPs
|
||||
for block in self.blocks:
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
class NFBlock(hk.Module):
|
||||
"""Normalizer-Free RegNet Block."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, expansion=2.25, se_ratio=0.5,
|
||||
kernel_size=3, group_size=8, stride=1,
|
||||
beta=1.0, alpha=0.2,
|
||||
which_conv=base.WSConv2D, activation=jax.nn.relu,
|
||||
stochdepth_rate=None, name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
self.expansion = expansion
|
||||
self.se_ratio = se_ratio
|
||||
self.kernel_size = kernel_size
|
||||
self.activation = activation
|
||||
self.beta, self.alpha = beta, alpha
|
||||
# Round expanded with based on group count
|
||||
width = int(self.in_ch * expansion)
|
||||
self.groups = width // group_size
|
||||
self.width = group_size * self.groups
|
||||
self.stride = stride
|
||||
# Conv 0 (typically expansion conv)
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
|
||||
name='conv0')
|
||||
# Grouped NxN conv
|
||||
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
|
||||
padding='SAME', feature_group_count=self.groups,
|
||||
name='conv1')
|
||||
# Conv 2, typically projection conv
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
|
||||
name='conv2')
|
||||
# Use shortcut conv on channel change or downsample.
|
||||
self.use_projection = stride > 1 or self.in_ch != self.out_ch
|
||||
if self.use_projection:
|
||||
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
|
||||
padding='SAME', name='conv_shortcut')
|
||||
# Squeeze + Excite Module
|
||||
self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio)
|
||||
|
||||
# Are we using stochastic depth?
|
||||
self._has_stochdepth = (stochdepth_rate is not None and
|
||||
stochdepth_rate > 0. and stochdepth_rate < 1.0)
|
||||
if self._has_stochdepth:
|
||||
self.stoch_depth = base.StochDepth(stochdepth_rate)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
out = self.activation(x) * self.beta
|
||||
if self.stride > 1: # Average-pool downsample.
|
||||
shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if self.use_projection:
|
||||
shortcut = self.conv_shortcut(shortcut)
|
||||
elif self.use_projection:
|
||||
shortcut = self.conv_shortcut(out)
|
||||
else:
|
||||
shortcut = x
|
||||
out = self.conv0(out)
|
||||
out = self.conv1(self.activation(out))
|
||||
out = 2 * self.se(out) * out # Multiply by 2 for rescaling
|
||||
out = self.conv2(self.activation(out))
|
||||
# Get average residual standard deviation for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
# Apply stochdepth if applicable.
|
||||
if self._has_stochdepth:
|
||||
out = self.stoch_depth(out, is_training)
|
||||
# SkipInit Gain
|
||||
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
|
||||
return out * self.alpha + shortcut, res_avg_var
|
||||
|
||||
def count_flops(self, h, w):
|
||||
# Count conv FLOPs based on input HW
|
||||
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
|
||||
# If block is strided we decrease resolution here.
|
||||
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
|
||||
if self.stride > 1:
|
||||
h, w = h / self.stride, w / self.stride
|
||||
if self.use_projection:
|
||||
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
|
||||
else:
|
||||
sc_flops = 0
|
||||
# SE flops happen on avg-pooled activations
|
||||
se_flops = self.se.fc0.output_size * self.width
|
||||
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
|
||||
|
||||
216
nfnets/nf_resnet.py
Normal file
216
nfnets/nf_resnet.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Norm-Free Residual Networks."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import base
|
||||
|
||||
|
||||
class NF_ResNet(hk.Module):
|
||||
"""Norm-Free preactivation ResNet."""
|
||||
|
||||
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
|
||||
'ResNet101': {'depth': [3, 4, 23, 3]},
|
||||
'ResNet152': {'depth': [3, 8, 36, 3]},
|
||||
'ResNet200': {'depth': [3, 24, 36, 3]},
|
||||
'ResNet288': {'depth': [24, 24, 24, 24]},
|
||||
'ResNet600': {'depth': [50, 50, 50, 50]},
|
||||
}
|
||||
|
||||
def __init__(self, num_classes, variant='ResNet50', width=4,
|
||||
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
|
||||
activation='relu', fc_init=None, skipinit_gain=jnp.zeros,
|
||||
use_se=False, se_ratio=0.25,
|
||||
name='NF_ResNet'):
|
||||
super().__init__(name=name)
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.width = width
|
||||
# Get variant info
|
||||
block_params = self.variant_dict[self.variant]
|
||||
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
|
||||
self.depth_pattern = block_params['depth']
|
||||
self.activation = base.nonlinearities[activation]
|
||||
if drop_rate is None:
|
||||
self.drop_rate = block_params['drop_rate']
|
||||
else:
|
||||
self.drop_rate = drop_rate
|
||||
self.which_conv = base.WSConv2D
|
||||
# Stem
|
||||
ch = int(16 * self.width)
|
||||
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
|
||||
padding='SAME', with_bias=False,
|
||||
name='initial_conv')
|
||||
|
||||
# Body
|
||||
self.blocks = []
|
||||
expected_std = 1.0
|
||||
num_blocks = sum(self.depth_pattern)
|
||||
index = 0 # Overall block index
|
||||
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
|
||||
for block_width, stage_depth, stride in zip(*block_args):
|
||||
for block_index in range(stage_depth):
|
||||
# Scalar pre-multiplier so each block sees an N(0,1) input at init
|
||||
beta = 1./ expected_std
|
||||
# Block stochastic depth drop-rate
|
||||
block_stochdepth_rate = stochdepth_rate * index / num_blocks
|
||||
self.blocks += [NFResBlock(ch, block_width,
|
||||
stride=stride if block_index == 0 else 1,
|
||||
beta=beta, alpha=alpha,
|
||||
activation=self.activation,
|
||||
which_conv=self.which_conv,
|
||||
stochdepth_rate=block_stochdepth_rate,
|
||||
skipinit_gain=skipinit_gain,
|
||||
use_se=use_se,
|
||||
se_ratio=se_ratio,
|
||||
)]
|
||||
ch = block_width
|
||||
index += 1
|
||||
# Reset expected std but still give it 1 block of growth
|
||||
if block_index == 0:
|
||||
expected_std = 1.0
|
||||
expected_std = (expected_std **2 + alpha**2)**0.5
|
||||
|
||||
# Head. By default, initialize with N(0, 0.01)
|
||||
if fc_init is None:
|
||||
fc_init = hk.initializers.RandomNormal(0.01, 0)
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training=True, return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
# Stem
|
||||
outputs = {}
|
||||
out = self.initial_conv(x)
|
||||
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_avg_var = block(out, is_training=is_training)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_avg_var
|
||||
# Final-conv->activation, pool, dropout, classify
|
||||
pool = jnp.mean(self.activation(out), [1, 2])
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
outputs['logits'] = self.fc(pool)
|
||||
return outputs
|
||||
|
||||
def count_flops(self, h, w):
|
||||
flops = []
|
||||
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
|
||||
h, w = h / 2, w / 2
|
||||
# Body FLOPs
|
||||
for block in self.blocks:
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
class NFResBlock(hk.Module):
|
||||
"""Normalizer-Free pre-activation ResNet Block."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25,
|
||||
kernel_size=3, stride=1,
|
||||
beta=1.0, alpha=0.2,
|
||||
which_conv=base.WSConv2D, activation=jax.nn.relu,
|
||||
skipinit_gain=jnp.zeros,
|
||||
stochdepth_rate=None,
|
||||
use_se=False, se_ratio=0.25,
|
||||
name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.activation = activation
|
||||
self.beta, self.alpha = beta, alpha
|
||||
self.skipinit_gain = skipinit_gain
|
||||
self.use_se, self.se_ratio = use_se, se_ratio
|
||||
# Bottleneck width
|
||||
self.width = int(self.out_ch * bottleneck_ratio)
|
||||
self.stride = stride
|
||||
# Conv 0 (typically expansion conv)
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
|
||||
name='conv0')
|
||||
# Grouped NxN conv
|
||||
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
|
||||
padding='SAME', name='conv1')
|
||||
# Conv 2, typically projection conv
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
|
||||
name='conv2')
|
||||
# Use shortcut conv on channel change or downsample.
|
||||
self.use_projection = stride > 1 or self.in_ch != self.out_ch
|
||||
if self.use_projection:
|
||||
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
|
||||
stride=stride, padding='SAME',
|
||||
name='conv_shortcut')
|
||||
# Are we using stochastic depth?
|
||||
self._has_stochdepth = (stochdepth_rate is not None and
|
||||
stochdepth_rate > 0. and stochdepth_rate < 1.0)
|
||||
if self._has_stochdepth:
|
||||
self.stoch_depth = base.StochDepth(stochdepth_rate)
|
||||
|
||||
if self.use_se:
|
||||
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
out = self.activation(x) * self.beta
|
||||
shortcut = x
|
||||
if self.use_projection: # Downsample with conv1x1
|
||||
shortcut = self.conv_shortcut(out)
|
||||
out = self.conv0(out)
|
||||
out = self.conv1(self.activation(out))
|
||||
out = self.conv2(self.activation(out))
|
||||
if self.use_se:
|
||||
out = 2 * self.se(out) * out
|
||||
# Get average residual standard deviation for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
# Apply stochdepth if applicable.
|
||||
if self._has_stochdepth:
|
||||
out = self.stoch_depth(out, is_training)
|
||||
# SkipInit Gain
|
||||
out = out * hk.get_parameter('skip_gain', (), out.dtype,
|
||||
init=self.skipinit_gain)
|
||||
return out * self.alpha + shortcut, res_avg_var
|
||||
|
||||
def count_flops(self, h, w):
|
||||
# Count conv FLOPs based on input HW
|
||||
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
|
||||
# If block is strided we decrease resolution here.
|
||||
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
|
||||
if self.stride > 1:
|
||||
h, w = h / self.stride, w / self.stride
|
||||
if self.use_projection:
|
||||
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
|
||||
else:
|
||||
sc_flops = 0
|
||||
# SE flops happen on avg-pooled activations
|
||||
se_flops = self.se.fc0.output_size * self.width
|
||||
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
|
||||
|
||||
455
nfnets/optim.py
Normal file
455
nfnets/optim.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Optimizers and Schedulers, inspired by the PyTorch API."""
|
||||
from collections import ChainMap # pylint:disable=g-importing-member
|
||||
from typing import Callable
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tree
|
||||
from nfnets import utils
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
"""Optimizer base class."""
|
||||
|
||||
def __init__(self, params, defaults):
|
||||
# Flag indicating if parameters have been broadcasted
|
||||
self._broadcasted = False
|
||||
# Optimizer hyperparameters; this is a dict to support using param_groups
|
||||
self._hyperparameters = {}
|
||||
# Mapping from model parameters to optimizer hyperparameters
|
||||
self._params2hyperparams = {}
|
||||
# Assign defaults
|
||||
self._hyperparameters = dict(**defaults)
|
||||
# Prepare parameter groups and mappings
|
||||
self.create_param_groups(params, defaults)
|
||||
# Join params at top-level if params is a list of groups
|
||||
if isinstance(params, list):
|
||||
flatmap = type(hk.data_structures.to_immutable_dict({}))
|
||||
if any([isinstance(group['params'], flatmap) for group in params]):
|
||||
params = hk.data_structures.merge(*[group['params']
|
||||
for group in params])
|
||||
else:
|
||||
params = dict(ChainMap(*[group['params'] for group in params]))
|
||||
# Prepare states
|
||||
create_buffers = lambda k, v: self.create_buffers('/'.join(k), v)
|
||||
self._states = tree.map_structure_with_path(create_buffers, params)
|
||||
|
||||
def add_hyperparam_group(self, group, suffix, defaults):
|
||||
"""Adds new hyperparameters to the hyperparams dict."""
|
||||
# Use default hyperparams unless overridden by group hyperparams
|
||||
group_dict = {key: key for key in defaults if key not in group}
|
||||
for key in group:
|
||||
if key != 'params': # Reserved keyword 'params'
|
||||
group_dict[key] = '%s_%s' % (key, suffix)
|
||||
self._hyperparameters[group_dict[key]] = group[key]
|
||||
# Set up params2hyperparams
|
||||
def set_p2h(k, _):
|
||||
self._params2hyperparams['/'.join(k)] = group_dict
|
||||
tree.map_structure_with_path(set_p2h, group['params'])
|
||||
|
||||
def create_param_groups(self, params, defaults):
|
||||
"""Creates param-hyperparam mappings."""
|
||||
if isinstance(params, list):
|
||||
for group_index, group in enumerate(params):
|
||||
# Add group to hyperparams and get this group's full hyperparameters
|
||||
self.add_hyperparam_group(group, group_index, defaults)
|
||||
else:
|
||||
mapping = {key: key for key in self._hyperparameters}
|
||||
def set_p2h(k, _):
|
||||
self._params2hyperparams['/'.join(k)] = mapping
|
||||
tree.map_structure_with_path(set_p2h, params)
|
||||
|
||||
def create_buffers(self, name, params):
|
||||
"""Method to be overridden by child classes."""
|
||||
pass
|
||||
|
||||
def get_opt_params(self, param_name, itr):
|
||||
"""Returns hyperparams corresponding to param_name."""
|
||||
mapping = self._params2hyperparams[param_name]
|
||||
output = {}
|
||||
for key in mapping:
|
||||
hyper = self._hyperparameters[mapping[key]]
|
||||
# Handle the case where a hyper is a class, for hybrids
|
||||
if isinstance(hyper, Callable) and not isinstance(hyper, type):
|
||||
output[key] = hyper(itr)
|
||||
else:
|
||||
output[key] = hyper
|
||||
return output
|
||||
|
||||
def get_hyper(self, param_name, hyper_name):
|
||||
"""Get an individual hyperparam for a given param."""
|
||||
mapping = self._params2hyperparams[param_name]
|
||||
return self._hyperparameters[mapping[hyper_name]]
|
||||
|
||||
def plugin(self, states):
|
||||
self._states = states
|
||||
|
||||
def states(self):
|
||||
return self._states
|
||||
|
||||
def broadcast(self):
|
||||
"""Brodcasts all buffers and parameters."""
|
||||
self._broadcasted = True
|
||||
for name, state in self._states.items():
|
||||
self._states[name] = {key: utils.broadcast(state[key]) for key in state}
|
||||
|
||||
def gather(self):
|
||||
"""Gathers state (if broadcasted) for saving."""
|
||||
states = {}
|
||||
for name in self._states:
|
||||
state = self._states[name]
|
||||
states[name] = {key: state[key] if state[key] is None else state[key][0]
|
||||
for key in state}
|
||||
return states
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Overrides the object's set-attribute function to register states, etc."""
|
||||
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
|
||||
self._hyperparameters[name] = value
|
||||
elif '_states' in self.__dict__ and name in self._states:
|
||||
self._states[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Override the object's get-attribute function to return states, etc."""
|
||||
if '_hyperparameters' in self.__dict__ and name in self._hyperparameters:
|
||||
return self._hyperparameters[name]
|
||||
elif '_states' in self.__dict__ and name in self._states:
|
||||
return self._states[name]
|
||||
else:
|
||||
object.__getattr__(self, name)
|
||||
|
||||
def step(self, params, grads, states, itr=None):
|
||||
"""Takes a single optimizer step.
|
||||
|
||||
Args:
|
||||
params: a dict containing the parameters to be updated.
|
||||
grads: a dict containing the gradients for each parameter in params.
|
||||
states: a dict containing any optimizer buffers (momentum, etc) for
|
||||
each parameter in params.
|
||||
itr: an optional integer indicating the current step, for scheduling.
|
||||
Returns:
|
||||
The updated params and optimizer buffers.
|
||||
"""
|
||||
get_hyper = lambda k, v: self.get_opt_params('/'.join(k), itr)
|
||||
hypers = tree.map_structure_with_path(get_hyper, params)
|
||||
outs = tree.map_structure_up_to(params, self.update_param,
|
||||
params, grads, states, hypers)
|
||||
return utils.split_tree(outs, params, 2)
|
||||
|
||||
|
||||
class Schedule(object):
|
||||
"""Hyperparameter scheduling objects."""
|
||||
|
||||
|
||||
class CosineDecay(Schedule):
|
||||
"""Cosine decay."""
|
||||
|
||||
def __init__(self, min_val, max_val, num_steps):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.num_steps = num_steps
|
||||
|
||||
def __call__(self, itr):
|
||||
cos = (1 + jnp.cos(jnp.pi * itr / self.num_steps))
|
||||
return 0.5 * (self.max_val - self.min_val) * cos + self.min_val
|
||||
|
||||
|
||||
class WarmupCosineDecay(Schedule):
|
||||
"""Cosine decay with linear warmup."""
|
||||
|
||||
def __init__(self, start_val, min_val, max_val, num_steps, warmup_steps):
|
||||
self.start_val = start_val
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.num_steps = num_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
def __call__(self, itr):
|
||||
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
|
||||
+ self.start_val)
|
||||
cos_itr = (itr - self.warmup_steps) / (self.num_steps - self.warmup_steps)
|
||||
cos = 1 + jnp.cos(jnp.pi * cos_itr)
|
||||
cos_val = 0.5 * (self.max_val - self.min_val) * cos + self.min_val
|
||||
# Select warmup_val if itr < warmup, else cosine val
|
||||
values = jnp.array([warmup_val, cos_val])
|
||||
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
|
||||
return jnp.take(values, index)
|
||||
|
||||
|
||||
class WarmupExpDecay(Schedule):
|
||||
"""Exponential step decay with linear warmup."""
|
||||
|
||||
def __init__(self, start_val, max_val, warmup_steps,
|
||||
decay_factor, decay_interval):
|
||||
self.start_val = start_val
|
||||
self.max_val = max_val
|
||||
self.warmup_steps = warmup_steps
|
||||
self.decay_factor = decay_factor
|
||||
self.decay_interval = decay_interval
|
||||
|
||||
def __call__(self, itr):
|
||||
warmup_val = ((self.max_val - self.start_val) * (itr / self.warmup_steps)
|
||||
+ self.start_val)
|
||||
# How many decay steps have we taken?
|
||||
num_decays = jnp.floor((itr - self.warmup_steps) / self.decay_interval)
|
||||
exp_val = self.max_val * (self.decay_factor ** num_decays)
|
||||
# Select warmup_val if itr < warmup, else exp_val
|
||||
values = jnp.array([warmup_val, exp_val])
|
||||
index = jnp.sum(jnp.array(self.warmup_steps) < itr)
|
||||
return jnp.take(values, index)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
"""Standard SGD with (nesterov) momentum and weight decay.
|
||||
|
||||
Attributes:
|
||||
params: Either a dict mapping param names to JAX tensors, or a list where
|
||||
each member of the list is a dict containing parameters
|
||||
and hyperparameters, allowing one to specify param-specific hyperparams.
|
||||
lr: Learning rate.
|
||||
weight_decay: Weight decay parameter. Note that this is decay, not L2 reg.
|
||||
momentum: Momentum parameter
|
||||
dampening: Dampening parameter
|
||||
nesterov: Bool indicating this optimizer will use the NAG formulation.
|
||||
"""
|
||||
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
|
||||
'nesterov': None}
|
||||
|
||||
def __init__(self, params, lr, weight_decay=None,
|
||||
momentum=None, dampening=0, nesterov=None):
|
||||
super().__init__(
|
||||
params, defaults={'lr': lr, 'weight_decay': weight_decay,
|
||||
'momentum': momentum, 'dampening': dampening,
|
||||
'nesterov': nesterov})
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
"""Prepares all momentum buffers for each parameter."""
|
||||
state = {'step': jnp.zeros(jax.local_device_count())}
|
||||
if self.get_hyper(name, 'momentum') is not None:
|
||||
state['momentum'] = jnp.zeros_like(param)
|
||||
return state
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""The actual update step for this optimizer."""
|
||||
if param is None:
|
||||
return param, state
|
||||
# Apply weight decay
|
||||
if opt_params.get('weight_decay') is not None:
|
||||
grad = grad + param * opt_params['weight_decay']
|
||||
# Update momentum buffers if needed
|
||||
if 'momentum' in state:
|
||||
state['momentum'] = (opt_params['momentum'] * state['momentum']
|
||||
+ (1 - opt_params['dampening']) * grad)
|
||||
if opt_params['nesterov'] is not None:
|
||||
grad = grad + opt_params['momentum'] * state['momentum']
|
||||
else:
|
||||
grad = state['momentum']
|
||||
state['step'] += 1
|
||||
return param - opt_params['lr'] * grad, state
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
"""Adam optimizer, Kingma & Ba, arxiv.org/abs/1412.6980.
|
||||
|
||||
Args:
|
||||
params (iterable): nested list of params to optimize
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (default: 0)
|
||||
use_adamw (bool, optional): If not None, use decoupled weight decay
|
||||
as in arxiv.org/abs/1711.05101. The paper version adds an additional
|
||||
"schedule" hyperparameter eta, which we instead just replace with the
|
||||
learning rate following the PyTorch implementation.
|
||||
|
||||
Note that this implementation will not instantiate a buffer if the
|
||||
beta term for that buffer is passed in as None, thus conserving memory.
|
||||
"""
|
||||
defaults = {'beta1': 0.9, 'beta2': 0.999, 'weight_decay': None, 'eps': 1e-8,
|
||||
'use_adamw': None}
|
||||
|
||||
def __init__(self, params, lr, beta1=0.9, beta2=0.999,
|
||||
eps=1e-8, weight_decay=None, use_adamw=None):
|
||||
super().__init__(params=params,
|
||||
defaults={'lr': lr, 'beta1': beta1,
|
||||
'beta2': beta2, 'eps': eps,
|
||||
'weight_decay': weight_decay,
|
||||
'use_adamw': use_adamw})
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
"""Prepare exp_avg and exp_avg_sq buffers."""
|
||||
state = {'step': jnp.zeros(jax.local_device_count())}
|
||||
if self.get_hyper(name, 'beta1') is not None:
|
||||
state['exp_avg'] = jnp.zeros_like(param)
|
||||
if self.get_hyper(name, 'beta2') is not None:
|
||||
state['exp_avg_sq'] = jnp.zeros_like(param)
|
||||
return state
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""The actual update step for this optimizer."""
|
||||
if param is None:
|
||||
return param, state
|
||||
state['step'] = state['step'] + 1
|
||||
# Apply weight decay
|
||||
if opt_params.get('weight_decay') is not None:
|
||||
if opt_params.get('use_adamw') is not None:
|
||||
param = param * (1 - opt_params['lr'] * opt_params['weight_decay'])
|
||||
else:
|
||||
grad = grad + param * opt_params['weight_decay']
|
||||
# First moment
|
||||
if 'exp_avg' in state:
|
||||
bias_correction1 = 1 - opt_params['beta1'] ** state['step']
|
||||
state['exp_avg'] = (state['exp_avg'] * opt_params['beta1']
|
||||
+ (1 - opt_params['beta1']) * grad)
|
||||
step_size = opt_params['lr'] * state['exp_avg'] / bias_correction1
|
||||
else:
|
||||
step_size = opt_params['lr'] * grad
|
||||
# Second moment
|
||||
if 'exp_avg_sq' in state:
|
||||
bias_correction2 = 1 - opt_params['beta2'] ** state['step']
|
||||
state['exp_avg_sq'] = (state['exp_avg_sq'] * opt_params['beta2']
|
||||
+ (1 - opt_params['beta2']) * grad * grad)
|
||||
denom = jnp.sqrt(state['exp_avg_sq']) * jax.lax.rsqrt(bias_correction2)
|
||||
denom = denom + opt_params['eps']
|
||||
else:
|
||||
denom = jnp.abs(grad) + opt_params['eps'] # Add eps to avoid divide-by-0
|
||||
|
||||
return param - step_size / denom, state
|
||||
|
||||
|
||||
class RMSProp(Optimizer):
|
||||
"""RMSProp optimizer, Tieleman and Hinton, ref: powerpoint slides.
|
||||
|
||||
Implements RMSProp as
|
||||
rms = decay * rms{t-1} + (1-decay) * gradient ** 2
|
||||
mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(rms + epsilon)
|
||||
param -= mom
|
||||
|
||||
Note that the rms buffer is initialized with ones as in TF, as opposed to
|
||||
zeros as in all other implementations.
|
||||
|
||||
Args:
|
||||
params (iterable): nested list of params to optimize
|
||||
lr (float): learning rate (default: 1e-3)
|
||||
decay (float): EMA decay rate for running estimate of squared gradient.
|
||||
momentum (float or None): Use heavy ball momentum instead of instant grad.
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (NOT ADAMW (default: 0))
|
||||
"""
|
||||
defaults = {'weight_decay': None, 'eps': 1e-8}
|
||||
|
||||
def __init__(self, params, lr, decay, momentum, weight_decay=None, eps=1e-8):
|
||||
super().__init__(params=params,
|
||||
defaults={'lr': lr, 'decay': decay,
|
||||
'momentum': momentum, 'eps': eps,
|
||||
'weight_decay': weight_decay})
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
"""Prepare exp_avg and exp_avg_sq buffers."""
|
||||
state = {'step': jnp.zeros(jax.local_device_count())}
|
||||
state['rms'] = jnp.ones_like(param)
|
||||
if self.get_hyper(name, 'momentum') is not None:
|
||||
state['momentum'] = jnp.zeros_like(param)
|
||||
return state
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""The actual update step for this optimizer."""
|
||||
if param is None:
|
||||
return param, state
|
||||
state['step'] = state['step'] + 1
|
||||
# Apply weight decay
|
||||
if opt_params.get('weight_decay') is not None:
|
||||
grad = grad + param * opt_params['weight_decay']
|
||||
# EMA of the squared gradient
|
||||
state['rms'] = (state['rms'] * opt_params['decay']
|
||||
+ (1 - opt_params['decay']) * (grad ** 2))
|
||||
scaled_grad = (opt_params['lr'] * grad
|
||||
/ (state['rms'] + opt_params['eps']) ** 0.5)
|
||||
if state['momentum'] is not None:
|
||||
state['momentum'] = (state['momentum'] * opt_params['momentum']
|
||||
+ scaled_grad)
|
||||
step_size = state['momentum']
|
||||
else:
|
||||
step_size = scaled_grad
|
||||
|
||||
return param - step_size, state
|
||||
|
||||
|
||||
class Fromage(Optimizer):
|
||||
"""Fromage optimizer, Bernstein et al. arXiv.org/abs/2002.03432.
|
||||
|
||||
This version optionally includes weight decay.
|
||||
Attributes:
|
||||
params (iterable): nested list of params to optimize
|
||||
lr (float): learning rate.
|
||||
eps (float, optional): Minimum allowable norm. This term is required for
|
||||
in case parameters are zero-initialized (default: 1e-5).
|
||||
weight_decay (float, optional): weight decay (default: 0).
|
||||
"""
|
||||
|
||||
defaults = {'weight_decay': None, 'eps': 1e-5}
|
||||
|
||||
def __init__(self, params, lr, weight_decay=None, eps=1e-5):
|
||||
super().__init__(
|
||||
params, defaults={'lr': lr, 'weight_decay': weight_decay, 'eps': eps})
|
||||
|
||||
def create_buffers(self, name, param): # pylint: disable=unused-argument
|
||||
"""Prepares all momentum buffers for each parameter."""
|
||||
return {'step': jnp.zeros(1)}
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
"""The actual update step for this optimizer."""
|
||||
if param is None:
|
||||
return param, state
|
||||
if opt_params['weight_decay'] is not None:
|
||||
grad = grad + param * opt_params['weight_decay']
|
||||
grad_norm = jnp.maximum(jnp.linalg.norm(grad), opt_params['eps'])
|
||||
param_norm = jnp.maximum(jnp.linalg.norm(param), opt_params['eps'])
|
||||
mult = jax.lax.rsqrt(1 + opt_params['lr'] ** 2)
|
||||
out = (param - opt_params['lr'] * grad * (param_norm / grad_norm)) * mult
|
||||
return out, state
|
||||
|
||||
|
||||
class Hybrid(Optimizer):
|
||||
"""Optimizer which permits passing param groups with different base opts.
|
||||
|
||||
|
||||
The API for this class follows the case for any other optimizer where one
|
||||
specifies a list of dicts with separate hyperparams, but in this case it
|
||||
requires the user to also specify an 'opt' key for each group, such as e.g.
|
||||
[{'params': params0, 'opt': optim.Adam, 'lr': 0.1}].
|
||||
|
||||
The user must also provide values for any arg in the selected optimizers which
|
||||
does not have a default value associated
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
if any(['opt' not in group for group in param_groups]):
|
||||
raise ValueError('All parameter groups must have an opt key!')
|
||||
self.defaults = ChainMap(*[group['opt'].defaults for group in param_groups])
|
||||
super().__init__(param_groups, defaults=dict(self.defaults))
|
||||
|
||||
def create_buffers(self, name, param):
|
||||
return self.get_hyper(name, 'opt').create_buffers(self, name, param)
|
||||
|
||||
def update_param(self, param, grad, state, opt_params):
|
||||
return opt_params['opt'].update_param(self, param, grad, state, opt_params)
|
||||
14
nfnets/requirements.txt
Normal file
14
nfnets/requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
absl-py==0.10.0
|
||||
chex>=0.0.2
|
||||
dm-haiku>=0.0.3
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.58
|
||||
jaxline>=0.0.1
|
||||
ml_collections>=0.1
|
||||
numpy>=1.18.0
|
||||
tensorflow>=2.3.1
|
||||
tensorflow-addons>=0.12.0
|
||||
tensorflow-datasets>=4.1.0
|
||||
tensorflow-probability>=0.12.1
|
||||
typing_extensions>=3.7
|
||||
wrapt>=1.11.2
|
||||
183
nfnets/resnet.py
Normal file
183
nfnets/resnet.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ResNet model family."""
|
||||
import functools
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import base
|
||||
|
||||
|
||||
class ResNet(hk.Module):
|
||||
"""ResNetv2 Models."""
|
||||
|
||||
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
|
||||
'ResNet101': {'depth': [3, 4, 23, 3]},
|
||||
'ResNet152': {'depth': [3, 8, 36, 3]},
|
||||
'ResNet200': {'depth': [3, 24, 36, 3]},
|
||||
'ResNet288': {'depth': [24, 24, 24, 24]},
|
||||
'ResNet600': {'depth': [50, 50, 50, 50]},
|
||||
}
|
||||
|
||||
def __init__(self, width, num_classes,
|
||||
variant='ResNet50',
|
||||
which_norm='BatchNorm', norm_kwargs=None,
|
||||
activation='relu', drop_rate=0.0,
|
||||
fc_init=jnp.zeros, conv_kwargs=None,
|
||||
preactivation=True, use_se=False, se_ratio=0.25,
|
||||
name='ResNet'):
|
||||
super().__init__(name=name)
|
||||
self.width = width
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.depth_pattern = self.variant_dict[variant]['depth']
|
||||
self.activation = getattr(jax.nn, activation)
|
||||
self.drop_rate = drop_rate
|
||||
self.which_norm = getattr(hk, which_norm)
|
||||
if norm_kwargs is not None:
|
||||
self.which_norm = functools.partial(self.which_norm, **norm_kwargs)
|
||||
if conv_kwargs is not None:
|
||||
self.which_conv = functools.partial(hk.Conv2D, **conv_kwargs)
|
||||
else:
|
||||
self.which_conv = hk.Conv2D
|
||||
self.preactivation = preactivation
|
||||
|
||||
# Stem
|
||||
self.initial_conv = self.which_conv(16 * self.width, kernel_shape=7,
|
||||
stride=2, padding='SAME',
|
||||
with_bias=False, name='initial_conv')
|
||||
if not self.preactivation:
|
||||
self.initial_bn = self.which_norm(name='initial_bn')
|
||||
which_block = ResBlockV2 if self.preactivation else ResBlockV1
|
||||
# Body
|
||||
self.blocks = []
|
||||
for multiplier, blocks_per_stage, stride in zip([64, 128, 256, 512],
|
||||
self.depth_pattern,
|
||||
[1, 2, 2, 2]):
|
||||
for block_index in range(blocks_per_stage):
|
||||
self.blocks += [which_block(multiplier * self.width,
|
||||
use_projection=block_index == 0,
|
||||
stride=stride if block_index == 0 else 1,
|
||||
activation=self.activation,
|
||||
which_norm=self.which_norm,
|
||||
which_conv=self.which_conv,
|
||||
use_se=use_se,
|
||||
se_ratio=se_ratio)]
|
||||
|
||||
# Head
|
||||
self.final_bn = self.which_norm(name='final_bn')
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training, test_local_stats=False,
|
||||
return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
outputs = {}
|
||||
# Stem
|
||||
out = self.initial_conv(x)
|
||||
if not self.preactivation:
|
||||
out = self.activation(self.initial_bn(out, is_training, test_local_stats))
|
||||
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_var = block(out, is_training, test_local_stats)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_var
|
||||
if self.preactivation:
|
||||
out = self.activation(self.final_bn(out, is_training, test_local_stats))
|
||||
# Pool, dropout, classify
|
||||
pool = jnp.mean(out, axis=[1, 2])
|
||||
# Return pool before dropout in case we want to regularize it separately.
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
outputs['logits'] = self.fc(pool)
|
||||
return outputs
|
||||
|
||||
|
||||
class ResBlockV2(hk.Module):
|
||||
"""ResNet preac block, 1x1->3x3->1x1 with strides and shortcut downsample."""
|
||||
|
||||
def __init__(self, out_ch, stride=1, use_projection=False,
|
||||
activation=jax.nn.relu, which_norm=hk.BatchNorm,
|
||||
which_conv=hk.Conv2D, use_se=False, se_ratio=0.25,
|
||||
name=None):
|
||||
super().__init__(name=name)
|
||||
self.out_ch = out_ch
|
||||
self.stride = stride
|
||||
self.use_projection = use_projection
|
||||
self.activation = activation
|
||||
self.which_norm = which_norm
|
||||
self.which_conv = which_conv
|
||||
self.use_se = use_se
|
||||
self.se_ratio = se_ratio
|
||||
|
||||
self.width = self.out_ch // 4
|
||||
|
||||
self.bn0 = which_norm(name='bn0')
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, with_bias=False,
|
||||
padding='SAME', name='conv0')
|
||||
self.bn1 = which_norm(name='bn1')
|
||||
self.conv1 = which_conv(self.width, stride=self.stride,
|
||||
kernel_shape=3, with_bias=False,
|
||||
padding='SAME', name='conv1')
|
||||
self.bn2 = which_norm(name='bn2')
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, with_bias=False,
|
||||
padding='SAME', name='conv2')
|
||||
if self.use_projection:
|
||||
self.conv_shortcut = which_conv(self.out_ch, stride=stride,
|
||||
kernel_shape=1, with_bias=False,
|
||||
padding='SAME', name='conv_shortcut')
|
||||
if self.use_se:
|
||||
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
|
||||
|
||||
def __call__(self, x, is_training, test_local_stats):
|
||||
bn_args = (is_training, test_local_stats)
|
||||
out = self.activation(self.bn0(x, *bn_args))
|
||||
if self.use_projection:
|
||||
shortcut = self.conv_shortcut(out)
|
||||
else:
|
||||
shortcut = x
|
||||
out = self.conv0(out)
|
||||
out = self.conv1(self.activation(self.bn1(out, *bn_args)))
|
||||
out = self.conv2(self.activation(self.bn2(out, *bn_args)))
|
||||
if self.use_se:
|
||||
out = self.se(out) * out
|
||||
# Get average residual standard deviation for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
return out + shortcut, res_avg_var
|
||||
|
||||
|
||||
class ResBlockV1(ResBlockV2):
|
||||
"""Post-Ac Residual Block."""
|
||||
|
||||
def __call__(self, x, is_training, test_local_stats):
|
||||
bn_args = (is_training, test_local_stats)
|
||||
if self.use_projection:
|
||||
shortcut = self.conv_shortcut(x)
|
||||
shortcut = self.which_norm(name='shortcut_bn')(shortcut, *bn_args)
|
||||
else:
|
||||
shortcut = x
|
||||
out = self.activation(self.bn0(self.conv0(x), *bn_args))
|
||||
out = self.activation(self.bn1(self.conv1(out), *bn_args))
|
||||
out = self.bn2(self.conv2(out), *bn_args)
|
||||
if self.use_se:
|
||||
out = self.se(out) * out
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
return self.activation(out + shortcut), res_avg_var
|
||||
22
nfnets/run.sh
Normal file
22
nfnets/run.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
set -e
|
||||
|
||||
python3 -m venv nfnets_venv
|
||||
source nfnets_venv/bin/activate
|
||||
pip3 install --upgrade setuptools wheel
|
||||
pip3 install -r nfnets/requirements.txt
|
||||
|
||||
python3 -m nfnets.test
|
||||
194
nfnets/skipinit_resnet.py
Normal file
194
nfnets/skipinit_resnet.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ResNetV2 (Pre-activation) with SkipInit."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import base
|
||||
|
||||
|
||||
# Nonlinearities
|
||||
nonlinearities = {
|
||||
'swish': jax.nn.silu,
|
||||
'relu': jax.nn.relu,
|
||||
'identity': lambda x: x}
|
||||
|
||||
|
||||
class SkipInit_ResNet(hk.Module):
|
||||
"""Skip-Init based ResNet."""
|
||||
|
||||
variant_dict = {'ResNet50': {'depth': [3, 4, 6, 3]},
|
||||
'ResNet101': {'depth': [3, 4, 23, 3]},
|
||||
'ResNet152': {'depth': [3, 8, 36, 3]},
|
||||
'ResNet200': {'depth': [3, 24, 36, 3]},
|
||||
'ResNet288': {'depth': [24, 24, 24, 24]},
|
||||
'ResNet600': {'depth': [50, 50, 50, 50]},
|
||||
}
|
||||
|
||||
def __init__(self, num_classes, variant='ResNet50', width=4,
|
||||
stochdepth_rate=0.1, drop_rate=None,
|
||||
activation='relu', fc_init=jnp.zeros,
|
||||
name='SkipInit_ResNet'):
|
||||
super().__init__(name=name)
|
||||
self.num_classes = num_classes
|
||||
self.variant = variant
|
||||
self.width = width
|
||||
# Get variant info
|
||||
block_params = self.variant_dict[self.variant]
|
||||
self.width_pattern = [item * self.width for item in [64, 128, 256, 512]]
|
||||
self.depth_pattern = block_params['depth']
|
||||
self.activation = nonlinearities[activation]
|
||||
if drop_rate is None:
|
||||
self.drop_rate = block_params['drop_rate']
|
||||
else:
|
||||
self.drop_rate = drop_rate
|
||||
self.which_conv = hk.Conv2D
|
||||
# Stem
|
||||
ch = int(16 * self.width)
|
||||
self.initial_conv = self.which_conv(ch, kernel_shape=7, stride=2,
|
||||
padding='SAME', with_bias=False,
|
||||
name='initial_conv')
|
||||
|
||||
# Body
|
||||
self.blocks = []
|
||||
num_blocks = sum(self.depth_pattern)
|
||||
index = 0 # Overall block index
|
||||
block_args = (self.width_pattern, self.depth_pattern, [1, 2, 2, 2])
|
||||
for block_width, stage_depth, stride in zip(*block_args):
|
||||
for block_index in range(stage_depth):
|
||||
# Block stochastic depth drop-rate
|
||||
block_stochdepth_rate = stochdepth_rate * index / num_blocks
|
||||
self.blocks += [NFResBlock(ch, block_width,
|
||||
stride=stride if block_index == 0 else 1,
|
||||
activation=self.activation,
|
||||
which_conv=self.which_conv,
|
||||
stochdepth_rate=block_stochdepth_rate,
|
||||
)]
|
||||
ch = block_width
|
||||
index += 1
|
||||
|
||||
# Head
|
||||
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
|
||||
|
||||
def __call__(self, x, is_training=True, return_metrics=False):
|
||||
"""Return the output of the final layer without any [log-]softmax."""
|
||||
# Stem
|
||||
outputs = {}
|
||||
out = self.initial_conv(x)
|
||||
out = hk.max_pool(out, window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1), padding='SAME')
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, 0))
|
||||
# Blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
out, res_avg_var = block(out, is_training=is_training)
|
||||
if return_metrics:
|
||||
outputs.update(base.signal_metrics(out, i + 1))
|
||||
outputs[f'res_avg_var_{i}'] = res_avg_var
|
||||
# Final-conv->activation, pool, dropout, classify
|
||||
pool = jnp.mean(self.activation(out), [1, 2])
|
||||
outputs['pool'] = pool
|
||||
# Optionally apply dropout
|
||||
if self.drop_rate > 0.0 and is_training:
|
||||
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
|
||||
outputs['logits'] = self.fc(pool)
|
||||
return outputs
|
||||
|
||||
def count_flops(self, h, w):
|
||||
flops = []
|
||||
flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
|
||||
h, w = h / 2, w / 2
|
||||
# Body FLOPs
|
||||
for block in self.blocks:
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
class NFResBlock(hk.Module):
|
||||
"""Normalizer-Free pre-activation ResNet Block."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25,
|
||||
kernel_size=3, stride=1,
|
||||
which_conv=hk.Conv2D, activation=jax.nn.relu,
|
||||
stochdepth_rate=None, name=None):
|
||||
super().__init__(name=name)
|
||||
self.in_ch, self.out_ch = in_ch, out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.activation = activation
|
||||
# Bottleneck width
|
||||
self.width = int(self.out_ch * bottleneck_ratio)
|
||||
self.stride = stride
|
||||
# Conv 0 (typically expansion conv)
|
||||
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
|
||||
name='conv0')
|
||||
# Grouped NxN conv
|
||||
self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride,
|
||||
padding='SAME', name='conv1')
|
||||
# Conv 2, typically projection conv
|
||||
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
|
||||
name='conv2')
|
||||
# Use shortcut conv on channel change or downsample.
|
||||
self.use_projection = stride > 1 or self.in_ch != self.out_ch
|
||||
if self.use_projection:
|
||||
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
|
||||
stride=stride, padding='SAME',
|
||||
name='conv_shortcut')
|
||||
# Are we using stochastic depth?
|
||||
self._has_stochdepth = (stochdepth_rate is not None and
|
||||
stochdepth_rate > 0. and stochdepth_rate < 1.0)
|
||||
if self._has_stochdepth:
|
||||
self.stoch_depth = base.StochDepth(stochdepth_rate)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
out = self.activation(x)
|
||||
shortcut = x
|
||||
if self.use_projection: # Downsample with conv1x1
|
||||
shortcut = self.conv_shortcut(out)
|
||||
out = self.conv0(out)
|
||||
out = self.conv1(self.activation(out))
|
||||
out = self.conv2(self.activation(out))
|
||||
# Get average residual standard deviation for reporting metrics.
|
||||
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
|
||||
# Apply stochdepth if applicable.
|
||||
if self._has_stochdepth:
|
||||
out = self.stoch_depth(out, is_training)
|
||||
# SkipInit Gain
|
||||
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
|
||||
return out + shortcut, res_avg_var
|
||||
|
||||
def count_flops(self, h, w):
|
||||
# Count conv FLOPs based on input HW
|
||||
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
|
||||
# If block is strided we decrease resolution here.
|
||||
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
|
||||
if self.stride > 1:
|
||||
h, w = h / self.stride, w / self.stride
|
||||
if self.use_projection:
|
||||
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
|
||||
else:
|
||||
sc_flops = 0
|
||||
# SE flops happen on avg-pooled activations
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
|
||||
|
||||
35
nfnets/test.py
Normal file
35
nfnets/test.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Quick script to test that experiment can import and run."""
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from nfnets import experiment
|
||||
|
||||
config = experiment.get_config()
|
||||
exp_config = config.experiment_kwargs.config
|
||||
exp_config.train_batch_size = 2
|
||||
exp_config.eval_batch_size = 2
|
||||
exp_config.lr = 0.1
|
||||
exp_config.fake_data = True
|
||||
exp_config.model_kwargs.width = 2
|
||||
print(exp_config.model_kwargs)
|
||||
|
||||
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
|
||||
bcast = jax.pmap(lambda x: x)
|
||||
global_step = bcast(jnp.zeros(jax.local_device_count()))
|
||||
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
|
||||
print('Taking a single experiment step for test purposes!')
|
||||
result = xp.step(global_step, rng)
|
||||
print(f'Step successfully taken, resulting metrics are {result}')
|
||||
107
nfnets/utils.py
Normal file
107
nfnets/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils."""
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tree
|
||||
|
||||
|
||||
def reduce_fn(x, mode):
|
||||
"""Reduce fn for various losses."""
|
||||
if mode == 'none' or mode is None:
|
||||
return jnp.asarray(x)
|
||||
elif mode == 'sum':
|
||||
return jnp.sum(x)
|
||||
elif mode == 'mean':
|
||||
return jnp.mean(x)
|
||||
else:
|
||||
raise ValueError('Unsupported reduction option.')
|
||||
|
||||
|
||||
def softmax_cross_entropy(logits, labels, reduction='sum'):
|
||||
"""Computes softmax cross entropy given logits and one-hot class labels.
|
||||
|
||||
Args:
|
||||
logits: Logit output values.
|
||||
labels: Ground truth one-hot-encoded labels.
|
||||
reduction: Type of reduction to apply to loss.
|
||||
|
||||
Returns:
|
||||
Loss value. If `reduction` is `none`, this has the same shape as `labels`;
|
||||
otherwise, it is scalar.
|
||||
|
||||
Raises:
|
||||
ValueError: If the type of `reduction` is unsupported.
|
||||
"""
|
||||
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
|
||||
return reduce_fn(loss, reduction)
|
||||
|
||||
|
||||
def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)):
|
||||
"""Calculate top-k error for multiple k values."""
|
||||
metrics = {}
|
||||
argsorted_logits = jnp.argsort(logits)
|
||||
for k in topk:
|
||||
pred_labels = argsorted_logits[..., -k:]
|
||||
# Get the number of examples where the label is in the top-k predictions
|
||||
correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32)
|
||||
if mask is not None:
|
||||
correct *= mask
|
||||
metrics[f'{prefix}top_{k}_acc'] = correct
|
||||
return metrics
|
||||
|
||||
|
||||
@jax.vmap
|
||||
def any_in(prediction, target):
|
||||
"""For each row in a and b, checks if any element of a is in b."""
|
||||
return jnp.isin(prediction, target)
|
||||
|
||||
|
||||
def tf1_ema(ema_value, current_value, decay, step):
|
||||
"""Implements EMA with TF1-style decay warmup."""
|
||||
decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step))
|
||||
return ema_value * decay + current_value * (1 - decay)
|
||||
|
||||
|
||||
def ema(ema_value, current_value, decay, step):
|
||||
"""Implements EMA without any warmup."""
|
||||
del step
|
||||
return ema_value * decay + current_value * (1 - decay)
|
||||
|
||||
|
||||
to_bf16 = lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x
|
||||
from_bf16 = lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x
|
||||
|
||||
|
||||
def _replicate(x, devices=None):
|
||||
"""Replicate an object on each device."""
|
||||
x = jax.numpy.array(x)
|
||||
if devices is None:
|
||||
devices = jax.local_devices()
|
||||
return jax.api.device_put_sharded(len(devices) * [x], devices)
|
||||
|
||||
|
||||
def broadcast(obj):
|
||||
"""Broadcasts an object to all devices."""
|
||||
if obj is not None and not isinstance(obj, bool):
|
||||
return _replicate(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def split_tree(tuple_tree, base_tree, n):
|
||||
"""Splits tuple_tree with n-tuple leaves into n trees."""
|
||||
return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop
|
||||
for i in range(n)]
|
||||
Reference in New Issue
Block a user