mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
a5efafff3a
PiperOrigin-RevId: 298822315
126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
# Lint as: python3.
|
|
# Copyright 2019 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Form a weighted average of several distograms.
|
|
|
|
Can also/instead form a weighted average of a set of distance histogram pickle
|
|
files, so long as they have identical hyperparameters.
|
|
"""
|
|
|
|
import os
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
|
|
|
|
from alphafold_casp13 import distogram_io
|
|
from alphafold_casp13 import parsers
|
|
|
|
flags.DEFINE_list(
|
|
'pickle_dirs', [],
|
|
'Comma separated list of directories with pickle files to ensemble.')
|
|
flags.DEFINE_list(
|
|
'weights', [],
|
|
'Comma separated list of weights for the pickle files from different dirs.')
|
|
flags.DEFINE_string(
|
|
'output_dir', None, 'Directory where to save results of the evaluation.')
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def ensemble_distance_histograms(pickle_dirs, weights, output_dir):
|
|
"""Find all the contact maps in the first dir, then ensemble across dirs."""
|
|
if len(pickle_dirs) <= 1:
|
|
logging.warning('Pointless to ensemble %d pickle_dirs %s',
|
|
len(pickle_dirs), pickle_dirs)
|
|
# Carry on if there's one dir, otherwise do nothing.
|
|
if not pickle_dirs:
|
|
return
|
|
|
|
tf.io.gfile.makedirs(output_dir)
|
|
one_dir_pickle_files = tf.io.gfile.glob(
|
|
os.path.join(pickle_dirs[0], '*.pickle'))
|
|
assert one_dir_pickle_files, pickle_dirs[0]
|
|
original_files = len(one_dir_pickle_files)
|
|
logging.info('Found %d files %d in first of %d dirs',
|
|
original_files, len(one_dir_pickle_files), len(pickle_dirs))
|
|
targets = [os.path.splitext(os.path.basename(f))[0]
|
|
for f in one_dir_pickle_files]
|
|
skipped = 0
|
|
wrote = 0
|
|
for t in targets:
|
|
dump_file = os.path.join(output_dir, t + '.pickle')
|
|
pickle_files = [os.path.join(pickle_dir, t + '.pickle')
|
|
for pickle_dir in pickle_dirs]
|
|
_, new_dict = ensemble_one_distance_histogram(pickle_files, weights)
|
|
if new_dict is not None:
|
|
wrote += 1
|
|
distogram_io.save_distance_histogram_from_dict(dump_file, new_dict)
|
|
msg = 'Distograms Wrote %s %d / %d Skipped %d %s' % (
|
|
t, wrote, len(one_dir_pickle_files), skipped, dump_file)
|
|
logging.info(msg)
|
|
|
|
|
|
def ensemble_one_distance_histogram(pickle_files, weights):
|
|
"""Average the given pickle_files and dump."""
|
|
dicts = []
|
|
sequence = None
|
|
max_dim = None
|
|
for picklefile in pickle_files:
|
|
if not tf.io.gfile.exists(picklefile):
|
|
logging.warning('missing %s', picklefile)
|
|
break
|
|
logging.info('loading pickle file %s', picklefile)
|
|
distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
|
|
if sequence is None:
|
|
sequence = distance_histogram_dict['sequence']
|
|
else:
|
|
assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
|
|
sequence, distance_histogram_dict['sequence'])
|
|
dicts.append(distance_histogram_dict)
|
|
assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
|
|
'%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
|
|
assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
|
|
), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
|
|
if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
|
|
max_dim = dicts[-1]['probs'].shape[2]
|
|
if len(dicts) != len(pickle_files):
|
|
logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
|
|
return sequence, None
|
|
ensemble_hist = (
|
|
sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
|
|
new_dict = dict(dicts[0])
|
|
new_dict['probs'] = ensemble_hist
|
|
return sequence, new_dict
|
|
|
|
|
|
def main(argv):
|
|
del argv # Unused.
|
|
num_dirs = len(FLAGS.pickle_dirs)
|
|
if FLAGS.weights:
|
|
assert len(FLAGS.weights) == num_dirs, (
|
|
'Supply as many weights as pickle_dirs, or no weights')
|
|
weights = [float(w) for w in FLAGS.weights]
|
|
else:
|
|
weights = [1.0 for w in range(num_dirs)]
|
|
|
|
ensemble_distance_histograms(
|
|
pickle_dirs=FLAGS.pickle_dirs,
|
|
weights=weights,
|
|
output_dir=FLAGS.output_dir)
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main)
|