Added pre-activation ResNet-18 checkpoints.

PiperOrigin-RevId: 388199776
This commit is contained in:
Sven Gowal
2021-08-02 13:57:35 +01:00
committed by alimuldal
parent 414683cc12
commit 48290008d1
3 changed files with 122 additions and 8 deletions

View File

@@ -42,10 +42,13 @@ The following table contains the models from **Rebuffi et al., 2021**.
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-106-16 | &#x2717; | 88.50% | 64.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 88.54% | 64.25% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2717; | 87.33% | 60.75% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | ResNet-18 | &#x2717; | 83.53% | 56.66% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.pt)
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-70-16 | &#x2717; | 92.41% | 80.42% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.pt)
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-28-10 | &#x2717; | 91.79% | 78.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.pt)
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | ResNet-18 | &#x2717; | 90.33% | 75.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2717; | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | ResNet-18 | &#x2717; | 56.87% | 28.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.pt)
### Installing

View File

@@ -30,9 +30,9 @@ _DATASET = flags.DEFINE_enum(
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
'Dataset on which the checkpoint is evaluated.')
_WIDTH = flags.DEFINE_integer(
'width', 16, 'Width of WideResNet.')
'width', 16, 'Width of WideResNet (if set to zero uses a PreActResNet).')
_DEPTH = flags.DEFINE_integer(
'depth', 70, 'Depth of WideResNet.')
'depth', 70, 'Depth of WideResNet or PreActResNet.')
_USE_CUDA = flags.DEFINE_boolean(
'use_cuda', True, 'Whether to use CUDA.')
_BATCH_SIZE = flags.DEFINE_integer(
@@ -44,25 +44,30 @@ _NUM_BATCHES = flags.DEFINE_integer(
def main(unused_argv):
print(f'Loading "{_CKPT.value}"')
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
f'{_WIDTH.value}.')
# Create model and dataset.
if _WIDTH.value == 0:
print(f'Using a PreActResNet with depth {_DEPTH.value}.')
model_ctor = model_zoo.PreActResNet
else:
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
f'{_WIDTH.value}.')
model_ctor = model_zoo.WideResNet
if _DATASET.value == 'mnist':
model = model_zoo.WideResNet(
model = model_ctor(
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2,
num_input_channels=1)
dataset_fn = datasets.MNIST
elif _DATASET.value == 'cifar10':
model = model_zoo.WideResNet(
model = model_ctor(
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN,
std=model_zoo.CIFAR10_STD)
dataset_fn = datasets.CIFAR10
else:
assert _DATASET.value == 'cifar100'
model = model_zoo.WideResNet(
model = model_ctor(
num_classes=100, depth=_DEPTH.value, width=_WIDTH.value,
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN,
std=model_zoo.CIFAR100_STD)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""WideResNet implementation in PyTorch."""
"""WideResNet and PreActResNet implementations in PyTorch."""
from typing import Tuple, Union
@@ -162,3 +162,109 @@ class WideResNet(nn.Module):
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.num_channels)
return self.logits(out)
class _PreActBlock(nn.Module):
"""Pre-activation ResNet Block."""
def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
super().__init__()
self._stride = stride
self.batchnorm_0 = nn.BatchNorm2d(in_planes)
self.relu_0 = activation_fn()
# We manually pad to obtain the same effect as `SAME` (necessary when
# `stride` is different than 1).
self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=stride, padding=0, bias=False)
self.batchnorm_1 = nn.BatchNorm2d(out_planes)
self.relu_1 = activation_fn()
self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.has_shortcut = stride != 1 or in_planes != out_planes
if self.has_shortcut:
self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=stride, padding=0, bias=False)
def _pad(self, x):
if self._stride == 1:
x = F.pad(x, (1, 1, 1, 1))
elif self._stride == 2:
x = F.pad(x, (0, 1, 0, 1))
else:
raise ValueError('Unsupported `stride`.')
return x
def forward(self, x):
out = self.relu_0(self.batchnorm_0(x))
shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x
out = self.conv_2d_1(self._pad(out))
out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out)))
return out + shortcut
class PreActResNet(nn.Module):
"""Pre-activation ResNet."""
def __init__(self,
num_classes: int = 10,
depth: int = 18,
width: int = 0, # Used to make the constructor consistent.
activation_fn: nn.Module = nn.ReLU,
mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
std: Union[Tuple[float, ...], float] = CIFAR10_STD,
padding: int = 0,
num_input_channels: int = 3):
super().__init__()
if width != 0:
raise ValueError('Unsupported `width`.')
self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
self.std = torch.tensor(std).view(num_input_channels, 1, 1)
self.mean_cuda = None
self.std_cuda = None
self.padding = padding
self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1,
padding=1, bias=False)
if depth == 18:
num_blocks = (2, 2, 2, 2)
elif depth == 34:
num_blocks = (3, 4, 6, 3)
else:
raise ValueError('Unsupported `depth`.')
self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn)
self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn)
self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn)
self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn)
self.batchnorm = nn.BatchNorm2d(512)
self.relu = activation_fn()
self.logits = nn.Linear(512, num_classes)
def _make_layer(self, in_planes, out_planes, num_blocks, stride,
activation_fn):
layers = []
for i, stride in enumerate([stride] + [1] * (num_blocks - 1)):
layers.append(
_PreActBlock(i == 0 and in_planes or out_planes,
out_planes,
stride,
activation_fn))
return nn.Sequential(*layers)
def forward(self, x):
if self.padding > 0:
x = F.pad(x, (self.padding,) * 4)
if x.is_cuda:
if self.mean_cuda is None:
self.mean_cuda = self.mean.cuda()
self.std_cuda = self.std.cuda()
out = (x - self.mean_cuda) / self.std_cuda
else:
out = (x - self.mean) / self.std
out = self.conv_2d(out)
out = self.layer_0(out)
out = self.layer_1(out)
out = self.layer_2(out)
out = self.layer_3(out)
out = self.relu(self.batchnorm(out))
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
return self.logits(out)