Files
deepmind-research/scratchgan/generators.py
T
Cyprien de Masson d'Autume b105a3646b Move to PY3.
PiperOrigin-RevId: 313217111
2020-05-27 15:00:56 +01:00

149 lines
5.7 KiB
Python

# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generators for text data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from scratchgan import utils
class LSTMGen(snt.AbstractModule):
"""A multi-layer LSTM language model.
Uses tied input/output embedding weights.
"""
def __init__(self,
vocab_size,
feature_sizes,
max_sequence_length,
batch_size,
use_layer_norm,
trainable_embedding_size,
input_dropout,
output_dropout,
pad_token,
embedding_source=None,
vocab_file=None,
name='lstm_gen'):
super(LSTMGen, self).__init__(name=name)
self._feature_sizes = feature_sizes
self._max_sequence_length = max_sequence_length
self._vocab_size = vocab_size
self._batch_size = batch_size
self._use_layer_norm = use_layer_norm
self._trainable_embedding_size = trainable_embedding_size
self._embedding_source = embedding_source
self._vocab_file = vocab_file
self._input_dropout = input_dropout
self._output_dropout = output_dropout
self._pad_token = pad_token
if self._embedding_source:
assert vocab_file
def _build(self, is_training=True, temperature=1.0):
input_keep_prob = (1. - self._input_dropout) if is_training else 1.0
output_keep_prob = (1. - self._output_dropout) if is_training else 1.0
batch_size = self._batch_size
max_sequence_length = self._max_sequence_length
if self._embedding_source:
all_embeddings = utils.make_partially_trainable_embeddings(
self._vocab_file, self._embedding_source, self._vocab_size,
self._trainable_embedding_size)
else:
all_embeddings = tf.get_variable(
'trainable_embeddings',
shape=[self._vocab_size, self._trainable_embedding_size],
trainable=True)
_, self._embedding_size = all_embeddings.shape.as_list()
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=input_keep_prob)
output_embeddings = tf.nn.dropout(
all_embeddings, keep_prob=output_keep_prob)
out_bias = tf.get_variable(
'out_bias', shape=[1, self._vocab_size], dtype=tf.float32)
in_proj = tf.get_variable(
'in_proj', shape=[self._embedding_size, self._feature_sizes[0]])
# If more than 1 layer, then output has dim sum(self._feature_sizes),
# which is different from input dim == self._feature_sizes[0]
# So we need a different projection matrix for input and output.
if len(self._feature_sizes) > 1:
out_proj = tf.get_variable(
'out_proj', shape=[self._embedding_size,
sum(self._feature_sizes)])
else:
out_proj = in_proj
encoder_cells = []
for feature_size in self._feature_sizes:
encoder_cells += [
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
]
encoder_cell = snt.DeepRNN(encoder_cells)
state = encoder_cell.initial_state(batch_size)
# Manual unrolling.
samples_list, logits_list, logprobs_list, embeddings_list = [], [], [], []
sample = tf.tile(
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
logging.info('Unrolling over %d steps.', max_sequence_length)
for _ in range(max_sequence_length):
# Input is sampled word at t-1.
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
embedding.shape.assert_is_compatible_with(
[batch_size, self._embedding_size])
embedding_proj = tf.matmul(embedding, in_proj)
embedding_proj.shape.assert_is_compatible_with(
[batch_size, self._feature_sizes[0]])
outputs, state = encoder_cell(embedding_proj, state)
outputs_proj = tf.matmul(outputs, out_proj, transpose_b=True)
logits = tf.matmul(
outputs_proj, output_embeddings, transpose_b=True) + out_bias
categorical = tfp.distributions.Categorical(logits=logits/temperature)
sample = categorical.sample()
logprobs = categorical.log_prob(sample)
samples_list.append(sample)
logits_list.append(logits)
logprobs_list.append(logprobs)
embeddings_list.append(embedding)
# Create an op to retrieve embeddings for full sequence, useful for testing.
embeddings = tf.stack( # pylint: disable=unused-variable
embeddings_list,
axis=1,
name='embeddings')
sequence = tf.stack(samples_list, axis=1)
logprobs = tf.stack(logprobs_list, axis=1)
# The sequence stops after the first occurrence of a PAD token.
sequence_length = utils.get_first_occurrence_indices(
sequence, self._pad_token)
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
masked_sequence = sequence * tf.cast(mask, tf.int32)
masked_logprobs = logprobs * tf.cast(mask, tf.float32)
return {
'sequence': masked_sequence,
'sequence_length': sequence_length,
'logprobs': masked_logprobs
}