mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 12:37:43 +08:00
Update README links and add enformer-training.ipynb.
PiperOrigin-RevId: 365241561
This commit is contained in:
committed by
Louise Deason
parent
c8fa651499
commit
af3aa09cfe
+26
-23
@@ -9,15 +9,9 @@ cite the following publication:
|
||||
"Effective gene expression prediction from sequence by integrating long-range
|
||||
interactions"
|
||||
|
||||
Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
|
||||
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
|
||||
Pushmeet Kohli1, David R. Kelley2*
|
||||
|
||||
1 DeepMind, London, UK
|
||||
2 Calico Life Sciences, South San Francisco, CA, USA
|
||||
3 Google, Tokyo, Japan
|
||||
4 These authors contributed equally.
|
||||
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
|
||||
Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam,
|
||||
Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper,
|
||||
Pushmeet Kohli, David R. Kelley
|
||||
|
||||
## Setup
|
||||
|
||||
@@ -46,7 +40,7 @@ python -m enformer_test
|
||||
We precomputed variant effect scores for all frequent variants (MAF>0.5%, in any
|
||||
population) present in the 1000 genomes project. Variant scores in HDF5 file
|
||||
format per chromosome for HG19 reference genome can be found
|
||||
[here](TODO).
|
||||
[here](https://console.cloud.google.com/storage/browser/dm-enformer/variant-scores/1000-genomes/enformer).
|
||||
The HDF5 file has the same format as the output of
|
||||
[this](https://github.com/calico/basenji/blob/738321c85f8925ae6ac318a6cd4901a42ea6bc3f/bin/basenji_sad.py#L264)
|
||||
script and contains the following arrays:
|
||||
@@ -66,8 +60,8 @@ script and contains the following arrays:
|
||||
model(reference_sequence))`
|
||||
|
||||
Furthermore, we provide the top 20 principal components of variant-effect scores
|
||||
in the PC20
|
||||
folder stored as a tabix-indexed TSV file per chromosome (HG19 reference
|
||||
in the [PC20 folder](https://console.cloud.google.com/storage/browser/dm-enformer/variant-scores/1000-genomes/enformer/PC20)
|
||||
stored as a tabix-indexed TSV file per chromosome (HG19 reference
|
||||
genome). The format of these files has the following columns:
|
||||
|
||||
* #CHROM - chromosome (chr1)
|
||||
@@ -92,14 +86,14 @@ zeros.
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
enformer = hub.Module("https://tfhub.dev/deepmind/enformer/...")
|
||||
enformer = hub.Module('https://tfhub.dev/deepmind/enformer/1')
|
||||
|
||||
SEQ_LENGTH = 393_216
|
||||
|
||||
# Numpy array [batch_size, SEQ_LENGTH, 4] one hot encoded in order 'ACGT'. The
|
||||
# `one_hot_encode` function is available in `enformer.py` and outputs can be
|
||||
# stacked to form a batch.
|
||||
inputs = …
|
||||
inputs = tf.zeros((1, SEQ_LENGTH, 4), dtype=tf.float32)
|
||||
predictions = enformer.predict_on_batch(inputs)
|
||||
predictions['human'].shape # [batch_size, 896, 5313]
|
||||
predictions[mouse].shape # [batch_size, 896, 1643]
|
||||
@@ -121,25 +115,34 @@ df_targets = pd.read_csv(targets_txt, sep='\t')
|
||||
df_targets.shape # (5313, 8) With rows match output shape above.
|
||||
```
|
||||
|
||||
## Modeling Code
|
||||
## Training Code
|
||||
|
||||
The model is implemented using [Sonnet](https://github.com/deepmind/sonnet). The
|
||||
full sonnet module is defined in `enformer.py` called Enformer. See the Sonnet
|
||||
documentation on distributed training
|
||||
[here](https://github.com/deepmind/sonnet#distributed-training).
|
||||
full sonnet module is defined in `enformer.py` called Enformer. See
|
||||
[enformer-training.ipynb](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb).
|
||||
on how to train the model on Basenji2 data.
|
||||
|
||||
## Colab
|
||||
|
||||
Further examples are given in the notebooks `enformer-example.ipynb`
|
||||
[](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-example.ipynb).
|
||||
Further usage and training examples are given in the following colab notebooks:
|
||||
|
||||
### `enformer-usage.ipynb` [](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-usage.ipynb).
|
||||
|
||||
This shows how to:
|
||||
|
||||
* Make predictions with Enformer and reproduce Fig. 1d
|
||||
* Compute contribution scores and reproduce parts of Fig. 2a
|
||||
* Predict the effect of a genetic variant and reproduce parts of Fig. 3g
|
||||
* **Make predictions** with pre-trained Enformer and reproduce Fig. 1d
|
||||
* **Compute contribution scores** and reproduce parts of Fig. 2a
|
||||
* **Predict the effect of genetic variants** and reproduce parts of Fig. 3g
|
||||
* Score multiple variants in a VCF
|
||||
|
||||
### `enformer-training.ipynb` [](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/enformer/enformer-training.ipynb).
|
||||
|
||||
This colab shows how to:
|
||||
|
||||
* Setup training data by directly accessing the Basenji2 data on GCS
|
||||
* Train the model for a few steps on both human and mouse genomes
|
||||
* Evaluate the model on human and mouse genomes
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official Google product.
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+24
-6
@@ -45,24 +45,31 @@ class Enformer(snt.Module):
|
||||
def __init__(self,
|
||||
channels: int = 1536,
|
||||
num_transformer_layers: int = 11,
|
||||
num_heads: int = 8,
|
||||
pooling_type: str = 'attention',
|
||||
name: str = 'enformer'):
|
||||
"""Enformer model.
|
||||
|
||||
Args:
|
||||
channels: Number of convolutional filters.
|
||||
channels: Number of convolutional filters and the overall 'width' of the
|
||||
model.
|
||||
num_transformer_layers: Number of transformer layers.
|
||||
num_heads: Number of attention heads.
|
||||
pooling_type: Which pooling function to use. Options: 'attention' or max'.
|
||||
name: Name of sonnet module.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
# pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
|
||||
heads_channels = {'human': 5313, 'mouse': 1643}
|
||||
dropout_rate = 0.4
|
||||
assert channels % num_heads == 0, ('channels needs to be divisible '
|
||||
f'by {num_heads}')
|
||||
whole_attention_kwargs = {
|
||||
'attention_dropout_rate': 0.05,
|
||||
'initializer': None,
|
||||
'key_size': 64,
|
||||
'num_heads': 8,
|
||||
'num_relative_position_features': 192,
|
||||
'num_heads': num_heads,
|
||||
'num_relative_position_features': channels // num_heads,
|
||||
'positional_dropout_rate': 0.01,
|
||||
'relative_position_functions': [
|
||||
'positional_features_exponential',
|
||||
@@ -71,7 +78,7 @@ class Enformer(snt.Module):
|
||||
],
|
||||
'relative_positions': True,
|
||||
'scaling': True,
|
||||
'value_size': 192,
|
||||
'value_size': channels // num_heads,
|
||||
'zero_initialize': True
|
||||
}
|
||||
|
||||
@@ -91,7 +98,7 @@ class Enformer(snt.Module):
|
||||
stem = Sequential(lambda: [
|
||||
snt.Conv1D(channels // 2, 15),
|
||||
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
|
||||
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
|
||||
pooling_module(pooling_type, pool_size=2),
|
||||
], name='stem')
|
||||
|
||||
filter_list = exponential_linspace_int(start=channels // 2, end=channels,
|
||||
@@ -100,7 +107,7 @@ class Enformer(snt.Module):
|
||||
Sequential(lambda: [
|
||||
conv_block(num_filters, 5),
|
||||
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
|
||||
SoftmaxPooling1D(pool_size=2, per_channel=True, w_init_scale=2.0)
|
||||
pooling_module(pooling_type, pool_size=2),
|
||||
],
|
||||
name=f'conv_tower_block_{i}')
|
||||
for i, num_filters in enumerate(filter_list)], name='conv_tower')
|
||||
@@ -216,6 +223,17 @@ class Sequential(snt.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
def pooling_module(kind, pool_size):
|
||||
"""Pooling module wrapper."""
|
||||
if kind == 'attention':
|
||||
return SoftmaxPooling1D(pool_size=pool_size, per_channel=True,
|
||||
w_init_scale=2.0)
|
||||
elif kind == 'max':
|
||||
return tf.keras.layers.MaxPool1D(pool_size=pool_size, padding='same')
|
||||
else:
|
||||
raise ValueError(f'Invalid pooling kind: {kind}.')
|
||||
|
||||
|
||||
class SoftmaxPooling1D(snt.Module):
|
||||
"""Pooling operation with optional weights."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user