From d55816beb3d55be2417dba8de2e7cf79b30ab5fc Mon Sep 17 00:00:00 2001 From: Loren Maggiore Date: Tue, 7 Dec 2021 09:58:30 +0000 Subject: [PATCH] Formatting fixes and internal changes. PiperOrigin-RevId: 414656804 --- nfnets/experiment.py | 106 ++++++++++++++++++++++-------------- nfnets/experiment_nfnets.py | 12 ++++ nfnets/fixup_resnet.py | 6 +- nfnets/nf_resnet.py | 8 +-- nfnets/optim.py | 2 +- nfnets/skipinit_resnet.py | 6 +- 6 files changed, 83 insertions(+), 57 deletions(-) diff --git a/nfnets/experiment.py b/nfnets/experiment.py index 724e5ea..0cf137e 100644 --- a/nfnets/experiment.py +++ b/nfnets/experiment.py @@ -75,28 +75,33 @@ def get_config(): bfloat16=False, lr_schedule=dict( name='WarmupCosineDecay', - kwargs=dict(num_steps=config.training_steps, - start_val=0, - min_val=0, - warmup_steps=5*steps_per_epoch), - ), + kwargs=dict( + num_steps=config.training_steps, + start_val=0, + min_val=0, + warmup_steps=5 * steps_per_epoch), + ), lr_scale_by_bs=True, optimizer=dict( name='SGD', - kwargs={'momentum': 0.9, 'nesterov': True, - 'weight_decay': 1e-4,}, + kwargs={ + 'momentum': 0.9, + 'nesterov': True, + 'weight_decay': 1e-4, + }, ), model_kwargs=dict( width=4, which_norm='BatchNorm', - norm_kwargs=dict(create_scale=True, - create_offset=True, - decay_rate=0.9, - ), # cross_replica_axis='i'), + norm_kwargs=dict( + create_scale=True, + create_offset=True, + decay_rate=0.9, + ), # cross_replica_axis='i'), variant='ResNet50', activation='relu', drop_rate=0.0, - ), + ), ),)) # Training loop config: log and checkpoint every minute @@ -136,8 +141,9 @@ class Experiment(experiment.AbstractExperiment): self._eval_input = None # Get model, loaded in from the zoo + module_prefix = 'nfnets.' self.model_module = importlib.import_module( - ('nfnets.'+ self.config.model.lower())) + (module_prefix + self.config.model.lower())) self.net = hk.transform_with_state(self._forward_fn) # Assign image sizes @@ -154,8 +160,8 @@ class Experiment(experiment.AbstractExperiment): self.test_imsize = variant_dict.get('test_imsize', test_imsize) donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2) - self.train_fn = jax.pmap(self._train_fn, axis_name='i', - donate_argnums=donate_argnums) + self.train_fn = jax.pmap( + self._train_fn, axis_name='i', donate_argnums=donate_argnums) self.eval_fn = jax.pmap(self._eval_fn, axis_name='i') def _initialize_train(self): @@ -163,8 +169,8 @@ class Experiment(experiment.AbstractExperiment): # Initialize net and EMA copy of net if no params available. if self._params is None: inputs = next(self._train_input) - init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True), - axis_name='i') + init_net = jax.pmap( + lambda *a: self.net.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params, self._state = init_net(init_rng, inputs) if self.config.use_ema: @@ -176,8 +182,9 @@ class Experiment(experiment.AbstractExperiment): def _make_opt(self): # Separate conv params and gains/biases def pred(mod, name, val): # pylint:disable=unused-argument - return (name in ['scale', 'offset', 'b'] - or 'gain' in name or 'bias' in name) + return (name in ['scale', 'offset', 'b'] or 'gain' in name or + 'bias' in name) + gains_biases, weights = hk.data_structures.partition(pred, self._params) # Lr schedule with batch-based LR scaling if self.config.lr_scale_by_bs: @@ -190,16 +197,22 @@ class Experiment(experiment.AbstractExperiment): opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()} opt_kwargs['lr'] = lr_schedule opt_module = getattr(optim, self.config.optimizer.name) - self.opt = opt_module([{'params': gains_biases, 'weight_decay': None}, - {'params': weights}], **opt_kwargs) + self.opt = opt_module([{ + 'params': gains_biases, + 'weight_decay': None + }, { + 'params': weights + }], **opt_kwargs) if self._opt_state is None: self._opt_state = self.opt.states() else: self.opt.plugin(self._opt_state) def _forward_fn(self, inputs, is_training): - net_kwargs = {'num_classes': self.config.num_classes, - **self.config.model_kwargs} + net_kwargs = { + 'num_classes': self.config.num_classes, + **self.config.model_kwargs + } net = getattr(self.model_module, self.config.model)(**net_kwargs) if self.config.get('transpose', False): images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC @@ -236,8 +249,7 @@ class Experiment(experiment.AbstractExperiment): scaled_loss = loss / jax.device_count() # Grads get psummed so do divide return scaled_loss, (metrics, state) - def _train_fn(self, params, states, opt_states, - inputs, rng, global_step, + def _train_fn(self, params, states, opt_states, inputs, rng, global_step, ema_params, ema_states): """Runs one batch forward + backward and run a single opt step.""" grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True) @@ -260,9 +272,14 @@ class Experiment(experiment.AbstractExperiment): ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step) ema_params = jax.tree_multimap(ema, ema_params, params) ema_states = jax.tree_multimap(ema, ema_states, states) - return {'params': params, 'states': states, 'opt_states': opt_states, - 'ema_params': ema_params, 'ema_states': ema_states, - 'metrics': metrics} + return { + 'params': params, + 'states': states, + 'opt_states': opt_states, + 'ema_params': ema_params, + 'ema_states': ema_states, + 'metrics': metrics + } # _ _ # | |_ _ __ __ _(_)_ __ @@ -275,10 +292,15 @@ class Experiment(experiment.AbstractExperiment): if self._train_input is None: self._initialize_train() inputs = next(self._train_input) - out = self.train_fn(params=self._params, states=self._state, - opt_states=self._opt_state, inputs=inputs, - rng=rng, global_step=global_step, - ema_params=self._ema_params, ema_states=self._ema_state) + out = self.train_fn( + params=self._params, + states=self._state, + opt_states=self._opt_state, + inputs=inputs, + rng=rng, + global_step=global_step, + ema_params=self._ema_params, + ema_states=self._ema_state) self._params, self._state = out['params'], out['states'] self._opt_state = out['opt_states'] self._ema_params, self._ema_state = out['ema_params'], out['ema_states'] @@ -294,7 +316,8 @@ class Experiment(experiment.AbstractExperiment): f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( - dataset.Split.TRAIN_AND_VALID, is_training=True, + dataset.Split.TRAIN_AND_VALID, + is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.train_imsize,) * 2, @@ -350,13 +373,16 @@ class Experiment(experiment.AbstractExperiment): bs_per_device = (self.config.eval_batch_size // jax.local_device_count()) split = dataset.Split.from_string(self.config.eval_subset) eval_preproc = self.config.get('eval_preproc', 'crop_resize') - return dataset.load(split, is_training=False, - batch_dims=[jax.local_device_count(), bs_per_device], - transpose=self.config.get('transpose', False), - image_size=(self.test_imsize,) * 2, - name=self.config.which_dataset, - eval_preproc=eval_preproc, - fake_data=self.config.get('fake_data', False)) + return dataset.load( + split, + is_training=False, + batch_dims=[jax.local_device_count(), bs_per_device], + transpose=self.config.get('transpose', False), + image_size=(self.test_imsize,) * 2, + name=self.config.which_dataset, + eval_preproc=eval_preproc, + fake_data=self.config.get('fake_data', False)) + if __name__ == '__main__': flags.mark_flag_as_required('config') diff --git a/nfnets/experiment_nfnets.py b/nfnets/experiment_nfnets.py index 8de7759..fa07a0b 100644 --- a/nfnets/experiment_nfnets.py +++ b/nfnets/experiment_nfnets.py @@ -14,11 +14,18 @@ # ============================================================================== r"""ImageNet experiment with NFNets.""" +import sys + +from absl import flags import haiku as hk +from jaxline import platform from ml_collections import config_dict + from nfnets import experiment from nfnets import optim +FLAGS = flags.FLAGS + def get_config(): """Return config object for training.""" @@ -124,3 +131,8 @@ class Experiment(experiment.Experiment): self._opt_state = self.opt.states() else: self.opt.plugin(self._opt_state) + + +if __name__ == '__main__': + flags.mark_flag_as_required('config') + platform.main(Experiment, sys.argv[1:]) diff --git a/nfnets/fixup_resnet.py b/nfnets/fixup_resnet.py index 291933c..745e4f2 100644 --- a/nfnets/fixup_resnet.py +++ b/nfnets/fixup_resnet.py @@ -121,11 +121,8 @@ class FixUp_ResNet(hk.Module): flops += [block.count_flops(h, w)] if block.stride > 1: h, w = h / block.stride, w / block.stride - # Head module FLOPs - out_ch = self.blocks[-1].out_ch - flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] # Count flops for classifier - flops += [self.final_conv.output_channels * self.fc.output_size] + flops += [self.blocks[-1].out_ch * self.fc.output_size] return flops, sum(flops) @@ -213,4 +210,3 @@ class ResBlock(hk.Module): sc_flops = 0 contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) return sum([expand_flops, dw_flops, contract_flops, sc_flops]) - diff --git a/nfnets/nf_resnet.py b/nfnets/nf_resnet.py index c179b1e..83872a1 100644 --- a/nfnets/nf_resnet.py +++ b/nfnets/nf_resnet.py @@ -81,7 +81,7 @@ class NF_ResNet(hk.Module): )] ch = block_width index += 1 - # Reset expected std but still give it 1 block of growth + # Reset expected std but still give it 1 block of growth if block_index == 0: expected_std = 1.0 expected_std = (expected_std **2 + alpha**2)**0.5 @@ -124,11 +124,8 @@ class NF_ResNet(hk.Module): flops += [block.count_flops(h, w)] if block.stride > 1: h, w = h / block.stride, w / block.stride - # Head module FLOPs - out_ch = self.blocks[-1].out_ch - flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] # Count flops for classifier - flops += [self.final_conv.output_channels * self.fc.output_size] + flops += [self.blocks[-1].out_ch * self.fc.output_size] return flops, sum(flops) @@ -213,4 +210,3 @@ class NFResBlock(hk.Module): se_flops += self.se.fc0.output_size * self.se.fc1.output_size contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops]) - diff --git a/nfnets/optim.py b/nfnets/optim.py index 84f8129..fa7447a 100644 --- a/nfnets/optim.py +++ b/nfnets/optim.py @@ -132,7 +132,7 @@ class Optimizer(object): elif '_states' in self.__dict__ and name in self._states: return self._states[name] else: - object.__getattr__(self, name) + object.__getattribute__(self, name) def step(self, params, grads, states, itr=None): """Takes a single optimizer step. diff --git a/nfnets/skipinit_resnet.py b/nfnets/skipinit_resnet.py index 06d6ba7..1bb586a 100644 --- a/nfnets/skipinit_resnet.py +++ b/nfnets/skipinit_resnet.py @@ -117,11 +117,8 @@ class SkipInit_ResNet(hk.Module): flops += [block.count_flops(h, w)] if block.stride > 1: h, w = h / block.stride, w / block.stride - # Head module FLOPs - out_ch = self.blocks[-1].out_ch - flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] # Count flops for classifier - flops += [self.final_conv.output_channels * self.fc.output_size] + flops += [self.blocks[-1].out_ch * self.fc.output_size] return flops, sum(flops) @@ -191,4 +188,3 @@ class NFResBlock(hk.Module): # SE flops happen on avg-pooled activations contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) return sum([expand_flops, dw_flops, contract_flops, sc_flops]) -