Files
deepmind-research/alphafold_casp13/contacts_network.py
T
Augustin Zidek d8d6d1e75a Add more information on data access and remove dead argument.
PiperOrigin-RevId: 289429259
2020-01-15 18:22:17 +00:00

492 lines
20 KiB
Python

# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network for predicting C-beta contacts."""
from absl import logging
import sonnet
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import asa_output
from alphafold_casp13 import secstruct
from alphafold_casp13 import two_dim_convnet
from alphafold_casp13 import two_dim_resnet
def call_on_tuple(f):
"""Unpacks a tuple input parameter into arguments for a function f.
Mimics tuple unpacking in lambdas, which existed in Python 2 but has been
removed in Python 3.
Args:
f: A function taking multiple arguments.
Returns:
A function equivalent to f accepting a tuple, which is then unpacked.
"""
return lambda args: f(*args)
class ContactsNet(sonnet.AbstractModule):
"""A network to go from sequence to distance histograms."""
def __init__(self,
binary_code_bits,
data_format,
distance_multiplier,
features,
features_forward,
max_range,
min_range,
num_bins,
reshape_layer,
resolution_noise_scale,
scalars,
targets,
network_2d_deep,
torsion_bins=None,
skip_connect=0,
position_specific_bias_size=0,
filters_1d=(),
collapsed_batch_norm=False,
is_ca_feature=False,
asa_multiplier=0.0,
secstruct_multiplier=0.0,
torsion_multiplier=0.0,
name='contacts_net'):
"""Construct position prediction network."""
super(ContactsNet, self).__init__(name=name)
self._filters_1d = filters_1d
self._collapsed_batch_norm = collapsed_batch_norm
self._is_ca_feature = is_ca_feature
self._binary_code_bits = binary_code_bits
self._data_format = data_format
self._distance_multiplier = distance_multiplier
self._features = features
self._features_forward = features_forward
self._max_range = max_range
self._min_range = min_range
self._num_bins = num_bins
self._position_specific_bias_size = position_specific_bias_size
self._reshape_layer = reshape_layer
self._resolution_noise_scale = resolution_noise_scale
self._scalars = scalars
self._torsion_bins = torsion_bins
self._skip_connect = skip_connect
self._targets = targets
self._network_2d_deep = network_2d_deep
self.asa_multiplier = asa_multiplier
self.secstruct_multiplier = secstruct_multiplier
self.torsion_multiplier = torsion_multiplier
with self._enter_variable_scope():
if self.secstruct_multiplier > 0:
self._secstruct = secstruct.Secstruct()
if self.asa_multiplier > 0:
self._asa = asa_output.ASAOutputLayer()
if self._position_specific_bias_size:
self._position_specific_bias = tf.get_variable(
'position_specific_bias',
[self._position_specific_bias_size, self._num_bins or 1],
initializer=tf.zeros_initializer())
def quant_threshold(self, threshold=8.0):
"""Find the bin that is 8A+: we sum mass below this bin gives contact prob.
Args:
threshold: The distance threshold.
Returns:
Index of bin.
"""
# Note that this misuses the max_range as the range.
return int(
(threshold - self._min_range) * self._num_bins / float(self._max_range))
def _build(self, crop_size_x=0, crop_size_y=0, placeholders=None):
"""Puts the network into the graph.
Args:
crop_size_x: Crop a chunk out in one dimension. 0 means no cropping.
crop_size_y: Crop a chunk out in one dimension. 0 means no cropping.
placeholders: A dict containing the placeholders needed.
Returns:
A Tensor with logits of size [batch_size, num_residues, 3].
"""
crop_placeholder = placeholders['crop_placeholder']
inputs_1d = placeholders['inputs_1d_placeholder']
if self._is_ca_feature and 'aatype' in self._features:
logging.info('Collapsing aatype to is_ca_feature %s',
inputs_1d.shape.as_list()[-1])
assert inputs_1d.shape.as_list()[-1] <= 21 + (
1 if 'seq_length' in self._features else 0)
inputs_1d = inputs_1d[:, :, 7:8]
logits = self.compute_outputs(
inputs_1d=inputs_1d,
residue_index=placeholders['residue_index_placeholder'],
inputs_2d=placeholders['inputs_2d_placeholder'],
crop_x=crop_placeholder[:, 0:2],
crop_y=crop_placeholder[:, 2:4],
use_on_the_fly_stats=True,
crop_size_x=crop_size_x,
crop_size_y=crop_size_y,
data_format='NHWC', # Force NHWC for evals.
)
return logits
def compute_outputs(self, inputs_1d, residue_index, inputs_2d, crop_x, crop_y,
use_on_the_fly_stats, crop_size_x, crop_size_y,
data_format='NHWC'):
"""Given the inputs for a block, compute the network outputs."""
hidden_1d = inputs_1d
hidden_1d_list = [hidden_1d]
if len(hidden_1d_list) != 1:
hidden_1d = tf.concat(hidden_1d_list, 2)
output_dimension = self._num_bins or 1
if self._distance_multiplier > 0:
output_dimension += 1
logits, activations = self._build_2d_embedding(
hidden_1d=hidden_1d,
residue_index=residue_index,
inputs_2d=inputs_2d,
output_dimension=output_dimension,
use_on_the_fly_stats=use_on_the_fly_stats,
crop_x=crop_x,
crop_y=crop_y,
crop_size_x=crop_size_x, crop_size_y=crop_size_y,
data_format=data_format)
logits = tf.debugging.check_numerics(
logits, 'NaN in resnet activations', name='resnet_activations')
if (self.secstruct_multiplier > 0 or
self.asa_multiplier > 0 or
self.torsion_multiplier > 0):
# Make a 1d embedding by reducing the 2D activations.
# We do this in the x direction and the y direction separately.
collapse_dim = 1
join_dim = -1
embedding_1d = tf.concat(
# First targets are crop_x (axis 2) which we must reduce on axis 1
[tf.concat([tf.reduce_max(activations, axis=collapse_dim),
tf.reduce_mean(activations, axis=collapse_dim)],
axis=join_dim),
# Next targets are crop_y (axis 1) which we must reduce on axis 2
tf.concat([tf.reduce_max(activations, axis=collapse_dim+1),
tf.reduce_mean(activations, axis=collapse_dim+1)],
axis=join_dim)],
axis=collapse_dim) # Join the two crops together.
if self._collapsed_batch_norm:
embedding_1d = tf.contrib.layers.batch_norm(
embedding_1d, is_training=use_on_the_fly_stats,
fused=True, decay=0.999, scope='collapsed_batch_norm',
data_format='NHWC')
for i, nfil in enumerate(self._filters_1d):
embedding_1d = tf.contrib.layers.fully_connected(
embedding_1d,
num_outputs=nfil,
normalizer_fn=(
tf.contrib.layers.batch_norm if self._collapsed_batch_norm
else None),
normalizer_params={'is_training': use_on_the_fly_stats,
'updates_collections': None},
scope='collapsed_embed_%d' % i)
if self.torsion_multiplier > 0:
self.torsion_logits = tf.contrib.layers.fully_connected(
embedding_1d,
num_outputs=self._torsion_bins * self._torsion_bins,
activation_fn=None,
scope='torsion_logits')
self.torsion_output = tf.nn.softmax(self.torsion_logits)
if self.secstruct_multiplier > 0:
self._secstruct.make_layer_new(embedding_1d)
if self.asa_multiplier > 0:
self.asa_logits = self._asa.compute_asa_output(embedding_1d)
return logits
@staticmethod
def _concatenate_2d(hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
binary_code_bits, crop_size_x, crop_size_y):
# Form the pairwise expansion of the 1D embedding
# And the residue offsets and (one) absolute position.
with tf.name_scope('Features2D'):
range_scale = 100.0 # Crude normalization factor.
n = tf.shape(hidden_1d)[1]
# pylint: disable=g-long-lambda
hidden_1d_cropped_y = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))], [0, 0]])),
elems=(crop_y, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_y = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_y -(n - c[0]))]])),
elems=(residue_index, crop_y), dtype=tf.int32,
back_prop=False)
hidden_1d_cropped_x = tf.map_fn(
call_on_tuple(lambda c, h: tf.pad(
h[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))], [0, 0]])),
elems=(crop_x, hidden_1d), dtype=tf.float32,
back_prop=True)
range_n_x = tf.map_fn(
call_on_tuple(lambda ri, c: tf.pad(
ri[tf.maximum(0, c[0]):c[1]],
[[tf.maximum(0, -c[0]),
tf.maximum(0, crop_size_x -(n - c[0]))]])),
elems=(residue_index, crop_x), dtype=tf.int32,
back_prop=False)
# pylint: enable=g-long-lambda
n_x = crop_size_x
n_y = crop_size_y
offset = (tf.expand_dims(tf.cast(range_n_x, tf.float32), 1) -
tf.expand_dims(tf.cast(range_n_y, tf.float32), 2)) / range_scale
position_features = [
tf.tile(
tf.reshape(
(tf.cast(range_n_y, tf.float32) - range_scale) / range_scale,
[-1, n_y, 1, 1]), [1, 1, n_x, 1],
name='TileRange'),
tf.tile(
tf.reshape(offset, [-1, n_y, n_x, 1]), [1, 1, 1, 1],
name='TileOffset')
]
channels = 2
if binary_code_bits:
# Binary coding of position.
exp_range_n_y = tf.expand_dims(range_n_y, 2)
bin_y = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_y // (1 << i), 2)
for i in range(binary_code_bits)], 2))
exp_range_n_x = tf.expand_dims(range_n_x, 2)
bin_x = tf.stop_gradient(
tf.concat([tf.math.floormod(exp_range_n_x // (1 << i), 2)
for i in range(binary_code_bits)], 2))
position_features += [
tf.tile(
tf.expand_dims(tf.cast(bin_y, tf.float32), 2), [1, 1, n_x, 1],
name='TileBinRangey'),
tf.tile(
tf.expand_dims(tf.cast(bin_x, tf.float32), 1), [1, n_y, 1, 1],
name='TileBinRangex')
]
channels += 2 * binary_code_bits
augmentation_features = position_features + [
tf.tile(tf.expand_dims(hidden_1d_cropped_x, 1),
[1, n_y, 1, 1], name='Tile1Dx'),
tf.tile(tf.expand_dims(hidden_1d_cropped_y, 2),
[1, 1, n_x, 1], name='Tile1Dy')]
channels += 2 * hidden_1d.shape.as_list()[-1]
channels += hidden_2d.shape.as_list()[-1]
hidden_2d = tf.concat(
[hidden_2d] + augmentation_features, 3, name='Stack2Dfeatures')
logging.info('2d stacked features are depth %d %s', channels, hidden_2d)
hidden_2d.set_shape([None, None, None, channels])
return hidden_2d
def _build_2d_embedding(self, hidden_1d, residue_index, inputs_2d,
output_dimension, use_on_the_fly_stats, crop_x,
crop_y, crop_size_x, crop_size_y, data_format):
"""Returns NHWC logits and NHWC preactivations."""
logging.info('2d %s %s', inputs_2d, data_format)
# Stack with diagonal has already happened.
inputs_2d_cropped = inputs_2d
features_forward = None
hidden_2d = inputs_2d_cropped
hidden_2d = self._concatenate_2d(
hidden_1d, residue_index, hidden_2d, crop_x, crop_y,
self._binary_code_bits, crop_size_x, crop_size_y)
config_2d_deep = self._network_2d_deep
num_features = hidden_2d.shape.as_list()[3]
if data_format == 'NCHW':
logging.info('NCHW shape deep pre %s', hidden_2d)
hidden_2d = tf.transpose(hidden_2d, perm=[0, 3, 1, 2])
hidden_2d.set_shape([None, num_features, None, None])
logging.info('NCHW shape deep post %s', hidden_2d)
layers_forward = None
if config_2d_deep.extra_blocks:
# Optionally put some extra double-size blocks at the beginning.
with tf.variable_scope('Deep2DExtra'):
hidden_2d = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=2 * config_2d_deep.num_filters,
num_channels=2 * config_2d_deep.num_filters,
num_layers=config_2d_deep.extra_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=True,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
num_features = 2 * config_2d_deep.num_filters
if self._skip_connect:
layers_forward = hidden_2d
if features_forward is not None:
hidden_2d = tf.concat([hidden_2d, features_forward], 1
if data_format == 'NCHW' else 3)
with tf.variable_scope('Deep2D'):
logging.info('2d hidden shape is %s', str(hidden_2d.shape.as_list()))
contact_pre_logits = two_dim_resnet.make_two_dim_resnet(
input_node=hidden_2d,
num_residues=None, # Unused
num_features=num_features,
num_predictions=(config_2d_deep.num_filters
if self._reshape_layer else output_dimension),
num_channels=config_2d_deep.num_filters,
num_layers=config_2d_deep.num_blocks *
config_2d_deep.num_layers_per_block,
filter_size=3,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
fancy=True,
final_non_linearity=self._reshape_layer,
atrou_rates=[1, 2, 4, 8],
data_format=data_format,
dropout_keep_prob=1.0
)
contact_logits = self._output_from_pre_logits(
contact_pre_logits, features_forward, layers_forward,
output_dimension, data_format, crop_x, crop_y, use_on_the_fly_stats)
if data_format == 'NCHW':
contact_pre_logits = tf.transpose(contact_pre_logits, perm=[0, 2, 3, 1])
# Both of these will be NHWC
return contact_logits, contact_pre_logits
def _output_from_pre_logits(self, contact_pre_logits, features_forward,
layers_forward, output_dimension, data_format,
crop_x, crop_y, use_on_the_fly_stats):
"""Given pre-logits, compute the final distogram/contact activations."""
config_2d_deep = self._network_2d_deep
if self._reshape_layer:
in_channels = config_2d_deep.num_filters
concat_features = [contact_pre_logits]
if features_forward is not None:
concat_features.append(features_forward)
in_channels += self._features_forward
if layers_forward is not None:
concat_features.append(layers_forward)
in_channels += 2 * config_2d_deep.num_filters
if len(concat_features) > 1:
contact_pre_logits = tf.concat(concat_features,
1 if data_format == 'NCHW' else 3)
contact_logits = two_dim_convnet.make_conv_layer(
contact_pre_logits,
in_channels=in_channels,
out_channels=output_dimension,
layer_name='output_reshape_1x1h',
filter_size=1,
filter_size_2=1,
non_linearity=False,
batch_norm=config_2d_deep.use_batch_norm,
is_training=use_on_the_fly_stats,
data_format=data_format)
else:
contact_logits = contact_pre_logits
if data_format == 'NCHW':
contact_logits = tf.transpose(contact_logits, perm=[0, 2, 3, 1])
if self._position_specific_bias_size:
# Make 2D pos-specific biases: NHWC.
biases = build_crops_biases(
self._position_specific_bias_size,
self._position_specific_bias, crop_x, crop_y, back_prop=True)
contact_logits += biases
# Will be NHWC.
return contact_logits
def update_crop_fetches(self, fetches):
"""Add auxiliary outputs for a crop to the fetches."""
if self.secstruct_multiplier > 0:
fetches['secstruct_probs'] = self._secstruct.get_q8_probs()
if self.asa_multiplier > 0:
fetches['asa_output'] = self._asa.asa_output
if self.torsion_multiplier > 0:
fetches['torsion_probs'] = self.torsion_output
def build_crops_biases(bias_size, raw_biases, crop_x, crop_y, back_prop):
"""Take the offset-specific biases and reshape them to match current crops.
Args:
bias_size: how many bias variables we're storing.
raw_biases: the bias variable
crop_x: B x 2 array of start/end for the batch
crop_y: B x 2 array of start/end for the batch
back_prop: whether to backprop through the map_fn.
Returns:
Reshaped biases.
"""
# First pad the biases with a copy of the final value to the maximum length.
max_off_diag = tf.reduce_max(
tf.maximum(tf.abs(crop_x[:, 1] - crop_y[:, 0]),
tf.abs(crop_y[:, 1] - crop_x[:, 0])))
padded_bias_size = tf.maximum(bias_size, max_off_diag)
biases = tf.concat(
[raw_biases,
tf.tile(raw_biases[-1:, :],
[padded_bias_size - bias_size, 1])], axis=0)
# Now prepend a mirror image (excluding 0th elt) for below-diagonal.
biases = tf.concat([tf.reverse(biases[1:, :], axis=[0]), biases], axis=0)
# Which diagonal of the full matrix each crop starts on (top left):
start_diag = crop_x[:, 0:1] - crop_y[:, 0:1] # B x 1
crop_size_x = tf.reduce_max(crop_x[:, 1] - crop_x[:, 0])
crop_size_y = tf.reduce_max(crop_y[:, 1] - crop_y[:, 0])
# Relative offset of each row within a crop:
# (off-diagonal decreases as y increases)
increment = tf.expand_dims(-tf.range(0, crop_size_y), 0) # 1 x crop_size_y
# Index of diagonal of first element of each row, flattened.
row_offsets = tf.reshape(start_diag + increment, [-1]) # B*crop_size_y
logging.info('row_offsets %s', row_offsets)
# Make it relative to the start of the biases array. (0-th diagonal is in
# the middle at position padded_bias_size - 1)
row_offsets += padded_bias_size - 1
# Map_fn to build the individual rows.
# B*cropsizey x cropsizex x num_bins
cropped_biases = tf.map_fn(lambda i: biases[i:i+crop_size_x, :],
elems=row_offsets, dtype=tf.float32,
back_prop=back_prop)
logging.info('cropped_biases %s', cropped_biases)
return tf.reshape(
cropped_biases, [-1, crop_size_y, crop_size_x, tf.shape(raw_biases)[-1]])