Added jaxline pipeline to train adversarially robust models.

PiperOrigin-RevId: 383399487
This commit is contained in:
Sven Gowal
2021-07-07 13:36:13 +00:00
committed by Louise Deason
parent d8df4155dc
commit 5909da5388
36 changed files with 5229 additions and 79 deletions

View File

@@ -13,7 +13,7 @@ We have released our top-performing models in two formats compatible with
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
This repository also contains our model definitions.
## Running the example code
## Running the code
### Downloading a model
@@ -47,10 +47,32 @@ The following table contains the models from **Rebuffi et al., 2021**.
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2717; | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
### Using the model
### Installing
Once downloaded, a model can be evaluated (clean accuracy) by running the
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
The following has been tested using Python 3.9.2.
Using `run.sh` will create and activate a virtualenv, install all necessary
dependencies and run a test program to ensure that you can import all the
modules.
```
# Run from the parent directory.
sh adversarial_robustness/run.sh
```
To run the provided code, use this virtualenv:
```
source /tmp/adversarial_robustness_venv/bin/activate
```
You may want to edit `requirements.txt` before running `run.sh` if GPU support
is needed (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation
[instructions](https://github.com/google/jax#installation) for more details.
### Using pre-trained models
Once downloaded, a model can be evaluated by running the `eval.py` script in
either the `jax` or `pytorch` folders. E.g.:
```
cd jax
@@ -58,7 +80,47 @@ python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
```
## Generated datasets
These models are also directly available within
[RobustBench](https://github.com/RobustBench/robustbench#model-zoo-quick-tour)'s
model zoo.
### Training your own model
We also provide a training pipeline that reproduces results from both
publications. This pipeline uses [Jaxline](https://github.com/deepmind/jaxline)
and is written using [JAX](https://github.com/google/jax) and
[Haiku](https://github.com/deepmind/dm-haiku). To train a model, modify the
configuration in the `get_config()` function of `jax/experiment.py` and issue
the following command from within the virtualenv created above:
```
cd jax
python3 train.py --config=experiment.py
```
The training pipeline can run with multiple worker machines and multiple devices
(either GPU or TPU). See [Jaxline](https://github.com/deepmind/jaxline) for more
details.
We do not provide a PyTorch implementation of our training pipeline. However,
you may find one on GitHub, e.g.,
[adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch)
(by Rahul Rade).
## Datasets
### Extracted dataset
Gowal et al. (2020) use samples extracted from
[TinyImages-80M](https://groups.csail.mit.edu/vision/TinyImages/).
Unfortunately, since then, the official TinyImages-80M dataset has been
withdrawn (due to the presence of offensive images). As such, we cannot provide
a download link to our extrated data until we have manually verified that all
extracted images are not offensive. If you want to reproduce our setup, consider
the generated datasets below. We are also happy to help, so feel free to reach
out to Sven Gowal directly.
### Generated datasets
Rebuffi et al. (2021) use samples generated by a Denoising Diffusion
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
@@ -82,8 +144,8 @@ labels = npzfile['label']
## Citing this work
If you use this code, data or these models in your work, please cite the
relevant accompanying paper:
If you use this code (or any derived code), data or these models in your work,
please cite the relevant accompanying paper:
```
@article{gowal2020uncovering,
@@ -95,7 +157,7 @@ relevant accompanying paper:
}
```
or
and/or
```
@article{rebuffi2021fixing,

View File

@@ -0,0 +1,310 @@
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adversarial attacks.
This file contains all the code necessary to create untargeted adversarial
attacks in JAX (within an l-infinity ball). For example, to create an untargeted
FGSM attack (with a single step), one can do the following:
```
import attacks
epsilon = 8/255 # Perturbation radius for inputs between 0 and 1.
fgsm_attack = attacks.UntargetedAttack(
attacks.PGD(
attacks.IteratedFGSM(epsilon),
num_steps=1,
initialize_fn=attacks.linf_initialize_fn(epsilon),
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
loss_fn=attacks.untargeted_cross_entropy)
```
Just as elegantly, one can specify an adversarial attack on KL-divergence
to a target distribution (using 10 steps with Adam and a piecewise constant step
schedule):
```
kl_attack_with_adam = attacks.UntargetedAttack(
attacks.PGD(
attacks.Adam(optax.piecewise_constant_schedule(
init_value=.1,
boundaries_and_scales={5: .1})),
num_steps=10,
initialize_fn=attacks.linf_initialize_fn(epsilon),
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
loss_fn=attacks.untargeted_kl_divergence)
```
The attack instances can be used later on to build adversarial examples:
```
my_model = ... # Model. We assume that 'my_model(.)' returns logits.
clean_images, image_labels = ... # Batch of images and associated labels.
rng = jax.random.PRNGKey(0) # A random generator state.
adversarial_images = fgsm_attack(my_model, rng, clean_images, image_labels)
```
See `experiment.py` or `eval.py` for more examples.
This file contains the following components:
* Losses:
* untargeted_cross_entropy: minimizes the likelihood of the label class.
* untargeted_kl_divergence: maximizes the KL-divergence of the predictions with
a target distribution.
* untargeted_margin: maximizes the margin loss (distance from the highest
non-true logits to the label class logit)
* Step optimizers:
* SGD: Stochastic Gradient Descent.
* IteratedFGSM: Also called BIM (see https://arxiv.org/pdf/1607.02533).
* Adam: See https://arxiv.org/pdf/1412.6980.
* Initialization and projection functions:
* linf_initialize_fn: Initialize function for l-infinity attacks.
* linf_project_fn: Projection function for l-infinity attacks.
* Projected Gradient Descent (PGD):
* PGD: Runs Projected Gradient Descent using the specified optimizer,
initialization and projection functions for a given number of steps.
* Untargeted attack:
* UntargetedAttack: Combines PGD and a specific loss function to find
adversarial examples.
"""
import functools
import inspect
from typing import Callable, Optional, Tuple, Union
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import optax
ModelFn = Callable[[chex.Array], chex.Array]
LossFn = Callable[[chex.Array], chex.Array]
ClassificationLossFn = Callable[[chex.Array, chex.Array], chex.Array]
OptimizeFn = Callable[[LossFn, chex.PRNGKey, chex.Array], chex.Array]
NormalizeFn = Callable[[chex.Array], chex.Array]
InitializeFn = Callable[[chex.PRNGKey, chex.Array], chex.Array]
ProjectFn = Callable[[chex.Array, chex.Array], chex.Array]
def untargeted_cross_entropy(logits: chex.Array,
labels: chex.Array) -> chex.Array:
"""Maximize the cross-entropy of the true class (make it less likely)."""
num_classes = logits.shape[-1]
log_probs = jax.nn.log_softmax(logits)
return jnp.sum(
hk.one_hot(labels, num_classes).astype(logits.dtype) * log_probs, axis=-1)
def untargeted_kl_divergence(logits: chex.Array,
label_probs: chex.Array) -> chex.Array:
"""Maximize the KL divergence between logits and label distribution."""
# We are explicitly maximizing the cross-entropy, as this is equivalent to
# maximizing the KL divergence (when `label_probs` does not depend
# on the values that produce `logits`).
log_probs = jax.nn.log_softmax(logits)
return jnp.sum(label_probs * log_probs, axis=-1)
def untargeted_margin(logits: chex.Array,
labels: chex.Array) -> chex.Array:
"""Make the highest non-correct logits higher than the true class logits."""
batch_size = logits.shape[0]
num_classes = logits.shape[-1]
label_logits = logits[jnp.arange(batch_size), labels]
logit_mask = hk.one_hot(labels, num_classes).astype(logits.dtype)
highest_logits = jnp.max(logits - 1e8 * logit_mask, axis=-1)
return label_logits - highest_logits
class UntargetedAttack:
"""Performs an untargeted attack."""
def __init__(self,
optimize_fn: OptimizeFn,
loss_fn: ClassificationLossFn = untargeted_cross_entropy):
"""Creates an untargeted attack.
Args:
optimize_fn: An `Optimizer` instance or any callable that takes
a loss function and an initial input and outputs a new input that
minimizes the loss function.
loss_fn: `loss_fn` is a surrogate loss. Its goal should be make the true
class less likely than any other class. Typical options for `loss_fn`
are `untargeted_cross_entropy` or `untargeted_margin`.
"""
self._optimize_fn = optimize_fn
self._loss_fn = loss_fn
def __call__(self,
logits_fn: ModelFn,
rng: chex.PRNGKey,
inputs: chex.Array,
labels: chex.Array) -> chex.Array:
"""Returns adversarial inputs."""
def _loss_fn(x):
return self._loss_fn(logits_fn(x), labels)
return self._optimize_fn(_loss_fn, rng, inputs)
# Convenience functions to detect the type of inputs required by the loss.
def expects_labels(self):
return 'labels' in inspect.getfullargspec(self._loss_fn).args
def expects_probabilities(self):
return 'label_probs' in inspect.getfullargspec(self._loss_fn).args
class StepOptimizer:
"""Makes a single gradient step that minimizes a loss function."""
def __init__(self,
gradient_transformation: optax.GradientTransformation):
self._gradient_transformation = gradient_transformation
def init(self,
loss_fn: LossFn,
x: chex.Array) -> optax.OptState:
self._loss_fn = loss_fn
return self._gradient_transformation.init(x)
def minimize(
self,
x: chex.Array,
state: optax.OptState) -> Tuple[chex.Array, chex.Array, optax.OptState]:
"""Performs a single minimization step."""
g, loss = gradients_fn(self._loss_fn, x)
if g is None:
raise ValueError('loss_fn does not depend on input.')
updates, state = self._gradient_transformation.update(g, state, x)
return optax.apply_updates(x, updates), loss, state
class SGD(StepOptimizer):
"""Vanilla gradient descent optimizer."""
def __init__(self,
learning_rate_fn: Union[float, int, optax.Schedule],
normalize_fn: Optional[NormalizeFn] = None):
# Accept schedules, as well as scalar values.
if isinstance(learning_rate_fn, (float, int)):
lr = float(learning_rate_fn)
learning_rate_fn = lambda _: lr
# Normalization.
def update_fn(updates, state, params=None):
del params
updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
return updates, state
gradient_transformation = optax.chain(
optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn),
optax.scale_by_schedule(learning_rate_fn),
optax.scale(-1.))
super(SGD, self).__init__(gradient_transformation)
class IteratedFGSM(SGD):
"""L-infinity normalized steps."""
def __init__(self,
learning_rate_fn: Union[float, int, optax.Schedule]):
super(IteratedFGSM, self).__init__(learning_rate_fn, jnp.sign)
class Adam(StepOptimizer):
"""The Adam optimizer defined in https://arxiv.org/abs/1412.6980."""
def __init__(
self,
learning_rate_fn: Union[float, int, optax.Schedule],
normalize_fn: Optional[NormalizeFn] = None,
beta1: float = .9,
beta2: float = .999,
epsilon: float = 1e-9):
# Accept schedules, as well as scalar values.
if isinstance(learning_rate_fn, (float, int)):
lr = float(learning_rate_fn)
learning_rate_fn = lambda _: lr
# Normalization.
def update_fn(updates, state, params=None):
del params
updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
return updates, state
gradient_transformation = optax.chain(
optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn),
optax.scale_by_adam(b1=beta1, b2=beta2, eps=epsilon),
optax.scale_by_schedule(learning_rate_fn),
optax.scale(-1.))
super(Adam, self).__init__(gradient_transformation)
class PGD:
"""Runs Project Gradient Descent (see https://arxiv.org/pdf/1706.06083)."""
def __init__(self,
optimizer: StepOptimizer,
num_steps: int,
initialize_fn: Optional[InitializeFn] = None,
project_fn: Optional[ProjectFn] = None):
self._optimizer = optimizer
if initialize_fn is None:
initialize_fn = lambda rng, x: x
self._initialize_fn = initialize_fn
if project_fn is None:
project_fn = lambda x, origin_x: x
self._project_fn = project_fn
self._num_steps = num_steps
def __call__(self,
loss_fn: LossFn,
rng: chex.PRNGKey,
x: chex.Array) -> chex.Array:
def _optimize(rng, x):
"""Optimizes loss_fn when keep_best is False."""
def body_fn(_, inputs):
opt_state, current_x = inputs
current_x, _, opt_state = self._optimizer.minimize(current_x, opt_state)
current_x = self._project_fn(current_x, x)
return opt_state, current_x
opt_state = self._optimizer.init(loss_fn, x)
current_x = self._project_fn(self._initialize_fn(rng, x), x)
_, current_x = jax.lax.fori_loop(0, self._num_steps, body_fn,
(opt_state, current_x))
return current_x
return jax.lax.stop_gradient(_optimize(rng, x))
def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn:
def project_fn(x, origin_x):
dx = jnp.clip(x - origin_x, -epsilon, epsilon)
return jnp.clip(origin_x + dx, bounds[0], bounds[1])
return project_fn
def linf_initialize_fn(epsilon: float) -> InitializeFn:
def initialize_fn(rng, x):
return x + jax.random.uniform(rng, x.shape, minval=-epsilon,
maxval=epsilon).astype(x.dtype)
return initialize_fn
def gradients_fn(loss_fn: LossFn,
x: chex.Array) -> Tuple[chex.Array, chex.Array]:
"""Returns the analytical gradient as computed by `jax.grad`."""
@functools.partial(jax.grad, has_aux=True)
def grad_reduced_loss_fn(x):
loss = loss_fn(x)
return jnp.sum(loss), loss
return grad_reduced_loss_fn(x)

View File

@@ -0,0 +1,180 @@
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datasets."""
from typing import Sequence
import chex
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
_CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
_CIFAR10_STD = (0.2471, 0.2435, 0.2616)
_CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
_CIFAR100_STD = (0.2673, 0.2564, 0.2762)
_DATA_URL = 'https://storage.googleapis.com/dm-adversarial-robustness/'
_ALLOWED_FILES = ('cifar10_ddpm.npz',)
_WEBPAGE = ('https://github.com/deepmind/deepmind-research/tree/master/'
'adversarial_robustness')
def cifar10_preprocess(mode: str = 'train'):
"""Preprocessing functions for CIFAR-10."""
def _preprocess_fn_train(example):
"""Preprocessing of CIFAR-10 images for training."""
image = example['image']
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = _random_jitter(image, pad=4, crop=32)
image = tf.image.random_flip_left_right(image)
label = tf.cast(example['label'], tf.int32)
return {'image': image, 'label': label}
def _preprocess_fn_test(example):
"""Preprocessing of CIFAR-10 images for testing."""
image = example['image']
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
label = tf.cast(example['label'], tf.int32)
return {'image': image, 'label': label}
return _preprocess_fn_train if mode == 'train' else _preprocess_fn_test
def cifar10_normalize(image: chex.Array) -> chex.Array:
means = jnp.array(_CIFAR10_MEAN, dtype=image.dtype)
stds = jnp.array(_CIFAR10_STD, dtype=image.dtype)
return (image - means) / stds
def mnist_normalize(image: chex.Array) -> chex.Array:
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
constant_values=0)
return (image - .5) * 2.
def cifar100_normalize(image: chex.Array) -> chex.Array:
means = jnp.array(_CIFAR100_MEAN, dtype=image.dtype)
stds = jnp.array(_CIFAR100_STD, dtype=image.dtype)
return (image - means) / stds
def load_cifar10(batch_sizes: Sequence[int],
subset: str = 'train',
is_training: bool = True,
drop_remainder: bool = True,
repeat: int = 1) -> tf.data.Dataset:
"""Loads CIFAR-10."""
if subset == 'train':
ds = tfds.load(name='cifar10', split=tfds.Split.TRAIN)
# In Gowal et al. (https://arxiv.org/abs/2010.03593) and Rebuffi et al.
# (https://arxiv.org/abs/2103.01946), we also keep a separate validation
# subset for early stopping and would run: ds = ds.skip(1_024).
elif subset == 'test':
ds = tfds.load(name='cifar10', split=tfds.Split.TEST)
else:
raise ValueError('Unknown subset: "{}"'.format(subset))
ds = ds.cache()
if is_training:
ds = ds.repeat()
ds = ds.shuffle(buffer_size=50_000, seed=0)
ds = _repeat_batch(batch_sizes, ds, repeat=repeat)
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
num_parallel_calls=tf.data.AUTOTUNE)
for batch_size in reversed(batch_sizes):
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds.prefetch(tf.data.AUTOTUNE)
def load_extra(batch_sizes: Sequence[int],
path_npz: str,
is_training: bool = True,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Loads extra data from a given path."""
if not tf.io.gfile.exists(path_npz):
if path_npz in _ALLOWED_FILES:
path_npz = tf.keras.utils.get_file(path_npz, _DATA_URL + path_npz)
else:
raise ValueError(f'Extra data not found ({path_npz}). See {_WEBPAGE} for '
'more details.')
with tf.io.gfile.GFile(path_npz, 'rb') as fp:
npzfile = np.load(fp)
data = {'image': npzfile['image'], 'label': npzfile['label']}
with tf.device('/device:cpu:0'): # Prevent allocation to happen on GPU.
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.cache()
if is_training:
ds = ds.repeat()
ds = ds.shuffle(buffer_size=50_000, seed=jax.host_id())
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
num_parallel_calls=tf.data.AUTOTUNE)
for batch_size in reversed(batch_sizes):
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds.prefetch(tf.data.AUTOTUNE)
def load_dummy_data(batch_sizes: Sequence[int],
is_training: bool = True,
**unused_kwargs) -> tf.data.Dataset:
"""Loads fictive data (use this function when testing)."""
ds = tf.data.Dataset.from_tensor_slices({
'image': np.zeros((1, 32, 32, 3), np.float32),
'label': np.zeros((1,), np.int32),
})
ds = ds.repeat()
if not is_training:
total_batch_size = np.prod(batch_sizes)
ds = ds.take(total_batch_size)
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
num_parallel_calls=tf.data.AUTOTUNE)
for batch_size in reversed(batch_sizes):
ds = ds.batch(batch_size, drop_remainder=True)
return ds.prefetch(tf.data.AUTOTUNE)
def _random_jitter(image: tf.Tensor, pad: int, crop: int) -> tf.Tensor:
shape = image.shape.as_list()
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]])
image = tf.image.random_crop(image, size=[crop, crop, shape[2]])
return image
def _repeat_batch(batch_sizes: Sequence[int],
ds: tf.data.Dataset,
repeat: int = 1) -> tf.data.Dataset:
"""Tiles the inner most batch dimension."""
if repeat <= 1:
return ds
if batch_sizes[-1] % repeat != 0:
raise ValueError(f'The last element of `batch_sizes` ({batch_sizes}) must '
f'be divisible by `repeat` ({repeat}).')
# Perform regular batching with reduced number of elements.
for i, batch_size in enumerate(reversed(batch_sizes)):
ds = ds.batch(batch_size // repeat if i == 0 else batch_size,
drop_remainder=True)
# Repeat batch.
fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1)
def repeat_inner_batch(example):
return jax.tree_map(fn, example)
ds = ds.map(repeat_inner_batch,
num_parallel_calls=tf.data.AUTOTUNE)
# Unbatch.
for _ in batch_sizes:
ds = ds.unbatch()
return ds

View File

@@ -14,14 +14,19 @@
"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST."""
import functools
from absl import app
from absl import flags
import haiku as hk
import numpy as np
import optax
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tqdm
from adversarial_robustness.jax import attacks
from adversarial_robustness.jax import datasets
from adversarial_robustness.jax import model_zoo
_CKPT = flags.DEFINE_string(
@@ -48,14 +53,14 @@ def main(unused_argv):
# Create dataset.
if _DATASET.value == 'mnist':
_, data_test = tf.keras.datasets.mnist.load_data()
normalize_fn = model_zoo.mnist_normalize
normalize_fn = datasets.mnist_normalize
elif _DATASET.value == 'cifar10':
_, data_test = tf.keras.datasets.cifar10.load_data()
normalize_fn = model_zoo.cifar10_normalize
normalize_fn = datasets.cifar10_normalize
else:
assert _DATASET.value == 'cifar100'
_, data_test = tf.keras.datasets.cifar100.load_data()
normalize_fn = model_zoo.cifar100_normalize
normalize_fn = datasets.cifar100_normalize
# Create model.
@hk.transform_with_state
@@ -83,22 +88,53 @@ def main(unused_argv):
else:
params, state = np.load(_CKPT.value, allow_pickle=True)
# Create adversarial attack. We run a PGD-40 attack with margin loss.
epsilon = 8 / 255
eval_attack = attacks.UntargetedAttack(
attacks.PGD(
attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule(
init_value=.1,
boundaries_and_scales={20: .1, 30: .01})),
num_steps=40,
initialize_fn=attacks.linf_initialize_fn(epsilon),
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
loss_fn=attacks.untargeted_margin)
def logits_fn(x, rng):
return model_fn.apply(params, state, rng, x)[0]
# Evaluation.
correct = 0
adv_correct = 0
total = 0
batch_count = 0
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
for images, labels in tqdm.tqdm(test_loader, total=total_batches):
outputs = model_fn.apply(params, state, next(rng_seq), images)[0]
rng = next(rng_seq)
loop_logits_fn = functools.partial(logits_fn, rng=rng)
# Clean examples.
outputs = loop_logits_fn(images)
correct += (np.argmax(outputs, 1) == labels).sum().item()
# Adversarial examples.
adv_images = eval_attack(loop_logits_fn, next(rng_seq), images, labels)
outputs = loop_logits_fn(adv_images)
predicted = np.argmax(outputs, 1)
adv_correct += (predicted == labels).sum().item()
total += labels.shape[0]
correct += (predicted == labels).sum().item()
batch_count += 1
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
break
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
print(f'Robust accuracy: {100 * adv_correct / total:.2f}%')
if __name__ == '__main__':
flags.mark_flag_as_required('ckpt')
try:
tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU.
except tf.errors.NotFoundError:
pass
app.run(main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,46 @@
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quick script to test that experiment can import and run."""
from absl import app
import jax
import jax.numpy as jnp
from jaxline import utils as jl_utils
from adversarial_robustness.jax import experiment
@jl_utils.disable_pmap_jit
def test_experiment(unused_argv):
"""Tests the main experiment."""
config = experiment.get_config()
exp_config = config.experiment_kwargs.config
exp_config.dry_run = True
exp_config.emulated_workers = 0
exp_config.training.batch_size = 2
exp_config.evaluation.batch_size = 2
exp_config.model.kwargs.depth = 10
exp_config.model.kwargs.width = 1
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'Step successfully taken, resulting metrics are {result}')
if __name__ == '__main__':
app.run(test_experiment)

View File

@@ -14,19 +14,14 @@
"""WideResNet implementation in JAX using Haiku."""
from typing import Any, Mapping, Optional, Text
from typing import Any, Dict, Optional
import chex
import haiku as hk
import jax
import jax.numpy as jnp
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2471, 0.2435, 0.2616)
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
CIFAR100_STD = (0.2673, 0.2564, 0.2762)
class _WideResNetBlock(hk.Module):
"""Block of a WideResNet."""
@@ -89,16 +84,16 @@ class WideResNet(hk.Module):
num_classes: int = 10,
depth: int = 28,
width: int = 10,
activation: Text = 'relu',
norm_args: Optional[Mapping[Text, Any]] = None,
name: Optional[Text] = None):
activation: str = 'relu',
norm_args: Optional[Dict[str, Any]] = None,
name: Optional[str] = None):
super(WideResNet, self).__init__(name=name)
if (depth - 4) % 6 != 0:
raise ValueError('depth should be 6n+4.')
self._activation = getattr(jax.nn, activation)
if norm_args is None:
norm_args = {
'create_offset': False,
'create_offset': True,
'create_scale': True,
'decay_rate': .99,
}
@@ -113,6 +108,7 @@ class WideResNet(hk.Module):
**norm_args)
self._linear = hk.Linear(
num_classes,
w_init=jnp.zeros,
name='logits')
blocks_per_layer = (depth - 4) // 6
@@ -132,7 +128,7 @@ class WideResNet(hk.Module):
name='resnet_lay_{}_block_{}'.format(layer_num, i)))
self._blocks.append(blocks_of_layer)
def __call__(self, inputs, **norm_kwargs):
def __call__(self, inputs: chex.Array, **norm_kwargs) -> chex.Array:
net = inputs
net = self._conv(net)
@@ -145,21 +141,3 @@ class WideResNet(hk.Module):
net = jnp.mean(net, axis=[1, 2])
return self._linear(net)
def mnist_normalize(image: jnp.array) -> jnp.array:
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
constant_values=0)
return (image - .5) * 2.
def cifar10_normalize(image: jnp.array) -> jnp.array:
means = jnp.array(CIFAR10_MEAN, dtype=image.dtype)
stds = jnp.array(CIFAR10_STD, dtype=image.dtype)
return (image - means) / stds
def cifar100_normalize(image: jnp.array) -> jnp.array:
means = jnp.array(CIFAR100_MEAN, dtype=image.dtype)
stds = jnp.array(CIFAR100_STD, dtype=image.dtype)
return (image - means) / stds

View File

@@ -0,0 +1,32 @@
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs a JAXline experiment to perform robust adversarial training."""
import functools
from absl import app
from absl import flags
from jaxline import platform
import tensorflow.compat.v2 as tf
from adversarial_robustness.jax import experiment
if __name__ == '__main__':
flags.mark_flag_as_required('config')
try:
tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU.
except tf.errors.NotFoundError:
pass
app.run(functools.partial(platform.main, experiment.Experiment))

View File

@@ -0,0 +1,197 @@
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions."""
import re
from typing import Optional, Sequence, Tuple
import chex
import einops
import haiku as hk
import jax
import jax.numpy as jnp
import optax
def get_cosine_schedule(
max_learning_rate: float,
total_steps: int,
warmup_steps: int = 0) -> optax.Schedule:
"""Builds a cosine decay schedule with initial warm-up."""
if total_steps < warmup_steps:
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps)
return optax.join_schedules([
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
transition_steps=warmup_steps),
optax.cosine_decay_schedule(init_value=max_learning_rate,
decay_steps=total_steps - warmup_steps),
], [warmup_steps])
def sgd_momentum(learning_rate_fn: optax.Schedule,
momentum: float = 0.,
nesterov: bool = False) -> optax.GradientTransformation:
return optax.chain(
optax.trace(decay=momentum, nesterov=nesterov),
optax.scale_by_schedule(learning_rate_fn),
optax.scale(-1.))
def cross_entropy(logits: chex.Array, labels: chex.Array) -> chex.Array:
return -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
def kl_divergence(q_logits: chex.Array,
p_logits: chex.Array) -> chex.Array:
"""Compute the KL divergence."""
p_probs = jax.nn.softmax(p_logits)
return cross_entropy(q_logits, p_probs) - cross_entropy(p_logits, p_probs)
def accuracy(logits: chex.Array, labels: chex.Array) -> chex.Array:
predicted_label = jnp.argmax(logits, axis=-1)
correct = jnp.equal(predicted_label, labels).astype(jnp.float32)
return jnp.sum(correct, axis=0) / logits.shape[0]
def weight_decay(params: hk.Params,
regex_match: Optional[Sequence[str]] = None,
regex_ignore: Optional[Sequence[str]] = None) -> chex.Array:
"""Computes the L2 regularization loss."""
if regex_match is None:
regex_match = ('.*w$', '.*b$')
if regex_ignore is None:
regex_ignore = ('.*batchnorm.*',)
l2_norm = 0.
for mod_name, mod_params in params.items():
for param_name, param in mod_params.items():
name = '/'.join([mod_name, param_name])
if (regex_match and
all(not re.match(regex, name) for regex in regex_match)):
continue
if (regex_ignore and
any(re.match(regex, name) for regex in regex_ignore)):
continue
l2_norm += jnp.sum(jnp.square(param))
return .5 * l2_norm
def ema_update(step: chex.Array,
avg_params: chex.ArrayTree,
new_params: chex.ArrayTree,
decay_rate: float = 0.99,
warmup_steps: int = 0,
dynamic_decay: bool = True) -> chex.ArrayTree:
"""Applies an exponential moving average."""
factor = (step >= warmup_steps).astype(jnp.float32)
if dynamic_decay:
# Uses TF-style EMA.
delta = step - warmup_steps
decay = jnp.minimum(decay_rate, (1. + delta) / (10. + delta))
else:
decay = decay_rate
decay *= factor
def _weighted_average(p1, p2):
d = decay.astype(p1.dtype)
return (1 - d) * p1 + d * p2
return jax.tree_multimap(_weighted_average, new_params, avg_params)
def cutmix(rng: chex.PRNGKey,
images: chex.Array,
labels: chex.Array,
alpha: float = 1.,
beta: float = 1.,
split: int = 1) -> Tuple[chex.Array, chex.Array]:
"""Composing two images by inserting a patch into another image."""
batch_size, height, width, _ = images.shape
if split > 1:
split_batch_size = batch_size // split
# Masking bounding box.
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
lam = jax.random.beta(lam_rng, a=alpha, b=beta, shape=())
cut_rat = jnp.sqrt(1. - lam)
cut_w = jnp.array(width * cut_rat, dtype=jnp.int32)
cut_h = jnp.array(height * cut_rat, dtype=jnp.int32)
box_coords = _random_box(box_rng, height, width, cut_h, cut_w)
# Adjust lambda.
lam = 1. - (box_coords[2] * box_coords[3] / (height * width))
idx = jax.random.permutation(rng, split_batch_size)
def _cutmix(x, y):
images_a = x
images_b = x[idx, :, :, :]
y = lam * y + (1. - lam) * y[idx, :]
x = _compose_two_images(images_a, images_b, box_coords)
return x, y
if split <= 1:
return _cutmix(images, labels)
# Apply CutMix separately on each sub-batch. This reverses the effect of
# `repeat` in datasets.
images = einops.rearrange(images, '(b1 b2) ... -> b1 b2 ...', b2=split)
labels = einops.rearrange(labels, '(b1 b2) ... -> b1 b2 ...', b2=split)
images, labels = jax.vmap(_cutmix, in_axes=1, out_axes=1)(images, labels)
images = einops.rearrange(images, 'b1 b2 ... -> (b1 b2) ...', b2=split)
labels = einops.rearrange(labels, 'b1 b2 ... -> (b1 b2) ...', b2=split)
return images, labels
def _random_box(rng: chex.PRNGKey,
height: chex.Numeric,
width: chex.Numeric,
cut_h: chex.Array,
cut_w: chex.Array) -> chex.Array:
"""Sample a random box of shape [cut_h, cut_w]."""
height_rng, width_rng = jax.random.split(rng)
minval_h = 0
minval_w = 0
maxval_h = height
maxval_w = width
i = jax.random.randint(
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32)
j = jax.random.randint(
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32)
bby1 = jnp.clip(i - cut_h // 2, 0, height)
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
h = jnp.clip(i + cut_h // 2, 0, height) - bby1
w = jnp.clip(j + cut_w // 2, 0, width) - bbx1
return jnp.array([bby1, bbx1, h, w])
def _compose_two_images(images: chex.Array,
image_permutation: chex.Array,
bbox: chex.Array) -> chex.Array:
"""Inserting the second minibatch into the first at the target locations."""
def _single_compose_two_images(image1, image2):
height, width, _ = image1.shape
mask = _window_mask(bbox, (height, width))
return image1 * (1. - mask) + image2 * mask
return jax.vmap(_single_compose_two_images)(images, image_permutation)
def _window_mask(destination_box: chex.Array,
size: Tuple[int, int]) -> jnp.ndarray:
"""Mask a part of the image."""
height_offset, width_offset, h, w = destination_box
h_range = jnp.reshape(jnp.arange(size[0]), [size[0], 1, 1])
w_range = jnp.reshape(jnp.arange(size[1]), [1, size[1], 1])
return jnp.logical_and(
jnp.logical_and(height_offset <= h_range,
h_range < height_offset + h),
jnp.logical_and(width_offset <= w_range,
w_range < width_offset + w)).astype(jnp.float32)

View File

@@ -0,0 +1,28 @@
# PyTorch evaluation
We provide PyTorch evaluation code for convenience. If you developed a version
of our training pipeline for PyTorch, please let us know as we will link it from
here.
Here are known PyTorch implementations of our training pipeline:
* https://github.com/imrahulr/adversarial_robustness_pytorch (by Rahul Rade)
Here are few consideration when reproducing our training pipeline in PyTorch.
As opposed to the [RST](https://github.com/yaircarmon/semisup-adv) code
(provided by Carmon et al.):
* We set the batch normalization decay to 0.99 (instead of 0.9).
* We do not apply weight decay (l2 regularization) to the batch normalization
scale and offset
* We use Haiku's default initialization for all layers (except the last, which
is initialized with zeros).
* The PGD attack used during training uniformly initializes the initial solution
over the l-p norm ball.
* We run the attack over the local batch statistics (rather than the evaluation
statistics).
* We update batch normalization statistics from adversarial examples only (
rather than both clean and adversarial examples).
* We use 10 epochs warm-up to our learning schedule.

View File

@@ -1,51 +1,64 @@
absl-py==0.10.0
# Direct dependencies.
absl-py==0.12.0
chex==0.0.7
dm-haiku==0.0.4
einops==0.3.0
jax==0.2.16
jaxlib==0.1.68
jaxline==0.0.3
ml-collections==0.1.0
numpy==1.19.5
optax==0.0.8
tensorflow==2.5.0
tensorflow-datasets==4.3.0
torch==1.9.0
torchvision==0.10.0
tqdm==4.61.1
# Transitive dependencies.
astunparse==1.6.3
attrs==20.3.0
cachetools==4.1.1
certifi==2020.11.8
chardet==3.0.4
dataclasses==0.6
dill==0.3.3
dm-haiku==0.0.3
attrs==21.2.0
cachetools==4.2.2
certifi==2021.5.30
chardet==4.0.0
contextlib2==21.6.0
dill==0.3.4
dm-tree==0.1.6
flatbuffers==1.12
future==0.18.2
gast==0.3.3
google-auth==1.23.0
google-auth-oauthlib==0.4.2
gast==0.4.0
google-auth==1.32.0
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
googleapis-common-protos==1.52.0
grpcio==1.33.2
h5py==2.10.0
googleapis-common-protos==1.53.0
grpcio==1.34.1
h5py==3.1.0
idna==2.10
importlib-resources==3.3.0
jax==0.2.6
jaxlib==0.1.57
keras-nightly==2.5.0.dev2021032900
Keras-Preprocessing==1.1.2
Markdown==3.3.3
numpy==1.18.5
oauthlib==3.1.0
Markdown==3.3.4
oauthlib==3.1.1
opt-einsum==3.3.0
Pillow==8.0.1
Pillow==8.2.0
pkg-resources==0.0.0
promise==2.3
protobuf==3.14.0
protobuf==3.17.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
requests==2.25.0
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.6
scipy==1.5.4
rsa==4.7.2
scipy==1.7.0
six==1.15.0
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-datasets==4.1.0
tensorflow-estimator==2.3.0
tensorflow-metadata==0.25.0
tabulate==0.8.9
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow-estimator==2.5.0
tensorflow-metadata==1.1.0
termcolor==1.1.0
torch==1.7.0
torchvision==0.8.1
tqdm==4.53.0
toolz==0.11.1
typing-extensions==3.7.4.3
urllib3==1.26.2
Werkzeug==1.0.1
urllib3==1.26.6
Werkzeug==2.0.1
wrapt==1.12.1

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set -euf -o pipefail # Stop at failure.
python3 -m venv /tmp/adversarial_robustness_venv
source /tmp/adversarial_robustness_venv/bin/activate
pip install -U pip
@@ -34,3 +36,9 @@ python3 -m adversarial_robustness.pytorch.eval \
--batch_size=1 \
--num_batches=1 \
--nouse_cuda
# We disable pmap/jit to avoid compilation during testing. Since the
# test only runs a single step, it would not benefit from such a compilation
# anyways.
python3 -m adversarial_robustness.jax.experiment_test \
--jaxline_disable_pmap_jit=True

View File

@@ -0,0 +1,123 @@
# One-shot StreetLearn and Memory & Planning Game
This repository contains code for the environments used in the paper
["Rapid Task-Solving in Novel Environments"](https://arxiv.org/abs/2006.03662)
by Sam Ritter, Ryan Faulkner, Laurent Sartran, Adam Santoro, Matt Botvinick and
David Raposo. It was published as a conference paper at ICLR 2021.
To cite this work:
```
@inproceedings{
ritter2021rapid,
title={Rapid Task-Solving in Novel Environments},
author={Samuel Ritter and Ryan Faulkner and Laurent Sartran and Adam Santoro
and Matthew Botvinick and David Raposo},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=F-mvpFpn_0q}
}
```
### Memory&Planning Game
The _Memory&Planning Game_ is a simple variation of the well-known _Memory
Game_, wherein players must remember the locations of cards in a grid. This
variation extends the challenge to require planning as well as remembering.
In the _Memory&Planning Game_, the agent occupies an environment consisting
of a grid of symbols (e.g. 4x4). The observation consists of two symbols — one
which corresponds to the agent's current location, and another that corresponds
to the "goal" the agent is tasked with navigating to. The agent can not see its
relative location with respect to other symbols in the grid. At each step the
agent selects one of 5 possible actions: _move left_, _move right_, _move up_,
_move down_, and _collect_. If the agent chooses the "collect" action when its
current location symbol matches the goal symbol, a reward of +1 is received.
Otherwise, the agent receives a reward of 0. At the beginning of each episode, a
new set of symbols is sampled, effectively inducing a new transition function.
The agent is allowed a fixed number of steps (e.g. 100) per episode to "collect"
as many goals as possible. Each time the agent collects a goal — which
corresponds to completing a task —, a new goal is sampled in and the transition
function stays fixed.
#### Example
The following code snipped shows an example of how to load the environment,
start a new episode and take a few random steps. A plot representing the current
state of the environment is displayed for each step.
```
import memory_planning_game
env = memory_planning_game.MemoryPlanningGame(4, seed=123)
_ = env.reset()
for _ in range(5):
timestep = env.take_random_action()
fig, ax = env.draw_maze()
```
![Memory & Planning Game environment](images/example_mpg.png)
### One-Shot StreetLearn
The _One-Shot StreetLearn_ is a domain wherein environments are sampled as
neighborhoods from the _StreetLearn_ dataset (Mirowski et al., 2019). Tasks are
then sampled by selecting a position and orientation that the agent must
navigate to from its current location.
In _One-Shot StreetLearn_, the agent's observations consist of a representation
of the current state and a representation of the goal state. The agent receives
no other information from the environment. The available actions are _turn
right_, which orients the agent clockwise toward the next available direction of
motion from its current location; _turn left_, which does the same in the other
direction; and _move forward_, which moves the agent along the direction it is
facing to the next available location. In each episode, we sample a new
neighborhood with 5 intersections from one of 12 cities. To reduce the
exploration problem while keeping the planning difficulty constant, we removed
all the locations that were not intersections (i.e. that corresponded to
degree-2 nodes in the connectivity graph of the sampled neighbourhood).
Note: In the paper we used images to represent each location, taken from the
_StreetLearn_ dataset. In this codebase, to simplify the public release process,
we replace the images with one-hot vectors, to represent each location.
Every time the agent reaches a goal, a new starting- and goal-state pair is
sampled, initiating a new task, until the fixed episode step limit is reached
(e.g. 200 steps). A step that takes the agent to the goal state results in a
reward of +1. Any other step results in a reward of 0.
#### City graph datasets
In this [link]
(https://console.cloud.google.com/storage/browser/one_shot_streetlearn_graphs)
you can download the datasets containing the connectivity graphs of different
cities. You will need at least one of these datasets in order to build the
environment.
#### Example
The following code snipped shows an example of how to load the environment for
an example city (in this case, Madrid), start a new episode and take a few
random steps. A plot representing the current state of the environment is
displayed for each step.
```
import numpy as np
import one_shot_streetlearn
region = 'madrid'
sl_config = {
'dataset_path': f'./datasets/{region}_full_graph.gexf',
'max_episode_steps': 200,
'num_junctions': 5,
}
env = one_shot_streetlearn.OneShotStreetLearn(**sl_config)
_ = env.reset()
for _ in range(4):
rnd_action = np.random.randint(env.NUM_ACTIONS)
timestep = env.step(rnd_action)
fig, ax = env.draw_subgraph()
```
![One-shot StreetLearn environment](images/example_osl.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

View File

@@ -0,0 +1,184 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Memory & Planning Game environment."""
import string
import dm_env
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
class MemoryPlanningGame(dm_env.Environment):
"""Memory & Planning Game environment."""
ACTION_NAMES = ['Up', 'Down', 'Left', 'Right', 'Collect']
NUM_ACTIONS = len(ACTION_NAMES)
DIRECTIONS = [
(0, 1), # Up
(0, -1), # Down
(-1, 0), # Left
(1, 0), # Right
(0, 0), # Collect
]
def __init__(self,
maze_size=4,
max_episode_steps=100,
target_reward=1.,
per_step_reward=0.,
random_respawn=False,
seed=None):
"""The Memory & Planning Game environment.
Args:
maze_size: (int) size of the maze dimension.
max_episode_steps: (int) number of steps per episode.
target_reward: (float) reward value of the target.
per_step_reward: (float) reward/cost of taking a step.
random_respawn: (bool) whether the agent respawns in a random location
upon collecting the goal.
seed: (int or None) seed for random number generator.
"""
self._maze_size = maze_size
self._num_labels = maze_size * maze_size
# The graph itself is the same across episodes, but the node labels will be
# randomly sampled in each episode.
self._graph = nx.grid_2d_graph(
self._maze_size, self._maze_size, periodic=True)
self._max_episode_steps = max_episode_steps
self._target_reward = target_reward
self._per_step_reward = per_step_reward
self._random_respawn = random_respawn
self._rng = np.random.RandomState(seed)
def _one_hot(self, node):
one_hot_vector = np.zeros([self._num_labels], dtype=np.int32)
one_hot_vector[self._labels[node]] = 1
return one_hot_vector
def step(self, action):
# If previous step was the last step of an episode, reset.
if self._needs_reset:
return self.reset()
# Increment step count and check if it's the last step of the episode.
self._episode_steps += 1
if self._episode_steps >= self._max_episode_steps:
self._needs_reset = True
transition = dm_env.termination
else:
transition = dm_env.transition
# Recompute agent's position given the selected action.
direction = self.DIRECTIONS[action]
self._position = tuple(
(np.array(self._position) + np.array(direction)) % self._maze_size)
self._previous_action = self.ACTION_NAMES[action]
# Get reward if agent is over the goal location and the selected action is
# `collect`.
if self._position == self._goal and self.ACTION_NAMES[action] == 'Collect':
reward = self._target_reward
self._set_new_goal()
else:
reward = self._per_step_reward
self._episode_reward += reward
return transition(reward, self._observation())
def _observation(self):
return {
'position': np.array(self._one_hot(self.position), dtype=np.int32),
'goal': np.array(self._one_hot(self.goal), dtype=np.int32),
}
def observation_spec(self):
return {
'position': dm_env.specs.Array(
shape=(self._num_labels,), dtype=np.int32, name='position'),
'goal': dm_env.specs.Array(
shape=(self._num_labels,), dtype=np.int32, name='goal'),
}
def action_spec(self):
return dm_env.specs.DiscreteArray(self.NUM_ACTIONS)
def take_random_action(self):
return self.step(self._rng.randint(self.NUM_ACTIONS))
def reset(self):
self._previous_action = ''
self._episode_reward = 0.
self._episode_steps = 0
self._needs_reset = False
random_labels = self._rng.permutation(self._num_labels)
self._labels = {n: random_labels[i]
for i, n in enumerate(self._graph.nodes())}
self._respawn()
self._set_new_goal()
return dm_env.restart(self._observation())
def _respawn(self):
random_idx = self._rng.randint(self._num_labels)
self._position = list(self._graph.nodes())[random_idx]
def _set_new_goal(self):
if self._random_respawn:
self._respawn()
goal = self._position
while goal == self._position:
random_idx = self._rng.randint(self._num_labels)
goal = list(self._graph.nodes())[random_idx]
self._goal = goal
@property
def position(self):
return self._position
@property
def goal(self):
return self._goal
@property
def previous_action(self):
return self._previous_action
@property
def episode_reward(self):
return self._episode_reward
def draw_maze(self, ax=None):
if ax is None:
plt.figure()
ax = plt.gca()
node_positions = {(x, y): (x, y) for x, y in self._graph.nodes()}
letters = string.ascii_uppercase + string.ascii_lowercase
labels = {n: letters[self._labels[n]] for n in self._graph.nodes()}
node_list = list(self._graph.nodes())
colors = []
for n in node_list:
if n == self.position:
colors.append('lightblue')
elif n == self.goal:
colors.append('lightgreen')
else:
colors.append('pink')
nx.draw(self._graph, pos=node_positions, nodelist=node_list, ax=ax,
node_color=colors, with_labels=True, node_size=200, labels=labels)
ax.set_title('{}\nEpisode reward={:.1f}'.format(
self.previous_action, self.episode_reward))
ax.margins(.1)
return plt.gcf(), ax

View File

@@ -0,0 +1,265 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""One-shot StreetLearn environment."""
import dm_env
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
def deg_to_rad(x):
"""Convert degrees to radians."""
return x / 180. * np.pi
def rad_to_deg(x):
"""Convert radians to degrees."""
return x * 180. / np.pi
class OneShotStreetLearn(dm_env.Environment):
"""One-shot Streetlearn environment."""
ACTION_NAMES = [
'Forward',
'Left',
'Right',
'Collect',
]
NUM_ACTIONS = len(ACTION_NAMES)
def __init__(self, dataset_path, max_episode_steps, num_junctions=8,
target_reward=1., per_step_reward=0., observation_length=60,
seed=None):
self._graph = nx.read_gexf(dataset_path)
self._node_attrs = self._graph.nodes(data=True)
self._num_junctions = num_junctions
self._observation_length = observation_length
self._max_episode_steps = max_episode_steps
self._target_reward = target_reward
self._per_step_reward = per_step_reward
self._rng = np.random.RandomState(seed)
self.reset()
def reset(self):
self._previous_action = ''
self._episode_reward = 0.
self._episode_steps = 0
self._needs_reset = False
self._subgraph = self.get_random_subgraph()
self._observation_map = self.randomize_observations(self._subgraph)
self._position = self._rng.choice(list(self._subgraph.nodes()))
neighbours = self._neighbors_bearings(self._subgraph, self._position)
self._neighbour = neighbours[self._rng.randint(len(neighbours))]
self._set_new_goal()
return dm_env.restart(self._observation())
@property
def _current_edge(self):
return (self._position, self._neighbour['neighbour'])
def _set_new_goal(self):
goal = None
edges = list(self._observation_map.keys())
while goal is None or goal == self._current_edge:
goal = edges[self._rng.randint(len(edges))]
self._goal = goal
def _one_hot(self, edge):
one_hot_vector = np.zeros([self._observation_length], dtype=np.int32)
one_hot_vector[self._observation_map[edge]] = 1
return one_hot_vector
def _observation(self):
return {
'position': np.array(self._one_hot(self._current_edge), dtype=np.int32),
'goal': np.array(self._one_hot(self._goal), dtype=np.int32),
}
def observation_spec(self):
return {
'position': dm_env.specs.Array(
shape=(self._observation_length,), dtype=np.int32, name='position'),
'goal': dm_env.specs.Array(
shape=(self._observation_length,), dtype=np.int32, name='goal'),
}
def action_spec(self):
return dm_env.specs.DiscreteArray(self.NUM_ACTIONS)
def step(self, action):
# If previous step was the last step of an episode, reset.
if self._needs_reset:
return self.reset()
# Increment step count and check if it's the last step of the episode.
self._episode_steps += 1
if self._episode_steps >= self._max_episode_steps:
self._needs_reset = True
transition = dm_env.termination
else:
transition = dm_env.transition
# Recompute agent's position
self._move(action)
self._previous_action = self.ACTION_NAMES[action]
# Get reward if agent is at the goal location and the selected action is
# `collect`.
if (self._current_edge == self._goal and
self.ACTION_NAMES[action] == 'Collect'):
reward = self._target_reward
self._set_new_goal()
else:
reward = self._per_step_reward
self._episode_reward += reward
return transition(reward, self._observation())
def randomize_observations(self, subgraph):
edges = list(subgraph.edges())
edges.extend([(y, x) for (x, y) in edges])
obs_permutation = self._rng.permutation(self._observation_length)
return {e: obs_permutation[i] for i, e in enumerate(edges)}
def _calculate_bearing(self, node, neighbor):
lat1 = deg_to_rad(self._node_attrs[node]['lat'])
lng1 = deg_to_rad(self._node_attrs[node]['lng'])
lat2 = deg_to_rad(self._node_attrs[neighbor]['lat'])
lng2 = deg_to_rad(self._node_attrs[neighbor]['lng'])
delta_lng = lng2 - lng1
theta = np.arctan2(
np.sin(delta_lng) * np.cos(lat2),
np.cos(lat1) * np.sin(lat2) -
np.sin(lat1) * np.cos(lat2) * np.cos(delta_lng))
return theta
def _neighbors_bearings(self, subgraph, node):
bearings = []
for neighbor in list(subgraph[node]):
orientation = self._calculate_bearing(node, neighbor)
bearings.append({'neighbour': neighbor, 'orientation': orientation})
bearings.sort(key=lambda x: x['orientation'])
return bearings
def _sort_neighbors(self, node, neighbour):
bearings = self._neighbors_bearings(self._subgraph, node)
bs = [x['orientation'] for x in bearings]
idx = np.argmin(np.abs(bs - neighbour['orientation']))
return {
'forward': bearings[idx],
'right': bearings[idx-1],
'left': bearings[(idx+1) % len(bearings)],
}
def _move(self, action):
neighbours = self._sort_neighbors(self._position, self._neighbour)
if action == 0:
new_node = self._neighbour['neighbour']
neighbours = self._sort_neighbors(new_node, neighbours['forward'])
new_neighbour = neighbours['forward']
else:
new_node = self._position
if action == 1:
new_neighbour = neighbours['left']
elif action == 2:
new_neighbour = neighbours['right']
else:
new_neighbour = self._neighbour
self._position = new_node
self._neighbour = new_neighbour
def _all_next_junctions(self, subgraph, node):
neighbors = list(subgraph[node])
edges = [self._get_next_junction(subgraph, node, nb) for nb in neighbors]
nodes = [y for (_, y) in edges]
return nodes, edges
def _get_next_junction(self, subgraph, initial_node, next_node):
node = initial_node
while subgraph.degree(next_node) == 2:
neighbours = list(subgraph.neighbors(next_node))
neighbours.remove(node)
node = next_node
next_node = neighbours.pop()
return (initial_node, next_node)
def get_random_subgraph(self):
graph = self._graph
num_nodes = len(graph)
rnd_index = self._rng.randint(num_nodes)
center_node = list(graph.nodes())[rnd_index]
while graph.degree(center_node) <= 2:
rnd_index = self._rng.randint(num_nodes)
center_node = list(graph.nodes())[rnd_index]
to_visit = [center_node]
visited = []
subgraph = nx.Graph()
while to_visit:
node = to_visit.pop(0)
visited.append(node)
new_nodes, new_edges = self._all_next_junctions(graph, node)
subgraph.add_edges_from(new_edges)
node_degrees = [subgraph.degree(n) for n in subgraph.nodes()]
count_junctions = len(list(filter(lambda x: x > 2, node_degrees)))
if count_junctions >= self._num_junctions:
break
new_nodes = filter(lambda x: x not in visited + to_visit, new_nodes)
to_visit.extend(new_nodes)
return subgraph
def draw_subgraph(self, ax=None):
if ax is None:
_ = plt.figure(figsize=(3, 3))
ax = plt.gca()
node_ids = list(self._subgraph.nodes())
pos = {
x: (self._node_attrs[x]['lat'], self._node_attrs[x]['lng'])
for x in node_ids
}
labels = {}
nc = 'pink'
ec = 'black'
ns = 50
nshape = 'o'
# Draw the current subgraph
nx.draw(self._subgraph, pos=pos, node_color=nc, with_labels=False,
node_size=ns, labels=labels, edgecolors=ec, node_shape=nshape,
ax=ax)
max_xy = np.array([np.array(x) for x in pos.values()]).max(0)
min_xy = np.array([np.array(x) for x in pos.values()]).min(0)
delta_xy = (max_xy - min_xy) / 6.
ax.set_xlim([min_xy[0] - delta_xy[0], max_xy[0] + delta_xy[0]])
ax.set_ylim([min_xy[1] - delta_xy[1], max_xy[1] + delta_xy[1]])
# Draw goal position and orientation
x = self._node_attrs[self._goal[0]]['lat']
y = self._node_attrs[self._goal[0]]['lng']
rotation = rad_to_deg(self._calculate_bearing(*self._goal))
_ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0),
markersize=14, markerfacecolor='white')
_ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0),
markersize=12, markerfacecolor='None')
# Draw current position and orientation
x = self._node_attrs[self._position]['lat']
y = self._node_attrs[self._position]['lng']
rotation = rad_to_deg(self._neighbour['orientation'])
_ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0),
markersize=14, markerfacecolor='lightgreen')
_ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0),
markersize=12, markerfacecolor='None')
ax.set_title('{}\nEpisode reward = {}'.format(
self._previous_action, self._episode_reward))
return plt.gcf(), ax

View File

@@ -0,0 +1,6 @@
dm-env>=1.2
dm-haiku>=0.0.3
jax>=0.2.8
matplotlib>=3.1.2
networkx>=2.3
numpy>=1.18.0

230
wikigraphs/README.md Normal file
View File

@@ -0,0 +1,230 @@
# WikiGraphs
This package provides tools to download the [WikiGraphs dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)
[1], collected by pairing each Wikipedia article from [WikiText-103](https://arxiv.org/pdf/1609.07843.pdf)
[2] with a knowledge graph (a subgraph from [Freebase knowledge graph](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)
[3]). The baseline code to reproduce results in [1] is included as well. We hope
this can spur more interest in developing models that can generate long text
conditioned on graph and generate graphs given text.
## Setup Jax environment
[Jax](https://github.com/google/jax#installation),
[Haiku](https://github.com/deepmind/dm-haiku#installation), and
[Optax](https://github.com/deepmind/dm-haiku#installation) are needed for this
package. It has been developed and tested on python 3 with the following
packages:
* Jax==0.2.13
* Haiku==0.0.5
* Optax==0.0.6
Other packages required can be installed via:
```bash
pip install -r requirements.txt
```
Note: you may need to use `pip3` to select pip for python 3 and `--user` option
to install the packages to avoid permission issues.
## Installation
```bash
pip install -e .
```
## Preparing the data
### Download the data
You can download and unzip the data by running the following command:
```bash
bash scripts/download.sh
```
This will put the downloaded WikiText-103 data in a temporary directory
`/tmp/data` with the tokenized WikiText-103 data in `/tmp/data/wikitext-103` and
the raw data in `/tmp/data/wikitext-103-raw`.
This script will also download our processed Freebase knowledge graph data in a
temporary directory `/tmp/data/freebase`.
### Build vocabularies
For WikiText-103, run the following command to generate a vocabulary file:
```bash
python scripts/build_vocab.py \
--vocab_file_path=/tmp/data/wikitext-vocab.csv \
--data_dir=/tmp/data/wikitext-103
```
You can change the default file paths but make sure you make them consistent.
### Pair Freebase graphs with WikiText
You can run the following command to pair the Freebase graphs with WikiText-103
articles:
```bash
python scripts/freebase_preprocess.py \
--freebase_dir=/tmp/data/freebase/max256 \
--output_dir=/tmp/data/wikigraphs/max256
```
where the `freebase_dir` `/tmp/data/freebase/max256` is the directory that
contains the Fsreebase graphs, which should have files `train.gz`, `valid.gz`
and `test.gz` in it; and `output_dir` is the directory that will contain the
generated paired Freebase-WikiText data.
Note: you may need to use `python3` to select python 3 if you have both python 2
and 3 on your system.
Given that there are the following number of articles in WikiText-103:
Subset | #articles
------ | ---------
Train | 28472*
Valid | 60
Test | 60
*Official number is 28475 but we were only able to find 28472 articles in
training set.
Our dataset covers around 80% of the WikiText articles:
Max graph size | 256 | 512 | 1024
---------------------------- | ----- | ----- | -----
\#articles in training set | 23431 | 23718 | 23760
Trainining set coverage | 82.3% | 83.3% | 83.5%
\#articles in validation set | 48 | 48 | 48
Validation set coverage | 80% | 80% | 80%
\#articles in test set | 43 | 43 | 43
Test set coverage | 71.7% | 71.7% | 71.7%
### Build vocabulary for WikiGraphs
You can build the vocabulary for the graph data (the max256 version) by running
the following command:
```bash
python scripts/build_vocab.py \
--vocab_file_path=/tmp/data/graph-vocab.csv \
--data_dir=/tmp/data/wikigraphs \
--version=max256 \
--data_type=graph \
--threshold=15
```
This gives us a vocabulary of size 31,087, with each token included in the
vocabulary appearing at least 15 times.
You also need to build a separate text vocabulary for the WikiGraphs data, as
our training set does not cover 100% of WikiText-103.
```bash
python scripts/build_vocab.py \
--vocab_file_path=/tmp/data/text-vocab.csv \
--data_dir=/tmp/data/wikigraphs \
--version=max256 \
--data_type=text \
--threshold=3
```
Here we choose threshold 3 which is also used by the original WikiText-103 data,
this gives us a vocabulary size of 238,068, only slightly smaller than the
original vocabulary size.
Note that when loading these vocabularies to build tokenizers, our tokenizers
will add a few extra tokens, like `<bos>`, `<pad>`, so the final vocab size
might be slightly different from the numbers above, depending on which tokenizer
you choose to use.
We only showcase how to build the vocabulary for the max256 version. The above
steps can be easily changed for the max512 and max1024 version.
## Loading the dataset
We provide JAX modules to load the WikiGraphs dataset. There are three classes
in `wikigraphs/data/paired_dataset.py`:
* `TextOnlyDataset`: loads only the text part of the WikiGraphs data
* `Bow2TextDataset`: loads text and the paired graph representated as one big
bag-of-words (BoW) on all nodes and edges from the graph
* `Graph2TextDataset`: returns text and the paired graph in which each node or
edge is represented by a BoW
Different versions of the dataset can be accessed by changing the `version`
argument in each class. For more detailed usage please refer to
`wikigraphs/data/paired_dataset_test.py`. Besides, the original WikiText dataset
can be loaded via the `Dataset` class in `wikigraphs/data/wikitext.py`.
Note: you may want to change the default data directory if you prefer to place
it elsewhere.
## Run baseline models
Note: baseline models will be available soon.
## Citing WikiGraphs
To cite this work:
```
@inproceedings{wang2021wikigraphs,
title={WikiGraphs: A Wikipedia Text-Knowledge Graph Paired Dataset},
author={Wang, Luyu and Li, Yujia and Aslan, Ozlem and Vinyals, Oriol},
booktitle={Proceedings of the Graph-Based Methods for Natural Language Processing (TextGraphs)},
pages={67--82},
year={2021}
}
```
## License
All code copyright 2021 DeepMind Technologies Limited
Code is licensed under the Apache License, Version 2.0 (the "License"); you may
not use this file except in compliance with the License. You may obtain a copy
of the License at:
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
[WikiGraphs](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)
[1] is licensed under the terms of the Creative Commons
Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
[WikiText-103 data](https://arxiv.org/pdf/1609.07843.pdf) [2] (unchanged) is
licensed by Salesforce.com, Inc. under the terms of the Creative Commons
Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. You can find
details about CC BY-SA 4.0 at:
https://creativecommons.org/licenses/by-sa/4.0/legalcode
[Freebase data](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)
[3] is licensed by Google LLC under the terms of the Creative
Commons CC BY 4.0 license. You may obtain a copy of the License at:
https://creativecommons.org/licenses/by/4.0/legalcode
## References
1. L. Wang, Y. Li, O. Aslan, and O. Vinyals, "[WikiGraphs: a wikipedia -
knowledge graph paired dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)",
in Proceedings of the Graph-based Methods for Natural Language Processing
(TextGraphs), pages 67-82, 2021.
2. S. Merity, C. Xiong, J. Bradbury, and R. Socher, "[Pointer sentinel mixture
models](https://arxiv.org/pdf/1609.07843.pdf)",
arXiv: 1609.07843, 2016.
3. K. Bollacker, C. Evans, P. Paritosh, T. Sturge, and J. Taylor,
"[Freebase: a collaboratively created graph database for structuring human
knowledge](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)",
in Proceedings of ACM SIGMOD international conference on Managementof data,
pages 12471250, 2008.

View File

@@ -0,0 +1,7 @@
absl-py==0.10.0
dm-haiku
jax>=0.2.13
nltk>=3.6.2
numpy>=1.19.5
optax>=0.0.6
scikit-learn>=0.24.2

View File

@@ -0,0 +1,166 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Script for building vocabulary files for datasets."""
import collections
import csv
import enum
import io
import os
from typing import List, Tuple
from absl import app
from absl import flags
from absl import logging
from wikigraphs.data import io_tools
from wikigraphs.data import paired_dataset
from wikigraphs.data import tokenizers
from wikigraphs.data import wikitext
class DatasetType(enum.Enum):
text = 1
graph = 2
wikitext = 3
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '', 'Path to the directory that contains the'
' unzipped wikitext-103 data.')
flags.DEFINE_string('vocab_file_path', '', 'Path to the output vocab file.')
flags.DEFINE_enum_class('data_type', DatasetType.wikitext, DatasetType,
'One of {`wikitext`, `graph`, `text`}.')
flags.DEFINE_integer('threshold', 1, 'Frequency threshold for a word to be'
' included in the vocabulary.')
flags.DEFINE_string('version', 'max256', 'Which version of paired data to use.')
def get_vocab(dataset: wikitext.RawDataset) -> List[Tuple[str, int]]:
"""Build vocabulary, return (word, count) tuples sorted by count."""
vocab = collections.defaultdict(int)
for pair in dataset:
for t in pair.text.split(' '):
if t:
vocab[t] += 1
return sorted(vocab.items(), key=lambda t: -t[1])
def write_vocab(vocab: List[Tuple[str, int]], output_path: str):
"""Write a vocab list to a file."""
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(output_path, mode='wb') as f_:
with io.TextIOWrapper(f_, encoding='utf-8') as f:
w = csv.writer(f)
w.writerows(vocab)
def build_wikitext_vocab():
logging.info('Loading the dataset.')
dataset = wikitext.RawDataset(subset='train', data_dir=FLAGS.data_dir)
logging.info('Building the vocab.')
vocab = get_vocab(dataset)
logging.info('Finished, vocab size %d, total number of tokens %d',
len(vocab), sum([c for _, c in vocab]))
logging.info('Writing the vocab to %s', FLAGS.vocab_file_path)
write_vocab(vocab, FLAGS.vocab_file_path)
def build_graph_vocab():
"""Build vocabulary for graph data."""
logging.info('Loading the dataset.')
dataset = paired_dataset.ParsedDataset(
subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version)
logging.info('Building graph vocab.')
vocab = collections.defaultdict(int)
for pair in dataset:
graph = pair.graph
for n in graph.nodes():
for t in tokenizers.GraphTokenizer.split_node(n):
if t:
vocab[t] += 1
for _, _, e in graph.edges():
for t in tokenizers.GraphTokenizer.split_edge(e):
if t:
vocab[t] += 1
vocab = sorted(vocab.items(), key=lambda t: -t[1])
vocab = [k for k, v in vocab if v >= FLAGS.threshold]
logging.info('Finished, vocab size %d.', len(vocab))
logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path)
io_tools.write_txt_file(FLAGS.vocab_file_path, '\n'.join(vocab),
# Some unicode characters requires utf-16 to encode.
encoding='utf-16')
def build_text_vocab():
"""Build vocabulary for the text part of the graph-to-text data."""
logging.info('Loading the dataset.')
dataset = paired_dataset.ParsedDataset(
subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version)
logging.info('Building text vocab.')
vocab = collections.defaultdict(int)
for pair in dataset:
for t in pair.text.split(' '):
if t:
vocab[t] += 1
vocab = sorted(vocab.items(), key=lambda t: -t[1])
logging.info('Finished, vocab size %d, total number of tokens %d.',
len(vocab), sum([v for _, v in vocab]))
vocab = [(k, v) for k, v in vocab if v >= FLAGS.threshold]
logging.info('After filtering, vocab size %d.', len(vocab))
logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path)
write_vocab(vocab, FLAGS.vocab_file_path)
def main(_):
if FLAGS.data_type == DatasetType.wikitext:
build_wikitext_vocab()
elif FLAGS.data_type == DatasetType.text:
build_text_vocab()
elif FLAGS.data_type == DatasetType.graph:
build_graph_vocab()
else:
raise ValueError(f'Unknown data type {FLAGS.data_type}.')
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,63 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
BASE_DIR=/tmp/data
# wikitext-103
TARGET_DIR=${BASE_DIR}/wikitext-103
mkdir -p ${TARGET_DIR}
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -P ${TARGET_DIR}
unzip ${TARGET_DIR}/wikitext-103-v1.zip -d ${TARGET_DIR}
mv ${TARGET_DIR}/wikitext-103/* ${TARGET_DIR}
rm -rf ${TARGET_DIR}/wikitext-103 ${TARGET_DIR}/wikitext-103-v1.zip
# wikitext-103-raw
TARGET_DIR=${BASE_DIR}/wikitext-103-raw
mkdir -p ${TARGET_DIR}
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip -P ${TARGET_DIR}
unzip ${TARGET_DIR}/wikitext-103-raw-v1.zip -d ${TARGET_DIR}
mv ${TARGET_DIR}/wikitext-103-raw/* ${TARGET_DIR}
rm -rf ${TARGET_DIR}/wikitext-103-raw ${TARGET_DIR}/wikitext-103-raw-v1.zip
# processed freebase graphs
FREEBASE_TARGET_DIR=/tmp/data
mkdir -p ${FREEBASE_TARGET_DIR}/packaged/
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuSS2o72dUCJrcLff6NBiLJuTgSU-uRo' -O ${FREEBASE_TARGET_DIR}/packaged/max256.tar
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1nOfUq3RUoPEWNZa2QHXl2q-1gA5F6kYh' -O ${FREEBASE_TARGET_DIR}/packaged/max512.tar
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuJwkocJXG1UcQ-RCH3JU96VsDvi7UD2' -O ${FREEBASE_TARGET_DIR}/packaged/max1024.tar
for version in max1024 max512 max256
do
output_dir=${FREEBASE_TARGET_DIR}/freebase/${version}/
mkdir -p ${output_dir}
tar -xvf ${FREEBASE_TARGET_DIR}/packaged/${version}.tar -C ${output_dir}
done
rm -rf ${FREEBASE_TARGET_DIR}/packaged

View File

@@ -0,0 +1,106 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Preprocess freebase data and pair with wikitext."""
import os
from absl import app
from absl import flags
from absl import logging
from wikigraphs.data import io_tools
from wikigraphs.data import wikitext
FLAGS = flags.FLAGS
flags.DEFINE_string('freebase_dir', '', 'Directory that containns Freebase'
' graphs.')
flags.DEFINE_string('output_dir', '', 'Path to output directory to store the'
' paired dataset.')
def pair_graphs_with_wikitext(subset: str, graph_dir: str, output_dir: str):
"""Pair graphs with wikitext articles, and write to output directory."""
logging.info('Pairing graphs from the %s set from %s with wikitext.',
subset, graph_dir)
graphs = list(io_tools.graphs_from_file(
os.path.join(graph_dir, f'{subset}.gz')))
title2graph = {
io_tools.normalize_freebase_string(g.title).replace(' ', ''): g
for g in graphs}
n_graphs = len(graphs)
# Use raw version of the wikitext data as the tokenized version has <unk> in
# titles which is bad for matching. We will handle the <unk>s through the
# tokenizer to make sure our data are equivalent to that of the tokenized
# version of wikitext-103.
wikitext_articles = list(wikitext.RawDataset(subset=subset, version='raw'))
n_wiki = len(wikitext_articles)
logging.info('Loaded %d graphs and %d wikitext articles in total.',
n_graphs, n_wiki)
# Keep track of the article titles in the dataset. Unfortunately wikitext-103
# has about 1% of duplicated articles, we want to take care of that.
retrieved_titles = set()
pairs = []
n_duplicates = 0
for a in wikitext_articles:
title = wikitext.normalize_title(a.title).replace(' ', '')
g = title2graph.get(title, None)
if g is not None:
if title not in retrieved_titles:
retrieved_titles.add(title)
pairs.append(io_tools.GraphTextPair(
center_node=g.center,
title=g.title,
edges=g.edges,
text=a.text))
else:
n_duplicates += 1
n_pairs = len(pairs)
logging.info('Matched %d/%d = %.1f%% of wikitext articles,'
' and %d/%d = %.1f%% of graphs.',
n_pairs, n_wiki, float(n_pairs) / n_wiki * 100,
n_pairs, n_graphs, float(n_pairs) / n_graphs * 100)
logging.info('Detected %d/%d = %.1f%% of duplicated wikitext articles.',
n_duplicates, n_wiki, float(n_duplicates) / n_wiki * 100)
io_tools.write_pairs_to_gzip_txt_file(
os.path.join(output_dir, f'{subset}.gz'), pairs)
def main(_):
for subset in ['train', 'valid', 'test']:
pair_graphs_with_wikitext(subset, FLAGS.freebase_dir, FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,143 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
r"""Tool to visualize graphs.
You need to have the command line tool `dot` installed locally, for example by
`sudo apt-get install graphviz`.
Example usage:
python visualize_graph.py \
--logtostderr --graph_ids=0:48 --truncate_limit=500 --layout=fdp
"""
import html
import os
import textwrap
from absl import app
from absl import flags
from absl import logging
from wikigraphs.data import io_tools
from wikigraphs.data import paired_dataset as pd
FLAGS = flags.FLAGS
flags.DEFINE_string('subset', 'valid', 'Which subset to choose graphs from.')
flags.DEFINE_string('graph_ids', '', 'A comma-separated string of graph IDs'
' (0-based), for example `1,2,3`. Or alternatively a'
' range, e.g. `0:10` which is equivalent to'
' `0,1,2,3,...,9`.')
flags.DEFINE_string('version', 'max256', 'Which version of data to load.')
flags.DEFINE_string('data_dir', '', 'Path to a directory that contains the raw'
' paired data, if provided.')
flags.DEFINE_string('output_dir', '/tmp/graph_vis', 'Output directory to save'
' the visualized graphs.')
flags.DEFINE_integer('truncate_limit', -1, 'Maximum length for graph nodes in'
' visualization.')
flags.DEFINE_string('layout', 'fdp', 'Which one of the dot layout to use.')
def truncate(s: str) -> str:
if FLAGS.truncate_limit > 0 and len(s) > FLAGS.truncate_limit:
s = s[:FLAGS.truncate_limit] + '...'
return s
def format_label(s: str, width: int = 40) -> str:
"""Format a node / edge label."""
s = io_tools.normalize_freebase_string(s)
s = truncate(s)
lines = s.split('\\n')
output_lines = []
for line in lines:
line = html.escape(line)
if width > 0:
output_lines += textwrap.wrap(line, width)
else:
output_lines.append(line)
return '<' + '<br/>'.join(output_lines) + '>'
def graph_to_dot(graph_text_pair: io_tools.GraphTextPair) -> str:
"""Convert a graph to a dot file."""
dot = ['digraph {', 'node [shape=rect];']
graph = pd.Graph.from_edges(graph_text_pair.edges)
center_node_id = graph.node2id(graph_text_pair.center_node)
for i, n in enumerate(graph.nodes()):
color = '#f5dc98' if i == center_node_id else (
'#b0ffad' if not(n[0] == '"' and n[-1] == '"') else '#ffffff')
label = format_label(n)
dot.append(f'{i} [ label = {label}, fillcolor="{color}", style="filled"];')
for i, j, e in graph.edges():
dot.append(f'{i} -> {j} [ label = {format_label(e, width=0)} ];')
dot.append('}')
return '\n'.join(dot)
def visualize_graph(graph_text_pair: io_tools.GraphTextPair,
graph_id: int,
output_dir: str):
"""Visualize a graph and save the visualization to the specified directory."""
dot = graph_to_dot(graph_text_pair)
output_file = os.path.join(output_dir, f'{graph_id}.dot')
logging.info('Writing output to %s', output_file)
with open(output_file, 'w') as f:
f.write(dot)
pdf_output = os.path.join(output_dir, f'{graph_id}.pdf')
os.system(f'dot -K{FLAGS.layout} -Tpdf -o {pdf_output} {output_file}')
def main(_):
logging.info('Loading the %s set of data.', FLAGS.subset)
pairs = list(pd.RawDataset(subset=FLAGS.subset,
data_dir=FLAGS.data_dir or None,
shuffle_data=False,
version=FLAGS.version))
logging.info('Loaded %d graph-text pairs.')
if ':' in FLAGS.graph_ids:
start, end = [int(i) for i in FLAGS.graph_ids.split(':')]
graph_ids = list(range(start, end))
else:
graph_ids = [int(i) for i in FLAGS.graph_ids.split(',')]
logging.info('Visualizing graphs with ID %r', graph_ids)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
for gid in graph_ids:
visualize_graph(pairs[gid], gid, FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)

43
wikigraphs/setup.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Setup for pip package."""
from setuptools import find_packages
from setuptools import setup
setup(
name='wikigraphs',
version='0.0.1',
description='A Wikipedia - knowledge graph paired dataset.',
url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs',
author='DeepMind',
author_email='luyuwang@google.com',
packages=find_packages(),
license='Apache 2.0',
)

View File

@@ -0,0 +1,36 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""WikiGraphs data modules."""
from . import dataset
from . import io_tools
from . import paired_dataset
from . import tokenizers
from . import tools
from . import wikitext

View File

@@ -0,0 +1,59 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Base class of the datasets."""
import abc
from typing import Any, Iterator
class Dataset(abc.ABC):
"""Base class for all datasets.
All sub-classes should define `_load_data()` where an iterator
`self._data_iter` should be instantiated that iterates over the dataset.
"""
def __init__(self):
"""Constructor."""
self._data_iter = None # An iterator produced by `self._load_data`.
@abc.abstractmethod
def _load_data(self) -> Iterator[Any]:
"""Prepare data for another pass through the dataset.
This method should return a generator in a child class.
"""
def __next__(self):
return next(self._data_iter)
def __iter__(self):
self._data_iter = self._load_data()
return self

View File

@@ -0,0 +1,179 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Some tools for I/O."""
import gzip
import io
import os
import re
from typing import NamedTuple, List, Iterator
from absl import logging
def read_txt_file(file_path: str, encoding: str = 'utf-8') -> str:
"""Read a plain txt file."""
with open(file_path, 'rb') as f:
content = f.read()
return content.decode(encoding)
def write_txt_file(file_path: str, txt: str, encoding: str = 'utf-8'):
"""Write the given txt string to file."""
make_dir_if_necessary(file_path)
with open(file_path, 'wb') as f:
f.write(txt.encode(encoding, 'surrogatepass'))
def read_gzip_txt_file(file_path: str, encoding: str = 'utf-8') -> str:
"""Read gzipped txt file."""
with open(file_path, 'rb') as f:
content = f.read()
with gzip.GzipFile(fileobj=io.BytesIO(content), mode='rb') as f:
content = f.read()
return content.decode(encoding)
def make_dir_if_necessary(output_path):
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def write_lines_to_gzipped_file(file_path, lines):
make_dir_if_necessary(file_path)
with open(file_path, 'wb') as f_zip:
with gzip.GzipFile(fileobj=f_zip, mode='wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
class Graph(NamedTuple):
title: str
center: str
edges: List[str]
def graphs_from_file(file_path: str) -> Iterator[Graph]:
"""Read freebase graphs from file.
Args:
file_path: path to the input `.gz` file that contains a list of graphs.
Yields:
graphs: a list of read from the file.
"""
content = read_gzip_txt_file(file_path)
graph_header_sep_re = re.compile(
r'(<graph center=[^ ]+ title="[^"]+">\n)')
graph_header_re = re.compile(
r'<graph center=([^ ]+) title="([^"]+)">\n')
parts = graph_header_sep_re.split(content)
# Skip the first part which is empty
for i in range(1, len(parts), 2):
header, body = parts[i], parts[i + 1]
m = graph_header_re.match(header)
yield Graph(title=m.group(2),
center=m.group(1),
edges=body.strip().split('\n'))
_UNICODE_RE = re.compile(r'(\$[0-9A-Fa-f]{4})')
def normalize_freebase_string(s: str) -> str:
"""Expand the `$xxxx` escaped unicode characters in the input string."""
# '"' is escaped as '``', convert it back.
s.replace('``', '"')
parts = _UNICODE_RE.split(s)
parts = [p if not _UNICODE_RE.match(p) else chr(int(p[1:], base=16))
for p in parts]
return ''.join(parts).replace('_', ' ')
class GraphTextPair(NamedTuple):
"""Text paired with raw graph represented as in `edges`."""
center_node: str
title: str
edges: List[str]
text: str
def pair2lines(pair):
lines = [f'<graph center={pair.center_node} title="{pair.title}">']
lines.append('<section id="text">')
lines.append(pair.text)
lines.append('<section id="edges">')
lines.extend(pair.edges)
return lines
def write_pairs_to_gzip_txt_file(file_path, pairs):
logging.info('Writing %d pairs to %s.', len(pairs), file_path)
lines = []
for p in pairs:
lines.extend(pair2lines(p))
write_lines_to_gzipped_file(file_path, lines)
def read_pairs_from_gzip_txt_file(file_path: str) -> Iterator[GraphTextPair]:
"""Read graph-text pairs from gzip txt files.
Args:
file_path: a `.gz` file of graph-text pairs written in the same format as
using the `write_pairs_to_gzip_txt_file` function.
Yields:
Graph-text pairs from this file.
"""
content = read_gzip_txt_file(file_path)
graph_header_sep_re = re.compile(
r'(<graph center=[^ ]+ title="[^"]+">)')
graph_header_re = re.compile(
r'<graph center=([^ ]+) title="([^"]+)">$')
section_sep_re = re.compile(r'\n(<section id="[^"]+">\n)')
parts = graph_header_sep_re.split(content)
# Skip the first part which is empty
for i in range(1, len(parts), 2):
header, body = parts[i], parts[i + 1]
m = graph_header_re.match(header)
# 5 parts total, empty first part, "text", text section, "edges", edges
# section.
section_parts = section_sep_re.split(body)
yield GraphTextPair(center_node=m.group(1),
title=m.group(2),
text=section_parts[2],
edges=section_parts[-1].strip().split('\n'))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,271 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Tests for wikigraphs.data.paired_dataset."""
from absl.testing import absltest
from wikigraphs.data import io_tools
from wikigraphs.data import paired_dataset
from wikigraphs.data import tokenizers
from wikigraphs.data import wikitext
WIKITEXT_ROOT = '/tmp/data/wikitext-103'
WIKIGRAPHS_ROOT = '/tmp/data/wikigraphs'
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv'
class PairedDatasetTest(absltest.TestCase):
def test_raw_paired_dataset_size(self):
dataset = paired_dataset.RawDataset(
subset='valid', shuffle_data=False, data_dir=WIKIGRAPHS_ROOT)
pairs = list(dataset)
self.assertLen(pairs, 48)
self.assertEqual(pairs[0].title, 'Homarus_gammarus')
self.assertEqual(pairs[-1].title, 'Rakie_Ayola')
# Make sure the content of the articles match the original
wikitext_set = wikitext.RawDataset(
subset='valid', shuffle_data=False, version='raw',
data_dir=WIKITEXT_ROOT)
title2article = {wikitext.normalize_title(a.title).replace(' ', ''): a.text
for a in wikitext_set}
for p in pairs:
title = io_tools.normalize_freebase_string(p.title).replace(' ', '')
article = title2article.get(title, None)
self.assertIsNotNone(article)
self.assertEqual(article, p.text)
def test_graph_from_edges(self):
edges = ['A\tE1\tB',
'A\tE2\tC',
'B\tE1\tC',
'C\tE3\tD',
'C\tE2\tE']
graph = paired_dataset.Graph.from_edges(edges)
self.assertEqual(graph.nodes(), ['A', 'B', 'C', 'D', 'E'])
self.assertEqual(graph.edges(), [(0, 1, 'E1'),
(0, 2, 'E2'),
(1, 2, 'E1'),
(2, 3, 'E3'),
(2, 4, 'E2')])
def test_graph_to_edges(self):
edges = ['A\tE1\tB',
'A\tE2\tC',
'B\tE1\tC',
'C\tE3\tD',
'C\tE2\tE']
graph = paired_dataset.Graph.from_edges(edges)
self.assertEqual(graph.to_edges(), edges)
def test_bow2text_dataset(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
batch_size = 4
seq_len = 256
dataset = paired_dataset.Bow2TextDataset(
tokenizer,
graph_tokenizer,
batch_size=batch_size,
timesteps=seq_len,
subset='valid',
subsample_nodes=0.7,
repeat=False,
data_dir=WIKIGRAPHS_ROOT)
num_tokens = 0
for batch in dataset:
num_tokens += batch['mask'].sum()
self.assertEqual(batch['graphs'].shape,
(batch_size, graph_tokenizer.vocab_size))
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
raw_num_tokens = 0
n_pairs = 0
for pair in raw_dataset:
raw_num_tokens += len(tokenizer.encode(
pair.text, prepend_bos=True, append_eos=True))
n_pairs += 1
# The first token of each example is not counted by `mask` as it masks the
# targets, and the first token of each example never appears in the targets.
self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
def test_graph2text_dataset(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
batch_size = 4
seq_len = 256
dataset = paired_dataset.Graph2TextDataset(
tokenizer,
graph_tokenizer,
batch_size=batch_size,
timesteps=seq_len,
subsample_nodes=0.8,
subset='valid',
data_dir=WIKIGRAPHS_ROOT)
data_iter = iter(dataset)
batch = next(data_iter)
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
self.assertIsInstance(batch['graphs'], list)
self.assertLen(batch['graphs'], batch_size)
for i in range(batch_size):
self.assertIsInstance(batch['graphs'][i], paired_dataset.ParsedGraph)
# +1 for the center_node mask
self.assertEqual(
batch['graphs'][i].nodes.shape[-1], graph_tokenizer.vocab_size + 1)
self.assertEqual(
batch['graphs'][i].edges.shape[-1], graph_tokenizer.vocab_size)
n_edges = batch['graphs'][i].edges.shape[0]
self.assertEqual(batch['graphs'][i].sender.shape, (n_edges,))
self.assertEqual(batch['graphs'][i].receiver.shape, (n_edges,))
# Make sure the token count matches across the tokenized data and the raw
# data set.
num_tokens = 0
for batch in dataset:
num_tokens += batch['mask'].sum()
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
raw_num_tokens = 0
n_pairs = 0
for pair in raw_dataset:
raw_num_tokens += len(tokenizer.encode(
pair.text, prepend_bos=True, append_eos=True))
n_pairs += 1
# The first token of each example is not counted by `mask` as it masks the
# targets, and the first token of each example never appears in the targets.
self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
def test_text_only_dataset(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
batch_size = 4
seq_len = 256
dataset = paired_dataset.TextOnlyDataset(
tokenizer,
batch_size=batch_size,
timesteps=seq_len,
subset='valid',
data_dir=WIKIGRAPHS_ROOT)
data_iter = iter(dataset)
batch = next(data_iter)
faux_batch = dataset.return_faux_batch()
self.assertCountEqual(list(batch.keys()),
['obs', 'target', 'should_reset', 'mask'])
self.assertCountEqual(list(faux_batch.keys()),
['obs', 'target', 'should_reset', 'mask'])
for k, v in batch.items():
faux_v = faux_batch[k]
self.assertEqual(v.shape, faux_v.shape)
self.assertEqual(v.dtype, faux_v.dtype)
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
num_tokens = 0
for batch in dataset:
num_tokens += batch['mask'].sum()
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
raw_num_tokens = 0
n_pairs = 0
for pair in raw_dataset:
raw_num_tokens += len(tokenizer.encode(
pair.text, prepend_bos=True, append_eos=True))
n_pairs += 1
self.assertEqual(num_tokens + n_pairs, raw_num_tokens)
def test_bow_retrieval_dataset(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
batch_size = 4
seq_len = 256
dataset = paired_dataset.Bow2TextDataset(
tokenizer,
graph_tokenizer,
batch_size=batch_size,
timesteps=seq_len,
subsample_nodes=0.8,
graph_retrieval_dataset=True,
subset='valid',
data_dir=WIKIGRAPHS_ROOT)
data_iter = iter(dataset)
batch = next(data_iter)
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
self.assertEqual(batch['graph_id'].shape, (batch_size,))
self.assertEqual(batch['seq_id'].shape, (batch_size,))
def test_graph_retrieval_dataset(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
batch_size = 4
seq_len = 256
dataset = paired_dataset.Graph2TextDataset(
tokenizer,
graph_tokenizer,
batch_size=batch_size,
timesteps=seq_len,
subsample_nodes=0.8,
graph_retrieval_dataset=True,
subset='valid',
data_dir=WIKIGRAPHS_ROOT)
data_iter = iter(dataset)
batch = next(data_iter)
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
self.assertEqual(batch['graph_id'].shape, (batch_size,))
self.assertEqual(batch['seq_id'].shape, (batch_size,))
if __name__ == '__main__':
absltest.main()

View File

@@ -0,0 +1,230 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Tokenizers for text data."""
import abc
import csv
import io
import re
from typing import List
import nltk
import numpy as np
from wikigraphs.data import io_tools
class Tokenizer(abc.ABC):
"""Base class for tokenizers."""
@abc.abstractmethod
def encode(self,
inputs: str,
prepend_bos: bool = False,
append_eos: bool = False) -> np.ndarray:
"""Encode input string into an array of token IDs.
Args:
inputs: a string.
prepend_bos: set to True to add <bos> token at the beginning of the token
sequence.
append_eos: set to True to add <eos> token at the end of the token
sequence.
Returns:
tokens: [n_tokens] int array.
"""
@abc.abstractmethod
def decode(self, inputs) -> str:
"""Decode a sequence of tokens back into a string.
Args:
inputs: array or list of ints.
Returns:
s: the decoded string using this tokenizer.
"""
@property
@abc.abstractmethod
def vocab_size(self) -> int:
"""Size of the vocabulary."""
@abc.abstractmethod
def pad_token(self) -> int:
"""ID of the <pad> token."""
@abc.abstractmethod
def bos_token(self) -> int:
"""ID of the <bos> token."""
class WordTokenizer(Tokenizer):
"""Word-level tokenizer for white-space separated text data."""
def __init__(self, vocab_file: str):
"""Constructor.
Args:
vocab_file: a csv vocab file.
"""
content = io_tools.read_txt_file(vocab_file, encoding='utf-8')
with io.StringIO(content) as f:
r = csv.reader(f)
vocab = [w for w, _ in r]
# Add pad and bos tokens to the vocab
to_add = ['<pad>', '<bos>']
if '<unk>' not in vocab:
to_add.append('<unk>')
vocab = to_add + vocab
# token-index mappings
self._t2i = {t: i for i, t in enumerate(vocab)}
self._i2t = {i: t for t, i in self._t2i.items()}
self._unk_token = self._t2i['<unk>']
self._bos_token = self._t2i['<bos>']
self._pad_token = self._t2i['<pad>']
@property
def vocab_size(self):
return len(self._t2i)
def encode(self, inputs, prepend_bos=False, append_eos=False):
tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t]
if prepend_bos:
tokens = [self._bos_token] + tokens
if append_eos:
# Reuse <bos> as <eos>.
tokens.append(self._bos_token)
return np.array(tokens, dtype=np.int32)
def decode(self, inputs):
"""Decode a sequence of token IDs back into a string."""
# Remove the first <bos> token if there is any.
if inputs[0] == self._bos_token:
inputs = inputs[1:]
tokens = []
for i in inputs:
# Use <bos> also as <eos> and stop there.
if i == self._bos_token:
break
tokens.append(self._i2t[i])
return ' '.join(tokens)
def pad_token(self):
return self._pad_token
def bos_token(self):
return self._bos_token
class GraphTokenizer:
"""Tokenizer for the content on the graphs."""
def __init__(self, vocab_file: str):
"""Constructor.
Args:
vocab_file: path to a vocab file.
"""
content = io_tools.read_txt_file(vocab_file, encoding='utf-16')
vocab = content.split('\n')
vocab = ['<pad>', '<bos>', '<unk>'] + vocab
# token-index mappings
self._t2i = {t: i for i, t in enumerate(vocab)}
self._i2t = {i: t for t, i in self._t2i.items()}
self._unk_token = self._t2i['<unk>']
self._bos_token = self._t2i['<bos>']
self._pad_token = self._t2i['<pad>']
@property
def vocab_size(self):
return len(self._t2i)
def encode_node(self, txt: str) -> np.ndarray:
return np.array([self._t2i.get(t, self._unk_token)
for t in self.split_node(txt)])
def encode_edge(self, txt: str) -> np.ndarray:
return np.array([self._t2i.get(t, self._unk_token)
for t in self.split_edge(txt)])
def encode(self, inputs, prepend_bos=False, append_eos=False):
tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t]
if prepend_bos:
tokens = [self._bos_token] + tokens
if append_eos:
# Reuse <bos> as <eos>.
tokens.append(self._bos_token)
return np.array(tokens, dtype=np.int32)
def decode(self, inputs):
"""Decode a sequence of token IDs back into a string."""
# Remove the first <bos> token if there is any.
if inputs[0] == self._bos_token:
inputs = inputs[1:]
tokens = []
for i in inputs:
# Use <bos> also as <eos> and stop there.
if i == self._bos_token:
break
tokens.append(self._i2t[i])
return ' '.join(tokens)
@classmethod
def split_node(cls, txt: str) -> List[str]:
"""Split a node string into a sequence of tokens."""
if txt[0] == '"' and txt[-1] == '"': # Node is a string literal.
tokens = nltk.wordpunct_tokenize(io_tools.normalize_freebase_string(
txt[1:-1].lower()))
for i, t in enumerate(tokens):
if t.isnumeric():
tokens[i] = '<number>'
return tokens
else: # If node is not a string literal it is always an entity.
return ['<entity>']
@classmethod
def split_edge(cls, txt: str) -> List[str]:
"""Split an edge string into a sequence of tokens."""
return re.split('[._ ]+', txt.lower().split('/')[1])
def pad_token(self):
return self._pad_token
def bos_token(self):
return self._bos_token

View File

@@ -0,0 +1,78 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Tests for wikigraphs.data.tokenizers."""
from absl.testing import absltest
from wikigraphs.data import tokenizers
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv'
class TokenizerTest(absltest.TestCase):
def test_tokenizer(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
# Vocab size must match published number.
self.assertEqual(tokenizer.vocab_size, 267735 + 2)
s = 'Hello world ! \n How are you ?'
encoded = tokenizer.encode(s, prepend_bos=True)
self.assertEqual(encoded.shape, (9,))
decoded = tokenizer.decode(encoded)
self.assertEqual(s, decoded)
def test_graph_tokenizer_tokenize_nodes_edges(self):
self.assertEqual(
tokenizers.GraphTokenizer.split_node(
'"Hello, how are you?"'),
['hello', ',', 'how', 'are', 'you', '?'])
self.assertEqual(
tokenizers.GraphTokenizer.split_node(
'"This building was built in 1998."'),
['this', 'building', 'was', 'built', 'in', '<number>', '.'])
self.assertEqual(
tokenizers.GraphTokenizer.split_node('ns/m.030ssw'),
['<entity>'])
self.assertEqual(
tokenizers.GraphTokenizer.split_edge('ns/common.topic.description'),
['common', 'topic', 'description'])
self.assertEqual(
tokenizers.GraphTokenizer.split_edge('ns/type.object.name'),
['type', 'object', 'name'])
def test_graph_tokenizer_vocab(self):
tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
self.assertEqual(tokenizer.vocab_size, 31087 + 3)
if __name__ == '__main__':
absltest.main()

View File

@@ -0,0 +1,242 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Some tools for processing data."""
from typing import Any, Iterator
from absl import logging
import numpy as np
def pad_to(x: np.array, size: int, axis: int = -1, pad_value: float = 0.):
"""Pad an array to the specified size along a specified axis."""
if x.shape[axis] > size:
raise ValueError(f'Data item has size {x.shape[axis]} larger than {size}'
f' in axis {axis} already.')
elif x.shape[axis] == size:
return x
else:
pad_amount = [(0, 0)] * x.ndim
pad_amount[axis] = (0, size - x.shape[axis])
return np.pad(x, pad_amount, mode='constant', constant_values=pad_value)
def dynamic_batch(
iterable: Iterator[Any],
batch_size: int,
timesteps: int,
return_incomplete_batch: bool = False,
pad: bool = False,
pad_value: float = 0.) -> Iterator[Any]:
"""Batches up values in iterable to [batch_size, timesteps].
This function takes items from the iterable and pack them into the batch.
Sequence #i in the batch is a continuation from the sequence #i in the
previous batch, i.e. it will start from where the previous sequence left over.
When an item is finished, a new item is taken from the iterable to append to
the sequence and fill the batch.
This function is designed for language modeling, where the input and the
target sequences are offset by one. We take that into account by making sure
neighboring batches have one token overlap.
Example:
If the iterable contains [[0, 1, 2], [10, 11, 12, 13, 14], [20, 21, 22]] and
batch size is 2, timesteps is 3, then the first batch would be:
[[0, 1, 2],
[10, 11, 12]]
then the second batch:
[[2, 20, 21], # seq 0 finished, continuing from seq 2
[12, 13, 14]]
Note the overlap of 1 token between these two batches, and the continuation
of sequences across batches.
Args:
iterable: the iterable that yields sequences of integer token IDs.
batch_size: number of examples in a batch.
timesteps: length of each sequence in a batch.
return_incomplete_batch: if True return the incomplete batches, which
typically appears at the end of the dataset.
pad: set to True to pad the incomplete batches.
pad_value: the value to use for padding.
Yields:
batches: where batches['obs'] are the observations of size
[batch_size, timesteps], and batches['should_reset'] is a 0/1 mask of
the same size that marks sequence boundaries, e.g. the entries in this
mask are all 0 except at locations where a new sequence is starting.
"""
if return_incomplete_batch and not pad:
raise ValueError(
f'If return_incomplete_batch, then pad must be True, currently {pad}.')
iterator = iter(iterable)
elems = []
for _ in range(batch_size):
item = next(iterator)
elems.append(item)
start_batch = [True] * batch_size
iter_finished = False
loaded_finished = False
while not (iter_finished and loaded_finished):
batch = []
for i in range(batch_size):
# should_reset value is 1 when a new sequence begins.
# [old[-3], old[-2], old[-1], new[0], new[1], new[2]]
# [0, 0, 0, 1, 0, 0]
should_reset = np.zeros(timesteps, np.float32)
if start_batch[i]:
should_reset[0] = 1
# Pack new examples in the sequence until they go beyond the required
# timesteps.
while len(elems[i]) < timesteps:
should_reset[len(elems[i])] = 1
try:
item = next(iterator)
except StopIteration:
iter_finished = True
break
elems[i] = np.concatenate([elems[i], item])
batch.append(dict(obs=elems[i][:timesteps], should_reset=should_reset))
# Shift and make sure we have a 1 token overlap.
elems[i] = elems[i][timesteps - 1:]
# Since the last token is shifted to be the first token of the next batch,
# We need to make sure reset is handled properly as well.
start_batch[i] = (should_reset[-1] == 1)
# If any loaded data is not yet consumed in the output we should keep
# generating.
loaded_finished = all(e.size == 0 for e in elems)
if not return_incomplete_batch:
elem_len = len(batch[0]['obs'])
if (elem_len != timesteps or
not all(len(x['obs']) == elem_len for x in batch[1:])):
logging.info('Dropping the (last?) incomplete batch.')
break
if pad:
for x in batch:
x['obs'] = pad_to(x['obs'], timesteps, axis=0, pad_value=pad_value)
yield dict(
obs=np.stack([x['obs'] for x in batch], axis=0),
should_reset=np.stack([x['should_reset'] for x in batch], axis=0))
def batch_graph_text_pairs(
iterable: Iterator[Any],
batch_size: int,
timesteps: int,
pad_value: float = 0.,
seq_and_graph_id: bool = False) -> Iterator[Any]:
"""Batch graph and text pairs.
This method pairs text with graphs, each text sequence is split into chunks
(with an overlap of 1) of size `timesteps`, and the graph associated with the
text is used and associated with each chunk as well. The last incomplete
chunk of each text sequence is padded with the `pad_value`.
Args:
iterable: Iterable that returns (graph, sequence) pairs, graph can be
anything, and sequence is a list of tokenized token IDs.
batch_size: Number of examples in a batch.
timesteps: Window size for the sequences.
pad_value: Value to use for padding.
seq_and_graph_id: whether the `iterable` contains `seq_id` and `graph_id`.
Yields:
batch: a batch of text sequence paired with graphs.
"""
iterator = iter(iterable)
seqs = [None] * batch_size
graphs = [None] * batch_size
graph_ids = [None] * batch_size
seq_ids = [None] * batch_size
iter_finished = False
loaded_finished = False
while not (iter_finished and loaded_finished):
batch = []
for idx in range(batch_size):
should_reset = np.zeros(timesteps, np.float32)
# pylint: disable=g-explicit-length-test
if seqs[idx] is None or len(seqs[idx]) == 0:
should_reset[0] = 1
# One sequence exhausted, get the next example.
try:
if seq_and_graph_id:
(graph, seq), (graph_id, seq_id) = next(iterator)
graph_ids[idx] = graph_id
seq_ids[idx] = seq_id
else:
graph, seq = next(iterator)
seqs[idx] = seq
graphs[idx] = graph
except StopIteration:
iter_finished = True
seqs[idx] = np.array([pad_value], dtype=np.int32)
graphs[idx] = None
example = dict(obs=seqs[idx][:timesteps], graph=graphs[idx],
should_reset=should_reset)
if seq_and_graph_id:
example['seq_id'] = seq_ids[idx]
example['graph_id'] = graph_ids[idx]
batch.append(example)
# Make sure that there is an overlap, as we generate targets by shifting
# the tensor by 1 timestep. So the next element should be shifted by
# `timesteps - 1' timesteps.
seqs[idx] = seqs[idx][timesteps - 1:]
# Make sure all loaded data are consumed in the output
loaded_finished = all(s.size == 0 for s in seqs)
# Also check for the last batch to avoid returning a fully empty batch
if iter_finished and all([np.all(b['obs'] == pad_value) for b in batch]):
break
# pad sequences to specified length
for e in batch:
e['obs'] = pad_to(e['obs'], timesteps, axis=0, pad_value=pad_value)
stacked_batch = dict(
obs=np.stack([e['obs'] for e in batch], axis=0),
graphs=[e['graph'] for e in batch],
should_reset=np.stack([e['should_reset'] for e in batch], axis=0))
if seq_and_graph_id:
stacked_batch['seq_id'] = np.stack(
[e['seq_id'] for e in batch], axis=0)
stacked_batch['graph_id'] = np.stack(
[e['graph_id'] for e in batch], axis=0)
yield stacked_batch

View File

@@ -0,0 +1,195 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Tests for wikigraphs.data.tools."""
from absl.testing import absltest
import numpy as np
from wikigraphs.data import tools
class ToolsTest(absltest.TestCase):
def test_padding(self):
np.testing.assert_array_equal(
tools.pad_to(np.arange(3), 5),
[0, 1, 2, 0, 0])
np.testing.assert_array_equal(
tools.pad_to(np.arange(3), 5, pad_value=-1),
[0, 1, 2, -1, -1])
np.testing.assert_array_equal(
tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=0, pad_value=-1),
[[0, 1, 2],
[3, 4, 5],
[-1, -1, -1],
[-1, -1, -1]])
np.testing.assert_array_equal(
tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=-1, pad_value=-1),
[[0, 1, 2, -1],
[3, 4, 5, -1]])
def test_dynamic_batch(self):
def dataset():
data = [[1, 2, 2, 2],
[1, 3, 3],
[1, 4]]
for d in data:
yield np.array(d, dtype=np.int32)
batches = list(tools.dynamic_batch(
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=False))
self.assertLen(batches, 1)
np.testing.assert_array_equal(
batches[0]['obs'],
[[1, 2, 2], [1, 3, 3]])
np.testing.assert_array_equal(
batches[0]['should_reset'],
[[1, 0, 0], [1, 0, 0]])
batches = list(tools.dynamic_batch(
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True,
pad=True, pad_value=0))
# Note `return_incomplete_batch=False` drops all the incomplete batches,
# and this can be more than just the last batch.
self.assertLen(batches, 3)
np.testing.assert_array_equal(
batches[0]['obs'],
[[1, 2, 2], [1, 3, 3]])
np.testing.assert_array_equal(
batches[0]['should_reset'],
[[1, 0, 0], [1, 0, 0]])
np.testing.assert_array_equal(
batches[1]['obs'],
[[2, 2, 1], [3, 0, 0]])
np.testing.assert_array_equal(
batches[1]['should_reset'],
[[0, 0, 1], [0, 1, 0]])
np.testing.assert_array_equal(
batches[2]['obs'],
[[1, 4, 0], [0, 0, 0]])
np.testing.assert_array_equal(
batches[2]['should_reset'],
[[1, 0, 1], [1, 0, 0]])
with self.assertRaises(ValueError):
batches = list(tools.dynamic_batch(
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True,
pad=False))
def test_batch_graph_text_pairs(self):
def source():
yield (1, np.array([1, 1, 1, 1, 1], dtype=np.int32))
yield (2, np.array([2, 2], dtype=np.int32))
yield (3, np.array([3, 3, 3, 3, 3, 3], dtype=np.int32))
data_iter = tools.batch_graph_text_pairs(
source(), batch_size=2, timesteps=3, pad_value=0)
batches = list(data_iter)
self.assertLen(batches, 4)
batch = batches[0]
np.testing.assert_array_equal(
batch['obs'],
[[1, 1, 1],
[2, 2, 0]])
self.assertEqual(batch['graphs'], [1, 2])
np.testing.assert_array_equal(
batch['should_reset'],
[[1, 0, 0],
[1, 0, 0]])
batch = batches[1]
np.testing.assert_array_equal(
batch['obs'],
[[1, 1, 1],
[3, 3, 3]])
self.assertEqual(batch['graphs'], [1, 3])
np.testing.assert_array_equal(
batch['should_reset'],
[[0, 0, 0],
[1, 0, 0]])
batch = batches[2]
np.testing.assert_array_equal(
batch['obs'],
[[1, 0, 0],
[3, 3, 3]])
self.assertEqual(batch['graphs'], [1, 3])
np.testing.assert_array_equal(
batch['should_reset'],
[[0, 0, 0],
[0, 0, 0]])
batch = batches[3]
np.testing.assert_array_equal(
batch['obs'],
[[0, 0, 0],
[3, 3, 0]])
self.assertEqual(batch['graphs'], [None, 3])
np.testing.assert_array_equal(
batch['should_reset'],
[[1, 0, 0],
[0, 0, 0]])
def test_batch_graph_text_pairs_batch_size1(self):
def source():
yield (0, np.array([1, 2], dtype=np.int32))
yield (1, np.array([1, 2, 3, 4, 5, 6], dtype=np.int32))
data_iter = tools.batch_graph_text_pairs(
source(), batch_size=1, timesteps=3, pad_value=0)
batches = list(data_iter)
batch = batches[0]
np.testing.assert_array_equal(batch['obs'], [[1, 2, 0]])
self.assertEqual(batch['graphs'], [0])
np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]])
batch = batches[1]
np.testing.assert_array_equal(batch['obs'], [[1, 2, 3]])
self.assertEqual(batch['graphs'], [1])
np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]])
batch = batches[2]
np.testing.assert_array_equal(batch['obs'], [[3, 4, 5]])
self.assertEqual(batch['graphs'], [1])
np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]])
batch = batches[3]
np.testing.assert_array_equal(batch['obs'], [[5, 6, 0]])
self.assertEqual(batch['graphs'], [1])
np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]])
self.assertLen(batches, 4)
if __name__ == '__main__':
absltest.main()

View File

@@ -0,0 +1,218 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Wikitext-103 datasets."""
import re
from typing import NamedTuple, List
from absl import logging
import numpy as np
from wikigraphs.data import dataset
from wikigraphs.data import tokenizers
from wikigraphs.data import tools
# The data directory that contains subdirectories `wikitext-103` and
# `wikitext-103-raw`.
DATA_ROOT = '/tmp/data/wikitext-103'
class WikitextArticle(NamedTuple):
title: str
text: str
def articles_from_file(file_path: str) -> List[WikitextArticle]:
"""Read wikitext articles from file.
Args:
file_path: path to the input `.tokens` file.
Returns:
A list of `WikitextArticle` tuples.
"""
with open(file_path, mode='rb') as f:
content = f.read()
content = content.decode('utf-8')
title_re = re.compile(r'(\n = ([^=].*) = \n \n)')
parts = title_re.split(content)
# Skip the first part which is empty
return [WikitextArticle(title=parts[i+1], text=parts[i] + parts[i+2])
for i in range(1, len(parts), 3)]
class RawDataset(dataset.Dataset):
"""Raw text dataset for wikitext-103."""
def __init__(self,
subset: str = 'train',
shuffle_data: bool = False,
data_dir: str = None,
version: str = 'tokens'):
"""Constructor.
Args:
subset: which subset to load, one of {"train", "valid", "test"}.
shuffle_data: if set to True the data will be randomly shuffled.
data_dir: if provided will be used instead of the default `DATA_ROOT` as
the directory that contains the data.
version: one of {'tokens', 'raw'}
"""
super().__init__()
self._subset = subset
self._shuffle_data = shuffle_data
self._data_dir = data_dir or DATA_ROOT
self._dataset = None
allowed_versions = ('tokens', 'raw')
if version not in allowed_versions:
raise ValueError(f'Version must be one of {allowed_versions}.')
self._version = version
def _load_data(self):
"""Prepare data for another pass through the dataset."""
if self._dataset is None:
data_root = self._data_dir + ('-raw' if self._version == 'raw' else '')
self._dataset = articles_from_file(
f'{data_root}/wiki.{self._subset}.{self._version}')
def source():
n_articles = len(self._dataset)
if self._shuffle_data:
idx = np.random.permutation(n_articles)
else:
idx = np.arange(n_articles)
for i in range(n_articles):
yield self._dataset[idx[i]]
return source()
def normalize_title(title: str) -> str:
"""Normalize the wikitext article title by handling special characters."""
return title.replace(
'@-@', '-').replace('@,@', ',').replace('@.@', '.').replace(' ', '')
class WikitextDataset(dataset.Dataset):
"""Tokenized dataset for wikitext-103."""
def __init__(self,
tokenizer: tokenizers.Tokenizer,
batch_size: int = 1,
timesteps: int = 128,
subset: str = 'train',
shuffle_data: bool = True,
data_dir: str = None,
repeat: bool = False,
debug: bool = False,
**kwargs):
"""Constructor.
Args:
tokenizer: a tokenizer for text data.
batch_size: number of sequences to put into a batch.
timesteps: length of the sequences.
subset: which subset to load, one of {"train", "valid", "test"}.
shuffle_data: if set to True the data will be randomly shuffled.
data_dir: if provided will be used instead of the default `DATA_ROOT` as
the directory that contains the data.
repeat: set to False to go through the data only once, otherwise go
through the data indefinitely.
debug: set to True to only load a small amount of data for fast debugging.
**kwargs: other arguments (for interface compatibility).
"""
super().__init__()
self._tokenizer = tokenizer
self._batch_size = batch_size
self._timesteps = timesteps
self._subset = subset
self._shuffle_data = shuffle_data
self._data_dir = data_dir
self._repeat = repeat
self._debug = debug
self._dataset = None
def _load_data(self):
"""Prepare data for one pass through the dataset."""
# Pre-tokenize everything in our dataset so we don't have to when going
# through the data more than once.
if not self._dataset:
raw_dataset = RawDataset(
subset=self._subset, shuffle_data=False, data_dir=self._data_dir)
if self._debug:
# Load a small number of examples for debugging.
self._dataset = [
self._tokenizer.encode(next(raw_dataset).text, prepend_bos=True)
for _ in range(5)]
else:
self._dataset = [self._tokenizer.encode(item.text, prepend_bos=True)
for item in raw_dataset]
logging.info('%s set loaded, total %d examples.',
self._subset, len(self._dataset))
def source():
idx = np.random.permutation(len(self._dataset))
for i in idx:
yield self._dataset[i]
def repeated_source():
if self._repeat:
while True:
yield from source()
else:
yield from source()
data_iter = tools.dynamic_batch(
repeated_source(),
self._batch_size,
self._timesteps + 1, # Extra token to count for the overlap.
return_incomplete_batch=True,
pad=True,
pad_value=self._tokenizer.pad_token())
data_iter = map(lambda x: dict( # pylint: disable=g-long-lambda
obs=x['obs'][:, :-1],
target=x['obs'][:, 1:],
should_reset=x['should_reset'][:, :-1],
mask=(x['obs'][:, 1:] != self._tokenizer.pad_token()).astype(
np.float32),
), data_iter)
return data_iter
def return_faux_batch(self):
"""Return a fake batch with the right shapes and dtypes."""
obs = np.zeros((self._batch_size, self._timesteps), dtype=np.int32)
target = np.zeros_like(obs, dtype=np.int32)
should_reset = np.zeros_like(obs, dtype=np.float32)
mask = np.zeros_like(obs, dtype=np.float32)
return dict(obs=obs, target=target, should_reset=should_reset, mask=mask)

View File

@@ -0,0 +1,84 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Tests for wikigraphs.data.wikitext."""
from absl.testing import absltest
from wikigraphs.data import tokenizers
from wikigraphs.data import wikitext
WIKITEXT_ROOT = '/tmp/data/wikitext-103'
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
class WikitextTest(absltest.TestCase):
def test_wikitext_size(self):
valid_set = wikitext.RawDataset(
subset='valid', shuffle_data=False, data_dir=WIKITEXT_ROOT)
n_tokens = 0
n_articles = 0
for article in valid_set:
n_tokens += len([t for t in article.text.split(' ') if t])
n_articles += 1
# Dataset size must match published values.
self.assertEqual(n_tokens, 217646)
self.assertEqual(n_articles, 60)
def test_wikitext_dataset_size(self):
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
batch_size = 4
timesteps = 256
valid_set = wikitext.WikitextDataset(
tokenizer=tokenizer, batch_size=batch_size, timesteps=timesteps,
subset='valid', shuffle_data=False, repeat=False,
data_dir=WIKITEXT_ROOT)
n_tokens = 0
n_bos = 0
for batch in valid_set:
n_tokens += (batch['obs'] != tokenizer.pad_token()).sum()
n_bos += (batch['obs'] == tokenizer.bos_token()).sum()
self.assertEqual(
batch['obs'].shape, (batch_size, timesteps))
self.assertEqual(
batch['target'].shape, (batch_size, timesteps))
self.assertEqual(
batch['should_reset'].shape, (batch_size, timesteps))
self.assertEqual(
batch['mask'].shape, (batch_size, timesteps))
n_tokens -= n_bos
self.assertEqual(n_tokens, 217646)
self.assertEqual(n_bos, 60)
if __name__ == '__main__':
absltest.main()