diff --git a/adversarial_robustness/README.md b/adversarial_robustness/README.md
new file mode 100644
index 0000000..95558c5
--- /dev/null
+++ b/adversarial_robustness/README.md
@@ -0,0 +1,63 @@
+# Adversarial Robustness
+
+This repository contains the code needed to evaluate models trained in
+[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593)
+
+
+## Contents
+
+We have released our top-performing models in two formats compatible with
+[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
+This repository also contains our model definitions.
+
+## Running the example code
+
+### Downloading a model
+
+Download a model from links listed in the following table.
+Clean and robust accuracies are measured on the full test set.
+The robust accuracy is measured using
+[AutoAttack](https://github.com/fra31/auto-attack).
+
+| dataset | norm | radius | architecture | extra data | clean | robust | link |
+|---|:---:|:---:|:---:|:---:|---:|---:|:---:|
+| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✓ | 91.10% | 65.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt)
+| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-28-10 | ✓ | 89.48% | 62.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
+| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 85.29% | 57.20% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.pt)
+| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-34-20 | ✗ | 85.64% | 56.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.pt)
+| CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✓ | 94.74% | 80.53% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt)
+| CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 90.90% | 74.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt)
+| CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✓ | 69.15% | 36.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt)
+| CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 60.86% | 30.03% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt)
+| MNIST | ℓ∞ | 0.3 | WRN-28-10 | ✗ | 99.26% | 96.34% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.pt)
+
+### Using the model
+
+Once downloaded, a model can be evaluated (clean accuracy) by running the
+`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
+
+```
+cd jax
+python3 eval.py \
+ --ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
+```
+
+
+## Citing this work
+
+If you use this code or these models in your work, please cite the accompanying
+paper:
+
+```
+@article{gowal2020uncovering,
+ title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples},
+ author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet},
+ journal={arXiv preprint arXiv:2010.03593},
+ year={2020},
+ url={https://arxiv.org/pdf/2010.03593}
+}
+```
+
+## Disclaimer
+
+This is not an official Google product.
diff --git a/adversarial_robustness/jax/eval.py b/adversarial_robustness/jax/eval.py
new file mode 100644
index 0000000..3ef8c5c
--- /dev/null
+++ b/adversarial_robustness/jax/eval.py
@@ -0,0 +1,104 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST."""
+
+from absl import app
+from absl import flags
+import haiku as hk
+import numpy as np
+import tensorflow.compat.v2 as tf
+import tensorflow_datasets as tfds
+import tqdm
+
+from adversarial_robustness.jax import model_zoo
+
+_CKPT = flags.DEFINE_string(
+ 'ckpt', None, 'Path to checkpoint.')
+_DATASET = flags.DEFINE_enum(
+ 'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
+ 'Dataset on which the checkpoint is evaluated.')
+_WIDTH = flags.DEFINE_integer(
+ 'width', 16, 'Width of WideResNet.')
+_DEPTH = flags.DEFINE_integer(
+ 'depth', 70, 'Depth of WideResNet.')
+_BATCH_SIZE = flags.DEFINE_integer(
+ 'batch_size', 100, 'Batch size.')
+_NUM_BATCHES = flags.DEFINE_integer(
+ 'num_batches', 0,
+ 'Number of batches to evaluate (zero means the whole dataset).')
+
+
+def main(unused_argv):
+ print(f'Loading "{_CKPT.value}"')
+ print(f'Using a WideResNet with depth {_DEPTH.value} and width '
+ f'{_WIDTH.value}.')
+
+ # Create dataset.
+ if _DATASET.value == 'mnist':
+ _, data_test = tf.keras.datasets.mnist.load_data()
+ normalize_fn = model_zoo.mnist_normalize
+ elif _DATASET.value == 'cifar10':
+ _, data_test = tf.keras.datasets.cifar10.load_data()
+ normalize_fn = model_zoo.cifar10_normalize
+ else:
+ assert _DATASET.value == 'cifar100'
+ _, data_test = tf.keras.datasets.cifar100.load_data()
+ normalize_fn = model_zoo.cifar100_normalize
+
+ # Create model.
+ @hk.transform_with_state
+ def model_fn(x, is_training=False):
+ model = model_zoo.WideResNet(
+ num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
+ activation='swish')
+ return model(normalize_fn(x), is_training=is_training)
+
+ # Build dataset.
+ images, labels = data_test
+ samples = (images.astype(np.float32) / 255.,
+ np.squeeze(labels, axis=-1).astype(np.int64))
+ data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value)
+ test_loader = tfds.as_numpy(data)
+
+ # Load model parameters.
+ rng_seq = hk.PRNGSequence(0)
+ if _CKPT.value == 'dummy':
+ for images, _ in test_loader:
+ break
+ params, state = model_fn.init(next(rng_seq), images, is_training=True)
+ # Reset iterator.
+ test_loader = tfds.as_numpy(data)
+ else:
+ params, state = np.load(_CKPT.value, allow_pickle=True)
+
+ # Evaluation.
+ correct = 0
+ total = 0
+ batch_count = 0
+ total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
+ for images, labels in tqdm.tqdm(test_loader, total=total_batches):
+ outputs = model_fn.apply(params, state, next(rng_seq), images)[0]
+ predicted = np.argmax(outputs, 1)
+ total += labels.shape[0]
+ correct += (predicted == labels).sum().item()
+ batch_count += 1
+ if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
+ break
+ print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('ckpt')
+ app.run(main)
diff --git a/adversarial_robustness/jax/model_zoo.py b/adversarial_robustness/jax/model_zoo.py
new file mode 100644
index 0000000..af54646
--- /dev/null
+++ b/adversarial_robustness/jax/model_zoo.py
@@ -0,0 +1,165 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""WideResNet implementation in JAX using Haiku."""
+
+from typing import Any, Mapping, Optional, Text
+
+import haiku as hk
+import jax
+import jax.numpy as jnp
+
+
+CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
+CIFAR10_STD = (0.2471, 0.2435, 0.2616)
+CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
+CIFAR100_STD = (0.2673, 0.2564, 0.2762)
+
+
+class _WideResNetBlock(hk.Module):
+ """Block of a WideResNet."""
+
+ def __init__(self, num_filters, stride=1, projection_shortcut=False,
+ activation=jax.nn.relu, norm_args=None, name=None):
+ super().__init__(name=name)
+ num_bottleneck_layers = 1
+ self._activation = activation
+ if norm_args is None:
+ norm_args = {
+ 'create_offset': False,
+ 'create_scale': True,
+ 'decay_rate': .99,
+ }
+ self._bn_modules = []
+ self._conv_modules = []
+ for i in range(num_bottleneck_layers + 1):
+ s = stride if i == 0 else 1
+ self._bn_modules.append(hk.BatchNorm(
+ name='batchnorm_{}'.format(i),
+ **norm_args))
+ self._conv_modules.append(hk.Conv2D(
+ output_channels=num_filters,
+ padding='SAME',
+ kernel_shape=(3, 3),
+ stride=s,
+ with_bias=False,
+ name='conv_{}'.format(i))) # pytype: disable=not-callable
+ if projection_shortcut:
+ self._shortcut = hk.Conv2D(
+ output_channels=num_filters,
+ kernel_shape=(1, 1),
+ stride=stride,
+ with_bias=False,
+ name='shortcut') # pytype: disable=not-callable
+ else:
+ self._shortcut = None
+
+ def __call__(self, inputs, **norm_kwargs):
+ x = inputs
+ orig_x = inputs
+ for i, (bn, conv) in enumerate(zip(self._bn_modules, self._conv_modules)):
+ x = bn(x, **norm_kwargs)
+ x = self._activation(x)
+ if self._shortcut is not None and i == 0:
+ orig_x = x
+ x = conv(x)
+ if self._shortcut is not None:
+ shortcut_x = self._shortcut(orig_x)
+ x += shortcut_x
+ else:
+ x += orig_x
+ return x
+
+
+class WideResNet(hk.Module):
+ """WideResNet designed for CIFAR-10."""
+
+ def __init__(self,
+ num_classes: int = 10,
+ depth: int = 28,
+ width: int = 10,
+ activation: Text = 'relu',
+ norm_args: Optional[Mapping[Text, Any]] = None,
+ name: Optional[Text] = None):
+ super(WideResNet, self).__init__(name=name)
+ if (depth - 4) % 6 != 0:
+ raise ValueError('depth should be 6n+4.')
+ self._activation = getattr(jax.nn, activation)
+ if norm_args is None:
+ norm_args = {
+ 'create_offset': False,
+ 'create_scale': True,
+ 'decay_rate': .99,
+ }
+ self._conv = hk.Conv2D(
+ output_channels=16,
+ kernel_shape=(3, 3),
+ stride=1,
+ with_bias=False,
+ name='init_conv') # pytype: disable=not-callable
+ self._bn = hk.BatchNorm(
+ name='batchnorm',
+ **norm_args)
+ self._linear = hk.Linear(
+ num_classes,
+ name='logits')
+
+ blocks_per_layer = (depth - 4) // 6
+ filter_sizes = [width * n for n in [16, 32, 64]]
+ self._blocks = []
+ for layer_num, filter_size in enumerate(filter_sizes):
+ blocks_of_layer = []
+ for i in range(blocks_per_layer):
+ stride = 2 if (layer_num != 0 and i == 0) else 1
+ projection_shortcut = (i == 0)
+ blocks_of_layer.append(_WideResNetBlock(
+ num_filters=filter_size,
+ stride=stride,
+ projection_shortcut=projection_shortcut,
+ activation=self._activation,
+ norm_args=norm_args,
+ name='resnet_lay_{}_block_{}'.format(layer_num, i)))
+ self._blocks.append(blocks_of_layer)
+
+ def __call__(self, inputs, **norm_kwargs):
+ net = inputs
+ net = self._conv(net)
+
+ # Blocks.
+ for blocks_of_layer in self._blocks:
+ for block in blocks_of_layer:
+ net = block(net, **norm_kwargs)
+ net = self._bn(net, **norm_kwargs)
+ net = self._activation(net)
+
+ net = jnp.mean(net, axis=[1, 2])
+ return self._linear(net)
+
+
+def mnist_normalize(image: jnp.array) -> jnp.array:
+ image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
+ constant_values=0)
+ return (image - .5) * 2.
+
+
+def cifar10_normalize(image: jnp.array) -> jnp.array:
+ means = jnp.array(CIFAR10_MEAN, dtype=image.dtype)
+ stds = jnp.array(CIFAR10_STD, dtype=image.dtype)
+ return (image - means) / stds
+
+
+def cifar100_normalize(image: jnp.array) -> jnp.array:
+ means = jnp.array(CIFAR100_MEAN, dtype=image.dtype)
+ stds = jnp.array(CIFAR100_STD, dtype=image.dtype)
+ return (image - means) / stds
diff --git a/adversarial_robustness/pytorch/eval.py b/adversarial_robustness/pytorch/eval.py
new file mode 100644
index 0000000..069a83f
--- /dev/null
+++ b/adversarial_robustness/pytorch/eval.py
@@ -0,0 +1,106 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluates a PyTorch checkpoint on CIFAR-10/100 or MNIST."""
+
+from absl import app
+from absl import flags
+import torch
+from torch.utils import data
+from torchvision import datasets
+from torchvision import transforms
+import tqdm
+
+from adversarial_robustness.pytorch import model_zoo
+
+_CKPT = flags.DEFINE_string(
+ 'ckpt', None, 'Path to checkpoint.')
+_DATASET = flags.DEFINE_enum(
+ 'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
+ 'Dataset on which the checkpoint is evaluated.')
+_WIDTH = flags.DEFINE_integer(
+ 'width', 16, 'Width of WideResNet.')
+_DEPTH = flags.DEFINE_integer(
+ 'depth', 70, 'Depth of WideResNet.')
+_USE_CUDA = flags.DEFINE_boolean(
+ 'use_cuda', True, 'Whether to use CUDA.')
+_BATCH_SIZE = flags.DEFINE_integer(
+ 'batch_size', 100, 'Batch size.')
+_NUM_BATCHES = flags.DEFINE_integer(
+ 'num_batches', 0,
+ 'Number of batches to evaluate (zero means the whole dataset).')
+
+
+def main(unused_argv):
+ print(f'Loading "{_CKPT.value}"')
+ print(f'Using a WideResNet with depth {_DEPTH.value} and width '
+ f'{_WIDTH.value}.')
+
+ # Create model and dataset.
+ if _DATASET.value == 'mnist':
+ model = model_zoo.WideResNet(
+ num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
+ activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2,
+ num_input_channels=1)
+ dataset_fn = datasets.MNIST
+ elif _DATASET.value == 'cifar10':
+ model = model_zoo.WideResNet(
+ num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
+ activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN,
+ std=model_zoo.CIFAR10_STD)
+ dataset_fn = datasets.CIFAR10
+ else:
+ assert _DATASET.value == 'cifar100'
+ model = model_zoo.WideResNet(
+ num_classes=100, depth=_DEPTH.value, width=_WIDTH.value,
+ activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN,
+ std=model_zoo.CIFAR100_STD)
+ dataset_fn = datasets.CIFAR100
+
+ # Load model.
+ if _CKPT.value != 'dummy':
+ params = torch.load(_CKPT.value)
+ model.load_state_dict(params)
+ if _USE_CUDA.value:
+ model.cuda()
+ model.eval()
+ print('Successfully loaded.')
+
+ # Load dataset.
+ transform_chain = transforms.Compose([transforms.ToTensor()])
+ ds = dataset_fn(root='/tmp/data', train=False, transform=transform_chain,
+ download=True)
+ test_loader = data.DataLoader(ds, batch_size=_BATCH_SIZE.value, shuffle=False,
+ num_workers=0)
+
+ # Evaluation.
+ correct = 0
+ total = 0
+ batch_count = 0
+ total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
+ with torch.no_grad():
+ for images, labels in tqdm.tqdm(test_loader, total=total_batches):
+ outputs = model(images)
+ _, predicted = torch.max(outputs.data, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+ batch_count += 1
+ if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
+ break
+ print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('ckpt')
+ app.run(main)
diff --git a/adversarial_robustness/pytorch/model_zoo.py b/adversarial_robustness/pytorch/model_zoo.py
new file mode 100644
index 0000000..fdd84c0
--- /dev/null
+++ b/adversarial_robustness/pytorch/model_zoo.py
@@ -0,0 +1,164 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""WideResNet implementation in PyTorch."""
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
+CIFAR10_STD = (0.2471, 0.2435, 0.2616)
+CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
+CIFAR100_STD = (0.2673, 0.2564, 0.2762)
+
+
+class _Swish(torch.autograd.Function):
+ """Custom implementation of swish."""
+
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_variables[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class Swish(nn.Module):
+ """Module using custom implementation."""
+
+ def forward(self, input_tensor):
+ return _Swish.apply(input_tensor)
+
+
+class _Block(nn.Module):
+ """WideResNet Block."""
+
+ def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
+ super().__init__()
+ self.batchnorm_0 = nn.BatchNorm2d(in_planes)
+ self.relu_0 = activation_fn()
+ # We manually pad to obtain the same effect as `SAME` (necessary when
+ # `stride` is different than 1).
+ self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=0, bias=False)
+ self.batchnorm_1 = nn.BatchNorm2d(out_planes)
+ self.relu_1 = activation_fn()
+ self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
+ padding=1, bias=False)
+ self.has_shortcut = in_planes != out_planes
+ if self.has_shortcut:
+ self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1,
+ stride=stride, padding=0, bias=False)
+ else:
+ self.shortcut = None
+ self._stride = stride
+
+ def forward(self, x):
+ if self.has_shortcut:
+ x = self.relu_0(self.batchnorm_0(x))
+ else:
+ out = self.relu_0(self.batchnorm_0(x))
+ v = x if self.has_shortcut else out
+ if self._stride == 1:
+ v = F.pad(v, (1, 1, 1, 1))
+ elif self._stride == 2:
+ v = F.pad(v, (0, 1, 0, 1))
+ else:
+ raise ValueError('Unsupported `stride`.')
+ out = self.conv_0(v)
+ out = self.relu_1(self.batchnorm_1(out))
+ out = self.conv_1(out)
+ out = torch.add(self.shortcut(x) if self.has_shortcut else x, out)
+ return out
+
+
+class _BlockGroup(nn.Module):
+ """WideResNet block group."""
+
+ def __init__(self, num_blocks, in_planes, out_planes, stride,
+ activation_fn=nn.ReLU):
+ super().__init__()
+ block = []
+ for i in range(num_blocks):
+ block.append(
+ _Block(i == 0 and in_planes or out_planes,
+ out_planes,
+ i == 0 and stride or 1,
+ activation_fn=activation_fn))
+ self.block = nn.Sequential(*block)
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class WideResNet(nn.Module):
+ """WideResNet."""
+
+ def __init__(self,
+ num_classes: int = 10,
+ depth: int = 28,
+ width: int = 10,
+ activation_fn: nn.Module = nn.ReLU,
+ mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
+ std: Union[Tuple[float, ...], float] = CIFAR10_STD,
+ padding: int = 0,
+ num_input_channels: int = 3):
+ super().__init__()
+ self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
+ self.std = torch.tensor(std).view(num_input_channels, 1, 1)
+ self.mean_cuda = None
+ self.std_cuda = None
+ self.padding = padding
+ num_channels = [16, 16 * width, 32 * width, 64 * width]
+ assert (depth - 4) % 6 == 0
+ num_blocks = (depth - 4) // 6
+ self.init_conv = nn.Conv2d(num_input_channels, num_channels[0],
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.layer = nn.Sequential(
+ _BlockGroup(num_blocks, num_channels[0], num_channels[1], 1,
+ activation_fn=activation_fn),
+ _BlockGroup(num_blocks, num_channels[1], num_channels[2], 2,
+ activation_fn=activation_fn),
+ _BlockGroup(num_blocks, num_channels[2], num_channels[3], 2,
+ activation_fn=activation_fn))
+ self.batchnorm = nn.BatchNorm2d(num_channels[3])
+ self.relu = activation_fn()
+ self.logits = nn.Linear(num_channels[3], num_classes)
+ self.num_channels = num_channels[3]
+
+ def forward(self, x):
+ if self.padding > 0:
+ x = F.pad(x, (self.padding,) * 4)
+ if x.is_cuda:
+ if self.mean_cuda is None:
+ self.mean_cuda = self.mean.cuda()
+ self.std_cuda = self.std.cuda()
+ out = (x - self.mean_cuda) / self.std_cuda
+ else:
+ out = (x - self.mean) / self.std
+ out = self.init_conv(out)
+ out = self.layer(out)
+ out = self.relu(self.batchnorm(out))
+ out = F.avg_pool2d(out, 8)
+ out = out.view(-1, self.num_channels)
+ return self.logits(out)
diff --git a/adversarial_robustness/requirements.txt b/adversarial_robustness/requirements.txt
new file mode 100644
index 0000000..160c6fc
--- /dev/null
+++ b/adversarial_robustness/requirements.txt
@@ -0,0 +1,51 @@
+absl-py==0.10.0
+astunparse==1.6.3
+attrs==20.3.0
+cachetools==4.1.1
+certifi==2020.11.8
+chardet==3.0.4
+dataclasses==0.6
+dill==0.3.3
+dm-haiku==0.0.3
+flatbuffers==1.12
+future==0.18.2
+gast==0.3.3
+google-auth==1.23.0
+google-auth-oauthlib==0.4.2
+google-pasta==0.2.0
+googleapis-common-protos==1.52.0
+grpcio==1.33.2
+h5py==2.10.0
+idna==2.10
+importlib-resources==3.3.0
+jax==0.2.6
+jaxlib==0.1.57
+Keras-Preprocessing==1.1.2
+Markdown==3.3.3
+numpy==1.18.5
+oauthlib==3.1.0
+opt-einsum==3.3.0
+Pillow==8.0.1
+promise==2.3
+protobuf==3.14.0
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+requests==2.25.0
+requests-oauthlib==1.3.0
+rsa==4.6
+scipy==1.5.4
+six==1.15.0
+tensorboard==2.4.0
+tensorboard-plugin-wit==1.7.0
+tensorflow==2.3.1
+tensorflow-datasets==4.1.0
+tensorflow-estimator==2.3.0
+tensorflow-metadata==0.25.0
+termcolor==1.1.0
+torch==1.7.0
+torchvision==0.8.1
+tqdm==4.53.0
+typing-extensions==3.7.4.3
+urllib3==1.26.2
+Werkzeug==1.0.1
+wrapt==1.12.1
diff --git a/adversarial_robustness/run.sh b/adversarial_robustness/run.sh
new file mode 100644
index 0000000..e901ed5
--- /dev/null
+++ b/adversarial_robustness/run.sh
@@ -0,0 +1,33 @@
+#!/bin/sh
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+python3 -m venv adversarial_robustness_venv
+source adversarial_robustness_venv/bin/activate
+pip install -r adversarial_robustness/requirements.txt
+
+python3 -m adversarial_robustness.jax.eval \
+ --ckpt=dummy \
+ --arch=wrn-10-1 \
+ --dataset=cifar10 \
+ --batch_size=1 \
+ --num_batches=1
+
+python3 -m adversarial_robustness.pytorch.eval \
+ --ckpt=dummy \
+ --arch=wrn-10-1 \
+ --dataset=cifar10 \
+ --batch_size=1 \
+ --num_batches=1 \
+ --nouse_cuda