mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 14:45:25 +08:00
Correct package installation in the provided colab.
PiperOrigin-RevId: 282344936
This commit is contained in:
committed by
Diego de Las Casas
parent
cdabc4cb84
commit
2cb746a4b0
File diff suppressed because one or more lines are too long
@@ -0,0 +1,142 @@
|
||||
# Hierarchical Probabilistic U-Net
|
||||
|
||||

|
||||
|
||||
This package provides an implementation of the Hierarchical Probabilistic U-Net
|
||||
(HPU-Net) as published in [A Hierarchical Probabilistic U-Net for Modeling
|
||||
Multi-Scale Ambiguities (2019)](https://arxiv.org/abs/1905.13077).
|
||||
|
||||
The HPU-Net combines a hierarchical VAE with a U-Net and learns an
|
||||
image-conditional distribution over plausible outputs, here segmentation maps.
|
||||
The hierarchical latent space decomposition allows to model independent
|
||||
variations across locations and scales and increases the granularity of
|
||||
predicted segmentations as compared to prior work ([the Probabilisitc
|
||||
U-Net](https://arxiv.org/abs/1806.05034)).
|
||||
The architecture, depicted above, interleaves the U-Net's decoder with a
|
||||
prior that is used when sampling at test time (see Subfigure a) above). Training
|
||||
proceeds as is typical for VAEs, i.e. a separate posterior network is employed
|
||||
whose latents are injected into the decoder (see Subfigure b) above).
|
||||
|
||||
The animated gif below shows 16 segmentation samples when sampling from either
|
||||
a) the full hierarchy or fixing some of the latents to the prior's mean: b)
|
||||
fixing all but the most local latents and c) fixing all but the most global
|
||||
latent. The first row depicts CT scans showing potential lung abnormalities and
|
||||
the rows below show individual samples and the standard deviations cross them.
|
||||
|
||||

|
||||
|
||||
In addition to the model code we provide the preprocessed version of the
|
||||
[LIDC-IDRI
|
||||
dataset](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI)
|
||||
that we employ as well as pretrained model weights. Both can be loaded in a
|
||||
public colab, see below.
|
||||
|
||||
## Colab [](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/hierarchical_probabilistic_unet/HPU_Net.ipynb)
|
||||
To quickly tinker with the pretrained model and the dataset without the need of
|
||||
installing anything locally click the `Open in Colab`-button above to follow the
|
||||
link to the colab.
|
||||
|
||||
|
||||
## Installation
|
||||
To install the package locally run:
|
||||
```bash
|
||||
git clone https://github.com/deepmind/hierarchical_prob_unet .
|
||||
cd hierarchical_prob_unet
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## LIDC 2D crops
|
||||
We provide a preprocessed version of the Lung Image Database Consortium image
|
||||
collection dataset
|
||||
([LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI))
|
||||
as used and described in [A Hierarchical Probabilistic U-Net for Modeling
|
||||
Multi-Scale Ambiguities (2019)](https://arxiv.org/abs/1905.13077).
|
||||
|
||||
The original dataset consists of 3D lung CT scans with semantic segmentations of
|
||||
possible lung abnormalities, as graded by four expert readers. We have cleaned
|
||||
up the data resulting in a slightly changed number of images for
|
||||
each data set split, see below (which leaves the results the same). The data is
|
||||
hosted in Google Cloud Storage ([this bucket](https://console.cloud.google.com/storage/browser/hpunet-data/lidc_crops/)).
|
||||
|
||||
#####Preprocessing:
|
||||
|
||||
We resampled the CT scans to 0.5mm x 0.5mm in-plane resolution and then cropped
|
||||
2D images of size 180 x 180 pixels, centered on the abnormality position. The
|
||||
abnormality positions are those where at least one of the experts segmented an
|
||||
abnormality and we assumed that two masks from different graders correspond to
|
||||
the same abnormality if their tightest bounding boxes overlap. We only used
|
||||
those abnormaliies that were specified as a polygon (outline) in the XML files
|
||||
of the LIDC dataset, disregarding the ones that only have center of shape. That
|
||||
is, according to the LIDC paper we use the ones that are larger than 3mm, and
|
||||
filter out the others, that are clinically less relevant ([2], see below).
|
||||
We also filterd out each Dicom file whose absolute value in the XML element of
|
||||
`SliceLocation` differs from the absolute value of the last element in
|
||||
`ImagePositionPatient`.
|
||||
|
||||
This preprocessing results in 8843 images for training, 1993 for validation and
|
||||
1980 for testing (corresponding to 530, 111 and 103 patients respectively).
|
||||
|
||||
#####Directory Structure
|
||||
|
||||
The GCS bucket contains tar.gz-files for the training, validation and test data.
|
||||
Each tar.gz-file contains a zipped directory with two subdirectories, one named
|
||||
`images` and `gt`. Their subdirectories comprise a directory for each patient
|
||||
that is part of that data split. Each patient's directory holds the
|
||||
corresponding cropped 2D images in .png-format. The naming scheme follows
|
||||
`z-<imageZposition>_c<crop number of the slice>.png` for CT scans and
|
||||
`z-<imageZposition>_c<crop number of the slice>_l<labeller id in [1,4]>.png` for
|
||||
the binary segmentation maps, allowing to match the images and their
|
||||
corresponding four annotations.
|
||||
|
||||
|
||||
#####Citations & Data Usage Policy:
|
||||
|
||||
The [LIDC-IDRI
|
||||
dataset](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI)
|
||||
was published in [1, 2, 3] and is made available under the [CC BY 3.0
|
||||
license](https://creativecommons.org/licenses/by/3.0/).
|
||||
|
||||
[1] Armato III, Samuel G., McLennan, Geoffrey, Bidaut, Luc, McNitt-Gray, Michael
|
||||
F., Meyer, Charles R., Reeves, Anthony P., … Clarke, Laurence P. (2015). Data
|
||||
From LIDC-IDRI. The Cancer Imaging Archive.
|
||||
([Link](http://doi.org/10.7937/K9/TCIA.2015.LO9QL9SX))
|
||||
|
||||
[2] Armato SG III, McLennan G, Bidaut L, McNitt-Gray MF, Meyer CR, Reeves AP,
|
||||
Zhao B, Aberle DR, Henschke CI, Hoffman EA, Kazerooni EA, MacMahon H, van Beek
|
||||
EJR, Yankelevitz D, et al.: The Lung Image Database Consortium (LIDC) and Image
|
||||
Database Resource Initiative (IDRI): A completed reference database of lung
|
||||
nodules on CT scans. Medical Physics, 38: 915--931, 2011.
|
||||
([Paper](https://www.ncbi.nlm.nih.gov/pubmed/21452728))
|
||||
|
||||
[3] Clark K, Vendt B, Smith K, Freymann J, Kirby J, Koppel P, Moore S, Phillips
|
||||
S, Maffitt D, Pringle M, Tarbox L, Prior F. The Cancer Imaging Archive (TCIA):
|
||||
Maintaining and Operating a Public Information Repository, Journal of Digital
|
||||
Imaging, Volume 26, Number 6, December, 2013, pp 1045-1057.
|
||||
([Paper](https://link.springer.com/article/10.1007%2Fs10278-013-9622-7))
|
||||
|
||||
We make the `LIDC 2D crops` data, accessible from above GCS bucket, available
|
||||
under the [CC BY 3.0 license](https://creativecommons.org/licenses/by/3.0/).
|
||||
|
||||
|
||||
## Pretrained Model
|
||||
We provide a pretrained model checkpoint ([Google Cloud Storage
|
||||
bucket](https://console.cloud.google.com/storage/browser/hpunet-data/model_checkpoint/))
|
||||
that can be loaded as exemplified in our colab.
|
||||
|
||||
|
||||
## Giving Credit
|
||||
|
||||
If you use this code in your work, we ask you to cite this paper:
|
||||
|
||||
```
|
||||
@article{kohl2019hierarchical,
|
||||
title={A Hierarchical Probabilistic U-Net for Modeling Multi-Scale Ambiguities},
|
||||
author={Kohl, Simon AA and Romera-Paredes, Bernardino and Maier-Hein, Klaus H and Rezende, Danilo Jimenez and Eslami, SM and Kohli, Pushmeet and Zisserman, Andrew and Ronneberger, Olaf},
|
||||
journal={arXiv preprint arXiv:1905.13077},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official Google product.
|
||||
@@ -0,0 +1,163 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility Functions for the GECO-objective.
|
||||
|
||||
(GECO is described in `Taming VAEs`, see https://arxiv.org/abs/1810.00597).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class MovingAverage(snt.AbstractModule):
|
||||
"""A thin wrapper around snt.MovingAverage.
|
||||
|
||||
The module adds the option not to differentiate through the last element that
|
||||
is added to the moving average, specified by means of the kwarg
|
||||
`differentiable`.
|
||||
"""
|
||||
|
||||
def __init__(self, decay, local=True, differentiable=False,
|
||||
name='snt_moving_average'):
|
||||
super(MovingAverage, self).__init__(name=name)
|
||||
self._differentiable = differentiable
|
||||
self._moving_average = snt.MovingAverage(
|
||||
decay=decay, local=local, name=name)
|
||||
|
||||
def _build(self, inputs):
|
||||
if not self._differentiable:
|
||||
inputs = tf.stop_gradient(inputs)
|
||||
return self._moving_average(inputs)
|
||||
|
||||
|
||||
class LagrangeMultiplier(snt.AbstractModule):
|
||||
"""A lagrange multiplier sonnet module."""
|
||||
|
||||
def __init__(self,
|
||||
rate=1e-2,
|
||||
name='snt_lagrange_multiplier'):
|
||||
"""Initializer for the sonnet module.
|
||||
|
||||
Args:
|
||||
rate: Scalar used to scale the magnitude of gradients of the Lagrange
|
||||
multipliers, defaulting to 1e-2.
|
||||
name: Name of the Lagrange multiplier sonnet module.
|
||||
"""
|
||||
super(LagrangeMultiplier, self).__init__(name=name)
|
||||
self._rate = rate
|
||||
|
||||
def _build(self, ma_constraint):
|
||||
"""Connects the module to the graph.
|
||||
|
||||
Args:
|
||||
ma_constraint: A loss minus a target value, denoting a constraint that
|
||||
shall be less or equal than zero.
|
||||
|
||||
Returns:
|
||||
An op, which when added to a loss and calling minimize on the loss
|
||||
results in the optimizer minimizing w.r.t. to the model's parameters and
|
||||
maximizing w.r.t. the Lagrande multipliers, hence enforcing the
|
||||
constraints.
|
||||
"""
|
||||
lagmul = snt.get_lagrange_multiplier(
|
||||
shape=ma_constraint.shape, rate=self._rate,
|
||||
initializer=np.ones(ma_constraint.shape))
|
||||
return lagmul
|
||||
|
||||
|
||||
def _sample_gumbel(shape, eps=1e-20):
|
||||
"""Transforms a uniform random variable to be standard Gumbel distributed."""
|
||||
|
||||
return -tf.log(
|
||||
-tf.log(tf.random_uniform(shape, minval=0, maxval=1) + eps) + eps)
|
||||
|
||||
|
||||
def _topk_mask(score, k):
|
||||
"""Returns a mask for the top-k elements in score."""
|
||||
|
||||
_, indices = tf.nn.top_k(score, k=k)
|
||||
return tf.scatter_nd(tf.expand_dims(indices, -1), tf.ones(k),
|
||||
tf.squeeze(score).shape.as_list())
|
||||
|
||||
|
||||
def ce_loss(logits, labels, mask=None, top_k_percentage=None,
|
||||
deterministic=False):
|
||||
"""Computes the cross-entropy loss.
|
||||
|
||||
Optionally a mask and a top-k percentage for the used pixels can be specified.
|
||||
|
||||
The top-k mask can be produced deterministically or sampled.
|
||||
Args:
|
||||
logits: A tensor of shape (b,h,w,num_classes)
|
||||
labels: A tensor of shape (b,h,w,num_classes)
|
||||
mask: None or a tensor of shape (b,h,w).
|
||||
top_k_percentage: None or a float in (0.,1.]. If None, a standard
|
||||
cross-entropy loss is calculated.
|
||||
deterministic: A Boolean indicating whether or not to produce the
|
||||
prospective top-k mask deterministically.
|
||||
|
||||
Returns:
|
||||
A dictionary holding the mean and the pixelwise sum of the loss for the
|
||||
batch as well as the employed loss mask.
|
||||
"""
|
||||
num_classes = logits.shape.as_list()[-1]
|
||||
y_flat = tf.reshape(logits, (-1, num_classes), name='reshape_y')
|
||||
t_flat = tf.reshape(labels, (-1, num_classes), name='reshape_t')
|
||||
if mask is None:
|
||||
mask = tf.ones(shape=(t_flat.shape.as_list()[0],))
|
||||
else:
|
||||
assert mask.shape.as_list()[:3] == labels.shape.as_list()[:3],\
|
||||
'The loss mask shape differs from the target shape: {} vs. {}.'.format(
|
||||
mask.shape.as_list(), labels.shape.as_list()[:3])
|
||||
mask = tf.reshape(mask, (-1,), name='reshape_mask')
|
||||
|
||||
n_pixels_in_batch = y_flat.shape.as_list()[0]
|
||||
xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=t_flat, logits=y_flat)
|
||||
|
||||
if top_k_percentage is not None:
|
||||
assert 0.0 < top_k_percentage <= 1.0
|
||||
k_pixels = tf.cast(tf.floor(n_pixels_in_batch * top_k_percentage), tf.int32)
|
||||
|
||||
stopgrad_xe = tf.stop_gradient(xe)
|
||||
norm_xe = stopgrad_xe / tf.reduce_sum(stopgrad_xe)
|
||||
|
||||
if deterministic:
|
||||
score = tf.log(norm_xe)
|
||||
else:
|
||||
# Use the Gumbel trick to sample the top-k pixels, equivalent to sampling
|
||||
# from a categorical distribution over pixels whose probabilities are
|
||||
# given by the normalized cross-entropy loss values. This is done by
|
||||
# adding Gumbel noise to the logarithmic normalized cross-entropy loss
|
||||
# (followed by choosing the top-k pixels).
|
||||
score = tf.log(norm_xe) + _sample_gumbel(norm_xe.shape.as_list())
|
||||
|
||||
score = score + tf.log(mask)
|
||||
top_k_mask = _topk_mask(score, k_pixels)
|
||||
mask = mask * top_k_mask
|
||||
|
||||
# Calculate batch-averages for the sum and mean of the loss
|
||||
batch_size = labels.shape.as_list()[0]
|
||||
xe = tf.reshape(xe, shape=(batch_size, -1))
|
||||
mask = tf.reshape(mask, shape=(batch_size, -1))
|
||||
ce_sum_per_instance = tf.reduce_sum(mask * xe, axis=1)
|
||||
ce_sum = tf.reduce_mean(ce_sum_per_instance, axis=0)
|
||||
ce_mean = tf.reduce_sum(mask * xe) / tf.reduce_sum(mask)
|
||||
|
||||
return {'mean': ce_mean, 'sum': ce_sum, 'mask': mask}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 11 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 707 KiB |
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,113 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the Hierarchical Probabilistic U-Net open-source version."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from model import HierarchicalProbUNet
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
_NUM_CLASSES = 2
|
||||
_BATCH_SIZE = 2
|
||||
_SPATIAL_SHAPE = [32, 32]
|
||||
_CHANNELS_PER_BLOCK = [5, 7, 9, 11, 13]
|
||||
_IMAGE_SHAPE = [_BATCH_SIZE] + _SPATIAL_SHAPE + [1]
|
||||
_BOTTLENECK_SIZE = _SPATIAL_SHAPE[0] // 2 ** (len(_CHANNELS_PER_BLOCK) - 1)
|
||||
_SEGMENTATION_SHAPE = [_BATCH_SIZE] + _SPATIAL_SHAPE + [_NUM_CLASSES]
|
||||
_LATENT_DIMS = [3, 2, 1]
|
||||
_INITIALIZERS = {'w': tf.orthogonal_initializer(gain=1.0, seed=None),
|
||||
'b': tf.truncated_normal_initializer(stddev=0.001)}
|
||||
|
||||
|
||||
def _get_placeholders():
|
||||
"""Returns placeholders for the image and segmentation."""
|
||||
img = tf.placeholder(dtype=tf.float32, shape=_IMAGE_SHAPE)
|
||||
seg = tf.placeholder(dtype=tf.float32, shape=_SEGMENTATION_SHAPE)
|
||||
return img, seg
|
||||
|
||||
|
||||
class HierarchicalProbUNetTest(tf.test.TestCase):
|
||||
|
||||
def test_shape_of_sample(self):
|
||||
hpu_net = HierarchicalProbUNet(latent_dims=_LATENT_DIMS,
|
||||
channels_per_block=_CHANNELS_PER_BLOCK,
|
||||
num_classes=_NUM_CLASSES,
|
||||
initializers=_INITIALIZERS)
|
||||
img, _ = _get_placeholders()
|
||||
sample = hpu_net.sample(img)
|
||||
self.assertEqual(sample.shape.as_list(), _SEGMENTATION_SHAPE)
|
||||
|
||||
def test_shape_of_reconstruction(self):
|
||||
hpu_net = HierarchicalProbUNet(latent_dims=_LATENT_DIMS,
|
||||
channels_per_block=_CHANNELS_PER_BLOCK,
|
||||
num_classes=_NUM_CLASSES,
|
||||
initializers=_INITIALIZERS)
|
||||
img, seg = _get_placeholders()
|
||||
reconstruction = hpu_net.reconstruct(img, seg)
|
||||
self.assertEqual(reconstruction.shape.as_list(), _SEGMENTATION_SHAPE)
|
||||
|
||||
def test_shapes_in_prior(self):
|
||||
hpu_net = HierarchicalProbUNet(latent_dims=_LATENT_DIMS,
|
||||
channels_per_block=_CHANNELS_PER_BLOCK,
|
||||
num_classes=_NUM_CLASSES,
|
||||
initializers=_INITIALIZERS)
|
||||
img, _ = _get_placeholders()
|
||||
prior_out = hpu_net._prior(img)
|
||||
distributions = prior_out['distributions']
|
||||
latents = prior_out['used_latents']
|
||||
encoder_features = prior_out['encoder_features']
|
||||
decoder_features = prior_out['decoder_features']
|
||||
|
||||
# Test number of latent disctributions.
|
||||
self.assertEqual(len(distributions), len(_LATENT_DIMS))
|
||||
|
||||
# Test shapes of latent scales.
|
||||
for level in range(len(_LATENT_DIMS)):
|
||||
latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** level
|
||||
latent_shape = [_BATCH_SIZE, latent_spatial_shape, latent_spatial_shape,
|
||||
_LATENT_DIMS[level]]
|
||||
self.assertEqual(latents[level].shape.as_list(), latent_shape)
|
||||
|
||||
# Test encoder shapes.
|
||||
for level in range(len(_CHANNELS_PER_BLOCK)):
|
||||
spatial_shape = _SPATIAL_SHAPE[0] // 2 ** level
|
||||
feature_shape = [_BATCH_SIZE, spatial_shape, spatial_shape,
|
||||
_CHANNELS_PER_BLOCK[level]]
|
||||
self.assertEqual(encoder_features[level].shape.as_list(), feature_shape)
|
||||
|
||||
# Test decoder shape.
|
||||
start_level = len(_LATENT_DIMS)
|
||||
latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** start_level
|
||||
latent_shape = [_BATCH_SIZE, latent_spatial_shape, latent_spatial_shape,
|
||||
_CHANNELS_PER_BLOCK[::-1][start_level]]
|
||||
self.assertEqual(decoder_features.shape.as_list(), latent_shape)
|
||||
|
||||
def test_shape_of_kl(self):
|
||||
hpu_net = HierarchicalProbUNet(latent_dims=_LATENT_DIMS,
|
||||
channels_per_block=_CHANNELS_PER_BLOCK,
|
||||
num_classes=_NUM_CLASSES,
|
||||
initializers=_INITIALIZERS)
|
||||
img, seg = _get_placeholders()
|
||||
kl_dict = hpu_net.kl(img, seg)
|
||||
self.assertEqual(len(kl_dict), len(_LATENT_DIMS))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv hpu-net-venv
|
||||
source hpu-net-venv/bin/activate
|
||||
pip3 install .
|
||||
python3 model_test.py
|
||||
deactivate
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Setup for pip package."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
REQUIRED_PACKAGES = ['numpy', 'dm-sonnet==1.35', 'tensorflow==1.14',
|
||||
'tensorflow-probability==0.7.0']
|
||||
|
||||
setup(
|
||||
name='hpu_net',
|
||||
version='0.1',
|
||||
description='A library for the Hierarchical Probabilistic U-Net model.',
|
||||
url='https://github.com/deepmind/deepmind-research/hierarchical_probabilistic_unet',
|
||||
author='DeepMind',
|
||||
author_email='no-reply@google.com',
|
||||
# Contained modules and scripts.
|
||||
packages=find_packages(),
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
platforms=['any'],
|
||||
license='Apache 2.0',
|
||||
)
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Architectural blocks and utility functions of the U-Net."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def res_block(input_features, n_channels, n_down_channels=None,
|
||||
activation_fn=tf.nn.relu, initializers=None, regularizers=None,
|
||||
convs_per_block=3):
|
||||
"""A pre-activated residual block.
|
||||
|
||||
Args:
|
||||
input_features: A tensor of shape (b, h, w, c).
|
||||
n_channels: An integer specifying the number of output channels.
|
||||
n_down_channels: An integer specifying the number of intermediate channels.
|
||||
activation_fn: A callable activation function.
|
||||
initializers: Initializers for the weights and biases.
|
||||
regularizers: Regularizers for the weights and biases.
|
||||
convs_per_block: An Integer specifying the number of convolutional layers.
|
||||
Returns:
|
||||
A tensor of shape (b, h, w, c).
|
||||
"""
|
||||
# Pre-activate the inputs.
|
||||
skip = input_features
|
||||
residual = activation_fn(input_features)
|
||||
|
||||
# Set the number of intermediate channels that we compress to.
|
||||
if n_down_channels is None:
|
||||
n_down_channels = n_channels
|
||||
|
||||
for c in range(convs_per_block):
|
||||
residual = snt.Conv2D(n_down_channels,
|
||||
(3, 3),
|
||||
padding='SAME',
|
||||
initializers=initializers,
|
||||
regularizers=regularizers)(residual)
|
||||
if c < convs_per_block - 1:
|
||||
residual = activation_fn(residual)
|
||||
|
||||
incoming_channels = input_features.shape[-1]
|
||||
if incoming_channels != n_channels:
|
||||
skip = snt.Conv2D(n_channels,
|
||||
(1, 1),
|
||||
padding='SAME',
|
||||
initializers=initializers,
|
||||
regularizers=regularizers)(skip)
|
||||
if n_down_channels != n_channels:
|
||||
residual = snt.Conv2D(n_channels,
|
||||
(1, 1),
|
||||
padding='SAME',
|
||||
initializers=initializers,
|
||||
regularizers=regularizers)(residual)
|
||||
return skip + residual
|
||||
|
||||
|
||||
def resize_up(input_features, scale=2):
|
||||
"""Nearest neighbor rescaling-operation for the input features.
|
||||
|
||||
Args:
|
||||
input_features: A tensor of shape (b, h, w, c).
|
||||
scale: An integer specifying the scaling factor.
|
||||
Returns: A tensor of shape (b, scale * h, scale * w, c).
|
||||
"""
|
||||
assert scale >= 1
|
||||
_, size_x, size_y, _ = input_features.shape.as_list()
|
||||
new_size_x = int(round(size_x * scale))
|
||||
new_size_y = int(round(size_y * scale))
|
||||
return tf.image.resize(
|
||||
input_features,
|
||||
[new_size_x, new_size_y],
|
||||
align_corners=True,
|
||||
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
|
||||
def resize_down(input_features, scale=2):
|
||||
"""Average pooling rescaling-operation for the input features.
|
||||
|
||||
Args:
|
||||
input_features: A tensor of shape (b, h, w, c).
|
||||
scale: An integer specifying the scaling factor.
|
||||
Returns: A tensor of shape (b, h / scale, w / scale, c).
|
||||
"""
|
||||
assert scale >= 1
|
||||
return tf.nn.avg_pool2d(
|
||||
input_features,
|
||||
ksize=(1, scale, scale, 1),
|
||||
strides=(1, scale, scale, 1),
|
||||
padding='VALID')
|
||||
Reference in New Issue
Block a user