diff --git a/scratchgan/README.md b/scratchgan/README.md new file mode 100644 index 0000000..905b971 --- /dev/null +++ b/scratchgan/README.md @@ -0,0 +1,61 @@ +# ScratchGAN + +This is the example code for the following NeurIPS 2019 paper. If you use the +code here please cite this paper: + + @article{DBLP:journals/corr/abs-1905-09922, + author = {Cyprien de Masson d'Autume and + Mihaela Rosca and + Jack W. Rae and + Shakir Mohamed}, + title = {Training language GANs from Scratch}, + journal = {CoRR}, + volume = {abs/1905.09922}, + year = {2019}, + url = {http://arxiv.org/abs/1905.09922}, + archivePrefix = {arXiv}, + eprint = {1905.09922}, + timestamp = {Wed, 29 May 2019 11:27:50 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1905-09922}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + + +## Contents + +The code contains: + + * `generators.py`: implementation of the generator. + * `discriminator_nets.py`: implementation of the discriminator. + * `eval_metrics.py`: implementation of the FED metric. + * `losses.py`: implementation of the RL loss for the generator. + * `reader.py`: data reader / tokenizer. + * `experiment.py`: main training script. + +The data contains: + + * `{train,valid,test}.json`: the EMNLP2017 News dataset. + * `glove_emnlp2017.txt`: the relevant subset of GloVe embeddings. + +## Running + +Place the data files in the directory specified by `data_dir` flag. + +Create and activate a virtual environment if needed: + + virtualenv scratchgan-venv + source scratchgan-venv/bin/activate + +Install requirements: + + pip install -r scratchgan/requirements.txt + +Run training and evaluation jobs: + + python2 scratchgan.experiment.py --mode="train" & + python2 scratchgan.experiment.py --mode="evaluate_pair" & + +The evaluation code is designed to run in parallel with the training. + +The training code saves checkpoints periodically, the evaluation code +looks for new checkpoints and evaluate them. diff --git a/scratchgan/__init__.py b/scratchgan/__init__.py new file mode 100644 index 0000000..96235ae --- /dev/null +++ b/scratchgan/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/scratchgan/discriminator_nets.py b/scratchgan/discriminator_nets.py new file mode 100644 index 0000000..37817ec --- /dev/null +++ b/scratchgan/discriminator_nets.py @@ -0,0 +1,121 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Discriminator networks for text data.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sonnet as snt +import tensorflow as tf +from scratchgan import utils + + +class LSTMEmbedDiscNet(snt.AbstractModule): + """An LSTM discriminator that operates on word indexes.""" + + def __init__(self, + feature_sizes, + vocab_size, + use_layer_norm, + trainable_embedding_size, + dropout, + pad_token, + embedding_source=None, + vocab_file=None, + name='LSTMEmbedDiscNet'): + super(LSTMEmbedDiscNet, self).__init__(name=name) + self._feature_sizes = feature_sizes + self._vocab_size = vocab_size + self._use_layer_norm = use_layer_norm + self._trainable_embedding_size = trainable_embedding_size + self._embedding_source = embedding_source + self._vocab_file = vocab_file + self._dropout = dropout + self._pad_token = pad_token + if self._embedding_source: + assert vocab_file + + def _build(self, sequence, sequence_length, is_training=True): + """Connect to the graph. + + Args: + sequence: A [batch_size, max_sequence_length] tensor of int. For example + the indices of words as sampled by the generator. + sequence_length: A [batch_size] tensor of int. Length of the sequence. + is_training: Boolean, False to disable dropout. + + Returns: + A [batch_size, max_sequence_length, feature_size] tensor of floats. For + each sequence in the batch, the features should (hopefully) allow to + distinguish if the value at each timestep is real or generated. + """ + batch_size, max_sequence_length = sequence.shape.as_list() + keep_prob = (1.0 - self._dropout) if is_training else 1.0 + + if self._embedding_source: + all_embeddings = utils.make_partially_trainable_embeddings( + self._vocab_file, self._embedding_source, self._vocab_size, + self._trainable_embedding_size) + else: + all_embeddings = tf.get_variable( + 'trainable_embedding', + shape=[self._vocab_size, self._trainable_embedding_size], + trainable=True) + _, self._embedding_size = all_embeddings.shape.as_list() + input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob) + embeddings = tf.nn.embedding_lookup(input_embeddings, sequence) + embeddings.shape.assert_is_compatible_with( + [batch_size, max_sequence_length, self._embedding_size]) + position_dim = 8 + embeddings_pos = utils.append_position_signal(embeddings, position_dim) + embeddings_pos = tf.reshape( + embeddings_pos, + [batch_size * max_sequence_length, self._embedding_size + position_dim]) + lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos) + lstm_inputs = tf.reshape( + lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]]) + lstm_inputs.shape.assert_is_compatible_with( + [batch_size, max_sequence_length, self._feature_sizes[0]]) + + encoder_cells = [] + for feature_size in self._feature_sizes: + encoder_cells += [ + snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm) + ] + encoder_cell = snt.DeepRNN(encoder_cells) + initial_state = encoder_cell.initial_state(batch_size) + + hidden_states, _ = tf.nn.dynamic_rnn( + cell=encoder_cell, + inputs=lstm_inputs, + sequence_length=sequence_length, + initial_state=initial_state, + swap_memory=True) + + hidden_states.shape.assert_is_compatible_with( + [batch_size, max_sequence_length, + sum(self._feature_sizes)]) + logits = snt.BatchApply(snt.Linear(1))(hidden_states) + logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1]) + logits_flat = tf.reshape(logits, [batch_size, max_sequence_length]) + + # Mask past first PAD symbol + # + # Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking + # into account the sequence_length properly, because otherwise + # the logits at a given timestep will depend on the inputs for all other + # timesteps, including the ones that should be masked. + mask = utils.get_mask_past_symbol(sequence, self._pad_token) + masked_logits_flat = logits_flat * tf.cast(mask, tf.float32) + return masked_logits_flat diff --git a/scratchgan/eval_metrics.py b/scratchgan/eval_metrics.py new file mode 100644 index 0000000..7b6bbb4 --- /dev/null +++ b/scratchgan/eval_metrics.py @@ -0,0 +1,48 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow_gan as tfgan +import tensorflow_hub as hub + + +def fid(generated_sentences, real_sentences): + """Compute FID rn sentences using pretrained universal sentence encoder. + + Args: + generated_sentences: list of N strings. + real_sentences: list of N strings. + + Returns: + Frechet distance between activations. + """ + embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2") + real_embed = embed(real_sentences) + generated_embed = embed(generated_sentences) + distance = tfgan.eval.frechet_classifier_distance_from_activations( + real_embed, generated_embed) + + # Restrict the thread pool size to prevent excessive CPU usage. + config = tf.ConfigProto() + config.intra_op_parallelism_threads = 16 + config.inter_op_parallelism_threads = 16 + with tf.Session(config=config) as session: + session.run(tf.global_variables_initializer()) + session.run(tf.tables_initializer()) + distance_np = session.run(distance) + return distance_np diff --git a/scratchgan/experiment.py b/scratchgan/experiment.py new file mode 100644 index 0000000..de68444 --- /dev/null +++ b/scratchgan/experiment.py @@ -0,0 +1,469 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training script for ScratchGAN.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import time + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow as tf +from tensorflow.io import gfile +from scratchgan import discriminator_nets +from scratchgan import eval_metrics +from scratchgan import generators +from scratchgan import losses +from scratchgan import reader +from scratchgan import utils + +flags.DEFINE_string("dataset", "emnlp2017", "Dataset.") +flags.DEFINE_integer("batch_size", 512, "Batch size") +flags.DEFINE_string("gen_type", "lstm", "Generator type.") +flags.DEFINE_string("disc_type", "lstm", "Discriminator type.") +flags.DEFINE_string("disc_loss_type", "ce", "Loss type.") +flags.DEFINE_integer("gen_feature_size", 512, "Generator feature size.") +flags.DEFINE_integer("disc_feature_size", 512, "Discriminator feature size.") +flags.DEFINE_integer("num_layers_gen", 2, "Number of generator layers.") +flags.DEFINE_integer("num_layers_disc", 1, "Number of discriminator layers.") +flags.DEFINE_bool("layer_norm_gen", False, "Layer norm generator.") +flags.DEFINE_bool("layer_norm_disc", True, "Layer norm discriminator.") +flags.DEFINE_float("gen_input_dropout", 0.0, "Input dropout generator.") +flags.DEFINE_float("gen_output_dropout", 0.0, "Input dropout discriminator.") +flags.DEFINE_float("l2_gen", 0.0, "L2 regularization generator.") +flags.DEFINE_float("l2_disc", 1e-6, "L2 regularization discriminator.") +flags.DEFINE_float("disc_dropout", 0.1, "Dropout discriminator") +flags.DEFINE_integer("trainable_embedding_size", 64, + "Size of trainable embedding.") +flags.DEFINE_bool("use_pretrained_embedding", True, "Use pretrained embedding.") +flags.DEFINE_integer("num_steps", int(200 * 1000), "Number of training steps.") +flags.DEFINE_integer("num_disc_updates", 1, "Number of discriminator updates.") +flags.DEFINE_integer("num_gen_updates", 1, "Number of generator updates.") +flags.DEFINE_string("data_dir", "/tmp/emnlp2017", "Directory where data is.") +flags.DEFINE_float("gen_lr", 9.59e-5, "Learning rate generator.") +flags.DEFINE_float("disc_lr", 9.38e-3, "Learning rate discriminator.") +flags.DEFINE_float("gen_beta1", 0.5, "Beta1 for generator.") +flags.DEFINE_float("disc_beta1", 0.5, "Beta1 for discriminator.") +flags.DEFINE_float("gamma", 0.23, "Discount factor.") +flags.DEFINE_float("baseline_decay", 0.08, "Baseline decay rate.") +flags.DEFINE_string("mode", "train", "train or evaluate_pair.") +flags.DEFINE_string("checkpoint_dir", "/tmp/emnlp2017/checkpoints/", + "Directory for checkpoints.") +flags.DEFINE_integer("export_every", 1000, "Frequency of checkpoint exports.") +flags.DEFINE_integer("num_examples_for_eval", int(1e4), + "Number of examples for evaluation") + +EVALUATOR_SLEEP_PERIOD = 60 # Seconds evaluator sleeps if nothing to do. + + +def main(_): + config = flags.FLAGS + + gfile.makedirs(config.checkpoint_dir) + if config.mode == "train": + train(config) + elif config.mode == "evaluate_pair": + while True: + checkpoint_path = utils.maybe_pick_models_to_evaluate( + checkpoint_dir=config.checkpoint_dir) + if checkpoint_path: + evaluate_pair( + config=config, + batch_size=config.batch_size, + checkpoint_path=checkpoint_path, + data_dir=config.data_dir, + dataset=config.dataset, + num_examples_for_eval=config.num_examples_for_eval) + else: + logging.info("No models to evaluate found, sleeping for %d seconds", + EVALUATOR_SLEEP_PERIOD) + time.sleep(EVALUATOR_SLEEP_PERIOD) + else: + raise Exception( + "Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\"" + % (config.mode)) + + +def train(config): + """Train.""" + logging.info("Training.") + + tf.reset_default_graph() + np.set_printoptions(precision=4) + + # Get data. + raw_data = reader.get_raw_data( + data_path=config.data_dir, dataset=config.dataset) + train_data, valid_data, word_to_id = raw_data + id_to_word = {v: k for k, v in word_to_id.iteritems()} + vocab_size = len(word_to_id) + max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset] + logging.info("Vocabulary size: %d", vocab_size) + + iterator = reader.iterator(raw_data=train_data, batch_size=config.batch_size) + iterator_valid = reader.iterator( + raw_data=valid_data, batch_size=config.batch_size) + + real_sequence = tf.placeholder( + dtype=tf.int32, + shape=[config.batch_size, max_length], + name="real_sequence") + real_sequence_length = tf.placeholder( + dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length") + first_batch_np = iterator.next() + valid_batch_np = iterator_valid.next() + + test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()} + test_fake_batch = { + "sequence": + tf.constant( + np.random.choice( + vocab_size, size=[config.batch_size, + max_length]).astype(np.int32)), + "sequence_length": + tf.constant( + np.random.choice(max_length, + size=[config.batch_size]).astype(np.int32)), + } + valid_batch = {k: tf.constant(v) for k, v in valid_batch_np.items()} + + # Create generator. + if config.use_pretrained_embedding: + embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) + vocab_file = "/tmp/vocab.txt" + with gfile.GFile(vocab_file, "w") as f: + for i in xrange(len(id_to_word)): + f.write(id_to_word[i] + "\n") + logging.info("Temporary vocab file: %s", vocab_file) + else: + embedding_source = None + vocab_file = None + + gen = generators.LSTMGen( + vocab_size=vocab_size, + feature_sizes=[config.gen_feature_size] * config.num_layers_gen, + max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset], + batch_size=config.batch_size, + use_layer_norm=config.layer_norm_gen, + trainable_embedding_size=config.trainable_embedding_size, + input_dropout=config.gen_input_dropout, + output_dropout=config.gen_output_dropout, + pad_token=reader.PAD_INT, + embedding_source=embedding_source, + vocab_file=vocab_file, + ) + gen_outputs = gen() + + # Create discriminator. + disc = discriminator_nets.LSTMEmbedDiscNet( + vocab_size=vocab_size, + feature_sizes=[config.disc_feature_size] * config.num_layers_disc, + trainable_embedding_size=config.trainable_embedding_size, + embedding_source=embedding_source, + use_layer_norm=config.layer_norm_disc, + pad_token=reader.PAD_INT, + vocab_file=vocab_file, + dropout=config.disc_dropout, + ) + disc_logits_real = disc( + sequence=real_sequence, sequence_length=real_sequence_length) + disc_logits_fake = disc( + sequence=gen_outputs["sequence"], + sequence_length=gen_outputs["sequence_length"]) + + # Loss of the discriminator. + if config.disc_loss_type == "ce": + targets_real = tf.ones( + [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) + targets_fake = tf.zeros( + [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) + loss_real = losses.sequential_cross_entropy_loss(disc_logits_real, + targets_real) + loss_fake = losses.sequential_cross_entropy_loss(disc_logits_fake, + targets_fake) + disc_loss = 0.5 * loss_real + 0.5 * loss_fake + + # Loss of the generator. + gen_loss, cumulative_rewards, baseline = losses.reinforce_loss( + disc_logits=disc_logits_fake, + gen_logprobs=gen_outputs["logprobs"], + gamma=config.gamma, + decay=config.baseline_decay) + + # Optimizers + disc_optimizer = tf.train.AdamOptimizer( + learning_rate=config.disc_lr, beta1=config.disc_beta1) + gen_optimizer = tf.train.AdamOptimizer( + learning_rate=config.gen_lr, beta1=config.gen_beta1) + + # Get losses and variables. + disc_vars = disc.get_all_variables() + gen_vars = gen.get_all_variables() + l2_disc = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in disc_vars])) + l2_gen = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in gen_vars])) + scalar_disc_loss = tf.reduce_mean(disc_loss) + config.l2_disc * l2_disc + scalar_gen_loss = tf.reduce_mean(gen_loss) + config.l2_gen * l2_gen + + # Update ops. + global_step = tf.train.get_or_create_global_step() + disc_update = disc_optimizer.minimize( + scalar_disc_loss, var_list=disc_vars, global_step=global_step) + gen_update = gen_optimizer.minimize( + scalar_gen_loss, var_list=gen_vars, global_step=global_step) + + # Saver. + saver = tf.train.Saver() + + # Metrics + test_disc_logits_real = disc(**test_real_batch) + test_disc_logits_fake = disc(**test_fake_batch) + valid_disc_logits = disc(**valid_batch) + disc_predictions_real = tf.nn.sigmoid(disc_logits_real) + disc_predictions_fake = tf.nn.sigmoid(disc_logits_fake) + valid_disc_predictions = tf.reduce_mean( + tf.nn.sigmoid(valid_disc_logits), axis=0) + test_disc_predictions_real = tf.reduce_mean( + tf.nn.sigmoid(test_disc_logits_real), axis=0) + test_disc_predictions_fake = tf.reduce_mean( + tf.nn.sigmoid(test_disc_logits_fake), axis=0) + + # Only log results for the first element of the batch. + metrics = { + "scalar_gen_loss": scalar_gen_loss, + "scalar_disc_loss": scalar_disc_loss, + "disc_predictions_real": tf.reduce_mean(disc_predictions_real), + "disc_predictions_fake": tf.reduce_mean(disc_predictions_fake), + "test_disc_predictions_real": tf.reduce_mean(test_disc_predictions_real), + "test_disc_predictions_fake": tf.reduce_mean(test_disc_predictions_fake), + "valid_disc_predictions": tf.reduce_mean(valid_disc_predictions), + "cumulative_rewards": tf.reduce_mean(cumulative_rewards), + "baseline": tf.reduce_mean(baseline), + } + + # Training. + logging.info("Starting training") + with tf.Session() as sess: + + sess.run(tf.global_variables_initializer()) + latest_ckpt = tf.train.latest_checkpoint(config.checkpoint_dir) + if latest_ckpt: + saver.restore(sess, latest_ckpt) + + for step in xrange(config.num_steps): + real_data_np = iterator.next() + train_feed = { + real_sequence: real_data_np["sequence"], + real_sequence_length: real_data_np["sequence_length"], + } + + # Update generator and discriminator. + for _ in xrange(config.num_disc_updates): + sess.run(disc_update, feed_dict=train_feed) + for _ in xrange(config.num_gen_updates): + sess.run(gen_update, feed_dict=train_feed) + + # Reporting + if step % config.export_every == 0: + gen_sequence_np, metrics_np = sess.run( + [gen_outputs["sequence"], metrics], feed_dict=train_feed) + metrics_np["gen_sentence"] = utils.sequence_to_sentence( + gen_sequence_np[0, :], id_to_word) + saver.save( + sess, + save_path=config.checkpoint_dir + "scratchgan", + global_step=global_step) + metrics_np["model_path"] = tf.train.latest_checkpoint( + config.checkpoint_dir) + logging.info(metrics_np) + + # After training, export models. + saver.save( + sess, + save_path=config.checkpoint_dir + "scratchgan", + global_step=global_step) + logging.info("Saved final model at %s.", + tf.train.latest_checkpoint(config.checkpoint_dir)) + + +def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset, + num_examples_for_eval): + """Evaluates a pair generator discriminator. + + This function loads a discriminator from disk, a generator, and evaluates the + discriminator against the generator. + + It returns the mean probability of the discriminator against several batches, + and the FID of the generator against the validation data. + + It also writes evaluation samples to disk. + + Args: + config: dict, the config file. + batch_size: int, size of the batch. + checkpoint_path: string, full path to the TF checkpoint on disk. + data_dir: string, path to a directory containing the dataset. + dataset: string, "emnlp2017", to select the right dataset. + num_examples_for_eval: int, number of examples for evaluation. + """ + tf.reset_default_graph() + logging.info("Evaluating checkpoint %s.", checkpoint_path) + + # Build graph. + train_data, valid_data, word_to_id = reader.get_raw_data( + data_dir, dataset=dataset) + id_to_word = {v: k for k, v in word_to_id.iteritems()} + vocab_size = len(word_to_id) + train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size) + valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size) + train_sequence = tf.placeholder( + dtype=tf.int32, + shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]], + name="train_sequence") + train_sequence_length = tf.placeholder( + dtype=tf.int32, shape=[batch_size], name="train_sequence_length") + valid_sequence = tf.placeholder( + dtype=tf.int32, + shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]], + name="valid_sequence") + valid_sequence_length = tf.placeholder( + dtype=tf.int32, shape=[batch_size], name="valid_sequence_length") + disc_inputs_train = { + "sequence": train_sequence, + "sequence_length": train_sequence_length, + } + disc_inputs_valid = { + "sequence": valid_sequence, + "sequence_length": valid_sequence_length, + } + if config.use_pretrained_embedding: + embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) + vocab_file = "/tmp/vocab.txt" + with gfile.GFile(vocab_file, "w") as f: + for i in xrange(len(id_to_word)): + f.write(id_to_word[i] + "\n") + logging.info("Temporary vocab file: %s", vocab_file) + else: + embedding_source = None + vocab_file = None + gen = generators.LSTMGen( + vocab_size=vocab_size, + feature_sizes=[config.gen_feature_size] * config.num_layers_gen, + max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset], + batch_size=config.batch_size, + use_layer_norm=config.layer_norm_gen, + trainable_embedding_size=config.trainable_embedding_size, + input_dropout=config.gen_input_dropout, + output_dropout=config.gen_output_dropout, + pad_token=reader.PAD_INT, + embedding_source=embedding_source, + vocab_file=vocab_file, + ) + gen_outputs = gen() + + disc = discriminator_nets.LSTMEmbedDiscNet( + vocab_size=vocab_size, + feature_sizes=[config.disc_feature_size] * config.num_layers_disc, + trainable_embedding_size=config.trainable_embedding_size, + embedding_source=embedding_source, + use_layer_norm=config.layer_norm_disc, + pad_token=reader.PAD_INT, + vocab_file=vocab_file, + dropout=config.disc_dropout, + ) + + disc_inputs = { + "sequence": gen_outputs["sequence"], + "sequence_length": gen_outputs["sequence_length"], + } + gen_logits = disc(**disc_inputs) + train_logits = disc(**disc_inputs_train) + valid_logits = disc(**disc_inputs_valid) + + # Saver. + saver = tf.train.Saver() + + # Reduce over time and batch. + train_probs = tf.reduce_mean(tf.nn.sigmoid(train_logits)) + valid_probs = tf.reduce_mean(tf.nn.sigmoid(valid_logits)) + gen_probs = tf.reduce_mean(tf.nn.sigmoid(gen_logits)) + + outputs = { + "train_probs": train_probs, + "valid_probs": valid_probs, + "gen_probs": gen_probs, + "gen_sequences": gen_outputs["sequence"], + "valid_sequences": valid_sequence + } + + # Get average discriminator score and store generated sequences. + all_valid_sentences = [] + all_gen_sentences = [] + all_gen_sequences = [] + mean_train_prob = 0.0 + mean_valid_prob = 0.0 + mean_gen_prob = 0.0 + + logging.info("Graph constructed, generating batches.") + num_batches = num_examples_for_eval // batch_size + 1 + + # Restrict the thread pool size to prevent excessive GCU usage on Borg. + tf_config = tf.ConfigProto() + tf_config.intra_op_parallelism_threads = 16 + tf_config.inter_op_parallelism_threads = 16 + + with tf.Session(config=tf_config) as sess: + + # Restore variables from checkpoints. + logging.info("Restoring variables.") + saver.restore(sess, checkpoint_path) + + for i in xrange(num_batches): + logging.info("Batch %d / %d", i, num_batches) + train_data_np = train_iterator.next() + valid_data_np = valid_iterator.next() + feed_dict = { + train_sequence: train_data_np["sequence"], + train_sequence_length: train_data_np["sequence_length"], + valid_sequence: valid_data_np["sequence"], + valid_sequence_length: valid_data_np["sequence_length"], + } + outputs_np = sess.run(outputs, feed_dict=feed_dict) + all_gen_sequences.extend(outputs_np["gen_sequences"]) + gen_sentences = utils.batch_sequences_to_sentences( + outputs_np["gen_sequences"], id_to_word) + valid_sentences = utils.batch_sequences_to_sentences( + outputs_np["valid_sequences"], id_to_word) + all_valid_sentences.extend(valid_sentences) + all_gen_sentences.extend(gen_sentences) + mean_train_prob += outputs_np["train_probs"] / batch_size + mean_valid_prob += outputs_np["valid_probs"] / batch_size + mean_gen_prob += outputs_np["gen_probs"] / batch_size + + logging.info("Evaluating FID.") + + # Compute FID + fid = eval_metrics.fid( + generated_sentences=all_gen_sentences[:num_examples_for_eval], + real_sentences=all_valid_sentences[:num_examples_for_eval]) + + utils.write_eval_results(config.checkpoint_dir, all_gen_sentences, + os.path.basename(checkpoint_path), mean_train_prob, + mean_valid_prob, mean_gen_prob, fid) + + +if __name__ == "__main__": + app.run(main) diff --git a/scratchgan/generators.py b/scratchgan/generators.py new file mode 100644 index 0000000..2215734 --- /dev/null +++ b/scratchgan/generators.py @@ -0,0 +1,147 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generators for text data.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +from scratchgan import utils + + +class LSTMGen(snt.AbstractModule): + """A multi-layer LSTM language model. + + Uses tied input/output embedding weights. + """ + + def __init__(self, + vocab_size, + feature_sizes, + max_sequence_length, + batch_size, + use_layer_norm, + trainable_embedding_size, + input_dropout, + output_dropout, + pad_token, + embedding_source=None, + vocab_file=None, + name='lstm_gen'): + super(LSTMGen, self).__init__(name=name) + self._feature_sizes = feature_sizes + self._max_sequence_length = max_sequence_length + self._vocab_size = vocab_size + self._batch_size = batch_size + self._use_layer_norm = use_layer_norm + self._trainable_embedding_size = trainable_embedding_size + self._embedding_source = embedding_source + self._vocab_file = vocab_file + self._input_dropout = input_dropout + self._output_dropout = output_dropout + self._pad_token = pad_token + if self._embedding_source: + assert vocab_file + + def _build(self, is_training=True, temperature=1.0): + input_keep_prob = (1. - self._input_dropout) if is_training else 1.0 + output_keep_prob = (1. - self._output_dropout) if is_training else 1.0 + + batch_size = self._batch_size + max_sequence_length = self._max_sequence_length + if self._embedding_source: + all_embeddings = utils.make_partially_trainable_embeddings( + self._vocab_file, self._embedding_source, self._vocab_size, + self._trainable_embedding_size) + else: + all_embeddings = tf.get_variable( + 'trainable_embeddings', + shape=[self._vocab_size, self._trainable_embedding_size], + trainable=True) + _, self._embedding_size = all_embeddings.shape.as_list() + input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=input_keep_prob) + output_embeddings = tf.nn.dropout( + all_embeddings, keep_prob=output_keep_prob) + + out_bias = tf.get_variable( + 'out_bias', shape=[1, self._vocab_size], dtype=tf.float32) + in_proj = tf.get_variable( + 'in_proj', shape=[self._embedding_size, self._feature_sizes[0]]) + # If more than 1 layer, then output has dim sum(self._feature_sizes), + # which is different from input dim == self._feature_sizes[0] + # So we need a different projection matrix for input and output. + if len(self._feature_sizes) > 1: + out_proj = tf.get_variable( + 'out_proj', shape=[self._embedding_size, + sum(self._feature_sizes)]) + else: + out_proj = in_proj + + encoder_cells = [] + for feature_size in self._feature_sizes: + encoder_cells += [ + snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm) + ] + encoder_cell = snt.DeepRNN(encoder_cells) + state = encoder_cell.initial_state(batch_size) + + # Manual unrolling. + samples_list, logits_list, logprobs_list, embeddings_list = [], [], [], [] + sample = tf.tile( + tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size]) + logging.info('Unrolling over %d steps.', max_sequence_length) + for _ in xrange(max_sequence_length): + # Input is sampled word at t-1. + embedding = tf.nn.embedding_lookup(input_embeddings, sample) + embedding.shape.assert_is_compatible_with( + [batch_size, self._embedding_size]) + embedding_proj = tf.matmul(embedding, in_proj) + embedding_proj.shape.assert_is_compatible_with( + [batch_size, self._feature_sizes[0]]) + + outputs, state = encoder_cell(embedding_proj, state) + outputs_proj = tf.matmul(outputs, out_proj, transpose_b=True) + logits = tf.matmul( + outputs_proj, output_embeddings, transpose_b=True) + out_bias + categorical = tfp.distributions.Categorical(logits=logits/temperature) + sample = categorical.sample() + logprobs = categorical.log_prob(sample) + + samples_list.append(sample) + logits_list.append(logits) + logprobs_list.append(logprobs) + embeddings_list.append(embedding) + + # Create an op to retrieve embeddings for full sequence, useful for testing. + embeddings = tf.stack( # pylint: disable=unused-variable + embeddings_list, + axis=1, + name='embeddings') + sequence = tf.stack(samples_list, axis=1) + logprobs = tf.stack(logprobs_list, axis=1) + + # The sequence stops after the first occurrence of a PAD token. + sequence_length = utils.get_first_occurrence_indices( + sequence, self._pad_token) + mask = utils.get_mask_past_symbol(sequence, self._pad_token) + masked_sequence = sequence * tf.cast(mask, tf.int32) + masked_logprobs = logprobs * tf.cast(mask, tf.float32) + return { + 'sequence': masked_sequence, + 'sequence_length': sequence_length, + 'logprobs': masked_logprobs + } diff --git a/scratchgan/losses.py b/scratchgan/losses.py new file mode 100644 index 0000000..a3a7ed0 --- /dev/null +++ b/scratchgan/losses.py @@ -0,0 +1,93 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Losses for sequential GANs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +def sequential_cross_entropy_loss(logits, expected): + """The cross entropy loss for binary classification. + + Used to train the discriminator when not using WGAN loss. + Assume logits is the log probability of classifying as 1. (real). + + Args: + logits: a `tf.Tensor`, the model produced logits, shape [batch_size, + sequence_length]. + expected: a `tf.Tensor`, the expected output, shape [batch_size, + sequence_length]. + + Returns: + A scalar `tf.Tensor`, the average loss obtained on the given inputs. + """ + batch_size, sequence_length = logits.shape.as_list() + expected = tf.cast(expected, tf.float32) + + ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=expected, logits=logits) + return tf.reshape(ce, [batch_size, sequence_length]) + + +def reinforce_loss(disc_logits, gen_logprobs, gamma, decay): + """The REINFORCE loss. + + Args: + disc_logits: float tensor, shape [batch_size, sequence_length]. + gen_logprobs: float32 tensor, shape [batch_size, sequence_length] + gamma: a float, discount factor for cumulative reward. + decay: a float, decay rate for the EWMA baseline of REINFORCE. + + Returns: + Float tensor, shape [batch_size, sequence_length], the REINFORCE loss for + each timestep. + """ + # Assume 1 logit for each timestep. + batch_size, sequence_length = disc_logits.shape.as_list() + gen_logprobs.shape.assert_is_compatible_with([batch_size, sequence_length]) + + disc_predictions = tf.nn.sigmoid(disc_logits) + + # MaskGAN uses log(D), but this is more stable empirically. + rewards = 2.0 * disc_predictions - 1 + + # Compute cumulative rewards. + rewards_list = tf.unstack(rewards, axis=1) + cumulative_rewards = [] + for t in xrange(sequence_length): + cum_value = tf.zeros(shape=[batch_size]) + for s in xrange(t, sequence_length): + cum_value += np.power(gamma, (s - t)) * rewards_list[s] + cumulative_rewards.append(cum_value) + cumulative_rewards = tf.stack(cumulative_rewards, axis=1) + + cumulative_rewards.shape.assert_is_compatible_with( + [batch_size, sequence_length]) + + with tf.variable_scope("reinforce", reuse=tf.AUTO_REUSE): + ewma_reward = tf.get_variable("ewma_reward", initializer=0.0) + + mean_reward = tf.reduce_mean(cumulative_rewards) + new_ewma_reward = decay * ewma_reward + (1.0 - decay) * mean_reward + update_op = tf.assign(ewma_reward, new_ewma_reward) + + # REINFORCE + with tf.control_dependencies([update_op]): + advantage = cumulative_rewards - ewma_reward + loss = -tf.stop_gradient(advantage) * gen_logprobs + + loss.shape.assert_is_compatible_with([batch_size, sequence_length]) + return loss, cumulative_rewards, ewma_reward diff --git a/scratchgan/reader.py b/scratchgan/reader.py new file mode 100644 index 0000000..f3fe1b5 --- /dev/null +++ b/scratchgan/reader.py @@ -0,0 +1,174 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for parsing text files.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json +import os + +from absl import logging +import numpy as np +from tensorflow.io import gfile + +# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32 +# lengths: [N, 2] array of int32, such that +# lengths[i, 0] is the number of non-pad tokens in sequences[i, :] +# TODO(cyprien): handle PTB +FILENAMES = { + "emnlp2017": ("train.json", "valid.json", "test.json"), +} + +# EMNLP2017 sentences have max length 50, add one for a PAD token so that all +# sentences end with PAD. +MAX_TOKENS_SEQUENCE = {"emnlp2017": 52} + +UNK = "" +PAD = " " + +PAD_INT = 0 + + +def tokenize(sentence): + """Split a string into words.""" + return sentence.split(" ") + [PAD] + + +def _build_vocab(json_data): + """Builds full vocab from json data.""" + vocab = collections.Counter() + for sentence in json_data: + tokens = tokenize(sentence["s"]) + vocab.update(tokens) + for title in sentence["t"]: + title_tokens = tokenize(title) + vocab.update(title_tokens) + # Most common words first. + count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0])) + words, _ = zip(*count_pairs) + words = list(words) + if UNK not in words: + words = [UNK] + words + word_to_id = dict(zip(words, range(len(words)))) + + # Tokens are now sorted by frequency. There's no guarantee that `PAD` will + # end up at `PAD_INT` index. Enforce it by swapping whatever token is + # currently at the `PAD_INT` index with the `PAD` token. + word = word_to_id.keys()[word_to_id.values().index(PAD_INT)] + word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD] + assert word_to_id[PAD] == PAD_INT + + return word_to_id + + +def string_sequence_to_sequence(string_sequence, word_to_id): + result = [] + for word in string_sequence: + if word in word_to_id: + result.append(word_to_id[word]) + else: + result.append(word_to_id[UNK]) + return result + + +def _integerize(json_data, word_to_id, dataset): + """Transform words into integers.""" + sequences = np.full((len(json_data), MAX_TOKENS_SEQUENCE[dataset]), + word_to_id[PAD], np.int32) + sequence_lengths = np.zeros(shape=(len(json_data)), dtype=np.int32) + for i, sentence in enumerate(json_data): + sequence_i = string_sequence_to_sequence( + tokenize(sentence["s"]), word_to_id) + sequence_lengths[i] = len(sequence_i) + sequences[i, :sequence_lengths[i]] = np.array(sequence_i) + return { + "sequences": sequences, + "sequence_lengths": sequence_lengths, + } + + +def get_raw_data(data_path, dataset, truncate_vocab=20000): + """Load raw data from data directory "data_path". + + Reads text files, converts strings to integer ids, + and performs mini-batching of the inputs. + + Args: + data_path: string path to the directory where simple-examples.tgz has been + extracted. + dataset: one of ["emnlp2017"] + truncate_vocab: int, number of words to keep in the vocabulary. + + Returns: + tuple (train_data, valid_data, vocabulary) where each of the data + objects can be passed to iterator. + + Raises: + ValueError: dataset not in ["emnlp2017"]. + """ + if dataset not in FILENAMES: + raise ValueError("Invalid dataset {}. Valid datasets: {}".format( + dataset, FILENAMES.keys())) + train_file, valid_file, _ = FILENAMES[dataset] + + train_path = os.path.join(data_path, train_file) + valid_path = os.path.join(data_path, valid_file) + + with gfile.GFile(train_path, "r") as json_file: + json_data_train = json.load(json_file) + with gfile.GFile(valid_path, "r") as json_file: + json_data_valid = json.load(json_file) + + word_to_id = _build_vocab(json_data_train) + logging.info("Full vocab length: %d", len(word_to_id)) + # Assume the vocab is sorted by frequency. + word_to_id_truncated = { + k: v for k, v in word_to_id.items() if v < truncate_vocab + } + logging.info("Truncated vocab length: %d", len(word_to_id_truncated)) + + train_data = _integerize(json_data_train, word_to_id_truncated, dataset) + valid_data = _integerize(json_data_valid, word_to_id_truncated, dataset) + return train_data, valid_data, word_to_id_truncated + + +def iterator(raw_data, batch_size, random=False): + """Looping iterators on the raw data.""" + sequences = raw_data["sequences"] + sequence_lengths = raw_data["sequence_lengths"] + + num_examples = sequences.shape[0] + indice_range = np.arange(num_examples) + + if random: + while True: + indices = np.random.choice(indice_range, size=batch_size, replace=True) + yield { + "sequence": sequences[indices, :], + "sequence_length": sequence_lengths[indices], + } + else: + start = 0 + while True: + sequence = sequences[start:(start + batch_size), :] + sequence_length = sequence_lengths[start:(start + batch_size)] + start += batch_size + if start + batch_size > num_examples: + start = (start + batch_size) % num_examples + yield { + "sequence": sequence, + "sequence_length": sequence_length, + } diff --git a/scratchgan/requirements.txt b/scratchgan/requirements.txt new file mode 100644 index 0000000..56ee034 --- /dev/null +++ b/scratchgan/requirements.txt @@ -0,0 +1,8 @@ +absl-py==0.7.1 +dm-sonnet==1.34 +numpy==1.16.4 +tensorflow==1.14 +tensorflow-probability==0.7.0 +tensorflow-gan == 1.0.0.dev0 +tensorflow-hub == 0.6.0 +tensorflow-io == 0.8.0 diff --git a/scratchgan/utils.py b/scratchgan/utils.py new file mode 100644 index 0000000..3e317df --- /dev/null +++ b/scratchgan/utils.py @@ -0,0 +1,308 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import math +import os +from absl import logging +import numpy as np +import tensorflow as tf +from tensorflow.io import gfile +from scratchgan import reader + +EVAL_FILENAME = "evaluated_checkpoints.csv" +GLOVE_DIM = 300 +GLOVE_STD = 0.3836 # Standard dev. of GloVe embeddings. + + +def _get_embedding_initializer(vocab_file, embedding_source, vocab_size): + """Loads pretrained embeddings from a file in GloVe format.""" + with gfile.GFile(embedding_source, "r") as f: + embedding_lines = f.readlines() + + # First line contains embedding dim. + _, embedding_dim = map(int, embedding_lines[0].split()) + # Get the tokens as strings. + tokens = [line.split()[0] for line in embedding_lines[1:]] + # Get the actual embedding matrix. + unsorted_emb = np.array( + [[float(x) for x in line.split()[1:]] for line in embedding_lines[1:]]) + + # Get the expected vocab order. + with gfile.GFile(vocab_file, "r") as f: + tokens_order = [l.strip() for l in f.readlines()] + assert vocab_size == len(tokens_order) + + # Put the embeddings in the order. + sorted_emb = np.zeros((vocab_size, embedding_dim)) + for i, token in enumerate(tokens_order): + if token in tokens: + sorted_emb[i, :] = unsorted_emb[tokens.index(token), :] + else: # If we don't have a pretrained embedding, initialize randomly. + sorted_emb[i, :] = np.random.normal( + loc=0.0, scale=GLOVE_STD, size=(GLOVE_DIM,)) + + return sorted_emb.astype(np.float32) + + +def append_position_signal(embeddings, position_dim=8): + """Append position signal. See get_position_signal.""" + batch_size, sequence_length, embedding_dim = embeddings.get_shape().as_list() + positions = get_position_signal(sequence_length, position_dim) + + # Append to embeddings. + position_inputs = tf.tile(positions[None, :, :], [batch_size, 1, 1]) + embeddings_pos = tf.concat([embeddings, position_inputs], axis=2) + embeddings_pos.shape.assert_is_compatible_with( + [batch_size, sequence_length, embedding_dim + position_dim]) + return embeddings_pos + + +def get_position_signal(sequence_length, position_dim=8): + """Return fixed position signal as sine waves. + + Sine waves frequencies are linearly spaced so that shortest is 2 and + longest is half the maximum length. That way the longest frequency + is long enough to be monotonous over the whole sequence length. + Sine waves are also shifted so that they don't all start with the same + value. + We don't use learned positional embeddings because these embeddings are + projected linearly along with the original embeddings, and the projection is + learned. + + Args: + sequence_length: int, T, length of the sequence.. + position_dim: int, P, number of sine waves. + + Returns: + A [T, P] tensor, position embeddings. + """ + # Compute the frequencies. + periods = tf.exp( + tf.lin_space( + tf.log(2.0), tf.log(tf.to_float(sequence_length)), position_dim)) + frequencies = 1.0 / periods # Shape [T, P]. + + # Compute the sine waves. + xs = frequencies[None, :] * tf.to_float(tf.range(sequence_length)[:, None]) + shifts = tf.lin_space(0.0, 2.0, position_dim)[None, :] # [1, P] + positions = tf.math.cos(math.pi * (xs + shifts)) # [T, P] + positions.shape.assert_is_compatible_with([sequence_length, position_dim]) + return positions + + +def get_mask_by_length(lengths, max_length): + """Returns a mask where x[i , j] = (j < lengths[i]). + + Args: + lengths: [B] tensor of int32 such that 0 <= lengths[i] <= max_length. + max_length: scalar tensor of int32. + + Returns: + [B, max_length] tensor of booleans such that x[i, j] is True + if and only if j < lengths[i]. + """ + batch_size = lengths.get_shape().as_list()[0] + indices = tf.range(start=0, limit=max_length) + all_indices = tf.tile(indices[None, :], [batch_size, 1]) + all_lengths = tf.tile(lengths[:, None], [1, max_length]) + mask = (all_indices < all_lengths) + mask_boolean = tf.cast(mask, tf.bool) + return mask_boolean + + +def get_mask_past_symbol(reference, symbol, optimize_for_tpu=False): + """For each row, mask is True before and at the first occurrence of symbol.""" + batch_size, max_length = reference.get_shape().as_list() + symbol = tf.convert_to_tensor(symbol) + symbol.shape.assert_is_compatible_with([]) + + first_indices = get_first_occurrence_indices(reference, symbol, + optimize_for_tpu) + first_indices.shape.assert_is_compatible_with([batch_size]) + + keep_lengths = tf.minimum(first_indices, max_length) + mask = get_mask_by_length(keep_lengths, max_length) + mask.shape.assert_is_compatible_with([batch_size, max_length]) + mask.set_shape([batch_size, max_length]) + return mask + + +def get_first_occurrence_indices(reference, symbol, optimize_for_tpu=False): + """For each row in reference, get index after the first occurrence of symbol. + + If symbol is not present on a row, return reference.shape[1] instead. + + Args: + reference: [B, T] tensor of elements of the same type as symbol. + symbol: int or [] scalar tensor of the same dtype as symbol. + optimize_for_tpu: bool, whether to use a TPU-capable variant. + + Returns: + A [B] reference of tf.int32 where x[i] is such that + reference[i, x[i]-1] == symbol, and reference[i, j] != symbol + for j= batch size since there can be + # several `symbol` in one row of tensor. We need to take only the position + # of the first occurrence for each row. `segment_min` does that, taking the + # lowest column index for each row index. + index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1], + index_all_occurrences[:, 0]) + index_first_occurrences.set_shape([batch_size]) + index_first_occurrences = tf.minimum(index_first_occurrences + 1, max_length) + return index_first_occurrences + + +def sequence_to_sentence(sequence, id_to_word): + """Turn a sequence into a sentence , inverse of sentence_to_sequence.""" + words = [] + for token_index in sequence: + if token_index in id_to_word: + words.append(id_to_word[token_index]) + else: + words.append(reader.UNK) + return " ".join(words) + + +def batch_sequences_to_sentences(sequences, id_to_word): + return [sequence_to_sentence(sequence, id_to_word) for sequence in sequences] + + +def write_eval_results(checkpoint_dir, all_gen_sentences, checkpoint_name, + mean_train_prob, mean_valid_prob, mean_gen_prob, fid): + """Write evaluation results to disk.""" + to_write = ",".join( + map(str, [ + checkpoint_name, mean_train_prob, mean_valid_prob, mean_gen_prob, fid + ])) + eval_filepath = os.path.join(checkpoint_dir, EVAL_FILENAME) + previous_eval_content = "" + if gfile.exists(eval_filepath): + with gfile.GFile(eval_filepath, "r") as f: + previous_eval_content = f.read() + with gfile.GFile(eval_filepath, "w") as f: + f.write(previous_eval_content + to_write + "\n") + + with gfile.GFile( + os.path.join(checkpoint_dir, checkpoint_name + "_sentences.txt"), + "w") as f: + f.write("\n".join(all_gen_sentences)) + + +def maybe_pick_models_to_evaluate(checkpoint_dir): + """Pick a checkpoint to evaluate that has not been evaluated already.""" + logging.info("Picking checkpoint to evaluate from %s.", checkpoint_dir) + + filenames = gfile.listdir(checkpoint_dir) + filenames = [f[:-5] for f in filenames if f[-5:] == ".meta"] + logging.info("Found existing checkpoints: %s", filenames) + + evaluated_filenames = [] + if gfile.exists(os.path.join(checkpoint_dir, EVAL_FILENAME)): + with gfile.GFile(os.path.join(checkpoint_dir, EVAL_FILENAME), "r") as f: + evaluated_filenames = [l.strip().split(",")[0] for l in f.readlines()] + logging.info("Found already evaluated checkpoints: %s", evaluated_filenames) + + checkpoints_to_evaluate = [ + f for f in filenames if f not in evaluated_filenames + ] + logging.info("Remaining potential checkpoints: %s", checkpoints_to_evaluate) + + if checkpoints_to_evaluate: + return os.path.join(checkpoint_dir, checkpoints_to_evaluate[0]) + else: + return None + + +def get_embedding_path(data_dir, dataset): + """By convention, this is where we store the embedding.""" + return os.path.join(data_dir, "glove_%s.txt" % dataset) + + +def make_partially_trainable_embeddings(vocab_file, embedding_source, + vocab_size, trainable_embedding_size): + """Makes embedding matrix with pretrained GloVe [1] part and trainable part. + + [1] Pennington, J., Socher, R., & Manning, C. (2014, October). Glove: Global + vectors for word representation. In Proceedings of the 2014 conference on + empirical methods in natural language processing (EMNLP) (pp. 1532-1543). + + Returns: + A matrix of partially pretrained embeddings. + """ + + # Our embeddings have 2 parts: a pre-trained, frozen, GloVe part, + # and a trainable, randomly initialized part. + # The standard deviation of the GloVe part is used to initialize + # the trainable part, so that both part have roughly the same distribution. + # + # Let g_ij be the j-th coordinates of the GloVe embedding of the i-th word. + # So that 0 < i < |vocab| and 0 < j < 300. + # Then sum_ij (g_ij - sum_kl g_kl)^2 = (0.3836)^2 + # + # In reality g_ij follows a truncated normal distribution + # min(max(N(0, s), -4.2), 4.2) but we approximate it by N(0, 0.3836). + embedding_initializer = _get_embedding_initializer( + vocab_file=vocab_file, + embedding_source=embedding_source, + vocab_size=vocab_size) + pretrained_embedding = tf.get_variable( + "pretrained_embedding", + initializer=embedding_initializer, + dtype=tf.float32) + trainable_embedding = tf.get_variable( + "trainable_embedding", + shape=[vocab_size, trainable_embedding_size], + initializer=tf.initializers.random_normal(mean=0.0, stddev=GLOVE_STD)) + # We just concatenate embeddings, they will pass through a projection + # matrix afterwards. + embedding = tf.concat([pretrained_embedding, trainable_embedding], axis=1) + return embedding