mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-24 00:05:19 +08:00
Added link to our paper.
PiperOrigin-RevId: 362909971
This commit is contained in:
committed by
Louise Deason
parent
5cf55efe1f
commit
db5c562251
@@ -0,0 +1,146 @@
|
||||
# Enformer
|
||||
|
||||
This package provides an implementation of the Enformer model and examples on
|
||||
running the model.
|
||||
|
||||
If this source code or accompanying files are helpful for your research please
|
||||
cite the following publication:
|
||||
|
||||
"Effective gene expression prediction from sequence by integrating long-range
|
||||
interactions"
|
||||
|
||||
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
|
||||
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
|
||||
Pushmeet Kohli1, David R. Kelley2*
|
||||
|
||||
1 DeepMind, London, UK
|
||||
2 Calico Life Sciences, South San Francisco, CA, USA
|
||||
3 Google, Tokyo, Japan
|
||||
4 These authors contributed equally.
|
||||
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
|
||||
|
||||
## Setup
|
||||
|
||||
Requirements:
|
||||
|
||||
* dm-sonnet (2.0.0)
|
||||
* kipoiseq (0.5.2)
|
||||
* numpy (1.19.5)
|
||||
* pandas (1.2.3)
|
||||
* tensoflow (2.4.1)
|
||||
* tensorflow-hub (0.11.0)
|
||||
|
||||
See `requirements.txt`.
|
||||
|
||||
To run the unit test:
|
||||
|
||||
```shell
|
||||
python3.8 -m venv enformer_venv
|
||||
source enformer_venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
python -m enformer_test
|
||||
```
|
||||
|
||||
## Pre-computed variant effect predictions
|
||||
|
||||
We precomputed variant effect scores for all frequent variants (MAF>0.5%, in any
|
||||
population) present in the 1000 genomes project. Variant scores in HDF5 file
|
||||
format per chromosome for HG19 reference genome can be found
|
||||
[here](TODO).
|
||||
The HDF5 file has the same format as the output of
|
||||
[this](https://github.com/calico/basenji/blob/738321c85f8925ae6ac318a6cd4901a42ea6bc3f/bin/basenji_sad.py#L264)
|
||||
script and contains the following arrays:
|
||||
|
||||
* snp \[num_snps](string) - snp id
|
||||
* chr \[num_snps](string) - chromosome name
|
||||
* pos \[num_snps](uint32) - position (1-based)
|
||||
* ref \[num_snps](string) - reference allele
|
||||
* alt \[num_snps](string) - alternative allele
|
||||
* target_ids \[num_targets](string) - target ids
|
||||
* target_labels \[num_targets](string) - target names
|
||||
* SAD \[num_snps, num_targets](float16) - SNP Activity Difference (SAD)
|
||||
scores - main variant effect score computed as `model(alternate_sequence) -
|
||||
model(reference_sequence)`.
|
||||
* SAR \[num_snps, num_targets](float16) - Same as SAD, by computing
|
||||
`np.log2(1 + model(alternate_sequence)) - np.log2(1 +
|
||||
model(reference_sequence))`
|
||||
|
||||
Furthermore, we provide the top 20 principal components of variant-effect scores
|
||||
in the PC20
|
||||
folder stored as a tabix-indexed TSV file per chromosome (HG19 reference
|
||||
genome). The format of these files has the following columns:
|
||||
|
||||
* #CHROM - chromosome (chr1)
|
||||
* POS - variant position (1-based)
|
||||
* ID - dbSNP identifier
|
||||
* REF - reference allele (e.g. A)
|
||||
* ALT - alternate allele (e.g. T)
|
||||
* PC{i} - i-th principal component of the variant effect prediction.
|
||||
|
||||
All model predictions are licensed under
|
||||
[CC-BY 4.0 license](https://creativecommons.org/licenses/by/4.0/).
|
||||
|
||||
## Running Inference
|
||||
|
||||
The simplest way to perform inference is to load the model via tfhub.dev (TODO:
|
||||
LINK). The input sequence length is 393,216 with the prediction corresponding to
|
||||
128 base pair windows of the center 114,688 base pairs. The input sequence is
|
||||
one hot encoded using the order of indices being 'ACGT' with N values being all
|
||||
zeros.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
enformer = hub.Module("https://tfhub.dev/deepmind/enformer/...")
|
||||
|
||||
SEQ_LENGTH = 393_216
|
||||
|
||||
# Numpy array [batch_size, SEQ_LENGTH, 4] one hot encoded in order 'ACGT'. The
|
||||
# `one_hot_encode` function is available in `enformer.py` and outputs can be
|
||||
# stacked to form a batch.
|
||||
inputs = …
|
||||
predictions = enformer.predict_on_batch(inputs)
|
||||
predictions['human'].shape # [batch_size, 896, 5313]
|
||||
predictions[mouse].shape # [batch_size, 896, 1643]
|
||||
```
|
||||
|
||||
## Outputs
|
||||
|
||||
For each 128 bp window, predictions are made for every track. The mapping from
|
||||
track idx to track name is found in the corresponding file in the basenji
|
||||
[dataset](https://github.com/calico/basenji/tree/master/manuscripts/cross2020)
|
||||
folder (targets_{organism}.txt file).
|
||||
|
||||
As an example, to load track annotations for the human targets:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
targets_txt = 'https://raw.githubusercontent.com/calico/basenji/0.5/manuscripts/cross2020/targets_human.txt'
|
||||
df_targets = pd.read_csv(targets_txt, sep='\t')
|
||||
df_targets.shape # (5313, 8) With rows match output shape above.
|
||||
```
|
||||
|
||||
## Modeling Code
|
||||
|
||||
The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The
|
||||
full sonnet module is defined in `enformer.py` called Enformer. See the Sonnet
|
||||
documentation on distributed training
|
||||
[here](https://github.com/deepmind/sonnet#distributed-training).
|
||||
|
||||
## Colab
|
||||
|
||||
Further examples are given in the notebooks `enformer-example.ipynb`
|
||||
[](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-example.ipynb).
|
||||
|
||||
This shows how to:
|
||||
|
||||
* Make predictions with Enformer and reproduce Fig. 1d
|
||||
* Compute contribution scores and reproduce parts of Fig. 2a
|
||||
* Predict the effect of a genetic variant and reproduce parts of Fig. 3g
|
||||
* Score multiple variants in a VCF
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official Google product.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -0,0 +1,313 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tensorflow implementation of Enformer model.
|
||||
|
||||
"Effective gene expression prediction from sequence by integrating long-range
|
||||
interactions"
|
||||
|
||||
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
|
||||
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
|
||||
Pushmeet Kohli1, David R. Kelley2*
|
||||
|
||||
1 DeepMind, London, UK
|
||||
2 Calico Life Sciences, South San Francisco, CA, USA
|
||||
3 Google, Tokyo, Japan
|
||||
4 These authors contributed equally.
|
||||
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
|
||||
"""
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable
|
||||
|
||||
import attention_module
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
SEQUENCE_LENGTH = 196_608
|
||||
BIN_SIZE = 128
|
||||
TARGET_LENGTH = 896
|
||||
|
||||
|
||||
class Enformer(snt.Module):
|
||||
"""Main model."""
|
||||
|
||||
def __init__(self,
|
||||
channels: int = 1536,
|
||||
num_transformer_layers: int = 11,
|
||||
name: str = 'enformer'):
|
||||
"""Enformer model.
|
||||
|
||||
Args:
|
||||
channels: Number of convolutional filters.
|
||||
num_transformer_layers: Number of transformer layers.
|
||||
name: Name of sonnet module.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
# pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
|
||||
heads_channels = {'human': 5313, 'mouse': 1643}
|
||||
dropout_rate = 0.4
|
||||
whole_attention_kwargs = {
|
||||
'attention_dropout_rate': 0.05,
|
||||
'initializer': None,
|
||||
'key_size': 64,
|
||||
'num_heads': 8,
|
||||
'num_relative_position_features': 192,
|
||||
'positional_dropout_rate': 0.01,
|
||||
'relative_position_functions': [
|
||||
'positional_features_exponential',
|
||||
'positional_features_central_mask',
|
||||
'positional_features_gamma'
|
||||
],
|
||||
'relative_positions': True,
|
||||
'scaling': True,
|
||||
'value_size': 192,
|
||||
'zero_initialize': True
|
||||
}
|
||||
|
||||
trunk_name_scope = tf.name_scope('trunk')
|
||||
trunk_name_scope.__enter__()
|
||||
# lambda is used in Sequential to construct the module under tf.name_scope.
|
||||
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
|
||||
return Sequential(lambda: [
|
||||
snt.BatchNorm(create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
scale_init=snt.initializers.Ones()),
|
||||
gelu,
|
||||
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
|
||||
], name=name)
|
||||
|
||||
stem = Sequential(lambda: [
|
||||
snt.Conv1D(channels // 2, 15),
|
||||
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
|
||||
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
|
||||
], name='stem')
|
||||
|
||||
filter_list = exponential_linspace_int(start=channels // 2, end=channels,
|
||||
num=6, divisible_by=128)
|
||||
conv_tower = Sequential(lambda: [
|
||||
Sequential(lambda: [
|
||||
conv_block(num_filters, 5),
|
||||
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
|
||||
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
|
||||
],
|
||||
name=f'conv_tower_block_{i}')
|
||||
for i, num_filters in enumerate(filter_list)], name='conv_tower')
|
||||
|
||||
# Transformer.
|
||||
def transformer_mlp():
|
||||
return Sequential(lambda: [
|
||||
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
|
||||
snt.Linear(channels * 2),
|
||||
snt.Dropout(dropout_rate),
|
||||
tf.nn.relu,
|
||||
snt.Linear(channels),
|
||||
snt.Dropout(dropout_rate)], name='mlp')
|
||||
|
||||
transformer = Sequential(lambda: [
|
||||
Sequential(lambda: [
|
||||
Residual(Sequential(lambda: [
|
||||
snt.LayerNorm(axis=-1,
|
||||
create_scale=True, create_offset=True,
|
||||
scale_init=snt.initializers.Ones()),
|
||||
attention_module.MultiheadAttention(**whole_attention_kwargs,
|
||||
name=f'attention_{i}'),
|
||||
snt.Dropout(dropout_rate)], name='mha')),
|
||||
Residual(transformer_mlp())], name=f'transformer_block_{i}')
|
||||
for i in range(num_transformer_layers)], name='transformer')
|
||||
|
||||
crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')
|
||||
|
||||
final_pointwise = Sequential(lambda: [
|
||||
conv_block(channels * 2, 1),
|
||||
snt.Dropout(dropout_rate / 8),
|
||||
gelu], name='final_pointwise')
|
||||
|
||||
self._trunk = Sequential([stem,
|
||||
conv_tower,
|
||||
transformer,
|
||||
crop_final,
|
||||
final_pointwise],
|
||||
name='trunk')
|
||||
trunk_name_scope.__exit__(None, None, None)
|
||||
|
||||
with tf.name_scope('heads'):
|
||||
self._heads = {
|
||||
head: Sequential(
|
||||
lambda: [snt.Linear(num_channels), tf.nn.softplus],
|
||||
name=f'head_{head}')
|
||||
for head, num_channels in heads_channels.items()
|
||||
}
|
||||
# pylint: enable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
|
||||
|
||||
@property
|
||||
def trunk(self):
|
||||
return self._trunk
|
||||
|
||||
@property
|
||||
def heads(self):
|
||||
return self._heads
|
||||
|
||||
def __call__(self, inputs: tf.Tensor,
|
||||
is_training: bool) -> Dict[str, tf.Tensor]:
|
||||
trunk_embedding = self.trunk(inputs, is_training=is_training)
|
||||
return {
|
||||
head: head_module(trunk_embedding, is_training=is_training)
|
||||
for head, head_module in self.heads.items()
|
||||
}
|
||||
|
||||
@tf.function(input_signature=[
|
||||
tf.TensorSpec([None, SEQUENCE_LENGTH, 4], tf.float32)])
|
||||
def predict_on_batch(self, x):
|
||||
"""Method for SavedModel."""
|
||||
return self(x, is_training=False)
|
||||
|
||||
|
||||
class TargetLengthCrop1D(snt.Module):
|
||||
"""Crop sequence to match the desired target length."""
|
||||
|
||||
def __init__(self, target_length: int, name='target_length_crop'):
|
||||
super().__init__(name=name)
|
||||
self._target_length = target_length
|
||||
|
||||
def __call__(self, inputs):
|
||||
trim = (inputs.shape[-2] - self._target_length) // 2
|
||||
if trim < 0:
|
||||
raise ValueError('inputs longer than target length')
|
||||
|
||||
return inputs[..., trim:-trim, :]
|
||||
|
||||
|
||||
class Sequential(snt.Module):
|
||||
"""snt.Sequential automatically passing is_training where it exists."""
|
||||
|
||||
def __init__(self,
|
||||
layers: Optional[Union[Callable[[], Iterable[snt.Module]],
|
||||
Iterable[Callable[..., Any]]]] = None,
|
||||
name: Optional[Text] = None):
|
||||
super().__init__(name=name)
|
||||
if layers is None:
|
||||
self._layers = []
|
||||
else:
|
||||
# layers wrapped in a lambda function to have a common namespace.
|
||||
if hasattr(layers, '__call__'):
|
||||
with tf.name_scope(name):
|
||||
layers = layers()
|
||||
self._layers = [layer for layer in layers if layer is not None]
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
|
||||
outputs = inputs
|
||||
for _, mod in enumerate(self._layers):
|
||||
if accepts_is_training(mod):
|
||||
outputs = mod(outputs, is_training=is_training, **kwargs)
|
||||
else:
|
||||
outputs = mod(outputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
class SoftmaxPooling1D(snt.Module):
|
||||
"""Pooling operation with optional weights."""
|
||||
|
||||
def __init__(self,
|
||||
pool_size: int = 2,
|
||||
per_channel: bool = False,
|
||||
w_init_scale: float = 0.0,
|
||||
name: str = 'softmax_pooling'):
|
||||
"""Softmax pooling.
|
||||
|
||||
Args:
|
||||
pool_size: Pooling size, same as in Max/AvgPooling.
|
||||
per_channel: If True, the logits/softmax weights will be computed for
|
||||
each channel separately. If False, same weights will be used across all
|
||||
channels.
|
||||
w_init_scale: When 0.0 is equivalent to avg pooling, and when
|
||||
~2.0 and `per_channel=False` it's equivalent to max pooling.
|
||||
name: Module name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._pool_size = pool_size
|
||||
self._per_channel = per_channel
|
||||
self._w_init_scale = w_init_scale
|
||||
self._logit_linear = None
|
||||
|
||||
@snt.once
|
||||
def _initialize(self, num_features):
|
||||
self._logit_linear = snt.Linear(
|
||||
output_size=num_features if self._per_channel else 1,
|
||||
with_bias=False, # Softmax is agnostic to shifts.
|
||||
w_init=snt.initializers.Identity(self._w_init_scale))
|
||||
|
||||
def __call__(self, inputs):
|
||||
_, length, num_features = inputs.shape
|
||||
self._initialize(num_features)
|
||||
inputs = tf.reshape(
|
||||
inputs,
|
||||
(-1, length // self._pool_size, self._pool_size, num_features))
|
||||
return tf.reduce_sum(
|
||||
inputs * tf.nn.softmax(self._logit_linear(inputs), axis=-2),
|
||||
axis=-2)
|
||||
|
||||
|
||||
class Residual(snt.Module):
|
||||
"""Residual block."""
|
||||
|
||||
def __init__(self, module: snt.Module, name='residual'):
|
||||
super().__init__(name=name)
|
||||
self._module = module
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, is_training: bool, *args,
|
||||
**kwargs) -> tf.Tensor:
|
||||
return inputs + self._module(inputs, is_training, *args, **kwargs)
|
||||
|
||||
|
||||
def gelu(x: tf.Tensor) -> tf.Tensor:
|
||||
"""Applies the Gaussian error linear unit (GELU) activation function.
|
||||
|
||||
Using approximiation in section 2 of the original paper:
|
||||
https://arxiv.org/abs/1606.08415
|
||||
|
||||
Args:
|
||||
x: Input tensor to apply gelu activation.
|
||||
Returns:
|
||||
Tensor with gelu activation applied to it.
|
||||
"""
|
||||
return tf.nn.sigmoid(1.702 * x) * x
|
||||
|
||||
|
||||
def one_hot_encode(sequence: str,
|
||||
alphabet: str = 'ACGT',
|
||||
neutral_alphabet: str = 'N',
|
||||
neutral_value: Any = 0,
|
||||
dtype=np.float64) -> np.ndarray:
|
||||
"""One-hot encode sequence."""
|
||||
def to_uint8(string):
|
||||
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
|
||||
hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype)
|
||||
hash_table[to_uint8(alphabet)] = np.eye(len(alphabet), dtype=dtype)
|
||||
hash_table[to_uint8(neutral_alphabet)] = neutral_value
|
||||
hash_table = hash_table.astype(dtype)
|
||||
return hash_table[to_uint8(sequence)]
|
||||
|
||||
|
||||
def exponential_linspace_int(start, end, num, divisible_by=1):
|
||||
"""Exponentially increasing values of integers."""
|
||||
def _round(x):
|
||||
return int(np.round(x / divisible_by) * divisible_by)
|
||||
|
||||
base = np.exp(np.log(end / start) / (num - 1))
|
||||
return [_round(start * base**i) for i in range(num)]
|
||||
|
||||
|
||||
def accepts_is_training(module):
|
||||
return 'is_training' in list(inspect.signature(module.__call__).parameters)
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test enformer model by applying random sequence as input.
|
||||
|
||||
Test:
|
||||
|
||||
$ python enformer_test.py
|
||||
"""
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import enformer
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestEnformer(unittest.TestCase):
|
||||
|
||||
def test_enformer(self):
|
||||
model = enformer.Enformer(channels=1536, num_transformer_layers=11)
|
||||
inputs = _get_random_input()
|
||||
outputs = model(inputs, is_training=True)
|
||||
self.assertEqual(outputs['human'].shape, (1, enformer.TARGET_LENGTH, 5313))
|
||||
self.assertEqual(outputs['mouse'].shape, (1, enformer.TARGET_LENGTH, 1643))
|
||||
|
||||
|
||||
def _get_random_input():
|
||||
seq = ''.join(
|
||||
[random.choice('ACGT') for _ in range(enformer.SEQUENCE_LENGTH)])
|
||||
return np.expand_dims(enformer.one_hot_encode(seq), 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,6 @@
|
||||
dm-sonnet==2.0.0
|
||||
kipoiseq==0.5.2
|
||||
numpy==1.19.5
|
||||
pandas==1.2.3
|
||||
tensorflow==2.4.1
|
||||
tensorflow-hub==0.11.0
|
||||
@@ -4,8 +4,7 @@
|
||||
|
||||
`dendritic_gated_network.ipynb` implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).
|
||||
|
||||
See our paper titled "A rapid and efficient learning rule for biological neural circuits" for details of the DGN model.
|
||||
|
||||
See our paper titled ["A rapid and efficient learning rule for biological neural circuits"](https://www.biorxiv.org/content/10.1101/2021.03.10.434756v1) for details of the DGN model.
|
||||
|
||||
### Instructions for running the `dendritic_gated_network.ipynb` colab/notebook.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"\n",
|
||||
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).\n",
|
||||
"\n",
|
||||
"See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n",
|
||||
"See our paper titled [\"A rapid and efficient learning rule for biological neural circuits\"](https://www.biorxiv.org/content/10.1101/2021.03.10.434756v1) for details of the DGN model.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Some implementation details:\n",
|
||||
|
||||
Reference in New Issue
Block a user