mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
Move to PY3.
PiperOrigin-RevId: 313217111
This commit is contained in:
committed by
Diego de Las Casas
parent
77886cb179
commit
b105a3646b
+14
-13
@@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -109,7 +110,7 @@ def train(config):
|
||||
raw_data = reader.get_raw_data(
|
||||
data_path=config.data_dir, dataset=config.dataset)
|
||||
train_data, valid_data, word_to_id = raw_data
|
||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
||||
id_to_word = {v: k for k, v in word_to_id.items()}
|
||||
vocab_size = len(word_to_id)
|
||||
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
|
||||
logging.info("Vocabulary size: %d", vocab_size)
|
||||
@@ -124,8 +125,8 @@ def train(config):
|
||||
name="real_sequence")
|
||||
real_sequence_length = tf.placeholder(
|
||||
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
|
||||
first_batch_np = iterator.next()
|
||||
valid_batch_np = iterator_valid.next()
|
||||
first_batch_np = next(iterator)
|
||||
valid_batch_np = next(iterator_valid)
|
||||
|
||||
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
|
||||
test_fake_batch = {
|
||||
@@ -146,7 +147,7 @@ def train(config):
|
||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||
vocab_file = "/tmp/vocab.txt"
|
||||
with gfile.GFile(vocab_file, "w") as f:
|
||||
for i in xrange(len(id_to_word)):
|
||||
for i in range(len(id_to_word)):
|
||||
f.write(id_to_word[i] + "\n")
|
||||
logging.info("Temporary vocab file: %s", vocab_file)
|
||||
else:
|
||||
@@ -263,17 +264,17 @@ def train(config):
|
||||
if latest_ckpt:
|
||||
saver.restore(sess, latest_ckpt)
|
||||
|
||||
for step in xrange(config.num_steps):
|
||||
real_data_np = iterator.next()
|
||||
for step in range(config.num_steps):
|
||||
real_data_np = next(iterator)
|
||||
train_feed = {
|
||||
real_sequence: real_data_np["sequence"],
|
||||
real_sequence_length: real_data_np["sequence_length"],
|
||||
}
|
||||
|
||||
# Update generator and discriminator.
|
||||
for _ in xrange(config.num_disc_updates):
|
||||
for _ in range(config.num_disc_updates):
|
||||
sess.run(disc_update, feed_dict=train_feed)
|
||||
for _ in xrange(config.num_gen_updates):
|
||||
for _ in range(config.num_gen_updates):
|
||||
sess.run(gen_update, feed_dict=train_feed)
|
||||
|
||||
# Reporting
|
||||
@@ -325,7 +326,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
||||
# Build graph.
|
||||
train_data, valid_data, word_to_id = reader.get_raw_data(
|
||||
data_dir, dataset=dataset)
|
||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
||||
id_to_word = {v: k for k, v in word_to_id.items()}
|
||||
vocab_size = len(word_to_id)
|
||||
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
|
||||
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
|
||||
@@ -353,7 +354,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||
vocab_file = "/tmp/vocab.txt"
|
||||
with gfile.GFile(vocab_file, "w") as f:
|
||||
for i in xrange(len(id_to_word)):
|
||||
for i in range(len(id_to_word)):
|
||||
f.write(id_to_word[i] + "\n")
|
||||
logging.info("Temporary vocab file: %s", vocab_file)
|
||||
else:
|
||||
@@ -431,10 +432,10 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
||||
logging.info("Restoring variables.")
|
||||
saver.restore(sess, checkpoint_path)
|
||||
|
||||
for i in xrange(num_batches):
|
||||
for i in range(num_batches):
|
||||
logging.info("Batch %d / %d", i, num_batches)
|
||||
train_data_np = train_iterator.next()
|
||||
valid_data_np = valid_iterator.next()
|
||||
train_data_np = next(train_iterator)
|
||||
valid_data_np = next(valid_iterator)
|
||||
feed_dict = {
|
||||
train_sequence: train_data_np["sequence"],
|
||||
train_sequence_length: train_data_np["sequence_length"],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -104,7 +105,7 @@ class LSTMGen(snt.AbstractModule):
|
||||
sample = tf.tile(
|
||||
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
|
||||
logging.info('Unrolling over %d steps.', max_sequence_length)
|
||||
for _ in xrange(max_sequence_length):
|
||||
for _ in range(max_sequence_length):
|
||||
# Input is sampled word at t-1.
|
||||
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
|
||||
embedding.shape.assert_is_compatible_with(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -67,9 +68,9 @@ def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
|
||||
# Compute cumulative rewards.
|
||||
rewards_list = tf.unstack(rewards, axis=1)
|
||||
cumulative_rewards = []
|
||||
for t in xrange(sequence_length):
|
||||
for t in range(sequence_length):
|
||||
cum_value = tf.zeros(shape=[batch_size])
|
||||
for s in xrange(t, sequence_length):
|
||||
for s in range(t, sequence_length):
|
||||
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
|
||||
cumulative_rewards.append(cum_value)
|
||||
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -56,17 +57,17 @@ def _build_vocab(json_data):
|
||||
title_tokens = tokenize(title)
|
||||
vocab.update(title_tokens)
|
||||
# Most common words first.
|
||||
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
|
||||
words, _ = zip(*count_pairs)
|
||||
count_pairs = sorted(list(vocab.items()), key=lambda x: (-x[1], x[0]))
|
||||
words, _ = list(zip(*count_pairs))
|
||||
words = list(words)
|
||||
if UNK not in words:
|
||||
words = [UNK] + words
|
||||
word_to_id = dict(zip(words, range(len(words))))
|
||||
word_to_id = dict(list(zip(words, list(range(len(words))))))
|
||||
|
||||
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
|
||||
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
|
||||
# currently at the `PAD_INT` index with the `PAD` token.
|
||||
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
|
||||
word = list(word_to_id.keys())[list(word_to_id.values()).index(PAD_INT)]
|
||||
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
|
||||
assert word_to_id[PAD] == PAD_INT
|
||||
|
||||
@@ -120,7 +121,7 @@ def get_raw_data(data_path, dataset, truncate_vocab=20000):
|
||||
"""
|
||||
if dataset not in FILENAMES:
|
||||
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
|
||||
dataset, FILENAMES.keys()))
|
||||
dataset, list(FILENAMES.keys())))
|
||||
train_file, valid_file, _ = FILENAMES[dataset]
|
||||
|
||||
train_path = os.path.join(data_path, train_file)
|
||||
|
||||
+8
-1
@@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -34,7 +35,7 @@ def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
|
||||
embedding_lines = f.readlines()
|
||||
|
||||
# First line contains embedding dim.
|
||||
_, embedding_dim = map(int, embedding_lines[0].split())
|
||||
_, embedding_dim = list(map(int, embedding_lines[0].split()))
|
||||
# Get the tokens as strings.
|
||||
tokens = [line.split()[0] for line in embedding_lines[1:]]
|
||||
# Get the actual embedding matrix.
|
||||
@@ -275,6 +276,12 @@ def make_partially_trainable_embeddings(vocab_file, embedding_source,
|
||||
vectors for word representation. In Proceedings of the 2014 conference on
|
||||
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
|
||||
|
||||
Args:
|
||||
vocab_file: vocabulary file.
|
||||
embedding_source: path to the actual embeddings.
|
||||
vocab_size: number of words in vocabulary.
|
||||
trainable_embedding_size: size of the trainable part of the embeddings.
|
||||
|
||||
Returns:
|
||||
A matrix of partially pretrained embeddings.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user