diff --git a/adversarial_robustness/README.md b/adversarial_robustness/README.md index 0708ffd..f2c87d9 100644 --- a/adversarial_robustness/README.md +++ b/adversarial_robustness/README.md @@ -42,10 +42,13 @@ The following table contains the models from **Rebuffi et al., 2021**. | CIFAR-10 | ℓ | 8 / 255 | WRN-106-16 | ✗ | 88.50% | 64.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.pt) | CIFAR-10 | ℓ | 8 / 255 | WRN-70-16 | ✗ | 88.54% | 64.25% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.pt) | CIFAR-10 | ℓ | 8 / 255 | WRN-28-10 | ✗ | 87.33% | 60.75% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.pt) +| CIFAR-10 | ℓ | 8 / 255 | ResNet-18 | ✗ | 83.53% | 56.66% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.pt) | CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 92.41% | 80.42% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.pt) | CIFAR-10 | ℓ2 | 128 / 255 | WRN-28-10 | ✗ | 91.79% | 78.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.pt) +| CIFAR-10 | ℓ2 | 128 / 255 | ResNet-18 | ✗ | 90.33% | 75.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.pt) | CIFAR-100 | ℓ | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt) | CIFAR-100 | ℓ | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt) +| CIFAR-100 | ℓ | 8 / 255 | ResNet-18 | ✗ | 56.87% | 28.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.pt) ### Installing diff --git a/adversarial_robustness/pytorch/eval.py b/adversarial_robustness/pytorch/eval.py index 069a83f..a8140db 100644 --- a/adversarial_robustness/pytorch/eval.py +++ b/adversarial_robustness/pytorch/eval.py @@ -30,9 +30,9 @@ _DATASET = flags.DEFINE_enum( 'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'], 'Dataset on which the checkpoint is evaluated.') _WIDTH = flags.DEFINE_integer( - 'width', 16, 'Width of WideResNet.') + 'width', 16, 'Width of WideResNet (if set to zero uses a PreActResNet).') _DEPTH = flags.DEFINE_integer( - 'depth', 70, 'Depth of WideResNet.') + 'depth', 70, 'Depth of WideResNet or PreActResNet.') _USE_CUDA = flags.DEFINE_boolean( 'use_cuda', True, 'Whether to use CUDA.') _BATCH_SIZE = flags.DEFINE_integer( @@ -44,25 +44,30 @@ _NUM_BATCHES = flags.DEFINE_integer( def main(unused_argv): print(f'Loading "{_CKPT.value}"') - print(f'Using a WideResNet with depth {_DEPTH.value} and width ' - f'{_WIDTH.value}.') # Create model and dataset. + if _WIDTH.value == 0: + print(f'Using a PreActResNet with depth {_DEPTH.value}.') + model_ctor = model_zoo.PreActResNet + else: + print(f'Using a WideResNet with depth {_DEPTH.value} and width ' + f'{_WIDTH.value}.') + model_ctor = model_zoo.WideResNet if _DATASET.value == 'mnist': - model = model_zoo.WideResNet( + model = model_ctor( num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2, num_input_channels=1) dataset_fn = datasets.MNIST elif _DATASET.value == 'cifar10': - model = model_zoo.WideResNet( + model = model_ctor( num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN, std=model_zoo.CIFAR10_STD) dataset_fn = datasets.CIFAR10 else: assert _DATASET.value == 'cifar100' - model = model_zoo.WideResNet( + model = model_ctor( num_classes=100, depth=_DEPTH.value, width=_WIDTH.value, activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN, std=model_zoo.CIFAR100_STD) diff --git a/adversarial_robustness/pytorch/model_zoo.py b/adversarial_robustness/pytorch/model_zoo.py index fdd84c0..bf3d037 100644 --- a/adversarial_robustness/pytorch/model_zoo.py +++ b/adversarial_robustness/pytorch/model_zoo.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""WideResNet implementation in PyTorch.""" +"""WideResNet and PreActResNet implementations in PyTorch.""" from typing import Tuple, Union @@ -162,3 +162,109 @@ class WideResNet(nn.Module): out = F.avg_pool2d(out, 8) out = out.view(-1, self.num_channels) return self.logits(out) + + +class _PreActBlock(nn.Module): + """Pre-activation ResNet Block.""" + + def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): + super().__init__() + self._stride = stride + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.relu_1 = activation_fn() + self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.has_shortcut = stride != 1 or in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + + def _pad(self, x): + if self._stride == 1: + x = F.pad(x, (1, 1, 1, 1)) + elif self._stride == 2: + x = F.pad(x, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + return x + + def forward(self, x): + out = self.relu_0(self.batchnorm_0(x)) + shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x + out = self.conv_2d_1(self._pad(out)) + out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) + return out + shortcut + + +class PreActResNet(nn.Module): + """Pre-activation ResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 18, + width: int = 0, # Used to make the constructor consistent. + activation_fn: nn.Module = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3): + super().__init__() + if width != 0: + raise ValueError('Unsupported `width`.') + self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) + self.std = torch.tensor(std).view(num_input_channels, 1, 1) + self.mean_cuda = None + self.std_cuda = None + self.padding = padding + self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, + padding=1, bias=False) + if depth == 18: + num_blocks = (2, 2, 2, 2) + elif depth == 34: + num_blocks = (3, 4, 6, 3) + else: + raise ValueError('Unsupported `depth`.') + self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) + self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) + self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) + self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) + self.batchnorm = nn.BatchNorm2d(512) + self.relu = activation_fn() + self.logits = nn.Linear(512, num_classes) + + def _make_layer(self, in_planes, out_planes, num_blocks, stride, + activation_fn): + layers = [] + for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): + layers.append( + _PreActBlock(i == 0 and in_planes or out_planes, + out_planes, + stride, + activation_fn)) + return nn.Sequential(*layers) + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + if x.is_cuda: + if self.mean_cuda is None: + self.mean_cuda = self.mean.cuda() + self.std_cuda = self.std.cuda() + out = (x - self.mean_cuda) / self.std_cuda + else: + out = (x - self.mean) / self.std + out = self.conv_2d(out) + out = self.layer_0(out) + out = self.layer_1(out) + out = self.layer_2(out) + out = self.layer_3(out) + out = self.relu(self.batchnorm(out)) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + return self.logits(out)