Files
deepmind-research/enformer/enformer-training.ipynb
T
Ziga Avsec af3aa09cfe Update README links and add enformer-training.ipynb.
PiperOrigin-RevId: 365241561
2021-03-26 15:48:49 +00:00

1018 lines
38 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "enformer-training.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "rb_ShvB9E8yM"
},
"source": [
"Copyright 2021 DeepMind Technologies Limited\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kXQjDxgdwUmW"
},
"source": [
"This colab showcases training of the Enformer model published in\n",
"\n",
"**\"Effective gene expression prediction from sequence by integrating long-range interactions\"**\n",
"\n",
"Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper, Pushmeet Kohli, David R. Kelley\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2AVkKjy3bh_A"
},
"source": [
"## Steps\n",
"\n",
"- Setup tf.data.Dataset by directly accessing the Basenji2 data on GCS: `gs://basenji_barnyard/data`\n",
"- Train the model for a few steps, alternating training on human and mouse data batches\n",
"- Evaluate the model on human and mouse genomes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sM_PMOT-2Xhi"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqR7ol3rxrtM"
},
"source": [
"**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vhjR7StI1tZn"
},
"source": [
"### Install dependencies"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WiDFm-a41tKW",
"outputId": "8b889c6e-f113-4664-f2c9-91110808ad92"
},
"source": [
"!pip install dm-sonnet tqdm"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting dm-sonnet\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/13/28/9185afffefb655ef1a29f4b84aa9f656826408ca2d1b9ffeba81fbfd40ec/dm_sonnet-2.0.0-py3-none-any.whl (254kB)\n",
"\r\u001b[K |█▎ | 10kB 13.3MB/s eta 0:00:01\r\u001b[K |██▋ | 20kB 11.7MB/s eta 0:00:01\r\u001b[K |███▉ | 30kB 8.7MB/s eta 0:00:01\r\u001b[K |█████▏ | 40kB 7.7MB/s eta 0:00:01\r\u001b[K |██████▍ | 51kB 4.5MB/s eta 0:00:01\r\u001b[K |███████▊ | 61kB 5.1MB/s eta 0:00:01\r\u001b[K |█████████ | 71kB 5.1MB/s eta 0:00:01\r\u001b[K |██████████▎ | 81kB 5.5MB/s eta 0:00:01\r\u001b[K |███████████▋ | 92kB 5.5MB/s eta 0:00:01\r\u001b[K |████████████▉ | 102kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████▏ | 112kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 122kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 133kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████ | 143kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████▎ | 153kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 163kB 5.7MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 174kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 184kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████▌ | 194kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████▊ | 204kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 215kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 225kB 5.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 235kB 5.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 245kB 5.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 256kB 5.7MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n",
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.15.0)\n",
"Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.12.0)\n",
"Requirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.19.5)\n",
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (1.12.1)\n",
"Requirement already satisfied: tabulate>=0.7.5 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.8.9)\n",
"Requirement already satisfied: dm-tree>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from dm-sonnet) (0.1.5)\n",
"Installing collected packages: dm-sonnet\n",
"Successfully installed dm-sonnet-2.0.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CokqDsb-fxme"
},
"source": [
"# Get enformer source code\n",
"!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py\n",
"!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xmffZS_306eb"
},
"source": [
"### Import"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hTGOLrbZxNHK",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f58b5c21-0764-4003-c794-aa89e5d336cc"
},
"source": [
"import tensorflow as tf\n",
"# Make sure the GPU is enabled \n",
"assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'\n",
"\n",
"# Easier debugging of OOM\n",
"%env TF_ENABLE_GPU_GARBAGE_COLLECTION=false"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"env: TF_ENABLE_GPU_GARBAGE_COLLECTION=false\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "S9ywsUmT05C1"
},
"source": [
"import sonnet as snt\n",
"from tqdm import tqdm\n",
"from IPython.display import clear_output\n",
"import numpy as np\n",
"import pandas as pd\n",
"import time\n",
"import os"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YUIbu0Xu1BnA"
},
"source": [
"assert snt.__version__.startswith('2.0')"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "PWzsyJddILcx",
"outputId": "3f1cac0f-6bce-430e-b3c0-9848d43e654c"
},
"source": [
"tf.__version__"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'2.4.1'"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xOhdaXG95eOl",
"outputId": "1e57ef49-254a-4050-89af-61bc0f8ea577"
},
"source": [
"# GPU colab has T4 with 16 GiB of memory\n",
"!nvidia-smi"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Fri Mar 26 12:28:00 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.67 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 42C P8 10W / 70W | 3MiB / 15109MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Xx--Nco09fN"
},
"source": [
"### Code"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BbXyDdoShFzX"
},
"source": [
"import enformer"
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MEb8OZli2Nbu"
},
"source": [
"# @title `get_targets(organism)`\n",
"def get_targets(organism):\n",
" targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'\n",
" return pd.read_csv(targets_txt, sep='\\t')"
],
"execution_count": 41,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2BuZ2gmUbpXZ"
},
"source": [
"# @title `get_dataset(organism, subset, num_threads=8)`\n",
"import glob\n",
"import json\n",
"import functools\n",
"\n",
"\n",
"def organism_path(organism):\n",
" return os.path.join('gs://basenji_barnyard/data', organism)\n",
"\n",
"\n",
"def get_dataset(organism, subset, num_threads=8):\n",
" metadata = get_metadata(organism)\n",
" dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),\n",
" compression_type='ZLIB',\n",
" num_parallel_reads=num_threads)\n",
" dataset = dataset.map(functools.partial(deserialize, metadata=metadata),\n",
" num_parallel_calls=num_threads)\n",
" return dataset\n",
"\n",
"\n",
"def get_metadata(organism):\n",
" # Keys:\n",
" # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,\n",
" # pool_width, crop_bp, target_length\n",
" path = os.path.join(organism_path(organism), 'statistics.json')\n",
" with tf.io.gfile.GFile(path, 'r') as f:\n",
" return json.load(f)\n",
"\n",
"\n",
"def tfrecord_files(organism, subset):\n",
" # Sort the values by int(*).\n",
" return sorted(tf.io.gfile.glob(os.path.join(\n",
" organism_path(organism), 'tfrecords', f'{subset}-*.tfr'\n",
" )), key=lambda x: int(x.split('-')[-1].split('.')[0]))\n",
"\n",
"\n",
"def deserialize(serialized_example, metadata):\n",
" \"\"\"Deserialize bytes stored in TFRecordFile.\"\"\"\n",
" feature_map = {\n",
" 'sequence': tf.io.FixedLenFeature([], tf.string),\n",
" 'target': tf.io.FixedLenFeature([], tf.string),\n",
" }\n",
" example = tf.io.parse_example(serialized_example, feature_map)\n",
" sequence = tf.io.decode_raw(example['sequence'], tf.bool)\n",
" sequence = tf.reshape(sequence, (metadata['seq_length'], 4))\n",
" sequence = tf.cast(sequence, tf.float32)\n",
"\n",
" target = tf.io.decode_raw(example['target'], tf.float16)\n",
" target = tf.reshape(target,\n",
" (metadata['target_length'], metadata['num_targets']))\n",
" target = tf.cast(target, tf.float32)\n",
"\n",
" return {'sequence': sequence,\n",
" 'target': target}\n"
],
"execution_count": 42,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VzGRXfwV4tYH"
},
"source": [
"## Load dataset"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
},
"id": "M_vr1mbl3jbD",
"outputId": "2de351ed-f43e-4469-a681-2a437d97c946"
},
"source": [
"df_targets_human = get_targets('human')\n",
"df_targets_human.head()"
],
"execution_count": 43,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>index</th>\n",
" <th>genome</th>\n",
" <th>identifier</th>\n",
" <th>file</th>\n",
" <th>clip</th>\n",
" <th>scale</th>\n",
" <th>sum_stat</th>\n",
" <th>description</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>ENCFF833POA</td>\n",
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>mean</td>\n",
" <td>DNASE:cerebellum male adult (27 years) and mal...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>ENCFF110QGM</td>\n",
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>mean</td>\n",
" <td>DNASE:frontal cortex male adult (27 years) and...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>ENCFF880MKD</td>\n",
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>mean</td>\n",
" <td>DNASE:chorion</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>ENCFF463ZLQ</td>\n",
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>mean</td>\n",
" <td>DNASE:Ishikawa treated with 0.02% dimethyl sul...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>ENCFF890OGQ</td>\n",
" <td>/home/drk/tillage/datasets/human/dnase/encode/...</td>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>mean</td>\n",
" <td>DNASE:GM03348</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" index genome ... sum_stat description\n",
"0 0 0 ... mean DNASE:cerebellum male adult (27 years) and mal...\n",
"1 1 0 ... mean DNASE:frontal cortex male adult (27 years) and...\n",
"2 2 0 ... mean DNASE:chorion\n",
"3 3 0 ... mean DNASE:Ishikawa treated with 0.02% dimethyl sul...\n",
"4 4 0 ... mean DNASE:GM03348\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YDSKttXI4hMT"
},
"source": [
"human_dataset = get_dataset('human', 'train').batch(1).repeat()\n",
"mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()\n",
"human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)"
],
"execution_count": 44,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2vx3116C7oFW"
},
"source": [
"it = iter(mouse_dataset)\n",
"example = next(it)"
],
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XeztqJZ74ixT",
"outputId": "39dc4051-5a19-4443-b6b0-bf6869faf5ec"
},
"source": [
"# Example input\n",
"it = iter(human_mouse_dataset)\n",
"example = next(it)\n",
"for i in range(len(example)):\n",
" print(['human', 'mouse'][i])\n",
" print({k: (v.shape, v.dtype) for k,v in example[i].items()})"
],
"execution_count": 46,
"outputs": [
{
"output_type": "stream",
"text": [
"human\n",
"{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}\n",
"mouse\n",
"{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SHHNHzFVbvTk"
},
"source": [
"## Model training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0U3hLJaUdZkG"
},
"source": [
"def create_step_function(model, optimizer):\n",
"\n",
" @tf.function\n",
" def train_step(batch, head, optimizer_clip_norm_global=0.2):\n",
" with tf.GradientTape() as tape:\n",
" outputs = model(batch['sequence'], is_training=True)[head]\n",
" loss = tf.reduce_mean(\n",
" tf.keras.losses.poisson(batch['target'], outputs))\n",
"\n",
" gradients = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply(gradients, model.trainable_variables)\n",
"\n",
" return loss\n",
" return train_step"
],
"execution_count": 51,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZXv5HU_242Ut"
},
"source": [
"learning_rate = tf.Variable(0., trainable=False, name='learning_rate')\n",
"optimizer = snt.optimizers.Adam(learning_rate=learning_rate)\n",
"num_warmup_steps = 5000\n",
"target_learning_rate = 0.0005\n",
"\n",
"model = enformer.Enformer(channels=1536 // 4, # Use 4x fewer channels to train faster.\n",
" num_heads=8,\n",
" num_transformer_layers=11,\n",
" pooling_type='max')\n",
"\n",
"train_step = create_step_function(model, optimizer)"
],
"execution_count": 52,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FrbDaOMWcFUl",
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "code",
"outputId": "6a42f69c-3003-47f2-a8d2-1b94c52eb57e"
},
"source": [
"# Train the model\n",
"steps_per_epoch = 20\n",
"num_epochs = 5\n",
"\n",
"data_it = iter(human_mouse_dataset)\n",
"global_step = 0\n",
"for epoch_i in range(num_epochs):\n",
" for i in tqdm(range(steps_per_epoch)):\n",
" global_step += 1\n",
"\n",
" if global_step > 1:\n",
" learning_rate_frac = tf.math.minimum(\n",
" 1.0, global_step / tf.math.maximum(1.0, num_warmup_steps)) \n",
" learning_rate.assign(target_learning_rate * learning_rate_frac)\n",
"\n",
" batch_human, batch_mouse = next(data_it)\n",
"\n",
" loss_human = train_step(batch=batch_human, head='human')\n",
" loss_mouse = train_step(batch=batch_mouse, head='mouse')\n",
"\n",
" # End of epoch.\n",
" print('')\n",
" print('loss_human', loss_human.numpy(),\n",
" 'loss_mouse', loss_mouse.numpy(),\n",
" 'learning_rate', optimizer.learning_rate.numpy()\n",
" )"
],
"execution_count": 59,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [00:24<00:00, 1.25s/it]\n",
" 0%| | 0/20 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"loss_human 1.774059 loss_mouse 0.94303024 learning_rate 2.0000002e-06\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [00:17<00:00, 1.13it/s]\n",
" 0%| | 0/20 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"loss_human 1.0067647 loss_mouse 0.8752468 learning_rate 4.0000004e-06\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [00:17<00:00, 1.13it/s]\n",
" 0%| | 0/20 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"loss_human 1.0471998 loss_mouse 0.89318746 learning_rate 6e-06\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [00:17<00:00, 1.14it/s]\n",
" 0%| | 0/20 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"loss_human 1.010262 loss_mouse 1.02991 learning_rate 8.000001e-06\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [00:17<00:00, 1.14it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"loss_human 1.111991 loss_mouse 0.84773445 learning_rate 1.0000001e-05\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cs0f0z0RcCfz"
},
"source": [
"## Evaluate"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8c4lNQrHkXSC",
"cellView": "form"
},
"source": [
"# @title `PearsonR` and `R2` metrics\n",
"\n",
"def _reduced_shape(shape, axis):\n",
" if axis is None:\n",
" return tf.TensorShape([])\n",
" return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])\n",
"\n",
"\n",
"class CorrelationStats(tf.keras.metrics.Metric):\n",
" \"\"\"Contains shared code for PearsonR and R2.\"\"\"\n",
"\n",
" def __init__(self, reduce_axis=None, name='pearsonr'):\n",
" \"\"\"Pearson correlation coefficient.\n",
"\n",
" Args:\n",
" reduce_axis: Specifies over which axis to compute the correlation (say\n",
" (0, 1). If not specified, it will compute the correlation across the\n",
" whole tensor.\n",
" name: Metric name.\n",
" \"\"\"\n",
" super(CorrelationStats, self).__init__(name=name)\n",
" self._reduce_axis = reduce_axis\n",
" self._shape = None # Specified in _initialize.\n",
"\n",
" def _initialize(self, input_shape):\n",
" # Remaining dimensions after reducing over self._reduce_axis.\n",
" self._shape = _reduced_shape(input_shape, self._reduce_axis)\n",
"\n",
" weight_kwargs = dict(shape=self._shape, initializer='zeros')\n",
" self._count = self.add_weight(name='count', **weight_kwargs)\n",
" self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)\n",
" self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)\n",
" self._true_squared_sum = self.add_weight(name='true_squared_sum',\n",
" **weight_kwargs)\n",
" self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)\n",
" self._pred_squared_sum = self.add_weight(name='pred_squared_sum',\n",
" **weight_kwargs)\n",
"\n",
" def update_state(self, y_true, y_pred, sample_weight=None):\n",
" \"\"\"Update the metric state.\n",
"\n",
" Args:\n",
" y_true: Multi-dimensional float tensor [batch, ...] containing the ground\n",
" truth values.\n",
" y_pred: float tensor with the same shape as y_true containing predicted\n",
" values.\n",
" sample_weight: 1D tensor aligned with y_true batch dimension specifying\n",
" the weight of individual observations.\n",
" \"\"\"\n",
" if self._shape is None:\n",
" # Explicit initialization check.\n",
" self._initialize(y_true.shape)\n",
" y_true.shape.assert_is_compatible_with(y_pred.shape)\n",
" y_true = tf.cast(y_true, 'float32')\n",
" y_pred = tf.cast(y_pred, 'float32')\n",
"\n",
" self._product_sum.assign_add(\n",
" tf.reduce_sum(y_true * y_pred, axis=self._reduce_axis))\n",
"\n",
" self._true_sum.assign_add(\n",
" tf.reduce_sum(y_true, axis=self._reduce_axis))\n",
"\n",
" self._true_squared_sum.assign_add(\n",
" tf.reduce_sum(tf.math.square(y_true), axis=self._reduce_axis))\n",
"\n",
" self._pred_sum.assign_add(\n",
" tf.reduce_sum(y_pred, axis=self._reduce_axis))\n",
"\n",
" self._pred_squared_sum.assign_add(\n",
" tf.reduce_sum(tf.math.square(y_pred), axis=self._reduce_axis))\n",
"\n",
" self._count.assign_add(\n",
" tf.reduce_sum(tf.ones_like(y_true), axis=self._reduce_axis))\n",
"\n",
" def result(self):\n",
" raise NotImplementedError('Must be implemented in subclasses.')\n",
"\n",
" def reset_states(self):\n",
" if self._shape is not None:\n",
" tf.keras.backend.batch_set_value([(v, np.zeros(self._shape))\n",
" for v in self.variables])\n",
"\n",
"\n",
"class PearsonR(CorrelationStats):\n",
" \"\"\"Pearson correlation coefficient.\n",
"\n",
" Computed as:\n",
" ((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])\n",
" \"\"\"\n",
"\n",
" def __init__(self, reduce_axis=(0,), name='pearsonr'):\n",
" \"\"\"Pearson correlation coefficient.\n",
"\n",
" Args:\n",
" reduce_axis: Specifies over which axis to compute the correlation.\n",
" name: Metric name.\n",
" \"\"\"\n",
" super(PearsonR, self).__init__(reduce_axis=reduce_axis,\n",
" name=name)\n",
"\n",
" def result(self):\n",
" true_mean = self._true_sum / self._count\n",
" pred_mean = self._pred_sum / self._count\n",
"\n",
" covariance = (self._product_sum\n",
" - true_mean * self._pred_sum\n",
" - pred_mean * self._true_sum\n",
" + self._count * true_mean * pred_mean)\n",
"\n",
" true_var = self._true_squared_sum - self._count * tf.math.square(true_mean)\n",
" pred_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)\n",
" tp_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)\n",
" correlation = covariance / tp_var\n",
"\n",
" return correlation\n",
"\n",
"\n",
"class R2(CorrelationStats):\n",
" \"\"\"R-squared (fraction of explained variance).\"\"\"\n",
"\n",
" def __init__(self, reduce_axis=None, name='R2'):\n",
" \"\"\"R-squared metric.\n",
"\n",
" Args:\n",
" reduce_axis: Specifies over which axis to compute the correlation.\n",
" name: Metric name.\n",
" \"\"\"\n",
" super(R2, self).__init__(reduce_axis=reduce_axis,\n",
" name=name)\n",
"\n",
" def result(self):\n",
" true_mean = self._true_sum / self._count\n",
" total = self._true_squared_sum - self._count * tf.math.square(true_mean)\n",
" residuals = (self._pred_squared_sum - 2 * self._product_sum\n",
" + self._true_squared_sum)\n",
"\n",
" return tf.ones_like(residuals) - residuals / total\n",
"\n",
"\n",
"class MetricDict:\n",
" def __init__(self, metrics):\n",
" self._metrics = metrics\n",
"\n",
" def update_state(self, y_true, y_pred):\n",
" for k, metric in self._metrics.items():\n",
" metric.update_state(y_true, y_pred)\n",
"\n",
" def result(self):\n",
" return {k: metric.result() for k, metric in self._metrics.items()}"
],
"execution_count": 60,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "x80gX9LrhBU-"
},
"source": [
"def evaluate_model(model, dataset, head, max_steps=None):\n",
" metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})\n",
" @tf.function\n",
" def predict(x):\n",
" return model(x, is_training=False)[head]\n",
"\n",
" for i, batch in tqdm(enumerate(dataset)):\n",
" if max_steps is not None and i > max_steps:\n",
" break\n",
" metric.update_state(batch['target'], predict(batch['sequence']))\n",
"\n",
" return metric.result()"
],
"execution_count": 61,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "57fNitK9hzwd",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "947aaadb-dad2-4a00-ddac-d765f65d782f"
},
"source": [
"metrics_human = evaluate_model(model,\n",
" dataset=get_dataset('human', 'valid').batch(1).prefetch(2),\n",
" head='human',\n",
" max_steps=100)\n",
"print('')\n",
"print({k: v.numpy().mean() for k, v in metrics_human.items()})"
],
"execution_count": 66,
"outputs": [
{
"output_type": "stream",
"text": [
"101it [00:23, 6.27it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"{'PearsonR': 0.0028573992}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "HY_wj95xiDtE",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fea839f7-b6c9-46ed-aece-c56b02e9ea16"
},
"source": [
"metrics_mouse = evaluate_model(model,\n",
" dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),\n",
" head='mouse',\n",
" max_steps=100)\n",
"print('')\n",
"print({k: v.numpy().mean() for k, v in metrics_mouse.items()})"
],
"execution_count": 63,
"outputs": [
{
"output_type": "stream",
"text": [
"101it [00:21, 6.54it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"{'PearsonR': 0.005183698}\n"
],
"name": "stdout"
}
]
}
]
}