diff --git a/perceiver/colabs/masked_language_modelling.ipynb b/perceiver/colabs/masked_language_modelling.ipynb new file mode 100644 index 0000000..0fbd27e --- /dev/null +++ b/perceiver/colabs/masked_language_modelling.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uFsBlyxwmRq2" + }, + "outputs": [], + "source": [ + "# Copyright 2021 DeepMind Technologies Limited\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IBVh06qUojjm" + }, + "outputs": [], + "source": [ + "#@title Import\n", + "from typing import Union\n", + "\n", + "import haiku as hk\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import pickle\n", + "\n", + "from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sFa-lRuVfKZt" + }, + "outputs": [], + "source": [ + "#@title Load parameters from checkpoint\n", + "!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle\n", + "\n", + "with open(\"language_perceiver_io_bytes.pickle\", \"rb\") as f:\n", + " params = pickle.loads(f.read())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LBzQQ7t_VCBo" + }, + "outputs": [], + "source": [ + "#@title Model config\n", + "D_MODEL = 768\n", + "D_LATENTS = 1280\n", + "MAX_SEQ_LEN = 2048\n", + "\n", + "encoder_config = dict(\n", + " num_self_attends_per_block=26,\n", + " num_blocks=1,\n", + " z_index_dim=256,\n", + " num_z_channels=D_LATENTS,\n", + " num_self_attend_heads=8,\n", + " num_cross_attend_heads=8,\n", + " qk_channels=8 * 32,\n", + " v_channels=D_LATENTS,\n", + " use_query_residual=True,\n", + " cross_attend_widening_factor=1,\n", + " self_attend_widening_factor=1)\n", + "\n", + "decoder_config = dict(\n", + " output_num_channels=D_LATENTS,\n", + " position_encoding_type='trainable',\n", + " output_index_dims=MAX_SEQ_LEN,\n", + " num_z_channels=D_LATENTS,\n", + " qk_channels=8 * 32,\n", + " v_channels=D_MODEL,\n", + " num_heads=8,\n", + " final_project=False,\n", + " use_query_residual=False,\n", + " trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))\n", + "\n", + "# The tokenizer is just UTF-8 encoding (with an offset)\n", + "tokenizer = bytes_tokenizer.BytesTokenizer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EWOeFoF0aCaT" + }, + "outputs": [], + "source": [ + "#@title Decoding Perceiver Model\n", + "def apply_perceiver(\n", + " inputs: jnp.ndarray, input_mask: jnp.ndarray) -\u003e jnp.ndarray:\n", + " \"\"\"Runs a forward pass on the Perceiver.\n", + "\n", + " Args:\n", + " inputs: input bytes, an int array of shape [B, T]\n", + " input_mask: Array of shape indicating which entries are valid and which are\n", + " masked. A truthy value indicates that the entry is valid.\n", + "\n", + " Returns:\n", + " The output logits, an array of shape [B, T, vocab_size].\n", + " \"\"\"\n", + " assert inputs.shape[1] == MAX_SEQ_LEN\n", + "\n", + " embedding_layer = hk.Embed(\n", + " vocab_size=tokenizer.vocab_size,\n", + " embed_dim=D_MODEL)\n", + " embedded_inputs = embedding_layer(inputs)\n", + "\n", + " batch_size = embedded_inputs.shape[0]\n", + "\n", + " input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(\n", + " index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)\n", + " embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)\n", + " perceiver_mod = perceiver.Perceiver(\n", + " encoder=perceiver.PerceiverEncoder(**encoder_config),\n", + " decoder=perceiver.BasicDecoder(**decoder_config))\n", + " output_embeddings = perceiver_mod(\n", + " embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)\n", + "\n", + " logits = io_processors.EmbeddingDecoder(\n", + " embedding_matrix=embedding_layer.embeddings)(output_embeddings)\n", + " return logits\n", + "\n", + "apply_perceiver = hk.transform(apply_perceiver).apply" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Pna1ZXEyOJZb" + }, + "outputs": [], + "source": [ + "#@title Create input\n", + "input_str = \"This is an incomplete sentence where some words are missing.\"\n", + "input_tokens = tokenizer.to_int(input_str)\n", + "\n", + "# Mask \" missing.\". Note that the model performs much better if the masked chunk\n", + "# starts with a space.\n", + "input_tokens[51:60] = tokenizer.mask_token\n", + "print(\"Tokenized string without masked bytes:\")\n", + "print(tokenizer.to_string(input_tokens))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6TCMuVUabnTg" + }, + "outputs": [], + "source": [ + "#@title Pad and reshape inputs\n", + "inputs = input_tokens[None]\n", + "input_mask = np.ones_like(inputs)\n", + "\n", + "def pad(max_sequence_length: int, inputs, input_mask):\n", + " input_len = inputs.shape[1]\n", + " assert input_len \u003c= max_sequence_length\n", + " pad_len = max_sequence_length - input_len\n", + " padded_inputs = np.pad(\n", + " inputs,\n", + " pad_width=((0, 0), (0, pad_len)),\n", + " constant_values=tokenizer.pad_token)\n", + " padded_mask = np.pad(\n", + " input_mask,\n", + " pad_width=((0, 0), (0, pad_len)),\n", + " constant_values=0)\n", + " return padded_inputs, padded_mask\n", + "\n", + "inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ipZs6p0Xk3lb" + }, + "outputs": [], + "source": [ + "#@title Run the model\n", + "\n", + "rng = jax.random.PRNGKey(1) # Unused\n", + "\n", + "out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)\n", + "\n", + "masked_tokens_predictions = out[0, 51:60].argmax(axis=-1)\n", + "print(\"Greedy predictions:\")\n", + "print(masked_tokens_predictions)\n", + "print()\n", + "print(\"Predicted string:\")\n", + "print(tokenizer.to_string(masked_tokens_predictions))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook3", + "kind": "private" + }, + "name": "Perceiver: Masked Language Modelling", + "private_outputs": true, + "provenance": [ + { + "file_id": "1N31dQM-SzjG-_acz405i3jCR_m6D4bq8", + "timestamp": 1627567455889 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}