mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 05:54:41 +08:00
Add the initial language colab for inference only
PiperOrigin-RevId: 387625457
This commit is contained in:
committed by
Diego de Las Casas
parent
310620d045
commit
8c431d40ea
@@ -0,0 +1,247 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "uFsBlyxwmRq2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright 2021 DeepMind Technologies Limited\n",
|
||||
"#\n",
|
||||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "IBVh06qUojjm"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Import\n",
|
||||
"from typing import Union\n",
|
||||
"\n",
|
||||
"import haiku as hk\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import numpy as np\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "sFa-lRuVfKZt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Load parameters from checkpoint\n",
|
||||
"!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle\n",
|
||||
"\n",
|
||||
"with open(\"language_perceiver_io_bytes.pickle\", \"rb\") as f:\n",
|
||||
" params = pickle.loads(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "LBzQQ7t_VCBo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Model config\n",
|
||||
"D_MODEL = 768\n",
|
||||
"D_LATENTS = 1280\n",
|
||||
"MAX_SEQ_LEN = 2048\n",
|
||||
"\n",
|
||||
"encoder_config = dict(\n",
|
||||
" num_self_attends_per_block=26,\n",
|
||||
" num_blocks=1,\n",
|
||||
" z_index_dim=256,\n",
|
||||
" num_z_channels=D_LATENTS,\n",
|
||||
" num_self_attend_heads=8,\n",
|
||||
" num_cross_attend_heads=8,\n",
|
||||
" qk_channels=8 * 32,\n",
|
||||
" v_channels=D_LATENTS,\n",
|
||||
" use_query_residual=True,\n",
|
||||
" cross_attend_widening_factor=1,\n",
|
||||
" self_attend_widening_factor=1)\n",
|
||||
"\n",
|
||||
"decoder_config = dict(\n",
|
||||
" output_num_channels=D_LATENTS,\n",
|
||||
" position_encoding_type='trainable',\n",
|
||||
" output_index_dims=MAX_SEQ_LEN,\n",
|
||||
" num_z_channels=D_LATENTS,\n",
|
||||
" qk_channels=8 * 32,\n",
|
||||
" v_channels=D_MODEL,\n",
|
||||
" num_heads=8,\n",
|
||||
" final_project=False,\n",
|
||||
" use_query_residual=False,\n",
|
||||
" trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))\n",
|
||||
"\n",
|
||||
"# The tokenizer is just UTF-8 encoding (with an offset)\n",
|
||||
"tokenizer = bytes_tokenizer.BytesTokenizer()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "EWOeFoF0aCaT"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Decoding Perceiver Model\n",
|
||||
"def apply_perceiver(\n",
|
||||
" inputs: jnp.ndarray, input_mask: jnp.ndarray) -\u003e jnp.ndarray:\n",
|
||||
" \"\"\"Runs a forward pass on the Perceiver.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" inputs: input bytes, an int array of shape [B, T]\n",
|
||||
" input_mask: Array of shape indicating which entries are valid and which are\n",
|
||||
" masked. A truthy value indicates that the entry is valid.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" The output logits, an array of shape [B, T, vocab_size].\n",
|
||||
" \"\"\"\n",
|
||||
" assert inputs.shape[1] == MAX_SEQ_LEN\n",
|
||||
"\n",
|
||||
" embedding_layer = hk.Embed(\n",
|
||||
" vocab_size=tokenizer.vocab_size,\n",
|
||||
" embed_dim=D_MODEL)\n",
|
||||
" embedded_inputs = embedding_layer(inputs)\n",
|
||||
"\n",
|
||||
" batch_size = embedded_inputs.shape[0]\n",
|
||||
"\n",
|
||||
" input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(\n",
|
||||
" index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)\n",
|
||||
" embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)\n",
|
||||
" perceiver_mod = perceiver.Perceiver(\n",
|
||||
" encoder=perceiver.PerceiverEncoder(**encoder_config),\n",
|
||||
" decoder=perceiver.BasicDecoder(**decoder_config))\n",
|
||||
" output_embeddings = perceiver_mod(\n",
|
||||
" embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)\n",
|
||||
"\n",
|
||||
" logits = io_processors.EmbeddingDecoder(\n",
|
||||
" embedding_matrix=embedding_layer.embeddings)(output_embeddings)\n",
|
||||
" return logits\n",
|
||||
"\n",
|
||||
"apply_perceiver = hk.transform(apply_perceiver).apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Pna1ZXEyOJZb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Create input\n",
|
||||
"input_str = \"This is an incomplete sentence where some words are missing.\"\n",
|
||||
"input_tokens = tokenizer.to_int(input_str)\n",
|
||||
"\n",
|
||||
"# Mask \" missing.\". Note that the model performs much better if the masked chunk\n",
|
||||
"# starts with a space.\n",
|
||||
"input_tokens[51:60] = tokenizer.mask_token\n",
|
||||
"print(\"Tokenized string without masked bytes:\")\n",
|
||||
"print(tokenizer.to_string(input_tokens))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "6TCMuVUabnTg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Pad and reshape inputs\n",
|
||||
"inputs = input_tokens[None]\n",
|
||||
"input_mask = np.ones_like(inputs)\n",
|
||||
"\n",
|
||||
"def pad(max_sequence_length: int, inputs, input_mask):\n",
|
||||
" input_len = inputs.shape[1]\n",
|
||||
" assert input_len \u003c= max_sequence_length\n",
|
||||
" pad_len = max_sequence_length - input_len\n",
|
||||
" padded_inputs = np.pad(\n",
|
||||
" inputs,\n",
|
||||
" pad_width=((0, 0), (0, pad_len)),\n",
|
||||
" constant_values=tokenizer.pad_token)\n",
|
||||
" padded_mask = np.pad(\n",
|
||||
" input_mask,\n",
|
||||
" pad_width=((0, 0), (0, pad_len)),\n",
|
||||
" constant_values=0)\n",
|
||||
" return padded_inputs, padded_mask\n",
|
||||
"\n",
|
||||
"inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ipZs6p0Xk3lb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Run the model\n",
|
||||
"\n",
|
||||
"rng = jax.random.PRNGKey(1) # Unused\n",
|
||||
"\n",
|
||||
"out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)\n",
|
||||
"\n",
|
||||
"masked_tokens_predictions = out[0, 51:60].argmax(axis=-1)\n",
|
||||
"print(\"Greedy predictions:\")\n",
|
||||
"print(masked_tokens_predictions)\n",
|
||||
"print()\n",
|
||||
"print(\"Predicted string:\")\n",
|
||||
"print(tokenizer.to_string(masked_tokens_predictions))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"last_runtime": {
|
||||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||||
"kind": "private"
|
||||
},
|
||||
"name": "Perceiver: Masked Language Modelling",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1N31dQM-SzjG-_acz405i3jCR_m6D4bq8",
|
||||
"timestamp": 1627567455889
|
||||
}
|
||||
],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user