mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-17 23:38:40 +08:00
Formatting fixes and internal changes.
PiperOrigin-RevId: 414656804
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
826ff89f21
commit
d55816beb3
+66
-40
@@ -75,28 +75,33 @@ def get_config():
|
||||
bfloat16=False,
|
||||
lr_schedule=dict(
|
||||
name='WarmupCosineDecay',
|
||||
kwargs=dict(num_steps=config.training_steps,
|
||||
start_val=0,
|
||||
min_val=0,
|
||||
warmup_steps=5*steps_per_epoch),
|
||||
),
|
||||
kwargs=dict(
|
||||
num_steps=config.training_steps,
|
||||
start_val=0,
|
||||
min_val=0,
|
||||
warmup_steps=5 * steps_per_epoch),
|
||||
),
|
||||
lr_scale_by_bs=True,
|
||||
optimizer=dict(
|
||||
name='SGD',
|
||||
kwargs={'momentum': 0.9, 'nesterov': True,
|
||||
'weight_decay': 1e-4,},
|
||||
kwargs={
|
||||
'momentum': 0.9,
|
||||
'nesterov': True,
|
||||
'weight_decay': 1e-4,
|
||||
},
|
||||
),
|
||||
model_kwargs=dict(
|
||||
width=4,
|
||||
which_norm='BatchNorm',
|
||||
norm_kwargs=dict(create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
), # cross_replica_axis='i'),
|
||||
norm_kwargs=dict(
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
), # cross_replica_axis='i'),
|
||||
variant='ResNet50',
|
||||
activation='relu',
|
||||
drop_rate=0.0,
|
||||
),
|
||||
),
|
||||
),))
|
||||
|
||||
# Training loop config: log and checkpoint every minute
|
||||
@@ -136,8 +141,9 @@ class Experiment(experiment.AbstractExperiment):
|
||||
self._eval_input = None
|
||||
|
||||
# Get model, loaded in from the zoo
|
||||
module_prefix = 'nfnets.'
|
||||
self.model_module = importlib.import_module(
|
||||
('nfnets.'+ self.config.model.lower()))
|
||||
(module_prefix + self.config.model.lower()))
|
||||
self.net = hk.transform_with_state(self._forward_fn)
|
||||
|
||||
# Assign image sizes
|
||||
@@ -154,8 +160,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
|
||||
|
||||
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
|
||||
self.train_fn = jax.pmap(self._train_fn, axis_name='i',
|
||||
donate_argnums=donate_argnums)
|
||||
self.train_fn = jax.pmap(
|
||||
self._train_fn, axis_name='i', donate_argnums=donate_argnums)
|
||||
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
|
||||
|
||||
def _initialize_train(self):
|
||||
@@ -163,8 +169,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
# Initialize net and EMA copy of net if no params available.
|
||||
if self._params is None:
|
||||
inputs = next(self._train_input)
|
||||
init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
|
||||
axis_name='i')
|
||||
init_net = jax.pmap(
|
||||
lambda *a: self.net.init(*a, is_training=True), axis_name='i')
|
||||
init_rng = jl_utils.bcast_local_devices(self.init_rng)
|
||||
self._params, self._state = init_net(init_rng, inputs)
|
||||
if self.config.use_ema:
|
||||
@@ -176,8 +182,9 @@ class Experiment(experiment.AbstractExperiment):
|
||||
def _make_opt(self):
|
||||
# Separate conv params and gains/biases
|
||||
def pred(mod, name, val): # pylint:disable=unused-argument
|
||||
return (name in ['scale', 'offset', 'b']
|
||||
or 'gain' in name or 'bias' in name)
|
||||
return (name in ['scale', 'offset', 'b'] or 'gain' in name or
|
||||
'bias' in name)
|
||||
|
||||
gains_biases, weights = hk.data_structures.partition(pred, self._params)
|
||||
# Lr schedule with batch-based LR scaling
|
||||
if self.config.lr_scale_by_bs:
|
||||
@@ -190,16 +197,22 @@ class Experiment(experiment.AbstractExperiment):
|
||||
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
|
||||
opt_kwargs['lr'] = lr_schedule
|
||||
opt_module = getattr(optim, self.config.optimizer.name)
|
||||
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None},
|
||||
{'params': weights}], **opt_kwargs)
|
||||
self.opt = opt_module([{
|
||||
'params': gains_biases,
|
||||
'weight_decay': None
|
||||
}, {
|
||||
'params': weights
|
||||
}], **opt_kwargs)
|
||||
if self._opt_state is None:
|
||||
self._opt_state = self.opt.states()
|
||||
else:
|
||||
self.opt.plugin(self._opt_state)
|
||||
|
||||
def _forward_fn(self, inputs, is_training):
|
||||
net_kwargs = {'num_classes': self.config.num_classes,
|
||||
**self.config.model_kwargs}
|
||||
net_kwargs = {
|
||||
'num_classes': self.config.num_classes,
|
||||
**self.config.model_kwargs
|
||||
}
|
||||
net = getattr(self.model_module, self.config.model)(**net_kwargs)
|
||||
if self.config.get('transpose', False):
|
||||
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
|
||||
@@ -236,8 +249,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
|
||||
return scaled_loss, (metrics, state)
|
||||
|
||||
def _train_fn(self, params, states, opt_states,
|
||||
inputs, rng, global_step,
|
||||
def _train_fn(self, params, states, opt_states, inputs, rng, global_step,
|
||||
ema_params, ema_states):
|
||||
"""Runs one batch forward + backward and run a single opt step."""
|
||||
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
|
||||
@@ -260,9 +272,14 @@ class Experiment(experiment.AbstractExperiment):
|
||||
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
|
||||
ema_params = jax.tree_multimap(ema, ema_params, params)
|
||||
ema_states = jax.tree_multimap(ema, ema_states, states)
|
||||
return {'params': params, 'states': states, 'opt_states': opt_states,
|
||||
'ema_params': ema_params, 'ema_states': ema_states,
|
||||
'metrics': metrics}
|
||||
return {
|
||||
'params': params,
|
||||
'states': states,
|
||||
'opt_states': opt_states,
|
||||
'ema_params': ema_params,
|
||||
'ema_states': ema_states,
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
# _ _
|
||||
# | |_ _ __ __ _(_)_ __
|
||||
@@ -275,10 +292,15 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if self._train_input is None:
|
||||
self._initialize_train()
|
||||
inputs = next(self._train_input)
|
||||
out = self.train_fn(params=self._params, states=self._state,
|
||||
opt_states=self._opt_state, inputs=inputs,
|
||||
rng=rng, global_step=global_step,
|
||||
ema_params=self._ema_params, ema_states=self._ema_state)
|
||||
out = self.train_fn(
|
||||
params=self._params,
|
||||
states=self._state,
|
||||
opt_states=self._opt_state,
|
||||
inputs=inputs,
|
||||
rng=rng,
|
||||
global_step=global_step,
|
||||
ema_params=self._ema_params,
|
||||
ema_states=self._ema_state)
|
||||
self._params, self._state = out['params'], out['states']
|
||||
self._opt_state = out['opt_states']
|
||||
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
|
||||
@@ -294,7 +316,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
f'Global batch size {global_batch_size} must be divisible by '
|
||||
f'num devices {num_devices}')
|
||||
return dataset.load(
|
||||
dataset.Split.TRAIN_AND_VALID, is_training=True,
|
||||
dataset.Split.TRAIN_AND_VALID,
|
||||
is_training=True,
|
||||
batch_dims=[jax.local_device_count(), bs_per_device],
|
||||
transpose=self.config.get('transpose', False),
|
||||
image_size=(self.train_imsize,) * 2,
|
||||
@@ -350,13 +373,16 @@ class Experiment(experiment.AbstractExperiment):
|
||||
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
|
||||
split = dataset.Split.from_string(self.config.eval_subset)
|
||||
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
|
||||
return dataset.load(split, is_training=False,
|
||||
batch_dims=[jax.local_device_count(), bs_per_device],
|
||||
transpose=self.config.get('transpose', False),
|
||||
image_size=(self.test_imsize,) * 2,
|
||||
name=self.config.which_dataset,
|
||||
eval_preproc=eval_preproc,
|
||||
fake_data=self.config.get('fake_data', False))
|
||||
return dataset.load(
|
||||
split,
|
||||
is_training=False,
|
||||
batch_dims=[jax.local_device_count(), bs_per_device],
|
||||
transpose=self.config.get('transpose', False),
|
||||
image_size=(self.test_imsize,) * 2,
|
||||
name=self.config.which_dataset,
|
||||
eval_preproc=eval_preproc,
|
||||
fake_data=self.config.get('fake_data', False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('config')
|
||||
|
||||
@@ -14,11 +14,18 @@
|
||||
# ==============================================================================
|
||||
r"""ImageNet experiment with NFNets."""
|
||||
|
||||
import sys
|
||||
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
from jaxline import platform
|
||||
from ml_collections import config_dict
|
||||
|
||||
from nfnets import experiment
|
||||
from nfnets import optim
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Return config object for training."""
|
||||
@@ -124,3 +131,8 @@ class Experiment(experiment.Experiment):
|
||||
self._opt_state = self.opt.states()
|
||||
else:
|
||||
self.opt.plugin(self._opt_state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('config')
|
||||
platform.main(Experiment, sys.argv[1:])
|
||||
|
||||
@@ -121,11 +121,8 @@ class FixUp_ResNet(hk.Module):
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
flops += [self.blocks[-1].out_ch * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
@@ -213,4 +210,3 @@ class ResBlock(hk.Module):
|
||||
sc_flops = 0
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
|
||||
|
||||
|
||||
+2
-6
@@ -81,7 +81,7 @@ class NF_ResNet(hk.Module):
|
||||
)]
|
||||
ch = block_width
|
||||
index += 1
|
||||
# Reset expected std but still give it 1 block of growth
|
||||
# Reset expected std but still give it 1 block of growth
|
||||
if block_index == 0:
|
||||
expected_std = 1.0
|
||||
expected_std = (expected_std **2 + alpha**2)**0.5
|
||||
@@ -124,11 +124,8 @@ class NF_ResNet(hk.Module):
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
flops += [self.blocks[-1].out_ch * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
@@ -213,4 +210,3 @@ class NFResBlock(hk.Module):
|
||||
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
|
||||
|
||||
|
||||
+1
-1
@@ -132,7 +132,7 @@ class Optimizer(object):
|
||||
elif '_states' in self.__dict__ and name in self._states:
|
||||
return self._states[name]
|
||||
else:
|
||||
object.__getattr__(self, name)
|
||||
object.__getattribute__(self, name)
|
||||
|
||||
def step(self, params, grads, states, itr=None):
|
||||
"""Takes a single optimizer step.
|
||||
|
||||
@@ -117,11 +117,8 @@ class SkipInit_ResNet(hk.Module):
|
||||
flops += [block.count_flops(h, w)]
|
||||
if block.stride > 1:
|
||||
h, w = h / block.stride, w / block.stride
|
||||
# Head module FLOPs
|
||||
out_ch = self.blocks[-1].out_ch
|
||||
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
|
||||
# Count flops for classifier
|
||||
flops += [self.final_conv.output_channels * self.fc.output_size]
|
||||
flops += [self.blocks[-1].out_ch * self.fc.output_size]
|
||||
return flops, sum(flops)
|
||||
|
||||
|
||||
@@ -191,4 +188,3 @@ class NFResBlock(hk.Module):
|
||||
# SE flops happen on avg-pooled activations
|
||||
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
|
||||
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user