mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-22 23:35:19 +08:00
Fixed issue in cutmix where split_batch_size would be undefined.
Allow users to easily specify that they do not wish to use extra data. PiperOrigin-RevId: 396784433
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
cff83be778
commit
41e2d45ed8
@@ -62,12 +62,17 @@ def get_config():
|
||||
# https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
|
||||
# If the path is set to "cifar10_ddpm.npz" and is not found in the current
|
||||
# directory, the corresponding data will be downloaded.
|
||||
extra_npz = 'cifar10_ddpm.npz'
|
||||
extra_npz = 'cifar10_ddpm.npz' # Can be `None`.
|
||||
|
||||
# Learning rate.
|
||||
learning_rate = .1 * max(train_batch_size / 256, 1.)
|
||||
learning_rate_warmup = steps_from_epochs(10)
|
||||
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
|
||||
use_cosine_schedule = True
|
||||
if use_cosine_schedule:
|
||||
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
|
||||
learning_rate_warmup)
|
||||
else:
|
||||
learning_rate_fn = utils.get_step_schedule(learning_rate, num_steps,
|
||||
learning_rate_warmup)
|
||||
|
||||
# Model definition.
|
||||
@@ -120,7 +125,7 @@ def get_config():
|
||||
weight_decay=5e-4,
|
||||
swa_decay=.995,
|
||||
use_cutmix=False,
|
||||
supervised_batch_ratio=.3,
|
||||
supervised_batch_ratio=.3 if extra_npz is not None else 1.,
|
||||
extra_data_path=extra_npz,
|
||||
extra_label_smoothing=.1,
|
||||
attack=train_attack),
|
||||
@@ -343,7 +348,13 @@ class Experiment(experiment.AbstractExperiment):
|
||||
#
|
||||
|
||||
def evaluate(self, global_step, rng, *unused_args, **unused_kwargs):
|
||||
return self.eval_epoch(self._avg_params or self._params, self._state, rng)
|
||||
scalars = self.eval_epoch(self._params, self._state, rng)
|
||||
if self._avg_params:
|
||||
avg_scalars = self.eval_epoch(self._avg_params or self._params,
|
||||
self._state, rng)
|
||||
for k, v in avg_scalars.items():
|
||||
scalars[k + '_swa'] = v
|
||||
return scalars
|
||||
|
||||
def eval_epoch(self, params, state, rng):
|
||||
host_id = jax.host_id()
|
||||
@@ -408,8 +419,11 @@ class Experiment(experiment.AbstractExperiment):
|
||||
self._repeat_batch = 1
|
||||
self.supervised_train_input = jl_utils.py_prefetch(
|
||||
self._supervised_train_dataset)
|
||||
self.extra_train_input = jl_utils.py_prefetch(
|
||||
self._extra_train_dataset)
|
||||
if self.config.training.extra_data_path is None:
|
||||
self.extra_train_input = None
|
||||
else:
|
||||
self.extra_train_input = jl_utils.py_prefetch(
|
||||
self._extra_train_dataset)
|
||||
self.normalize_fn = datasets.cifar10_normalize
|
||||
|
||||
# Optimizer.
|
||||
@@ -423,7 +437,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
# Create inputs to initialize the network state.
|
||||
images, _, _ = jax.pmap(self.concatenate)(
|
||||
next(self.supervised_train_input),
|
||||
next(self.extra_train_input))
|
||||
next(self.extra_train_input) if self.extra_train_input is not None
|
||||
else None)
|
||||
images = jax.pmap(self.normalize_fn)(images)
|
||||
# Initialize weights and biases.
|
||||
init_net = jax.pmap(
|
||||
|
||||
@@ -41,6 +41,23 @@ def get_cosine_schedule(
|
||||
], [warmup_steps])
|
||||
|
||||
|
||||
def get_step_schedule(
|
||||
max_learning_rate: float,
|
||||
total_steps: int,
|
||||
warmup_steps: int = 0) -> optax.Schedule:
|
||||
"""Builds a step schedule with initial warm-up."""
|
||||
if total_steps < warmup_steps:
|
||||
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||
transition_steps=warmup_steps)
|
||||
return optax.join_schedules([
|
||||
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||
transition_steps=warmup_steps),
|
||||
optax.piecewise_constant_schedule(
|
||||
init_value=max_learning_rate,
|
||||
boundaries_and_scales={total_steps * 2 // 3: .1}),
|
||||
], [warmup_steps])
|
||||
|
||||
|
||||
def sgd_momentum(learning_rate_fn: optax.Schedule,
|
||||
momentum: float = 0.,
|
||||
nesterov: bool = False) -> optax.GradientTransformation:
|
||||
@@ -118,8 +135,7 @@ def cutmix(rng: chex.PRNGKey,
|
||||
split: int = 1) -> Tuple[chex.Array, chex.Array]:
|
||||
"""Composing two images by inserting a patch into another image."""
|
||||
batch_size, height, width, _ = images.shape
|
||||
if split > 1:
|
||||
split_batch_size = batch_size // split
|
||||
split_batch_size = batch_size // split if split > 1 else batch_size
|
||||
|
||||
# Masking bounding box.
|
||||
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
|
||||
@@ -158,14 +174,10 @@ def _random_box(rng: chex.PRNGKey,
|
||||
cut_w: chex.Array) -> chex.Array:
|
||||
"""Sample a random box of shape [cut_h, cut_w]."""
|
||||
height_rng, width_rng = jax.random.split(rng)
|
||||
minval_h = 0
|
||||
minval_w = 0
|
||||
maxval_h = height
|
||||
maxval_w = width
|
||||
i = jax.random.randint(
|
||||
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32)
|
||||
height_rng, shape=(), minval=0, maxval=height, dtype=jnp.int32)
|
||||
j = jax.random.randint(
|
||||
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32)
|
||||
width_rng, shape=(), minval=0, maxval=width, dtype=jnp.int32)
|
||||
bby1 = jnp.clip(i - cut_h // 2, 0, height)
|
||||
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
|
||||
h = jnp.clip(i + cut_h // 2, 0, height) - bby1
|
||||
|
||||
Reference in New Issue
Block a user