Update README links and add enformer-training.ipynb.

PiperOrigin-RevId: 365241561
This commit is contained in:
Ziga Avsec
2021-03-26 15:37:43 +00:00
committed by Louise Deason
parent c8fa651499
commit af3aa09cfe
5 changed files with 2781 additions and 1847 deletions
+26 -23
View File
@@ -9,15 +9,9 @@ cite the following publication:
"Effective gene expression prediction from sequence by integrating long-range "Effective gene expression prediction from sequence by integrating long-range
interactions" interactions"
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3, Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam,
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper,
Pushmeet Kohli1, David R. Kelley2* Pushmeet Kohli, David R. Kelley
1 DeepMind, London, UK
2 Calico Life Sciences, South San Francisco, CA, USA
3 Google, Tokyo, Japan
4 These authors contributed equally.
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
## Setup ## Setup
@@ -46,7 +40,7 @@ python -m enformer_test
We precomputed variant effect scores for all frequent variants (MAF>0.5%, in any We precomputed variant effect scores for all frequent variants (MAF>0.5%, in any
population) present in the 1000 genomes project. Variant scores in HDF5 file population) present in the 1000 genomes project. Variant scores in HDF5 file
format per chromosome for HG19 reference genome can be found format per chromosome for HG19 reference genome can be found
[here](TODO). [here](https://console.cloud.google.com/storage/browser/dm-enformer/variant-scores/1000-genomes/enformer).
The HDF5 file has the same format as the output of The HDF5 file has the same format as the output of
[this](https://github.com/calico/basenji/blob/738321c85f8925ae6ac318a6cd4901a42ea6bc3f/bin/basenji_sad.py#L264) [this](https://github.com/calico/basenji/blob/738321c85f8925ae6ac318a6cd4901a42ea6bc3f/bin/basenji_sad.py#L264)
script and contains the following arrays: script and contains the following arrays:
@@ -66,8 +60,8 @@ script and contains the following arrays:
model(reference_sequence))` model(reference_sequence))`
Furthermore, we provide the top 20 principal components of variant-effect scores Furthermore, we provide the top 20 principal components of variant-effect scores
in the PC20 in the [PC20 folder](https://console.cloud.google.com/storage/browser/dm-enformer/variant-scores/1000-genomes/enformer/PC20)
folder stored as a tabix-indexed TSV file per chromosome (HG19 reference stored as a tabix-indexed TSV file per chromosome (HG19 reference
genome). The format of these files has the following columns: genome). The format of these files has the following columns:
* #CHROM - chromosome (chr1) * #CHROM - chromosome (chr1)
@@ -92,14 +86,14 @@ zeros.
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
enformer = hub.Module("https://tfhub.dev/deepmind/enformer/...") enformer = hub.Module('https://tfhub.dev/deepmind/enformer/1')
SEQ_LENGTH = 393_216 SEQ_LENGTH = 393_216
# Numpy array [batch_size, SEQ_LENGTH, 4] one hot encoded in order 'ACGT'. The # Numpy array [batch_size, SEQ_LENGTH, 4] one hot encoded in order 'ACGT'. The
# `one_hot_encode` function is available in `enformer.py` and outputs can be # `one_hot_encode` function is available in `enformer.py` and outputs can be
# stacked to form a batch. # stacked to form a batch.
inputs = inputs = tf.zeros((1, SEQ_LENGTH, 4), dtype=tf.float32)
predictions = enformer.predict_on_batch(inputs) predictions = enformer.predict_on_batch(inputs)
predictions['human'].shape # [batch_size, 896, 5313] predictions['human'].shape # [batch_size, 896, 5313]
predictions[mouse].shape # [batch_size, 896, 1643] predictions[mouse].shape # [batch_size, 896, 1643]
@@ -121,25 +115,34 @@ df_targets = pd.read_csv(targets_txt, sep='\t')
df_targets.shape # (5313, 8) With rows match output shape above. df_targets.shape # (5313, 8) With rows match output shape above.
``` ```
## Modeling Code ## Training Code
The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The
full sonnet module is defined in `enformer.py` called Enformer. See the Sonnet full sonnet module is defined in `enformer.py` called Enformer. See
documentation on distributed training [enformer-training.ipynb](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb).
[here](https://github.com/deepmind/sonnet#distributed-training). on how to train the model on Basenji2 data.
## Colab ## Colab
Further examples are given in the notebooks `enformer-example.ipynb` Further usage and training examples are given in the following colab notebooks:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-example.ipynb).
### `enformer-usage.ipynb` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-usage.ipynb).
This shows how to: This shows how to:
* Make predictions with Enformer and reproduce Fig. 1d * **Make predictions** with pre-trained Enformer and reproduce Fig. 1d
* Compute contribution scores and reproduce parts of Fig. 2a * **Compute contribution scores** and reproduce parts of Fig. 2a
* Predict the effect of a genetic variant and reproduce parts of Fig. 3g * **Predict the effect of genetic variants** and reproduce parts of Fig. 3g
* Score multiple variants in a VCF * Score multiple variants in a VCF
### `enformer-training.ipynb` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb).
This colab shows how to:
* Setup training data by directly accessing the Basenji2 data on GCS
* Train the model for a few steps on both human and mouse genomes
* Evaluate the model on human and mouse genomes
## Disclaimer ## Disclaimer
This is not an official Google product. This is not an official Google product.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+24 -6
View File
@@ -45,24 +45,31 @@ class Enformer(snt.Module):
def __init__(self, def __init__(self,
channels: int = 1536, channels: int = 1536,
num_transformer_layers: int = 11, num_transformer_layers: int = 11,
num_heads: int = 8,
pooling_type: str = 'attention',
name: str = 'enformer'): name: str = 'enformer'):
"""Enformer model. """Enformer model.
Args: Args:
channels: Number of convolutional filters. channels: Number of convolutional filters and the overall 'width' of the
model.
num_transformer_layers: Number of transformer layers. num_transformer_layers: Number of transformer layers.
num_heads: Number of attention heads.
pooling_type: Which pooling function to use. Options: 'attention' or max'.
name: Name of sonnet module. name: Name of sonnet module.
""" """
super().__init__(name=name) super().__init__(name=name)
# pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop # pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
heads_channels = {'human': 5313, 'mouse': 1643} heads_channels = {'human': 5313, 'mouse': 1643}
dropout_rate = 0.4 dropout_rate = 0.4
assert channels % num_heads == 0, ('channels needs to be divisible '
f'by {num_heads}')
whole_attention_kwargs = { whole_attention_kwargs = {
'attention_dropout_rate': 0.05, 'attention_dropout_rate': 0.05,
'initializer': None, 'initializer': None,
'key_size': 64, 'key_size': 64,
'num_heads': 8, 'num_heads': num_heads,
'num_relative_position_features': 192, 'num_relative_position_features': channels // num_heads,
'positional_dropout_rate': 0.01, 'positional_dropout_rate': 0.01,
'relative_position_functions': [ 'relative_position_functions': [
'positional_features_exponential', 'positional_features_exponential',
@@ -71,7 +78,7 @@ class Enformer(snt.Module):
], ],
'relative_positions': True, 'relative_positions': True,
'scaling': True, 'scaling': True,
'value_size': 192, 'value_size': channels // num_heads,
'zero_initialize': True 'zero_initialize': True
} }
@@ -91,7 +98,7 @@ class Enformer(snt.Module):
stem = Sequential(lambda: [ stem = Sequential(lambda: [
snt.Conv1D(channels // 2, 15), snt.Conv1D(channels // 2, 15),
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')), Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0) pooling_module(pooling_type, pool_size=2),
], name='stem') ], name='stem')
filter_list = exponential_linspace_int(start=channels // 2, end=channels, filter_list = exponential_linspace_int(start=channels // 2, end=channels,
@@ -100,7 +107,7 @@ class Enformer(snt.Module):
Sequential(lambda: [ Sequential(lambda: [
conv_block(num_filters, 5), conv_block(num_filters, 5),
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')), Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0) pooling_module(pooling_type, pool_size=2),
], ],
name=f'conv_tower_block_{i}') name=f'conv_tower_block_{i}')
for i, num_filters in enumerate(filter_list)], name='conv_tower') for i, num_filters in enumerate(filter_list)], name='conv_tower')
@@ -216,6 +223,17 @@ class Sequential(snt.Module):
return outputs return outputs
def pooling_module(kind, pool_size):
"""Pooling module wrapper."""
if kind == 'attention':
return SoftmaxPooling1D(pool_size=pool_size, per_channel=True,
w_init_scale=2.0)
elif kind == 'max':
return tf.keras.layers.MaxPool1D(pool_size=pool_size, padding='same')
else:
raise ValueError(f'Invalid pooling kind: {kind}.')
class SoftmaxPooling1D(snt.Module): class SoftmaxPooling1D(snt.Module):
"""Pooling operation with optional weights.""" """Pooling operation with optional weights."""