mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 04:45:31 +08:00
Move to PY3.
PiperOrigin-RevId: 313217111
This commit is contained in:
committed by
Diego de Las Casas
parent
77886cb179
commit
b105a3646b
+14
-13
@@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python3
|
||||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -109,7 +110,7 @@ def train(config):
|
|||||||
raw_data = reader.get_raw_data(
|
raw_data = reader.get_raw_data(
|
||||||
data_path=config.data_dir, dataset=config.dataset)
|
data_path=config.data_dir, dataset=config.dataset)
|
||||||
train_data, valid_data, word_to_id = raw_data
|
train_data, valid_data, word_to_id = raw_data
|
||||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
id_to_word = {v: k for k, v in word_to_id.items()}
|
||||||
vocab_size = len(word_to_id)
|
vocab_size = len(word_to_id)
|
||||||
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
|
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
|
||||||
logging.info("Vocabulary size: %d", vocab_size)
|
logging.info("Vocabulary size: %d", vocab_size)
|
||||||
@@ -124,8 +125,8 @@ def train(config):
|
|||||||
name="real_sequence")
|
name="real_sequence")
|
||||||
real_sequence_length = tf.placeholder(
|
real_sequence_length = tf.placeholder(
|
||||||
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
|
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
|
||||||
first_batch_np = iterator.next()
|
first_batch_np = next(iterator)
|
||||||
valid_batch_np = iterator_valid.next()
|
valid_batch_np = next(iterator_valid)
|
||||||
|
|
||||||
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
|
test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
|
||||||
test_fake_batch = {
|
test_fake_batch = {
|
||||||
@@ -146,7 +147,7 @@ def train(config):
|
|||||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||||
vocab_file = "/tmp/vocab.txt"
|
vocab_file = "/tmp/vocab.txt"
|
||||||
with gfile.GFile(vocab_file, "w") as f:
|
with gfile.GFile(vocab_file, "w") as f:
|
||||||
for i in xrange(len(id_to_word)):
|
for i in range(len(id_to_word)):
|
||||||
f.write(id_to_word[i] + "\n")
|
f.write(id_to_word[i] + "\n")
|
||||||
logging.info("Temporary vocab file: %s", vocab_file)
|
logging.info("Temporary vocab file: %s", vocab_file)
|
||||||
else:
|
else:
|
||||||
@@ -263,17 +264,17 @@ def train(config):
|
|||||||
if latest_ckpt:
|
if latest_ckpt:
|
||||||
saver.restore(sess, latest_ckpt)
|
saver.restore(sess, latest_ckpt)
|
||||||
|
|
||||||
for step in xrange(config.num_steps):
|
for step in range(config.num_steps):
|
||||||
real_data_np = iterator.next()
|
real_data_np = next(iterator)
|
||||||
train_feed = {
|
train_feed = {
|
||||||
real_sequence: real_data_np["sequence"],
|
real_sequence: real_data_np["sequence"],
|
||||||
real_sequence_length: real_data_np["sequence_length"],
|
real_sequence_length: real_data_np["sequence_length"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update generator and discriminator.
|
# Update generator and discriminator.
|
||||||
for _ in xrange(config.num_disc_updates):
|
for _ in range(config.num_disc_updates):
|
||||||
sess.run(disc_update, feed_dict=train_feed)
|
sess.run(disc_update, feed_dict=train_feed)
|
||||||
for _ in xrange(config.num_gen_updates):
|
for _ in range(config.num_gen_updates):
|
||||||
sess.run(gen_update, feed_dict=train_feed)
|
sess.run(gen_update, feed_dict=train_feed)
|
||||||
|
|
||||||
# Reporting
|
# Reporting
|
||||||
@@ -325,7 +326,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
|||||||
# Build graph.
|
# Build graph.
|
||||||
train_data, valid_data, word_to_id = reader.get_raw_data(
|
train_data, valid_data, word_to_id = reader.get_raw_data(
|
||||||
data_dir, dataset=dataset)
|
data_dir, dataset=dataset)
|
||||||
id_to_word = {v: k for k, v in word_to_id.iteritems()}
|
id_to_word = {v: k for k, v in word_to_id.items()}
|
||||||
vocab_size = len(word_to_id)
|
vocab_size = len(word_to_id)
|
||||||
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
|
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
|
||||||
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
|
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
|
||||||
@@ -353,7 +354,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
|||||||
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
|
||||||
vocab_file = "/tmp/vocab.txt"
|
vocab_file = "/tmp/vocab.txt"
|
||||||
with gfile.GFile(vocab_file, "w") as f:
|
with gfile.GFile(vocab_file, "w") as f:
|
||||||
for i in xrange(len(id_to_word)):
|
for i in range(len(id_to_word)):
|
||||||
f.write(id_to_word[i] + "\n")
|
f.write(id_to_word[i] + "\n")
|
||||||
logging.info("Temporary vocab file: %s", vocab_file)
|
logging.info("Temporary vocab file: %s", vocab_file)
|
||||||
else:
|
else:
|
||||||
@@ -431,10 +432,10 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
|
|||||||
logging.info("Restoring variables.")
|
logging.info("Restoring variables.")
|
||||||
saver.restore(sess, checkpoint_path)
|
saver.restore(sess, checkpoint_path)
|
||||||
|
|
||||||
for i in xrange(num_batches):
|
for i in range(num_batches):
|
||||||
logging.info("Batch %d / %d", i, num_batches)
|
logging.info("Batch %d / %d", i, num_batches)
|
||||||
train_data_np = train_iterator.next()
|
train_data_np = next(train_iterator)
|
||||||
valid_data_np = valid_iterator.next()
|
valid_data_np = next(valid_iterator)
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
train_sequence: train_data_np["sequence"],
|
train_sequence: train_data_np["sequence"],
|
||||||
train_sequence_length: train_data_np["sequence_length"],
|
train_sequence_length: train_data_np["sequence_length"],
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python3
|
||||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -104,7 +105,7 @@ class LSTMGen(snt.AbstractModule):
|
|||||||
sample = tf.tile(
|
sample = tf.tile(
|
||||||
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
|
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
|
||||||
logging.info('Unrolling over %d steps.', max_sequence_length)
|
logging.info('Unrolling over %d steps.', max_sequence_length)
|
||||||
for _ in xrange(max_sequence_length):
|
for _ in range(max_sequence_length):
|
||||||
# Input is sampled word at t-1.
|
# Input is sampled word at t-1.
|
||||||
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
|
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
|
||||||
embedding.shape.assert_is_compatible_with(
|
embedding.shape.assert_is_compatible_with(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python3
|
||||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -67,9 +68,9 @@ def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
|
|||||||
# Compute cumulative rewards.
|
# Compute cumulative rewards.
|
||||||
rewards_list = tf.unstack(rewards, axis=1)
|
rewards_list = tf.unstack(rewards, axis=1)
|
||||||
cumulative_rewards = []
|
cumulative_rewards = []
|
||||||
for t in xrange(sequence_length):
|
for t in range(sequence_length):
|
||||||
cum_value = tf.zeros(shape=[batch_size])
|
cum_value = tf.zeros(shape=[batch_size])
|
||||||
for s in xrange(t, sequence_length):
|
for s in range(t, sequence_length):
|
||||||
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
|
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
|
||||||
cumulative_rewards.append(cum_value)
|
cumulative_rewards.append(cum_value)
|
||||||
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
|
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python3
|
||||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -56,17 +57,17 @@ def _build_vocab(json_data):
|
|||||||
title_tokens = tokenize(title)
|
title_tokens = tokenize(title)
|
||||||
vocab.update(title_tokens)
|
vocab.update(title_tokens)
|
||||||
# Most common words first.
|
# Most common words first.
|
||||||
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
|
count_pairs = sorted(list(vocab.items()), key=lambda x: (-x[1], x[0]))
|
||||||
words, _ = zip(*count_pairs)
|
words, _ = list(zip(*count_pairs))
|
||||||
words = list(words)
|
words = list(words)
|
||||||
if UNK not in words:
|
if UNK not in words:
|
||||||
words = [UNK] + words
|
words = [UNK] + words
|
||||||
word_to_id = dict(zip(words, range(len(words))))
|
word_to_id = dict(list(zip(words, list(range(len(words))))))
|
||||||
|
|
||||||
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
|
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
|
||||||
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
|
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
|
||||||
# currently at the `PAD_INT` index with the `PAD` token.
|
# currently at the `PAD_INT` index with the `PAD` token.
|
||||||
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
|
word = list(word_to_id.keys())[list(word_to_id.values()).index(PAD_INT)]
|
||||||
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
|
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
|
||||||
assert word_to_id[PAD] == PAD_INT
|
assert word_to_id[PAD] == PAD_INT
|
||||||
|
|
||||||
@@ -120,7 +121,7 @@ def get_raw_data(data_path, dataset, truncate_vocab=20000):
|
|||||||
"""
|
"""
|
||||||
if dataset not in FILENAMES:
|
if dataset not in FILENAMES:
|
||||||
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
|
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
|
||||||
dataset, FILENAMES.keys()))
|
dataset, list(FILENAMES.keys())))
|
||||||
train_file, valid_file, _ = FILENAMES[dataset]
|
train_file, valid_file, _ = FILENAMES[dataset]
|
||||||
|
|
||||||
train_path = os.path.join(data_path, train_file)
|
train_path = os.path.join(data_path, train_file)
|
||||||
|
|||||||
+8
-1
@@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python3
|
||||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -34,7 +35,7 @@ def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
|
|||||||
embedding_lines = f.readlines()
|
embedding_lines = f.readlines()
|
||||||
|
|
||||||
# First line contains embedding dim.
|
# First line contains embedding dim.
|
||||||
_, embedding_dim = map(int, embedding_lines[0].split())
|
_, embedding_dim = list(map(int, embedding_lines[0].split()))
|
||||||
# Get the tokens as strings.
|
# Get the tokens as strings.
|
||||||
tokens = [line.split()[0] for line in embedding_lines[1:]]
|
tokens = [line.split()[0] for line in embedding_lines[1:]]
|
||||||
# Get the actual embedding matrix.
|
# Get the actual embedding matrix.
|
||||||
@@ -275,6 +276,12 @@ def make_partially_trainable_embeddings(vocab_file, embedding_source,
|
|||||||
vectors for word representation. In Proceedings of the 2014 conference on
|
vectors for word representation. In Proceedings of the 2014 conference on
|
||||||
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
|
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file: vocabulary file.
|
||||||
|
embedding_source: path to the actual embeddings.
|
||||||
|
vocab_size: number of words in vocabulary.
|
||||||
|
trainable_embedding_size: size of the trainable part of the embeddings.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A matrix of partially pretrained embeddings.
|
A matrix of partially pretrained embeddings.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user