Files
deepmind-research/alphafold_casp13/ensemble_contact_maps.py
T
Augustin Zidek a5efafff3a Import TensorFlow without compat mode as our code targets TF 1.14.
PiperOrigin-RevId: 298822315
2020-03-04 13:25:32 +00:00

126 lines
4.6 KiB
Python

# Lint as: python3.
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Form a weighted average of several distograms.
Can also/instead form a weighted average of a set of distance histogram pickle
files, so long as they have identical hyperparameters.
"""
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers
flags.DEFINE_list(
'pickle_dirs', [],
'Comma separated list of directories with pickle files to ensemble.')
flags.DEFINE_list(
'weights', [],
'Comma separated list of weights for the pickle files from different dirs.')
flags.DEFINE_string(
'output_dir', None, 'Directory where to save results of the evaluation.')
FLAGS = flags.FLAGS
def ensemble_distance_histograms(pickle_dirs, weights, output_dir):
"""Find all the contact maps in the first dir, then ensemble across dirs."""
if len(pickle_dirs) <= 1:
logging.warning('Pointless to ensemble %d pickle_dirs %s',
len(pickle_dirs), pickle_dirs)
# Carry on if there's one dir, otherwise do nothing.
if not pickle_dirs:
return
tf.io.gfile.makedirs(output_dir)
one_dir_pickle_files = tf.io.gfile.glob(
os.path.join(pickle_dirs[0], '*.pickle'))
assert one_dir_pickle_files, pickle_dirs[0]
original_files = len(one_dir_pickle_files)
logging.info('Found %d files %d in first of %d dirs',
original_files, len(one_dir_pickle_files), len(pickle_dirs))
targets = [os.path.splitext(os.path.basename(f))[0]
for f in one_dir_pickle_files]
skipped = 0
wrote = 0
for t in targets:
dump_file = os.path.join(output_dir, t + '.pickle')
pickle_files = [os.path.join(pickle_dir, t + '.pickle')
for pickle_dir in pickle_dirs]
_, new_dict = ensemble_one_distance_histogram(pickle_files, weights)
if new_dict is not None:
wrote += 1
distogram_io.save_distance_histogram_from_dict(dump_file, new_dict)
msg = 'Distograms Wrote %s %d / %d Skipped %d %s' % (
t, wrote, len(one_dir_pickle_files), skipped, dump_file)
logging.info(msg)
def ensemble_one_distance_histogram(pickle_files, weights):
"""Average the given pickle_files and dump."""
dicts = []
sequence = None
max_dim = None
for picklefile in pickle_files:
if not tf.io.gfile.exists(picklefile):
logging.warning('missing %s', picklefile)
break
logging.info('loading pickle file %s', picklefile)
distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
if sequence is None:
sequence = distance_histogram_dict['sequence']
else:
assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
sequence, distance_histogram_dict['sequence'])
dicts.append(distance_histogram_dict)
assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
'%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
max_dim = dicts[-1]['probs'].shape[2]
if len(dicts) != len(pickle_files):
logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
return sequence, None
ensemble_hist = (
sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
new_dict = dict(dicts[0])
new_dict['probs'] = ensemble_hist
return sequence, new_dict
def main(argv):
del argv # Unused.
num_dirs = len(FLAGS.pickle_dirs)
if FLAGS.weights:
assert len(FLAGS.weights) == num_dirs, (
'Supply as many weights as pickle_dirs, or no weights')
weights = [float(w) for w in FLAGS.weights]
else:
weights = [1.0 for w in range(num_dirs)]
ensemble_distance_histograms(
pickle_dirs=FLAGS.pickle_dirs,
weights=weights,
output_dir=FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)