Fixed issue in cutmix where split_batch_size would be undefined.

Allow users to easily specify that they do not wish to use extra data.

PiperOrigin-RevId: 396784433
This commit is contained in:
Sven Gowal
2021-09-15 10:01:03 +01:00
committed by Saran Tunyasuvunakool
parent cff83be778
commit 41e2d45ed8
2 changed files with 42 additions and 15 deletions
+22 -7
View File
@@ -62,12 +62,17 @@ def get_config():
# https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
# If the path is set to "cifar10_ddpm.npz" and is not found in the current
# directory, the corresponding data will be downloaded.
extra_npz = 'cifar10_ddpm.npz'
extra_npz = 'cifar10_ddpm.npz' # Can be `None`.
# Learning rate.
learning_rate = .1 * max(train_batch_size / 256, 1.)
learning_rate_warmup = steps_from_epochs(10)
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
use_cosine_schedule = True
if use_cosine_schedule:
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
learning_rate_warmup)
else:
learning_rate_fn = utils.get_step_schedule(learning_rate, num_steps,
learning_rate_warmup)
# Model definition.
@@ -120,7 +125,7 @@ def get_config():
weight_decay=5e-4,
swa_decay=.995,
use_cutmix=False,
supervised_batch_ratio=.3,
supervised_batch_ratio=.3 if extra_npz is not None else 1.,
extra_data_path=extra_npz,
extra_label_smoothing=.1,
attack=train_attack),
@@ -343,7 +348,13 @@ class Experiment(experiment.AbstractExperiment):
#
def evaluate(self, global_step, rng, *unused_args, **unused_kwargs):
return self.eval_epoch(self._avg_params or self._params, self._state, rng)
scalars = self.eval_epoch(self._params, self._state, rng)
if self._avg_params:
avg_scalars = self.eval_epoch(self._avg_params or self._params,
self._state, rng)
for k, v in avg_scalars.items():
scalars[k + '_swa'] = v
return scalars
def eval_epoch(self, params, state, rng):
host_id = jax.host_id()
@@ -408,8 +419,11 @@ class Experiment(experiment.AbstractExperiment):
self._repeat_batch = 1
self.supervised_train_input = jl_utils.py_prefetch(
self._supervised_train_dataset)
self.extra_train_input = jl_utils.py_prefetch(
self._extra_train_dataset)
if self.config.training.extra_data_path is None:
self.extra_train_input = None
else:
self.extra_train_input = jl_utils.py_prefetch(
self._extra_train_dataset)
self.normalize_fn = datasets.cifar10_normalize
# Optimizer.
@@ -423,7 +437,8 @@ class Experiment(experiment.AbstractExperiment):
# Create inputs to initialize the network state.
images, _, _ = jax.pmap(self.concatenate)(
next(self.supervised_train_input),
next(self.extra_train_input))
next(self.extra_train_input) if self.extra_train_input is not None
else None)
images = jax.pmap(self.normalize_fn)(images)
# Initialize weights and biases.
init_net = jax.pmap(
+20 -8
View File
@@ -41,6 +41,23 @@ def get_cosine_schedule(
], [warmup_steps])
def get_step_schedule(
max_learning_rate: float,
total_steps: int,
warmup_steps: int = 0) -> optax.Schedule:
"""Builds a step schedule with initial warm-up."""
if total_steps < warmup_steps:
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps)
return optax.join_schedules([
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps),
optax.piecewise_constant_schedule(
init_value=max_learning_rate,
boundaries_and_scales={total_steps * 2 // 3: .1}),
], [warmup_steps])
def sgd_momentum(learning_rate_fn: optax.Schedule,
momentum: float = 0.,
nesterov: bool = False) -> optax.GradientTransformation:
@@ -118,8 +135,7 @@ def cutmix(rng: chex.PRNGKey,
split: int = 1) -> Tuple[chex.Array, chex.Array]:
"""Composing two images by inserting a patch into another image."""
batch_size, height, width, _ = images.shape
if split > 1:
split_batch_size = batch_size // split
split_batch_size = batch_size // split if split > 1 else batch_size
# Masking bounding box.
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
@@ -158,14 +174,10 @@ def _random_box(rng: chex.PRNGKey,
cut_w: chex.Array) -> chex.Array:
"""Sample a random box of shape [cut_h, cut_w]."""
height_rng, width_rng = jax.random.split(rng)
minval_h = 0
minval_w = 0
maxval_h = height
maxval_w = width
i = jax.random.randint(
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32)
height_rng, shape=(), minval=0, maxval=height, dtype=jnp.int32)
j = jax.random.randint(
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32)
width_rng, shape=(), minval=0, maxval=width, dtype=jnp.int32)
bby1 = jnp.clip(i - cut_h // 2, 0, height)
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
h = jnp.clip(i + cut_h // 2, 0, height) - bby1