mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
a5efafff3a
PiperOrigin-RevId: 298822315
234 lines
10 KiB
Python
234 lines
10 KiB
Python
# Lint as: python3.
|
|
# Copyright 2019 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Contact prediction convnet experiment example."""
|
|
|
|
from absl import logging
|
|
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
|
|
|
|
from alphafold_casp13 import contacts_dataset
|
|
from alphafold_casp13 import contacts_network
|
|
|
|
|
|
def _int_ph(shape, name):
|
|
return tf.placeholder(
|
|
dtype=tf.int32, shape=shape, name=('%s_placeholder' % name))
|
|
|
|
|
|
def _float_ph(shape, name):
|
|
return tf.placeholder(
|
|
dtype=tf.float32, shape=shape, name=('%s_placeholder' % name))
|
|
|
|
|
|
class Contacts(object):
|
|
"""Contact prediction experiment."""
|
|
|
|
def __init__(
|
|
self, tfrecord, stats_file, network_config, crop_size_x, crop_size_y,
|
|
feature_normalization, normalization_exclusion):
|
|
"""Builds the TensorFlow graph."""
|
|
self.network_config = network_config
|
|
self.crop_size_x = crop_size_x
|
|
self.crop_size_y = crop_size_y
|
|
|
|
self._feature_normalization = feature_normalization
|
|
self._normalization_exclusion = normalization_exclusion
|
|
self._model = contacts_network.ContactsNet(**network_config)
|
|
self._features = network_config.features
|
|
self._scalars = network_config.scalars
|
|
self._targets = network_config.targets
|
|
# Add extra targets we need.
|
|
required_targets = ['domain_name', 'resolution', 'chain_name']
|
|
if self.model.torsion_multiplier > 0:
|
|
required_targets.extend([
|
|
'phi_angles', 'phi_mask', 'psi_angles', 'psi_mask'])
|
|
if self.model.secstruct_multiplier > 0:
|
|
required_targets.extend(['sec_structure', 'sec_structure_mask'])
|
|
if self.model.asa_multiplier > 0:
|
|
required_targets.extend(['solv_surf', 'solv_surf_mask'])
|
|
extra_targets = [t for t in required_targets if t not in self._targets]
|
|
if extra_targets:
|
|
targets = list(self._targets)
|
|
targets.extend(extra_targets)
|
|
self._targets = tuple(targets)
|
|
logging.info('Targets %s %s extra %s',
|
|
type(self._targets), self._targets, extra_targets)
|
|
logging.info('Evaluating on %s, stats: %s', tfrecord, stats_file)
|
|
self._build_evaluation_graph(tfrecord=tfrecord, stats_file=stats_file)
|
|
|
|
@property
|
|
def model(self):
|
|
return self._model
|
|
|
|
def _get_feature_normalization(self, features):
|
|
return {key: self._feature_normalization
|
|
for key in features
|
|
if key not in list(self._normalization_exclusion)}
|
|
|
|
def _build_evaluation_graph(self, tfrecord, stats_file):
|
|
"""Constructs the graph in pieces so it can be fed."""
|
|
with tf.name_scope('competitionsep'):
|
|
# Construct the dataset and mapping ops.
|
|
dataset = contacts_dataset.create_tf_dataset(
|
|
tf_record_filename=tfrecord,
|
|
features=tuple(self._features) + tuple(
|
|
self._scalars) + tuple(self._targets))
|
|
|
|
def normalize(data):
|
|
return contacts_dataset.normalize_from_stats_file(
|
|
features=data,
|
|
stats_file_path=stats_file,
|
|
feature_normalization=self._get_feature_normalization(
|
|
self._features),
|
|
copy_unnormalized=list(set(self._features) & set(self._targets)))
|
|
|
|
def convert_to_legacy(features):
|
|
return contacts_dataset.convert_to_legacy_proteins_dataset_format(
|
|
features, self._features, self._scalars, self._targets)
|
|
|
|
dataset = dataset.map(normalize)
|
|
dataset = dataset.map(convert_to_legacy)
|
|
dataset = dataset.batch(1)
|
|
|
|
# Get a batch of tensors in the legacy ProteinsDataset format.
|
|
iterator = tf.data.make_one_shot_iterator(dataset)
|
|
self._input_batch = iterator.get_next()
|
|
|
|
self.num_eval_examples = sum(
|
|
1 for _ in tf.python_io.tf_record_iterator(tfrecord))
|
|
|
|
logging.info('Eval batch:\n%s', self._input_batch)
|
|
feature_dim_1d = self._input_batch.inputs_1d.shape.as_list()[-1]
|
|
feature_dim_2d = self._input_batch.inputs_2d.shape.as_list()[-1]
|
|
feature_dim_2d *= 3 # The diagonals will be stacked before feeding.
|
|
|
|
# Now placeholders for the graph to compute the outputs for one crop.
|
|
self.inputs_1d_placeholder = _float_ph(
|
|
shape=[None, None, feature_dim_1d], name='inputs_1d')
|
|
self.residue_index_placeholder = _int_ph(
|
|
shape=[None, None], name='residue_index')
|
|
self.inputs_2d_placeholder = _float_ph(
|
|
shape=[None, None, None, feature_dim_2d], name='inputs_2d')
|
|
# 4 ints: x_start, x_end, y_start, y_end.
|
|
self.crop_placeholder = _int_ph(shape=[None, 4], name='crop')
|
|
|
|
# Finally placeholders for the graph to score the complete contact map.
|
|
self.probs_placeholder = _float_ph(shape=[None, None, None], name='probs')
|
|
self.softmax_probs_placeholder = _float_ph(
|
|
shape=[None, None, None, self.network_config.num_bins],
|
|
name='softmax_probs')
|
|
self.cb_placeholder = _float_ph(shape=[None, None, 3], name='cb')
|
|
self.cb_mask_placeholder = _float_ph(shape=[None, None], name='cb_mask')
|
|
self.lengths_placeholder = _int_ph(shape=[None], name='lengths')
|
|
|
|
if self.model.secstruct_multiplier > 0:
|
|
self.sec_structure_placeholder = _float_ph(
|
|
shape=[None, None, 8], name='sec_structure')
|
|
self.sec_structure_logits_placeholder = _float_ph(
|
|
shape=[None, None, 8], name='sec_structure_logits')
|
|
self.sec_structure_mask_placeholder = _float_ph(
|
|
shape=[None, None, 1], name='sec_structure_mask')
|
|
|
|
if self.model.asa_multiplier > 0:
|
|
self.solv_surf_placeholder = _float_ph(
|
|
shape=[None, None, 1], name='solv_surf')
|
|
self.solv_surf_logits_placeholder = _float_ph(
|
|
shape=[None, None, 1], name='solv_surf_logits')
|
|
self.solv_surf_mask_placeholder = _float_ph(
|
|
shape=[None, None, 1], name='solv_surf_mask')
|
|
|
|
if self.model.torsion_multiplier > 0:
|
|
self.torsions_truth_placeholder = _float_ph(
|
|
shape=[None, None, 2], name='torsions_truth')
|
|
self.torsions_mask_placeholder = _float_ph(
|
|
shape=[None, None, 1], name='torsions_mask')
|
|
self.torsion_logits_placeholder = _float_ph(
|
|
shape=[None, None, self.network_config.torsion_bins ** 2],
|
|
name='torsion_logits')
|
|
|
|
# Build a dict to pass all the placeholders into build.
|
|
placeholders = {
|
|
'inputs_1d_placeholder': self.inputs_1d_placeholder,
|
|
'residue_index_placeholder': self.residue_index_placeholder,
|
|
'inputs_2d_placeholder': self.inputs_2d_placeholder,
|
|
'crop_placeholder': self.crop_placeholder,
|
|
'probs_placeholder': self.probs_placeholder,
|
|
'softmax_probs_placeholder': self.softmax_probs_placeholder,
|
|
'cb_placeholder': self.cb_placeholder,
|
|
'cb_mask_placeholder': self.cb_mask_placeholder,
|
|
'lengths_placeholder': self.lengths_placeholder,
|
|
}
|
|
if self.model.secstruct_multiplier > 0:
|
|
placeholders.update({
|
|
'sec_structure': self.sec_structure_placeholder,
|
|
'sec_structure_logits_placeholder':
|
|
self.sec_structure_logits_placeholder,
|
|
'sec_structure_mask': self.sec_structure_mask_placeholder,})
|
|
if self.model.asa_multiplier > 0:
|
|
placeholders.update({
|
|
'solv_surf': self.solv_surf_placeholder,
|
|
'solv_surf_logits_placeholder': self.solv_surf_logits_placeholder,
|
|
'solv_surf_mask': self.solv_surf_mask_placeholder,})
|
|
if self.model.torsion_multiplier > 0:
|
|
placeholders.update({
|
|
'torsions_truth': self.torsions_truth_placeholder,
|
|
'torsion_logits_placeholder': self.torsion_logits_placeholder,
|
|
'torsions_truth_mask': self.torsions_mask_placeholder,})
|
|
|
|
activations = self._model(
|
|
crop_size_x=self.crop_size_x,
|
|
crop_size_y=self.crop_size_y,
|
|
placeholders=placeholders)
|
|
self.eval_probs_softmax = tf.nn.softmax(
|
|
activations[:, :, :, :self.network_config.num_bins])
|
|
self.eval_probs = tf.reduce_sum(
|
|
self.eval_probs_softmax[:, :, :, :self._model.quant_threshold()],
|
|
axis=3)
|
|
|
|
def get_one_example(self, sess):
|
|
"""Pull one example off the queue so we can feed it for evaluation."""
|
|
request_dict = {
|
|
'inputs_1d': self._input_batch.inputs_1d,
|
|
'inputs_2d': self._input_batch.inputs_2d,
|
|
'sequence_lengths': self._input_batch.sequence_lengths,
|
|
'beta_positions': self._input_batch.targets.beta_positions,
|
|
'beta_mask': self._input_batch.targets.beta_mask,
|
|
'domain_name': self._input_batch.targets.domain_name,
|
|
'chain_name': self._input_batch.targets.chain_name,
|
|
'sequences': self._input_batch.sequences,
|
|
}
|
|
if hasattr(self._input_batch.targets, 'residue_index'):
|
|
request_dict.update(
|
|
{'residue_index': self._input_batch.targets.residue_index})
|
|
if hasattr(self._input_batch.targets, 'phi_angles'):
|
|
request_dict.update(
|
|
{'phi_angles': self._input_batch.targets.phi_angles,
|
|
'psi_angles': self._input_batch.targets.psi_angles,
|
|
'phi_mask': self._input_batch.targets.phi_mask,
|
|
'psi_mask': self._input_batch.targets.psi_mask})
|
|
if hasattr(self._input_batch.targets, 'sec_structure'):
|
|
request_dict.update(
|
|
{'sec_structure': self._input_batch.targets.sec_structure,
|
|
'sec_structure_mask': self._input_batch.targets.sec_structure_mask,})
|
|
if hasattr(self._input_batch.targets, 'solv_surf'):
|
|
request_dict.update(
|
|
{'solv_surf': self._input_batch.targets.solv_surf,
|
|
'solv_surf_mask': self._input_batch.targets.solv_surf_mask,})
|
|
if hasattr(self._input_batch.targets, 'alpha_positions'):
|
|
request_dict.update(
|
|
{'alpha_positions': self._input_batch.targets.alpha_positions,
|
|
'alpha_mask': self._input_batch.targets.alpha_mask,})
|
|
batch = sess.run(request_dict)
|
|
return batch
|