Initial release of A Deep Learning Approach for Characterizing Major Galaxy Mergers

PiperOrigin-RevId: 369646863
This commit is contained in:
Louise Deason
2021-04-21 14:00:28 +00:00
parent fe4a129143
commit 3dc0baece1
12 changed files with 1538 additions and 0 deletions

View File

@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers)
* [Better, Faster Fermionic Neural Networks](kfac_ferminet_alpha) (KFAC implementation)
* [Object-based attention for spatio-temporal reasoning](object_attention_for_reasoning)
* [Effective gene expression prediction from sequence by integrating long-range interactions](enformer)

41
galaxy_mergers/README.md Normal file
View File

@@ -0,0 +1,41 @@
# A Deep Learning Approach for Characterizing Major Galaxy Mergers
This repository contains evaluation code and checkpoints to reproduce
figures in https://arxiv.org/abs/2102.05182.
The main evaluation module is `main.py`. It uses the provided checkpoint path
and dataset path to run evaluation.
## Setup
To set up a Python virtual environment with the required dependencies, run:
```shell
python3 -m venv galaxy_mergers_env
source galaxy_mergers_env/bin/activate
pip install --upgrade pip setuptools wheel
pip install -r requirements.txt
```
### License
While the code is licensed under the Apache 2.0 License, the checkpoints weights
are made available for non-commercial use only under the terms of the
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
license. You can find details at:
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
### Citing our work
If you use this work, consider citing our paper:
```bibtex
@article{koppula2021deep,
title={A Deep Learning Approach for Characterizing Major Galaxy Mergers},
author={Koppula, Skanda and Bapst, Victor and Huertas-Company, Marc and Blackwell, Sam and Grabska-Barwinska, Agnieszka and Dieleman, Sander and Huber, Andrea and Antropova, Natasha and Binkowski, Mikolaj and Openshaw, Hannah and others},
journal={Workshop for Machine Learning and the Physical Sciences @ NeurIPS 2020},
year={2021}
}
```

View File

@@ -0,0 +1,87 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers to pre-process Antennae galaxy images."""
import collections
import os
from astropy.io import fits
import numpy as np
from scipy import ndimage
import tensorflow.compat.v2 as tf
def norm_antennae_images(images, scale=1000):
return tf.math.asinh(images/scale)
def renorm_antennae(images):
median = np.percentile(images.numpy().flatten(), 50)
img_range = np.ptp(images.numpy().flatten())
return (images - median) / (img_range / 2)
def get_antennae_images(antennae_fits_dir):
"""Load the raw Antennae galaxy images."""
all_fits_files = [
os.path.join(antennae_fits_dir, f)
for f in os.listdir(antennae_fits_dir)
]
freq_mapping = {'red': 160, 'blue': 850}
paired_fits_files = collections.defaultdict(list)
for f in all_fits_files:
redshift = float(f[-8:-5])
paired_fits_files[redshift].append(f)
for redshift, files in paired_fits_files.items():
paired_fits_files[redshift] = sorted(
files, key=lambda f: freq_mapping[f.split('/')[-1].split('_')[0]])
print('Reading files:', paired_fits_files)
print('Redshifts:', sorted(paired_fits_files.keys()))
galaxy_views = collections.defaultdict(list)
for redshift in paired_fits_files:
for view_path in paired_fits_files[redshift]:
with open(view_path, 'rb') as f:
fits_data = fits.open(f)
galaxy_views[redshift].append(np.array(fits_data[0].data))
batched_images = []
for redshift in paired_fits_files:
img = tf.constant(np.array(galaxy_views[redshift]))
img = tf.transpose(img, (1, 2, 0))
img = tf.image.resize(img, size=(60, 60))
batched_images.append(img)
return tf.stack(batched_images)
def preprocess_antennae_images(antennae_images):
"""Pre-process the Antennae galaxy images into a reasonable range."""
rotated_antennae_images = [
ndimage.rotate(img, 10, reshape=True, cval=-1)[10:-10, 10:-10]
for img in antennae_images
]
rotated_antennae_images = [
np.clip(img, 0, 1e9) for img in rotated_antennae_images
]
rotated_antennae_images = tf.stack(rotated_antennae_images)
normed_antennae_images = norm_antennae_images(rotated_antennae_images)
normed_antennae_images = tf.clip_by_value(normed_antennae_images, 1, 4.5)
renormed_antennae_images = renorm_antennae(normed_antennae_images)
return renormed_antennae_images

98
galaxy_mergers/config.py Normal file
View File

@@ -0,0 +1,98 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Default config, focused on model evaluation."""
from ml_collections import config_dict
def get_config(filter_time_intervals=None):
"""Return config object for training."""
config = config_dict.ConfigDict()
config.eval_strategy = config_dict.ConfigDict()
config.eval_strategy.class_name = 'OneDeviceConfig'
config.eval_strategy.kwargs = config_dict.ConfigDict(
dict(device_type='v100'))
## Experiment config.
config.experiment_kwargs = config_dict.ConfigDict(dict(
resnet_kwargs=dict(
blocks_per_group_list=[3, 4, 6, 3], # This choice is ResNet50.
bn_config=dict(
decay_rate=0.9,
eps=1e-5),
resnet_v2=False,
additional_features_mode='mlp',
),
optimizer_config=dict(
class_name='Momentum',
kwargs={'momentum': 0.9},
# Set up the learning rate schedule.
lr_init=0.025,
lr_factor=0.1,
lr_schedule=(50e3, 100e3, 150e3),
gradient_clip=5.,
),
l2_regularization=1e-4,
total_train_batch_size=128,
train_net_args={'is_training': True},
eval_batch_size=128,
eval_net_args={'is_training': True},
data_config=dict(
# dataset loading
dataset_path=None,
num_val_splits=10,
val_split=0,
# image cropping
image_size=(80, 80, 7),
train_crop_type='crop_fixed',
test_crop_type='crop_fixed',
n_crop_repeat=1,
train_augmentations=dict(
rotation_and_flip=True,
rescaling=True,
translation=True,
),
test_augmentations=dict(
rotation_and_flip=False,
rescaling=False,
translation=False,
),
test_time_ensembling='sum',
num_eval_buckets=5,
eval_confidence_interval=95,
task='grounded_unnormalized_regression',
loss_config=dict(
loss='mse',
mse_normalize=False,
),
model_uncertainty=True,
additional_features='',
time_filter_intervals=filter_time_intervals,
class_boundaries={
'0': [[-1., 0]],
'1': [[0, 1.]]
},
frequencies_to_use='all',
),
n_train_epochs=100
))
return config

262
galaxy_mergers/evaluator.py Normal file
View File

@@ -0,0 +1,262 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation runner."""
import collections
from absl import logging
import tensorflow.compat.v2 as tf
from galaxy_mergers import config as tp_config
from galaxy_mergers import helpers
from galaxy_mergers import losses
from galaxy_mergers import model
from galaxy_mergers import preprocessing
class GalaxyMergeClassifierEvaluator():
"""Galaxy Merge Rate Prediction Evaluation Runner."""
def __init__(self, strategy, optimizer_config, total_train_batch_size,
train_net_args, eval_batch_size, eval_net_args,
l2_regularization, data_config, resnet_kwargs, n_train_epochs):
"""Initializes evaluator/experiment."""
logging.info('Initializing evaluator...')
self._strategy = strategy
self._data_config = data_config
self._use_additional_features = bool(data_config['additional_features'])
self._eval_batch_size = eval_batch_size
self._eval_net_args = eval_net_args
self._num_buckets = data_config['num_eval_buckets']
self._n_repeats = data_config['n_crop_repeat']
self._image_size = data_config['image_size']
self._task_type = data_config['task']
self._loss_config = data_config['loss_config']
self._model_uncertainty = data_config['model_uncertainty']
del l2_regularization, optimizer_config, train_net_args
del total_train_batch_size, n_train_epochs
logging.info('Creating model...')
num_classes = 2 if self._model_uncertainty else 1
if self._task_type == losses.TASK_CLASSIFICATION:
num_classes = len(self._data_config['class_boundaries'])
self.model = model.ResNet(
n_repeats=self._data_config['n_crop_repeat'], num_classes=num_classes,
use_additional_features=self._use_additional_features, **resnet_kwargs)
self._eval_input = None
def build_eval_input(self, additional_lambdas=None):
"""Create the galaxy merger evaluation dataset."""
def decode_fn(record_bytes):
parsed_example = tf.io.parse_single_example(
record_bytes,
{
'image':
tf.io.VarLenFeature(tf.float32),
'image_shape':
tf.io.FixedLenFeature([3], dtype=tf.int64),
'axis':
tf.io.FixedLenFeature([], dtype=tf.int64),
'proposed_crop':
tf.io.FixedLenFeature([2, 2], dtype=tf.int64),
'normalized_time':
tf.io.FixedLenFeature([], dtype=tf.float32),
'unnormalized_time':
tf.io.FixedLenFeature([], dtype=tf.float32),
'grounded_normalized_time':
tf.io.FixedLenFeature([], dtype=tf.float32),
'redshift':
tf.io.FixedLenFeature([], dtype=tf.float32),
'sequence_average_redshift':
tf.io.FixedLenFeature([], dtype=tf.float32),
'mass':
tf.io.FixedLenFeature([], dtype=tf.float32),
'time_index':
tf.io.FixedLenFeature([], dtype=tf.int64),
'sequence_id':
tf.io.FixedLenFeature([], dtype=tf.string),
})
parsed_example['image'] = tf.sparse.to_dense(
parsed_example['image'], default_value=0)
dataset_row = parsed_example
return dataset_row
def build_eval_pipeline(_):
"""Generate the processed input evaluation data."""
logging.info('Building evaluation input pipeline...')
ds_path = self._data_config['dataset_path']
ds = tf.data.TFRecordDataset([ds_path]).map(decode_fn)
augmentations = dict(
rotation_and_flip=False,
rescaling=False,
translation=False
)
ds = preprocessing.prepare_dataset(
ds=ds, target_size=self._image_size,
crop_type=self._data_config['test_crop_type'],
n_repeats=self._n_repeats,
augmentations=augmentations,
task_type=self._task_type,
additional_features=self._data_config['additional_features'],
class_boundaries=self._data_config['class_boundaries'],
time_intervals=self._data_config['time_filter_intervals'],
frequencies_to_use=self._data_config['frequencies_to_use'],
additional_lambdas=additional_lambdas)
batched_ds = ds.cache().batch(self._eval_batch_size).prefetch(128)
logging.info('Finished building input pipeline...')
return batched_ds
return self._strategy.experimental_distribute_datasets_from_function(
build_eval_pipeline)
def run_test_model_ensemble(self, images, physical_features, augmentations):
"""Run evaluation on input images."""
image_variations = [images]
image_shape = images.shape.as_list()
if augmentations['rotation_and_flip']:
image_variations = preprocessing.get_all_rotations_and_flips(
image_variations)
if augmentations['rescaling']:
image_variations = preprocessing.get_all_rescalings(
image_variations, image_shape[1], augmentations['translation'])
# Put all augmented images into the batch: batch * num_augmented
augmented_images = tf.stack(image_variations, axis=0)
augmented_images = tf.reshape(augmented_images, [-1] + image_shape[1:])
if self._use_additional_features:
physical_features = tf.concat(
[physical_features] * len(image_variations), axis=0)
n_reps = self._data_config['n_crop_repeat']
augmented_images = preprocessing.move_repeats_to_batch(augmented_images,
n_reps)
logits_or_times = self.model(augmented_images, physical_features,
**self._eval_net_args)
if self._task_type == losses.TASK_CLASSIFICATION:
mu, log_sigma_sq = helpers.aggregate_classification_ensemble(
logits_or_times, len(image_variations),
self._data_config['test_time_ensembling'])
else:
assert self._task_type in losses.REGRESSION_TASKS
mu, log_sigma_sq = helpers.aggregate_regression_ensemble(
logits_or_times, len(image_variations),
self._model_uncertainty,
self._data_config['test_time_ensembling'])
return mu, log_sigma_sq
@property
def checkpoint_items(self):
return {'model': self.model}
def run_model_on_dataset(evaluator, dataset, config, n_batches=16):
"""Runs the model against a dataset, aggregates model output."""
scalar_metrics_to_log = collections.defaultdict(list)
model_outputs_to_log = collections.defaultdict(list)
dataset_features_to_log = collections.defaultdict(list)
batch_count = 1
for all_inputs in dataset:
if config.experiment_kwargs.data_config['additional_features']:
images = all_inputs[0]
physical_features = all_inputs[1]
labels, regression_targets, _ = all_inputs[2:5]
other_dataset_features = all_inputs[5:]
else:
images, physical_features = all_inputs[0], None
labels, regression_targets, _ = all_inputs[1:4]
other_dataset_features = all_inputs[4:]
mu, log_sigma_sq = evaluator.run_test_model_ensemble(
images, physical_features,
config.experiment_kwargs.data_config['test_augmentations'])
loss_config = config.experiment_kwargs.data_config['loss_config']
task_type = config.experiment_kwargs.data_config['task']
uncertainty = config.experiment_kwargs.data_config['model_uncertainty']
conf = config.experiment_kwargs.data_config['eval_confidence_interval']
scalar_metrics, vector_metrics = losses.compute_loss_and_metrics(
mu, log_sigma_sq, regression_targets, labels,
task_type, uncertainty, loss_config, 0, conf, mode='eval')
for i, dataset_feature in enumerate(other_dataset_features):
dataset_features_to_log[i].append(dataset_feature.numpy())
for scalar_metric in scalar_metrics:
v = scalar_metrics[scalar_metric]
val = v if isinstance(v, int) or isinstance(v, float) else v.numpy()
scalar_metrics_to_log[scalar_metric].append(val)
for vector_metric in vector_metrics:
val = vector_metrics[vector_metric].numpy()
model_outputs_to_log[vector_metric].append(val)
regression_targets_np = regression_targets.numpy()
labels_np = labels.numpy()
model_outputs_to_log['regression_targets'].append(regression_targets_np)
model_outputs_to_log['labels'].append(labels_np)
model_outputs_to_log['model_input_images'].append(images.numpy())
if n_batches and batch_count >= n_batches:
break
batch_count += 1
return scalar_metrics_to_log, model_outputs_to_log, dataset_features_to_log
def get_config_dataset_evaluator(filter_time_intervals,
ckpt_path,
config_override=None,
setup_dataset=True):
"""Set-up a default config, evaluation dataset, and evaluator."""
config = tp_config.get_config(filter_time_intervals=filter_time_intervals)
if config_override:
with config.ignore_type():
config.update_from_flattened_dict(config_override)
strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0')
experiment = GalaxyMergeClassifierEvaluator(
strategy=strategy, **config.experiment_kwargs)
helpers.restore_checkpoint(ckpt_path, experiment)
if setup_dataset:
additional_lambdas = [
lambda ds: ds['sequence_id'],
lambda ds: ds['time_index'],
lambda ds: ds['axis'],
lambda ds: ds['normalized_time'],
lambda ds: ds['grounded_normalized_time'],
lambda ds: ds['unnormalized_time'],
lambda ds: ds['redshift'],
lambda ds: ds['mass']
]
ds = experiment.build_eval_input(additional_lambdas=additional_lambdas)
else:
ds = None
return config, ds, experiment

228
galaxy_mergers/helpers.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for a galaxy merger model evaluation."""
import glob
import os
from astropy import cosmology
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow.compat.v2 as tf
def restore_checkpoint(checkpoint_dir, experiment):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
global_step = tf.Variable(
0, dtype=tf.int32, trainable=False, name='global_step')
checkpoint = tf.train.Checkpoint(
_global_step_=global_step, **experiment.checkpoint_items)
checkpoint.restore(checkpoint_path)
def sum_average_transformed_mu_and_sigma(mu, log_sigma_sq):
"""Computes <mu>, var(mu) + <var> in transformed representation.
This corresponds to assuming that the output distribution is a sum of
Gaussian and computing the mean and variance of the resulting (non-Gaussian)
distribution.
Args:
mu: Tensor of shape [B, ...] representing the means of the input
distributions.
log_sigma_sq: Tensor of shape [B, ...] representing log(sigma**2) of the
input distributions. Can be None, in which case the variance is assumed
to be zero.
Returns:
mu: Tensor of shape [...] representing the means of the output
distributions.
log_sigma_sq: Tensor of shape [...] representing log(sigma**2) of the
output distributions.
"""
av_mu = tf.reduce_mean(mu, axis=0)
var_mu = tf.math.reduce_std(mu, axis=0)**2
if log_sigma_sq is None:
return av_mu, tf.math.log(var_mu)
max_log_sigma_sq = tf.reduce_max(log_sigma_sq, axis=0)
log_sigma_sq -= max_log_sigma_sq
# (sigma/sigma_0)**2
sigma_sq = tf.math.exp(log_sigma_sq)
# (<sigma**2>)/sigma_0**2 (<1)
av_sigma_sq = tf.reduce_mean(sigma_sq, axis=0)
# (<sigma**2> + var(mu))/sigma_0**2
av_sigma_sq += var_mu * tf.math.exp(-max_log_sigma_sq)
# log(<sigma**2> + var(mu))
log_av_sigma_sq = tf.math.log(av_sigma_sq) + max_log_sigma_sq
return av_mu, log_av_sigma_sq
def aggregate_regression_ensemble(logits_or_times, ensemble_size,
use_uncertainty, test_time_ensembling):
"""Aggregate output of model ensemble."""
out_shape = logits_or_times.shape.as_list()[1:]
logits_or_times = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape)
mus = logits_or_times[..., 0]
log_sigma_sqs = logits_or_times[..., -1] if use_uncertainty else None
if test_time_ensembling == 'sum':
mu, log_sigma_sq = sum_average_transformed_mu_and_sigma(mus, log_sigma_sqs)
elif test_time_ensembling == 'none':
mu = mus[0]
log_sigma_sq = log_sigma_sqs[0] if use_uncertainty else None
else:
raise ValueError('Unexpected test_time_ensembling')
return mu, log_sigma_sq
def aggregate_classification_ensemble(logits_or_times, ensemble_size,
test_time_ensembling):
"""Averages the output logits across models in the ensemble."""
out_shape = logits_or_times.shape.as_list()[1:]
logits = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape)
if test_time_ensembling == 'sum':
logits = tf.reduce_mean(logits, axis=0)
return logits, None
elif test_time_ensembling == 'none':
return logits, None
else:
raise ValueError('Unexpected test_time_ensembling')
def unpack_evaluator_output(data, return_seq_info=False, return_redshift=False):
"""Unpack evaluator.run_model_on_dataset output."""
mus = np.array(data[1]['mu']).flatten()
sigmas = np.array(data[1]['sigma']).flatten()
regression_targets = np.array(data[1]['regression_targets']).flatten()
outputs = [mus, sigmas, regression_targets]
if return_seq_info:
seq_ids = np.array(data[2][0]).flatten()
seq_ids = np.array([seq_id.decode('UTF-8') for seq_id in seq_ids])
time_idxs = np.array(data[2][1]).flatten()
axes = np.array(data[2][2]).flatten()
outputs += [seq_ids, axes, time_idxs]
if return_redshift:
redshifts = np.array(data[2][6]).flatten()
outputs += [redshifts]
return outputs
def process_data_into_myrs(redshifts, *data_lists):
"""Converts normalized time to virial time using Planck cosmology."""
# small hack to avoid build tools not recognizing non-standard trickery
# done in the astropy library:
# https://github.com/astropy/astropy/blob/master/astropy/cosmology/core.py#L3290
# that dynamically generates and imports new classes.
planck13 = getattr(cosmology, 'Plank13')
hubble_constants = planck13.H(redshifts) # (km/s)/megaparsec
inv_hubble_constants = 1/hubble_constants # (megaparsec*s) / km
megaparsec_to_km = 1e19*3.1
seconds_to_gigayears = 1e-15/31.556
conversion_factor = megaparsec_to_km * seconds_to_gigayears
hubble_time_gigayears = conversion_factor * inv_hubble_constants
hubble_to_virial_time = 0.14 # approximate simulation-based conversion factor
virial_dyn_time = hubble_to_virial_time*hubble_time_gigayears.value
return [data_list*virial_dyn_time for data_list in data_lists]
def print_rmse_and_class_accuracy(mus, regression_targets, redshifts):
"""Convert to virial dynamical time and print stats."""
time_pred, time_gt = process_data_into_myrs(
redshifts, mus, regression_targets)
time_sq_errors = (time_pred-time_gt)**2
rmse = np.sqrt(np.mean(time_sq_errors))
labels = regression_targets > 0
class_preds = mus > 0
accuracy = sum((labels == class_preds).astype(np.int8)) / len(class_preds)
print(f'95% Error: {np.percentile(np.sqrt(time_sq_errors), 95)}')
print(f'RMSE: {rmse}')
print(f'Classification Accuracy: {accuracy}')
def print_stats(vec, do_print=True):
fvec = vec.flatten()
if do_print:
print(len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec))
return (len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec))
def get_image_from_fits(base_dir, seq='475_31271', time='497', axis=2):
"""Read *.fits galaxy image from directory."""
axis_map = {0: 'x', 1: 'y', 2: 'z'}
fits_glob = f'{base_dir}/{seq}/fits_of_flux_psf/{time}/*_{axis_map[axis]}_*.fits'
def get_freq_from_path(p):
return int(p.split('/')[-1].split('_')[2][1:])
fits_image_paths = sorted(glob.glob(fits_glob), key=get_freq_from_path)
assert len(fits_image_paths) == 7
combined_frequencies = []
for fit_path in fits_image_paths:
with open(fit_path, 'rb') as f:
fits_data = np.array(fits.open(f)[0].data.astype(np.float32))
combined_frequencies.append(fits_data)
fits_image = np.transpose(np.array(combined_frequencies), (1, 2, 0))
return fits_image
def stack_desired_galaxy_images(base_dir, seq, n_time_slices):
"""Searth through galaxy image directory gathering images."""
fits_sequence_dir = os.path.join(base_dir, seq, 'fits_of_flux_psf')
all_times_for_seq = os.listdir(fits_sequence_dir)
hop = (len(all_times_for_seq)-1)//(n_time_slices-1)
desired_time_idxs = [k*hop for k in range(n_time_slices)]
all_imgs = []
for j in desired_time_idxs:
time = all_times_for_seq[j]
img = get_image_from_fits(base_dir=base_dir, seq=seq, time=time, axis=2)
all_imgs.append(img)
min_img_size = min([img.shape[0] for img in all_imgs])
return all_imgs, min_img_size
def draw_galaxy_image(image, target_size=None, color_map='viridis'):
normalized_image = image / max(image.flatten())
color_map = plt.get_cmap(color_map)
colored_image = color_map(normalized_image)[:, :, :3]
colored_image = (colored_image * 255).astype(np.uint8)
colored_image = Image.fromarray(colored_image, mode='RGB')
if target_size:
colored_image = colored_image.resize(target_size, Image.ANTIALIAS)
return colored_image
def collect_merger_sequence(ds, seq=b'370_11071', n_examples_to_sift=5000):
images, targets, redshifts = [], [], []
for i, all_inputs in enumerate(ds):
if all_inputs[4][0].numpy() == seq:
images.append(all_inputs[0][0].numpy())
targets.append(all_inputs[2][0].numpy())
redshifts.append(all_inputs[10][0].numpy())
if i > n_examples_to_sift: break
return np.squeeze(images), np.squeeze(targets), np.squeeze(redshifts)
def take_samples(sample_idxs, *data_lists):
return [np.take(l, sample_idxs, axis=0) for l in data_lists]

View File

@@ -0,0 +1,68 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers to visualize gradients and other interpretability analysis."""
import numpy as np
import tensorflow.compat.v2 as tf
def rotate_by_right_angle_multiple(image, rot=90):
"""Rotate an image by right angles."""
if rot not in [0, 90, 180, 270]:
raise ValueError(f"Cannot rotate by non-90 degree angle {rot}")
if rot in [90, -270]:
image = np.transpose(image, (1, 0, 2))
image = image[::-1]
elif rot in [180, -180]:
image = image[::-1, ::-1]
elif rot in [270, -90]:
image = np.transpose(image, (1, 0, 2))
image = image[:, ::-1]
return image
def compute_gradient(images, evaluator, is_training=False):
inputs = tf.Variable(images[None], dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(inputs)
time_sigma = evaluator.model(inputs, None, is_training)
grad_time = tape.gradient(time_sigma[:, 0], inputs)
return grad_time, time_sigma
def compute_grads_for_rotations(images, evaluator, is_training=False):
test_gradients, test_outputs = [], []
for rotation in np.arange(0, 360, 90):
images_rot = rotate_by_right_angle_multiple(images, rotation)
grads, time_sigma = compute_gradient(images_rot, evaluator, is_training)
grads = np.squeeze(grads.numpy())
inv_grads = rotate_by_right_angle_multiple(grads, -rotation)
test_gradients.append(inv_grads)
test_outputs.append(time_sigma.numpy())
return np.squeeze(test_gradients), np.squeeze(test_outputs)
def compute_grads_for_rotations_and_flips(images, evaluator):
grads, time_sigma = compute_grads_for_rotations(images, evaluator)
grads_f, time_sigma_f = compute_grads_for_rotations(images[::-1], evaluator)
grads_f = grads_f[:, ::-1]
all_grads = np.concatenate([grads, grads_f], 0)
model_outputs = np.concatenate((time_sigma, time_sigma_f), 0)
return all_grads, model_outputs

169
galaxy_mergers/losses.py Normal file
View File

@@ -0,0 +1,169 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers to compute loss metrics."""
import scipy.stats
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
TASK_CLASSIFICATION = 'classification'
TASK_NORMALIZED_REGRESSION = 'normalized_regression'
TASK_UNNORMALIZED_REGRESSION = 'unnormalized_regression'
TASK_GROUNDED_UNNORMALIZED_REGRESSION = 'grounded_unnormalized_regression'
REGRESSION_TASKS = [TASK_NORMALIZED_REGRESSION, TASK_UNNORMALIZED_REGRESSION,
TASK_GROUNDED_UNNORMALIZED_REGRESSION]
ALL_TASKS = [TASK_CLASSIFICATION] + REGRESSION_TASKS
LOSS_MSE = 'mse'
LOSS_SOFTMAX_CROSS_ENTROPY = 'softmax_cross_entropy'
ALL_LOSSES = [LOSS_SOFTMAX_CROSS_ENTROPY, LOSS_MSE]
def normalize_regression_loss(regression_loss, predictions):
# Normalize loss such that:
# 1) E_{x uniform}[loss(x, prediction)] does not depend on prediction
# 2) E_{x uniform, prediction uniform}[loss(x, prediction)] is as before.
# Divides MSE regression loss by E[(prediction-x)^2]; assumes x=[-1,1]
normalization = 2./3.
normalized_loss = regression_loss / ((1./3 + predictions**2) / normalization)
return normalized_loss
def equal32(x, y):
return tf.cast(tf.equal(x, y), tf.float32)
def mse_loss(predicted, targets):
return (predicted - targets) ** 2
def get_std_factor_from_confidence_percent(percent):
dec = percent/100.
inv_dec = 1 - dec
return scipy.stats.norm.ppf(dec+inv_dec/2)
def get_all_metric_names(task_type, model_uncertainty, loss_config, # pylint: disable=unused-argument
mode='eval', return_dict=True):
"""Get all the scalar fields produced by compute_loss_and_metrics."""
names = ['regularization_loss', 'prediction_accuracy', str(mode)+'_loss']
if task_type == TASK_CLASSIFICATION:
names += ['classification_loss']
else:
names += ['regression_loss', 'avg_mu', 'var_mu']
if model_uncertainty:
names += ['uncertainty_loss', 'scaled_regression_loss',
'uncertainty_plus_scaled_regression',
'avg_sigma', 'var_sigma',
'percent_in_conf_interval', 'error_sigma_correlation',
'avg_prob']
if return_dict:
return {name: 0. for name in names}
else:
return names
def compute_loss_and_metrics(mu, log_sigma_sq,
regression_targets, labels,
task_type, model_uncertainty, loss_config,
regularization_loss=0., confidence_interval=95,
mode='train'):
"""Computes loss statistics and other metrics."""
scalars_to_log = dict()
vectors_to_log = dict()
scalars_to_log['regularization_loss'] = regularization_loss
vectors_to_log['mu'] = mu
if task_type == TASK_CLASSIFICATION:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=mu, labels=labels, name='cross_entropy')
classification_loss = tf.reduce_mean(cross_entropy, name='class_loss')
total_loss = classification_loss
sigma = None
scalars_to_log['classification_loss'] = classification_loss
predicted_labels = tf.argmax(mu, axis=1)
correct_predictions = equal32(predicted_labels, labels)
else:
regression_loss = mse_loss(mu, regression_targets)
if 'mse_normalize' in loss_config and loss_config['mse_normalize']:
assert task_type in [TASK_GROUNDED_UNNORMALIZED_REGRESSION,
TASK_NORMALIZED_REGRESSION]
regression_loss = normalize_regression_loss(regression_loss, mu)
avg_regression_loss = tf.reduce_mean(regression_loss)
vectors_to_log['regression_loss'] = regression_loss
scalars_to_log['regression_loss'] = avg_regression_loss
scalars_to_log['avg_mu'] = tf.reduce_mean(mu)
scalars_to_log['var_mu'] = tf.reduce_mean(mse_loss(mu, tf.reduce_mean(mu)))
predicted_labels = tf.cast(mu > 0, tf.int64)
correct_predictions = equal32(predicted_labels, labels)
if model_uncertainty:
# This implements Eq. (1) in https://arxiv.org/pdf/1612.01474.pdf
inv_sigma_sq = tf.math.exp(-log_sigma_sq)
scaled_regression_loss = regression_loss * inv_sigma_sq
scaled_regression_loss = tf.reduce_mean(scaled_regression_loss)
uncertainty_loss = tf.reduce_mean(log_sigma_sq)
total_loss = uncertainty_loss + scaled_regression_loss
scalars_to_log['uncertainty_loss'] = uncertainty_loss
scalars_to_log['scaled_regression_loss'] = scaled_regression_loss
scalars_to_log['uncertainty_plus_scaled_regression'] = total_loss
sigma = tf.math.exp(log_sigma_sq / 2.)
vectors_to_log['sigma'] = sigma
scalars_to_log['avg_sigma'] = tf.reduce_mean(sigma)
var_sigma = tf.reduce_mean(mse_loss(sigma, tf.reduce_mean(sigma)))
scalars_to_log['var_sigma'] = var_sigma
# Compute # of labels that fall into the confidence interval.
std_factor = get_std_factor_from_confidence_percent(confidence_interval)
lower_bound = mu - std_factor * sigma
upper_bound = mu + std_factor * sigma
preds = tf.logical_and(tf.greater(regression_targets, lower_bound),
tf.less(regression_targets, upper_bound))
percent_in_conf_interval = tf.reduce_mean(tf.cast(preds, tf.float32))
scalars_to_log['percent_in_conf_interval'] = percent_in_conf_interval*100
error_sigma_corr = tfp.stats.correlation(x=regression_loss,
y=sigma, event_axis=None)
scalars_to_log['error_sigma_correlation'] = error_sigma_corr
dists = tfp.distributions.Normal(mu, sigma)
probs = dists.prob(regression_targets)
scalars_to_log['avg_prob'] = tf.reduce_mean(probs)
else:
total_loss = avg_regression_loss
loss_name = str(mode)+'_loss'
total_loss = tf.add(total_loss, regularization_loss, name=loss_name)
scalars_to_log[loss_name] = total_loss
vectors_to_log['correct_predictions'] = correct_predictions
scalars_to_log['prediction_accuracy'] = tf.reduce_mean(correct_predictions)
# Validate that metrics outputted are exactly what is expected
expected = get_all_metric_names(task_type, model_uncertainty,
loss_config, mode, False)
assert set(expected) == set(scalars_to_log.keys())
return scalars_to_log, vectors_to_log

51
galaxy_mergers/main.py Normal file
View File

@@ -0,0 +1,51 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple script to model evaluation on a checkpoint and dataset."""
import ast
from absl import app
from absl import flags
from absl import logging
from galaxy_mergers import evaluator
flags.DEFINE_string('checkpoint_path', '', 'Path to TF2 checkpoint to eval.')
flags.DEFINE_string('data_path', '', 'Path to TFRecord(s) with data.')
flags.DEFINE_string('filter_time_intervals', None,
'Merger time intervals on which to perform regression.'
'Specify None for the default time interval [-1,1], or'
' a custom list of intervals, e.g. [[-0.2,0], [0.5,1]].')
FLAGS = flags.FLAGS
def main(_) -> None:
if FLAGS.filter_time_intervals is not None:
filter_time_intervals = ast.literal_eval(FLAGS.filter_time_intervals)
else:
filter_time_intervals = None
config, ds, experiment = evaluator.get_config_dataset_evaluator(
filter_time_intervals,
FLAGS.checkpoint_path,
config_override={
'experiment_kwargs.data_config.dataset_path': FLAGS.data_path,
})
metrics, _, _ = evaluator.run_model_on_dataset(experiment, ds, config)
logging.info('Evaluation complete. Metrics: %s', metrics)
if __name__ == '__main__':
app.run(main)

201
galaxy_mergers/model.py Normal file
View File

@@ -0,0 +1,201 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fork of a generic ResNet to incorporate additional cosmological features."""
from typing import Mapping, Optional, Sequence, Text
import sonnet.v2 as snt
import tensorflow.compat.v2 as tf
class ResNet(snt.Module):
"""ResNet model."""
def __init__(self,
n_repeats: int,
blocks_per_group_list: Sequence[int],
num_classes: int,
bn_config: Optional[Mapping[Text, float]] = None,
resnet_v2: bool = False,
channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048),
use_additional_features: bool = False,
additional_features_mode: Optional[Text] = "per_block",
name: Optional[Text] = None):
"""Constructs a ResNet model.
Args:
n_repeats: The batch dimension for the input is expected to have the form
`B = b * n_repeats`. After the conv stack, the logits for the
`n_repeats` replicas are reduced, leading to an output batch dimension
of `b`.
blocks_per_group_list: A sequence of length 4 that indicates the number of
blocks created in each group.
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers. By default the `decay_rate` is
`0.9` and `eps` is `1e-5`.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
False.
channels_per_group_list: A sequence of length 4 that indicates the number
of channels used for each block in each group.
use_additional_features: If true, additional vector features will be
concatenated to the residual stack before logits are computed.
additional_features_mode: Mode for processing additional features.
Supported modes: 'mlp' and 'per_block'.
name: Name of the module.
"""
super(ResNet, self).__init__(name=name)
self._n_repeats = n_repeats
if bn_config is None:
bn_config = {"decay_rate": 0.9, "eps": 1e-5}
self._bn_config = bn_config
self._resnet_v2 = resnet_v2
# Number of blocks in each group for ResNet.
if len(blocks_per_group_list) != 4:
raise ValueError(
"`blocks_per_group_list` must be of length 4 not {}".format(
len(blocks_per_group_list)))
self._blocks_per_group_list = blocks_per_group_list
# Number of channels in each group for ResNet.
if len(channels_per_group_list) != 4:
raise ValueError(
"`channels_per_group_list` must be of length 4 not {}".format(
len(channels_per_group_list)))
self._channels_per_group_list = channels_per_group_list
self._use_additional_features = use_additional_features
self._additional_features_mode = additional_features_mode
self._initial_conv = snt.Conv2D(
output_channels=64,
kernel_shape=7,
stride=2,
with_bias=False,
padding="SAME",
name="initial_conv")
if not self._resnet_v2:
self._initial_batchnorm = snt.BatchNorm(
create_scale=True,
create_offset=True,
name="initial_batchnorm",
**bn_config)
self._block_groups = []
strides = [1, 2, 2, 2]
for i in range(4):
self._block_groups.append(
snt.nets.resnet.BlockGroup(
channels=self._channels_per_group_list[i],
num_blocks=self._blocks_per_group_list[i],
stride=strides[i],
bn_config=bn_config,
resnet_v2=resnet_v2,
name="block_group_%d" % (i)))
if self._resnet_v2:
self._final_batchnorm = snt.BatchNorm(
create_scale=True,
create_offset=True,
name="final_batchnorm",
**bn_config)
self._logits = snt.Linear(
output_size=num_classes,
w_init=snt.initializers.VarianceScaling(scale=2.0), name="logits")
if self._use_additional_features:
self._embedding = LinearBNReLU(output_size=16, name="embedding",
**bn_config)
if self._additional_features_mode == "mlp":
self._feature_repr = LinearBNReLU(
output_size=self._channels_per_group_list[-1], name="features_repr",
**bn_config)
elif self._additional_features_mode == "per_block":
self._feature_repr = []
for i, ch in enumerate(self._channels_per_group_list):
self._feature_repr.append(
LinearBNReLU(output_size=ch, name=f"features_{i}", **bn_config))
else:
raise ValueError(f"Unsupported addiitonal features mode: "
f"{additional_features_mode}")
def __call__(self, inputs, features, is_training):
net = inputs
net = self._initial_conv(net)
if not self._resnet_v2:
net = self._initial_batchnorm(net, is_training=is_training)
net = tf.nn.relu(net)
net = tf.nn.max_pool2d(
net, ksize=3, strides=2, padding="SAME", name="initial_max_pool")
if self._use_additional_features:
assert features is not None
features = self._embedding(features, is_training=is_training)
for i, block_group in enumerate(self._block_groups):
net = block_group(net, is_training)
if (self._use_additional_features and
self._additional_features_mode == "per_block"):
features_i = self._feature_repr[i](features, is_training=is_training)
# support for n_repeats > 1
features_i = tf.repeat(features_i, self._n_repeats, axis=0)
net += features_i[:, None, None, :] # expand to spacial resolution
if self._resnet_v2:
net = self._final_batchnorm(net, is_training=is_training)
net = tf.nn.relu(net)
net = tf.reduce_mean(net, axis=[1, 2], name="final_avg_pool")
# Re-split the batch dimension
net = tf.reshape(net, [-1, self._n_repeats] + net.shape.as_list()[1:])
# Average over the various repeats of the input (e.g. those could have
# corresponded to different crops).
net = tf.reduce_mean(net, axis=1)
if (self._use_additional_features and
self._additional_features_mode == "mlp"):
net += self._feature_repr(features, is_training=is_training)
return self._logits(net)
class LinearBNReLU(snt.Module):
"""Wrapper class for Linear layer with Batch Norm and ReLU activation."""
def __init__(self, output_size=64,
w_init=snt.initializers.VarianceScaling(scale=2.0),
name="linear", **bn_config):
"""Constructs a LinearBNReLU module.
Args:
output_size: Output dimension.
w_init: weight Initializer for snt.Linear.
name: Name of the module.
**bn_config: Optional parameters to be passed to snt.BatchNorm.
"""
super(LinearBNReLU, self).__init__(name=name)
self._linear = snt.Linear(output_size=output_size, w_init=w_init,
name=f"{name}_linear")
self._bn = snt.BatchNorm(create_scale=True, create_offset=True,
name=f"{name}_bn", **bn_config)
def __call__(self, x, is_training):
x = self._linear(x)
x = self._bn(x, is_training=is_training)
return tf.nn.relu(x)

View File

@@ -0,0 +1,279 @@
# Copyright 2021 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-processing functions for input data."""
import functools
from absl import logging
import tensorflow.compat.v2 as tf
from galaxy_mergers import losses
CROP_TYPE_NONE = 'crop_none'
CROP_TYPE_FIXED = 'crop_fixed'
CROP_TYPE_RANDOM = 'crop_random'
DATASET_FREQUENCY_MEAN = 4.0
DATASET_FREQUENCY_RANGE = 8.0
PHYSICAL_FEATURES_MIN_MAX = {
'redshift': (0.572788, 2.112304),
'mass': (9.823963, 10.951282)
}
ALL_FREQUENCIES = [105, 125, 160, 435, 606, 775, 850]
VALID_ADDITIONAL_FEATURES = ['redshift', 'sequence_average_redshift', 'mass']
def _make_padding_sizes(pad_size, random_centering):
if random_centering:
pad_size_left = tf.random.uniform(
shape=[], minval=0, maxval=pad_size+1, dtype=tf.int32)
else:
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
def resize_and_pad(image, target_size, random_centering):
"""Resize image to target_size (<= image.size) and pad to original size."""
original_shape = image.shape
size = tf.reshape(target_size, [1])
size = tf.concat([size, size], axis=0)
image = tf.image.resize(image, size=size)
pad_size = original_shape[1] - target_size
pad_size_left, pad_size_right = _make_padding_sizes(
pad_size, random_centering)
padding = [[pad_size_left, pad_size_right],
[pad_size_left, pad_size_right], [0, 0]]
if len(original_shape) == 4:
padding = [[0, 0]] + padding
image = tf.pad(image, padding)
image.set_shape(original_shape)
return image
def resize_and_extract(image, target_size, random_centering):
"""Upscale image to target_size (>image.size), extract original size crop."""
original_shape = image.shape
size = tf.reshape(target_size, [1])
size = tf.concat([size, size], axis=0)
image = tf.image.resize(image, size=size)
pad_size = target_size - original_shape[1]
pad_size_left, pad_size_right = _make_padding_sizes(
pad_size, random_centering)
if len(original_shape) == 3:
image = tf.expand_dims(image, 0)
image = tf.cond(pad_size_right > 0,
lambda: image[:, pad_size_left:-pad_size_right, :, :],
lambda: image[:, pad_size_left:, :, :])
image = tf.cond(pad_size_right > 0,
lambda: image[:, :, pad_size_left:-pad_size_right, :],
lambda: image[:, :, pad_size_left:, :])
if len(original_shape) == 3:
image = tf.squeeze(image, 0)
image.set_shape(original_shape)
return image
def resize_and_center(image, target_size, random_centering):
return tf.cond(
tf.math.less_equal(target_size, image.shape[1]),
lambda: resize_and_pad(image, target_size, random_centering),
lambda: resize_and_extract(image, target_size, random_centering))
def random_rotation_and_flip(image):
angle = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
return tf.image.random_flip_left_right(tf.image.rot90(image, angle))
def get_all_rotations_and_flips(images):
assert isinstance(images, list)
new_images = []
for image in images:
for rotation in range(4):
new_images.append(tf.image.rot90(image, rotation))
flipped_image = tf.image.flip_left_right(image)
new_images.append(tf.image.rot90(flipped_image, rotation))
return new_images
def random_rescaling(image, random_centering):
assert image.shape.as_list()[0] == image.shape.as_list()[1]
original_size = image.shape.as_list()[1]
min_size = 2 * (original_size // 4)
max_size = original_size * 2
target_size = tf.random.uniform(
shape=[], minval=min_size, maxval=max_size // 2,
dtype=tf.int32) * 2
return resize_and_center(image, target_size, random_centering)
def get_all_rescalings(images, image_width, random_centering):
"""Get a uniform sample of rescalings of all images in input."""
assert isinstance(images, list)
min_size = 2 * (image_width // 4)
max_size = image_width * 2
delta_size = (max_size + 2 - min_size) // 5
sizes = range(min_size, max_size + 2, delta_size)
new_images = []
for image in images:
for size in sizes:
new_images.append(resize_and_center(image, size, random_centering))
return new_images
def move_repeats_to_batch(image, n_repeats):
width, height, n_channels = image.shape.as_list()[1:]
image = tf.reshape(image, [-1, width, height, n_channels, n_repeats])
image = tf.transpose(image, [0, 4, 1, 2, 3]) # [B, repeats, x, y, c]
return tf.reshape(image, [-1, width, height, n_channels])
def get_classification_label(dataset_row, class_boundaries):
merge_time = dataset_row['grounded_normalized_time']
label = tf.dtypes.cast(0, tf.int64)
for category, intervals in class_boundaries.items():
for interval in intervals:
if merge_time > interval[0] and merge_time < interval[1]:
label = tf.dtypes.cast(int(category), tf.int64)
return label
def get_regression_label(dataset_row, task_type):
"""Returns time-until-merger regression target given desired modeling task."""
if task_type == losses.TASK_NORMALIZED_REGRESSION:
return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32)
elif task_type == losses.TASK_GROUNDED_UNNORMALIZED_REGRESSION:
return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32)
elif task_type == losses.TASK_UNNORMALIZED_REGRESSION:
return tf.dtypes.cast(dataset_row['unnormalized_time'], tf.float32)
elif task_type == losses.TASK_CLASSIFICATION:
return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32)
else:
raise ValueError
def get_normalized_time_target(dataset_row):
return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32)
def apply_time_filter(dataset_row, time_interval):
"""Returns True if data is within the given time intervals."""
merge_time = dataset_row['grounded_normalized_time']
lower_time, upper_time = time_interval
return merge_time > lower_time and merge_time < upper_time
def normalize_physical_feature(name, dataset_row):
min_feat, max_feat = PHYSICAL_FEATURES_MIN_MAX[name]
value = getattr(dataset_row, name)
return 2 * (value - min_feat) / (max_feat - min_feat) - 1
def prepare_dataset(ds, target_size, crop_type, n_repeats, augmentations,
task_type, additional_features, class_boundaries,
time_intervals=None, frequencies_to_use='all',
additional_lambdas=None):
"""Prepare a zipped dataset of image, classification/regression labels."""
def _prepare_image(dataset_row):
"""Transpose, crop and cast an image."""
image = tf.dtypes.cast(dataset_row['image'], tf.float32)
image = tf.reshape(image, tf.cast(dataset_row['image_shape'], tf.int32))
image = tf.transpose(image, perm=[1, 2, 0]) # Convert to NHWC
freqs = ALL_FREQUENCIES if frequencies_to_use == 'all' else frequencies_to_use
idxs_to_keep = [ALL_FREQUENCIES.index(f) for f in freqs]
image = tf.gather(params=image, indices=idxs_to_keep, axis=-1)
# Based on offline computation on the empirical frequency range:
# Converts [0, 8.] ~~> [-1, 1]
image = (image - DATASET_FREQUENCY_MEAN)/(DATASET_FREQUENCY_RANGE/2.0)
def crop(image):
if crop_type == CROP_TYPE_FIXED:
crop_loc = tf.cast(dataset_row['proposed_crop'][0], tf.int32)
crop_size = tf.cast(dataset_row['proposed_crop'][1], tf.int32)
image = image[
crop_loc[0]:crop_loc[0] + crop_size[0],
crop_loc[1]:crop_loc[1] + crop_size[1], :]
image = tf.image.resize(image, target_size[0:2])
image.set_shape([target_size[0], target_size[1], target_size[2]])
elif crop_type == CROP_TYPE_RANDOM:
image = tf.image.random_crop(image, target_size)
image.set_shape([target_size[0], target_size[1], target_size[2]])
elif crop_type != CROP_TYPE_NONE:
raise NotImplementedError
return image
repeated_images = []
for _ in range(n_repeats):
repeated_images.append(crop(image))
image = tf.concat(repeated_images, axis=-1)
if augmentations['rotation_and_flip']:
image = random_rotation_and_flip(image)
if augmentations['rescaling']:
image = random_rescaling(image, augmentations['translation'])
return image
def get_regression_label_wrapper(dataset_row):
return get_regression_label(dataset_row, task_type=task_type)
def get_classification_label_wrapper(dataset_row):
return get_classification_label(dataset_row,
class_boundaries=class_boundaries)
if time_intervals:
for time_interval in time_intervals:
filter_fn = functools.partial(apply_time_filter,
time_interval=time_interval)
ds = ds.filter(filter_fn)
datasets = [ds.map(_prepare_image)]
if additional_features:
additional_features = additional_features.split(',')
assert all([f in VALID_ADDITIONAL_FEATURES for f in additional_features])
logging.info('Running with additional features: %s.',
', '.join(additional_features))
def _prepare_additional_features(dataset_row):
features = []
for f in additional_features:
features.append(normalize_physical_feature(f, dataset_row))
features = tf.convert_to_tensor(features, dtype=tf.float32)
features.set_shape([len(additional_features)])
return features
datasets += [ds.map(_prepare_additional_features)]
datasets += [
ds.map(get_classification_label_wrapper),
ds.map(get_regression_label_wrapper),
ds.map(get_normalized_time_target)]
if additional_lambdas:
for process_fn in additional_lambdas:
datasets += [ds.map(process_fn)]
return tf.data.Dataset.zip(tuple(datasets))

View File

@@ -0,0 +1,53 @@
absl-py==0.11.0
astropy==4.2
astunparse==1.6.3
cachetools==4.2.1
certifi==2020.12.5
chardet==4.0.0
cloudpickle==1.6.0
contextlib2==0.6.0.post1
cycler==0.10.0
decorator==4.4.2
dm-sonnet==2.0.0
dm-tree==0.1.5
flatbuffers==1.12
gast==0.3.3
google-auth==1.27.0
google-auth-oauthlib==0.4.2
google-pasta==0.2.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.3.4
ml-collections==0.1.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
Pillow==8.1.0
pkg-resources==0.0.0
protobuf==3.15.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyerfa==1.7.2
pyparsing==2.4.7
python-dateutil==2.8.1
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scipy==1.6.1
six==1.15.0
tabulate==0.8.9
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-estimator==2.4.0
tensorflow-probability==0.12.1
termcolor==1.1.0
typing-extensions==3.7.4.3
urllib3==1.26.3
Werkzeug==1.0.1
wrapt==1.12.1