mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
8b47d2a576
PiperOrigin-RevId: 287568660
122 lines
4.8 KiB
Python
122 lines
4.8 KiB
Python
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Discriminator networks for text data."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import sonnet as snt
|
|
import tensorflow.compat.v1 as tf
|
|
from scratchgan import utils
|
|
|
|
|
|
class LSTMEmbedDiscNet(snt.AbstractModule):
|
|
"""An LSTM discriminator that operates on word indexes."""
|
|
|
|
def __init__(self,
|
|
feature_sizes,
|
|
vocab_size,
|
|
use_layer_norm,
|
|
trainable_embedding_size,
|
|
dropout,
|
|
pad_token,
|
|
embedding_source=None,
|
|
vocab_file=None,
|
|
name='LSTMEmbedDiscNet'):
|
|
super(LSTMEmbedDiscNet, self).__init__(name=name)
|
|
self._feature_sizes = feature_sizes
|
|
self._vocab_size = vocab_size
|
|
self._use_layer_norm = use_layer_norm
|
|
self._trainable_embedding_size = trainable_embedding_size
|
|
self._embedding_source = embedding_source
|
|
self._vocab_file = vocab_file
|
|
self._dropout = dropout
|
|
self._pad_token = pad_token
|
|
if self._embedding_source:
|
|
assert vocab_file
|
|
|
|
def _build(self, sequence, sequence_length, is_training=True):
|
|
"""Connect to the graph.
|
|
|
|
Args:
|
|
sequence: A [batch_size, max_sequence_length] tensor of int. For example
|
|
the indices of words as sampled by the generator.
|
|
sequence_length: A [batch_size] tensor of int. Length of the sequence.
|
|
is_training: Boolean, False to disable dropout.
|
|
|
|
Returns:
|
|
A [batch_size, max_sequence_length, feature_size] tensor of floats. For
|
|
each sequence in the batch, the features should (hopefully) allow to
|
|
distinguish if the value at each timestep is real or generated.
|
|
"""
|
|
batch_size, max_sequence_length = sequence.shape.as_list()
|
|
keep_prob = (1.0 - self._dropout) if is_training else 1.0
|
|
|
|
if self._embedding_source:
|
|
all_embeddings = utils.make_partially_trainable_embeddings(
|
|
self._vocab_file, self._embedding_source, self._vocab_size,
|
|
self._trainable_embedding_size)
|
|
else:
|
|
all_embeddings = tf.get_variable(
|
|
'trainable_embedding',
|
|
shape=[self._vocab_size, self._trainable_embedding_size],
|
|
trainable=True)
|
|
_, self._embedding_size = all_embeddings.shape.as_list()
|
|
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
|
|
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
|
|
embeddings.shape.assert_is_compatible_with(
|
|
[batch_size, max_sequence_length, self._embedding_size])
|
|
position_dim = 8
|
|
embeddings_pos = utils.append_position_signal(embeddings, position_dim)
|
|
embeddings_pos = tf.reshape(
|
|
embeddings_pos,
|
|
[batch_size * max_sequence_length, self._embedding_size + position_dim])
|
|
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
|
|
lstm_inputs = tf.reshape(
|
|
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]])
|
|
lstm_inputs.shape.assert_is_compatible_with(
|
|
[batch_size, max_sequence_length, self._feature_sizes[0]])
|
|
|
|
encoder_cells = []
|
|
for feature_size in self._feature_sizes:
|
|
encoder_cells += [
|
|
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
|
|
]
|
|
encoder_cell = snt.DeepRNN(encoder_cells)
|
|
initial_state = encoder_cell.initial_state(batch_size)
|
|
|
|
hidden_states, _ = tf.nn.dynamic_rnn(
|
|
cell=encoder_cell,
|
|
inputs=lstm_inputs,
|
|
sequence_length=sequence_length,
|
|
initial_state=initial_state,
|
|
swap_memory=True)
|
|
|
|
hidden_states.shape.assert_is_compatible_with(
|
|
[batch_size, max_sequence_length,
|
|
sum(self._feature_sizes)])
|
|
logits = snt.BatchApply(snt.Linear(1))(hidden_states)
|
|
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1])
|
|
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])
|
|
|
|
# Mask past first PAD symbol
|
|
#
|
|
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
|
|
# into account the sequence_length properly, because otherwise
|
|
# the logits at a given timestep will depend on the inputs for all other
|
|
# timesteps, including the ones that should be masked.
|
|
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
|
|
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
|
|
return masked_logits_flat
|