mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-13 07:17:38 +08:00
Follow common naming convention for all notebooks.
PiperOrigin-RevId: 387819214
This commit is contained in:
committed by
Diego de Las Casas
parent
316297879f
commit
ab87c9cd72
@@ -4,7 +4,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "VHzUTH5KqNEt"
|
||||
"id": "532jUiWVFvuK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -20,8 +20,18 @@
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "VHzUTH5KqNEt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Imports\n",
|
||||
"\n",
|
||||
"import functools\n",
|
||||
@@ -1265,6 +1275,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "Z7ZQJ2auy4Lt"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -1340,7 +1351,7 @@
|
||||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||||
"kind": "private"
|
||||
},
|
||||
"name": "ImageNet Classification.ipynb",
|
||||
"name": "Perceiver IO: ImageNet Classification.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
@@ -1351,7 +1362,8 @@
|
||||
"file_id": "1bt4J3-jS7C-xZQrtx0AgAeKSBxOUoOPy",
|
||||
"timestamp": 1627592566036
|
||||
}
|
||||
]
|
||||
],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "IBVh06qUojjm"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -47,6 +48,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "sFa-lRuVfKZt"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -62,6 +64,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "LBzQQ7t_VCBo"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -104,6 +107,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "EWOeFoF0aCaT"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -154,7 +158,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Create input\n",
|
||||
"input_str = \"This is an incomplete sentence where some words are missing.\"\n",
|
||||
"input_tokens = tokenizer.to_int(input_str)\n",
|
||||
"\n",
|
||||
@@ -169,6 +172,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "6TCMuVUabnTg"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -202,8 +206,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Run the model\n",
|
||||
"\n",
|
||||
"rng = jax.random.PRNGKey(1) # Unused\n",
|
||||
"\n",
|
||||
"out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)\n",
|
||||
@@ -224,7 +226,7 @@
|
||||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||||
"kind": "private"
|
||||
},
|
||||
"name": "Perceiver: Masked Language Modelling",
|
||||
"name": "Perceiver IO: Masked Language Modelling",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "VHzUTH5KqNEt"
|
||||
"id": "AEigJ-mOGOk9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -20,8 +20,18 @@
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "VHzUTH5KqNEt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Imports\n",
|
||||
"\n",
|
||||
"import functools\n",
|
||||
@@ -44,6 +54,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "uxeP5yit7hJg"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -125,6 +136,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "dmvRv3o-6ASw"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -134,7 +146,7 @@
|
||||
"# If you encounter GPU memory errors while running the function below,\n",
|
||||
"# you can run it on the CPU instead:\n",
|
||||
"# _apply_optical_flow_model = jax.jit(optical_flow.apply, backend=\"cpu\")\n",
|
||||
"_apply_optical_flow = jax.jit(optical_flow.apply)\n",
|
||||
"_apply_optical_flow_model = jax.jit(optical_flow.apply)\n",
|
||||
"\n",
|
||||
"def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):\n",
|
||||
" if min_overlap \u003e= TRAIN_SIZE[0] or min_overlap \u003e= TRAIN_SIZE[1]:\n",
|
||||
@@ -202,6 +214,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "EVRWatw4LXFx"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -255,6 +268,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "Z7ZQJ2auy4Lt"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -309,14 +323,15 @@
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "Optical_Flow_Visualization.ipynb",
|
||||
"name": "Perceiver IO: Optical Flow Visualization.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1bt4J3-jS7C-xZQrtx0AgAeKSBxOUoOPy",
|
||||
"timestamp": 1627577366926
|
||||
}
|
||||
]
|
||||
],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "VHzUTH5KqNEt"
|
||||
"id": "6hVqbkgBFVKB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -20,8 +20,18 @@
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "VHzUTH5KqNEt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Imports\n",
|
||||
"\n",
|
||||
"import base64\n",
|
||||
@@ -51,6 +61,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "Bn1jTwkv3gHf"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -134,6 +145,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "QqXUfdsF3iZ6"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -358,6 +370,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "uxeP5yit7hJg"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -551,6 +564,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "EVRWatw4LXFx"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -613,9 +627,9 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Auto-encode the entire video, one frame at a time\n",
|
||||
"# Auto-encode the entire video, one chunk at a time\n",
|
||||
"\n",
|
||||
"# Partial video and audio into 16 frame chunks\n",
|
||||
"# Partial video and audio into 16-frame chunks\n",
|
||||
"nframes = video.shape[0]\n",
|
||||
"# Truncate to be divisible by 16\n",
|
||||
"nframes = nframes - (nframes % 16)\n",
|
||||
@@ -623,9 +637,23 @@
|
||||
"audio_chunks = jnp.reshape(audio[:nframes * AUDIO_SAMPLES_PER_FRAME],\n",
|
||||
" [nframes // 16, 16 * AUDIO_SAMPLES_PER_FRAME, 2])\n",
|
||||
"\n",
|
||||
"reconstruction = jax.vmap(jax.jit(functools.partial(autoencode_video, params, state, rng), backend='gpu'),\n",
|
||||
" in_axes=1, out_axes=1)(video_chunks[None, :], audio_chunks[None, :, :, 0:1])\n",
|
||||
"reconstruction = jax.tree_map(lambda x: jnp.reshape(x, [-1] + list(x.shape[3:])), reconstruction)"
|
||||
"encode = jax.jit(functools.partial(autoencode_video, params, state, rng))\n",
|
||||
"\n",
|
||||
"# Logically, what we do is the following code. We write out the loop to allocate\n",
|
||||
"# GPU memory for only one chunk\n",
|
||||
"#\n",
|
||||
"# reconstruction = jax.vmap(encode, in_axes=1, out_axes=1)(\n",
|
||||
"# video_chunks[None, :], audio_chunks[None, :, :, 0:1])\n",
|
||||
"\n",
|
||||
"chunks = []\n",
|
||||
"for i in range(nframes // 16):\n",
|
||||
" reconstruction = encode(video_chunks[None, i], audio_chunks[None, i, :, 0:1])\n",
|
||||
" chunks.append(jax.tree_map(lambda x: np.array(x), reconstruction))\n",
|
||||
"\n",
|
||||
"reconstruction = jax.tree_multimap(lambda *args: np.stack(args, axis=1),\n",
|
||||
" *chunks)\n",
|
||||
"\n",
|
||||
"reconstruction = jax.tree_map(lambda x: np.reshape(x, [-1] + list(x.shape[3:])), reconstruction)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -648,14 +676,15 @@
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "Video_Autoencoding.ipynb",
|
||||
"name": "Perceiver IO: Video Autoencoding.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1qD1lvE-5c4LVw9l7H9XjA3DNtiYIcTgj",
|
||||
"timestamp": 1626089023488
|
||||
}
|
||||
]
|
||||
],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
|
||||
Reference in New Issue
Block a user