diff --git a/enformer/README.md b/enformer/README.md index 6fabf39..88fea20 100644 --- a/enformer/README.md +++ b/enformer/README.md @@ -120,8 +120,9 @@ df_targets.shape # (5313, 8) With rows match output shape above. The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The full sonnet module is defined in `enformer.py` called Enformer. See -[enformer-training.ipynb](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb). -on how to train the model on Basenji2 data. +[enformer-training.ipynb](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb) +on how to train the model on Basenji2 data and how to load the pre-trained +weights into the Enformer module for fine-tuning. ## Colab @@ -143,6 +144,7 @@ This colab shows how to: * Setup training data by directly accessing the Basenji2 data on GCS * Train the model for a few steps on both human and mouse genomes * Evaluate the model on human and mouse genomes +* Restore the model from a checkpoint for fine-tuning ## Disclaimer diff --git a/enformer/enformer-training.ipynb b/enformer/enformer-training.ipynb index 0681e33..b459c87 100644 --- a/enformer/enformer-training.ipynb +++ b/enformer/enformer-training.ipynb @@ -1,22 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "enformer-training.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "markdown", @@ -80,7 +62,7 @@ "id": "NqR7ol3rxrtM" }, "source": [ - "**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU" + "**Start the colab kernel with GPU**: Runtime -\u003e Change runtime type -\u003e GPU" ] }, { @@ -94,6 +76,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -101,43 +84,42 @@ "id": "WiDFm-a41tKW", "outputId": "8b889c6e-f113-4664-f2c9-91110808ad92" }, - "source": [ - "!pip install dm-sonnet tqdm" - ], - "execution_count": 1, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Collecting dm-sonnet\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/13/28/9185afffefb655ef1a29f4b84aa9f656826408ca2d1b9ffeba81fbfd40ec/dm_sonnet-2.0.0-py3-none-any.whl (254kB)\n", "\r\u001b[K |█▎ | 10kB 13.3MB/s eta 0:00:01\r\u001b[K |██▋ | 20kB 11.7MB/s eta 0:00:01\r\u001b[K |███▉ | 30kB 8.7MB/s eta 0:00:01\r\u001b[K |█████▏ | 40kB 7.7MB/s eta 0:00:01\r\u001b[K |██████▍ | 51kB 4.5MB/s eta 0:00:01\r\u001b[K |███████▊ | 61kB 5.1MB/s eta 0:00:01\r\u001b[K |█████████ | 71kB 5.1MB/s eta 0:00:01\r\u001b[K |██████████▎ | 81kB 5.5MB/s eta 0:00:01\r\u001b[K |███████████▋ | 92kB 5.5MB/s eta 0:00:01\r\u001b[K |████████████▉ | 102kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████▏ | 112kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 122kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 133kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████ | 143kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████▎ | 153kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 163kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 174kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 184kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████▌ | 194kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████▊ | 204kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 215kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 225kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 235kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 245kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 256kB 5.7MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.15.0)\n", - "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.12.0)\n", - "Requirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.19.5)\n", - "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.12.1)\n", - "Requirement already satisfied: tabulate>=0.7.5 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.8.9)\n", - "Requirement already satisfied: dm-tree>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.1.5)\n", + "Requirement already satisfied: six\u003e=1.12.0 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.15.0)\n", + "Requirement already satisfied: absl-py\u003e=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.12.0)\n", + "Requirement already satisfied: numpy\u003e=1.16.3 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.19.5)\n", + "Requirement already satisfied: wrapt\u003e=1.11.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.12.1)\n", + "Requirement already satisfied: tabulate\u003e=0.7.5 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.8.9)\n", + "Requirement already satisfied: dm-tree\u003e=0.1.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.1.5)\n", "Installing collected packages: dm-sonnet\n", "Successfully installed dm-sonnet-2.0.0\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "!pip install dm-sonnet tqdm" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "CokqDsb-fxme" }, + "outputs": [], "source": [ "# Get enformer source code\n", "!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py\n", "!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py" - ], - "execution_count": 15, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -150,37 +132,39 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "hTGOLrbZxNHK", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "hTGOLrbZxNHK", "outputId": "f58b5c21-0764-4003-c794-aa89e5d336cc" }, - "source": [ - "import tensorflow as tf\n", - "# Make sure the GPU is enabled \n", - "assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'\n", - "\n", - "# Easier debugging of OOM\n", - "%env TF_ENABLE_GPU_GARBAGE_COLLECTION=false" - ], - "execution_count": 4, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "env: TF_ENABLE_GPU_GARBAGE_COLLECTION=false\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "import tensorflow as tf\n", + "# Make sure the GPU is enabled \n", + "assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -\u003e Change runtime type -\u003e GPU'\n", + "\n", + "# Easier debugging of OOM\n", + "%env TF_ENABLE_GPU_GARBAGE_COLLECTION=false" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "S9ywsUmT05C1" }, + "outputs": [], "source": [ "import sonnet as snt\n", "from tqdm import tqdm\n", @@ -189,23 +173,22 @@ "import pandas as pd\n", "import time\n", "import os" - ], - "execution_count": 5, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "YUIbu0Xu1BnA" }, + "outputs": [], "source": [ "assert snt.__version__.startswith('2.0')" - ], - "execution_count": 6, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -214,13 +197,8 @@ "id": "PWzsyJddILcx", "outputId": "3f1cac0f-6bce-430e-b3c0-9848d43e654c" }, - "source": [ - "tf.__version__" - ], - "execution_count": 7, "outputs": [ { - "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" @@ -229,15 +207,20 @@ "'2.4.1'" ] }, + "execution_count": 7, "metadata": { "tags": [] }, - "execution_count": 7 + "output_type": "execute_result" } + ], + "source": [ + "tf.__version__" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -245,13 +228,9 @@ "id": "xOhdaXG95eOl", "outputId": "1e57ef49-254a-4050-89af-61bc0f8ea577" }, - "source": [ - "# GPU colab has T4 with 16 GiB of memory\n", - "!nvidia-smi" - ], - "execution_count": 8, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Fri Mar 26 12:28:00 2021 \n", @@ -274,9 +253,12 @@ "|=============================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------+\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "# GPU colab has T4 with 16 GiB of memory\n", + "!nvidia-smi" ] }, { @@ -290,34 +272,36 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "BbXyDdoShFzX" }, + "outputs": [], "source": [ "import enformer" - ], - "execution_count": 37, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "MEb8OZli2Nbu" }, + "outputs": [], "source": [ "# @title `get_targets(organism)`\n", "def get_targets(organism):\n", " targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'\n", " return pd.read_csv(targets_txt, sep='\\t')" - ], - "execution_count": 41, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "2BuZ2gmUbpXZ" }, + "outputs": [], "source": [ "# @title `get_dataset(organism, subset, num_threads=8)`\n", "import glob\n", @@ -373,9 +357,7 @@ "\n", " return {'sequence': sequence,\n", " 'target': target}\n" - ], - "execution_count": 42, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -388,6 +370,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -396,18 +379,12 @@ "id": "M_vr1mbl3jbD", "outputId": "2de351ed-f43e-4469-a681-2a437d97c946" }, - "source": [ - "df_targets_human = get_targets('human')\n", - "df_targets_human.head()" - ], - "execution_count": 43, "outputs": [ { - "output_type": "execute_result", "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexgenomeidentifierfileclipscalesum_statdescription
000ENCFF833POA/home/drk/tillage/datasets/human/dnase/encode/...322meanDNASE:cerebellum male adult (27 years) and mal...
110ENCFF110QGM/home/drk/tillage/datasets/human/dnase/encode/...322meanDNASE:frontal cortex male adult (27 years) and...
220ENCFF880MKD/home/drk/tillage/datasets/human/dnase/encode/...322meanDNASE:chorion
330ENCFF463ZLQ/home/drk/tillage/datasets/human/dnase/encode/...322meanDNASE:Ishikawa treated with 0.02% dimethyl sul...
440ENCFF890OGQ/home/drk/tillage/datasets/human/dnase/encode/...322meanDNASE:GM03348
\n", - "
" + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003eindex\u003c/th\u003e\n", + " \u003cth\u003egenome\u003c/th\u003e\n", + " \u003cth\u003eidentifier\u003c/th\u003e\n", + " \u003cth\u003efile\u003c/th\u003e\n", + " \u003cth\u003eclip\u003c/th\u003e\n", + " \u003cth\u003escale\u003c/th\u003e\n", + " \u003cth\u003esum_stat\u003c/th\u003e\n", + " \u003cth\u003edescription\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e0\u003c/th\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003eENCFF833POA\u003c/td\u003e\n", + " \u003ctd\u003e/home/drk/tillage/datasets/human/dnase/encode/...\u003c/td\u003e\n", + " \u003ctd\u003e32\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003emean\u003c/td\u003e\n", + " \u003ctd\u003eDNASE:cerebellum male adult (27 years) and mal...\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e1\u003c/th\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003eENCFF110QGM\u003c/td\u003e\n", + " \u003ctd\u003e/home/drk/tillage/datasets/human/dnase/encode/...\u003c/td\u003e\n", + " \u003ctd\u003e32\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003emean\u003c/td\u003e\n", + " \u003ctd\u003eDNASE:frontal cortex male adult (27 years) and...\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e2\u003c/th\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003eENCFF880MKD\u003c/td\u003e\n", + " \u003ctd\u003e/home/drk/tillage/datasets/human/dnase/encode/...\u003c/td\u003e\n", + " \u003ctd\u003e32\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003emean\u003c/td\u003e\n", + " \u003ctd\u003eDNASE:chorion\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e3\u003c/th\u003e\n", + " \u003ctd\u003e3\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003eENCFF463ZLQ\u003c/td\u003e\n", + " \u003ctd\u003e/home/drk/tillage/datasets/human/dnase/encode/...\u003c/td\u003e\n", + " \u003ctd\u003e32\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003emean\u003c/td\u003e\n", + " \u003ctd\u003eDNASE:Ishikawa treated with 0.02% dimethyl sul...\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e4\u003c/th\u003e\n", + " \u003ctd\u003e4\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003eENCFF890OGQ\u003c/td\u003e\n", + " \u003ctd\u003e/home/drk/tillage/datasets/human/dnase/encode/...\u003c/td\u003e\n", + " \u003ctd\u003e32\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003emean\u003c/td\u003e\n", + " \u003ctd\u003eDNASE:GM03348\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e" ], "text/plain": [ " index genome ... sum_stat description\n", @@ -505,40 +482,46 @@ "[5 rows x 8 columns]" ] }, + "execution_count": 43, "metadata": { "tags": [] }, - "execution_count": 43 + "output_type": "execute_result" } + ], + "source": [ + "df_targets_human = get_targets('human')\n", + "df_targets_human.head()" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "YDSKttXI4hMT" }, + "outputs": [], "source": [ "human_dataset = get_dataset('human', 'train').batch(1).repeat()\n", "mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()\n", "human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)" - ], - "execution_count": 44, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "2vx3116C7oFW" }, + "outputs": [], "source": [ "it = iter(mouse_dataset)\n", "example = next(it)" - ], - "execution_count": 45, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -546,6 +529,18 @@ "id": "XeztqJZ74ixT", "outputId": "39dc4051-5a19-4443-b6b0-bf6869faf5ec" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "human\n", + "{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}\n", + "mouse\n", + "{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}\n" + ] + } + ], "source": [ "# Example input\n", "it = iter(human_mouse_dataset)\n", @@ -553,19 +548,6 @@ "for i in range(len(example)):\n", " print(['human', 'mouse'][i])\n", " print({k: (v.shape, v.dtype) for k,v in example[i].items()})" - ], - "execution_count": 46, - "outputs": [ - { - "output_type": "stream", - "text": [ - "human\n", - "{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}\n", - "mouse\n", - "{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}\n" - ], - "name": "stdout" - } ] }, { @@ -579,9 +561,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "0U3hLJaUdZkG" }, + "outputs": [], "source": [ "def create_step_function(model, optimizer):\n", "\n", @@ -597,15 +581,15 @@ "\n", " return loss\n", " return train_step" - ], - "execution_count": 51, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "ZXv5HU_242Ut" }, + "outputs": [], "source": [ "learning_rate = tf.Variable(0., trainable=False, name='learning_rate')\n", "optimizer = snt.optimizers.Adam(learning_rate=learning_rate)\n", @@ -618,20 +602,107 @@ " pooling_type='max')\n", "\n", "train_step = create_step_function(model, optimizer)" - ], - "execution_count": 52, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "FrbDaOMWcFUl", + "cellView": "code", "colab": { "base_uri": "https://localhost:8080/" }, - "cellView": "code", + "id": "FrbDaOMWcFUl", "outputId": "6a42f69c-3003-47f2-a8d2-1b94c52eb57e" }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:24\u003c00:00, 1.25s/it]\n", + " 0%| | 0/20 [00:00\u003c?, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "loss_human 1.774059 loss_mouse 0.94303024 learning_rate 2.0000002e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:17\u003c00:00, 1.13it/s]\n", + " 0%| | 0/20 [00:00\u003c?, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "loss_human 1.0067647 loss_mouse 0.8752468 learning_rate 4.0000004e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:17\u003c00:00, 1.13it/s]\n", + " 0%| | 0/20 [00:00\u003c?, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "loss_human 1.0471998 loss_mouse 0.89318746 learning_rate 6e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:17\u003c00:00, 1.14it/s]\n", + " 0%| | 0/20 [00:00\u003c?, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "loss_human 1.010262 loss_mouse 1.02991 learning_rate 8.000001e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:17\u003c00:00, 1.14it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "loss_human 1.111991 loss_mouse 0.84773445 learning_rate 1.0000001e-05\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Train the model\n", "steps_per_epoch = 20\n", @@ -643,7 +714,7 @@ " for i in tqdm(range(steps_per_epoch)):\n", " global_step += 1\n", "\n", - " if global_step > 1:\n", + " if global_step \u003e 1:\n", " learning_rate_frac = tf.math.minimum(\n", " 1.0, global_step / tf.math.maximum(1.0, num_warmup_steps)) \n", " learning_rate.assign(target_learning_rate * learning_rate_frac)\n", @@ -659,95 +730,6 @@ " 'loss_mouse', loss_mouse.numpy(),\n", " 'learning_rate', optimizer.learning_rate.numpy()\n", " )" - ], - "execution_count": 59, - "outputs": [ - { - "output_type": "stream", - "text": [ - "100%|██████████| 20/20 [00:24<00:00, 1.25s/it]\n", - " 0%| | 0/20 [00:00 max_steps:\n", + " if max_steps is not None and i \u003e max_steps:\n", " break\n", " metric.update_state(batch['target'], predict(batch['sequence']))\n", "\n", " return metric.result()" - ], - "execution_count": 61, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "57fNitK9hzwd", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "57fNitK9hzwd", "outputId": "947aaadb-dad2-4a00-ddac-d765f65d782f" }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "101it [00:23, 6.27it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "{'PearsonR': 0.0028573992}\n" + ] + } + ], "source": [ "metrics_human = evaluate_model(model,\n", " dataset=get_dataset('human', 'valid').batch(1).prefetch(2),\n", @@ -957,35 +957,35 @@ " max_steps=100)\n", "print('')\n", "print({k: v.numpy().mean() for k, v in metrics_human.items()})" - ], - "execution_count": 66, - "outputs": [ - { - "output_type": "stream", - "text": [ - "101it [00:23, 6.27it/s]" - ], - "name": "stderr" - }, - { - "output_type": "stream", - "text": [ - "\n", - "{'PearsonR': 0.0028573992}\n" - ], - "name": "stdout" - } ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "HY_wj95xiDtE", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "HY_wj95xiDtE", "outputId": "fea839f7-b6c9-46ed-aece-c56b02e9ea16" }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "101it [00:21, 6.54it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "{'PearsonR': 0.005183698}\n" + ] + } + ], "source": [ "metrics_mouse = evaluate_model(model,\n", " dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),\n", @@ -993,25 +993,329 @@ " max_steps=100)\n", "print('')\n", "print({k: v.numpy().mean() for k, v in metrics_mouse.items()})" - ], - "execution_count": 63, + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5k1yaJrNCgvw" + }, + "source": [ + "# Restore Checkpoint\n", + "\n", + "Note: For the TF-Hub Enformer model, the required input sequence length is 393,216 which actually gets cropped within the model to 196,608. The open source module does not internally crop the sequence. Therefore, the code below crops the central `196,608 bp` of the longer sequence to reproduce the output of the TF hub from the reloaded checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DB2cGdH8EGfn" + }, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "EXTENDED_SEQ_LENGTH = 393_216\n", + "SEQ_LENGTH = 196_608\n", + "inputs = np.array(np.random.random((1, EXTENDED_SEQ_LENGTH, 4)), dtype=np.float32)\n", + "inputs_cropped = enformer.TargetLengthCrop1D(SEQ_LENGTH)(inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mdf35itsCjEY" + }, + "outputs": [], + "source": [ + "checkpoint_gs_path = 'gs://dm-enformer/models/enformer/sonnet_weights/*'\n", + "checkpoint_path = '/tmp/enformer_checkpoint'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 280, + "status": "ok", + "timestamp": 1653476327690, + "user": { + "displayName": "Kyle Taylor", + "userId": "14169907681771397124" + }, + "user_tz": -60 + }, + "id": "G2P4IHqswLul", + "outputId": "180abe21-ba00-4031-d9d7-2326f1f742f2" + }, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ - "101it [00:21, 6.54it/s]" - ], - "name": "stderr" - }, - { - "output_type": "stream", - "text": [ - "\n", - "{'PearsonR': 0.005183698}\n" - ], - "name": "stdout" + "mkdir: cannot create directory ‘/tmp/enformer_checkpoint’: File exists\n" + ] } + ], + "source": [ + "!mkdir /tmp/enformer_checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 19279, + "status": "ok", + "timestamp": 1653476347208, + "user": { + "displayName": "Kyle Taylor", + "userId": "14169907681771397124" + }, + "user_tz": -60 + }, + "id": "LTL8EISGCujC", + "outputId": "2b743b9b-480d-44dc-b08e-82d2bc089a47" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gs://dm-enformer/models/enformer/sonnet_weights/checkpoint\n", + "gs://dm-enformer/models/enformer/sonnet_weights/enformer-fine-tuned-human-1.data-00000-of-00001\n", + "gs://dm-enformer/models/enformer/sonnet_weights/enformer-fine-tuned-human-1.index\n" + ] + } + ], + "source": [ + "# Copy checkpoints from GCS to temporary directory.\n", + "# This will take a while as the checkpoint is ~ 1GB.\n", + "for file_path in tf.io.gfile.glob(checkpoint_gs_path):\n", + " print(file_path)\n", + " file_name = os.path.basename(file_path)\n", + " tf.io.gfile.copy(file_path, f'{checkpoint_path}/{file_name}', overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 322, + "status": "ok", + "timestamp": 1653476347527, + "user": { + "displayName": "Kyle Taylor", + "userId": "14169907681771397124" + }, + "user_tz": -60 + }, + "id": "9VSeTx0sCvcw", + "outputId": "b52d7570-c355-4068-b932-3796b56e5586" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 959M\n", + "-rw-r--r-- 1 root root 111 May 25 10:58 checkpoint\n", + "-rw-r--r-- 1 root root 959M May 25 10:59 enformer-fine-tuned-human-1.data-00000-of-00001\n", + "-rw-r--r-- 1 root root 5.7K May 25 10:59 enformer-fine-tuned-human-1.index\n" + ] + } + ], + "source": [ + "!ls -lh /tmp/enformer_checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "00Y2GgRED3aI" + }, + "outputs": [], + "source": [ + "enformer_model = enformer.Enformer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mFyIiGyiD5yh" + }, + "outputs": [], + "source": [ + "checkpoint = tf.train.Checkpoint(module=enformer_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1653476347529, + "user": { + "displayName": "Kyle Taylor", + "userId": "14169907681771397124" + }, + "user_tz": -60 + }, + "id": "VuyspnpOD9kA", + "outputId": "3b495138-be27-4459-af82-c5da2af5bd2d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/enformer_checkpoint/enformer-fine-tuned-human-1\n" + ] + } + ], + "source": [ + "latest = tf.train.latest_checkpoint(checkpoint_path)\n", + "print(latest)\n", + "status = checkpoint.restore(latest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MKkVOTyKEABJ" + }, + "outputs": [], + "source": [ + "# Using `is_training=False` to match TF-hub predict_on_batch function.\n", + "restored_predictions = enformer_model(inputs_cropped, is_training=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OX650jqCEQdv" + }, + "outputs": [], + "source": [ + "import tensorflow_hub as hub\n", + "enformer_tf_hub_model = hub.load(\"https://tfhub.dev/deepmind/enformer/1\").model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yiOTFTSdE5H1" + }, + "outputs": [], + "source": [ + "hub_predictions = enformer_tf_hub_model.predict_on_batch(inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 15, + "status": "ok", + "timestamp": 1653476357260, + "user": { + "displayName": "Kyle Taylor", + "userId": "14169907681771397124" + }, + "user_tz": -60 + }, + "id": "uYrWgfaGFbpL", + "outputId": "cb9d3ad0-ee14-46ac-b188-a4b6a697a159" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.allclose(hub_predictions['human'], restored_predictions['human'], atol=1e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4wEOUMeNzK8q" + }, + "outputs": [], + "source": [ + "# Can run with 'is_training=True' but note that this will\n", + "# change the predictions as the batch statistics will be updated\n", + "# and the outputs will likley not match the TF-hub model.\n", + "# enformer(inputs_cropped, is_training=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jyVHRPAN5w6J" + }, + "outputs": [], + "source": [ + "" ] } - ] + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/research/genomics/dna_to_rna:colab", + "kind": "private" + }, + "name": "enformer-training.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }