Added generated datasets.

PiperOrigin-RevId: 368669515
This commit is contained in:
Sven Gowal
2021-04-15 18:37:55 +01:00
committed by Diego de Las Casas
parent cfbcb1600f
commit e3de1fd90f
3 changed files with 48 additions and 2 deletions
+23 -2
View File
@@ -58,11 +58,32 @@ python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
```
## Generated datasets
Rebuffi et al. (2021) use samples generated by a Denoising Diffusion
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
to improve robustness. The DDPM is solely trained on the original training data
and does not use additional external data. The following table links to datasets
of 1M **generated** samples for CIFAR-10, CIFAR-100 and SVHN.
| dataset | model | size | link |
|---|---|:---:|:---:|
| CIFAR-10 | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
| CIFAR-100 | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_ddpm.npz) |
| SVHN | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/svhn_ddpm.npz) |
To load each dataset, use NumPy. E.g.:
```
npzfile = np.load('cifar10_ddpm.npz')
images = npzfile['image']
labels = npzfile['label']
```
## Citing this work
If you use this code or these models in your work, please cite the relevant
accompanying paper:
If you use this code, data or these models in your work, please cite the
relevant accompanying paper:
```
@article{gowal2020uncovering,
@@ -41,6 +41,27 @@ python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
```
## Generated datasets
This work uses samples generated by a Denoising Diffusion
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
to improve robustness. The DDPM is solely trained on the original training data
and does not use additional external data. The following table links to datasets
of 1M **generated** samples for CIFAR-10, CIFAR-100 and SVHN.
| dataset | model | size | link |
|---|---|:---:|:---:|
| CIFAR-10 | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
| CIFAR-100 | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_ddpm.npz) |
| SVHN | DDPM | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/svhn_ddpm.npz) |
To load each dataset, use NumPy. E.g.:
```
npzfile = np.load('cifar10_ddpm.npz')
images = npzfile['image']
labels = npzfile['label']
```
## Citing this work
+4
View File
@@ -21,12 +21,16 @@ pip install -r adversarial_robustness/requirements.txt
python3 -m adversarial_robustness.jax.eval \
--ckpt=dummy \
--dataset=cifar10 \
--width=1 \
--depth=10 \
--batch_size=1 \
--num_batches=1
python3 -m adversarial_robustness.pytorch.eval \
--ckpt=dummy \
--dataset=cifar10 \
--width=1 \
--depth=10 \
--batch_size=1 \
--num_batches=1 \
--nouse_cuda