Add the initial language colab for inference only

PiperOrigin-RevId: 387625457
This commit is contained in:
Sebastian Borgeaud
2021-07-29 19:09:31 +01:00
committed by Diego de Las Casas
parent 310620d045
commit 8c431d40ea
@@ -0,0 +1,247 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uFsBlyxwmRq2"
},
"outputs": [],
"source": [
"# Copyright 2021 DeepMind Technologies Limited\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IBVh06qUojjm"
},
"outputs": [],
"source": [
"#@title Import\n",
"from typing import Union\n",
"\n",
"import haiku as hk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import pickle\n",
"\n",
"from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sFa-lRuVfKZt"
},
"outputs": [],
"source": [
"#@title Load parameters from checkpoint\n",
"!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle\n",
"\n",
"with open(\"language_perceiver_io_bytes.pickle\", \"rb\") as f:\n",
" params = pickle.loads(f.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LBzQQ7t_VCBo"
},
"outputs": [],
"source": [
"#@title Model config\n",
"D_MODEL = 768\n",
"D_LATENTS = 1280\n",
"MAX_SEQ_LEN = 2048\n",
"\n",
"encoder_config = dict(\n",
" num_self_attends_per_block=26,\n",
" num_blocks=1,\n",
" z_index_dim=256,\n",
" num_z_channels=D_LATENTS,\n",
" num_self_attend_heads=8,\n",
" num_cross_attend_heads=8,\n",
" qk_channels=8 * 32,\n",
" v_channels=D_LATENTS,\n",
" use_query_residual=True,\n",
" cross_attend_widening_factor=1,\n",
" self_attend_widening_factor=1)\n",
"\n",
"decoder_config = dict(\n",
" output_num_channels=D_LATENTS,\n",
" position_encoding_type='trainable',\n",
" output_index_dims=MAX_SEQ_LEN,\n",
" num_z_channels=D_LATENTS,\n",
" qk_channels=8 * 32,\n",
" v_channels=D_MODEL,\n",
" num_heads=8,\n",
" final_project=False,\n",
" use_query_residual=False,\n",
" trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))\n",
"\n",
"# The tokenizer is just UTF-8 encoding (with an offset)\n",
"tokenizer = bytes_tokenizer.BytesTokenizer()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EWOeFoF0aCaT"
},
"outputs": [],
"source": [
"#@title Decoding Perceiver Model\n",
"def apply_perceiver(\n",
" inputs: jnp.ndarray, input_mask: jnp.ndarray) -\u003e jnp.ndarray:\n",
" \"\"\"Runs a forward pass on the Perceiver.\n",
"\n",
" Args:\n",
" inputs: input bytes, an int array of shape [B, T]\n",
" input_mask: Array of shape indicating which entries are valid and which are\n",
" masked. A truthy value indicates that the entry is valid.\n",
"\n",
" Returns:\n",
" The output logits, an array of shape [B, T, vocab_size].\n",
" \"\"\"\n",
" assert inputs.shape[1] == MAX_SEQ_LEN\n",
"\n",
" embedding_layer = hk.Embed(\n",
" vocab_size=tokenizer.vocab_size,\n",
" embed_dim=D_MODEL)\n",
" embedded_inputs = embedding_layer(inputs)\n",
"\n",
" batch_size = embedded_inputs.shape[0]\n",
"\n",
" input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(\n",
" index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)\n",
" embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)\n",
" perceiver_mod = perceiver.Perceiver(\n",
" encoder=perceiver.PerceiverEncoder(**encoder_config),\n",
" decoder=perceiver.BasicDecoder(**decoder_config))\n",
" output_embeddings = perceiver_mod(\n",
" embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)\n",
"\n",
" logits = io_processors.EmbeddingDecoder(\n",
" embedding_matrix=embedding_layer.embeddings)(output_embeddings)\n",
" return logits\n",
"\n",
"apply_perceiver = hk.transform(apply_perceiver).apply"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pna1ZXEyOJZb"
},
"outputs": [],
"source": [
"#@title Create input\n",
"input_str = \"This is an incomplete sentence where some words are missing.\"\n",
"input_tokens = tokenizer.to_int(input_str)\n",
"\n",
"# Mask \" missing.\". Note that the model performs much better if the masked chunk\n",
"# starts with a space.\n",
"input_tokens[51:60] = tokenizer.mask_token\n",
"print(\"Tokenized string without masked bytes:\")\n",
"print(tokenizer.to_string(input_tokens))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6TCMuVUabnTg"
},
"outputs": [],
"source": [
"#@title Pad and reshape inputs\n",
"inputs = input_tokens[None]\n",
"input_mask = np.ones_like(inputs)\n",
"\n",
"def pad(max_sequence_length: int, inputs, input_mask):\n",
" input_len = inputs.shape[1]\n",
" assert input_len \u003c= max_sequence_length\n",
" pad_len = max_sequence_length - input_len\n",
" padded_inputs = np.pad(\n",
" inputs,\n",
" pad_width=((0, 0), (0, pad_len)),\n",
" constant_values=tokenizer.pad_token)\n",
" padded_mask = np.pad(\n",
" input_mask,\n",
" pad_width=((0, 0), (0, pad_len)),\n",
" constant_values=0)\n",
" return padded_inputs, padded_mask\n",
"\n",
"inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ipZs6p0Xk3lb"
},
"outputs": [],
"source": [
"#@title Run the model\n",
"\n",
"rng = jax.random.PRNGKey(1) # Unused\n",
"\n",
"out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)\n",
"\n",
"masked_tokens_predictions = out[0, 51:60].argmax(axis=-1)\n",
"print(\"Greedy predictions:\")\n",
"print(masked_tokens_predictions)\n",
"print()\n",
"print(\"Predicted string:\")\n",
"print(tokenizer.to_string(masked_tokens_predictions))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"name": "Perceiver: Masked Language Modelling",
"private_outputs": true,
"provenance": [
{
"file_id": "1N31dQM-SzjG-_acz405i3jCR_m6D4bq8",
"timestamp": 1627567455889
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}