diff --git a/adversarial_robustness/README.md b/adversarial_robustness/README.md new file mode 100644 index 0000000..95558c5 --- /dev/null +++ b/adversarial_robustness/README.md @@ -0,0 +1,63 @@ +# Adversarial Robustness + +This repository contains the code needed to evaluate models trained in +[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593) + + +## Contents + +We have released our top-performing models in two formats compatible with +[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/). +This repository also contains our model definitions. + +## Running the example code + +### Downloading a model + +Download a model from links listed in the following table. +Clean and robust accuracies are measured on the full test set. +The robust accuracy is measured using +[AutoAttack](https://github.com/fra31/auto-attack). + +| dataset | norm | radius | architecture | extra data | clean | robust | link | +|---|:---:|:---:|:---:|:---:|---:|---:|:---:| +| CIFAR-10 | ℓ | 8 / 255 | WRN-70-16 | ✓ | 91.10% | 65.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt) +| CIFAR-10 | ℓ | 8 / 255 | WRN-28-10 | ✓ | 89.48% | 62.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt) +| CIFAR-10 | ℓ | 8 / 255 | WRN-70-16 | ✗ | 85.29% | 57.20% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.pt) +| CIFAR-10 | ℓ | 8 / 255 | WRN-34-20 | ✗ | 85.64% | 56.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.pt) +| CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✓ | 94.74% | 80.53% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt) +| CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 90.90% | 74.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt) +| CIFAR-100 | ℓ | 8 / 255 | WRN-70-16 | ✓ | 69.15% | 36.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt) +| CIFAR-100 | ℓ | 8 / 255 | WRN-70-16 | ✗ | 60.86% | 30.03% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt) +| MNIST | ℓ | 0.3 | WRN-28-10 | ✗ | 99.26% | 96.34% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.pt) + +### Using the model + +Once downloaded, a model can be evaluated (clean accuracy) by running the +`eval.py` script in either the `jax` or `pytorch` folders. E.g.: + +``` +cd jax +python3 eval.py \ + --ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10 +``` + + +## Citing this work + +If you use this code or these models in your work, please cite the accompanying +paper: + +``` +@article{gowal2020uncovering, + title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples}, + author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet}, + journal={arXiv preprint arXiv:2010.03593}, + year={2020}, + url={https://arxiv.org/pdf/2010.03593} +} +``` + +## Disclaimer + +This is not an official Google product. diff --git a/adversarial_robustness/jax/eval.py b/adversarial_robustness/jax/eval.py new file mode 100644 index 0000000..3ef8c5c --- /dev/null +++ b/adversarial_robustness/jax/eval.py @@ -0,0 +1,104 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST.""" + +from absl import app +from absl import flags +import haiku as hk +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds +import tqdm + +from adversarial_robustness.jax import model_zoo + +_CKPT = flags.DEFINE_string( + 'ckpt', None, 'Path to checkpoint.') +_DATASET = flags.DEFINE_enum( + 'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'], + 'Dataset on which the checkpoint is evaluated.') +_WIDTH = flags.DEFINE_integer( + 'width', 16, 'Width of WideResNet.') +_DEPTH = flags.DEFINE_integer( + 'depth', 70, 'Depth of WideResNet.') +_BATCH_SIZE = flags.DEFINE_integer( + 'batch_size', 100, 'Batch size.') +_NUM_BATCHES = flags.DEFINE_integer( + 'num_batches', 0, + 'Number of batches to evaluate (zero means the whole dataset).') + + +def main(unused_argv): + print(f'Loading "{_CKPT.value}"') + print(f'Using a WideResNet with depth {_DEPTH.value} and width ' + f'{_WIDTH.value}.') + + # Create dataset. + if _DATASET.value == 'mnist': + _, data_test = tf.keras.datasets.mnist.load_data() + normalize_fn = model_zoo.mnist_normalize + elif _DATASET.value == 'cifar10': + _, data_test = tf.keras.datasets.cifar10.load_data() + normalize_fn = model_zoo.cifar10_normalize + else: + assert _DATASET.value == 'cifar100' + _, data_test = tf.keras.datasets.cifar100.load_data() + normalize_fn = model_zoo.cifar100_normalize + + # Create model. + @hk.transform_with_state + def model_fn(x, is_training=False): + model = model_zoo.WideResNet( + num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, + activation='swish') + return model(normalize_fn(x), is_training=is_training) + + # Build dataset. + images, labels = data_test + samples = (images.astype(np.float32) / 255., + np.squeeze(labels, axis=-1).astype(np.int64)) + data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value) + test_loader = tfds.as_numpy(data) + + # Load model parameters. + rng_seq = hk.PRNGSequence(0) + if _CKPT.value == 'dummy': + for images, _ in test_loader: + break + params, state = model_fn.init(next(rng_seq), images, is_training=True) + # Reset iterator. + test_loader = tfds.as_numpy(data) + else: + params, state = np.load(_CKPT.value, allow_pickle=True) + + # Evaluation. + correct = 0 + total = 0 + batch_count = 0 + total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value) + for images, labels in tqdm.tqdm(test_loader, total=total_batches): + outputs = model_fn.apply(params, state, next(rng_seq), images)[0] + predicted = np.argmax(outputs, 1) + total += labels.shape[0] + correct += (predicted == labels).sum().item() + batch_count += 1 + if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value: + break + print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%') + + +if __name__ == '__main__': + flags.mark_flag_as_required('ckpt') + app.run(main) diff --git a/adversarial_robustness/jax/model_zoo.py b/adversarial_robustness/jax/model_zoo.py new file mode 100644 index 0000000..af54646 --- /dev/null +++ b/adversarial_robustness/jax/model_zoo.py @@ -0,0 +1,165 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WideResNet implementation in JAX using Haiku.""" + +from typing import Any, Mapping, Optional, Text + +import haiku as hk +import jax +import jax.numpy as jnp + + +CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +CIFAR10_STD = (0.2471, 0.2435, 0.2616) +CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) +CIFAR100_STD = (0.2673, 0.2564, 0.2762) + + +class _WideResNetBlock(hk.Module): + """Block of a WideResNet.""" + + def __init__(self, num_filters, stride=1, projection_shortcut=False, + activation=jax.nn.relu, norm_args=None, name=None): + super().__init__(name=name) + num_bottleneck_layers = 1 + self._activation = activation + if norm_args is None: + norm_args = { + 'create_offset': False, + 'create_scale': True, + 'decay_rate': .99, + } + self._bn_modules = [] + self._conv_modules = [] + for i in range(num_bottleneck_layers + 1): + s = stride if i == 0 else 1 + self._bn_modules.append(hk.BatchNorm( + name='batchnorm_{}'.format(i), + **norm_args)) + self._conv_modules.append(hk.Conv2D( + output_channels=num_filters, + padding='SAME', + kernel_shape=(3, 3), + stride=s, + with_bias=False, + name='conv_{}'.format(i))) # pytype: disable=not-callable + if projection_shortcut: + self._shortcut = hk.Conv2D( + output_channels=num_filters, + kernel_shape=(1, 1), + stride=stride, + with_bias=False, + name='shortcut') # pytype: disable=not-callable + else: + self._shortcut = None + + def __call__(self, inputs, **norm_kwargs): + x = inputs + orig_x = inputs + for i, (bn, conv) in enumerate(zip(self._bn_modules, self._conv_modules)): + x = bn(x, **norm_kwargs) + x = self._activation(x) + if self._shortcut is not None and i == 0: + orig_x = x + x = conv(x) + if self._shortcut is not None: + shortcut_x = self._shortcut(orig_x) + x += shortcut_x + else: + x += orig_x + return x + + +class WideResNet(hk.Module): + """WideResNet designed for CIFAR-10.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 28, + width: int = 10, + activation: Text = 'relu', + norm_args: Optional[Mapping[Text, Any]] = None, + name: Optional[Text] = None): + super(WideResNet, self).__init__(name=name) + if (depth - 4) % 6 != 0: + raise ValueError('depth should be 6n+4.') + self._activation = getattr(jax.nn, activation) + if norm_args is None: + norm_args = { + 'create_offset': False, + 'create_scale': True, + 'decay_rate': .99, + } + self._conv = hk.Conv2D( + output_channels=16, + kernel_shape=(3, 3), + stride=1, + with_bias=False, + name='init_conv') # pytype: disable=not-callable + self._bn = hk.BatchNorm( + name='batchnorm', + **norm_args) + self._linear = hk.Linear( + num_classes, + name='logits') + + blocks_per_layer = (depth - 4) // 6 + filter_sizes = [width * n for n in [16, 32, 64]] + self._blocks = [] + for layer_num, filter_size in enumerate(filter_sizes): + blocks_of_layer = [] + for i in range(blocks_per_layer): + stride = 2 if (layer_num != 0 and i == 0) else 1 + projection_shortcut = (i == 0) + blocks_of_layer.append(_WideResNetBlock( + num_filters=filter_size, + stride=stride, + projection_shortcut=projection_shortcut, + activation=self._activation, + norm_args=norm_args, + name='resnet_lay_{}_block_{}'.format(layer_num, i))) + self._blocks.append(blocks_of_layer) + + def __call__(self, inputs, **norm_kwargs): + net = inputs + net = self._conv(net) + + # Blocks. + for blocks_of_layer in self._blocks: + for block in blocks_of_layer: + net = block(net, **norm_kwargs) + net = self._bn(net, **norm_kwargs) + net = self._activation(net) + + net = jnp.mean(net, axis=[1, 2]) + return self._linear(net) + + +def mnist_normalize(image: jnp.array) -> jnp.array: + image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant', + constant_values=0) + return (image - .5) * 2. + + +def cifar10_normalize(image: jnp.array) -> jnp.array: + means = jnp.array(CIFAR10_MEAN, dtype=image.dtype) + stds = jnp.array(CIFAR10_STD, dtype=image.dtype) + return (image - means) / stds + + +def cifar100_normalize(image: jnp.array) -> jnp.array: + means = jnp.array(CIFAR100_MEAN, dtype=image.dtype) + stds = jnp.array(CIFAR100_STD, dtype=image.dtype) + return (image - means) / stds diff --git a/adversarial_robustness/pytorch/eval.py b/adversarial_robustness/pytorch/eval.py new file mode 100644 index 0000000..069a83f --- /dev/null +++ b/adversarial_robustness/pytorch/eval.py @@ -0,0 +1,106 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluates a PyTorch checkpoint on CIFAR-10/100 or MNIST.""" + +from absl import app +from absl import flags +import torch +from torch.utils import data +from torchvision import datasets +from torchvision import transforms +import tqdm + +from adversarial_robustness.pytorch import model_zoo + +_CKPT = flags.DEFINE_string( + 'ckpt', None, 'Path to checkpoint.') +_DATASET = flags.DEFINE_enum( + 'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'], + 'Dataset on which the checkpoint is evaluated.') +_WIDTH = flags.DEFINE_integer( + 'width', 16, 'Width of WideResNet.') +_DEPTH = flags.DEFINE_integer( + 'depth', 70, 'Depth of WideResNet.') +_USE_CUDA = flags.DEFINE_boolean( + 'use_cuda', True, 'Whether to use CUDA.') +_BATCH_SIZE = flags.DEFINE_integer( + 'batch_size', 100, 'Batch size.') +_NUM_BATCHES = flags.DEFINE_integer( + 'num_batches', 0, + 'Number of batches to evaluate (zero means the whole dataset).') + + +def main(unused_argv): + print(f'Loading "{_CKPT.value}"') + print(f'Using a WideResNet with depth {_DEPTH.value} and width ' + f'{_WIDTH.value}.') + + # Create model and dataset. + if _DATASET.value == 'mnist': + model = model_zoo.WideResNet( + num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, + activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2, + num_input_channels=1) + dataset_fn = datasets.MNIST + elif _DATASET.value == 'cifar10': + model = model_zoo.WideResNet( + num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, + activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN, + std=model_zoo.CIFAR10_STD) + dataset_fn = datasets.CIFAR10 + else: + assert _DATASET.value == 'cifar100' + model = model_zoo.WideResNet( + num_classes=100, depth=_DEPTH.value, width=_WIDTH.value, + activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN, + std=model_zoo.CIFAR100_STD) + dataset_fn = datasets.CIFAR100 + + # Load model. + if _CKPT.value != 'dummy': + params = torch.load(_CKPT.value) + model.load_state_dict(params) + if _USE_CUDA.value: + model.cuda() + model.eval() + print('Successfully loaded.') + + # Load dataset. + transform_chain = transforms.Compose([transforms.ToTensor()]) + ds = dataset_fn(root='/tmp/data', train=False, transform=transform_chain, + download=True) + test_loader = data.DataLoader(ds, batch_size=_BATCH_SIZE.value, shuffle=False, + num_workers=0) + + # Evaluation. + correct = 0 + total = 0 + batch_count = 0 + total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value) + with torch.no_grad(): + for images, labels in tqdm.tqdm(test_loader, total=total_batches): + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + batch_count += 1 + if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value: + break + print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%') + + +if __name__ == '__main__': + flags.mark_flag_as_required('ckpt') + app.run(main) diff --git a/adversarial_robustness/pytorch/model_zoo.py b/adversarial_robustness/pytorch/model_zoo.py new file mode 100644 index 0000000..fdd84c0 --- /dev/null +++ b/adversarial_robustness/pytorch/model_zoo.py @@ -0,0 +1,164 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WideResNet implementation in PyTorch.""" + +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +CIFAR10_STD = (0.2471, 0.2435, 0.2616) +CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) +CIFAR100_STD = (0.2673, 0.2564, 0.2762) + + +class _Swish(torch.autograd.Function): + """Custom implementation of swish.""" + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(nn.Module): + """Module using custom implementation.""" + + def forward(self, input_tensor): + return _Swish.apply(input_tensor) + + +class _Block(nn.Module): + """WideResNet Block.""" + + def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): + super().__init__() + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=0, bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.relu_1 = activation_fn() + self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.has_shortcut = in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, + stride=stride, padding=0, bias=False) + else: + self.shortcut = None + self._stride = stride + + def forward(self, x): + if self.has_shortcut: + x = self.relu_0(self.batchnorm_0(x)) + else: + out = self.relu_0(self.batchnorm_0(x)) + v = x if self.has_shortcut else out + if self._stride == 1: + v = F.pad(v, (1, 1, 1, 1)) + elif self._stride == 2: + v = F.pad(v, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + out = self.conv_0(v) + out = self.relu_1(self.batchnorm_1(out)) + out = self.conv_1(out) + out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) + return out + + +class _BlockGroup(nn.Module): + """WideResNet block group.""" + + def __init__(self, num_blocks, in_planes, out_planes, stride, + activation_fn=nn.ReLU): + super().__init__() + block = [] + for i in range(num_blocks): + block.append( + _Block(i == 0 and in_planes or out_planes, + out_planes, + i == 0 and stride or 1, + activation_fn=activation_fn)) + self.block = nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class WideResNet(nn.Module): + """WideResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 28, + width: int = 10, + activation_fn: nn.Module = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3): + super().__init__() + self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) + self.std = torch.tensor(std).view(num_input_channels, 1, 1) + self.mean_cuda = None + self.std_cuda = None + self.padding = padding + num_channels = [16, 16 * width, 32 * width, 64 * width] + assert (depth - 4) % 6 == 0 + num_blocks = (depth - 4) // 6 + self.init_conv = nn.Conv2d(num_input_channels, num_channels[0], + kernel_size=3, stride=1, padding=1, bias=False) + self.layer = nn.Sequential( + _BlockGroup(num_blocks, num_channels[0], num_channels[1], 1, + activation_fn=activation_fn), + _BlockGroup(num_blocks, num_channels[1], num_channels[2], 2, + activation_fn=activation_fn), + _BlockGroup(num_blocks, num_channels[2], num_channels[3], 2, + activation_fn=activation_fn)) + self.batchnorm = nn.BatchNorm2d(num_channels[3]) + self.relu = activation_fn() + self.logits = nn.Linear(num_channels[3], num_classes) + self.num_channels = num_channels[3] + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + if x.is_cuda: + if self.mean_cuda is None: + self.mean_cuda = self.mean.cuda() + self.std_cuda = self.std.cuda() + out = (x - self.mean_cuda) / self.std_cuda + else: + out = (x - self.mean) / self.std + out = self.init_conv(out) + out = self.layer(out) + out = self.relu(self.batchnorm(out)) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.num_channels) + return self.logits(out) diff --git a/adversarial_robustness/requirements.txt b/adversarial_robustness/requirements.txt new file mode 100644 index 0000000..160c6fc --- /dev/null +++ b/adversarial_robustness/requirements.txt @@ -0,0 +1,51 @@ +absl-py==0.10.0 +astunparse==1.6.3 +attrs==20.3.0 +cachetools==4.1.1 +certifi==2020.11.8 +chardet==3.0.4 +dataclasses==0.6 +dill==0.3.3 +dm-haiku==0.0.3 +flatbuffers==1.12 +future==0.18.2 +gast==0.3.3 +google-auth==1.23.0 +google-auth-oauthlib==0.4.2 +google-pasta==0.2.0 +googleapis-common-protos==1.52.0 +grpcio==1.33.2 +h5py==2.10.0 +idna==2.10 +importlib-resources==3.3.0 +jax==0.2.6 +jaxlib==0.1.57 +Keras-Preprocessing==1.1.2 +Markdown==3.3.3 +numpy==1.18.5 +oauthlib==3.1.0 +opt-einsum==3.3.0 +Pillow==8.0.1 +promise==2.3 +protobuf==3.14.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +requests==2.25.0 +requests-oauthlib==1.3.0 +rsa==4.6 +scipy==1.5.4 +six==1.15.0 +tensorboard==2.4.0 +tensorboard-plugin-wit==1.7.0 +tensorflow==2.3.1 +tensorflow-datasets==4.1.0 +tensorflow-estimator==2.3.0 +tensorflow-metadata==0.25.0 +termcolor==1.1.0 +torch==1.7.0 +torchvision==0.8.1 +tqdm==4.53.0 +typing-extensions==3.7.4.3 +urllib3==1.26.2 +Werkzeug==1.0.1 +wrapt==1.12.1 diff --git a/adversarial_robustness/run.sh b/adversarial_robustness/run.sh new file mode 100644 index 0000000..e901ed5 --- /dev/null +++ b/adversarial_robustness/run.sh @@ -0,0 +1,33 @@ +#!/bin/sh +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m venv adversarial_robustness_venv +source adversarial_robustness_venv/bin/activate +pip install -r adversarial_robustness/requirements.txt + +python3 -m adversarial_robustness.jax.eval \ + --ckpt=dummy \ + --arch=wrn-10-1 \ + --dataset=cifar10 \ + --batch_size=1 \ + --num_batches=1 + +python3 -m adversarial_robustness.pytorch.eval \ + --ckpt=dummy \ + --arch=wrn-10-1 \ + --dataset=cifar10 \ + --batch_size=1 \ + --num_batches=1 \ + --nouse_cuda