mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 05:24:03 +08:00
Fixed issue in cutmix where split_batch_size would be undefined.
Allow users to easily specify that they do not wish to use extra data. PiperOrigin-RevId: 396784433
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
cff83be778
commit
41e2d45ed8
@@ -62,12 +62,17 @@ def get_config():
|
|||||||
# https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
|
# https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
|
||||||
# If the path is set to "cifar10_ddpm.npz" and is not found in the current
|
# If the path is set to "cifar10_ddpm.npz" and is not found in the current
|
||||||
# directory, the corresponding data will be downloaded.
|
# directory, the corresponding data will be downloaded.
|
||||||
extra_npz = 'cifar10_ddpm.npz'
|
extra_npz = 'cifar10_ddpm.npz' # Can be `None`.
|
||||||
|
|
||||||
# Learning rate.
|
# Learning rate.
|
||||||
learning_rate = .1 * max(train_batch_size / 256, 1.)
|
learning_rate = .1 * max(train_batch_size / 256, 1.)
|
||||||
learning_rate_warmup = steps_from_epochs(10)
|
learning_rate_warmup = steps_from_epochs(10)
|
||||||
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
|
use_cosine_schedule = True
|
||||||
|
if use_cosine_schedule:
|
||||||
|
learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
|
||||||
|
learning_rate_warmup)
|
||||||
|
else:
|
||||||
|
learning_rate_fn = utils.get_step_schedule(learning_rate, num_steps,
|
||||||
learning_rate_warmup)
|
learning_rate_warmup)
|
||||||
|
|
||||||
# Model definition.
|
# Model definition.
|
||||||
@@ -120,7 +125,7 @@ def get_config():
|
|||||||
weight_decay=5e-4,
|
weight_decay=5e-4,
|
||||||
swa_decay=.995,
|
swa_decay=.995,
|
||||||
use_cutmix=False,
|
use_cutmix=False,
|
||||||
supervised_batch_ratio=.3,
|
supervised_batch_ratio=.3 if extra_npz is not None else 1.,
|
||||||
extra_data_path=extra_npz,
|
extra_data_path=extra_npz,
|
||||||
extra_label_smoothing=.1,
|
extra_label_smoothing=.1,
|
||||||
attack=train_attack),
|
attack=train_attack),
|
||||||
@@ -343,7 +348,13 @@ class Experiment(experiment.AbstractExperiment):
|
|||||||
#
|
#
|
||||||
|
|
||||||
def evaluate(self, global_step, rng, *unused_args, **unused_kwargs):
|
def evaluate(self, global_step, rng, *unused_args, **unused_kwargs):
|
||||||
return self.eval_epoch(self._avg_params or self._params, self._state, rng)
|
scalars = self.eval_epoch(self._params, self._state, rng)
|
||||||
|
if self._avg_params:
|
||||||
|
avg_scalars = self.eval_epoch(self._avg_params or self._params,
|
||||||
|
self._state, rng)
|
||||||
|
for k, v in avg_scalars.items():
|
||||||
|
scalars[k + '_swa'] = v
|
||||||
|
return scalars
|
||||||
|
|
||||||
def eval_epoch(self, params, state, rng):
|
def eval_epoch(self, params, state, rng):
|
||||||
host_id = jax.host_id()
|
host_id = jax.host_id()
|
||||||
@@ -408,8 +419,11 @@ class Experiment(experiment.AbstractExperiment):
|
|||||||
self._repeat_batch = 1
|
self._repeat_batch = 1
|
||||||
self.supervised_train_input = jl_utils.py_prefetch(
|
self.supervised_train_input = jl_utils.py_prefetch(
|
||||||
self._supervised_train_dataset)
|
self._supervised_train_dataset)
|
||||||
self.extra_train_input = jl_utils.py_prefetch(
|
if self.config.training.extra_data_path is None:
|
||||||
self._extra_train_dataset)
|
self.extra_train_input = None
|
||||||
|
else:
|
||||||
|
self.extra_train_input = jl_utils.py_prefetch(
|
||||||
|
self._extra_train_dataset)
|
||||||
self.normalize_fn = datasets.cifar10_normalize
|
self.normalize_fn = datasets.cifar10_normalize
|
||||||
|
|
||||||
# Optimizer.
|
# Optimizer.
|
||||||
@@ -423,7 +437,8 @@ class Experiment(experiment.AbstractExperiment):
|
|||||||
# Create inputs to initialize the network state.
|
# Create inputs to initialize the network state.
|
||||||
images, _, _ = jax.pmap(self.concatenate)(
|
images, _, _ = jax.pmap(self.concatenate)(
|
||||||
next(self.supervised_train_input),
|
next(self.supervised_train_input),
|
||||||
next(self.extra_train_input))
|
next(self.extra_train_input) if self.extra_train_input is not None
|
||||||
|
else None)
|
||||||
images = jax.pmap(self.normalize_fn)(images)
|
images = jax.pmap(self.normalize_fn)(images)
|
||||||
# Initialize weights and biases.
|
# Initialize weights and biases.
|
||||||
init_net = jax.pmap(
|
init_net = jax.pmap(
|
||||||
|
|||||||
@@ -41,6 +41,23 @@ def get_cosine_schedule(
|
|||||||
], [warmup_steps])
|
], [warmup_steps])
|
||||||
|
|
||||||
|
|
||||||
|
def get_step_schedule(
|
||||||
|
max_learning_rate: float,
|
||||||
|
total_steps: int,
|
||||||
|
warmup_steps: int = 0) -> optax.Schedule:
|
||||||
|
"""Builds a step schedule with initial warm-up."""
|
||||||
|
if total_steps < warmup_steps:
|
||||||
|
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||||
|
transition_steps=warmup_steps)
|
||||||
|
return optax.join_schedules([
|
||||||
|
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||||
|
transition_steps=warmup_steps),
|
||||||
|
optax.piecewise_constant_schedule(
|
||||||
|
init_value=max_learning_rate,
|
||||||
|
boundaries_and_scales={total_steps * 2 // 3: .1}),
|
||||||
|
], [warmup_steps])
|
||||||
|
|
||||||
|
|
||||||
def sgd_momentum(learning_rate_fn: optax.Schedule,
|
def sgd_momentum(learning_rate_fn: optax.Schedule,
|
||||||
momentum: float = 0.,
|
momentum: float = 0.,
|
||||||
nesterov: bool = False) -> optax.GradientTransformation:
|
nesterov: bool = False) -> optax.GradientTransformation:
|
||||||
@@ -118,8 +135,7 @@ def cutmix(rng: chex.PRNGKey,
|
|||||||
split: int = 1) -> Tuple[chex.Array, chex.Array]:
|
split: int = 1) -> Tuple[chex.Array, chex.Array]:
|
||||||
"""Composing two images by inserting a patch into another image."""
|
"""Composing two images by inserting a patch into another image."""
|
||||||
batch_size, height, width, _ = images.shape
|
batch_size, height, width, _ = images.shape
|
||||||
if split > 1:
|
split_batch_size = batch_size // split if split > 1 else batch_size
|
||||||
split_batch_size = batch_size // split
|
|
||||||
|
|
||||||
# Masking bounding box.
|
# Masking bounding box.
|
||||||
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
|
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
|
||||||
@@ -158,14 +174,10 @@ def _random_box(rng: chex.PRNGKey,
|
|||||||
cut_w: chex.Array) -> chex.Array:
|
cut_w: chex.Array) -> chex.Array:
|
||||||
"""Sample a random box of shape [cut_h, cut_w]."""
|
"""Sample a random box of shape [cut_h, cut_w]."""
|
||||||
height_rng, width_rng = jax.random.split(rng)
|
height_rng, width_rng = jax.random.split(rng)
|
||||||
minval_h = 0
|
|
||||||
minval_w = 0
|
|
||||||
maxval_h = height
|
|
||||||
maxval_w = width
|
|
||||||
i = jax.random.randint(
|
i = jax.random.randint(
|
||||||
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32)
|
height_rng, shape=(), minval=0, maxval=height, dtype=jnp.int32)
|
||||||
j = jax.random.randint(
|
j = jax.random.randint(
|
||||||
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32)
|
width_rng, shape=(), minval=0, maxval=width, dtype=jnp.int32)
|
||||||
bby1 = jnp.clip(i - cut_h // 2, 0, height)
|
bby1 = jnp.clip(i - cut_h // 2, 0, height)
|
||||||
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
|
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
|
||||||
h = jnp.clip(i + cut_h // 2, 0, height) - bby1
|
h = jnp.clip(i + cut_h // 2, 0, height) - bby1
|
||||||
|
|||||||
Reference in New Issue
Block a user