diff --git a/wikigraphs/README.md b/wikigraphs/README.md index 7e63a57..4cdbde8 100644 --- a/wikigraphs/README.md +++ b/wikigraphs/README.md @@ -167,7 +167,74 @@ it elsewhere. ## Run baseline models -Note: baseline models will be available soon. +Note: our code supports training with multiple GPUs. + +To run the default baseline GNN-based TransformerXL on Wikigraphs with 8 +GPUs: + +```base +python main.py --model_type=graph2text \ + --dataset=freebase2wikitext \ + --checkpoint_dir=/tmp/graph2text \ + --job_mode=train \ + --train_batch_size=64 \ + --gnn_num_layers=1 \ + --num_gpus=8 +``` + +We ran our experiments in the paper using 8 Nvidia V100 GPUs. To allow for +batch parallization for the GNN-based (graph2text) model, we pad graphs to +the largest graph in the batch. The full run takes almost 4 days. BoW- and +nodes-based models can be trained within 14 hours because there is no +additional padding. + +Or to quickly test-run a small model: + +```base +python main.py --model_type=graph2text \ + --dataset=freebase2wikitext \ + --checkpoint_dir=/tmp/graph2text \ + --job_mode=train \ + --train_batch_size=2 \ + --gnn_num_layers=1 +``` + +To evaluate the model on the validation set (this only uses 1 GPU): + +```base +python main.py --model_type=graph2text \ + --dataset=freebase2wikitext \ + --checkpoint_dir=/tmp/graph2text \ + --job_mode=eval \ + --eval_subset=valid +``` + +To generate 960 samples from the model using the graphs in the validation set (using 8 GPUs): + +```base +python main.py --model_type=graph2text \ + --dataset=freebase2wikitext \ + --checkpoint_dir=/tmp/graph2text \ + --job_mode=sample \ + --eval_subset=valid \ + --num_gpus=8 \ + --num_samples=960 +``` + +To compute the rBLEU score of the generated samples: + +```base +python scripts/compute_bleu_score.py --dataset=freebase2wikitext \ + --checkpoint_dir=/tmp/graph2text +``` + +To compute the retrieval scores: + +```base +python main.py --dataset=freebase2wikitext \ + --job_mode=retrieve \ + --checkpoint_dir=/tmp/graph2text +``` ## Citing WikiGraphs diff --git a/wikigraphs/main.py b/wikigraphs/main.py new file mode 100644 index 0000000..6954172 --- /dev/null +++ b/wikigraphs/main.py @@ -0,0 +1,416 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Train a transformer for language modeling on Wikitext-103.""" + +import concurrent +import functools +import os +import pickle +import time + +from absl import app +from absl import flags +from absl import logging + +import jax +import jraph +import numpy as np +import optax + +from updaters import CheckpointingUpdater +from updaters import Updater +import utils + + +# Train +flags.DEFINE_integer('train_batch_size', 4, '(Per-Device) batch size for' + ' training.') +flags.DEFINE_integer('train_timesteps', 150, 'Sequence length to learn on') +flags.DEFINE_integer('train_memory_size', 150, 'Memory size for transformer XL') +flags.DEFINE_bool('debug', False, 'Whether to turn on debugging mode') +flags.DEFINE_string('job_mode', 'train', + 'One of `train`, `eval`, `sample`, `retrieve`.') +flags.DEFINE_integer('random_seed', 42, 'Random seed id.') +flags.DEFINE_integer('num_gpus', 8, 'Number of GPUs for training.') + +# Eval +flags.DEFINE_integer('eval_batch_size', 1, 'Evaluation batch size') +flags.DEFINE_string('eval_subset', 'valid', 'Which subset to evaluate on,' + ' one of `valid`, `test`.') +flags.DEFINE_integer('eval_every', 10, 'Evaluation frequency.') +flags.DEFINE_integer('eval_timesteps', 64, 'Sequence length to learn on') +flags.DEFINE_integer('eval_memory_size', 640, 'Memory size for transformer XL') +flags.DEFINE_integer('max_eval_samples', -1, 'Max number of eval samples. Set' + ' as -1 to use the entire eval set.') + +# Model +flags.DEFINE_integer('emb_dim', 410, 'model width') +flags.DEFINE_integer('num_heads', 10, 'Number of attention heads') +flags.DEFINE_integer('num_layers', 16, 'Number of transformer layers') +flags.DEFINE_integer('dense_dim', 2100, 'Size of dense hidden layer.') +flags.DEFINE_integer('tail_shrink_factor', 4, + 'Low-frequency vocabulary shrinkage factor in adaptive' + ' softmax.') +flags.DEFINE_string('emb_type', 'adaptive_softmax', 'Type of the word embedding' + ' layer.') +flags.DEFINE_integer('clamp_len', 400, 'Clamp length for transformer XL.') +flags.DEFINE_float('dropout', 0.1, 'Dropout rate for the transformer layers.') +flags.DEFINE_float('dropout_attn', 0.0, 'Dropout rate for the attention' + ' weights.') +flags.DEFINE_float('self_att_init_scale', 0.02, + 'Self attention module initilization scale.') +flags.DEFINE_float('dense_init_scale', 0.02, + 'Dense module initilization scale.') + +# Graph neural net configs +flags.DEFINE_string('gnn_embed_type', 'adaptive', 'Token embedding type for the' + ' graph.') +flags.DEFINE_integer('gnn_embed_dim', 128, 'Graph node embedding size.') +flags.DEFINE_integer('gnn_num_layers', 1, 'Number of layers in the GNN.') +flags.DEFINE_bool('gnn_layer_norm', True, 'Whether to use layer norm in GNN.') + +# Bag-of-words to text configs +flags.DEFINE_integer('bow_embedding_dim', 256, 'Size of the bow embeddings.') +flags.DEFINE_integer('bow_n_tokens', 1, 'Number of tokens to use for the' + ' bow2text model.') + +# Sampling +flags.DEFINE_float('sampling_temperature', 0.8, 'Temperature used for' + ' sampling. Sampling becomes more deterministic with a' + ' lower temperature. Setting temperature to 1.0 samples' + ' from the model distribution.') +flags.DEFINE_bool('prompt_title', False, 'Whether to prompt title when sample') +flags.DEFINE_integer('sample_length', 512, 'Length of samples.') +flags.DEFINE_integer('sample_memory_size', 640, 'Memory size for sampling.') +flags.DEFINE_integer('num_samples', 1000, 'Maximum number of samples to' + ' generate.') + +# Optimization +flags.DEFINE_float('init_lr', 0.00025, 'Initial learning rate.') +flags.DEFINE_float('min_lr_ratio', 0.0, 'Minimum learning rate as a ratio of' + ' `init_lr`.') +flags.DEFINE_string('lr_schedule', 'cosine', 'One of `default`, `cosine`.') +flags.DEFINE_float('grad_clip', 0.25, 'Maximum gradient norm allowed for' + ' clipping, set to a very large number to disable clipping.') +flags.DEFINE_integer('max_steps', 200_000, 'Number of training steps.') +flags.DEFINE_string('checkpoint_dir', '/tmp/graph2text', + 'Directory to store checkpoints.') + +# Data +flags.DEFINE_string('dataset', 'freebase2wikitext', 'Which dataset to train on,' + ' one of "wikitext", "freebase2wikitext".') +flags.DEFINE_string('model_type', 'graph2text', 'One of "text", "graph2text",' + ' "bow2text".') +flags.DEFINE_string('graph_data_version', 'max256', 'One of "max256", "max512",' + ' "max1024".') + +flags.DEFINE_integer('log_every', 50, 'Log every this many steps.') +flags.DEFINE_integer('ckpt_every', 1000, 'Checkpoint every this many steps.') + +FLAGS = flags.FLAGS + + +def _preprocess(batch, num_devices=1): + return utils.preprocess(batch, FLAGS.model_type, num_devices) + + +def _train(updater, train_dataset, num_devices): + """Train the transformer model.""" + # Initialize parameters. + logging.info('Initializing parameters...') + rng = jax.random.PRNGKey(FLAGS.random_seed) + state = updater.init( + rng, _preprocess(train_dataset.return_faux_batch(), num_devices)) + + logging.info('Starting train loop...') + prev_time = time.time() + while True: + data = next(train_dataset) + state, metrics = updater.update(state, _preprocess(data, num_devices)) + # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. + # Using values from state/metrics too often will block the runahead and can + # cause these overheads to become more prominent. + step = np.array(metrics['step']) + if step % FLAGS.log_every == 0: + steps_per_sec = FLAGS.log_every / (time.time() - prev_time) + prev_time = time.time() + metrics.update({'steps_per_sec': steps_per_sec}) + logging.info({k: float(v) for k, v in metrics.items()}) + if step % FLAGS.ckpt_every == 0: + updater.save_checkpoint(state) + if step > FLAGS.max_steps: + break + + +def _eval(updater, eval_dataset): + """Evaluate the transformer model.""" + checkpoint_state = updater.load_checkpoint() + rng = jax.random.PRNGKey(FLAGS.random_seed) + state = updater.init_from_checkpoint( + rng, _preprocess(eval_dataset.return_faux_batch()), checkpoint_state) + eval_out, state = utils.evaluate( + eval_dataset, state, updater, FLAGS.eval_batch_size, _preprocess, + FLAGS.max_eval_samples, print_progress_every=20) + logging.info('Eval output: %s', eval_out) + + +def _retrieve(updater, eval_dataset): + """Graph and text retrieval using the transformer model.""" + checkpoint_state = updater.load_checkpoint() + rng = jax.random.PRNGKey(FLAGS.random_seed) + state = updater.init_from_checkpoint( + rng, _preprocess(eval_dataset.return_faux_batch()), checkpoint_state) + retrieval_out, _ = utils.compute_text_graph_relevance( + eval_dataset, state, updater, preprocess_fn=_preprocess, + print_progress_every=20) + logging.info('Retrieval output: %s', retrieval_out) + + +def _sample(eval_dataset, tokenizer, devices, batch_size=1): + """Evaluate the graph2text transformer.""" + checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'checkpoint.pkl') + logging.info('Loading checkpoint from %s', checkpoint_dir) + with open(checkpoint_dir, 'rb') as f: + state = pickle.load(f) + + if FLAGS.model_type == 'graph2text': + # process list of graphs into a batch + eval_dataset = map(lambda x: dict( # pylint: disable=g-long-lambda + obs=x['obs'], + target=x['target'], + should_reset=x['should_reset'], + mask=x['mask'], + graphs=jraph.batch(x['graphs']), + ), eval_dataset) + eval_dataset = utils.take_unique_graphs(eval_dataset, FLAGS.model_type) + + samplers = [] + for device in devices: + sampler = utils.build_sampler(tokenizer, device=device) + samplers.append(sampler) + + step = state['step'] + params = state['params'] + sample_logger = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(samplers)) as executor: + futures = dict() + for sampler in samplers: + batch = next(eval_dataset) + prompts = utils.construct_prompts( + batch['obs'], batch_size, FLAGS.sample_length, tokenizer, + prompt_title=FLAGS.prompt_title) + if FLAGS.model_type in ['graph2text', 'bow2text']: + future = executor.submit( + utils.generate_samples, params, tokenizer, sampler, + model_type=FLAGS.model_type, prompts=prompts, + graphs=batch['graphs']) + futures[future] = (sampler, batch['graphs'], batch['obs']) + else: + future = executor.submit( + utils.generate_samples, params, tokenizer, sampler, + model_type=FLAGS.model_type, prompts=prompts, graphs=None) + futures[future] = (sampler, batch['obs']) + + n_samples = 0 + + while n_samples < FLAGS.num_samples: + for future, future_items in list(futures.items()): + if not future.done(): + continue + samples, tokens = future.result() + if FLAGS.model_type == 'graph2text': + sampler, graphs, text = future_items + graphs = jraph.unbatch(graphs) + elif FLAGS.model_type == 'bow2text': + sampler, graphs, text = future_items + else: + sampler, text = future_items + + if FLAGS.model_type in ['graph2text', 'bow2text']: + for s, g, tk, txt in zip(samples, graphs, tokens, text): + # Only log a small fraction of the generated samples, if we are + # generating non-stop. Otherwise log every sample. + logging.info('[step %d]', step) + logging.info('graph=\n%r', g) + logging.info('sample=\n%s', s) + if FLAGS.model_type == 'graph2text': + sample_logger.append({ + 'step': step, + 'sample': s, + 'sample_tokens': tk, + 'ground_truth_text': txt, + }) + elif FLAGS.model_type == 'bow2text': + sample_logger.append({ + 'step': step, + 'bow': g, + 'sample': s, + 'sample_tokens': tk, + 'ground_truth_text': txt, + }) + else: + for s, tk, txt in zip(samples, tokens, text): + # Only log a small fraction of the generated samples, if we are + # generating non-stop. Otherwise log every sample. + logging.info('[step %d]', step) + logging.info('sample=\n%s', s) + sample_logger.append({ + 'step': step, + 'sample': s, + 'sample_tokens': tk, + 'ground_truth_text': txt, + }) + n_samples += len(samples) + logging.info('Finished generating %d samples', n_samples) + + del futures[future] + + if n_samples < FLAGS.num_samples: + batch = next(eval_dataset) + prompts = utils.construct_prompts( + batch['obs'], batch_size, FLAGS.sample_length, tokenizer, + prompt_title=FLAGS.prompt_title) + if FLAGS.model_type in ['graph2text', 'bow2text']: + future = executor.submit( + utils.generate_samples, params, tokenizer, sampler, + model_type=FLAGS.model_type, prompts=prompts, + graphs=batch['graphs']) + futures[future] = (sampler, batch['graphs'], batch['obs']) + else: + future = executor.submit( + utils.generate_samples, params, tokenizer, sampler, + model_type=FLAGS.model_type, prompts=prompts, graphs=None) + futures[future] = (sampler, batch['obs']) + + logging.info('Finished') + path = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl') + with open(path, 'wb') as f: + pickle.dump(dict(samples=sample_logger), f) + logging.info('Samples saved to %s', path) + + +def main(_): + # Create the dataset. + tokenizer = utils.init_tokenizer(FLAGS.dataset) + graph_tokenizer = utils.init_graph_tokenizer() + dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type) + has_graph = True if FLAGS.model_type == 'graph2text' else False + local_devices = jax.local_devices() + num_gpus = min(FLAGS.num_gpus, len(local_devices)) + + if FLAGS.job_mode == 'train': + train_dataset = dataset_class( + tokenizer=tokenizer, + graph_tokenizer=graph_tokenizer, + batch_size=FLAGS.train_batch_size, + subset='train', + timesteps=FLAGS.train_timesteps, + version=FLAGS.graph_data_version, + shuffle_data=True, + repeat=True, + debug=FLAGS.debug) + train_iter = iter(train_dataset) + loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, + cache_steps=FLAGS.train_memory_size) + optimizer = optax.chain( + optax.clip_by_global_norm(FLAGS.grad_clip), + optax.scale_by_adam(), + optax.scale_by_schedule(functools.partial( + utils.schedule, + lr_schedule=FLAGS.lr_schedule, + init_lr=FLAGS.init_lr, + min_lr_ratio=FLAGS.min_lr_ratio, + max_steps=FLAGS.max_steps)), + optax.scale(-1)) + optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5) + updater = Updater(loss_fn, optimizer, + devices=local_devices[:num_gpus], + has_graph=has_graph) + updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) + _train(updater, train_iter, num_gpus) + elif FLAGS.job_mode == 'eval': + eval_dataset = dataset_class( + tokenizer=tokenizer, + graph_tokenizer=graph_tokenizer, + batch_size=FLAGS.eval_batch_size, + subset=FLAGS.eval_subset, + timesteps=FLAGS.eval_timesteps, + version=FLAGS.graph_data_version, + shuffle_data=False, + repeat=False, + debug=FLAGS.debug) + eval_iter = iter(eval_dataset) + loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, + cache_steps=FLAGS.eval_memory_size) + # only use one device for evaluation + devices = local_devices[:1] + updater = Updater(loss_fn, optimizer=None, devices=devices, + has_graph=has_graph) + updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) + _eval(updater, eval_iter) + elif FLAGS.job_mode == 'sample': + eval_dataset = dataset_class( + tokenizer=tokenizer, + graph_tokenizer=graph_tokenizer, + batch_size=1, + subset=FLAGS.eval_subset, + timesteps=FLAGS.sample_length, + version=FLAGS.graph_data_version, + shuffle_data=False, + repeat=True, + debug=FLAGS.debug) + eval_iter = iter(eval_dataset) + _sample(eval_iter, tokenizer, local_devices[:num_gpus]) + elif FLAGS.job_mode == 'retrieve': + eval_dataset = dataset_class( + tokenizer=tokenizer, + graph_tokenizer=graph_tokenizer, + batch_size=1, + subset=FLAGS.eval_subset, + timesteps=FLAGS.eval_timesteps, + version=FLAGS.graph_data_version, + shuffle_data=False, + repeat=False, + graph_retrieval_dataset=True, + debug=FLAGS.debug) + eval_iter = iter(eval_dataset) + loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, + cache_steps=FLAGS.eval_memory_size) + # only use one device for evaluation + devices = local_devices[:1] + updater = Updater(loss_fn, optimizer=None, devices=devices, + has_graph=has_graph) + updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) + _retrieve(updater, eval_iter) + +if __name__ == '__main__': + app.run(main) diff --git a/wikigraphs/scripts/compute_blue_score.py b/wikigraphs/scripts/compute_blue_score.py new file mode 100644 index 0000000..6d9b9e4 --- /dev/null +++ b/wikigraphs/scripts/compute_blue_score.py @@ -0,0 +1,109 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Compute the bleu score on generated text and the ground truth.""" + +import math +import os +import pickle + +from absl import app +from absl import flags +from absl import logging + +import numpy as np + +import utils + + +flags.DEFINE_string('checkpoint_dir', '/tmp/transformerXL', + 'Checkpoint directory to load saved samples.') +flags.DEFINE_string('dataset', 'freebase2wikitext', 'Which dataset to the model' + ' is trained on, one of "wikitext", "freebase2wikitext".') + +FLAGS = flags.FLAGS + + +def group_samples(samples, tokenizer): + """Groups generated and ground truth texts.""" + groups = {} + for i, row in enumerate(samples): + gt = tokenizer.decode(row['ground_truth_text']) + sample = tokenizer.decode(row['sample_tokens']) + if gt not in groups: + groups[gt] = (gt.split(), [sample.split()]) + else: + groups[gt][-1].append(sample.split()) + if (i + 1) % 100 == 0: + logging.info('Processed %d samples', i + 1) + return groups + + +def eval_samples(raw_samples, tokenizer): + """Evaluates generated samples.""" + gt_refs = [] + samples = [] + + groups = group_samples(raw_samples, tokenizer) + groups = list(groups.values()) + avg_group_size = np.mean([len(g[-1]) for g in groups]) + logging.info('Average samples per example: %.2f', avg_group_size) + avg_group_size = int(math.ceil(avg_group_size)) + for i, (gt, s) in enumerate(groups): + gt_refs.append(gt) + idx = i % len(groups) + samples.append(groups[idx][-1]) + + gt_bleu, gt_n_grams = utils.compute_bleu(samples, gt_refs) + + logging.info('Processed %d samples in total.', sum([len(s) for s in samples])) + flat_samples = [] + for s in samples: + flat_samples.extend(s) + logging.info('Average sample len: %.2f', + np.mean([len(s) for s in flat_samples])) + logging.info('Average ground-truth len: %.2f', + np.mean([len(gt) for gt in gt_refs])) + + logging.info('Ground-truth BLEU: %6.2f, n-gram precision: (%s)', + gt_bleu * 100, + ', '.join(['%6.2f%%' % (s * 100) for s in gt_n_grams])) + + +def main(_): + tokenizer = utils.init_tokenizer(FLAGS.dataset) + checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl') + logging.info('Loading samples from %s', checkpoint_dir) + with open(checkpoint_dir, 'rb') as f: + samples = pickle.load(f)['samples'] + eval_samples(samples, tokenizer) + + +if __name__ == '__main__': + app.run(main) diff --git a/wikigraphs/updaters.py b/wikigraphs/updaters.py new file mode 100644 index 0000000..8bb258f --- /dev/null +++ b/wikigraphs/updaters.py @@ -0,0 +1,311 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Data Parallel Updater for Graph2text data.""" + +import functools +import os +import pickle + +from absl import logging + +import haiku as hk +import jax +from jax.tree_util import tree_multimap +import numpy as np +import optax + + +def call_fn_with_state_keys(jit_fn, state, other_inputs, keys): + """Executes `jit_fn`, filtering out all keys except some subset.""" + state = state.copy() + extra_state = {} + for k in list(state.keys()): + if k not in keys: + extra_state[k] = state.pop(k) + return jit_fn(state, *other_inputs), extra_state + + +class Updater: + """Graph2text model updater with multi-GPU support.""" + + def __init__(self, loss_fn, optimizer, devices=None, has_graph=False): + self._net_init_fn, self._apply_fn = hk.transform_with_state( + functools.partial(loss_fn, is_training=True)) + _, self._eval_apply_fn = hk.transform_with_state( + functools.partial(loss_fn, is_training=False)) + + if optimizer is None: + optimizer = optax.identity() + self._optimizer = optimizer + + self._num_devices = jax.local_device_count() + if devices is None: + devices = [] + for host_id in range(jax.process_count()): + for device_id in jax.local_devices(host_id): + devices.append(device_id) + else: + self._num_devices = min(self._num_devices, len(devices)) + + def _pmap(f, static_broadcasted_argnums=()): + return jax.pmap(f, axis_name='i', devices=devices, + static_broadcasted_argnums=static_broadcasted_argnums) + + def handle_graph_size(fn): + def _fn(*args): + batch = args[-1].copy() + max_graph_size = batch['max_graph_size'] + del batch['max_graph_size'] + args = args[:-1] + (batch, max_graph_size) + return fn(*args) + return _fn + + # Try to jit. + if has_graph: + # If the model contains full graphs, we need to set the max_graph_size + # as a statically broadcasted argument. + self._init_fn = handle_graph_size(_pmap(self._init, 4)) + self._update_fn = handle_graph_size(_pmap(self._update, 2)) + self._eval_fn = handle_graph_size(_pmap(self._eval, 2)) + else: + self._init_fn = _pmap(self._init) + self._update_fn = _pmap(self._update) + self._eval_fn = _pmap(self._eval) + + def _init(self, master_rng, params, network_state, data, max_graph_size=None): + """Initializes state of the updater.""" + out_rng, init_rng = jax.random.split(master_rng) + if max_graph_size is not None: + new_params, new_network_state = self._net_init_fn( + init_rng, data, max_graph_size) + else: + new_params, new_network_state = self._net_init_fn(init_rng, data) + if params is None: + params = new_params + if network_state is None: + network_state = new_network_state + opt_state = self._optimizer.init(params) + return dict( + replicated_step=0, + rng=out_rng, + state=network_state, + opt_state=opt_state, + params=params, + ) + + def init(self, master_rng, data, params=None, network_state=None, + replicated_params=False): + """Initializes state of the updater.""" + data = self._preprocess(data) + rngs = np.array([master_rng] * self._num_devices) + if not replicated_params and params is not None: + params = jax.tree_map( + lambda x: np.array([x] * self._num_devices), params) + state = self._init_fn(rngs, params, network_state, data) + state['step'] = np.array(0, dtype=np.int64) + # Wait for initialization to finish before starting training to keep + # memory usage low. + flat_params = jax.tree_leaves(state['params']) + if flat_params: + jax.tree_leaves(state['params'])[0].block_until_ready() + return state + + def _update(self, state, data, max_graph_size=None): + """Updates parameters.""" + replicated_step = state['replicated_step'] + rng = state['rng'] + opt_state = state['opt_state'] + params = state['params'] + net_state = state['state'] + + rng, new_rng = jax.random.split(rng) + rng = jax.random.fold_in(rng, jax.lax.axis_index('i')) + + def _loss(params, state, batch, rng): + if max_graph_size is not None: + (loss, metrics), state = self._apply_fn(params, state, rng, batch, + max_graph_size) + else: + (loss, metrics), state = self._apply_fn(params, state, rng, batch) + return loss, (metrics, state) + + (loss, (metrics, new_net_state)), g = jax.value_and_grad( + _loss, has_aux=True)(params, net_state, data, rng) + g = jax.lax.pmean(g, axis_name='i') + loss = jax.lax.pmean(loss, axis_name='i') + metrics = jax.lax.pmean(metrics, axis_name='i') + + updates, new_opt_state = self._optimizer.update(g, opt_state, params) + new_params = optax.apply_updates(params, updates) + + new_state = dict( + replicated_step=replicated_step + 1, + rng=new_rng, + state=new_net_state, + opt_state=new_opt_state, + params=new_params, + ) + + metrics['loss'] = loss + metrics['step'] = replicated_step + return new_state, metrics + + def update(self, state, data): + """Updates the state using some data and returns metrics.""" + data = self._preprocess(data) + (state, out), extra_state = call_fn_with_state_keys( + self._update_fn, state, [data], keys=set([ + 'state', 'params', 'rng', 'replicated_step', 'opt_state'])) + state.update(extra_state) + state['step'] += 1 + return state, tree_multimap(lambda x: x[0], out) + + def _eval(self, state, data, max_graph_size=None): + """Evaluates the current state on the given data.""" + if max_graph_size is not None: + (loss, metrics), new_state = self._eval_apply_fn( + state['params'], state['state'], state['rng'], data, max_graph_size) + else: + (loss, metrics), new_state = self._eval_apply_fn( + state['params'], state['state'], state['rng'], data) + state['state'] = new_state + loss = jax.lax.pmean(loss, axis_name='i') + metrics = jax.lax.pmean(metrics, axis_name='i') + metrics['loss'] = loss + metrics['step'] = state['replicated_step'] + return state, metrics + + def eval_return_state(self, state, data): + """Returns metrics without updating the model.""" + data = self._preprocess(data) + (state, out), extra_state = call_fn_with_state_keys( + self._eval_fn, state, [data], keys=set([ + 'state', 'params', 'rng', 'replicated_step'])) + state.update(extra_state) + return state, tree_multimap(lambda x: x[0], out) + + def eval(self, state, data): + """Returns metrics without updating the model.""" + _, out = self.eval_return_state(state, data) + return out + + def _preprocess(self, data): + """Reshapes input so that it can be distributed across multiple cores.""" + multi_inputs = data.copy() + + def add_core_dimension(x): + if np.isscalar(x): + return x + if x.shape[0] % self._num_devices != 0: + raise ValueError(f'The batch size must be a multiple of the number of' + f' devices. Got batch size = {x.shape[0]} and number' + f' of devices = {self._num_devices}.') + prefix = (self._num_devices, x.shape[0] // self._num_devices) + return np.reshape(x, prefix + x.shape[1:]) + + multi_inputs = tree_multimap(add_core_dimension, multi_inputs) + return multi_inputs + + def params(self, state): + """Returns model parameters.""" + return tree_multimap(lambda x: x[0], state['params']) + + def opt_state(self, state): + """Returns the state of the optimiser.""" + return tree_multimap(lambda x: x[0], state['opt_state']) + + def network_state(self, state): + """Returns the model's state.""" + return tree_multimap(lambda x: x[0], state['state']) + + def to_checkpoint_state(self, state): + """Transforms the updater state into a checkpointable state.""" + checkpoint_state = state.copy() + # Wrapper around checkpoint_state['step'] so we can get [0]. + checkpoint_state['step'] = checkpoint_state['step'][np.newaxis] + # Unstack the replicated contents. + checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state) + return checkpoint_state + + def from_checkpoint_state(self, checkpoint_state): + """Initializes the updater state from the checkpointed state.""" + # Expand the checkpoint so we have a copy for each device. + state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]), + checkpoint_state) + state['step'] = state['step'][0] # Undo stacking for step. + return state + + +class CheckpointingUpdater: + """A checkpointing wrapper around an Updater.""" + + def __init__(self, + inner: Updater, + checkpoint_dir: str): + self._inner = inner + self._checkpoint_dir = checkpoint_dir + + def _checkpoint_paths(self): + return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint' in p] + + def init(self, rng, data, params=None, network_state=None): + """Initialize experiment state.""" + if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths(): + os.makedirs(self._checkpoint_dir, exist_ok=True) + return self._inner.init(rng, data, params, network_state) + return self.load_checkpoint() + + def init_from_checkpoint(self, rng, data, checkpoint_state): + params = self._inner.params(checkpoint_state) + network_state = None + return self._inner.init(rng, data, params, network_state) + + def eval_return_state(self, state, data): + return self._inner.eval_return_state(state, data) + + def save_checkpoint(self, state): + path = os.path.join(self._checkpoint_dir, 'checkpoint.pkl') + logging.info('Serializing experiment state to %s', path) + checkpoint_state = self._inner.to_checkpoint_state(jax.device_get(state)) + with open(path, 'wb') as f: + pickle.dump(checkpoint_state, f) + + def load_checkpoint(self): + checkpoint = os.path.join(self._checkpoint_dir, + self._checkpoint_paths()[-1]) + logging.info('Loading checkpoint from %s', checkpoint) + with open(checkpoint, 'rb') as f: + state = pickle.load(f) + return self._inner.from_checkpoint_state(state) + + def update(self, state, data): + """Update experiment state.""" + state, out = self._inner.update(state, data) + return state, out diff --git a/wikigraphs/utils.py b/wikigraphs/utils.py new file mode 100644 index 0000000..09af2c5 --- /dev/null +++ b/wikigraphs/utils.py @@ -0,0 +1,522 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Utility functions for the training script.""" + +import collections +import math +import random + +from absl import flags +from absl import logging + +import jax.numpy as jnp +import jraph +import numpy as np +import sklearn + +from wikigraphs.data import paired_dataset as pd +from wikigraphs.data import tokenizers +from wikigraphs.data import wikitext as wt +from wikigraphs.model import graph_net as gn +from wikigraphs.model import sampler as transformer_sampler +from wikigraphs.model import transformer + + +FLAGS = flags.FLAGS + +VOCAB_FILES_MAP = { + 'wikitext': '/tmp/data/wikitext-vocab.csv', + 'freebase2wikitext': '/tmp/data/text-vocab.csv', +} + +GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv' + + +def init_tokenizer(dataset_name): + """Initialie the tokenizer.""" + logging.info('Loading tokenizer...') + tokenizer = tokenizers.WordTokenizer(VOCAB_FILES_MAP[dataset_name]) + logging.info('Vocab size: %d', tokenizer.vocab_size) + return tokenizer + + +def init_graph_tokenizer(): + """Initialie the tokenizer.""" + logging.info('Loading graph tokenizer...') + tokenizer = tokenizers.GraphTokenizer(GRAPH_VOCAB_FILE) + logging.info('Vocab size: %d', tokenizer.vocab_size) + return tokenizer + + +def get_dataset_class(dataset_name, model_type, job_mode='train'): + """Get the dataset class used for all jobs.""" + if dataset_name == 'freebase2wikitext': + if model_type == 'bow2text': + return pd.Bow2TextDataset + elif FLAGS.model_type == 'graph2text': + return pd.Graph2TextDataset + elif FLAGS.model_type == 'text': + if job_mode in ['train', 'eval']: + return pd.TextOnlyDataset + else: + # for sampling: taking the unique graphs for a fair comparison + return pd.Bow2TextDataset + else: + # Add other graph2text data here. + raise NotImplementedError() + else: + def dataset(graph_tokenizer, *args, **kwargs): + del graph_tokenizer + return wt.Dataset(*args, **kwargs) + return dataset + + +def preprocess(batch, model_type, num_devices=1): + """Preprocess the batch before sending to the model.""" + if model_type == 'text': + if 'graphs' in batch: + del batch['graphs'] + elif model_type == 'bow2text': + # Do nothing, bow2text data is already in a good form. + pass + else: # graph2text + if num_devices == 1: + graphs = gn.pad_graphs(jraph.batch(batch['graphs'])) + else: + # We need to first batch graphs into num_devices batchs. + graphs = gn.batch_graphs_by_device(batch['graphs'], num_devices) + # Then we pad them to the maximum graph size in the batch and concat. + # This way graphs can be distributed to each device through pmap. + graphs = gn.pad_graphs_by_device(graphs) + max_graph_size = gn.pad_size(graphs.n_node.max()) + batch.update({ + 'graphs': graphs, + 'max_graph_size': max_graph_size}) + return batch + + +def text_model_fn(vocab_size): + return transformer.TransformerXL( + vocab_size=vocab_size, + emb_dim=FLAGS.emb_dim, + num_layers=FLAGS.num_layers, + num_heads=FLAGS.num_heads, + dropout_prob=FLAGS.dropout, + dropout_attn_prob=FLAGS.dropout_attn, + self_att_init_scale=FLAGS.self_att_init_scale, + dense_init_scale=FLAGS.dense_init_scale, + dense_dim=FLAGS.dense_dim, + tail_shrink_factor=FLAGS.tail_shrink_factor, + relative_pos_clamp_len=FLAGS.clamp_len or None) + + +def graph2text_model_fn(vocab_size): + """Get graph2text transformer model.""" + return transformer.Graph2TextTransformer( + vocab_size=vocab_size, + emb_dim=FLAGS.emb_dim, + num_layers=FLAGS.num_layers, + num_heads=FLAGS.num_heads, + dropout_prob=FLAGS.dropout, + dropout_attn_prob=FLAGS.dropout_attn, + self_att_init_scale=FLAGS.self_att_init_scale, + dense_init_scale=FLAGS.dense_init_scale, + dense_dim=FLAGS.dense_dim, + tail_shrink_factor=FLAGS.tail_shrink_factor, + relative_pos_clamp_len=FLAGS.clamp_len or None, + gnn_embed_dim=FLAGS.gnn_embed_dim, + gnn_num_layers=FLAGS.gnn_num_layers, + gnn_layer_norm=FLAGS.gnn_layer_norm) + + +def bow2text_model_fn(vocab_size): + """Get the bow2text model.""" + return transformer.Bow2TextTransformer( + vocab_size=vocab_size, + emb_dim=FLAGS.emb_dim, + num_layers=FLAGS.num_layers, + num_heads=FLAGS.num_heads, + dropout_prob=FLAGS.dropout, + dropout_attn_prob=FLAGS.dropout_attn, + self_att_init_scale=FLAGS.self_att_init_scale, + dense_init_scale=FLAGS.dense_init_scale, + dense_dim=FLAGS.dense_dim, + tail_shrink_factor=FLAGS.tail_shrink_factor, + relative_pos_clamp_len=FLAGS.clamp_len or None, + bow_embedding_dim=FLAGS.bow_embedding_dim, + bow_n_tokens=FLAGS.bow_n_tokens) + + +def build_loss_fn(vocab_size, cache_steps): + """Build the appropriate loss function according to the configs.""" + if FLAGS.model_type == 'text': + def loss_fn(data, is_training=True): + return text_model_fn(vocab_size=vocab_size).loss( + data['obs'], data['target'], data['mask'], + is_training=is_training, + should_reset=data['should_reset'], + cache_steps=cache_steps) + elif FLAGS.model_type == 'graph2text': + def loss_fn(data, max_graph_size, is_training=True): + return graph2text_model_fn(vocab_size=vocab_size).loss( + data['graphs'], max_graph_size, True, + data['obs'], data['target'], data['mask'], + is_training=is_training, + should_reset=data['should_reset'], + cache_steps=cache_steps) + elif FLAGS.model_type == 'bow2text': + def loss_fn(data, is_training=True): + return bow2text_model_fn(vocab_size=vocab_size).loss( + data['graphs'], data['obs'], data['target'], data['mask'], + is_training=is_training, + should_reset=data['should_reset'], + cache_steps=cache_steps) + else: + raise ValueError(f'Unknown model type "{FLAGS.model_type}".') + return loss_fn + + +def build_sampler(tokenizer, device=None): + """Build the appropriate sampler according to the configs.""" + if FLAGS.model_type == 'text': + model_fn = lambda prompts: text_model_fn(tokenizer.vocab_size)( # pylint: disable=g-long-lambda + prompts, is_training=False, cache_steps=FLAGS.sample_memory_size) + sampler_class = transformer_sampler.TransformerXLSampler + elif FLAGS.model_type == 'graph2text': + def model_fn(graphs, max_graph_size, prompts): + return graph2text_model_fn(tokenizer.vocab_size)( + graphs, max_graph_size, True, prompts, is_training=False, + cache_steps=FLAGS.sample_memory_size) + sampler_class = transformer_sampler.Graph2TextTransformerSampler + elif FLAGS.model_type == 'bow2text': + def model_fn(graphs, prompts): + return bow2text_model_fn(tokenizer.vocab_size)( + graphs, prompts, is_training=False, + cache_steps=FLAGS.sample_memory_size) + sampler_class = transformer_sampler.Bow2TextTransformerSampler + sampler = sampler_class(model_fn, FLAGS.sampling_temperature, device) + return sampler + + +def schedule(i, lr_schedule, init_lr, min_lr_ratio, max_steps): + if lr_schedule == 'cosine': + cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * i / max_steps)) + decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio + return init_lr * decayed + else: + return jnp.where( + i > 350000, init_lr / 3**3, + jnp.where(i > 250000, init_lr / 3**2, + jnp.where(i > 150000, init_lr / 3, init_lr))) + + +def evaluate(eval_set, initial_state, updater, eval_batch_size=1, + preprocess_fn=None, max_eval_samples=-1, + print_progress_every=None): + """Evaluate a model on given dataset.""" + total_losses = [] + total_counts = [] + token_accuracy = [] + seq_accuracy = [] + state = initial_state + step = state['step'] + for i, batch in enumerate(eval_set): + state, eval_out = updater.eval_return_state(state, preprocess_fn(batch)) + total_losses.append(eval_out['total_loss']) + total_counts.append(eval_out['total_count']) + token_accuracy.append( + eval_out['token_accuracy'] * eval_out['total_count']) + seq_accuracy.append(eval_out['seq_accuracy']) + if print_progress_every and (i + 1) % print_progress_every == 0: + total_loss = float(jnp.array(total_losses).sum()) + total_count = float(jnp.array(total_counts).sum()) + avg_loss = total_loss / total_count + bpc = avg_loss * np.log2(np.e) + perplexity = np.exp(avg_loss) + logging.info( + 'Evaluated %d batches, total tokens %d, average loss %g,' + ' bpc %g, perplexity %g.', + i + 1, total_count, avg_loss, bpc, perplexity) + if 0 < max_eval_samples <= (i + 1) * eval_batch_size: + break + + total_loss = jnp.array(total_losses).sum() + total_count = jnp.array(total_counts).sum() + avg_loss = total_loss / total_count + eval_out = dict(total_loss=float(total_loss), + total_count=float(total_count), + loss=float(avg_loss), + token_accuracy=float( + jnp.array(token_accuracy).sum() / total_count), + seq_accuracy=float( + jnp.array(seq_accuracy).sum() / len(seq_accuracy)), + step=float(step), + bits_per_token=float(avg_loss) * np.log2(np.e), + perplexity=np.exp(float(avg_loss))) + return eval_out, state + + +def extract_title(text, tokenizer): + r"""Extract the title in the text. + + The wikitext articles is in the format of `\n = TITLE = \n \n...`. We extract + the title as the tokens from the start to when the `\n \n` first appears. + + Args: + text: tokenized input text using `tokenizer`. + tokenizer: text tokenizer. + + Returns: + title_end_idx: a numpy.array of shape (batch_size,), it indicates the index + in `text` that marks the end of the title. + """ + batch_size, text_length = text.shape + title_end_idx = np.ones(batch_size, dtype=np.int32) + newline_token = tokenizer.encode('\n')[0] + for b in range(batch_size): + prev_token = 1 # start tokens + for i in range(1, text_length): # skip start token + # when we first see '\n \n', that is the title + if prev_token == newline_token and text[b, i] == newline_token: + title_end_idx[b] = i + break + else: + prev_token = text[b, i] + return title_end_idx + + +def construct_prompts(text, batch_size, sample_length, tokenizer, prompt_title): + """Construct prompts for text generation. + + Args: + text: tokenized input text using `tokenizer`. + batch_size: the size of the batch. + sample_length: the length of the sample to be generated. + tokenizer: text tokenizer. + prompt_title: whether to return a prompt with the title of the `text`. + + Returns: + prompts: a numpy.array of shape [batch_size, sample_length], in which -1 + indicates tokens that need to be generated using the sampler. + + """ + prompts = -np.ones((batch_size, sample_length), dtype=np.int32) + prompts[:, 0] = tokenizer.bos_token() + if prompt_title and text is not None: + title_end_idx = extract_title(text, tokenizer) + for i in range(batch_size): + prompts[i, 1:title_end_idx[i]+1] = text[i, 1:title_end_idx[i]+1] + return prompts + + +def generate_samples(params, tokenizer, sampler, model_type, prompts, graphs): + """Generate a batch of samples using a sampler.""" + if model_type == 'text': + samples = sampler.sample(params, prompts) + elif model_type == 'graph2text': + samples = sampler.sample(params, prompts, graphs, pad=True) + elif model_type == 'bow2text': + samples = sampler.sample(params, prompts, graphs) + else: + raise ValueError(f'Unknown model_type {model_type}') + return [tokenizer.decode(s) for s in samples], samples + + +def take_unique_graphs(data_iter, model_type): + """Filter data such that it only returns batches with unique graphs.""" + prev_graphs = None + for batch in data_iter: + graphs = batch.get('graphs', None) + # If there's no graph in batch, don't do any filtering + if graphs is None: + yield batch + else: + if prev_graphs is None: + prev_graphs = graphs + yield batch + else: + if model_type == 'graph2text': + not_same_graph = (prev_graphs.nodes.shape != graphs.nodes.shape or + not (prev_graphs.nodes == graphs.nodes).all()) + else: + not_same_graph = (prev_graphs.shape != graphs.shape or + not (prev_graphs == graphs).all()) + if not_same_graph: + prev_graphs = graphs + yield batch + + +def compute_map_sklearn(pred, gt): + """Computes mAP using scikit-learn.""" + assert len(gt.shape) == len(pred.shape) == 2, ( + 'gt should be a one-hot encoding with the same shape as pred') + ap = [ + sklearn.metrics.average_precision_score( + gt[c, :], pred[c, :], average=None) + for c in range(gt.shape[0]) + ] + return sum(ap) / len(ap) + + +def compute_recall_at_k(pred, k=1): + """Computes recall@1 score.""" + num_articles = pred.shape[1] + return sklearn.metrics.top_k_accuracy_score( + np.arange(num_articles), pred, k=k) + + +def compute_text_graph_relevance( + eval_set, initial_state, updater, eval_batch_size=1, preprocess_fn=None, + print_progress_every=None): + """Compute the text and graph relevance a model on given dataset.""" + assert eval_batch_size == 1 + num_articles = eval_set.num_articles + tokens_count = np.zeros((num_articles, num_articles)) + log_probs = np.zeros((num_articles, num_articles)) # [graphs, texts] + state = initial_state + for i, batch in enumerate(eval_set): + state, eval_out = updater.eval_return_state(state, preprocess_fn(batch)) + graph_id = batch['graph_id'][0] + seq_id = batch['seq_id'][0] + tokens_count[graph_id, seq_id] += eval_out['total_count'] + log_probs[graph_id, seq_id] += eval_out['log_probs'] + if print_progress_every is not None and (i + 1) % print_progress_every == 0: + logging.info('Evaluated %d samples', i + 1) + + log_probs_per_token = log_probs / tokens_count + labels = np.eye(num_articles) + eval_out = dict( + log_probs=log_probs, + tokens_count=tokens_count, + log_probs_per_token=log_probs_per_token, + text2graph_recall_at_1=compute_recall_at_k(log_probs_per_token.T, k=1), + text2graph_recall_at_5=compute_recall_at_k(log_probs_per_token.T, k=5), + text2graph_map=compute_map_sklearn(log_probs_per_token.T, labels), + graph2text_recall_at_1=compute_recall_at_k(log_probs_per_token, k=1), + graph2text_recall_at_5=compute_recall_at_k(log_probs_per_token, k=5), + graph2text_map=compute_map_sklearn(log_probs_per_token, labels)) + return eval_out, state + + +def _get_ngrams(segment, max_order): + """Extracts all n-grams upto a given maximum order from an input segment. + + Args: + segment: text segment from which n-grams will be extracted. + max_order: maximum length in tokens of the n-grams returned by this + methods. + + Returns: + The Counter containing all n-grams upto max_order in segment + with a count of how many times each n-gram occurred. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i:i+order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def compute_bleu(reference_corpus, translation_corpus, max_order=4, + smooth=False): + """Computes BLEU score of translated segments against one or more references. + + Originally from tensor2tensor/tensor2tensor/utils/bleu_hook.py + + Args: + reference_corpus: list of lists of references for each translation. Each + reference should be tokenized into a list of tokens. + translation_corpus: list of translations to score. Each translation + should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. + + Returns: + BLEU score and n-gram precisions. + """ + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for (references, translation) in zip(reference_corpus, + translation_corpus): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram)-1] += overlap[ngram] + for order in range(1, max_order+1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order-1] += possible_matches + + if random.random() < 0.01: + print('==========') + for k, v in overlap.items(): + if len(k) >= 3: + print('%s : %d' % (str(k), v)) + + # print(matches_by_order) + # print(possible_matches_by_order) + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = ((matches_by_order[i] + 1.) / + (possible_matches_by_order[i] + 1.)) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1. + else: + bp = math.exp(1 - 1. / ratio) + + bleu = geo_mean * bp + + return bleu, precisions diff --git a/wikigraphs/wikigraphs/model/__init__.py b/wikigraphs/wikigraphs/model/__init__.py index 991e3c1..dccb10f 100644 --- a/wikigraphs/wikigraphs/model/__init__.py +++ b/wikigraphs/wikigraphs/model/__init__.py @@ -30,5 +30,6 @@ """WikiGraphs model modules.""" from . import embedding from . import graph_net +from . import sampler from . import transformer from . import transformer_block diff --git a/wikigraphs/wikigraphs/model/graph_net.py b/wikigraphs/wikigraphs/model/graph_net.py index cbc627b..fedbd4a 100644 --- a/wikigraphs/wikigraphs/model/graph_net.py +++ b/wikigraphs/wikigraphs/model/graph_net.py @@ -41,7 +41,7 @@ import numpy as np ArrayType = Union[np.ndarray, jnp.ndarray] -def pad_size(in_size): +def pad_size(in_size: int): out_size = 1 while out_size < in_size: out_size *= 2 diff --git a/wikigraphs/wikigraphs/model/sampler.py b/wikigraphs/wikigraphs/model/sampler.py new file mode 100644 index 0000000..62d79af --- /dev/null +++ b/wikigraphs/wikigraphs/model/sampler.py @@ -0,0 +1,361 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Samplers for the graph2text transformers.""" + +import abc +from typing import Any, Optional, Mapping + +import haiku as hk +import jax +import jax.numpy as jnp +import jraph +import numpy as np + +from wikigraphs.model import graph_net as gn + + +class BaseSampler: + """Base class for transformer samplers.""" + + def __init__(self, + model_fn, + temperature: float = 1.0, + device: Optional[Any] = None, + rng: Optional[np.ndarray] = None): + """Constructor. + + Args: + model_fn: a transformer language model defined in model.transformer. + temperature: sampling temperature. + device: the sampler will run on this device if provided. + rng: random number generator. + """ + self._temperature = temperature + self._device = device or jax.local_devices()[0] + init_fn, apply_fn = hk.transform_with_state(model_fn) + + if rng is None: + rng = jax.random.PRNGKey(np.random.randint(2**32)) + rng = jax.random.fold_in(rng, jax.host_id()) + self._rng = rng + self._init_state = None + self._jit_model(init_fn, apply_fn) + + def _jit_model(self, init_fn, apply_fn): + """Jit the `init_fn` and `apply_fn`.""" + pass + + @abc.abstractmethod + def _sample(self, + params: Mapping[str, Any], + state: Mapping[str, Any], + rng: jnp.ndarray, + x: jnp.ndarray, + **kwargs) -> np.ndarray: + """Generate samples. + + Args: + params: parameters of the transformer. + state: state of the transformer. + rng: random number generator. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + **kwargs: additional inputs. + + Returns: + output: [batch_size, sample_len] tensor, the generated sequence. + """ + + @abc.abstractmethod + def sample(self, + params: Mapping[str, Any], + x: jnp.ndarray, + **kwargs) -> jnp.ndarray: + """Generate samples based on the given parameters and prompts. + + Args: + params: parameters of the transformer. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + **kwargs: additional inputs. + + Returns: + output: the generated sequence. + """ + + +class TransformerXLSampler(BaseSampler): + """Sampling from the TransformerXL model.""" + + def _jit_model(self, init_fn, apply_fn): + """Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`.""" + self._init_fn = jax.jit(init_fn, device=self._device) + self._apply_fn = apply_fn + self._sample_fn = jax.jit(self._sample, device=self._device) + + def _sample(self, + params: Mapping[str, Any], + state: Mapping[str, Any], + rng: jnp.ndarray, + x: jnp.ndarray) -> np.ndarray: + """Generate unconditional samples. + + Args: + params: parameters of the transformer. + state: state of the transformer. + rng: random number generator. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + + Returns: + output: [batch_size, sample_len] tensor, the generated sequence. + """ + batch_size, sample_len = x.shape + + def one_step(params, state, rng, i, x): + step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1]) + rng, rng_ = jax.random.split(rng) + # step_sample shape is [batch_size, 1]. + logits, state = self._apply_fn(params, state, rng_, step_sample) + rng, rng_ = jax.random.split(rng) + step_sample = jax.random.categorical(rng_, logits / self._temperature) + update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:, + None] + x = jax.lax.dynamic_update_slice(x, update, [0, i + 1]) + return state, rng, x + + def loop_body(i, data): + state, rng, x = data + return one_step(params, state, rng, i, x) + + _, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body, + (state, rng, x)) + + return x + + def sample(self, + params: Mapping[str, Any], + x: jnp.ndarray) -> jnp.ndarray: + """Generate samples based on the given graphs and parameters. + + Args: + params: parameters of the transformer. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + + Returns: + output: the generated sequence. + """ + if self._init_state is None: + self._rng, rng = jax.random.split(self._rng) + self._init_params, self._init_state = self._init_fn(rng, x[:, :1]) + if params is None: + params = self._init_params + + self._rng, rng = jax.random.split(self._rng) + sample = self._sample_fn(params, self._init_state, rng, x) + return sample + + +class Bow2TextTransformerSampler(BaseSampler): + """Sampling from the TransformerXL model.""" + + def _jit_model(self, init_fn, apply_fn): + """Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`.""" + self._init_fn = jax.jit(init_fn, device=self._device) + self._apply_fn = apply_fn + self._sample_fn = jax.jit(self._sample, device=self._device) + + def _sample(self, + params: Mapping[str, Any], + state: Mapping[str, Any], + rng: jnp.ndarray, + bow: jnp.ndarray, + x: jnp.ndarray) -> np.ndarray: + """Generate samples conditioned on the bag-of-words of the graph. + + Args: + params: parameters of the transformer. + state: state of the transformer. + rng: random number generator. + bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + + Returns: + output: [batch_size, sample_len] tensor, the generated sequence. + """ + batch_size, sample_len = x.shape + + def one_step(params, state, rng, i, x): + step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1]) + rng, rng_ = jax.random.split(rng) + # step_sample shape is [batch_size, 1]. + logits, state = self._apply_fn(params, state, rng_, bow, step_sample) + rng, rng_ = jax.random.split(rng) + step_sample = jax.random.categorical(rng_, logits / self._temperature) + update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:, + None] + x = jax.lax.dynamic_update_slice(x, update, [0, i + 1]) + return state, rng, x + + def loop_body(i, data): + state, rng, x = data + return one_step(params, state, rng, i, x) + + _, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body, + (state, rng, x)) + + return x + + def sample(self, + params: Mapping[str, Any], + x: jnp.ndarray, + bow: jnp.ndarray) -> jnp.ndarray: + """Generate samples based on the given graphs and parameters. + + Args: + params: parameters of the transformer. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. + + Returns: + output: the generated sequence. + """ + if self._init_state is None: + self._rng, rng = jax.random.split(self._rng) + self._init_params, self._init_state = self._init_fn(rng, bow, x[:, :1]) + if params is None: + params = self._init_params + + self._rng, rng = jax.random.split(self._rng) + sample = self._sample_fn(params, self._init_state, rng, bow, x) + return sample + + +class Graph2TextTransformerSampler(BaseSampler): + """Sampling from the Graph2Text TransformerXL model.""" + + def _jit_model(self, init_fn, apply_fn): + """Jit `init_fn` and `apply_fn`, the latter is used in `self._sample`.""" + # `pad_n_nodes` is set as a static argument. + self._init_fn = jax.jit(init_fn, device=self._device, static_argnums=2) + self._apply_fn = apply_fn + self._sample_fn = jax.jit(self._sample, device=self._device, + static_argnums=4) + + def _sample(self, + params: Mapping[str, Any], + state: Mapping[str, Any], + rng: jnp.ndarray, + graphs: jraph.GraphsTuple, + pad_n_nodes: int, + x: jnp.ndarray) -> np.ndarray: + """Generate samples conditioned on the bag-of-words reprensation of graph. + + Args: + params: parameters of the transformer. + state: state of the transformer. + rng: random number generator. + graphs: a graph structured using graph_net.Graph. + pad_n_nodes: size for each node to pad to. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + + Returns: + output: [batch_size, sample_len] tensor, the generated sequence. + """ + batch_size, sample_len = x.shape + + def one_step(params, state, rng, i, x): + step_sample = jax.lax.dynamic_slice(x, [0, i], [batch_size, 1]) + rng, rng_ = jax.random.split(rng) + # step_sample shape is [batch_size, 1]. + logits, state = self._apply_fn( + params, state, rng_, graphs, pad_n_nodes, step_sample) + rng, rng_ = jax.random.split(rng) + step_sample = jax.random.categorical(rng_, logits / self._temperature) + update = jnp.where(x[:, i + 1] < 0, step_sample[:, 0], x[:, i + 1])[:, + None] + x = jax.lax.dynamic_update_slice(x, update, [0, i + 1]) + return state, rng, x + + def loop_body(i, data): + state, rng, x = data + return one_step(params, state, rng, i, x) + + _, _, x = jax.lax.fori_loop(0, sample_len - 1, loop_body, + (state, rng, x)) + + return x + + def sample(self, + params: Mapping[str, Any], + x: jnp.ndarray, + graphs: jraph.GraphsTuple, + pad: bool = True) -> jnp.ndarray: + """Generate samples based on the given graphs and parameters. + + Args: + params: parameters of the transformer. + x: a prompt of shape [batch_size, sample_len], in which an entry of -1 + indicates it will be generate at that place. Otherwise it acts as the + prompt. + graphs: a graph structured using graph_net.Graph. + pad: whether to pad the graph nodes and edges or not. + + Returns: + output: the generated sequence. + """ + if pad: + graphs = gn.pad_graphs(graphs) + max_graph_size = gn.pad_size(graphs.n_node.max()) + else: + max_graph_size = graphs.n_node.max() + + if self._init_state is None: + self._rng, rng = jax.random.split(self._rng) + self._init_params, self._init_state = self._init_fn( + rng, graphs, max_graph_size, x[:, :1]) + if params is None: + params = self._init_params + + self._rng, rng = jax.random.split(self._rng) + sample = self._sample_fn( + params, self._init_state, rng, graphs, max_graph_size, x) + return sample diff --git a/wikigraphs/wikigraphs/model/sampler_test.py b/wikigraphs/wikigraphs/model/sampler_test.py new file mode 100644 index 0000000..2d17a66 --- /dev/null +++ b/wikigraphs/wikigraphs/model/sampler_test.py @@ -0,0 +1,140 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.model.sampler.""" + +from absl.testing import absltest +import jraph +import numpy as np + +from wikigraphs.model import sampler +from wikigraphs.model import transformer as models + + +class SamplerTest(absltest.TestCase): + + def test_uncond_sampler_runs(self): + prompt = np.array([[0, 1, 2, -1, -1], + [0, 1, 2, -1, -1]], dtype=np.int32) + vocab_size = prompt.max() + 1 + bos_token = 0 + memory_size = 2 + params = None + + def model_fn(x): + return models.TransformerXL( + vocab_size=vocab_size, + emb_dim=8, + num_layers=2, + num_heads=4, + cutoffs=[])(x, is_training=False, cache_steps=memory_size) + + uncond_sampler = sampler.TransformerXLSampler(model_fn) + sample = uncond_sampler.sample(params, prompt) + self.assertTrue((sample[:, 0] == bos_token).all()) + self.assertTrue((sample != -1).all()) + self.assertEqual(sample.shape, prompt.shape) + sample2 = uncond_sampler.sample(params, prompt) + self.assertTrue((sample2[:, 0] == bos_token).all()) + self.assertTrue((sample2 != -1).all()) + self.assertEqual(sample2.shape, prompt.shape) + self.assertTrue((sample != sample2).any()) + + def test_bow2text_sampler_runs(self): + bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], + [0, 1, 0, 0, 1, 0, 1, 0]], dtype=np.int32) + prompt = np.array([[0, 1, 2, -1, -1, -1], + [0, 1, 2, -1, -1, -1]], dtype=np.int32) + vocab_size = prompt.max() + 1 + bos_token = 0 + memory_size = 2 + params = None + + def model_fn(bow, x): + return models.Bow2TextTransformer( + vocab_size=vocab_size, + emb_dim=16, + num_layers=2, + num_heads=4, + cutoffs=[])(bow, x, is_training=False, cache_steps=memory_size) + + bow_sampler = sampler.Bow2TextTransformerSampler(model_fn) + sample = bow_sampler.sample(params, prompt, bow) + self.assertTrue((sample[:, 0] == bos_token).all()) + self.assertTrue((sample != -1).all()) + self.assertEqual(sample.shape, prompt.shape) + sample2 = bow_sampler.sample(params, prompt, bow) + self.assertTrue((sample2[:, 0] == bos_token).all()) + self.assertTrue((sample2 != -1).all()) + self.assertEqual(sample2.shape, prompt.shape) + self.assertTrue((sample != sample2).any()) + + def test_graph2text_sampler_runs(self): + graphs = jraph.GraphsTuple( + nodes=np.ones((4, 3), dtype=np.float32), + edges=np.ones((3, 1), dtype=np.float32), + senders=np.array([0, 2, 3], dtype=np.int32), + receivers=np.array([1, 3, 2], dtype=np.int32), + n_node=np.array([2, 2], dtype=np.int32), + n_edge=np.array([1, 2], dtype=np.int32), + globals=None, + ) + prompt = np.array([[0, 1, 2, -1, -1, -1], + [0, 1, 2, -1, -1, -1]], dtype=np.int32) + vocab_size = prompt.max() + 1 + bos_token = 0 + memory_size = 2 + params = None + + def model_fn(graphs, max_graph_size, x): + return models.Graph2TextTransformer( + vocab_size=vocab_size, + emb_dim=8, + num_layers=2, + num_heads=4, + cutoffs=[], + gnn_embed_dim=8, + gnn_num_layers=2)( + graphs, max_graph_size, True, x, + is_training=False, cache_steps=memory_size) + + graph_sampler = sampler.Graph2TextTransformerSampler(model_fn) + sample = graph_sampler.sample(params, prompt, graphs) + self.assertTrue((sample[:, 0] == bos_token).all()) + self.assertTrue((sample != -1).all()) + self.assertEqual(sample.shape, prompt.shape) + sample2 = graph_sampler.sample(params, prompt, graphs) + self.assertTrue((sample2[:, 0] == bos_token).all()) + self.assertTrue((sample2 != -1).all()) + self.assertEqual(sample2.shape, prompt.shape) + self.assertTrue((sample != sample2).any()) + + +if __name__ == '__main__': + absltest.main()