diff --git a/adversarial_robustness/README.md b/adversarial_robustness/README.md
index 0708ffd..f2c87d9 100644
--- a/adversarial_robustness/README.md
+++ b/adversarial_robustness/README.md
@@ -42,10 +42,13 @@ The following table contains the models from **Rebuffi et al., 2021**.
| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-106-16 | ✗ | 88.50% | 64.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 88.54% | 64.25% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | ℓ∞ | 8 / 255 | WRN-28-10 | ✗ | 87.33% | 60.75% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.pt)
+| CIFAR-10 | ℓ∞ | 8 / 255 | ResNet-18 | ✗ | 83.53% | 56.66% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.pt)
| CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 92.41% | 80.42% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | ℓ2 | 128 / 255 | WRN-28-10 | ✗ | 91.79% | 78.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.pt)
+| CIFAR-10 | ℓ2 | 128 / 255 | ResNet-18 | ✗ | 90.33% | 75.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.pt)
| CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
| CIFAR-100 | ℓ∞ | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
+| CIFAR-100 | ℓ∞ | 8 / 255 | ResNet-18 | ✗ | 56.87% | 28.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.pt)
### Installing
diff --git a/adversarial_robustness/pytorch/eval.py b/adversarial_robustness/pytorch/eval.py
index 069a83f..a8140db 100644
--- a/adversarial_robustness/pytorch/eval.py
+++ b/adversarial_robustness/pytorch/eval.py
@@ -30,9 +30,9 @@ _DATASET = flags.DEFINE_enum(
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
'Dataset on which the checkpoint is evaluated.')
_WIDTH = flags.DEFINE_integer(
- 'width', 16, 'Width of WideResNet.')
+ 'width', 16, 'Width of WideResNet (if set to zero uses a PreActResNet).')
_DEPTH = flags.DEFINE_integer(
- 'depth', 70, 'Depth of WideResNet.')
+ 'depth', 70, 'Depth of WideResNet or PreActResNet.')
_USE_CUDA = flags.DEFINE_boolean(
'use_cuda', True, 'Whether to use CUDA.')
_BATCH_SIZE = flags.DEFINE_integer(
@@ -44,25 +44,30 @@ _NUM_BATCHES = flags.DEFINE_integer(
def main(unused_argv):
print(f'Loading "{_CKPT.value}"')
- print(f'Using a WideResNet with depth {_DEPTH.value} and width '
- f'{_WIDTH.value}.')
# Create model and dataset.
+ if _WIDTH.value == 0:
+ print(f'Using a PreActResNet with depth {_DEPTH.value}.')
+ model_ctor = model_zoo.PreActResNet
+ else:
+ print(f'Using a WideResNet with depth {_DEPTH.value} and width '
+ f'{_WIDTH.value}.')
+ model_ctor = model_zoo.WideResNet
if _DATASET.value == 'mnist':
- model = model_zoo.WideResNet(
+ model = model_ctor(
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2,
num_input_channels=1)
dataset_fn = datasets.MNIST
elif _DATASET.value == 'cifar10':
- model = model_zoo.WideResNet(
+ model = model_ctor(
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN,
std=model_zoo.CIFAR10_STD)
dataset_fn = datasets.CIFAR10
else:
assert _DATASET.value == 'cifar100'
- model = model_zoo.WideResNet(
+ model = model_ctor(
num_classes=100, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN,
std=model_zoo.CIFAR100_STD)
diff --git a/adversarial_robustness/pytorch/model_zoo.py b/adversarial_robustness/pytorch/model_zoo.py
index fdd84c0..bf3d037 100644
--- a/adversarial_robustness/pytorch/model_zoo.py
+++ b/adversarial_robustness/pytorch/model_zoo.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""WideResNet implementation in PyTorch."""
+"""WideResNet and PreActResNet implementations in PyTorch."""
from typing import Tuple, Union
@@ -162,3 +162,109 @@ class WideResNet(nn.Module):
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.num_channels)
return self.logits(out)
+
+
+class _PreActBlock(nn.Module):
+ """Pre-activation ResNet Block."""
+
+ def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
+ super().__init__()
+ self._stride = stride
+ self.batchnorm_0 = nn.BatchNorm2d(in_planes)
+ self.relu_0 = activation_fn()
+ # We manually pad to obtain the same effect as `SAME` (necessary when
+ # `stride` is different than 1).
+ self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=stride, padding=0, bias=False)
+ self.batchnorm_1 = nn.BatchNorm2d(out_planes)
+ self.relu_1 = activation_fn()
+ self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
+ padding=1, bias=False)
+ self.has_shortcut = stride != 1 or in_planes != out_planes
+ if self.has_shortcut:
+ self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=stride, padding=0, bias=False)
+
+ def _pad(self, x):
+ if self._stride == 1:
+ x = F.pad(x, (1, 1, 1, 1))
+ elif self._stride == 2:
+ x = F.pad(x, (0, 1, 0, 1))
+ else:
+ raise ValueError('Unsupported `stride`.')
+ return x
+
+ def forward(self, x):
+ out = self.relu_0(self.batchnorm_0(x))
+ shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x
+ out = self.conv_2d_1(self._pad(out))
+ out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out)))
+ return out + shortcut
+
+
+class PreActResNet(nn.Module):
+ """Pre-activation ResNet."""
+
+ def __init__(self,
+ num_classes: int = 10,
+ depth: int = 18,
+ width: int = 0, # Used to make the constructor consistent.
+ activation_fn: nn.Module = nn.ReLU,
+ mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
+ std: Union[Tuple[float, ...], float] = CIFAR10_STD,
+ padding: int = 0,
+ num_input_channels: int = 3):
+ super().__init__()
+ if width != 0:
+ raise ValueError('Unsupported `width`.')
+ self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
+ self.std = torch.tensor(std).view(num_input_channels, 1, 1)
+ self.mean_cuda = None
+ self.std_cuda = None
+ self.padding = padding
+ self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1,
+ padding=1, bias=False)
+ if depth == 18:
+ num_blocks = (2, 2, 2, 2)
+ elif depth == 34:
+ num_blocks = (3, 4, 6, 3)
+ else:
+ raise ValueError('Unsupported `depth`.')
+ self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn)
+ self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn)
+ self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn)
+ self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn)
+ self.batchnorm = nn.BatchNorm2d(512)
+ self.relu = activation_fn()
+ self.logits = nn.Linear(512, num_classes)
+
+ def _make_layer(self, in_planes, out_planes, num_blocks, stride,
+ activation_fn):
+ layers = []
+ for i, stride in enumerate([stride] + [1] * (num_blocks - 1)):
+ layers.append(
+ _PreActBlock(i == 0 and in_planes or out_planes,
+ out_planes,
+ stride,
+ activation_fn))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.padding > 0:
+ x = F.pad(x, (self.padding,) * 4)
+ if x.is_cuda:
+ if self.mean_cuda is None:
+ self.mean_cuda = self.mean.cuda()
+ self.std_cuda = self.std.cuda()
+ out = (x - self.mean_cuda) / self.std_cuda
+ else:
+ out = (x - self.mean) / self.std
+ out = self.conv_2d(out)
+ out = self.layer_0(out)
+ out = self.layer_1(out)
+ out = self.layer_2(out)
+ out = self.layer_3(out)
+ out = self.relu(self.batchnorm(out))
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ return self.logits(out)