Move to PY3.

PiperOrigin-RevId: 313217111
This commit is contained in:
Cyprien de Masson d'Autume
2020-05-26 18:19:25 +01:00
committed by Diego de Las Casas
parent 77886cb179
commit b105a3646b
5 changed files with 33 additions and 22 deletions
+14 -13
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC # Copyright 2019 DeepMind Technologies Limited and Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -109,7 +110,7 @@ def train(config):
raw_data = reader.get_raw_data( raw_data = reader.get_raw_data(
data_path=config.data_dir, dataset=config.dataset) data_path=config.data_dir, dataset=config.dataset)
train_data, valid_data, word_to_id = raw_data train_data, valid_data, word_to_id = raw_data
id_to_word = {v: k for k, v in word_to_id.iteritems()} id_to_word = {v: k for k, v in word_to_id.items()}
vocab_size = len(word_to_id) vocab_size = len(word_to_id)
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset] max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
logging.info("Vocabulary size: %d", vocab_size) logging.info("Vocabulary size: %d", vocab_size)
@@ -124,8 +125,8 @@ def train(config):
name="real_sequence") name="real_sequence")
real_sequence_length = tf.placeholder( real_sequence_length = tf.placeholder(
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length") dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
first_batch_np = iterator.next() first_batch_np = next(iterator)
valid_batch_np = iterator_valid.next() valid_batch_np = next(iterator_valid)
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()} test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
test_fake_batch = { test_fake_batch = {
@@ -146,7 +147,7 @@ def train(config):
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt" vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f: with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)): for i in range(len(id_to_word)):
f.write(id_to_word[i] + "\n") f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file) logging.info("Temporary vocab file: %s", vocab_file)
else: else:
@@ -263,17 +264,17 @@ def train(config):
if latest_ckpt: if latest_ckpt:
saver.restore(sess, latest_ckpt) saver.restore(sess, latest_ckpt)
for step in xrange(config.num_steps): for step in range(config.num_steps):
real_data_np = iterator.next() real_data_np = next(iterator)
train_feed = { train_feed = {
real_sequence: real_data_np["sequence"], real_sequence: real_data_np["sequence"],
real_sequence_length: real_data_np["sequence_length"], real_sequence_length: real_data_np["sequence_length"],
} }
# Update generator and discriminator. # Update generator and discriminator.
for _ in xrange(config.num_disc_updates): for _ in range(config.num_disc_updates):
sess.run(disc_update, feed_dict=train_feed) sess.run(disc_update, feed_dict=train_feed)
for _ in xrange(config.num_gen_updates): for _ in range(config.num_gen_updates):
sess.run(gen_update, feed_dict=train_feed) sess.run(gen_update, feed_dict=train_feed)
# Reporting # Reporting
@@ -325,7 +326,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
# Build graph. # Build graph.
train_data, valid_data, word_to_id = reader.get_raw_data( train_data, valid_data, word_to_id = reader.get_raw_data(
data_dir, dataset=dataset) data_dir, dataset=dataset)
id_to_word = {v: k for k, v in word_to_id.iteritems()} id_to_word = {v: k for k, v in word_to_id.items()}
vocab_size = len(word_to_id) vocab_size = len(word_to_id)
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size) train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size) valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
@@ -353,7 +354,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt" vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f: with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)): for i in range(len(id_to_word)):
f.write(id_to_word[i] + "\n") f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file) logging.info("Temporary vocab file: %s", vocab_file)
else: else:
@@ -431,10 +432,10 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
logging.info("Restoring variables.") logging.info("Restoring variables.")
saver.restore(sess, checkpoint_path) saver.restore(sess, checkpoint_path)
for i in xrange(num_batches): for i in range(num_batches):
logging.info("Batch %d / %d", i, num_batches) logging.info("Batch %d / %d", i, num_batches)
train_data_np = train_iterator.next() train_data_np = next(train_iterator)
valid_data_np = valid_iterator.next() valid_data_np = next(valid_iterator)
feed_dict = { feed_dict = {
train_sequence: train_data_np["sequence"], train_sequence: train_data_np["sequence"],
train_sequence_length: train_data_np["sequence_length"], train_sequence_length: train_data_np["sequence_length"],
+2 -1
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC # Copyright 2019 DeepMind Technologies Limited and Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -104,7 +105,7 @@ class LSTMGen(snt.AbstractModule):
sample = tf.tile( sample = tf.tile(
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size]) tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
logging.info('Unrolling over %d steps.', max_sequence_length) logging.info('Unrolling over %d steps.', max_sequence_length)
for _ in xrange(max_sequence_length): for _ in range(max_sequence_length):
# Input is sampled word at t-1. # Input is sampled word at t-1.
embedding = tf.nn.embedding_lookup(input_embeddings, sample) embedding = tf.nn.embedding_lookup(input_embeddings, sample)
embedding.shape.assert_is_compatible_with( embedding.shape.assert_is_compatible_with(
+3 -2
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC # Copyright 2019 DeepMind Technologies Limited and Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -67,9 +68,9 @@ def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
# Compute cumulative rewards. # Compute cumulative rewards.
rewards_list = tf.unstack(rewards, axis=1) rewards_list = tf.unstack(rewards, axis=1)
cumulative_rewards = [] cumulative_rewards = []
for t in xrange(sequence_length): for t in range(sequence_length):
cum_value = tf.zeros(shape=[batch_size]) cum_value = tf.zeros(shape=[batch_size])
for s in xrange(t, sequence_length): for s in range(t, sequence_length):
cum_value += np.power(gamma, (s - t)) * rewards_list[s] cum_value += np.power(gamma, (s - t)) * rewards_list[s]
cumulative_rewards.append(cum_value) cumulative_rewards.append(cum_value)
cumulative_rewards = tf.stack(cumulative_rewards, axis=1) cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
+6 -5
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC # Copyright 2019 DeepMind Technologies Limited and Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -56,17 +57,17 @@ def _build_vocab(json_data):
title_tokens = tokenize(title) title_tokens = tokenize(title)
vocab.update(title_tokens) vocab.update(title_tokens)
# Most common words first. # Most common words first.
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0])) count_pairs = sorted(list(vocab.items()), key=lambda x: (-x[1], x[0]))
words, _ = zip(*count_pairs) words, _ = list(zip(*count_pairs))
words = list(words) words = list(words)
if UNK not in words: if UNK not in words:
words = [UNK] + words words = [UNK] + words
word_to_id = dict(zip(words, range(len(words)))) word_to_id = dict(list(zip(words, list(range(len(words))))))
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will # Tokens are now sorted by frequency. There's no guarantee that `PAD` will
# end up at `PAD_INT` index. Enforce it by swapping whatever token is # end up at `PAD_INT` index. Enforce it by swapping whatever token is
# currently at the `PAD_INT` index with the `PAD` token. # currently at the `PAD_INT` index with the `PAD` token.
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)] word = list(word_to_id.keys())[list(word_to_id.values()).index(PAD_INT)]
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD] word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
assert word_to_id[PAD] == PAD_INT assert word_to_id[PAD] == PAD_INT
@@ -120,7 +121,7 @@ def get_raw_data(data_path, dataset, truncate_vocab=20000):
""" """
if dataset not in FILENAMES: if dataset not in FILENAMES:
raise ValueError("Invalid dataset {}. Valid datasets: {}".format( raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
dataset, FILENAMES.keys())) dataset, list(FILENAMES.keys())))
train_file, valid_file, _ = FILENAMES[dataset] train_file, valid_file, _ = FILENAMES[dataset]
train_path = os.path.join(data_path, train_file) train_path = os.path.join(data_path, train_file)
+8 -1
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC # Copyright 2019 DeepMind Technologies Limited and Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,7 +35,7 @@ def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
embedding_lines = f.readlines() embedding_lines = f.readlines()
# First line contains embedding dim. # First line contains embedding dim.
_, embedding_dim = map(int, embedding_lines[0].split()) _, embedding_dim = list(map(int, embedding_lines[0].split()))
# Get the tokens as strings. # Get the tokens as strings.
tokens = [line.split()[0] for line in embedding_lines[1:]] tokens = [line.split()[0] for line in embedding_lines[1:]]
# Get the actual embedding matrix. # Get the actual embedding matrix.
@@ -275,6 +276,12 @@ def make_partially_trainable_embeddings(vocab_file, embedding_source,
vectors for word representation. In Proceedings of the 2014 conference on vectors for word representation. In Proceedings of the 2014 conference on
empirical methods in natural language processing (EMNLP) (pp. 1532-1543). empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
Args:
vocab_file: vocabulary file.
embedding_source: path to the actual embeddings.
vocab_size: number of words in vocabulary.
trainable_embedding_size: size of the trainable part of the embeddings.
Returns: Returns:
A matrix of partially pretrained embeddings. A matrix of partially pretrained embeddings.
""" """