mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Added pre-activation ResNet-18 checkpoints.
PiperOrigin-RevId: 388199776
This commit is contained in:
@@ -42,10 +42,13 @@ The following table contains the models from **Rebuffi et al., 2021**.
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-106-16 | ✗ | 88.50% | 64.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn106-16_cutmix_ddpm_v2.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 88.54% | 64.25% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_cutmix_ddpm_v2.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-28-10 | ✗ | 87.33% | 60.75% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.pt)
|
||||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | ResNet-18 | ✗ | 83.53% | 56.66% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_resnet18_ddpm.pt)
|
||||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-70-16 | ✗ | 92.41% | 80.42% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_cutmix_ddpm_v2.pt)
|
||||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-28-10 | ✗ | 91.79% | 78.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn28-10_cutmix_ddpm_v2.pt)
|
||||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | ResNet-18 | ✗ | 90.33% | 75.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_resnet18_cutmix_ddpm.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | ResNet-18 | ✗ | 56.87% | 28.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_resnet18_ddpm.pt)
|
||||
|
||||
### Installing
|
||||
|
||||
|
||||
@@ -30,9 +30,9 @@ _DATASET = flags.DEFINE_enum(
|
||||
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
|
||||
'Dataset on which the checkpoint is evaluated.')
|
||||
_WIDTH = flags.DEFINE_integer(
|
||||
'width', 16, 'Width of WideResNet.')
|
||||
'width', 16, 'Width of WideResNet (if set to zero uses a PreActResNet).')
|
||||
_DEPTH = flags.DEFINE_integer(
|
||||
'depth', 70, 'Depth of WideResNet.')
|
||||
'depth', 70, 'Depth of WideResNet or PreActResNet.')
|
||||
_USE_CUDA = flags.DEFINE_boolean(
|
||||
'use_cuda', True, 'Whether to use CUDA.')
|
||||
_BATCH_SIZE = flags.DEFINE_integer(
|
||||
@@ -44,25 +44,30 @@ _NUM_BATCHES = flags.DEFINE_integer(
|
||||
|
||||
def main(unused_argv):
|
||||
print(f'Loading "{_CKPT.value}"')
|
||||
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
|
||||
f'{_WIDTH.value}.')
|
||||
|
||||
# Create model and dataset.
|
||||
if _WIDTH.value == 0:
|
||||
print(f'Using a PreActResNet with depth {_DEPTH.value}.')
|
||||
model_ctor = model_zoo.PreActResNet
|
||||
else:
|
||||
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
|
||||
f'{_WIDTH.value}.')
|
||||
model_ctor = model_zoo.WideResNet
|
||||
if _DATASET.value == 'mnist':
|
||||
model = model_zoo.WideResNet(
|
||||
model = model_ctor(
|
||||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=.5, std=.5, padding=2,
|
||||
num_input_channels=1)
|
||||
dataset_fn = datasets.MNIST
|
||||
elif _DATASET.value == 'cifar10':
|
||||
model = model_zoo.WideResNet(
|
||||
model = model_ctor(
|
||||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR10_MEAN,
|
||||
std=model_zoo.CIFAR10_STD)
|
||||
dataset_fn = datasets.CIFAR10
|
||||
else:
|
||||
assert _DATASET.value == 'cifar100'
|
||||
model = model_zoo.WideResNet(
|
||||
model = model_ctor(
|
||||
num_classes=100, depth=_DEPTH.value, width=_WIDTH.value,
|
||||
activation_fn=model_zoo.Swish, mean=model_zoo.CIFAR100_MEAN,
|
||||
std=model_zoo.CIFAR100_STD)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""WideResNet implementation in PyTorch."""
|
||||
"""WideResNet and PreActResNet implementations in PyTorch."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
@@ -162,3 +162,109 @@ class WideResNet(nn.Module):
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.num_channels)
|
||||
return self.logits(out)
|
||||
|
||||
|
||||
class _PreActBlock(nn.Module):
|
||||
"""Pre-activation ResNet Block."""
|
||||
|
||||
def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
|
||||
super().__init__()
|
||||
self._stride = stride
|
||||
self.batchnorm_0 = nn.BatchNorm2d(in_planes)
|
||||
self.relu_0 = activation_fn()
|
||||
# We manually pad to obtain the same effect as `SAME` (necessary when
|
||||
# `stride` is different than 1).
|
||||
self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||
stride=stride, padding=0, bias=False)
|
||||
self.batchnorm_1 = nn.BatchNorm2d(out_planes)
|
||||
self.relu_1 = activation_fn()
|
||||
self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.has_shortcut = stride != 1 or in_planes != out_planes
|
||||
if self.has_shortcut:
|
||||
self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||
stride=stride, padding=0, bias=False)
|
||||
|
||||
def _pad(self, x):
|
||||
if self._stride == 1:
|
||||
x = F.pad(x, (1, 1, 1, 1))
|
||||
elif self._stride == 2:
|
||||
x = F.pad(x, (0, 1, 0, 1))
|
||||
else:
|
||||
raise ValueError('Unsupported `stride`.')
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self.relu_0(self.batchnorm_0(x))
|
||||
shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x
|
||||
out = self.conv_2d_1(self._pad(out))
|
||||
out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out)))
|
||||
return out + shortcut
|
||||
|
||||
|
||||
class PreActResNet(nn.Module):
|
||||
"""Pre-activation ResNet."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 10,
|
||||
depth: int = 18,
|
||||
width: int = 0, # Used to make the constructor consistent.
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
|
||||
std: Union[Tuple[float, ...], float] = CIFAR10_STD,
|
||||
padding: int = 0,
|
||||
num_input_channels: int = 3):
|
||||
super().__init__()
|
||||
if width != 0:
|
||||
raise ValueError('Unsupported `width`.')
|
||||
self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
|
||||
self.std = torch.tensor(std).view(num_input_channels, 1, 1)
|
||||
self.mean_cuda = None
|
||||
self.std_cuda = None
|
||||
self.padding = padding
|
||||
self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
if depth == 18:
|
||||
num_blocks = (2, 2, 2, 2)
|
||||
elif depth == 34:
|
||||
num_blocks = (3, 4, 6, 3)
|
||||
else:
|
||||
raise ValueError('Unsupported `depth`.')
|
||||
self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn)
|
||||
self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn)
|
||||
self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn)
|
||||
self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn)
|
||||
self.batchnorm = nn.BatchNorm2d(512)
|
||||
self.relu = activation_fn()
|
||||
self.logits = nn.Linear(512, num_classes)
|
||||
|
||||
def _make_layer(self, in_planes, out_planes, num_blocks, stride,
|
||||
activation_fn):
|
||||
layers = []
|
||||
for i, stride in enumerate([stride] + [1] * (num_blocks - 1)):
|
||||
layers.append(
|
||||
_PreActBlock(i == 0 and in_planes or out_planes,
|
||||
out_planes,
|
||||
stride,
|
||||
activation_fn))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.padding > 0:
|
||||
x = F.pad(x, (self.padding,) * 4)
|
||||
if x.is_cuda:
|
||||
if self.mean_cuda is None:
|
||||
self.mean_cuda = self.mean.cuda()
|
||||
self.std_cuda = self.std.cuda()
|
||||
out = (x - self.mean_cuda) / self.std_cuda
|
||||
else:
|
||||
out = (x - self.mean) / self.std
|
||||
out = self.conv_2d(out)
|
||||
out = self.layer_0(out)
|
||||
out = self.layer_1(out)
|
||||
out = self.layer_2(out)
|
||||
out = self.layer_3(out)
|
||||
out = self.relu(self.batchnorm(out))
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
return self.logits(out)
|
||||
|
||||
Reference in New Issue
Block a user