Use bibtex format.

PiperOrigin-RevId: 281754451
This commit is contained in:
Cyprien de Masson d'Autume
2019-11-21 16:27:49 +00:00
committed by Diego de Las Casas
parent db2393d5a9
commit d705acd6a7
10 changed files with 1442 additions and 0 deletions
+61
View File
@@ -0,0 +1,61 @@
# ScratchGAN
This is the example code for the following NeurIPS 2019 paper. If you use the
code here please cite this paper:
@article{DBLP:journals/corr/abs-1905-09922,
author = {Cyprien de Masson d'Autume and
Mihaela Rosca and
Jack W. Rae and
Shakir Mohamed},
title = {Training language GANs from Scratch},
journal = {CoRR},
volume = {abs/1905.09922},
year = {2019},
url = {http://arxiv.org/abs/1905.09922},
archivePrefix = {arXiv},
eprint = {1905.09922},
timestamp = {Wed, 29 May 2019 11:27:50 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1905-09922},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
## Contents
The code contains:
* `generators.py`: implementation of the generator.
* `discriminator_nets.py`: implementation of the discriminator.
* `eval_metrics.py`: implementation of the FED metric.
* `losses.py`: implementation of the RL loss for the generator.
* `reader.py`: data reader / tokenizer.
* `experiment.py`: main training script.
The data contains:
* `{train,valid,test}.json`: the EMNLP2017 News dataset.
* `glove_emnlp2017.txt`: the relevant subset of GloVe embeddings.
## Running
Place the data files in the directory specified by `data_dir` flag.
Create and activate a virtual environment if needed:
virtualenv scratchgan-venv
source scratchgan-venv/bin/activate
Install requirements:
pip install -r scratchgan/requirements.txt
Run training and evaluation jobs:
python2 scratchgan.experiment.py --mode="train" &
python2 scratchgan.experiment.py --mode="evaluate_pair" &
The evaluation code is designed to run in parallel with the training.
The training code saves checkpoints periodically, the evaluation code
looks for new checkpoints and evaluate them.
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+121
View File
@@ -0,0 +1,121 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discriminator networks for text data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
from scratchgan import utils
class LSTMEmbedDiscNet(snt.AbstractModule):
"""An LSTM discriminator that operates on word indexes."""
def __init__(self,
feature_sizes,
vocab_size,
use_layer_norm,
trainable_embedding_size,
dropout,
pad_token,
embedding_source=None,
vocab_file=None,
name='LSTMEmbedDiscNet'):
super(LSTMEmbedDiscNet, self).__init__(name=name)
self._feature_sizes = feature_sizes
self._vocab_size = vocab_size
self._use_layer_norm = use_layer_norm
self._trainable_embedding_size = trainable_embedding_size
self._embedding_source = embedding_source
self._vocab_file = vocab_file
self._dropout = dropout
self._pad_token = pad_token
if self._embedding_source:
assert vocab_file
def _build(self, sequence, sequence_length, is_training=True):
"""Connect to the graph.
Args:
sequence: A [batch_size, max_sequence_length] tensor of int. For example
the indices of words as sampled by the generator.
sequence_length: A [batch_size] tensor of int. Length of the sequence.
is_training: Boolean, False to disable dropout.
Returns:
A [batch_size, max_sequence_length, feature_size] tensor of floats. For
each sequence in the batch, the features should (hopefully) allow to
distinguish if the value at each timestep is real or generated.
"""
batch_size, max_sequence_length = sequence.shape.as_list()
keep_prob = (1.0 - self._dropout) if is_training else 1.0
if self._embedding_source:
all_embeddings = utils.make_partially_trainable_embeddings(
self._vocab_file, self._embedding_source, self._vocab_size,
self._trainable_embedding_size)
else:
all_embeddings = tf.get_variable(
'trainable_embedding',
shape=[self._vocab_size, self._trainable_embedding_size],
trainable=True)
_, self._embedding_size = all_embeddings.shape.as_list()
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
embeddings.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._embedding_size])
position_dim = 8
embeddings_pos = utils.append_position_signal(embeddings, position_dim)
embeddings_pos = tf.reshape(
embeddings_pos,
[batch_size * max_sequence_length, self._embedding_size + position_dim])
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
lstm_inputs = tf.reshape(
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]])
lstm_inputs.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._feature_sizes[0]])
encoder_cells = []
for feature_size in self._feature_sizes:
encoder_cells += [
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
]
encoder_cell = snt.DeepRNN(encoder_cells)
initial_state = encoder_cell.initial_state(batch_size)
hidden_states, _ = tf.nn.dynamic_rnn(
cell=encoder_cell,
inputs=lstm_inputs,
sequence_length=sequence_length,
initial_state=initial_state,
swap_memory=True)
hidden_states.shape.assert_is_compatible_with(
[batch_size, max_sequence_length,
sum(self._feature_sizes)])
logits = snt.BatchApply(snt.Linear(1))(hidden_states)
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1])
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])
# Mask past first PAD symbol
#
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
# into account the sequence_length properly, because otherwise
# the logits at a given timestep will depend on the inputs for all other
# timesteps, including the ones that should be masked.
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
return masked_logits_flat
+48
View File
@@ -0,0 +1,48 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
def fid(generated_sentences, real_sentences):
"""Compute FID rn sentences using pretrained universal sentence encoder.
Args:
generated_sentences: list of N strings.
real_sentences: list of N strings.
Returns:
Frechet distance between activations.
"""
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
real_embed = embed(real_sentences)
generated_embed = embed(generated_sentences)
distance = tfgan.eval.frechet_classifier_distance_from_activations(
real_embed, generated_embed)
# Restrict the thread pool size to prevent excessive CPU usage.
config = tf.ConfigProto()
config.intra_op_parallelism_threads = 16
config.inter_op_parallelism_threads = 16
with tf.Session(config=config) as session:
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
distance_np = session.run(distance)
return distance_np
+469
View File
@@ -0,0 +1,469 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script for ScratchGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile
from scratchgan import discriminator_nets
from scratchgan import eval_metrics
from scratchgan import generators
from scratchgan import losses
from scratchgan import reader
from scratchgan import utils
flags.DEFINE_string("dataset", "emnlp2017", "Dataset.")
flags.DEFINE_integer("batch_size", 512, "Batch size")
flags.DEFINE_string("gen_type", "lstm", "Generator type.")
flags.DEFINE_string("disc_type", "lstm", "Discriminator type.")
flags.DEFINE_string("disc_loss_type", "ce", "Loss type.")
flags.DEFINE_integer("gen_feature_size", 512, "Generator feature size.")
flags.DEFINE_integer("disc_feature_size", 512, "Discriminator feature size.")
flags.DEFINE_integer("num_layers_gen", 2, "Number of generator layers.")
flags.DEFINE_integer("num_layers_disc", 1, "Number of discriminator layers.")
flags.DEFINE_bool("layer_norm_gen", False, "Layer norm generator.")
flags.DEFINE_bool("layer_norm_disc", True, "Layer norm discriminator.")
flags.DEFINE_float("gen_input_dropout", 0.0, "Input dropout generator.")
flags.DEFINE_float("gen_output_dropout", 0.0, "Input dropout discriminator.")
flags.DEFINE_float("l2_gen", 0.0, "L2 regularization generator.")
flags.DEFINE_float("l2_disc", 1e-6, "L2 regularization discriminator.")
flags.DEFINE_float("disc_dropout", 0.1, "Dropout discriminator")
flags.DEFINE_integer("trainable_embedding_size", 64,
"Size of trainable embedding.")
flags.DEFINE_bool("use_pretrained_embedding", True, "Use pretrained embedding.")
flags.DEFINE_integer("num_steps", int(200 * 1000), "Number of training steps.")
flags.DEFINE_integer("num_disc_updates", 1, "Number of discriminator updates.")
flags.DEFINE_integer("num_gen_updates", 1, "Number of generator updates.")
flags.DEFINE_string("data_dir", "/tmp/emnlp2017", "Directory where data is.")
flags.DEFINE_float("gen_lr", 9.59e-5, "Learning rate generator.")
flags.DEFINE_float("disc_lr", 9.38e-3, "Learning rate discriminator.")
flags.DEFINE_float("gen_beta1", 0.5, "Beta1 for generator.")
flags.DEFINE_float("disc_beta1", 0.5, "Beta1 for discriminator.")
flags.DEFINE_float("gamma", 0.23, "Discount factor.")
flags.DEFINE_float("baseline_decay", 0.08, "Baseline decay rate.")
flags.DEFINE_string("mode", "train", "train or evaluate_pair.")
flags.DEFINE_string("checkpoint_dir", "/tmp/emnlp2017/checkpoints/",
"Directory for checkpoints.")
flags.DEFINE_integer("export_every", 1000, "Frequency of checkpoint exports.")
flags.DEFINE_integer("num_examples_for_eval", int(1e4),
"Number of examples for evaluation")
EVALUATOR_SLEEP_PERIOD = 60 # Seconds evaluator sleeps if nothing to do.
def main(_):
config = flags.FLAGS
gfile.makedirs(config.checkpoint_dir)
if config.mode == "train":
train(config)
elif config.mode == "evaluate_pair":
while True:
checkpoint_path = utils.maybe_pick_models_to_evaluate(
checkpoint_dir=config.checkpoint_dir)
if checkpoint_path:
evaluate_pair(
config=config,
batch_size=config.batch_size,
checkpoint_path=checkpoint_path,
data_dir=config.data_dir,
dataset=config.dataset,
num_examples_for_eval=config.num_examples_for_eval)
else:
logging.info("No models to evaluate found, sleeping for %d seconds",
EVALUATOR_SLEEP_PERIOD)
time.sleep(EVALUATOR_SLEEP_PERIOD)
else:
raise Exception(
"Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\""
% (config.mode))
def train(config):
"""Train."""
logging.info("Training.")
tf.reset_default_graph()
np.set_printoptions(precision=4)
# Get data.
raw_data = reader.get_raw_data(
data_path=config.data_dir, dataset=config.dataset)
train_data, valid_data, word_to_id = raw_data
id_to_word = {v: k for k, v in word_to_id.iteritems()}
vocab_size = len(word_to_id)
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
logging.info("Vocabulary size: %d", vocab_size)
iterator = reader.iterator(raw_data=train_data, batch_size=config.batch_size)
iterator_valid = reader.iterator(
raw_data=valid_data, batch_size=config.batch_size)
real_sequence = tf.placeholder(
dtype=tf.int32,
shape=[config.batch_size, max_length],
name="real_sequence")
real_sequence_length = tf.placeholder(
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
first_batch_np = iterator.next()
valid_batch_np = iterator_valid.next()
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
test_fake_batch = {
"sequence":
tf.constant(
np.random.choice(
vocab_size, size=[config.batch_size,
max_length]).astype(np.int32)),
"sequence_length":
tf.constant(
np.random.choice(max_length,
size=[config.batch_size]).astype(np.int32)),
}
valid_batch = {k: tf.constant(v) for k, v in valid_batch_np.items()}
# Create generator.
if config.use_pretrained_embedding:
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)):
f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file)
else:
embedding_source = None
vocab_file = None
gen = generators.LSTMGen(
vocab_size=vocab_size,
feature_sizes=[config.gen_feature_size] * config.num_layers_gen,
max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset],
batch_size=config.batch_size,
use_layer_norm=config.layer_norm_gen,
trainable_embedding_size=config.trainable_embedding_size,
input_dropout=config.gen_input_dropout,
output_dropout=config.gen_output_dropout,
pad_token=reader.PAD_INT,
embedding_source=embedding_source,
vocab_file=vocab_file,
)
gen_outputs = gen()
# Create discriminator.
disc = discriminator_nets.LSTMEmbedDiscNet(
vocab_size=vocab_size,
feature_sizes=[config.disc_feature_size] * config.num_layers_disc,
trainable_embedding_size=config.trainable_embedding_size,
embedding_source=embedding_source,
use_layer_norm=config.layer_norm_disc,
pad_token=reader.PAD_INT,
vocab_file=vocab_file,
dropout=config.disc_dropout,
)
disc_logits_real = disc(
sequence=real_sequence, sequence_length=real_sequence_length)
disc_logits_fake = disc(
sequence=gen_outputs["sequence"],
sequence_length=gen_outputs["sequence_length"])
# Loss of the discriminator.
if config.disc_loss_type == "ce":
targets_real = tf.ones(
[config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]])
targets_fake = tf.zeros(
[config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]])
loss_real = losses.sequential_cross_entropy_loss(disc_logits_real,
targets_real)
loss_fake = losses.sequential_cross_entropy_loss(disc_logits_fake,
targets_fake)
disc_loss = 0.5 * loss_real + 0.5 * loss_fake
# Loss of the generator.
gen_loss, cumulative_rewards, baseline = losses.reinforce_loss(
disc_logits=disc_logits_fake,
gen_logprobs=gen_outputs["logprobs"],
gamma=config.gamma,
decay=config.baseline_decay)
# Optimizers
disc_optimizer = tf.train.AdamOptimizer(
learning_rate=config.disc_lr, beta1=config.disc_beta1)
gen_optimizer = tf.train.AdamOptimizer(
learning_rate=config.gen_lr, beta1=config.gen_beta1)
# Get losses and variables.
disc_vars = disc.get_all_variables()
gen_vars = gen.get_all_variables()
l2_disc = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in disc_vars]))
l2_gen = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in gen_vars]))
scalar_disc_loss = tf.reduce_mean(disc_loss) + config.l2_disc * l2_disc
scalar_gen_loss = tf.reduce_mean(gen_loss) + config.l2_gen * l2_gen
# Update ops.
global_step = tf.train.get_or_create_global_step()
disc_update = disc_optimizer.minimize(
scalar_disc_loss, var_list=disc_vars, global_step=global_step)
gen_update = gen_optimizer.minimize(
scalar_gen_loss, var_list=gen_vars, global_step=global_step)
# Saver.
saver = tf.train.Saver()
# Metrics
test_disc_logits_real = disc(**test_real_batch)
test_disc_logits_fake = disc(**test_fake_batch)
valid_disc_logits = disc(**valid_batch)
disc_predictions_real = tf.nn.sigmoid(disc_logits_real)
disc_predictions_fake = tf.nn.sigmoid(disc_logits_fake)
valid_disc_predictions = tf.reduce_mean(
tf.nn.sigmoid(valid_disc_logits), axis=0)
test_disc_predictions_real = tf.reduce_mean(
tf.nn.sigmoid(test_disc_logits_real), axis=0)
test_disc_predictions_fake = tf.reduce_mean(
tf.nn.sigmoid(test_disc_logits_fake), axis=0)
# Only log results for the first element of the batch.
metrics = {
"scalar_gen_loss": scalar_gen_loss,
"scalar_disc_loss": scalar_disc_loss,
"disc_predictions_real": tf.reduce_mean(disc_predictions_real),
"disc_predictions_fake": tf.reduce_mean(disc_predictions_fake),
"test_disc_predictions_real": tf.reduce_mean(test_disc_predictions_real),
"test_disc_predictions_fake": tf.reduce_mean(test_disc_predictions_fake),
"valid_disc_predictions": tf.reduce_mean(valid_disc_predictions),
"cumulative_rewards": tf.reduce_mean(cumulative_rewards),
"baseline": tf.reduce_mean(baseline),
}
# Training.
logging.info("Starting training")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
latest_ckpt = tf.train.latest_checkpoint(config.checkpoint_dir)
if latest_ckpt:
saver.restore(sess, latest_ckpt)
for step in xrange(config.num_steps):
real_data_np = iterator.next()
train_feed = {
real_sequence: real_data_np["sequence"],
real_sequence_length: real_data_np["sequence_length"],
}
# Update generator and discriminator.
for _ in xrange(config.num_disc_updates):
sess.run(disc_update, feed_dict=train_feed)
for _ in xrange(config.num_gen_updates):
sess.run(gen_update, feed_dict=train_feed)
# Reporting
if step % config.export_every == 0:
gen_sequence_np, metrics_np = sess.run(
[gen_outputs["sequence"], metrics], feed_dict=train_feed)
metrics_np["gen_sentence"] = utils.sequence_to_sentence(
gen_sequence_np[0, :], id_to_word)
saver.save(
sess,
save_path=config.checkpoint_dir + "scratchgan",
global_step=global_step)
metrics_np["model_path"] = tf.train.latest_checkpoint(
config.checkpoint_dir)
logging.info(metrics_np)
# After training, export models.
saver.save(
sess,
save_path=config.checkpoint_dir + "scratchgan",
global_step=global_step)
logging.info("Saved final model at %s.",
tf.train.latest_checkpoint(config.checkpoint_dir))
def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
num_examples_for_eval):
"""Evaluates a pair generator discriminator.
This function loads a discriminator from disk, a generator, and evaluates the
discriminator against the generator.
It returns the mean probability of the discriminator against several batches,
and the FID of the generator against the validation data.
It also writes evaluation samples to disk.
Args:
config: dict, the config file.
batch_size: int, size of the batch.
checkpoint_path: string, full path to the TF checkpoint on disk.
data_dir: string, path to a directory containing the dataset.
dataset: string, "emnlp2017", to select the right dataset.
num_examples_for_eval: int, number of examples for evaluation.
"""
tf.reset_default_graph()
logging.info("Evaluating checkpoint %s.", checkpoint_path)
# Build graph.
train_data, valid_data, word_to_id = reader.get_raw_data(
data_dir, dataset=dataset)
id_to_word = {v: k for k, v in word_to_id.iteritems()}
vocab_size = len(word_to_id)
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
train_sequence = tf.placeholder(
dtype=tf.int32,
shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]],
name="train_sequence")
train_sequence_length = tf.placeholder(
dtype=tf.int32, shape=[batch_size], name="train_sequence_length")
valid_sequence = tf.placeholder(
dtype=tf.int32,
shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]],
name="valid_sequence")
valid_sequence_length = tf.placeholder(
dtype=tf.int32, shape=[batch_size], name="valid_sequence_length")
disc_inputs_train = {
"sequence": train_sequence,
"sequence_length": train_sequence_length,
}
disc_inputs_valid = {
"sequence": valid_sequence,
"sequence_length": valid_sequence_length,
}
if config.use_pretrained_embedding:
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)):
f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file)
else:
embedding_source = None
vocab_file = None
gen = generators.LSTMGen(
vocab_size=vocab_size,
feature_sizes=[config.gen_feature_size] * config.num_layers_gen,
max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset],
batch_size=config.batch_size,
use_layer_norm=config.layer_norm_gen,
trainable_embedding_size=config.trainable_embedding_size,
input_dropout=config.gen_input_dropout,
output_dropout=config.gen_output_dropout,
pad_token=reader.PAD_INT,
embedding_source=embedding_source,
vocab_file=vocab_file,
)
gen_outputs = gen()
disc = discriminator_nets.LSTMEmbedDiscNet(
vocab_size=vocab_size,
feature_sizes=[config.disc_feature_size] * config.num_layers_disc,
trainable_embedding_size=config.trainable_embedding_size,
embedding_source=embedding_source,
use_layer_norm=config.layer_norm_disc,
pad_token=reader.PAD_INT,
vocab_file=vocab_file,
dropout=config.disc_dropout,
)
disc_inputs = {
"sequence": gen_outputs["sequence"],
"sequence_length": gen_outputs["sequence_length"],
}
gen_logits = disc(**disc_inputs)
train_logits = disc(**disc_inputs_train)
valid_logits = disc(**disc_inputs_valid)
# Saver.
saver = tf.train.Saver()
# Reduce over time and batch.
train_probs = tf.reduce_mean(tf.nn.sigmoid(train_logits))
valid_probs = tf.reduce_mean(tf.nn.sigmoid(valid_logits))
gen_probs = tf.reduce_mean(tf.nn.sigmoid(gen_logits))
outputs = {
"train_probs": train_probs,
"valid_probs": valid_probs,
"gen_probs": gen_probs,
"gen_sequences": gen_outputs["sequence"],
"valid_sequences": valid_sequence
}
# Get average discriminator score and store generated sequences.
all_valid_sentences = []
all_gen_sentences = []
all_gen_sequences = []
mean_train_prob = 0.0
mean_valid_prob = 0.0
mean_gen_prob = 0.0
logging.info("Graph constructed, generating batches.")
num_batches = num_examples_for_eval // batch_size + 1
# Restrict the thread pool size to prevent excessive GCU usage on Borg.
tf_config = tf.ConfigProto()
tf_config.intra_op_parallelism_threads = 16
tf_config.inter_op_parallelism_threads = 16
with tf.Session(config=tf_config) as sess:
# Restore variables from checkpoints.
logging.info("Restoring variables.")
saver.restore(sess, checkpoint_path)
for i in xrange(num_batches):
logging.info("Batch %d / %d", i, num_batches)
train_data_np = train_iterator.next()
valid_data_np = valid_iterator.next()
feed_dict = {
train_sequence: train_data_np["sequence"],
train_sequence_length: train_data_np["sequence_length"],
valid_sequence: valid_data_np["sequence"],
valid_sequence_length: valid_data_np["sequence_length"],
}
outputs_np = sess.run(outputs, feed_dict=feed_dict)
all_gen_sequences.extend(outputs_np["gen_sequences"])
gen_sentences = utils.batch_sequences_to_sentences(
outputs_np["gen_sequences"], id_to_word)
valid_sentences = utils.batch_sequences_to_sentences(
outputs_np["valid_sequences"], id_to_word)
all_valid_sentences.extend(valid_sentences)
all_gen_sentences.extend(gen_sentences)
mean_train_prob += outputs_np["train_probs"] / batch_size
mean_valid_prob += outputs_np["valid_probs"] / batch_size
mean_gen_prob += outputs_np["gen_probs"] / batch_size
logging.info("Evaluating FID.")
# Compute FID
fid = eval_metrics.fid(
generated_sentences=all_gen_sentences[:num_examples_for_eval],
real_sentences=all_valid_sentences[:num_examples_for_eval])
utils.write_eval_results(config.checkpoint_dir, all_gen_sentences,
os.path.basename(checkpoint_path), mean_train_prob,
mean_valid_prob, mean_gen_prob, fid)
if __name__ == "__main__":
app.run(main)
+147
View File
@@ -0,0 +1,147 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generators for text data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp
from scratchgan import utils
class LSTMGen(snt.AbstractModule):
"""A multi-layer LSTM language model.
Uses tied input/output embedding weights.
"""
def __init__(self,
vocab_size,
feature_sizes,
max_sequence_length,
batch_size,
use_layer_norm,
trainable_embedding_size,
input_dropout,
output_dropout,
pad_token,
embedding_source=None,
vocab_file=None,
name='lstm_gen'):
super(LSTMGen, self).__init__(name=name)
self._feature_sizes = feature_sizes
self._max_sequence_length = max_sequence_length
self._vocab_size = vocab_size
self._batch_size = batch_size
self._use_layer_norm = use_layer_norm
self._trainable_embedding_size = trainable_embedding_size
self._embedding_source = embedding_source
self._vocab_file = vocab_file
self._input_dropout = input_dropout
self._output_dropout = output_dropout
self._pad_token = pad_token
if self._embedding_source:
assert vocab_file
def _build(self, is_training=True, temperature=1.0):
input_keep_prob = (1. - self._input_dropout) if is_training else 1.0
output_keep_prob = (1. - self._output_dropout) if is_training else 1.0
batch_size = self._batch_size
max_sequence_length = self._max_sequence_length
if self._embedding_source:
all_embeddings = utils.make_partially_trainable_embeddings(
self._vocab_file, self._embedding_source, self._vocab_size,
self._trainable_embedding_size)
else:
all_embeddings = tf.get_variable(
'trainable_embeddings',
shape=[self._vocab_size, self._trainable_embedding_size],
trainable=True)
_, self._embedding_size = all_embeddings.shape.as_list()
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=input_keep_prob)
output_embeddings = tf.nn.dropout(
all_embeddings, keep_prob=output_keep_prob)
out_bias = tf.get_variable(
'out_bias', shape=[1, self._vocab_size], dtype=tf.float32)
in_proj = tf.get_variable(
'in_proj', shape=[self._embedding_size, self._feature_sizes[0]])
# If more than 1 layer, then output has dim sum(self._feature_sizes),
# which is different from input dim == self._feature_sizes[0]
# So we need a different projection matrix for input and output.
if len(self._feature_sizes) > 1:
out_proj = tf.get_variable(
'out_proj', shape=[self._embedding_size,
sum(self._feature_sizes)])
else:
out_proj = in_proj
encoder_cells = []
for feature_size in self._feature_sizes:
encoder_cells += [
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
]
encoder_cell = snt.DeepRNN(encoder_cells)
state = encoder_cell.initial_state(batch_size)
# Manual unrolling.
samples_list, logits_list, logprobs_list, embeddings_list = [], [], [], []
sample = tf.tile(
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
logging.info('Unrolling over %d steps.', max_sequence_length)
for _ in xrange(max_sequence_length):
# Input is sampled word at t-1.
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
embedding.shape.assert_is_compatible_with(
[batch_size, self._embedding_size])
embedding_proj = tf.matmul(embedding, in_proj)
embedding_proj.shape.assert_is_compatible_with(
[batch_size, self._feature_sizes[0]])
outputs, state = encoder_cell(embedding_proj, state)
outputs_proj = tf.matmul(outputs, out_proj, transpose_b=True)
logits = tf.matmul(
outputs_proj, output_embeddings, transpose_b=True) + out_bias
categorical = tfp.distributions.Categorical(logits=logits/temperature)
sample = categorical.sample()
logprobs = categorical.log_prob(sample)
samples_list.append(sample)
logits_list.append(logits)
logprobs_list.append(logprobs)
embeddings_list.append(embedding)
# Create an op to retrieve embeddings for full sequence, useful for testing.
embeddings = tf.stack( # pylint: disable=unused-variable
embeddings_list,
axis=1,
name='embeddings')
sequence = tf.stack(samples_list, axis=1)
logprobs = tf.stack(logprobs_list, axis=1)
# The sequence stops after the first occurrence of a PAD token.
sequence_length = utils.get_first_occurrence_indices(
sequence, self._pad_token)
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
masked_sequence = sequence * tf.cast(mask, tf.int32)
masked_logprobs = logprobs * tf.cast(mask, tf.float32)
return {
'sequence': masked_sequence,
'sequence_length': sequence_length,
'logprobs': masked_logprobs
}
+93
View File
@@ -0,0 +1,93 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses for sequential GANs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def sequential_cross_entropy_loss(logits, expected):
"""The cross entropy loss for binary classification.
Used to train the discriminator when not using WGAN loss.
Assume logits is the log probability of classifying as 1. (real).
Args:
logits: a `tf.Tensor`, the model produced logits, shape [batch_size,
sequence_length].
expected: a `tf.Tensor`, the expected output, shape [batch_size,
sequence_length].
Returns:
A scalar `tf.Tensor`, the average loss obtained on the given inputs.
"""
batch_size, sequence_length = logits.shape.as_list()
expected = tf.cast(expected, tf.float32)
ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=expected, logits=logits)
return tf.reshape(ce, [batch_size, sequence_length])
def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
"""The REINFORCE loss.
Args:
disc_logits: float tensor, shape [batch_size, sequence_length].
gen_logprobs: float32 tensor, shape [batch_size, sequence_length]
gamma: a float, discount factor for cumulative reward.
decay: a float, decay rate for the EWMA baseline of REINFORCE.
Returns:
Float tensor, shape [batch_size, sequence_length], the REINFORCE loss for
each timestep.
"""
# Assume 1 logit for each timestep.
batch_size, sequence_length = disc_logits.shape.as_list()
gen_logprobs.shape.assert_is_compatible_with([batch_size, sequence_length])
disc_predictions = tf.nn.sigmoid(disc_logits)
# MaskGAN uses log(D), but this is more stable empirically.
rewards = 2.0 * disc_predictions - 1
# Compute cumulative rewards.
rewards_list = tf.unstack(rewards, axis=1)
cumulative_rewards = []
for t in xrange(sequence_length):
cum_value = tf.zeros(shape=[batch_size])
for s in xrange(t, sequence_length):
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
cumulative_rewards.append(cum_value)
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
cumulative_rewards.shape.assert_is_compatible_with(
[batch_size, sequence_length])
with tf.variable_scope("reinforce", reuse=tf.AUTO_REUSE):
ewma_reward = tf.get_variable("ewma_reward", initializer=0.0)
mean_reward = tf.reduce_mean(cumulative_rewards)
new_ewma_reward = decay * ewma_reward + (1.0 - decay) * mean_reward
update_op = tf.assign(ewma_reward, new_ewma_reward)
# REINFORCE
with tf.control_dependencies([update_op]):
advantage = cumulative_rewards - ewma_reward
loss = -tf.stop_gradient(advantage) * gen_logprobs
loss.shape.assert_is_compatible_with([batch_size, sequence_length])
return loss, cumulative_rewards, ewma_reward
+174
View File
@@ -0,0 +1,174 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for parsing text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from absl import logging
import numpy as np
from tensorflow.io import gfile
# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32
# lengths: [N, 2] array of int32, such that
# lengths[i, 0] is the number of non-pad tokens in sequences[i, :]
# TODO(cyprien): handle PTB
FILENAMES = {
"emnlp2017": ("train.json", "valid.json", "test.json"),
}
# EMNLP2017 sentences have max length 50, add one for a PAD token so that all
# sentences end with PAD.
MAX_TOKENS_SEQUENCE = {"emnlp2017": 52}
UNK = "<unk>"
PAD = " "
PAD_INT = 0
def tokenize(sentence):
"""Split a string into words."""
return sentence.split(" ") + [PAD]
def _build_vocab(json_data):
"""Builds full vocab from json data."""
vocab = collections.Counter()
for sentence in json_data:
tokens = tokenize(sentence["s"])
vocab.update(tokens)
for title in sentence["t"]:
title_tokens = tokenize(title)
vocab.update(title_tokens)
# Most common words first.
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
words, _ = zip(*count_pairs)
words = list(words)
if UNK not in words:
words = [UNK] + words
word_to_id = dict(zip(words, range(len(words))))
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
# currently at the `PAD_INT` index with the `PAD` token.
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
assert word_to_id[PAD] == PAD_INT
return word_to_id
def string_sequence_to_sequence(string_sequence, word_to_id):
result = []
for word in string_sequence:
if word in word_to_id:
result.append(word_to_id[word])
else:
result.append(word_to_id[UNK])
return result
def _integerize(json_data, word_to_id, dataset):
"""Transform words into integers."""
sequences = np.full((len(json_data), MAX_TOKENS_SEQUENCE[dataset]),
word_to_id[PAD], np.int32)
sequence_lengths = np.zeros(shape=(len(json_data)), dtype=np.int32)
for i, sentence in enumerate(json_data):
sequence_i = string_sequence_to_sequence(
tokenize(sentence["s"]), word_to_id)
sequence_lengths[i] = len(sequence_i)
sequences[i, :sequence_lengths[i]] = np.array(sequence_i)
return {
"sequences": sequences,
"sequence_lengths": sequence_lengths,
}
def get_raw_data(data_path, dataset, truncate_vocab=20000):
"""Load raw data from data directory "data_path".
Reads text files, converts strings to integer ids,
and performs mini-batching of the inputs.
Args:
data_path: string path to the directory where simple-examples.tgz has been
extracted.
dataset: one of ["emnlp2017"]
truncate_vocab: int, number of words to keep in the vocabulary.
Returns:
tuple (train_data, valid_data, vocabulary) where each of the data
objects can be passed to iterator.
Raises:
ValueError: dataset not in ["emnlp2017"].
"""
if dataset not in FILENAMES:
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
dataset, FILENAMES.keys()))
train_file, valid_file, _ = FILENAMES[dataset]
train_path = os.path.join(data_path, train_file)
valid_path = os.path.join(data_path, valid_file)
with gfile.GFile(train_path, "r") as json_file:
json_data_train = json.load(json_file)
with gfile.GFile(valid_path, "r") as json_file:
json_data_valid = json.load(json_file)
word_to_id = _build_vocab(json_data_train)
logging.info("Full vocab length: %d", len(word_to_id))
# Assume the vocab is sorted by frequency.
word_to_id_truncated = {
k: v for k, v in word_to_id.items() if v < truncate_vocab
}
logging.info("Truncated vocab length: %d", len(word_to_id_truncated))
train_data = _integerize(json_data_train, word_to_id_truncated, dataset)
valid_data = _integerize(json_data_valid, word_to_id_truncated, dataset)
return train_data, valid_data, word_to_id_truncated
def iterator(raw_data, batch_size, random=False):
"""Looping iterators on the raw data."""
sequences = raw_data["sequences"]
sequence_lengths = raw_data["sequence_lengths"]
num_examples = sequences.shape[0]
indice_range = np.arange(num_examples)
if random:
while True:
indices = np.random.choice(indice_range, size=batch_size, replace=True)
yield {
"sequence": sequences[indices, :],
"sequence_length": sequence_lengths[indices],
}
else:
start = 0
while True:
sequence = sequences[start:(start + batch_size), :]
sequence_length = sequence_lengths[start:(start + batch_size)]
start += batch_size
if start + batch_size > num_examples:
start = (start + batch_size) % num_examples
yield {
"sequence": sequence,
"sequence_length": sequence_length,
}
+8
View File
@@ -0,0 +1,8 @@
absl-py==0.7.1
dm-sonnet==1.34
numpy==1.16.4
tensorflow==1.14
tensorflow-probability==0.7.0
tensorflow-gan == 1.0.0.dev0
tensorflow-hub == 0.6.0
tensorflow-io == 0.8.0
+308
View File
@@ -0,0 +1,308 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile
from scratchgan import reader
EVAL_FILENAME = "evaluated_checkpoints.csv"
GLOVE_DIM = 300
GLOVE_STD = 0.3836 # Standard dev. of GloVe embeddings.
def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
"""Loads pretrained embeddings from a file in GloVe format."""
with gfile.GFile(embedding_source, "r") as f:
embedding_lines = f.readlines()
# First line contains embedding dim.
_, embedding_dim = map(int, embedding_lines[0].split())
# Get the tokens as strings.
tokens = [line.split()[0] for line in embedding_lines[1:]]
# Get the actual embedding matrix.
unsorted_emb = np.array(
[[float(x) for x in line.split()[1:]] for line in embedding_lines[1:]])
# Get the expected vocab order.
with gfile.GFile(vocab_file, "r") as f:
tokens_order = [l.strip() for l in f.readlines()]
assert vocab_size == len(tokens_order)
# Put the embeddings in the order.
sorted_emb = np.zeros((vocab_size, embedding_dim))
for i, token in enumerate(tokens_order):
if token in tokens:
sorted_emb[i, :] = unsorted_emb[tokens.index(token), :]
else: # If we don't have a pretrained embedding, initialize randomly.
sorted_emb[i, :] = np.random.normal(
loc=0.0, scale=GLOVE_STD, size=(GLOVE_DIM,))
return sorted_emb.astype(np.float32)
def append_position_signal(embeddings, position_dim=8):
"""Append position signal. See get_position_signal."""
batch_size, sequence_length, embedding_dim = embeddings.get_shape().as_list()
positions = get_position_signal(sequence_length, position_dim)
# Append to embeddings.
position_inputs = tf.tile(positions[None, :, :], [batch_size, 1, 1])
embeddings_pos = tf.concat([embeddings, position_inputs], axis=2)
embeddings_pos.shape.assert_is_compatible_with(
[batch_size, sequence_length, embedding_dim + position_dim])
return embeddings_pos
def get_position_signal(sequence_length, position_dim=8):
"""Return fixed position signal as sine waves.
Sine waves frequencies are linearly spaced so that shortest is 2 and
longest is half the maximum length. That way the longest frequency
is long enough to be monotonous over the whole sequence length.
Sine waves are also shifted so that they don't all start with the same
value.
We don't use learned positional embeddings because these embeddings are
projected linearly along with the original embeddings, and the projection is
learned.
Args:
sequence_length: int, T, length of the sequence..
position_dim: int, P, number of sine waves.
Returns:
A [T, P] tensor, position embeddings.
"""
# Compute the frequencies.
periods = tf.exp(
tf.lin_space(
tf.log(2.0), tf.log(tf.to_float(sequence_length)), position_dim))
frequencies = 1.0 / periods # Shape [T, P].
# Compute the sine waves.
xs = frequencies[None, :] * tf.to_float(tf.range(sequence_length)[:, None])
shifts = tf.lin_space(0.0, 2.0, position_dim)[None, :] # [1, P]
positions = tf.math.cos(math.pi * (xs + shifts)) # [T, P]
positions.shape.assert_is_compatible_with([sequence_length, position_dim])
return positions
def get_mask_by_length(lengths, max_length):
"""Returns a mask where x[i , j] = (j < lengths[i]).
Args:
lengths: [B] tensor of int32 such that 0 <= lengths[i] <= max_length.
max_length: scalar tensor of int32.
Returns:
[B, max_length] tensor of booleans such that x[i, j] is True
if and only if j < lengths[i].
"""
batch_size = lengths.get_shape().as_list()[0]
indices = tf.range(start=0, limit=max_length)
all_indices = tf.tile(indices[None, :], [batch_size, 1])
all_lengths = tf.tile(lengths[:, None], [1, max_length])
mask = (all_indices < all_lengths)
mask_boolean = tf.cast(mask, tf.bool)
return mask_boolean
def get_mask_past_symbol(reference, symbol, optimize_for_tpu=False):
"""For each row, mask is True before and at the first occurrence of symbol."""
batch_size, max_length = reference.get_shape().as_list()
symbol = tf.convert_to_tensor(symbol)
symbol.shape.assert_is_compatible_with([])
first_indices = get_first_occurrence_indices(reference, symbol,
optimize_for_tpu)
first_indices.shape.assert_is_compatible_with([batch_size])
keep_lengths = tf.minimum(first_indices, max_length)
mask = get_mask_by_length(keep_lengths, max_length)
mask.shape.assert_is_compatible_with([batch_size, max_length])
mask.set_shape([batch_size, max_length])
return mask
def get_first_occurrence_indices(reference, symbol, optimize_for_tpu=False):
"""For each row in reference, get index after the first occurrence of symbol.
If symbol is not present on a row, return reference.shape[1] instead.
Args:
reference: [B, T] tensor of elements of the same type as symbol.
symbol: int or [] scalar tensor of the same dtype as symbol.
optimize_for_tpu: bool, whether to use a TPU-capable variant.
Returns:
A [B] reference of tf.int32 where x[i] is such that
reference[i, x[i]-1] == symbol, and reference[i, j] != symbol
for j<i-1. If symbol is not present on row i then x[i] = T.
"""
if optimize_for_tpu:
# Run code which can be compiled on TPU.
# Transpose refernce to [T, B]
reference = tf.transpose(reference, [1, 0])
range_tensor = tf.range(reference.shape.as_list()[0])
indexes = tf.stack([range_tensor] * reference.shape.as_list()[1], 1)
symbol = tf.stack([symbol] * reference.shape.as_list()[1], 0)
initial_indices = tf.constant(
reference.shape.as_list()[0],
shape=[reference.shape.as_list()[1]],
dtype=tf.int32)
# We want a function which moves backwards.
def fn(current_index, elems):
ref, ind = elems
return tf.where(tf.equal(ref, symbol), ind + 1, current_index)
min_indexes = tf.scan(
fn, (reference, indexes),
initializer=initial_indices,
parallel_iterations=1,
reverse=True)
return min_indexes[0]
batch_size, max_length = reference.get_shape().as_list()
symbol = tf.convert_to_tensor(symbol)
symbol.shape.assert_is_compatible_with([])
# Add symbol at the end of each row, to make sure tf.where works.
tensor = tf.concat(
[reference, tf.tile(symbol[None, None], [batch_size, 1])], axis=1)
index_all_occurrences = tf.where(tf.equal(tensor, symbol))
index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
# `index_all_occurrences` is a [N, 2] tensor with coordinates of all positions
# of `symbol` in `tensor`. So N will be >= batch size since there can be
# several `symbol` in one row of tensor. We need to take only the position
# of the first occurrence for each row. `segment_min` does that, taking the
# lowest column index for each row index.
index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1],
index_all_occurrences[:, 0])
index_first_occurrences.set_shape([batch_size])
index_first_occurrences = tf.minimum(index_first_occurrences + 1, max_length)
return index_first_occurrences
def sequence_to_sentence(sequence, id_to_word):
"""Turn a sequence into a sentence , inverse of sentence_to_sequence."""
words = []
for token_index in sequence:
if token_index in id_to_word:
words.append(id_to_word[token_index])
else:
words.append(reader.UNK)
return " ".join(words)
def batch_sequences_to_sentences(sequences, id_to_word):
return [sequence_to_sentence(sequence, id_to_word) for sequence in sequences]
def write_eval_results(checkpoint_dir, all_gen_sentences, checkpoint_name,
mean_train_prob, mean_valid_prob, mean_gen_prob, fid):
"""Write evaluation results to disk."""
to_write = ",".join(
map(str, [
checkpoint_name, mean_train_prob, mean_valid_prob, mean_gen_prob, fid
]))
eval_filepath = os.path.join(checkpoint_dir, EVAL_FILENAME)
previous_eval_content = ""
if gfile.exists(eval_filepath):
with gfile.GFile(eval_filepath, "r") as f:
previous_eval_content = f.read()
with gfile.GFile(eval_filepath, "w") as f:
f.write(previous_eval_content + to_write + "\n")
with gfile.GFile(
os.path.join(checkpoint_dir, checkpoint_name + "_sentences.txt"),
"w") as f:
f.write("\n".join(all_gen_sentences))
def maybe_pick_models_to_evaluate(checkpoint_dir):
"""Pick a checkpoint to evaluate that has not been evaluated already."""
logging.info("Picking checkpoint to evaluate from %s.", checkpoint_dir)
filenames = gfile.listdir(checkpoint_dir)
filenames = [f[:-5] for f in filenames if f[-5:] == ".meta"]
logging.info("Found existing checkpoints: %s", filenames)
evaluated_filenames = []
if gfile.exists(os.path.join(checkpoint_dir, EVAL_FILENAME)):
with gfile.GFile(os.path.join(checkpoint_dir, EVAL_FILENAME), "r") as f:
evaluated_filenames = [l.strip().split(",")[0] for l in f.readlines()]
logging.info("Found already evaluated checkpoints: %s", evaluated_filenames)
checkpoints_to_evaluate = [
f for f in filenames if f not in evaluated_filenames
]
logging.info("Remaining potential checkpoints: %s", checkpoints_to_evaluate)
if checkpoints_to_evaluate:
return os.path.join(checkpoint_dir, checkpoints_to_evaluate[0])
else:
return None
def get_embedding_path(data_dir, dataset):
"""By convention, this is where we store the embedding."""
return os.path.join(data_dir, "glove_%s.txt" % dataset)
def make_partially_trainable_embeddings(vocab_file, embedding_source,
vocab_size, trainable_embedding_size):
"""Makes embedding matrix with pretrained GloVe [1] part and trainable part.
[1] Pennington, J., Socher, R., & Manning, C. (2014, October). Glove: Global
vectors for word representation. In Proceedings of the 2014 conference on
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
Returns:
A matrix of partially pretrained embeddings.
"""
# Our embeddings have 2 parts: a pre-trained, frozen, GloVe part,
# and a trainable, randomly initialized part.
# The standard deviation of the GloVe part is used to initialize
# the trainable part, so that both part have roughly the same distribution.
#
# Let g_ij be the j-th coordinates of the GloVe embedding of the i-th word.
# So that 0 < i < |vocab| and 0 < j < 300.
# Then sum_ij (g_ij - sum_kl g_kl)^2 = (0.3836)^2
#
# In reality g_ij follows a truncated normal distribution
# min(max(N(0, s), -4.2), 4.2) but we approximate it by N(0, 0.3836).
embedding_initializer = _get_embedding_initializer(
vocab_file=vocab_file,
embedding_source=embedding_source,
vocab_size=vocab_size)
pretrained_embedding = tf.get_variable(
"pretrained_embedding",
initializer=embedding_initializer,
dtype=tf.float32)
trainable_embedding = tf.get_variable(
"trainable_embedding",
shape=[vocab_size, trainable_embedding_size],
initializer=tf.initializers.random_normal(mean=0.0, stddev=GLOVE_STD))
# We just concatenate embeddings, they will pass through a projection
# matrix afterwards.
embedding = tf.concat([pretrained_embedding, trainable_embedding], axis=1)
return embedding