mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Explicitly replace "import tensorflow" with "tensorflow.compat.v1"
PiperOrigin-RevId: 287568660
This commit is contained in:
committed by
Diego de Las Casas
parent
5b22f16532
commit
8b47d2a576
151
alphafold_casp13/README.md
Normal file
151
alphafold_casp13/README.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# AlphaFold
|
||||
|
||||
This package provides an implementation of the contact prediction network,
|
||||
associated model weights and CASP13 dataset as published in Nature.
|
||||
|
||||
Any publication that discloses findings arising from using this source code must
|
||||
cite *AlphaFold: Protein structure prediction using potentials from deep
|
||||
learning* by Andrew W. Senior, Richard Evans, John Jumper, James Kirkpatrick,
|
||||
Laurent Sifre, Tim Green, Chongli Qin, Augustin Žídek, Alexander W. R. Nelson,
|
||||
Alex Bridgland, Hugo Penedones, Stig Petersen, Karen Simonyan, Steve Crossan,
|
||||
Pushmeet Kohli, David T. Jones, David Silver, Koray Kavukcuoglu, Demis Hassabis.
|
||||
|
||||
## Setup
|
||||
|
||||
### Dependencies
|
||||
|
||||
* Python 3.6+.
|
||||
* [Abseil 0.8.0+](https://github.com/abseil/abseil-py)
|
||||
* [Numpy 1.16+](https://numpy.org)
|
||||
* [Six 1.12+](https://pypi.org/project/six/)
|
||||
* [Sonnet 1.35+](https://github.com/deepmind/sonnet)
|
||||
* [TensorFlow 1.14](https://tensorflow.org). Not compatible with TensorFlow
|
||||
2.0+.
|
||||
* [TensorFlow Probability 0.7.0](https://www.tensorflow.org/probability)
|
||||
|
||||
You can set up Python virtual environment with these dependencies inside the
|
||||
forked `deepmind_research` repository using:
|
||||
|
||||
```shell
|
||||
python3 -m venv alphafold_venv
|
||||
source alphafold_venv/bin/activate
|
||||
pip install -r alphafold_casp13/requirements.txt
|
||||
```
|
||||
|
||||
### Input data
|
||||
|
||||
The dataset can be downloaded from
|
||||
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/alphafold_casp13_data).
|
||||
|
||||
Download it e.g. using `wget`:
|
||||
|
||||
```shell
|
||||
wget https://storage.googleapis.com/alphafold_casp13_data/casp13_data.zip
|
||||
```
|
||||
|
||||
The zip file contains 1 directory for each CASP13 target and a `LICENSE.md`
|
||||
file. Each target directory contains the following files:
|
||||
|
||||
1. `TARGET.tfrec` file. This is a
|
||||
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file
|
||||
with serialized tf.train.Example protocol buffers that contain the features
|
||||
needed to run the model.
|
||||
1. `contacts/TARGET.pickle` file(s) with the predicted distogram.
|
||||
1. `contacts/TARGET.rr` file(s) with the contact map derived from the predicted
|
||||
distogram. The RR format is described on the
|
||||
[CASP website](http://predictioncenter.org/casp13/index.cgi?page=format#RR).
|
||||
|
||||
Note that for **T0999** the target was manually split based on hits in HHSearch
|
||||
into 5 sub-targets, hence there are 5 distograms
|
||||
(`contacts/T0999s{1,2,3,4,5}.pickle`) and 5 RR files
|
||||
(`contacts/T0999s{1,2,3,4,5}.rr`).
|
||||
|
||||
The `contacts/` folder is not needed to run the model, these files are included
|
||||
only for convenience so that you don't need to run the inference for CASP13
|
||||
targets to get the contact map.
|
||||
|
||||
### Model checkpoints
|
||||
|
||||
The model checkpoints can be downloaded from
|
||||
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/alphafold_casp13_data).
|
||||
|
||||
Download them e.g. using `wget`:
|
||||
|
||||
```shell
|
||||
wget https://storage.googleapis.com/alphafold_casp13_data/alphafold_casp13_weights.zip
|
||||
```
|
||||
|
||||
The zip file contains:
|
||||
|
||||
1. A directory `873731`. This contains the weights for the distogram model.
|
||||
1. A directory `916425`. This contains the weights for the background distogram
|
||||
model.
|
||||
1. A directory `941521`. This contains the weights for the torsion model.
|
||||
1. `LICENSE.md`. The model checkpoints have a non-commercial license which is
|
||||
defined in this file.
|
||||
|
||||
Each directory with model weights contains a number of different model
|
||||
configurations. Each model has a config file and associated weights. There is
|
||||
only one torsion model. Each model directory also contains a stats file that is
|
||||
used for feature normalization specific to that model.
|
||||
|
||||
## Distogram prediction
|
||||
|
||||
### Running the system
|
||||
|
||||
You can use the `run_eval.sh` script to run the entire Distogram prediction
|
||||
system. There are a few steps you need to start with:
|
||||
|
||||
1. Download the input data as described above. Unpack the data in the
|
||||
directory with the code.
|
||||
1. Download the model checkpoints as described above. Unpack the data.
|
||||
1. In `run_eval.sh` set the following:
|
||||
* `DISTOGRAM_MODEL` to the path to the directory with the distogram model.
|
||||
* `BACKGROUND_MODEL` to the path to the directory with the background
|
||||
model.
|
||||
* `TORSION_MODEL` to the path to the directory with the torsion model.
|
||||
* `TARGET` to the path to the directory with the target input data.
|
||||
|
||||
Then run `alphafold_casp13/run_eval.sh` from the `deepmind_research` parent
|
||||
directory (you will get errors if you try running `run_eval.sh` directly from
|
||||
the `alphafold_casp13` directory).
|
||||
|
||||
The contact prediction works in the following way:
|
||||
|
||||
1. 4 replicas (by *replica* we mean a configuration file describing the network
|
||||
architecture and a snapshot with the network weights), each with slightly
|
||||
different model configuration, are launched to predict the distogram.
|
||||
1. 4 replicas, each with slightly different model configuration are launched to
|
||||
predict the background distogram.
|
||||
1. 1 replica is launched to predict the torsions.
|
||||
1. The predictions from the different replicas are averaged together using
|
||||
`ensemble_contact_maps.py`.
|
||||
1. The predictions for the 64 × 64 distogram crops are pasted together using
|
||||
`paste_contact_maps.py`.
|
||||
|
||||
When running `run_eval.sh` the output has the following directory structure:
|
||||
|
||||
* **distogram/**: Contains 4 subfolders, one for each replica. Each of these
|
||||
contain the predicted ASA, secondary structure and a pickle file with the
|
||||
distogram for each crop. It also contains an `ensemble` directory with the
|
||||
ensembled distograms.
|
||||
* **background_distogram/**: Contains 4 subfolders, one for each replica. Each
|
||||
of these contain a pickle file with the background distogram for each crop.
|
||||
It also contains an `ensemble` directory with the ensembled background
|
||||
distograms.
|
||||
* **torsion/**: Contains 1 subfolder as there was only a single replica. This
|
||||
folder contains contains the predicted ASA, secondary structure, backbone
|
||||
torsions and a pickle file with the distogram for each crop. It also
|
||||
contains an `ensemble` directory with the ensembled torsions.
|
||||
* **pasted/**: Contains distograms obtained from the ensembled distograms by
|
||||
pasting. An RR contact map file is computed from this pasted distogram.
|
||||
**This is the final distogram that was used in the subsequent AlphaFold
|
||||
folding pipeline in CASP13.**
|
||||
|
||||
## Data splits
|
||||
|
||||
We used a version of [PDB](https://www.rcsb.org/) downloaded on 2018-03-15. The
|
||||
train/test split can be found in the `train_domains.txt` and `test_domains.txt`
|
||||
files.
|
||||
|
||||
Disclaimer: This is not an official Google product.
|
||||
36
alphafold_casp13/asa_output.py
Normal file
36
alphafold_casp13/asa_output.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Class for predicting Accessible Surface Area."""
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.contrib import layers as contrib_layers
|
||||
|
||||
|
||||
class ASAOutputLayer(object):
|
||||
"""An output layer to predict Accessible Surface Area."""
|
||||
|
||||
def __init__(self, name='asa'):
|
||||
self.name = name
|
||||
|
||||
def compute_asa_output(self, activations):
|
||||
"""Just compute the logits and outputs given activations."""
|
||||
asa_logits = contrib_layers.linear(
|
||||
activations,
|
||||
1,
|
||||
weights_initializer=tf.random_uniform_initializer(-0.01, 0.01),
|
||||
scope='ASALogits')
|
||||
self.asa_output = tf.nn.relu(asa_logits, name='ASA_output_relu')
|
||||
|
||||
return asa_logits
|
||||
63
alphafold_casp13/config_dict.py
Normal file
63
alphafold_casp13/config_dict.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for storing configuration flags."""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class ConfigDict(dict):
|
||||
"""Configuration dictionary with convenient dot element access."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ConfigDict, self).__init__(*args, **kwargs)
|
||||
for arg in args:
|
||||
if isinstance(arg, dict):
|
||||
for key, value in arg.items():
|
||||
self._add(key, value)
|
||||
for key, value in kwargs.items():
|
||||
self._add(key, value)
|
||||
|
||||
def _add(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
self[key] = ConfigDict(value)
|
||||
else:
|
||||
self[key] = value
|
||||
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
return self[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self.__setitem__(key, value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(ConfigDict, self).__setitem__(key, value)
|
||||
self.__dict__.update({key: value})
|
||||
|
||||
def __delattr__(self, item):
|
||||
self.__delitem__(item)
|
||||
|
||||
def __delitem__(self, key):
|
||||
super(ConfigDict, self).__delitem__(key)
|
||||
del self.__dict__[key]
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_string):
|
||||
return cls(json.loads(json_string))
|
||||
394
alphafold_casp13/contacts.py
Normal file
394
alphafold_casp13/contacts.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Code to run distogram inference."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import config_dict
|
||||
from alphafold_casp13 import contacts_experiment
|
||||
from alphafold_casp13 import distogram_io
|
||||
from alphafold_casp13 import secstruct
|
||||
|
||||
flags.DEFINE_string('config_path', None, 'Path of the JSON config file.')
|
||||
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path for evaluation.')
|
||||
flags.DEFINE_boolean('cpu', False, 'Force onto CPU.')
|
||||
flags.DEFINE_string('output_path', None,
|
||||
'Base path where all output files will be saved to.')
|
||||
flags.DEFINE_string('eval_sstable', None,
|
||||
'Path of the SSTable to read the input tf.Examples from.')
|
||||
flags.DEFINE_string('stats_file', None,
|
||||
'Path of the statistics file to use for normalization.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# A named tuple to store the outputs of a single prediction run.
|
||||
Prediction = collections.namedtuple(
|
||||
'Prediction', [
|
||||
'single_message', # A debugging message.
|
||||
'num_crops_local', # The number of crops used to make this prediction.
|
||||
'sequence', # The amino acid sequence.
|
||||
'filebase', # The chain name. All output files will use this name.
|
||||
'softmax_probs', # Softmax of the distogram.
|
||||
'ss', # Secondary structure prediction.
|
||||
'asa', # ASA prediction.
|
||||
'torsions', # Torsion prediction.
|
||||
])
|
||||
|
||||
|
||||
def evaluate(crop_size_x, crop_size_y, feature_normalization, checkpoint_path,
|
||||
normalization_exclusion, eval_config, network_config):
|
||||
"""Main evaluation loop."""
|
||||
experiment = contacts_experiment.Contacts(
|
||||
tfrecord=eval_config.eval_sstable,
|
||||
stats_file=eval_config.stats_file,
|
||||
network_config=network_config,
|
||||
crop_size_x=crop_size_x,
|
||||
crop_size_y=crop_size_y,
|
||||
feature_normalization=feature_normalization,
|
||||
normalization_exclusion=normalization_exclusion)
|
||||
|
||||
checkpoint = snt.get_saver(experiment.model, collections=[
|
||||
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
|
||||
tf.compat.v1.GraphKeys.MOVING_AVERAGE_VARIABLES])
|
||||
|
||||
with tf.compat.v1.train.SingularMonitoredSession(hooks=[]) as sess:
|
||||
logging.info('Restoring from checkpoint %s', checkpoint_path)
|
||||
checkpoint.restore(sess, checkpoint_path)
|
||||
|
||||
logging.info('Writing output to %s', eval_config.output_path)
|
||||
eval_begin_time = time.time()
|
||||
_run_evaluation(sess=sess,
|
||||
experiment=experiment,
|
||||
eval_config=eval_config,
|
||||
output_dir=eval_config.output_path,
|
||||
min_range=network_config.min_range,
|
||||
max_range=network_config.max_range,
|
||||
num_bins=network_config.num_bins,
|
||||
torsion_bins=network_config.torsion_bins)
|
||||
logging.info('Finished eval %.1fs', (time.time() - eval_begin_time))
|
||||
|
||||
|
||||
def _run_evaluation(
|
||||
sess, experiment, eval_config, output_dir, min_range, max_range, num_bins,
|
||||
torsion_bins):
|
||||
"""Evaluate a contact map by aggregating crops.
|
||||
|
||||
Args:
|
||||
sess: A tf.train.Session.
|
||||
experiment: An experiment class.
|
||||
eval_config: A config dict of eval parameters.
|
||||
output_dir: Directory to save the predictions to.
|
||||
min_range: The minimum range in Angstroms to consider in distograms.
|
||||
max_range: The maximum range in Angstroms to consider in distograms, see
|
||||
num_bins below for clarification.
|
||||
num_bins: The number of bins in the distance histogram being predicted.
|
||||
We divide the min_range--(min_range + max_range) Angstrom range into this
|
||||
many bins.
|
||||
torsion_bins: The number of bins the torsion angles are discretised into.
|
||||
"""
|
||||
tf.io.gfile.makedirs(os.path.join(output_dir, 'pickle_files'))
|
||||
|
||||
logging.info('Eval config is %s\nnum_bins: %d', eval_config, num_bins)
|
||||
num_examples = 0
|
||||
num_crops = 0
|
||||
start_all_time = time.time()
|
||||
|
||||
# Either do the whole test set, or up to a specified limit.
|
||||
max_examples = experiment.num_eval_examples
|
||||
if eval_config.max_num_examples > 0:
|
||||
max_examples = min(max_examples, eval_config.max_num_examples)
|
||||
|
||||
while num_examples < max_examples:
|
||||
one_prediction = compute_one_prediction(
|
||||
num_examples, experiment, sess, eval_config, num_bins, torsion_bins)
|
||||
|
||||
single_message = one_prediction.single_message
|
||||
num_crops_local = one_prediction.num_crops_local
|
||||
sequence = one_prediction.sequence
|
||||
filebase = one_prediction.filebase
|
||||
softmax_probs = one_prediction.softmax_probs
|
||||
ss = one_prediction.ss
|
||||
asa = one_prediction.asa
|
||||
torsions = one_prediction.torsions
|
||||
|
||||
num_examples += 1
|
||||
num_crops += num_crops_local
|
||||
|
||||
# Save the output files.
|
||||
filename = os.path.join(output_dir,
|
||||
'pickle_files', '%s.pickle' % filebase)
|
||||
distogram_io.save_distance_histogram(
|
||||
filename, softmax_probs, filebase, sequence,
|
||||
min_range=min_range, max_range=max_range, num_bins=num_bins)
|
||||
|
||||
if experiment.model.torsion_multiplier > 0:
|
||||
torsions_dir = os.path.join(output_dir, 'torsions')
|
||||
tf.io.gfile.makedirs(torsions_dir)
|
||||
distogram_io.save_torsions(torsions_dir, filebase, sequence, torsions)
|
||||
|
||||
if experiment.model.secstruct_multiplier > 0:
|
||||
ss_dir = os.path.join(output_dir, 'secstruct')
|
||||
tf.io.gfile.makedirs(ss_dir)
|
||||
secstruct.save_secstructs(ss_dir, filebase, None, sequence, ss)
|
||||
|
||||
if experiment.model.asa_multiplier > 0:
|
||||
asa_dir = os.path.join(output_dir, 'asa')
|
||||
tf.io.gfile.makedirs(asa_dir)
|
||||
secstruct.save_secstructs(asa_dir, filebase, None, sequence,
|
||||
np.expand_dims(asa, 1), label='Deepmind 2D ASA')
|
||||
|
||||
time_spent = time.time() - start_all_time
|
||||
logging.info(
|
||||
'Evaluate %d examples, %d crops %.1f crops/ex. '
|
||||
'Took %.1fs, %.3f s/example %.3f crops/s\n%s',
|
||||
num_examples, num_crops, num_crops / float(num_examples), time_spent,
|
||||
time_spent / num_examples, num_crops / time_spent, single_message)
|
||||
|
||||
logging.info('Tested on %d', num_examples)
|
||||
|
||||
|
||||
def compute_one_prediction(
|
||||
num_examples, experiment, sess, eval_config, num_bins, torsion_bins):
|
||||
"""Find the contact map for a single domain."""
|
||||
num_crops_local = 0
|
||||
debug_steps = 0
|
||||
start = time.time()
|
||||
output_fetches = {'probs': experiment.eval_probs}
|
||||
output_fetches['softmax_probs'] = experiment.eval_probs_softmax
|
||||
# Add the auxiliary outputs if present.
|
||||
experiment.model.update_crop_fetches(output_fetches)
|
||||
# Get data.
|
||||
batch = experiment.get_one_example(sess)
|
||||
length = batch['sequence_lengths'][0]
|
||||
batch_size = batch['sequence_lengths'].shape[0]
|
||||
domain = batch['domain_name'][0][0].decode('utf-8')
|
||||
chain = batch['chain_name'][0][0].decode('utf-8')
|
||||
filebase = domain or chain
|
||||
sequence = six.ensure_str(batch['sequences'][0][0])
|
||||
logging.info('SepWorking on %d %s %s %d', num_examples, domain, chain, length)
|
||||
inputs_1d = batch['inputs_1d']
|
||||
if 'residue_index' in batch:
|
||||
logging.info('Getting residue_index from features')
|
||||
residue_index = np.squeeze(
|
||||
batch['residue_index'], axis=2).astype(np.int32)
|
||||
else:
|
||||
logging.info('Generating residue_index')
|
||||
residue_index = np.tile(np.expand_dims(
|
||||
np.arange(length, dtype=np.int32), 0), [batch_size, 1])
|
||||
assert batch_size == 1
|
||||
num_examples += batch_size
|
||||
# Crops.
|
||||
prob_accum = np.zeros((length, length, 2))
|
||||
ss_accum = np.zeros((length, 8))
|
||||
torsions_accum = np.zeros((length, torsion_bins**2))
|
||||
asa_accum = np.zeros((length,))
|
||||
weights_1d_accum = np.zeros((length,))
|
||||
softmax_prob_accum = np.zeros((length, length, num_bins), dtype=np.float32)
|
||||
|
||||
crop_size_x = experiment.crop_size_x
|
||||
crop_step_x = crop_size_x // eval_config.crop_shingle_x
|
||||
crop_size_y = experiment.crop_size_y
|
||||
crop_step_y = crop_size_y // eval_config.crop_shingle_y
|
||||
|
||||
prob_weights = 1
|
||||
if eval_config.pyramid_weights > 0:
|
||||
sx = np.expand_dims(np.linspace(1.0 / crop_size_x, 1, crop_size_x), 1)
|
||||
sy = np.expand_dims(np.linspace(1.0 / crop_size_y, 1, crop_size_y), 0)
|
||||
prob_weights = np.minimum(np.minimum(sx, np.flipud(sx)),
|
||||
np.minimum(sy, np.fliplr(sy)))
|
||||
prob_weights /= np.max(prob_weights)
|
||||
prob_weights = np.minimum(prob_weights, eval_config.pyramid_weights)
|
||||
logging.log_first_n(logging.INFO, 'Crop: %dx%d step %d,%d pyr %.2f',
|
||||
debug_steps,
|
||||
crop_size_x, crop_size_y,
|
||||
crop_step_x, crop_step_y, eval_config.pyramid_weights)
|
||||
# Accumulate all crops, starting and ending half off the square.
|
||||
for i in range(-crop_size_x // 2, length - crop_size_x // 2, crop_step_x):
|
||||
for j in range(-crop_size_y // 2, length - crop_size_y // 2, crop_step_y):
|
||||
# The ideal crop.
|
||||
patch = compute_one_patch(
|
||||
sess, experiment, output_fetches, inputs_1d, residue_index,
|
||||
prob_weights, batch, length, i, j, crop_size_x, crop_size_y)
|
||||
# Assemble the crops into a final complete prediction.
|
||||
ic = max(0, i)
|
||||
jc = max(0, j)
|
||||
ic_to = ic + patch['prob'].shape[1]
|
||||
jc_to = jc + patch['prob'].shape[0]
|
||||
prob_accum[jc:jc_to, ic:ic_to, 0] += patch['prob'] * patch['weight']
|
||||
prob_accum[jc:jc_to, ic:ic_to, 1] += patch['weight']
|
||||
softmax_prob_accum[jc:jc_to, ic:ic_to, :] += (
|
||||
patch['softmax'] * np.expand_dims(patch['weight'], 2))
|
||||
weights_1d_accum[jc:jc_to] += 1
|
||||
weights_1d_accum[ic:ic_to] += 1
|
||||
if 'asa_x' in patch:
|
||||
asa_accum[ic:ic + patch['asa_x'].shape[0]] += np.squeeze(
|
||||
patch['asa_x'], axis=1)
|
||||
asa_accum[jc:jc + patch['asa_y'].shape[0]] += np.squeeze(
|
||||
patch['asa_y'], axis=1)
|
||||
if 'ss_x' in patch:
|
||||
ss_accum[ic:ic + patch['ss_x'].shape[0]] += patch['ss_x']
|
||||
ss_accum[jc:jc + patch['ss_y'].shape[0]] += patch['ss_y']
|
||||
if 'torsions_x' in patch:
|
||||
torsions_accum[
|
||||
ic:ic + patch['torsions_x'].shape[0]] += patch['torsions_x']
|
||||
torsions_accum[
|
||||
jc:jc + patch['torsions_y'].shape[0]] += patch['torsions_y']
|
||||
num_crops_local += 1
|
||||
single_message = (
|
||||
'Constructed %s len %d from %d chunks [%d, %d x %d, %d] '
|
||||
'in %5.1fs' % (
|
||||
filebase, length, num_crops_local,
|
||||
crop_size_x, crop_step_x, crop_size_y, crop_step_y,
|
||||
time.time() - start))
|
||||
logging.info(single_message)
|
||||
logging.info('prob_accum[:, :, 1]: %s', prob_accum[:, :, 1])
|
||||
assert (prob_accum[:, :, 1] > 0.0).all()
|
||||
probs = prob_accum[:, :, 0] / prob_accum[:, :, 1]
|
||||
softmax_probs = softmax_prob_accum[:, :, :] / prob_accum[:, :, 1:2]
|
||||
|
||||
asa_accum /= weights_1d_accum
|
||||
ss_accum /= np.expand_dims(weights_1d_accum, 1)
|
||||
torsions_accum /= np.expand_dims(weights_1d_accum, 1)
|
||||
|
||||
# The probs are symmetrical.
|
||||
probs = (probs + probs.transpose()) / 2
|
||||
if num_bins > 1:
|
||||
softmax_probs = (softmax_probs + np.transpose(
|
||||
softmax_probs, axes=[1, 0, 2])) / 2
|
||||
return Prediction(
|
||||
single_message=single_message,
|
||||
num_crops_local=num_crops_local,
|
||||
sequence=sequence,
|
||||
filebase=filebase,
|
||||
softmax_probs=softmax_probs,
|
||||
ss=ss_accum,
|
||||
asa=asa_accum,
|
||||
torsions=torsions_accum)
|
||||
|
||||
|
||||
def compute_one_patch(sess, experiment, output_fetches, inputs_1d,
|
||||
residue_index, prob_weights, batch, length, i, j,
|
||||
crop_size_x, crop_size_y):
|
||||
"""Compute the output predictions for a single crop."""
|
||||
# Note that these are allowed to go off the end of the protein.
|
||||
end_x = i + crop_size_x
|
||||
end_y = j + crop_size_y
|
||||
crop_limits = np.array([[i, end_x, j, end_y]], dtype=np.int32)
|
||||
ic = max(0, i)
|
||||
jc = max(0, j)
|
||||
end_x_cropped = min(length, end_x)
|
||||
end_y_cropped = min(length, end_y)
|
||||
prepad_x = max(0, -i)
|
||||
prepad_y = max(0, -j)
|
||||
postpad_x = end_x - end_x_cropped
|
||||
postpad_y = end_y - end_y_cropped
|
||||
|
||||
# Precrop the 2D features:
|
||||
inputs_2d = np.pad(batch['inputs_2d'][
|
||||
:, jc:end_y, ic:end_x, :],
|
||||
[[0, 0],
|
||||
[prepad_y, postpad_y],
|
||||
[prepad_x, postpad_x],
|
||||
[0, 0]], mode='constant')
|
||||
assert inputs_2d.shape[1] == crop_size_y
|
||||
assert inputs_2d.shape[2] == crop_size_x
|
||||
|
||||
# Generate the corresponding crop, but it might be truncated.
|
||||
cxx = batch['inputs_2d'][:, ic:end_x, ic:end_x, :]
|
||||
cyy = batch['inputs_2d'][:, jc:end_y, jc:end_y, :]
|
||||
if cxx.shape[1] < inputs_2d.shape[1]:
|
||||
cxx = np.pad(cxx, [[0, 0],
|
||||
[prepad_x, max(0, i + crop_size_y - length)],
|
||||
[prepad_x, postpad_x],
|
||||
[0, 0]], mode='constant')
|
||||
assert cxx.shape[1] == crop_size_y
|
||||
assert cxx.shape[2] == crop_size_x
|
||||
if cyy.shape[2] < inputs_2d.shape[2]:
|
||||
cyy = np.pad(cyy, [[0, 0],
|
||||
[prepad_y, postpad_y],
|
||||
[prepad_y, max(0, j + crop_size_x - length)],
|
||||
[0, 0]], mode='constant')
|
||||
assert cyy.shape[1] == crop_size_y
|
||||
assert cyy.shape[2] == crop_size_x
|
||||
inputs_2d = np.concatenate([inputs_2d, cxx, cyy], 3)
|
||||
|
||||
output_results = sess.run(output_fetches, feed_dict={
|
||||
experiment.inputs_1d_placeholder: inputs_1d,
|
||||
experiment.residue_index_placeholder: residue_index,
|
||||
experiment.inputs_2d_placeholder: inputs_2d,
|
||||
experiment.crop_placeholder: crop_limits,
|
||||
})
|
||||
# Crop out the "live" region of the probs.
|
||||
prob_patch = output_results['probs'][
|
||||
0, prepad_y:crop_size_y - postpad_y,
|
||||
prepad_x:crop_size_x - postpad_x]
|
||||
weight_patch = prob_weights[prepad_y:crop_size_y - postpad_y,
|
||||
prepad_x:crop_size_x - postpad_x]
|
||||
patch = {'prob': prob_patch, 'weight': weight_patch}
|
||||
|
||||
if 'softmax_probs' in output_results:
|
||||
patch['softmax'] = output_results['softmax_probs'][
|
||||
0, prepad_y:crop_size_y - postpad_y,
|
||||
prepad_x:crop_size_x - postpad_x]
|
||||
if 'secstruct_probs' in output_results:
|
||||
patch['ss_x'] = output_results['secstruct_probs'][
|
||||
0, prepad_x:crop_size_x - postpad_x]
|
||||
patch['ss_y'] = output_results['secstruct_probs'][
|
||||
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
|
||||
if 'torsion_probs' in output_results:
|
||||
patch['torsions_x'] = output_results['torsion_probs'][
|
||||
0, prepad_x:crop_size_x - postpad_x]
|
||||
patch['torsions_y'] = output_results['torsion_probs'][
|
||||
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
|
||||
if 'asa_output' in output_results:
|
||||
patch['asa_x'] = output_results['asa_output'][
|
||||
0, prepad_x:crop_size_x - postpad_x]
|
||||
patch['asa_y'] = output_results['asa_output'][
|
||||
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
|
||||
return patch
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
logging.info('Loading a JSON config from: %s', FLAGS.config_path)
|
||||
with tf.io.gfile.GFile(FLAGS.config_path, 'r') as f:
|
||||
config = config_dict.ConfigDict.from_json(f.read())
|
||||
|
||||
# Redefine the relevant output fields.
|
||||
if FLAGS.eval_sstable:
|
||||
config.eval_config.eval_sstable = FLAGS.eval_sstable
|
||||
if FLAGS.stats_file:
|
||||
config.eval_config.stats_file = FLAGS.stats_file
|
||||
if FLAGS.output_path:
|
||||
config.eval_config.output_path = FLAGS.output_path
|
||||
|
||||
with tf.device('/cpu:0' if FLAGS.cpu else None):
|
||||
evaluate(checkpoint_path=FLAGS.checkpoint_path, **config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
364
alphafold_casp13/contacts_dataset.py
Normal file
364
alphafold_casp13/contacts_dataset.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF wrapper for protein tf.Example datasets."""
|
||||
|
||||
import collections
|
||||
import enum
|
||||
import json
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
_ProteinDescription = collections.namedtuple(
|
||||
'_ProteinDescription', (
|
||||
'sequence_lengths', 'key', 'sequences', 'inputs_1d', 'inputs_2d',
|
||||
'inputs_2d_diagonal', 'crops', 'scalars', 'targets'))
|
||||
|
||||
|
||||
class FeatureType(enum.Enum):
|
||||
ZERO_DIM = 0 # Shape [x]
|
||||
ONE_DIM = 1 # Shape [num_res, x]
|
||||
TWO_DIM = 2 # Shape [num_res, num_res, x]
|
||||
|
||||
# Placeholder values that will be replaced with their true value at runtime.
|
||||
NUM_RES = 'num residues placeholder'
|
||||
|
||||
# Sizes of the protein features. NUM_RES is allowed as a placeholder to be
|
||||
# replaced with the number of residues.
|
||||
FEATURES = {
|
||||
'aatype': (tf.float32, [NUM_RES, 21]),
|
||||
'alpha_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'alpha_positions': (tf.float32, [NUM_RES, 3]),
|
||||
'beta_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'beta_positions': (tf.float32, [NUM_RES, 3]),
|
||||
'between_segment_residues': (tf.int64, [NUM_RES, 1]),
|
||||
'chain_name': (tf.string, [1]),
|
||||
'deletion_probability': (tf.float32, [NUM_RES, 1]),
|
||||
'domain_name': (tf.string, [1]),
|
||||
'gap_matrix': (tf.float32, [NUM_RES, NUM_RES, 1]),
|
||||
'hhblits_profile': (tf.float32, [NUM_RES, 22]),
|
||||
'hmm_profile': (tf.float32, [NUM_RES, 30]),
|
||||
'key': (tf.string, [1]),
|
||||
'mutual_information': (tf.float32, [NUM_RES, NUM_RES, 1]),
|
||||
'non_gapped_profile': (tf.float32, [NUM_RES, 21]),
|
||||
'num_alignments': (tf.int64, [NUM_RES, 1]),
|
||||
'num_effective_alignments': (tf.float32, [1]),
|
||||
'phi_angles': (tf.float32, [NUM_RES, 1]),
|
||||
'phi_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'profile': (tf.float32, [NUM_RES, 21]),
|
||||
'profile_with_prior': (tf.float32, [NUM_RES, 22]),
|
||||
'profile_with_prior_without_gaps': (tf.float32, [NUM_RES, 21]),
|
||||
'pseudo_bias': (tf.float32, [NUM_RES, 22]),
|
||||
'pseudo_frob': (tf.float32, [NUM_RES, NUM_RES, 1]),
|
||||
'pseudolikelihood': (tf.float32, [NUM_RES, NUM_RES, 484]),
|
||||
'psi_angles': (tf.float32, [NUM_RES, 1]),
|
||||
'psi_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'residue_index': (tf.int64, [NUM_RES, 1]),
|
||||
'resolution': (tf.float32, [1]),
|
||||
'reweighted_profile': (tf.float32, [NUM_RES, 22]),
|
||||
'sec_structure': (tf.int64, [NUM_RES, 8]),
|
||||
'sec_structure_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'seq_length': (tf.int64, [NUM_RES, 1]),
|
||||
'sequence': (tf.string, [1]),
|
||||
'solv_surf': (tf.float32, [NUM_RES, 1]),
|
||||
'solv_surf_mask': (tf.int64, [NUM_RES, 1]),
|
||||
'superfamily': (tf.string, [1]),
|
||||
}
|
||||
|
||||
FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
|
||||
FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
|
||||
|
||||
|
||||
def shape(feature_name, num_residues, features=None):
|
||||
"""Get the shape for the given feature name.
|
||||
|
||||
Args:
|
||||
feature_name: String identifier for the feature. If the feature name ends
|
||||
with "_unnormalized", theis suffix is stripped off.
|
||||
num_residues: The number of residues in the current domain - some elements
|
||||
of the shape can be dynamic and will be replaced by this value.
|
||||
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
|
||||
|
||||
Returns:
|
||||
List of ints representation the tensor size.
|
||||
"""
|
||||
features = features or FEATURES
|
||||
if feature_name.endswith('_unnormalized'):
|
||||
feature_name = feature_name[:-13]
|
||||
|
||||
unused_dtype, raw_sizes = features[feature_name]
|
||||
replacements = {NUM_RES: num_residues}
|
||||
|
||||
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
|
||||
return sizes
|
||||
|
||||
|
||||
def dim(feature_name):
|
||||
"""Determine the type of feature.
|
||||
|
||||
Args:
|
||||
feature_name: String identifier for the feature to lookup. If the feature
|
||||
name ends with "_unnormalized", theis suffix is stripped off.
|
||||
|
||||
Returns:
|
||||
A FeatureType enum describing whether the feature is of size num_res or
|
||||
num_res * num_res.
|
||||
|
||||
Raises:
|
||||
ValueError: If the feature is of an unknown type.
|
||||
"""
|
||||
if feature_name.endswith('_unnormalized'):
|
||||
feature_name = feature_name[:-13]
|
||||
|
||||
num_dims = len(FEATURE_SIZES[feature_name])
|
||||
if num_dims == 1:
|
||||
return FeatureType.ZERO_DIM
|
||||
elif num_dims == 2 and FEATURE_SIZES[feature_name][0] == NUM_RES:
|
||||
return FeatureType.ONE_DIM
|
||||
elif num_dims == 3 and FEATURE_SIZES[feature_name][0] == NUM_RES:
|
||||
return FeatureType.TWO_DIM
|
||||
else:
|
||||
raise ValueError('Expect feature sizes to be 2 or 3, got %i' %
|
||||
len(FEATURE_SIZES[feature_name]))
|
||||
|
||||
|
||||
def _concat_or_zeros(tensor_list, axis, tensor_shape, name):
|
||||
"""Concatenates the tensors if given, otherwise returns a tensor of zeros."""
|
||||
if tensor_list:
|
||||
return tf.concat(tensor_list, axis=axis, name=name)
|
||||
return tf.zeros(tensor_shape, name=name + '_zeros')
|
||||
|
||||
|
||||
def parse_tfexample(raw_data, features):
|
||||
"""Read a single TF Example proto and return a subset of its features.
|
||||
|
||||
Args:
|
||||
raw_data: A serialized tf.Example proto.
|
||||
features: A dictionary of features, mapping string feature names to a tuple
|
||||
(dtype, shape). This dictionary should be a subset of
|
||||
protein_features.FEATURES (or the dictionary itself for all features).
|
||||
|
||||
Returns:
|
||||
A dictionary of features mapping feature names to features. Only the given
|
||||
features are returned, all other ones are filtered out.
|
||||
"""
|
||||
feature_map = {
|
||||
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
|
||||
for k, v in features.items()
|
||||
}
|
||||
parsed_features = tf.io.parse_single_example(raw_data, feature_map)
|
||||
|
||||
# Find out what is the number of sequences and the number of alignments.
|
||||
num_residues = tf.cast(parsed_features['seq_length'][0], dtype=tf.int32)
|
||||
|
||||
# Reshape the tensors according to the sequence length and num alignments.
|
||||
for k, v in parsed_features.items():
|
||||
new_shape = shape(feature_name=k, num_residues=num_residues)
|
||||
# Make sure the feature we are reshaping is not empty.
|
||||
assert_non_empty = tf.compat.v1.assert_greater(
|
||||
tf.size(v), 0, name='assert_%s_non_empty' % k,
|
||||
message='The feature %s is not set in the tf.Example. Either do not '
|
||||
'request the feature or use a tf.Example that has the feature set.' % k)
|
||||
with tf.control_dependencies([assert_non_empty]):
|
||||
parsed_features[k] = tf.reshape(v, new_shape, name='reshape_%s' % k)
|
||||
|
||||
return parsed_features
|
||||
|
||||
|
||||
def create_tf_dataset(tf_record_filename, features):
|
||||
"""Creates an instance of tf.data.Dataset backed by a protein dataset SSTable.
|
||||
|
||||
Args:
|
||||
tf_record_filename: A string with filename of the TFRecord file.
|
||||
features: A list of strings of feature names to be returned in the dataset.
|
||||
|
||||
Returns:
|
||||
A tf.data.Dataset object. Its items are dictionaries from feature names to
|
||||
feature values.
|
||||
"""
|
||||
# Make sure these features are always read.
|
||||
required_features = ['aatype', 'sequence', 'seq_length']
|
||||
features = list(set(features) | set(required_features))
|
||||
features = {name: FEATURES[name] for name in features}
|
||||
|
||||
tf_dataset = tf.data.TFRecordDataset(filenames=[tf_record_filename])
|
||||
tf_dataset = tf_dataset.map(lambda raw: parse_tfexample(raw, features))
|
||||
|
||||
return tf_dataset
|
||||
|
||||
|
||||
def normalize_from_stats_file(
|
||||
features, stats_file_path, feature_normalization, copy_unnormalized=None):
|
||||
"""Normalizes the features set in the feature_normalization by the norm stats.
|
||||
|
||||
Args:
|
||||
features: A dictionary mapping feature names to feature tensors.
|
||||
stats_file_path: A string with the path of the statistics JSON file.
|
||||
feature_normalization: A dictionary specifying the normalization type for
|
||||
each input feature. Acceptable values are 'std' and 'none'. If not
|
||||
specified default to 'none'. Any extra features that are not present in
|
||||
features will be ignored.
|
||||
copy_unnormalized: A list of features whose unnormalized copy should be
|
||||
added. For any feature F in this list a feature F + "_unnormalized" will
|
||||
be added in the output dictionary containing the unnormalized feature.
|
||||
This is useful if you have a feature you want to have both in
|
||||
desired_features (normalized) and also in desired_targets (unnormalized).
|
||||
See convert_to_legacy_proteins_dataset_format for more details.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping features names to feature tensors. The ones that were
|
||||
specified in feature_normalization will be normalized.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown normalization mode is used.
|
||||
"""
|
||||
with tf.io.gfile.GFile(stats_file_path, 'r') as f:
|
||||
norm_stats = json.loads(f.read())
|
||||
|
||||
if not copy_unnormalized:
|
||||
copy_unnormalized = []
|
||||
# We need this unnormalized in convert_to_legacy_proteins_dataset_format.
|
||||
copy_unnormalized.append('num_alignments')
|
||||
|
||||
for feature in copy_unnormalized:
|
||||
if feature in features:
|
||||
features[feature + '_unnormalized'] = features[feature]
|
||||
|
||||
range_epsilon = 1e-12
|
||||
for key, value in features.items():
|
||||
if key not in feature_normalization or feature_normalization[key] == 'none':
|
||||
pass
|
||||
elif feature_normalization[key] == 'std':
|
||||
value = tf.cast(value, dtype=tf.float32)
|
||||
train_mean = tf.cast(norm_stats['mean'][key], dtype=tf.float32)
|
||||
train_range = tf.sqrt(tf.cast(norm_stats['var'][key], dtype=tf.float32))
|
||||
value -= train_mean
|
||||
value = tf.compat.v1.where(
|
||||
train_range > range_epsilon, value / train_range, value)
|
||||
features[key] = value
|
||||
else:
|
||||
raise ValueError('Unknown normalization mode %s for feature %s.'
|
||||
% (feature_normalization[key], key))
|
||||
return features
|
||||
|
||||
|
||||
def convert_to_legacy_proteins_dataset_format(
|
||||
features, desired_features, desired_scalars, desired_targets):
|
||||
"""Converts the output of tf.Dataset to the legacy format.
|
||||
|
||||
Args:
|
||||
features: A dictionary mapping feature names to feature tensors.
|
||||
desired_features: A list with the names of the desired features. These will
|
||||
be filtered out of features and returned in one of the inputs_1d or
|
||||
inputs_2d. The features concatenated in `inputs_1d`, `inputs_2d` will be
|
||||
concatenated in the same order as they were given in `desired_features`.
|
||||
desired_scalars: A list naming the desired scalars. These will
|
||||
be filtered out of features and returned in scalars. If features contain
|
||||
an unnormalized version of a desired scalar, it will be used.
|
||||
desired_targets: A list naming the desired targets. These will
|
||||
be filtered out of features and returned in targets. If features contain
|
||||
an unnormalized version of a desired target, it will be used.
|
||||
|
||||
Returns:
|
||||
A _ProteinDescription namedtuple consisting of:
|
||||
sequence_length: A scalar int32 tensor with the sequence length.
|
||||
key: A string tensor with the sequence key or empty if not set features.
|
||||
sequences: A string tensor with the protein sequence.
|
||||
inputs_1d: All 1D features in a single tensor of shape
|
||||
[num_res, 1d_channels].
|
||||
inputs_2d: All 2D features in a single tensor of shape
|
||||
[num_res, num_res, 2d_channels].
|
||||
inputs_2d_diagonal: All 2D diagonal features in a single tensor of shape
|
||||
[num_res, num_res, 2d_diagonal_channels]. If no diagonal features found
|
||||
in features, the tensor will be set to inputs_2d.
|
||||
crops: A int32 tensor with the crop poisitions. If not set in features,
|
||||
it will be set to [0, num_res, 0, num_res].
|
||||
scalars: All requested scalar tensors in a list.
|
||||
targets: All requested target tensors in a list.
|
||||
|
||||
Raises:
|
||||
ValueError: If the feature size is invalid.
|
||||
"""
|
||||
tensors_1d = []
|
||||
tensors_2d = []
|
||||
tensors_2d_diagonal = []
|
||||
for key in desired_features:
|
||||
# Determine if the feature is 1D or 2D.
|
||||
feature_dim = dim(key)
|
||||
if feature_dim == FeatureType.ONE_DIM:
|
||||
tensors_1d.append(tf.cast(features[key], dtype=tf.float32))
|
||||
elif feature_dim == FeatureType.TWO_DIM:
|
||||
if key not in features:
|
||||
if not(key + '_cropped' in features and key + '_diagonal' in features):
|
||||
raise ValueError(
|
||||
'The 2D feature %s is not in the features dictionary and neither '
|
||||
'are its cropped and diagonal versions.' % key)
|
||||
else:
|
||||
tensors_2d.append(
|
||||
tf.cast(features[key + '_cropped'], dtype=tf.float32))
|
||||
tensors_2d_diagonal.append(
|
||||
tf.cast(features[key + '_diagonal'], dtype=tf.float32))
|
||||
else:
|
||||
tensors_2d.append(tf.cast(features[key], dtype=tf.float32))
|
||||
else:
|
||||
raise ValueError('Unexpected FeatureType returned: %s' % str(feature_dim))
|
||||
|
||||
# Determine num_res from the sequence as seq_length was possibly normalized.
|
||||
num_res = tf.strings.length(features['sequence'])[0]
|
||||
|
||||
# Concatenate feature tensors into a single tensor
|
||||
inputs_1d = _concat_or_zeros(
|
||||
tensors_1d, axis=1, tensor_shape=[num_res, 0],
|
||||
name='inputs_1d_concat')
|
||||
inputs_2d = _concat_or_zeros(
|
||||
tensors_2d, axis=2, tensor_shape=[num_res, num_res, 0],
|
||||
name='inputs_2d_concat')
|
||||
if tensors_2d_diagonal:
|
||||
# The legacy dataset outputs the two diagonal crops stacked
|
||||
# A1, B1, C1, A2, B2, C2. So convert the A1, A2, B1, B2, C1, C2 format.
|
||||
diagonal_crops1 = [t[:, :, :(t.shape[2] // 2)] for t in tensors_2d_diagonal]
|
||||
diagonal_crops2 = [t[:, :, (t.shape[2] // 2):] for t in tensors_2d_diagonal]
|
||||
inputs_2d_diagonal = tf.concat(diagonal_crops1 + diagonal_crops2, axis=2)
|
||||
else:
|
||||
inputs_2d_diagonal = inputs_2d
|
||||
|
||||
sequence = features['sequence']
|
||||
sequence_key = features.get('key', tf.constant(['']))[0]
|
||||
if 'crops' in features:
|
||||
crops = features['crops']
|
||||
else:
|
||||
crops = tf.stack([0, tf.shape(sequence)[0], 0, tf.shape(sequence)[0]])
|
||||
|
||||
scalar_tensors = []
|
||||
for key in desired_scalars:
|
||||
scalar_tensors.append(features.get(key + '_unnormalized', features[key]))
|
||||
|
||||
target_tensors = []
|
||||
for key in desired_targets:
|
||||
target_tensors.append(features.get(key + '_unnormalized', features[key]))
|
||||
|
||||
scalar_class = collections.namedtuple('_ScalarClass', desired_scalars)
|
||||
target_class = collections.namedtuple('_TargetClass', desired_targets)
|
||||
|
||||
return _ProteinDescription(
|
||||
sequence_lengths=num_res,
|
||||
key=sequence_key,
|
||||
sequences=sequence,
|
||||
inputs_1d=inputs_1d,
|
||||
inputs_2d=inputs_2d,
|
||||
inputs_2d_diagonal=inputs_2d_diagonal,
|
||||
crops=crops,
|
||||
scalars=scalar_class(*scalar_tensors),
|
||||
targets=target_class(*target_tensors))
|
||||
233
alphafold_casp13/contacts_experiment.py
Normal file
233
alphafold_casp13/contacts_experiment.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contact prediction convnet experiment example."""
|
||||
|
||||
from absl import logging
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import contacts_dataset
|
||||
from alphafold_casp13 import contacts_network
|
||||
|
||||
|
||||
def _int_ph(shape, name):
|
||||
return tf.compat.v1.placeholder(
|
||||
dtype=tf.int32, shape=shape, name=('%s_placeholder' % name))
|
||||
|
||||
|
||||
def _float_ph(shape, name):
|
||||
return tf.compat.v1.placeholder(
|
||||
dtype=tf.float32, shape=shape, name=('%s_placeholder' % name))
|
||||
|
||||
|
||||
class Contacts(object):
|
||||
"""Contact prediction experiment."""
|
||||
|
||||
def __init__(
|
||||
self, tfrecord, stats_file, network_config, crop_size_x, crop_size_y,
|
||||
feature_normalization, normalization_exclusion):
|
||||
"""Builds the TensorFlow graph."""
|
||||
self.network_config = network_config
|
||||
self.crop_size_x = crop_size_x
|
||||
self.crop_size_y = crop_size_y
|
||||
|
||||
self._feature_normalization = feature_normalization
|
||||
self._normalization_exclusion = normalization_exclusion
|
||||
self._model = contacts_network.ContactsNet(**network_config)
|
||||
self._features = network_config.features
|
||||
self._scalars = network_config.scalars
|
||||
self._targets = network_config.targets
|
||||
# Add extra targets we need.
|
||||
required_targets = ['domain_name', 'resolution', 'chain_name']
|
||||
if self.model.torsion_multiplier > 0:
|
||||
required_targets.extend([
|
||||
'phi_angles', 'phi_mask', 'psi_angles', 'psi_mask'])
|
||||
if self.model.secstruct_multiplier > 0:
|
||||
required_targets.extend(['sec_structure', 'sec_structure_mask'])
|
||||
if self.model.asa_multiplier > 0:
|
||||
required_targets.extend(['solv_surf', 'solv_surf_mask'])
|
||||
extra_targets = [t for t in required_targets if t not in self._targets]
|
||||
if extra_targets:
|
||||
targets = list(self._targets)
|
||||
targets.extend(extra_targets)
|
||||
self._targets = tuple(targets)
|
||||
logging.info('Targets %s %s extra %s',
|
||||
type(self._targets), self._targets, extra_targets)
|
||||
logging.info('Evaluating on %s, stats: %s', tfrecord, stats_file)
|
||||
self._build_evaluation_graph(tfrecord=tfrecord, stats_file=stats_file)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
||||
|
||||
def _get_feature_normalization(self, features):
|
||||
return {key: self._feature_normalization
|
||||
for key in features
|
||||
if key not in list(self._normalization_exclusion)}
|
||||
|
||||
def _build_evaluation_graph(self, tfrecord, stats_file):
|
||||
"""Constructs the graph in pieces so it can be fed."""
|
||||
with tf.name_scope('competitionsep'):
|
||||
# Construct the dataset and mapping ops.
|
||||
dataset = contacts_dataset.create_tf_dataset(
|
||||
tf_record_filename=tfrecord,
|
||||
features=tuple(self._features) + tuple(
|
||||
self._scalars) + tuple(self._targets))
|
||||
|
||||
def normalize(data):
|
||||
return contacts_dataset.normalize_from_stats_file(
|
||||
features=data,
|
||||
stats_file_path=stats_file,
|
||||
feature_normalization=self._get_feature_normalization(
|
||||
self._features),
|
||||
copy_unnormalized=list(set(self._features) & set(self._targets)))
|
||||
|
||||
def convert_to_legacy(features):
|
||||
return contacts_dataset.convert_to_legacy_proteins_dataset_format(
|
||||
features, self._features, self._scalars, self._targets)
|
||||
|
||||
dataset = dataset.map(normalize)
|
||||
dataset = dataset.map(convert_to_legacy)
|
||||
dataset = dataset.batch(1)
|
||||
|
||||
# Get a batch of tensors in the legacy ProteinsDataset format.
|
||||
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
|
||||
self._input_batch = iterator.get_next()
|
||||
|
||||
self.num_eval_examples = sum(
|
||||
1 for _ in tf.python_io.tf_record_iterator(tfrecord))
|
||||
|
||||
logging.info('Eval batch:\n%s', self._input_batch)
|
||||
feature_dim_1d = self._input_batch.inputs_1d.shape.as_list()[-1]
|
||||
feature_dim_2d = self._input_batch.inputs_2d.shape.as_list()[-1]
|
||||
feature_dim_2d *= 3 # The diagonals will be stacked before feeding.
|
||||
|
||||
# Now placeholders for the graph to compute the outputs for one crop.
|
||||
self.inputs_1d_placeholder = _float_ph(
|
||||
shape=[None, None, feature_dim_1d], name='inputs_1d')
|
||||
self.residue_index_placeholder = _int_ph(
|
||||
shape=[None, None], name='residue_index')
|
||||
self.inputs_2d_placeholder = _float_ph(
|
||||
shape=[None, None, None, feature_dim_2d], name='inputs_2d')
|
||||
# 4 ints: x_start, x_end, y_start, y_end.
|
||||
self.crop_placeholder = _int_ph(shape=[None, 4], name='crop')
|
||||
|
||||
# Finally placeholders for the graph to score the complete contact map.
|
||||
self.probs_placeholder = _float_ph(shape=[None, None, None], name='probs')
|
||||
self.softmax_probs_placeholder = _float_ph(
|
||||
shape=[None, None, None, self.network_config.num_bins],
|
||||
name='softmax_probs')
|
||||
self.cb_placeholder = _float_ph(shape=[None, None, 3], name='cb')
|
||||
self.cb_mask_placeholder = _float_ph(shape=[None, None], name='cb_mask')
|
||||
self.lengths_placeholder = _int_ph(shape=[None], name='lengths')
|
||||
|
||||
if self.model.secstruct_multiplier > 0:
|
||||
self.sec_structure_placeholder = _float_ph(
|
||||
shape=[None, None, 8], name='sec_structure')
|
||||
self.sec_structure_logits_placeholder = _float_ph(
|
||||
shape=[None, None, 8], name='sec_structure_logits')
|
||||
self.sec_structure_mask_placeholder = _float_ph(
|
||||
shape=[None, None, 1], name='sec_structure_mask')
|
||||
|
||||
if self.model.asa_multiplier > 0:
|
||||
self.solv_surf_placeholder = _float_ph(
|
||||
shape=[None, None, 1], name='solv_surf')
|
||||
self.solv_surf_logits_placeholder = _float_ph(
|
||||
shape=[None, None, 1], name='solv_surf_logits')
|
||||
self.solv_surf_mask_placeholder = _float_ph(
|
||||
shape=[None, None, 1], name='solv_surf_mask')
|
||||
|
||||
if self.model.torsion_multiplier > 0:
|
||||
self.torsions_truth_placeholder = _float_ph(
|
||||
shape=[None, None, 2], name='torsions_truth')
|
||||
self.torsions_mask_placeholder = _float_ph(
|
||||
shape=[None, None, 1], name='torsions_mask')
|
||||
self.torsion_logits_placeholder = _float_ph(
|
||||
shape=[None, None, self.network_config.torsion_bins ** 2],
|
||||
name='torsion_logits')
|
||||
|
||||
# Build a dict to pass all the placeholders into build.
|
||||
placeholders = {
|
||||
'inputs_1d_placeholder': self.inputs_1d_placeholder,
|
||||
'residue_index_placeholder': self.residue_index_placeholder,
|
||||
'inputs_2d_placeholder': self.inputs_2d_placeholder,
|
||||
'crop_placeholder': self.crop_placeholder,
|
||||
'probs_placeholder': self.probs_placeholder,
|
||||
'softmax_probs_placeholder': self.softmax_probs_placeholder,
|
||||
'cb_placeholder': self.cb_placeholder,
|
||||
'cb_mask_placeholder': self.cb_mask_placeholder,
|
||||
'lengths_placeholder': self.lengths_placeholder,
|
||||
}
|
||||
if self.model.secstruct_multiplier > 0:
|
||||
placeholders.update({
|
||||
'sec_structure': self.sec_structure_placeholder,
|
||||
'sec_structure_logits_placeholder':
|
||||
self.sec_structure_logits_placeholder,
|
||||
'sec_structure_mask': self.sec_structure_mask_placeholder,})
|
||||
if self.model.asa_multiplier > 0:
|
||||
placeholders.update({
|
||||
'solv_surf': self.solv_surf_placeholder,
|
||||
'solv_surf_logits_placeholder': self.solv_surf_logits_placeholder,
|
||||
'solv_surf_mask': self.solv_surf_mask_placeholder,})
|
||||
if self.model.torsion_multiplier > 0:
|
||||
placeholders.update({
|
||||
'torsions_truth': self.torsions_truth_placeholder,
|
||||
'torsion_logits_placeholder': self.torsion_logits_placeholder,
|
||||
'torsions_truth_mask': self.torsions_mask_placeholder,})
|
||||
|
||||
activations = self._model(
|
||||
crop_size_x=self.crop_size_x,
|
||||
crop_size_y=self.crop_size_y,
|
||||
placeholders=placeholders)
|
||||
self.eval_probs_softmax = tf.nn.softmax(
|
||||
activations[:, :, :, :self.network_config.num_bins])
|
||||
self.eval_probs = tf.reduce_sum(
|
||||
self.eval_probs_softmax[:, :, :, :self._model.quant_threshold()],
|
||||
axis=3)
|
||||
|
||||
def get_one_example(self, sess):
|
||||
"""Pull one example off the queue so we can feed it for evaluation."""
|
||||
request_dict = {
|
||||
'inputs_1d': self._input_batch.inputs_1d,
|
||||
'inputs_2d': self._input_batch.inputs_2d,
|
||||
'sequence_lengths': self._input_batch.sequence_lengths,
|
||||
'beta_positions': self._input_batch.targets.beta_positions,
|
||||
'beta_mask': self._input_batch.targets.beta_mask,
|
||||
'domain_name': self._input_batch.targets.domain_name,
|
||||
'chain_name': self._input_batch.targets.chain_name,
|
||||
'sequences': self._input_batch.sequences,
|
||||
}
|
||||
if hasattr(self._input_batch.targets, 'residue_index'):
|
||||
request_dict.update(
|
||||
{'residue_index': self._input_batch.targets.residue_index})
|
||||
if hasattr(self._input_batch.targets, 'phi_angles'):
|
||||
request_dict.update(
|
||||
{'phi_angles': self._input_batch.targets.phi_angles,
|
||||
'psi_angles': self._input_batch.targets.psi_angles,
|
||||
'phi_mask': self._input_batch.targets.phi_mask,
|
||||
'psi_mask': self._input_batch.targets.psi_mask})
|
||||
if hasattr(self._input_batch.targets, 'sec_structure'):
|
||||
request_dict.update(
|
||||
{'sec_structure': self._input_batch.targets.sec_structure,
|
||||
'sec_structure_mask': self._input_batch.targets.sec_structure_mask,})
|
||||
if hasattr(self._input_batch.targets, 'solv_surf'):
|
||||
request_dict.update(
|
||||
{'solv_surf': self._input_batch.targets.solv_surf,
|
||||
'solv_surf_mask': self._input_batch.targets.solv_surf_mask,})
|
||||
if hasattr(self._input_batch.targets, 'alpha_positions'):
|
||||
request_dict.update(
|
||||
{'alpha_positions': self._input_batch.targets.alpha_positions,
|
||||
'alpha_mask': self._input_batch.targets.alpha_mask,})
|
||||
batch = sess.run(request_dict)
|
||||
return batch
|
||||
498
alphafold_casp13/contacts_network.py
Normal file
498
alphafold_casp13/contacts_network.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Network for predicting C-beta contacts."""
|
||||
|
||||
from absl import logging
|
||||
import sonnet
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import asa_output
|
||||
from alphafold_casp13 import secstruct
|
||||
from alphafold_casp13 import two_dim_convnet
|
||||
from alphafold_casp13 import two_dim_resnet
|
||||
from tensorflow.contrib import layers as contrib_layers
|
||||
|
||||
|
||||
def call_on_tuple(f):
|
||||
"""Unpacks a tuple input parameter into arguments for a function f.
|
||||
|
||||
Mimics tuple unpacking in lambdas, which existed in Python 2 but has been
|
||||
removed in Python 3.
|
||||
|
||||
Args:
|
||||
f: A function taking multiple arguments.
|
||||
|
||||
Returns:
|
||||
A function equivalent to f accepting a tuple, which is then unpacked.
|
||||
"""
|
||||
return lambda args: f(*args)
|
||||
|
||||
|
||||
class ContactsNet(sonnet.AbstractModule):
|
||||
"""A network to go from sequence to secondary structure."""
|
||||
|
||||
def __init__(self,
|
||||
binary_code_bits,
|
||||
data_format,
|
||||
distance_multiplier,
|
||||
features,
|
||||
features_forward,
|
||||
max_range,
|
||||
min_range,
|
||||
num_bins,
|
||||
reshape_layer,
|
||||
resolution_noise_scale,
|
||||
scalars,
|
||||
targets,
|
||||
network_2d,
|
||||
network_2d_deep,
|
||||
torsion_bins=None,
|
||||
skip_connect=0,
|
||||
position_specific_bias_size=0,
|
||||
filters_1d=(),
|
||||
collapsed_batch_norm=False,
|
||||
is_ca_feature=False,
|
||||
asa_multiplier=0.0,
|
||||
secstruct_multiplier=0.0,
|
||||
torsion_multiplier=0.0,
|
||||
name='contacts_net'):
|
||||
"""Construct position prediction network."""
|
||||
super(ContactsNet, self).__init__(name=name)
|
||||
|
||||
self._filters_1d = filters_1d
|
||||
self._collapsed_batch_norm = collapsed_batch_norm
|
||||
self._is_ca_feature = is_ca_feature
|
||||
self._binary_code_bits = binary_code_bits
|
||||
self._data_format = data_format
|
||||
self._distance_multiplier = distance_multiplier
|
||||
self._features = features
|
||||
self._features_forward = features_forward
|
||||
self._max_range = max_range
|
||||
self._min_range = min_range
|
||||
self._num_bins = num_bins
|
||||
self._position_specific_bias_size = position_specific_bias_size
|
||||
self._reshape_layer = reshape_layer
|
||||
self._resolution_noise_scale = resolution_noise_scale
|
||||
self._scalars = scalars
|
||||
self._torsion_bins = torsion_bins
|
||||
self._skip_connect = skip_connect
|
||||
self._targets = targets
|
||||
self._network_2d = network_2d
|
||||
self._network_2d_deep = network_2d_deep
|
||||
|
||||
self.asa_multiplier = asa_multiplier
|
||||
self.secstruct_multiplier = secstruct_multiplier
|
||||
self.torsion_multiplier = torsion_multiplier
|
||||
|
||||
with self._enter_variable_scope():
|
||||
if self.secstruct_multiplier > 0:
|
||||
self._secstruct = secstruct.Secstruct()
|
||||
if self.asa_multiplier > 0:
|
||||
self._asa = asa_output.ASAOutputLayer()
|
||||
if self._position_specific_bias_size:
|
||||
self._position_specific_bias = tf.compat.v1.get_variable(
|
||||
'position_specific_bias',
|
||||
[self._position_specific_bias_size, self._num_bins or 1],
|
||||
initializer=tf.zeros_initializer())
|
||||
|
||||
def quant_threshold(self, threshold=8.0):
|
||||
"""Find the bin that is 8A+: we sum mass below this bin gives contact prob.
|
||||
|
||||
Args:
|
||||
threshold: The distance threshold.
|
||||
Returns:
|
||||
Index of bin.
|
||||
"""
|
||||
# Note that this misuses the max_range as the range.
|
||||
return int(
|
||||
(threshold - self._min_range) * self._num_bins / float(self._max_range))
|
||||
|
||||
def _build(self, crop_size_x=0, crop_size_y=0, placeholders=None):
|
||||
"""Puts the network into the graph.
|
||||
|
||||
Args:
|
||||
crop_size_x: Crop a chunk out in one dimension. 0 means no cropping.
|
||||
crop_size_y: Crop a chunk out in one dimension. 0 means no cropping.
|
||||
placeholders: A dict containing the placeholders needed.
|
||||
|
||||
Returns:
|
||||
A Tensor with logits of size [batch_size, num_residues, 3].
|
||||
"""
|
||||
crop_placeholder = placeholders['crop_placeholder']
|
||||
inputs_1d = placeholders['inputs_1d_placeholder']
|
||||
if self._is_ca_feature and 'aatype' in self._features:
|
||||
logging.info('Collapsing aatype to is_ca_feature %s',
|
||||
inputs_1d.shape.as_list()[-1])
|
||||
assert inputs_1d.shape.as_list()[-1] <= 21 + (
|
||||
1 if 'seq_length' in self._features else 0)
|
||||
inputs_1d = inputs_1d[:, :, 7:8]
|
||||
logits = self.compute_outputs(
|
||||
inputs_1d=inputs_1d,
|
||||
residue_index=placeholders['residue_index_placeholder'],
|
||||
inputs_2d=placeholders['inputs_2d_placeholder'],
|
||||
crop_x=crop_placeholder[:, 0:2],
|
||||
crop_y=crop_placeholder[:, 2:4],
|
||||
use_on_the_fly_stats=True,
|
||||
crop_size_x=crop_size_x,
|
||||
crop_size_y=crop_size_y,
|
||||
data_format='NHWC', # Force NHWC for evals.
|
||||
)
|
||||
return logits
|
||||
|
||||
def compute_outputs(self, inputs_1d, residue_index, inputs_2d, crop_x, crop_y,
|
||||
use_on_the_fly_stats, crop_size_x, crop_size_y,
|
||||
data_format='NHWC'):
|
||||
"""Given the inputs for a block, compute the network outputs."""
|
||||
hidden_1d = inputs_1d
|
||||
hidden_1d_list = [hidden_1d]
|
||||
if len(hidden_1d_list) != 1:
|
||||
hidden_1d = tf.concat(hidden_1d_list, 2)
|
||||
|
||||
output_dimension = self._num_bins or 1
|
||||
if self._distance_multiplier > 0:
|
||||
output_dimension += 1
|
||||
logits, activations = self._build_2d_embedding(
|
||||
hidden_1d=hidden_1d,
|
||||
residue_index=residue_index,
|
||||
inputs_2d=inputs_2d,
|
||||
output_dimension=output_dimension,
|
||||
use_on_the_fly_stats=use_on_the_fly_stats,
|
||||
crop_x=crop_x,
|
||||
crop_y=crop_y,
|
||||
crop_size_x=crop_size_x, crop_size_y=crop_size_y,
|
||||
data_format=data_format)
|
||||
logits = tf.debugging.check_numerics(
|
||||
logits, 'NaN in resnet activations', name='resnet_activations')
|
||||
if (self.secstruct_multiplier > 0 or
|
||||
self.asa_multiplier > 0 or
|
||||
self.torsion_multiplier > 0):
|
||||
# Make a 1d embedding by reducing the 2D activations.
|
||||
# We do this in the x direction and the y direction separately.
|
||||
|
||||
collapse_dim = 1
|
||||
join_dim = -1
|
||||
embedding_1d = tf.concat(
|
||||
# First targets are crop_x (axis 2) which we must reduce on axis 1
|
||||
[tf.concat([tf.reduce_max(activations, axis=collapse_dim),
|
||||
tf.reduce_mean(activations, axis=collapse_dim)],
|
||||
axis=join_dim),
|
||||
# Next targets are crop_y (axis 1) which we must reduce on axis 2
|
||||
tf.concat([tf.reduce_max(activations, axis=collapse_dim+1),
|
||||
tf.reduce_mean(activations, axis=collapse_dim+1)],
|
||||
axis=join_dim)],
|
||||
axis=collapse_dim) # Join the two crops together.
|
||||
if self._collapsed_batch_norm:
|
||||
embedding_1d = contrib_layers.batch_norm(
|
||||
embedding_1d,
|
||||
is_training=use_on_the_fly_stats,
|
||||
fused=True,
|
||||
decay=0.999,
|
||||
scope='collapsed_batch_norm',
|
||||
data_format='NHWC')
|
||||
for i, nfil in enumerate(self._filters_1d):
|
||||
embedding_1d = contrib_layers.fully_connected(
|
||||
embedding_1d,
|
||||
num_outputs=nfil,
|
||||
normalizer_fn=(contrib_layers.batch_norm
|
||||
if self._collapsed_batch_norm else None),
|
||||
normalizer_params={
|
||||
'is_training': use_on_the_fly_stats,
|
||||
'updates_collections': None
|
||||
},
|
||||
scope='collapsed_embed_%d' % i)
|
||||
|
||||
if self.torsion_multiplier > 0:
|
||||
self.torsion_logits = contrib_layers.fully_connected(
|
||||
embedding_1d,
|
||||
num_outputs=self._torsion_bins * self._torsion_bins,
|
||||
activation_fn=None,
|
||||
scope='torsion_logits')
|
||||
self.torsion_output = tf.nn.softmax(self.torsion_logits)
|
||||
if self.secstruct_multiplier > 0:
|
||||
self._secstruct.make_layer_new(embedding_1d)
|
||||
if self.asa_multiplier > 0:
|
||||
self.asa_logits = self._asa.compute_asa_output(embedding_1d)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _concatenate_2d(hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
|
||||
binary_code_bits, crop_size_x, crop_size_y):
|
||||
# Form the pairwise expansion of the 1D embedding
|
||||
# And the residue offsets and (one) absolute position.
|
||||
with tf.name_scope('Features2D'):
|
||||
range_scale = 100.0 # Crude normalization factor.
|
||||
n = tf.shape(hidden_1d)[1]
|
||||
# pylint: disable=g-long-lambda
|
||||
hidden_1d_cropped_y = tf.map_fn(
|
||||
call_on_tuple(lambda c, h: tf.pad(
|
||||
h[tf.maximum(0, c[0]):c[1]],
|
||||
[[tf.maximum(0, -c[0]),
|
||||
tf.maximum(0, crop_size_y -(n - c[0]))], [0, 0]])),
|
||||
elems=(crop_y, hidden_1d), dtype=tf.float32,
|
||||
back_prop=True)
|
||||
range_n_y = tf.map_fn(
|
||||
call_on_tuple(lambda ri, c: tf.pad(
|
||||
ri[tf.maximum(0, c[0]):c[1]],
|
||||
[[tf.maximum(0, -c[0]),
|
||||
tf.maximum(0, crop_size_y -(n - c[0]))]])),
|
||||
elems=(residue_index, crop_y), dtype=tf.int32,
|
||||
back_prop=False)
|
||||
hidden_1d_cropped_x = tf.map_fn(
|
||||
call_on_tuple(lambda c, h: tf.pad(
|
||||
h[tf.maximum(0, c[0]):c[1]],
|
||||
[[tf.maximum(0, -c[0]),
|
||||
tf.maximum(0, crop_size_x -(n - c[0]))], [0, 0]])),
|
||||
elems=(crop_x, hidden_1d), dtype=tf.float32,
|
||||
back_prop=True)
|
||||
range_n_x = tf.map_fn(
|
||||
call_on_tuple(lambda ri, c: tf.pad(
|
||||
ri[tf.maximum(0, c[0]):c[1]],
|
||||
[[tf.maximum(0, -c[0]),
|
||||
tf.maximum(0, crop_size_x -(n - c[0]))]])),
|
||||
elems=(residue_index, crop_x), dtype=tf.int32,
|
||||
back_prop=False)
|
||||
# pylint: enable=g-long-lambda
|
||||
n_x = crop_size_x
|
||||
n_y = crop_size_y
|
||||
|
||||
offset = (tf.expand_dims(tf.cast(range_n_x, tf.float32), 1) -
|
||||
tf.expand_dims(tf.cast(range_n_y, tf.float32), 2)) / range_scale
|
||||
position_features = [
|
||||
tf.tile(
|
||||
tf.reshape(
|
||||
(tf.cast(range_n_y, tf.float32) - range_scale) / range_scale,
|
||||
[-1, n_y, 1, 1]), [1, 1, n_x, 1],
|
||||
name='TileRange'),
|
||||
tf.tile(
|
||||
tf.reshape(offset, [-1, n_y, n_x, 1]), [1, 1, 1, 1],
|
||||
name='TileOffset')
|
||||
]
|
||||
channels = 2
|
||||
if binary_code_bits:
|
||||
# Binary coding of position.
|
||||
exp_range_n_y = tf.expand_dims(range_n_y, 2)
|
||||
bin_y = tf.stop_gradient(
|
||||
tf.concat([tf.math.floormod(exp_range_n_y // (1 << i), 2)
|
||||
for i in range(binary_code_bits)], 2))
|
||||
exp_range_n_x = tf.expand_dims(range_n_x, 2)
|
||||
bin_x = tf.stop_gradient(
|
||||
tf.concat([tf.math.floormod(exp_range_n_x // (1 << i), 2)
|
||||
for i in range(binary_code_bits)], 2))
|
||||
position_features += [
|
||||
tf.tile(
|
||||
tf.expand_dims(tf.cast(bin_y, tf.float32), 2), [1, 1, n_x, 1],
|
||||
name='TileBinRangey'),
|
||||
tf.tile(
|
||||
tf.expand_dims(tf.cast(bin_x, tf.float32), 1), [1, n_y, 1, 1],
|
||||
name='TileBinRangex')
|
||||
]
|
||||
channels += 2 * binary_code_bits
|
||||
|
||||
augmentation_features = position_features + [
|
||||
tf.tile(tf.expand_dims(hidden_1d_cropped_x, 1),
|
||||
[1, n_y, 1, 1], name='Tile1Dx'),
|
||||
tf.tile(tf.expand_dims(hidden_1d_cropped_y, 2),
|
||||
[1, 1, n_x, 1], name='Tile1Dy')]
|
||||
channels += 2 * hidden_1d.shape.as_list()[-1]
|
||||
channels += hidden_2d.shape.as_list()[-1]
|
||||
hidden_2d = tf.concat(
|
||||
[hidden_2d] + augmentation_features, 3, name='Stack2Dfeatures')
|
||||
logging.info('2d stacked features are depth %d %s', channels, hidden_2d)
|
||||
hidden_2d.set_shape([None, None, None, channels])
|
||||
return hidden_2d
|
||||
|
||||
def _build_2d_embedding(self, hidden_1d, residue_index, inputs_2d,
|
||||
output_dimension, use_on_the_fly_stats, crop_x,
|
||||
crop_y, crop_size_x, crop_size_y, data_format):
|
||||
"""Returns NHWC logits and NHWC preactivations."""
|
||||
logging.info('2d %s %s', inputs_2d, data_format)
|
||||
|
||||
# Stack with diagonal has already happened.
|
||||
inputs_2d_cropped = inputs_2d
|
||||
|
||||
features_forward = None
|
||||
hidden_2d = inputs_2d_cropped
|
||||
hidden_2d = self._concatenate_2d(
|
||||
hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
|
||||
self._binary_code_bits, crop_size_x, crop_size_y)
|
||||
|
||||
config_2d_deep = self._network_2d_deep
|
||||
num_features = hidden_2d.shape.as_list()[3]
|
||||
if data_format == 'NCHW':
|
||||
logging.info('NCHW shape deep pre %s', hidden_2d)
|
||||
hidden_2d = tf.transpose(hidden_2d, perm=[0, 3, 1, 2])
|
||||
hidden_2d.set_shape([None, num_features, None, None])
|
||||
logging.info('NCHW shape deep post %s', hidden_2d)
|
||||
layers_forward = None
|
||||
if config_2d_deep.extra_blocks:
|
||||
# Optionally put some extra double-size blocks at the beginning.
|
||||
with tf.compat.v1.variable_scope('Deep2DExtra'):
|
||||
hidden_2d = two_dim_resnet.make_two_dim_resnet(
|
||||
input_node=hidden_2d,
|
||||
num_residues=None, # Unused
|
||||
num_features=num_features,
|
||||
num_predictions=2 * config_2d_deep.num_filters,
|
||||
num_channels=2 * config_2d_deep.num_filters,
|
||||
num_layers=config_2d_deep.extra_blocks *
|
||||
config_2d_deep.num_layers_per_block,
|
||||
filter_size=3,
|
||||
batch_norm=config_2d_deep.use_batch_norm,
|
||||
is_training=use_on_the_fly_stats,
|
||||
fancy=True,
|
||||
final_non_linearity=True,
|
||||
atrou_rates=[1, 2, 4, 8],
|
||||
data_format=data_format,
|
||||
dropout_keep_prob=1.0
|
||||
)
|
||||
num_features = 2 * config_2d_deep.num_filters
|
||||
if self._skip_connect:
|
||||
layers_forward = hidden_2d
|
||||
if features_forward is not None:
|
||||
hidden_2d = tf.concat([hidden_2d, features_forward], 1
|
||||
if data_format == 'NCHW' else 3)
|
||||
with tf.compat.v1.variable_scope('Deep2D'):
|
||||
logging.info('2d hidden shape is %s', str(hidden_2d.shape.as_list()))
|
||||
contact_pre_logits = two_dim_resnet.make_two_dim_resnet(
|
||||
input_node=hidden_2d,
|
||||
num_residues=None, # Unused
|
||||
num_features=num_features,
|
||||
num_predictions=(config_2d_deep.num_filters
|
||||
if self._reshape_layer else output_dimension),
|
||||
num_channels=config_2d_deep.num_filters,
|
||||
num_layers=config_2d_deep.num_blocks *
|
||||
config_2d_deep.num_layers_per_block,
|
||||
filter_size=3,
|
||||
batch_norm=config_2d_deep.use_batch_norm,
|
||||
is_training=use_on_the_fly_stats,
|
||||
fancy=True,
|
||||
final_non_linearity=self._reshape_layer,
|
||||
atrou_rates=[1, 2, 4, 8],
|
||||
data_format=data_format,
|
||||
dropout_keep_prob=1.0
|
||||
)
|
||||
|
||||
contact_logits = self._output_from_pre_logits(
|
||||
contact_pre_logits, features_forward, layers_forward,
|
||||
output_dimension, data_format, crop_x, crop_y, use_on_the_fly_stats)
|
||||
if data_format == 'NCHW':
|
||||
contact_pre_logits = tf.transpose(contact_pre_logits, perm=[0, 2, 3, 1])
|
||||
# Both of these will be NHWC
|
||||
return contact_logits, contact_pre_logits
|
||||
|
||||
def _output_from_pre_logits(self, contact_pre_logits, features_forward,
|
||||
layers_forward, output_dimension, data_format,
|
||||
crop_x, crop_y, use_on_the_fly_stats):
|
||||
"""Given pre-logits, compute the final distogram/contact activations."""
|
||||
config_2d_deep = self._network_2d_deep
|
||||
if self._reshape_layer:
|
||||
in_channels = config_2d_deep.num_filters
|
||||
concat_features = [contact_pre_logits]
|
||||
if features_forward is not None:
|
||||
concat_features.append(features_forward)
|
||||
in_channels += self._features_forward
|
||||
if layers_forward is not None:
|
||||
concat_features.append(layers_forward)
|
||||
in_channels += 2 * config_2d_deep.num_filters
|
||||
if len(concat_features) > 1:
|
||||
contact_pre_logits = tf.concat(concat_features,
|
||||
1 if data_format == 'NCHW' else 3)
|
||||
|
||||
contact_logits = two_dim_convnet.make_conv_layer(
|
||||
contact_pre_logits,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_dimension,
|
||||
layer_name='output_reshape_1x1h',
|
||||
filter_size=1,
|
||||
filter_size_2=1,
|
||||
non_linearity=False,
|
||||
batch_norm=config_2d_deep.use_batch_norm,
|
||||
is_training=use_on_the_fly_stats,
|
||||
data_format=data_format)
|
||||
else:
|
||||
contact_logits = contact_pre_logits
|
||||
|
||||
if data_format == 'NCHW':
|
||||
contact_logits = tf.transpose(contact_logits, perm=[0, 2, 3, 1])
|
||||
|
||||
if self._position_specific_bias_size:
|
||||
# Make 2D pos-specific biases: NHWC.
|
||||
biases = build_crops_biases(
|
||||
self._position_specific_bias_size,
|
||||
self._position_specific_bias, crop_x, crop_y, back_prop=True)
|
||||
contact_logits += biases
|
||||
|
||||
# Will be NHWC.
|
||||
return contact_logits
|
||||
|
||||
def update_crop_fetches(self, fetches):
|
||||
"""Add auxiliary outputs for a crop to the fetches."""
|
||||
if self.secstruct_multiplier > 0:
|
||||
fetches['secstruct_probs'] = self._secstruct.get_q8_probs()
|
||||
if self.asa_multiplier > 0:
|
||||
fetches['asa_output'] = self._asa.asa_output
|
||||
if self.torsion_multiplier > 0:
|
||||
fetches['torsion_probs'] = self.torsion_output
|
||||
|
||||
|
||||
def build_crops_biases(bias_size, raw_biases, crop_x, crop_y, back_prop):
|
||||
"""Take the offset-specific biases and reshape them to match current crops.
|
||||
|
||||
Args:
|
||||
bias_size: how many bias variables we're storing.
|
||||
raw_biases: the bias variable
|
||||
crop_x: B x 2 array of start/end for the batch
|
||||
crop_y: B x 2 array of start/end for the batch
|
||||
back_prop: whether to backprop through the map_fn.
|
||||
|
||||
Returns:
|
||||
Reshaped biases.
|
||||
"""
|
||||
# First pad the biases with a copy of the final value to the maximum length.
|
||||
max_off_diag = tf.reduce_max(
|
||||
tf.maximum(tf.abs(crop_x[:, 1] - crop_y[:, 0]),
|
||||
tf.abs(crop_y[:, 1] - crop_x[:, 0])))
|
||||
padded_bias_size = tf.maximum(bias_size, max_off_diag)
|
||||
biases = tf.concat(
|
||||
[raw_biases,
|
||||
tf.tile(raw_biases[-1:, :],
|
||||
[padded_bias_size - bias_size, 1])], axis=0)
|
||||
# Now prepend a mirror image (excluding 0th elt) for below-diagonal.
|
||||
biases = tf.concat([tf.reverse(biases[1:, :], axis=[0]), biases], axis=0)
|
||||
|
||||
# Which diagonal of the full matrix each crop starts on (top left):
|
||||
start_diag = crop_x[:, 0:1] - crop_y[:, 0:1] # B x 1
|
||||
crop_size_x = tf.reduce_max(crop_x[:, 1] - crop_x[:, 0])
|
||||
crop_size_y = tf.reduce_max(crop_y[:, 1] - crop_y[:, 0])
|
||||
|
||||
# Relative offset of each row within a crop:
|
||||
# (off-diagonal decreases as y increases)
|
||||
increment = tf.expand_dims(-tf.range(0, crop_size_y), 0) # 1 x crop_size_y
|
||||
|
||||
# Index of diagonal of first element of each row, flattened.
|
||||
row_offsets = tf.reshape(start_diag + increment, [-1]) # B*crop_size_y
|
||||
logging.info('row_offsets %s', row_offsets)
|
||||
|
||||
# Make it relative to the start of the biases array. (0-th diagonal is in
|
||||
# the middle at position padded_bias_size - 1)
|
||||
row_offsets += padded_bias_size - 1
|
||||
|
||||
# Map_fn to build the individual rows.
|
||||
# B*cropsizey x cropsizex x num_bins
|
||||
cropped_biases = tf.map_fn(lambda i: biases[i:i+crop_size_x, :],
|
||||
elems=row_offsets, dtype=tf.float32,
|
||||
back_prop=back_prop)
|
||||
logging.info('cropped_biases %s', cropped_biases)
|
||||
return tf.reshape(
|
||||
cropped_biases, [-1, crop_size_y, crop_size_x, tf.shape(raw_biases)[-1]])
|
||||
98
alphafold_casp13/distogram_io.py
Normal file
98
alphafold_casp13/distogram_io.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Write contact map predictions to a tf.io.gfile.
|
||||
|
||||
Either write a binary contact map as an RR format text file, or a
|
||||
histogram prediction as a pickle of a dict containing a numpy array.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import six.moves.cPickle as pickle
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
RR_FORMAT = """PFRMAT RR
|
||||
TARGET {}
|
||||
AUTHOR DM-ORIGAMI-TEAM
|
||||
METHOD {}
|
||||
MODEL 1
|
||||
{}
|
||||
"""
|
||||
|
||||
|
||||
def save_rr_file(filename, probs, domain, sequence,
|
||||
method='dm-contacts-resnet'):
|
||||
"""Save a contact probability matrix as an RR file."""
|
||||
assert len(sequence) == probs.shape[0]
|
||||
assert len(sequence) == probs.shape[1]
|
||||
with tf.io.gfile.GFile(filename, 'w') as f:
|
||||
f.write(RR_FORMAT.format(domain, method, sequence))
|
||||
for i in range(probs.shape[0]):
|
||||
for j in range(i + 1, probs.shape[1]):
|
||||
f.write('{:d} {:d} {:d} {:d} {:f}\n'.format(
|
||||
i + 1, j + 1, 0, 8, probs[j, i]))
|
||||
f.write('END\n')
|
||||
|
||||
|
||||
def save_torsions(torsions_dir, filebase, sequence, torsions_probs):
|
||||
"""Save Torsions to a file as pickle of a dict."""
|
||||
filename = os.path.join(torsions_dir, filebase + '.torsions')
|
||||
t_dict = dict(probs=torsions_probs, sequence=sequence)
|
||||
with tf.io.gfile.GFile(filename, 'w') as fh:
|
||||
pickle.dump(t_dict, fh, protocol=2)
|
||||
|
||||
|
||||
def save_distance_histogram(
|
||||
filename, probs, domain, sequence, min_range, max_range, num_bins):
|
||||
"""Save a distance histogram prediction matrix as a pickle file."""
|
||||
dh_dict = {
|
||||
'min_range': min_range,
|
||||
'max_range': max_range,
|
||||
'num_bins': num_bins,
|
||||
'domain': domain,
|
||||
'sequence': sequence,
|
||||
'probs': probs.astype(np.float32)}
|
||||
save_distance_histogram_from_dict(filename, dh_dict)
|
||||
|
||||
|
||||
def save_distance_histogram_from_dict(filename, dh_dict):
|
||||
"""Save a distance histogram prediction matrix as a pickle file."""
|
||||
fields = ['min_range', 'max_range', 'num_bins', 'domain', 'sequence', 'probs']
|
||||
missing_fields = [f for f in fields if f not in dh_dict]
|
||||
assert not missing_fields, 'Fields {} missing from dictionary'.format(
|
||||
missing_fields)
|
||||
assert len(dh_dict['sequence']) == dh_dict['probs'].shape[0]
|
||||
assert len(dh_dict['sequence']) == dh_dict['probs'].shape[1]
|
||||
assert dh_dict['num_bins'] == dh_dict['probs'].shape[2]
|
||||
assert dh_dict['min_range'] >= 0.0
|
||||
assert dh_dict['max_range'] > 0.0
|
||||
with tf.io.gfile.GFile(filename, 'wb') as fw:
|
||||
pickle.dump(dh_dict, fw, protocol=2)
|
||||
|
||||
|
||||
def contact_map_from_distogram(distogram_dict):
|
||||
"""Split the boundary bin."""
|
||||
num_bins = distogram_dict['probs'].shape[-1]
|
||||
bin_size_angstrom = distogram_dict['max_range'] / num_bins
|
||||
threshold_cts = (8.0 - distogram_dict['min_range']) / bin_size_angstrom
|
||||
threshold_bin = int(threshold_cts) # Round down
|
||||
pred_contacts = np.sum(distogram_dict['probs'][:, :, :threshold_bin], axis=-1)
|
||||
if threshold_bin < threshold_cts: # Add on the fraction of the boundary bin.
|
||||
pred_contacts += distogram_dict['probs'][:, :, threshold_bin] * (
|
||||
threshold_cts - threshold_bin)
|
||||
return pred_contacts
|
||||
|
||||
125
alphafold_casp13/ensemble_contact_maps.py
Normal file
125
alphafold_casp13/ensemble_contact_maps.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Form a weighted average of several distograms.
|
||||
|
||||
Can also/instead form a weighted average of a set of distance histogram pickle
|
||||
files, so long as they have identical hyperparameters.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import distogram_io
|
||||
from alphafold_casp13 import parsers
|
||||
|
||||
flags.DEFINE_list(
|
||||
'pickle_dirs', [],
|
||||
'Comma separated list of directories with pickle files to ensemble.')
|
||||
flags.DEFINE_list(
|
||||
'weights', [],
|
||||
'Comma separated list of weights for the pickle files from different dirs.')
|
||||
flags.DEFINE_string(
|
||||
'output_dir', None, 'Directory where to save results of the evaluation.')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def ensemble_distance_histograms(pickle_dirs, weights, output_dir):
|
||||
"""Find all the contact maps in the first dir, then ensemble across dirs."""
|
||||
if len(pickle_dirs) <= 1:
|
||||
logging.warning('Pointless to ensemble %d pickle_dirs %s',
|
||||
len(pickle_dirs), pickle_dirs)
|
||||
# Carry on if there's one dir, otherwise do nothing.
|
||||
if not pickle_dirs:
|
||||
return
|
||||
|
||||
tf.io.gfile.makedirs(output_dir)
|
||||
one_dir_pickle_files = tf.io.gfile.glob(
|
||||
os.path.join(pickle_dirs[0], '*.pickle'))
|
||||
assert one_dir_pickle_files, pickle_dirs[0]
|
||||
original_files = len(one_dir_pickle_files)
|
||||
logging.info('Found %d files %d in first of %d dirs',
|
||||
original_files, len(one_dir_pickle_files), len(pickle_dirs))
|
||||
targets = [os.path.splitext(os.path.basename(f))[0]
|
||||
for f in one_dir_pickle_files]
|
||||
skipped = 0
|
||||
wrote = 0
|
||||
for t in targets:
|
||||
dump_file = os.path.join(output_dir, t + '.pickle')
|
||||
pickle_files = [os.path.join(pickle_dir, t + '.pickle')
|
||||
for pickle_dir in pickle_dirs]
|
||||
_, new_dict = ensemble_one_distance_histogram(pickle_files, weights)
|
||||
if new_dict is not None:
|
||||
wrote += 1
|
||||
distogram_io.save_distance_histogram_from_dict(dump_file, new_dict)
|
||||
msg = 'Distograms Wrote %s %d / %d Skipped %d %s' % (
|
||||
t, wrote, len(one_dir_pickle_files), skipped, dump_file)
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def ensemble_one_distance_histogram(pickle_files, weights):
|
||||
"""Average the given pickle_files and dump."""
|
||||
dicts = []
|
||||
sequence = None
|
||||
max_dim = None
|
||||
for picklefile in pickle_files:
|
||||
if not tf.io.gfile.exists(picklefile):
|
||||
logging.warning('missing %s', picklefile)
|
||||
break
|
||||
logging.info('loading pickle file %s', picklefile)
|
||||
distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
|
||||
if sequence is None:
|
||||
sequence = distance_histogram_dict['sequence']
|
||||
else:
|
||||
assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
|
||||
sequence, distance_histogram_dict['sequence'])
|
||||
dicts.append(distance_histogram_dict)
|
||||
assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
|
||||
'%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
|
||||
assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
|
||||
), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
|
||||
if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
|
||||
max_dim = dicts[-1]['probs'].shape[2]
|
||||
if len(dicts) != len(pickle_files):
|
||||
logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
|
||||
return sequence, None
|
||||
ensemble_hist = (
|
||||
sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
|
||||
new_dict = dict(dicts[0])
|
||||
new_dict['probs'] = ensemble_hist
|
||||
return sequence, new_dict
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
num_dirs = len(FLAGS.pickle_dirs)
|
||||
if FLAGS.weights:
|
||||
assert len(FLAGS.weights) == num_dirs, (
|
||||
'Supply as many weights as pickle_dirs, or no weights')
|
||||
weights = [float(w) for w in FLAGS.weights]
|
||||
else:
|
||||
weights = [1.0 for w in range(num_dirs)]
|
||||
|
||||
ensemble_distance_histograms(
|
||||
pickle_dirs=FLAGS.pickle_dirs,
|
||||
weights=weights,
|
||||
output_dir=FLAGS.output_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
67
alphafold_casp13/parsers.py
Normal file
67
alphafold_casp13/parsers.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Parsers for various standard biology or AlphaFold-specific formats."""
|
||||
|
||||
import pickle
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
def distance_histogram_dict(f):
|
||||
"""Parses distance histogram dict pickle.
|
||||
|
||||
Distance histograms are stored as pickles of dicts.
|
||||
|
||||
Write one of these with contacts/write_rr_file.write_pickle_file()
|
||||
|
||||
Args:
|
||||
f: File-like handle to distance histogram dict pickle.
|
||||
|
||||
Returns:
|
||||
Dict with fields:
|
||||
probs: (an L x L x num_bins) histogram.
|
||||
num_bins: number of bins for each residue pair
|
||||
min_range: left hand edge of the distance histogram
|
||||
max_range: the extent of the histogram NOT the right hand edge.
|
||||
"""
|
||||
contact_dict = pickle.load(f, encoding='latin1')
|
||||
|
||||
num_res = len(contact_dict['sequence'])
|
||||
|
||||
if not all(key in contact_dict.keys()
|
||||
for key in ['probs', 'num_bins', 'min_range', 'max_range']):
|
||||
raise ValueError('The pickled contact dict doesn\'t contain all required '
|
||||
'keys: probs, num_bins, min_range, max_range but %s.' %
|
||||
contact_dict.keys())
|
||||
if contact_dict['probs'].ndim != 3:
|
||||
raise ValueError(
|
||||
'Probs is not rank 3 but %d' % contact_dict['probs'].ndim)
|
||||
if contact_dict['num_bins'] != contact_dict['probs'].shape[2]:
|
||||
raise ValueError(
|
||||
'The probs shape doesn\'t match num_bins in the third dimension. '
|
||||
'Expected %d got %d.' % (contact_dict['num_bins'],
|
||||
contact_dict['probs'].shape[2]))
|
||||
if contact_dict['probs'].shape[:2] != (num_res, num_res):
|
||||
raise ValueError(
|
||||
'The first two probs dims (%i, %i) aren\'t equal to len(sequence) %i'
|
||||
% (contact_dict['probs'].shape[0], contact_dict['probs'].shape[1],
|
||||
num_res))
|
||||
return contact_dict
|
||||
|
||||
|
||||
def parse_distance_histogram_dict(filepath):
|
||||
"""Parses distance histogram piclkle from filepath."""
|
||||
with tf.io.gfile.GFile(filepath, 'rb') as f:
|
||||
return distance_histogram_dict(f)
|
||||
201
alphafold_casp13/paste_contact_maps.py
Normal file
201
alphafold_casp13/paste_contact_maps.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Combines predictions by pasting."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import distogram_io
|
||||
from alphafold_casp13 import parsers
|
||||
|
||||
flags.DEFINE_string("pickle_input_dir", None,
|
||||
"Directory to read pickle distance histogram files from.")
|
||||
flags.DEFINE_string("output_dir", None, "Directory to write chain RR files to.")
|
||||
flags.DEFINE_string("tfrecord_path", "",
|
||||
"If provided, construct the average weighted by number of "
|
||||
"alignments.")
|
||||
flags.DEFINE_string("crop_sizes", "64,128,256", "The crop sizes to use.")
|
||||
flags.DEFINE_integer("crop_step", 32, "The step size for cropping.")
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def generate_domains(target, sequence, crop_sizes, crop_step):
|
||||
"""Take fasta files and generate a domain definition for data generation."""
|
||||
logging.info("Generating crop domains for target %s", target)
|
||||
|
||||
windows = [int(x) for x in crop_sizes.split(",")]
|
||||
num_residues = len(sequence)
|
||||
domains = []
|
||||
domains.append({"name": target, "description": (1, num_residues)})
|
||||
|
||||
for window in windows:
|
||||
starts = list(range(0, num_residues - window, crop_step))
|
||||
# Append a last crop to ensure we get all the way to the end of the
|
||||
# sequence, even when num_residues - window is not divisible by crop_step.
|
||||
if num_residues >= window:
|
||||
starts += [num_residues - window]
|
||||
for start in starts:
|
||||
name = "%s-l%i_s%i" % (target, window, start)
|
||||
domains.append({"name": name, "description": (start + 1, start + window)})
|
||||
return domains
|
||||
|
||||
|
||||
def get_weights(path):
|
||||
"""Fetch all the weights from a TFRecord."""
|
||||
if not path:
|
||||
return {}
|
||||
logging.info("Getting weights from %s", path)
|
||||
weights = {}
|
||||
record_iterator = tf.python_io.tf_record_iterator(path=path)
|
||||
for serialized_tfexample in record_iterator:
|
||||
example = tf.train.Example()
|
||||
example.ParseFromString(serialized_tfexample)
|
||||
|
||||
domain_name = six.ensure_str(
|
||||
example.features.feature["domain_name"].bytes_list.value[0])
|
||||
weights[domain_name] = float(
|
||||
example.features.feature["num_alignments"].int64_list.value[0])
|
||||
logging.info("Weight %s: %d", domain_name, weights[domain_name])
|
||||
logging.info("Loaded %d weights", len(weights))
|
||||
return weights
|
||||
|
||||
|
||||
def paste_distance_histograms(
|
||||
input_dir, output_dir, weights, crop_sizes, crop_step):
|
||||
"""Paste together distograms for given domains of given targets and write.
|
||||
|
||||
Domains distance histograms are 'pasted', meaning they are substituted
|
||||
directly into the contact map. The order is determined by the order in the
|
||||
domain definition file.
|
||||
|
||||
Args:
|
||||
input_dir: String, path to directory containing chain and domain-level
|
||||
distogram files.
|
||||
output_dir: String, path to directory to write out chain-level distrogram
|
||||
files.
|
||||
weights: A dictionary with weights.
|
||||
crop_sizes: The crop sizes.
|
||||
crop_step: The step size for cropping.
|
||||
|
||||
Raises:
|
||||
ValueError: if histogram parameters don't match.
|
||||
"""
|
||||
tf.io.gfile.makedirs(output_dir)
|
||||
|
||||
targets = tf.io.gfile.glob(os.path.join(input_dir, "*.pickle"))
|
||||
targets = [os.path.splitext(os.path.basename(t))[0] for t in targets]
|
||||
targets = set([t.split("-")[0] for t in targets])
|
||||
logging.info("Pasting distance histograms for %d targets", len(targets))
|
||||
|
||||
for target in sorted(targets):
|
||||
logging.info("%s as chain", target)
|
||||
|
||||
chain_pickle_path = os.path.join(input_dir, "%s.pickle" % target)
|
||||
distance_histogram_dict = parsers.parse_distance_histogram_dict(
|
||||
chain_pickle_path)
|
||||
|
||||
combined_cmap = np.array(distance_histogram_dict["probs"])
|
||||
# Make the counter map 1-deep but still rank 3.
|
||||
counter_map = np.ones_like(combined_cmap[:, :, 0:1])
|
||||
|
||||
sequence = distance_histogram_dict["sequence"]
|
||||
|
||||
target_domains = generate_domains(
|
||||
target=target, sequence=sequence, crop_sizes=crop_sizes,
|
||||
crop_step=crop_step)
|
||||
|
||||
# Paste in each domain.
|
||||
for domain in sorted(target_domains, key=lambda x: x["name"]):
|
||||
if domain["name"] == target:
|
||||
logging.info("Skipping %s as domain", target)
|
||||
continue
|
||||
|
||||
if "," in domain["description"]:
|
||||
logging.info("Skipping multisegment domain %s",
|
||||
domain["name"])
|
||||
continue
|
||||
|
||||
crop_start, crop_end = domain["description"]
|
||||
|
||||
domain_pickle_path = os.path.join(input_dir, "%s.pickle" % domain["name"])
|
||||
|
||||
weight = weights.get(domain["name"], 1e9)
|
||||
|
||||
logging.info("Pasting %s: %d-%d. weight: %f", domain_pickle_path,
|
||||
crop_start, crop_end, weight)
|
||||
|
||||
domain_distance_histogram_dict = parsers.parse_distance_histogram_dict(
|
||||
domain_pickle_path)
|
||||
for field in ["num_bins", "min_range", "max_range"]:
|
||||
if domain_distance_histogram_dict[field] != distance_histogram_dict[
|
||||
field]:
|
||||
raise ValueError("Field {} does not match {} {}".format(
|
||||
field,
|
||||
domain_distance_histogram_dict[field],
|
||||
distance_histogram_dict[field]))
|
||||
weight_matrix_size = crop_end - crop_start + 1
|
||||
weight_matrix = np.ones(
|
||||
(weight_matrix_size, weight_matrix_size), dtype=np.float32) * weight
|
||||
combined_cmap[crop_start - 1:crop_end, crop_start - 1:crop_end, :] += (
|
||||
domain_distance_histogram_dict["probs"] *
|
||||
np.expand_dims(weight_matrix, 2))
|
||||
counter_map[crop_start - 1:crop_end,
|
||||
crop_start - 1:crop_end, 0] += weight_matrix
|
||||
|
||||
# Broadcast across the histogram bins.
|
||||
combined_cmap /= counter_map
|
||||
|
||||
# Write out full-chain cmap for folding.
|
||||
output_chain_pickle_path = os.path.join(output_dir,
|
||||
"{}.pickle".format(target))
|
||||
|
||||
logging.info("Writing to %s", output_chain_pickle_path)
|
||||
|
||||
distance_histogram_dict["probs"] = combined_cmap
|
||||
distance_histogram_dict["target"] = target
|
||||
|
||||
# Save the distogram pickle file.
|
||||
distogram_io.save_distance_histogram_from_dict(
|
||||
output_chain_pickle_path, distance_histogram_dict)
|
||||
|
||||
# Compute the contact map and save it as an RR file.
|
||||
contact_probs = distogram_io.contact_map_from_distogram(
|
||||
distance_histogram_dict)
|
||||
rr_path = os.path.join(output_dir, "%s.rr" % target)
|
||||
distogram_io.save_rr_file(
|
||||
filename=rr_path,
|
||||
probs=contact_probs,
|
||||
domain=target,
|
||||
sequence=distance_histogram_dict["sequence"])
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
flags.mark_flag_as_required("pickle_input_dir")
|
||||
|
||||
weights = get_weights(FLAGS.tfrecord_path)
|
||||
|
||||
paste_distance_histograms(
|
||||
FLAGS.pickle_input_dir, FLAGS.output_dir, weights, FLAGS.crop_sizes,
|
||||
FLAGS.crop_step)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
6
alphafold_casp13/requirements.txt
Normal file
6
alphafold_casp13/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
absl-py>=0.8.0
|
||||
numpy>=1.16
|
||||
six>=1.12
|
||||
dm-sonnet>=1.35
|
||||
tensorflow==1.14
|
||||
tensorflow-probability==0.7.0
|
||||
102
alphafold_casp13/run_eval.sh
Executable file
102
alphafold_casp13/run_eval.sh
Executable file
@@ -0,0 +1,102 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
|
||||
# We assume the script is being run from the deepmind_research parent directory.
|
||||
|
||||
DISTOGRAM_MODEL="alphafold_casp13/873731" # Path to the directory with the distogram model.
|
||||
BACKGROUND_MODEL="alphafold_casp13/916425" # Path to the directory with the background model.
|
||||
TORSION_MODEL="alphafold_casp13/941521" # Path to the directory with the torsion model.
|
||||
|
||||
TARGET="T1019s2" # The name of the target.
|
||||
TARGET_PATH="alphafold_casp13/${TARGET}" # Path to the directory with the target input data.
|
||||
|
||||
# Set up the virtual environment and install dependencies.
|
||||
python3 -m venv alphafold_venv
|
||||
source alphafold_venv/bin/activate
|
||||
pip install -r alphafold_casp13/requirements.txt
|
||||
|
||||
# Create the output directory.
|
||||
OUTPUT_DIR="${HOME}/contacts_${TARGET}_$(date +%Y_%m_%d_%H_%M_%S)"
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
echo "Saving output to ${OUTPUT_DIR}/"
|
||||
|
||||
# Run contact prediction over 4 replicas.
|
||||
for replica in 0 1 2 3; do
|
||||
echo "Launching all models for replica ${replica}"
|
||||
|
||||
# Run the distogram model.
|
||||
python3 -m alphafold_casp13.contacts \
|
||||
--logtostderr \
|
||||
--cpu=true \
|
||||
--config_path="${DISTOGRAM_MODEL}/${replica}/config.json" \
|
||||
--checkpoint_path="${DISTOGRAM_MODEL}/${replica}/tf_graph_data/tf_graph_data.ckpt" \
|
||||
--output_path="${OUTPUT_DIR}/distogram/${replica}" \
|
||||
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
|
||||
--stats_file="${DISTOGRAM_MODEL}/stats_train_s35.json" &
|
||||
|
||||
# Run the background model.
|
||||
python3 -m alphafold_casp13.contacts \
|
||||
--logtostderr \
|
||||
--cpu=true \
|
||||
--config_path="${BACKGROUND_MODEL}/${replica}/config.json" \
|
||||
--checkpoint_path="${BACKGROUND_MODEL}/${replica}/tf_graph_data/tf_graph_data.ckpt" \
|
||||
--output_path="${OUTPUT_DIR}/background_distogram/${replica}" \
|
||||
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
|
||||
--stats_file="${BACKGROUND_MODEL}/stats_train_s35.json" &
|
||||
done
|
||||
|
||||
# Run the torsion model, but only 1 replica.
|
||||
python3 -m alphafold_casp13.contacts \
|
||||
--logtostderr \
|
||||
--cpu=true \
|
||||
--config_path="${TORSION_MODEL}/0/config.json" \
|
||||
--checkpoint_path="${TORSION_MODEL}/0/tf_graph_data/tf_graph_data.ckpt" \
|
||||
--output_path="${OUTPUT_DIR}/torsion/0" \
|
||||
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
|
||||
--stats_file="${TORSION_MODEL}/stats_train_s35.json" &
|
||||
|
||||
echo "All models running, waiting for them to complete"
|
||||
wait
|
||||
|
||||
echo "Ensembling all replica outputs"
|
||||
|
||||
# Run the ensembling jobs for distograms, background distograms.
|
||||
for output_dir in "${OUTPUT_DIR}/distogram" "${OUTPUT_DIR}/background_distogram"; do
|
||||
pickle_dirs="${output_dir}/0/pickle_files/,${output_dir}/1/pickle_files/,${output_dir}/2/pickle_files/,${output_dir}/3/pickle_files/"
|
||||
|
||||
# Ensemble distograms.
|
||||
python3 -m alphafold_casp13.ensemble_contact_maps \
|
||||
--logtostderr \
|
||||
--pickle_dirs="${pickle_dirs}" \
|
||||
--output_dir="${output_dir}/ensemble/"
|
||||
done
|
||||
|
||||
# Only ensemble single replica distogram for torsions.
|
||||
python3 -m alphafold_casp13.ensemble_contact_maps \
|
||||
--logtostderr \
|
||||
--pickle_dirs="${OUTPUT_DIR}/torsion/0/pickle_files/" \
|
||||
--output_dir="${OUTPUT_DIR}/torsion/ensemble/"
|
||||
|
||||
echo "Pasting contact maps"
|
||||
|
||||
python3 -m alphafold_casp13.paste_contact_maps \
|
||||
--logtostderr \
|
||||
--pickle_input_dir="${OUTPUT_DIR}/distogram/ensemble/" \
|
||||
--output_dir="${OUTPUT_DIR}/pasted/" \
|
||||
--tfrecord_path="${TARGET_PATH}/${TARGET}.tfrec"
|
||||
|
||||
echo "Done"
|
||||
94
alphafold_casp13/secstruct.py
Normal file
94
alphafold_casp13/secstruct.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Layer for modelling and scoring secondary structure."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.contrib import layers as contrib_layers
|
||||
|
||||
|
||||
# 8-class classes (Q8)
|
||||
SECONDARY_STRUCTURES = '-HETSGBI'
|
||||
|
||||
# Equivalence classes for 3-class (Q3) from Li & Yu 2016.
|
||||
# See http://www.cmbi.ru.nl/dssp.html for letter explanations.
|
||||
Q3_MAP = ['-TSGIB', 'H', 'E']
|
||||
|
||||
|
||||
def make_q3_matrices():
|
||||
"""Generate mapping matrices for secstruct Q8:Q3 equivalence classes."""
|
||||
dimension = len(SECONDARY_STRUCTURES)
|
||||
q3_map_matrix = np.zeros((dimension, len(Q3_MAP)))
|
||||
q3_lookup = np.zeros((dimension,), dtype=np.int32)
|
||||
for i, eclass in enumerate(Q3_MAP): # equivalence classes
|
||||
for m in eclass: # Members of the class.
|
||||
ss_type = SECONDARY_STRUCTURES.index(m)
|
||||
q3_map_matrix[ss_type, i] = 1.0
|
||||
q3_lookup[ss_type] = i
|
||||
return q3_map_matrix, q3_lookup
|
||||
|
||||
|
||||
class Secstruct(object):
|
||||
"""Make a layer that computes hierarchical secstruct."""
|
||||
# Build static, shared structures:
|
||||
q3_map_matrix, q3_lookup = make_q3_matrices()
|
||||
static_dimension = len(SECONDARY_STRUCTURES)
|
||||
|
||||
def __init__(self, name='secstruct'):
|
||||
self.name = name
|
||||
self._dimension = Secstruct.static_dimension
|
||||
|
||||
def make_layer_new(self, activations):
|
||||
"""Make the layer."""
|
||||
with tf.compat.v1.variable_scope(self.name, reuse=tf.AUTO_REUSE):
|
||||
logging.info('Creating secstruct %s', activations)
|
||||
self.logits = contrib_layers.linear(activations, self._dimension)
|
||||
self.ss_q8_probs = tf.nn.softmax(self.logits)
|
||||
self.ss_q3_probs = tf.matmul(
|
||||
self.ss_q8_probs, tf.constant(self.q3_map_matrix, dtype=tf.float32))
|
||||
|
||||
def get_q8_probs(self):
|
||||
return self.ss_q8_probs
|
||||
|
||||
|
||||
def save_secstructs(dump_dir_path, name, index, sequence, probs,
|
||||
label='Deepmind secstruct'):
|
||||
"""Write secstruct prob distributions to an ss2 file.
|
||||
|
||||
Can be overloaded to write out asa values too.
|
||||
|
||||
Args:
|
||||
dump_dir_path: directory where to write files.
|
||||
name: name of domain
|
||||
index: index number of multiple samples. (or None for no index)
|
||||
sequence: string of L residue labels
|
||||
probs: L x D matrix of probabilities. L is length of sequence,
|
||||
D is probability dimension (usually 3).
|
||||
label: A label for the file.
|
||||
"""
|
||||
filename = os.path.join(dump_dir_path, '%s.ss2' % name)
|
||||
if index is not None:
|
||||
filename = os.path.join(dump_dir_path, '%s_%04d.ss2' % (name, index))
|
||||
with tf.io.gfile.GFile(filename, 'w') as gf:
|
||||
logging.info('Saving secstruct to %s', filename)
|
||||
gf.write('# %s CLASSES [%s] %s sample %s\n\n' % (
|
||||
label, ''.join(SECONDARY_STRUCTURES[:probs.shape[1]]), name, index))
|
||||
for l in range(probs.shape[0]):
|
||||
ss = SECONDARY_STRUCTURES[np.argmax(probs[l, :])]
|
||||
gf.write('%4d %1s %1s %s\n' % (l + 1, sequence[l], ss, ''.join(
|
||||
[('%6.3f' % p) for p in probs[l, :]])))
|
||||
1820
alphafold_casp13/test_domains.txt
Normal file
1820
alphafold_casp13/test_domains.txt
Normal file
File diff suppressed because it is too large
Load Diff
29426
alphafold_casp13/train_domains.txt
Normal file
29426
alphafold_casp13/train_domains.txt
Normal file
File diff suppressed because it is too large
Load Diff
139
alphafold_casp13/two_dim_convnet.py
Normal file
139
alphafold_casp13/two_dim_convnet.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""One dim convnet for angles prediction."""
|
||||
|
||||
from absl import logging
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.contrib import layers as contrib_layers
|
||||
|
||||
|
||||
def weight_variable(shape, stddev=0.01):
|
||||
"""Returns the weight variable."""
|
||||
logging.vlog(1, 'weight init for shape %s', str(shape))
|
||||
return tf.compat.v1.get_variable(
|
||||
'w', shape, initializer=tf.random_normal_initializer(stddev=stddev))
|
||||
|
||||
|
||||
def bias_variable(shape):
|
||||
return tf.compat.v1.get_variable(
|
||||
'b', shape, initializer=tf.zeros_initializer())
|
||||
|
||||
|
||||
def conv2d(x, w, atrou_rate=1, data_format='NHWC'):
|
||||
if atrou_rate > 1:
|
||||
return tf.nn.convolution(
|
||||
x,
|
||||
w,
|
||||
dilation_rate=[atrou_rate] * 2,
|
||||
padding='SAME',
|
||||
data_format=data_format)
|
||||
else:
|
||||
return tf.nn.conv2d(
|
||||
x, w, strides=[1, 1, 1, 1], padding='SAME', data_format=data_format)
|
||||
|
||||
|
||||
def make_conv_sep2d_layer(input_node,
|
||||
in_channels,
|
||||
channel_multiplier,
|
||||
out_channels,
|
||||
layer_name,
|
||||
filter_size,
|
||||
filter_size_2=None,
|
||||
batch_norm=False,
|
||||
is_training=True,
|
||||
atrou_rate=1,
|
||||
data_format='NHWC',
|
||||
stddev=0.01):
|
||||
"""Use separable convolutions."""
|
||||
if filter_size_2 is None:
|
||||
filter_size_2 = filter_size
|
||||
logging.vlog(1, 'layer %s in %d out %d chan mult %d', layer_name, in_channels,
|
||||
out_channels, channel_multiplier)
|
||||
with tf.compat.v1.variable_scope(layer_name):
|
||||
with tf.compat.v1.variable_scope('depthwise'):
|
||||
w_depthwise = weight_variable(
|
||||
[filter_size, filter_size_2, in_channels, channel_multiplier],
|
||||
stddev=stddev)
|
||||
with tf.compat.v1.variable_scope('pointwise'):
|
||||
w_pointwise = weight_variable(
|
||||
[1, 1, in_channels * channel_multiplier, out_channels], stddev=stddev)
|
||||
h_conv = tf.nn.separable_conv2d(
|
||||
input_node,
|
||||
w_depthwise,
|
||||
w_pointwise,
|
||||
padding='SAME',
|
||||
strides=[1, 1, 1, 1],
|
||||
rate=[atrou_rate, atrou_rate],
|
||||
data_format=data_format)
|
||||
|
||||
if batch_norm:
|
||||
h_conv = batch_norm_layer(
|
||||
h_conv, layer_name=layer_name, is_training=is_training,
|
||||
data_format=data_format)
|
||||
else:
|
||||
b_conv = bias_variable([out_channels])
|
||||
h_conv = tf.nn.bias_add(h_conv, b_conv, data_format=data_format)
|
||||
|
||||
return h_conv
|
||||
|
||||
|
||||
def batch_norm_layer(h_conv, layer_name, is_training=True, data_format='NCHW'):
|
||||
"""Batch norm layer."""
|
||||
logging.vlog(1, 'batch norm for layer %s', layer_name)
|
||||
return contrib_layers.batch_norm(
|
||||
h_conv,
|
||||
is_training=is_training,
|
||||
fused=True,
|
||||
decay=0.999,
|
||||
scope=layer_name,
|
||||
data_format=data_format)
|
||||
|
||||
|
||||
def make_conv_layer(input_node,
|
||||
in_channels,
|
||||
out_channels,
|
||||
layer_name,
|
||||
filter_size,
|
||||
filter_size_2=None,
|
||||
non_linearity=True,
|
||||
batch_norm=False,
|
||||
is_training=True,
|
||||
atrou_rate=1,
|
||||
data_format='NHWC',
|
||||
stddev=0.01):
|
||||
"""Creates a convolution layer."""
|
||||
|
||||
if filter_size_2 is None:
|
||||
filter_size_2 = filter_size
|
||||
logging.vlog(
|
||||
1, 'layer %s in %d out %d', layer_name, in_channels, out_channels)
|
||||
with tf.compat.v1.variable_scope(layer_name):
|
||||
w_conv = weight_variable(
|
||||
[filter_size, filter_size_2, in_channels, out_channels], stddev=stddev)
|
||||
h_conv = conv2d(
|
||||
input_node, w_conv, atrou_rate=atrou_rate, data_format=data_format)
|
||||
|
||||
if batch_norm:
|
||||
h_conv = batch_norm_layer(
|
||||
h_conv, layer_name=layer_name, is_training=is_training,
|
||||
data_format=data_format)
|
||||
else:
|
||||
b_conv = bias_variable([out_channels])
|
||||
h_conv = tf.nn.bias_add(h_conv, b_conv, data_format=data_format)
|
||||
|
||||
if non_linearity:
|
||||
h_conv = tf.nn.elu(h_conv)
|
||||
|
||||
return h_conv
|
||||
202
alphafold_casp13/two_dim_resnet.py
Normal file
202
alphafold_casp13/two_dim_resnet.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Lint as: python3.
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""2D Resnet."""
|
||||
|
||||
from absl import logging
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from alphafold_casp13 import two_dim_convnet
|
||||
|
||||
|
||||
def make_sep_res_layer(
|
||||
input_node,
|
||||
in_channels,
|
||||
out_channels,
|
||||
layer_name,
|
||||
filter_size,
|
||||
filter_size_2=None,
|
||||
batch_norm=False,
|
||||
is_training=True,
|
||||
divide_channels_by=2,
|
||||
atrou_rate=1,
|
||||
channel_multiplier=0,
|
||||
data_format='NHWC',
|
||||
stddev=0.01,
|
||||
dropout_keep_prob=1.0):
|
||||
"""A separable resnet block."""
|
||||
|
||||
with tf.name_scope(layer_name):
|
||||
input_times_almost_1 = input_node
|
||||
h_conv = input_times_almost_1
|
||||
|
||||
if batch_norm:
|
||||
h_conv = two_dim_convnet.batch_norm_layer(
|
||||
h_conv, layer_name=layer_name, is_training=is_training,
|
||||
data_format=data_format)
|
||||
|
||||
h_conv = tf.nn.elu(h_conv)
|
||||
|
||||
if filter_size_2 is None:
|
||||
filter_size_2 = filter_size
|
||||
|
||||
# 1x1 with half size
|
||||
h_conv = two_dim_convnet.make_conv_layer(
|
||||
h_conv,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels / divide_channels_by,
|
||||
layer_name=layer_name + '_1x1h',
|
||||
filter_size=1,
|
||||
filter_size_2=1,
|
||||
non_linearity=True,
|
||||
batch_norm=batch_norm,
|
||||
is_training=is_training,
|
||||
data_format=data_format,
|
||||
stddev=stddev)
|
||||
|
||||
# 3x3 with half size
|
||||
if channel_multiplier == 0:
|
||||
h_conv = two_dim_convnet.make_conv_layer(
|
||||
h_conv,
|
||||
in_channels=in_channels / divide_channels_by,
|
||||
out_channels=in_channels / divide_channels_by,
|
||||
layer_name=layer_name + '_%dx%dh' % (filter_size, filter_size_2),
|
||||
filter_size=filter_size,
|
||||
filter_size_2=filter_size_2,
|
||||
non_linearity=True,
|
||||
batch_norm=batch_norm,
|
||||
is_training=is_training,
|
||||
atrou_rate=atrou_rate,
|
||||
data_format=data_format,
|
||||
stddev=stddev)
|
||||
else:
|
||||
# We use separable convolution for 3x3
|
||||
h_conv = two_dim_convnet.make_conv_sep2d_layer(
|
||||
h_conv,
|
||||
in_channels=in_channels / divide_channels_by,
|
||||
channel_multiplier=channel_multiplier,
|
||||
out_channels=in_channels / divide_channels_by,
|
||||
layer_name=layer_name + '_sep%dx%dh' % (filter_size, filter_size_2),
|
||||
filter_size=filter_size,
|
||||
filter_size_2=filter_size_2,
|
||||
batch_norm=batch_norm,
|
||||
is_training=is_training,
|
||||
atrou_rate=atrou_rate,
|
||||
data_format=data_format,
|
||||
stddev=stddev)
|
||||
|
||||
# 1x1 back to normal size without relu
|
||||
h_conv = two_dim_convnet.make_conv_layer(
|
||||
h_conv,
|
||||
in_channels=in_channels / divide_channels_by,
|
||||
out_channels=out_channels,
|
||||
layer_name=layer_name + '_1x1',
|
||||
filter_size=1,
|
||||
filter_size_2=1,
|
||||
non_linearity=False,
|
||||
batch_norm=False,
|
||||
is_training=is_training,
|
||||
data_format=data_format,
|
||||
stddev=stddev)
|
||||
|
||||
if dropout_keep_prob < 1.0:
|
||||
logging.info('dropout keep prob %f', dropout_keep_prob)
|
||||
h_conv = tf.nn.dropout(h_conv, keep_prob=dropout_keep_prob)
|
||||
|
||||
return h_conv + input_times_almost_1
|
||||
|
||||
|
||||
def make_two_dim_resnet(
|
||||
input_node,
|
||||
num_residues=50,
|
||||
num_features=40,
|
||||
num_predictions=1,
|
||||
num_channels=32,
|
||||
num_layers=2,
|
||||
filter_size=3,
|
||||
filter_size_2=None,
|
||||
final_non_linearity=False,
|
||||
name_prefix='',
|
||||
fancy=True,
|
||||
batch_norm=False,
|
||||
is_training=False,
|
||||
atrou_rates=None,
|
||||
channel_multiplier=0,
|
||||
divide_channels_by=2,
|
||||
resize_features_with_1x1=False,
|
||||
data_format='NHWC',
|
||||
stddev=0.01,
|
||||
dropout_keep_prob=1.0):
|
||||
"""Two dim resnet towers."""
|
||||
del num_residues # Unused.
|
||||
|
||||
if atrou_rates is None:
|
||||
atrou_rates = [1]
|
||||
if not fancy:
|
||||
raise ValueError('non fancy deprecated')
|
||||
|
||||
logging.info('atrou rates %s', atrou_rates)
|
||||
|
||||
logging.info('name prefix %s', name_prefix)
|
||||
x_image = input_node
|
||||
previous_layer = x_image
|
||||
non_linearity = True
|
||||
for i_layer in range(num_layers):
|
||||
in_channels = num_channels
|
||||
out_channels = num_channels
|
||||
|
||||
curr_atrou_rate = atrou_rates[i_layer % len(atrou_rates)]
|
||||
|
||||
if i_layer == 0:
|
||||
in_channels = num_features
|
||||
if i_layer == num_layers - 1:
|
||||
out_channels = num_predictions
|
||||
non_linearity = final_non_linearity
|
||||
if i_layer == 0 or i_layer == num_layers - 1:
|
||||
layer_name = name_prefix + 'conv%d' % (i_layer + 1)
|
||||
initial_filter_size = filter_size
|
||||
if resize_features_with_1x1:
|
||||
initial_filter_size = 1
|
||||
previous_layer = two_dim_convnet.make_conv_layer(
|
||||
input_node=previous_layer,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
layer_name=layer_name,
|
||||
filter_size=initial_filter_size,
|
||||
filter_size_2=filter_size_2,
|
||||
non_linearity=non_linearity,
|
||||
atrou_rate=curr_atrou_rate,
|
||||
data_format=data_format,
|
||||
stddev=stddev)
|
||||
else:
|
||||
layer_name = name_prefix + 'res%d' % (i_layer + 1)
|
||||
previous_layer = make_sep_res_layer(
|
||||
input_node=previous_layer,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
layer_name=layer_name,
|
||||
filter_size=filter_size,
|
||||
filter_size_2=filter_size_2,
|
||||
batch_norm=batch_norm,
|
||||
is_training=is_training,
|
||||
atrou_rate=curr_atrou_rate,
|
||||
channel_multiplier=channel_multiplier,
|
||||
divide_channels_by=divide_channels_by,
|
||||
data_format=data_format,
|
||||
stddev=stddev,
|
||||
dropout_keep_prob=dropout_keep_prob)
|
||||
|
||||
y = previous_layer
|
||||
|
||||
return y
|
||||
@@ -17,7 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf
|
||||
from scratchgan import utils
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_gan as tfgan
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.io import gfile
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
from scratchgan import discriminator_nets
|
||||
from scratchgan import eval_metrics
|
||||
from scratchgan import generators
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_probability as tfp
|
||||
from scratchgan import utils
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
def sequential_cross_entropy_loss(logits, expected):
|
||||
|
||||
@@ -22,7 +22,7 @@ import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from tensorflow.io import gfile
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
|
||||
# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32
|
||||
# lengths: [N, 2] array of int32, such that
|
||||
|
||||
@@ -19,8 +19,8 @@ import math
|
||||
import os
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.io import gfile
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
from scratchgan import reader
|
||||
|
||||
EVAL_FILENAME = "evaluated_checkpoints.csv"
|
||||
|
||||
Reference in New Issue
Block a user