mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 21:15:21 +08:00
Use bibtex format.
PiperOrigin-RevId: 281754451
This commit is contained in:
committed by
Diego de Las Casas
parent
db2393d5a9
commit
d705acd6a7
@@ -0,0 +1,174 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for parsing text files."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from tensorflow.io import gfile
|
||||
|
||||
# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32
|
||||
# lengths: [N, 2] array of int32, such that
|
||||
# lengths[i, 0] is the number of non-pad tokens in sequences[i, :]
|
||||
# TODO(cyprien): handle PTB
|
||||
FILENAMES = {
|
||||
"emnlp2017": ("train.json", "valid.json", "test.json"),
|
||||
}
|
||||
|
||||
# EMNLP2017 sentences have max length 50, add one for a PAD token so that all
|
||||
# sentences end with PAD.
|
||||
MAX_TOKENS_SEQUENCE = {"emnlp2017": 52}
|
||||
|
||||
UNK = "<unk>"
|
||||
PAD = " "
|
||||
|
||||
PAD_INT = 0
|
||||
|
||||
|
||||
def tokenize(sentence):
|
||||
"""Split a string into words."""
|
||||
return sentence.split(" ") + [PAD]
|
||||
|
||||
|
||||
def _build_vocab(json_data):
|
||||
"""Builds full vocab from json data."""
|
||||
vocab = collections.Counter()
|
||||
for sentence in json_data:
|
||||
tokens = tokenize(sentence["s"])
|
||||
vocab.update(tokens)
|
||||
for title in sentence["t"]:
|
||||
title_tokens = tokenize(title)
|
||||
vocab.update(title_tokens)
|
||||
# Most common words first.
|
||||
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
|
||||
words, _ = zip(*count_pairs)
|
||||
words = list(words)
|
||||
if UNK not in words:
|
||||
words = [UNK] + words
|
||||
word_to_id = dict(zip(words, range(len(words))))
|
||||
|
||||
# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
|
||||
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
|
||||
# currently at the `PAD_INT` index with the `PAD` token.
|
||||
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
|
||||
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
|
||||
assert word_to_id[PAD] == PAD_INT
|
||||
|
||||
return word_to_id
|
||||
|
||||
|
||||
def string_sequence_to_sequence(string_sequence, word_to_id):
|
||||
result = []
|
||||
for word in string_sequence:
|
||||
if word in word_to_id:
|
||||
result.append(word_to_id[word])
|
||||
else:
|
||||
result.append(word_to_id[UNK])
|
||||
return result
|
||||
|
||||
|
||||
def _integerize(json_data, word_to_id, dataset):
|
||||
"""Transform words into integers."""
|
||||
sequences = np.full((len(json_data), MAX_TOKENS_SEQUENCE[dataset]),
|
||||
word_to_id[PAD], np.int32)
|
||||
sequence_lengths = np.zeros(shape=(len(json_data)), dtype=np.int32)
|
||||
for i, sentence in enumerate(json_data):
|
||||
sequence_i = string_sequence_to_sequence(
|
||||
tokenize(sentence["s"]), word_to_id)
|
||||
sequence_lengths[i] = len(sequence_i)
|
||||
sequences[i, :sequence_lengths[i]] = np.array(sequence_i)
|
||||
return {
|
||||
"sequences": sequences,
|
||||
"sequence_lengths": sequence_lengths,
|
||||
}
|
||||
|
||||
|
||||
def get_raw_data(data_path, dataset, truncate_vocab=20000):
|
||||
"""Load raw data from data directory "data_path".
|
||||
|
||||
Reads text files, converts strings to integer ids,
|
||||
and performs mini-batching of the inputs.
|
||||
|
||||
Args:
|
||||
data_path: string path to the directory where simple-examples.tgz has been
|
||||
extracted.
|
||||
dataset: one of ["emnlp2017"]
|
||||
truncate_vocab: int, number of words to keep in the vocabulary.
|
||||
|
||||
Returns:
|
||||
tuple (train_data, valid_data, vocabulary) where each of the data
|
||||
objects can be passed to iterator.
|
||||
|
||||
Raises:
|
||||
ValueError: dataset not in ["emnlp2017"].
|
||||
"""
|
||||
if dataset not in FILENAMES:
|
||||
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
|
||||
dataset, FILENAMES.keys()))
|
||||
train_file, valid_file, _ = FILENAMES[dataset]
|
||||
|
||||
train_path = os.path.join(data_path, train_file)
|
||||
valid_path = os.path.join(data_path, valid_file)
|
||||
|
||||
with gfile.GFile(train_path, "r") as json_file:
|
||||
json_data_train = json.load(json_file)
|
||||
with gfile.GFile(valid_path, "r") as json_file:
|
||||
json_data_valid = json.load(json_file)
|
||||
|
||||
word_to_id = _build_vocab(json_data_train)
|
||||
logging.info("Full vocab length: %d", len(word_to_id))
|
||||
# Assume the vocab is sorted by frequency.
|
||||
word_to_id_truncated = {
|
||||
k: v for k, v in word_to_id.items() if v < truncate_vocab
|
||||
}
|
||||
logging.info("Truncated vocab length: %d", len(word_to_id_truncated))
|
||||
|
||||
train_data = _integerize(json_data_train, word_to_id_truncated, dataset)
|
||||
valid_data = _integerize(json_data_valid, word_to_id_truncated, dataset)
|
||||
return train_data, valid_data, word_to_id_truncated
|
||||
|
||||
|
||||
def iterator(raw_data, batch_size, random=False):
|
||||
"""Looping iterators on the raw data."""
|
||||
sequences = raw_data["sequences"]
|
||||
sequence_lengths = raw_data["sequence_lengths"]
|
||||
|
||||
num_examples = sequences.shape[0]
|
||||
indice_range = np.arange(num_examples)
|
||||
|
||||
if random:
|
||||
while True:
|
||||
indices = np.random.choice(indice_range, size=batch_size, replace=True)
|
||||
yield {
|
||||
"sequence": sequences[indices, :],
|
||||
"sequence_length": sequence_lengths[indices],
|
||||
}
|
||||
else:
|
||||
start = 0
|
||||
while True:
|
||||
sequence = sequences[start:(start + batch_size), :]
|
||||
sequence_length = sequence_lengths[start:(start + batch_size)]
|
||||
start += batch_size
|
||||
if start + batch_size > num_examples:
|
||||
start = (start + batch_size) % num_examples
|
||||
yield {
|
||||
"sequence": sequence,
|
||||
"sequence_length": sequence_length,
|
||||
}
|
||||
Reference in New Issue
Block a user