mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 06:06:39 +08:00
Release training scripts and samplers.
PiperOrigin-RevId: 390063136
This commit is contained in:
committed by
Diego de Las Casas
parent
5e35a17cf4
commit
1243baa0b0
+68
-1
@@ -167,7 +167,74 @@ it elsewhere.
|
||||
|
||||
## Run baseline models
|
||||
|
||||
Note: baseline models will be available soon.
|
||||
Note: our code supports training with multiple GPUs.
|
||||
|
||||
To run the default baseline GNN-based TransformerXL on Wikigraphs with 8
|
||||
GPUs:
|
||||
|
||||
```base
|
||||
python main.py --model_type=graph2text \
|
||||
--dataset=freebase2wikitext \
|
||||
--checkpoint_dir=/tmp/graph2text \
|
||||
--job_mode=train \
|
||||
--train_batch_size=64 \
|
||||
--gnn_num_layers=1 \
|
||||
--num_gpus=8
|
||||
```
|
||||
|
||||
We ran our experiments in the paper using 8 Nvidia V100 GPUs. To allow for
|
||||
batch parallization for the GNN-based (graph2text) model, we pad graphs to
|
||||
the largest graph in the batch. The full run takes almost 4 days. BoW- and
|
||||
nodes-based models can be trained within 14 hours because there is no
|
||||
additional padding.
|
||||
|
||||
Or to quickly test-run a small model:
|
||||
|
||||
```base
|
||||
python main.py --model_type=graph2text \
|
||||
--dataset=freebase2wikitext \
|
||||
--checkpoint_dir=/tmp/graph2text \
|
||||
--job_mode=train \
|
||||
--train_batch_size=2 \
|
||||
--gnn_num_layers=1
|
||||
```
|
||||
|
||||
To evaluate the model on the validation set (this only uses 1 GPU):
|
||||
|
||||
```base
|
||||
python main.py --model_type=graph2text \
|
||||
--dataset=freebase2wikitext \
|
||||
--checkpoint_dir=/tmp/graph2text \
|
||||
--job_mode=eval \
|
||||
--eval_subset=valid
|
||||
```
|
||||
|
||||
To generate 960 samples from the model using the graphs in the validation set (using 8 GPUs):
|
||||
|
||||
```base
|
||||
python main.py --model_type=graph2text \
|
||||
--dataset=freebase2wikitext \
|
||||
--checkpoint_dir=/tmp/graph2text \
|
||||
--job_mode=sample \
|
||||
--eval_subset=valid \
|
||||
--num_gpus=8 \
|
||||
--num_samples=960
|
||||
```
|
||||
|
||||
To compute the rBLEU score of the generated samples:
|
||||
|
||||
```base
|
||||
python scripts/compute_bleu_score.py --dataset=freebase2wikitext \
|
||||
--checkpoint_dir=/tmp/graph2text
|
||||
```
|
||||
|
||||
To compute the retrieval scores:
|
||||
|
||||
```base
|
||||
python main.py --dataset=freebase2wikitext \
|
||||
--job_mode=retrieve \
|
||||
--checkpoint_dir=/tmp/graph2text
|
||||
```
|
||||
|
||||
## Citing WikiGraphs
|
||||
|
||||
|
||||
@@ -0,0 +1,416 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Train a transformer for language modeling on Wikitext-103."""
|
||||
|
||||
import concurrent
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
import jraph
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
from updaters import CheckpointingUpdater
|
||||
from updaters import Updater
|
||||
import utils
|
||||
|
||||
|
||||
# Train
|
||||
flags.DEFINE_integer('train_batch_size', 4, '(Per-Device) batch size for'
|
||||
' training.')
|
||||
flags.DEFINE_integer('train_timesteps', 150, 'Sequence length to learn on')
|
||||
flags.DEFINE_integer('train_memory_size', 150, 'Memory size for transformer XL')
|
||||
flags.DEFINE_bool('debug', False, 'Whether to turn on debugging mode')
|
||||
flags.DEFINE_string('job_mode', 'train',
|
||||
'One of `train`, `eval`, `sample`, `retrieve`.')
|
||||
flags.DEFINE_integer('random_seed', 42, 'Random seed id.')
|
||||
flags.DEFINE_integer('num_gpus', 8, 'Number of GPUs for training.')
|
||||
|
||||
# Eval
|
||||
flags.DEFINE_integer('eval_batch_size', 1, 'Evaluation batch size')
|
||||
flags.DEFINE_string('eval_subset', 'valid', 'Which subset to evaluate on,'
|
||||
' one of `valid`, `test`.')
|
||||
flags.DEFINE_integer('eval_every', 10, 'Evaluation frequency.')
|
||||
flags.DEFINE_integer('eval_timesteps', 64, 'Sequence length to learn on')
|
||||
flags.DEFINE_integer('eval_memory_size', 640, 'Memory size for transformer XL')
|
||||
flags.DEFINE_integer('max_eval_samples', -1, 'Max number of eval samples. Set'
|
||||
' as -1 to use the entire eval set.')
|
||||
|
||||
# Model
|
||||
flags.DEFINE_integer('emb_dim', 410, 'model width')
|
||||
flags.DEFINE_integer('num_heads', 10, 'Number of attention heads')
|
||||
flags.DEFINE_integer('num_layers', 16, 'Number of transformer layers')
|
||||
flags.DEFINE_integer('dense_dim', 2100, 'Size of dense hidden layer.')
|
||||
flags.DEFINE_integer('tail_shrink_factor', 4,
|
||||
'Low-frequency vocabulary shrinkage factor in adaptive'
|
||||
' softmax.')
|
||||
flags.DEFINE_string('emb_type', 'adaptive_softmax', 'Type of the word embedding'
|
||||
' layer.')
|
||||
flags.DEFINE_integer('clamp_len', 400, 'Clamp length for transformer XL.')
|
||||
flags.DEFINE_float('dropout', 0.1, 'Dropout rate for the transformer layers.')
|
||||
flags.DEFINE_float('dropout_attn', 0.0, 'Dropout rate for the attention'
|
||||
' weights.')
|
||||
flags.DEFINE_float('self_att_init_scale', 0.02,
|
||||
'Self attention module initilization scale.')
|
||||
flags.DEFINE_float('dense_init_scale', 0.02,
|
||||
'Dense module initilization scale.')
|
||||
|
||||
# Graph neural net configs
|
||||
flags.DEFINE_string('gnn_embed_type', 'adaptive', 'Token embedding type for the'
|
||||
' graph.')
|
||||
flags.DEFINE_integer('gnn_embed_dim', 128, 'Graph node embedding size.')
|
||||
flags.DEFINE_integer('gnn_num_layers', 1, 'Number of layers in the GNN.')
|
||||
flags.DEFINE_bool('gnn_layer_norm', True, 'Whether to use layer norm in GNN.')
|
||||
|
||||
# Bag-of-words to text configs
|
||||
flags.DEFINE_integer('bow_embedding_dim', 256, 'Size of the bow embeddings.')
|
||||
flags.DEFINE_integer('bow_n_tokens', 1, 'Number of tokens to use for the'
|
||||
' bow2text model.')
|
||||
|
||||
# Sampling
|
||||
flags.DEFINE_float('sampling_temperature', 0.8, 'Temperature used for'
|
||||
' sampling. Sampling becomes more deterministic with a'
|
||||
' lower temperature. Setting temperature to 1.0 samples'
|
||||
' from the model distribution.')
|
||||
flags.DEFINE_bool('prompt_title', False, 'Whether to prompt title when sample')
|
||||
flags.DEFINE_integer('sample_length', 512, 'Length of samples.')
|
||||
flags.DEFINE_integer('sample_memory_size', 640, 'Memory size for sampling.')
|
||||
flags.DEFINE_integer('num_samples', 1000, 'Maximum number of samples to'
|
||||
' generate.')
|
||||
|
||||
# Optimization
|
||||
flags.DEFINE_float('init_lr', 0.00025, 'Initial learning rate.')
|
||||
flags.DEFINE_float('min_lr_ratio', 0.0, 'Minimum learning rate as a ratio of'
|
||||
' `init_lr`.')
|
||||
flags.DEFINE_string('lr_schedule', 'cosine', 'One of `default`, `cosine`.')
|
||||
flags.DEFINE_float('grad_clip', 0.25, 'Maximum gradient norm allowed for'
|
||||
' clipping, set to a very large number to disable clipping.')
|
||||
flags.DEFINE_integer('max_steps', 200_000, 'Number of training steps.')
|
||||
flags.DEFINE_string('checkpoint_dir', '/tmp/graph2text',
|
||||
'Directory to store checkpoints.')
|
||||
|
||||
# Data
|
||||
flags.DEFINE_string('dataset', 'freebase2wikitext', 'Which dataset to train on,'
|
||||
' one of "wikitext", "freebase2wikitext".')
|
||||
flags.DEFINE_string('model_type', 'graph2text', 'One of "text", "graph2text",'
|
||||
' "bow2text".')
|
||||
flags.DEFINE_string('graph_data_version', 'max256', 'One of "max256", "max512",'
|
||||
' "max1024".')
|
||||
|
||||
flags.DEFINE_integer('log_every', 50, 'Log every this many steps.')
|
||||
flags.DEFINE_integer('ckpt_every', 1000, 'Checkpoint every this many steps.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def _preprocess(batch, num_devices=1):
|
||||
return utils.preprocess(batch, FLAGS.model_type, num_devices)
|
||||
|
||||
|
||||
def _train(updater, train_dataset, num_devices):
|
||||
"""Train the transformer model."""
|
||||
# Initialize parameters.
|
||||
logging.info('Initializing parameters...')
|
||||
rng = jax.random.PRNGKey(FLAGS.random_seed)
|
||||
state = updater.init(
|
||||
rng, _preprocess(train_dataset.return_faux_batch(), num_devices))
|
||||
|
||||
logging.info('Starting train loop...')
|
||||
prev_time = time.time()
|
||||
while True:
|
||||
data = next(train_dataset)
|
||||
state, metrics = updater.update(state, _preprocess(data, num_devices))
|
||||
# We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
|
||||
# Using values from state/metrics too often will block the runahead and can
|
||||
# cause these overheads to become more prominent.
|
||||
step = np.array(metrics['step'])
|
||||
if step % FLAGS.log_every == 0:
|
||||
steps_per_sec = FLAGS.log_every / (time.time() - prev_time)
|
||||
prev_time = time.time()
|
||||
metrics.update({'steps_per_sec': steps_per_sec})
|
||||
logging.info({k: float(v) for k, v in metrics.items()})
|
||||
if step % FLAGS.ckpt_every == 0:
|
||||
updater.save_checkpoint(state)
|
||||
if step > FLAGS.max_steps:
|
||||
break
|
||||
|
||||
|
||||
def _eval(updater, eval_dataset):
|
||||
"""Evaluate the transformer model."""
|
||||
checkpoint_state = updater.load_checkpoint()
|
||||
rng = jax.random.PRNGKey(FLAGS.random_seed)
|
||||
state = updater.init_from_checkpoint(
|
||||
rng, _preprocess(eval_dataset.return_faux_batch()), checkpoint_state)
|
||||
eval_out, state = utils.evaluate(
|
||||
eval_dataset, state, updater, FLAGS.eval_batch_size, _preprocess,
|
||||
FLAGS.max_eval_samples, print_progress_every=20)
|
||||
logging.info('Eval output: %s', eval_out)
|
||||
|
||||
|
||||
def _retrieve(updater, eval_dataset):
|
||||
"""Graph and text retrieval using the transformer model."""
|
||||
checkpoint_state = updater.load_checkpoint()
|
||||
rng = jax.random.PRNGKey(FLAGS.random_seed)
|
||||
state = updater.init_from_checkpoint(
|
||||
rng, _preprocess(eval_dataset.return_faux_batch()), checkpoint_state)
|
||||
retrieval_out, _ = utils.compute_text_graph_relevance(
|
||||
eval_dataset, state, updater, preprocess_fn=_preprocess,
|
||||
print_progress_every=20)
|
||||
logging.info('Retrieval output: %s', retrieval_out)
|
||||
|
||||
|
||||
def _sample(eval_dataset, tokenizer, devices, batch_size=1):
|
||||
"""Evaluate the graph2text transformer."""
|
||||
checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'checkpoint.pkl')
|
||||
logging.info('Loading checkpoint from %s', checkpoint_dir)
|
||||
with open(checkpoint_dir, 'rb') as f:
|
||||
state = pickle.load(f)
|
||||
|
||||
if FLAGS.model_type == 'graph2text':
|
||||
# process list of graphs into a batch
|
||||
eval_dataset = map(lambda x: dict( # pylint: disable=g-long-lambda
|
||||
obs=x['obs'],
|
||||
target=x['target'],
|
||||
should_reset=x['should_reset'],
|
||||
mask=x['mask'],
|
||||
graphs=jraph.batch(x['graphs']),
|
||||
), eval_dataset)
|
||||
eval_dataset = utils.take_unique_graphs(eval_dataset, FLAGS.model_type)
|
||||
|
||||
samplers = []
|
||||
for device in devices:
|
||||
sampler = utils.build_sampler(tokenizer, device=device)
|
||||
samplers.append(sampler)
|
||||
|
||||
step = state['step']
|
||||
params = state['params']
|
||||
sample_logger = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(samplers)) as executor:
|
||||
futures = dict()
|
||||
for sampler in samplers:
|
||||
batch = next(eval_dataset)
|
||||
prompts = utils.construct_prompts(
|
||||
batch['obs'], batch_size, FLAGS.sample_length, tokenizer,
|
||||
prompt_title=FLAGS.prompt_title)
|
||||
if FLAGS.model_type in ['graph2text', 'bow2text']:
|
||||
future = executor.submit(
|
||||
utils.generate_samples, params, tokenizer, sampler,
|
||||
model_type=FLAGS.model_type, prompts=prompts,
|
||||
graphs=batch['graphs'])
|
||||
futures[future] = (sampler, batch['graphs'], batch['obs'])
|
||||
else:
|
||||
future = executor.submit(
|
||||
utils.generate_samples, params, tokenizer, sampler,
|
||||
model_type=FLAGS.model_type, prompts=prompts, graphs=None)
|
||||
futures[future] = (sampler, batch['obs'])
|
||||
|
||||
n_samples = 0
|
||||
|
||||
while n_samples < FLAGS.num_samples:
|
||||
for future, future_items in list(futures.items()):
|
||||
if not future.done():
|
||||
continue
|
||||
samples, tokens = future.result()
|
||||
if FLAGS.model_type == 'graph2text':
|
||||
sampler, graphs, text = future_items
|
||||
graphs = jraph.unbatch(graphs)
|
||||
elif FLAGS.model_type == 'bow2text':
|
||||
sampler, graphs, text = future_items
|
||||
else:
|
||||
sampler, text = future_items
|
||||
|
||||
if FLAGS.model_type in ['graph2text', 'bow2text']:
|
||||
for s, g, tk, txt in zip(samples, graphs, tokens, text):
|
||||
# Only log a small fraction of the generated samples, if we are
|
||||
# generating non-stop. Otherwise log every sample.
|
||||
logging.info('[step %d]', step)
|
||||
logging.info('graph=\n%r', g)
|
||||
logging.info('sample=\n%s', s)
|
||||
if FLAGS.model_type == 'graph2text':
|
||||
sample_logger.append({
|
||||
'step': step,
|
||||
'sample': s,
|
||||
'sample_tokens': tk,
|
||||
'ground_truth_text': txt,
|
||||
})
|
||||
elif FLAGS.model_type == 'bow2text':
|
||||
sample_logger.append({
|
||||
'step': step,
|
||||
'bow': g,
|
||||
'sample': s,
|
||||
'sample_tokens': tk,
|
||||
'ground_truth_text': txt,
|
||||
})
|
||||
else:
|
||||
for s, tk, txt in zip(samples, tokens, text):
|
||||
# Only log a small fraction of the generated samples, if we are
|
||||
# generating non-stop. Otherwise log every sample.
|
||||
logging.info('[step %d]', step)
|
||||
logging.info('sample=\n%s', s)
|
||||
sample_logger.append({
|
||||
'step': step,
|
||||
'sample': s,
|
||||
'sample_tokens': tk,
|
||||
'ground_truth_text': txt,
|
||||
})
|
||||
n_samples += len(samples)
|
||||
logging.info('Finished generating %d samples', n_samples)
|
||||
|
||||
del futures[future]
|
||||
|
||||
if n_samples < FLAGS.num_samples:
|
||||
batch = next(eval_dataset)
|
||||
prompts = utils.construct_prompts(
|
||||
batch['obs'], batch_size, FLAGS.sample_length, tokenizer,
|
||||
prompt_title=FLAGS.prompt_title)
|
||||
if FLAGS.model_type in ['graph2text', 'bow2text']:
|
||||
future = executor.submit(
|
||||
utils.generate_samples, params, tokenizer, sampler,
|
||||
model_type=FLAGS.model_type, prompts=prompts,
|
||||
graphs=batch['graphs'])
|
||||
futures[future] = (sampler, batch['graphs'], batch['obs'])
|
||||
else:
|
||||
future = executor.submit(
|
||||
utils.generate_samples, params, tokenizer, sampler,
|
||||
model_type=FLAGS.model_type, prompts=prompts, graphs=None)
|
||||
futures[future] = (sampler, batch['obs'])
|
||||
|
||||
logging.info('Finished')
|
||||
path = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl')
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(dict(samples=sample_logger), f)
|
||||
logging.info('Samples saved to %s', path)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Create the dataset.
|
||||
tokenizer = utils.init_tokenizer(FLAGS.dataset)
|
||||
graph_tokenizer = utils.init_graph_tokenizer()
|
||||
dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type)
|
||||
has_graph = True if FLAGS.model_type == 'graph2text' else False
|
||||
local_devices = jax.local_devices()
|
||||
num_gpus = min(FLAGS.num_gpus, len(local_devices))
|
||||
|
||||
if FLAGS.job_mode == 'train':
|
||||
train_dataset = dataset_class(
|
||||
tokenizer=tokenizer,
|
||||
graph_tokenizer=graph_tokenizer,
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
subset='train',
|
||||
timesteps=FLAGS.train_timesteps,
|
||||
version=FLAGS.graph_data_version,
|
||||
shuffle_data=True,
|
||||
repeat=True,
|
||||
debug=FLAGS.debug)
|
||||
train_iter = iter(train_dataset)
|
||||
loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
|
||||
cache_steps=FLAGS.train_memory_size)
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(FLAGS.grad_clip),
|
||||
optax.scale_by_adam(),
|
||||
optax.scale_by_schedule(functools.partial(
|
||||
utils.schedule,
|
||||
lr_schedule=FLAGS.lr_schedule,
|
||||
init_lr=FLAGS.init_lr,
|
||||
min_lr_ratio=FLAGS.min_lr_ratio,
|
||||
max_steps=FLAGS.max_steps)),
|
||||
optax.scale(-1))
|
||||
optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
|
||||
updater = Updater(loss_fn, optimizer,
|
||||
devices=local_devices[:num_gpus],
|
||||
has_graph=has_graph)
|
||||
updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
|
||||
_train(updater, train_iter, num_gpus)
|
||||
elif FLAGS.job_mode == 'eval':
|
||||
eval_dataset = dataset_class(
|
||||
tokenizer=tokenizer,
|
||||
graph_tokenizer=graph_tokenizer,
|
||||
batch_size=FLAGS.eval_batch_size,
|
||||
subset=FLAGS.eval_subset,
|
||||
timesteps=FLAGS.eval_timesteps,
|
||||
version=FLAGS.graph_data_version,
|
||||
shuffle_data=False,
|
||||
repeat=False,
|
||||
debug=FLAGS.debug)
|
||||
eval_iter = iter(eval_dataset)
|
||||
loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
|
||||
cache_steps=FLAGS.eval_memory_size)
|
||||
# only use one device for evaluation
|
||||
devices = local_devices[:1]
|
||||
updater = Updater(loss_fn, optimizer=None, devices=devices,
|
||||
has_graph=has_graph)
|
||||
updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
|
||||
_eval(updater, eval_iter)
|
||||
elif FLAGS.job_mode == 'sample':
|
||||
eval_dataset = dataset_class(
|
||||
tokenizer=tokenizer,
|
||||
graph_tokenizer=graph_tokenizer,
|
||||
batch_size=1,
|
||||
subset=FLAGS.eval_subset,
|
||||
timesteps=FLAGS.sample_length,
|
||||
version=FLAGS.graph_data_version,
|
||||
shuffle_data=False,
|
||||
repeat=True,
|
||||
debug=FLAGS.debug)
|
||||
eval_iter = iter(eval_dataset)
|
||||
_sample(eval_iter, tokenizer, local_devices[:num_gpus])
|
||||
elif FLAGS.job_mode == 'retrieve':
|
||||
eval_dataset = dataset_class(
|
||||
tokenizer=tokenizer,
|
||||
graph_tokenizer=graph_tokenizer,
|
||||
batch_size=1,
|
||||
subset=FLAGS.eval_subset,
|
||||
timesteps=FLAGS.eval_timesteps,
|
||||
version=FLAGS.graph_data_version,
|
||||
shuffle_data=False,
|
||||
repeat=False,
|
||||
graph_retrieval_dataset=True,
|
||||
debug=FLAGS.debug)
|
||||
eval_iter = iter(eval_dataset)
|
||||
loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
|
||||
cache_steps=FLAGS.eval_memory_size)
|
||||
# only use one device for evaluation
|
||||
devices = local_devices[:1]
|
||||
updater = Updater(loss_fn, optimizer=None, devices=devices,
|
||||
has_graph=has_graph)
|
||||
updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
|
||||
_retrieve(updater, eval_iter)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Compute the bleu score on generated text and the ground truth."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
flags.DEFINE_string('checkpoint_dir', '/tmp/transformerXL',
|
||||
'Checkpoint directory to load saved samples.')
|
||||
flags.DEFINE_string('dataset', 'freebase2wikitext', 'Which dataset to the model'
|
||||
' is trained on, one of "wikitext", "freebase2wikitext".')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def group_samples(samples, tokenizer):
|
||||
"""Groups generated and ground truth texts."""
|
||||
groups = {}
|
||||
for i, row in enumerate(samples):
|
||||
gt = tokenizer.decode(row['ground_truth_text'])
|
||||
sample = tokenizer.decode(row['sample_tokens'])
|
||||
if gt not in groups:
|
||||
groups[gt] = (gt.split(), [sample.split()])
|
||||
else:
|
||||
groups[gt][-1].append(sample.split())
|
||||
if (i + 1) % 100 == 0:
|
||||
logging.info('Processed %d samples', i + 1)
|
||||
return groups
|
||||
|
||||
|
||||
def eval_samples(raw_samples, tokenizer):
|
||||
"""Evaluates generated samples."""
|
||||
gt_refs = []
|
||||
samples = []
|
||||
|
||||
groups = group_samples(raw_samples, tokenizer)
|
||||
groups = list(groups.values())
|
||||
avg_group_size = np.mean([len(g[-1]) for g in groups])
|
||||
logging.info('Average samples per example: %.2f', avg_group_size)
|
||||
avg_group_size = int(math.ceil(avg_group_size))
|
||||
for i, (gt, s) in enumerate(groups):
|
||||
gt_refs.append(gt)
|
||||
idx = i % len(groups)
|
||||
samples.append(groups[idx][-1])
|
||||
|
||||
gt_bleu, gt_n_grams = utils.compute_bleu(samples, gt_refs)
|
||||
|
||||
logging.info('Processed %d samples in total.', sum([len(s) for s in samples]))
|
||||
flat_samples = []
|
||||
for s in samples:
|
||||
flat_samples.extend(s)
|
||||
logging.info('Average sample len: %.2f',
|
||||
np.mean([len(s) for s in flat_samples]))
|
||||
logging.info('Average ground-truth len: %.2f',
|
||||
np.mean([len(gt) for gt in gt_refs]))
|
||||
|
||||
logging.info('Ground-truth BLEU: %6.2f, n-gram precision: (%s)',
|
||||
gt_bleu * 100,
|
||||
', '.join(['%6.2f%%' % (s * 100) for s in gt_n_grams]))
|
||||
|
||||
|
||||
def main(_):
|
||||
tokenizer = utils.init_tokenizer(FLAGS.dataset)
|
||||
checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl')
|
||||
logging.info('Loading samples from %s', checkpoint_dir)
|
||||
with open(checkpoint_dir, 'rb') as f:
|
||||
samples = pickle.load(f)['samples']
|
||||
eval_samples(samples, tokenizer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,311 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Data Parallel Updater for Graph2text data."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from absl import logging
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
from jax.tree_util import tree_multimap
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
|
||||
def call_fn_with_state_keys(jit_fn, state, other_inputs, keys):
|
||||
"""Executes `jit_fn`, filtering out all keys except some subset."""
|
||||
state = state.copy()
|
||||
extra_state = {}
|
||||
for k in list(state.keys()):
|
||||
if k not in keys:
|
||||
extra_state[k] = state.pop(k)
|
||||
return jit_fn(state, *other_inputs), extra_state
|
||||
|
||||
|
||||
class Updater:
|
||||
"""Graph2text model updater with multi-GPU support."""
|
||||
|
||||
def __init__(self, loss_fn, optimizer, devices=None, has_graph=False):
|
||||
self._net_init_fn, self._apply_fn = hk.transform_with_state(
|
||||
functools.partial(loss_fn, is_training=True))
|
||||
_, self._eval_apply_fn = hk.transform_with_state(
|
||||
functools.partial(loss_fn, is_training=False))
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = optax.identity()
|
||||
self._optimizer = optimizer
|
||||
|
||||
self._num_devices = jax.local_device_count()
|
||||
if devices is None:
|
||||
devices = []
|
||||
for host_id in range(jax.process_count()):
|
||||
for device_id in jax.local_devices(host_id):
|
||||
devices.append(device_id)
|
||||
else:
|
||||
self._num_devices = min(self._num_devices, len(devices))
|
||||
|
||||
def _pmap(f, static_broadcasted_argnums=()):
|
||||
return jax.pmap(f, axis_name='i', devices=devices,
|
||||
static_broadcasted_argnums=static_broadcasted_argnums)
|
||||
|
||||
def handle_graph_size(fn):
|
||||
def _fn(*args):
|
||||
batch = args[-1].copy()
|
||||
max_graph_size = batch['max_graph_size']
|
||||
del batch['max_graph_size']
|
||||
args = args[:-1] + (batch, max_graph_size)
|
||||
return fn(*args)
|
||||
return _fn
|
||||
|
||||
# Try to jit.
|
||||
if has_graph:
|
||||
# If the model contains full graphs, we need to set the max_graph_size
|
||||
# as a statically broadcasted argument.
|
||||
self._init_fn = handle_graph_size(_pmap(self._init, 4))
|
||||
self._update_fn = handle_graph_size(_pmap(self._update, 2))
|
||||
self._eval_fn = handle_graph_size(_pmap(self._eval, 2))
|
||||
else:
|
||||
self._init_fn = _pmap(self._init)
|
||||
self._update_fn = _pmap(self._update)
|
||||
self._eval_fn = _pmap(self._eval)
|
||||
|
||||
def _init(self, master_rng, params, network_state, data, max_graph_size=None):
|
||||
"""Initializes state of the updater."""
|
||||
out_rng, init_rng = jax.random.split(master_rng)
|
||||
if max_graph_size is not None:
|
||||
new_params, new_network_state = self._net_init_fn(
|
||||
init_rng, data, max_graph_size)
|
||||
else:
|
||||
new_params, new_network_state = self._net_init_fn(init_rng, data)
|
||||
if params is None:
|
||||
params = new_params
|
||||
if network_state is None:
|
||||
network_state = new_network_state
|
||||
opt_state = self._optimizer.init(params)
|
||||
return dict(
|
||||
replicated_step=0,
|
||||
rng=out_rng,
|
||||
state=network_state,
|
||||
opt_state=opt_state,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def init(self, master_rng, data, params=None, network_state=None,
|
||||
replicated_params=False):
|
||||
"""Initializes state of the updater."""
|
||||
data = self._preprocess(data)
|
||||
rngs = np.array([master_rng] * self._num_devices)
|
||||
if not replicated_params and params is not None:
|
||||
params = jax.tree_map(
|
||||
lambda x: np.array([x] * self._num_devices), params)
|
||||
state = self._init_fn(rngs, params, network_state, data)
|
||||
state['step'] = np.array(0, dtype=np.int64)
|
||||
# Wait for initialization to finish before starting training to keep
|
||||
# memory usage low.
|
||||
flat_params = jax.tree_leaves(state['params'])
|
||||
if flat_params:
|
||||
jax.tree_leaves(state['params'])[0].block_until_ready()
|
||||
return state
|
||||
|
||||
def _update(self, state, data, max_graph_size=None):
|
||||
"""Updates parameters."""
|
||||
replicated_step = state['replicated_step']
|
||||
rng = state['rng']
|
||||
opt_state = state['opt_state']
|
||||
params = state['params']
|
||||
net_state = state['state']
|
||||
|
||||
rng, new_rng = jax.random.split(rng)
|
||||
rng = jax.random.fold_in(rng, jax.lax.axis_index('i'))
|
||||
|
||||
def _loss(params, state, batch, rng):
|
||||
if max_graph_size is not None:
|
||||
(loss, metrics), state = self._apply_fn(params, state, rng, batch,
|
||||
max_graph_size)
|
||||
else:
|
||||
(loss, metrics), state = self._apply_fn(params, state, rng, batch)
|
||||
return loss, (metrics, state)
|
||||
|
||||
(loss, (metrics, new_net_state)), g = jax.value_and_grad(
|
||||
_loss, has_aux=True)(params, net_state, data, rng)
|
||||
g = jax.lax.pmean(g, axis_name='i')
|
||||
loss = jax.lax.pmean(loss, axis_name='i')
|
||||
metrics = jax.lax.pmean(metrics, axis_name='i')
|
||||
|
||||
updates, new_opt_state = self._optimizer.update(g, opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
new_state = dict(
|
||||
replicated_step=replicated_step + 1,
|
||||
rng=new_rng,
|
||||
state=new_net_state,
|
||||
opt_state=new_opt_state,
|
||||
params=new_params,
|
||||
)
|
||||
|
||||
metrics['loss'] = loss
|
||||
metrics['step'] = replicated_step
|
||||
return new_state, metrics
|
||||
|
||||
def update(self, state, data):
|
||||
"""Updates the state using some data and returns metrics."""
|
||||
data = self._preprocess(data)
|
||||
(state, out), extra_state = call_fn_with_state_keys(
|
||||
self._update_fn, state, [data], keys=set([
|
||||
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
|
||||
state.update(extra_state)
|
||||
state['step'] += 1
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
|
||||
def _eval(self, state, data, max_graph_size=None):
|
||||
"""Evaluates the current state on the given data."""
|
||||
if max_graph_size is not None:
|
||||
(loss, metrics), new_state = self._eval_apply_fn(
|
||||
state['params'], state['state'], state['rng'], data, max_graph_size)
|
||||
else:
|
||||
(loss, metrics), new_state = self._eval_apply_fn(
|
||||
state['params'], state['state'], state['rng'], data)
|
||||
state['state'] = new_state
|
||||
loss = jax.lax.pmean(loss, axis_name='i')
|
||||
metrics = jax.lax.pmean(metrics, axis_name='i')
|
||||
metrics['loss'] = loss
|
||||
metrics['step'] = state['replicated_step']
|
||||
return state, metrics
|
||||
|
||||
def eval_return_state(self, state, data):
|
||||
"""Returns metrics without updating the model."""
|
||||
data = self._preprocess(data)
|
||||
(state, out), extra_state = call_fn_with_state_keys(
|
||||
self._eval_fn, state, [data], keys=set([
|
||||
'state', 'params', 'rng', 'replicated_step']))
|
||||
state.update(extra_state)
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
|
||||
def eval(self, state, data):
|
||||
"""Returns metrics without updating the model."""
|
||||
_, out = self.eval_return_state(state, data)
|
||||
return out
|
||||
|
||||
def _preprocess(self, data):
|
||||
"""Reshapes input so that it can be distributed across multiple cores."""
|
||||
multi_inputs = data.copy()
|
||||
|
||||
def add_core_dimension(x):
|
||||
if np.isscalar(x):
|
||||
return x
|
||||
if x.shape[0] % self._num_devices != 0:
|
||||
raise ValueError(f'The batch size must be a multiple of the number of'
|
||||
f' devices. Got batch size = {x.shape[0]} and number'
|
||||
f' of devices = {self._num_devices}.')
|
||||
prefix = (self._num_devices, x.shape[0] // self._num_devices)
|
||||
return np.reshape(x, prefix + x.shape[1:])
|
||||
|
||||
multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
|
||||
return multi_inputs
|
||||
|
||||
def params(self, state):
|
||||
"""Returns model parameters."""
|
||||
return tree_multimap(lambda x: x[0], state['params'])
|
||||
|
||||
def opt_state(self, state):
|
||||
"""Returns the state of the optimiser."""
|
||||
return tree_multimap(lambda x: x[0], state['opt_state'])
|
||||
|
||||
def network_state(self, state):
|
||||
"""Returns the model's state."""
|
||||
return tree_multimap(lambda x: x[0], state['state'])
|
||||
|
||||
def to_checkpoint_state(self, state):
|
||||
"""Transforms the updater state into a checkpointable state."""
|
||||
checkpoint_state = state.copy()
|
||||
# Wrapper around checkpoint_state['step'] so we can get [0].
|
||||
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
|
||||
# Unstack the replicated contents.
|
||||
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
|
||||
return checkpoint_state
|
||||
|
||||
def from_checkpoint_state(self, checkpoint_state):
|
||||
"""Initializes the updater state from the checkpointed state."""
|
||||
# Expand the checkpoint so we have a copy for each device.
|
||||
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
|
||||
checkpoint_state)
|
||||
state['step'] = state['step'][0] # Undo stacking for step.
|
||||
return state
|
||||
|
||||
|
||||
class CheckpointingUpdater:
|
||||
"""A checkpointing wrapper around an Updater."""
|
||||
|
||||
def __init__(self,
|
||||
inner: Updater,
|
||||
checkpoint_dir: str):
|
||||
self._inner = inner
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
|
||||
def _checkpoint_paths(self):
|
||||
return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint' in p]
|
||||
|
||||
def init(self, rng, data, params=None, network_state=None):
|
||||
"""Initialize experiment state."""
|
||||
if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths():
|
||||
os.makedirs(self._checkpoint_dir, exist_ok=True)
|
||||
return self._inner.init(rng, data, params, network_state)
|
||||
return self.load_checkpoint()
|
||||
|
||||
def init_from_checkpoint(self, rng, data, checkpoint_state):
|
||||
params = self._inner.params(checkpoint_state)
|
||||
network_state = None
|
||||
return self._inner.init(rng, data, params, network_state)
|
||||
|
||||
def eval_return_state(self, state, data):
|
||||
return self._inner.eval_return_state(state, data)
|
||||
|
||||
def save_checkpoint(self, state):
|
||||
path = os.path.join(self._checkpoint_dir, 'checkpoint.pkl')
|
||||
logging.info('Serializing experiment state to %s', path)
|
||||
checkpoint_state = self._inner.to_checkpoint_state(jax.device_get(state))
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(checkpoint_state, f)
|
||||
|
||||
def load_checkpoint(self):
|
||||
checkpoint = os.path.join(self._checkpoint_dir,
|
||||
self._checkpoint_paths()[-1])
|
||||
logging.info('Loading checkpoint from %s', checkpoint)
|
||||
with open(checkpoint, 'rb') as f:
|
||||
state = pickle.load(f)
|
||||
return self._inner.from_checkpoint_state(state)
|
||||
|
||||
def update(self, state, data):
|
||||
"""Update experiment state."""
|
||||
state, out = self._inner.update(state, data)
|
||||
return state, out
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,5 +30,6 @@
|
||||
"""WikiGraphs model modules."""
|
||||
from . import embedding
|
||||
from . import graph_net
|
||||
from . import sampler
|
||||
from . import transformer
|
||||
from . import transformer_block
|
||||
|
||||
@@ -41,7 +41,7 @@ import numpy as np
|
||||
ArrayType = Union[np.ndarray, jnp.ndarray]
|
||||
|
||||
|
||||
def pad_size(in_size):
|
||||
def pad_size(in_size: int):
|
||||
out_size = 1
|
||||
while out_size < in_size:
|
||||
out_size *= 2
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Samplers for the graph2text transformers."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
from wikigraphs.model import graph_net as gn
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base class for transformer samplers."""
|
||||
|
||||
def __init__(self,
|
||||
model_fn,
|
||||
temperature: float = 1.0,
|
||||
device: Optional[Any] = None,
|
||||
rng: Optional[np.ndarray] = None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
model_fn: a transformer language model defined in model.transformer.
|
||||
temperature: sampling temperature.
|
||||
device: the sampler will run on this device if provided.
|
||||
rng: random number generator.
|
||||
"""
|
||||
self._temperature = temperature
|
||||
self._device = device or jax.local_devices()[0]
|
||||
init_fn, apply_fn = hk.transform_with_state(model_fn)
|
||||
|
||||
if rng is None:
|
||||
rng = jax.random.PRNGKey(np.random.randint(2**32))
|
||||
rng = jax.random.fold_in(rng, jax.host_id())
|
||||
self._rng = rng
|
||||
self._init_state = None
|
||||
self._jit_model(init_fn, apply_fn)
|
||||
|
||||
def _jit_model(self, init_fn, apply_fn):
|
||||
"""Jit the `init_fn` and `apply_fn`."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sample(self,
|
||||
params: Mapping[str, Any],
|
||||
state: Mapping[str, Any],
|
||||
rng: jnp.ndarray,
|
||||
x: jnp.ndarray,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""Generate samples.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
state: state of the transformer.
|
||||
rng: random number generator.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
**kwargs: additional inputs.
|
||||
|
||||
Returns:
|
||||
output: [batch_size, sample_len] tensor, the generated sequence.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self,
|
||||
params: Mapping[str, Any],
|
||||
x: jnp.ndarray,
|
||||
**kwargs) -> jnp.ndarray:
|
||||
"""Generate samples based on the given parameters and prompts.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
**kwargs: additional inputs.
|
||||
|
||||
Returns:
|
||||
output: the generated sequence.
|
||||
"""
|
||||
|
||||
|
||||
class TransformerXLSampler(BaseSampler):
|
||||
"""Sampling from the TransformerXL model."""
|
||||
|
||||
def _jit_model(self, init_fn, apply_fn):
|
||||
"""Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`."""
|
||||
self._init_fn = jax.jit(init_fn, device=self._device)
|
||||
self._apply_fn = apply_fn
|
||||
self._sample_fn = jax.jit(self._sample, device=self._device)
|
||||
|
||||
def _sample(self,
|
||||
params: Mapping[str, Any],
|
||||
state: Mapping[str, Any],
|
||||
rng: jnp.ndarray,
|
||||
x: jnp.ndarray) -> np.ndarray:
|
||||
"""Generate unconditional samples.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
state: state of the transformer.
|
||||
rng: random number generator.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
|
||||
Returns:
|
||||
output: [batch_size, sample_len] tensor, the generated sequence.
|
||||
"""
|
||||
batch_size, sample_len = x.shape
|
||||
|
||||
def one_step(params, state, rng, i, x):
|
||||
step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1])
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
# step_sample shape is [batch_size, 1].
|
||||
logits, state = self._apply_fn(params, state, rng_, step_sample)
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
step_sample = jax.random.categorical(rng_, logits / self._temperature)
|
||||
update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:,
|
||||
None]
|
||||
x = jax.lax.dynamic_update_slice(x, update, [0, i + 1])
|
||||
return state, rng, x
|
||||
|
||||
def loop_body(i, data):
|
||||
state, rng, x = data
|
||||
return one_step(params, state, rng, i, x)
|
||||
|
||||
_, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body,
|
||||
(state, rng, x))
|
||||
|
||||
return x
|
||||
|
||||
def sample(self,
|
||||
params: Mapping[str, Any],
|
||||
x: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Generate samples based on the given graphs and parameters.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
|
||||
Returns:
|
||||
output: the generated sequence.
|
||||
"""
|
||||
if self._init_state is None:
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
self._init_params, self._init_state = self._init_fn(rng, x[:, :1])
|
||||
if params is None:
|
||||
params = self._init_params
|
||||
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
sample = self._sample_fn(params, self._init_state, rng, x)
|
||||
return sample
|
||||
|
||||
|
||||
class Bow2TextTransformerSampler(BaseSampler):
|
||||
"""Sampling from the TransformerXL model."""
|
||||
|
||||
def _jit_model(self, init_fn, apply_fn):
|
||||
"""Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`."""
|
||||
self._init_fn = jax.jit(init_fn, device=self._device)
|
||||
self._apply_fn = apply_fn
|
||||
self._sample_fn = jax.jit(self._sample, device=self._device)
|
||||
|
||||
def _sample(self,
|
||||
params: Mapping[str, Any],
|
||||
state: Mapping[str, Any],
|
||||
rng: jnp.ndarray,
|
||||
bow: jnp.ndarray,
|
||||
x: jnp.ndarray) -> np.ndarray:
|
||||
"""Generate samples conditioned on the bag-of-words of the graph.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
state: state of the transformer.
|
||||
rng: random number generator.
|
||||
bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
|
||||
Returns:
|
||||
output: [batch_size, sample_len] tensor, the generated sequence.
|
||||
"""
|
||||
batch_size, sample_len = x.shape
|
||||
|
||||
def one_step(params, state, rng, i, x):
|
||||
step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1])
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
# step_sample shape is [batch_size, 1].
|
||||
logits, state = self._apply_fn(params, state, rng_, bow, step_sample)
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
step_sample = jax.random.categorical(rng_, logits / self._temperature)
|
||||
update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:,
|
||||
None]
|
||||
x = jax.lax.dynamic_update_slice(x, update, [0, i + 1])
|
||||
return state, rng, x
|
||||
|
||||
def loop_body(i, data):
|
||||
state, rng, x = data
|
||||
return one_step(params, state, rng, i, x)
|
||||
|
||||
_, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body,
|
||||
(state, rng, x))
|
||||
|
||||
return x
|
||||
|
||||
def sample(self,
|
||||
params: Mapping[str, Any],
|
||||
x: jnp.ndarray,
|
||||
bow: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Generate samples based on the given graphs and parameters.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector.
|
||||
|
||||
Returns:
|
||||
output: the generated sequence.
|
||||
"""
|
||||
if self._init_state is None:
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
self._init_params, self._init_state = self._init_fn(rng, bow, x[:, :1])
|
||||
if params is None:
|
||||
params = self._init_params
|
||||
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
sample = self._sample_fn(params, self._init_state, rng, bow, x)
|
||||
return sample
|
||||
|
||||
|
||||
class Graph2TextTransformerSampler(BaseSampler):
|
||||
"""Sampling from the Graph2Text TransformerXL model."""
|
||||
|
||||
def _jit_model(self, init_fn, apply_fn):
|
||||
"""Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`."""
|
||||
# `pad_n_nodes` is set as a static argument.
|
||||
self._init_fn = jax.jit(init_fn, device=self._device, static_argnums=2)
|
||||
self._apply_fn = apply_fn
|
||||
self._sample_fn = jax.jit(self._sample, device=self._device,
|
||||
static_argnums=4)
|
||||
|
||||
def _sample(self,
|
||||
params: Mapping[str, Any],
|
||||
state: Mapping[str, Any],
|
||||
rng: jnp.ndarray,
|
||||
graphs: jraph.GraphsTuple,
|
||||
pad_n_nodes: int,
|
||||
x: jnp.ndarray) -> np.ndarray:
|
||||
"""Generate samples conditioned on the bag-of-words reprensation of graph.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
state: state of the transformer.
|
||||
rng: random number generator.
|
||||
graphs: a graph structured using graph_net.Graph.
|
||||
pad_n_nodes: size for each node to pad to.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
|
||||
Returns:
|
||||
output: [batch_size, sample_len] tensor, the generated sequence.
|
||||
"""
|
||||
batch_size, sample_len = x.shape
|
||||
|
||||
def one_step(params, state, rng, i, x):
|
||||
step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1])
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
# step_sample shape is [batch_size, 1].
|
||||
logits, state = self._apply_fn(
|
||||
params, state, rng_, graphs, pad_n_nodes, step_sample)
|
||||
rng, rng_ = jax.random.split(rng)
|
||||
step_sample = jax.random.categorical(rng_, logits / self._temperature)
|
||||
update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:,
|
||||
None]
|
||||
x = jax.lax.dynamic_update_slice(x, update, [0, i + 1])
|
||||
return state, rng, x
|
||||
|
||||
def loop_body(i, data):
|
||||
state, rng, x = data
|
||||
return one_step(params, state, rng, i, x)
|
||||
|
||||
_, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body,
|
||||
(state, rng, x))
|
||||
|
||||
return x
|
||||
|
||||
def sample(self,
|
||||
params: Mapping[str, Any],
|
||||
x: jnp.ndarray,
|
||||
graphs: jraph.GraphsTuple,
|
||||
pad: bool = True) -> jnp.ndarray:
|
||||
"""Generate samples based on the given graphs and parameters.
|
||||
|
||||
Args:
|
||||
params: parameters of the transformer.
|
||||
x: a prompt of shape [batch_size, sample_len], in which an entry of -1
|
||||
indicates it will be generate at that place. Otherwise it acts as the
|
||||
prompt.
|
||||
graphs: a graph structured using graph_net.Graph.
|
||||
pad: whether to pad the graph nodes and edges or not.
|
||||
|
||||
Returns:
|
||||
output: the generated sequence.
|
||||
"""
|
||||
if pad:
|
||||
graphs = gn.pad_graphs(graphs)
|
||||
max_graph_size = gn.pad_size(graphs.n_node.max())
|
||||
else:
|
||||
max_graph_size = graphs.n_node.max()
|
||||
|
||||
if self._init_state is None:
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
self._init_params, self._init_state = self._init_fn(
|
||||
rng, graphs, max_graph_size, x[:, :1])
|
||||
if params is None:
|
||||
params = self._init_params
|
||||
|
||||
self._rng, rng = jax.random.split(self._rng)
|
||||
sample = self._sample_fn(
|
||||
params, self._init_state, rng, graphs, max_graph_size, x)
|
||||
return sample
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# WikiGraphs is licensed under the terms of the Creative Commons
|
||||
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
|
||||
#
|
||||
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
|
||||
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
|
||||
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
|
||||
#
|
||||
# Freebase data is licensed by Google LLC under the terms of the Creative
|
||||
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
|
||||
#
|
||||
# https://creativecommons.org/licenses/by/4.0/legalcode
|
||||
#
|
||||
# ==============================================================================
|
||||
"""Tests for wikigraphs.model.sampler."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
from wikigraphs.model import sampler
|
||||
from wikigraphs.model import transformer as models
|
||||
|
||||
|
||||
class SamplerTest(absltest.TestCase):
|
||||
|
||||
def test_uncond_sampler_runs(self):
|
||||
prompt = np.array([[0, 1, 2, -1, -1],
|
||||
[0, 1, 2, -1, -1]], dtype=np.int32)
|
||||
vocab_size = prompt.max() + 1
|
||||
bos_token = 0
|
||||
memory_size = 2
|
||||
params = None
|
||||
|
||||
def model_fn(x):
|
||||
return models.TransformerXL(
|
||||
vocab_size=vocab_size,
|
||||
emb_dim=8,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
cutoffs=[])(x, is_training=False, cache_steps=memory_size)
|
||||
|
||||
uncond_sampler = sampler.TransformerXLSampler(model_fn)
|
||||
sample = uncond_sampler.sample(params, prompt)
|
||||
self.assertTrue((sample[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample != -1).all())
|
||||
self.assertEqual(sample.shape, prompt.shape)
|
||||
sample2 = uncond_sampler.sample(params, prompt)
|
||||
self.assertTrue((sample2[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample2 != -1).all())
|
||||
self.assertEqual(sample2.shape, prompt.shape)
|
||||
self.assertTrue((sample != sample2).any())
|
||||
|
||||
def test_bow2text_sampler_runs(self):
|
||||
bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1, 0]], dtype=np.int32)
|
||||
prompt = np.array([[0, 1, 2, -1, -1, -1],
|
||||
[0, 1, 2, -1, -1, -1]], dtype=np.int32)
|
||||
vocab_size = prompt.max() + 1
|
||||
bos_token = 0
|
||||
memory_size = 2
|
||||
params = None
|
||||
|
||||
def model_fn(bow, x):
|
||||
return models.Bow2TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
emb_dim=16,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
cutoffs=[])(bow, x, is_training=False, cache_steps=memory_size)
|
||||
|
||||
bow_sampler = sampler.Bow2TextTransformerSampler(model_fn)
|
||||
sample = bow_sampler.sample(params, prompt, bow)
|
||||
self.assertTrue((sample[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample != -1).all())
|
||||
self.assertEqual(sample.shape, prompt.shape)
|
||||
sample2 = bow_sampler.sample(params, prompt, bow)
|
||||
self.assertTrue((sample2[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample2 != -1).all())
|
||||
self.assertEqual(sample2.shape, prompt.shape)
|
||||
self.assertTrue((sample != sample2).any())
|
||||
|
||||
def test_graph2text_sampler_runs(self):
|
||||
graphs = jraph.GraphsTuple(
|
||||
nodes=np.ones((4, 3), dtype=np.float32),
|
||||
edges=np.ones((3, 1), dtype=np.float32),
|
||||
senders=np.array([0, 2, 3], dtype=np.int32),
|
||||
receivers=np.array([1, 3, 2], dtype=np.int32),
|
||||
n_node=np.array([2, 2], dtype=np.int32),
|
||||
n_edge=np.array([1, 2], dtype=np.int32),
|
||||
globals=None,
|
||||
)
|
||||
prompt = np.array([[0, 1, 2, -1, -1, -1],
|
||||
[0, 1, 2, -1, -1, -1]], dtype=np.int32)
|
||||
vocab_size = prompt.max() + 1
|
||||
bos_token = 0
|
||||
memory_size = 2
|
||||
params = None
|
||||
|
||||
def model_fn(graphs, max_graph_size, x):
|
||||
return models.Graph2TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
emb_dim=8,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
cutoffs=[],
|
||||
gnn_embed_dim=8,
|
||||
gnn_num_layers=2)(
|
||||
graphs, max_graph_size, True, x,
|
||||
is_training=False, cache_steps=memory_size)
|
||||
|
||||
graph_sampler = sampler.Graph2TextTransformerSampler(model_fn)
|
||||
sample = graph_sampler.sample(params, prompt, graphs)
|
||||
self.assertTrue((sample[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample != -1).all())
|
||||
self.assertEqual(sample.shape, prompt.shape)
|
||||
sample2 = graph_sampler.sample(params, prompt, graphs)
|
||||
self.assertTrue((sample2[:, 0] == bos_token).all())
|
||||
self.assertTrue((sample2 != -1).all())
|
||||
self.assertEqual(sample2.shape, prompt.shape)
|
||||
self.assertTrue((sample != sample2).any())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user