Make TensorFlow imports compatible with TF 1.14.

PiperOrigin-RevId: 289424199
This commit is contained in:
Augustin Zidek
2020-01-13 13:45:38 +00:00
committed by Diego de Las Casas
parent 3c6b08d3ca
commit 25a0bc935f
12 changed files with 26 additions and 35 deletions
+3 -5
View File
@@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
"""Class for predicting Accessible Surface Area.""" """Class for predicting Accessible Surface Area."""
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tensorflow.contrib import layers as contrib_layers
class ASAOutputLayer(object): class ASAOutputLayer(object):
@@ -26,9 +25,8 @@ class ASAOutputLayer(object):
def compute_asa_output(self, activations): def compute_asa_output(self, activations):
"""Just compute the logits and outputs given activations.""" """Just compute the logits and outputs given activations."""
asa_logits = contrib_layers.linear( asa_logits = tf.contrib.layers.linear(
activations, activations, 1,
1,
weights_initializer=tf.random_uniform_initializer(-0.01, 0.01), weights_initializer=tf.random_uniform_initializer(-0.01, 0.01),
scope='ASALogits') scope='ASALogits')
self.asa_output = tf.nn.relu(asa_logits, name='ASA_output_relu') self.asa_output = tf.nn.relu(asa_logits, name='ASA_output_relu')
+1 -1
View File
@@ -24,7 +24,7 @@ from absl import logging
import numpy as np import numpy as np
import six import six
import sonnet as snt import sonnet as snt
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import config_dict from alphafold_casp13 import config_dict
from alphafold_casp13 import contacts_experiment from alphafold_casp13 import contacts_experiment
+1 -1
View File
@@ -18,7 +18,7 @@ import collections
import enum import enum
import json import json
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
_ProteinDescription = collections.namedtuple( _ProteinDescription = collections.namedtuple(
+1 -1
View File
@@ -15,7 +15,7 @@
"""Contact prediction convnet experiment example.""" """Contact prediction convnet experiment example."""
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import contacts_dataset from alphafold_casp13 import contacts_dataset
from alphafold_casp13 import contacts_network from alphafold_casp13 import contacts_network
+11 -16
View File
@@ -16,13 +16,12 @@
from absl import logging from absl import logging
import sonnet import sonnet
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import asa_output from alphafold_casp13 import asa_output
from alphafold_casp13 import secstruct from alphafold_casp13 import secstruct
from alphafold_casp13 import two_dim_convnet from alphafold_casp13 import two_dim_convnet
from alphafold_casp13 import two_dim_resnet from alphafold_casp13 import two_dim_resnet
from tensorflow.contrib import layers as contrib_layers
def call_on_tuple(f): def call_on_tuple(f):
@@ -194,27 +193,23 @@ class ContactsNet(sonnet.AbstractModule):
axis=join_dim)], axis=join_dim)],
axis=collapse_dim) # Join the two crops together. axis=collapse_dim) # Join the two crops together.
if self._collapsed_batch_norm: if self._collapsed_batch_norm:
embedding_1d = contrib_layers.batch_norm( embedding_1d = tf.contrib.layers.batch_norm(
embedding_1d, embedding_1d, is_training=use_on_the_fly_stats,
is_training=use_on_the_fly_stats, fused=True, decay=0.999, scope='collapsed_batch_norm',
fused=True,
decay=0.999,
scope='collapsed_batch_norm',
data_format='NHWC') data_format='NHWC')
for i, nfil in enumerate(self._filters_1d): for i, nfil in enumerate(self._filters_1d):
embedding_1d = contrib_layers.fully_connected( embedding_1d = tf.contrib.layers.fully_connected(
embedding_1d, embedding_1d,
num_outputs=nfil, num_outputs=nfil,
normalizer_fn=(contrib_layers.batch_norm normalizer_fn=(
if self._collapsed_batch_norm else None), tf.contrib.layers.batch_norm if self._collapsed_batch_norm
normalizer_params={ else None),
'is_training': use_on_the_fly_stats, normalizer_params={'is_training': use_on_the_fly_stats,
'updates_collections': None 'updates_collections': None},
},
scope='collapsed_embed_%d' % i) scope='collapsed_embed_%d' % i)
if self.torsion_multiplier > 0: if self.torsion_multiplier > 0:
self.torsion_logits = contrib_layers.fully_connected( self.torsion_logits = tf.contrib.layers.fully_connected(
embedding_1d, embedding_1d,
num_outputs=self._torsion_bins * self._torsion_bins, num_outputs=self._torsion_bins * self._torsion_bins,
activation_fn=None, activation_fn=None,
+1 -1
View File
@@ -22,7 +22,7 @@ import os
import numpy as np import numpy as np
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
RR_FORMAT = """PFRMAT RR RR_FORMAT = """PFRMAT RR
+1 -1
View File
@@ -24,7 +24,7 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import distogram_io from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers from alphafold_casp13 import parsers
+1 -1
View File
@@ -16,7 +16,7 @@
import pickle import pickle
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
def distance_histogram_dict(f): def distance_histogram_dict(f):
+1 -1
View File
@@ -21,7 +21,7 @@ from absl import flags
from absl import logging from absl import logging
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import distogram_io from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers from alphafold_casp13 import parsers
+2 -3
View File
@@ -18,8 +18,7 @@ import os
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tensorflow.contrib import layers as contrib_layers
# 8-class classes (Q8) # 8-class classes (Q8)
@@ -57,7 +56,7 @@ class Secstruct(object):
"""Make the layer.""" """Make the layer."""
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
logging.info('Creating secstruct %s', activations) logging.info('Creating secstruct %s', activations)
self.logits = contrib_layers.linear(activations, self._dimension) self.logits = tf.contrib.layers.linear(activations, self._dimension)
self.ss_q8_probs = tf.nn.softmax(self.logits) self.ss_q8_probs = tf.nn.softmax(self.logits)
self.ss_q3_probs = tf.matmul( self.ss_q3_probs = tf.matmul(
self.ss_q8_probs, tf.constant(self.q3_map_matrix, dtype=tf.float32)) self.ss_q8_probs, tf.constant(self.q3_map_matrix, dtype=tf.float32))
+2 -3
View File
@@ -15,8 +15,7 @@
"""Two dimensional convolutional neural net layers.""" """Two dimensional convolutional neural net layers."""
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tensorflow.contrib import layers as contrib_layers
def weight_variable(shape, stddev=0.01): def weight_variable(shape, stddev=0.01):
@@ -92,7 +91,7 @@ def make_conv_sep2d_layer(input_node,
def batch_norm_layer(h_conv, layer_name, is_training=True, data_format='NCHW'): def batch_norm_layer(h_conv, layer_name, is_training=True, data_format='NCHW'):
"""Batch norm layer.""" """Batch norm layer."""
logging.vlog(1, 'batch norm for layer %s', layer_name) logging.vlog(1, 'batch norm for layer %s', layer_name)
return contrib_layers.batch_norm( return tf.contrib.layers.batch_norm(
h_conv, h_conv,
is_training=is_training, is_training=is_training,
fused=True, fused=True,
+1 -1
View File
@@ -15,7 +15,7 @@
"""2D Resnet.""" """2D Resnet."""
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import two_dim_convnet from alphafold_casp13 import two_dim_convnet