Formatting fixes and internal changes.

PiperOrigin-RevId: 414656804
This commit is contained in:
Loren Maggiore
2021-12-07 09:58:30 +00:00
committed by Saran Tunyasuvunakool
parent 826ff89f21
commit d55816beb3
6 changed files with 83 additions and 57 deletions
+66 -40
View File
@@ -75,28 +75,33 @@ def get_config():
bfloat16=False,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(num_steps=config.training_steps,
start_val=0,
min_val=0,
warmup_steps=5*steps_per_epoch),
),
kwargs=dict(
num_steps=config.training_steps,
start_val=0,
min_val=0,
warmup_steps=5 * steps_per_epoch),
),
lr_scale_by_bs=True,
optimizer=dict(
name='SGD',
kwargs={'momentum': 0.9, 'nesterov': True,
'weight_decay': 1e-4,},
kwargs={
'momentum': 0.9,
'nesterov': True,
'weight_decay': 1e-4,
},
),
model_kwargs=dict(
width=4,
which_norm='BatchNorm',
norm_kwargs=dict(create_scale=True,
create_offset=True,
decay_rate=0.9,
), # cross_replica_axis='i'),
norm_kwargs=dict(
create_scale=True,
create_offset=True,
decay_rate=0.9,
), # cross_replica_axis='i'),
variant='ResNet50',
activation='relu',
drop_rate=0.0,
),
),
),))
# Training loop config: log and checkpoint every minute
@@ -136,8 +141,9 @@ class Experiment(experiment.AbstractExperiment):
self._eval_input = None
# Get model, loaded in from the zoo
module_prefix = 'nfnets.'
self.model_module = importlib.import_module(
('nfnets.'+ self.config.model.lower()))
(module_prefix + self.config.model.lower()))
self.net = hk.transform_with_state(self._forward_fn)
# Assign image sizes
@@ -154,8 +160,8 @@ class Experiment(experiment.AbstractExperiment):
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
self.train_fn = jax.pmap(self._train_fn, axis_name='i',
donate_argnums=donate_argnums)
self.train_fn = jax.pmap(
self._train_fn, axis_name='i', donate_argnums=donate_argnums)
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
def _initialize_train(self):
@@ -163,8 +169,8 @@ class Experiment(experiment.AbstractExperiment):
# Initialize net and EMA copy of net if no params available.
if self._params is None:
inputs = next(self._train_input)
init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
axis_name='i')
init_net = jax.pmap(
lambda *a: self.net.init(*a, is_training=True), axis_name='i')
init_rng = jl_utils.bcast_local_devices(self.init_rng)
self._params, self._state = init_net(init_rng, inputs)
if self.config.use_ema:
@@ -176,8 +182,9 @@ class Experiment(experiment.AbstractExperiment):
def _make_opt(self):
# Separate conv params and gains/biases
def pred(mod, name, val): # pylint:disable=unused-argument
return (name in ['scale', 'offset', 'b']
or 'gain' in name or 'bias' in name)
return (name in ['scale', 'offset', 'b'] or 'gain' in name or
'bias' in name)
gains_biases, weights = hk.data_structures.partition(pred, self._params)
# Lr schedule with batch-based LR scaling
if self.config.lr_scale_by_bs:
@@ -190,16 +197,22 @@ class Experiment(experiment.AbstractExperiment):
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
opt_kwargs['lr'] = lr_schedule
opt_module = getattr(optim, self.config.optimizer.name)
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None},
{'params': weights}], **opt_kwargs)
self.opt = opt_module([{
'params': gains_biases,
'weight_decay': None
}, {
'params': weights
}], **opt_kwargs)
if self._opt_state is None:
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
def _forward_fn(self, inputs, is_training):
net_kwargs = {'num_classes': self.config.num_classes,
**self.config.model_kwargs}
net_kwargs = {
'num_classes': self.config.num_classes,
**self.config.model_kwargs
}
net = getattr(self.model_module, self.config.model)(**net_kwargs)
if self.config.get('transpose', False):
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
@@ -236,8 +249,7 @@ class Experiment(experiment.AbstractExperiment):
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
return scaled_loss, (metrics, state)
def _train_fn(self, params, states, opt_states,
inputs, rng, global_step,
def _train_fn(self, params, states, opt_states, inputs, rng, global_step,
ema_params, ema_states):
"""Runs one batch forward + backward and run a single opt step."""
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
@@ -260,9 +272,14 @@ class Experiment(experiment.AbstractExperiment):
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_states = jax.tree_multimap(ema, ema_states, states)
return {'params': params, 'states': states, 'opt_states': opt_states,
'ema_params': ema_params, 'ema_states': ema_states,
'metrics': metrics}
return {
'params': params,
'states': states,
'opt_states': opt_states,
'ema_params': ema_params,
'ema_states': ema_states,
'metrics': metrics
}
# _ _
# | |_ _ __ __ _(_)_ __
@@ -275,10 +292,15 @@ class Experiment(experiment.AbstractExperiment):
if self._train_input is None:
self._initialize_train()
inputs = next(self._train_input)
out = self.train_fn(params=self._params, states=self._state,
opt_states=self._opt_state, inputs=inputs,
rng=rng, global_step=global_step,
ema_params=self._ema_params, ema_states=self._ema_state)
out = self.train_fn(
params=self._params,
states=self._state,
opt_states=self._opt_state,
inputs=inputs,
rng=rng,
global_step=global_step,
ema_params=self._ema_params,
ema_states=self._ema_state)
self._params, self._state = out['params'], out['states']
self._opt_state = out['opt_states']
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
@@ -294,7 +316,8 @@ class Experiment(experiment.AbstractExperiment):
f'Global batch size {global_batch_size} must be divisible by '
f'num devices {num_devices}')
return dataset.load(
dataset.Split.TRAIN_AND_VALID, is_training=True,
dataset.Split.TRAIN_AND_VALID,
is_training=True,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.train_imsize,) * 2,
@@ -350,13 +373,16 @@ class Experiment(experiment.AbstractExperiment):
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
split = dataset.Split.from_string(self.config.eval_subset)
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
return dataset.load(split, is_training=False,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.test_imsize,) * 2,
name=self.config.which_dataset,
eval_preproc=eval_preproc,
fake_data=self.config.get('fake_data', False))
return dataset.load(
split,
is_training=False,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.test_imsize,) * 2,
name=self.config.which_dataset,
eval_preproc=eval_preproc,
fake_data=self.config.get('fake_data', False))
if __name__ == '__main__':
flags.mark_flag_as_required('config')
+12
View File
@@ -14,11 +14,18 @@
# ==============================================================================
r"""ImageNet experiment with NFNets."""
import sys
from absl import flags
import haiku as hk
from jaxline import platform
from ml_collections import config_dict
from nfnets import experiment
from nfnets import optim
FLAGS = flags.FLAGS
def get_config():
"""Return config object for training."""
@@ -124,3 +131,8 @@ class Experiment(experiment.Experiment):
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
if __name__ == '__main__':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])
+1 -5
View File
@@ -121,11 +121,8 @@ class FixUp_ResNet(hk.Module):
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
flops += [self.blocks[-1].out_ch * self.fc.output_size]
return flops, sum(flops)
@@ -213,4 +210,3 @@ class ResBlock(hk.Module):
sc_flops = 0
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
+2 -6
View File
@@ -81,7 +81,7 @@ class NF_ResNet(hk.Module):
)]
ch = block_width
index += 1
# Reset expected std but still give it 1 block of growth
# Reset expected std but still give it 1 block of growth
if block_index == 0:
expected_std = 1.0
expected_std = (expected_std **2 + alpha**2)**0.5
@@ -124,11 +124,8 @@ class NF_ResNet(hk.Module):
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
flops += [self.blocks[-1].out_ch * self.fc.output_size]
return flops, sum(flops)
@@ -213,4 +210,3 @@ class NFResBlock(hk.Module):
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
+1 -1
View File
@@ -132,7 +132,7 @@ class Optimizer(object):
elif '_states' in self.__dict__ and name in self._states:
return self._states[name]
else:
object.__getattr__(self, name)
object.__getattribute__(self, name)
def step(self, params, grads, states, itr=None):
"""Takes a single optimizer step.
+1 -5
View File
@@ -117,11 +117,8 @@ class SkipInit_ResNet(hk.Module):
flops += [block.count_flops(h, w)]
if block.stride > 1:
h, w = h / block.stride, w / block.stride
# Head module FLOPs
out_ch = self.blocks[-1].out_ch
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
# Count flops for classifier
flops += [self.final_conv.output_channels * self.fc.output_size]
flops += [self.blocks[-1].out_ch * self.fc.output_size]
return flops, sum(flops)
@@ -191,4 +188,3 @@ class NFResBlock(hk.Module):
# SE flops happen on avg-pooled activations
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
return sum([expand_flops, dw_flops, contract_flops, sc_flops])