From 67455ed367af15440d88ff963eadde92c267f996 Mon Sep 17 00:00:00 2001 From: Luyu Wang Date: Fri, 6 Aug 2021 17:39:22 +0100 Subject: [PATCH] Release graph2text transformer modules. PiperOrigin-RevId: 389194661 --- wikigraphs/wikigraphs/model/__init__.py | 3 + wikigraphs/wikigraphs/model/embedding.py | 381 +++++++++++++ wikigraphs/wikigraphs/model/transformer.py | 522 +++++++++++++++++ .../wikigraphs/model/transformer_block.py | 529 ++++++++++++++++++ .../wikigraphs/model/transformer_test.py | 373 ++++++++++++ 5 files changed, 1808 insertions(+) create mode 100644 wikigraphs/wikigraphs/model/embedding.py create mode 100644 wikigraphs/wikigraphs/model/transformer.py create mode 100644 wikigraphs/wikigraphs/model/transformer_block.py create mode 100644 wikigraphs/wikigraphs/model/transformer_test.py diff --git a/wikigraphs/wikigraphs/model/__init__.py b/wikigraphs/wikigraphs/model/__init__.py index 6bc8744..991e3c1 100644 --- a/wikigraphs/wikigraphs/model/__init__.py +++ b/wikigraphs/wikigraphs/model/__init__.py @@ -28,4 +28,7 @@ # # ============================================================================== """WikiGraphs model modules.""" +from . import embedding from . import graph_net +from . import transformer +from . import transformer_block diff --git a/wikigraphs/wikigraphs/model/embedding.py b/wikigraphs/wikigraphs/model/embedding.py new file mode 100644 index 0000000..bc663ce --- /dev/null +++ b/wikigraphs/wikigraphs/model/embedding.py @@ -0,0 +1,381 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Transformer embedding modules.""" + +from typing import List, Optional + +import haiku as hk +from haiku import initializers as init +import jax +import jax.numpy as jnp +import jraph + +from wikigraphs.model import graph_net as gn + + +def get_pos_start(timesteps: int, batch_size: int) -> jnp.ndarray: + """Find the right slice of positional embeddings for incremental sampling.""" + pos_start = hk.get_state( + 'cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) + hk.set_state('cache_progress_idx', pos_start + timesteps) + return pos_start + + +class SinusoidalPositionEmbedding(hk.Module): + """Position encoding, using mixture of sinusoidal signals.""" + + def __init__(self, + dim: int, + cache_steps: int = 0, + reverse_order: bool = False, + clamp_len: Optional[int] = None, + name: Optional[str] = None): + """Initialize a SinusoidalPositionEmbedding. + + Args: + dim: Embedding dimension. + cache_steps: The length of the memory. + reverse_order: If set to True, position index is reversed. + clamp_len: position beyond clamp_len will be reset to clamp_len, default + to not clamping. + name: Optional name for this Haiku module. + """ + super(SinusoidalPositionEmbedding, self).__init__(name=name) + self._dim = dim + self._cache_steps = cache_steps + self._reverse_order = reverse_order + self._clamp_len = clamp_len + self._inv_freq = 1.0 / ( + 10000 ** (jnp.arange(0, dim, 2).astype(jnp.float32) / dim)) + + def __call__(self, timesteps: int, batch_size: int) -> jnp.ndarray: + """Computes the sinusoidal position embedding. + + Args: + timesteps: The length of the sequence. + batch_size: The size of the batch. + + Returns: + Sinusoidal position embedding. + """ + full_length = timesteps + self._cache_steps + + if self._reverse_order: + positions = jnp.arange(full_length - 1, -1, -1) + positions = jnp.repeat(positions[None, :], batch_size, axis=0) + else: + if self._cache_steps > 0: + positions = (get_pos_start(timesteps, batch_size)[:, None] + + jnp.arange(timesteps)[None, :]) + else: + positions = jnp.arange(0, full_length) + positions = jnp.repeat(positions[None, :], batch_size, axis=0) + + if self._clamp_len is not None: + positions = jnp.minimum(positions, self._clamp_len) + + scaled_time = positions[:, :, None] * self._inv_freq[None, None, :] + return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2) + + +def relative_shift(x: jnp.ndarray) -> jnp.ndarray: + """Shift the relative logits.""" + x_shape = list(x.shape) + x = jnp.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]]) + x = jnp.reshape( + x, [x_shape[0], x_shape[1], x_shape[3] + 1, x_shape[2]])[:, :, 1:, :] + x = jnp.reshape(x, x_shape) + return x + + +class RelativePositionEmbedding(hk.Module): + """Position encoding, using relative positions than absolute positions.""" + + def __init__(self, + dim: int, + dropout_rate: float, + r_w_bias: jnp.ndarray, + r_r_bias: jnp.ndarray, + init_scale: float = 0.02, + clamp_len: Optional[int] = None, + name: Optional[str] = None): + """Initialize a RelativePositionEmbedding. + + Args: + dim: Embedding dimension. + dropout_rate: dropout rate. + r_w_bias: global content bias. + r_r_bias: global positional bias. + init_scale: the initialization scale of the RandomNormal used for the + linear layer. + clamp_len: position beyond clamp_len will be reset to clamp_len, default + to not clamping. + name: Optional name for this Haiku module. + """ + super(RelativePositionEmbedding, self).__init__(name=name) + self._dim = dim + self._dropout_rate = dropout_rate + self._r_w_bias = r_w_bias + self._r_r_bias = r_r_bias + self._init_scale = init_scale + self._sinusoidal_pos_emb = SinusoidalPositionEmbedding( + dim=dim, + reverse_order=True, + clamp_len=clamp_len, + name=name) + + def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray: + """Computes the relative position embedding. + + Args: + q: The query. + k: The key. + + Returns: + Relative position embedding. + """ + # Use key instead of query to obtain the length. + batch_size, key_length, num_heads, head_dim = list(k.shape) + # Content based addressing and global content bias + content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k) + + # Relative position encoding + positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size) + positional_encodings = hk.dropout(hk.next_rng_key(), self._dropout_rate, + positional_encodings) + rel_pos_emb = hk.Conv1D( + output_channels=self._dim, kernel_shape=1, with_bias=False, + w_init=init.RandomNormal(stddev=self._init_scale))(positional_encodings) + rel_pos_emb = jnp.reshape(rel_pos_emb, [ + batch_size, key_length, num_heads, head_dim]) + + # Content dependent positional bias and global positional bias + rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias, + rel_pos_emb) + rel_pos_score = relative_shift(rel_pos_score) + assert content_score.shape == rel_pos_score.shape + return content_score + rel_pos_score + + +def hierarchical_logprobs( + logits: jnp.ndarray, + class_logits: jnp.ndarray, + cutoffs: List[int]) -> jnp.ndarray: + """Hierarchical log-probs for adaptive softmax.""" + sizes = [y - x for x, y in zip(cutoffs[:-1], cutoffs[1:])] + num_tails = len(sizes) - 1 + split_logits = jnp.split(logits, cutoffs[1:-1], axis=-1) + all_head_logits = jnp.concatenate([split_logits[0], class_logits], -1) + # Mask out item 0, the NULL token + all_head_logits += jnp.concatenate( + [jnp.ones([1], dtype=logits.dtype) * -10, + jnp.zeros([sizes[0] + num_tails - 1], dtype=logits.dtype)], 0) + all_head_logprobs = jax.nn.log_softmax(all_head_logits) + head_logprobs, class_logprobs = jnp.split(all_head_logprobs, + [sizes[0]], axis=-1) + tail_logprobs = [] + for i, tail_size in enumerate(sizes[1:]): # pylint: disable=unused-variable + tail_logprobs += [jax.nn.log_softmax(split_logits[i + 1]) + + class_logprobs[..., [i]]] + return jnp.concatenate([head_logprobs] + tail_logprobs, -1) + + +class AdaptiveSoftmaxEmbedding(hk.Module): + """Adaptive inputs and softmax (https://arxiv.org/abs/1809.10853).""" + + def __init__(self, + dim: int, + vocab_size: int, + cutoffs: List[int], + tail_shrink_factor: int = 4, + hierarchical: bool = True, + init_std: float = 0.02, + init_proj_std: float = 0.01, + dtype: jnp.dtype = jnp.float32, + name: Optional[str] = None): + """Initialize a AdaptiveSoftmaxEmbedding. + + Args: + dim: dimensionality of the hidden space. + vocab_size: the size of the vocabulary. + cutoffs: the cutoff indices of the vocabulary used for the adaptive + softmax embedding. + tail_shrink_factor: how many times to shrink the hidden dimensionality + for low-frequency vocabulary after each cutoff. + hierarchical: whether to use hierarchical softmax. + init_std: standard deviation of the Normal distribution used to initialize + the embedding weights. + init_proj_std: standard deviation of the Normal distribution used to + initialize the projection weights. + dtype: Optional data type default to jnp.float32. + name: Optional name for this Haiku module. + """ + super(AdaptiveSoftmaxEmbedding, self).__init__(name=name) + self._hidden_size = dim + self._vocab_size = vocab_size + self._cutoffs = [0] + list(cutoffs) + [self._vocab_size] + self._tail_shrink_factor = tail_shrink_factor + self._hierarchical = hierarchical + self._dtype = dtype + self._embeddings = [] + self._projections = [] + + self._bias = hk.get_parameter( + 'bias', [self._vocab_size], dtype=self._dtype, init=jnp.zeros) + + l_cutoffs = self._cutoffs[:-1] + r_cutoffs = self._cutoffs[1:] + for i, (l_cutoff, r_cutoff) in enumerate(zip(l_cutoffs, r_cutoffs)): + hidden_size = self._hidden_size // (self._tail_shrink_factor ** i) + embedding = hk.get_parameter( + f'embeddings_{l_cutoff}_{r_cutoff}', + [r_cutoff - l_cutoff, hidden_size], + dtype=self._dtype, + init=hk.initializers.RandomNormal(stddev=init_std)) + self._embeddings += [embedding] + if self._tail_shrink_factor != 1: + projection = hk.get_parameter( + f'projection_{l_cutoff}_{r_cutoff}', + [hidden_size, self._hidden_size], + dtype=self._dtype, + init=hk.initializers.RandomNormal(stddev=init_proj_std)) + self._projections += [projection] + + if self._tail_shrink_factor != 1: + self._output_projection = hk.get_parameter( + 'output_head_projection', + [self._hidden_size, self._hidden_size], + dtype=self._dtype, + init=hk.initializers.RandomNormal(stddev=init_proj_std)) + + if self._hierarchical: + self._class_weights = hk.get_parameter( + 'tail_class_weights', + [self._hidden_size, len(cutoffs)], + init=hk.initializers.RandomNormal(stddev=init_std)) + self._class_bias = hk.get_parameter( + 'tail_class_bias', + [len(cutoffs)], + dtype=self._dtype, + init=jnp.zeros) + + @hk.transparent + def build_embeddings(self): + """Builds input embeddings.""" + if self._projections: + embedding_mat = [ + jnp.dot(emb, proj) for emb, proj in zip(self._embeddings, + self._projections)] + else: + embedding_mat = self._embeddings + input_embeddings = jnp.concatenate(embedding_mat, 0) + return input_embeddings + + @hk.transparent + def build_output_embeddings(self): + """Builds separate output embeddings.""" + if self._projections: + projections = [self._output_projection] + self._projections[1:] + embedding_mat = [jnp.dot(emb, proj) + for emb, proj in zip(self._embeddings, projections)] + else: + embedding_mat = self._embeddings + output_embeddings = jnp.concatenate(embedding_mat, 0) + return jnp.transpose(output_embeddings) + + def embed_input(self, input_tokens: jnp.ndarray) -> jnp.ndarray: + """Embeds the input.""" + assert jnp.issubdtype(input_tokens.dtype, jnp.integer) + input_embeddings = self.build_embeddings() + embedded_inputs = input_embeddings[input_tokens] + return embedded_inputs * self._hidden_size ** 0.5 + + def embed_output(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Outputs logits.""" + output_embs = self.build_output_embeddings() + logits = jnp.einsum('btd,dv->btv', inputs, output_embs) + self._bias + if self._hierarchical: + class_logits = jnp.dot(inputs, self._class_weights) + self._class_bias + logprobs = hierarchical_logprobs(logits, class_logits, self._cutoffs) + return logprobs + else: + return logits + + +class GraphEmbeddingModel(hk.Module): + """A single graph network for embedding graph data.""" + + def __init__(self, + embed_dim: int, + num_layers: int, + msg_hidden_size_factor: int = 2, + use_layer_norm: bool = False, + name: Optional[str] = None): + """Constructor. + + Args: + embed_dim: node embedding size. + num_layers: number of message passing layers to use. + msg_hidden_size_factor: size of the message network hiddens as a factor + of embed_dim. + use_layer_norm: whether to apply layer norm on node updates. + name: optional name for this module. + """ + super().__init__(name=name) + self._embed_dim = embed_dim + self._num_layers = num_layers + self._msg_hidden_size_factor = msg_hidden_size_factor + self._use_layer_norm = use_layer_norm + + def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + """Compute embeddings for each node in the graphs. + + Args: + graphs: a set of graphs batched into a single graph. The nodes and edges + are represented as feature tensors. + + Returns: + graphs: new graph with node embeddings updated (shape [n_nodes, + embed_dim]). + """ + nodes = hk.Linear(self._embed_dim)(graphs.nodes) + edges = hk.Linear(self._embed_dim)(graphs.edges) + + nodes = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)( + jax.nn.gelu(nodes)) + edges = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)( + jax.nn.gelu(edges)) + + graphs = graphs._replace(nodes=nodes, edges=edges) + graphs = gn.SimpleGraphNet( + num_layers=self._num_layers, + msg_hidden_size_factor=self._msg_hidden_size_factor, + layer_norm=self._use_layer_norm)(graphs) + return graphs diff --git a/wikigraphs/wikigraphs/model/transformer.py b/wikigraphs/wikigraphs/model/transformer.py new file mode 100644 index 0000000..7362ff7 --- /dev/null +++ b/wikigraphs/wikigraphs/model/transformer.py @@ -0,0 +1,522 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Jax implementation of the Transformer-XL model.""" + +from typing import Dict, List, Optional, Tuple + +import haiku as hk +from haiku import initializers as init +import jax +import jax.numpy as jnp +import jraph +import numpy as np + +from wikigraphs.model import transformer_block +from wikigraphs.model.embedding import AdaptiveSoftmaxEmbedding +from wikigraphs.model.embedding import GraphEmbeddingModel + + +# For WikiText-103 +DEFAULT_CUTOFFS = (20000 + 1, 40000 + 1, 200000 + 1) + + +def sequence_prediction_metrics( + logits: jnp.ndarray, + labels: jnp.ndarray, + mask: Optional[jnp.ndarray] = None + ) -> Dict[str, float]: + """Compute the metrics for sequence prediction. + + Args: + logits: [B, T, V] array of logits. + labels: [B, T] array of labels. + mask: [B, T] array of binary masks, if provided. + + Returns: + metrics: a dictionary of metrics. + """ + vocab_size = logits.shape[-1] + logps = jax.nn.log_softmax(logits) + labels_one_hot = hk.one_hot(labels, vocab_size) + class_logps = jnp.sum(logps * labels_one_hot, axis=-1) + prediction_correct = jnp.argmax(logits, axis=-1) == labels + if mask is not None: + masked_logps = mask * class_logps + total_count = jnp.sum(mask) + tokens_correct = jnp.sum(prediction_correct * mask) + seq_correct = jnp.all( + jnp.logical_or(prediction_correct, jnp.logical_not(mask)), axis=-1) + else: + masked_logps = class_logps + total_count = np.prod(class_logps.shape) + tokens_correct = jnp.sum(prediction_correct) + seq_correct = jnp.all(prediction_correct, axis=-1) + + token_accuracy = tokens_correct.astype(jnp.float32) / total_count + seq_accuracy = jnp.mean(seq_correct) + log_probs = jnp.mean(jnp.sum(masked_logps, axis=-1)) + total_loss = -jnp.sum(masked_logps) + loss = total_loss / total_count + return dict( + loss=loss, + total_loss=total_loss, + total_count=total_count, + token_accuracy=token_accuracy, + seq_accuracy=seq_accuracy, + log_probs=log_probs, + ) + + +class TransformerXL(hk.Module): + """TransformerXL language model with memory using GPT2 blocks. + + TransformerXL: https://arxiv.org/abs/1901.02860 + GPT-2: http://www.persagen.com/files/misc/radford2019language.pdf + """ + + def __init__(self, + vocab_size: int = 256, + emb_dim: int = 256, + num_layers: int = 10, + num_heads: int = 8, + dropout_prob: float = 0.1, + dropout_attn_prob: float = 0.0, + self_att_init_scale: float = 0.02, + dense_init_scale: float = 0.02, + dense_dim: int = 2100, + cutoffs: List[int] = DEFAULT_CUTOFFS, + tail_shrink_factor: int = 1, + relative_pos_clamp_len: Optional[int] = None, + name: Optional[str] = None): + """Initialize a TransformerXL. + + Args: + vocab_size: the size of the vocabulary. + emb_dim: the dimensionality of the embeddings. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + dropout_prob: dropout probability. + dropout_attn_prob: dropout probability of the attention module. + self_att_init_scale: the initialization scale of the VarianceScaling + used for the linear layer in the attention module. + dense_init_scale: the initialization scale of the VarianceScaling + used for the linear layer in the feedforward module. + dense_dim: feature size of the feedforward block. + cutoffs: the cutoff indices of the vocabulary used for the adaptive + softmax embedding. + tail_shrink_factor: how many times to shrink the hidden dimensionality + for low-frequency vocabulary after each cutoff in the adaptive softmax + embedding. + relative_pos_clamp_len: clamp length of the relative position embeddings. + name: Optional name for this Haiku module. + """ + super().__init__(name=name) + self._vocab_size = vocab_size + self._emb_dim = emb_dim + self._num_layers = num_layers + self._num_heads = num_heads + self._dropout_prob = dropout_prob + self._dropout_attn_prob = dropout_attn_prob + self._self_att_init_scale = self_att_init_scale + self._dense_init_scale = dense_init_scale + self._dense_dim = dense_dim + self._relative_pos_clamp_len = relative_pos_clamp_len + self._io_emb = AdaptiveSoftmaxEmbedding( + emb_dim, vocab_size, cutoffs=cutoffs, + tail_shrink_factor=tail_shrink_factor) + + def __call__(self, + x: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + is_training: bool = True, + should_reset: Optional[jnp.ndarray] = None, + cache_steps: int = 0, + extra: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Computes the outputs of the TransformerXL. + + Args: + x: [batch, timesteps]. Inputs at time step t. + mask: [batch, timesteps]. It indicates what tokens to be predicted. In + other words it corresponds to non-pad tokens in x_{t+1}. + is_training: whether the current stage is training or not. + should_reset: reset marker [batch, timesteps]. + cache_steps: number of timesteps in the cache. + extra: if provided should be extra key-value input + [batch, extra_timesteps, in_dim]. + extra_mask: if provided should be the mask for extra key-value input, + [batch, extra_timesteps]. + + Returns: + output: transformer output [batch, timesteps]. + """ + if cache_steps == 0: + cache_steps = x.shape[1] + if should_reset is None: + should_reset = jnp.where(x == 1, 1, 0) + h = self._io_emb.embed_input(x) + + if mask is not None: + attention_mask = mask[:, None, None, :] + else: + attention_mask = None + + head_dim = self._emb_dim // self._num_heads + assert self._emb_dim % self._num_heads == 0, 'Head dim should be an int.' + + # Biases for relative position embedding shared across all layers + r_w_bias = hk.get_parameter( + 'r_w_bias', [1, 1, self._num_heads, head_dim], + init=init.RandomNormal(stddev=self._self_att_init_scale)) + r_r_bias = hk.get_parameter( + 'r_r_bias', [1, 1, self._num_heads, head_dim], + init=init.RandomNormal(stddev=self._self_att_init_scale)) + + for i in range(self._num_layers): + if mask is not None: + h *= mask[:, :, None] + h = transformer_block.GPT2Block( + r_w_bias=r_w_bias, + r_r_bias=r_r_bias, + causal=True, + dense_dim=self._dense_dim, + dropout_prob=self._dropout_prob, + dropout_attn_prob=self._dropout_attn_prob, + num_heads=self._num_heads, + self_att_init_scale=self._self_att_init_scale, + dense_init_scale=self._dense_init_scale, + relative_pos_clamp_len=self._relative_pos_clamp_len, + name='transformer_block_{}'.format(i), + )( + h, mask=attention_mask, is_training=is_training, + should_reset=should_reset, cache_steps=cache_steps, + extra=extra, extra_mask=extra_mask) + + if mask is not None: + h *= mask[:, :, None] + return self._io_emb.embed_output(h) + + def loss(self, + inputs: jnp.ndarray, + labels: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + is_training: bool = True, + should_reset: Optional[jnp.ndarray] = None, + cache_steps: int = 0, + extra: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None + ) -> Tuple[float, Dict[str, float]]: + """Computes the loss of the TransformerXL. + + Args: + inputs: [batch, timesteps]. + labels: [batch, timesteps]. + mask: [batch, timesteps]. It indicates what tokens to be predicted. In + other words it corresponds to non-pad tokens in the `labels`. + is_training: whether the current stage is training or not. + should_reset: reset marker [batch, timesteps]. + cache_steps: number of timesteps in the cache. + extra: if provided should be extra key-value input + [batch, extra_timesteps, in_dim]. + extra_mask: if provided should be the mask for extra key-value input, + [batch, extra_timesteps]. + + Returns: + output: loss and a dict containing metrics. + """ + # [B, T, V] + logits = self(inputs, mask=mask, is_training=is_training, + should_reset=should_reset, cache_steps=cache_steps, + extra=extra, extra_mask=extra_mask) + + metrics = sequence_prediction_metrics(logits, labels, mask) + return metrics['loss'], metrics + + +def repeat_rows(a: jnp.ndarray, repeats: int, out_length: int) -> jnp.ndarray: + """Repeat rows of input tensor a. + + Output is + [a[0], + a[0], + ... + a[0], # A total of repeats[0] copies of a[0]. + a[1], + a[1], + ..., + a[1], # A total of repeats[1] copies of a[1]. + ... + a[n-1]], # A total of repeats[n-1] copies of a[n-1]. + + Args: + a: [n_rows, ...] input tensor. + repeats: [n_rows] int tensor, the number of repeats for each row. + out_length: number of rows in the output, it should be the same as + sum(repeats), provided to be static for jit. + + Returns: + out: [out_length, ...] output tensor. + """ + a = jnp.asarray(a) + n = a.shape[0] + assert n == repeats.size + chunk_start = jnp.cumsum(repeats) + idx = jnp.sum(jnp.arange(out_length)[:, None] >= chunk_start[None, :], + axis=-1) + return a[idx] + + +def unpack_and_pad( + packed: jnp.ndarray, + split_sizes: jnp.ndarray, + pad_size: int, + pad_value: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Unpack and pad tensors to a standard size. + + Args: + packed: a [total_size, ...] tensor, which contains n individual tensors + concatenated along the 0-th axis. + split_sizes: size [n] int tensor, size of each individual tensor. + pad_size: size for each split to pad to. + pad_value: the value to use for padding. + + Returns: + tensors: [n, pad_size, ...] tensor, tensors[i] is the i-th individual tensor + padded to pad_size length. + mask: [n, pad_size] mask tensor indicating which value is padded. + """ + in_shape = list(packed.shape) + total_size = in_shape[0] + n_splits = split_sizes.shape[0] + idx = jnp.arange(pad_size) + masks = split_sizes[:, None] > idx[None, :] + + out_shape = in_shape[:] + out_shape[0] = n_splits * pad_size + out = jnp.full(out_shape, pad_value, dtype=packed.dtype) + # Index for the rows of `packed`: + # Define split_start[k] = sum_{i=0}^{k-1} split_sizes[i], which is the + # starting index of split k. So if split_start[k] <= i < split_start[k+1] + # then index belongs to split k. We therefore have: + # idx[i] = k * pad_size + i - split_start[k] + cumsum = jnp.concatenate([jnp.array([0], dtype=split_sizes.dtype), + jnp.cumsum(split_sizes)[:-1]]) + idx = jnp.arange(total_size) + idx += repeat_rows(jnp.arange(n_splits), split_sizes, total_size) * pad_size + idx -= repeat_rows(cumsum, split_sizes, total_size) + out = jax.ops.index_update(out, idx, packed) + out = out.reshape([n_splits, pad_size] + out_shape[1:]) + return out, masks + + +class Graph2TextTransformer(hk.Module): + """A graph2text TransformerXL model. + + It embeds the graph with a simple graph neural network model, and passes the + graph embeddings to the TransformerXL model, which are presented as the extra + inputs to attend to in addition to the text embeddings inputs. + """ + + def __init__(self, + *transformer_args, + gnn_embed_dim: int = 128, + gnn_num_layers: int = 5, + gnn_layer_norm: bool = False, + name: Optional[str] = None, + **transformer_kwargs): + """Constructor. + + Args: + *transformer_args: args for the transformer module. + gnn_embed_dim: node embedding size. + gnn_num_layers: number of message passing layers to use. + gnn_layer_norm: whether to use layer norm in the GNN. + name: optional name for this module. + **transformer_kwargs: kwargs for the transformer module. + """ + super().__init__(name=name) + self._transformer = TransformerXL(*transformer_args, **transformer_kwargs) + self._gnn = GraphEmbeddingModel( + embed_dim=gnn_embed_dim, + num_layers=gnn_num_layers, + use_layer_norm=gnn_layer_norm) + + def _encode_graphs(self, + graphs: jraph.GraphsTuple, + pad_n_nodes: Optional[int] = None, + padded: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Encode graphs so that it can be used in the transformer. + + Args: + graphs: a graph structured using jraph.GraphsTuple. + pad_n_nodes: size for each node to pad to. + padded: Whether to pad each graph to the same number of nodes. + + Returns: + tensors: unpacked and padded graph nodes. + mask: mask tensor indicating which value is padded. + """ + graphs = self._gnn(graphs) + if pad_n_nodes is None: + pad_n_nodes = graphs.n_node.max() + out, mask = unpack_and_pad(graphs.nodes, graphs.n_node, pad_n_nodes) + if padded: + # Remove the padding graph from the batch + return out[:-1], mask[:-1] + else: + return out, mask + + def __call__(self, + graphs: jraph.GraphsTuple, + pad_n_nodes: int, + batch_padded: bool, + *args, **kwargs): + """Computes the outputs of the graph2text TransformerXL. + + Args: + graphs: a graph structured using graph_net.Graph. + pad_n_nodes: size for each node to pad to. + batch_padded: whether the graph batch is padded or not. + *args: args to the TransformerXL model. + **kwargs: kwargs to the TransformerXL model. + + Returns: + output: transformer output [batch, timesteps]. + """ + extra, extra_mask = self._encode_graphs(graphs, pad_n_nodes, batch_padded) + return self._transformer( + *args, extra=extra, extra_mask=extra_mask, **kwargs) + + def loss(self, + graphs: jraph.GraphsTuple, + pad_n_nodes: int, + batch_padded: bool, + inputs: jnp.ndarray, + labels: jnp.ndarray, + mask: jnp.ndarray, + **kwargs): + """Computes the loss of the graph2text TransformerXL. + + Args: + graphs: a graph structured using graph_net.Graph. + pad_n_nodes: size for each node to pad to. + batch_padded: whether the graph batch is padded or not. + inputs: [batch, timesteps]. + labels: [batch, timesteps]. + mask: [batch, timesteps]. + **kwargs: kwargs to the TransformerXL model. + + Returns: + output: loss and a dict containing metrics. + """ + extra, extra_mask = self._encode_graphs(graphs, pad_n_nodes, batch_padded) + return self._transformer.loss( + inputs, labels, mask, extra=extra, extra_mask=extra_mask, **kwargs) + + +class Bow2TextTransformer(hk.Module): + """A bag-of-words to text TransformerXL model. + + This model embeds bag-of-words into vectors and the text transformer can then + condition on these vectors to generate text. + + More specifically, the bow embedded vectors will be treated as extra tokens + that the transformer can attend to, in addition to the text data it is already + modelling. + + To make the model more expressive, we allow each bag-of-words to be embedded + into potentially more than 1 vectors, and the transformer will treat them as + more than 1 extra tokens correspondingly. + """ + + def __init__(self, + *transformer_args, + bow_embedding_dim: int = 256, + bow_n_tokens: int = 1, + name: Optional[str] = None, + **transformer_kwargs): + """Constructor. + + Args: + *transformer_args: the TransformerXL constructor arguments. + bow_embedding_dim: dimensionality for the bag-of-words embeddings. + bow_n_tokens: number of extra tokens to create for the bag-of-words + representations. + name: optional name for this module. + **transformer_kwargs: kwargs for the transformer module. + """ + super().__init__(name=name) + self._transformer = TransformerXL(*transformer_args, **transformer_kwargs) + self._bow_embedding_dim = bow_embedding_dim + self._bow_n_tokens = bow_n_tokens + + def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray: + """Encode the bag-of-words into tensors that can be used by the transormer. + + Args: + bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. + + Returns: + embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor. + """ + batch_size = bow.shape[0] + bow = bow.astype(jnp.float32) + + # [B, D * n] + embeddings = hk.Linear(self._bow_embedding_dim * self._bow_n_tokens)(bow) + embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings)) + return jnp.reshape( + embeddings, [batch_size, self._bow_n_tokens, self._bow_embedding_dim]) + + def __call__(self, bow: jnp.ndarray, *args, **kwargs): + """Compute the output of this bag-of-words-to-text transformer model. + + Args: + bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. + *args: args to the TransformerXL model. + **kwargs: kwargs to the TransformerXL model. + + Returns: + output: transformer output [batch, timesteps]. + """ + return self._transformer(*args, extra=self._encode_bow(bow), **kwargs) + + def loss(self, bow: jnp.ndarray, *args, **kwargs): + """Computes the loss of the graph2text TransformerXL. + + Args: + bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. + *args: args to the TransformerXL model. + **kwargs: kwargs to the TransformerXL model. + + Returns: + output: loss and a dict containing metrics. + """ + return self._transformer.loss(*args, extra=self._encode_bow(bow), **kwargs) diff --git a/wikigraphs/wikigraphs/model/transformer_block.py b/wikigraphs/wikigraphs/model/transformer_block.py new file mode 100644 index 0000000..1113e82 --- /dev/null +++ b/wikigraphs/wikigraphs/model/transformer_block.py @@ -0,0 +1,529 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Transformer blocks.""" + +import math +from typing import Callable, Optional + +import haiku as hk +from haiku import initializers as init +import jax +import jax.numpy as jnp + +from wikigraphs.model.embedding import RelativePositionEmbedding + + +def conv1d(x, num_units, init_scale=0.02, with_bias=True): + return hk.Conv1D( + output_channels=num_units, kernel_shape=1, with_bias=with_bias, + w_init=init.RandomNormal(stddev=init_scale))(x) + + +def layer_norm(x): + return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) + + +class FeedForwardBlock(hk.Module): + """Feed forward block.""" + + def __init__(self, + dense_dim: int = 2100, + dropout_prob: float = 0.1, + init_scale: float = 1., + name: Optional[str] = None): + """Initializes a FeedForwardBlock. + + Args: + dense_dim: feature size of the feedforward block. + dropout_prob: dropout probability. + init_scale: the initialization scale of the VarianceScaling used for the + feedforward layer. + name: Optional name for this Haiku module. + """ + super(FeedForwardBlock, self).__init__(name=name) + self._dense_dim = dense_dim + self._dropout_prob = dropout_prob + self._init_scale = init_scale + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + hiddens = x.shape[-1] + x = conv1d(x, num_units=self._dense_dim, init_scale=self._init_scale) + x = jax.nn.relu(x) + x = hk.dropout(hk.next_rng_key(), self._dropout_prob, x) + x = conv1d(x, num_units=hiddens, init_scale=self._init_scale) + return hk.dropout(hk.next_rng_key(), self._dropout_prob, x) + + +def get_reset_attention_mask(should_reset: jnp.ndarray) -> jnp.ndarray: + """Maps a reset token vector into an attention mask that consists of blocks. + + A sequence of should reset tokens such as: + [0, 1, 0, 1, 0, 0] + transforms into an attention mask such as: + [[1, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]] + Args: + should_reset: Reset tokens with shape [batch, timesteps]. + Returns: + attention_mask: Attention mask with shape [batch, timesteps, timesteps]. + """ + should_reset = jnp.cumsum(should_reset, axis=-1) + attention_mask = should_reset[:, :, None] == should_reset[:, None, :] + return attention_mask.astype(jnp.float32) + + +def attend(q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + attend_fn: + Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, + dropout_prob: float = 0.0, + extra_k: Optional[jnp.ndarray] = None, + extra_v: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Computes multi-head attention using the given query, key and value. + + Args: + q: Query with shape [batch, q_timesteps, num_heads, head_dim]. + k: Key with shape [batch, timesteps, num_heads, head_dim]. + v: Value with shape [batch, timesteps, num_heads, head_dim]. + mask: Attention mask to apply [batch, 1, q_timesteps, timesteps]. + attend_fn: An optionally defined attend function. The default attend_fn is + is jnp.einsum('bthd,bThd->bhtT', q, k). + dropout_prob: dropout probability on the attention weights. + extra_k: Extra keys to attend to, if provided. Note the extra keys and + values do not apply the specified attention_fn, but instead use the + default dot-product attention. [batch, timesteps_extra, num_heads, + head_dim]. + extra_v: Extra values to attend to, if provided. [batch, timesteps_extra, + num_heads, head_dim]. + extra_mask: Extra attention mask to apply on the extra inputs [batch, 1, + q_timesteps, timesteps_extra]. + + Returns: + Output of the attention with shape [batch, timesteps, hiddens] + """ + infinity_proxy = 1e9 + batch, q_time, num_heads, head_dim = q.shape + hiddens = num_heads * head_dim + + _, kv_time, _, _ = k.shape + expected_kv_shape = (batch, kv_time, num_heads, head_dim) + + if k.shape != expected_kv_shape: + raise ValueError( + f'Expected key shape {expected_kv_shape} but got shape {k.shape}') + if v.shape != expected_kv_shape: + raise ValueError( + f'Expected value shape {expected_kv_shape} but got shape {v.shape}') + + if attend_fn is not None: + attention = attend_fn(q, k) + else: + attention = jnp.einsum('bthd,bThd->bhtT', q, k) + + if mask is not None: + attention = attention * mask - infinity_proxy * (1 - mask) + + if extra_k is not None and extra_v is not None: + extra_time = extra_k.shape[1] + expected_extra_shape = (batch, extra_time, num_heads, head_dim) + if extra_k.shape != expected_extra_shape: + raise ValueError( + f'Expected extra key shape {expected_extra_shape} but got' + f' {extra_k.shape}') + if extra_v.shape != expected_extra_shape: + raise ValueError( + f'Expected extra value shape {expected_extra_shape} but got' + f' {extra_v.shape}') + + # [B, H, t, T'] + extra_attention = jnp.einsum('bthd,bThd->bhtT', q, extra_k) + if extra_mask is not None: + extra_attention = extra_attention * extra_mask - infinity_proxy * ( + 1 - extra_mask) + + # [B, H, t, T+T'] + attention = jnp.concatenate([attention, extra_attention], axis=-1) + # [B, T+T', H, D] + v = jnp.concatenate([v, extra_v], axis=1) + + scale = 1. / math.sqrt(head_dim) + attention *= scale + normalized = jax.nn.softmax(attention) + if dropout_prob > 0: + normalized = hk.dropout(hk.next_rng_key(), dropout_prob, normalized) + summed = jnp.einsum('bhtT,bThd->bthd', normalized, v) + return jnp.reshape(summed, [batch, q_time, hiddens]) + + +class Attention(hk.Module): + """Attention with memory (https://arxiv.org/abs/1901.02860). + + This implementation leverages the `state` in Haiku, in which the inputs are + stored as `states`. At each step, these states in memory are updated with a + rolling window. + """ + + def __init__(self, + r_w_bias: Optional[jnp.ndarray] = None, + r_r_bias: Optional[jnp.ndarray] = None, + num_heads: int = 8, + init_scale: float = 1.0, + with_final_bias: bool = False, + final_init_scale_multiplier: float = 1., + relative_pos_clamp_len: Optional[int] = None, + dropout_prob: float = 0.0, + name: Optional[str] = None): + """Initializes a Attention module. + + Args: + r_w_bias: global content bias. + r_r_bias: global positional bias. + num_heads: number of attention heads. + init_scale: the initialization scale of the VarianceScaling used for the + linear layer. + with_final_bias: whether to let final layer have biases. + final_init_scale_multiplier: how much to scale the initialization scale of + the output layer. + relative_pos_clamp_len: clamp length of the relative position embeddings. + dropout_prob: dropout probability. + name: Optional name for this Haiku module. + """ + super(Attention, self).__init__(name=name) + self._r_w_bias = r_w_bias + self._r_r_bias = r_r_bias + self._num_heads = num_heads + self._init_scale = init_scale + self._with_final_bias = with_final_bias + self._final_init_scale = final_init_scale_multiplier * init_scale + self._relative_pos_clamp_len = relative_pos_clamp_len + self._dropout_prob = dropout_prob + + def _update_cache(self, + key: jnp.ndarray, + value: jnp.ndarray, + cache_steps: Optional[int] = None, + axis: int = 1) -> jnp.ndarray: + """Update the cache stored in hk.state.""" + cache_shape = list(value.shape) + value_steps = cache_shape[axis] + if cache_steps is not None: + cache_shape[axis] += cache_steps + cache = hk.get_state( + key, shape=cache_shape, dtype=value.dtype, init=jnp.zeros) + + # Overwrite at index 0, then rotate timesteps left so what was just + # inserted is first. + value = jax.lax.dynamic_update_slice( + cache, value, jnp.zeros(len(cache_shape), dtype=jnp.int32)) + value = jnp.roll(value, -value_steps, axis) + hk.set_state(key, value) + return value + + def _update_memory(self, + mem: jnp.ndarray, + mask: jnp.ndarray, + input_length: int, + cache_steps: int, + should_reset: jnp.ndarray) -> jnp.ndarray: + """Logic for using and updating cached activations.""" + batch_size = mem.shape[0] + if cache_steps > 0: + # Tells us how much of the cache should be used. + cache_progress_idx = hk.get_state( + 'cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) + hk.set_state('cache_progress_idx', cache_progress_idx + input_length) + mem = self._update_cache('mem', mem, cache_steps=cache_steps) + if mask is None: + mask = jnp.ones((batch_size, 1, input_length, input_length)) + cache_mask = (jnp.arange(cache_steps - 1, -1, -1)[None, None, None, :] + < cache_progress_idx[:, None, None, None]) + cache_mask = jnp.broadcast_to( + cache_mask, (batch_size, 1, input_length, cache_steps)) + mask = jnp.concatenate([cache_mask, mask], axis=-1) + if should_reset is not None: + if cache_steps > 0: + should_reset = self._update_cache('should_reset', should_reset, + cache_steps=cache_steps) + reset_mask = get_reset_attention_mask(should_reset)[:, None, :, :] + mask *= reset_mask[:, :, cache_steps:, :] + return mem, mask + + def __call__(self, + x: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + should_reset: Optional[jnp.ndarray] = None, + cache_steps: int = 0, + extra: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Compute the multi-head attention. + + Args: + x: input [batch, x_timesteps, in_dim]. + mask: attention mask [batch, 1, x_timesteps, y_timesteps]. + should_reset: reset marker [batch, timesteps]. + cache_steps: number of timesteps in the cache. + extra: if provided should be extra key-value input + [batch, extra_timesteps, in_dim']. + extra_mask: if provided should be the mask for extra key-value input, + [batch, extra_timesteps]. + + Returns: + output: attention output [batch, x_timesteps, in_dim]. + """ + hiddens_in = x.shape[-1] + steps = x.shape[1] + qkv_hiddens = hiddens_in + + y, mask = self._update_memory(x, mask, steps, cache_steps, should_reset) + + q = conv1d(x, qkv_hiddens, init_scale=self._init_scale, with_bias=False) + k = conv1d(y, qkv_hiddens, init_scale=self._init_scale, with_bias=False) + v = conv1d(y, qkv_hiddens, init_scale=self._init_scale, with_bias=False) + + batch, q_time, _ = q.shape + _, kv_time, _ = k.shape + head_dim = qkv_hiddens // self._num_heads + assert qkv_hiddens % self._num_heads == 0, 'Head dim should be an integer.' + q = jnp.reshape(q, [batch, q_time, self._num_heads, head_dim]) + k = jnp.reshape(k, [batch, kv_time, self._num_heads, head_dim]) + v = jnp.reshape(v, [batch, kv_time, self._num_heads, head_dim]) + + attend_fn = RelativePositionEmbedding( + dim=qkv_hiddens, dropout_rate=self._dropout_prob, + r_w_bias=self._r_w_bias, r_r_bias=self._r_r_bias, + init_scale=self._init_scale, clamp_len=self._relative_pos_clamp_len) + + if extra is not None: + extra_k = conv1d(extra, qkv_hiddens, init_scale=self._init_scale, + with_bias=False) + extra_v = conv1d(extra, qkv_hiddens, init_scale=self._init_scale, + with_bias=False) + extra_time = extra.shape[1] + extra_k = jnp.reshape( + extra_k, [batch, extra_time, self._num_heads, head_dim]) + extra_v = jnp.reshape( + extra_v, [batch, extra_time, self._num_heads, head_dim]) + if extra_mask is not None: + extra_mask = extra_mask[:, None, None, :] + attn_vec = attend(q, k, v, mask=mask, attend_fn=attend_fn, + dropout_prob=self._dropout_prob, + extra_k=extra_k, extra_v=extra_v, extra_mask=extra_mask) + else: + attn_vec = attend(q, k, v, mask=mask, attend_fn=attend_fn, + dropout_prob=self._dropout_prob) + attn_out = conv1d(attn_vec, hiddens_in, with_bias=self._with_final_bias, + init_scale=self._final_init_scale) + return hk.dropout(hk.next_rng_key(), self._dropout_prob, attn_out) + + +class SelfAttentionBlock(hk.Module): + """Self attention block.""" + + def __init__(self, + r_w_bias: Optional[jnp.ndarray] = None, + r_r_bias: Optional[jnp.ndarray] = None, + causal: bool = False, + num_heads: int = 8, + dropout_prob: float = 0.1, + dropout_attn_prob: float = 0.0, + init_scale: float = 1.0, + relative_pos_clamp_len: Optional[int] = None, + name: Optional[str] = None): + """Initializes a SelfAttentionBlock. + + Args: + r_w_bias: global content bias. + r_r_bias: global positional bias. + causal: whether to apply a causal mask to the input. + num_heads: number of attention heads. + dropout_prob: dropout probability. + dropout_attn_prob: dropout probability of the attention module. + init_scale: the initialization scale of the VarianceScaling used for the + linear layer. + relative_pos_clamp_len: clamp length of the relative position embeddings. + name: Optional name for this Haiku module. + """ + super(SelfAttentionBlock, self).__init__(name=name) + self._r_w_bias = r_w_bias + self._r_r_bias = r_r_bias + self._causal = causal + self._num_heads = num_heads + self._dropout_prob = dropout_prob + self._dropout_attn_prob = dropout_attn_prob + self._init_scale = init_scale + + self._relative_pos_clamp_len = relative_pos_clamp_len + + def __call__(self, + x: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + should_reset: Optional[jnp.ndarray] = None, + cache_steps: int = 0, + extra: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Computes the outputs of the self attention block. + + Args: + x: query input [batch, x_timesteps, in_dim]. + mask: attention mask [batch, 1, 1, x_timesteps]. + should_reset: reset marker [batch, timesteps]. + cache_steps: number of timesteps in the cache. + extra: if provided should be extra key-value input + [batch, extra_timesteps, in_dim']. + extra_mask: if provided should be the mask for extra key-value input, + [batch, extra_timesteps]. + + Returns: + output: block output [batch, x_timesteps, in_dim]. + """ + if self._causal: + timesteps = x.shape[1] + batch_size = x.shape[0] + t = jnp.arange(timesteps, dtype=jnp.int32) + causal_mask = (t[:, None] >= t[None, :])[None, None, :, :] + causal_mask = causal_mask.astype(x.dtype) + if mask is None: + mask = jnp.broadcast_to( + causal_mask, (batch_size, 1, timesteps, timesteps)) + else: + mask *= causal_mask + x = Attention( + self._r_w_bias, + self._r_r_bias, + num_heads=self._num_heads, + init_scale=self._init_scale, + relative_pos_clamp_len=self._relative_pos_clamp_len, + dropout_prob=self._dropout_attn_prob)( + x, mask=mask, should_reset=should_reset, + cache_steps=cache_steps, extra=extra, extra_mask=extra_mask) + else: + x = Attention( + self._r_w_bias, + self._r_r_bias, + num_heads=self._num_heads, + init_scale=self._init_scale, + dropout_prob=self._dropout_attn_prob)( + x, mask=mask, extra=extra, extra_mask=extra_mask) + return hk.dropout(hk.next_rng_key(), self._dropout_prob, x) + + +class GPT2Block(hk.Module): + """GPT-2 style transformer block with memory.""" + + def __init__(self, + r_w_bias: Optional[jnp.ndarray] = None, + r_r_bias: Optional[jnp.ndarray] = None, + causal: bool = True, + dense_dim: int = 2100, + dropout_prob: float = 0.1, + dropout_attn_prob: float = 0.0, + num_heads: int = 8, + self_att_init_scale: float = 0.02, + dense_init_scale: float = 0.02, + relative_pos_clamp_len: Optional[int] = None, + name: Optional[str] = None): + """Initializes a GPT2Block. + + Args: + r_w_bias: global content bias. + r_r_bias: global positional bias. + causal: whether to apply a causal mask to the input. + dense_dim: feature size of the feedforward block. + dropout_prob: dropout probability. + dropout_attn_prob: dropout probability of the attention module. + num_heads: number of attention heads. + self_att_init_scale: the initialization scale of the VarianceScaling + used for the linear layer in the attention module. + dense_init_scale: the initialization scale of the VarianceScaling + used for the linear layer in the feedforward module. + relative_pos_clamp_len: clamp length of the relative position embeddings. + name: Optional name for this Haiku module. + """ + super(GPT2Block, self).__init__(name=name) + self._r_w_bias = r_w_bias + self._r_r_bias = r_r_bias + self._causal = causal + self._dense_dim = dense_dim + self._dropout_prob = dropout_prob + self._dropout_attn_prob = dropout_attn_prob + self._num_heads = num_heads + self._self_att_init_scale = self_att_init_scale + self._dense_init_scale = dense_init_scale + self._relative_pos_clamp_len = relative_pos_clamp_len + + def __call__(self, + x: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + is_training: bool = True, + should_reset: Optional[jnp.ndarray] = None, + cache_steps: int = 0, + extra: Optional[jnp.ndarray] = None, + extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Computes the outputs of the GPT-2 block. + + Args: + x: query input [batch, x_timesteps, in_dim]. + mask: attention mask [batch, 1, 1, x_timesteps]. + is_training: whether the current stage is training or not. + should_reset: reset marker [batch, timesteps]. + cache_steps: number of timesteps in the cache. + extra: if provided should be extra key-value input + [batch, extra_timesteps, in_dim']. + extra_mask: if provided should be the mask for extra key-value input, + [batch, extra_timesteps]. + + Returns: + output: block output [batch, x_timesteps, in_dim]. + """ + dropout_prob = self._dropout_prob if is_training else 0.0 + dropout_attn_prob = self._dropout_attn_prob if is_training else 0.0 + x = layer_norm(x + SelfAttentionBlock( + self._r_w_bias, + self._r_r_bias, + causal=self._causal, + num_heads=self._num_heads, + dropout_prob=dropout_prob, + dropout_attn_prob=dropout_attn_prob, + init_scale=self._self_att_init_scale, + relative_pos_clamp_len=self._relative_pos_clamp_len)( + x, mask=mask, should_reset=should_reset, + cache_steps=cache_steps, extra=extra, extra_mask=extra_mask)) + x = layer_norm(x + FeedForwardBlock( + dense_dim=self._dense_dim, + dropout_prob=dropout_prob, + init_scale=self._dense_init_scale)(x)) + return x diff --git a/wikigraphs/wikigraphs/model/transformer_test.py b/wikigraphs/wikigraphs/model/transformer_test.py new file mode 100644 index 0000000..887ea55 --- /dev/null +++ b/wikigraphs/wikigraphs/model/transformer_test.py @@ -0,0 +1,373 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.model.transformer.""" + +from absl import logging +from absl.testing import absltest + +import haiku as hk +import jax +import jax.numpy as jnp +import jraph +import numpy as np +import optax + +from wikigraphs.model import embedding +from wikigraphs.model import transformer as models + + +def tree_size(nest): + return sum(x.size for x in jax.tree_util.tree_leaves(nest)) + + +class TransformerXlTest(absltest.TestCase): + + def test_transformer_param_count(self): + seqs = np.array([[1, 2, 3, 0, 0], + [3, 3, 5, 1, 2]], dtype=np.int32) + x = seqs[:, :-1] + y = seqs[:, 1:] + vocab_size = 267_735 + + def forward(inputs, labels): + input_mask = (labels != 0).astype(jnp.float32) + model = models.TransformerXL( + vocab_size=vocab_size, + emb_dim=210, + num_layers=2, + num_heads=10, + dropout_prob=0.0, + dropout_attn_prob=0.0, + self_att_init_scale=0.02, + dense_init_scale=0.02, + dense_dim=2100, + cutoffs=(20000, 40000, 200000), # WikiText-103 + relative_pos_clamp_len=None, + ) + return model.loss(inputs, labels, mask=input_mask, cache_steps=2) + + init_fn, apply_fn = hk.transform_with_state(forward) + key = hk.PRNGSequence(8) + params, state = init_fn(next(key), x, y) + out, _ = apply_fn(params, state, next(key), x, y) + loss, metrics = out + + logging.info('loss: %g', loss) + logging.info('metrics: %r', metrics) + + param_count = tree_size(params) + self.assertEqual(param_count, 58_704_438) + + def test_transformer_with_extra_runs(self): + extra = np.array([[1, 1, 0, 0], + [2, 2, 2, 2], + [3, 3, 3, 0]], dtype=np.int32) + seqs = np.array([[1, 2, 3, 0, 0], + [2, 4, 5, 6, 0], + [3, 3, 5, 1, 2]], dtype=np.int32) + x = seqs[:, :-1] + y = seqs[:, 1:] + vocab_size = seqs.max() + 1 + extra_vocab_size = extra.max() + 1 + + def forward(inputs, labels, extra): + input_mask = (labels != 0).astype(jnp.float32) + extra_mask = (extra != 0).astype(jnp.float32) + extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra) + model = models.TransformerXL( + vocab_size=vocab_size, + emb_dim=16, + num_layers=2, + num_heads=4, + cutoffs=[], + ) + return model.loss(inputs, labels, mask=input_mask, + extra=extra, extra_mask=extra_mask) + + init_fn, apply_fn = hk.transform_with_state(forward) + key = hk.PRNGSequence(8) + params, state = init_fn(next(key), x, y, extra) + out, _ = apply_fn(params, state, next(key), x, y, extra) + loss, metrics = out + + logging.info('loss: %g', loss) + logging.info('metrics: %r', metrics) + + def test_graph_embedding_model_runs(self): + graph = jraph.GraphsTuple( + nodes=np.array([[0, 1, 1], + [1, 2, 0], + [0, 3, 0], + [0, 4, 4]], dtype=np.float32), + edges=np.array([[1, 1], + [2, 2], + [3, 3]], dtype=np.float32), + senders=np.array([0, 1, 2], dtype=np.int32), + receivers=np.array([1, 2, 3], dtype=np.int32), + n_node=np.array([4], dtype=np.int32), + n_edge=np.array([3], dtype=np.int32), + globals=None) + embed_dim = 3 + + def forward(graph): + return embedding.GraphEmbeddingModel(embed_dim=3, num_layers=2)(graph) + + init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward)) + key = hk.PRNGSequence(8) + params = init_fn(next(key), graph) + out = apply_fn(params, graph) + + self.assertEqual(out.nodes.shape, (graph.nodes.shape[0], embed_dim)) + self.assertEqual(out.edges.shape, (graph.edges.shape[0], embed_dim)) + np.testing.assert_array_equal(out.senders, graph.senders) + np.testing.assert_array_equal(out.receivers, graph.receivers) + np.testing.assert_array_equal(out.n_node, graph.n_node) + + def test_unpack_and_pad(self): + x = np.array([1, 1, 2, 2, 2, 3, 4, 4], dtype=np.float32) + s = np.array([2, 3, 1, 2], dtype=np.int32) + + tensors, mask = models.unpack_and_pad(x, s, pad_size=s.max(), pad_value=0) + + np.testing.assert_array_equal( + tensors, + [[1, 1, 0], + [2, 2, 2], + [3, 0, 0], + [4, 4, 0]]) + np.testing.assert_array_equal( + mask, + [[1, 1, 0], + [1, 1, 1], + [1, 0, 0], + [1, 1, 0]]) + + # [n, 1] tensor + x = np.array([1, 1, 2, 2, 2, 3, 4, 4], dtype=np.float32)[:, None] + s = np.array([2, 3, 1, 2], dtype=np.int32) + + tensors, mask = models.unpack_and_pad(x, s, pad_size=s.max(), pad_value=0) + + np.testing.assert_array_equal( + tensors, + np.array([[1, 1, 0], + [2, 2, 2], + [3, 0, 0], + [4, 4, 0]])[:, :, None]) + np.testing.assert_array_equal( + mask, + [[1, 1, 0], + [1, 1, 1], + [1, 0, 0], + [1, 1, 0]]) + + def test_graph_conditioned_transformer_runs(self): + graphs = jraph.GraphsTuple( + nodes=np.ones((4, 3), dtype=np.float32), + edges=np.ones((3, 1), dtype=np.float32), + senders=np.array([0, 2, 3], dtype=np.int32), + receivers=np.array([1, 3, 2], dtype=np.int32), + n_node=np.array([2, 2], dtype=np.int32), + n_edge=np.array([1, 2], dtype=np.int32), + globals=None, + ) + seqs = np.array([[1, 1, 0], + [2, 2, 2]], dtype=np.int32) + vocab_size = seqs.max() + 1 + embed_dim = 8 + + x = seqs[:, :-1] + y = seqs[:, 1:] + + def forward(graphs, inputs, labels): + graphs = models.GraphEmbeddingModel(embed_dim=embed_dim, + num_layers=2)(graphs) + extra, extra_mask = models.unpack_and_pad(graphs.nodes, + graphs.n_node, + graphs.n_node.max()) + input_mask = (labels != 0).astype(jnp.float32) + transformer = models.TransformerXL(vocab_size=vocab_size, + emb_dim=embed_dim, + num_layers=2, + num_heads=4, + cutoffs=[]) + return transformer.loss(inputs, labels, mask=input_mask, extra=extra, + extra_mask=extra_mask) + + init_fn, apply_fn = hk.transform_with_state(forward) + key = hk.PRNGSequence(8) + params, state = init_fn(next(key), graphs, x, y) + out, _ = apply_fn(params, state, next(key), graphs, x, y) + loss, metrics = out + + logging.info('loss: %g', loss) + logging.info('metrics: %r', metrics) + + def test_graph_conditioned_transformer_learns(self): + graphs = jraph.GraphsTuple( + nodes=np.ones((4, 3), dtype=np.float32), + edges=np.ones((3, 1), dtype=np.float32), + senders=np.array([0, 2, 3], dtype=np.int32), + receivers=np.array([1, 3, 2], dtype=np.int32), + n_node=np.array([2, 2], dtype=np.int32), + n_edge=np.array([1, 2], dtype=np.int32), + globals=None, + ) + seqs = np.array([[1, 2, 2, 0], + [1, 3, 3, 3]], dtype=np.int32) + vocab_size = seqs.max() + 1 + embed_dim = 8 + max_graph_size = graphs.n_node.max() + + logging.info('Training seqs: %r', seqs) + + x = seqs[:, :-1] + y = seqs[:, 1:] + + def model_fn(vocab_size, embed_dim): + return models.Graph2TextTransformer( + vocab_size=vocab_size, + emb_dim=embed_dim, + num_layers=2, + num_heads=4, + cutoffs=[], + gnn_embed_dim=embed_dim, + gnn_num_layers=2) + + def forward(graphs, inputs, labels, max_graph_size): + input_mask = (labels != 0).astype(jnp.float32) + return model_fn(vocab_size, embed_dim).loss( + graphs, max_graph_size, False, inputs, labels, mask=input_mask) + + init_fn, apply_fn = hk.transform_with_state(forward) + rng = hk.PRNGSequence(8) + params, state = init_fn(next(rng), graphs, x, y, max_graph_size) + + def apply(*args, **kwargs): + out, state = apply_fn(*args, **kwargs) + return out[0], (out[1], state) + apply = jax.jit(apply, static_argnums=6) + + optimizer = optax.chain( + optax.scale_by_adam(), + optax.scale(-1e-3)) + opt_state = optimizer.init(params) + for i in range(500): + (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)( + params, state, next(rng), graphs, x, y, max_graph_size) + metrics, state = model_state + updates, opt_state = optimizer.update(grad, opt_state, params) + params = optax.apply_updates(params, updates) + if (i + 1) % 100 == 0: + logging.info( + 'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()}) + logging.info('Loss: %.8f', loss) + self.assertLess(loss, 1.0) + + def test_bow_transformer_runs(self): + bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], + [0, 1, 0, 0, 1, 0, 1, 0], + [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32) + seqs = np.array([[1, 2, 3, 0, 0], + [2, 4, 5, 6, 0], + [3, 3, 5, 1, 2]], dtype=np.int32) + x = seqs[:, :-1] + y = seqs[:, 1:] + vocab_size = seqs.max() + 1 + + def forward(bow, inputs, labels): + model = models.Bow2TextTransformer( + vocab_size=vocab_size, + emb_dim=16, + num_layers=2, + num_heads=4, + cutoffs=[]) + return model.loss(bow, inputs, labels) + + init_fn, apply_fn = hk.transform_with_state(forward) + key = hk.PRNGSequence(8) + params, state = init_fn(next(key), bow, x, y) + out, _ = apply_fn(params, state, next(key), bow, x, y) + loss, metrics = out + + logging.info('loss: %g', loss) + logging.info('metrics: %r', metrics) + + def test_bow_transformer_learns(self): + bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], + [0, 1, 0, 0, 1, 0, 1, 0], + [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32) + seqs = np.array([[1, 2, 2, 3, 0, 0], + [1, 2, 4, 5, 6, 0], + [1, 3, 3, 5, 4, 2]], dtype=np.int32) + x = seqs[:, :-1] + y = seqs[:, 1:] + vocab_size = seqs.max() + 1 + + def model_fn(): + return models.Bow2TextTransformer( + vocab_size=vocab_size, + emb_dim=16, + num_layers=2, + num_heads=4, + cutoffs=[]) + + def loss_fn(bow, inputs, labels): + mask = (labels != 0).astype(jnp.float32) + return model_fn().loss(bow, inputs, labels, mask=mask) + + init_fn, apply_fn = hk.transform_with_state(loss_fn) + key = hk.PRNGSequence(8) + params, state = init_fn(next(key), bow, x, y) + + def apply(*args, **kwargs): + out, state = apply_fn(*args, **kwargs) + return out[0], (out[1], state) + value_and_grad = jax.jit(jax.value_and_grad(apply, has_aux=True)) + + optimizer = optax.chain( + optax.scale_by_adam(), + optax.scale(-1e-3)) + opt_state = optimizer.init(params) + for i in range(800): + (loss, model_state), grad = value_and_grad( + params, state, next(key), bow, x, y) + metrics, state = model_state + updates, opt_state = optimizer.update(grad, opt_state, params) + params = optax.apply_updates(params, updates) + if (i + 1) % 100 == 0: + logging.info('Step %d, %r', i + 1, + {k: float(v) for k, v in metrics.items()}) + logging.info('Loss: %.8f', loss) + self.assertLess(loss, 0.1) + + +if __name__ == '__main__': + absltest.main()