mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Initial release of A Deep Learning Approach for Characterizing Major Galaxy Mergers
PiperOrigin-RevId: 369646863
This commit is contained in:
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers)
|
||||
* [Better, Faster Fermionic Neural Networks](kfac_ferminet_alpha) (KFAC implementation)
|
||||
* [Object-based attention for spatio-temporal reasoning](object_attention_for_reasoning)
|
||||
* [Effective gene expression prediction from sequence by integrating long-range interactions](enformer)
|
||||
|
||||
41
galaxy_mergers/README.md
Normal file
41
galaxy_mergers/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# A Deep Learning Approach for Characterizing Major Galaxy Mergers
|
||||
|
||||
This repository contains evaluation code and checkpoints to reproduce
|
||||
figures in https://arxiv.org/abs/2102.05182.
|
||||
|
||||
The main evaluation module is `main.py`. It uses the provided checkpoint path
|
||||
and dataset path to run evaluation.
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
To set up a Python virtual environment with the required dependencies, run:
|
||||
|
||||
```shell
|
||||
python3 -m venv galaxy_mergers_env
|
||||
source galaxy_mergers_env/bin/activate
|
||||
pip install --upgrade pip setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### License
|
||||
|
||||
While the code is licensed under the Apache 2.0 License, the checkpoints weights
|
||||
are made available for non-commercial use only under the terms of the
|
||||
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
||||
license. You can find details at:
|
||||
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
|
||||
|
||||
|
||||
### Citing our work
|
||||
|
||||
If you use this work, consider citing our paper:
|
||||
|
||||
```bibtex
|
||||
@article{koppula2021deep,
|
||||
title={A Deep Learning Approach for Characterizing Major Galaxy Mergers},
|
||||
author={Koppula, Skanda and Bapst, Victor and Huertas-Company, Marc and Blackwell, Sam and Grabska-Barwinska, Agnieszka and Dieleman, Sander and Huber, Andrea and Antropova, Natasha and Binkowski, Mikolaj and Openshaw, Hannah and others},
|
||||
journal={Workshop for Machine Learning and the Physical Sciences @ NeurIPS 2020},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
87
galaxy_mergers/antennae_helpers.py
Normal file
87
galaxy_mergers/antennae_helpers.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers to pre-process Antennae galaxy images."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from astropy.io import fits
|
||||
import numpy as np
|
||||
from scipy import ndimage
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def norm_antennae_images(images, scale=1000):
|
||||
return tf.math.asinh(images/scale)
|
||||
|
||||
|
||||
def renorm_antennae(images):
|
||||
median = np.percentile(images.numpy().flatten(), 50)
|
||||
img_range = np.ptp(images.numpy().flatten())
|
||||
return (images - median) / (img_range / 2)
|
||||
|
||||
|
||||
def get_antennae_images(antennae_fits_dir):
|
||||
"""Load the raw Antennae galaxy images."""
|
||||
all_fits_files = [
|
||||
os.path.join(antennae_fits_dir, f)
|
||||
for f in os.listdir(antennae_fits_dir)
|
||||
]
|
||||
freq_mapping = {'red': 160, 'blue': 850}
|
||||
|
||||
paired_fits_files = collections.defaultdict(list)
|
||||
for f in all_fits_files:
|
||||
redshift = float(f[-8:-5])
|
||||
paired_fits_files[redshift].append(f)
|
||||
|
||||
for redshift, files in paired_fits_files.items():
|
||||
paired_fits_files[redshift] = sorted(
|
||||
files, key=lambda f: freq_mapping[f.split('/')[-1].split('_')[0]])
|
||||
|
||||
print('Reading files:', paired_fits_files)
|
||||
print('Redshifts:', sorted(paired_fits_files.keys()))
|
||||
|
||||
galaxy_views = collections.defaultdict(list)
|
||||
for redshift in paired_fits_files:
|
||||
for view_path in paired_fits_files[redshift]:
|
||||
with open(view_path, 'rb') as f:
|
||||
fits_data = fits.open(f)
|
||||
galaxy_views[redshift].append(np.array(fits_data[0].data))
|
||||
|
||||
batched_images = []
|
||||
for redshift in paired_fits_files:
|
||||
img = tf.constant(np.array(galaxy_views[redshift]))
|
||||
img = tf.transpose(img, (1, 2, 0))
|
||||
img = tf.image.resize(img, size=(60, 60))
|
||||
batched_images.append(img)
|
||||
|
||||
return tf.stack(batched_images)
|
||||
|
||||
|
||||
def preprocess_antennae_images(antennae_images):
|
||||
"""Pre-process the Antennae galaxy images into a reasonable range."""
|
||||
rotated_antennae_images = [
|
||||
ndimage.rotate(img, 10, reshape=True, cval=-1)[10:-10, 10:-10]
|
||||
for img in antennae_images
|
||||
]
|
||||
rotated_antennae_images = [
|
||||
np.clip(img, 0, 1e9) for img in rotated_antennae_images
|
||||
]
|
||||
rotated_antennae_images = tf.stack(rotated_antennae_images)
|
||||
normed_antennae_images = norm_antennae_images(rotated_antennae_images)
|
||||
normed_antennae_images = tf.clip_by_value(normed_antennae_images, 1, 4.5)
|
||||
renormed_antennae_images = renorm_antennae(normed_antennae_images)
|
||||
return renormed_antennae_images
|
||||
98
galaxy_mergers/config.py
Normal file
98
galaxy_mergers/config.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Default config, focused on model evaluation."""
|
||||
|
||||
from ml_collections import config_dict
|
||||
|
||||
|
||||
def get_config(filter_time_intervals=None):
|
||||
"""Return config object for training."""
|
||||
config = config_dict.ConfigDict()
|
||||
config.eval_strategy = config_dict.ConfigDict()
|
||||
config.eval_strategy.class_name = 'OneDeviceConfig'
|
||||
config.eval_strategy.kwargs = config_dict.ConfigDict(
|
||||
dict(device_type='v100'))
|
||||
|
||||
## Experiment config.
|
||||
config.experiment_kwargs = config_dict.ConfigDict(dict(
|
||||
resnet_kwargs=dict(
|
||||
blocks_per_group_list=[3, 4, 6, 3], # This choice is ResNet50.
|
||||
bn_config=dict(
|
||||
decay_rate=0.9,
|
||||
eps=1e-5),
|
||||
resnet_v2=False,
|
||||
additional_features_mode='mlp',
|
||||
),
|
||||
optimizer_config=dict(
|
||||
class_name='Momentum',
|
||||
kwargs={'momentum': 0.9},
|
||||
# Set up the learning rate schedule.
|
||||
lr_init=0.025,
|
||||
lr_factor=0.1,
|
||||
lr_schedule=(50e3, 100e3, 150e3),
|
||||
gradient_clip=5.,
|
||||
),
|
||||
l2_regularization=1e-4,
|
||||
total_train_batch_size=128,
|
||||
train_net_args={'is_training': True},
|
||||
eval_batch_size=128,
|
||||
eval_net_args={'is_training': True},
|
||||
data_config=dict(
|
||||
# dataset loading
|
||||
dataset_path=None,
|
||||
num_val_splits=10,
|
||||
val_split=0,
|
||||
|
||||
# image cropping
|
||||
image_size=(80, 80, 7),
|
||||
train_crop_type='crop_fixed',
|
||||
test_crop_type='crop_fixed',
|
||||
n_crop_repeat=1,
|
||||
|
||||
train_augmentations=dict(
|
||||
rotation_and_flip=True,
|
||||
rescaling=True,
|
||||
translation=True,
|
||||
),
|
||||
|
||||
test_augmentations=dict(
|
||||
rotation_and_flip=False,
|
||||
rescaling=False,
|
||||
translation=False,
|
||||
),
|
||||
test_time_ensembling='sum',
|
||||
|
||||
num_eval_buckets=5,
|
||||
eval_confidence_interval=95,
|
||||
|
||||
task='grounded_unnormalized_regression',
|
||||
loss_config=dict(
|
||||
loss='mse',
|
||||
mse_normalize=False,
|
||||
),
|
||||
model_uncertainty=True,
|
||||
additional_features='',
|
||||
time_filter_intervals=filter_time_intervals,
|
||||
class_boundaries={
|
||||
'0': [[-1., 0]],
|
||||
'1': [[0, 1.]]
|
||||
},
|
||||
frequencies_to_use='all',
|
||||
),
|
||||
n_train_epochs=100
|
||||
))
|
||||
|
||||
return config
|
||||
262
galaxy_mergers/evaluator.py
Normal file
262
galaxy_mergers/evaluator.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Evaluation runner."""
|
||||
|
||||
import collections
|
||||
from absl import logging
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from galaxy_mergers import config as tp_config
|
||||
from galaxy_mergers import helpers
|
||||
from galaxy_mergers import losses
|
||||
from galaxy_mergers import model
|
||||
from galaxy_mergers import preprocessing
|
||||
|
||||
|
||||
class GalaxyMergeClassifierEvaluator():
|
||||
"""Galaxy Merge Rate Prediction Evaluation Runner."""
|
||||
|
||||
def __init__(self, strategy, optimizer_config, total_train_batch_size,
|
||||
train_net_args, eval_batch_size, eval_net_args,
|
||||
l2_regularization, data_config, resnet_kwargs, n_train_epochs):
|
||||
"""Initializes evaluator/experiment."""
|
||||
logging.info('Initializing evaluator...')
|
||||
self._strategy = strategy
|
||||
self._data_config = data_config
|
||||
self._use_additional_features = bool(data_config['additional_features'])
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._eval_net_args = eval_net_args
|
||||
self._num_buckets = data_config['num_eval_buckets']
|
||||
self._n_repeats = data_config['n_crop_repeat']
|
||||
self._image_size = data_config['image_size']
|
||||
self._task_type = data_config['task']
|
||||
self._loss_config = data_config['loss_config']
|
||||
self._model_uncertainty = data_config['model_uncertainty']
|
||||
del l2_regularization, optimizer_config, train_net_args
|
||||
del total_train_batch_size, n_train_epochs
|
||||
|
||||
logging.info('Creating model...')
|
||||
num_classes = 2 if self._model_uncertainty else 1
|
||||
if self._task_type == losses.TASK_CLASSIFICATION:
|
||||
num_classes = len(self._data_config['class_boundaries'])
|
||||
self.model = model.ResNet(
|
||||
n_repeats=self._data_config['n_crop_repeat'], num_classes=num_classes,
|
||||
use_additional_features=self._use_additional_features, **resnet_kwargs)
|
||||
|
||||
self._eval_input = None
|
||||
|
||||
def build_eval_input(self, additional_lambdas=None):
|
||||
"""Create the galaxy merger evaluation dataset."""
|
||||
|
||||
def decode_fn(record_bytes):
|
||||
parsed_example = tf.io.parse_single_example(
|
||||
record_bytes,
|
||||
{
|
||||
'image':
|
||||
tf.io.VarLenFeature(tf.float32),
|
||||
'image_shape':
|
||||
tf.io.FixedLenFeature([3], dtype=tf.int64),
|
||||
'axis':
|
||||
tf.io.FixedLenFeature([], dtype=tf.int64),
|
||||
'proposed_crop':
|
||||
tf.io.FixedLenFeature([2, 2], dtype=tf.int64),
|
||||
'normalized_time':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'unnormalized_time':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'grounded_normalized_time':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'redshift':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'sequence_average_redshift':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'mass':
|
||||
tf.io.FixedLenFeature([], dtype=tf.float32),
|
||||
'time_index':
|
||||
tf.io.FixedLenFeature([], dtype=tf.int64),
|
||||
'sequence_id':
|
||||
tf.io.FixedLenFeature([], dtype=tf.string),
|
||||
})
|
||||
parsed_example['image'] = tf.sparse.to_dense(
|
||||
parsed_example['image'], default_value=0)
|
||||
dataset_row = parsed_example
|
||||
return dataset_row
|
||||
|
||||
def build_eval_pipeline(_):
|
||||
"""Generate the processed input evaluation data."""
|
||||
|
||||
logging.info('Building evaluation input pipeline...')
|
||||
ds_path = self._data_config['dataset_path']
|
||||
ds = tf.data.TFRecordDataset([ds_path]).map(decode_fn)
|
||||
|
||||
augmentations = dict(
|
||||
rotation_and_flip=False,
|
||||
rescaling=False,
|
||||
translation=False
|
||||
)
|
||||
ds = preprocessing.prepare_dataset(
|
||||
ds=ds, target_size=self._image_size,
|
||||
crop_type=self._data_config['test_crop_type'],
|
||||
n_repeats=self._n_repeats,
|
||||
augmentations=augmentations,
|
||||
task_type=self._task_type,
|
||||
additional_features=self._data_config['additional_features'],
|
||||
class_boundaries=self._data_config['class_boundaries'],
|
||||
time_intervals=self._data_config['time_filter_intervals'],
|
||||
frequencies_to_use=self._data_config['frequencies_to_use'],
|
||||
additional_lambdas=additional_lambdas)
|
||||
|
||||
batched_ds = ds.cache().batch(self._eval_batch_size).prefetch(128)
|
||||
logging.info('Finished building input pipeline...')
|
||||
return batched_ds
|
||||
|
||||
return self._strategy.experimental_distribute_datasets_from_function(
|
||||
build_eval_pipeline)
|
||||
|
||||
def run_test_model_ensemble(self, images, physical_features, augmentations):
|
||||
"""Run evaluation on input images."""
|
||||
image_variations = [images]
|
||||
image_shape = images.shape.as_list()
|
||||
|
||||
if augmentations['rotation_and_flip']:
|
||||
image_variations = preprocessing.get_all_rotations_and_flips(
|
||||
image_variations)
|
||||
|
||||
if augmentations['rescaling']:
|
||||
image_variations = preprocessing.get_all_rescalings(
|
||||
image_variations, image_shape[1], augmentations['translation'])
|
||||
|
||||
# Put all augmented images into the batch: batch * num_augmented
|
||||
augmented_images = tf.stack(image_variations, axis=0)
|
||||
augmented_images = tf.reshape(augmented_images, [-1] + image_shape[1:])
|
||||
if self._use_additional_features:
|
||||
physical_features = tf.concat(
|
||||
[physical_features] * len(image_variations), axis=0)
|
||||
|
||||
n_reps = self._data_config['n_crop_repeat']
|
||||
augmented_images = preprocessing.move_repeats_to_batch(augmented_images,
|
||||
n_reps)
|
||||
|
||||
logits_or_times = self.model(augmented_images, physical_features,
|
||||
**self._eval_net_args)
|
||||
if self._task_type == losses.TASK_CLASSIFICATION:
|
||||
mu, log_sigma_sq = helpers.aggregate_classification_ensemble(
|
||||
logits_or_times, len(image_variations),
|
||||
self._data_config['test_time_ensembling'])
|
||||
else:
|
||||
assert self._task_type in losses.REGRESSION_TASKS
|
||||
mu, log_sigma_sq = helpers.aggregate_regression_ensemble(
|
||||
logits_or_times, len(image_variations),
|
||||
self._model_uncertainty,
|
||||
self._data_config['test_time_ensembling'])
|
||||
|
||||
return mu, log_sigma_sq
|
||||
|
||||
@property
|
||||
def checkpoint_items(self):
|
||||
return {'model': self.model}
|
||||
|
||||
|
||||
def run_model_on_dataset(evaluator, dataset, config, n_batches=16):
|
||||
"""Runs the model against a dataset, aggregates model output."""
|
||||
|
||||
scalar_metrics_to_log = collections.defaultdict(list)
|
||||
model_outputs_to_log = collections.defaultdict(list)
|
||||
dataset_features_to_log = collections.defaultdict(list)
|
||||
|
||||
batch_count = 1
|
||||
for all_inputs in dataset:
|
||||
if config.experiment_kwargs.data_config['additional_features']:
|
||||
images = all_inputs[0]
|
||||
physical_features = all_inputs[1]
|
||||
labels, regression_targets, _ = all_inputs[2:5]
|
||||
other_dataset_features = all_inputs[5:]
|
||||
else:
|
||||
images, physical_features = all_inputs[0], None
|
||||
labels, regression_targets, _ = all_inputs[1:4]
|
||||
other_dataset_features = all_inputs[4:]
|
||||
|
||||
mu, log_sigma_sq = evaluator.run_test_model_ensemble(
|
||||
images, physical_features,
|
||||
config.experiment_kwargs.data_config['test_augmentations'])
|
||||
|
||||
loss_config = config.experiment_kwargs.data_config['loss_config']
|
||||
task_type = config.experiment_kwargs.data_config['task']
|
||||
uncertainty = config.experiment_kwargs.data_config['model_uncertainty']
|
||||
conf = config.experiment_kwargs.data_config['eval_confidence_interval']
|
||||
scalar_metrics, vector_metrics = losses.compute_loss_and_metrics(
|
||||
mu, log_sigma_sq, regression_targets, labels,
|
||||
task_type, uncertainty, loss_config, 0, conf, mode='eval')
|
||||
|
||||
for i, dataset_feature in enumerate(other_dataset_features):
|
||||
dataset_features_to_log[i].append(dataset_feature.numpy())
|
||||
|
||||
for scalar_metric in scalar_metrics:
|
||||
v = scalar_metrics[scalar_metric]
|
||||
val = v if isinstance(v, int) or isinstance(v, float) else v.numpy()
|
||||
scalar_metrics_to_log[scalar_metric].append(val)
|
||||
|
||||
for vector_metric in vector_metrics:
|
||||
val = vector_metrics[vector_metric].numpy()
|
||||
model_outputs_to_log[vector_metric].append(val)
|
||||
|
||||
regression_targets_np = regression_targets.numpy()
|
||||
labels_np = labels.numpy()
|
||||
model_outputs_to_log['regression_targets'].append(regression_targets_np)
|
||||
model_outputs_to_log['labels'].append(labels_np)
|
||||
model_outputs_to_log['model_input_images'].append(images.numpy())
|
||||
|
||||
if n_batches and batch_count >= n_batches:
|
||||
break
|
||||
batch_count += 1
|
||||
|
||||
return scalar_metrics_to_log, model_outputs_to_log, dataset_features_to_log
|
||||
|
||||
|
||||
def get_config_dataset_evaluator(filter_time_intervals,
|
||||
ckpt_path,
|
||||
config_override=None,
|
||||
setup_dataset=True):
|
||||
"""Set-up a default config, evaluation dataset, and evaluator."""
|
||||
config = tp_config.get_config(filter_time_intervals=filter_time_intervals)
|
||||
|
||||
if config_override:
|
||||
with config.ignore_type():
|
||||
config.update_from_flattened_dict(config_override)
|
||||
|
||||
strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0')
|
||||
experiment = GalaxyMergeClassifierEvaluator(
|
||||
strategy=strategy, **config.experiment_kwargs)
|
||||
|
||||
helpers.restore_checkpoint(ckpt_path, experiment)
|
||||
|
||||
if setup_dataset:
|
||||
additional_lambdas = [
|
||||
lambda ds: ds['sequence_id'],
|
||||
lambda ds: ds['time_index'],
|
||||
lambda ds: ds['axis'],
|
||||
lambda ds: ds['normalized_time'],
|
||||
lambda ds: ds['grounded_normalized_time'],
|
||||
lambda ds: ds['unnormalized_time'],
|
||||
lambda ds: ds['redshift'],
|
||||
lambda ds: ds['mass']
|
||||
]
|
||||
|
||||
ds = experiment.build_eval_input(additional_lambdas=additional_lambdas)
|
||||
else:
|
||||
ds = None
|
||||
return config, ds, experiment
|
||||
228
galaxy_mergers/helpers.py
Normal file
228
galaxy_mergers/helpers.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers for a galaxy merger model evaluation."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from astropy import cosmology
|
||||
from astropy.io import fits
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def restore_checkpoint(checkpoint_dir, experiment):
|
||||
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
global_step = tf.Variable(
|
||||
0, dtype=tf.int32, trainable=False, name='global_step')
|
||||
checkpoint = tf.train.Checkpoint(
|
||||
_global_step_=global_step, **experiment.checkpoint_items)
|
||||
checkpoint.restore(checkpoint_path)
|
||||
|
||||
|
||||
def sum_average_transformed_mu_and_sigma(mu, log_sigma_sq):
|
||||
"""Computes <mu>, var(mu) + <var> in transformed representation.
|
||||
|
||||
This corresponds to assuming that the output distribution is a sum of
|
||||
Gaussian and computing the mean and variance of the resulting (non-Gaussian)
|
||||
distribution.
|
||||
|
||||
Args:
|
||||
mu: Tensor of shape [B, ...] representing the means of the input
|
||||
distributions.
|
||||
log_sigma_sq: Tensor of shape [B, ...] representing log(sigma**2) of the
|
||||
input distributions. Can be None, in which case the variance is assumed
|
||||
to be zero.
|
||||
|
||||
Returns:
|
||||
mu: Tensor of shape [...] representing the means of the output
|
||||
distributions.
|
||||
log_sigma_sq: Tensor of shape [...] representing log(sigma**2) of the
|
||||
output distributions.
|
||||
"""
|
||||
av_mu = tf.reduce_mean(mu, axis=0)
|
||||
var_mu = tf.math.reduce_std(mu, axis=0)**2
|
||||
if log_sigma_sq is None:
|
||||
return av_mu, tf.math.log(var_mu)
|
||||
max_log_sigma_sq = tf.reduce_max(log_sigma_sq, axis=0)
|
||||
log_sigma_sq -= max_log_sigma_sq
|
||||
# (sigma/sigma_0)**2
|
||||
sigma_sq = tf.math.exp(log_sigma_sq)
|
||||
# (<sigma**2>)/sigma_0**2 (<1)
|
||||
av_sigma_sq = tf.reduce_mean(sigma_sq, axis=0)
|
||||
# (<sigma**2> + var(mu))/sigma_0**2
|
||||
av_sigma_sq += var_mu * tf.math.exp(-max_log_sigma_sq)
|
||||
# log(<sigma**2> + var(mu))
|
||||
log_av_sigma_sq = tf.math.log(av_sigma_sq) + max_log_sigma_sq
|
||||
return av_mu, log_av_sigma_sq
|
||||
|
||||
|
||||
def aggregate_regression_ensemble(logits_or_times, ensemble_size,
|
||||
use_uncertainty, test_time_ensembling):
|
||||
"""Aggregate output of model ensemble."""
|
||||
out_shape = logits_or_times.shape.as_list()[1:]
|
||||
logits_or_times = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape)
|
||||
mus = logits_or_times[..., 0]
|
||||
log_sigma_sqs = logits_or_times[..., -1] if use_uncertainty else None
|
||||
|
||||
if test_time_ensembling == 'sum':
|
||||
mu, log_sigma_sq = sum_average_transformed_mu_and_sigma(mus, log_sigma_sqs)
|
||||
elif test_time_ensembling == 'none':
|
||||
mu = mus[0]
|
||||
log_sigma_sq = log_sigma_sqs[0] if use_uncertainty else None
|
||||
else:
|
||||
raise ValueError('Unexpected test_time_ensembling')
|
||||
return mu, log_sigma_sq
|
||||
|
||||
|
||||
def aggregate_classification_ensemble(logits_or_times, ensemble_size,
|
||||
test_time_ensembling):
|
||||
"""Averages the output logits across models in the ensemble."""
|
||||
out_shape = logits_or_times.shape.as_list()[1:]
|
||||
logits = tf.reshape(logits_or_times, [ensemble_size, -1] + out_shape)
|
||||
|
||||
if test_time_ensembling == 'sum':
|
||||
logits = tf.reduce_mean(logits, axis=0)
|
||||
return logits, None
|
||||
elif test_time_ensembling == 'none':
|
||||
return logits, None
|
||||
else:
|
||||
raise ValueError('Unexpected test_time_ensembling')
|
||||
|
||||
|
||||
def unpack_evaluator_output(data, return_seq_info=False, return_redshift=False):
|
||||
"""Unpack evaluator.run_model_on_dataset output."""
|
||||
mus = np.array(data[1]['mu']).flatten()
|
||||
sigmas = np.array(data[1]['sigma']).flatten()
|
||||
regression_targets = np.array(data[1]['regression_targets']).flatten()
|
||||
outputs = [mus, sigmas, regression_targets]
|
||||
|
||||
if return_seq_info:
|
||||
seq_ids = np.array(data[2][0]).flatten()
|
||||
seq_ids = np.array([seq_id.decode('UTF-8') for seq_id in seq_ids])
|
||||
time_idxs = np.array(data[2][1]).flatten()
|
||||
axes = np.array(data[2][2]).flatten()
|
||||
outputs += [seq_ids, axes, time_idxs]
|
||||
|
||||
if return_redshift:
|
||||
redshifts = np.array(data[2][6]).flatten()
|
||||
outputs += [redshifts]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def process_data_into_myrs(redshifts, *data_lists):
|
||||
"""Converts normalized time to virial time using Planck cosmology."""
|
||||
# small hack to avoid build tools not recognizing non-standard trickery
|
||||
# done in the astropy library:
|
||||
# https://github.com/astropy/astropy/blob/master/astropy/cosmology/core.py#L3290
|
||||
# that dynamically generates and imports new classes.
|
||||
planck13 = getattr(cosmology, 'Plank13')
|
||||
hubble_constants = planck13.H(redshifts) # (km/s)/megaparsec
|
||||
inv_hubble_constants = 1/hubble_constants # (megaparsec*s) / km
|
||||
megaparsec_to_km = 1e19*3.1
|
||||
seconds_to_gigayears = 1e-15/31.556
|
||||
conversion_factor = megaparsec_to_km * seconds_to_gigayears
|
||||
hubble_time_gigayears = conversion_factor * inv_hubble_constants
|
||||
|
||||
hubble_to_virial_time = 0.14 # approximate simulation-based conversion factor
|
||||
virial_dyn_time = hubble_to_virial_time*hubble_time_gigayears.value
|
||||
return [data_list*virial_dyn_time for data_list in data_lists]
|
||||
|
||||
|
||||
def print_rmse_and_class_accuracy(mus, regression_targets, redshifts):
|
||||
"""Convert to virial dynamical time and print stats."""
|
||||
time_pred, time_gt = process_data_into_myrs(
|
||||
redshifts, mus, regression_targets)
|
||||
time_sq_errors = (time_pred-time_gt)**2
|
||||
rmse = np.sqrt(np.mean(time_sq_errors))
|
||||
labels = regression_targets > 0
|
||||
class_preds = mus > 0
|
||||
accuracy = sum((labels == class_preds).astype(np.int8)) / len(class_preds)
|
||||
|
||||
print(f'95% Error: {np.percentile(np.sqrt(time_sq_errors), 95)}')
|
||||
print(f'RMSE: {rmse}')
|
||||
print(f'Classification Accuracy: {accuracy}')
|
||||
|
||||
|
||||
def print_stats(vec, do_print=True):
|
||||
fvec = vec.flatten()
|
||||
if do_print:
|
||||
print(len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec))
|
||||
return (len(fvec), min(fvec), np.mean(fvec), np.median(fvec), max(fvec))
|
||||
|
||||
|
||||
def get_image_from_fits(base_dir, seq='475_31271', time='497', axis=2):
|
||||
"""Read *.fits galaxy image from directory."""
|
||||
axis_map = {0: 'x', 1: 'y', 2: 'z'}
|
||||
fits_glob = f'{base_dir}/{seq}/fits_of_flux_psf/{time}/*_{axis_map[axis]}_*.fits'
|
||||
|
||||
def get_freq_from_path(p):
|
||||
return int(p.split('/')[-1].split('_')[2][1:])
|
||||
|
||||
fits_image_paths = sorted(glob.glob(fits_glob), key=get_freq_from_path)
|
||||
assert len(fits_image_paths) == 7
|
||||
combined_frequencies = []
|
||||
for fit_path in fits_image_paths:
|
||||
with open(fit_path, 'rb') as f:
|
||||
fits_data = np.array(fits.open(f)[0].data.astype(np.float32))
|
||||
combined_frequencies.append(fits_data)
|
||||
fits_image = np.transpose(np.array(combined_frequencies), (1, 2, 0))
|
||||
return fits_image
|
||||
|
||||
|
||||
def stack_desired_galaxy_images(base_dir, seq, n_time_slices):
|
||||
"""Searth through galaxy image directory gathering images."""
|
||||
fits_sequence_dir = os.path.join(base_dir, seq, 'fits_of_flux_psf')
|
||||
all_times_for_seq = os.listdir(fits_sequence_dir)
|
||||
hop = (len(all_times_for_seq)-1)//(n_time_slices-1)
|
||||
desired_time_idxs = [k*hop for k in range(n_time_slices)]
|
||||
|
||||
all_imgs = []
|
||||
for j in desired_time_idxs:
|
||||
time = all_times_for_seq[j]
|
||||
img = get_image_from_fits(base_dir=base_dir, seq=seq, time=time, axis=2)
|
||||
all_imgs.append(img)
|
||||
|
||||
min_img_size = min([img.shape[0] for img in all_imgs])
|
||||
return all_imgs, min_img_size
|
||||
|
||||
|
||||
def draw_galaxy_image(image, target_size=None, color_map='viridis'):
|
||||
normalized_image = image / max(image.flatten())
|
||||
color_map = plt.get_cmap(color_map)
|
||||
colored_image = color_map(normalized_image)[:, :, :3]
|
||||
colored_image = (colored_image * 255).astype(np.uint8)
|
||||
colored_image = Image.fromarray(colored_image, mode='RGB')
|
||||
if target_size:
|
||||
colored_image = colored_image.resize(target_size, Image.ANTIALIAS)
|
||||
return colored_image
|
||||
|
||||
|
||||
def collect_merger_sequence(ds, seq=b'370_11071', n_examples_to_sift=5000):
|
||||
images, targets, redshifts = [], [], []
|
||||
for i, all_inputs in enumerate(ds):
|
||||
if all_inputs[4][0].numpy() == seq:
|
||||
images.append(all_inputs[0][0].numpy())
|
||||
targets.append(all_inputs[2][0].numpy())
|
||||
redshifts.append(all_inputs[10][0].numpy())
|
||||
if i > n_examples_to_sift: break
|
||||
return np.squeeze(images), np.squeeze(targets), np.squeeze(redshifts)
|
||||
|
||||
|
||||
def take_samples(sample_idxs, *data_lists):
|
||||
return [np.take(l, sample_idxs, axis=0) for l in data_lists]
|
||||
68
galaxy_mergers/interpretability_helpers.py
Normal file
68
galaxy_mergers/interpretability_helpers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers to visualize gradients and other interpretability analysis."""
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def rotate_by_right_angle_multiple(image, rot=90):
|
||||
"""Rotate an image by right angles."""
|
||||
if rot not in [0, 90, 180, 270]:
|
||||
raise ValueError(f"Cannot rotate by non-90 degree angle {rot}")
|
||||
|
||||
if rot in [90, -270]:
|
||||
image = np.transpose(image, (1, 0, 2))
|
||||
image = image[::-1]
|
||||
elif rot in [180, -180]:
|
||||
image = image[::-1, ::-1]
|
||||
elif rot in [270, -90]:
|
||||
image = np.transpose(image, (1, 0, 2))
|
||||
image = image[:, ::-1]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def compute_gradient(images, evaluator, is_training=False):
|
||||
inputs = tf.Variable(images[None], dtype=tf.float32)
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch(inputs)
|
||||
time_sigma = evaluator.model(inputs, None, is_training)
|
||||
grad_time = tape.gradient(time_sigma[:, 0], inputs)
|
||||
return grad_time, time_sigma
|
||||
|
||||
|
||||
def compute_grads_for_rotations(images, evaluator, is_training=False):
|
||||
test_gradients, test_outputs = [], []
|
||||
for rotation in np.arange(0, 360, 90):
|
||||
images_rot = rotate_by_right_angle_multiple(images, rotation)
|
||||
grads, time_sigma = compute_gradient(images_rot, evaluator, is_training)
|
||||
grads = np.squeeze(grads.numpy())
|
||||
inv_grads = rotate_by_right_angle_multiple(grads, -rotation)
|
||||
test_gradients.append(inv_grads)
|
||||
test_outputs.append(time_sigma.numpy())
|
||||
return np.squeeze(test_gradients), np.squeeze(test_outputs)
|
||||
|
||||
|
||||
def compute_grads_for_rotations_and_flips(images, evaluator):
|
||||
grads, time_sigma = compute_grads_for_rotations(images, evaluator)
|
||||
grads_f, time_sigma_f = compute_grads_for_rotations(images[::-1], evaluator)
|
||||
grads_f = grads_f[:, ::-1]
|
||||
all_grads = np.concatenate([grads, grads_f], 0)
|
||||
model_outputs = np.concatenate((time_sigma, time_sigma_f), 0)
|
||||
return all_grads, model_outputs
|
||||
|
||||
|
||||
169
galaxy_mergers/losses.py
Normal file
169
galaxy_mergers/losses.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers to compute loss metrics."""
|
||||
|
||||
import scipy.stats
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
TASK_CLASSIFICATION = 'classification'
|
||||
TASK_NORMALIZED_REGRESSION = 'normalized_regression'
|
||||
TASK_UNNORMALIZED_REGRESSION = 'unnormalized_regression'
|
||||
TASK_GROUNDED_UNNORMALIZED_REGRESSION = 'grounded_unnormalized_regression'
|
||||
REGRESSION_TASKS = [TASK_NORMALIZED_REGRESSION, TASK_UNNORMALIZED_REGRESSION,
|
||||
TASK_GROUNDED_UNNORMALIZED_REGRESSION]
|
||||
ALL_TASKS = [TASK_CLASSIFICATION] + REGRESSION_TASKS
|
||||
|
||||
LOSS_MSE = 'mse'
|
||||
LOSS_SOFTMAX_CROSS_ENTROPY = 'softmax_cross_entropy'
|
||||
ALL_LOSSES = [LOSS_SOFTMAX_CROSS_ENTROPY, LOSS_MSE]
|
||||
|
||||
|
||||
def normalize_regression_loss(regression_loss, predictions):
|
||||
# Normalize loss such that:
|
||||
# 1) E_{x uniform}[loss(x, prediction)] does not depend on prediction
|
||||
# 2) E_{x uniform, prediction uniform}[loss(x, prediction)] is as before.
|
||||
# Divides MSE regression loss by E[(prediction-x)^2]; assumes x=[-1,1]
|
||||
normalization = 2./3.
|
||||
normalized_loss = regression_loss / ((1./3 + predictions**2) / normalization)
|
||||
return normalized_loss
|
||||
|
||||
|
||||
def equal32(x, y):
|
||||
return tf.cast(tf.equal(x, y), tf.float32)
|
||||
|
||||
|
||||
def mse_loss(predicted, targets):
|
||||
return (predicted - targets) ** 2
|
||||
|
||||
|
||||
def get_std_factor_from_confidence_percent(percent):
|
||||
dec = percent/100.
|
||||
inv_dec = 1 - dec
|
||||
return scipy.stats.norm.ppf(dec+inv_dec/2)
|
||||
|
||||
|
||||
def get_all_metric_names(task_type, model_uncertainty, loss_config, # pylint: disable=unused-argument
|
||||
mode='eval', return_dict=True):
|
||||
"""Get all the scalar fields produced by compute_loss_and_metrics."""
|
||||
names = ['regularization_loss', 'prediction_accuracy', str(mode)+'_loss']
|
||||
if task_type == TASK_CLASSIFICATION:
|
||||
names += ['classification_loss']
|
||||
else:
|
||||
names += ['regression_loss', 'avg_mu', 'var_mu']
|
||||
if model_uncertainty:
|
||||
names += ['uncertainty_loss', 'scaled_regression_loss',
|
||||
'uncertainty_plus_scaled_regression',
|
||||
'avg_sigma', 'var_sigma',
|
||||
'percent_in_conf_interval', 'error_sigma_correlation',
|
||||
'avg_prob']
|
||||
if return_dict:
|
||||
return {name: 0. for name in names}
|
||||
else:
|
||||
return names
|
||||
|
||||
|
||||
def compute_loss_and_metrics(mu, log_sigma_sq,
|
||||
regression_targets, labels,
|
||||
task_type, model_uncertainty, loss_config,
|
||||
regularization_loss=0., confidence_interval=95,
|
||||
mode='train'):
|
||||
"""Computes loss statistics and other metrics."""
|
||||
|
||||
scalars_to_log = dict()
|
||||
vectors_to_log = dict()
|
||||
scalars_to_log['regularization_loss'] = regularization_loss
|
||||
vectors_to_log['mu'] = mu
|
||||
|
||||
if task_type == TASK_CLASSIFICATION:
|
||||
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=mu, labels=labels, name='cross_entropy')
|
||||
classification_loss = tf.reduce_mean(cross_entropy, name='class_loss')
|
||||
total_loss = classification_loss
|
||||
sigma = None
|
||||
scalars_to_log['classification_loss'] = classification_loss
|
||||
|
||||
predicted_labels = tf.argmax(mu, axis=1)
|
||||
correct_predictions = equal32(predicted_labels, labels)
|
||||
|
||||
else:
|
||||
regression_loss = mse_loss(mu, regression_targets)
|
||||
if 'mse_normalize' in loss_config and loss_config['mse_normalize']:
|
||||
assert task_type in [TASK_GROUNDED_UNNORMALIZED_REGRESSION,
|
||||
TASK_NORMALIZED_REGRESSION]
|
||||
regression_loss = normalize_regression_loss(regression_loss, mu)
|
||||
|
||||
avg_regression_loss = tf.reduce_mean(regression_loss)
|
||||
vectors_to_log['regression_loss'] = regression_loss
|
||||
scalars_to_log['regression_loss'] = avg_regression_loss
|
||||
|
||||
scalars_to_log['avg_mu'] = tf.reduce_mean(mu)
|
||||
scalars_to_log['var_mu'] = tf.reduce_mean(mse_loss(mu, tf.reduce_mean(mu)))
|
||||
|
||||
predicted_labels = tf.cast(mu > 0, tf.int64)
|
||||
correct_predictions = equal32(predicted_labels, labels)
|
||||
|
||||
if model_uncertainty:
|
||||
# This implements Eq. (1) in https://arxiv.org/pdf/1612.01474.pdf
|
||||
inv_sigma_sq = tf.math.exp(-log_sigma_sq)
|
||||
scaled_regression_loss = regression_loss * inv_sigma_sq
|
||||
scaled_regression_loss = tf.reduce_mean(scaled_regression_loss)
|
||||
uncertainty_loss = tf.reduce_mean(log_sigma_sq)
|
||||
total_loss = uncertainty_loss + scaled_regression_loss
|
||||
|
||||
scalars_to_log['uncertainty_loss'] = uncertainty_loss
|
||||
scalars_to_log['scaled_regression_loss'] = scaled_regression_loss
|
||||
scalars_to_log['uncertainty_plus_scaled_regression'] = total_loss
|
||||
|
||||
sigma = tf.math.exp(log_sigma_sq / 2.)
|
||||
vectors_to_log['sigma'] = sigma
|
||||
scalars_to_log['avg_sigma'] = tf.reduce_mean(sigma)
|
||||
var_sigma = tf.reduce_mean(mse_loss(sigma, tf.reduce_mean(sigma)))
|
||||
scalars_to_log['var_sigma'] = var_sigma
|
||||
|
||||
# Compute # of labels that fall into the confidence interval.
|
||||
std_factor = get_std_factor_from_confidence_percent(confidence_interval)
|
||||
lower_bound = mu - std_factor * sigma
|
||||
upper_bound = mu + std_factor * sigma
|
||||
preds = tf.logical_and(tf.greater(regression_targets, lower_bound),
|
||||
tf.less(regression_targets, upper_bound))
|
||||
percent_in_conf_interval = tf.reduce_mean(tf.cast(preds, tf.float32))
|
||||
scalars_to_log['percent_in_conf_interval'] = percent_in_conf_interval*100
|
||||
|
||||
error_sigma_corr = tfp.stats.correlation(x=regression_loss,
|
||||
y=sigma, event_axis=None)
|
||||
scalars_to_log['error_sigma_correlation'] = error_sigma_corr
|
||||
|
||||
dists = tfp.distributions.Normal(mu, sigma)
|
||||
probs = dists.prob(regression_targets)
|
||||
scalars_to_log['avg_prob'] = tf.reduce_mean(probs)
|
||||
|
||||
else:
|
||||
total_loss = avg_regression_loss
|
||||
|
||||
loss_name = str(mode)+'_loss'
|
||||
total_loss = tf.add(total_loss, regularization_loss, name=loss_name)
|
||||
scalars_to_log[loss_name] = total_loss
|
||||
vectors_to_log['correct_predictions'] = correct_predictions
|
||||
scalars_to_log['prediction_accuracy'] = tf.reduce_mean(correct_predictions)
|
||||
|
||||
# Validate that metrics outputted are exactly what is expected
|
||||
expected = get_all_metric_names(task_type, model_uncertainty,
|
||||
loss_config, mode, False)
|
||||
assert set(expected) == set(scalars_to_log.keys())
|
||||
|
||||
return scalars_to_log, vectors_to_log
|
||||
51
galaxy_mergers/main.py
Normal file
51
galaxy_mergers/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Simple script to model evaluation on a checkpoint and dataset."""
|
||||
|
||||
import ast
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from galaxy_mergers import evaluator
|
||||
|
||||
flags.DEFINE_string('checkpoint_path', '', 'Path to TF2 checkpoint to eval.')
|
||||
flags.DEFINE_string('data_path', '', 'Path to TFRecord(s) with data.')
|
||||
flags.DEFINE_string('filter_time_intervals', None,
|
||||
'Merger time intervals on which to perform regression.'
|
||||
'Specify None for the default time interval [-1,1], or'
|
||||
' a custom list of intervals, e.g. [[-0.2,0], [0.5,1]].')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
if FLAGS.filter_time_intervals is not None:
|
||||
filter_time_intervals = ast.literal_eval(FLAGS.filter_time_intervals)
|
||||
else:
|
||||
filter_time_intervals = None
|
||||
config, ds, experiment = evaluator.get_config_dataset_evaluator(
|
||||
filter_time_intervals,
|
||||
FLAGS.checkpoint_path,
|
||||
config_override={
|
||||
'experiment_kwargs.data_config.dataset_path': FLAGS.data_path,
|
||||
})
|
||||
metrics, _, _ = evaluator.run_model_on_dataset(experiment, ds, config)
|
||||
logging.info('Evaluation complete. Metrics: %s', metrics)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
201
galaxy_mergers/model.py
Normal file
201
galaxy_mergers/model.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fork of a generic ResNet to incorporate additional cosmological features."""
|
||||
|
||||
from typing import Mapping, Optional, Sequence, Text
|
||||
|
||||
import sonnet.v2 as snt
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
class ResNet(snt.Module):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(self,
|
||||
n_repeats: int,
|
||||
blocks_per_group_list: Sequence[int],
|
||||
num_classes: int,
|
||||
bn_config: Optional[Mapping[Text, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048),
|
||||
use_additional_features: bool = False,
|
||||
additional_features_mode: Optional[Text] = "per_block",
|
||||
name: Optional[Text] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
n_repeats: The batch dimension for the input is expected to have the form
|
||||
`B = b * n_repeats`. After the conv stack, the logits for the
|
||||
`n_repeats` replicas are reduced, leading to an output batch dimension
|
||||
of `b`.
|
||||
blocks_per_group_list: A sequence of length 4 that indicates the number of
|
||||
blocks created in each group.
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers. By default the `decay_rate` is
|
||||
`0.9` and `eps` is `1e-5`.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
|
||||
False.
|
||||
channels_per_group_list: A sequence of length 4 that indicates the number
|
||||
of channels used for each block in each group.
|
||||
use_additional_features: If true, additional vector features will be
|
||||
concatenated to the residual stack before logits are computed.
|
||||
additional_features_mode: Mode for processing additional features.
|
||||
Supported modes: 'mlp' and 'per_block'.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super(ResNet, self).__init__(name=name)
|
||||
self._n_repeats = n_repeats
|
||||
if bn_config is None:
|
||||
bn_config = {"decay_rate": 0.9, "eps": 1e-5}
|
||||
self._bn_config = bn_config
|
||||
self._resnet_v2 = resnet_v2
|
||||
|
||||
# Number of blocks in each group for ResNet.
|
||||
if len(blocks_per_group_list) != 4:
|
||||
raise ValueError(
|
||||
"`blocks_per_group_list` must be of length 4 not {}".format(
|
||||
len(blocks_per_group_list)))
|
||||
self._blocks_per_group_list = blocks_per_group_list
|
||||
|
||||
# Number of channels in each group for ResNet.
|
||||
if len(channels_per_group_list) != 4:
|
||||
raise ValueError(
|
||||
"`channels_per_group_list` must be of length 4 not {}".format(
|
||||
len(channels_per_group_list)))
|
||||
self._channels_per_group_list = channels_per_group_list
|
||||
self._use_additional_features = use_additional_features
|
||||
self._additional_features_mode = additional_features_mode
|
||||
|
||||
self._initial_conv = snt.Conv2D(
|
||||
output_channels=64,
|
||||
kernel_shape=7,
|
||||
stride=2,
|
||||
with_bias=False,
|
||||
padding="SAME",
|
||||
name="initial_conv")
|
||||
if not self._resnet_v2:
|
||||
self._initial_batchnorm = snt.BatchNorm(
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
name="initial_batchnorm",
|
||||
**bn_config)
|
||||
|
||||
self._block_groups = []
|
||||
strides = [1, 2, 2, 2]
|
||||
for i in range(4):
|
||||
self._block_groups.append(
|
||||
snt.nets.resnet.BlockGroup(
|
||||
channels=self._channels_per_group_list[i],
|
||||
num_blocks=self._blocks_per_group_list[i],
|
||||
stride=strides[i],
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
name="block_group_%d" % (i)))
|
||||
|
||||
if self._resnet_v2:
|
||||
self._final_batchnorm = snt.BatchNorm(
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
name="final_batchnorm",
|
||||
**bn_config)
|
||||
|
||||
self._logits = snt.Linear(
|
||||
output_size=num_classes,
|
||||
w_init=snt.initializers.VarianceScaling(scale=2.0), name="logits")
|
||||
|
||||
if self._use_additional_features:
|
||||
self._embedding = LinearBNReLU(output_size=16, name="embedding",
|
||||
**bn_config)
|
||||
|
||||
if self._additional_features_mode == "mlp":
|
||||
self._feature_repr = LinearBNReLU(
|
||||
output_size=self._channels_per_group_list[-1], name="features_repr",
|
||||
**bn_config)
|
||||
elif self._additional_features_mode == "per_block":
|
||||
self._feature_repr = []
|
||||
for i, ch in enumerate(self._channels_per_group_list):
|
||||
self._feature_repr.append(
|
||||
LinearBNReLU(output_size=ch, name=f"features_{i}", **bn_config))
|
||||
else:
|
||||
raise ValueError(f"Unsupported addiitonal features mode: "
|
||||
f"{additional_features_mode}")
|
||||
|
||||
def __call__(self, inputs, features, is_training):
|
||||
net = inputs
|
||||
net = self._initial_conv(net)
|
||||
if not self._resnet_v2:
|
||||
net = self._initial_batchnorm(net, is_training=is_training)
|
||||
net = tf.nn.relu(net)
|
||||
|
||||
net = tf.nn.max_pool2d(
|
||||
net, ksize=3, strides=2, padding="SAME", name="initial_max_pool")
|
||||
|
||||
if self._use_additional_features:
|
||||
assert features is not None
|
||||
features = self._embedding(features, is_training=is_training)
|
||||
|
||||
for i, block_group in enumerate(self._block_groups):
|
||||
net = block_group(net, is_training)
|
||||
|
||||
if (self._use_additional_features and
|
||||
self._additional_features_mode == "per_block"):
|
||||
features_i = self._feature_repr[i](features, is_training=is_training)
|
||||
# support for n_repeats > 1
|
||||
features_i = tf.repeat(features_i, self._n_repeats, axis=0)
|
||||
net += features_i[:, None, None, :] # expand to spacial resolution
|
||||
|
||||
if self._resnet_v2:
|
||||
net = self._final_batchnorm(net, is_training=is_training)
|
||||
net = tf.nn.relu(net)
|
||||
net = tf.reduce_mean(net, axis=[1, 2], name="final_avg_pool")
|
||||
# Re-split the batch dimension
|
||||
net = tf.reshape(net, [-1, self._n_repeats] + net.shape.as_list()[1:])
|
||||
# Average over the various repeats of the input (e.g. those could have
|
||||
# corresponded to different crops).
|
||||
net = tf.reduce_mean(net, axis=1)
|
||||
|
||||
if (self._use_additional_features and
|
||||
self._additional_features_mode == "mlp"):
|
||||
net += self._feature_repr(features, is_training=is_training)
|
||||
|
||||
return self._logits(net)
|
||||
|
||||
|
||||
class LinearBNReLU(snt.Module):
|
||||
"""Wrapper class for Linear layer with Batch Norm and ReLU activation."""
|
||||
|
||||
def __init__(self, output_size=64,
|
||||
w_init=snt.initializers.VarianceScaling(scale=2.0),
|
||||
name="linear", **bn_config):
|
||||
"""Constructs a LinearBNReLU module.
|
||||
|
||||
Args:
|
||||
output_size: Output dimension.
|
||||
w_init: weight Initializer for snt.Linear.
|
||||
name: Name of the module.
|
||||
**bn_config: Optional parameters to be passed to snt.BatchNorm.
|
||||
"""
|
||||
super(LinearBNReLU, self).__init__(name=name)
|
||||
self._linear = snt.Linear(output_size=output_size, w_init=w_init,
|
||||
name=f"{name}_linear")
|
||||
self._bn = snt.BatchNorm(create_scale=True, create_offset=True,
|
||||
name=f"{name}_bn", **bn_config)
|
||||
|
||||
def __call__(self, x, is_training):
|
||||
x = self._linear(x)
|
||||
x = self._bn(x, is_training=is_training)
|
||||
return tf.nn.relu(x)
|
||||
279
galaxy_mergers/preprocessing.py
Normal file
279
galaxy_mergers/preprocessing.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pre-processing functions for input data."""
|
||||
|
||||
import functools
|
||||
from absl import logging
|
||||
import tensorflow.compat.v2 as tf
|
||||
from galaxy_mergers import losses
|
||||
|
||||
|
||||
CROP_TYPE_NONE = 'crop_none'
|
||||
CROP_TYPE_FIXED = 'crop_fixed'
|
||||
CROP_TYPE_RANDOM = 'crop_random'
|
||||
|
||||
DATASET_FREQUENCY_MEAN = 4.0
|
||||
DATASET_FREQUENCY_RANGE = 8.0
|
||||
|
||||
PHYSICAL_FEATURES_MIN_MAX = {
|
||||
'redshift': (0.572788, 2.112304),
|
||||
'mass': (9.823963, 10.951282)
|
||||
}
|
||||
|
||||
ALL_FREQUENCIES = [105, 125, 160, 435, 606, 775, 850]
|
||||
|
||||
VALID_ADDITIONAL_FEATURES = ['redshift', 'sequence_average_redshift', 'mass']
|
||||
|
||||
|
||||
def _make_padding_sizes(pad_size, random_centering):
|
||||
if random_centering:
|
||||
pad_size_left = tf.random.uniform(
|
||||
shape=[], minval=0, maxval=pad_size+1, dtype=tf.int32)
|
||||
else:
|
||||
pad_size_left = pad_size // 2
|
||||
pad_size_right = pad_size - pad_size_left
|
||||
return pad_size_left, pad_size_right
|
||||
|
||||
|
||||
def resize_and_pad(image, target_size, random_centering):
|
||||
"""Resize image to target_size (<= image.size) and pad to original size."""
|
||||
original_shape = image.shape
|
||||
size = tf.reshape(target_size, [1])
|
||||
size = tf.concat([size, size], axis=0)
|
||||
image = tf.image.resize(image, size=size)
|
||||
pad_size = original_shape[1] - target_size
|
||||
pad_size_left, pad_size_right = _make_padding_sizes(
|
||||
pad_size, random_centering)
|
||||
padding = [[pad_size_left, pad_size_right],
|
||||
[pad_size_left, pad_size_right], [0, 0]]
|
||||
if len(original_shape) == 4:
|
||||
padding = [[0, 0]] + padding
|
||||
image = tf.pad(image, padding)
|
||||
image.set_shape(original_shape)
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_extract(image, target_size, random_centering):
|
||||
"""Upscale image to target_size (>image.size), extract original size crop."""
|
||||
original_shape = image.shape
|
||||
size = tf.reshape(target_size, [1])
|
||||
size = tf.concat([size, size], axis=0)
|
||||
image = tf.image.resize(image, size=size)
|
||||
pad_size = target_size - original_shape[1]
|
||||
pad_size_left, pad_size_right = _make_padding_sizes(
|
||||
pad_size, random_centering)
|
||||
if len(original_shape) == 3:
|
||||
image = tf.expand_dims(image, 0)
|
||||
image = tf.cond(pad_size_right > 0,
|
||||
lambda: image[:, pad_size_left:-pad_size_right, :, :],
|
||||
lambda: image[:, pad_size_left:, :, :])
|
||||
image = tf.cond(pad_size_right > 0,
|
||||
lambda: image[:, :, pad_size_left:-pad_size_right, :],
|
||||
lambda: image[:, :, pad_size_left:, :])
|
||||
if len(original_shape) == 3:
|
||||
image = tf.squeeze(image, 0)
|
||||
image.set_shape(original_shape)
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_center(image, target_size, random_centering):
|
||||
return tf.cond(
|
||||
tf.math.less_equal(target_size, image.shape[1]),
|
||||
lambda: resize_and_pad(image, target_size, random_centering),
|
||||
lambda: resize_and_extract(image, target_size, random_centering))
|
||||
|
||||
|
||||
def random_rotation_and_flip(image):
|
||||
angle = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
|
||||
return tf.image.random_flip_left_right(tf.image.rot90(image, angle))
|
||||
|
||||
|
||||
def get_all_rotations_and_flips(images):
|
||||
assert isinstance(images, list)
|
||||
new_images = []
|
||||
for image in images:
|
||||
for rotation in range(4):
|
||||
new_images.append(tf.image.rot90(image, rotation))
|
||||
flipped_image = tf.image.flip_left_right(image)
|
||||
new_images.append(tf.image.rot90(flipped_image, rotation))
|
||||
return new_images
|
||||
|
||||
|
||||
def random_rescaling(image, random_centering):
|
||||
assert image.shape.as_list()[0] == image.shape.as_list()[1]
|
||||
original_size = image.shape.as_list()[1]
|
||||
min_size = 2 * (original_size // 4)
|
||||
max_size = original_size * 2
|
||||
target_size = tf.random.uniform(
|
||||
shape=[], minval=min_size, maxval=max_size // 2,
|
||||
dtype=tf.int32) * 2
|
||||
return resize_and_center(image, target_size, random_centering)
|
||||
|
||||
|
||||
def get_all_rescalings(images, image_width, random_centering):
|
||||
"""Get a uniform sample of rescalings of all images in input."""
|
||||
assert isinstance(images, list)
|
||||
min_size = 2 * (image_width // 4)
|
||||
max_size = image_width * 2
|
||||
delta_size = (max_size + 2 - min_size) // 5
|
||||
sizes = range(min_size, max_size + 2, delta_size)
|
||||
new_images = []
|
||||
for image in images:
|
||||
for size in sizes:
|
||||
new_images.append(resize_and_center(image, size, random_centering))
|
||||
return new_images
|
||||
|
||||
|
||||
def move_repeats_to_batch(image, n_repeats):
|
||||
width, height, n_channels = image.shape.as_list()[1:]
|
||||
image = tf.reshape(image, [-1, width, height, n_channels, n_repeats])
|
||||
image = tf.transpose(image, [0, 4, 1, 2, 3]) # [B, repeats, x, y, c]
|
||||
return tf.reshape(image, [-1, width, height, n_channels])
|
||||
|
||||
|
||||
def get_classification_label(dataset_row, class_boundaries):
|
||||
merge_time = dataset_row['grounded_normalized_time']
|
||||
label = tf.dtypes.cast(0, tf.int64)
|
||||
for category, intervals in class_boundaries.items():
|
||||
for interval in intervals:
|
||||
if merge_time > interval[0] and merge_time < interval[1]:
|
||||
label = tf.dtypes.cast(int(category), tf.int64)
|
||||
return label
|
||||
|
||||
|
||||
def get_regression_label(dataset_row, task_type):
|
||||
"""Returns time-until-merger regression target given desired modeling task."""
|
||||
if task_type == losses.TASK_NORMALIZED_REGRESSION:
|
||||
return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32)
|
||||
elif task_type == losses.TASK_GROUNDED_UNNORMALIZED_REGRESSION:
|
||||
return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32)
|
||||
elif task_type == losses.TASK_UNNORMALIZED_REGRESSION:
|
||||
return tf.dtypes.cast(dataset_row['unnormalized_time'], tf.float32)
|
||||
elif task_type == losses.TASK_CLASSIFICATION:
|
||||
return tf.dtypes.cast(dataset_row['grounded_normalized_time'], tf.float32)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_normalized_time_target(dataset_row):
|
||||
return tf.dtypes.cast(dataset_row['normalized_time'], tf.float32)
|
||||
|
||||
|
||||
def apply_time_filter(dataset_row, time_interval):
|
||||
"""Returns True if data is within the given time intervals."""
|
||||
merge_time = dataset_row['grounded_normalized_time']
|
||||
lower_time, upper_time = time_interval
|
||||
return merge_time > lower_time and merge_time < upper_time
|
||||
|
||||
|
||||
def normalize_physical_feature(name, dataset_row):
|
||||
min_feat, max_feat = PHYSICAL_FEATURES_MIN_MAX[name]
|
||||
value = getattr(dataset_row, name)
|
||||
return 2 * (value - min_feat) / (max_feat - min_feat) - 1
|
||||
|
||||
|
||||
def prepare_dataset(ds, target_size, crop_type, n_repeats, augmentations,
|
||||
task_type, additional_features, class_boundaries,
|
||||
time_intervals=None, frequencies_to_use='all',
|
||||
additional_lambdas=None):
|
||||
"""Prepare a zipped dataset of image, classification/regression labels."""
|
||||
def _prepare_image(dataset_row):
|
||||
"""Transpose, crop and cast an image."""
|
||||
image = tf.dtypes.cast(dataset_row['image'], tf.float32)
|
||||
image = tf.reshape(image, tf.cast(dataset_row['image_shape'], tf.int32))
|
||||
image = tf.transpose(image, perm=[1, 2, 0]) # Convert to NHWC
|
||||
|
||||
freqs = ALL_FREQUENCIES if frequencies_to_use == 'all' else frequencies_to_use
|
||||
idxs_to_keep = [ALL_FREQUENCIES.index(f) for f in freqs]
|
||||
image = tf.gather(params=image, indices=idxs_to_keep, axis=-1)
|
||||
|
||||
# Based on offline computation on the empirical frequency range:
|
||||
# Converts [0, 8.] ~~> [-1, 1]
|
||||
image = (image - DATASET_FREQUENCY_MEAN)/(DATASET_FREQUENCY_RANGE/2.0)
|
||||
|
||||
def crop(image):
|
||||
if crop_type == CROP_TYPE_FIXED:
|
||||
crop_loc = tf.cast(dataset_row['proposed_crop'][0], tf.int32)
|
||||
crop_size = tf.cast(dataset_row['proposed_crop'][1], tf.int32)
|
||||
image = image[
|
||||
crop_loc[0]:crop_loc[0] + crop_size[0],
|
||||
crop_loc[1]:crop_loc[1] + crop_size[1], :]
|
||||
image = tf.image.resize(image, target_size[0:2])
|
||||
image.set_shape([target_size[0], target_size[1], target_size[2]])
|
||||
|
||||
elif crop_type == CROP_TYPE_RANDOM:
|
||||
image = tf.image.random_crop(image, target_size)
|
||||
image.set_shape([target_size[0], target_size[1], target_size[2]])
|
||||
|
||||
elif crop_type != CROP_TYPE_NONE:
|
||||
raise NotImplementedError
|
||||
|
||||
return image
|
||||
|
||||
repeated_images = []
|
||||
for _ in range(n_repeats):
|
||||
repeated_images.append(crop(image))
|
||||
image = tf.concat(repeated_images, axis=-1)
|
||||
|
||||
if augmentations['rotation_and_flip']:
|
||||
image = random_rotation_and_flip(image)
|
||||
if augmentations['rescaling']:
|
||||
image = random_rescaling(image, augmentations['translation'])
|
||||
|
||||
return image
|
||||
|
||||
def get_regression_label_wrapper(dataset_row):
|
||||
return get_regression_label(dataset_row, task_type=task_type)
|
||||
|
||||
def get_classification_label_wrapper(dataset_row):
|
||||
return get_classification_label(dataset_row,
|
||||
class_boundaries=class_boundaries)
|
||||
|
||||
if time_intervals:
|
||||
for time_interval in time_intervals:
|
||||
filter_fn = functools.partial(apply_time_filter,
|
||||
time_interval=time_interval)
|
||||
ds = ds.filter(filter_fn)
|
||||
|
||||
datasets = [ds.map(_prepare_image)]
|
||||
|
||||
if additional_features:
|
||||
additional_features = additional_features.split(',')
|
||||
assert all([f in VALID_ADDITIONAL_FEATURES for f in additional_features])
|
||||
logging.info('Running with additional features: %s.',
|
||||
', '.join(additional_features))
|
||||
|
||||
def _prepare_additional_features(dataset_row):
|
||||
features = []
|
||||
for f in additional_features:
|
||||
features.append(normalize_physical_feature(f, dataset_row))
|
||||
features = tf.convert_to_tensor(features, dtype=tf.float32)
|
||||
features.set_shape([len(additional_features)])
|
||||
return features
|
||||
|
||||
datasets += [ds.map(_prepare_additional_features)]
|
||||
|
||||
datasets += [
|
||||
ds.map(get_classification_label_wrapper),
|
||||
ds.map(get_regression_label_wrapper),
|
||||
ds.map(get_normalized_time_target)]
|
||||
|
||||
if additional_lambdas:
|
||||
for process_fn in additional_lambdas:
|
||||
datasets += [ds.map(process_fn)]
|
||||
|
||||
return tf.data.Dataset.zip(tuple(datasets))
|
||||
|
||||
53
galaxy_mergers/requirements.txt
Normal file
53
galaxy_mergers/requirements.txt
Normal file
@@ -0,0 +1,53 @@
|
||||
absl-py==0.11.0
|
||||
astropy==4.2
|
||||
astunparse==1.6.3
|
||||
cachetools==4.2.1
|
||||
certifi==2020.12.5
|
||||
chardet==4.0.0
|
||||
cloudpickle==1.6.0
|
||||
contextlib2==0.6.0.post1
|
||||
cycler==0.10.0
|
||||
decorator==4.4.2
|
||||
dm-sonnet==2.0.0
|
||||
dm-tree==0.1.5
|
||||
flatbuffers==1.12
|
||||
gast==0.3.3
|
||||
google-auth==1.27.0
|
||||
google-auth-oauthlib==0.4.2
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.32.0
|
||||
h5py==2.10.0
|
||||
idna==2.10
|
||||
Keras-Preprocessing==1.1.2
|
||||
kiwisolver==1.3.1
|
||||
Markdown==3.3.4
|
||||
matplotlib==3.3.4
|
||||
ml-collections==0.1.0
|
||||
numpy==1.19.5
|
||||
oauthlib==3.1.0
|
||||
opt-einsum==3.3.0
|
||||
Pillow==8.1.0
|
||||
pkg-resources==0.0.0
|
||||
protobuf==3.15.3
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pyerfa==1.7.2
|
||||
pyparsing==2.4.7
|
||||
python-dateutil==2.8.1
|
||||
PyYAML==5.4.1
|
||||
requests==2.25.1
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.7.2
|
||||
scipy==1.6.1
|
||||
six==1.15.0
|
||||
tabulate==0.8.9
|
||||
tensorboard==2.4.1
|
||||
tensorboard-plugin-wit==1.8.0
|
||||
tensorflow==2.4.1
|
||||
tensorflow-estimator==2.4.0
|
||||
tensorflow-probability==0.12.1
|
||||
termcolor==1.1.0
|
||||
typing-extensions==3.7.4.3
|
||||
urllib3==1.26.3
|
||||
Werkzeug==1.0.1
|
||||
wrapt==1.12.1
|
||||
Reference in New Issue
Block a user