mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Added jaxline pipeline to train adversarially robust models.
PiperOrigin-RevId: 383399487
This commit is contained in:
committed by
Louise Deason
parent
d8df4155dc
commit
5909da5388
@@ -13,7 +13,7 @@ We have released our top-performing models in two formats compatible with
|
||||
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
|
||||
This repository also contains our model definitions.
|
||||
|
||||
## Running the example code
|
||||
## Running the code
|
||||
|
||||
### Downloading a model
|
||||
|
||||
@@ -47,10 +47,32 @@ The following table contains the models from **Rebuffi et al., 2021**.
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
|
||||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
|
||||
|
||||
### Using the model
|
||||
### Installing
|
||||
|
||||
Once downloaded, a model can be evaluated (clean accuracy) by running the
|
||||
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
|
||||
The following has been tested using Python 3.9.2.
|
||||
Using `run.sh` will create and activate a virtualenv, install all necessary
|
||||
dependencies and run a test program to ensure that you can import all the
|
||||
modules.
|
||||
|
||||
```
|
||||
# Run from the parent directory.
|
||||
sh adversarial_robustness/run.sh
|
||||
```
|
||||
|
||||
To run the provided code, use this virtualenv:
|
||||
|
||||
```
|
||||
source /tmp/adversarial_robustness_venv/bin/activate
|
||||
```
|
||||
|
||||
You may want to edit `requirements.txt` before running `run.sh` if GPU support
|
||||
is needed (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation
|
||||
[instructions](https://github.com/google/jax#installation) for more details.
|
||||
|
||||
### Using pre-trained models
|
||||
|
||||
Once downloaded, a model can be evaluated by running the `eval.py` script in
|
||||
either the `jax` or `pytorch` folders. E.g.:
|
||||
|
||||
```
|
||||
cd jax
|
||||
@@ -58,7 +80,47 @@ python3 eval.py \
|
||||
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
|
||||
```
|
||||
|
||||
## Generated datasets
|
||||
These models are also directly available within
|
||||
[RobustBench](https://github.com/RobustBench/robustbench#model-zoo-quick-tour)'s
|
||||
model zoo.
|
||||
|
||||
### Training your own model
|
||||
|
||||
We also provide a training pipeline that reproduces results from both
|
||||
publications. This pipeline uses [Jaxline](https://github.com/deepmind/jaxline)
|
||||
and is written using [JAX](https://github.com/google/jax) and
|
||||
[Haiku](https://github.com/deepmind/dm-haiku). To train a model, modify the
|
||||
configuration in the `get_config()` function of `jax/experiment.py` and issue
|
||||
the following command from within the virtualenv created above:
|
||||
|
||||
```
|
||||
cd jax
|
||||
python3 train.py --config=experiment.py
|
||||
```
|
||||
|
||||
The training pipeline can run with multiple worker machines and multiple devices
|
||||
(either GPU or TPU). See [Jaxline](https://github.com/deepmind/jaxline) for more
|
||||
details.
|
||||
|
||||
We do not provide a PyTorch implementation of our training pipeline. However,
|
||||
you may find one on GitHub, e.g.,
|
||||
[adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch)
|
||||
(by Rahul Rade).
|
||||
|
||||
## Datasets
|
||||
|
||||
### Extracted dataset
|
||||
|
||||
Gowal et al. (2020) use samples extracted from
|
||||
[TinyImages-80M](https://groups.csail.mit.edu/vision/TinyImages/).
|
||||
Unfortunately, since then, the official TinyImages-80M dataset has been
|
||||
withdrawn (due to the presence of offensive images). As such, we cannot provide
|
||||
a download link to our extrated data until we have manually verified that all
|
||||
extracted images are not offensive. If you want to reproduce our setup, consider
|
||||
the generated datasets below. We are also happy to help, so feel free to reach
|
||||
out to Sven Gowal directly.
|
||||
|
||||
### Generated datasets
|
||||
|
||||
Rebuffi et al. (2021) use samples generated by a Denoising Diffusion
|
||||
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
|
||||
@@ -82,8 +144,8 @@ labels = npzfile['label']
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use this code, data or these models in your work, please cite the
|
||||
relevant accompanying paper:
|
||||
If you use this code (or any derived code), data or these models in your work,
|
||||
please cite the relevant accompanying paper:
|
||||
|
||||
```
|
||||
@article{gowal2020uncovering,
|
||||
@@ -95,7 +157,7 @@ relevant accompanying paper:
|
||||
}
|
||||
```
|
||||
|
||||
or
|
||||
and/or
|
||||
|
||||
```
|
||||
@article{rebuffi2021fixing,
|
||||
|
||||
310
adversarial_robustness/jax/attacks.py
Normal file
310
adversarial_robustness/jax/attacks.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Adversarial attacks.
|
||||
|
||||
This file contains all the code necessary to create untargeted adversarial
|
||||
attacks in JAX (within an l-infinity ball). For example, to create an untargeted
|
||||
FGSM attack (with a single step), one can do the following:
|
||||
|
||||
```
|
||||
import attacks
|
||||
|
||||
epsilon = 8/255 # Perturbation radius for inputs between 0 and 1.
|
||||
fgsm_attack = attacks.UntargetedAttack(
|
||||
attacks.PGD(
|
||||
attacks.IteratedFGSM(epsilon),
|
||||
num_steps=1,
|
||||
initialize_fn=attacks.linf_initialize_fn(epsilon),
|
||||
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
|
||||
loss_fn=attacks.untargeted_cross_entropy)
|
||||
```
|
||||
|
||||
Just as elegantly, one can specify an adversarial attack on KL-divergence
|
||||
to a target distribution (using 10 steps with Adam and a piecewise constant step
|
||||
schedule):
|
||||
|
||||
```
|
||||
kl_attack_with_adam = attacks.UntargetedAttack(
|
||||
attacks.PGD(
|
||||
attacks.Adam(optax.piecewise_constant_schedule(
|
||||
init_value=.1,
|
||||
boundaries_and_scales={5: .1})),
|
||||
num_steps=10,
|
||||
initialize_fn=attacks.linf_initialize_fn(epsilon),
|
||||
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
|
||||
loss_fn=attacks.untargeted_kl_divergence)
|
||||
```
|
||||
|
||||
The attack instances can be used later on to build adversarial examples:
|
||||
|
||||
```
|
||||
my_model = ... # Model. We assume that 'my_model(.)' returns logits.
|
||||
clean_images, image_labels = ... # Batch of images and associated labels.
|
||||
rng = jax.random.PRNGKey(0) # A random generator state.
|
||||
|
||||
adversarial_images = fgsm_attack(my_model, rng, clean_images, image_labels)
|
||||
```
|
||||
|
||||
See `experiment.py` or `eval.py` for more examples.
|
||||
|
||||
This file contains the following components:
|
||||
* Losses:
|
||||
* untargeted_cross_entropy: minimizes the likelihood of the label class.
|
||||
* untargeted_kl_divergence: maximizes the KL-divergence of the predictions with
|
||||
a target distribution.
|
||||
* untargeted_margin: maximizes the margin loss (distance from the highest
|
||||
non-true logits to the label class logit)
|
||||
* Step optimizers:
|
||||
* SGD: Stochastic Gradient Descent.
|
||||
* IteratedFGSM: Also called BIM (see https://arxiv.org/pdf/1607.02533).
|
||||
* Adam: See https://arxiv.org/pdf/1412.6980.
|
||||
* Initialization and projection functions:
|
||||
* linf_initialize_fn: Initialize function for l-infinity attacks.
|
||||
* linf_project_fn: Projection function for l-infinity attacks.
|
||||
* Projected Gradient Descent (PGD):
|
||||
* PGD: Runs Projected Gradient Descent using the specified optimizer,
|
||||
initialization and projection functions for a given number of steps.
|
||||
* Untargeted attack:
|
||||
* UntargetedAttack: Combines PGD and a specific loss function to find
|
||||
adversarial examples.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
|
||||
ModelFn = Callable[[chex.Array], chex.Array]
|
||||
LossFn = Callable[[chex.Array], chex.Array]
|
||||
ClassificationLossFn = Callable[[chex.Array, chex.Array], chex.Array]
|
||||
OptimizeFn = Callable[[LossFn, chex.PRNGKey, chex.Array], chex.Array]
|
||||
NormalizeFn = Callable[[chex.Array], chex.Array]
|
||||
InitializeFn = Callable[[chex.PRNGKey, chex.Array], chex.Array]
|
||||
ProjectFn = Callable[[chex.Array, chex.Array], chex.Array]
|
||||
|
||||
|
||||
def untargeted_cross_entropy(logits: chex.Array,
|
||||
labels: chex.Array) -> chex.Array:
|
||||
"""Maximize the cross-entropy of the true class (make it less likely)."""
|
||||
num_classes = logits.shape[-1]
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
return jnp.sum(
|
||||
hk.one_hot(labels, num_classes).astype(logits.dtype) * log_probs, axis=-1)
|
||||
|
||||
|
||||
def untargeted_kl_divergence(logits: chex.Array,
|
||||
label_probs: chex.Array) -> chex.Array:
|
||||
"""Maximize the KL divergence between logits and label distribution."""
|
||||
# We are explicitly maximizing the cross-entropy, as this is equivalent to
|
||||
# maximizing the KL divergence (when `label_probs` does not depend
|
||||
# on the values that produce `logits`).
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
return jnp.sum(label_probs * log_probs, axis=-1)
|
||||
|
||||
|
||||
def untargeted_margin(logits: chex.Array,
|
||||
labels: chex.Array) -> chex.Array:
|
||||
"""Make the highest non-correct logits higher than the true class logits."""
|
||||
batch_size = logits.shape[0]
|
||||
num_classes = logits.shape[-1]
|
||||
label_logits = logits[jnp.arange(batch_size), labels]
|
||||
logit_mask = hk.one_hot(labels, num_classes).astype(logits.dtype)
|
||||
highest_logits = jnp.max(logits - 1e8 * logit_mask, axis=-1)
|
||||
return label_logits - highest_logits
|
||||
|
||||
|
||||
class UntargetedAttack:
|
||||
"""Performs an untargeted attack."""
|
||||
|
||||
def __init__(self,
|
||||
optimize_fn: OptimizeFn,
|
||||
loss_fn: ClassificationLossFn = untargeted_cross_entropy):
|
||||
"""Creates an untargeted attack.
|
||||
|
||||
Args:
|
||||
optimize_fn: An `Optimizer` instance or any callable that takes
|
||||
a loss function and an initial input and outputs a new input that
|
||||
minimizes the loss function.
|
||||
loss_fn: `loss_fn` is a surrogate loss. Its goal should be make the true
|
||||
class less likely than any other class. Typical options for `loss_fn`
|
||||
are `untargeted_cross_entropy` or `untargeted_margin`.
|
||||
"""
|
||||
self._optimize_fn = optimize_fn
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def __call__(self,
|
||||
logits_fn: ModelFn,
|
||||
rng: chex.PRNGKey,
|
||||
inputs: chex.Array,
|
||||
labels: chex.Array) -> chex.Array:
|
||||
"""Returns adversarial inputs."""
|
||||
def _loss_fn(x):
|
||||
return self._loss_fn(logits_fn(x), labels)
|
||||
return self._optimize_fn(_loss_fn, rng, inputs)
|
||||
|
||||
# Convenience functions to detect the type of inputs required by the loss.
|
||||
def expects_labels(self):
|
||||
return 'labels' in inspect.getfullargspec(self._loss_fn).args
|
||||
|
||||
def expects_probabilities(self):
|
||||
return 'label_probs' in inspect.getfullargspec(self._loss_fn).args
|
||||
|
||||
|
||||
class StepOptimizer:
|
||||
"""Makes a single gradient step that minimizes a loss function."""
|
||||
|
||||
def __init__(self,
|
||||
gradient_transformation: optax.GradientTransformation):
|
||||
self._gradient_transformation = gradient_transformation
|
||||
|
||||
def init(self,
|
||||
loss_fn: LossFn,
|
||||
x: chex.Array) -> optax.OptState:
|
||||
self._loss_fn = loss_fn
|
||||
return self._gradient_transformation.init(x)
|
||||
|
||||
def minimize(
|
||||
self,
|
||||
x: chex.Array,
|
||||
state: optax.OptState) -> Tuple[chex.Array, chex.Array, optax.OptState]:
|
||||
"""Performs a single minimization step."""
|
||||
g, loss = gradients_fn(self._loss_fn, x)
|
||||
if g is None:
|
||||
raise ValueError('loss_fn does not depend on input.')
|
||||
updates, state = self._gradient_transformation.update(g, state, x)
|
||||
return optax.apply_updates(x, updates), loss, state
|
||||
|
||||
|
||||
class SGD(StepOptimizer):
|
||||
"""Vanilla gradient descent optimizer."""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate_fn: Union[float, int, optax.Schedule],
|
||||
normalize_fn: Optional[NormalizeFn] = None):
|
||||
# Accept schedules, as well as scalar values.
|
||||
if isinstance(learning_rate_fn, (float, int)):
|
||||
lr = float(learning_rate_fn)
|
||||
learning_rate_fn = lambda _: lr
|
||||
# Normalization.
|
||||
def update_fn(updates, state, params=None):
|
||||
del params
|
||||
updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
|
||||
return updates, state
|
||||
gradient_transformation = optax.chain(
|
||||
optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn),
|
||||
optax.scale_by_schedule(learning_rate_fn),
|
||||
optax.scale(-1.))
|
||||
super(SGD, self).__init__(gradient_transformation)
|
||||
|
||||
|
||||
class IteratedFGSM(SGD):
|
||||
"""L-infinity normalized steps."""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate_fn: Union[float, int, optax.Schedule]):
|
||||
super(IteratedFGSM, self).__init__(learning_rate_fn, jnp.sign)
|
||||
|
||||
|
||||
class Adam(StepOptimizer):
|
||||
"""The Adam optimizer defined in https://arxiv.org/abs/1412.6980."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate_fn: Union[float, int, optax.Schedule],
|
||||
normalize_fn: Optional[NormalizeFn] = None,
|
||||
beta1: float = .9,
|
||||
beta2: float = .999,
|
||||
epsilon: float = 1e-9):
|
||||
# Accept schedules, as well as scalar values.
|
||||
if isinstance(learning_rate_fn, (float, int)):
|
||||
lr = float(learning_rate_fn)
|
||||
learning_rate_fn = lambda _: lr
|
||||
# Normalization.
|
||||
def update_fn(updates, state, params=None):
|
||||
del params
|
||||
updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
|
||||
return updates, state
|
||||
gradient_transformation = optax.chain(
|
||||
optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn),
|
||||
optax.scale_by_adam(b1=beta1, b2=beta2, eps=epsilon),
|
||||
optax.scale_by_schedule(learning_rate_fn),
|
||||
optax.scale(-1.))
|
||||
super(Adam, self).__init__(gradient_transformation)
|
||||
|
||||
|
||||
class PGD:
|
||||
"""Runs Project Gradient Descent (see https://arxiv.org/pdf/1706.06083)."""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: StepOptimizer,
|
||||
num_steps: int,
|
||||
initialize_fn: Optional[InitializeFn] = None,
|
||||
project_fn: Optional[ProjectFn] = None):
|
||||
self._optimizer = optimizer
|
||||
if initialize_fn is None:
|
||||
initialize_fn = lambda rng, x: x
|
||||
self._initialize_fn = initialize_fn
|
||||
if project_fn is None:
|
||||
project_fn = lambda x, origin_x: x
|
||||
self._project_fn = project_fn
|
||||
self._num_steps = num_steps
|
||||
|
||||
def __call__(self,
|
||||
loss_fn: LossFn,
|
||||
rng: chex.PRNGKey,
|
||||
x: chex.Array) -> chex.Array:
|
||||
def _optimize(rng, x):
|
||||
"""Optimizes loss_fn when keep_best is False."""
|
||||
def body_fn(_, inputs):
|
||||
opt_state, current_x = inputs
|
||||
current_x, _, opt_state = self._optimizer.minimize(current_x, opt_state)
|
||||
current_x = self._project_fn(current_x, x)
|
||||
return opt_state, current_x
|
||||
opt_state = self._optimizer.init(loss_fn, x)
|
||||
current_x = self._project_fn(self._initialize_fn(rng, x), x)
|
||||
_, current_x = jax.lax.fori_loop(0, self._num_steps, body_fn,
|
||||
(opt_state, current_x))
|
||||
return current_x
|
||||
return jax.lax.stop_gradient(_optimize(rng, x))
|
||||
|
||||
|
||||
def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn:
|
||||
def project_fn(x, origin_x):
|
||||
dx = jnp.clip(x - origin_x, -epsilon, epsilon)
|
||||
return jnp.clip(origin_x + dx, bounds[0], bounds[1])
|
||||
return project_fn
|
||||
|
||||
|
||||
def linf_initialize_fn(epsilon: float) -> InitializeFn:
|
||||
def initialize_fn(rng, x):
|
||||
return x + jax.random.uniform(rng, x.shape, minval=-epsilon,
|
||||
maxval=epsilon).astype(x.dtype)
|
||||
return initialize_fn
|
||||
|
||||
|
||||
def gradients_fn(loss_fn: LossFn,
|
||||
x: chex.Array) -> Tuple[chex.Array, chex.Array]:
|
||||
"""Returns the analytical gradient as computed by `jax.grad`."""
|
||||
@functools.partial(jax.grad, has_aux=True)
|
||||
def grad_reduced_loss_fn(x):
|
||||
loss = loss_fn(x)
|
||||
return jnp.sum(loss), loss
|
||||
return grad_reduced_loss_fn(x)
|
||||
180
adversarial_robustness/jax/datasets.py
Normal file
180
adversarial_robustness/jax/datasets.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Datasets."""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
||||
_CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
||||
_CIFAR10_STD = (0.2471, 0.2435, 0.2616)
|
||||
_CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
|
||||
_CIFAR100_STD = (0.2673, 0.2564, 0.2762)
|
||||
|
||||
_DATA_URL = 'https://storage.googleapis.com/dm-adversarial-robustness/'
|
||||
_ALLOWED_FILES = ('cifar10_ddpm.npz',)
|
||||
_WEBPAGE = ('https://github.com/deepmind/deepmind-research/tree/master/'
|
||||
'adversarial_robustness')
|
||||
|
||||
|
||||
def cifar10_preprocess(mode: str = 'train'):
|
||||
"""Preprocessing functions for CIFAR-10."""
|
||||
def _preprocess_fn_train(example):
|
||||
"""Preprocessing of CIFAR-10 images for training."""
|
||||
image = example['image']
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
image = _random_jitter(image, pad=4, crop=32)
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
label = tf.cast(example['label'], tf.int32)
|
||||
return {'image': image, 'label': label}
|
||||
|
||||
def _preprocess_fn_test(example):
|
||||
"""Preprocessing of CIFAR-10 images for testing."""
|
||||
image = example['image']
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
label = tf.cast(example['label'], tf.int32)
|
||||
return {'image': image, 'label': label}
|
||||
|
||||
return _preprocess_fn_train if mode == 'train' else _preprocess_fn_test
|
||||
|
||||
|
||||
def cifar10_normalize(image: chex.Array) -> chex.Array:
|
||||
means = jnp.array(_CIFAR10_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(_CIFAR10_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
|
||||
|
||||
def mnist_normalize(image: chex.Array) -> chex.Array:
|
||||
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
|
||||
constant_values=0)
|
||||
return (image - .5) * 2.
|
||||
|
||||
|
||||
def cifar100_normalize(image: chex.Array) -> chex.Array:
|
||||
means = jnp.array(_CIFAR100_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(_CIFAR100_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
|
||||
|
||||
def load_cifar10(batch_sizes: Sequence[int],
|
||||
subset: str = 'train',
|
||||
is_training: bool = True,
|
||||
drop_remainder: bool = True,
|
||||
repeat: int = 1) -> tf.data.Dataset:
|
||||
"""Loads CIFAR-10."""
|
||||
if subset == 'train':
|
||||
ds = tfds.load(name='cifar10', split=tfds.Split.TRAIN)
|
||||
# In Gowal et al. (https://arxiv.org/abs/2010.03593) and Rebuffi et al.
|
||||
# (https://arxiv.org/abs/2103.01946), we also keep a separate validation
|
||||
# subset for early stopping and would run: ds = ds.skip(1_024).
|
||||
elif subset == 'test':
|
||||
ds = tfds.load(name='cifar10', split=tfds.Split.TEST)
|
||||
else:
|
||||
raise ValueError('Unknown subset: "{}"'.format(subset))
|
||||
|
||||
ds = ds.cache()
|
||||
if is_training:
|
||||
ds = ds.repeat()
|
||||
ds = ds.shuffle(buffer_size=50_000, seed=0)
|
||||
ds = _repeat_batch(batch_sizes, ds, repeat=repeat)
|
||||
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
for batch_size in reversed(batch_sizes):
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
return ds.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
|
||||
def load_extra(batch_sizes: Sequence[int],
|
||||
path_npz: str,
|
||||
is_training: bool = True,
|
||||
drop_remainder: bool = True) -> tf.data.Dataset:
|
||||
"""Loads extra data from a given path."""
|
||||
if not tf.io.gfile.exists(path_npz):
|
||||
if path_npz in _ALLOWED_FILES:
|
||||
path_npz = tf.keras.utils.get_file(path_npz, _DATA_URL + path_npz)
|
||||
else:
|
||||
raise ValueError(f'Extra data not found ({path_npz}). See {_WEBPAGE} for '
|
||||
'more details.')
|
||||
with tf.io.gfile.GFile(path_npz, 'rb') as fp:
|
||||
npzfile = np.load(fp)
|
||||
data = {'image': npzfile['image'], 'label': npzfile['label']}
|
||||
with tf.device('/device:cpu:0'): # Prevent allocation to happen on GPU.
|
||||
ds = tf.data.Dataset.from_tensor_slices(data)
|
||||
ds = ds.cache()
|
||||
if is_training:
|
||||
ds = ds.repeat()
|
||||
ds = ds.shuffle(buffer_size=50_000, seed=jax.host_id())
|
||||
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
for batch_size in reversed(batch_sizes):
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
return ds.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
|
||||
def load_dummy_data(batch_sizes: Sequence[int],
|
||||
is_training: bool = True,
|
||||
**unused_kwargs) -> tf.data.Dataset:
|
||||
"""Loads fictive data (use this function when testing)."""
|
||||
ds = tf.data.Dataset.from_tensor_slices({
|
||||
'image': np.zeros((1, 32, 32, 3), np.float32),
|
||||
'label': np.zeros((1,), np.int32),
|
||||
})
|
||||
ds = ds.repeat()
|
||||
if not is_training:
|
||||
total_batch_size = np.prod(batch_sizes)
|
||||
ds = ds.take(total_batch_size)
|
||||
ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
for batch_size in reversed(batch_sizes):
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
|
||||
def _random_jitter(image: tf.Tensor, pad: int, crop: int) -> tf.Tensor:
|
||||
shape = image.shape.as_list()
|
||||
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]])
|
||||
image = tf.image.random_crop(image, size=[crop, crop, shape[2]])
|
||||
return image
|
||||
|
||||
|
||||
def _repeat_batch(batch_sizes: Sequence[int],
|
||||
ds: tf.data.Dataset,
|
||||
repeat: int = 1) -> tf.data.Dataset:
|
||||
"""Tiles the inner most batch dimension."""
|
||||
if repeat <= 1:
|
||||
return ds
|
||||
if batch_sizes[-1] % repeat != 0:
|
||||
raise ValueError(f'The last element of `batch_sizes` ({batch_sizes}) must '
|
||||
f'be divisible by `repeat` ({repeat}).')
|
||||
# Perform regular batching with reduced number of elements.
|
||||
for i, batch_size in enumerate(reversed(batch_sizes)):
|
||||
ds = ds.batch(batch_size // repeat if i == 0 else batch_size,
|
||||
drop_remainder=True)
|
||||
# Repeat batch.
|
||||
fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1)
|
||||
def repeat_inner_batch(example):
|
||||
return jax.tree_map(fn, example)
|
||||
ds = ds.map(repeat_inner_batch,
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
# Unbatch.
|
||||
for _ in batch_sizes:
|
||||
ds = ds.unbatch()
|
||||
return ds
|
||||
@@ -14,14 +14,19 @@
|
||||
|
||||
"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST."""
|
||||
|
||||
import functools
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
import numpy as np
|
||||
import optax
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tqdm
|
||||
|
||||
from adversarial_robustness.jax import attacks
|
||||
from adversarial_robustness.jax import datasets
|
||||
from adversarial_robustness.jax import model_zoo
|
||||
|
||||
_CKPT = flags.DEFINE_string(
|
||||
@@ -48,14 +53,14 @@ def main(unused_argv):
|
||||
# Create dataset.
|
||||
if _DATASET.value == 'mnist':
|
||||
_, data_test = tf.keras.datasets.mnist.load_data()
|
||||
normalize_fn = model_zoo.mnist_normalize
|
||||
normalize_fn = datasets.mnist_normalize
|
||||
elif _DATASET.value == 'cifar10':
|
||||
_, data_test = tf.keras.datasets.cifar10.load_data()
|
||||
normalize_fn = model_zoo.cifar10_normalize
|
||||
normalize_fn = datasets.cifar10_normalize
|
||||
else:
|
||||
assert _DATASET.value == 'cifar100'
|
||||
_, data_test = tf.keras.datasets.cifar100.load_data()
|
||||
normalize_fn = model_zoo.cifar100_normalize
|
||||
normalize_fn = datasets.cifar100_normalize
|
||||
|
||||
# Create model.
|
||||
@hk.transform_with_state
|
||||
@@ -83,22 +88,53 @@ def main(unused_argv):
|
||||
else:
|
||||
params, state = np.load(_CKPT.value, allow_pickle=True)
|
||||
|
||||
# Create adversarial attack. We run a PGD-40 attack with margin loss.
|
||||
epsilon = 8 / 255
|
||||
eval_attack = attacks.UntargetedAttack(
|
||||
attacks.PGD(
|
||||
attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule(
|
||||
init_value=.1,
|
||||
boundaries_and_scales={20: .1, 30: .01})),
|
||||
num_steps=40,
|
||||
initialize_fn=attacks.linf_initialize_fn(epsilon),
|
||||
project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
|
||||
loss_fn=attacks.untargeted_margin)
|
||||
|
||||
def logits_fn(x, rng):
|
||||
return model_fn.apply(params, state, rng, x)[0]
|
||||
|
||||
# Evaluation.
|
||||
correct = 0
|
||||
adv_correct = 0
|
||||
total = 0
|
||||
batch_count = 0
|
||||
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
|
||||
for images, labels in tqdm.tqdm(test_loader, total=total_batches):
|
||||
outputs = model_fn.apply(params, state, next(rng_seq), images)[0]
|
||||
rng = next(rng_seq)
|
||||
loop_logits_fn = functools.partial(logits_fn, rng=rng)
|
||||
|
||||
# Clean examples.
|
||||
outputs = loop_logits_fn(images)
|
||||
correct += (np.argmax(outputs, 1) == labels).sum().item()
|
||||
|
||||
# Adversarial examples.
|
||||
adv_images = eval_attack(loop_logits_fn, next(rng_seq), images, labels)
|
||||
outputs = loop_logits_fn(adv_images)
|
||||
predicted = np.argmax(outputs, 1)
|
||||
adv_correct += (predicted == labels).sum().item()
|
||||
|
||||
total += labels.shape[0]
|
||||
correct += (predicted == labels).sum().item()
|
||||
batch_count += 1
|
||||
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
|
||||
break
|
||||
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
|
||||
print(f'Robust accuracy: {100 * adv_correct / total:.2f}%')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('ckpt')
|
||||
try:
|
||||
tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU.
|
||||
except tf.errors.NotFoundError:
|
||||
pass
|
||||
app.run(main)
|
||||
|
||||
565
adversarial_robustness/jax/experiment.py
Normal file
565
adversarial_robustness/jax/experiment.py
Normal file
File diff suppressed because it is too large
Load Diff
46
adversarial_robustness/jax/experiment_test.py
Normal file
46
adversarial_robustness/jax/experiment_test.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Quick script to test that experiment can import and run."""
|
||||
|
||||
from absl import app
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jaxline import utils as jl_utils
|
||||
from adversarial_robustness.jax import experiment
|
||||
|
||||
|
||||
@jl_utils.disable_pmap_jit
|
||||
def test_experiment(unused_argv):
|
||||
"""Tests the main experiment."""
|
||||
config = experiment.get_config()
|
||||
exp_config = config.experiment_kwargs.config
|
||||
exp_config.dry_run = True
|
||||
exp_config.emulated_workers = 0
|
||||
exp_config.training.batch_size = 2
|
||||
exp_config.evaluation.batch_size = 2
|
||||
exp_config.model.kwargs.depth = 10
|
||||
exp_config.model.kwargs.width = 1
|
||||
|
||||
xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
|
||||
bcast = jax.pmap(lambda x: x)
|
||||
global_step = bcast(jnp.zeros(jax.local_device_count()))
|
||||
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
|
||||
print('Taking a single experiment step for test purposes!')
|
||||
result = xp.step(global_step, rng)
|
||||
print(f'Step successfully taken, resulting metrics are {result}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(test_experiment)
|
||||
@@ -14,19 +14,14 @@
|
||||
|
||||
"""WideResNet implementation in JAX using Haiku."""
|
||||
|
||||
from typing import Any, Mapping, Optional, Text
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
||||
CIFAR10_STD = (0.2471, 0.2435, 0.2616)
|
||||
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
|
||||
CIFAR100_STD = (0.2673, 0.2564, 0.2762)
|
||||
|
||||
|
||||
class _WideResNetBlock(hk.Module):
|
||||
"""Block of a WideResNet."""
|
||||
|
||||
@@ -89,16 +84,16 @@ class WideResNet(hk.Module):
|
||||
num_classes: int = 10,
|
||||
depth: int = 28,
|
||||
width: int = 10,
|
||||
activation: Text = 'relu',
|
||||
norm_args: Optional[Mapping[Text, Any]] = None,
|
||||
name: Optional[Text] = None):
|
||||
activation: str = 'relu',
|
||||
norm_args: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None):
|
||||
super(WideResNet, self).__init__(name=name)
|
||||
if (depth - 4) % 6 != 0:
|
||||
raise ValueError('depth should be 6n+4.')
|
||||
self._activation = getattr(jax.nn, activation)
|
||||
if norm_args is None:
|
||||
norm_args = {
|
||||
'create_offset': False,
|
||||
'create_offset': True,
|
||||
'create_scale': True,
|
||||
'decay_rate': .99,
|
||||
}
|
||||
@@ -113,6 +108,7 @@ class WideResNet(hk.Module):
|
||||
**norm_args)
|
||||
self._linear = hk.Linear(
|
||||
num_classes,
|
||||
w_init=jnp.zeros,
|
||||
name='logits')
|
||||
|
||||
blocks_per_layer = (depth - 4) // 6
|
||||
@@ -132,7 +128,7 @@ class WideResNet(hk.Module):
|
||||
name='resnet_lay_{}_block_{}'.format(layer_num, i)))
|
||||
self._blocks.append(blocks_of_layer)
|
||||
|
||||
def __call__(self, inputs, **norm_kwargs):
|
||||
def __call__(self, inputs: chex.Array, **norm_kwargs) -> chex.Array:
|
||||
net = inputs
|
||||
net = self._conv(net)
|
||||
|
||||
@@ -145,21 +141,3 @@ class WideResNet(hk.Module):
|
||||
|
||||
net = jnp.mean(net, axis=[1, 2])
|
||||
return self._linear(net)
|
||||
|
||||
|
||||
def mnist_normalize(image: jnp.array) -> jnp.array:
|
||||
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
|
||||
constant_values=0)
|
||||
return (image - .5) * 2.
|
||||
|
||||
|
||||
def cifar10_normalize(image: jnp.array) -> jnp.array:
|
||||
means = jnp.array(CIFAR10_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(CIFAR10_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
|
||||
|
||||
def cifar100_normalize(image: jnp.array) -> jnp.array:
|
||||
means = jnp.array(CIFAR100_MEAN, dtype=image.dtype)
|
||||
stds = jnp.array(CIFAR100_STD, dtype=image.dtype)
|
||||
return (image - means) / stds
|
||||
|
||||
32
adversarial_robustness/jax/train.py
Normal file
32
adversarial_robustness/jax/train.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Runs a JAXline experiment to perform robust adversarial training."""
|
||||
|
||||
import functools
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from jaxline import platform
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from adversarial_robustness.jax import experiment
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('config')
|
||||
try:
|
||||
tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU.
|
||||
except tf.errors.NotFoundError:
|
||||
pass
|
||||
app.run(functools.partial(platform.main, experiment.Experiment))
|
||||
197
adversarial_robustness/jax/utils.py
Normal file
197
adversarial_robustness/jax/utils.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helper functions."""
|
||||
|
||||
import re
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import chex
|
||||
import einops
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
|
||||
def get_cosine_schedule(
|
||||
max_learning_rate: float,
|
||||
total_steps: int,
|
||||
warmup_steps: int = 0) -> optax.Schedule:
|
||||
"""Builds a cosine decay schedule with initial warm-up."""
|
||||
if total_steps < warmup_steps:
|
||||
return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||
transition_steps=warmup_steps)
|
||||
return optax.join_schedules([
|
||||
optax.linear_schedule(init_value=0., end_value=max_learning_rate,
|
||||
transition_steps=warmup_steps),
|
||||
optax.cosine_decay_schedule(init_value=max_learning_rate,
|
||||
decay_steps=total_steps - warmup_steps),
|
||||
], [warmup_steps])
|
||||
|
||||
|
||||
def sgd_momentum(learning_rate_fn: optax.Schedule,
|
||||
momentum: float = 0.,
|
||||
nesterov: bool = False) -> optax.GradientTransformation:
|
||||
return optax.chain(
|
||||
optax.trace(decay=momentum, nesterov=nesterov),
|
||||
optax.scale_by_schedule(learning_rate_fn),
|
||||
optax.scale(-1.))
|
||||
|
||||
|
||||
def cross_entropy(logits: chex.Array, labels: chex.Array) -> chex.Array:
|
||||
return -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
|
||||
|
||||
|
||||
def kl_divergence(q_logits: chex.Array,
|
||||
p_logits: chex.Array) -> chex.Array:
|
||||
"""Compute the KL divergence."""
|
||||
p_probs = jax.nn.softmax(p_logits)
|
||||
return cross_entropy(q_logits, p_probs) - cross_entropy(p_logits, p_probs)
|
||||
|
||||
|
||||
def accuracy(logits: chex.Array, labels: chex.Array) -> chex.Array:
|
||||
predicted_label = jnp.argmax(logits, axis=-1)
|
||||
correct = jnp.equal(predicted_label, labels).astype(jnp.float32)
|
||||
return jnp.sum(correct, axis=0) / logits.shape[0]
|
||||
|
||||
|
||||
def weight_decay(params: hk.Params,
|
||||
regex_match: Optional[Sequence[str]] = None,
|
||||
regex_ignore: Optional[Sequence[str]] = None) -> chex.Array:
|
||||
"""Computes the L2 regularization loss."""
|
||||
if regex_match is None:
|
||||
regex_match = ('.*w$', '.*b$')
|
||||
if regex_ignore is None:
|
||||
regex_ignore = ('.*batchnorm.*',)
|
||||
l2_norm = 0.
|
||||
for mod_name, mod_params in params.items():
|
||||
for param_name, param in mod_params.items():
|
||||
name = '/'.join([mod_name, param_name])
|
||||
if (regex_match and
|
||||
all(not re.match(regex, name) for regex in regex_match)):
|
||||
continue
|
||||
if (regex_ignore and
|
||||
any(re.match(regex, name) for regex in regex_ignore)):
|
||||
continue
|
||||
l2_norm += jnp.sum(jnp.square(param))
|
||||
return .5 * l2_norm
|
||||
|
||||
|
||||
def ema_update(step: chex.Array,
|
||||
avg_params: chex.ArrayTree,
|
||||
new_params: chex.ArrayTree,
|
||||
decay_rate: float = 0.99,
|
||||
warmup_steps: int = 0,
|
||||
dynamic_decay: bool = True) -> chex.ArrayTree:
|
||||
"""Applies an exponential moving average."""
|
||||
factor = (step >= warmup_steps).astype(jnp.float32)
|
||||
if dynamic_decay:
|
||||
# Uses TF-style EMA.
|
||||
delta = step - warmup_steps
|
||||
decay = jnp.minimum(decay_rate, (1. + delta) / (10. + delta))
|
||||
else:
|
||||
decay = decay_rate
|
||||
decay *= factor
|
||||
def _weighted_average(p1, p2):
|
||||
d = decay.astype(p1.dtype)
|
||||
return (1 - d) * p1 + d * p2
|
||||
return jax.tree_multimap(_weighted_average, new_params, avg_params)
|
||||
|
||||
|
||||
def cutmix(rng: chex.PRNGKey,
|
||||
images: chex.Array,
|
||||
labels: chex.Array,
|
||||
alpha: float = 1.,
|
||||
beta: float = 1.,
|
||||
split: int = 1) -> Tuple[chex.Array, chex.Array]:
|
||||
"""Composing two images by inserting a patch into another image."""
|
||||
batch_size, height, width, _ = images.shape
|
||||
if split > 1:
|
||||
split_batch_size = batch_size // split
|
||||
|
||||
# Masking bounding box.
|
||||
box_rng, lam_rng, rng = jax.random.split(rng, num=3)
|
||||
lam = jax.random.beta(lam_rng, a=alpha, b=beta, shape=())
|
||||
cut_rat = jnp.sqrt(1. - lam)
|
||||
cut_w = jnp.array(width * cut_rat, dtype=jnp.int32)
|
||||
cut_h = jnp.array(height * cut_rat, dtype=jnp.int32)
|
||||
box_coords = _random_box(box_rng, height, width, cut_h, cut_w)
|
||||
# Adjust lambda.
|
||||
lam = 1. - (box_coords[2] * box_coords[3] / (height * width))
|
||||
idx = jax.random.permutation(rng, split_batch_size)
|
||||
def _cutmix(x, y):
|
||||
images_a = x
|
||||
images_b = x[idx, :, :, :]
|
||||
y = lam * y + (1. - lam) * y[idx, :]
|
||||
x = _compose_two_images(images_a, images_b, box_coords)
|
||||
return x, y
|
||||
|
||||
if split <= 1:
|
||||
return _cutmix(images, labels)
|
||||
|
||||
# Apply CutMix separately on each sub-batch. This reverses the effect of
|
||||
# `repeat` in datasets.
|
||||
images = einops.rearrange(images, '(b1 b2) ... -> b1 b2 ...', b2=split)
|
||||
labels = einops.rearrange(labels, '(b1 b2) ... -> b1 b2 ...', b2=split)
|
||||
images, labels = jax.vmap(_cutmix, in_axes=1, out_axes=1)(images, labels)
|
||||
images = einops.rearrange(images, 'b1 b2 ... -> (b1 b2) ...', b2=split)
|
||||
labels = einops.rearrange(labels, 'b1 b2 ... -> (b1 b2) ...', b2=split)
|
||||
return images, labels
|
||||
|
||||
|
||||
def _random_box(rng: chex.PRNGKey,
|
||||
height: chex.Numeric,
|
||||
width: chex.Numeric,
|
||||
cut_h: chex.Array,
|
||||
cut_w: chex.Array) -> chex.Array:
|
||||
"""Sample a random box of shape [cut_h, cut_w]."""
|
||||
height_rng, width_rng = jax.random.split(rng)
|
||||
minval_h = 0
|
||||
minval_w = 0
|
||||
maxval_h = height
|
||||
maxval_w = width
|
||||
i = jax.random.randint(
|
||||
height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32)
|
||||
j = jax.random.randint(
|
||||
width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32)
|
||||
bby1 = jnp.clip(i - cut_h // 2, 0, height)
|
||||
bbx1 = jnp.clip(j - cut_w // 2, 0, width)
|
||||
h = jnp.clip(i + cut_h // 2, 0, height) - bby1
|
||||
w = jnp.clip(j + cut_w // 2, 0, width) - bbx1
|
||||
return jnp.array([bby1, bbx1, h, w])
|
||||
|
||||
|
||||
def _compose_two_images(images: chex.Array,
|
||||
image_permutation: chex.Array,
|
||||
bbox: chex.Array) -> chex.Array:
|
||||
"""Inserting the second minibatch into the first at the target locations."""
|
||||
def _single_compose_two_images(image1, image2):
|
||||
height, width, _ = image1.shape
|
||||
mask = _window_mask(bbox, (height, width))
|
||||
return image1 * (1. - mask) + image2 * mask
|
||||
return jax.vmap(_single_compose_two_images)(images, image_permutation)
|
||||
|
||||
|
||||
def _window_mask(destination_box: chex.Array,
|
||||
size: Tuple[int, int]) -> jnp.ndarray:
|
||||
"""Mask a part of the image."""
|
||||
height_offset, width_offset, h, w = destination_box
|
||||
h_range = jnp.reshape(jnp.arange(size[0]), [size[0], 1, 1])
|
||||
w_range = jnp.reshape(jnp.arange(size[1]), [1, size[1], 1])
|
||||
return jnp.logical_and(
|
||||
jnp.logical_and(height_offset <= h_range,
|
||||
h_range < height_offset + h),
|
||||
jnp.logical_and(width_offset <= w_range,
|
||||
w_range < width_offset + w)).astype(jnp.float32)
|
||||
28
adversarial_robustness/pytorch/README.md
Normal file
28
adversarial_robustness/pytorch/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# PyTorch evaluation
|
||||
|
||||
We provide PyTorch evaluation code for convenience. If you developed a version
|
||||
of our training pipeline for PyTorch, please let us know as we will link it from
|
||||
here.
|
||||
|
||||
Here are known PyTorch implementations of our training pipeline:
|
||||
|
||||
* https://github.com/imrahulr/adversarial_robustness_pytorch (by Rahul Rade)
|
||||
|
||||
Here are few consideration when reproducing our training pipeline in PyTorch.
|
||||
As opposed to the [RST](https://github.com/yaircarmon/semisup-adv) code
|
||||
(provided by Carmon et al.):
|
||||
|
||||
* We set the batch normalization decay to 0.99 (instead of 0.9).
|
||||
* We do not apply weight decay (l2 regularization) to the batch normalization
|
||||
scale and offset
|
||||
* We use Haiku's default initialization for all layers (except the last, which
|
||||
is initialized with zeros).
|
||||
* The PGD attack used during training uniformly initializes the initial solution
|
||||
over the l-p norm ball.
|
||||
* We run the attack over the local batch statistics (rather than the evaluation
|
||||
statistics).
|
||||
* We update batch normalization statistics from adversarial examples only (
|
||||
rather than both clean and adversarial examples).
|
||||
* We use 10 epochs warm-up to our learning schedule.
|
||||
|
||||
|
||||
@@ -1,51 +1,64 @@
|
||||
absl-py==0.10.0
|
||||
# Direct dependencies.
|
||||
absl-py==0.12.0
|
||||
chex==0.0.7
|
||||
dm-haiku==0.0.4
|
||||
einops==0.3.0
|
||||
jax==0.2.16
|
||||
jaxlib==0.1.68
|
||||
jaxline==0.0.3
|
||||
ml-collections==0.1.0
|
||||
numpy==1.19.5
|
||||
optax==0.0.8
|
||||
tensorflow==2.5.0
|
||||
tensorflow-datasets==4.3.0
|
||||
torch==1.9.0
|
||||
torchvision==0.10.0
|
||||
tqdm==4.61.1
|
||||
# Transitive dependencies.
|
||||
astunparse==1.6.3
|
||||
attrs==20.3.0
|
||||
cachetools==4.1.1
|
||||
certifi==2020.11.8
|
||||
chardet==3.0.4
|
||||
dataclasses==0.6
|
||||
dill==0.3.3
|
||||
dm-haiku==0.0.3
|
||||
attrs==21.2.0
|
||||
cachetools==4.2.2
|
||||
certifi==2021.5.30
|
||||
chardet==4.0.0
|
||||
contextlib2==21.6.0
|
||||
dill==0.3.4
|
||||
dm-tree==0.1.6
|
||||
flatbuffers==1.12
|
||||
future==0.18.2
|
||||
gast==0.3.3
|
||||
google-auth==1.23.0
|
||||
google-auth-oauthlib==0.4.2
|
||||
gast==0.4.0
|
||||
google-auth==1.32.0
|
||||
google-auth-oauthlib==0.4.4
|
||||
google-pasta==0.2.0
|
||||
googleapis-common-protos==1.52.0
|
||||
grpcio==1.33.2
|
||||
h5py==2.10.0
|
||||
googleapis-common-protos==1.53.0
|
||||
grpcio==1.34.1
|
||||
h5py==3.1.0
|
||||
idna==2.10
|
||||
importlib-resources==3.3.0
|
||||
jax==0.2.6
|
||||
jaxlib==0.1.57
|
||||
keras-nightly==2.5.0.dev2021032900
|
||||
Keras-Preprocessing==1.1.2
|
||||
Markdown==3.3.3
|
||||
numpy==1.18.5
|
||||
oauthlib==3.1.0
|
||||
Markdown==3.3.4
|
||||
oauthlib==3.1.1
|
||||
opt-einsum==3.3.0
|
||||
Pillow==8.0.1
|
||||
Pillow==8.2.0
|
||||
pkg-resources==0.0.0
|
||||
promise==2.3
|
||||
protobuf==3.14.0
|
||||
protobuf==3.17.3
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
requests==2.25.0
|
||||
PyYAML==5.4.1
|
||||
requests==2.25.1
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.6
|
||||
scipy==1.5.4
|
||||
rsa==4.7.2
|
||||
scipy==1.7.0
|
||||
six==1.15.0
|
||||
tensorboard==2.4.0
|
||||
tensorboard-plugin-wit==1.7.0
|
||||
tensorflow==2.3.1
|
||||
tensorflow-datasets==4.1.0
|
||||
tensorflow-estimator==2.3.0
|
||||
tensorflow-metadata==0.25.0
|
||||
tabulate==0.8.9
|
||||
tensorboard==2.5.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.0
|
||||
tensorflow-estimator==2.5.0
|
||||
tensorflow-metadata==1.1.0
|
||||
termcolor==1.1.0
|
||||
torch==1.7.0
|
||||
torchvision==0.8.1
|
||||
tqdm==4.53.0
|
||||
toolz==0.11.1
|
||||
typing-extensions==3.7.4.3
|
||||
urllib3==1.26.2
|
||||
Werkzeug==1.0.1
|
||||
urllib3==1.26.6
|
||||
Werkzeug==2.0.1
|
||||
wrapt==1.12.1
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -euf -o pipefail # Stop at failure.
|
||||
|
||||
python3 -m venv /tmp/adversarial_robustness_venv
|
||||
source /tmp/adversarial_robustness_venv/bin/activate
|
||||
pip install -U pip
|
||||
@@ -34,3 +36,9 @@ python3 -m adversarial_robustness.pytorch.eval \
|
||||
--batch_size=1 \
|
||||
--num_batches=1 \
|
||||
--nouse_cuda
|
||||
|
||||
# We disable pmap/jit to avoid compilation during testing. Since the
|
||||
# test only runs a single step, it would not benefit from such a compilation
|
||||
# anyways.
|
||||
python3 -m adversarial_robustness.jax.experiment_test \
|
||||
--jaxline_disable_pmap_jit=True
|
||||
|
||||
123
rapid_task_solving/README.md
Normal file
123
rapid_task_solving/README.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# One-shot StreetLearn and Memory & Planning Game
|
||||
|
||||
This repository contains code for the environments used in the paper
|
||||
["Rapid Task-Solving in Novel Environments"](https://arxiv.org/abs/2006.03662)
|
||||
by Sam Ritter, Ryan Faulkner, Laurent Sartran, Adam Santoro, Matt Botvinick and
|
||||
David Raposo. It was published as a conference paper at ICLR 2021.
|
||||
|
||||
To cite this work:
|
||||
|
||||
```
|
||||
@inproceedings{
|
||||
ritter2021rapid,
|
||||
title={Rapid Task-Solving in Novel Environments},
|
||||
author={Samuel Ritter and Ryan Faulkner and Laurent Sartran and Adam Santoro
|
||||
and Matthew Botvinick and David Raposo},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=F-mvpFpn_0q}
|
||||
}
|
||||
```
|
||||
|
||||
### Memory&Planning Game
|
||||
|
||||
The _Memory&Planning Game_ is a simple variation of the well-known _Memory
|
||||
Game_, wherein players must remember the locations of cards in a grid. This
|
||||
variation extends the challenge to require planning as well as remembering.
|
||||
|
||||
In the _Memory&Planning Game_, the agent occupies an environment consisting
|
||||
of a grid of symbols (e.g. 4x4). The observation consists of two symbols — one
|
||||
which corresponds to the agent's current location, and another that corresponds
|
||||
to the "goal" the agent is tasked with navigating to. The agent can not see its
|
||||
relative location with respect to other symbols in the grid. At each step the
|
||||
agent selects one of 5 possible actions: _move left_, _move right_, _move up_,
|
||||
_move down_, and _collect_. If the agent chooses the "collect" action when its
|
||||
current location symbol matches the goal symbol, a reward of +1 is received.
|
||||
Otherwise, the agent receives a reward of 0. At the beginning of each episode, a
|
||||
new set of symbols is sampled, effectively inducing a new transition function.
|
||||
The agent is allowed a fixed number of steps (e.g. 100) per episode to "collect"
|
||||
as many goals as possible. Each time the agent collects a goal — which
|
||||
corresponds to completing a task —, a new goal is sampled in and the transition
|
||||
function stays fixed.
|
||||
|
||||
#### Example
|
||||
|
||||
The following code snipped shows an example of how to load the environment,
|
||||
start a new episode and take a few random steps. A plot representing the current
|
||||
state of the environment is displayed for each step.
|
||||
|
||||
```
|
||||
import memory_planning_game
|
||||
|
||||
env = memory_planning_game.MemoryPlanningGame(4, seed=123)
|
||||
_ = env.reset()
|
||||
for _ in range(5):
|
||||
timestep = env.take_random_action()
|
||||
fig, ax = env.draw_maze()
|
||||
```
|
||||
|
||||

|
||||
|
||||
### One-Shot StreetLearn
|
||||
|
||||
The _One-Shot StreetLearn_ is a domain wherein environments are sampled as
|
||||
neighborhoods from the _StreetLearn_ dataset (Mirowski et al., 2019). Tasks are
|
||||
then sampled by selecting a position and orientation that the agent must
|
||||
navigate to from its current location.
|
||||
|
||||
In _One-Shot StreetLearn_, the agent's observations consist of a representation
|
||||
of the current state and a representation of the goal state. The agent receives
|
||||
no other information from the environment. The available actions are _turn
|
||||
right_, which orients the agent clockwise toward the next available direction of
|
||||
motion from its current location; _turn left_, which does the same in the other
|
||||
direction; and _move forward_, which moves the agent along the direction it is
|
||||
facing to the next available location. In each episode, we sample a new
|
||||
neighborhood with 5 intersections from one of 12 cities. To reduce the
|
||||
exploration problem while keeping the planning difficulty constant, we removed
|
||||
all the locations that were not intersections (i.e. that corresponded to
|
||||
degree-2 nodes in the connectivity graph of the sampled neighbourhood).
|
||||
|
||||
Note: In the paper we used images to represent each location, taken from the
|
||||
_StreetLearn_ dataset. In this codebase, to simplify the public release process,
|
||||
we replace the images with one-hot vectors, to represent each location.
|
||||
|
||||
Every time the agent reaches a goal, a new starting- and goal-state pair is
|
||||
sampled, initiating a new task, until the fixed episode step limit is reached
|
||||
(e.g. 200 steps). A step that takes the agent to the goal state results in a
|
||||
reward of +1. Any other step results in a reward of 0.
|
||||
|
||||
#### City graph datasets
|
||||
|
||||
In this [link]
|
||||
(https://console.cloud.google.com/storage/browser/one_shot_streetlearn_graphs)
|
||||
you can download the datasets containing the connectivity graphs of different
|
||||
cities. You will need at least one of these datasets in order to build the
|
||||
environment.
|
||||
|
||||
#### Example
|
||||
|
||||
The following code snipped shows an example of how to load the environment for
|
||||
an example city (in this case, Madrid), start a new episode and take a few
|
||||
random steps. A plot representing the current state of the environment is
|
||||
displayed for each step.
|
||||
|
||||
```
|
||||
import numpy as np
|
||||
import one_shot_streetlearn
|
||||
|
||||
region = 'madrid'
|
||||
sl_config = {
|
||||
'dataset_path': f'./datasets/{region}_full_graph.gexf',
|
||||
'max_episode_steps': 200,
|
||||
'num_junctions': 5,
|
||||
}
|
||||
env = one_shot_streetlearn.OneShotStreetLearn(**sl_config)
|
||||
|
||||
_ = env.reset()
|
||||
for _ in range(4):
|
||||
rnd_action = np.random.randint(env.NUM_ACTIONS)
|
||||
timestep = env.step(rnd_action)
|
||||
fig, ax = env.draw_subgraph()
|
||||
```
|
||||
|
||||

|
||||
BIN
rapid_task_solving/images/example_mpg.png
Normal file
BIN
rapid_task_solving/images/example_mpg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
BIN
rapid_task_solving/images/example_osl.png
Normal file
BIN
rapid_task_solving/images/example_osl.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
184
rapid_task_solving/memory_planning_game.py
Normal file
184
rapid_task_solving/memory_planning_game.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Memory & Planning Game environment."""
|
||||
import string
|
||||
|
||||
import dm_env
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MemoryPlanningGame(dm_env.Environment):
|
||||
"""Memory & Planning Game environment."""
|
||||
|
||||
ACTION_NAMES = ['Up', 'Down', 'Left', 'Right', 'Collect']
|
||||
NUM_ACTIONS = len(ACTION_NAMES)
|
||||
DIRECTIONS = [
|
||||
(0, 1), # Up
|
||||
(0, -1), # Down
|
||||
(-1, 0), # Left
|
||||
(1, 0), # Right
|
||||
(0, 0), # Collect
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
maze_size=4,
|
||||
max_episode_steps=100,
|
||||
target_reward=1.,
|
||||
per_step_reward=0.,
|
||||
random_respawn=False,
|
||||
seed=None):
|
||||
"""The Memory & Planning Game environment.
|
||||
|
||||
Args:
|
||||
maze_size: (int) size of the maze dimension.
|
||||
max_episode_steps: (int) number of steps per episode.
|
||||
target_reward: (float) reward value of the target.
|
||||
per_step_reward: (float) reward/cost of taking a step.
|
||||
random_respawn: (bool) whether the agent respawns in a random location
|
||||
upon collecting the goal.
|
||||
seed: (int or None) seed for random number generator.
|
||||
"""
|
||||
self._maze_size = maze_size
|
||||
self._num_labels = maze_size * maze_size
|
||||
# The graph itself is the same across episodes, but the node labels will be
|
||||
# randomly sampled in each episode.
|
||||
self._graph = nx.grid_2d_graph(
|
||||
self._maze_size, self._maze_size, periodic=True)
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._target_reward = target_reward
|
||||
self._per_step_reward = per_step_reward
|
||||
self._random_respawn = random_respawn
|
||||
self._rng = np.random.RandomState(seed)
|
||||
|
||||
def _one_hot(self, node):
|
||||
one_hot_vector = np.zeros([self._num_labels], dtype=np.int32)
|
||||
one_hot_vector[self._labels[node]] = 1
|
||||
return one_hot_vector
|
||||
|
||||
def step(self, action):
|
||||
# If previous step was the last step of an episode, reset.
|
||||
if self._needs_reset:
|
||||
return self.reset()
|
||||
|
||||
# Increment step count and check if it's the last step of the episode.
|
||||
self._episode_steps += 1
|
||||
if self._episode_steps >= self._max_episode_steps:
|
||||
self._needs_reset = True
|
||||
transition = dm_env.termination
|
||||
else:
|
||||
transition = dm_env.transition
|
||||
|
||||
# Recompute agent's position given the selected action.
|
||||
direction = self.DIRECTIONS[action]
|
||||
self._position = tuple(
|
||||
(np.array(self._position) + np.array(direction)) % self._maze_size)
|
||||
self._previous_action = self.ACTION_NAMES[action]
|
||||
|
||||
# Get reward if agent is over the goal location and the selected action is
|
||||
# `collect`.
|
||||
if self._position == self._goal and self.ACTION_NAMES[action] == 'Collect':
|
||||
reward = self._target_reward
|
||||
self._set_new_goal()
|
||||
else:
|
||||
reward = self._per_step_reward
|
||||
self._episode_reward += reward
|
||||
|
||||
return transition(reward, self._observation())
|
||||
|
||||
def _observation(self):
|
||||
return {
|
||||
'position': np.array(self._one_hot(self.position), dtype=np.int32),
|
||||
'goal': np.array(self._one_hot(self.goal), dtype=np.int32),
|
||||
}
|
||||
|
||||
def observation_spec(self):
|
||||
return {
|
||||
'position': dm_env.specs.Array(
|
||||
shape=(self._num_labels,), dtype=np.int32, name='position'),
|
||||
'goal': dm_env.specs.Array(
|
||||
shape=(self._num_labels,), dtype=np.int32, name='goal'),
|
||||
}
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(self.NUM_ACTIONS)
|
||||
|
||||
def take_random_action(self):
|
||||
return self.step(self._rng.randint(self.NUM_ACTIONS))
|
||||
|
||||
def reset(self):
|
||||
self._previous_action = ''
|
||||
self._episode_reward = 0.
|
||||
self._episode_steps = 0
|
||||
self._needs_reset = False
|
||||
random_labels = self._rng.permutation(self._num_labels)
|
||||
self._labels = {n: random_labels[i]
|
||||
for i, n in enumerate(self._graph.nodes())}
|
||||
self._respawn()
|
||||
self._set_new_goal()
|
||||
return dm_env.restart(self._observation())
|
||||
|
||||
def _respawn(self):
|
||||
random_idx = self._rng.randint(self._num_labels)
|
||||
self._position = list(self._graph.nodes())[random_idx]
|
||||
|
||||
def _set_new_goal(self):
|
||||
if self._random_respawn:
|
||||
self._respawn()
|
||||
goal = self._position
|
||||
while goal == self._position:
|
||||
random_idx = self._rng.randint(self._num_labels)
|
||||
goal = list(self._graph.nodes())[random_idx]
|
||||
self._goal = goal
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
return self._position
|
||||
|
||||
@property
|
||||
def goal(self):
|
||||
return self._goal
|
||||
|
||||
@property
|
||||
def previous_action(self):
|
||||
return self._previous_action
|
||||
|
||||
@property
|
||||
def episode_reward(self):
|
||||
return self._episode_reward
|
||||
|
||||
def draw_maze(self, ax=None):
|
||||
if ax is None:
|
||||
plt.figure()
|
||||
ax = plt.gca()
|
||||
node_positions = {(x, y): (x, y) for x, y in self._graph.nodes()}
|
||||
letters = string.ascii_uppercase + string.ascii_lowercase
|
||||
labels = {n: letters[self._labels[n]] for n in self._graph.nodes()}
|
||||
node_list = list(self._graph.nodes())
|
||||
colors = []
|
||||
for n in node_list:
|
||||
if n == self.position:
|
||||
colors.append('lightblue')
|
||||
elif n == self.goal:
|
||||
colors.append('lightgreen')
|
||||
else:
|
||||
colors.append('pink')
|
||||
nx.draw(self._graph, pos=node_positions, nodelist=node_list, ax=ax,
|
||||
node_color=colors, with_labels=True, node_size=200, labels=labels)
|
||||
ax.set_title('{}\nEpisode reward={:.1f}'.format(
|
||||
self.previous_action, self.episode_reward))
|
||||
ax.margins(.1)
|
||||
return plt.gcf(), ax
|
||||
265
rapid_task_solving/one_shot_streetlearn.py
Normal file
265
rapid_task_solving/one_shot_streetlearn.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""One-shot StreetLearn environment."""
|
||||
|
||||
import dm_env
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def deg_to_rad(x):
|
||||
"""Convert degrees to radians."""
|
||||
return x / 180. * np.pi
|
||||
|
||||
|
||||
def rad_to_deg(x):
|
||||
"""Convert radians to degrees."""
|
||||
return x * 180. / np.pi
|
||||
|
||||
|
||||
class OneShotStreetLearn(dm_env.Environment):
|
||||
"""One-shot Streetlearn environment."""
|
||||
|
||||
ACTION_NAMES = [
|
||||
'Forward',
|
||||
'Left',
|
||||
'Right',
|
||||
'Collect',
|
||||
]
|
||||
NUM_ACTIONS = len(ACTION_NAMES)
|
||||
|
||||
def __init__(self, dataset_path, max_episode_steps, num_junctions=8,
|
||||
target_reward=1., per_step_reward=0., observation_length=60,
|
||||
seed=None):
|
||||
self._graph = nx.read_gexf(dataset_path)
|
||||
self._node_attrs = self._graph.nodes(data=True)
|
||||
self._num_junctions = num_junctions
|
||||
self._observation_length = observation_length
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._target_reward = target_reward
|
||||
self._per_step_reward = per_step_reward
|
||||
self._rng = np.random.RandomState(seed)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._previous_action = ''
|
||||
self._episode_reward = 0.
|
||||
self._episode_steps = 0
|
||||
self._needs_reset = False
|
||||
self._subgraph = self.get_random_subgraph()
|
||||
self._observation_map = self.randomize_observations(self._subgraph)
|
||||
self._position = self._rng.choice(list(self._subgraph.nodes()))
|
||||
neighbours = self._neighbors_bearings(self._subgraph, self._position)
|
||||
self._neighbour = neighbours[self._rng.randint(len(neighbours))]
|
||||
self._set_new_goal()
|
||||
return dm_env.restart(self._observation())
|
||||
|
||||
@property
|
||||
def _current_edge(self):
|
||||
return (self._position, self._neighbour['neighbour'])
|
||||
|
||||
def _set_new_goal(self):
|
||||
goal = None
|
||||
edges = list(self._observation_map.keys())
|
||||
while goal is None or goal == self._current_edge:
|
||||
goal = edges[self._rng.randint(len(edges))]
|
||||
self._goal = goal
|
||||
|
||||
def _one_hot(self, edge):
|
||||
one_hot_vector = np.zeros([self._observation_length], dtype=np.int32)
|
||||
one_hot_vector[self._observation_map[edge]] = 1
|
||||
return one_hot_vector
|
||||
|
||||
def _observation(self):
|
||||
return {
|
||||
'position': np.array(self._one_hot(self._current_edge), dtype=np.int32),
|
||||
'goal': np.array(self._one_hot(self._goal), dtype=np.int32),
|
||||
}
|
||||
|
||||
def observation_spec(self):
|
||||
return {
|
||||
'position': dm_env.specs.Array(
|
||||
shape=(self._observation_length,), dtype=np.int32, name='position'),
|
||||
'goal': dm_env.specs.Array(
|
||||
shape=(self._observation_length,), dtype=np.int32, name='goal'),
|
||||
}
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(self.NUM_ACTIONS)
|
||||
|
||||
def step(self, action):
|
||||
# If previous step was the last step of an episode, reset.
|
||||
if self._needs_reset:
|
||||
return self.reset()
|
||||
|
||||
# Increment step count and check if it's the last step of the episode.
|
||||
self._episode_steps += 1
|
||||
if self._episode_steps >= self._max_episode_steps:
|
||||
self._needs_reset = True
|
||||
transition = dm_env.termination
|
||||
else:
|
||||
transition = dm_env.transition
|
||||
|
||||
# Recompute agent's position
|
||||
self._move(action)
|
||||
self._previous_action = self.ACTION_NAMES[action]
|
||||
|
||||
# Get reward if agent is at the goal location and the selected action is
|
||||
# `collect`.
|
||||
if (self._current_edge == self._goal and
|
||||
self.ACTION_NAMES[action] == 'Collect'):
|
||||
reward = self._target_reward
|
||||
self._set_new_goal()
|
||||
else:
|
||||
reward = self._per_step_reward
|
||||
self._episode_reward += reward
|
||||
|
||||
return transition(reward, self._observation())
|
||||
|
||||
def randomize_observations(self, subgraph):
|
||||
edges = list(subgraph.edges())
|
||||
edges.extend([(y, x) for (x, y) in edges])
|
||||
obs_permutation = self._rng.permutation(self._observation_length)
|
||||
return {e: obs_permutation[i] for i, e in enumerate(edges)}
|
||||
|
||||
def _calculate_bearing(self, node, neighbor):
|
||||
lat1 = deg_to_rad(self._node_attrs[node]['lat'])
|
||||
lng1 = deg_to_rad(self._node_attrs[node]['lng'])
|
||||
lat2 = deg_to_rad(self._node_attrs[neighbor]['lat'])
|
||||
lng2 = deg_to_rad(self._node_attrs[neighbor]['lng'])
|
||||
delta_lng = lng2 - lng1
|
||||
theta = np.arctan2(
|
||||
np.sin(delta_lng) * np.cos(lat2),
|
||||
np.cos(lat1) * np.sin(lat2) -
|
||||
np.sin(lat1) * np.cos(lat2) * np.cos(delta_lng))
|
||||
return theta
|
||||
|
||||
def _neighbors_bearings(self, subgraph, node):
|
||||
bearings = []
|
||||
for neighbor in list(subgraph[node]):
|
||||
orientation = self._calculate_bearing(node, neighbor)
|
||||
bearings.append({'neighbour': neighbor, 'orientation': orientation})
|
||||
bearings.sort(key=lambda x: x['orientation'])
|
||||
return bearings
|
||||
|
||||
def _sort_neighbors(self, node, neighbour):
|
||||
bearings = self._neighbors_bearings(self._subgraph, node)
|
||||
bs = [x['orientation'] for x in bearings]
|
||||
idx = np.argmin(np.abs(bs - neighbour['orientation']))
|
||||
return {
|
||||
'forward': bearings[idx],
|
||||
'right': bearings[idx-1],
|
||||
'left': bearings[(idx+1) % len(bearings)],
|
||||
}
|
||||
|
||||
def _move(self, action):
|
||||
neighbours = self._sort_neighbors(self._position, self._neighbour)
|
||||
if action == 0:
|
||||
new_node = self._neighbour['neighbour']
|
||||
neighbours = self._sort_neighbors(new_node, neighbours['forward'])
|
||||
new_neighbour = neighbours['forward']
|
||||
else:
|
||||
new_node = self._position
|
||||
if action == 1:
|
||||
new_neighbour = neighbours['left']
|
||||
elif action == 2:
|
||||
new_neighbour = neighbours['right']
|
||||
else:
|
||||
new_neighbour = self._neighbour
|
||||
self._position = new_node
|
||||
self._neighbour = new_neighbour
|
||||
|
||||
def _all_next_junctions(self, subgraph, node):
|
||||
neighbors = list(subgraph[node])
|
||||
edges = [self._get_next_junction(subgraph, node, nb) for nb in neighbors]
|
||||
nodes = [y for (_, y) in edges]
|
||||
return nodes, edges
|
||||
|
||||
def _get_next_junction(self, subgraph, initial_node, next_node):
|
||||
node = initial_node
|
||||
while subgraph.degree(next_node) == 2:
|
||||
neighbours = list(subgraph.neighbors(next_node))
|
||||
neighbours.remove(node)
|
||||
node = next_node
|
||||
next_node = neighbours.pop()
|
||||
return (initial_node, next_node)
|
||||
|
||||
def get_random_subgraph(self):
|
||||
graph = self._graph
|
||||
num_nodes = len(graph)
|
||||
rnd_index = self._rng.randint(num_nodes)
|
||||
center_node = list(graph.nodes())[rnd_index]
|
||||
while graph.degree(center_node) <= 2:
|
||||
rnd_index = self._rng.randint(num_nodes)
|
||||
center_node = list(graph.nodes())[rnd_index]
|
||||
to_visit = [center_node]
|
||||
visited = []
|
||||
subgraph = nx.Graph()
|
||||
while to_visit:
|
||||
node = to_visit.pop(0)
|
||||
visited.append(node)
|
||||
new_nodes, new_edges = self._all_next_junctions(graph, node)
|
||||
subgraph.add_edges_from(new_edges)
|
||||
node_degrees = [subgraph.degree(n) for n in subgraph.nodes()]
|
||||
count_junctions = len(list(filter(lambda x: x > 2, node_degrees)))
|
||||
if count_junctions >= self._num_junctions:
|
||||
break
|
||||
new_nodes = filter(lambda x: x not in visited + to_visit, new_nodes)
|
||||
to_visit.extend(new_nodes)
|
||||
return subgraph
|
||||
|
||||
def draw_subgraph(self, ax=None):
|
||||
if ax is None:
|
||||
_ = plt.figure(figsize=(3, 3))
|
||||
ax = plt.gca()
|
||||
node_ids = list(self._subgraph.nodes())
|
||||
pos = {
|
||||
x: (self._node_attrs[x]['lat'], self._node_attrs[x]['lng'])
|
||||
for x in node_ids
|
||||
}
|
||||
labels = {}
|
||||
nc = 'pink'
|
||||
ec = 'black'
|
||||
ns = 50
|
||||
nshape = 'o'
|
||||
# Draw the current subgraph
|
||||
nx.draw(self._subgraph, pos=pos, node_color=nc, with_labels=False,
|
||||
node_size=ns, labels=labels, edgecolors=ec, node_shape=nshape,
|
||||
ax=ax)
|
||||
max_xy = np.array([np.array(x) for x in pos.values()]).max(0)
|
||||
min_xy = np.array([np.array(x) for x in pos.values()]).min(0)
|
||||
delta_xy = (max_xy - min_xy) / 6.
|
||||
ax.set_xlim([min_xy[0] - delta_xy[0], max_xy[0] + delta_xy[0]])
|
||||
ax.set_ylim([min_xy[1] - delta_xy[1], max_xy[1] + delta_xy[1]])
|
||||
# Draw goal position and orientation
|
||||
x = self._node_attrs[self._goal[0]]['lat']
|
||||
y = self._node_attrs[self._goal[0]]['lng']
|
||||
rotation = rad_to_deg(self._calculate_bearing(*self._goal))
|
||||
_ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0),
|
||||
markersize=14, markerfacecolor='white')
|
||||
_ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0),
|
||||
markersize=12, markerfacecolor='None')
|
||||
# Draw current position and orientation
|
||||
x = self._node_attrs[self._position]['lat']
|
||||
y = self._node_attrs[self._position]['lng']
|
||||
rotation = rad_to_deg(self._neighbour['orientation'])
|
||||
_ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0),
|
||||
markersize=14, markerfacecolor='lightgreen')
|
||||
_ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0),
|
||||
markersize=12, markerfacecolor='None')
|
||||
ax.set_title('{}\nEpisode reward = {}'.format(
|
||||
self._previous_action, self._episode_reward))
|
||||
return plt.gcf(), ax
|
||||
6
rapid_task_solving/requirements.txt
Normal file
6
rapid_task_solving/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
dm-env>=1.2
|
||||
dm-haiku>=0.0.3
|
||||
jax>=0.2.8
|
||||
matplotlib>=3.1.2
|
||||
networkx>=2.3
|
||||
numpy>=1.18.0
|
||||
230
wikigraphs/README.md
Normal file
230
wikigraphs/README.md
Normal file
@@ -0,0 +1,230 @@
|
||||
# WikiGraphs
|
||||
|
||||
This package provides tools to download the [WikiGraphs dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)
|
||||
[1], collected by pairing each Wikipedia article from [WikiText-103](https://arxiv.org/pdf/1609.07843.pdf)
|
||||
[2] with a knowledge graph (a subgraph from [Freebase knowledge graph](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)
|
||||
[3]). The baseline code to reproduce results in [1] is included as well. We hope
|
||||
this can spur more interest in developing models that can generate long text
|
||||
conditioned on graph and generate graphs given text.
|
||||
|
||||
## Setup Jax environment
|
||||
|
||||
[Jax](https://github.com/google/jax#installation),
|
||||
[Haiku](https://github.com/deepmind/dm-haiku#installation), and
|
||||
[Optax](https://github.com/deepmind/dm-haiku#installation) are needed for this
|
||||
package. It has been developed and tested on python 3 with the following
|
||||
packages:
|
||||
|
||||
* Jax==0.2.13
|
||||
* Haiku==0.0.5
|
||||
* Optax==0.0.6
|
||||
|
||||
Other packages required can be installed via:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Note: you may need to use `pip3` to select pip for python 3 and `--user` option
|
||||
to install the packages to avoid permission issues.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Preparing the data
|
||||
|
||||
### Download the data
|
||||
|
||||
You can download and unzip the data by running the following command:
|
||||
|
||||
```bash
|
||||
bash scripts/download.sh
|
||||
```
|
||||
|
||||
This will put the downloaded WikiText-103 data in a temporary directory
|
||||
`/tmp/data` with the tokenized WikiText-103 data in `/tmp/data/wikitext-103` and
|
||||
the raw data in `/tmp/data/wikitext-103-raw`.
|
||||
|
||||
This script will also download our processed Freebase knowledge graph data in a
|
||||
temporary directory `/tmp/data/freebase`.
|
||||
|
||||
### Build vocabularies
|
||||
|
||||
For WikiText-103, run the following command to generate a vocabulary file:
|
||||
|
||||
```bash
|
||||
python scripts/build_vocab.py \
|
||||
--vocab_file_path=/tmp/data/wikitext-vocab.csv \
|
||||
--data_dir=/tmp/data/wikitext-103
|
||||
```
|
||||
|
||||
You can change the default file paths but make sure you make them consistent.
|
||||
|
||||
### Pair Freebase graphs with WikiText
|
||||
|
||||
You can run the following command to pair the Freebase graphs with WikiText-103
|
||||
articles:
|
||||
|
||||
```bash
|
||||
python scripts/freebase_preprocess.py \
|
||||
--freebase_dir=/tmp/data/freebase/max256 \
|
||||
--output_dir=/tmp/data/wikigraphs/max256
|
||||
```
|
||||
|
||||
where the `freebase_dir` `/tmp/data/freebase/max256` is the directory that
|
||||
contains the Fsreebase graphs, which should have files `train.gz`, `valid.gz`
|
||||
and `test.gz` in it; and `output_dir` is the directory that will contain the
|
||||
generated paired Freebase-WikiText data.
|
||||
|
||||
Note: you may need to use `python3` to select python 3 if you have both python 2
|
||||
and 3 on your system.
|
||||
|
||||
Given that there are the following number of articles in WikiText-103:
|
||||
|
||||
Subset | #articles
|
||||
------ | ---------
|
||||
Train | 28472*
|
||||
Valid | 60
|
||||
Test | 60
|
||||
|
||||
*Official number is 28475 but we were only able to find 28472 articles in
|
||||
training set.
|
||||
|
||||
Our dataset covers around 80% of the WikiText articles:
|
||||
|
||||
Max graph size | 256 | 512 | 1024
|
||||
---------------------------- | ----- | ----- | -----
|
||||
\#articles in training set | 23431 | 23718 | 23760
|
||||
Trainining set coverage | 82.3% | 83.3% | 83.5%
|
||||
\#articles in validation set | 48 | 48 | 48
|
||||
Validation set coverage | 80% | 80% | 80%
|
||||
\#articles in test set | 43 | 43 | 43
|
||||
Test set coverage | 71.7% | 71.7% | 71.7%
|
||||
|
||||
### Build vocabulary for WikiGraphs
|
||||
|
||||
You can build the vocabulary for the graph data (the max256 version) by running
|
||||
the following command:
|
||||
|
||||
```bash
|
||||
python scripts/build_vocab.py \
|
||||
--vocab_file_path=/tmp/data/graph-vocab.csv \
|
||||
--data_dir=/tmp/data/wikigraphs \
|
||||
--version=max256 \
|
||||
--data_type=graph \
|
||||
--threshold=15
|
||||
```
|
||||
|
||||
This gives us a vocabulary of size 31,087, with each token included in the
|
||||
vocabulary appearing at least 15 times.
|
||||
|
||||
You also need to build a separate text vocabulary for the WikiGraphs data, as
|
||||
our training set does not cover 100% of WikiText-103.
|
||||
|
||||
```bash
|
||||
python scripts/build_vocab.py \
|
||||
--vocab_file_path=/tmp/data/text-vocab.csv \
|
||||
--data_dir=/tmp/data/wikigraphs \
|
||||
--version=max256 \
|
||||
--data_type=text \
|
||||
--threshold=3
|
||||
```
|
||||
|
||||
Here we choose threshold 3 which is also used by the original WikiText-103 data,
|
||||
this gives us a vocabulary size of 238,068, only slightly smaller than the
|
||||
original vocabulary size.
|
||||
|
||||
Note that when loading these vocabularies to build tokenizers, our tokenizers
|
||||
will add a few extra tokens, like `<bos>`, `<pad>`, so the final vocab size
|
||||
might be slightly different from the numbers above, depending on which tokenizer
|
||||
you choose to use.
|
||||
|
||||
We only showcase how to build the vocabulary for the max256 version. The above
|
||||
steps can be easily changed for the max512 and max1024 version.
|
||||
|
||||
## Loading the dataset
|
||||
|
||||
We provide JAX modules to load the WikiGraphs dataset. There are three classes
|
||||
in `wikigraphs/data/paired_dataset.py`:
|
||||
|
||||
* `TextOnlyDataset`: loads only the text part of the WikiGraphs data
|
||||
* `Bow2TextDataset`: loads text and the paired graph representated as one big
|
||||
bag-of-words (BoW) on all nodes and edges from the graph
|
||||
* `Graph2TextDataset`: returns text and the paired graph in which each node or
|
||||
edge is represented by a BoW
|
||||
|
||||
Different versions of the dataset can be accessed by changing the `version`
|
||||
argument in each class. For more detailed usage please refer to
|
||||
`wikigraphs/data/paired_dataset_test.py`. Besides, the original WikiText dataset
|
||||
can be loaded via the `Dataset` class in `wikigraphs/data/wikitext.py`.
|
||||
|
||||
Note: you may want to change the default data directory if you prefer to place
|
||||
it elsewhere.
|
||||
|
||||
## Run baseline models
|
||||
|
||||
Note: baseline models will be available soon.
|
||||
|
||||
## Citing WikiGraphs
|
||||
|
||||
To cite this work:
|
||||
|
||||
```
|
||||
@inproceedings{wang2021wikigraphs,
|
||||
title={WikiGraphs: A Wikipedia Text-Knowledge Graph Paired Dataset},
|
||||
author={Wang, Luyu and Li, Yujia and Aslan, Ozlem and Vinyals, Oriol},
|
||||
booktitle={Proceedings of the Graph-Based Methods for Natural Language Processing (TextGraphs)},
|
||||
pages={67--82},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
All code copyright 2021 DeepMind Technologies Limited
|
||||
|
||||
Code is licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
not use this file except in compliance with the License. You may obtain a copy
|
||||
of the License at:
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed
|
||||
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
[WikiGraphs](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)
|
||||
[1] is licensed under the terms of the Creative Commons
|
||||
Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
|
||||
[WikiText-103 data](https://arxiv.org/pdf/1609.07843.pdf) [2] (unchanged) is
|
||||
licensed by Salesforce.com, Inc. under the terms of the Creative Commons
|
||||
Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. You can find
|
||||
details about CC BY-SA 4.0 at:
|
||||
|
||||
https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
|
||||
[Freebase data](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)
|
||||
[3] is licensed by Google LLC under the terms of the Creative
|
||||
Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
|
||||
https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
|
||||
## References
|
||||
|
||||
1. L. Wang, Y. Li, O. Aslan, and O. Vinyals, "[WikiGraphs: a wikipedia -
|
||||
knowledge graph paired dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)",
|
||||
in Proceedings of the Graph-based Methods for Natural Language Processing
|
||||
(TextGraphs), pages 67-82, 2021.
|
||||
2. S. Merity, C. Xiong, J. Bradbury, and R. Socher, "[Pointer sentinel mixture
|
||||
models](https://arxiv.org/pdf/1609.07843.pdf)",
|
||||
arXiv: 1609.07843, 2016.
|
||||
3. K. Bollacker, C. Evans, P. Paritosh, T. Sturge, and J. Taylor,
|
||||
"[Freebase: a collaboratively created graph database for structuring human
|
||||
knowledge](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)",
|
||||
in Proceedings of ACM SIGMOD international conference on Managementof data,
|
||||
pages 1247–1250, 2008.
|
||||
7
wikigraphs/requirements.txt
Normal file
7
wikigraphs/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
absl-py==0.10.0
|
||||
dm-haiku
|
||||
jax>=0.2.13
|
||||
nltk>=3.6.2
|
||||
numpy>=1.19.5
|
||||
optax>=0.0.6
|
||||
scikit-learn>=0.24.2
|
||||
166
wikigraphs/scripts/build_vocab.py
Normal file
166
wikigraphs/scripts/build_vocab.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Script for building vocabulary files for datasets."""
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import enum
|
||||
import io
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from wikigraphs.data import io_tools
|
||||
from wikigraphs.data import paired_dataset
|
||||
from wikigraphs.data import tokenizers
|
||||
from wikigraphs.data import wikitext
|
||||
|
||||
|
||||
class DatasetType(enum.Enum):
|
||||
text = 1
|
||||
graph = 2
|
||||
wikitext = 3
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('data_dir', '', 'Path to the directory that contains the'
|
||||
' unzipped wikitext-103 data.')
|
||||
flags.DEFINE_string('vocab_file_path', '', 'Path to the output vocab file.')
|
||||
flags.DEFINE_enum_class('data_type', DatasetType.wikitext, DatasetType,
|
||||
'One of {`wikitext`, `graph`, `text`}.')
|
||||
flags.DEFINE_integer('threshold', 1, 'Frequency threshold for a word to be'
|
||||
' included in the vocabulary.')
|
||||
flags.DEFINE_string('version', 'max256', 'Which version of paired data to use.')
|
||||
|
||||
|
||||
def get_vocab(dataset: wikitext.RawDataset) -> List[Tuple[str, int]]:
|
||||
"""Build vocabulary, return (word, count) tuples sorted by count."""
|
||||
vocab = collections.defaultdict(int)
|
||||
|
||||
for pair in dataset:
|
||||
for t in pair.text.split(' '):
|
||||
if t:
|
||||
vocab[t] += 1
|
||||
|
||||
return sorted(vocab.items(), key=lambda t: -t[1])
|
||||
|
||||
|
||||
def write_vocab(vocab: List[Tuple[str, int]], output_path: str):
|
||||
"""Write a vocab list to a file."""
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(output_path, mode='wb') as f_:
|
||||
with io.TextIOWrapper(f_, encoding='utf-8') as f:
|
||||
w = csv.writer(f)
|
||||
w.writerows(vocab)
|
||||
|
||||
|
||||
def build_wikitext_vocab():
|
||||
logging.info('Loading the dataset.')
|
||||
dataset = wikitext.RawDataset(subset='train', data_dir=FLAGS.data_dir)
|
||||
logging.info('Building the vocab.')
|
||||
vocab = get_vocab(dataset)
|
||||
logging.info('Finished, vocab size %d, total number of tokens %d',
|
||||
len(vocab), sum([c for _, c in vocab]))
|
||||
logging.info('Writing the vocab to %s', FLAGS.vocab_file_path)
|
||||
write_vocab(vocab, FLAGS.vocab_file_path)
|
||||
|
||||
|
||||
def build_graph_vocab():
|
||||
"""Build vocabulary for graph data."""
|
||||
logging.info('Loading the dataset.')
|
||||
dataset = paired_dataset.ParsedDataset(
|
||||
subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version)
|
||||
logging.info('Building graph vocab.')
|
||||
|
||||
vocab = collections.defaultdict(int)
|
||||
for pair in dataset:
|
||||
graph = pair.graph
|
||||
for n in graph.nodes():
|
||||
for t in tokenizers.GraphTokenizer.split_node(n):
|
||||
if t:
|
||||
vocab[t] += 1
|
||||
for _, _, e in graph.edges():
|
||||
for t in tokenizers.GraphTokenizer.split_edge(e):
|
||||
if t:
|
||||
vocab[t] += 1
|
||||
|
||||
vocab = sorted(vocab.items(), key=lambda t: -t[1])
|
||||
vocab = [k for k, v in vocab if v >= FLAGS.threshold]
|
||||
|
||||
logging.info('Finished, vocab size %d.', len(vocab))
|
||||
logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path)
|
||||
|
||||
io_tools.write_txt_file(FLAGS.vocab_file_path, '\n'.join(vocab),
|
||||
# Some unicode characters requires utf-16 to encode.
|
||||
encoding='utf-16')
|
||||
|
||||
|
||||
def build_text_vocab():
|
||||
"""Build vocabulary for the text part of the graph-to-text data."""
|
||||
logging.info('Loading the dataset.')
|
||||
dataset = paired_dataset.ParsedDataset(
|
||||
subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version)
|
||||
logging.info('Building text vocab.')
|
||||
|
||||
vocab = collections.defaultdict(int)
|
||||
for pair in dataset:
|
||||
for t in pair.text.split(' '):
|
||||
if t:
|
||||
vocab[t] += 1
|
||||
|
||||
vocab = sorted(vocab.items(), key=lambda t: -t[1])
|
||||
logging.info('Finished, vocab size %d, total number of tokens %d.',
|
||||
len(vocab), sum([v for _, v in vocab]))
|
||||
vocab = [(k, v) for k, v in vocab if v >= FLAGS.threshold]
|
||||
logging.info('After filtering, vocab size %d.', len(vocab))
|
||||
logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path)
|
||||
|
||||
write_vocab(vocab, FLAGS.vocab_file_path)
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.data_type == DatasetType.wikitext:
|
||||
build_wikitext_vocab()
|
||||
elif FLAGS.data_type == DatasetType.text:
|
||||
build_text_vocab()
|
||||
elif FLAGS.data_type == DatasetType.graph:
|
||||
build_graph_vocab()
|
||||
else:
|
||||
raise ValueError(f'Unknown data type {FLAGS.data_type}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
63
wikigraphs/scripts/download.sh
Normal file
63
wikigraphs/scripts/download.sh
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
BASE_DIR=/tmp/data
|
||||
|
||||
# wikitext-103
|
||||
TARGET_DIR=${BASE_DIR}/wikitext-103
|
||||
mkdir -p ${TARGET_DIR}
|
||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -P ${TARGET_DIR}
|
||||
unzip ${TARGET_DIR}/wikitext-103-v1.zip -d ${TARGET_DIR}
|
||||
mv ${TARGET_DIR}/wikitext-103/* ${TARGET_DIR}
|
||||
rm -rf ${TARGET_DIR}/wikitext-103 ${TARGET_DIR}/wikitext-103-v1.zip
|
||||
|
||||
# wikitext-103-raw
|
||||
TARGET_DIR=${BASE_DIR}/wikitext-103-raw
|
||||
mkdir -p ${TARGET_DIR}
|
||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip -P ${TARGET_DIR}
|
||||
unzip ${TARGET_DIR}/wikitext-103-raw-v1.zip -d ${TARGET_DIR}
|
||||
mv ${TARGET_DIR}/wikitext-103-raw/* ${TARGET_DIR}
|
||||
rm -rf ${TARGET_DIR}/wikitext-103-raw ${TARGET_DIR}/wikitext-103-raw-v1.zip
|
||||
|
||||
|
||||
# processed freebase graphs
|
||||
FREEBASE_TARGET_DIR=/tmp/data
|
||||
mkdir -p ${FREEBASE_TARGET_DIR}/packaged/
|
||||
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuSS2o72dUCJrcLff6NBiLJuTgSU-uRo' -O ${FREEBASE_TARGET_DIR}/packaged/max256.tar
|
||||
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1nOfUq3RUoPEWNZa2QHXl2q-1gA5F6kYh' -O ${FREEBASE_TARGET_DIR}/packaged/max512.tar
|
||||
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuJwkocJXG1UcQ-RCH3JU96VsDvi7UD2' -O ${FREEBASE_TARGET_DIR}/packaged/max1024.tar
|
||||
|
||||
for version in max1024 max512 max256
|
||||
do
|
||||
output_dir=${FREEBASE_TARGET_DIR}/freebase/${version}/
|
||||
mkdir -p ${output_dir}
|
||||
tar -xvf ${FREEBASE_TARGET_DIR}/packaged/${version}.tar -C ${output_dir}
|
||||
done
|
||||
rm -rf ${FREEBASE_TARGET_DIR}/packaged
|
||||
106
wikigraphs/scripts/freebase_preprocess.py
Normal file
106
wikigraphs/scripts/freebase_preprocess.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Preprocess freebase data and pair with wikitext."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from wikigraphs.data import io_tools
|
||||
from wikigraphs.data import wikitext
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('freebase_dir', '', 'Directory that containns Freebase'
|
||||
' graphs.')
|
||||
flags.DEFINE_string('output_dir', '', 'Path to output directory to store the'
|
||||
' paired dataset.')
|
||||
|
||||
|
||||
def pair_graphs_with_wikitext(subset: str, graph_dir: str, output_dir: str):
|
||||
"""Pair graphs with wikitext articles, and write to output directory."""
|
||||
logging.info('Pairing graphs from the %s set from %s with wikitext.',
|
||||
subset, graph_dir)
|
||||
graphs = list(io_tools.graphs_from_file(
|
||||
os.path.join(graph_dir, f'{subset}.gz')))
|
||||
title2graph = {
|
||||
io_tools.normalize_freebase_string(g.title).replace(' ', ''): g
|
||||
for g in graphs}
|
||||
n_graphs = len(graphs)
|
||||
|
||||
# Use raw version of the wikitext data as the tokenized version has <unk> in
|
||||
# titles which is bad for matching. We will handle the <unk>s through the
|
||||
# tokenizer to make sure our data are equivalent to that of the tokenized
|
||||
# version of wikitext-103.
|
||||
wikitext_articles = list(wikitext.RawDataset(subset=subset, version='raw'))
|
||||
n_wiki = len(wikitext_articles)
|
||||
logging.info('Loaded %d graphs and %d wikitext articles in total.',
|
||||
n_graphs, n_wiki)
|
||||
|
||||
# Keep track of the article titles in the dataset. Unfortunately wikitext-103
|
||||
# has about 1% of duplicated articles, we want to take care of that.
|
||||
retrieved_titles = set()
|
||||
pairs = []
|
||||
n_duplicates = 0
|
||||
for a in wikitext_articles:
|
||||
title = wikitext.normalize_title(a.title).replace(' ', '')
|
||||
g = title2graph.get(title, None)
|
||||
if g is not None:
|
||||
if title not in retrieved_titles:
|
||||
retrieved_titles.add(title)
|
||||
pairs.append(io_tools.GraphTextPair(
|
||||
center_node=g.center,
|
||||
title=g.title,
|
||||
edges=g.edges,
|
||||
text=a.text))
|
||||
else:
|
||||
n_duplicates += 1
|
||||
|
||||
n_pairs = len(pairs)
|
||||
logging.info('Matched %d/%d = %.1f%% of wikitext articles,'
|
||||
' and %d/%d = %.1f%% of graphs.',
|
||||
n_pairs, n_wiki, float(n_pairs) / n_wiki * 100,
|
||||
n_pairs, n_graphs, float(n_pairs) / n_graphs * 100)
|
||||
logging.info('Detected %d/%d = %.1f%% of duplicated wikitext articles.',
|
||||
n_duplicates, n_wiki, float(n_duplicates) / n_wiki * 100)
|
||||
|
||||
io_tools.write_pairs_to_gzip_txt_file(
|
||||
os.path.join(output_dir, f'{subset}.gz'), pairs)
|
||||
|
||||
|
||||
def main(_):
|
||||
for subset in ['train', 'valid', 'test']:
|
||||
pair_graphs_with_wikitext(subset, FLAGS.freebase_dir, FLAGS.output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
143
wikigraphs/scripts/visualize_graph.py
Normal file
143
wikigraphs/scripts/visualize_graph.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
r"""Tool to visualize graphs.
|
||||
|
||||
You need to have the command line tool `dot` installed locally, for example by
|
||||
`sudo apt-get install graphviz`.
|
||||
|
||||
Example usage:
|
||||
python visualize_graph.py \
|
||||
--logtostderr --graph_ids=0:48 --truncate_limit=500 --layout=fdp
|
||||
"""
|
||||
|
||||
import html
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from wikigraphs.data import io_tools
|
||||
from wikigraphs.data import paired_dataset as pd
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('subset', 'valid', 'Which subset to choose graphs from.')
|
||||
flags.DEFINE_string('graph_ids', '', 'A comma-separated string of graph IDs'
|
||||
' (0-based), for example `1,2,3`. Or alternatively a'
|
||||
' range, e.g. `0:10` which is equivalent to'
|
||||
' `0,1,2,3,...,9`.')
|
||||
flags.DEFINE_string('version', 'max256', 'Which version of data to load.')
|
||||
flags.DEFINE_string('data_dir', '', 'Path to a directory that contains the raw'
|
||||
' paired data, if provided.')
|
||||
flags.DEFINE_string('output_dir', '/tmp/graph_vis', 'Output directory to save'
|
||||
' the visualized graphs.')
|
||||
flags.DEFINE_integer('truncate_limit', -1, 'Maximum length for graph nodes in'
|
||||
' visualization.')
|
||||
flags.DEFINE_string('layout', 'fdp', 'Which one of the dot layout to use.')
|
||||
|
||||
|
||||
def truncate(s: str) -> str:
|
||||
if FLAGS.truncate_limit > 0 and len(s) > FLAGS.truncate_limit:
|
||||
s = s[:FLAGS.truncate_limit] + '...'
|
||||
return s
|
||||
|
||||
|
||||
def format_label(s: str, width: int = 40) -> str:
|
||||
"""Format a node / edge label."""
|
||||
s = io_tools.normalize_freebase_string(s)
|
||||
s = truncate(s)
|
||||
lines = s.split('\\n')
|
||||
output_lines = []
|
||||
for line in lines:
|
||||
line = html.escape(line)
|
||||
if width > 0:
|
||||
output_lines += textwrap.wrap(line, width)
|
||||
else:
|
||||
output_lines.append(line)
|
||||
return '<' + '<br/>'.join(output_lines) + '>'
|
||||
|
||||
|
||||
def graph_to_dot(graph_text_pair: io_tools.GraphTextPair) -> str:
|
||||
"""Convert a graph to a dot file."""
|
||||
dot = ['digraph {', 'node [shape=rect];']
|
||||
graph = pd.Graph.from_edges(graph_text_pair.edges)
|
||||
center_node_id = graph.node2id(graph_text_pair.center_node)
|
||||
|
||||
for i, n in enumerate(graph.nodes()):
|
||||
color = '#f5dc98' if i == center_node_id else (
|
||||
'#b0ffad' if not(n[0] == '"' and n[-1] == '"') else '#ffffff')
|
||||
label = format_label(n)
|
||||
dot.append(f'{i} [ label = {label}, fillcolor="{color}", style="filled"];')
|
||||
|
||||
for i, j, e in graph.edges():
|
||||
dot.append(f'{i} -> {j} [ label = {format_label(e, width=0)} ];')
|
||||
dot.append('}')
|
||||
return '\n'.join(dot)
|
||||
|
||||
|
||||
def visualize_graph(graph_text_pair: io_tools.GraphTextPair,
|
||||
graph_id: int,
|
||||
output_dir: str):
|
||||
"""Visualize a graph and save the visualization to the specified directory."""
|
||||
dot = graph_to_dot(graph_text_pair)
|
||||
output_file = os.path.join(output_dir, f'{graph_id}.dot')
|
||||
logging.info('Writing output to %s', output_file)
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(dot)
|
||||
pdf_output = os.path.join(output_dir, f'{graph_id}.pdf')
|
||||
os.system(f'dot -K{FLAGS.layout} -Tpdf -o {pdf_output} {output_file}')
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.info('Loading the %s set of data.', FLAGS.subset)
|
||||
pairs = list(pd.RawDataset(subset=FLAGS.subset,
|
||||
data_dir=FLAGS.data_dir or None,
|
||||
shuffle_data=False,
|
||||
version=FLAGS.version))
|
||||
logging.info('Loaded %d graph-text pairs.')
|
||||
|
||||
if ':' in FLAGS.graph_ids:
|
||||
start, end = [int(i) for i in FLAGS.graph_ids.split(':')]
|
||||
graph_ids = list(range(start, end))
|
||||
else:
|
||||
graph_ids = [int(i) for i in FLAGS.graph_ids.split(',')]
|
||||
logging.info('Visualizing graphs with ID %r', graph_ids)
|
||||
|
||||
if not os.path.exists(FLAGS.output_dir):
|
||||
os.makedirs(FLAGS.output_dir)
|
||||
|
||||
for gid in graph_ids:
|
||||
visualize_graph(pairs[gid], gid, FLAGS.output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
43
wikigraphs/setup.py
Normal file
43
wikigraphs/setup.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Setup for pip package."""
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='wikigraphs',
|
||||
version='0.0.1',
|
||||
description='A Wikipedia - knowledge graph paired dataset.',
|
||||
url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs',
|
||||
author='DeepMind',
|
||||
author_email='luyuwang@google.com',
|
||||
packages=find_packages(),
|
||||
license='Apache 2.0',
|
||||
)
|
||||
36
wikigraphs/wikigraphs/data/__init__.py
Normal file
36
wikigraphs/wikigraphs/data/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""WikiGraphs data modules."""
|
||||
from . import dataset
|
||||
from . import io_tools
|
||||
from . import paired_dataset
|
||||
from . import tokenizers
|
||||
from . import tools
|
||||
from . import wikitext
|
||||
59
wikigraphs/wikigraphs/data/dataset.py
Normal file
59
wikigraphs/wikigraphs/data/dataset.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Base class of the datasets."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
class Dataset(abc.ABC):
|
||||
"""Base class for all datasets.
|
||||
|
||||
All sub-classes should define `_load_data()` where an iterator
|
||||
`self._data_iter` should be instantiated that iterates over the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Constructor."""
|
||||
self._data_iter = None # An iterator produced by `self._load_data`.
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_data(self) -> Iterator[Any]:
|
||||
"""Prepare data for another pass through the dataset.
|
||||
|
||||
This method should return a generator in a child class.
|
||||
"""
|
||||
|
||||
def __next__(self):
|
||||
return next(self._data_iter)
|
||||
|
||||
def __iter__(self):
|
||||
self._data_iter = self._load_data()
|
||||
return self
|
||||
179
wikigraphs/wikigraphs/data/io_tools.py
Normal file
179
wikigraphs/wikigraphs/data/io_tools.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Some tools for I/O."""
|
||||
|
||||
import gzip
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import NamedTuple, List, Iterator
|
||||
|
||||
from absl import logging
|
||||
|
||||
|
||||
def read_txt_file(file_path: str, encoding: str = 'utf-8') -> str:
|
||||
"""Read a plain txt file."""
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
return content.decode(encoding)
|
||||
|
||||
|
||||
def write_txt_file(file_path: str, txt: str, encoding: str = 'utf-8'):
|
||||
"""Write the given txt string to file."""
|
||||
make_dir_if_necessary(file_path)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(txt.encode(encoding, 'surrogatepass'))
|
||||
|
||||
|
||||
def read_gzip_txt_file(file_path: str, encoding: str = 'utf-8') -> str:
|
||||
"""Read gzipped txt file."""
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
with gzip.GzipFile(fileobj=io.BytesIO(content), mode='rb') as f:
|
||||
content = f.read()
|
||||
return content.decode(encoding)
|
||||
|
||||
|
||||
def make_dir_if_necessary(output_path):
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
def write_lines_to_gzipped_file(file_path, lines):
|
||||
make_dir_if_necessary(file_path)
|
||||
with open(file_path, 'wb') as f_zip:
|
||||
with gzip.GzipFile(fileobj=f_zip, mode='wb') as f:
|
||||
f.write('\n'.join(lines).encode('utf-8'))
|
||||
|
||||
|
||||
class Graph(NamedTuple):
|
||||
title: str
|
||||
center: str
|
||||
edges: List[str]
|
||||
|
||||
|
||||
def graphs_from_file(file_path: str) -> Iterator[Graph]:
|
||||
"""Read freebase graphs from file.
|
||||
|
||||
Args:
|
||||
file_path: path to the input `.gz` file that contains a list of graphs.
|
||||
|
||||
Yields:
|
||||
graphs: a list of read from the file.
|
||||
"""
|
||||
content = read_gzip_txt_file(file_path)
|
||||
|
||||
graph_header_sep_re = re.compile(
|
||||
r'(<graph center=[^ ]+ title="[^"]+">\n)')
|
||||
graph_header_re = re.compile(
|
||||
r'<graph center=([^ ]+) title="([^"]+)">\n')
|
||||
parts = graph_header_sep_re.split(content)
|
||||
|
||||
# Skip the first part which is empty
|
||||
for i in range(1, len(parts), 2):
|
||||
header, body = parts[i], parts[i + 1]
|
||||
m = graph_header_re.match(header)
|
||||
yield Graph(title=m.group(2),
|
||||
center=m.group(1),
|
||||
edges=body.strip().split('\n'))
|
||||
|
||||
|
||||
_UNICODE_RE = re.compile(r'(\$[0-9A-Fa-f]{4})')
|
||||
|
||||
|
||||
def normalize_freebase_string(s: str) -> str:
|
||||
"""Expand the `$xxxx` escaped unicode characters in the input string."""
|
||||
# '"' is escaped as '``', convert it back.
|
||||
s.replace('``', '"')
|
||||
parts = _UNICODE_RE.split(s)
|
||||
parts = [p if not _UNICODE_RE.match(p) else chr(int(p[1:], base=16))
|
||||
for p in parts]
|
||||
return ''.join(parts).replace('_', ' ')
|
||||
|
||||
|
||||
class GraphTextPair(NamedTuple):
|
||||
"""Text paired with raw graph represented as in `edges`."""
|
||||
center_node: str
|
||||
title: str
|
||||
edges: List[str]
|
||||
text: str
|
||||
|
||||
|
||||
def pair2lines(pair):
|
||||
lines = [f'<graph center={pair.center_node} title="{pair.title}">']
|
||||
lines.append('<section id="text">')
|
||||
lines.append(pair.text)
|
||||
lines.append('<section id="edges">')
|
||||
lines.extend(pair.edges)
|
||||
return lines
|
||||
|
||||
|
||||
def write_pairs_to_gzip_txt_file(file_path, pairs):
|
||||
logging.info('Writing %d pairs to %s.', len(pairs), file_path)
|
||||
lines = []
|
||||
for p in pairs:
|
||||
lines.extend(pair2lines(p))
|
||||
write_lines_to_gzipped_file(file_path, lines)
|
||||
|
||||
|
||||
def read_pairs_from_gzip_txt_file(file_path: str) -> Iterator[GraphTextPair]:
|
||||
"""Read graph-text pairs from gzip txt files.
|
||||
|
||||
Args:
|
||||
file_path: a `.gz` file of graph-text pairs written in the same format as
|
||||
using the `write_pairs_to_gzip_txt_file` function.
|
||||
|
||||
Yields:
|
||||
Graph-text pairs from this file.
|
||||
"""
|
||||
content = read_gzip_txt_file(file_path)
|
||||
|
||||
graph_header_sep_re = re.compile(
|
||||
r'(<graph center=[^ ]+ title="[^"]+">)')
|
||||
graph_header_re = re.compile(
|
||||
r'<graph center=([^ ]+) title="([^"]+)">$')
|
||||
section_sep_re = re.compile(r'\n(<section id="[^"]+">\n)')
|
||||
parts = graph_header_sep_re.split(content)
|
||||
|
||||
# Skip the first part which is empty
|
||||
for i in range(1, len(parts), 2):
|
||||
header, body = parts[i], parts[i + 1]
|
||||
m = graph_header_re.match(header)
|
||||
|
||||
# 5 parts total, empty first part, "text", text section, "edges", edges
|
||||
# section.
|
||||
section_parts = section_sep_re.split(body)
|
||||
|
||||
yield GraphTextPair(center_node=m.group(1),
|
||||
title=m.group(2),
|
||||
text=section_parts[2],
|
||||
edges=section_parts[-1].strip().split('\n'))
|
||||
767
wikigraphs/wikigraphs/data/paired_dataset.py
Normal file
767
wikigraphs/wikigraphs/data/paired_dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
271
wikigraphs/wikigraphs/data/paired_dataset_test.py
Normal file
271
wikigraphs/wikigraphs/data/paired_dataset_test.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tests for wikigraphs.data.paired_dataset."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from wikigraphs.data import io_tools
|
||||
from wikigraphs.data import paired_dataset
|
||||
from wikigraphs.data import tokenizers
|
||||
from wikigraphs.data import wikitext
|
||||
|
||||
|
||||
WIKITEXT_ROOT = '/tmp/data/wikitext-103'
|
||||
WIKIGRAPHS_ROOT = '/tmp/data/wikigraphs'
|
||||
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
|
||||
GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv'
|
||||
|
||||
|
||||
class PairedDatasetTest(absltest.TestCase):
|
||||
|
||||
def test_raw_paired_dataset_size(self):
|
||||
dataset = paired_dataset.RawDataset(
|
||||
subset='valid', shuffle_data=False, data_dir=WIKIGRAPHS_ROOT)
|
||||
pairs = list(dataset)
|
||||
self.assertLen(pairs, 48)
|
||||
|
||||
self.assertEqual(pairs[0].title, 'Homarus_gammarus')
|
||||
self.assertEqual(pairs[-1].title, 'Rakie_Ayola')
|
||||
|
||||
# Make sure the content of the articles match the original
|
||||
wikitext_set = wikitext.RawDataset(
|
||||
subset='valid', shuffle_data=False, version='raw',
|
||||
data_dir=WIKITEXT_ROOT)
|
||||
title2article = {wikitext.normalize_title(a.title).replace(' ', ''): a.text
|
||||
for a in wikitext_set}
|
||||
for p in pairs:
|
||||
title = io_tools.normalize_freebase_string(p.title).replace(' ', '')
|
||||
article = title2article.get(title, None)
|
||||
self.assertIsNotNone(article)
|
||||
self.assertEqual(article, p.text)
|
||||
|
||||
def test_graph_from_edges(self):
|
||||
edges = ['A\tE1\tB',
|
||||
'A\tE2\tC',
|
||||
'B\tE1\tC',
|
||||
'C\tE3\tD',
|
||||
'C\tE2\tE']
|
||||
graph = paired_dataset.Graph.from_edges(edges)
|
||||
self.assertEqual(graph.nodes(), ['A', 'B', 'C', 'D', 'E'])
|
||||
self.assertEqual(graph.edges(), [(0, 1, 'E1'),
|
||||
(0, 2, 'E2'),
|
||||
(1, 2, 'E1'),
|
||||
(2, 3, 'E3'),
|
||||
(2, 4, 'E2')])
|
||||
|
||||
def test_graph_to_edges(self):
|
||||
edges = ['A\tE1\tB',
|
||||
'A\tE2\tC',
|
||||
'B\tE1\tC',
|
||||
'C\tE3\tD',
|
||||
'C\tE2\tE']
|
||||
graph = paired_dataset.Graph.from_edges(edges)
|
||||
self.assertEqual(graph.to_edges(), edges)
|
||||
|
||||
def test_bow2text_dataset(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 256
|
||||
dataset = paired_dataset.Bow2TextDataset(
|
||||
tokenizer,
|
||||
graph_tokenizer,
|
||||
batch_size=batch_size,
|
||||
timesteps=seq_len,
|
||||
subset='valid',
|
||||
subsample_nodes=0.7,
|
||||
repeat=False,
|
||||
data_dir=WIKIGRAPHS_ROOT)
|
||||
|
||||
num_tokens = 0
|
||||
for batch in dataset:
|
||||
num_tokens += batch['mask'].sum()
|
||||
self.assertEqual(batch['graphs'].shape,
|
||||
(batch_size, graph_tokenizer.vocab_size))
|
||||
|
||||
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
|
||||
raw_num_tokens = 0
|
||||
n_pairs = 0
|
||||
for pair in raw_dataset:
|
||||
raw_num_tokens += len(tokenizer.encode(
|
||||
pair.text, prepend_bos=True, append_eos=True))
|
||||
n_pairs += 1
|
||||
|
||||
# The first token of each example is not counted by `mask` as it masks the
|
||||
# targets, and the first token of each example never appears in the targets.
|
||||
self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
|
||||
|
||||
def test_graph2text_dataset(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 256
|
||||
dataset = paired_dataset.Graph2TextDataset(
|
||||
tokenizer,
|
||||
graph_tokenizer,
|
||||
batch_size=batch_size,
|
||||
timesteps=seq_len,
|
||||
subsample_nodes=0.8,
|
||||
subset='valid',
|
||||
data_dir=WIKIGRAPHS_ROOT)
|
||||
data_iter = iter(dataset)
|
||||
batch = next(data_iter)
|
||||
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
|
||||
self.assertIsInstance(batch['graphs'], list)
|
||||
self.assertLen(batch['graphs'], batch_size)
|
||||
for i in range(batch_size):
|
||||
self.assertIsInstance(batch['graphs'][i], paired_dataset.ParsedGraph)
|
||||
|
||||
# +1 for the center_node mask
|
||||
self.assertEqual(
|
||||
batch['graphs'][i].nodes.shape[-1], graph_tokenizer.vocab_size + 1)
|
||||
self.assertEqual(
|
||||
batch['graphs'][i].edges.shape[-1], graph_tokenizer.vocab_size)
|
||||
n_edges = batch['graphs'][i].edges.shape[0]
|
||||
self.assertEqual(batch['graphs'][i].sender.shape, (n_edges,))
|
||||
self.assertEqual(batch['graphs'][i].receiver.shape, (n_edges,))
|
||||
|
||||
# Make sure the token count matches across the tokenized data and the raw
|
||||
# data set.
|
||||
num_tokens = 0
|
||||
for batch in dataset:
|
||||
num_tokens += batch['mask'].sum()
|
||||
|
||||
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
|
||||
raw_num_tokens = 0
|
||||
n_pairs = 0
|
||||
for pair in raw_dataset:
|
||||
raw_num_tokens += len(tokenizer.encode(
|
||||
pair.text, prepend_bos=True, append_eos=True))
|
||||
n_pairs += 1
|
||||
|
||||
# The first token of each example is not counted by `mask` as it masks the
|
||||
# targets, and the first token of each example never appears in the targets.
|
||||
self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
|
||||
|
||||
def test_text_only_dataset(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 256
|
||||
dataset = paired_dataset.TextOnlyDataset(
|
||||
tokenizer,
|
||||
batch_size=batch_size,
|
||||
timesteps=seq_len,
|
||||
subset='valid',
|
||||
data_dir=WIKIGRAPHS_ROOT)
|
||||
data_iter = iter(dataset)
|
||||
batch = next(data_iter)
|
||||
faux_batch = dataset.return_faux_batch()
|
||||
|
||||
self.assertCountEqual(list(batch.keys()),
|
||||
['obs', 'target', 'should_reset', 'mask'])
|
||||
self.assertCountEqual(list(faux_batch.keys()),
|
||||
['obs', 'target', 'should_reset', 'mask'])
|
||||
for k, v in batch.items():
|
||||
faux_v = faux_batch[k]
|
||||
self.assertEqual(v.shape, faux_v.shape)
|
||||
self.assertEqual(v.dtype, faux_v.dtype)
|
||||
|
||||
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
|
||||
|
||||
num_tokens = 0
|
||||
for batch in dataset:
|
||||
num_tokens += batch['mask'].sum()
|
||||
|
||||
raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
|
||||
raw_num_tokens = 0
|
||||
n_pairs = 0
|
||||
for pair in raw_dataset:
|
||||
raw_num_tokens += len(tokenizer.encode(
|
||||
pair.text, prepend_bos=True, append_eos=True))
|
||||
n_pairs += 1
|
||||
self.assertEqual(num_tokens + n_pairs, raw_num_tokens)
|
||||
|
||||
def test_bow_retrieval_dataset(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 256
|
||||
dataset = paired_dataset.Bow2TextDataset(
|
||||
tokenizer,
|
||||
graph_tokenizer,
|
||||
batch_size=batch_size,
|
||||
timesteps=seq_len,
|
||||
subsample_nodes=0.8,
|
||||
graph_retrieval_dataset=True,
|
||||
subset='valid',
|
||||
data_dir=WIKIGRAPHS_ROOT)
|
||||
data_iter = iter(dataset)
|
||||
batch = next(data_iter)
|
||||
|
||||
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['graph_id'].shape, (batch_size,))
|
||||
self.assertEqual(batch['seq_id'].shape, (batch_size,))
|
||||
|
||||
def test_graph_retrieval_dataset(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 256
|
||||
dataset = paired_dataset.Graph2TextDataset(
|
||||
tokenizer,
|
||||
graph_tokenizer,
|
||||
batch_size=batch_size,
|
||||
timesteps=seq_len,
|
||||
subsample_nodes=0.8,
|
||||
graph_retrieval_dataset=True,
|
||||
subset='valid',
|
||||
data_dir=WIKIGRAPHS_ROOT)
|
||||
data_iter = iter(dataset)
|
||||
batch = next(data_iter)
|
||||
|
||||
self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['target'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
|
||||
self.assertEqual(batch['graph_id'].shape, (batch_size,))
|
||||
self.assertEqual(batch['seq_id'].shape, (batch_size,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
230
wikigraphs/wikigraphs/data/tokenizers.py
Normal file
230
wikigraphs/wikigraphs/data/tokenizers.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tokenizers for text data."""
|
||||
|
||||
import abc
|
||||
import csv
|
||||
import io
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
import numpy as np
|
||||
|
||||
from wikigraphs.data import io_tools
|
||||
|
||||
|
||||
class Tokenizer(abc.ABC):
|
||||
"""Base class for tokenizers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def encode(self,
|
||||
inputs: str,
|
||||
prepend_bos: bool = False,
|
||||
append_eos: bool = False) -> np.ndarray:
|
||||
"""Encode input string into an array of token IDs.
|
||||
|
||||
Args:
|
||||
inputs: a string.
|
||||
prepend_bos: set to True to add <bos> token at the beginning of the token
|
||||
sequence.
|
||||
append_eos: set to True to add <eos> token at the end of the token
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
tokens: [n_tokens] int array.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def decode(self, inputs) -> str:
|
||||
"""Decode a sequence of tokens back into a string.
|
||||
|
||||
Args:
|
||||
inputs: array or list of ints.
|
||||
|
||||
Returns:
|
||||
s: the decoded string using this tokenizer.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def vocab_size(self) -> int:
|
||||
"""Size of the vocabulary."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def pad_token(self) -> int:
|
||||
"""ID of the <pad> token."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def bos_token(self) -> int:
|
||||
"""ID of the <bos> token."""
|
||||
|
||||
|
||||
class WordTokenizer(Tokenizer):
|
||||
"""Word-level tokenizer for white-space separated text data."""
|
||||
|
||||
def __init__(self, vocab_file: str):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
vocab_file: a csv vocab file.
|
||||
"""
|
||||
content = io_tools.read_txt_file(vocab_file, encoding='utf-8')
|
||||
|
||||
with io.StringIO(content) as f:
|
||||
r = csv.reader(f)
|
||||
vocab = [w for w, _ in r]
|
||||
|
||||
# Add pad and bos tokens to the vocab
|
||||
to_add = ['<pad>', '<bos>']
|
||||
if '<unk>' not in vocab:
|
||||
to_add.append('<unk>')
|
||||
vocab = to_add + vocab
|
||||
|
||||
# token-index mappings
|
||||
self._t2i = {t: i for i, t in enumerate(vocab)}
|
||||
self._i2t = {i: t for t, i in self._t2i.items()}
|
||||
|
||||
self._unk_token = self._t2i['<unk>']
|
||||
self._bos_token = self._t2i['<bos>']
|
||||
self._pad_token = self._t2i['<pad>']
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self._t2i)
|
||||
|
||||
def encode(self, inputs, prepend_bos=False, append_eos=False):
|
||||
tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t]
|
||||
if prepend_bos:
|
||||
tokens = [self._bos_token] + tokens
|
||||
if append_eos:
|
||||
# Reuse <bos> as <eos>.
|
||||
tokens.append(self._bos_token)
|
||||
return np.array(tokens, dtype=np.int32)
|
||||
|
||||
def decode(self, inputs):
|
||||
"""Decode a sequence of token IDs back into a string."""
|
||||
# Remove the first <bos> token if there is any.
|
||||
if inputs[0] == self._bos_token:
|
||||
inputs = inputs[1:]
|
||||
tokens = []
|
||||
for i in inputs:
|
||||
# Use <bos> also as <eos> and stop there.
|
||||
if i == self._bos_token:
|
||||
break
|
||||
tokens.append(self._i2t[i])
|
||||
return ' '.join(tokens)
|
||||
|
||||
def pad_token(self):
|
||||
return self._pad_token
|
||||
|
||||
def bos_token(self):
|
||||
return self._bos_token
|
||||
|
||||
|
||||
class GraphTokenizer:
|
||||
"""Tokenizer for the content on the graphs."""
|
||||
|
||||
def __init__(self, vocab_file: str):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
vocab_file: path to a vocab file.
|
||||
"""
|
||||
content = io_tools.read_txt_file(vocab_file, encoding='utf-16')
|
||||
|
||||
vocab = content.split('\n')
|
||||
vocab = ['<pad>', '<bos>', '<unk>'] + vocab
|
||||
|
||||
# token-index mappings
|
||||
self._t2i = {t: i for i, t in enumerate(vocab)}
|
||||
self._i2t = {i: t for t, i in self._t2i.items()}
|
||||
|
||||
self._unk_token = self._t2i['<unk>']
|
||||
self._bos_token = self._t2i['<bos>']
|
||||
self._pad_token = self._t2i['<pad>']
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self._t2i)
|
||||
|
||||
def encode_node(self, txt: str) -> np.ndarray:
|
||||
return np.array([self._t2i.get(t, self._unk_token)
|
||||
for t in self.split_node(txt)])
|
||||
|
||||
def encode_edge(self, txt: str) -> np.ndarray:
|
||||
return np.array([self._t2i.get(t, self._unk_token)
|
||||
for t in self.split_edge(txt)])
|
||||
|
||||
def encode(self, inputs, prepend_bos=False, append_eos=False):
|
||||
tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t]
|
||||
if prepend_bos:
|
||||
tokens = [self._bos_token] + tokens
|
||||
if append_eos:
|
||||
# Reuse <bos> as <eos>.
|
||||
tokens.append(self._bos_token)
|
||||
return np.array(tokens, dtype=np.int32)
|
||||
|
||||
def decode(self, inputs):
|
||||
"""Decode a sequence of token IDs back into a string."""
|
||||
# Remove the first <bos> token if there is any.
|
||||
if inputs[0] == self._bos_token:
|
||||
inputs = inputs[1:]
|
||||
tokens = []
|
||||
for i in inputs:
|
||||
# Use <bos> also as <eos> and stop there.
|
||||
if i == self._bos_token:
|
||||
break
|
||||
tokens.append(self._i2t[i])
|
||||
return ' '.join(tokens)
|
||||
|
||||
@classmethod
|
||||
def split_node(cls, txt: str) -> List[str]:
|
||||
"""Split a node string into a sequence of tokens."""
|
||||
if txt[0] == '"' and txt[-1] == '"': # Node is a string literal.
|
||||
tokens = nltk.wordpunct_tokenize(io_tools.normalize_freebase_string(
|
||||
txt[1:-1].lower()))
|
||||
for i, t in enumerate(tokens):
|
||||
if t.isnumeric():
|
||||
tokens[i] = '<number>'
|
||||
return tokens
|
||||
else: # If node is not a string literal it is always an entity.
|
||||
return ['<entity>']
|
||||
|
||||
@classmethod
|
||||
def split_edge(cls, txt: str) -> List[str]:
|
||||
"""Split an edge string into a sequence of tokens."""
|
||||
return re.split('[._ ]+', txt.lower().split('/')[1])
|
||||
|
||||
def pad_token(self):
|
||||
return self._pad_token
|
||||
|
||||
def bos_token(self):
|
||||
return self._bos_token
|
||||
78
wikigraphs/wikigraphs/data/tokenizers_test.py
Normal file
78
wikigraphs/wikigraphs/data/tokenizers_test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tests for wikigraphs.data.tokenizers."""
|
||||
from absl.testing import absltest
|
||||
from wikigraphs.data import tokenizers
|
||||
|
||||
|
||||
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
|
||||
GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv'
|
||||
|
||||
|
||||
class TokenizerTest(absltest.TestCase):
|
||||
|
||||
def test_tokenizer(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
# Vocab size must match published number.
|
||||
self.assertEqual(tokenizer.vocab_size, 267735 + 2)
|
||||
|
||||
s = 'Hello world ! \n How are you ?'
|
||||
encoded = tokenizer.encode(s, prepend_bos=True)
|
||||
self.assertEqual(encoded.shape, (9,))
|
||||
decoded = tokenizer.decode(encoded)
|
||||
self.assertEqual(s, decoded)
|
||||
|
||||
def test_graph_tokenizer_tokenize_nodes_edges(self):
|
||||
self.assertEqual(
|
||||
tokenizers.GraphTokenizer.split_node(
|
||||
'"Hello, how are you?"'),
|
||||
['hello', ',', 'how', 'are', 'you', '?'])
|
||||
self.assertEqual(
|
||||
tokenizers.GraphTokenizer.split_node(
|
||||
'"This building was built in 1998."'),
|
||||
['this', 'building', 'was', 'built', 'in', '<number>', '.'])
|
||||
self.assertEqual(
|
||||
tokenizers.GraphTokenizer.split_node('ns/m.030ssw'),
|
||||
['<entity>'])
|
||||
|
||||
self.assertEqual(
|
||||
tokenizers.GraphTokenizer.split_edge('ns/common.topic.description'),
|
||||
['common', 'topic', 'description'])
|
||||
self.assertEqual(
|
||||
tokenizers.GraphTokenizer.split_edge('ns/type.object.name'),
|
||||
['type', 'object', 'name'])
|
||||
|
||||
def test_graph_tokenizer_vocab(self):
|
||||
tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)
|
||||
self.assertEqual(tokenizer.vocab_size, 31087 + 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
242
wikigraphs/wikigraphs/data/tools.py
Normal file
242
wikigraphs/wikigraphs/data/tools.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Some tools for processing data."""
|
||||
|
||||
from typing import Any, Iterator
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pad_to(x: np.array, size: int, axis: int = -1, pad_value: float = 0.):
|
||||
"""Pad an array to the specified size along a specified axis."""
|
||||
if x.shape[axis] > size:
|
||||
raise ValueError(f'Data item has size {x.shape[axis]} larger than {size}'
|
||||
f' in axis {axis} already.')
|
||||
elif x.shape[axis] == size:
|
||||
return x
|
||||
else:
|
||||
pad_amount = [(0, 0)] * x.ndim
|
||||
pad_amount[axis] = (0, size - x.shape[axis])
|
||||
return np.pad(x, pad_amount, mode='constant', constant_values=pad_value)
|
||||
|
||||
|
||||
def dynamic_batch(
|
||||
iterable: Iterator[Any],
|
||||
batch_size: int,
|
||||
timesteps: int,
|
||||
return_incomplete_batch: bool = False,
|
||||
pad: bool = False,
|
||||
pad_value: float = 0.) -> Iterator[Any]:
|
||||
"""Batches up values in iterable to [batch_size, timesteps].
|
||||
|
||||
This function takes items from the iterable and pack them into the batch.
|
||||
Sequence #i in the batch is a continuation from the sequence #i in the
|
||||
previous batch, i.e. it will start from where the previous sequence left over.
|
||||
When an item is finished, a new item is taken from the iterable to append to
|
||||
the sequence and fill the batch.
|
||||
|
||||
This function is designed for language modeling, where the input and the
|
||||
target sequences are offset by one. We take that into account by making sure
|
||||
neighboring batches have one token overlap.
|
||||
|
||||
Example:
|
||||
If the iterable contains [[0, 1, 2], [10, 11, 12, 13, 14], [20, 21, 22]] and
|
||||
batch size is 2, timesteps is 3, then the first batch would be:
|
||||
[[0, 1, 2],
|
||||
[10, 11, 12]]
|
||||
then the second batch:
|
||||
[[2, 20, 21], # seq 0 finished, continuing from seq 2
|
||||
[12, 13, 14]]
|
||||
Note the overlap of 1 token between these two batches, and the continuation
|
||||
of sequences across batches.
|
||||
|
||||
Args:
|
||||
iterable: the iterable that yields sequences of integer token IDs.
|
||||
batch_size: number of examples in a batch.
|
||||
timesteps: length of each sequence in a batch.
|
||||
return_incomplete_batch: if True return the incomplete batches, which
|
||||
typically appears at the end of the dataset.
|
||||
pad: set to True to pad the incomplete batches.
|
||||
pad_value: the value to use for padding.
|
||||
|
||||
Yields:
|
||||
batches: where batches['obs'] are the observations of size
|
||||
[batch_size, timesteps], and batches['should_reset'] is a 0/1 mask of
|
||||
the same size that marks sequence boundaries, e.g. the entries in this
|
||||
mask are all 0 except at locations where a new sequence is starting.
|
||||
"""
|
||||
if return_incomplete_batch and not pad:
|
||||
raise ValueError(
|
||||
f'If return_incomplete_batch, then pad must be True, currently {pad}.')
|
||||
|
||||
iterator = iter(iterable)
|
||||
elems = []
|
||||
for _ in range(batch_size):
|
||||
item = next(iterator)
|
||||
elems.append(item)
|
||||
start_batch = [True] * batch_size
|
||||
|
||||
iter_finished = False
|
||||
loaded_finished = False
|
||||
while not (iter_finished and loaded_finished):
|
||||
batch = []
|
||||
for i in range(batch_size):
|
||||
# should_reset value is 1 when a new sequence begins.
|
||||
# [old[-3], old[-2], old[-1], new[0], new[1], new[2]]
|
||||
# [0, 0, 0, 1, 0, 0]
|
||||
should_reset = np.zeros(timesteps, np.float32)
|
||||
if start_batch[i]:
|
||||
should_reset[0] = 1
|
||||
|
||||
# Pack new examples in the sequence until they go beyond the required
|
||||
# timesteps.
|
||||
while len(elems[i]) < timesteps:
|
||||
should_reset[len(elems[i])] = 1
|
||||
try:
|
||||
item = next(iterator)
|
||||
except StopIteration:
|
||||
iter_finished = True
|
||||
break
|
||||
elems[i] = np.concatenate([elems[i], item])
|
||||
|
||||
batch.append(dict(obs=elems[i][:timesteps], should_reset=should_reset))
|
||||
# Shift and make sure we have a 1 token overlap.
|
||||
elems[i] = elems[i][timesteps - 1:]
|
||||
# Since the last token is shifted to be the first token of the next batch,
|
||||
# We need to make sure reset is handled properly as well.
|
||||
start_batch[i] = (should_reset[-1] == 1)
|
||||
|
||||
# If any loaded data is not yet consumed in the output we should keep
|
||||
# generating.
|
||||
loaded_finished = all(e.size == 0 for e in elems)
|
||||
|
||||
if not return_incomplete_batch:
|
||||
elem_len = len(batch[0]['obs'])
|
||||
if (elem_len != timesteps or
|
||||
not all(len(x['obs']) == elem_len for x in batch[1:])):
|
||||
logging.info('Dropping the (last?) incomplete batch.')
|
||||
break
|
||||
|
||||
if pad:
|
||||
for x in batch:
|
||||
x['obs'] = pad_to(x['obs'], timesteps, axis=0, pad_value=pad_value)
|
||||
|
||||
yield dict(
|
||||
obs=np.stack([x['obs'] for x in batch], axis=0),
|
||||
should_reset=np.stack([x['should_reset'] for x in batch], axis=0))
|
||||
|
||||
|
||||
def batch_graph_text_pairs(
|
||||
iterable: Iterator[Any],
|
||||
batch_size: int,
|
||||
timesteps: int,
|
||||
pad_value: float = 0.,
|
||||
seq_and_graph_id: bool = False) -> Iterator[Any]:
|
||||
"""Batch graph and text pairs.
|
||||
|
||||
This method pairs text with graphs, each text sequence is split into chunks
|
||||
(with an overlap of 1) of size `timesteps`, and the graph associated with the
|
||||
text is used and associated with each chunk as well. The last incomplete
|
||||
chunk of each text sequence is padded with the `pad_value`.
|
||||
|
||||
Args:
|
||||
iterable: Iterable that returns (graph, sequence) pairs, graph can be
|
||||
anything, and sequence is a list of tokenized token IDs.
|
||||
batch_size: Number of examples in a batch.
|
||||
timesteps: Window size for the sequences.
|
||||
pad_value: Value to use for padding.
|
||||
seq_and_graph_id: whether the `iterable` contains `seq_id` and `graph_id`.
|
||||
|
||||
Yields:
|
||||
batch: a batch of text sequence paired with graphs.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
seqs = [None] * batch_size
|
||||
graphs = [None] * batch_size
|
||||
graph_ids = [None] * batch_size
|
||||
seq_ids = [None] * batch_size
|
||||
|
||||
iter_finished = False
|
||||
loaded_finished = False
|
||||
while not (iter_finished and loaded_finished):
|
||||
batch = []
|
||||
for idx in range(batch_size):
|
||||
should_reset = np.zeros(timesteps, np.float32)
|
||||
# pylint: disable=g-explicit-length-test
|
||||
if seqs[idx] is None or len(seqs[idx]) == 0:
|
||||
should_reset[0] = 1
|
||||
# One sequence exhausted, get the next example.
|
||||
try:
|
||||
if seq_and_graph_id:
|
||||
(graph, seq), (graph_id, seq_id) = next(iterator)
|
||||
graph_ids[idx] = graph_id
|
||||
seq_ids[idx] = seq_id
|
||||
else:
|
||||
graph, seq = next(iterator)
|
||||
seqs[idx] = seq
|
||||
graphs[idx] = graph
|
||||
except StopIteration:
|
||||
iter_finished = True
|
||||
seqs[idx] = np.array([pad_value], dtype=np.int32)
|
||||
graphs[idx] = None
|
||||
|
||||
example = dict(obs=seqs[idx][:timesteps], graph=graphs[idx],
|
||||
should_reset=should_reset)
|
||||
if seq_and_graph_id:
|
||||
example['seq_id'] = seq_ids[idx]
|
||||
example['graph_id'] = graph_ids[idx]
|
||||
|
||||
batch.append(example)
|
||||
# Make sure that there is an overlap, as we generate targets by shifting
|
||||
# the tensor by 1 timestep. So the next element should be shifted by
|
||||
# `timesteps - 1' timesteps.
|
||||
seqs[idx] = seqs[idx][timesteps - 1:]
|
||||
|
||||
# Make sure all loaded data are consumed in the output
|
||||
loaded_finished = all(s.size == 0 for s in seqs)
|
||||
|
||||
# Also check for the last batch to avoid returning a fully empty batch
|
||||
if iter_finished and all([np.all(b['obs'] == pad_value) for b in batch]):
|
||||
break
|
||||
|
||||
# pad sequences to specified length
|
||||
for e in batch:
|
||||
e['obs'] = pad_to(e['obs'], timesteps, axis=0, pad_value=pad_value)
|
||||
stacked_batch = dict(
|
||||
obs=np.stack([e['obs'] for e in batch], axis=0),
|
||||
graphs=[e['graph'] for e in batch],
|
||||
should_reset=np.stack([e['should_reset'] for e in batch], axis=0))
|
||||
if seq_and_graph_id:
|
||||
stacked_batch['seq_id'] = np.stack(
|
||||
[e['seq_id'] for e in batch], axis=0)
|
||||
stacked_batch['graph_id'] = np.stack(
|
||||
[e['graph_id'] for e in batch], axis=0)
|
||||
yield stacked_batch
|
||||
195
wikigraphs/wikigraphs/data/tools_test.py
Normal file
195
wikigraphs/wikigraphs/data/tools_test.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tests for wikigraphs.data.tools."""
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from wikigraphs.data import tools
|
||||
|
||||
|
||||
class ToolsTest(absltest.TestCase):
|
||||
|
||||
def test_padding(self):
|
||||
np.testing.assert_array_equal(
|
||||
tools.pad_to(np.arange(3), 5),
|
||||
[0, 1, 2, 0, 0])
|
||||
np.testing.assert_array_equal(
|
||||
tools.pad_to(np.arange(3), 5, pad_value=-1),
|
||||
[0, 1, 2, -1, -1])
|
||||
np.testing.assert_array_equal(
|
||||
tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=0, pad_value=-1),
|
||||
[[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[-1, -1, -1],
|
||||
[-1, -1, -1]])
|
||||
np.testing.assert_array_equal(
|
||||
tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=-1, pad_value=-1),
|
||||
[[0, 1, 2, -1],
|
||||
[3, 4, 5, -1]])
|
||||
|
||||
def test_dynamic_batch(self):
|
||||
def dataset():
|
||||
data = [[1, 2, 2, 2],
|
||||
[1, 3, 3],
|
||||
[1, 4]]
|
||||
for d in data:
|
||||
yield np.array(d, dtype=np.int32)
|
||||
batches = list(tools.dynamic_batch(
|
||||
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=False))
|
||||
self.assertLen(batches, 1)
|
||||
np.testing.assert_array_equal(
|
||||
batches[0]['obs'],
|
||||
[[1, 2, 2], [1, 3, 3]])
|
||||
np.testing.assert_array_equal(
|
||||
batches[0]['should_reset'],
|
||||
[[1, 0, 0], [1, 0, 0]])
|
||||
|
||||
batches = list(tools.dynamic_batch(
|
||||
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True,
|
||||
pad=True, pad_value=0))
|
||||
# Note `return_incomplete_batch=False` drops all the incomplete batches,
|
||||
# and this can be more than just the last batch.
|
||||
self.assertLen(batches, 3)
|
||||
np.testing.assert_array_equal(
|
||||
batches[0]['obs'],
|
||||
[[1, 2, 2], [1, 3, 3]])
|
||||
np.testing.assert_array_equal(
|
||||
batches[0]['should_reset'],
|
||||
[[1, 0, 0], [1, 0, 0]])
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
batches[1]['obs'],
|
||||
[[2, 2, 1], [3, 0, 0]])
|
||||
np.testing.assert_array_equal(
|
||||
batches[1]['should_reset'],
|
||||
[[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
batches[2]['obs'],
|
||||
[[1, 4, 0], [0, 0, 0]])
|
||||
np.testing.assert_array_equal(
|
||||
batches[2]['should_reset'],
|
||||
[[1, 0, 1], [1, 0, 0]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
batches = list(tools.dynamic_batch(
|
||||
dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True,
|
||||
pad=False))
|
||||
|
||||
def test_batch_graph_text_pairs(self):
|
||||
def source():
|
||||
yield (1, np.array([1, 1, 1, 1, 1], dtype=np.int32))
|
||||
yield (2, np.array([2, 2], dtype=np.int32))
|
||||
yield (3, np.array([3, 3, 3, 3, 3, 3], dtype=np.int32))
|
||||
|
||||
data_iter = tools.batch_graph_text_pairs(
|
||||
source(), batch_size=2, timesteps=3, pad_value=0)
|
||||
|
||||
batches = list(data_iter)
|
||||
self.assertLen(batches, 4)
|
||||
|
||||
batch = batches[0]
|
||||
np.testing.assert_array_equal(
|
||||
batch['obs'],
|
||||
[[1, 1, 1],
|
||||
[2, 2, 0]])
|
||||
self.assertEqual(batch['graphs'], [1, 2])
|
||||
np.testing.assert_array_equal(
|
||||
batch['should_reset'],
|
||||
[[1, 0, 0],
|
||||
[1, 0, 0]])
|
||||
|
||||
batch = batches[1]
|
||||
np.testing.assert_array_equal(
|
||||
batch['obs'],
|
||||
[[1, 1, 1],
|
||||
[3, 3, 3]])
|
||||
self.assertEqual(batch['graphs'], [1, 3])
|
||||
np.testing.assert_array_equal(
|
||||
batch['should_reset'],
|
||||
[[0, 0, 0],
|
||||
[1, 0, 0]])
|
||||
|
||||
batch = batches[2]
|
||||
np.testing.assert_array_equal(
|
||||
batch['obs'],
|
||||
[[1, 0, 0],
|
||||
[3, 3, 3]])
|
||||
self.assertEqual(batch['graphs'], [1, 3])
|
||||
np.testing.assert_array_equal(
|
||||
batch['should_reset'],
|
||||
[[0, 0, 0],
|
||||
[0, 0, 0]])
|
||||
|
||||
batch = batches[3]
|
||||
np.testing.assert_array_equal(
|
||||
batch['obs'],
|
||||
[[0, 0, 0],
|
||||
[3, 3, 0]])
|
||||
self.assertEqual(batch['graphs'], [None, 3])
|
||||
np.testing.assert_array_equal(
|
||||
batch['should_reset'],
|
||||
[[1, 0, 0],
|
||||
[0, 0, 0]])
|
||||
|
||||
def test_batch_graph_text_pairs_batch_size1(self):
|
||||
def source():
|
||||
yield (0, np.array([1, 2], dtype=np.int32))
|
||||
yield (1, np.array([1, 2, 3, 4, 5, 6], dtype=np.int32))
|
||||
|
||||
data_iter = tools.batch_graph_text_pairs(
|
||||
source(), batch_size=1, timesteps=3, pad_value=0)
|
||||
|
||||
batches = list(data_iter)
|
||||
|
||||
batch = batches[0]
|
||||
np.testing.assert_array_equal(batch['obs'], [[1, 2, 0]])
|
||||
self.assertEqual(batch['graphs'], [0])
|
||||
np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]])
|
||||
|
||||
batch = batches[1]
|
||||
np.testing.assert_array_equal(batch['obs'], [[1, 2, 3]])
|
||||
self.assertEqual(batch['graphs'], [1])
|
||||
np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]])
|
||||
|
||||
batch = batches[2]
|
||||
np.testing.assert_array_equal(batch['obs'], [[3, 4, 5]])
|
||||
self.assertEqual(batch['graphs'], [1])
|
||||
np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]])
|
||||
|
||||
batch = batches[3]
|
||||
np.testing.assert_array_equal(batch['obs'], [[5, 6, 0]])
|
||||
self.assertEqual(batch['graphs'], [1])
|
||||
np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]])
|
||||
|
||||
self.assertLen(batches, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
218
wikigraphs/wikigraphs/data/wikitext.py
Normal file
218
wikigraphs/wikigraphs/data/wikitext.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Wikitext-103 datasets."""
|
||||
|
||||
import re
|
||||
from typing import NamedTuple, List
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from wikigraphs.data import dataset
|
||||
from wikigraphs.data import tokenizers
|
||||
from wikigraphs.data import tools
|
||||
|
||||
|
||||
# The data directory that contains subdirectories `wikitext-103` and
|
||||
# `wikitext-103-raw`.
|
||||
DATA_ROOT = '/tmp/data/wikitext-103'
|
||||
|
||||
|
||||
class WikitextArticle(NamedTuple):
|
||||
title: str
|
||||
text: str
|
||||
|
||||
|
||||
def articles_from_file(file_path: str) -> List[WikitextArticle]:
|
||||
"""Read wikitext articles from file.
|
||||
|
||||
Args:
|
||||
file_path: path to the input `.tokens` file.
|
||||
|
||||
Returns:
|
||||
A list of `WikitextArticle` tuples.
|
||||
"""
|
||||
with open(file_path, mode='rb') as f:
|
||||
content = f.read()
|
||||
content = content.decode('utf-8')
|
||||
|
||||
title_re = re.compile(r'(\n = ([^=].*) = \n \n)')
|
||||
parts = title_re.split(content)
|
||||
|
||||
# Skip the first part which is empty
|
||||
return [WikitextArticle(title=parts[i+1], text=parts[i] + parts[i+2])
|
||||
for i in range(1, len(parts), 3)]
|
||||
|
||||
|
||||
class RawDataset(dataset.Dataset):
|
||||
"""Raw text dataset for wikitext-103."""
|
||||
|
||||
def __init__(self,
|
||||
subset: str = 'train',
|
||||
shuffle_data: bool = False,
|
||||
data_dir: str = None,
|
||||
version: str = 'tokens'):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
subset: which subset to load, one of {"train", "valid", "test"}.
|
||||
shuffle_data: if set to True the data will be randomly shuffled.
|
||||
data_dir: if provided will be used instead of the default `DATA_ROOT` as
|
||||
the directory that contains the data.
|
||||
version: one of {'tokens', 'raw'}
|
||||
"""
|
||||
super().__init__()
|
||||
self._subset = subset
|
||||
self._shuffle_data = shuffle_data
|
||||
self._data_dir = data_dir or DATA_ROOT
|
||||
self._dataset = None
|
||||
|
||||
allowed_versions = ('tokens', 'raw')
|
||||
if version not in allowed_versions:
|
||||
raise ValueError(f'Version must be one of {allowed_versions}.')
|
||||
self._version = version
|
||||
|
||||
def _load_data(self):
|
||||
"""Prepare data for another pass through the dataset."""
|
||||
if self._dataset is None:
|
||||
data_root = self._data_dir + ('-raw' if self._version == 'raw' else '')
|
||||
self._dataset = articles_from_file(
|
||||
f'{data_root}/wiki.{self._subset}.{self._version}')
|
||||
|
||||
def source():
|
||||
n_articles = len(self._dataset)
|
||||
if self._shuffle_data:
|
||||
idx = np.random.permutation(n_articles)
|
||||
else:
|
||||
idx = np.arange(n_articles)
|
||||
for i in range(n_articles):
|
||||
yield self._dataset[idx[i]]
|
||||
|
||||
return source()
|
||||
|
||||
|
||||
def normalize_title(title: str) -> str:
|
||||
"""Normalize the wikitext article title by handling special characters."""
|
||||
return title.replace(
|
||||
'@-@', '-').replace('@,@', ',').replace('@.@', '.').replace(' ', '')
|
||||
|
||||
|
||||
class WikitextDataset(dataset.Dataset):
|
||||
"""Tokenized dataset for wikitext-103."""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: tokenizers.Tokenizer,
|
||||
batch_size: int = 1,
|
||||
timesteps: int = 128,
|
||||
subset: str = 'train',
|
||||
shuffle_data: bool = True,
|
||||
data_dir: str = None,
|
||||
repeat: bool = False,
|
||||
debug: bool = False,
|
||||
**kwargs):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
tokenizer: a tokenizer for text data.
|
||||
batch_size: number of sequences to put into a batch.
|
||||
timesteps: length of the sequences.
|
||||
subset: which subset to load, one of {"train", "valid", "test"}.
|
||||
shuffle_data: if set to True the data will be randomly shuffled.
|
||||
data_dir: if provided will be used instead of the default `DATA_ROOT` as
|
||||
the directory that contains the data.
|
||||
repeat: set to False to go through the data only once, otherwise go
|
||||
through the data indefinitely.
|
||||
debug: set to True to only load a small amount of data for fast debugging.
|
||||
**kwargs: other arguments (for interface compatibility).
|
||||
"""
|
||||
super().__init__()
|
||||
self._tokenizer = tokenizer
|
||||
self._batch_size = batch_size
|
||||
self._timesteps = timesteps
|
||||
self._subset = subset
|
||||
self._shuffle_data = shuffle_data
|
||||
self._data_dir = data_dir
|
||||
self._repeat = repeat
|
||||
self._debug = debug
|
||||
self._dataset = None
|
||||
|
||||
def _load_data(self):
|
||||
"""Prepare data for one pass through the dataset."""
|
||||
# Pre-tokenize everything in our dataset so we don't have to when going
|
||||
# through the data more than once.
|
||||
if not self._dataset:
|
||||
raw_dataset = RawDataset(
|
||||
subset=self._subset, shuffle_data=False, data_dir=self._data_dir)
|
||||
if self._debug:
|
||||
# Load a small number of examples for debugging.
|
||||
self._dataset = [
|
||||
self._tokenizer.encode(next(raw_dataset).text, prepend_bos=True)
|
||||
for _ in range(5)]
|
||||
else:
|
||||
self._dataset = [self._tokenizer.encode(item.text, prepend_bos=True)
|
||||
for item in raw_dataset]
|
||||
logging.info('%s set loaded, total %d examples.',
|
||||
self._subset, len(self._dataset))
|
||||
|
||||
def source():
|
||||
idx = np.random.permutation(len(self._dataset))
|
||||
for i in idx:
|
||||
yield self._dataset[i]
|
||||
|
||||
def repeated_source():
|
||||
if self._repeat:
|
||||
while True:
|
||||
yield from source()
|
||||
else:
|
||||
yield from source()
|
||||
|
||||
data_iter = tools.dynamic_batch(
|
||||
repeated_source(),
|
||||
self._batch_size,
|
||||
self._timesteps + 1, # Extra token to count for the overlap.
|
||||
return_incomplete_batch=True,
|
||||
pad=True,
|
||||
pad_value=self._tokenizer.pad_token())
|
||||
data_iter = map(lambda x: dict( # pylint: disable=g-long-lambda
|
||||
obs=x['obs'][:, :-1],
|
||||
target=x['obs'][:, 1:],
|
||||
should_reset=x['should_reset'][:, :-1],
|
||||
mask=(x['obs'][:, 1:] != self._tokenizer.pad_token()).astype(
|
||||
np.float32),
|
||||
), data_iter)
|
||||
return data_iter
|
||||
|
||||
def return_faux_batch(self):
|
||||
"""Return a fake batch with the right shapes and dtypes."""
|
||||
obs = np.zeros((self._batch_size, self._timesteps), dtype=np.int32)
|
||||
target = np.zeros_like(obs, dtype=np.int32)
|
||||
should_reset = np.zeros_like(obs, dtype=np.float32)
|
||||
mask = np.zeros_like(obs, dtype=np.float32)
|
||||
return dict(obs=obs, target=target, should_reset=should_reset, mask=mask)
|
||||
84
wikigraphs/wikigraphs/data/wikitext_test.py
Normal file
84
wikigraphs/wikigraphs/data/wikitext_test.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tests for wikigraphs.data.wikitext."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from wikigraphs.data import tokenizers
|
||||
from wikigraphs.data import wikitext
|
||||
|
||||
|
||||
WIKITEXT_ROOT = '/tmp/data/wikitext-103'
|
||||
WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv'
|
||||
|
||||
|
||||
class WikitextTest(absltest.TestCase):
|
||||
|
||||
def test_wikitext_size(self):
|
||||
valid_set = wikitext.RawDataset(
|
||||
subset='valid', shuffle_data=False, data_dir=WIKITEXT_ROOT)
|
||||
n_tokens = 0
|
||||
n_articles = 0
|
||||
for article in valid_set:
|
||||
n_tokens += len([t for t in article.text.split(' ') if t])
|
||||
n_articles += 1
|
||||
|
||||
# Dataset size must match published values.
|
||||
self.assertEqual(n_tokens, 217646)
|
||||
self.assertEqual(n_articles, 60)
|
||||
|
||||
def test_wikitext_dataset_size(self):
|
||||
tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
|
||||
batch_size = 4
|
||||
timesteps = 256
|
||||
valid_set = wikitext.WikitextDataset(
|
||||
tokenizer=tokenizer, batch_size=batch_size, timesteps=timesteps,
|
||||
subset='valid', shuffle_data=False, repeat=False,
|
||||
data_dir=WIKITEXT_ROOT)
|
||||
n_tokens = 0
|
||||
n_bos = 0
|
||||
for batch in valid_set:
|
||||
n_tokens += (batch['obs'] != tokenizer.pad_token()).sum()
|
||||
n_bos += (batch['obs'] == tokenizer.bos_token()).sum()
|
||||
self.assertEqual(
|
||||
batch['obs'].shape, (batch_size, timesteps))
|
||||
self.assertEqual(
|
||||
batch['target'].shape, (batch_size, timesteps))
|
||||
self.assertEqual(
|
||||
batch['should_reset'].shape, (batch_size, timesteps))
|
||||
self.assertEqual(
|
||||
batch['mask'].shape, (batch_size, timesteps))
|
||||
|
||||
n_tokens -= n_bos
|
||||
self.assertEqual(n_tokens, 217646)
|
||||
self.assertEqual(n_bos, 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user