Initial commit for NFNets and AGC

PiperOrigin-RevId: 357022514
This commit is contained in:
Andy Brock
2021-02-11 19:38:33 +00:00
committed by Louise Deason
parent abc9ba0635
commit 68a1754d29
7 changed files with 646 additions and 20 deletions
+21 -4
View File
@@ -1,7 +1,9 @@
# Code for Normalizer-Free ResNets (Brock, De, Smith, ICLR2021)
# Code for Normalizer-Free Networks
This repository contains code for the ICLR 2021 paper
"Characterizing signal propagation to close the performance gap in unnormalized
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith.
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith, and the arXiv preprint
"High-Performance Large-Scale Image Recognition Without Normalization" by
Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan.
## Running this code
@@ -16,14 +18,29 @@ used in dataset.py in order to train on ImageNet.
## Giving Credit
If you use this code in your work, we ask you to cite this paper:
If you use this code in your work, we ask you to please cite one or both of the
following papers.
The reference for the Normalizer-Free structure and NF-ResNets or NF-Regnets:
```
@inproceedings{brock2021characterizing,
author={Andrew Brock and Soham De and Samuel L. SMith},
author={Andrew Brock and Soham De and Samuel L. Smith},
title={Characterizing signal propagation to close the performance gap in
unnormalized ResNets},
booktitle={9th International Conference on Learning Representations, {ICLR}},
year={2021}
}
```
The reference for Adaptive Gradient Clipping (AGC) and the NFNets models:
```
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
```
+78
View File
@@ -0,0 +1,78 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adaptive gradient clipping transform for Optax."""
import jax
import jax.numpy as jnp
import optax
def compute_norm(x, axis, keepdims):
"""Axis-wise euclidean norm."""
return jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
def unitwise_norm(x):
"""Compute norms of each output unit separately, also for linear layers."""
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO or multihead linear
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2,]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}')
return compute_norm(x, axis, keepdims)
def my_clip(g_norm, max_norm, grad):
"""Applies my gradient clipping unit-wise."""
trigger = g_norm < max_norm
# This little max(., 1e-6) is distinct from the normal eps and just prevents
# division by zero. It technically should be impossible to engage.
clipped_grad = grad * (max_norm / jnp.maximum(g_norm, 1e-6))
return jnp.where(trigger, grad, clipped_grad)
def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation:
"""Clip updates to be at most clipping * parameter_norm.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization.
Args:
clip: Maximum allowed ratio of update norm to parameter norm.
eps: epsilon term to prevent clipping of zero-initialized params.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return optax.ClipByGlobalNormState()
def update_fn(updates, state, params):
g_norm = jax.tree_map(unitwise_norm, updates)
p_norm = jax.tree_map(unitwise_norm, params)
# Maximum allowable norm
max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale
updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
+45 -1
View File
@@ -50,6 +50,50 @@ nf_regnet_params = {
'drop_rate': 0.5},
}
nfnet_params = {}
# F-series models
nfnet_params.update(**{
'F0': {
'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],
'train_imsize': 192, 'test_imsize': 256,
'RA_level': '405', 'drop_rate': 0.2},
'F1': {
'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],
'train_imsize': 224, 'test_imsize': 320,
'RA_level': '410', 'drop_rate': 0.3},
'F2': {
'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],
'train_imsize': 256, 'test_imsize': 352,
'RA_level': '410', 'drop_rate': 0.4},
'F3': {
'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],
'train_imsize': 320, 'test_imsize': 416,
'RA_level': '415', 'drop_rate': 0.4},
'F4': {
'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],
'train_imsize': 384, 'test_imsize': 512,
'RA_level': '415', 'drop_rate': 0.5},
'F5': {
'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],
'train_imsize': 416, 'test_imsize': 544,
'RA_level': '415', 'drop_rate': 0.5},
'F6': {
'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],
'train_imsize': 448, 'test_imsize': 576,
'RA_level': '415', 'drop_rate': 0.5},
'F7': {
'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24],
'train_imsize': 480, 'test_imsize': 608,
'RA_level': '415', 'drop_rate': 0.5},
})
# Minor variants FN+, slightly wider
nfnet_params.update(**{
**{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],}
for key in nfnet_params}
})
# Nonlinearities with magic constants (gamma) baked in.
# Note that not all nonlinearities will be stable, especially if they are
@@ -122,7 +166,7 @@ def signal_metrics(x, i):
def count_conv_flops(in_ch, conv, h, w):
"""For a conv layer with in_ch inputs, count the FLOPS."""
# How many output floats are we producing?
# How many outputs are we producing? Note this is wrong for VALID padding.
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
# At each OHW location we do computation equal to (I//G) * kh * kw
flop_per_loc = (in_ch / conv.feature_group_count)
+126
View File
@@ -0,0 +1,126 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""ImageNet experiment with NFNets."""
import haiku as hk
from ml_collections import config_dict
from nfnets import experiment
from nfnets import optim
def get_config():
"""Return config object for training."""
config = experiment.get_config()
# Experiment config.
train_batch_size = 4096 # Global batch size.
images_per_epoch = 1281167
num_epochs = 360
steps_per_epoch = images_per_epoch / train_batch_size
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
config.random_seed = 0
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
lr=0.1,
num_epochs=num_epochs,
label_smoothing=0.1,
model='NFNet',
image_size=224,
use_ema=True,
ema_decay=0.99999,
ema_start=0,
augment_name=None,
augment_before_mix=False,
eval_preproc='resize_crop_32',
train_batch_size=train_batch_size,
eval_batch_size=50,
eval_subset='test',
num_classes=1000,
which_dataset='imagenet',
which_loss='softmax_cross_entropy', # One of softmax or sigmoid
bfloat16=True,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(num_steps=config.training_steps,
start_val=0,
min_val=0.0,
warmup_steps=5*steps_per_epoch),
),
lr_scale_by_bs=True,
optimizer=dict(
name='SGD_AGC',
kwargs={'momentum': 0.9, 'nesterov': True,
'weight_decay': 2e-5,
'clipping': 0.01, 'eps': 1e-3},
),
model_kwargs=dict(
variant='F0',
width=1.0,
se_ratio=0.5,
alpha=0.2,
stochdepth_rate=0.25,
drop_rate=None, # Use native drop-rate
activation='gelu',
final_conv_mult=2,
final_conv_ch=None,
use_two_convs=True,
),
)))
# Unlike NF-RegNets, use the same weight decay for all, but vary RA levels
variant = config.experiment_kwargs.config.model_kwargs.variant
# RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5)
augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415',
'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant]
aug_base_name = 'cutmix_mixup_randaugment'
config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}'
return config
class Experiment(experiment.Experiment):
"""Experiment with correct parameter filtering for applying AGC."""
def _make_opt(self):
# Separate conv params and gains/biases
def pred_gb(mod, name, val):
del mod, val
return (name in ['scale', 'offset', 'b']
or 'gain' in name or 'bias' in name)
gains_biases, weights = hk.data_structures.partition(pred_gb, self._params)
def pred_fc(mod, name, val):
del name, val
return 'linear' in mod and 'squeeze_excite' not in mod
fc_weights, weights = hk.data_structures.partition(pred_fc, weights)
# Lr schedule with batch-based LR scaling
if self.config.lr_scale_by_bs:
max_lr = (self.config.lr * self.config.train_batch_size) / 256
else:
max_lr = self.config.lr
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
# Optimizer; no need to broadcast!
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
opt_kwargs['lr'] = lr_schedule
opt_module = getattr(optim, self.config.optimizer.name)
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None,},
{'params': fc_weights, 'clipping': None},
{'params': weights}], **opt_kwargs)
if self._opt_state is None:
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
+272
View File
@@ -0,0 +1,272 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Norm-Free Nets."""
# pylint: disable=unused-import
# pylint: disable=invalid-name
import functools
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
from nfnets import base
class NFNet(hk.Module):
"""Normalizer-Free Networks with an improved architecture.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization.
"""
variant_dict = base.nfnet_params
def __init__(self, num_classes, variant='F0',
width=1.0, se_ratio=0.5,
alpha=0.2, stochdepth_rate=0.1, drop_rate=None,
activation='gelu', fc_init=None,
final_conv_mult=2, final_conv_ch=None,
use_two_convs=True,
name='NFNet'):
super().__init__(name=name)
self.num_classes = num_classes
self.variant = variant
self.width = width
self.se_ratio = se_ratio
# Get variant info
block_params = self.variant_dict[self.variant]
self.train_imsize = block_params['train_imsize']
self.test_imsize = block_params['test_imsize']
self.width_pattern = block_params['width']
self.depth_pattern = block_params['depth']
self.bneck_pattern = block_params.get('expansion', [0.5] * 4)
self.group_pattern = block_params.get('group_width', [128] * 4)
self.big_pattern = block_params.get('big_width', [True] * 4)
self.activation = base.nonlinearities[activation]
if drop_rate is None:
self.drop_rate = block_params['drop_rate']
else:
self.drop_rate = drop_rate
self.which_conv = base.WSConv2D
# Stem
ch = self.width_pattern[0] // 2
self.stem = hk.Sequential([
self.which_conv(16, kernel_shape=3, stride=2,
padding='SAME', name='stem_conv0'),
self.activation,
self.which_conv(32, kernel_shape=3, stride=1,
padding='SAME', name='stem_conv1'),
self.activation,
self.which_conv(64, kernel_shape=3, stride=1,
padding='SAME', name='stem_conv2'),
self.activation,
self.which_conv(ch, kernel_shape=3, stride=2,
padding='SAME', name='stem_conv3'),
])
# Body
self.blocks = []
expected_std = 1.0
num_blocks = sum(self.depth_pattern)
index = 0 # Overall block index
stride_pattern = [1, 2, 2, 2]
block_args = zip(self.width_pattern, self.depth_pattern, self.bneck_pattern,
self.group_pattern, self.big_pattern, stride_pattern)
for (block_width, stage_depth, expand_ratio,
group_size, big_width, stride) in block_args:
for block_index in range(stage_depth):
# Scalar pre-multiplier so each block sees an N(0,1) input at init
beta = 1./ expected_std
# Block stochastic depth drop-rate
block_stochdepth_rate = stochdepth_rate * index / num_blocks
out_ch = (int(block_width * self.width))
self.blocks += [NFBlock(ch, out_ch,
expansion=expand_ratio, se_ratio=se_ratio,
group_size=group_size,
stride=stride if block_index == 0 else 1,
beta=beta, alpha=alpha,
activation=self.activation,
which_conv=self.which_conv,
stochdepth_rate=block_stochdepth_rate,
big_width=big_width,
use_two_convs=use_two_convs,
)]
ch = out_ch
index += 1
# Reset expected std but still give it 1 block of growth
if block_index == 0:
expected_std = 1.0
expected_std = (expected_std **2 + alpha**2)**0.5
# Head
if final_conv_mult is None:
if final_conv_ch is None:
raise ValueError('Must provide one of final_conv_mult or final_conv_ch')
ch = final_conv_ch
else:
ch = int(final_conv_mult * ch)
self.final_conv = self.which_conv(ch, kernel_shape=1,
padding='SAME', name='final_conv')
# By default, initialize with N(0, 0.01)
if fc_init is None:
fc_init = hk.initializers.RandomNormal(mean=0, stddev=0.01)
self.fc = hk.Linear(self.num_classes, w_init=fc_init, with_bias=True)
def __call__(self, x, is_training=True, return_metrics=False):
"""Return the output of the final layer without any [log-]softmax."""
# Stem
outputs = {}
out = self.stem(x)
if return_metrics:
outputs.update(base.signal_metrics(out, 0))
# Blocks
for i, block in enumerate(self.blocks):
out, res_avg_var = block(out, is_training=is_training)
if return_metrics:
outputs.update(base.signal_metrics(out, i + 1))
outputs[f'res_avg_var_{i}'] = res_avg_var
# Final-conv->activation, pool, dropout, classify
out = self.activation(self.final_conv(out))
pool = jnp.mean(out, [1, 2])
outputs['pool'] = pool
# Optionally apply dropout
if self.drop_rate > 0.0 and is_training:
pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
outputs['logits'] = self.fc(pool)
return outputs
def count_flops(self, h, w):
flops = []
ch = 3
for module in self.stem.layers:
if isinstance(module, hk.Conv2D):
flops += [base.count_conv_flops(ch, module, h, w)]
if any([item > 1 for item in module.stride]):
h, w = h / module.stride[0], w / module.stride[1]
ch = module.output_channels
# Body FLOPs
for block in self.blocks:
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
return flops, sum(flops)
class NFBlock(hk.Module):
"""Normalizer-Free Net Block."""
def __init__(self, in_ch, out_ch, expansion=0.5, se_ratio=0.5,
kernel_shape=3, group_size=128, stride=1,
beta=1.0, alpha=0.2,
which_conv=base.WSConv2D, activation=jax.nn.gelu,
big_width=True, use_two_convs=True,
stochdepth_rate=None, name=None):
super().__init__(name=name)
self.in_ch, self.out_ch = in_ch, out_ch
self.expansion = expansion
self.se_ratio = se_ratio
self.kernel_shape = kernel_shape
self.activation = activation
self.beta, self.alpha = beta, alpha
# Mimic resnet style bigwidth scaling?
width = int((self.out_ch if big_width else self.in_ch) * expansion)
# Round expanded with based on group count
self.groups = width // group_size
self.width = group_size * self.groups
self.stride = stride
self.use_two_convs = use_two_convs
# Conv 0 (typically expansion conv)
self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME',
name='conv0')
# Grouped NxN conv
self.conv1 = which_conv(self.width, kernel_shape=kernel_shape,
stride=stride, padding='SAME',
feature_group_count=self.groups, name='conv1')
if self.use_two_convs:
self.conv1b = which_conv(self.width, kernel_shape=kernel_shape,
stride=1, padding='SAME',
feature_group_count=self.groups, name='conv1b')
# Conv 2, typically projection conv
self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME',
name='conv2')
# Use shortcut conv on channel change or downsample.
self.use_projection = stride > 1 or self.in_ch != self.out_ch
if self.use_projection:
self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1,
padding='SAME', name='conv_shortcut')
# Squeeze + Excite Module
self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
# Are we using stochastic depth?
self._has_stochdepth = (stochdepth_rate is not None and
stochdepth_rate > 0. and stochdepth_rate < 1.0)
if self._has_stochdepth:
self.stoch_depth = base.StochDepth(stochdepth_rate)
def __call__(self, x, is_training):
out = self.activation(x) * self.beta
if self.stride > 1: # Average-pool downsample.
shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1),
strides=(1, 2, 2, 1), padding='SAME')
if self.use_projection:
shortcut = self.conv_shortcut(shortcut)
elif self.use_projection:
shortcut = self.conv_shortcut(out)
else:
shortcut = x
out = self.conv0(out)
out = self.conv1(self.activation(out))
if self.use_two_convs:
out = self.conv1b(self.activation(out))
out = self.conv2(self.activation(out))
out = (self.se(out) * 2) * out # Multiply by 2 for rescaling
# Get average residual standard deviation for reporting metrics.
res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
# Apply stochdepth if applicable.
if self._has_stochdepth:
out = self.stoch_depth(out, is_training)
# SkipInit Gain
out = out * hk.get_parameter('skip_gain', (), out.dtype, init=jnp.zeros)
return out * self.alpha + shortcut, res_avg_var
def count_flops(self, h, w):
# Count conv FLOPs based on input HW
expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
# If block is strided we decrease resolution here.
dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
if self.stride > 1:
h, w = h / self.stride, w / self.stride
if self.use_two_convs:
dw_flops += base.count_conv_flops(self.width, self.conv1b, h, w)
if self.use_projection:
sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
else:
sc_flops = 0
# SE flops happen on avg-pooled activations
se_flops = self.se.fc0.output_size * self.out_ch
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
+62
View File
@@ -429,6 +429,68 @@ class Fromage(Optimizer):
return out, state
def compute_norm(x, axis, keepdims):
"""Returns norm over arbitrary axis."""
norm = jnp.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
return norm
def unitwise_norm(x):
"""Computes norms of each output unit separately, assuming (HW)IO weights."""
if len(jnp.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2,]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
return compute_norm(x, axis, keepdims)
class SGD_AGC(Optimizer): # pylint:disable=invalid-name
"""SGD with Unit-Adaptive Gradient-Clipping.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization.
"""
defaults = {'weight_decay': None, 'momentum': None, 'dampening': 0,
'nesterov': None, 'clipping': 0.01, 'eps': 1e-3}
def __init__(self, params, lr, weight_decay=None,
momentum=None, dampening=0, nesterov=None,
clipping=0.01, eps=1e-3):
super().__init__(
params, defaults={'lr': lr, 'weight_decay': weight_decay,
'momentum': momentum, 'dampening': dampening,
'clipping': clipping, 'nesterov': nesterov,
'eps': eps})
def create_buffers(self, name, param):
return SGD.create_buffers(self, name, param)
def update_param(self, param, grad, state, opt_params):
"""Clips grads if necessary, then applies the optimizer update."""
if param is None:
return param, state
if opt_params['clipping'] is not None:
param_norm = jnp.maximum(unitwise_norm(param), opt_params['eps'])
grad_norm = unitwise_norm(grad)
max_norm = param_norm * opt_params['clipping']
# If grad norm > clipping * param_norm, rescale
trigger = grad_norm > max_norm
# Note the max(||G||, 1e-6) is technically unnecessary here, as
# the clipping shouldn't trigger if the grad norm is zero,
# but we include it in practice as a "just-in-case".
clipped_grad = grad * (max_norm / jnp.maximum(grad_norm, 1e-6))
grad = jnp.where(trigger, clipped_grad, grad)
return SGD.update_param(self, param, grad, state, opt_params)
class Hybrid(Optimizer):
"""Optimizer which permits passing param groups with different base opts.
+42 -15
View File
@@ -16,20 +16,47 @@
import jax
import jax.numpy as jnp
from nfnets import experiment
from nfnets import experiment_nfnets
config = experiment.get_config()
exp_config = config.experiment_kwargs.config
exp_config.train_batch_size = 2
exp_config.eval_batch_size = 2
exp_config.lr = 0.1
exp_config.fake_data = True
exp_config.model_kwargs.width = 2
print(exp_config.model_kwargs)
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'Step successfully taken, resulting metrics are {result}')
def test_experiment():
"""Tests the main experiment."""
config = experiment.get_config()
exp_config = config.experiment_kwargs.config
exp_config.train_batch_size = 2
exp_config.eval_batch_size = 2
exp_config.lr = 0.1
exp_config.fake_data = True
exp_config.model_kwargs.width = 2
print(exp_config.model_kwargs)
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'Step successfully taken, resulting metrics are {result}')
def test_nfnet_experiment():
"""Tests the NFNet experiment."""
config = experiment_nfnets.get_config()
exp_config = config.experiment_kwargs.config
exp_config.train_batch_size = 2
exp_config.eval_batch_size = 2
exp_config.lr = 0.1
exp_config.fake_data = True
exp_config.model_kwargs.width = 2
print(exp_config.model_kwargs)
xp = experiment_nfnets.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single NFNet experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'NFNet Step successfully taken, resulting metrics are {result}')
test_experiment()
test_nfnet_experiment()