diff --git a/README.md b/README.md index 0a13cc4..d84aa90 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers) * [Better, Faster Fermionic Neural Networks](kfac_ferminet_alpha) (KFAC implementation) * [Object-based attention for spatio-temporal reasoning](object_attention_for_reasoning) * [Effective gene expression prediction from sequence by integrating long-range interactions](enformer) diff --git a/galaxy_mergers/README.md b/galaxy_mergers/README.md new file mode 100644 index 0000000..1b531e4 --- /dev/null +++ b/galaxy_mergers/README.md @@ -0,0 +1,41 @@ +# A Deep Learning Approach for Characterizing Major Galaxy Mergers + +This repository contains evaluation code and checkpoints to reproduce +figures in https://arxiv.org/abs/2102.05182. + +The main evaluation module is `main.py`. It uses the provided checkpoint path +and dataset path to run evaluation. + + +## Setup + +To set up a Python virtual environment with the required dependencies, run: + +```shell +python3 -m venv galaxy_mergers_env +source galaxy_mergers_env/bin/activate +pip install --upgrade pip setuptools wheel +pip install -r requirements.txt +``` + +### License + +While the code is licensed under the Apache 2.0 License, the checkpoints weights +are made available for non-commercial use only under the terms of the +Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +license. You can find details at: +https://creativecommons.org/licenses/by-nc/4.0/legalcode. + + +### Citing our work + +If you use this work, consider citing our paper: + +```bibtex +@article{koppula2021deep, + title={A Deep Learning Approach for Characterizing Major Galaxy Mergers}, + author={Koppula, Skanda and Bapst, Victor and Huertas-Company, Marc and Blackwell, Sam and Grabska-Barwinska, Agnieszka and Dieleman, Sander and Huber, Andrea and Antropova, Natasha and Binkowski, Mikolaj and Openshaw, Hannah and others}, + journal={Workshop for Machine Learning and the Physical Sciences @ NeurIPS 2020}, + year={2021} +} +``` diff --git a/galaxy_mergers/antennae_helpers.py b/galaxy_mergers/antennae_helpers.py new file mode 100644 index 0000000..b8de4a6 --- /dev/null +++ b/galaxy_mergers/antennae_helpers.py @@ -0,0 +1,87 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers to pre-process Antennae galaxy images.""" + +import collections +import os + +from astropy.io import fits +import numpy as np +from scipy import ndimage +import tensorflow.compat.v2 as tf + + +def norm_antennae_images(images, scale=1000): + return tf.math.asinh(images/scale) + + +def renorm_antennae(images): + median = np.percentile(images.numpy().flatten(), 50) + img_range = np.ptp(images.numpy().flatten()) + return (images - median) / (img_range / 2) + + +def get_antennae_images(antennae_fits_dir): + """Load the raw Antennae galaxy images.""" + all_fits_files = [ + os.path.join(antennae_fits_dir, f) + for f in os.listdir(antennae_fits_dir) + ] + freq_mapping = {'red': 160, 'blue': 850} + + paired_fits_files = collections.defaultdict(list) + for f in all_fits_files: + redshift = float(f[-8:-5]) + paired_fits_files[redshift].append(f) + + for redshift, files in paired_fits_files.items(): + paired_fits_files[redshift] = sorted( + files, key=lambda f: freq_mapping[f.split('/')[-1].split('_')[0]]) + + print('Reading files:', paired_fits_files) + print('Redshifts:', sorted(paired_fits_files.keys())) + + galaxy_views = collections.defaultdict(list) + for redshift in paired_fits_files: + for view_path in paired_fits_files[redshift]: + with open(view_path, 'rb') as f: + fits_data = fits.open(f) + galaxy_views[redshift].append(np.array(fits_data[0].data)) + + batched_images = [] + for redshift in paired_fits_files: + img = tf.constant(np.array(galaxy_views[redshift])) + img = tf.transpose(img, (1, 2, 0)) + img = tf.image.resize(img, size=(60, 60)) + batched_images.append(img) + + return tf.stack(batched_images) + + +def preprocess_antennae_images(antennae_images): + """Pre-process the Antennae galaxy images into a reasonable range.""" + rotated_antennae_images = [ + ndimage.rotate(img, 10, reshape=True, cval=-1)[10:-10, 10:-10] + for img in antennae_images + ] + rotated_antennae_images = [ + np.clip(img, 0, 1e9) for img in rotated_antennae_images + ] + rotated_antennae_images = tf.stack(rotated_antennae_images) + normed_antennae_images = norm_antennae_images(rotated_antennae_images) + normed_antennae_images = tf.clip_by_value(normed_antennae_images, 1, 4.5) + renormed_antennae_images = renorm_antennae(normed_antennae_images) + return renormed_antennae_images diff --git a/galaxy_mergers/config.py b/galaxy_mergers/config.py new file mode 100644 index 0000000..34a0ae4 --- /dev/null +++ b/galaxy_mergers/config.py @@ -0,0 +1,98 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default config, focused on model evaluation.""" + +from ml_collections import config_dict + + +def get_config(filter_time_intervals=None): + """Return config object for training.""" + config = config_dict.ConfigDict() + config.eval_strategy = config_dict.ConfigDict() + config.eval_strategy.class_name = 'OneDeviceConfig' + config.eval_strategy.kwargs = config_dict.ConfigDict( + dict(device_type='v100')) + + ## Experiment config. + config.experiment_kwargs = config_dict.ConfigDict(dict( + resnet_kwargs=dict( + blocks_per_group_list=[3, 4, 6, 3], # This choice is ResNet50. + bn_config=dict( + decay_rate=0.9, + eps=1e-5), + resnet_v2=False, + additional_features_mode='mlp', + ), + optimizer_config=dict( + class_name='Momentum', + kwargs={'momentum': 0.9}, + # Set up the learning rate schedule. + lr_init=0.025, + lr_factor=0.1, + lr_schedule=(50e3, 100e3, 150e3), + gradient_clip=5., + ), + l2_regularization=1e-4, + total_train_batch_size=128, + train_net_args={'is_training': True}, + eval_batch_size=128, + eval_net_args={'is_training': True}, + data_config=dict( + # dataset loading + dataset_path=None, + num_val_splits=10, + val_split=0, + + # image cropping + image_size=(80, 80, 7), + train_crop_type='crop_fixed', + test_crop_type='crop_fixed', + n_crop_repeat=1, + + train_augmentations=dict( + rotation_and_flip=True, + rescaling=True, + translation=True, + ), + + test_augmentations=dict( + rotation_and_flip=False, + rescaling=False, + translation=False, + ), + test_time_ensembling='sum', + + num_eval_buckets=5, + eval_confidence_interval=95, + + task='grounded_unnormalized_regression', + loss_config=dict( + loss='mse', + mse_normalize=False, + ), + model_uncertainty=True, + additional_features='', + time_filter_intervals=filter_time_intervals, + class_boundaries={ + '0': [[-1., 0]], + '1': [[0, 1.]] + }, + frequencies_to_use='all', + ), + n_train_epochs=100 + )) + + return config diff --git a/galaxy_mergers/evaluator.py b/galaxy_mergers/evaluator.py new file mode 100644 index 0000000..05a7a10 --- /dev/null +++ b/galaxy_mergers/evaluator.py @@ -0,0 +1,262 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Evaluation runner.""" + +import collections +from absl import logging +import tensorflow.compat.v2 as tf + +from galaxy_mergers import config as tp_config +from galaxy_mergers import helpers +from galaxy_mergers import losses +from galaxy_mergers import model +from galaxy_mergers import preprocessing + + +class GalaxyMergeClassifierEvaluator(): + """Galaxy Merge Rate Prediction Evaluation Runner.""" + + def __init__(self, strategy, optimizer_config, total_train_batch_size, + train_net_args, eval_batch_size, eval_net_args, + l2_regularization, data_config, resnet_kwargs, n_train_epochs): + """Initializes evaluator/experiment.""" + logging.info('Initializing evaluator...') + self._strategy = strategy + self._data_config = data_config + self._use_additional_features = bool(data_config['additional_features']) + self._eval_batch_size = eval_batch_size + self._eval_net_args = eval_net_args + self._num_buckets = data_config['num_eval_buckets'] + self._n_repeats = data_config['n_crop_repeat'] + self._image_size = data_config['image_size'] + self._task_type = data_config['task'] + self._loss_config = data_config['loss_config'] + self._model_uncertainty = data_config['model_uncertainty'] + del l2_regularization, optimizer_config, train_net_args + del total_train_batch_size, n_train_epochs + + logging.info('Creating model...') + num_classes = 2 if self._model_uncertainty else 1 + if self._task_type == losses.TASK_CLASSIFICATION: + num_classes = len(self._data_config['class_boundaries']) + self.model = model.ResNet( + n_repeats=self._data_config['n_crop_repeat'], num_classes=num_classes, + use_additional_features=self._use_additional_features, **resnet_kwargs) + + self._eval_input = None + + def build_eval_input(self, additional_lambdas=None): + """Create the galaxy merger evaluation dataset.""" + + def decode_fn(record_bytes): + parsed_example = tf.io.parse_single_example( + record_bytes, + { + 'image': + tf.io.VarLenFeature(tf.float32), + 'image_shape': + tf.io.FixedLenFeature([3], dtype=tf.int64), + 'axis': + tf.io.FixedLenFeature([], dtype=tf.int64), + 'proposed_crop': + tf.io.FixedLenFeature([2, 2], dtype=tf.int64), + 'normalized_time': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'unnormalized_time': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'grounded_normalized_time': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'redshift': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'sequence_average_redshift': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'mass': + tf.io.FixedLenFeature([], dtype=tf.float32), + 'time_index': + tf.io.FixedLenFeature([], dtype=tf.int64), + 'sequence_id': + tf.io.FixedLenFeature([], dtype=tf.string), + }) + parsed_example['image'] = tf.sparse.to_dense( + parsed_example['image'], default_value=0) + dataset_row = parsed_example + return dataset_row + + def build_eval_pipeline(_): + """Generate the processed input evaluation data.""" + + logging.info('Building evaluation input pipeline...') + ds_path = self._data_config['dataset_path'] + ds = tf.data.TFRecordDataset([ds_path]).map(decode_fn) + + augmentations = dict( + rotation_and_flip=False, + rescaling=False, + translation=False + ) + ds = preprocessing.prepare_dataset( + ds=ds, target_size=self._image_size, + crop_type=self._data_config['test_crop_type'], + n_repeats=self._n_repeats, + augmentations=augmentations, + task_type=self._task_type, + additional_features=self._data_config['additional_features'], + class_boundaries=self._data_config['class_boundaries'], + time_intervals=self._data_config['time_filter_intervals'], + frequencies_to_use=self._data_config['frequencies_to_use'], + additional_lambdas=additional_lambdas) + + batched_ds = ds.cache().batch(self._eval_batch_size).prefetch(128) + logging.info('Finished building input pipeline...') + return batched_ds + + return self._strategy.experimental_distribute_datasets_from_function( + build_eval_pipeline) + + def run_test_model_ensemble(self, images, physical_features, augmentations): + """Run evaluation on input images.""" + image_variations = [images] + image_shape = images.shape.as_list() + + if augmentations['rotation_and_flip']: + image_variations = preprocessing.get_all_rotations_and_flips( + image_variations) + + if augmentations['rescaling']: + image_variations = preprocessing.get_all_rescalings( + image_variations, image_shape[1], augmentations['translation']) + + # Put all augmented images into the batch: batch * num_augmented + augmented_images = tf.stack(image_variations, axis=0) + augmented_images = tf.reshape(augmented_images, [-1] + image_shape[1:]) + if self._use_additional_features: + physical_features = tf.concat( + [physical_features] * len(image_variations), axis=0) + + n_reps = self._data_config['n_crop_repeat'] + augmented_images = preprocessing.move_repeats_to_batch(augmented_images, + n_reps) + + logits_or_times = self.model(augmented_images, physical_features, + **self._eval_net_args) + if self._task_type == losses.TASK_CLASSIFICATION: + mu, log_sigma_sq = helpers.aggregate_classification_ensemble( + logits_or_times, len(image_variations), + self._data_config['test_time_ensembling']) + else: + assert self._task_type in losses.REGRESSION_TASKS + mu, log_sigma_sq = helpers.aggregate_regression_ensemble( + logits_or_times, len(image_variations), + self._model_uncertainty, + self._data_config['test_time_ensembling']) + + return mu, log_sigma_sq + + @property + def checkpoint_items(self): + return {'model': self.model} + + +def run_model_on_dataset(evaluator, dataset, config, n_batches=16): + """Runs the model against a dataset, aggregates model output.""" + + scalar_metrics_to_log = collections.defaultdict(list) + model_outputs_to_log = collections.defaultdict(list) + dataset_features_to_log = collections.defaultdict(list) + + batch_count = 1 + for all_inputs in dataset: + if config.experiment_kwargs.data_config['additional_features']: + images = all_inputs[0] + physical_features = all_inputs[1] + labels, regression_targets, _ = all_inputs[2:5] + other_dataset_features = all_inputs[5:] + else: + images, physical_features = all_inputs[0], None + labels, regression_targets, _ = all_inputs[1:4] + other_dataset_features = all_inputs[4:] + + mu, log_sigma_sq = evaluator.run_test_model_ensemble( + images, physical_features, + config.experiment_kwargs.data_config['test_augmentations']) + + loss_config = config.experiment_kwargs.data_config['loss_config'] + task_type = config.experiment_kwargs.data_config['task'] + uncertainty = config.experiment_kwargs.data_config['model_uncertainty'] + conf = config.experiment_kwargs.data_config['eval_confidence_interval'] + scalar_metrics, vector_metrics = losses.compute_loss_and_metrics( + mu, log_sigma_sq, regression_targets, labels, + task_type, uncertainty, loss_config, 0, conf, mode='eval') + + for i, dataset_feature in enumerate(other_dataset_features): + dataset_features_to_log[i].append(dataset_feature.numpy()) + + for scalar_metric in scalar_metrics: + v = scalar_metrics[scalar_metric] + val = v if isinstance(v, int) or isinstance(v, float) else v.numpy() + scalar_metrics_to_log[scalar_metric].append(val) + + for vector_metric in vector_metrics: + val = vector_metrics[vector_metric].numpy() + model_outputs_to_log[vector_metric].append(val) + + regression_targets_np = regression_targets.numpy() + labels_np = labels.numpy() + model_outputs_to_log['regression_targets'].append(regression_targets_np) + model_outputs_to_log['labels'].append(labels_np) + model_outputs_to_log['model_input_images'].append(images.numpy()) + + if n_batches and batch_count >= n_batches: + break + batch_count += 1 + + return scalar_metrics_to_log, model_outputs_to_log, dataset_features_to_log + + +def get_config_dataset_evaluator(filter_time_intervals, + ckpt_path, + config_override=None, + setup_dataset=True): + """Set-up a default config, evaluation dataset, and evaluator.""" + config = tp_config.get_config(filter_time_intervals=filter_time_intervals) + + if config_override: + with config.ignore_type(): + config.update_from_flattened_dict(config_override) + + strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0') + experiment = GalaxyMergeClassifierEvaluator( + strategy=strategy, **config.experiment_kwargs) + + helpers.restore_checkpoint(ckpt_path, experiment) + + if setup_dataset: + additional_lambdas = [ + lambda ds: ds['sequence_id'], + lambda ds: ds['time_index'], + lambda ds: ds['axis'], + lambda ds: ds['normalized_time'], + lambda ds: ds['grounded_normalized_time'], + lambda ds: ds['unnormalized_time'], + lambda ds: ds['redshift'], + lambda ds: ds['mass'] + ] + + ds = experiment.build_eval_input(additional_lambdas=additional_lambdas) + else: + ds = None + return config, ds, experiment diff --git a/galaxy_mergers/helpers.py b/galaxy_mergers/helpers.py new file mode 100644 index 0000000..8fc9a86 --- /dev/null +++ b/galaxy_mergers/helpers.py @@ -0,0 +1,228 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for a galaxy merger model evaluation.""" + +import glob +import os +from astropy import cosmology +from astropy.io import fits +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +import tensorflow.compat.v2 as tf + + +def restore_checkpoint(checkpoint_dir, experiment): + checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) + global_step = tf.Variable( + 0, dtype=tf.int32, trainable=False, name='global_step') + checkpoint = tf.train.Checkpoint( + _global_step_=global_step, **experiment.checkpoint_items) + checkpoint.restore(checkpoint_path) + + +def sum_average_transformed_mu_and_sigma(mu, log_sigma_sq): + """Computes , var(mu) + in transformed representation. + + This corresponds to assuming that the output distribution is a sum of + Gaussian and computing the mean and variance of the resulting (non-Gaussian) + distribution. + + Args: + mu: Tensor of shape [B, ...] representing the means of the input + distributions. + log_sigma_sq: Tensor of shape [B, ...] representing log(sigma**2) of the + input distributions. Can be None, in which case the variance is assumed + to be zero. + + Returns: + mu: Tensor of shape [...] representing the means of the output + distributions. + log_sigma_sq: Tensor of shape [...] representing log(sigma**2) of the + output distributions. + """ + av_mu = tf.reduce_mean(mu, axis=0) + var_mu = tf.math.reduce_std(mu, axis=0)**2 + if log_sigma_sq is None: + return av_mu, tf.math.log(var_mu) + max_log_sigma_sq = tf.reduce_max(log_sigma_sq, axis=0) + log_sigma_sq -= max_log_sigma_sq + # (sigma/sigma_0)**2 + sigma_sq = tf.math.exp(log_sigma_sq) + # ()/sigma_0**2 (<1) + av_sigma_sq = tf.reduce_mean(sigma_sq, axis=0) + # ( + var(mu))/sigma_0**2 + av_sigma_sq += var_mu * tf.math.exp(-max_log_sigma_sq) + # log( + var(mu)) + log_av_sigma_sq = tf.math.log(av_sigma_sq) + max_log_sigma_sq + return av_mu, log_av_sigma_sq + + +def aggregate_regression_ensemble(logits_or_times, ensemble_size, + use_uncertainty, test_time_ensembling): + """Aggregate output of model ensemble.""" + out_shape = logits_or_times.shape.as_list()[1:] + logits_or_times = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape) + mus = logits_or_times[..., 0] + log_sigma_sqs = logits_or_times[..., -1] if use_uncertainty else None + + if test_time_ensembling == 'sum': + mu, log_sigma_sq = sum_average_transformed_mu_and_sigma(mus, log_sigma_sqs) + elif test_time_ensembling == 'none': + mu = mus[0] + log_sigma_sq = log_sigma_sqs[0] if use_uncertainty else None + else: + raise ValueError('Unexpected test_time_ensembling') + return mu, log_sigma_sq + + +def aggregate_classification_ensemble(logits_or_times, ensemble_size, + test_time_ensembling): + """Averages the output logits across models in the ensemble.""" + out_shape = logits_or_times.shape.as_list()[1:] + logits = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape) + + if test_time_ensembling == 'sum': + logits = tf.reduce_mean(logits, axis=0) + return logits, None + elif test_time_ensembling == 'none': + return logits, None + else: + raise ValueError('Unexpected test_time_ensembling') + + +def unpack_evaluator_output(data, return_seq_info=False, return_redshift=False): + """Unpack evaluator.run_model_on_dataset output.""" + mus = np.array(data[1]['mu']).flatten() + sigmas = np.array(data[1]['sigma']).flatten() + regression_targets = np.array(data[1]['regression_targets']).flatten() + outputs = [mus, sigmas, regression_targets] + + if return_seq_info: + seq_ids = np.array(data[2][0]).flatten() + seq_ids = np.array([seq_id.decode('UTF-8') for seq_id in seq_ids]) + time_idxs = np.array(data[2][1]).flatten() + axes = np.array(data[2][2]).flatten() + outputs += [seq_ids, axes, time_idxs] + + if return_redshift: + redshifts = np.array(data[2][6]).flatten() + outputs += [redshifts] + + return outputs + + +def process_data_into_myrs(redshifts, *data_lists): + """Converts normalized time to virial time using Planck cosmology.""" + # small hack to avoid build tools not recognizing non-standard trickery + # done in the astropy library: + # https://github.com/astropy/astropy/blob/master/astropy/cosmology/core.py#L3290 + # that dynamically generates and imports new classes. + planck13 = getattr(cosmology, 'Plank13') + hubble_constants = planck13.H(redshifts) # (km/s)/megaparsec + inv_hubble_constants = 1/hubble_constants # (megaparsec*s) / km + megaparsec_to_km = 1e19*3.1 + seconds_to_gigayears = 1e-15/31.556 + conversion_factor = megaparsec_to_km * seconds_to_gigayears + hubble_time_gigayears = conversion_factor * inv_hubble_constants + + hubble_to_virial_time = 0.14 # approximate simulation-based conversion factor + virial_dyn_time = hubble_to_virial_time*hubble_time_gigayears.value + return [data_list*virial_dyn_time for data_list in data_lists] + + +def print_rmse_and_class_accuracy(mus, regression_targets, redshifts): + """Convert to virial dynamical time and print stats.""" + time_pred, time_gt = process_data_into_myrs( + redshifts, mus, regression_targets) + time_sq_errors = (time_pred-time_gt)**2 + rmse = np.sqrt(np.mean(time_sq_errors)) + labels = regression_targets > 0 + class_preds = mus > 0 + accuracy = sum((labels == class_preds).astype(np.int8)) / len(class_preds) + + print(f'95% Error: {np.percentile(np.sqrt(time_sq_errors), 95)}') + print(f'RMSE: {rmse}') + print(f'Classification Accuracy: {accuracy}') + + +def print_stats(vec, do_print=True): + fvec = vec.flatten() + if do_print: + print(len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec)) + return (len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec)) + + +def get_image_from_fits(base_dir, seq='475_31271', time='497', axis=2): + """Read *.fits galaxy image from directory.""" + axis_map = {0: 'x', 1: 'y', 2: 'z'} + fits_glob = f'{base_dir}/{seq}/fits_of_flux_psf/{time}/*_{axis_map[axis]}_*.fits' + + def get_freq_from_path(p): + return int(p.split('/')[-1].split('_')[2][1:]) + + fits_image_paths = sorted(glob.glob(fits_glob), key=get_freq_from_path) + assert len(fits_image_paths) == 7 + combined_frequencies = [] + for fit_path in fits_image_paths: + with open(fit_path, 'rb') as f: + fits_data = np.array(fits.open(f)[0].data.astype(np.float32)) + combined_frequencies.append(fits_data) + fits_image = np.transpose(np.array(combined_frequencies), (1, 2, 0)) + return fits_image + + +def stack_desired_galaxy_images(base_dir, seq, n_time_slices): + """Searth through galaxy image directory gathering images.""" + fits_sequence_dir = os.path.join(base_dir, seq, 'fits_of_flux_psf') + all_times_for_seq = os.listdir(fits_sequence_dir) + hop = (len(all_times_for_seq)-1)//(n_time_slices-1) + desired_time_idxs = [k*hop for k in range(n_time_slices)] + + all_imgs = [] + for j in desired_time_idxs: + time = all_times_for_seq[j] + img = get_image_from_fits(base_dir=base_dir, seq=seq, time=time, axis=2) + all_imgs.append(img) + + min_img_size = min([img.shape[0] for img in all_imgs]) + return all_imgs, min_img_size + + +def draw_galaxy_image(image, target_size=None, color_map='viridis'): + normalized_image = image / max(image.flatten()) + color_map = plt.get_cmap(color_map) + colored_image = color_map(normalized_image)[:, :, :3] + colored_image = (colored_image * 255).astype(np.uint8) + colored_image = Image.fromarray(colored_image, mode='RGB') + if target_size: + colored_image = colored_image.resize(target_size, Image.ANTIALIAS) + return colored_image + + +def collect_merger_sequence(ds, seq=b'370_11071', n_examples_to_sift=5000): + images, targets, redshifts = [], [], [] + for i, all_inputs in enumerate(ds): + if all_inputs[4][0].numpy() == seq: + images.append(all_inputs[0][0].numpy()) + targets.append(all_inputs[2][0].numpy()) + redshifts.append(all_inputs[10][0].numpy()) + if i > n_examples_to_sift: break + return np.squeeze(images), np.squeeze(targets), np.squeeze(redshifts) + + +def take_samples(sample_idxs, *data_lists): + return [np.take(l, sample_idxs, axis=0) for l in data_lists] diff --git a/galaxy_mergers/interpretability_helpers.py b/galaxy_mergers/interpretability_helpers.py new file mode 100644 index 0000000..4330e61 --- /dev/null +++ b/galaxy_mergers/interpretability_helpers.py @@ -0,0 +1,68 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers to visualize gradients and other interpretability analysis.""" + +import numpy as np +import tensorflow.compat.v2 as tf + + +def rotate_by_right_angle_multiple(image, rot=90): + """Rotate an image by right angles.""" + if rot not in [0, 90, 180, 270]: + raise ValueError(f"Cannot rotate by non-90 degree angle {rot}") + + if rot in [90, -270]: + image = np.transpose(image, (1, 0, 2)) + image = image[::-1] + elif rot in [180, -180]: + image = image[::-1, ::-1] + elif rot in [270, -90]: + image = np.transpose(image, (1, 0, 2)) + image = image[:, ::-1] + + return image + + +def compute_gradient(images, evaluator, is_training=False): + inputs = tf.Variable(images[None], dtype=tf.float32) + with tf.GradientTape() as tape: + tape.watch(inputs) + time_sigma = evaluator.model(inputs, None, is_training) + grad_time = tape.gradient(time_sigma[:, 0], inputs) + return grad_time, time_sigma + + +def compute_grads_for_rotations(images, evaluator, is_training=False): + test_gradients, test_outputs = [], [] + for rotation in np.arange(0, 360, 90): + images_rot = rotate_by_right_angle_multiple(images, rotation) + grads, time_sigma = compute_gradient(images_rot, evaluator, is_training) + grads = np.squeeze(grads.numpy()) + inv_grads = rotate_by_right_angle_multiple(grads, -rotation) + test_gradients.append(inv_grads) + test_outputs.append(time_sigma.numpy()) + return np.squeeze(test_gradients), np.squeeze(test_outputs) + + +def compute_grads_for_rotations_and_flips(images, evaluator): + grads, time_sigma = compute_grads_for_rotations(images, evaluator) + grads_f, time_sigma_f = compute_grads_for_rotations(images[::-1], evaluator) + grads_f = grads_f[:, ::-1] + all_grads = np.concatenate([grads, grads_f], 0) + model_outputs = np.concatenate((time_sigma, time_sigma_f), 0) + return all_grads, model_outputs + + diff --git a/galaxy_mergers/losses.py b/galaxy_mergers/losses.py new file mode 100644 index 0000000..88f8658 --- /dev/null +++ b/galaxy_mergers/losses.py @@ -0,0 +1,169 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers to compute loss metrics.""" + +import scipy.stats +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + + +TASK_CLASSIFICATION = 'classification' +TASK_NORMALIZED_REGRESSION = 'normalized_regression' +TASK_UNNORMALIZED_REGRESSION = 'unnormalized_regression' +TASK_GROUNDED_UNNORMALIZED_REGRESSION = 'grounded_unnormalized_regression' +REGRESSION_TASKS = [TASK_NORMALIZED_REGRESSION, TASK_UNNORMALIZED_REGRESSION, + TASK_GROUNDED_UNNORMALIZED_REGRESSION] +ALL_TASKS = [TASK_CLASSIFICATION] + REGRESSION_TASKS + +LOSS_MSE = 'mse' +LOSS_SOFTMAX_CROSS_ENTROPY = 'softmax_cross_entropy' +ALL_LOSSES = [LOSS_SOFTMAX_CROSS_ENTROPY, LOSS_MSE] + + +def normalize_regression_loss(regression_loss, predictions): + # Normalize loss such that: + # 1) E_{x uniform}[loss(x, prediction)] does not depend on prediction + # 2) E_{x uniform, prediction uniform}[loss(x, prediction)] is as before. + # Divides MSE regression loss by E[(prediction-x)^2]; assumes x=[-1,1] + normalization = 2./3. + normalized_loss = regression_loss / ((1./3 + predictions**2) / normalization) + return normalized_loss + + +def equal32(x, y): + return tf.cast(tf.equal(x, y), tf.float32) + + +def mse_loss(predicted, targets): + return (predicted - targets) ** 2 + + +def get_std_factor_from_confidence_percent(percent): + dec = percent/100. + inv_dec = 1 - dec + return scipy.stats.norm.ppf(dec+inv_dec/2) + + +def get_all_metric_names(task_type, model_uncertainty, loss_config, # pylint: disable=unused-argument + mode='eval', return_dict=True): + """Get all the scalar fields produced by compute_loss_and_metrics.""" + names = ['regularization_loss', 'prediction_accuracy', str(mode)+'_loss'] + if task_type == TASK_CLASSIFICATION: + names += ['classification_loss'] + else: + names += ['regression_loss', 'avg_mu', 'var_mu'] + if model_uncertainty: + names += ['uncertainty_loss', 'scaled_regression_loss', + 'uncertainty_plus_scaled_regression', + 'avg_sigma', 'var_sigma', + 'percent_in_conf_interval', 'error_sigma_correlation', + 'avg_prob'] + if return_dict: + return {name: 0. for name in names} + else: + return names + + +def compute_loss_and_metrics(mu, log_sigma_sq, + regression_targets, labels, + task_type, model_uncertainty, loss_config, + regularization_loss=0., confidence_interval=95, + mode='train'): + """Computes loss statistics and other metrics.""" + + scalars_to_log = dict() + vectors_to_log = dict() + scalars_to_log['regularization_loss'] = regularization_loss + vectors_to_log['mu'] = mu + + if task_type == TASK_CLASSIFICATION: + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=mu, labels=labels, name='cross_entropy') + classification_loss = tf.reduce_mean(cross_entropy, name='class_loss') + total_loss = classification_loss + sigma = None + scalars_to_log['classification_loss'] = classification_loss + + predicted_labels = tf.argmax(mu, axis=1) + correct_predictions = equal32(predicted_labels, labels) + + else: + regression_loss = mse_loss(mu, regression_targets) + if 'mse_normalize' in loss_config and loss_config['mse_normalize']: + assert task_type in [TASK_GROUNDED_UNNORMALIZED_REGRESSION, + TASK_NORMALIZED_REGRESSION] + regression_loss = normalize_regression_loss(regression_loss, mu) + + avg_regression_loss = tf.reduce_mean(regression_loss) + vectors_to_log['regression_loss'] = regression_loss + scalars_to_log['regression_loss'] = avg_regression_loss + + scalars_to_log['avg_mu'] = tf.reduce_mean(mu) + scalars_to_log['var_mu'] = tf.reduce_mean(mse_loss(mu, tf.reduce_mean(mu))) + + predicted_labels = tf.cast(mu > 0, tf.int64) + correct_predictions = equal32(predicted_labels, labels) + + if model_uncertainty: + # This implements Eq. (1) in https://arxiv.org/pdf/1612.01474.pdf + inv_sigma_sq = tf.math.exp(-log_sigma_sq) + scaled_regression_loss = regression_loss * inv_sigma_sq + scaled_regression_loss = tf.reduce_mean(scaled_regression_loss) + uncertainty_loss = tf.reduce_mean(log_sigma_sq) + total_loss = uncertainty_loss + scaled_regression_loss + + scalars_to_log['uncertainty_loss'] = uncertainty_loss + scalars_to_log['scaled_regression_loss'] = scaled_regression_loss + scalars_to_log['uncertainty_plus_scaled_regression'] = total_loss + + sigma = tf.math.exp(log_sigma_sq / 2.) + vectors_to_log['sigma'] = sigma + scalars_to_log['avg_sigma'] = tf.reduce_mean(sigma) + var_sigma = tf.reduce_mean(mse_loss(sigma, tf.reduce_mean(sigma))) + scalars_to_log['var_sigma'] = var_sigma + + # Compute # of labels that fall into the confidence interval. + std_factor = get_std_factor_from_confidence_percent(confidence_interval) + lower_bound = mu - std_factor * sigma + upper_bound = mu + std_factor * sigma + preds = tf.logical_and(tf.greater(regression_targets, lower_bound), + tf.less(regression_targets, upper_bound)) + percent_in_conf_interval = tf.reduce_mean(tf.cast(preds, tf.float32)) + scalars_to_log['percent_in_conf_interval'] = percent_in_conf_interval*100 + + error_sigma_corr = tfp.stats.correlation(x=regression_loss, + y=sigma, event_axis=None) + scalars_to_log['error_sigma_correlation'] = error_sigma_corr + + dists = tfp.distributions.Normal(mu, sigma) + probs = dists.prob(regression_targets) + scalars_to_log['avg_prob'] = tf.reduce_mean(probs) + + else: + total_loss = avg_regression_loss + + loss_name = str(mode)+'_loss' + total_loss = tf.add(total_loss, regularization_loss, name=loss_name) + scalars_to_log[loss_name] = total_loss + vectors_to_log['correct_predictions'] = correct_predictions + scalars_to_log['prediction_accuracy'] = tf.reduce_mean(correct_predictions) + + # Validate that metrics outputted are exactly what is expected + expected = get_all_metric_names(task_type, model_uncertainty, + loss_config, mode, False) + assert set(expected) == set(scalars_to_log.keys()) + + return scalars_to_log, vectors_to_log diff --git a/galaxy_mergers/main.py b/galaxy_mergers/main.py new file mode 100644 index 0000000..9df5ab4 --- /dev/null +++ b/galaxy_mergers/main.py @@ -0,0 +1,51 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple script to model evaluation on a checkpoint and dataset.""" + +import ast +from absl import app +from absl import flags +from absl import logging + +from galaxy_mergers import evaluator + +flags.DEFINE_string('checkpoint_path', '', 'Path to TF2 checkpoint to eval.') +flags.DEFINE_string('data_path', '', 'Path to TFRecord(s) with data.') +flags.DEFINE_string('filter_time_intervals', None, + 'Merger time intervals on which to perform regression.' + 'Specify None for the default time interval [-1,1], or' + ' a custom list of intervals, e.g. [[-0.2,0], [0.5,1]].') + +FLAGS = flags.FLAGS + + +def main(_) -> None: + if FLAGS.filter_time_intervals is not None: + filter_time_intervals = ast.literal_eval(FLAGS.filter_time_intervals) + else: + filter_time_intervals = None + config, ds, experiment = evaluator.get_config_dataset_evaluator( + filter_time_intervals, + FLAGS.checkpoint_path, + config_override={ + 'experiment_kwargs.data_config.dataset_path': FLAGS.data_path, + }) + metrics, _, _ = evaluator.run_model_on_dataset(experiment, ds, config) + logging.info('Evaluation complete. Metrics: %s', metrics) + + +if __name__ == '__main__': + app.run(main) diff --git a/galaxy_mergers/model.py b/galaxy_mergers/model.py new file mode 100644 index 0000000..c2ba011 --- /dev/null +++ b/galaxy_mergers/model.py @@ -0,0 +1,201 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fork of a generic ResNet to incorporate additional cosmological features.""" + +from typing import Mapping, Optional, Sequence, Text + +import sonnet.v2 as snt +import tensorflow.compat.v2 as tf + + +class ResNet(snt.Module): + """ResNet model.""" + + def __init__(self, + n_repeats: int, + blocks_per_group_list: Sequence[int], + num_classes: int, + bn_config: Optional[Mapping[Text, float]] = None, + resnet_v2: bool = False, + channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048), + use_additional_features: bool = False, + additional_features_mode: Optional[Text] = "per_block", + name: Optional[Text] = None): + """Constructs a ResNet model. + + Args: + n_repeats: The batch dimension for the input is expected to have the form + `B = b * n_repeats`. After the conv stack, the logits for the + `n_repeats` replicas are reduced, leading to an output batch dimension + of `b`. + blocks_per_group_list: A sequence of length 4 that indicates the number of + blocks created in each group. + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. By default the `decay_rate` is + `0.9` and `eps` is `1e-5`. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to + False. + channels_per_group_list: A sequence of length 4 that indicates the number + of channels used for each block in each group. + use_additional_features: If true, additional vector features will be + concatenated to the residual stack before logits are computed. + additional_features_mode: Mode for processing additional features. + Supported modes: 'mlp' and 'per_block'. + name: Name of the module. + """ + super(ResNet, self).__init__(name=name) + self._n_repeats = n_repeats + if bn_config is None: + bn_config = {"decay_rate": 0.9, "eps": 1e-5} + self._bn_config = bn_config + self._resnet_v2 = resnet_v2 + + # Number of blocks in each group for ResNet. + if len(blocks_per_group_list) != 4: + raise ValueError( + "`blocks_per_group_list` must be of length 4 not {}".format( + len(blocks_per_group_list))) + self._blocks_per_group_list = blocks_per_group_list + + # Number of channels in each group for ResNet. + if len(channels_per_group_list) != 4: + raise ValueError( + "`channels_per_group_list` must be of length 4 not {}".format( + len(channels_per_group_list))) + self._channels_per_group_list = channels_per_group_list + self._use_additional_features = use_additional_features + self._additional_features_mode = additional_features_mode + + self._initial_conv = snt.Conv2D( + output_channels=64, + kernel_shape=7, + stride=2, + with_bias=False, + padding="SAME", + name="initial_conv") + if not self._resnet_v2: + self._initial_batchnorm = snt.BatchNorm( + create_scale=True, + create_offset=True, + name="initial_batchnorm", + **bn_config) + + self._block_groups = [] + strides = [1, 2, 2, 2] + for i in range(4): + self._block_groups.append( + snt.nets.resnet.BlockGroup( + channels=self._channels_per_group_list[i], + num_blocks=self._blocks_per_group_list[i], + stride=strides[i], + bn_config=bn_config, + resnet_v2=resnet_v2, + name="block_group_%d" % (i))) + + if self._resnet_v2: + self._final_batchnorm = snt.BatchNorm( + create_scale=True, + create_offset=True, + name="final_batchnorm", + **bn_config) + + self._logits = snt.Linear( + output_size=num_classes, + w_init=snt.initializers.VarianceScaling(scale=2.0), name="logits") + + if self._use_additional_features: + self._embedding = LinearBNReLU(output_size=16, name="embedding", + **bn_config) + + if self._additional_features_mode == "mlp": + self._feature_repr = LinearBNReLU( + output_size=self._channels_per_group_list[-1], name="features_repr", + **bn_config) + elif self._additional_features_mode == "per_block": + self._feature_repr = [] + for i, ch in enumerate(self._channels_per_group_list): + self._feature_repr.append( + LinearBNReLU(output_size=ch, name=f"features_{i}", **bn_config)) + else: + raise ValueError(f"Unsupported addiitonal features mode: " + f"{additional_features_mode}") + + def __call__(self, inputs, features, is_training): + net = inputs + net = self._initial_conv(net) + if not self._resnet_v2: + net = self._initial_batchnorm(net, is_training=is_training) + net = tf.nn.relu(net) + + net = tf.nn.max_pool2d( + net, ksize=3, strides=2, padding="SAME", name="initial_max_pool") + + if self._use_additional_features: + assert features is not None + features = self._embedding(features, is_training=is_training) + + for i, block_group in enumerate(self._block_groups): + net = block_group(net, is_training) + + if (self._use_additional_features and + self._additional_features_mode == "per_block"): + features_i = self._feature_repr[i](features, is_training=is_training) + # support for n_repeats > 1 + features_i = tf.repeat(features_i, self._n_repeats, axis=0) + net += features_i[:, None, None, :] # expand to spacial resolution + + if self._resnet_v2: + net = self._final_batchnorm(net, is_training=is_training) + net = tf.nn.relu(net) + net = tf.reduce_mean(net, axis=[1, 2], name="final_avg_pool") + # Re-split the batch dimension + net = tf.reshape(net, [-1, self._n_repeats] + net.shape.as_list()[1:]) + # Average over the various repeats of the input (e.g. those could have + # corresponded to different crops). + net = tf.reduce_mean(net, axis=1) + + if (self._use_additional_features and + self._additional_features_mode == "mlp"): + net += self._feature_repr(features, is_training=is_training) + + return self._logits(net) + + +class LinearBNReLU(snt.Module): + """Wrapper class for Linear layer with Batch Norm and ReLU activation.""" + + def __init__(self, output_size=64, + w_init=snt.initializers.VarianceScaling(scale=2.0), + name="linear", **bn_config): + """Constructs a LinearBNReLU module. + + Args: + output_size: Output dimension. + w_init: weight Initializer for snt.Linear. + name: Name of the module. + **bn_config: Optional parameters to be passed to snt.BatchNorm. + """ + super(LinearBNReLU, self).__init__(name=name) + self._linear = snt.Linear(output_size=output_size, w_init=w_init, + name=f"{name}_linear") + self._bn = snt.BatchNorm(create_scale=True, create_offset=True, + name=f"{name}_bn", **bn_config) + + def __call__(self, x, is_training): + x = self._linear(x) + x = self._bn(x, is_training=is_training) + return tf.nn.relu(x) diff --git a/galaxy_mergers/preprocessing.py b/galaxy_mergers/preprocessing.py new file mode 100644 index 0000000..0a7a5d6 --- /dev/null +++ b/galaxy_mergers/preprocessing.py @@ -0,0 +1,279 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pre-processing functions for input data.""" + +import functools +from absl import logging +import tensorflow.compat.v2 as tf +from galaxy_mergers import losses + + +CROP_TYPE_NONE = 'crop_none' +CROP_TYPE_FIXED = 'crop_fixed' +CROP_TYPE_RANDOM = 'crop_random' + +DATASET_FREQUENCY_MEAN = 4.0 +DATASET_FREQUENCY_RANGE = 8.0 + +PHYSICAL_FEATURES_MIN_MAX = { + 'redshift': (0.572788, 2.112304), + 'mass': (9.823963, 10.951282) +} + +ALL_FREQUENCIES = [105, 125, 160, 435, 606, 775, 850] + +VALID_ADDITIONAL_FEATURES = ['redshift', 'sequence_average_redshift', 'mass'] + + +def _make_padding_sizes(pad_size, random_centering): + if random_centering: + pad_size_left = tf.random.uniform( + shape=[], minval=0, maxval=pad_size+1, dtype=tf.int32) + else: + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + +def resize_and_pad(image, target_size, random_centering): + """Resize image to target_size (<= image.size) and pad to original size.""" + original_shape = image.shape + size = tf.reshape(target_size, [1]) + size = tf.concat([size, size], axis=0) + image = tf.image.resize(image, size=size) + pad_size = original_shape[1] - target_size + pad_size_left, pad_size_right = _make_padding_sizes( + pad_size, random_centering) + padding = [[pad_size_left, pad_size_right], + [pad_size_left, pad_size_right], [0, 0]] + if len(original_shape) == 4: + padding = [[0, 0]] + padding + image = tf.pad(image, padding) + image.set_shape(original_shape) + return image + + +def resize_and_extract(image, target_size, random_centering): + """Upscale image to target_size (>image.size), extract original size crop.""" + original_shape = image.shape + size = tf.reshape(target_size, [1]) + size = tf.concat([size, size], axis=0) + image = tf.image.resize(image, size=size) + pad_size = target_size - original_shape[1] + pad_size_left, pad_size_right = _make_padding_sizes( + pad_size, random_centering) + if len(original_shape) == 3: + image = tf.expand_dims(image, 0) + image = tf.cond(pad_size_right > 0, + lambda: image[:, pad_size_left:-pad_size_right, :, :], + lambda: image[:, pad_size_left:, :, :]) + image = tf.cond(pad_size_right > 0, + lambda: image[:, :, pad_size_left:-pad_size_right, :], + lambda: image[:, :, pad_size_left:, :]) + if len(original_shape) == 3: + image = tf.squeeze(image, 0) + image.set_shape(original_shape) + return image + + +def resize_and_center(image, target_size, random_centering): + return tf.cond( + tf.math.less_equal(target_size, image.shape[1]), + lambda: resize_and_pad(image, target_size, random_centering), + lambda: resize_and_extract(image, target_size, random_centering)) + + +def random_rotation_and_flip(image): + angle = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32) + return tf.image.random_flip_left_right(tf.image.rot90(image, angle)) + + +def get_all_rotations_and_flips(images): + assert isinstance(images, list) + new_images = [] + for image in images: + for rotation in range(4): + new_images.append(tf.image.rot90(image, rotation)) + flipped_image = tf.image.flip_left_right(image) + new_images.append(tf.image.rot90(flipped_image, rotation)) + return new_images + + +def random_rescaling(image, random_centering): + assert image.shape.as_list()[0] == image.shape.as_list()[1] + original_size = image.shape.as_list()[1] + min_size = 2 * (original_size // 4) + max_size = original_size * 2 + target_size = tf.random.uniform( + shape=[], minval=min_size, maxval=max_size // 2, + dtype=tf.int32) * 2 + return resize_and_center(image, target_size, random_centering) + + +def get_all_rescalings(images, image_width, random_centering): + """Get a uniform sample of rescalings of all images in input.""" + assert isinstance(images, list) + min_size = 2 * (image_width // 4) + max_size = image_width * 2 + delta_size = (max_size + 2 - min_size) // 5 + sizes = range(min_size, max_size + 2, delta_size) + new_images = [] + for image in images: + for size in sizes: + new_images.append(resize_and_center(image, size, random_centering)) + return new_images + + +def move_repeats_to_batch(image, n_repeats): + width, height, n_channels = image.shape.as_list()[1:] + image = tf.reshape(image, [-1, width, height, n_channels, n_repeats]) + image = tf.transpose(image, [0, 4, 1, 2, 3]) # [B, repeats, x, y, c] + return tf.reshape(image, [-1, width, height, n_channels]) + + +def get_classification_label(dataset_row, class_boundaries): + merge_time = dataset_row['grounded_normalized_time'] + label = tf.dtypes.cast(0, tf.int64) + for category, intervals in class_boundaries.items(): + for interval in intervals: + if merge_time > interval[0] and merge_time < interval[1]: + label = tf.dtypes.cast(int(category), tf.int64) + return label + + +def get_regression_label(dataset_row, task_type): + """Returns time-until-merger regression target given desired modeling task.""" + if task_type == losses.TASK_NORMALIZED_REGRESSION: + return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32) + elif task_type == losses.TASK_GROUNDED_UNNORMALIZED_REGRESSION: + return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32) + elif task_type == losses.TASK_UNNORMALIZED_REGRESSION: + return tf.dtypes.cast(dataset_row['unnormalized_time'], tf.float32) + elif task_type == losses.TASK_CLASSIFICATION: + return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32) + else: + raise ValueError + + +def get_normalized_time_target(dataset_row): + return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32) + + +def apply_time_filter(dataset_row, time_interval): + """Returns True if data is within the given time intervals.""" + merge_time = dataset_row['grounded_normalized_time'] + lower_time, upper_time = time_interval + return merge_time > lower_time and merge_time < upper_time + + +def normalize_physical_feature(name, dataset_row): + min_feat, max_feat = PHYSICAL_FEATURES_MIN_MAX[name] + value = getattr(dataset_row, name) + return 2 * (value - min_feat) / (max_feat - min_feat) - 1 + + +def prepare_dataset(ds, target_size, crop_type, n_repeats, augmentations, + task_type, additional_features, class_boundaries, + time_intervals=None, frequencies_to_use='all', + additional_lambdas=None): + """Prepare a zipped dataset of image, classification/regression labels.""" + def _prepare_image(dataset_row): + """Transpose, crop and cast an image.""" + image = tf.dtypes.cast(dataset_row['image'], tf.float32) + image = tf.reshape(image, tf.cast(dataset_row['image_shape'], tf.int32)) + image = tf.transpose(image, perm=[1, 2, 0]) # Convert to NHWC + + freqs = ALL_FREQUENCIES if frequencies_to_use == 'all' else frequencies_to_use + idxs_to_keep = [ALL_FREQUENCIES.index(f) for f in freqs] + image = tf.gather(params=image, indices=idxs_to_keep, axis=-1) + + # Based on offline computation on the empirical frequency range: + # Converts [0, 8.] ~~> [-1, 1] + image = (image - DATASET_FREQUENCY_MEAN)/(DATASET_FREQUENCY_RANGE/2.0) + + def crop(image): + if crop_type == CROP_TYPE_FIXED: + crop_loc = tf.cast(dataset_row['proposed_crop'][0], tf.int32) + crop_size = tf.cast(dataset_row['proposed_crop'][1], tf.int32) + image = image[ + crop_loc[0]:crop_loc[0] + crop_size[0], + crop_loc[1]:crop_loc[1] + crop_size[1], :] + image = tf.image.resize(image, target_size[0:2]) + image.set_shape([target_size[0], target_size[1], target_size[2]]) + + elif crop_type == CROP_TYPE_RANDOM: + image = tf.image.random_crop(image, target_size) + image.set_shape([target_size[0], target_size[1], target_size[2]]) + + elif crop_type != CROP_TYPE_NONE: + raise NotImplementedError + + return image + + repeated_images = [] + for _ in range(n_repeats): + repeated_images.append(crop(image)) + image = tf.concat(repeated_images, axis=-1) + + if augmentations['rotation_and_flip']: + image = random_rotation_and_flip(image) + if augmentations['rescaling']: + image = random_rescaling(image, augmentations['translation']) + + return image + + def get_regression_label_wrapper(dataset_row): + return get_regression_label(dataset_row, task_type=task_type) + + def get_classification_label_wrapper(dataset_row): + return get_classification_label(dataset_row, + class_boundaries=class_boundaries) + + if time_intervals: + for time_interval in time_intervals: + filter_fn = functools.partial(apply_time_filter, + time_interval=time_interval) + ds = ds.filter(filter_fn) + + datasets = [ds.map(_prepare_image)] + + if additional_features: + additional_features = additional_features.split(',') + assert all([f in VALID_ADDITIONAL_FEATURES for f in additional_features]) + logging.info('Running with additional features: %s.', + ', '.join(additional_features)) + + def _prepare_additional_features(dataset_row): + features = [] + for f in additional_features: + features.append(normalize_physical_feature(f, dataset_row)) + features = tf.convert_to_tensor(features, dtype=tf.float32) + features.set_shape([len(additional_features)]) + return features + + datasets += [ds.map(_prepare_additional_features)] + + datasets += [ + ds.map(get_classification_label_wrapper), + ds.map(get_regression_label_wrapper), + ds.map(get_normalized_time_target)] + + if additional_lambdas: + for process_fn in additional_lambdas: + datasets += [ds.map(process_fn)] + + return tf.data.Dataset.zip(tuple(datasets)) + diff --git a/galaxy_mergers/requirements.txt b/galaxy_mergers/requirements.txt new file mode 100644 index 0000000..5b1d3f2 --- /dev/null +++ b/galaxy_mergers/requirements.txt @@ -0,0 +1,53 @@ +absl-py==0.11.0 +astropy==4.2 +astunparse==1.6.3 +cachetools==4.2.1 +certifi==2020.12.5 +chardet==4.0.0 +cloudpickle==1.6.0 +contextlib2==0.6.0.post1 +cycler==0.10.0 +decorator==4.4.2 +dm-sonnet==2.0.0 +dm-tree==0.1.5 +flatbuffers==1.12 +gast==0.3.3 +google-auth==1.27.0 +google-auth-oauthlib==0.4.2 +google-pasta==0.2.0 +grpcio==1.32.0 +h5py==2.10.0 +idna==2.10 +Keras-Preprocessing==1.1.2 +kiwisolver==1.3.1 +Markdown==3.3.4 +matplotlib==3.3.4 +ml-collections==0.1.0 +numpy==1.19.5 +oauthlib==3.1.0 +opt-einsum==3.3.0 +Pillow==8.1.0 +pkg-resources==0.0.0 +protobuf==3.15.3 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pyerfa==1.7.2 +pyparsing==2.4.7 +python-dateutil==2.8.1 +PyYAML==5.4.1 +requests==2.25.1 +requests-oauthlib==1.3.0 +rsa==4.7.2 +scipy==1.6.1 +six==1.15.0 +tabulate==0.8.9 +tensorboard==2.4.1 +tensorboard-plugin-wit==1.8.0 +tensorflow==2.4.1 +tensorflow-estimator==2.4.0 +tensorflow-probability==0.12.1 +termcolor==1.1.0 +typing-extensions==3.7.4.3 +urllib3==1.26.3 +Werkzeug==1.0.1 +wrapt==1.12.1