diff --git a/perceiver/colabs/imagenet_classification.ipynb b/perceiver/colabs/imagenet_classification.ipynb index eb93ae7..608481a 100644 --- a/perceiver/colabs/imagenet_classification.ipynb +++ b/perceiver/colabs/imagenet_classification.ipynb @@ -4,7 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "VHzUTH5KqNEt" + "id": "532jUiWVFvuK" }, "outputs": [], "source": [ @@ -20,8 +20,18 @@ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VHzUTH5KqNEt" + }, + "outputs": [], + "source": [ "#@title Imports\n", "\n", "import functools\n", @@ -1265,6 +1275,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "Z7ZQJ2auy4Lt" }, "outputs": [], @@ -1340,7 +1351,7 @@ "build_target": "//learning/deepmind/dm_python:dm_notebook3", "kind": "private" }, - "name": "ImageNet Classification.ipynb", + "name": "Perceiver IO: ImageNet Classification.ipynb", "private_outputs": true, "provenance": [ { @@ -1351,7 +1362,8 @@ "file_id": "1bt4J3-jS7C-xZQrtx0AgAeKSBxOUoOPy", "timestamp": 1627592566036 } - ] + ], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3", diff --git a/perceiver/colabs/masked_language_modelling.ipynb b/perceiver/colabs/masked_language_modelling.ipynb index 0fbd27e..0d430e8 100644 --- a/perceiver/colabs/masked_language_modelling.ipynb +++ b/perceiver/colabs/masked_language_modelling.ipynb @@ -27,6 +27,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "IBVh06qUojjm" }, "outputs": [], @@ -47,6 +48,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "sFa-lRuVfKZt" }, "outputs": [], @@ -62,6 +64,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "LBzQQ7t_VCBo" }, "outputs": [], @@ -104,6 +107,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "EWOeFoF0aCaT" }, "outputs": [], @@ -154,7 +158,6 @@ }, "outputs": [], "source": [ - "#@title Create input\n", "input_str = \"This is an incomplete sentence where some words are missing.\"\n", "input_tokens = tokenizer.to_int(input_str)\n", "\n", @@ -169,6 +172,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "6TCMuVUabnTg" }, "outputs": [], @@ -202,8 +206,6 @@ }, "outputs": [], "source": [ - "#@title Run the model\n", - "\n", "rng = jax.random.PRNGKey(1) # Unused\n", "\n", "out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)\n", @@ -224,7 +226,7 @@ "build_target": "//learning/deepmind/dm_python:dm_notebook3", "kind": "private" }, - "name": "Perceiver: Masked Language Modelling", + "name": "Perceiver IO: Masked Language Modelling", "private_outputs": true, "provenance": [ { diff --git a/perceiver/colabs/optical_flow.ipynb b/perceiver/colabs/optical_flow.ipynb index 91103d9..db32cd5 100644 --- a/perceiver/colabs/optical_flow.ipynb +++ b/perceiver/colabs/optical_flow.ipynb @@ -4,7 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "VHzUTH5KqNEt" + "id": "AEigJ-mOGOk9" }, "outputs": [], "source": [ @@ -20,8 +20,18 @@ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VHzUTH5KqNEt" + }, + "outputs": [], + "source": [ "#@title Imports\n", "\n", "import functools\n", @@ -44,6 +54,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "uxeP5yit7hJg" }, "outputs": [], @@ -125,6 +136,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "dmvRv3o-6ASw" }, "outputs": [], @@ -134,7 +146,7 @@ "# If you encounter GPU memory errors while running the function below,\n", "# you can run it on the CPU instead:\n", "# _apply_optical_flow_model = jax.jit(optical_flow.apply, backend=\"cpu\")\n", - "_apply_optical_flow = jax.jit(optical_flow.apply)\n", + "_apply_optical_flow_model = jax.jit(optical_flow.apply)\n", "\n", "def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):\n", " if min_overlap \u003e= TRAIN_SIZE[0] or min_overlap \u003e= TRAIN_SIZE[1]:\n", @@ -202,6 +214,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "EVRWatw4LXFx" }, "outputs": [], @@ -255,6 +268,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "Z7ZQJ2auy4Lt" }, "outputs": [], @@ -309,14 +323,15 @@ "metadata": { "colab": { "collapsed_sections": [], - "name": "Optical_Flow_Visualization.ipynb", + "name": "Perceiver IO: Optical Flow Visualization.ipynb", "private_outputs": true, "provenance": [ { "file_id": "1bt4J3-jS7C-xZQrtx0AgAeKSBxOUoOPy", "timestamp": 1627577366926 } - ] + ], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3", diff --git a/perceiver/colabs/video_autoencoding.ipynb b/perceiver/colabs/video_autoencoding.ipynb index 8c64dc7..8e08f1c 100644 --- a/perceiver/colabs/video_autoencoding.ipynb +++ b/perceiver/colabs/video_autoencoding.ipynb @@ -4,7 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "VHzUTH5KqNEt" + "id": "6hVqbkgBFVKB" }, "outputs": [], "source": [ @@ -20,8 +20,18 @@ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VHzUTH5KqNEt" + }, + "outputs": [], + "source": [ "#@title Imports\n", "\n", "import base64\n", @@ -51,6 +61,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "Bn1jTwkv3gHf" }, "outputs": [], @@ -134,6 +145,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "QqXUfdsF3iZ6" }, "outputs": [], @@ -358,6 +370,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "uxeP5yit7hJg" }, "outputs": [], @@ -551,6 +564,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "EVRWatw4LXFx" }, "outputs": [], @@ -613,9 +627,9 @@ }, "outputs": [], "source": [ - "# Auto-encode the entire video, one frame at a time\n", + "# Auto-encode the entire video, one chunk at a time\n", "\n", - "# Partial video and audio into 16 frame chunks\n", + "# Partial video and audio into 16-frame chunks\n", "nframes = video.shape[0]\n", "# Truncate to be divisible by 16\n", "nframes = nframes - (nframes % 16)\n", @@ -623,9 +637,23 @@ "audio_chunks = jnp.reshape(audio[:nframes * AUDIO_SAMPLES_PER_FRAME],\n", " [nframes // 16, 16 * AUDIO_SAMPLES_PER_FRAME, 2])\n", "\n", - "reconstruction = jax.vmap(jax.jit(functools.partial(autoencode_video, params, state, rng), backend='gpu'),\n", - " in_axes=1, out_axes=1)(video_chunks[None, :], audio_chunks[None, :, :, 0:1])\n", - "reconstruction = jax.tree_map(lambda x: jnp.reshape(x, [-1] + list(x.shape[3:])), reconstruction)" + "encode = jax.jit(functools.partial(autoencode_video, params, state, rng))\n", + "\n", + "# Logically, what we do is the following code. We write out the loop to allocate\n", + "# GPU memory for only one chunk\n", + "#\n", + "# reconstruction = jax.vmap(encode, in_axes=1, out_axes=1)(\n", + "# video_chunks[None, :], audio_chunks[None, :, :, 0:1])\n", + "\n", + "chunks = []\n", + "for i in range(nframes // 16):\n", + " reconstruction = encode(video_chunks[None, i], audio_chunks[None, i, :, 0:1])\n", + " chunks.append(jax.tree_map(lambda x: np.array(x), reconstruction))\n", + "\n", + "reconstruction = jax.tree_multimap(lambda *args: np.stack(args, axis=1),\n", + " *chunks)\n", + "\n", + "reconstruction = jax.tree_map(lambda x: np.reshape(x, [-1] + list(x.shape[3:])), reconstruction)" ] }, { @@ -648,14 +676,15 @@ "build_target": "", "kind": "local" }, - "name": "Video_Autoencoding.ipynb", + "name": "Perceiver IO: Video Autoencoding.ipynb", "private_outputs": true, "provenance": [ { "file_id": "1qD1lvE-5c4LVw9l7H9XjA3DNtiYIcTgj", "timestamp": 1626089023488 } - ] + ], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3",