Fixed issue in cutmix where split_batch_size would be undefined.

Allow users to easily specify that they do not wish to use extra data.

PiperOrigin-RevId: 396784433
This commit is contained in:
Sven Gowal
2021-09-15 10:01:03 +01:00
committed by Saran Tunyasuvunakool
parent cff83be778
commit 41e2d45ed8
2 changed files with 42 additions and 15 deletions
+22 -7
View File
@@ -62,12 +62,17 @@ def get_config():
# https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness. # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
# If the path is set to "cifar10_ddpm.npz" and is not found in the current # If the path is set to "cifar10_ddpm.npz" and is not found in the current
# directory, the corresponding data will be downloaded. # directory, the corresponding data will be downloaded.
extra_npz = 'cifar10_ddpm.npz' extra_npz = 'cifar10_ddpm.npz' # Can be `None`.
# Learning rate. # Learning rate.
learning_rate = .1 * max(train_batch_size / 256, 1.) learning_rate = .1 * max(train_batch_size / 256, 1.)
learning_rate_warmup = steps_from_epochs(10) learning_rate_warmup = steps_from_epochs(10)
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, use_cosine_schedule = True
if use_cosine_schedule:
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
learning_rate_warmup)
else:
learning_rate_fn = utils.get_step_schedule(learning_rate, num_steps,
learning_rate_warmup) learning_rate_warmup)
# Model definition. # Model definition.
@@ -120,7 +125,7 @@ def get_config():
weight_decay=5e-4, weight_decay=5e-4,
swa_decay=.995, swa_decay=.995,
use_cutmix=False, use_cutmix=False,
supervised_batch_ratio=.3, supervised_batch_ratio=.3 if extra_npz is not None else 1.,
extra_data_path=extra_npz, extra_data_path=extra_npz,
extra_label_smoothing=.1, extra_label_smoothing=.1,
attack=train_attack), attack=train_attack),
@@ -343,7 +348,13 @@ class Experiment(experiment.AbstractExperiment):
# #
def evaluate(self, global_step, rng, *unused_args, **unused_kwargs): def evaluate(self, global_step, rng, *unused_args, **unused_kwargs):
return self.eval_epoch(self._avg_params or self._params, self._state, rng) scalars = self.eval_epoch(self._params, self._state, rng)
if self._avg_params:
avg_scalars = self.eval_epoch(self._avg_params or self._params,
self._state, rng)
for k, v in avg_scalars.items():
scalars[k + '_swa'] = v
return scalars
def eval_epoch(self, params, state, rng): def eval_epoch(self, params, state, rng):
host_id = jax.host_id() host_id = jax.host_id()
@@ -408,8 +419,11 @@ class Experiment(experiment.AbstractExperiment):
self._repeat_batch = 1 self._repeat_batch = 1
self.supervised_train_input = jl_utils.py_prefetch( self.supervised_train_input = jl_utils.py_prefetch(
self._supervised_train_dataset) self._supervised_train_dataset)
self.extra_train_input = jl_utils.py_prefetch( if self.config.training.extra_data_path is None:
self._extra_train_dataset) self.extra_train_input = None
else:
self.extra_train_input = jl_utils.py_prefetch(
self._extra_train_dataset)
self.normalize_fn = datasets.cifar10_normalize self.normalize_fn = datasets.cifar10_normalize
# Optimizer. # Optimizer.
@@ -423,7 +437,8 @@ class Experiment(experiment.AbstractExperiment):
# Create inputs to initialize the network state. # Create inputs to initialize the network state.
images, _, _ = jax.pmap(self.concatenate)( images, _, _ = jax.pmap(self.concatenate)(
next(self.supervised_train_input), next(self.supervised_train_input),
next(self.extra_train_input)) next(self.extra_train_input) if self.extra_train_input is not None
else None)
images = jax.pmap(self.normalize_fn)(images) images = jax.pmap(self.normalize_fn)(images)
# Initialize weights and biases. # Initialize weights and biases.
init_net = jax.pmap( init_net = jax.pmap(
+20 -8
View File
@@ -41,6 +41,23 @@ def get_cosine_schedule(
], [warmup_steps]) ], [warmup_steps])
def get_step_schedule(
max_learning_rate: float,
total_steps: int,
warmup_steps: int = 0) -> optax.Schedule:
"""Builds a step schedule with initial warm-up."""
if total_steps < warmup_steps:
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps)
return optax.join_schedules([
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps),
optax.piecewise_constant_schedule(
init_value=max_learning_rate,
boundaries_and_scales={total_steps * 2 // 3: .1}),
], [warmup_steps])
def sgd_momentum(learning_rate_fn: optax.Schedule, def sgd_momentum(learning_rate_fn: optax.Schedule,
momentum: float = 0., momentum: float = 0.,
nesterov: bool = False) -> optax.GradientTransformation: nesterov: bool = False) -> optax.GradientTransformation:
@@ -118,8 +135,7 @@ def cutmix(rng: chex.PRNGKey,
split: int = 1) -> Tuple[chex.Array, chex.Array]: split: int = 1) -> Tuple[chex.Array, chex.Array]:
"""Composing two images by inserting a patch into another image.""" """Composing two images by inserting a patch into another image."""
batch_size, height, width, _ = images.shape batch_size, height, width, _ = images.shape
if split > 1: split_batch_size = batch_size // split if split > 1 else batch_size
split_batch_size = batch_size // split
# Masking bounding box. # Masking bounding box.
box_rng, lam_rng, rng = jax.random.split(rng, num=3) box_rng, lam_rng, rng = jax.random.split(rng, num=3)
@@ -158,14 +174,10 @@ def _random_box(rng: chex.PRNGKey,
cut_w: chex.Array) -> chex.Array: cut_w: chex.Array) -> chex.Array:
"""Sample a random box of shape [cut_h, cut_w].""" """Sample a random box of shape [cut_h, cut_w]."""
height_rng, width_rng = jax.random.split(rng) height_rng, width_rng = jax.random.split(rng)
minval_h = 0
minval_w = 0
maxval_h = height
maxval_w = width
i = jax.random.randint( i = jax.random.randint(
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32) height_rng, shape=(), minval=0, maxval=height, dtype=jnp.int32)
j = jax.random.randint( j = jax.random.randint(
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32) width_rng, shape=(), minval=0, maxval=width, dtype=jnp.int32)
bby1 = jnp.clip(i - cut_h // 2, 0, height) bby1 = jnp.clip(i - cut_h // 2, 0, height)
bbx1 = jnp.clip(j - cut_w // 2, 0, width) bbx1 = jnp.clip(j - cut_w // 2, 0, width)
h = jnp.clip(i + cut_h // 2, 0, height) - bby1 h = jnp.clip(i + cut_h // 2, 0, height) - bby1