Explicitly replace "import tensorflow" with "tensorflow.compat.v1"

PiperOrigin-RevId: 287568660
This commit is contained in:
Deepmind Team
2019-12-30 18:35:38 +00:00
committed by Diego de Las Casas
parent 5b22f16532
commit 8b47d2a576
25 changed files with 34028 additions and 9 deletions

151
alphafold_casp13/README.md Normal file
View File

@@ -0,0 +1,151 @@
# AlphaFold
This package provides an implementation of the contact prediction network,
associated model weights and CASP13 dataset as published in Nature.
Any publication that discloses findings arising from using this source code must
cite *AlphaFold: Protein structure prediction using potentials from deep
learning* by Andrew W. Senior, Richard Evans, John Jumper, James Kirkpatrick,
Laurent Sifre, Tim Green, Chongli Qin, Augustin Žídek, Alexander W. R. Nelson,
Alex Bridgland, Hugo Penedones, Stig Petersen, Karen Simonyan, Steve Crossan,
Pushmeet Kohli, David T. Jones, David Silver, Koray Kavukcuoglu, Demis Hassabis.
## Setup
### Dependencies
* Python 3.6+.
* [Abseil 0.8.0+](https://github.com/abseil/abseil-py)
* [Numpy 1.16+](https://numpy.org)
* [Six 1.12+](https://pypi.org/project/six/)
* [Sonnet 1.35+](https://github.com/deepmind/sonnet)
* [TensorFlow 1.14](https://tensorflow.org). Not compatible with TensorFlow
2.0+.
* [TensorFlow Probability 0.7.0](https://www.tensorflow.org/probability)
You can set up Python virtual environment with these dependencies inside the
forked `deepmind_research` repository using:
```shell
python3 -m venv alphafold_venv
source alphafold_venv/bin/activate
pip install -r alphafold_casp13/requirements.txt
```
### Input data
The dataset can be downloaded from
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/alphafold_casp13_data).
Download it e.g. using `wget`:
```shell
wget https://storage.googleapis.com/alphafold_casp13_data/casp13_data.zip
```
The zip file contains 1 directory for each CASP13 target and a `LICENSE.md`
file. Each target directory contains the following files:
1. `TARGET.tfrec` file. This is a
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file
with serialized tf.train.Example protocol buffers that contain the features
needed to run the model.
1. `contacts/TARGET.pickle` file(s) with the predicted distogram.
1. `contacts/TARGET.rr` file(s) with the contact map derived from the predicted
distogram. The RR format is described on the
[CASP website](http://predictioncenter.org/casp13/index.cgi?page=format#RR).
Note that for **T0999** the target was manually split based on hits in HHSearch
into 5 sub-targets, hence there are 5 distograms
(`contacts/T0999s{1,2,3,4,5}.pickle`) and 5 RR files
(`contacts/T0999s{1,2,3,4,5}.rr`).
The `contacts/` folder is not needed to run the model, these files are included
only for convenience so that you don't need to run the inference for CASP13
targets to get the contact map.
### Model checkpoints
The model checkpoints can be downloaded from
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/alphafold_casp13_data).
Download them e.g. using `wget`:
```shell
wget https://storage.googleapis.com/alphafold_casp13_data/alphafold_casp13_weights.zip
```
The zip file contains:
1. A directory `873731`. This contains the weights for the distogram model.
1. A directory `916425`. This contains the weights for the background distogram
model.
1. A directory `941521`. This contains the weights for the torsion model.
1. `LICENSE.md`. The model checkpoints have a non-commercial license which is
defined in this file.
Each directory with model weights contains a number of different model
configurations. Each model has a config file and associated weights. There is
only one torsion model. Each model directory also contains a stats file that is
used for feature normalization specific to that model.
## Distogram prediction
### Running the system
You can use the `run_eval.sh` script to run the entire Distogram prediction
system. There are a few steps you need to start with:
1. Download the input data as described above. Unpack the data in the
directory with the code.
1. Download the model checkpoints as described above. Unpack the data.
1. In `run_eval.sh` set the following:
* `DISTOGRAM_MODEL` to the path to the directory with the distogram model.
* `BACKGROUND_MODEL` to the path to the directory with the background
model.
* `TORSION_MODEL` to the path to the directory with the torsion model.
* `TARGET` to the path to the directory with the target input data.
Then run `alphafold_casp13/run_eval.sh` from the `deepmind_research` parent
directory (you will get errors if you try running `run_eval.sh` directly from
the `alphafold_casp13` directory).
The contact prediction works in the following way:
1. 4 replicas (by *replica* we mean a configuration file describing the network
architecture and a snapshot with the network weights), each with slightly
different model configuration, are launched to predict the distogram.
1. 4 replicas, each with slightly different model configuration are launched to
predict the background distogram.
1. 1 replica is launched to predict the torsions.
1. The predictions from the different replicas are averaged together using
`ensemble_contact_maps.py`.
1. The predictions for the 64 × 64 distogram crops are pasted together using
`paste_contact_maps.py`.
When running `run_eval.sh` the output has the following directory structure:
* **distogram/**: Contains 4 subfolders, one for each replica. Each of these
contain the predicted ASA, secondary structure and a pickle file with the
distogram for each crop. It also contains an `ensemble` directory with the
ensembled distograms.
* **background_distogram/**: Contains 4 subfolders, one for each replica. Each
of these contain a pickle file with the background distogram for each crop.
It also contains an `ensemble` directory with the ensembled background
distograms.
* **torsion/**: Contains 1 subfolder as there was only a single replica. This
folder contains contains the predicted ASA, secondary structure, backbone
torsions and a pickle file with the distogram for each crop. It also
contains an `ensemble` directory with the ensembled torsions.
* **pasted/**: Contains distograms obtained from the ensembled distograms by
pasting. An RR contact map file is computed from this pasted distogram.
**This is the final distogram that was used in the subsequent AlphaFold
folding pipeline in CASP13.**
## Data splits
We used a version of [PDB](https://www.rcsb.org/) downloaded on 2018-03-15. The
train/test split can be found in the `train_domains.txt` and `test_domains.txt`
files.
Disclaimer: This is not an official Google product.

View File

@@ -0,0 +1,36 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for predicting Accessible Surface Area."""
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
class ASAOutputLayer(object):
"""An output layer to predict Accessible Surface Area."""
def __init__(self, name='asa'):
self.name = name
def compute_asa_output(self, activations):
"""Just compute the logits and outputs given activations."""
asa_logits = contrib_layers.linear(
activations,
1,
weights_initializer=tf.random_uniform_initializer(-0.01, 0.01),
scope='ASALogits')
self.asa_output = tf.nn.relu(asa_logits, name='ASA_output_relu')
return asa_logits

View File

@@ -0,0 +1,63 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for storing configuration flags."""
import json
class ConfigDict(dict):
"""Configuration dictionary with convenient dot element access."""
def __init__(self, *args, **kwargs):
super(ConfigDict, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
self._add(key, value)
for key, value in kwargs.items():
self._add(key, value)
def _add(self, key, value):
if isinstance(value, dict):
self[key] = ConfigDict(value)
else:
self[key] = value
def __getattr__(self, attr):
try:
return self[attr]
except KeyError as e:
raise AttributeError(e)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(ConfigDict, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(ConfigDict, self).__delitem__(key)
del self.__dict__[key]
def to_json(self):
return json.dumps(self)
@classmethod
def from_json(cls, json_string):
return cls(json.loads(json_string))

View File

@@ -0,0 +1,394 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to run distogram inference."""
import collections
import os
import time
from absl import app
from absl import flags
from absl import logging
import numpy as np
import six
import sonnet as snt
import tensorflow.compat.v1 as tf
from alphafold_casp13 import config_dict
from alphafold_casp13 import contacts_experiment
from alphafold_casp13 import distogram_io
from alphafold_casp13 import secstruct
flags.DEFINE_string('config_path', None, 'Path of the JSON config file.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path for evaluation.')
flags.DEFINE_boolean('cpu', False, 'Force onto CPU.')
flags.DEFINE_string('output_path', None,
'Base path where all output files will be saved to.')
flags.DEFINE_string('eval_sstable', None,
'Path of the SSTable to read the input tf.Examples from.')
flags.DEFINE_string('stats_file', None,
'Path of the statistics file to use for normalization.')
FLAGS = flags.FLAGS
# A named tuple to store the outputs of a single prediction run.
Prediction = collections.namedtuple(
'Prediction', [
'single_message', # A debugging message.
'num_crops_local', # The number of crops used to make this prediction.
'sequence', # The amino acid sequence.
'filebase', # The chain name. All output files will use this name.
'softmax_probs', # Softmax of the distogram.
'ss', # Secondary structure prediction.
'asa', # ASA prediction.
'torsions', # Torsion prediction.
])
def evaluate(crop_size_x, crop_size_y, feature_normalization, checkpoint_path,
normalization_exclusion, eval_config, network_config):
"""Main evaluation loop."""
experiment = contacts_experiment.Contacts(
tfrecord=eval_config.eval_sstable,
stats_file=eval_config.stats_file,
network_config=network_config,
crop_size_x=crop_size_x,
crop_size_y=crop_size_y,
feature_normalization=feature_normalization,
normalization_exclusion=normalization_exclusion)
checkpoint = snt.get_saver(experiment.model, collections=[
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
tf.compat.v1.GraphKeys.MOVING_AVERAGE_VARIABLES])
with tf.compat.v1.train.SingularMonitoredSession(hooks=[]) as sess:
logging.info('Restoring from checkpoint %s', checkpoint_path)
checkpoint.restore(sess, checkpoint_path)
logging.info('Writing output to %s', eval_config.output_path)
eval_begin_time = time.time()
_run_evaluation(sess=sess,
experiment=experiment,
eval_config=eval_config,
output_dir=eval_config.output_path,
min_range=network_config.min_range,
max_range=network_config.max_range,
num_bins=network_config.num_bins,
torsion_bins=network_config.torsion_bins)
logging.info('Finished eval %.1fs', (time.time() - eval_begin_time))
def _run_evaluation(
sess, experiment, eval_config, output_dir, min_range, max_range, num_bins,
torsion_bins):
"""Evaluate a contact map by aggregating crops.
Args:
sess: A tf.train.Session.
experiment: An experiment class.
eval_config: A config dict of eval parameters.
output_dir: Directory to save the predictions to.
min_range: The minimum range in Angstroms to consider in distograms.
max_range: The maximum range in Angstroms to consider in distograms, see
num_bins below for clarification.
num_bins: The number of bins in the distance histogram being predicted.
We divide the min_range--(min_range + max_range) Angstrom range into this
many bins.
torsion_bins: The number of bins the torsion angles are discretised into.
"""
tf.io.gfile.makedirs(os.path.join(output_dir, 'pickle_files'))
logging.info('Eval config is %s\nnum_bins: %d', eval_config, num_bins)
num_examples = 0
num_crops = 0
start_all_time = time.time()
# Either do the whole test set, or up to a specified limit.
max_examples = experiment.num_eval_examples
if eval_config.max_num_examples > 0:
max_examples = min(max_examples, eval_config.max_num_examples)
while num_examples < max_examples:
one_prediction = compute_one_prediction(
num_examples, experiment, sess, eval_config, num_bins, torsion_bins)
single_message = one_prediction.single_message
num_crops_local = one_prediction.num_crops_local
sequence = one_prediction.sequence
filebase = one_prediction.filebase
softmax_probs = one_prediction.softmax_probs
ss = one_prediction.ss
asa = one_prediction.asa
torsions = one_prediction.torsions
num_examples += 1
num_crops += num_crops_local
# Save the output files.
filename = os.path.join(output_dir,
'pickle_files', '%s.pickle' % filebase)
distogram_io.save_distance_histogram(
filename, softmax_probs, filebase, sequence,
min_range=min_range, max_range=max_range, num_bins=num_bins)
if experiment.model.torsion_multiplier > 0:
torsions_dir = os.path.join(output_dir, 'torsions')
tf.io.gfile.makedirs(torsions_dir)
distogram_io.save_torsions(torsions_dir, filebase, sequence, torsions)
if experiment.model.secstruct_multiplier > 0:
ss_dir = os.path.join(output_dir, 'secstruct')
tf.io.gfile.makedirs(ss_dir)
secstruct.save_secstructs(ss_dir, filebase, None, sequence, ss)
if experiment.model.asa_multiplier > 0:
asa_dir = os.path.join(output_dir, 'asa')
tf.io.gfile.makedirs(asa_dir)
secstruct.save_secstructs(asa_dir, filebase, None, sequence,
np.expand_dims(asa, 1), label='Deepmind 2D ASA')
time_spent = time.time() - start_all_time
logging.info(
'Evaluate %d examples, %d crops %.1f crops/ex. '
'Took %.1fs, %.3f s/example %.3f crops/s\n%s',
num_examples, num_crops, num_crops / float(num_examples), time_spent,
time_spent / num_examples, num_crops / time_spent, single_message)
logging.info('Tested on %d', num_examples)
def compute_one_prediction(
num_examples, experiment, sess, eval_config, num_bins, torsion_bins):
"""Find the contact map for a single domain."""
num_crops_local = 0
debug_steps = 0
start = time.time()
output_fetches = {'probs': experiment.eval_probs}
output_fetches['softmax_probs'] = experiment.eval_probs_softmax
# Add the auxiliary outputs if present.
experiment.model.update_crop_fetches(output_fetches)
# Get data.
batch = experiment.get_one_example(sess)
length = batch['sequence_lengths'][0]
batch_size = batch['sequence_lengths'].shape[0]
domain = batch['domain_name'][0][0].decode('utf-8')
chain = batch['chain_name'][0][0].decode('utf-8')
filebase = domain or chain
sequence = six.ensure_str(batch['sequences'][0][0])
logging.info('SepWorking on %d %s %s %d', num_examples, domain, chain, length)
inputs_1d = batch['inputs_1d']
if 'residue_index' in batch:
logging.info('Getting residue_index from features')
residue_index = np.squeeze(
batch['residue_index'], axis=2).astype(np.int32)
else:
logging.info('Generating residue_index')
residue_index = np.tile(np.expand_dims(
np.arange(length, dtype=np.int32), 0), [batch_size, 1])
assert batch_size == 1
num_examples += batch_size
# Crops.
prob_accum = np.zeros((length, length, 2))
ss_accum = np.zeros((length, 8))
torsions_accum = np.zeros((length, torsion_bins**2))
asa_accum = np.zeros((length,))
weights_1d_accum = np.zeros((length,))
softmax_prob_accum = np.zeros((length, length, num_bins), dtype=np.float32)
crop_size_x = experiment.crop_size_x
crop_step_x = crop_size_x // eval_config.crop_shingle_x
crop_size_y = experiment.crop_size_y
crop_step_y = crop_size_y // eval_config.crop_shingle_y
prob_weights = 1
if eval_config.pyramid_weights > 0:
sx = np.expand_dims(np.linspace(1.0 / crop_size_x, 1, crop_size_x), 1)
sy = np.expand_dims(np.linspace(1.0 / crop_size_y, 1, crop_size_y), 0)
prob_weights = np.minimum(np.minimum(sx, np.flipud(sx)),
np.minimum(sy, np.fliplr(sy)))
prob_weights /= np.max(prob_weights)
prob_weights = np.minimum(prob_weights, eval_config.pyramid_weights)
logging.log_first_n(logging.INFO, 'Crop: %dx%d step %d,%d pyr %.2f',
debug_steps,
crop_size_x, crop_size_y,
crop_step_x, crop_step_y, eval_config.pyramid_weights)
# Accumulate all crops, starting and ending half off the square.
for i in range(-crop_size_x // 2, length - crop_size_x // 2, crop_step_x):
for j in range(-crop_size_y // 2, length - crop_size_y // 2, crop_step_y):
# The ideal crop.
patch = compute_one_patch(
sess, experiment, output_fetches, inputs_1d, residue_index,
prob_weights, batch, length, i, j, crop_size_x, crop_size_y)
# Assemble the crops into a final complete prediction.
ic = max(0, i)
jc = max(0, j)
ic_to = ic + patch['prob'].shape[1]
jc_to = jc + patch['prob'].shape[0]
prob_accum[jc:jc_to, ic:ic_to, 0] += patch['prob'] * patch['weight']
prob_accum[jc:jc_to, ic:ic_to, 1] += patch['weight']
softmax_prob_accum[jc:jc_to, ic:ic_to, :] += (
patch['softmax'] * np.expand_dims(patch['weight'], 2))
weights_1d_accum[jc:jc_to] += 1
weights_1d_accum[ic:ic_to] += 1
if 'asa_x' in patch:
asa_accum[ic:ic + patch['asa_x'].shape[0]] += np.squeeze(
patch['asa_x'], axis=1)
asa_accum[jc:jc + patch['asa_y'].shape[0]] += np.squeeze(
patch['asa_y'], axis=1)
if 'ss_x' in patch:
ss_accum[ic:ic + patch['ss_x'].shape[0]] += patch['ss_x']
ss_accum[jc:jc + patch['ss_y'].shape[0]] += patch['ss_y']
if 'torsions_x' in patch:
torsions_accum[
ic:ic + patch['torsions_x'].shape[0]] += patch['torsions_x']
torsions_accum[
jc:jc + patch['torsions_y'].shape[0]] += patch['torsions_y']
num_crops_local += 1
single_message = (
'Constructed %s len %d from %d chunks [%d, %d x %d, %d] '
'in %5.1fs' % (
filebase, length, num_crops_local,
crop_size_x, crop_step_x, crop_size_y, crop_step_y,
time.time() - start))
logging.info(single_message)
logging.info('prob_accum[:, :, 1]: %s', prob_accum[:, :, 1])
assert (prob_accum[:, :, 1] > 0.0).all()
probs = prob_accum[:, :, 0] / prob_accum[:, :, 1]
softmax_probs = softmax_prob_accum[:, :, :] / prob_accum[:, :, 1:2]
asa_accum /= weights_1d_accum
ss_accum /= np.expand_dims(weights_1d_accum, 1)
torsions_accum /= np.expand_dims(weights_1d_accum, 1)
# The probs are symmetrical.
probs = (probs + probs.transpose()) / 2
if num_bins > 1:
softmax_probs = (softmax_probs + np.transpose(
softmax_probs, axes=[1, 0, 2])) / 2
return Prediction(
single_message=single_message,
num_crops_local=num_crops_local,
sequence=sequence,
filebase=filebase,
softmax_probs=softmax_probs,
ss=ss_accum,
asa=asa_accum,
torsions=torsions_accum)
def compute_one_patch(sess, experiment, output_fetches, inputs_1d,
residue_index, prob_weights, batch, length, i, j,
crop_size_x, crop_size_y):
"""Compute the output predictions for a single crop."""
# Note that these are allowed to go off the end of the protein.
end_x = i + crop_size_x
end_y = j + crop_size_y
crop_limits = np.array([[i, end_x, j, end_y]], dtype=np.int32)
ic = max(0, i)
jc = max(0, j)
end_x_cropped = min(length, end_x)
end_y_cropped = min(length, end_y)
prepad_x = max(0, -i)
prepad_y = max(0, -j)
postpad_x = end_x - end_x_cropped
postpad_y = end_y - end_y_cropped
# Precrop the 2D features:
inputs_2d = np.pad(batch['inputs_2d'][
:, jc:end_y, ic:end_x, :],
[[0, 0],
[prepad_y, postpad_y],
[prepad_x, postpad_x],
[0, 0]], mode='constant')
assert inputs_2d.shape[1] == crop_size_y
assert inputs_2d.shape[2] == crop_size_x
# Generate the corresponding crop, but it might be truncated.
cxx = batch['inputs_2d'][:, ic:end_x, ic:end_x, :]
cyy = batch['inputs_2d'][:, jc:end_y, jc:end_y, :]
if cxx.shape[1] < inputs_2d.shape[1]:
cxx = np.pad(cxx, [[0, 0],
[prepad_x, max(0, i + crop_size_y - length)],
[prepad_x, postpad_x],
[0, 0]], mode='constant')
assert cxx.shape[1] == crop_size_y
assert cxx.shape[2] == crop_size_x
if cyy.shape[2] < inputs_2d.shape[2]:
cyy = np.pad(cyy, [[0, 0],
[prepad_y, postpad_y],
[prepad_y, max(0, j + crop_size_x - length)],
[0, 0]], mode='constant')
assert cyy.shape[1] == crop_size_y
assert cyy.shape[2] == crop_size_x
inputs_2d = np.concatenate([inputs_2d, cxx, cyy], 3)
output_results = sess.run(output_fetches, feed_dict={
experiment.inputs_1d_placeholder: inputs_1d,
experiment.residue_index_placeholder: residue_index,
experiment.inputs_2d_placeholder: inputs_2d,
experiment.crop_placeholder: crop_limits,
})
# Crop out the "live" region of the probs.
prob_patch = output_results['probs'][
0, prepad_y:crop_size_y - postpad_y,
prepad_x:crop_size_x - postpad_x]
weight_patch = prob_weights[prepad_y:crop_size_y - postpad_y,
prepad_x:crop_size_x - postpad_x]
patch = {'prob': prob_patch, 'weight': weight_patch}
if 'softmax_probs' in output_results:
patch['softmax'] = output_results['softmax_probs'][
0, prepad_y:crop_size_y - postpad_y,
prepad_x:crop_size_x - postpad_x]
if 'secstruct_probs' in output_results:
patch['ss_x'] = output_results['secstruct_probs'][
0, prepad_x:crop_size_x - postpad_x]
patch['ss_y'] = output_results['secstruct_probs'][
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
if 'torsion_probs' in output_results:
patch['torsions_x'] = output_results['torsion_probs'][
0, prepad_x:crop_size_x - postpad_x]
patch['torsions_y'] = output_results['torsion_probs'][
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
if 'asa_output' in output_results:
patch['asa_x'] = output_results['asa_output'][
0, prepad_x:crop_size_x - postpad_x]
patch['asa_y'] = output_results['asa_output'][
0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y]
return patch
def main(argv):
del argv # Unused.
logging.info('Loading a JSON config from: %s', FLAGS.config_path)
with tf.io.gfile.GFile(FLAGS.config_path, 'r') as f:
config = config_dict.ConfigDict.from_json(f.read())
# Redefine the relevant output fields.
if FLAGS.eval_sstable:
config.eval_config.eval_sstable = FLAGS.eval_sstable
if FLAGS.stats_file:
config.eval_config.stats_file = FLAGS.stats_file
if FLAGS.output_path:
config.eval_config.output_path = FLAGS.output_path
with tf.device('/cpu:0' if FLAGS.cpu else None):
evaluate(checkpoint_path=FLAGS.checkpoint_path, **config)
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,364 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF wrapper for protein tf.Example datasets."""
import collections
import enum
import json
import tensorflow.compat.v1 as tf
_ProteinDescription = collections.namedtuple(
'_ProteinDescription', (
'sequence_lengths', 'key', 'sequences', 'inputs_1d', 'inputs_2d',
'inputs_2d_diagonal', 'crops', 'scalars', 'targets'))
class FeatureType(enum.Enum):
ZERO_DIM = 0 # Shape [x]
ONE_DIM = 1 # Shape [num_res, x]
TWO_DIM = 2 # Shape [num_res, num_res, x]
# Placeholder values that will be replaced with their true value at runtime.
NUM_RES = 'num residues placeholder'
# Sizes of the protein features. NUM_RES is allowed as a placeholder to be
# replaced with the number of residues.
FEATURES = {
'aatype': (tf.float32, [NUM_RES, 21]),
'alpha_mask': (tf.int64, [NUM_RES, 1]),
'alpha_positions': (tf.float32, [NUM_RES, 3]),
'beta_mask': (tf.int64, [NUM_RES, 1]),
'beta_positions': (tf.float32, [NUM_RES, 3]),
'between_segment_residues': (tf.int64, [NUM_RES, 1]),
'chain_name': (tf.string, [1]),
'deletion_probability': (tf.float32, [NUM_RES, 1]),
'domain_name': (tf.string, [1]),
'gap_matrix': (tf.float32, [NUM_RES, NUM_RES, 1]),
'hhblits_profile': (tf.float32, [NUM_RES, 22]),
'hmm_profile': (tf.float32, [NUM_RES, 30]),
'key': (tf.string, [1]),
'mutual_information': (tf.float32, [NUM_RES, NUM_RES, 1]),
'non_gapped_profile': (tf.float32, [NUM_RES, 21]),
'num_alignments': (tf.int64, [NUM_RES, 1]),
'num_effective_alignments': (tf.float32, [1]),
'phi_angles': (tf.float32, [NUM_RES, 1]),
'phi_mask': (tf.int64, [NUM_RES, 1]),
'profile': (tf.float32, [NUM_RES, 21]),
'profile_with_prior': (tf.float32, [NUM_RES, 22]),
'profile_with_prior_without_gaps': (tf.float32, [NUM_RES, 21]),
'pseudo_bias': (tf.float32, [NUM_RES, 22]),
'pseudo_frob': (tf.float32, [NUM_RES, NUM_RES, 1]),
'pseudolikelihood': (tf.float32, [NUM_RES, NUM_RES, 484]),
'psi_angles': (tf.float32, [NUM_RES, 1]),
'psi_mask': (tf.int64, [NUM_RES, 1]),
'residue_index': (tf.int64, [NUM_RES, 1]),
'resolution': (tf.float32, [1]),
'reweighted_profile': (tf.float32, [NUM_RES, 22]),
'sec_structure': (tf.int64, [NUM_RES, 8]),
'sec_structure_mask': (tf.int64, [NUM_RES, 1]),
'seq_length': (tf.int64, [NUM_RES, 1]),
'sequence': (tf.string, [1]),
'solv_surf': (tf.float32, [NUM_RES, 1]),
'solv_surf_mask': (tf.int64, [NUM_RES, 1]),
'superfamily': (tf.string, [1]),
}
FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
def shape(feature_name, num_residues, features=None):
"""Get the shape for the given feature name.
Args:
feature_name: String identifier for the feature. If the feature name ends
with "_unnormalized", theis suffix is stripped off.
num_residues: The number of residues in the current domain - some elements
of the shape can be dynamic and will be replaced by this value.
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
Returns:
List of ints representation the tensor size.
"""
features = features or FEATURES
if feature_name.endswith('_unnormalized'):
feature_name = feature_name[:-13]
unused_dtype, raw_sizes = features[feature_name]
replacements = {NUM_RES: num_residues}
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
return sizes
def dim(feature_name):
"""Determine the type of feature.
Args:
feature_name: String identifier for the feature to lookup. If the feature
name ends with "_unnormalized", theis suffix is stripped off.
Returns:
A FeatureType enum describing whether the feature is of size num_res or
num_res * num_res.
Raises:
ValueError: If the feature is of an unknown type.
"""
if feature_name.endswith('_unnormalized'):
feature_name = feature_name[:-13]
num_dims = len(FEATURE_SIZES[feature_name])
if num_dims == 1:
return FeatureType.ZERO_DIM
elif num_dims == 2 and FEATURE_SIZES[feature_name][0] == NUM_RES:
return FeatureType.ONE_DIM
elif num_dims == 3 and FEATURE_SIZES[feature_name][0] == NUM_RES:
return FeatureType.TWO_DIM
else:
raise ValueError('Expect feature sizes to be 2 or 3, got %i' %
len(FEATURE_SIZES[feature_name]))
def _concat_or_zeros(tensor_list, axis, tensor_shape, name):
"""Concatenates the tensors if given, otherwise returns a tensor of zeros."""
if tensor_list:
return tf.concat(tensor_list, axis=axis, name=name)
return tf.zeros(tensor_shape, name=name + '_zeros')
def parse_tfexample(raw_data, features):
"""Read a single TF Example proto and return a subset of its features.
Args:
raw_data: A serialized tf.Example proto.
features: A dictionary of features, mapping string feature names to a tuple
(dtype, shape). This dictionary should be a subset of
protein_features.FEATURES (or the dictionary itself for all features).
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
feature_map = {
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
for k, v in features.items()
}
parsed_features = tf.io.parse_single_example(raw_data, feature_map)
# Find out what is the number of sequences and the number of alignments.
num_residues = tf.cast(parsed_features['seq_length'][0], dtype=tf.int32)
# Reshape the tensors according to the sequence length and num alignments.
for k, v in parsed_features.items():
new_shape = shape(feature_name=k, num_residues=num_residues)
# Make sure the feature we are reshaping is not empty.
assert_non_empty = tf.compat.v1.assert_greater(
tf.size(v), 0, name='assert_%s_non_empty' % k,
message='The feature %s is not set in the tf.Example. Either do not '
'request the feature or use a tf.Example that has the feature set.' % k)
with tf.control_dependencies([assert_non_empty]):
parsed_features[k] = tf.reshape(v, new_shape, name='reshape_%s' % k)
return parsed_features
def create_tf_dataset(tf_record_filename, features):
"""Creates an instance of tf.data.Dataset backed by a protein dataset SSTable.
Args:
tf_record_filename: A string with filename of the TFRecord file.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A tf.data.Dataset object. Its items are dictionaries from feature names to
feature values.
"""
# Make sure these features are always read.
required_features = ['aatype', 'sequence', 'seq_length']
features = list(set(features) | set(required_features))
features = {name: FEATURES[name] for name in features}
tf_dataset = tf.data.TFRecordDataset(filenames=[tf_record_filename])
tf_dataset = tf_dataset.map(lambda raw: parse_tfexample(raw, features))
return tf_dataset
def normalize_from_stats_file(
features, stats_file_path, feature_normalization, copy_unnormalized=None):
"""Normalizes the features set in the feature_normalization by the norm stats.
Args:
features: A dictionary mapping feature names to feature tensors.
stats_file_path: A string with the path of the statistics JSON file.
feature_normalization: A dictionary specifying the normalization type for
each input feature. Acceptable values are 'std' and 'none'. If not
specified default to 'none'. Any extra features that are not present in
features will be ignored.
copy_unnormalized: A list of features whose unnormalized copy should be
added. For any feature F in this list a feature F + "_unnormalized" will
be added in the output dictionary containing the unnormalized feature.
This is useful if you have a feature you want to have both in
desired_features (normalized) and also in desired_targets (unnormalized).
See convert_to_legacy_proteins_dataset_format for more details.
Returns:
A dictionary mapping features names to feature tensors. The ones that were
specified in feature_normalization will be normalized.
Raises:
ValueError: If an unknown normalization mode is used.
"""
with tf.io.gfile.GFile(stats_file_path, 'r') as f:
norm_stats = json.loads(f.read())
if not copy_unnormalized:
copy_unnormalized = []
# We need this unnormalized in convert_to_legacy_proteins_dataset_format.
copy_unnormalized.append('num_alignments')
for feature in copy_unnormalized:
if feature in features:
features[feature + '_unnormalized'] = features[feature]
range_epsilon = 1e-12
for key, value in features.items():
if key not in feature_normalization or feature_normalization[key] == 'none':
pass
elif feature_normalization[key] == 'std':
value = tf.cast(value, dtype=tf.float32)
train_mean = tf.cast(norm_stats['mean'][key], dtype=tf.float32)
train_range = tf.sqrt(tf.cast(norm_stats['var'][key], dtype=tf.float32))
value -= train_mean
value = tf.compat.v1.where(
train_range > range_epsilon, value / train_range, value)
features[key] = value
else:
raise ValueError('Unknown normalization mode %s for feature %s.'
% (feature_normalization[key], key))
return features
def convert_to_legacy_proteins_dataset_format(
features, desired_features, desired_scalars, desired_targets):
"""Converts the output of tf.Dataset to the legacy format.
Args:
features: A dictionary mapping feature names to feature tensors.
desired_features: A list with the names of the desired features. These will
be filtered out of features and returned in one of the inputs_1d or
inputs_2d. The features concatenated in `inputs_1d`, `inputs_2d` will be
concatenated in the same order as they were given in `desired_features`.
desired_scalars: A list naming the desired scalars. These will
be filtered out of features and returned in scalars. If features contain
an unnormalized version of a desired scalar, it will be used.
desired_targets: A list naming the desired targets. These will
be filtered out of features and returned in targets. If features contain
an unnormalized version of a desired target, it will be used.
Returns:
A _ProteinDescription namedtuple consisting of:
sequence_length: A scalar int32 tensor with the sequence length.
key: A string tensor with the sequence key or empty if not set features.
sequences: A string tensor with the protein sequence.
inputs_1d: All 1D features in a single tensor of shape
[num_res, 1d_channels].
inputs_2d: All 2D features in a single tensor of shape
[num_res, num_res, 2d_channels].
inputs_2d_diagonal: All 2D diagonal features in a single tensor of shape
[num_res, num_res, 2d_diagonal_channels]. If no diagonal features found
in features, the tensor will be set to inputs_2d.
crops: A int32 tensor with the crop poisitions. If not set in features,
it will be set to [0, num_res, 0, num_res].
scalars: All requested scalar tensors in a list.
targets: All requested target tensors in a list.
Raises:
ValueError: If the feature size is invalid.
"""
tensors_1d = []
tensors_2d = []
tensors_2d_diagonal = []
for key in desired_features:
# Determine if the feature is 1D or 2D.
feature_dim = dim(key)
if feature_dim == FeatureType.ONE_DIM:
tensors_1d.append(tf.cast(features[key], dtype=tf.float32))
elif feature_dim == FeatureType.TWO_DIM:
if key not in features:
if not(key + '_cropped' in features and key + '_diagonal' in features):
raise ValueError(
'The 2D feature %s is not in the features dictionary and neither '
'are its cropped and diagonal versions.' % key)
else:
tensors_2d.append(
tf.cast(features[key + '_cropped'], dtype=tf.float32))
tensors_2d_diagonal.append(
tf.cast(features[key + '_diagonal'], dtype=tf.float32))
else:
tensors_2d.append(tf.cast(features[key], dtype=tf.float32))
else:
raise ValueError('Unexpected FeatureType returned: %s' % str(feature_dim))
# Determine num_res from the sequence as seq_length was possibly normalized.
num_res = tf.strings.length(features['sequence'])[0]
# Concatenate feature tensors into a single tensor
inputs_1d = _concat_or_zeros(
tensors_1d, axis=1, tensor_shape=[num_res, 0],
name='inputs_1d_concat')
inputs_2d = _concat_or_zeros(
tensors_2d, axis=2, tensor_shape=[num_res, num_res, 0],
name='inputs_2d_concat')
if tensors_2d_diagonal:
# The legacy dataset outputs the two diagonal crops stacked
# A1, B1, C1, A2, B2, C2. So convert the A1, A2, B1, B2, C1, C2 format.
diagonal_crops1 = [t[:, :, :(t.shape[2] // 2)] for t in tensors_2d_diagonal]
diagonal_crops2 = [t[:, :, (t.shape[2] // 2):] for t in tensors_2d_diagonal]
inputs_2d_diagonal = tf.concat(diagonal_crops1 + diagonal_crops2, axis=2)
else:
inputs_2d_diagonal = inputs_2d
sequence = features['sequence']
sequence_key = features.get('key', tf.constant(['']))[0]
if 'crops' in features:
crops = features['crops']
else:
crops = tf.stack([0, tf.shape(sequence)[0], 0, tf.shape(sequence)[0]])
scalar_tensors = []
for key in desired_scalars:
scalar_tensors.append(features.get(key + '_unnormalized', features[key]))
target_tensors = []
for key in desired_targets:
target_tensors.append(features.get(key + '_unnormalized', features[key]))
scalar_class = collections.namedtuple('_ScalarClass', desired_scalars)
target_class = collections.namedtuple('_TargetClass', desired_targets)
return _ProteinDescription(
sequence_lengths=num_res,
key=sequence_key,
sequences=sequence,
inputs_1d=inputs_1d,
inputs_2d=inputs_2d,
inputs_2d_diagonal=inputs_2d_diagonal,
crops=crops,
scalars=scalar_class(*scalar_tensors),
targets=target_class(*target_tensors))

View File

@@ -0,0 +1,233 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contact prediction convnet experiment example."""
from absl import logging
import tensorflow.compat.v1 as tf
from alphafold_casp13 import contacts_dataset
from alphafold_casp13 import contacts_network
def _int_ph(shape, name):
return tf.compat.v1.placeholder(
dtype=tf.int32, shape=shape, name=('%s_placeholder' % name))
def _float_ph(shape, name):
return tf.compat.v1.placeholder(
dtype=tf.float32, shape=shape, name=('%s_placeholder' % name))
class Contacts(object):
"""Contact prediction experiment."""
def __init__(
self, tfrecord, stats_file, network_config, crop_size_x, crop_size_y,
feature_normalization, normalization_exclusion):
"""Builds the TensorFlow graph."""
self.network_config = network_config
self.crop_size_x = crop_size_x
self.crop_size_y = crop_size_y
self._feature_normalization = feature_normalization
self._normalization_exclusion = normalization_exclusion
self._model = contacts_network.ContactsNet(**network_config)
self._features = network_config.features
self._scalars = network_config.scalars
self._targets = network_config.targets
# Add extra targets we need.
required_targets = ['domain_name', 'resolution', 'chain_name']
if self.model.torsion_multiplier > 0:
required_targets.extend([
'phi_angles', 'phi_mask', 'psi_angles', 'psi_mask'])
if self.model.secstruct_multiplier > 0:
required_targets.extend(['sec_structure', 'sec_structure_mask'])
if self.model.asa_multiplier > 0:
required_targets.extend(['solv_surf', 'solv_surf_mask'])
extra_targets = [t for t in required_targets if t not in self._targets]
if extra_targets:
targets = list(self._targets)
targets.extend(extra_targets)
self._targets = tuple(targets)
logging.info('Targets %s %s extra %s',
type(self._targets), self._targets, extra_targets)
logging.info('Evaluating on %s, stats: %s', tfrecord, stats_file)
self._build_evaluation_graph(tfrecord=tfrecord, stats_file=stats_file)
@property
def model(self):
return self._model
def _get_feature_normalization(self, features):
return {key: self._feature_normalization
for key in features
if key not in list(self._normalization_exclusion)}
def _build_evaluation_graph(self, tfrecord, stats_file):
"""Constructs the graph in pieces so it can be fed."""
with tf.name_scope('competitionsep'):
# Construct the dataset and mapping ops.
dataset = contacts_dataset.create_tf_dataset(
tf_record_filename=tfrecord,
features=tuple(self._features) + tuple(
self._scalars) + tuple(self._targets))
def normalize(data):
return contacts_dataset.normalize_from_stats_file(
features=data,
stats_file_path=stats_file,
feature_normalization=self._get_feature_normalization(
self._features),
copy_unnormalized=list(set(self._features) & set(self._targets)))
def convert_to_legacy(features):
return contacts_dataset.convert_to_legacy_proteins_dataset_format(
features, self._features, self._scalars, self._targets)
dataset = dataset.map(normalize)
dataset = dataset.map(convert_to_legacy)
dataset = dataset.batch(1)
# Get a batch of tensors in the legacy ProteinsDataset format.
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
self._input_batch = iterator.get_next()
self.num_eval_examples = sum(
1 for _ in tf.python_io.tf_record_iterator(tfrecord))
logging.info('Eval batch:\n%s', self._input_batch)
feature_dim_1d = self._input_batch.inputs_1d.shape.as_list()[-1]
feature_dim_2d = self._input_batch.inputs_2d.shape.as_list()[-1]
feature_dim_2d *= 3 # The diagonals will be stacked before feeding.
# Now placeholders for the graph to compute the outputs for one crop.
self.inputs_1d_placeholder = _float_ph(
shape=[None, None, feature_dim_1d], name='inputs_1d')
self.residue_index_placeholder = _int_ph(
shape=[None, None], name='residue_index')
self.inputs_2d_placeholder = _float_ph(
shape=[None, None, None, feature_dim_2d], name='inputs_2d')
# 4 ints: x_start, x_end, y_start, y_end.
self.crop_placeholder = _int_ph(shape=[None, 4], name='crop')
# Finally placeholders for the graph to score the complete contact map.
self.probs_placeholder = _float_ph(shape=[None, None, None], name='probs')
self.softmax_probs_placeholder = _float_ph(
shape=[None, None, None, self.network_config.num_bins],
name='softmax_probs')
self.cb_placeholder = _float_ph(shape=[None, None, 3], name='cb')
self.cb_mask_placeholder = _float_ph(shape=[None, None], name='cb_mask')
self.lengths_placeholder = _int_ph(shape=[None], name='lengths')
if self.model.secstruct_multiplier > 0:
self.sec_structure_placeholder = _float_ph(
shape=[None, None, 8], name='sec_structure')
self.sec_structure_logits_placeholder = _float_ph(
shape=[None, None, 8], name='sec_structure_logits')
self.sec_structure_mask_placeholder = _float_ph(
shape=[None, None, 1], name='sec_structure_mask')
if self.model.asa_multiplier > 0:
self.solv_surf_placeholder = _float_ph(
shape=[None, None, 1], name='solv_surf')
self.solv_surf_logits_placeholder = _float_ph(
shape=[None, None, 1], name='solv_surf_logits')
self.solv_surf_mask_placeholder = _float_ph(
shape=[None, None, 1], name='solv_surf_mask')
if self.model.torsion_multiplier > 0:
self.torsions_truth_placeholder = _float_ph(
shape=[None, None, 2], name='torsions_truth')
self.torsions_mask_placeholder = _float_ph(
shape=[None, None, 1], name='torsions_mask')
self.torsion_logits_placeholder = _float_ph(
shape=[None, None, self.network_config.torsion_bins ** 2],
name='torsion_logits')
# Build a dict to pass all the placeholders into build.
placeholders = {
'inputs_1d_placeholder': self.inputs_1d_placeholder,
'residue_index_placeholder': self.residue_index_placeholder,
'inputs_2d_placeholder': self.inputs_2d_placeholder,
'crop_placeholder': self.crop_placeholder,
'probs_placeholder': self.probs_placeholder,
'softmax_probs_placeholder': self.softmax_probs_placeholder,
'cb_placeholder': self.cb_placeholder,
'cb_mask_placeholder': self.cb_mask_placeholder,
'lengths_placeholder': self.lengths_placeholder,
}
if self.model.secstruct_multiplier > 0:
placeholders.update({
'sec_structure': self.sec_structure_placeholder,
'sec_structure_logits_placeholder':
self.sec_structure_logits_placeholder,
'sec_structure_mask': self.sec_structure_mask_placeholder,})
if self.model.asa_multiplier > 0:
placeholders.update({
'solv_surf': self.solv_surf_placeholder,
'solv_surf_logits_placeholder': self.solv_surf_logits_placeholder,
'solv_surf_mask': self.solv_surf_mask_placeholder,})
if self.model.torsion_multiplier > 0:
placeholders.update({
'torsions_truth': self.torsions_truth_placeholder,
'torsion_logits_placeholder': self.torsion_logits_placeholder,
'torsions_truth_mask': self.torsions_mask_placeholder,})
activations = self._model(
crop_size_x=self.crop_size_x,
crop_size_y=self.crop_size_y,
placeholders=placeholders)
self.eval_probs_softmax = tf.nn.softmax(
activations[:, :, :, :self.network_config.num_bins])
self.eval_probs = tf.reduce_sum(
self.eval_probs_softmax[:, :, :, :self._model.quant_threshold()],
axis=3)
def get_one_example(self, sess):
"""Pull one example off the queue so we can feed it for evaluation."""
request_dict = {
'inputs_1d': self._input_batch.inputs_1d,
'inputs_2d': self._input_batch.inputs_2d,
'sequence_lengths': self._input_batch.sequence_lengths,
'beta_positions': self._input_batch.targets.beta_positions,
'beta_mask': self._input_batch.targets.beta_mask,
'domain_name': self._input_batch.targets.domain_name,
'chain_name': self._input_batch.targets.chain_name,
'sequences': self._input_batch.sequences,
}
if hasattr(self._input_batch.targets, 'residue_index'):
request_dict.update(
{'residue_index': self._input_batch.targets.residue_index})
if hasattr(self._input_batch.targets, 'phi_angles'):
request_dict.update(
{'phi_angles': self._input_batch.targets.phi_angles,
'psi_angles': self._input_batch.targets.psi_angles,
'phi_mask': self._input_batch.targets.phi_mask,
'psi_mask': self._input_batch.targets.psi_mask})
if hasattr(self._input_batch.targets, 'sec_structure'):
request_dict.update(
{'sec_structure': self._input_batch.targets.sec_structure,
'sec_structure_mask': self._input_batch.targets.sec_structure_mask,})
if hasattr(self._input_batch.targets, 'solv_surf'):
request_dict.update(
{'solv_surf': self._input_batch.targets.solv_surf,
'solv_surf_mask': self._input_batch.targets.solv_surf_mask,})
if hasattr(self._input_batch.targets, 'alpha_positions'):
request_dict.update(
{'alpha_positions': self._input_batch.targets.alpha_positions,
'alpha_mask': self._input_batch.targets.alpha_mask,})
batch = sess.run(request_dict)
return batch

View File

@@ -0,0 +1,498 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network for predicting C-beta contacts."""
from absl import logging
import sonnet
import tensorflow.compat.v1 as tf
from alphafold_casp13 import asa_output
from alphafold_casp13 import secstruct
from alphafold_casp13 import two_dim_convnet
from alphafold_casp13 import two_dim_resnet
from tensorflow.contrib import layers as contrib_layers
def call_on_tuple(f):
"""Unpacks a tuple input parameter into arguments for a function f.
Mimics tuple unpacking in lambdas, which existed in Python 2 but has been
removed in Python 3.
Args:
f: A function taking multiple arguments.
Returns:
A function equivalent to f accepting a tuple, which is then unpacked.
"""
return lambda args: f(*args)
class ContactsNet(sonnet.AbstractModule):
"""A network to go from sequence to secondary structure."""
def __init__(self,
binary_code_bits,
data_format,
distance_multiplier,
features,
features_forward,
max_range,
min_range,
num_bins,
reshape_layer,
resolution_noise_scale,
scalars,
targets,
network_2d,
network_2d_deep,
torsion_bins=None,
skip_connect=0,
position_specific_bias_size=0,
filters_1d=(),
collapsed_batch_norm=False,
is_ca_feature=False,
asa_multiplier=0.0,
secstruct_multiplier=0.0,
torsion_multiplier=0.0,
name='contacts_net'):
"""Construct position prediction network."""
super(ContactsNet, self).__init__(name=name)
self._filters_1d = filters_1d
self._collapsed_batch_norm = collapsed_batch_norm
self._is_ca_feature = is_ca_feature
self._binary_code_bits = binary_code_bits
self._data_format = data_format
self._distance_multiplier = distance_multiplier
self._features = features
self._features_forward = features_forward
self._max_range = max_range
self._min_range = min_range
self._num_bins = num_bins
self._position_specific_bias_size = position_specific_bias_size
self._reshape_layer = reshape_layer
self._resolution_noise_scale = resolution_noise_scale
self._scalars = scalars
self._torsion_bins = torsion_bins
self._skip_connect = skip_connect
self._targets = targets
self._network_2d = network_2d
self._network_2d_deep = network_2d_deep
self.asa_multiplier = asa_multiplier
self.secstruct_multiplier = secstruct_multiplier
self.torsion_multiplier = torsion_multiplier
with self._enter_variable_scope():
if self.secstruct_multiplier > 0:
self._secstruct = secstruct.Secstruct()
if self.asa_multiplier > 0:
self._asa = asa_output.ASAOutputLayer()
if self._position_specific_bias_size:
self._position_specific_bias = tf.compat.v1.get_variable(
'position_specific_bias',
[self._position_specific_bias_size, self._num_bins or 1],
initializer=tf.zeros_initializer())
def quant_threshold(self, threshold=8.0):
"""Find the bin that is 8A+: we sum mass below this bin gives contact prob.
Args:
threshold: The distance threshold.
Returns:
Index of bin.
"""
# Note that this misuses the max_range as the range.
return int(
(threshold - self._min_range) * self._num_bins / float(self._max_range))
def _build(self, crop_size_x=0, crop_size_y=0, placeholders=None):
"""Puts the network into the graph.
Args:
crop_size_x: Crop a chunk out in one dimension. 0 means no cropping.
crop_size_y: Crop a chunk out in one dimension. 0 means no cropping.
placeholders: A dict containing the placeholders needed.
Returns:
A Tensor with logits of size [batch_size, num_residues, 3].
"""
crop_placeholder = placeholders['crop_placeholder']
inputs_1d = placeholders['inputs_1d_placeholder']
if self._is_ca_feature and 'aatype' in self._features:
logging.info('Collapsing aatype to is_ca_feature %s',
inputs_1d.shape.as_list()[-1])
assert inputs_1d.shape.as_list()[-1] <= 21 + (
1 if 'seq_length' in self._features else 0)
inputs_1d = inputs_1d[:, :, 7:8]
logits = self.compute_outputs(
inputs_1d=inputs_1d,
residue_index=placeholders['residue_index_placeholder'],
inputs_2d=placeholders['inputs_2d_placeholder'],
crop_x=crop_placeholder[:, 0:2],
crop_y=crop_placeholder[:, 2:4],
use_on_the_fly_stats=True,
crop_size_x=crop_size_x,
crop_size_y=crop_size_y,
data_format='NHWC', # Force NHWC for evals.
)
return logits
def compute_outputs(self, inputs_1d, residue_index, inputs_2d, crop_x, crop_y,
use_on_the_fly_stats, crop_size_x, crop_size_y,
data_format='NHWC'):
"""Given the inputs for a block, compute the network outputs."""
hidden_1d = inputs_1d
hidden_1d_list = [hidden_1d]
if len(hidden_1d_list) != 1:
hidden_1d = tf.concat(hidden_1d_list, 2)
output_dimension = self._num_bins or 1
if self._distance_multiplier > 0:
output_dimension += 1
logits, activations = self._build_2d_embedding(
hidden_1d=hidden_1d,
residue_index=residue_index,
inputs_2d=inputs_2d,
output_dimension=output_dimension,
use_on_the_fly_stats=use_on_the_fly_stats,
crop_x=crop_x,
crop_y=crop_y,
crop_size_x=crop_size_x, crop_size_y=crop_size_y,
data_format=data_format)
logits = tf.debugging.check_numerics(
logits, 'NaN in resnet activations', name='resnet_activations')
if (self.secstruct_multiplier > 0 or
self.asa_multiplier > 0 or
self.torsion_multiplier > 0):
# Make a 1d embedding by reducing the 2D activations.
# We do this in the x direction and the y direction separately.
collapse_dim = 1
join_dim = -1
embedding_1d = tf.concat(
# First targets are crop_x (axis 2) which we must reduce on axis 1
[tf.concat([tf.reduce_max(activations, axis=collapse_dim),
tf.reduce_mean(activations, axis=collapse_dim)],
axis=join_dim),
# Next targets are crop_y (axis 1) which we must reduce on axis 2
tf.concat([tf.reduce_max(activations, axis=collapse_dim+1),
tf.reduce_mean(activations, axis=collapse_dim+1)],
axis=join_dim)],
axis=collapse_dim) # Join the two crops together.
if self._collapsed_batch_norm:
embedding_1d = contrib_layers.batch_norm(
embedding_1d,
is_training=use_on_the_fly_stats,
fused=True,
decay=0.999,
scope='collapsed_batch_norm',
data_format='NHWC')
for i, nfil in enumerate(self._filters_1d):
embedding_1d = contrib_layers.fully_connected(
embedding_1d,
num_outputs=nfil,
normalizer_fn=(contrib_layers.batch_norm
if self._collapsed_batch_norm else None),
normalizer_params={
'is_training': use_on_the_fly_stats,
'updates_collections': None
},
scope='collapsed_embed_%d' % i)
if self.torsion_multiplier > 0:
self.torsion_logits = contrib_layers.fully_connected(
embedding_1d,
num_outputs=self._torsion_bins * self._torsion_bins,
activation_fn=None,
scope='torsion_logits')
self.torsion_output = tf.nn.softmax(self.torsion_logits)
if self.secstruct_multiplier > 0:
self._secstruct.make_layer_new(embedding_1d)
if self.asa_multiplier > 0:
self.asa_logits = self._asa.compute_asa_output(embedding_1d)
return logits
@staticmethod
def _concatenate_2d(hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
binary_code_bits, crop_size_x, crop_size_y):
# Form the pairwise expansion of the 1D embedding
# And the residue offsets and (one) absolute position.
with tf.name_scope('Features2D'):
range_scale = 100.0 # Crude normalization factor.
n = tf.shape(hidden_1d)[1]
# pylint: disable=g-long-lambda
hidden_1d_cropped_y = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))], [0, 0]])),
elems=(crop_y, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_y = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))]])),
elems=(residue_index, crop_y), dtype=tf.int32,
back_prop=False)
hidden_1d_cropped_x = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))], [0, 0]])),
elems=(crop_x, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_x = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))]])),
elems=(residue_index, crop_x), dtype=tf.int32,
back_prop=False)
# pylint: enable=g-long-lambda
n_x = crop_size_x
n_y = crop_size_y
offset = (tf.expand_dims(tf.cast(range_n_x, tf.float32), 1) -
tf.expand_dims(tf.cast(range_n_y, tf.float32), 2)) / range_scale
position_features = [
tf.tile(
tf.reshape(
(tf.cast(range_n_y, tf.float32) - range_scale) / range_scale,
[-1, n_y, 1, 1]), [1, 1, n_x, 1],
name='TileRange'),
tf.tile(
tf.reshape(offset, [-1, n_y, n_x, 1]), [1, 1, 1, 1],
name='TileOffset')
]
channels = 2
if binary_code_bits:
# Binary coding of position.
exp_range_n_y = tf.expand_dims(range_n_y, 2)
bin_y = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_y // (1 << i), 2)
for i in range(binary_code_bits)], 2))
exp_range_n_x = tf.expand_dims(range_n_x, 2)
bin_x = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_x // (1 << i), 2)
for i in range(binary_code_bits)], 2))
position_features += [
tf.tile(
tf.expand_dims(tf.cast(bin_y, tf.float32), 2), [1, 1, n_x, 1],
name='TileBinRangey'),
tf.tile(
tf.expand_dims(tf.cast(bin_x, tf.float32), 1), [1, n_y, 1, 1],
name='TileBinRangex')
]
channels += 2 * binary_code_bits
augmentation_features = position_features + [
tf.tile(tf.expand_dims(hidden_1d_cropped_x, 1),
[1, n_y, 1, 1], name='Tile1Dx'),
tf.tile(tf.expand_dims(hidden_1d_cropped_y, 2),
[1, 1, n_x, 1], name='Tile1Dy')]
channels += 2 * hidden_1d.shape.as_list()[-1]
channels += hidden_2d.shape.as_list()[-1]
hidden_2d = tf.concat(
[hidden_2d] + augmentation_features, 3, name='Stack2Dfeatures')
logging.info('2d stacked features are depth %d %s', channels, hidden_2d)
hidden_2d.set_shape([None, None, None, channels])
return hidden_2d
def _build_2d_embedding(self, hidden_1d, residue_index, inputs_2d,
output_dimension, use_on_the_fly_stats, crop_x,
crop_y, crop_size_x, crop_size_y, data_format):
"""Returns NHWC logits and NHWC preactivations."""
logging.info('2d %s %s', inputs_2d, data_format)
# Stack with diagonal has already happened.
inputs_2d_cropped = inputs_2d
features_forward = None
hidden_2d = inputs_2d_cropped
hidden_2d = self._concatenate_2d(
hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
self._binary_code_bits, crop_size_x, crop_size_y)
config_2d_deep = self._network_2d_deep
num_features = hidden_2d.shape.as_list()[3]
if data_format == 'NCHW':
logging.info('NCHW shape deep pre %s', hidden_2d)
hidden_2d = tf.transpose(hidden_2d, perm=[0, 3, 1, 2])
hidden_2d.set_shape([None, num_features, None, None])
logging.info('NCHW shape deep post %s', hidden_2d)
layers_forward = None
if config_2d_deep.extra_blocks:
# Optionally put some extra double-size blocks at the beginning.
with tf.compat.v1.variable_scope('Deep2DExtra'):
hidden_2d = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=2 * config_2d_deep.num_filters,
num_channels=2 * config_2d_deep.num_filters,
num_layers=config_2d_deep.extra_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=True,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
num_features = 2 * config_2d_deep.num_filters
if self._skip_connect:
layers_forward = hidden_2d
if features_forward is not None:
hidden_2d = tf.concat([hidden_2d, features_forward], 1
if data_format == 'NCHW' else 3)
with tf.compat.v1.variable_scope('Deep2D'):
logging.info('2d hidden shape is %s', str(hidden_2d.shape.as_list()))
contact_pre_logits = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=(config_2d_deep.num_filters
if self._reshape_layer else output_dimension),
num_channels=config_2d_deep.num_filters,
num_layers=config_2d_deep.num_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=self._reshape_layer,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
contact_logits = self._output_from_pre_logits(
contact_pre_logits, features_forward, layers_forward,
output_dimension, data_format, crop_x, crop_y, use_on_the_fly_stats)
if data_format == 'NCHW':
contact_pre_logits = tf.transpose(contact_pre_logits, perm=[0, 2, 3, 1])
# Both of these will be NHWC
return contact_logits, contact_pre_logits
def _output_from_pre_logits(self, contact_pre_logits, features_forward,
layers_forward, output_dimension, data_format,
crop_x, crop_y, use_on_the_fly_stats):
"""Given pre-logits, compute the final distogram/contact activations."""
config_2d_deep = self._network_2d_deep
if self._reshape_layer:
in_channels = config_2d_deep.num_filters
concat_features = [contact_pre_logits]
if features_forward is not None:
concat_features.append(features_forward)
in_channels += self._features_forward
if layers_forward is not None:
concat_features.append(layers_forward)
in_channels += 2 * config_2d_deep.num_filters
if len(concat_features) > 1:
contact_pre_logits = tf.concat(concat_features,
1 if data_format == 'NCHW' else 3)
contact_logits = two_dim_convnet.make_conv_layer(
contact_pre_logits,
in_channels=in_channels,
out_channels=output_dimension,
layer_name='output_reshape_1x1h',
filter_size=1,
filter_size_2=1,
non_linearity=False,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
data_format=data_format)
else:
contact_logits = contact_pre_logits
if data_format == 'NCHW':
contact_logits = tf.transpose(contact_logits, perm=[0, 2, 3, 1])
if self._position_specific_bias_size:
# Make 2D pos-specific biases: NHWC.
biases = build_crops_biases(
self._position_specific_bias_size,
self._position_specific_bias, crop_x, crop_y, back_prop=True)
contact_logits += biases
# Will be NHWC.
return contact_logits
def update_crop_fetches(self, fetches):
"""Add auxiliary outputs for a crop to the fetches."""
if self.secstruct_multiplier > 0:
fetches['secstruct_probs'] = self._secstruct.get_q8_probs()
if self.asa_multiplier > 0:
fetches['asa_output'] = self._asa.asa_output
if self.torsion_multiplier > 0:
fetches['torsion_probs'] = self.torsion_output
def build_crops_biases(bias_size, raw_biases, crop_x, crop_y, back_prop):
"""Take the offset-specific biases and reshape them to match current crops.
Args:
bias_size: how many bias variables we're storing.
raw_biases: the bias variable
crop_x: B x 2 array of start/end for the batch
crop_y: B x 2 array of start/end for the batch
back_prop: whether to backprop through the map_fn.
Returns:
Reshaped biases.
"""
# First pad the biases with a copy of the final value to the maximum length.
max_off_diag = tf.reduce_max(
tf.maximum(tf.abs(crop_x[:, 1] - crop_y[:, 0]),
tf.abs(crop_y[:, 1] - crop_x[:, 0])))
padded_bias_size = tf.maximum(bias_size, max_off_diag)
biases = tf.concat(
[raw_biases,
tf.tile(raw_biases[-1:, :],
[padded_bias_size - bias_size, 1])], axis=0)
# Now prepend a mirror image (excluding 0th elt) for below-diagonal.
biases = tf.concat([tf.reverse(biases[1:, :], axis=[0]), biases], axis=0)
# Which diagonal of the full matrix each crop starts on (top left):
start_diag = crop_x[:, 0:1] - crop_y[:, 0:1] # B x 1
crop_size_x = tf.reduce_max(crop_x[:, 1] - crop_x[:, 0])
crop_size_y = tf.reduce_max(crop_y[:, 1] - crop_y[:, 0])
# Relative offset of each row within a crop:
# (off-diagonal decreases as y increases)
increment = tf.expand_dims(-tf.range(0, crop_size_y), 0) # 1 x crop_size_y
# Index of diagonal of first element of each row, flattened.
row_offsets = tf.reshape(start_diag + increment, [-1]) # B*crop_size_y
logging.info('row_offsets %s', row_offsets)
# Make it relative to the start of the biases array. (0-th diagonal is in
# the middle at position padded_bias_size - 1)
row_offsets += padded_bias_size - 1
# Map_fn to build the individual rows.
# B*cropsizey x cropsizex x num_bins
cropped_biases = tf.map_fn(lambda i: biases[i:i+crop_size_x, :],
elems=row_offsets, dtype=tf.float32,
back_prop=back_prop)
logging.info('cropped_biases %s', cropped_biases)
return tf.reshape(
cropped_biases, [-1, crop_size_y, crop_size_x, tf.shape(raw_biases)[-1]])

View File

@@ -0,0 +1,98 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Write contact map predictions to a tf.io.gfile.
Either write a binary contact map as an RR format text file, or a
histogram prediction as a pickle of a dict containing a numpy array.
"""
import os
import numpy as np
import six.moves.cPickle as pickle
import tensorflow.compat.v1 as tf
RR_FORMAT = """PFRMAT RR
TARGET {}
AUTHOR DM-ORIGAMI-TEAM
METHOD {}
MODEL 1
{}
"""
def save_rr_file(filename, probs, domain, sequence,
method='dm-contacts-resnet'):
"""Save a contact probability matrix as an RR file."""
assert len(sequence) == probs.shape[0]
assert len(sequence) == probs.shape[1]
with tf.io.gfile.GFile(filename, 'w') as f:
f.write(RR_FORMAT.format(domain, method, sequence))
for i in range(probs.shape[0]):
for j in range(i + 1, probs.shape[1]):
f.write('{:d} {:d} {:d} {:d} {:f}\n'.format(
i + 1, j + 1, 0, 8, probs[j, i]))
f.write('END\n')
def save_torsions(torsions_dir, filebase, sequence, torsions_probs):
"""Save Torsions to a file as pickle of a dict."""
filename = os.path.join(torsions_dir, filebase + '.torsions')
t_dict = dict(probs=torsions_probs, sequence=sequence)
with tf.io.gfile.GFile(filename, 'w') as fh:
pickle.dump(t_dict, fh, protocol=2)
def save_distance_histogram(
filename, probs, domain, sequence, min_range, max_range, num_bins):
"""Save a distance histogram prediction matrix as a pickle file."""
dh_dict = {
'min_range': min_range,
'max_range': max_range,
'num_bins': num_bins,
'domain': domain,
'sequence': sequence,
'probs': probs.astype(np.float32)}
save_distance_histogram_from_dict(filename, dh_dict)
def save_distance_histogram_from_dict(filename, dh_dict):
"""Save a distance histogram prediction matrix as a pickle file."""
fields = ['min_range', 'max_range', 'num_bins', 'domain', 'sequence', 'probs']
missing_fields = [f for f in fields if f not in dh_dict]
assert not missing_fields, 'Fields {} missing from dictionary'.format(
missing_fields)
assert len(dh_dict['sequence']) == dh_dict['probs'].shape[0]
assert len(dh_dict['sequence']) == dh_dict['probs'].shape[1]
assert dh_dict['num_bins'] == dh_dict['probs'].shape[2]
assert dh_dict['min_range'] >= 0.0
assert dh_dict['max_range'] > 0.0
with tf.io.gfile.GFile(filename, 'wb') as fw:
pickle.dump(dh_dict, fw, protocol=2)
def contact_map_from_distogram(distogram_dict):
"""Split the boundary bin."""
num_bins = distogram_dict['probs'].shape[-1]
bin_size_angstrom = distogram_dict['max_range'] / num_bins
threshold_cts = (8.0 - distogram_dict['min_range']) / bin_size_angstrom
threshold_bin = int(threshold_cts) # Round down
pred_contacts = np.sum(distogram_dict['probs'][:, :, :threshold_bin], axis=-1)
if threshold_bin < threshold_cts: # Add on the fraction of the boundary bin.
pred_contacts += distogram_dict['probs'][:, :, threshold_bin] * (
threshold_cts - threshold_bin)
return pred_contacts

View File

@@ -0,0 +1,125 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Form a weighted average of several distograms.
Can also/instead form a weighted average of a set of distance histogram pickle
files, so long as they have identical hyperparameters.
"""
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers
flags.DEFINE_list(
'pickle_dirs', [],
'Comma separated list of directories with pickle files to ensemble.')
flags.DEFINE_list(
'weights', [],
'Comma separated list of weights for the pickle files from different dirs.')
flags.DEFINE_string(
'output_dir', None, 'Directory where to save results of the evaluation.')
FLAGS = flags.FLAGS
def ensemble_distance_histograms(pickle_dirs, weights, output_dir):
"""Find all the contact maps in the first dir, then ensemble across dirs."""
if len(pickle_dirs) <= 1:
logging.warning('Pointless to ensemble %d pickle_dirs %s',
len(pickle_dirs), pickle_dirs)
# Carry on if there's one dir, otherwise do nothing.
if not pickle_dirs:
return
tf.io.gfile.makedirs(output_dir)
one_dir_pickle_files = tf.io.gfile.glob(
os.path.join(pickle_dirs[0], '*.pickle'))
assert one_dir_pickle_files, pickle_dirs[0]
original_files = len(one_dir_pickle_files)
logging.info('Found %d files %d in first of %d dirs',
original_files, len(one_dir_pickle_files), len(pickle_dirs))
targets = [os.path.splitext(os.path.basename(f))[0]
for f in one_dir_pickle_files]
skipped = 0
wrote = 0
for t in targets:
dump_file = os.path.join(output_dir, t + '.pickle')
pickle_files = [os.path.join(pickle_dir, t + '.pickle')
for pickle_dir in pickle_dirs]
_, new_dict = ensemble_one_distance_histogram(pickle_files, weights)
if new_dict is not None:
wrote += 1
distogram_io.save_distance_histogram_from_dict(dump_file, new_dict)
msg = 'Distograms Wrote %s %d / %d Skipped %d %s' % (
t, wrote, len(one_dir_pickle_files), skipped, dump_file)
logging.info(msg)
def ensemble_one_distance_histogram(pickle_files, weights):
"""Average the given pickle_files and dump."""
dicts = []
sequence = None
max_dim = None
for picklefile in pickle_files:
if not tf.io.gfile.exists(picklefile):
logging.warning('missing %s', picklefile)
break
logging.info('loading pickle file %s', picklefile)
distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
if sequence is None:
sequence = distance_histogram_dict['sequence']
else:
assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
sequence, distance_histogram_dict['sequence'])
dicts.append(distance_histogram_dict)
assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
'%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
max_dim = dicts[-1]['probs'].shape[2]
if len(dicts) != len(pickle_files):
logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
return sequence, None
ensemble_hist = (
sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
new_dict = dict(dicts[0])
new_dict['probs'] = ensemble_hist
return sequence, new_dict
def main(argv):
del argv # Unused.
num_dirs = len(FLAGS.pickle_dirs)
if FLAGS.weights:
assert len(FLAGS.weights) == num_dirs, (
'Supply as many weights as pickle_dirs, or no weights')
weights = [float(w) for w in FLAGS.weights]
else:
weights = [1.0 for w in range(num_dirs)]
ensemble_distance_histograms(
pickle_dirs=FLAGS.pickle_dirs,
weights=weights,
output_dir=FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,67 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parsers for various standard biology or AlphaFold-specific formats."""
import pickle
import tensorflow.compat.v1 as tf
def distance_histogram_dict(f):
"""Parses distance histogram dict pickle.
Distance histograms are stored as pickles of dicts.
Write one of these with contacts/write_rr_file.write_pickle_file()
Args:
f: File-like handle to distance histogram dict pickle.
Returns:
Dict with fields:
probs: (an L x L x num_bins) histogram.
num_bins: number of bins for each residue pair
min_range: left hand edge of the distance histogram
max_range: the extent of the histogram NOT the right hand edge.
"""
contact_dict = pickle.load(f, encoding='latin1')
num_res = len(contact_dict['sequence'])
if not all(key in contact_dict.keys()
for key in ['probs', 'num_bins', 'min_range', 'max_range']):
raise ValueError('The pickled contact dict doesn\'t contain all required '
'keys: probs, num_bins, min_range, max_range but %s.' %
contact_dict.keys())
if contact_dict['probs'].ndim != 3:
raise ValueError(
'Probs is not rank 3 but %d' % contact_dict['probs'].ndim)
if contact_dict['num_bins'] != contact_dict['probs'].shape[2]:
raise ValueError(
'The probs shape doesn\'t match num_bins in the third dimension. '
'Expected %d got %d.' % (contact_dict['num_bins'],
contact_dict['probs'].shape[2]))
if contact_dict['probs'].shape[:2] != (num_res, num_res):
raise ValueError(
'The first two probs dims (%i, %i) aren\'t equal to len(sequence) %i'
% (contact_dict['probs'].shape[0], contact_dict['probs'].shape[1],
num_res))
return contact_dict
def parse_distance_histogram_dict(filepath):
"""Parses distance histogram piclkle from filepath."""
with tf.io.gfile.GFile(filepath, 'rb') as f:
return distance_histogram_dict(f)

View File

@@ -0,0 +1,201 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Combines predictions by pasting."""
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import six
import tensorflow.compat.v1 as tf
from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers
flags.DEFINE_string("pickle_input_dir", None,
"Directory to read pickle distance histogram files from.")
flags.DEFINE_string("output_dir", None, "Directory to write chain RR files to.")
flags.DEFINE_string("tfrecord_path", "",
"If provided, construct the average weighted by number of "
"alignments.")
flags.DEFINE_string("crop_sizes", "64,128,256", "The crop sizes to use.")
flags.DEFINE_integer("crop_step", 32, "The step size for cropping.")
FLAGS = flags.FLAGS
def generate_domains(target, sequence, crop_sizes, crop_step):
"""Take fasta files and generate a domain definition for data generation."""
logging.info("Generating crop domains for target %s", target)
windows = [int(x) for x in crop_sizes.split(",")]
num_residues = len(sequence)
domains = []
domains.append({"name": target, "description": (1, num_residues)})
for window in windows:
starts = list(range(0, num_residues - window, crop_step))
# Append a last crop to ensure we get all the way to the end of the
# sequence, even when num_residues - window is not divisible by crop_step.
if num_residues >= window:
starts += [num_residues - window]
for start in starts:
name = "%s-l%i_s%i" % (target, window, start)
domains.append({"name": name, "description": (start + 1, start + window)})
return domains
def get_weights(path):
"""Fetch all the weights from a TFRecord."""
if not path:
return {}
logging.info("Getting weights from %s", path)
weights = {}
record_iterator = tf.python_io.tf_record_iterator(path=path)
for serialized_tfexample in record_iterator:
example = tf.train.Example()
example.ParseFromString(serialized_tfexample)
domain_name = six.ensure_str(
example.features.feature["domain_name"].bytes_list.value[0])
weights[domain_name] = float(
example.features.feature["num_alignments"].int64_list.value[0])
logging.info("Weight %s: %d", domain_name, weights[domain_name])
logging.info("Loaded %d weights", len(weights))
return weights
def paste_distance_histograms(
input_dir, output_dir, weights, crop_sizes, crop_step):
"""Paste together distograms for given domains of given targets and write.
Domains distance histograms are 'pasted', meaning they are substituted
directly into the contact map. The order is determined by the order in the
domain definition file.
Args:
input_dir: String, path to directory containing chain and domain-level
distogram files.
output_dir: String, path to directory to write out chain-level distrogram
files.
weights: A dictionary with weights.
crop_sizes: The crop sizes.
crop_step: The step size for cropping.
Raises:
ValueError: if histogram parameters don't match.
"""
tf.io.gfile.makedirs(output_dir)
targets = tf.io.gfile.glob(os.path.join(input_dir, "*.pickle"))
targets = [os.path.splitext(os.path.basename(t))[0] for t in targets]
targets = set([t.split("-")[0] for t in targets])
logging.info("Pasting distance histograms for %d targets", len(targets))
for target in sorted(targets):
logging.info("%s as chain", target)
chain_pickle_path = os.path.join(input_dir, "%s.pickle" % target)
distance_histogram_dict = parsers.parse_distance_histogram_dict(
chain_pickle_path)
combined_cmap = np.array(distance_histogram_dict["probs"])
# Make the counter map 1-deep but still rank 3.
counter_map = np.ones_like(combined_cmap[:, :, 0:1])
sequence = distance_histogram_dict["sequence"]
target_domains = generate_domains(
target=target, sequence=sequence, crop_sizes=crop_sizes,
crop_step=crop_step)
# Paste in each domain.
for domain in sorted(target_domains, key=lambda x: x["name"]):
if domain["name"] == target:
logging.info("Skipping %s as domain", target)
continue
if "," in domain["description"]:
logging.info("Skipping multisegment domain %s",
domain["name"])
continue
crop_start, crop_end = domain["description"]
domain_pickle_path = os.path.join(input_dir, "%s.pickle" % domain["name"])
weight = weights.get(domain["name"], 1e9)
logging.info("Pasting %s: %d-%d. weight: %f", domain_pickle_path,
crop_start, crop_end, weight)
domain_distance_histogram_dict = parsers.parse_distance_histogram_dict(
domain_pickle_path)
for field in ["num_bins", "min_range", "max_range"]:
if domain_distance_histogram_dict[field] != distance_histogram_dict[
field]:
raise ValueError("Field {} does not match {} {}".format(
field,
domain_distance_histogram_dict[field],
distance_histogram_dict[field]))
weight_matrix_size = crop_end - crop_start + 1
weight_matrix = np.ones(
(weight_matrix_size, weight_matrix_size), dtype=np.float32) * weight
combined_cmap[crop_start - 1:crop_end, crop_start - 1:crop_end, :] += (
domain_distance_histogram_dict["probs"] *
np.expand_dims(weight_matrix, 2))
counter_map[crop_start - 1:crop_end,
crop_start - 1:crop_end, 0] += weight_matrix
# Broadcast across the histogram bins.
combined_cmap /= counter_map
# Write out full-chain cmap for folding.
output_chain_pickle_path = os.path.join(output_dir,
"{}.pickle".format(target))
logging.info("Writing to %s", output_chain_pickle_path)
distance_histogram_dict["probs"] = combined_cmap
distance_histogram_dict["target"] = target
# Save the distogram pickle file.
distogram_io.save_distance_histogram_from_dict(
output_chain_pickle_path, distance_histogram_dict)
# Compute the contact map and save it as an RR file.
contact_probs = distogram_io.contact_map_from_distogram(
distance_histogram_dict)
rr_path = os.path.join(output_dir, "%s.rr" % target)
distogram_io.save_rr_file(
filename=rr_path,
probs=contact_probs,
domain=target,
sequence=distance_histogram_dict["sequence"])
def main(argv):
del argv # Unused.
flags.mark_flag_as_required("pickle_input_dir")
weights = get_weights(FLAGS.tfrecord_path)
paste_distance_histograms(
FLAGS.pickle_input_dir, FLAGS.output_dir, weights, FLAGS.crop_sizes,
FLAGS.crop_step)
if __name__ == "__main__":
app.run(main)

View File

@@ -0,0 +1,6 @@
absl-py>=0.8.0
numpy>=1.16
six>=1.12
dm-sonnet>=1.35
tensorflow==1.14
tensorflow-probability==0.7.0

102
alphafold_casp13/run_eval.sh Executable file
View File

@@ -0,0 +1,102 @@
#!/bin/bash
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# We assume the script is being run from the deepmind_research parent directory.
DISTOGRAM_MODEL="alphafold_casp13/873731" # Path to the directory with the distogram model.
BACKGROUND_MODEL="alphafold_casp13/916425" # Path to the directory with the background model.
TORSION_MODEL="alphafold_casp13/941521" # Path to the directory with the torsion model.
TARGET="T1019s2" # The name of the target.
TARGET_PATH="alphafold_casp13/${TARGET}" # Path to the directory with the target input data.
# Set up the virtual environment and install dependencies.
python3 -m venv alphafold_venv
source alphafold_venv/bin/activate
pip install -r alphafold_casp13/requirements.txt
# Create the output directory.
OUTPUT_DIR="${HOME}/contacts_${TARGET}_$(date +%Y_%m_%d_%H_%M_%S)"
mkdir -p "${OUTPUT_DIR}"
echo "Saving output to ${OUTPUT_DIR}/"
# Run contact prediction over 4 replicas.
for replica in 0 1 2 3; do
echo "Launching all models for replica ${replica}"
# Run the distogram model.
python3 -m alphafold_casp13.contacts \
--logtostderr \
--cpu=true \
--config_path="${DISTOGRAM_MODEL}/${replica}/config.json" \
--checkpoint_path="${DISTOGRAM_MODEL}/${replica}/tf_graph_data/tf_graph_data.ckpt" \
--output_path="${OUTPUT_DIR}/distogram/${replica}" \
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
--stats_file="${DISTOGRAM_MODEL}/stats_train_s35.json" &
# Run the background model.
python3 -m alphafold_casp13.contacts \
--logtostderr \
--cpu=true \
--config_path="${BACKGROUND_MODEL}/${replica}/config.json" \
--checkpoint_path="${BACKGROUND_MODEL}/${replica}/tf_graph_data/tf_graph_data.ckpt" \
--output_path="${OUTPUT_DIR}/background_distogram/${replica}" \
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
--stats_file="${BACKGROUND_MODEL}/stats_train_s35.json" &
done
# Run the torsion model, but only 1 replica.
python3 -m alphafold_casp13.contacts \
--logtostderr \
--cpu=true \
--config_path="${TORSION_MODEL}/0/config.json" \
--checkpoint_path="${TORSION_MODEL}/0/tf_graph_data/tf_graph_data.ckpt" \
--output_path="${OUTPUT_DIR}/torsion/0" \
--eval_sstable="${TARGET_PATH}/${TARGET}.tfrec" \
--stats_file="${TORSION_MODEL}/stats_train_s35.json" &
echo "All models running, waiting for them to complete"
wait
echo "Ensembling all replica outputs"
# Run the ensembling jobs for distograms, background distograms.
for output_dir in "${OUTPUT_DIR}/distogram" "${OUTPUT_DIR}/background_distogram"; do
pickle_dirs="${output_dir}/0/pickle_files/,${output_dir}/1/pickle_files/,${output_dir}/2/pickle_files/,${output_dir}/3/pickle_files/"
# Ensemble distograms.
python3 -m alphafold_casp13.ensemble_contact_maps \
--logtostderr \
--pickle_dirs="${pickle_dirs}" \
--output_dir="${output_dir}/ensemble/"
done
# Only ensemble single replica distogram for torsions.
python3 -m alphafold_casp13.ensemble_contact_maps \
--logtostderr \
--pickle_dirs="${OUTPUT_DIR}/torsion/0/pickle_files/" \
--output_dir="${OUTPUT_DIR}/torsion/ensemble/"
echo "Pasting contact maps"
python3 -m alphafold_casp13.paste_contact_maps \
--logtostderr \
--pickle_input_dir="${OUTPUT_DIR}/distogram/ensemble/" \
--output_dir="${OUTPUT_DIR}/pasted/" \
--tfrecord_path="${TARGET_PATH}/${TARGET}.tfrec"
echo "Done"

View File

@@ -0,0 +1,94 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layer for modelling and scoring secondary structure."""
import os
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
# 8-class classes (Q8)
SECONDARY_STRUCTURES = '-HETSGBI'
# Equivalence classes for 3-class (Q3) from Li & Yu 2016.
# See http://www.cmbi.ru.nl/dssp.html for letter explanations.
Q3_MAP = ['-TSGIB', 'H', 'E']
def make_q3_matrices():
"""Generate mapping matrices for secstruct Q8:Q3 equivalence classes."""
dimension = len(SECONDARY_STRUCTURES)
q3_map_matrix = np.zeros((dimension, len(Q3_MAP)))
q3_lookup = np.zeros((dimension,), dtype=np.int32)
for i, eclass in enumerate(Q3_MAP): # equivalence classes
for m in eclass: # Members of the class.
ss_type = SECONDARY_STRUCTURES.index(m)
q3_map_matrix[ss_type, i] = 1.0
q3_lookup[ss_type] = i
return q3_map_matrix, q3_lookup
class Secstruct(object):
"""Make a layer that computes hierarchical secstruct."""
# Build static, shared structures:
q3_map_matrix, q3_lookup = make_q3_matrices()
static_dimension = len(SECONDARY_STRUCTURES)
def __init__(self, name='secstruct'):
self.name = name
self._dimension = Secstruct.static_dimension
def make_layer_new(self, activations):
"""Make the layer."""
with tf.compat.v1.variable_scope(self.name, reuse=tf.AUTO_REUSE):
logging.info('Creating secstruct %s', activations)
self.logits = contrib_layers.linear(activations, self._dimension)
self.ss_q8_probs = tf.nn.softmax(self.logits)
self.ss_q3_probs = tf.matmul(
self.ss_q8_probs, tf.constant(self.q3_map_matrix, dtype=tf.float32))
def get_q8_probs(self):
return self.ss_q8_probs
def save_secstructs(dump_dir_path, name, index, sequence, probs,
label='Deepmind secstruct'):
"""Write secstruct prob distributions to an ss2 file.
Can be overloaded to write out asa values too.
Args:
dump_dir_path: directory where to write files.
name: name of domain
index: index number of multiple samples. (or None for no index)
sequence: string of L residue labels
probs: L x D matrix of probabilities. L is length of sequence,
D is probability dimension (usually 3).
label: A label for the file.
"""
filename = os.path.join(dump_dir_path, '%s.ss2' % name)
if index is not None:
filename = os.path.join(dump_dir_path, '%s_%04d.ss2' % (name, index))
with tf.io.gfile.GFile(filename, 'w') as gf:
logging.info('Saving secstruct to %s', filename)
gf.write('# %s CLASSES [%s] %s sample %s\n\n' % (
label, ''.join(SECONDARY_STRUCTURES[:probs.shape[1]]), name, index))
for l in range(probs.shape[0]):
ss = SECONDARY_STRUCTURES[np.argmax(probs[l, :])]
gf.write('%4d %1s %1s %s\n' % (l + 1, sequence[l], ss, ''.join(
[('%6.3f' % p) for p in probs[l, :]])))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,139 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""One dim convnet for angles prediction."""
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
def weight_variable(shape, stddev=0.01):
"""Returns the weight variable."""
logging.vlog(1, 'weight init for shape %s', str(shape))
return tf.compat.v1.get_variable(
'w', shape, initializer=tf.random_normal_initializer(stddev=stddev))
def bias_variable(shape):
return tf.compat.v1.get_variable(
'b', shape, initializer=tf.zeros_initializer())
def conv2d(x, w, atrou_rate=1, data_format='NHWC'):
if atrou_rate > 1:
return tf.nn.convolution(
x,
w,
dilation_rate=[atrou_rate] * 2,
padding='SAME',
data_format=data_format)
else:
return tf.nn.conv2d(
x, w, strides=[1, 1, 1, 1], padding='SAME', data_format=data_format)
def make_conv_sep2d_layer(input_node,
in_channels,
channel_multiplier,
out_channels,
layer_name,
filter_size,
filter_size_2=None,
batch_norm=False,
is_training=True,
atrou_rate=1,
data_format='NHWC',
stddev=0.01):
"""Use separable convolutions."""
if filter_size_2 is None:
filter_size_2 = filter_size
logging.vlog(1, 'layer %s in %d out %d chan mult %d', layer_name, in_channels,
out_channels, channel_multiplier)
with tf.compat.v1.variable_scope(layer_name):
with tf.compat.v1.variable_scope('depthwise'):
w_depthwise = weight_variable(
[filter_size, filter_size_2, in_channels, channel_multiplier],
stddev=stddev)
with tf.compat.v1.variable_scope('pointwise'):
w_pointwise = weight_variable(
[1, 1, in_channels * channel_multiplier, out_channels], stddev=stddev)
h_conv = tf.nn.separable_conv2d(
input_node,
w_depthwise,
w_pointwise,
padding='SAME',
strides=[1, 1, 1, 1],
rate=[atrou_rate, atrou_rate],
data_format=data_format)
if batch_norm:
h_conv = batch_norm_layer(
h_conv, layer_name=layer_name, is_training=is_training,
data_format=data_format)
else:
b_conv = bias_variable([out_channels])
h_conv = tf.nn.bias_add(h_conv, b_conv, data_format=data_format)
return h_conv
def batch_norm_layer(h_conv, layer_name, is_training=True, data_format='NCHW'):
"""Batch norm layer."""
logging.vlog(1, 'batch norm for layer %s', layer_name)
return contrib_layers.batch_norm(
h_conv,
is_training=is_training,
fused=True,
decay=0.999,
scope=layer_name,
data_format=data_format)
def make_conv_layer(input_node,
in_channels,
out_channels,
layer_name,
filter_size,
filter_size_2=None,
non_linearity=True,
batch_norm=False,
is_training=True,
atrou_rate=1,
data_format='NHWC',
stddev=0.01):
"""Creates a convolution layer."""
if filter_size_2 is None:
filter_size_2 = filter_size
logging.vlog(
1, 'layer %s in %d out %d', layer_name, in_channels, out_channels)
with tf.compat.v1.variable_scope(layer_name):
w_conv = weight_variable(
[filter_size, filter_size_2, in_channels, out_channels], stddev=stddev)
h_conv = conv2d(
input_node, w_conv, atrou_rate=atrou_rate, data_format=data_format)
if batch_norm:
h_conv = batch_norm_layer(
h_conv, layer_name=layer_name, is_training=is_training,
data_format=data_format)
else:
b_conv = bias_variable([out_channels])
h_conv = tf.nn.bias_add(h_conv, b_conv, data_format=data_format)
if non_linearity:
h_conv = tf.nn.elu(h_conv)
return h_conv

View File

@@ -0,0 +1,202 @@
# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""2D Resnet."""
from absl import logging
import tensorflow.compat.v1 as tf
from alphafold_casp13 import two_dim_convnet
def make_sep_res_layer(
input_node,
in_channels,
out_channels,
layer_name,
filter_size,
filter_size_2=None,
batch_norm=False,
is_training=True,
divide_channels_by=2,
atrou_rate=1,
channel_multiplier=0,
data_format='NHWC',
stddev=0.01,
dropout_keep_prob=1.0):
"""A separable resnet block."""
with tf.name_scope(layer_name):
input_times_almost_1 = input_node
h_conv = input_times_almost_1
if batch_norm:
h_conv = two_dim_convnet.batch_norm_layer(
h_conv, layer_name=layer_name, is_training=is_training,
data_format=data_format)
h_conv = tf.nn.elu(h_conv)
if filter_size_2 is None:
filter_size_2 = filter_size
# 1x1 with half size
h_conv = two_dim_convnet.make_conv_layer(
h_conv,
in_channels=in_channels,
out_channels=in_channels / divide_channels_by,
layer_name=layer_name + '_1x1h',
filter_size=1,
filter_size_2=1,
non_linearity=True,
batch_norm=batch_norm,
is_training=is_training,
data_format=data_format,
stddev=stddev)
# 3x3 with half size
if channel_multiplier == 0:
h_conv = two_dim_convnet.make_conv_layer(
h_conv,
in_channels=in_channels / divide_channels_by,
out_channels=in_channels / divide_channels_by,
layer_name=layer_name + '_%dx%dh' % (filter_size, filter_size_2),
filter_size=filter_size,
filter_size_2=filter_size_2,
non_linearity=True,
batch_norm=batch_norm,
is_training=is_training,
atrou_rate=atrou_rate,
data_format=data_format,
stddev=stddev)
else:
# We use separable convolution for 3x3
h_conv = two_dim_convnet.make_conv_sep2d_layer(
h_conv,
in_channels=in_channels / divide_channels_by,
channel_multiplier=channel_multiplier,
out_channels=in_channels / divide_channels_by,
layer_name=layer_name + '_sep%dx%dh' % (filter_size, filter_size_2),
filter_size=filter_size,
filter_size_2=filter_size_2,
batch_norm=batch_norm,
is_training=is_training,
atrou_rate=atrou_rate,
data_format=data_format,
stddev=stddev)
# 1x1 back to normal size without relu
h_conv = two_dim_convnet.make_conv_layer(
h_conv,
in_channels=in_channels / divide_channels_by,
out_channels=out_channels,
layer_name=layer_name + '_1x1',
filter_size=1,
filter_size_2=1,
non_linearity=False,
batch_norm=False,
is_training=is_training,
data_format=data_format,
stddev=stddev)
if dropout_keep_prob < 1.0:
logging.info('dropout keep prob %f', dropout_keep_prob)
h_conv = tf.nn.dropout(h_conv, keep_prob=dropout_keep_prob)
return h_conv + input_times_almost_1
def make_two_dim_resnet(
input_node,
num_residues=50,
num_features=40,
num_predictions=1,
num_channels=32,
num_layers=2,
filter_size=3,
filter_size_2=None,
final_non_linearity=False,
name_prefix='',
fancy=True,
batch_norm=False,
is_training=False,
atrou_rates=None,
channel_multiplier=0,
divide_channels_by=2,
resize_features_with_1x1=False,
data_format='NHWC',
stddev=0.01,
dropout_keep_prob=1.0):
"""Two dim resnet towers."""
del num_residues # Unused.
if atrou_rates is None:
atrou_rates = [1]
if not fancy:
raise ValueError('non fancy deprecated')
logging.info('atrou rates %s', atrou_rates)
logging.info('name prefix %s', name_prefix)
x_image = input_node
previous_layer = x_image
non_linearity = True
for i_layer in range(num_layers):
in_channels = num_channels
out_channels = num_channels
curr_atrou_rate = atrou_rates[i_layer % len(atrou_rates)]
if i_layer == 0:
in_channels = num_features
if i_layer == num_layers - 1:
out_channels = num_predictions
non_linearity = final_non_linearity
if i_layer == 0 or i_layer == num_layers - 1:
layer_name = name_prefix + 'conv%d' % (i_layer + 1)
initial_filter_size = filter_size
if resize_features_with_1x1:
initial_filter_size = 1
previous_layer = two_dim_convnet.make_conv_layer(
input_node=previous_layer,
in_channels=in_channels,
out_channels=out_channels,
layer_name=layer_name,
filter_size=initial_filter_size,
filter_size_2=filter_size_2,
non_linearity=non_linearity,
atrou_rate=curr_atrou_rate,
data_format=data_format,
stddev=stddev)
else:
layer_name = name_prefix + 'res%d' % (i_layer + 1)
previous_layer = make_sep_res_layer(
input_node=previous_layer,
in_channels=in_channels,
out_channels=out_channels,
layer_name=layer_name,
filter_size=filter_size,
filter_size_2=filter_size_2,
batch_norm=batch_norm,
is_training=is_training,
atrou_rate=curr_atrou_rate,
channel_multiplier=channel_multiplier,
divide_channels_by=divide_channels_by,
data_format=data_format,
stddev=stddev,
dropout_keep_prob=dropout_keep_prob)
y = previous_layer
return y

View File

@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
import tensorflow.compat.v1 as tf
from scratchgan import utils

View File

@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub

View File

@@ -22,8 +22,8 @@ from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.io import gfile
from scratchgan import discriminator_nets
from scratchgan import eval_metrics
from scratchgan import generators

View File

@@ -18,7 +18,7 @@ from __future__ import print_function
from absl import logging
import sonnet as snt
import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from scratchgan import utils

View File

@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
def sequential_cross_entropy_loss(logits, expected):

View File

@@ -22,7 +22,7 @@ import os
from absl import logging
import numpy as np
from tensorflow.io import gfile
from tensorflow.compat.v1.io import gfile
# sequences: [N, MAX_TOKENS_SEQUENCE] array of int32
# lengths: [N, 2] array of int32, such that

View File

@@ -19,8 +19,8 @@ import math
import os
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.io import gfile
from scratchgan import reader
EVAL_FILENAME = "evaluated_checkpoints.csv"