diff --git a/adversarial_robustness/jax/experiment.py b/adversarial_robustness/jax/experiment.py index d604921..20c037f 100644 --- a/adversarial_robustness/jax/experiment.py +++ b/adversarial_robustness/jax/experiment.py @@ -62,12 +62,17 @@ def get_config(): # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness. # If the path is set to "cifar10_ddpm.npz" and is not found in the current # directory, the corresponding data will be downloaded. - extra_npz = 'cifar10_ddpm.npz' + extra_npz = 'cifar10_ddpm.npz' # Can be `None`. # Learning rate. learning_rate = .1 * max(train_batch_size / 256, 1.) learning_rate_warmup = steps_from_epochs(10) - learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, + use_cosine_schedule = True + if use_cosine_schedule: + learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, + learning_rate_warmup) + else: + learning_rate_fn = utils.get_step_schedule(learning_rate, num_steps, learning_rate_warmup) # Model definition. @@ -120,7 +125,7 @@ def get_config(): weight_decay=5e-4, swa_decay=.995, use_cutmix=False, - supervised_batch_ratio=.3, + supervised_batch_ratio=.3 if extra_npz is not None else 1., extra_data_path=extra_npz, extra_label_smoothing=.1, attack=train_attack), @@ -343,7 +348,13 @@ class Experiment(experiment.AbstractExperiment): # def evaluate(self, global_step, rng, *unused_args, **unused_kwargs): - return self.eval_epoch(self._avg_params or self._params, self._state, rng) + scalars = self.eval_epoch(self._params, self._state, rng) + if self._avg_params: + avg_scalars = self.eval_epoch(self._avg_params or self._params, + self._state, rng) + for k, v in avg_scalars.items(): + scalars[k + '_swa'] = v + return scalars def eval_epoch(self, params, state, rng): host_id = jax.host_id() @@ -408,8 +419,11 @@ class Experiment(experiment.AbstractExperiment): self._repeat_batch = 1 self.supervised_train_input = jl_utils.py_prefetch( self._supervised_train_dataset) - self.extra_train_input = jl_utils.py_prefetch( - self._extra_train_dataset) + if self.config.training.extra_data_path is None: + self.extra_train_input = None + else: + self.extra_train_input = jl_utils.py_prefetch( + self._extra_train_dataset) self.normalize_fn = datasets.cifar10_normalize # Optimizer. @@ -423,7 +437,8 @@ class Experiment(experiment.AbstractExperiment): # Create inputs to initialize the network state. images, _, _ = jax.pmap(self.concatenate)( next(self.supervised_train_input), - next(self.extra_train_input)) + next(self.extra_train_input) if self.extra_train_input is not None + else None) images = jax.pmap(self.normalize_fn)(images) # Initialize weights and biases. init_net = jax.pmap( diff --git a/adversarial_robustness/jax/utils.py b/adversarial_robustness/jax/utils.py index 8f0ba13..00728b0 100644 --- a/adversarial_robustness/jax/utils.py +++ b/adversarial_robustness/jax/utils.py @@ -41,6 +41,23 @@ def get_cosine_schedule( ], [warmup_steps]) +def get_step_schedule( + max_learning_rate: float, + total_steps: int, + warmup_steps: int = 0) -> optax.Schedule: + """Builds a step schedule with initial warm-up.""" + if total_steps < warmup_steps: + return optax.linear_schedule(init_value=0., end_value=max_learning_rate, + transition_steps=warmup_steps) + return optax.join_schedules([ + optax.linear_schedule(init_value=0., end_value=max_learning_rate, + transition_steps=warmup_steps), + optax.piecewise_constant_schedule( + init_value=max_learning_rate, + boundaries_and_scales={total_steps * 2 // 3: .1}), + ], [warmup_steps]) + + def sgd_momentum(learning_rate_fn: optax.Schedule, momentum: float = 0., nesterov: bool = False) -> optax.GradientTransformation: @@ -118,8 +135,7 @@ def cutmix(rng: chex.PRNGKey, split: int = 1) -> Tuple[chex.Array, chex.Array]: """Composing two images by inserting a patch into another image.""" batch_size, height, width, _ = images.shape - if split > 1: - split_batch_size = batch_size // split + split_batch_size = batch_size // split if split > 1 else batch_size # Masking bounding box. box_rng, lam_rng, rng = jax.random.split(rng, num=3) @@ -158,14 +174,10 @@ def _random_box(rng: chex.PRNGKey, cut_w: chex.Array) -> chex.Array: """Sample a random box of shape [cut_h, cut_w].""" height_rng, width_rng = jax.random.split(rng) - minval_h = 0 - minval_w = 0 - maxval_h = height - maxval_w = width i = jax.random.randint( - height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32) + height_rng, shape=(), minval=0, maxval=height, dtype=jnp.int32) j = jax.random.randint( - width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32) + width_rng, shape=(), minval=0, maxval=width, dtype=jnp.int32) bby1 = jnp.clip(i - cut_h // 2, 0, height) bbx1 = jnp.clip(j - cut_w // 2, 0, width) h = jnp.clip(i + cut_h // 2, 0, height) - bby1