Move to PY3.

PiperOrigin-RevId: 313217111
This commit is contained in:
Cyprien de Masson d'Autume
2020-05-26 18:19:25 +01:00
committed by Diego de Las Casas
parent 77886cb179
commit b105a3646b
5 changed files with 33 additions and 22 deletions
+6 -5
View File
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -56,17 +57,17 @@ def _build_vocab(json_data):
title_tokens = tokenize(title)
vocab.update(title_tokens)
# Most common words first.
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
words, _ = zip(*count_pairs)
count_pairs = sorted(list(vocab.items()), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
words = list(words)
if UNK not in words:
words = [UNK] + words
word_to_id = dict(zip(words, range(len(words))))
word_to_id = dict(list(zip(words, list(range(len(words))))))
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
# currently at the `PAD_INT` index with the `PAD` token.
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
word = list(word_to_id.keys())[list(word_to_id.values()).index(PAD_INT)]
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
assert word_to_id[PAD] == PAD_INT
@@ -120,7 +121,7 @@ def get_raw_data(data_path, dataset, truncate_vocab=20000):
"""
if dataset not in FILENAMES:
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
dataset, FILENAMES.keys()))
dataset, list(FILENAMES.keys())))
train_file, valid_file, _ = FILENAMES[dataset]
train_path = os.path.join(data_path, train_file)