mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
af3aa09cfe
PiperOrigin-RevId: 365241561
1018 lines
38 KiB
Plaintext
1018 lines
38 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "enformer-training.ipynb",
|
|
"provenance": [],
|
|
"collapsed_sections": [],
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rb_ShvB9E8yM"
|
|
},
|
|
"source": [
|
|
"Copyright 2021 DeepMind Technologies Limited\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "kXQjDxgdwUmW"
|
|
},
|
|
"source": [
|
|
"This colab showcases training of the Enformer model published in\n",
|
|
"\n",
|
|
"**\"Effective gene expression prediction from sequence by integrating long-range interactions\"**\n",
|
|
"\n",
|
|
"Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper, Pushmeet Kohli, David R. Kelley\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "2AVkKjy3bh_A"
|
|
},
|
|
"source": [
|
|
"## Steps\n",
|
|
"\n",
|
|
"- Setup tf.data.Dataset by directly accessing the Basenji2 data on GCS: `gs://basenji_barnyard/data`\n",
|
|
"- Train the model for a few steps, alternating training on human and mouse data batches\n",
|
|
"- Evaluate the model on human and mouse genomes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "sM_PMOT-2Xhi"
|
|
},
|
|
"source": [
|
|
"## Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "NqR7ol3rxrtM"
|
|
},
|
|
"source": [
|
|
"**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "vhjR7StI1tZn"
|
|
},
|
|
"source": [
|
|
"### Install dependencies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "WiDFm-a41tKW",
|
|
"outputId": "8b889c6e-f113-4664-f2c9-91110808ad92"
|
|
},
|
|
"source": [
|
|
"!pip install dm-sonnet tqdm"
|
|
],
|
|
"execution_count": 1,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting dm-sonnet\n",
|
|
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/13/28/9185afffefb655ef1a29f4b84aa9f656826408ca2d1b9ffeba81fbfd40ec/dm_sonnet-2.0.0-py3-none-any.whl (254kB)\n",
|
|
"\r\u001b[K |█▎ | 10kB 13.3MB/s eta 0:00:01\r\u001b[K |██▋ | 20kB 11.7MB/s eta 0:00:01\r\u001b[K |███▉ | 30kB 8.7MB/s eta 0:00:01\r\u001b[K |█████▏ | 40kB 7.7MB/s eta 0:00:01\r\u001b[K |██████▍ | 51kB 4.5MB/s eta 0:00:01\r\u001b[K |███████▊ | 61kB 5.1MB/s eta 0:00:01\r\u001b[K |█████████ | 71kB 5.1MB/s eta 0:00:01\r\u001b[K |██████████▎ | 81kB 5.5MB/s eta 0:00:01\r\u001b[K |███████████▋ | 92kB 5.5MB/s eta 0:00:01\r\u001b[K |████████████▉ | 102kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████▏ | 112kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 122kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 133kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████ | 143kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████▎ | 153kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 163kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 174kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 184kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████▌ | 194kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████▊ | 204kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 215kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 225kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 235kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 245kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 256kB 5.7MB/s \n",
|
|
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n",
|
|
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.15.0)\n",
|
|
"Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.12.0)\n",
|
|
"Requirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.19.5)\n",
|
|
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.12.1)\n",
|
|
"Requirement already satisfied: tabulate>=0.7.5 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.8.9)\n",
|
|
"Requirement already satisfied: dm-tree>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.1.5)\n",
|
|
"Installing collected packages: dm-sonnet\n",
|
|
"Successfully installed dm-sonnet-2.0.0\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "CokqDsb-fxme"
|
|
},
|
|
"source": [
|
|
"# Get enformer source code\n",
|
|
"!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py\n",
|
|
"!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py"
|
|
],
|
|
"execution_count": 15,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xmffZS_306eb"
|
|
},
|
|
"source": [
|
|
"### Import"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "hTGOLrbZxNHK",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "f58b5c21-0764-4003-c794-aa89e5d336cc"
|
|
},
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"# Make sure the GPU is enabled \n",
|
|
"assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'\n",
|
|
"\n",
|
|
"# Easier debugging of OOM\n",
|
|
"%env TF_ENABLE_GPU_GARBAGE_COLLECTION=false"
|
|
],
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"env: TF_ENABLE_GPU_GARBAGE_COLLECTION=false\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "S9ywsUmT05C1"
|
|
},
|
|
"source": [
|
|
"import sonnet as snt\n",
|
|
"from tqdm import tqdm\n",
|
|
"from IPython.display import clear_output\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import time\n",
|
|
"import os"
|
|
],
|
|
"execution_count": 5,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YUIbu0Xu1BnA"
|
|
},
|
|
"source": [
|
|
"assert snt.__version__.startswith('2.0')"
|
|
],
|
|
"execution_count": 6,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"id": "PWzsyJddILcx",
|
|
"outputId": "3f1cac0f-6bce-430e-b3c0-9848d43e654c"
|
|
},
|
|
"source": [
|
|
"tf.__version__"
|
|
],
|
|
"execution_count": 7,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'2.4.1'"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 7
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "xOhdaXG95eOl",
|
|
"outputId": "1e57ef49-254a-4050-89af-61bc0f8ea577"
|
|
},
|
|
"source": [
|
|
"# GPU colab has T4 with 16 GiB of memory\n",
|
|
"!nvidia-smi"
|
|
],
|
|
"execution_count": 8,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Fri Mar 26 12:28:00 2021 \n",
|
|
"+-----------------------------------------------------------------------------+\n",
|
|
"| NVIDIA-SMI 460.67 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
|
|
"|-------------------------------+----------------------+----------------------+\n",
|
|
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
|
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
|
|
"| | | MIG M. |\n",
|
|
"|===============================+======================+======================|\n",
|
|
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
|
|
"| N/A 42C P8 10W / 70W | 3MiB / 15109MiB | 0% Default |\n",
|
|
"| | | N/A |\n",
|
|
"+-------------------------------+----------------------+----------------------+\n",
|
|
" \n",
|
|
"+-----------------------------------------------------------------------------+\n",
|
|
"| Processes: |\n",
|
|
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
|
"| ID ID Usage |\n",
|
|
"|=============================================================================|\n",
|
|
"| No running processes found |\n",
|
|
"+-----------------------------------------------------------------------------+\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0Xx--Nco09fN"
|
|
},
|
|
"source": [
|
|
"### Code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "BbXyDdoShFzX"
|
|
},
|
|
"source": [
|
|
"import enformer"
|
|
],
|
|
"execution_count": 37,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "MEb8OZli2Nbu"
|
|
},
|
|
"source": [
|
|
"# @title `get_targets(organism)`\n",
|
|
"def get_targets(organism):\n",
|
|
" targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'\n",
|
|
" return pd.read_csv(targets_txt, sep='\\t')"
|
|
],
|
|
"execution_count": 41,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2BuZ2gmUbpXZ"
|
|
},
|
|
"source": [
|
|
"# @title `get_dataset(organism, subset, num_threads=8)`\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import functools\n",
|
|
"\n",
|
|
"\n",
|
|
"def organism_path(organism):\n",
|
|
" return os.path.join('gs://basenji_barnyard/data', organism)\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_dataset(organism, subset, num_threads=8):\n",
|
|
" metadata = get_metadata(organism)\n",
|
|
" dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),\n",
|
|
" compression_type='ZLIB',\n",
|
|
" num_parallel_reads=num_threads)\n",
|
|
" dataset = dataset.map(functools.partial(deserialize, metadata=metadata),\n",
|
|
" num_parallel_calls=num_threads)\n",
|
|
" return dataset\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_metadata(organism):\n",
|
|
" # Keys:\n",
|
|
" # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,\n",
|
|
" # pool_width, crop_bp, target_length\n",
|
|
" path = os.path.join(organism_path(organism), 'statistics.json')\n",
|
|
" with tf.io.gfile.GFile(path, 'r') as f:\n",
|
|
" return json.load(f)\n",
|
|
"\n",
|
|
"\n",
|
|
"def tfrecord_files(organism, subset):\n",
|
|
" # Sort the values by int(*).\n",
|
|
" return sorted(tf.io.gfile.glob(os.path.join(\n",
|
|
" organism_path(organism), 'tfrecords', f'{subset}-*.tfr'\n",
|
|
" )), key=lambda x: int(x.split('-')[-1].split('.')[0]))\n",
|
|
"\n",
|
|
"\n",
|
|
"def deserialize(serialized_example, metadata):\n",
|
|
" \"\"\"Deserialize bytes stored in TFRecordFile.\"\"\"\n",
|
|
" feature_map = {\n",
|
|
" 'sequence': tf.io.FixedLenFeature([], tf.string),\n",
|
|
" 'target': tf.io.FixedLenFeature([], tf.string),\n",
|
|
" }\n",
|
|
" example = tf.io.parse_example(serialized_example, feature_map)\n",
|
|
" sequence = tf.io.decode_raw(example['sequence'], tf.bool)\n",
|
|
" sequence = tf.reshape(sequence, (metadata['seq_length'], 4))\n",
|
|
" sequence = tf.cast(sequence, tf.float32)\n",
|
|
"\n",
|
|
" target = tf.io.decode_raw(example['target'], tf.float16)\n",
|
|
" target = tf.reshape(target,\n",
|
|
" (metadata['target_length'], metadata['num_targets']))\n",
|
|
" target = tf.cast(target, tf.float32)\n",
|
|
"\n",
|
|
" return {'sequence': sequence,\n",
|
|
" 'target': target}\n"
|
|
],
|
|
"execution_count": 42,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "VzGRXfwV4tYH"
|
|
},
|
|
"source": [
|
|
"## Load dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 203
|
|
},
|
|
"id": "M_vr1mbl3jbD",
|
|
"outputId": "2de351ed-f43e-4469-a681-2a437d97c946"
|
|
},
|
|
"source": [
|
|
"df_targets_human = get_targets('human')\n",
|
|
"df_targets_human.head()"
|
|
],
|
|
"execution_count": 43,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>index</th>\n",
|
|
" <th>genome</th>\n",
|
|
" <th>identifier</th>\n",
|
|
" <th>file</th>\n",
|
|
" <th>clip</th>\n",
|
|
" <th>scale</th>\n",
|
|
" <th>sum_stat</th>\n",
|
|
" <th>description</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>ENCFF833POA</td>\n",
|
|
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
|
|
" <td>32</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>mean</td>\n",
|
|
" <td>DNASE:cerebellum male adult (27 years) and mal...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>ENCFF110QGM</td>\n",
|
|
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
|
|
" <td>32</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>mean</td>\n",
|
|
" <td>DNASE:frontal cortex male adult (27 years) and...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>ENCFF880MKD</td>\n",
|
|
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
|
|
" <td>32</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>mean</td>\n",
|
|
" <td>DNASE:chorion</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>ENCFF463ZLQ</td>\n",
|
|
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
|
|
" <td>32</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>mean</td>\n",
|
|
" <td>DNASE:Ishikawa treated with 0.02% dimethyl sul...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>ENCFF890OGQ</td>\n",
|
|
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
|
|
" <td>32</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>mean</td>\n",
|
|
" <td>DNASE:GM03348</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" index genome ... sum_stat description\n",
|
|
"0 0 0 ... mean DNASE:cerebellum male adult (27 years) and mal...\n",
|
|
"1 1 0 ... mean DNASE:frontal cortex male adult (27 years) and...\n",
|
|
"2 2 0 ... mean DNASE:chorion\n",
|
|
"3 3 0 ... mean DNASE:Ishikawa treated with 0.02% dimethyl sul...\n",
|
|
"4 4 0 ... mean DNASE:GM03348\n",
|
|
"\n",
|
|
"[5 rows x 8 columns]"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 43
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YDSKttXI4hMT"
|
|
},
|
|
"source": [
|
|
"human_dataset = get_dataset('human', 'train').batch(1).repeat()\n",
|
|
"mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()\n",
|
|
"human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)"
|
|
],
|
|
"execution_count": 44,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2vx3116C7oFW"
|
|
},
|
|
"source": [
|
|
"it = iter(mouse_dataset)\n",
|
|
"example = next(it)"
|
|
],
|
|
"execution_count": 45,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "XeztqJZ74ixT",
|
|
"outputId": "39dc4051-5a19-4443-b6b0-bf6869faf5ec"
|
|
},
|
|
"source": [
|
|
"# Example input\n",
|
|
"it = iter(human_mouse_dataset)\n",
|
|
"example = next(it)\n",
|
|
"for i in range(len(example)):\n",
|
|
" print(['human', 'mouse'][i])\n",
|
|
" print({k: (v.shape, v.dtype) for k,v in example[i].items()})"
|
|
],
|
|
"execution_count": 46,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"human\n",
|
|
"{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}\n",
|
|
"mouse\n",
|
|
"{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "SHHNHzFVbvTk"
|
|
},
|
|
"source": [
|
|
"## Model training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "0U3hLJaUdZkG"
|
|
},
|
|
"source": [
|
|
"def create_step_function(model, optimizer):\n",
|
|
"\n",
|
|
" @tf.function\n",
|
|
" def train_step(batch, head, optimizer_clip_norm_global=0.2):\n",
|
|
" with tf.GradientTape() as tape:\n",
|
|
" outputs = model(batch['sequence'], is_training=True)[head]\n",
|
|
" loss = tf.reduce_mean(\n",
|
|
" tf.keras.losses.poisson(batch['target'], outputs))\n",
|
|
"\n",
|
|
" gradients = tape.gradient(loss, model.trainable_variables)\n",
|
|
" optimizer.apply(gradients, model.trainable_variables)\n",
|
|
"\n",
|
|
" return loss\n",
|
|
" return train_step"
|
|
],
|
|
"execution_count": 51,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ZXv5HU_242Ut"
|
|
},
|
|
"source": [
|
|
"learning_rate = tf.Variable(0., trainable=False, name='learning_rate')\n",
|
|
"optimizer = snt.optimizers.Adam(learning_rate=learning_rate)\n",
|
|
"num_warmup_steps = 5000\n",
|
|
"target_learning_rate = 0.0005\n",
|
|
"\n",
|
|
"model = enformer.Enformer(channels=1536 // 4, # Use 4x fewer channels to train faster.\n",
|
|
" num_heads=8,\n",
|
|
" num_transformer_layers=11,\n",
|
|
" pooling_type='max')\n",
|
|
"\n",
|
|
"train_step = create_step_function(model, optimizer)"
|
|
],
|
|
"execution_count": 52,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "FrbDaOMWcFUl",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"cellView": "code",
|
|
"outputId": "6a42f69c-3003-47f2-a8d2-1b94c52eb57e"
|
|
},
|
|
"source": [
|
|
"# Train the model\n",
|
|
"steps_per_epoch = 20\n",
|
|
"num_epochs = 5\n",
|
|
"\n",
|
|
"data_it = iter(human_mouse_dataset)\n",
|
|
"global_step = 0\n",
|
|
"for epoch_i in range(num_epochs):\n",
|
|
" for i in tqdm(range(steps_per_epoch)):\n",
|
|
" global_step += 1\n",
|
|
"\n",
|
|
" if global_step > 1:\n",
|
|
" learning_rate_frac = tf.math.minimum(\n",
|
|
" 1.0, global_step / tf.math.maximum(1.0, num_warmup_steps)) \n",
|
|
" learning_rate.assign(target_learning_rate * learning_rate_frac)\n",
|
|
"\n",
|
|
" batch_human, batch_mouse = next(data_it)\n",
|
|
"\n",
|
|
" loss_human = train_step(batch=batch_human, head='human')\n",
|
|
" loss_mouse = train_step(batch=batch_mouse, head='mouse')\n",
|
|
"\n",
|
|
" # End of epoch.\n",
|
|
" print('')\n",
|
|
" print('loss_human', loss_human.numpy(),\n",
|
|
" 'loss_mouse', loss_mouse.numpy(),\n",
|
|
" 'learning_rate', optimizer.learning_rate.numpy()\n",
|
|
" )"
|
|
],
|
|
"execution_count": 59,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 20/20 [00:24<00:00, 1.25s/it]\n",
|
|
" 0%| | 0/20 [00:00<?, ?it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"loss_human 1.774059 loss_mouse 0.94303024 learning_rate 2.0000002e-06\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 20/20 [00:17<00:00, 1.13it/s]\n",
|
|
" 0%| | 0/20 [00:00<?, ?it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"loss_human 1.0067647 loss_mouse 0.8752468 learning_rate 4.0000004e-06\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 20/20 [00:17<00:00, 1.13it/s]\n",
|
|
" 0%| | 0/20 [00:00<?, ?it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"loss_human 1.0471998 loss_mouse 0.89318746 learning_rate 6e-06\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 20/20 [00:17<00:00, 1.14it/s]\n",
|
|
" 0%| | 0/20 [00:00<?, ?it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"loss_human 1.010262 loss_mouse 1.02991 learning_rate 8.000001e-06\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 20/20 [00:17<00:00, 1.14it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"loss_human 1.111991 loss_mouse 0.84773445 learning_rate 1.0000001e-05\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
],
|
|
"name": "stderr"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Cs0f0z0RcCfz"
|
|
},
|
|
"source": [
|
|
"## Evaluate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "8c4lNQrHkXSC",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"# @title `PearsonR` and `R2` metrics\n",
|
|
"\n",
|
|
"def _reduced_shape(shape, axis):\n",
|
|
" if axis is None:\n",
|
|
" return tf.TensorShape([])\n",
|
|
" return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])\n",
|
|
"\n",
|
|
"\n",
|
|
"class CorrelationStats(tf.keras.metrics.Metric):\n",
|
|
" \"\"\"Contains shared code for PearsonR and R2.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, reduce_axis=None, name='pearsonr'):\n",
|
|
" \"\"\"Pearson correlation coefficient.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" reduce_axis: Specifies over which axis to compute the correlation (say\n",
|
|
" (0, 1). If not specified, it will compute the correlation across the\n",
|
|
" whole tensor.\n",
|
|
" name: Metric name.\n",
|
|
" \"\"\"\n",
|
|
" super(CorrelationStats, self).__init__(name=name)\n",
|
|
" self._reduce_axis = reduce_axis\n",
|
|
" self._shape = None # Specified in _initialize.\n",
|
|
"\n",
|
|
" def _initialize(self, input_shape):\n",
|
|
" # Remaining dimensions after reducing over self._reduce_axis.\n",
|
|
" self._shape = _reduced_shape(input_shape, self._reduce_axis)\n",
|
|
"\n",
|
|
" weight_kwargs = dict(shape=self._shape, initializer='zeros')\n",
|
|
" self._count = self.add_weight(name='count', **weight_kwargs)\n",
|
|
" self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)\n",
|
|
" self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)\n",
|
|
" self._true_squared_sum = self.add_weight(name='true_squared_sum',\n",
|
|
" **weight_kwargs)\n",
|
|
" self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)\n",
|
|
" self._pred_squared_sum = self.add_weight(name='pred_squared_sum',\n",
|
|
" **weight_kwargs)\n",
|
|
"\n",
|
|
" def update_state(self, y_true, y_pred, sample_weight=None):\n",
|
|
" \"\"\"Update the metric state.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" y_true: Multi-dimensional float tensor [batch, ...] containing the ground\n",
|
|
" truth values.\n",
|
|
" y_pred: float tensor with the same shape as y_true containing predicted\n",
|
|
" values.\n",
|
|
" sample_weight: 1D tensor aligned with y_true batch dimension specifying\n",
|
|
" the weight of individual observations.\n",
|
|
" \"\"\"\n",
|
|
" if self._shape is None:\n",
|
|
" # Explicit initialization check.\n",
|
|
" self._initialize(y_true.shape)\n",
|
|
" y_true.shape.assert_is_compatible_with(y_pred.shape)\n",
|
|
" y_true = tf.cast(y_true, 'float32')\n",
|
|
" y_pred = tf.cast(y_pred, 'float32')\n",
|
|
"\n",
|
|
" self._product_sum.assign_add(\n",
|
|
" tf.reduce_sum(y_true * y_pred, axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" self._true_sum.assign_add(\n",
|
|
" tf.reduce_sum(y_true, axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" self._true_squared_sum.assign_add(\n",
|
|
" tf.reduce_sum(tf.math.square(y_true), axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" self._pred_sum.assign_add(\n",
|
|
" tf.reduce_sum(y_pred, axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" self._pred_squared_sum.assign_add(\n",
|
|
" tf.reduce_sum(tf.math.square(y_pred), axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" self._count.assign_add(\n",
|
|
" tf.reduce_sum(tf.ones_like(y_true), axis=self._reduce_axis))\n",
|
|
"\n",
|
|
" def result(self):\n",
|
|
" raise NotImplementedError('Must be implemented in subclasses.')\n",
|
|
"\n",
|
|
" def reset_states(self):\n",
|
|
" if self._shape is not None:\n",
|
|
" tf.keras.backend.batch_set_value([(v, np.zeros(self._shape))\n",
|
|
" for v in self.variables])\n",
|
|
"\n",
|
|
"\n",
|
|
"class PearsonR(CorrelationStats):\n",
|
|
" \"\"\"Pearson correlation coefficient.\n",
|
|
"\n",
|
|
" Computed as:\n",
|
|
" ((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, reduce_axis=(0,), name='pearsonr'):\n",
|
|
" \"\"\"Pearson correlation coefficient.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" reduce_axis: Specifies over which axis to compute the correlation.\n",
|
|
" name: Metric name.\n",
|
|
" \"\"\"\n",
|
|
" super(PearsonR, self).__init__(reduce_axis=reduce_axis,\n",
|
|
" name=name)\n",
|
|
"\n",
|
|
" def result(self):\n",
|
|
" true_mean = self._true_sum / self._count\n",
|
|
" pred_mean = self._pred_sum / self._count\n",
|
|
"\n",
|
|
" covariance = (self._product_sum\n",
|
|
" - true_mean * self._pred_sum\n",
|
|
" - pred_mean * self._true_sum\n",
|
|
" + self._count * true_mean * pred_mean)\n",
|
|
"\n",
|
|
" true_var = self._true_squared_sum - self._count * tf.math.square(true_mean)\n",
|
|
" pred_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)\n",
|
|
" tp_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)\n",
|
|
" correlation = covariance / tp_var\n",
|
|
"\n",
|
|
" return correlation\n",
|
|
"\n",
|
|
"\n",
|
|
"class R2(CorrelationStats):\n",
|
|
" \"\"\"R-squared (fraction of explained variance).\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, reduce_axis=None, name='R2'):\n",
|
|
" \"\"\"R-squared metric.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" reduce_axis: Specifies over which axis to compute the correlation.\n",
|
|
" name: Metric name.\n",
|
|
" \"\"\"\n",
|
|
" super(R2, self).__init__(reduce_axis=reduce_axis,\n",
|
|
" name=name)\n",
|
|
"\n",
|
|
" def result(self):\n",
|
|
" true_mean = self._true_sum / self._count\n",
|
|
" total = self._true_squared_sum - self._count * tf.math.square(true_mean)\n",
|
|
" residuals = (self._pred_squared_sum - 2 * self._product_sum\n",
|
|
" + self._true_squared_sum)\n",
|
|
"\n",
|
|
" return tf.ones_like(residuals) - residuals / total\n",
|
|
"\n",
|
|
"\n",
|
|
"class MetricDict:\n",
|
|
" def __init__(self, metrics):\n",
|
|
" self._metrics = metrics\n",
|
|
"\n",
|
|
" def update_state(self, y_true, y_pred):\n",
|
|
" for k, metric in self._metrics.items():\n",
|
|
" metric.update_state(y_true, y_pred)\n",
|
|
"\n",
|
|
" def result(self):\n",
|
|
" return {k: metric.result() for k, metric in self._metrics.items()}"
|
|
],
|
|
"execution_count": 60,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "x80gX9LrhBU-"
|
|
},
|
|
"source": [
|
|
"def evaluate_model(model, dataset, head, max_steps=None):\n",
|
|
" metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})\n",
|
|
" @tf.function\n",
|
|
" def predict(x):\n",
|
|
" return model(x, is_training=False)[head]\n",
|
|
"\n",
|
|
" for i, batch in tqdm(enumerate(dataset)):\n",
|
|
" if max_steps is not None and i > max_steps:\n",
|
|
" break\n",
|
|
" metric.update_state(batch['target'], predict(batch['sequence']))\n",
|
|
"\n",
|
|
" return metric.result()"
|
|
],
|
|
"execution_count": 61,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "57fNitK9hzwd",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "947aaadb-dad2-4a00-ddac-d765f65d782f"
|
|
},
|
|
"source": [
|
|
"metrics_human = evaluate_model(model,\n",
|
|
" dataset=get_dataset('human', 'valid').batch(1).prefetch(2),\n",
|
|
" head='human',\n",
|
|
" max_steps=100)\n",
|
|
"print('')\n",
|
|
"print({k: v.numpy().mean() for k, v in metrics_human.items()})"
|
|
],
|
|
"execution_count": 66,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"101it [00:23, 6.27it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"{'PearsonR': 0.0028573992}\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "HY_wj95xiDtE",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "fea839f7-b6c9-46ed-aece-c56b02e9ea16"
|
|
},
|
|
"source": [
|
|
"metrics_mouse = evaluate_model(model,\n",
|
|
" dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),\n",
|
|
" head='mouse',\n",
|
|
" max_steps=100)\n",
|
|
"print('')\n",
|
|
"print({k: v.numpy().mean() for k, v in metrics_mouse.items()})"
|
|
],
|
|
"execution_count": 63,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"101it [00:21, 6.54it/s]"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"{'PearsonR': 0.005183698}\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|