mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-20 11:22:09 +08:00
Internal change.
PiperOrigin-RevId: 348193373
This commit is contained in:
committed by
Louise Deason
parent
17700a6f8f
commit
a6aeb2641f
@@ -0,0 +1,63 @@
|
||||
# Adversarial Robustness
|
||||
|
||||
This repository contains the code needed to evaluate models trained in
|
||||
[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593)
|
||||
|
||||
|
||||
## Contents
|
||||
|
||||
We have released our top-performing models in two formats compatible with
|
||||
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
|
||||
This repository also contains our model definitions.
|
||||
|
||||
## Running the example code
|
||||
|
||||
### Downloading a model
|
||||
|
||||
Download a model from links listed in the following table.
|
||||
Clean and robust accuracies are measured on the full test set.
|
||||
The robust accuracy is measured using
|
||||
[AutoAttack](https://github.com/fra31/auto-attack).
|
||||
|
||||
| dataset | norm | radius | architecture | extra data | clean | robust | link |
|
||||
|---|:---:|:---:|:---:|:---:|---:|---:|:---:|
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✓ | 91.10% | 65.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-28-10 | ✓ | 89.48% | 62.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 85.29% | 57.20% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-34-20 | ✗ | 85.64% | 56.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.pt)
|
||||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-70-16 | ✓ | 94.74% | 80.53% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt)
|
||||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-70-16 | ✗ | 90.90% | 74.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✓ | 69.15% | 36.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 60.86% | 30.03% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt)
|
||||
| MNIST | ℓ<sub>∞</sub> | 0.3 | WRN-28-10 | ✗ | 99.26% | 96.34% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.pt)
|
||||
|
||||
### Using the model
|
||||
|
||||
Once downloaded, a model can be evaluated (clean accuracy) by running the
|
||||
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
|
||||
|
||||
```
|
||||
cd jax
|
||||
python3 eval.py \
|
||||
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
|
||||
```
|
||||
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use this code or these models in your work, please cite the accompanying
|
||||
paper:
|
||||
|
||||
```
|
||||
@article{gowal2020uncovering,
|
||||
title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples},
|
||||
author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet},
|
||||
journal={arXiv preprint arXiv:2010.03593},
|
||||
year={2020},
|
||||
url={https://arxiv.org/pdf/2010.03593}
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official Google product.
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tqdm
|
||||
|
||||
from adversarial_robustness.jax import model_zoo
|
||||
|
||||
_CKPT = flags.DEFINE_string(
|
||||
'ckpt', None, 'Path to checkpoint.')
|
||||
_DATASET = flags.DEFINE_enum(
|
||||
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
|
||||
'Dataset on which the checkpoint is evaluated.')
|
||||
_WIDTH = flags.DEFINE_integer(
|
||||
'width', 16, 'Width of WideResNet.')
|
||||
_DEPTH = flags.DEFINE_integer(
|
||||
'depth', 70, 'Depth of WideResNet.')
|
||||
_BATCH_SIZE = flags.DEFINE_integer(
|
||||
'batch_size', 100, 'Batch size.')
|
||||
_NUM_BATCHES = flags.DEFINE_integer(
|
||||
'num_batches', 0,
|
||||
'Number of batches to evaluate (zero means the whole dataset).')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
print(f'Loading "{_CKPT.value}"')
|
||||
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
|
||||
f'{_WIDTH.value}.')
|
||||
|
||||
# Create dataset.
|
||||
if _DATASET.value == 'mnist':
|
||||
_, data_test = tf.keras.datasets.mnist.load_data()
|
||||
normalize_fn = model_zoo.mnist_normalize
|
||||
elif _DATASET.value == 'cifar10':
|
||||
_, data_test = tf.keras.datasets.cifar10.load_data()
|
||||
normalize_fn = model_zoo.cifar10_normalize
|
||||
else:
|
||||
assert _DATASET.value == 'cifar100'
|
||||
_, data_test = tf.keras.datasets.cifar100.load_data()
|
||||
normalize_fn = model_zoo.cifar100_normalize
|
||||
|
||||
# Create model.
|
||||
@hk.transform_with_state
|
||||
def model_fn(x, is_training=False):
|
||||
model = model_zoo.WideResNet(
|
||||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation='swish')
|
||||
return model(normalize_fn(x), is_training=is_training)
|
||||
|
||||
# Build dataset.
|
||||
images, labels = data_test
|
||||
samples = (images.astype(np.float32) / 255.,
|
||||
np.squeeze(labels, axis=-1).astype(np.int64))
|
||||
data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value)
|
||||
test_loader = tfds.as_numpy(data)
|
||||
|
||||
# Load model parameters.
|
||||
rng_seq = hk.PRNGSequence(0)
|
||||
if _CKPT.value == 'dummy':
|
||||
for images, _ in test_loader:
|
||||
break
|
||||
params, state = model_fn.init(next(rng_seq), images, is_training=True)
|
||||
# Reset iterator.
|
||||
test_loader = tfds.as_numpy(data)
|
||||
else:
|
||||
params, state = np.load(_CKPT.value, allow_pickle=True)
|
||||
|
||||
# Evaluation.
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_count = 0
|
||||
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
|
||||
for images, labels in tqdm.tqdm(test_loader, total=total_batches):
|
||||
outputs = model_fn.apply(params, state, next(rng_seq), images)[0]
|
||||
predicted = np.argmax(outputs, 1)
|
||||
total += labels.shape[0]
|
||||
correct += (predicted == labels).sum().item()
|
||||
batch_count += 1
|
||||
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
|
||||
break
|
||||
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('ckpt')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,165 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""WideResNet implementation in JAX using Haiku."""
|
||||
|
||||
from typing import Any, Mapping, Optional, Text
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
||||
CIFAR10_STD = (0.2471, 0.2435, 0.2616)
|
||||
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
|
||||
CIFAR100_STD = (0.2673, 0.2564, 0.2762)
|
||||
|
||||
|
||||
class _WideResNetBlock(hk.Module):
|
||||
"""Block of a WideResNet."""
|
||||
|
||||
def __init__(self, num_filters, stride=1, projection_shortcut=False,
|
||||
activation=jax.nn.relu, norm_args=None, name=None):
|
||||
super().__init__(name=name)
|
||||
num_bottleneck_layers = 1
|
||||
self._activation = activation
|
||||
if norm_args is None:
|
||||
norm_args = {
|
||||
'create_offset': False,
|
||||
'create_scale': True,
|
||||
'decay_rate': .99,
|
||||
}
|
||||
self._bn_modules = []
|
||||
self._conv_modules = []
|
||||
for i in range(num_bottleneck_layers + 1):
|
||||
s = stride if i == 0 else 1
|
||||
self._bn_modules.append(hk.BatchNorm(
|
||||
name='batchnorm_{}'.format(i),
|
||||
**norm_args))
|
||||
self._conv_modules.append(hk.Conv2D(
|
||||
output_channels=num_filters,
|
||||
padding='SAME',
|
||||
kernel_shape=(3, 3),
|
||||
stride=s,
|
||||
with_bias=False,
|
||||
name='conv_{}'.format(i))) # pytype: disable=not-callable
|
||||
if projection_shortcut:
|
||||
self._shortcut = hk.Conv2D(
|
||||
output_channels=num_filters,
|
||||
kernel_shape=(1, 1),
|
||||
stride=stride,
|
||||
with_bias=False,
|
||||
name='shortcut') # pytype: disable=not-callable
|
||||
else:
|
||||
self._shortcut = None
|
||||
|
||||
def __call__(self, inputs, **norm_kwargs):
|
||||
x = inputs
|
||||
orig_x = inputs
|
||||
for i, (bn, conv) in enumerate(zip(self._bn_modules, self._conv_modules)):
|
||||
x = bn(x, **norm_kwargs)
|
||||
x = self._activation(x)
|
||||
if self._shortcut is not None and i == 0:
|
||||
orig_x = x
|
||||
x = conv(x)
|
||||
if self._shortcut is not None:
|
||||
shortcut_x = self._shortcut(orig_x)
|
||||
x += shortcut_x
|
||||
else:
|
||||
x += orig_x
|
||||
return x
|
||||
|
||||
|
||||
class WideResNet(hk.Module):
|
||||
"""WideResNet designed for CIFAR-10."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 10,
|
||||
depth: int = 28,
|
||||
width: int = 10,
|
||||
activation: Text = 'relu',
|
||||
norm_args: Optional[Mapping[Text, Any]] = None,
|
||||
name: Optional[Text] = None):
|
||||
super(WideResNet, self).__init__(name=name)
|
||||
if (depth - 4) % 6 != 0:
|
||||
raise ValueError('depth should be 6n+4.')
|
||||
self._activation = getattr(jax.nn, activation)
|
||||
if norm_args is None:
|
||||
norm_args = {
|
||||
'create_offset': False,
|
||||
'create_scale': True,
|
||||
'decay_rate': .99,
|
||||
}
|
||||
self._conv = hk.Conv2D(
|
||||
output_channels=16,
|
||||
kernel_shape=(3, 3),
|
||||
stride=1,
|
||||
with_bias=False,
|
||||
name='init_conv') # pytype: disable=not-callable
|
||||
self._bn = hk.BatchNorm(
|
||||
name='batchnorm',
|
||||
**norm_args)
|
||||
self._linear = hk.Linear(
|
||||
num_classes,
|
||||
name='logits')
|
||||
|
||||
blocks_per_layer = (depth - 4) // 6
|
||||
filter_sizes = [width * n for n in [16, 32, 64]]
|
||||
self._blocks = []
|
||||
for layer_num, filter_size in enumerate(filter_sizes):
|
||||
blocks_of_layer = []
|
||||
for i in range(blocks_per_layer):
|
||||
stride = 2 if (layer_num != 0 and i == 0) else 1
|
||||
projection_shortcut = (i == 0)
|
||||
blocks_of_layer.append(_WideResNetBlock(
|
||||
num_filters=filter_size,
|
||||
stride=stride,
|
||||
projection_shortcut=projection_shortcut,
|
||||
activation=self._activation,
|
||||
norm_args=norm_args,
|
||||
name='resnet_lay_{}_block_{}'.format(layer_num, i)))
|
||||
self._blocks.append(blocks_of_layer)
|
||||
|
||||
def __call__(self, inputs, **norm_kwargs):
|
||||
net = inputs
|
||||
net = self._conv(net)
|
||||
|
||||
# Blocks.
|
||||
for blocks_of_layer in self._blocks:
|
||||
for block in blocks_of_layer:
|
||||
net = block(net, **norm_kwargs)
|
||||
net = self._bn(net, **norm_kwargs)
|
||||
net = self._activation(net)
|
||||
|
||||
net = jnp.mean(net, axis=[1, 2])
|
||||
return self._linear(net)
|
||||
|
||||
|
||||
def mnist_normalize(image: jnp.array) -> jnp.array:
|
||||
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
|
||||
constant_values=0)
|
||||
return (image - .5) * 2.
|
||||
|
||||
|
||||
def cifar10_normalize(image: jnp.array) -> jnp.array:
|
||||
means = jnp.array(CIFAR10_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(CIFAR10_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
|
||||
|
||||
def cifar100_normalize(image: jnp.array) -> jnp.array:
|
||||
means = jnp.array(CIFAR100_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(CIFAR100_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Evaluates a PyTorch checkpoint on CIFAR-10/100 or MNIST."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import torch
|
||||
from torch.utils import data
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import tqdm
|
||||
|
||||
from adversarial_robustness.pytorch import model_zoo
|
||||
|
||||
_CKPT = flags.DEFINE_string(
|
||||
'ckpt', None, 'Path to checkpoint.')
|
||||
_DATASET = flags.DEFINE_enum(
|
||||
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
|
||||
'Dataset on which the checkpoint is evaluated.')
|
||||
_WIDTH = flags.DEFINE_integer(
|
||||
'width', 16, 'Width of WideResNet.')
|
||||
_DEPTH = flags.DEFINE_integer(
|
||||
'depth', 70, 'Depth of WideResNet.')
|
||||
_USE_CUDA = flags.DEFINE_boolean(
|
||||
'use_cuda', True, 'Whether to use CUDA.')
|
||||
_BATCH_SIZE = flags.DEFINE_integer(
|
||||
'batch_size', 100, 'Batch size.')
|
||||
_NUM_BATCHES = flags.DEFINE_integer(
|
||||
'num_batches', 0,
|
||||
'Number of batches to evaluate (zero means the whole dataset).')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
print(f'Loading "{_CKPT.value}"')
|
||||
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
|
||||
f'{_WIDTH.value}.')
|
||||
|
||||
# Create model and dataset.
|
||||
if _DATASET.value == 'mnist':
|
||||
model = model_zoo.WideResNet(
|
||||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2,
|
||||
num_input_channels=1)
|
||||
dataset_fn = datasets.MNIST
|
||||
elif _DATASET.value == 'cifar10':
|
||||
model = model_zoo.WideResNet(
|
||||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN,
|
||||
std=model_zoo.CIFAR10_STD)
|
||||
dataset_fn = datasets.CIFAR10
|
||||
else:
|
||||
assert _DATASET.value == 'cifar100'
|
||||
model = model_zoo.WideResNet(
|
||||
num_classes=100, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN,
|
||||
std=model_zoo.CIFAR100_STD)
|
||||
dataset_fn = datasets.CIFAR100
|
||||
|
||||
# Load model.
|
||||
if _CKPT.value != 'dummy':
|
||||
params = torch.load(_CKPT.value)
|
||||
model.load_state_dict(params)
|
||||
if _USE_CUDA.value:
|
||||
model.cuda()
|
||||
model.eval()
|
||||
print('Successfully loaded.')
|
||||
|
||||
# Load dataset.
|
||||
transform_chain = transforms.Compose([transforms.ToTensor()])
|
||||
ds = dataset_fn(root='/tmp/data', train=False, transform=transform_chain,
|
||||
download=True)
|
||||
test_loader = data.DataLoader(ds, batch_size=_BATCH_SIZE.value, shuffle=False,
|
||||
num_workers=0)
|
||||
|
||||
# Evaluation.
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_count = 0
|
||||
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm.tqdm(test_loader, total=total_batches):
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
batch_count += 1
|
||||
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
|
||||
break
|
||||
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('ckpt')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""WideResNet implementation in PyTorch."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
||||
CIFAR10_STD = (0.2471, 0.2435, 0.2616)
|
||||
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
|
||||
CIFAR100_STD = (0.2673, 0.2564, 0.2762)
|
||||
|
||||
|
||||
class _Swish(torch.autograd.Function):
|
||||
"""Custom implementation of swish."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
result = i * torch.sigmoid(i)
|
||||
ctx.save_for_backward(i)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
i = ctx.saved_variables[0]
|
||||
sigmoid_i = torch.sigmoid(i)
|
||||
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Module using custom implementation."""
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return _Swish.apply(input_tensor)
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""WideResNet Block."""
|
||||
|
||||
def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
|
||||
super().__init__()
|
||||
self.batchnorm_0 = nn.BatchNorm2d(in_planes)
|
||||
self.relu_0 = activation_fn()
|
||||
# We manually pad to obtain the same effect as `SAME` (necessary when
|
||||
# `stride` is different than 1).
|
||||
self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=0, bias=False)
|
||||
self.batchnorm_1 = nn.BatchNorm2d(out_planes)
|
||||
self.relu_1 = activation_fn()
|
||||
self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.has_shortcut = in_planes != out_planes
|
||||
if self.has_shortcut:
|
||||
self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1,
|
||||
stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self._stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_shortcut:
|
||||
x = self.relu_0(self.batchnorm_0(x))
|
||||
else:
|
||||
out = self.relu_0(self.batchnorm_0(x))
|
||||
v = x if self.has_shortcut else out
|
||||
if self._stride == 1:
|
||||
v = F.pad(v, (1, 1, 1, 1))
|
||||
elif self._stride == 2:
|
||||
v = F.pad(v, (0, 1, 0, 1))
|
||||
else:
|
||||
raise ValueError('Unsupported `stride`.')
|
||||
out = self.conv_0(v)
|
||||
out = self.relu_1(self.batchnorm_1(out))
|
||||
out = self.conv_1(out)
|
||||
out = torch.add(self.shortcut(x) if self.has_shortcut else x, out)
|
||||
return out
|
||||
|
||||
|
||||
class _BlockGroup(nn.Module):
|
||||
"""WideResNet block group."""
|
||||
|
||||
def __init__(self, num_blocks, in_planes, out_planes, stride,
|
||||
activation_fn=nn.ReLU):
|
||||
super().__init__()
|
||||
block = []
|
||||
for i in range(num_blocks):
|
||||
block.append(
|
||||
_Block(i == 0 and in_planes or out_planes,
|
||||
out_planes,
|
||||
i == 0 and stride or 1,
|
||||
activation_fn=activation_fn))
|
||||
self.block = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
"""WideResNet."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 10,
|
||||
depth: int = 28,
|
||||
width: int = 10,
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
|
||||
std: Union[Tuple[float, ...], float] = CIFAR10_STD,
|
||||
padding: int = 0,
|
||||
num_input_channels: int = 3):
|
||||
super().__init__()
|
||||
self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
|
||||
self.std = torch.tensor(std).view(num_input_channels, 1, 1)
|
||||
self.mean_cuda = None
|
||||
self.std_cuda = None
|
||||
self.padding = padding
|
||||
num_channels = [16, 16 * width, 32 * width, 64 * width]
|
||||
assert (depth - 4) % 6 == 0
|
||||
num_blocks = (depth - 4) // 6
|
||||
self.init_conv = nn.Conv2d(num_input_channels, num_channels[0],
|
||||
kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.layer = nn.Sequential(
|
||||
_BlockGroup(num_blocks, num_channels[0], num_channels[1], 1,
|
||||
activation_fn=activation_fn),
|
||||
_BlockGroup(num_blocks, num_channels[1], num_channels[2], 2,
|
||||
activation_fn=activation_fn),
|
||||
_BlockGroup(num_blocks, num_channels[2], num_channels[3], 2,
|
||||
activation_fn=activation_fn))
|
||||
self.batchnorm = nn.BatchNorm2d(num_channels[3])
|
||||
self.relu = activation_fn()
|
||||
self.logits = nn.Linear(num_channels[3], num_classes)
|
||||
self.num_channels = num_channels[3]
|
||||
|
||||
def forward(self, x):
|
||||
if self.padding > 0:
|
||||
x = F.pad(x, (self.padding,) * 4)
|
||||
if x.is_cuda:
|
||||
if self.mean_cuda is None:
|
||||
self.mean_cuda = self.mean.cuda()
|
||||
self.std_cuda = self.std.cuda()
|
||||
out = (x - self.mean_cuda) / self.std_cuda
|
||||
else:
|
||||
out = (x - self.mean) / self.std
|
||||
out = self.init_conv(out)
|
||||
out = self.layer(out)
|
||||
out = self.relu(self.batchnorm(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.num_channels)
|
||||
return self.logits(out)
|
||||
@@ -0,0 +1,51 @@
|
||||
absl-py==0.10.0
|
||||
astunparse==1.6.3
|
||||
attrs==20.3.0
|
||||
cachetools==4.1.1
|
||||
certifi==2020.11.8
|
||||
chardet==3.0.4
|
||||
dataclasses==0.6
|
||||
dill==0.3.3
|
||||
dm-haiku==0.0.3
|
||||
flatbuffers==1.12
|
||||
future==0.18.2
|
||||
gast==0.3.3
|
||||
google-auth==1.23.0
|
||||
google-auth-oauthlib==0.4.2
|
||||
google-pasta==0.2.0
|
||||
googleapis-common-protos==1.52.0
|
||||
grpcio==1.33.2
|
||||
h5py==2.10.0
|
||||
idna==2.10
|
||||
importlib-resources==3.3.0
|
||||
jax==0.2.6
|
||||
jaxlib==0.1.57
|
||||
Keras-Preprocessing==1.1.2
|
||||
Markdown==3.3.3
|
||||
numpy==1.18.5
|
||||
oauthlib==3.1.0
|
||||
opt-einsum==3.3.0
|
||||
Pillow==8.0.1
|
||||
promise==2.3
|
||||
protobuf==3.14.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
requests==2.25.0
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.6
|
||||
scipy==1.5.4
|
||||
six==1.15.0
|
||||
tensorboard==2.4.0
|
||||
tensorboard-plugin-wit==1.7.0
|
||||
tensorflow==2.3.1
|
||||
tensorflow-datasets==4.1.0
|
||||
tensorflow-estimator==2.3.0
|
||||
tensorflow-metadata==0.25.0
|
||||
termcolor==1.1.0
|
||||
torch==1.7.0
|
||||
torchvision==0.8.1
|
||||
tqdm==4.53.0
|
||||
typing-extensions==3.7.4.3
|
||||
urllib3==1.26.2
|
||||
Werkzeug==1.0.1
|
||||
wrapt==1.12.1
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv adversarial_robustness_venv
|
||||
source adversarial_robustness_venv/bin/activate
|
||||
pip install -r adversarial_robustness/requirements.txt
|
||||
|
||||
python3 -m adversarial_robustness.jax.eval \
|
||||
--ckpt=dummy \
|
||||
--arch=wrn-10-1 \
|
||||
--dataset=cifar10 \
|
||||
--batch_size=1 \
|
||||
--num_batches=1
|
||||
|
||||
python3 -m adversarial_robustness.pytorch.eval \
|
||||
--ckpt=dummy \
|
||||
--arch=wrn-10-1 \
|
||||
--dataset=cifar10 \
|
||||
--batch_size=1 \
|
||||
--num_batches=1 \
|
||||
--nouse_cuda
|
||||
Reference in New Issue
Block a user