Added link to our paper.

PiperOrigin-RevId: 362909971
This commit is contained in:
Agnieszka Grabska-Barwińska
2021-03-15 11:27:17 +00:00
committed by Louise Deason
parent 5cf55efe1f
commit db5c562251
8 changed files with 2845 additions and 3 deletions
+146
View File
@@ -0,0 +1,146 @@
# Enformer
This package provides an implementation of the Enformer model and examples on
running the model.
If this source code or accompanying files are helpful for your research please
cite the following publication:
"Effective gene expression prediction from sequence by integrating long-range
interactions"
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
Pushmeet Kohli1, David R. Kelley2*
1 DeepMind, London, UK
2 Calico Life Sciences, South San Francisco, CA, USA
3 Google, Tokyo, Japan
4 These authors contributed equally.
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
## Setup
Requirements:
* dm-sonnet (2.0.0)
* kipoiseq (0.5.2)
* numpy (1.19.5)
* pandas (1.2.3)
* tensoflow (2.4.1)
* tensorflow-hub (0.11.0)
See `requirements.txt`.
To run the unit test:
```shell
python3.8 -m venv enformer_venv
source enformer_venv/bin/activate
pip install -r requirements.txt
python -m enformer_test
```
## Pre-computed variant effect predictions
We precomputed variant effect scores for all frequent variants (MAF>0.5%, in any
population) present in the 1000 genomes project. Variant scores in HDF5 file
format per chromosome for HG19 reference genome can be found
[here](TODO).
The HDF5 file has the same format as the output of
[this](https://github.com/calico/basenji/blob/738321c85f8925ae6ac318a6cd4901a42ea6bc3f/bin/basenji_sad.py#L264)
script and contains the following arrays:
* snp \[num_snps](string) - snp id
* chr \[num_snps](string) - chromosome name
* pos \[num_snps](uint32) - position (1-based)
* ref \[num_snps](string) - reference allele
* alt \[num_snps](string) - alternative allele
* target_ids \[num_targets](string) - target ids
* target_labels \[num_targets](string) - target names
* SAD \[num_snps, num_targets](float16) - SNP Activity Difference (SAD)
scores - main variant effect score computed as `model(alternate_sequence) -
model(reference_sequence)`.
* SAR \[num_snps, num_targets](float16) - Same as SAD, by computing
`np.log2(1 + model(alternate_sequence)) - np.log2(1 +
model(reference_sequence))`
Furthermore, we provide the top 20 principal components of variant-effect scores
in the PC20
folder stored as a tabix-indexed TSV file per chromosome (HG19 reference
genome). The format of these files has the following columns:
* #CHROM - chromosome (chr1)
* POS - variant position (1-based)
* ID - dbSNP identifier
* REF - reference allele (e.g. A)
* ALT - alternate allele (e.g. T)
* PC{i} - i-th principal component of the variant effect prediction.
All model predictions are licensed under
[CC-BY 4.0 license](https://creativecommons.org/licenses/by/4.0/).
## Running Inference
The simplest way to perform inference is to load the model via tfhub.dev (TODO:
LINK). The input sequence length is 393,216 with the prediction corresponding to
128 base pair windows of the center 114,688 base pairs. The input sequence is
one hot encoded using the order of indices being 'ACGT' with N values being all
zeros.
```python
import tensorflow as tf
import tensorflow_hub as hub
enformer = hub.Module("https://tfhub.dev/deepmind/enformer/...")
SEQ_LENGTH = 393_216
# Numpy array [batch_size, SEQ_LENGTH, 4] one hot encoded in order 'ACGT'. The
# `one_hot_encode` function is available in `enformer.py` and outputs can be
# stacked to form a batch.
inputs =
predictions = enformer.predict_on_batch(inputs)
predictions['human'].shape # [batch_size, 896, 5313]
predictions[mouse].shape # [batch_size, 896, 1643]
```
## Outputs
For each 128 bp window, predictions are made for every track. The mapping from
track idx to track name is found in the corresponding file in the basenji
[dataset](https://github.com/calico/basenji/tree/master/manuscripts/cross2020)
folder (targets_{organism}.txt file).
As an example, to load track annotations for the human targets:
```python
import pandas as pd
targets_txt = 'https://raw.githubusercontent.com/calico/basenji/0.5/manuscripts/cross2020/targets_human.txt'
df_targets = pd.read_csv(targets_txt, sep='\t')
df_targets.shape # (5313, 8) With rows match output shape above.
```
## Modeling Code
The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The
full sonnet module is defined in `enformer.py` called Enformer. See the Sonnet
documentation on distributed training
[here](https://github.com/deepmind/sonnet#distributed-training).
## Colab
Further examples are given in the notebooks `enformer-example.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-example.ipynb).
This shows how to:
* Make predictions with Enformer and reproduce Fig. 1d
* Compute contribution scores and reproduce parts of Fig. 2a
* Predict the effect of a genetic variant and reproduce parts of Fig. 3g
* Score multiple variants in a VCF
## Disclaimer
This is not an official Google product.
File diff suppressed because it is too large Load Diff
File diff suppressed because one or more lines are too long
+313
View File
@@ -0,0 +1,313 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow implementation of Enformer model.
"Effective gene expression prediction from sequence by integrating long-range
interactions"
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
Pushmeet Kohli1, David R. Kelley2*
1 DeepMind, London, UK
2 Calico Life Sciences, South San Francisco, CA, USA
3 Google, Tokyo, Japan
4 These authors contributed equally.
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
"""
import inspect
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable
import attention_module
import numpy as np
import sonnet as snt
import tensorflow as tf
SEQUENCE_LENGTH = 196_608
BIN_SIZE = 128
TARGET_LENGTH = 896
class Enformer(snt.Module):
"""Main model."""
def __init__(self,
channels: int = 1536,
num_transformer_layers: int = 11,
name: str = 'enformer'):
"""Enformer model.
Args:
channels: Number of convolutional filters.
num_transformer_layers: Number of transformer layers.
name: Name of sonnet module.
"""
super().__init__(name=name)
# pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
heads_channels = {'human': 5313, 'mouse': 1643}
dropout_rate = 0.4
whole_attention_kwargs = {
'attention_dropout_rate': 0.05,
'initializer': None,
'key_size': 64,
'num_heads': 8,
'num_relative_position_features': 192,
'positional_dropout_rate': 0.01,
'relative_position_functions': [
'positional_features_exponential',
'positional_features_central_mask',
'positional_features_gamma'
],
'relative_positions': True,
'scaling': True,
'value_size': 192,
'zero_initialize': True
}
trunk_name_scope = tf.name_scope('trunk')
trunk_name_scope.__enter__()
# lambda is used in Sequential to construct the module under tf.name_scope.
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
return Sequential(lambda: [
snt.BatchNorm(create_scale=True,
create_offset=True,
decay_rate=0.9,
scale_init=snt.initializers.Ones()),
gelu,
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
], name=name)
stem = Sequential(lambda: [
snt.Conv1D(channels // 2, 15),
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
], name='stem')
filter_list = exponential_linspace_int(start=channels // 2, end=channels,
num=6, divisible_by=128)
conv_tower = Sequential(lambda: [
Sequential(lambda: [
conv_block(num_filters, 5),
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
],
name=f'conv_tower_block_{i}')
for i, num_filters in enumerate(filter_list)], name='conv_tower')
# Transformer.
def transformer_mlp():
return Sequential(lambda: [
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
snt.Linear(channels * 2),
snt.Dropout(dropout_rate),
tf.nn.relu,
snt.Linear(channels),
snt.Dropout(dropout_rate)], name='mlp')
transformer = Sequential(lambda: [
Sequential(lambda: [
Residual(Sequential(lambda: [
snt.LayerNorm(axis=-1,
create_scale=True, create_offset=True,
scale_init=snt.initializers.Ones()),
attention_module.MultiheadAttention(**whole_attention_kwargs,
name=f'attention_{i}'),
snt.Dropout(dropout_rate)], name='mha')),
Residual(transformer_mlp())], name=f'transformer_block_{i}')
for i in range(num_transformer_layers)], name='transformer')
crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')
final_pointwise = Sequential(lambda: [
conv_block(channels * 2, 1),
snt.Dropout(dropout_rate / 8),
gelu], name='final_pointwise')
self._trunk = Sequential([stem,
conv_tower,
transformer,
crop_final,
final_pointwise],
name='trunk')
trunk_name_scope.__exit__(None, None, None)
with tf.name_scope('heads'):
self._heads = {
head: Sequential(
lambda: [snt.Linear(num_channels), tf.nn.softplus],
name=f'head_{head}')
for head, num_channels in heads_channels.items()
}
# pylint: enable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
@property
def trunk(self):
return self._trunk
@property
def heads(self):
return self._heads
def __call__(self, inputs: tf.Tensor,
is_training: bool) -> Dict[str, tf.Tensor]:
trunk_embedding = self.trunk(inputs, is_training=is_training)
return {
head: head_module(trunk_embedding, is_training=is_training)
for head, head_module in self.heads.items()
}
@tf.function(input_signature=[
tf.TensorSpec([None, SEQUENCE_LENGTH, 4], tf.float32)])
def predict_on_batch(self, x):
"""Method for SavedModel."""
return self(x, is_training=False)
class TargetLengthCrop1D(snt.Module):
"""Crop sequence to match the desired target length."""
def __init__(self, target_length: int, name='target_length_crop'):
super().__init__(name=name)
self._target_length = target_length
def __call__(self, inputs):
trim = (inputs.shape[-2] - self._target_length) // 2
if trim < 0:
raise ValueError('inputs longer than target length')
return inputs[..., trim:-trim, :]
class Sequential(snt.Module):
"""snt.Sequential automatically passing is_training where it exists."""
def __init__(self,
layers: Optional[Union[Callable[[], Iterable[snt.Module]],
Iterable[Callable[..., Any]]]] = None,
name: Optional[Text] = None):
super().__init__(name=name)
if layers is None:
self._layers = []
else:
# layers wrapped in a lambda function to have a common namespace.
if hasattr(layers, '__call__'):
with tf.name_scope(name):
layers = layers()
self._layers = [layer for layer in layers if layer is not None]
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
outputs = inputs
for _, mod in enumerate(self._layers):
if accepts_is_training(mod):
outputs = mod(outputs, is_training=is_training, **kwargs)
else:
outputs = mod(outputs, **kwargs)
return outputs
class SoftmaxPooling1D(snt.Module):
"""Pooling operation with optional weights."""
def __init__(self,
pool_size: int = 2,
per_channel: bool = False,
w_init_scale: float = 0.0,
name: str = 'softmax_pooling'):
"""Softmax pooling.
Args:
pool_size: Pooling size, same as in Max/AvgPooling.
per_channel: If True, the logits/softmax weights will be computed for
each channel separately. If False, same weights will be used across all
channels.
w_init_scale: When 0.0 is equivalent to avg pooling, and when
~2.0 and `per_channel=False` it's equivalent to max pooling.
name: Module name.
"""
super().__init__(name=name)
self._pool_size = pool_size
self._per_channel = per_channel
self._w_init_scale = w_init_scale
self._logit_linear = None
@snt.once
def _initialize(self, num_features):
self._logit_linear = snt.Linear(
output_size=num_features if self._per_channel else 1,
with_bias=False, # Softmax is agnostic to shifts.
w_init=snt.initializers.Identity(self._w_init_scale))
def __call__(self, inputs):
_, length, num_features = inputs.shape
self._initialize(num_features)
inputs = tf.reshape(
inputs,
(-1, length // self._pool_size, self._pool_size, num_features))
return tf.reduce_sum(
inputs * tf.nn.softmax(self._logit_linear(inputs), axis=-2),
axis=-2)
class Residual(snt.Module):
"""Residual block."""
def __init__(self, module: snt.Module, name='residual'):
super().__init__(name=name)
self._module = module
def __call__(self, inputs: tf.Tensor, is_training: bool, *args,
**kwargs) -> tf.Tensor:
return inputs + self._module(inputs, is_training, *args, **kwargs)
def gelu(x: tf.Tensor) -> tf.Tensor:
"""Applies the Gaussian error linear unit (GELU) activation function.
Using approximiation in section 2 of the original paper:
https://arxiv.org/abs/1606.08415
Args:
x: Input tensor to apply gelu activation.
Returns:
Tensor with gelu activation applied to it.
"""
return tf.nn.sigmoid(1.702 * x) * x
def one_hot_encode(sequence: str,
alphabet: str = 'ACGT',
neutral_alphabet: str = 'N',
neutral_value: Any = 0,
dtype=np.float64) -> np.ndarray:
"""One-hot encode sequence."""
def to_uint8(string):
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype)
hash_table[to_uint8(alphabet)] = np.eye(len(alphabet), dtype=dtype)
hash_table[to_uint8(neutral_alphabet)] = neutral_value
hash_table = hash_table.astype(dtype)
return hash_table[to_uint8(sequence)]
def exponential_linspace_int(start, end, num, divisible_by=1):
"""Exponentially increasing values of integers."""
def _round(x):
return int(np.round(x / divisible_by) * divisible_by)
base = np.exp(np.log(end / start) / (num - 1))
return [_round(start * base**i) for i in range(num)]
def accepts_is_training(module):
return 'is_training' in list(inspect.signature(module.__call__).parameters)
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test enformer model by applying random sequence as input.
Test:
$ python enformer_test.py
"""
import random
import unittest
import enformer
import numpy as np
class TestEnformer(unittest.TestCase):
def test_enformer(self):
model = enformer.Enformer(channels=1536, num_transformer_layers=11)
inputs = _get_random_input()
outputs = model(inputs, is_training=True)
self.assertEqual(outputs['human'].shape, (1, enformer.TARGET_LENGTH, 5313))
self.assertEqual(outputs['mouse'].shape, (1, enformer.TARGET_LENGTH, 1643))
def _get_random_input():
seq = ''.join(
[random.choice('ACGT') for _ in range(enformer.SEQUENCE_LENGTH)])
return np.expand_dims(enformer.one_hot_encode(seq), 0)
if __name__ == '__main__':
unittest.main()
+6
View File
@@ -0,0 +1,6 @@
dm-sonnet==2.0.0
kipoiseq==0.5.2
numpy==1.19.5
pandas==1.2.3
tensorflow==2.4.1
tensorflow-hub==0.11.0
+1 -2
View File
@@ -4,8 +4,7 @@
`dendritic_gated_network.ipynb` implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).
See our paper titled "A rapid and efficient learning rule for biological neural circuits" for details of the DGN model.
See our paper titled ["A rapid and efficient learning rule for biological neural circuits"](https://www.biorxiv.org/content/10.1101/2021.03.10.434756v1) for details of the DGN model.
### Instructions for running the `dendritic_gated_network.ipynb` colab/notebook.
@@ -10,7 +10,7 @@
"\n",
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).\n",
"\n",
"See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n",
"See our paper titled [\"A rapid and efficient learning rule for biological neural circuits\"](https://www.biorxiv.org/content/10.1101/2021.03.10.434756v1) for details of the DGN model.\n",
"\n",
"\n",
"Some implementation details:\n",