mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Use bibtex format.
PiperOrigin-RevId: 281754451
This commit is contained in:
committed by
Diego de Las Casas
parent
db2393d5a9
commit
d705acd6a7
@@ -0,0 +1,61 @@
|
||||
# ScratchGAN
|
||||
|
||||
This is the example code for the following NeurIPS 2019 paper. If you use the
|
||||
code here please cite this paper:
|
||||
|
||||
@article{DBLP:journals/corr/abs-1905-09922,
|
||||
author = {Cyprien de Masson d'Autume and
|
||||
Mihaela Rosca and
|
||||
Jack W. Rae and
|
||||
Shakir Mohamed},
|
||||
title = {Training language GANs from Scratch},
|
||||
journal = {CoRR},
|
||||
volume = {abs/1905.09922},
|
||||
year = {2019},
|
||||
url = {http://arxiv.org/abs/1905.09922},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {1905.09922},
|
||||
timestamp = {Wed, 29 May 2019 11:27:50 +0200},
|
||||
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1905-09922},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
|
||||
|
||||
## Contents
|
||||
|
||||
The code contains:
|
||||
|
||||
* `generators.py`: implementation of the generator.
|
||||
* `discriminator_nets.py`: implementation of the discriminator.
|
||||
* `eval_metrics.py`: implementation of the FED metric.
|
||||
* `losses.py`: implementation of the RL loss for the generator.
|
||||
* `reader.py`: data reader / tokenizer.
|
||||
* `experiment.py`: main training script.
|
||||
|
||||
The data contains:
|
||||
|
||||
* `{train,valid,test}.json`: the EMNLP2017 News dataset.
|
||||
* `glove_emnlp2017.txt`: the relevant subset of GloVe embeddings.
|
||||
|
||||
## Running
|
||||
|
||||
Place the data files in the directory specified by `data_dir` flag.
|
||||
|
||||
Create and activate a virtual environment if needed:
|
||||
|
||||
virtualenv scratchgan-venv
|
||||
source scratchgan-venv/bin/activate
|
||||
|
||||
Install requirements:
|
||||
|
||||
pip install -r scratchgan/requirements.txt
|
||||
|
||||
Run training and evaluation jobs:
|
||||
|
||||
python2 scratchgan.experiment.py --mode="train" &
|
||||
python2 scratchgan.experiment.py --mode="evaluate_pair" &
|
||||
|
||||
The evaluation code is designed to run in parallel with the training.
|
||||
|
||||
The training code saves checkpoints periodically, the evaluation code
|
||||
looks for new checkpoints and evaluate them.
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,121 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Discriminator networks for text data."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
from scratchgan import utils
|
||||
|
||||
|
||||
class LSTMEmbedDiscNet(snt.AbstractModule):
|
||||
"""An LSTM discriminator that operates on word indexes."""
|
||||
|
||||
def __init__(self,
|
||||
feature_sizes,
|
||||
vocab_size,
|
||||
use_layer_norm,
|
||||
trainable_embedding_size,
|
||||
dropout,
|
||||
pad_token,
|
||||
embedding_source=None,
|
||||
vocab_file=None,
|
||||
name='LSTMEmbedDiscNet'):
|
||||
super(LSTMEmbedDiscNet, self).__init__(name=name)
|
||||
self._feature_sizes = feature_sizes
|
||||
self._vocab_size = vocab_size
|
||||
self._use_layer_norm = use_layer_norm
|
||||
self._trainable_embedding_size = trainable_embedding_size
|
||||
self._embedding_source = embedding_source
|
||||
self._vocab_file = vocab_file
|
||||
self._dropout = dropout
|
||||
self._pad_token = pad_token
|
||||
if self._embedding_source:
|
||||
assert vocab_file
|
||||
|
||||
def _build(self, sequence, sequence_length, is_training=True):
|
||||
"""Connect to the graph.
|
||||
|
||||
Args:
|
||||
sequence: A [batch_size, max_sequence_length] tensor of int. For example
|
||||
the indices of words as sampled by the generator.
|
||||
sequence_length: A [batch_size] tensor of int. Length of the sequence.
|
||||
is_training: Boolean, False to disable dropout.
|
||||
|
||||
Returns:
|
||||
A [batch_size, max_sequence_length, feature_size] tensor of floats. For
|
||||
each sequence in the batch, the features should (hopefully) allow to
|
||||
distinguish if the value at each timestep is real or generated.
|
||||
"""
|
||||
batch_size, max_sequence_length = sequence.shape.as_list()
|
||||
keep_prob = (1.0 - self._dropout) if is_training else 1.0
|
||||
|
||||
if self._embedding_source:
|
||||
all_embeddings = utils.make_partially_trainable_embeddings(
|
||||
self._vocab_file, self._embedding_source, self._vocab_size,
|
||||
self._trainable_embedding_size)
|
||||
else:
|
||||
all_embeddings = tf.get_variable(
|
||||
'trainable_embedding',
|
||||
shape=[self._vocab_size, self._trainable_embedding_size],
|
||||
trainable=True)
|
||||
_, self._embedding_size = all_embeddings.shape.as_list()
|
||||
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
|
||||
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
|
||||
embeddings.shape.assert_is_compatible_with(
|
||||
[batch_size, max_sequence_length, self._embedding_size])
|
||||
position_dim = 8
|
||||
embeddings_pos = utils.append_position_signal(embeddings, position_dim)
|
||||
embeddings_pos = tf.reshape(
|
||||
embeddings_pos,
|
||||
[batch_size * max_sequence_length, self._embedding_size + position_dim])
|
||||
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
|
||||
lstm_inputs = tf.reshape(
|
||||
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]])
|
||||
lstm_inputs.shape.assert_is_compatible_with(
|
||||
[batch_size, max_sequence_length, self._feature_sizes[0]])
|
||||
|
||||
encoder_cells = []
|
||||
for feature_size in self._feature_sizes:
|
||||
encoder_cells += [
|
||||
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
|
||||
]
|
||||
encoder_cell = snt.DeepRNN(encoder_cells)
|
||||
initial_state = encoder_cell.initial_state(batch_size)
|
||||
|
||||
hidden_states, _ = tf.nn.dynamic_rnn(
|
||||
cell=encoder_cell,
|
||||
inputs=lstm_inputs,
|
||||
sequence_length=sequence_length,
|
||||
initial_state=initial_state,
|
||||
swap_memory=True)
|
||||
|
||||
hidden_states.shape.assert_is_compatible_with(
|
||||
[batch_size, max_sequence_length,
|
||||
sum(self._feature_sizes)])
|
||||
logits = snt.BatchApply(snt.Linear(1))(hidden_states)
|
||||
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1])
|
||||
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])
|
||||
|
||||
# Mask past first PAD symbol
|
||||
#
|
||||
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
|
||||
# into account the sequence_length properly, because otherwise
|
||||
# the logits at a given timestep will depend on the inputs for all other
|
||||
# timesteps, including the ones that should be masked.
|
||||
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
|
||||
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
|
||||
return masked_logits_flat
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation metrics."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_gan as tfgan
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
def fid(generated_sentences, real_sentences):
|
||||
"""Compute FID rn sentences using pretrained universal sentence encoder.
|
||||
|
||||
Args:
|
||||
generated_sentences: list of N strings.
|
||||
real_sentences: list of N strings.
|
||||
|
||||
Returns:
|
||||
Frechet distance between activations.
|
||||
"""
|
||||
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
|
||||
real_embed = embed(real_sentences)
|
||||
generated_embed = embed(generated_sentences)
|
||||
distance = tfgan.eval.frechet_classifier_distance_from_activations(
|
||||
real_embed, generated_embed)
|
||||
|
||||
# Restrict the thread pool size to prevent excessive CPU usage.
|
||||
config = tf.ConfigProto()
|
||||
config.intra_op_parallelism_threads = 16
|
||||
config.inter_op_parallelism_threads = 16
|
||||
with tf.Session(config=config) as session:
|
||||
session.run(tf.global_variables_initializer())
|
||||
session.run(tf.tables_initializer())
|
||||
distance_np = session.run(distance)
|
||||
return distance_np
|
||||
@@ -0,0 +1,469 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training script for ScratchGAN."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.io import gfile
|
||||
from scratchgan import discriminator_nets
|
||||
from scratchgan import eval_metrics
|
||||
from scratchgan import generators
|
||||
from scratchgan import losses
|
||||
from scratchgan import reader
|
||||
from scratchgan import utils
|
||||
|
||||
flags.DEFINE_string("dataset", "emnlp2017", "Dataset.")
|
||||
flags.DEFINE_integer("batch_size", 512, "Batch size")
|
||||
flags.DEFINE_string("gen_type", "lstm", "Generator type.")
|
||||
flags.DEFINE_string("disc_type", "lstm", "Discriminator type.")
|
||||
flags.DEFINE_string("disc_loss_type", "ce", "Loss type.")
|
||||
flags.DEFINE_integer("gen_feature_size", 512, "Generator feature size.")
|
||||
flags.DEFINE_integer("disc_feature_size", 512, "Discriminator feature size.")
|
||||
flags.DEFINE_integer("num_layers_gen", 2, "Number of generator layers.")
|
||||
flags.DEFINE_integer("num_layers_disc", 1, "Number of discriminator layers.")
|
||||
flags.DEFINE_bool("layer_norm_gen", False, "Layer norm generator.")
|
||||
flags.DEFINE_bool("layer_norm_disc", True, "Layer norm discriminator.")
|
||||
flags.DEFINE_float("gen_input_dropout", 0.0, "Input dropout generator.")
|
||||
flags.DEFINE_float("gen_output_dropout", 0.0, "Input dropout discriminator.")
|
||||
flags.DEFINE_float("l2_gen", 0.0, "L2 regularization generator.")
|
||||
flags.DEFINE_float("l2_disc", 1e-6, "L2 regularization discriminator.")
|
||||
flags.DEFINE_float("disc_dropout", 0.1, "Dropout discriminator")
|
||||
flags.DEFINE_integer("trainable_embedding_size", 64,
|
||||
"Size of trainable embedding.")
|
||||
flags.DEFINE_bool("use_pretrained_embedding", True, "Use pretrained embedding.")
|
||||
flags.DEFINE_integer("num_steps", int(200 * 1000), "Number of training steps.")
|
||||
flags.DEFINE_integer("num_disc_updates", 1, "Number of discriminator updates.")
|
||||
flags.DEFINE_integer("num_gen_updates", 1, "Number of generator updates.")
|
||||
flags.DEFINE_string("data_dir", "/tmp/emnlp2017", "Directory where data is.")
|
||||
flags.DEFINE_float("gen_lr", 9.59e-5, "Learning rate generator.")
|
||||
flags.DEFINE_float("disc_lr", 9.38e-3, "Learning rate discriminator.")
|
||||
flags.DEFINE_float("gen_beta1", 0.5, "Beta1 for generator.")
|
||||
flags.DEFINE_float("disc_beta1", 0.5, "Beta1 for discriminator.")
|
||||
flags.DEFINE_float("gamma", 0.23, "Discount factor.")
|
||||
flags.DEFINE_float("baseline_decay", 0.08, "Baseline decay rate.")
|
||||
flags.DEFINE_string("mode", "train", "train or evaluate_pair.")
|
||||
flags.DEFINE_string("checkpoint_dir", "/tmp/emnlp2017/checkpoints/",
|
||||
"Directory for checkpoints.")
|
||||
flags.DEFINE_integer("export_every", 1000, "Frequency of checkpoint exports.")
|
||||
flags.DEFINE_integer("num_examples_for_eval", int(1e4),
|
||||
"Number of examples for evaluation")
|
||||
|
||||
EVALUATOR_SLEEP_PERIOD = 60 # Seconds evaluator sleeps if nothing to do.
|
||||
|
||||
|
||||
def main(_):
|
||||
config = flags.FLAGS
|
||||
|
||||
gfile.makedirs(config.checkpoint_dir)
|
||||
if config.mode == "train":
|
||||
train(config)
|
||||
elif config.mode == "evaluate_pair":
|
||||
while True:
|
||||
checkpoint_path = utils.maybe_pick_models_to_evaluate(
|
||||
checkpoint_dir=config.checkpoint_dir)
|
||||
if checkpoint_path:
|
||||
evaluate_pair(
|
||||
config=config,
|
||||
batch_size=config.batch_size,
|
||||
checkpoint_path=checkpoint_path,
|
||||
data_dir=config.data_dir,
|
||||
dataset=config.dataset,
|
||||
num_examples_for_eval=config.num_examples_for_eval)
|
||||
else:
|
||||
logging.info("No models to evaluate found, sleeping for %d seconds",
|
||||
EVALUATOR_SLEEP_PERIOD)
|
||||
time.sleep(EVALUATOR_SLEEP_PERIOD)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\""
|
||||
% (config.mode))
|
||||
|
||||
|
||||
def train(config):
|
||||
"""Train."""
|
||||
logging.info("Training.")
|
||||
|
||||
tf.reset_default_graph()
|
||||
np.set_printoptions(precision=4)
|
||||
|
||||
# Get data.
|
||||
raw_data = reader.get_raw_data(
|
||||
data_path=config.data_dir, dataset=config.dataset)
|
||||
train_data, valid_data, word_to_id = raw_data
|
||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
||||
vocab_size = len(word_to_id)
|
||||
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
|
||||
logging.info("Vocabulary size: %d", vocab_size)
|
||||
|
||||
iterator = reader.iterator(raw_data=train_data, batch_size=config.batch_size)
|
||||
iterator_valid = reader.iterator(
|
||||
raw_data=valid_data, batch_size=config.batch_size)
|
||||
|
||||
real_sequence = tf.placeholder(
|
||||
dtype=tf.int32,
|
||||
shape=[config.batch_size, max_length],
|
||||
name="real_sequence")
|
||||
real_sequence_length = tf.placeholder(
|
||||
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
|
||||
first_batch_np = iterator.next()
|
||||
valid_batch_np = iterator_valid.next()
|
||||
|
||||
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
|
||||
test_fake_batch = {
|
||||
"sequence":
|
||||
tf.constant(
|
||||
np.random.choice(
|
||||
vocab_size, size=[config.batch_size,
|
||||
max_length]).astype(np.int32)),
|
||||
"sequence_length":
|
||||
tf.constant(
|
||||
np.random.choice(max_length,
|
||||
size=[config.batch_size]).astype(np.int32)),
|
||||
}
|
||||
valid_batch = {k: tf.constant(v) for k, v in valid_batch_np.items()}
|
||||
|
||||
# Create generator.
|
||||
if config.use_pretrained_embedding:
|
||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||
vocab_file = "/tmp/vocab.txt"
|
||||
with gfile.GFile(vocab_file, "w") as f:
|
||||
for i in xrange(len(id_to_word)):
|
||||
f.write(id_to_word[i] + "\n")
|
||||
logging.info("Temporary vocab file: %s", vocab_file)
|
||||
else:
|
||||
embedding_source = None
|
||||
vocab_file = None
|
||||
|
||||
gen = generators.LSTMGen(
|
||||
vocab_size=vocab_size,
|
||||
feature_sizes=[config.gen_feature_size] * config.num_layers_gen,
|
||||
max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset],
|
||||
batch_size=config.batch_size,
|
||||
use_layer_norm=config.layer_norm_gen,
|
||||
trainable_embedding_size=config.trainable_embedding_size,
|
||||
input_dropout=config.gen_input_dropout,
|
||||
output_dropout=config.gen_output_dropout,
|
||||
pad_token=reader.PAD_INT,
|
||||
embedding_source=embedding_source,
|
||||
vocab_file=vocab_file,
|
||||
)
|
||||
gen_outputs = gen()
|
||||
|
||||
# Create discriminator.
|
||||
disc = discriminator_nets.LSTMEmbedDiscNet(
|
||||
vocab_size=vocab_size,
|
||||
feature_sizes=[config.disc_feature_size] * config.num_layers_disc,
|
||||
trainable_embedding_size=config.trainable_embedding_size,
|
||||
embedding_source=embedding_source,
|
||||
use_layer_norm=config.layer_norm_disc,
|
||||
pad_token=reader.PAD_INT,
|
||||
vocab_file=vocab_file,
|
||||
dropout=config.disc_dropout,
|
||||
)
|
||||
disc_logits_real = disc(
|
||||
sequence=real_sequence, sequence_length=real_sequence_length)
|
||||
disc_logits_fake = disc(
|
||||
sequence=gen_outputs["sequence"],
|
||||
sequence_length=gen_outputs["sequence_length"])
|
||||
|
||||
# Loss of the discriminator.
|
||||
if config.disc_loss_type == "ce":
|
||||
targets_real = tf.ones(
|
||||
[config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]])
|
||||
targets_fake = tf.zeros(
|
||||
[config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]])
|
||||
loss_real = losses.sequential_cross_entropy_loss(disc_logits_real,
|
||||
targets_real)
|
||||
loss_fake = losses.sequential_cross_entropy_loss(disc_logits_fake,
|
||||
targets_fake)
|
||||
disc_loss = 0.5 * loss_real + 0.5 * loss_fake
|
||||
|
||||
# Loss of the generator.
|
||||
gen_loss, cumulative_rewards, baseline = losses.reinforce_loss(
|
||||
disc_logits=disc_logits_fake,
|
||||
gen_logprobs=gen_outputs["logprobs"],
|
||||
gamma=config.gamma,
|
||||
decay=config.baseline_decay)
|
||||
|
||||
# Optimizers
|
||||
disc_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config.disc_lr, beta1=config.disc_beta1)
|
||||
gen_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config.gen_lr, beta1=config.gen_beta1)
|
||||
|
||||
# Get losses and variables.
|
||||
disc_vars = disc.get_all_variables()
|
||||
gen_vars = gen.get_all_variables()
|
||||
l2_disc = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in disc_vars]))
|
||||
l2_gen = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in gen_vars]))
|
||||
scalar_disc_loss = tf.reduce_mean(disc_loss) + config.l2_disc * l2_disc
|
||||
scalar_gen_loss = tf.reduce_mean(gen_loss) + config.l2_gen * l2_gen
|
||||
|
||||
# Update ops.
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
disc_update = disc_optimizer.minimize(
|
||||
scalar_disc_loss, var_list=disc_vars, global_step=global_step)
|
||||
gen_update = gen_optimizer.minimize(
|
||||
scalar_gen_loss, var_list=gen_vars, global_step=global_step)
|
||||
|
||||
# Saver.
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Metrics
|
||||
test_disc_logits_real = disc(**test_real_batch)
|
||||
test_disc_logits_fake = disc(**test_fake_batch)
|
||||
valid_disc_logits = disc(**valid_batch)
|
||||
disc_predictions_real = tf.nn.sigmoid(disc_logits_real)
|
||||
disc_predictions_fake = tf.nn.sigmoid(disc_logits_fake)
|
||||
valid_disc_predictions = tf.reduce_mean(
|
||||
tf.nn.sigmoid(valid_disc_logits), axis=0)
|
||||
test_disc_predictions_real = tf.reduce_mean(
|
||||
tf.nn.sigmoid(test_disc_logits_real), axis=0)
|
||||
test_disc_predictions_fake = tf.reduce_mean(
|
||||
tf.nn.sigmoid(test_disc_logits_fake), axis=0)
|
||||
|
||||
# Only log results for the first element of the batch.
|
||||
metrics = {
|
||||
"scalar_gen_loss": scalar_gen_loss,
|
||||
"scalar_disc_loss": scalar_disc_loss,
|
||||
"disc_predictions_real": tf.reduce_mean(disc_predictions_real),
|
||||
"disc_predictions_fake": tf.reduce_mean(disc_predictions_fake),
|
||||
"test_disc_predictions_real": tf.reduce_mean(test_disc_predictions_real),
|
||||
"test_disc_predictions_fake": tf.reduce_mean(test_disc_predictions_fake),
|
||||
"valid_disc_predictions": tf.reduce_mean(valid_disc_predictions),
|
||||
"cumulative_rewards": tf.reduce_mean(cumulative_rewards),
|
||||
"baseline": tf.reduce_mean(baseline),
|
||||
}
|
||||
|
||||
# Training.
|
||||
logging.info("Starting training")
|
||||
with tf.Session() as sess:
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
latest_ckpt = tf.train.latest_checkpoint(config.checkpoint_dir)
|
||||
if latest_ckpt:
|
||||
saver.restore(sess, latest_ckpt)
|
||||
|
||||
for step in xrange(config.num_steps):
|
||||
real_data_np = iterator.next()
|
||||
train_feed = {
|
||||
real_sequence: real_data_np["sequence"],
|
||||
real_sequence_length: real_data_np["sequence_length"],
|
||||
}
|
||||
|
||||
# Update generator and discriminator.
|
||||
for _ in xrange(config.num_disc_updates):
|
||||
sess.run(disc_update, feed_dict=train_feed)
|
||||
for _ in xrange(config.num_gen_updates):
|
||||
sess.run(gen_update, feed_dict=train_feed)
|
||||
|
||||
# Reporting
|
||||
if step % config.export_every == 0:
|
||||
gen_sequence_np, metrics_np = sess.run(
|
||||
[gen_outputs["sequence"], metrics], feed_dict=train_feed)
|
||||
metrics_np["gen_sentence"] = utils.sequence_to_sentence(
|
||||
gen_sequence_np[0, :], id_to_word)
|
||||
saver.save(
|
||||
sess,
|
||||
save_path=config.checkpoint_dir + "scratchgan",
|
||||
global_step=global_step)
|
||||
metrics_np["model_path"] = tf.train.latest_checkpoint(
|
||||
config.checkpoint_dir)
|
||||
logging.info(metrics_np)
|
||||
|
||||
# After training, export models.
|
||||
saver.save(
|
||||
sess,
|
||||
save_path=config.checkpoint_dir + "scratchgan",
|
||||
global_step=global_step)
|
||||
logging.info("Saved final model at %s.",
|
||||
tf.train.latest_checkpoint(config.checkpoint_dir))
|
||||
|
||||
|
||||
def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
||||
num_examples_for_eval):
|
||||
"""Evaluates a pair generator discriminator.
|
||||
|
||||
This function loads a discriminator from disk, a generator, and evaluates the
|
||||
discriminator against the generator.
|
||||
|
||||
It returns the mean probability of the discriminator against several batches,
|
||||
and the FID of the generator against the validation data.
|
||||
|
||||
It also writes evaluation samples to disk.
|
||||
|
||||
Args:
|
||||
config: dict, the config file.
|
||||
batch_size: int, size of the batch.
|
||||
checkpoint_path: string, full path to the TF checkpoint on disk.
|
||||
data_dir: string, path to a directory containing the dataset.
|
||||
dataset: string, "emnlp2017", to select the right dataset.
|
||||
num_examples_for_eval: int, number of examples for evaluation.
|
||||
"""
|
||||
tf.reset_default_graph()
|
||||
logging.info("Evaluating checkpoint %s.", checkpoint_path)
|
||||
|
||||
# Build graph.
|
||||
train_data, valid_data, word_to_id = reader.get_raw_data(
|
||||
data_dir, dataset=dataset)
|
||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
||||
vocab_size = len(word_to_id)
|
||||
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
|
||||
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
|
||||
train_sequence = tf.placeholder(
|
||||
dtype=tf.int32,
|
||||
shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]],
|
||||
name="train_sequence")
|
||||
train_sequence_length = tf.placeholder(
|
||||
dtype=tf.int32, shape=[batch_size], name="train_sequence_length")
|
||||
valid_sequence = tf.placeholder(
|
||||
dtype=tf.int32,
|
||||
shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]],
|
||||
name="valid_sequence")
|
||||
valid_sequence_length = tf.placeholder(
|
||||
dtype=tf.int32, shape=[batch_size], name="valid_sequence_length")
|
||||
disc_inputs_train = {
|
||||
"sequence": train_sequence,
|
||||
"sequence_length": train_sequence_length,
|
||||
}
|
||||
disc_inputs_valid = {
|
||||
"sequence": valid_sequence,
|
||||
"sequence_length": valid_sequence_length,
|
||||
}
|
||||
if config.use_pretrained_embedding:
|
||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||
vocab_file = "/tmp/vocab.txt"
|
||||
with gfile.GFile(vocab_file, "w") as f:
|
||||
for i in xrange(len(id_to_word)):
|
||||
f.write(id_to_word[i] + "\n")
|
||||
logging.info("Temporary vocab file: %s", vocab_file)
|
||||
else:
|
||||
embedding_source = None
|
||||
vocab_file = None
|
||||
gen = generators.LSTMGen(
|
||||
vocab_size=vocab_size,
|
||||
feature_sizes=[config.gen_feature_size] * config.num_layers_gen,
|
||||
max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset],
|
||||
batch_size=config.batch_size,
|
||||
use_layer_norm=config.layer_norm_gen,
|
||||
trainable_embedding_size=config.trainable_embedding_size,
|
||||
input_dropout=config.gen_input_dropout,
|
||||
output_dropout=config.gen_output_dropout,
|
||||
pad_token=reader.PAD_INT,
|
||||
embedding_source=embedding_source,
|
||||
vocab_file=vocab_file,
|
||||
)
|
||||
gen_outputs = gen()
|
||||
|
||||
disc = discriminator_nets.LSTMEmbedDiscNet(
|
||||
vocab_size=vocab_size,
|
||||
feature_sizes=[config.disc_feature_size] * config.num_layers_disc,
|
||||
trainable_embedding_size=config.trainable_embedding_size,
|
||||
embedding_source=embedding_source,
|
||||
use_layer_norm=config.layer_norm_disc,
|
||||
pad_token=reader.PAD_INT,
|
||||
vocab_file=vocab_file,
|
||||
dropout=config.disc_dropout,
|
||||
)
|
||||
|
||||
disc_inputs = {
|
||||
"sequence": gen_outputs["sequence"],
|
||||
"sequence_length": gen_outputs["sequence_length"],
|
||||
}
|
||||
gen_logits = disc(**disc_inputs)
|
||||
train_logits = disc(**disc_inputs_train)
|
||||
valid_logits = disc(**disc_inputs_valid)
|
||||
|
||||
# Saver.
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Reduce over time and batch.
|
||||
train_probs = tf.reduce_mean(tf.nn.sigmoid(train_logits))
|
||||
valid_probs = tf.reduce_mean(tf.nn.sigmoid(valid_logits))
|
||||
gen_probs = tf.reduce_mean(tf.nn.sigmoid(gen_logits))
|
||||
|
||||
outputs = {
|
||||
"train_probs": train_probs,
|
||||
"valid_probs": valid_probs,
|
||||
"gen_probs": gen_probs,
|
||||
"gen_sequences": gen_outputs["sequence"],
|
||||
"valid_sequences": valid_sequence
|
||||
}
|
||||
|
||||
# Get average discriminator score and store generated sequences.
|
||||
all_valid_sentences = []
|
||||
all_gen_sentences = []
|
||||
all_gen_sequences = []
|
||||
mean_train_prob = 0.0
|
||||
mean_valid_prob = 0.0
|
||||
mean_gen_prob = 0.0
|
||||
|
||||
logging.info("Graph constructed, generating batches.")
|
||||
num_batches = num_examples_for_eval // batch_size + 1
|
||||
|
||||
# Restrict the thread pool size to prevent excessive GCU usage on Borg.
|
||||
tf_config = tf.ConfigProto()
|
||||
tf_config.intra_op_parallelism_threads = 16
|
||||
tf_config.inter_op_parallelism_threads = 16
|
||||
|
||||
with tf.Session(config=tf_config) as sess:
|
||||
|
||||
# Restore variables from checkpoints.
|
||||
logging.info("Restoring variables.")
|
||||
saver.restore(sess, checkpoint_path)
|
||||
|
||||
for i in xrange(num_batches):
|
||||
logging.info("Batch %d / %d", i, num_batches)
|
||||
train_data_np = train_iterator.next()
|
||||
valid_data_np = valid_iterator.next()
|
||||
feed_dict = {
|
||||
train_sequence: train_data_np["sequence"],
|
||||
train_sequence_length: train_data_np["sequence_length"],
|
||||
valid_sequence: valid_data_np["sequence"],
|
||||
valid_sequence_length: valid_data_np["sequence_length"],
|
||||
}
|
||||
outputs_np = sess.run(outputs, feed_dict=feed_dict)
|
||||
all_gen_sequences.extend(outputs_np["gen_sequences"])
|
||||
gen_sentences = utils.batch_sequences_to_sentences(
|
||||
outputs_np["gen_sequences"], id_to_word)
|
||||
valid_sentences = utils.batch_sequences_to_sentences(
|
||||
outputs_np["valid_sequences"], id_to_word)
|
||||
all_valid_sentences.extend(valid_sentences)
|
||||
all_gen_sentences.extend(gen_sentences)
|
||||
mean_train_prob += outputs_np["train_probs"] / batch_size
|
||||
mean_valid_prob += outputs_np["valid_probs"] / batch_size
|
||||
mean_gen_prob += outputs_np["gen_probs"] / batch_size
|
||||
|
||||
logging.info("Evaluating FID.")
|
||||
|
||||
# Compute FID
|
||||
fid = eval_metrics.fid(
|
||||
generated_sentences=all_gen_sentences[:num_examples_for_eval],
|
||||
real_sentences=all_valid_sentences[:num_examples_for_eval])
|
||||
|
||||
utils.write_eval_results(config.checkpoint_dir, all_gen_sentences,
|
||||
os.path.basename(checkpoint_path), mean_train_prob,
|
||||
mean_valid_prob, mean_gen_prob, fid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -0,0 +1,147 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generators for text data."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
from scratchgan import utils
|
||||
|
||||
|
||||
class LSTMGen(snt.AbstractModule):
|
||||
"""A multi-layer LSTM language model.
|
||||
|
||||
Uses tied input/output embedding weights.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
feature_sizes,
|
||||
max_sequence_length,
|
||||
batch_size,
|
||||
use_layer_norm,
|
||||
trainable_embedding_size,
|
||||
input_dropout,
|
||||
output_dropout,
|
||||
pad_token,
|
||||
embedding_source=None,
|
||||
vocab_file=None,
|
||||
name='lstm_gen'):
|
||||
super(LSTMGen, self).__init__(name=name)
|
||||
self._feature_sizes = feature_sizes
|
||||
self._max_sequence_length = max_sequence_length
|
||||
self._vocab_size = vocab_size
|
||||
self._batch_size = batch_size
|
||||
self._use_layer_norm = use_layer_norm
|
||||
self._trainable_embedding_size = trainable_embedding_size
|
||||
self._embedding_source = embedding_source
|
||||
self._vocab_file = vocab_file
|
||||
self._input_dropout = input_dropout
|
||||
self._output_dropout = output_dropout
|
||||
self._pad_token = pad_token
|
||||
if self._embedding_source:
|
||||
assert vocab_file
|
||||
|
||||
def _build(self, is_training=True, temperature=1.0):
|
||||
input_keep_prob = (1. - self._input_dropout) if is_training else 1.0
|
||||
output_keep_prob = (1. - self._output_dropout) if is_training else 1.0
|
||||
|
||||
batch_size = self._batch_size
|
||||
max_sequence_length = self._max_sequence_length
|
||||
if self._embedding_source:
|
||||
all_embeddings = utils.make_partially_trainable_embeddings(
|
||||
self._vocab_file, self._embedding_source, self._vocab_size,
|
||||
self._trainable_embedding_size)
|
||||
else:
|
||||
all_embeddings = tf.get_variable(
|
||||
'trainable_embeddings',
|
||||
shape=[self._vocab_size, self._trainable_embedding_size],
|
||||
trainable=True)
|
||||
_, self._embedding_size = all_embeddings.shape.as_list()
|
||||
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=input_keep_prob)
|
||||
output_embeddings = tf.nn.dropout(
|
||||
all_embeddings, keep_prob=output_keep_prob)
|
||||
|
||||
out_bias = tf.get_variable(
|
||||
'out_bias', shape=[1, self._vocab_size], dtype=tf.float32)
|
||||
in_proj = tf.get_variable(
|
||||
'in_proj', shape=[self._embedding_size, self._feature_sizes[0]])
|
||||
# If more than 1 layer, then output has dim sum(self._feature_sizes),
|
||||
# which is different from input dim == self._feature_sizes[0]
|
||||
# So we need a different projection matrix for input and output.
|
||||
if len(self._feature_sizes) > 1:
|
||||
out_proj = tf.get_variable(
|
||||
'out_proj', shape=[self._embedding_size,
|
||||
sum(self._feature_sizes)])
|
||||
else:
|
||||
out_proj = in_proj
|
||||
|
||||
encoder_cells = []
|
||||
for feature_size in self._feature_sizes:
|
||||
encoder_cells += [
|
||||
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
|
||||
]
|
||||
encoder_cell = snt.DeepRNN(encoder_cells)
|
||||
state = encoder_cell.initial_state(batch_size)
|
||||
|
||||
# Manual unrolling.
|
||||
samples_list, logits_list, logprobs_list, embeddings_list = [], [], [], []
|
||||
sample = tf.tile(
|
||||
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
|
||||
logging.info('Unrolling over %d steps.', max_sequence_length)
|
||||
for _ in xrange(max_sequence_length):
|
||||
# Input is sampled word at t-1.
|
||||
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
|
||||
embedding.shape.assert_is_compatible_with(
|
||||
[batch_size, self._embedding_size])
|
||||
embedding_proj = tf.matmul(embedding, in_proj)
|
||||
embedding_proj.shape.assert_is_compatible_with(
|
||||
[batch_size, self._feature_sizes[0]])
|
||||
|
||||
outputs, state = encoder_cell(embedding_proj, state)
|
||||
outputs_proj = tf.matmul(outputs, out_proj, transpose_b=True)
|
||||
logits = tf.matmul(
|
||||
outputs_proj, output_embeddings, transpose_b=True) + out_bias
|
||||
categorical = tfp.distributions.Categorical(logits=logits/temperature)
|
||||
sample = categorical.sample()
|
||||
logprobs = categorical.log_prob(sample)
|
||||
|
||||
samples_list.append(sample)
|
||||
logits_list.append(logits)
|
||||
logprobs_list.append(logprobs)
|
||||
embeddings_list.append(embedding)
|
||||
|
||||
# Create an op to retrieve embeddings for full sequence, useful for testing.
|
||||
embeddings = tf.stack( # pylint: disable=unused-variable
|
||||
embeddings_list,
|
||||
axis=1,
|
||||
name='embeddings')
|
||||
sequence = tf.stack(samples_list, axis=1)
|
||||
logprobs = tf.stack(logprobs_list, axis=1)
|
||||
|
||||
# The sequence stops after the first occurrence of a PAD token.
|
||||
sequence_length = utils.get_first_occurrence_indices(
|
||||
sequence, self._pad_token)
|
||||
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
|
||||
masked_sequence = sequence * tf.cast(mask, tf.int32)
|
||||
masked_logprobs = logprobs * tf.cast(mask, tf.float32)
|
||||
return {
|
||||
'sequence': masked_sequence,
|
||||
'sequence_length': sequence_length,
|
||||
'logprobs': masked_logprobs
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Losses for sequential GANs."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def sequential_cross_entropy_loss(logits, expected):
|
||||
"""The cross entropy loss for binary classification.
|
||||
|
||||
Used to train the discriminator when not using WGAN loss.
|
||||
Assume logits is the log probability of classifying as 1. (real).
|
||||
|
||||
Args:
|
||||
logits: a `tf.Tensor`, the model produced logits, shape [batch_size,
|
||||
sequence_length].
|
||||
expected: a `tf.Tensor`, the expected output, shape [batch_size,
|
||||
sequence_length].
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor`, the average loss obtained on the given inputs.
|
||||
"""
|
||||
batch_size, sequence_length = logits.shape.as_list()
|
||||
expected = tf.cast(expected, tf.float32)
|
||||
|
||||
ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=expected, logits=logits)
|
||||
return tf.reshape(ce, [batch_size, sequence_length])
|
||||
|
||||
|
||||
def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
|
||||
"""The REINFORCE loss.
|
||||
|
||||
Args:
|
||||
disc_logits: float tensor, shape [batch_size, sequence_length].
|
||||
gen_logprobs: float32 tensor, shape [batch_size, sequence_length]
|
||||
gamma: a float, discount factor for cumulative reward.
|
||||
decay: a float, decay rate for the EWMA baseline of REINFORCE.
|
||||
|
||||
Returns:
|
||||
Float tensor, shape [batch_size, sequence_length], the REINFORCE loss for
|
||||
each timestep.
|
||||
"""
|
||||
# Assume 1 logit for each timestep.
|
||||
batch_size, sequence_length = disc_logits.shape.as_list()
|
||||
gen_logprobs.shape.assert_is_compatible_with([batch_size, sequence_length])
|
||||
|
||||
disc_predictions = tf.nn.sigmoid(disc_logits)
|
||||
|
||||
# MaskGAN uses log(D), but this is more stable empirically.
|
||||
rewards = 2.0 * disc_predictions - 1
|
||||
|
||||
# Compute cumulative rewards.
|
||||
rewards_list = tf.unstack(rewards, axis=1)
|
||||
cumulative_rewards = []
|
||||
for t in xrange(sequence_length):
|
||||
cum_value = tf.zeros(shape=[batch_size])
|
||||
for s in xrange(t, sequence_length):
|
||||
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
|
||||
cumulative_rewards.append(cum_value)
|
||||
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
|
||||
|
||||
cumulative_rewards.shape.assert_is_compatible_with(
|
||||
[batch_size, sequence_length])
|
||||
|
||||
with tf.variable_scope("reinforce", reuse=tf.AUTO_REUSE):
|
||||
ewma_reward = tf.get_variable("ewma_reward", initializer=0.0)
|
||||
|
||||
mean_reward = tf.reduce_mean(cumulative_rewards)
|
||||
new_ewma_reward = decay * ewma_reward + (1.0 - decay) * mean_reward
|
||||
update_op = tf.assign(ewma_reward, new_ewma_reward)
|
||||
|
||||
# REINFORCE
|
||||
with tf.control_dependencies([update_op]):
|
||||
advantage = cumulative_rewards - ewma_reward
|
||||
loss = -tf.stop_gradient(advantage) * gen_logprobs
|
||||
|
||||
loss.shape.assert_is_compatible_with([batch_size, sequence_length])
|
||||
return loss, cumulative_rewards, ewma_reward
|
||||
@@ -0,0 +1,174 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for parsing text files."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from tensorflow.io import gfile
|
||||
|
||||
# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32
|
||||
# lengths: [N, 2] array of int32, such that
|
||||
# lengths[i, 0] is the number of non-pad tokens in sequences[i, :]
|
||||
# TODO(cyprien): handle PTB
|
||||
FILENAMES = {
|
||||
"emnlp2017": ("train.json", "valid.json", "test.json"),
|
||||
}
|
||||
|
||||
# EMNLP2017 sentences have max length 50, add one for a PAD token so that all
|
||||
# sentences end with PAD.
|
||||
MAX_TOKENS_SEQUENCE = {"emnlp2017": 52}
|
||||
|
||||
UNK = "<unk>"
|
||||
PAD = " "
|
||||
|
||||
PAD_INT = 0
|
||||
|
||||
|
||||
def tokenize(sentence):
|
||||
"""Split a string into words."""
|
||||
return sentence.split(" ") + [PAD]
|
||||
|
||||
|
||||
def _build_vocab(json_data):
|
||||
"""Builds full vocab from json data."""
|
||||
vocab = collections.Counter()
|
||||
for sentence in json_data:
|
||||
tokens = tokenize(sentence["s"])
|
||||
vocab.update(tokens)
|
||||
for title in sentence["t"]:
|
||||
title_tokens = tokenize(title)
|
||||
vocab.update(title_tokens)
|
||||
# Most common words first.
|
||||
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
|
||||
words, _ = zip(*count_pairs)
|
||||
words = list(words)
|
||||
if UNK not in words:
|
||||
words = [UNK] + words
|
||||
word_to_id = dict(zip(words, range(len(words))))
|
||||
|
||||
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
|
||||
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
|
||||
# currently at the `PAD_INT` index with the `PAD` token.
|
||||
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
|
||||
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
|
||||
assert word_to_id[PAD] == PAD_INT
|
||||
|
||||
return word_to_id
|
||||
|
||||
|
||||
def string_sequence_to_sequence(string_sequence, word_to_id):
|
||||
result = []
|
||||
for word in string_sequence:
|
||||
if word in word_to_id:
|
||||
result.append(word_to_id[word])
|
||||
else:
|
||||
result.append(word_to_id[UNK])
|
||||
return result
|
||||
|
||||
|
||||
def _integerize(json_data, word_to_id, dataset):
|
||||
"""Transform words into integers."""
|
||||
sequences = np.full((len(json_data), MAX_TOKENS_SEQUENCE[dataset]),
|
||||
word_to_id[PAD], np.int32)
|
||||
sequence_lengths = np.zeros(shape=(len(json_data)), dtype=np.int32)
|
||||
for i, sentence in enumerate(json_data):
|
||||
sequence_i = string_sequence_to_sequence(
|
||||
tokenize(sentence["s"]), word_to_id)
|
||||
sequence_lengths[i] = len(sequence_i)
|
||||
sequences[i, :sequence_lengths[i]] = np.array(sequence_i)
|
||||
return {
|
||||
"sequences": sequences,
|
||||
"sequence_lengths": sequence_lengths,
|
||||
}
|
||||
|
||||
|
||||
def get_raw_data(data_path, dataset, truncate_vocab=20000):
|
||||
"""Load raw data from data directory "data_path".
|
||||
|
||||
Reads text files, converts strings to integer ids,
|
||||
and performs mini-batching of the inputs.
|
||||
|
||||
Args:
|
||||
data_path: string path to the directory where simple-examples.tgz has been
|
||||
extracted.
|
||||
dataset: one of ["emnlp2017"]
|
||||
truncate_vocab: int, number of words to keep in the vocabulary.
|
||||
|
||||
Returns:
|
||||
tuple (train_data, valid_data, vocabulary) where each of the data
|
||||
objects can be passed to iterator.
|
||||
|
||||
Raises:
|
||||
ValueError: dataset not in ["emnlp2017"].
|
||||
"""
|
||||
if dataset not in FILENAMES:
|
||||
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
|
||||
dataset, FILENAMES.keys()))
|
||||
train_file, valid_file, _ = FILENAMES[dataset]
|
||||
|
||||
train_path = os.path.join(data_path, train_file)
|
||||
valid_path = os.path.join(data_path, valid_file)
|
||||
|
||||
with gfile.GFile(train_path, "r") as json_file:
|
||||
json_data_train = json.load(json_file)
|
||||
with gfile.GFile(valid_path, "r") as json_file:
|
||||
json_data_valid = json.load(json_file)
|
||||
|
||||
word_to_id = _build_vocab(json_data_train)
|
||||
logging.info("Full vocab length: %d", len(word_to_id))
|
||||
# Assume the vocab is sorted by frequency.
|
||||
word_to_id_truncated = {
|
||||
k: v for k, v in word_to_id.items() if v < truncate_vocab
|
||||
}
|
||||
logging.info("Truncated vocab length: %d", len(word_to_id_truncated))
|
||||
|
||||
train_data = _integerize(json_data_train, word_to_id_truncated, dataset)
|
||||
valid_data = _integerize(json_data_valid, word_to_id_truncated, dataset)
|
||||
return train_data, valid_data, word_to_id_truncated
|
||||
|
||||
|
||||
def iterator(raw_data, batch_size, random=False):
|
||||
"""Looping iterators on the raw data."""
|
||||
sequences = raw_data["sequences"]
|
||||
sequence_lengths = raw_data["sequence_lengths"]
|
||||
|
||||
num_examples = sequences.shape[0]
|
||||
indice_range = np.arange(num_examples)
|
||||
|
||||
if random:
|
||||
while True:
|
||||
indices = np.random.choice(indice_range, size=batch_size, replace=True)
|
||||
yield {
|
||||
"sequence": sequences[indices, :],
|
||||
"sequence_length": sequence_lengths[indices],
|
||||
}
|
||||
else:
|
||||
start = 0
|
||||
while True:
|
||||
sequence = sequences[start:(start + batch_size), :]
|
||||
sequence_length = sequence_lengths[start:(start + batch_size)]
|
||||
start += batch_size
|
||||
if start + batch_size > num_examples:
|
||||
start = (start + batch_size) % num_examples
|
||||
yield {
|
||||
"sequence": sequence,
|
||||
"sequence_length": sequence_length,
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
absl-py==0.7.1
|
||||
dm-sonnet==1.34
|
||||
numpy==1.16.4
|
||||
tensorflow==1.14
|
||||
tensorflow-probability==0.7.0
|
||||
tensorflow-gan == 1.0.0.dev0
|
||||
tensorflow-hub == 0.6.0
|
||||
tensorflow-io == 0.8.0
|
||||
@@ -0,0 +1,308 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import math
|
||||
import os
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.io import gfile
|
||||
from scratchgan import reader
|
||||
|
||||
EVAL_FILENAME = "evaluated_checkpoints.csv"
|
||||
GLOVE_DIM = 300
|
||||
GLOVE_STD = 0.3836 # Standard dev. of GloVe embeddings.
|
||||
|
||||
|
||||
def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
|
||||
"""Loads pretrained embeddings from a file in GloVe format."""
|
||||
with gfile.GFile(embedding_source, "r") as f:
|
||||
embedding_lines = f.readlines()
|
||||
|
||||
# First line contains embedding dim.
|
||||
_, embedding_dim = map(int, embedding_lines[0].split())
|
||||
# Get the tokens as strings.
|
||||
tokens = [line.split()[0] for line in embedding_lines[1:]]
|
||||
# Get the actual embedding matrix.
|
||||
unsorted_emb = np.array(
|
||||
[[float(x) for x in line.split()[1:]] for line in embedding_lines[1:]])
|
||||
|
||||
# Get the expected vocab order.
|
||||
with gfile.GFile(vocab_file, "r") as f:
|
||||
tokens_order = [l.strip() for l in f.readlines()]
|
||||
assert vocab_size == len(tokens_order)
|
||||
|
||||
# Put the embeddings in the order.
|
||||
sorted_emb = np.zeros((vocab_size, embedding_dim))
|
||||
for i, token in enumerate(tokens_order):
|
||||
if token in tokens:
|
||||
sorted_emb[i, :] = unsorted_emb[tokens.index(token), :]
|
||||
else: # If we don't have a pretrained embedding, initialize randomly.
|
||||
sorted_emb[i, :] = np.random.normal(
|
||||
loc=0.0, scale=GLOVE_STD, size=(GLOVE_DIM,))
|
||||
|
||||
return sorted_emb.astype(np.float32)
|
||||
|
||||
|
||||
def append_position_signal(embeddings, position_dim=8):
|
||||
"""Append position signal. See get_position_signal."""
|
||||
batch_size, sequence_length, embedding_dim = embeddings.get_shape().as_list()
|
||||
positions = get_position_signal(sequence_length, position_dim)
|
||||
|
||||
# Append to embeddings.
|
||||
position_inputs = tf.tile(positions[None, :, :], [batch_size, 1, 1])
|
||||
embeddings_pos = tf.concat([embeddings, position_inputs], axis=2)
|
||||
embeddings_pos.shape.assert_is_compatible_with(
|
||||
[batch_size, sequence_length, embedding_dim + position_dim])
|
||||
return embeddings_pos
|
||||
|
||||
|
||||
def get_position_signal(sequence_length, position_dim=8):
|
||||
"""Return fixed position signal as sine waves.
|
||||
|
||||
Sine waves frequencies are linearly spaced so that shortest is 2 and
|
||||
longest is half the maximum length. That way the longest frequency
|
||||
is long enough to be monotonous over the whole sequence length.
|
||||
Sine waves are also shifted so that they don't all start with the same
|
||||
value.
|
||||
We don't use learned positional embeddings because these embeddings are
|
||||
projected linearly along with the original embeddings, and the projection is
|
||||
learned.
|
||||
|
||||
Args:
|
||||
sequence_length: int, T, length of the sequence..
|
||||
position_dim: int, P, number of sine waves.
|
||||
|
||||
Returns:
|
||||
A [T, P] tensor, position embeddings.
|
||||
"""
|
||||
# Compute the frequencies.
|
||||
periods = tf.exp(
|
||||
tf.lin_space(
|
||||
tf.log(2.0), tf.log(tf.to_float(sequence_length)), position_dim))
|
||||
frequencies = 1.0 / periods # Shape [T, P].
|
||||
|
||||
# Compute the sine waves.
|
||||
xs = frequencies[None, :] * tf.to_float(tf.range(sequence_length)[:, None])
|
||||
shifts = tf.lin_space(0.0, 2.0, position_dim)[None, :] # [1, P]
|
||||
positions = tf.math.cos(math.pi * (xs + shifts)) # [T, P]
|
||||
positions.shape.assert_is_compatible_with([sequence_length, position_dim])
|
||||
return positions
|
||||
|
||||
|
||||
def get_mask_by_length(lengths, max_length):
|
||||
"""Returns a mask where x[i , j] = (j < lengths[i]).
|
||||
|
||||
Args:
|
||||
lengths: [B] tensor of int32 such that 0 <= lengths[i] <= max_length.
|
||||
max_length: scalar tensor of int32.
|
||||
|
||||
Returns:
|
||||
[B, max_length] tensor of booleans such that x[i, j] is True
|
||||
if and only if j < lengths[i].
|
||||
"""
|
||||
batch_size = lengths.get_shape().as_list()[0]
|
||||
indices = tf.range(start=0, limit=max_length)
|
||||
all_indices = tf.tile(indices[None, :], [batch_size, 1])
|
||||
all_lengths = tf.tile(lengths[:, None], [1, max_length])
|
||||
mask = (all_indices < all_lengths)
|
||||
mask_boolean = tf.cast(mask, tf.bool)
|
||||
return mask_boolean
|
||||
|
||||
|
||||
def get_mask_past_symbol(reference, symbol, optimize_for_tpu=False):
|
||||
"""For each row, mask is True before and at the first occurrence of symbol."""
|
||||
batch_size, max_length = reference.get_shape().as_list()
|
||||
symbol = tf.convert_to_tensor(symbol)
|
||||
symbol.shape.assert_is_compatible_with([])
|
||||
|
||||
first_indices = get_first_occurrence_indices(reference, symbol,
|
||||
optimize_for_tpu)
|
||||
first_indices.shape.assert_is_compatible_with([batch_size])
|
||||
|
||||
keep_lengths = tf.minimum(first_indices, max_length)
|
||||
mask = get_mask_by_length(keep_lengths, max_length)
|
||||
mask.shape.assert_is_compatible_with([batch_size, max_length])
|
||||
mask.set_shape([batch_size, max_length])
|
||||
return mask
|
||||
|
||||
|
||||
def get_first_occurrence_indices(reference, symbol, optimize_for_tpu=False):
|
||||
"""For each row in reference, get index after the first occurrence of symbol.
|
||||
|
||||
If symbol is not present on a row, return reference.shape[1] instead.
|
||||
|
||||
Args:
|
||||
reference: [B, T] tensor of elements of the same type as symbol.
|
||||
symbol: int or [] scalar tensor of the same dtype as symbol.
|
||||
optimize_for_tpu: bool, whether to use a TPU-capable variant.
|
||||
|
||||
Returns:
|
||||
A [B] reference of tf.int32 where x[i] is such that
|
||||
reference[i, x[i]-1] == symbol, and reference[i, j] != symbol
|
||||
for j<i-1. If symbol is not present on row i then x[i] = T.
|
||||
"""
|
||||
if optimize_for_tpu:
|
||||
# Run code which can be compiled on TPU.
|
||||
# Transpose refernce to [T, B]
|
||||
reference = tf.transpose(reference, [1, 0])
|
||||
range_tensor = tf.range(reference.shape.as_list()[0])
|
||||
indexes = tf.stack([range_tensor] * reference.shape.as_list()[1], 1)
|
||||
symbol = tf.stack([symbol] * reference.shape.as_list()[1], 0)
|
||||
|
||||
initial_indices = tf.constant(
|
||||
reference.shape.as_list()[0],
|
||||
shape=[reference.shape.as_list()[1]],
|
||||
dtype=tf.int32)
|
||||
|
||||
# We want a function which moves backwards.
|
||||
def fn(current_index, elems):
|
||||
ref, ind = elems
|
||||
return tf.where(tf.equal(ref, symbol), ind + 1, current_index)
|
||||
|
||||
min_indexes = tf.scan(
|
||||
fn, (reference, indexes),
|
||||
initializer=initial_indices,
|
||||
parallel_iterations=1,
|
||||
reverse=True)
|
||||
return min_indexes[0]
|
||||
|
||||
batch_size, max_length = reference.get_shape().as_list()
|
||||
symbol = tf.convert_to_tensor(symbol)
|
||||
symbol.shape.assert_is_compatible_with([])
|
||||
# Add symbol at the end of each row, to make sure tf.where works.
|
||||
tensor = tf.concat(
|
||||
[reference, tf.tile(symbol[None, None], [batch_size, 1])], axis=1)
|
||||
index_all_occurrences = tf.where(tf.equal(tensor, symbol))
|
||||
index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
|
||||
# `index_all_occurrences` is a [N, 2] tensor with coordinates of all positions
|
||||
# of `symbol` in `tensor`. So N will be >= batch size since there can be
|
||||
# several `symbol` in one row of tensor. We need to take only the position
|
||||
# of the first occurrence for each row. `segment_min` does that, taking the
|
||||
# lowest column index for each row index.
|
||||
index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1],
|
||||
index_all_occurrences[:, 0])
|
||||
index_first_occurrences.set_shape([batch_size])
|
||||
index_first_occurrences = tf.minimum(index_first_occurrences + 1, max_length)
|
||||
return index_first_occurrences
|
||||
|
||||
|
||||
def sequence_to_sentence(sequence, id_to_word):
|
||||
"""Turn a sequence into a sentence , inverse of sentence_to_sequence."""
|
||||
words = []
|
||||
for token_index in sequence:
|
||||
if token_index in id_to_word:
|
||||
words.append(id_to_word[token_index])
|
||||
else:
|
||||
words.append(reader.UNK)
|
||||
return " ".join(words)
|
||||
|
||||
|
||||
def batch_sequences_to_sentences(sequences, id_to_word):
|
||||
return [sequence_to_sentence(sequence, id_to_word) for sequence in sequences]
|
||||
|
||||
|
||||
def write_eval_results(checkpoint_dir, all_gen_sentences, checkpoint_name,
|
||||
mean_train_prob, mean_valid_prob, mean_gen_prob, fid):
|
||||
"""Write evaluation results to disk."""
|
||||
to_write = ",".join(
|
||||
map(str, [
|
||||
checkpoint_name, mean_train_prob, mean_valid_prob, mean_gen_prob, fid
|
||||
]))
|
||||
eval_filepath = os.path.join(checkpoint_dir, EVAL_FILENAME)
|
||||
previous_eval_content = ""
|
||||
if gfile.exists(eval_filepath):
|
||||
with gfile.GFile(eval_filepath, "r") as f:
|
||||
previous_eval_content = f.read()
|
||||
with gfile.GFile(eval_filepath, "w") as f:
|
||||
f.write(previous_eval_content + to_write + "\n")
|
||||
|
||||
with gfile.GFile(
|
||||
os.path.join(checkpoint_dir, checkpoint_name + "_sentences.txt"),
|
||||
"w") as f:
|
||||
f.write("\n".join(all_gen_sentences))
|
||||
|
||||
|
||||
def maybe_pick_models_to_evaluate(checkpoint_dir):
|
||||
"""Pick a checkpoint to evaluate that has not been evaluated already."""
|
||||
logging.info("Picking checkpoint to evaluate from %s.", checkpoint_dir)
|
||||
|
||||
filenames = gfile.listdir(checkpoint_dir)
|
||||
filenames = [f[:-5] for f in filenames if f[-5:] == ".meta"]
|
||||
logging.info("Found existing checkpoints: %s", filenames)
|
||||
|
||||
evaluated_filenames = []
|
||||
if gfile.exists(os.path.join(checkpoint_dir, EVAL_FILENAME)):
|
||||
with gfile.GFile(os.path.join(checkpoint_dir, EVAL_FILENAME), "r") as f:
|
||||
evaluated_filenames = [l.strip().split(",")[0] for l in f.readlines()]
|
||||
logging.info("Found already evaluated checkpoints: %s", evaluated_filenames)
|
||||
|
||||
checkpoints_to_evaluate = [
|
||||
f for f in filenames if f not in evaluated_filenames
|
||||
]
|
||||
logging.info("Remaining potential checkpoints: %s", checkpoints_to_evaluate)
|
||||
|
||||
if checkpoints_to_evaluate:
|
||||
return os.path.join(checkpoint_dir, checkpoints_to_evaluate[0])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_embedding_path(data_dir, dataset):
|
||||
"""By convention, this is where we store the embedding."""
|
||||
return os.path.join(data_dir, "glove_%s.txt" % dataset)
|
||||
|
||||
|
||||
def make_partially_trainable_embeddings(vocab_file, embedding_source,
|
||||
vocab_size, trainable_embedding_size):
|
||||
"""Makes embedding matrix with pretrained GloVe [1] part and trainable part.
|
||||
|
||||
[1] Pennington, J., Socher, R., & Manning, C. (2014, October). Glove: Global
|
||||
vectors for word representation. In Proceedings of the 2014 conference on
|
||||
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
|
||||
|
||||
Returns:
|
||||
A matrix of partially pretrained embeddings.
|
||||
"""
|
||||
|
||||
# Our embeddings have 2 parts: a pre-trained, frozen, GloVe part,
|
||||
# and a trainable, randomly initialized part.
|
||||
# The standard deviation of the GloVe part is used to initialize
|
||||
# the trainable part, so that both part have roughly the same distribution.
|
||||
#
|
||||
# Let g_ij be the j-th coordinates of the GloVe embedding of the i-th word.
|
||||
# So that 0 < i < |vocab| and 0 < j < 300.
|
||||
# Then sum_ij (g_ij - sum_kl g_kl)^2 = (0.3836)^2
|
||||
#
|
||||
# In reality g_ij follows a truncated normal distribution
|
||||
# min(max(N(0, s), -4.2), 4.2) but we approximate it by N(0, 0.3836).
|
||||
embedding_initializer = _get_embedding_initializer(
|
||||
vocab_file=vocab_file,
|
||||
embedding_source=embedding_source,
|
||||
vocab_size=vocab_size)
|
||||
pretrained_embedding = tf.get_variable(
|
||||
"pretrained_embedding",
|
||||
initializer=embedding_initializer,
|
||||
dtype=tf.float32)
|
||||
trainable_embedding = tf.get_variable(
|
||||
"trainable_embedding",
|
||||
shape=[vocab_size, trainable_embedding_size],
|
||||
initializer=tf.initializers.random_normal(mean=0.0, stddev=GLOVE_STD))
|
||||
# We just concatenate embeddings, they will pass through a projection
|
||||
# matrix afterwards.
|
||||
embedding = tf.concat([pretrained_embedding, trainable_embedding], axis=1)
|
||||
return embedding
|
||||
Reference in New Issue
Block a user