Files
deepmind-research/iodine/modules/utils.py
Yilei Yang 152ac280f2 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440896914
2022-05-26 17:43:58 +01:00

502 lines
17 KiB
Python

# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for IODINE."""
# pylint: disable=g-doc-bad-indent, g-doc-return-or-yield, g-doc-args
# pylint: disable=missing-docstring
import importlib
import math
from absl import logging
from matplotlib.colors import hsv_to_rgb
import numpy as np
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
ACT_FUNCS = {
"identity": tf.identity,
"sigmoid": tf.nn.sigmoid,
"tanh": tf.nn.tanh,
"relu": tf.nn.relu,
"elu": tf.nn.elu,
"selu": tf.nn.selu,
"softplus": tf.nn.softplus,
"exp": tf.exp,
"softmax": tf.nn.softmax,
}
def get_act_func(name_or_func):
if name_or_func is None:
return tf.identity
if callable(name_or_func):
return name_or_func
elif isinstance(name_or_func, str):
return ACT_FUNCS[name_or_func.lower()]
else:
raise KeyError(
'Unknown activation function "{}" of type {}"'.format(
name_or_func, type(name_or_func)
)
)
DISTS = {
"normal": tfd.Normal,
"log_normal": tfd.LogNormal,
"laplace": tfd.Laplace,
"logistic": tfd.Logistic,
}
def get_distribution(name_or_dist):
if isinstance(name_or_dist, type(tfd.Normal)):
return name_or_dist
elif isinstance(name_or_dist, str):
return DISTS[name_or_dist.lower()]
raise KeyError(
'Unknown distribution "{}" of type {}"'.format(name_or_dist,
type(name_or_dist)))
def get_mask_plot_colors(nr_colors):
"""Get nr_colors uniformly spaced hues to plot mask values."""
hsv_colors = np.ones((nr_colors, 3), dtype=np.float32)
hsv_colors[:, 0] = np.linspace(0, 1, nr_colors, endpoint=False)
color_conv = hsv_to_rgb(hsv_colors)
return color_conv
def color_transform(masks):
with tf.name_scope("color_transform"):
n_components = masks.shape.as_list()[-1]
colors = tf.constant(get_mask_plot_colors(n_components), name="mask_colors")
return tf.tensordot(masks, colors, axes=1)
def construct_diagnostic_image(
images,
recons,
masks,
components,
border_width=2,
nr_images=6,
clip=True,
mask_components=False,
):
"""Construct a single image containing image, recons., mask, and components.
Args:
images: (B, H, W, C)
recons: (B, H, W, C)
masks: (B, H, W, K)
components: (B, H, W, K, C)
border_width: int. width of the border in pixels. (default=2)
nr_images: int. Number of images to include. (default=6)
clip: bool. Whether to clip the final image to range [0, 1].
Returns:
diag_images: (nr, H+border_width*2, (W+border_width*2) * (K+3), 3)
"""
with tf.name_scope("diagnostic_image"):
# transform the masks into RGB images
recolored_masks = color_transform(masks[:nr_images])
if images.get_shape().as_list()[-1] == 1:
# deal with grayscale images
images = tf.tile(images[:nr_images], [1, 1, 1, 3])
recons = tf.tile(recons[:nr_images], [1, 1, 1, 3])
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
if mask_components:
components *= masks[:nr_images, ..., tf.newaxis]
# Pad everything
no_pad, pad = (0, 0), (border_width, border_width)
paddings = tf.constant([no_pad, pad, pad, no_pad])
paddings_components = tf.constant([no_pad, pad, pad, no_pad, no_pad])
pad_images = tf.pad(images[:nr_images], paddings, constant_values=0.5)
pad_recons = tf.pad(recons[:nr_images], paddings, constant_values=0.5)
pad_masks = tf.pad(recolored_masks, paddings, constant_values=1.0)
pad_components = tf.pad(
components[:nr_images], paddings_components, constant_values=0.5
)
# reshape components into single wide image
pad_components = tf.transpose(pad_components, [0, 1, 3, 2, 4])
pc_shape = pad_components.shape.as_list()
pc_shape[2] = pc_shape[2] * pc_shape.pop(3)
pad_components = tf.reshape(pad_components, pc_shape)
# concatenate all parts along width
diag_imgs = tf.concat(
[pad_images, pad_recons, pad_masks, pad_components], axis=2
)
# concatenate all images along height
diag_shape = diag_imgs.shape.as_list()
final_img = tf.reshape(diag_imgs, [1, -1, diag_shape[2], diag_shape[3]])
if clip:
final_img = tf.clip_by_value(final_img, 0.0, 1.0)
return final_img
def construct_reconstr_image(images, recons, border_width=2,
nr_images=6, clip=True):
"""Construct a single image containing image, and recons.
Args:
images: (B, H, W, C)
recons: (B, H, W, C)
border_width: int. width of the border in pixels. (default=2)
nr_images: int. Number of images to include. (default=6)
clip: bool. Whether to clip the final image to range [0, 1].
Returns:
rec_images: (nr, H+border_width*2, (W+border_width*2) * 2, 3)
"""
with tf.name_scope("diagnostic_image"):
# Pad everything
no_pad, pad = (0, 0), (border_width, border_width)
paddings = tf.constant([no_pad, pad, pad, no_pad])
pad_images = tf.pad(images[:nr_images], paddings, constant_values=0.5)
pad_recons = tf.pad(recons[:nr_images], paddings, constant_values=0.5)
# concatenate all parts along width
diag_imgs = tf.concat([pad_images, pad_recons], axis=2)
# concatenate all images along height
diag_shape = diag_imgs.shape.as_list()
final_img = tf.reshape(diag_imgs, [1, -1, diag_shape[2], diag_shape[3]])
if clip:
final_img = tf.clip_by_value(final_img, 0.0, 1.0)
return final_img
def construct_iterations_image(
images, recons, masks, border_width=2, nr_seqs=2, clip=True
):
"""Construct a single image containing image, and recons.
Args:
images: (B, T, 1, H, W, C)
recons: (B, T, 1, H, W, C)
masks: (B, T, K, H, W, 1)
border_width: int. width of the border in pixels. (default=2)
nr_seqs: int. Number of sequences to include. (default=2)
clip: bool. Whether to clip the final image to range [0, 1].
Returns:
rec_images: (nr, H+border_width*2, (W+border_width*2) * 2, 3)
"""
sg = shapeguard.ShapeGuard()
sg.guard(recons, "B, T, 1, H, W, C")
if images.get_shape().as_list()[1] == 1:
images = tf.tile(images, sg["1, T, 1, 1, 1, 1"])
sg.guard(images, "B, T, 1, H, W, C")
sg.guard(masks, " B, T, K, H, W, 1")
if sg.C == 1: # deal with grayscale
images = tf.tile(images, [1, 1, 1, 1, 1, 3])
recons = tf.tile(recons, [1, 1, 1, 1, 1, 3])
sg.S = min(nr_seqs, sg.B)
with tf.name_scope("diagnostic_image"):
# convert masks to rgb
masks_trans = tf.transpose(masks[:nr_seqs], [0, 1, 5, 3, 4, 2])
recolored_masks = color_transform(masks_trans)
# Pad everything
no_pad, pad = (0, 0), (border_width, border_width)
paddings = tf.constant([no_pad, no_pad, no_pad, pad, pad, no_pad])
pad_images = tf.pad(images[:nr_seqs], paddings, constant_values=0.5)
pad_recons = tf.pad(recons[:nr_seqs], paddings, constant_values=0.5)
pad_masks = tf.pad(recolored_masks, paddings, constant_values=0.5)
# concatenate all parts along width
triples = tf.concat([pad_images, pad_recons, pad_masks], axis=3)
triples = sg.guard(triples[:, :, 0], "S, T, 3*Hp, Wp, 3")
# concatenate iterations along width and sequences along height
final = tf.reshape(
tf.transpose(triples, [0, 2, 1, 3, 4]), sg["1, S*3*Hp, Wp*T, 3"]
)
if clip:
final = tf.clip_by_value(final, 0.0, 1.0)
return final
def get_overview_image(image, output_dist, mask_components=False):
recons = output_dist.mean()[:, 0]
image = image[:, 0]
if hasattr(output_dist, "mixture_distribution") and hasattr(
output_dist, "components_distribution"
):
mask = output_dist.mixture_distribution.probs[:, 0]
components = output_dist.components_distribution.mean()[:, 0]
return construct_diagnostic_image(
image, recons, mask, components, mask_components=mask_components
)
else:
return construct_reconstr_image(image, recons)
class OnlineMeanVarEstimator(snt.AbstractModule):
"""Online estimator for mean and variance using Welford's algorithm."""
def __init__(self, axis=None, ddof=0.0, name="online_mean_var"):
super().__init__(name=name)
self._axis = axis
self._ddof = ddof
def _build(self, x, weights=None):
if weights is None:
weights = tf.ones_like(x)
if weights.get_shape().as_list() != x.get_shape().as_list():
weights = tf.broadcast_to(weights, x.get_shape().as_list())
sum_weights = tf.reduce_sum(weights, axis=self._axis)
shape = sum_weights.get_shape().as_list()
total = tf.get_variable(
"total",
shape=shape,
dtype=weights.dtype,
initializer=tf.zeros_initializer(),
trainable=False,
)
mean = tf.get_variable(
"mean",
shape=shape,
dtype=x.dtype,
initializer=tf.zeros_initializer(),
trainable=False,
)
m2 = tf.get_variable(
"M2",
shape=shape,
dtype=x.dtype,
initializer=tf.zeros_initializer(),
trainable=False,
)
total_update = tf.assign_add(total, sum_weights)
with tf.control_dependencies([total_update]):
delta = (x - mean) * weights
mean_update = tf.assign_add(
mean, tf.reduce_sum(delta, axis=self._axis) / total
)
with tf.control_dependencies([mean_update]):
delta2 = x - mean
m2_update = tf.assign_add(
m2, tf.reduce_sum(delta * delta2, axis=self._axis)
)
with tf.control_dependencies([m2_update]):
return tf.identity(mean), m2 / (total - self._ddof), tf.identity(total)
def print_shapes(name, value, indent=""):
if isinstance(value, dict):
print("{}{}:".format(indent, name))
for k, v in sorted(value.items(),
key=lambda x: (isinstance(x[1], dict), x[0])):
print_shapes(k, v, indent + " ")
elif isinstance(value, list):
print(
"{}{}[{}]: {} @ {}".format(
indent, name, len(value), value[0].shape, value[0].dtype
)
)
elif isinstance(value, np.ndarray):
print("{}{}: {} @ {}".format(indent, name, value.shape, value.dtype))
elif isinstance(value, tf.Tensor):
print(
"{}{}: {} @ {}".format(
indent, name, value.get_shape().as_list(), value.dtype
)
)
elif np.isscalar(value):
print("{}{}: {}".format(indent, name, value))
else:
print("{}{}.type: {}".format(indent, name, type(value)))
def _pad_images(images, image_border_value=0.5, border_width=2):
"""Pad images to create gray borders.
Args:
images: Tensor of shape [B, H], [B, H, W], or [B, H, W, C].
image_border_value: Scalar value of greyscale borderfor images.
border_width: Int. Border width in pixels.
Raises:
ValueError: if the image provided is not {2,3,4} dimensional.
Returns:
Tensor of same shape as images, except H and W being H + border_width and
W + border_width.
"""
image_rank = len(images.get_shape())
border_paddings = (border_width, border_width)
if image_rank == 2: # [B, H]
paddings = [(0, 0), border_paddings]
elif image_rank == 3: # [B, H, W]
paddings = [(0, 0), border_paddings, border_paddings]
elif image_rank == 4: # [B, H, W, C]
paddings = [(0, 0), border_paddings, border_paddings, (0, 0)]
else:
raise ValueError("expected image to be 2D, 3D or 4D, got %d" % image_rank)
paddings = tf.constant(paddings)
return tf.pad(images, paddings, "CONSTANT",
constant_values=image_border_value)
def images_to_grid(
images,
grid_height=None,
grid_width=4,
max_grid_height=4,
max_grid_width=4,
image_border_value=0.5,
):
"""Combine images and arrange them in a grid.
Args:
images: Tensor of shape [B, H], [B, H, W], or [B, H, W, C].
grid_height: Height of the grid of images to output, or None. Either
`grid_width` or `grid_height` must be set to an integer value.
If None, `grid_height` is set to ceil(B/`grid_width`), and capped at
`max_grid_height` when provided.
grid_width: Width of the grid of images to output, or None. Either
`grid_width` or `grid_height` must be set to an integer value.
If None, `grid_width` is set to ceil(B/`grid_height`), and capped at
`max_grid_width` when provided.
max_grid_height: Maximum allowable height of the grid of images to
output or None. Only used when `grid_height` is None.
max_grid_width: Maximum allowable width of the grid of images to output,
or None. Only used when `grid_width` is None.
image_border_value: None or scalar value of greyscale borderfor images.
If None, then no border is rendered.
Raises:
ValueError: if neither of grid_width or grid_height are set to a positive
integer.
Returns:
images: Tensor of shape [height*H, width*W, C].
C will be set to 1 if the input was provided with no channels.
Contains all input images in a grid.
"""
# If only one dimension is set, infer how big the other one should be.
if grid_height is None:
if not isinstance(grid_width, int) or grid_width <= 0:
raise ValueError(
"if `grid_height` is None, `grid_width` must be " "a positive integer"
)
grid_height = int(math.ceil(images.get_shape()[0].value / grid_width))
if max_grid_height is not None:
grid_height = min(max_grid_height, grid_height)
if grid_width is None:
if not isinstance(grid_height, int) or grid_height <= 0:
raise ValueError(
"if `grid_width` is None, `grid_height` must be " "a positive integer"
)
grid_width = int(math.ceil(images.get_shape()[0].value / grid_height))
if max_grid_width is not None:
grid_width = min(max_grid_width, grid_width)
images = images[: grid_height * grid_width, ...]
# Pad with extra blank frames if grid_height x grid_width is less than the
# number of frames provided.
pre_images_shape = images.get_shape().as_list()
if pre_images_shape[0] < grid_height * grid_width:
pre_images_shape[0] = grid_height * grid_width - pre_images_shape[0]
if image_border_value is not None:
dummy_frames = image_border_value * tf.ones(
shape=pre_images_shape, dtype=images.dtype
)
else:
dummy_frames = tf.zeros(shape=pre_images_shape, dtype=images.dtype)
images = tf.concat([images, dummy_frames], axis=0)
if image_border_value:
images = _pad_images(images, image_border_value=image_border_value)
images_shape = images.get_shape().as_list()
images = tf.reshape(images, [grid_height, grid_width] + images_shape[1:])
if len(images_shape) == 2:
images = tf.expand_dims(images, -1)
if len(images_shape) <= 3:
images = tf.expand_dims(images, -1)
image_height, image_width, channels = images.get_shape().as_list()[2:]
images = tf.transpose(images, perm=[0, 2, 1, 3, 4])
images = tf.reshape(
images, [grid_height * image_height, grid_width * image_width, channels]
)
return images
def flatten_all_but_last(tensor, n_dims=1):
shape = tensor.shape.as_list()
batch_dims = shape[:-n_dims]
flat_tensor = tf.reshape(tensor, [np.prod(batch_dims)] + shape[-n_dims:])
def unflatten(other_tensor):
other_shape = other_tensor.shape.as_list()
return tf.reshape(other_tensor, batch_dims + other_shape[1:])
return flat_tensor, unflatten
def ensure_3d(tensor):
if tensor.shape.ndims == 2:
return tensor[..., None]
assert tensor.shape.ndims == 3
return tensor
built_element_cache = {
"none": None,
"global_step": tf.train.get_or_create_global_step(),
}
def build(plan, identifier):
logging.debug("building %s", identifier)
if identifier in built_element_cache:
logging.debug("%s is already built, returning", identifier)
return built_element_cache[identifier]
elif not isinstance(plan, dict):
return plan
elif "constructor" in plan:
ctor = _resolve_constructor(plan)
kwargs = {
k: build(v, identifier=k) for k, v in plan.items() if k != "constructor"
}
with tf.variable_scope(identifier):
built_element_cache[identifier] = ctor(**kwargs)
return built_element_cache[identifier]
else:
return {k: build(v, identifier=k) for k, v in plan.items()}
def _resolve_constructor(plan_subsection):
assert "constructor" in plan_subsection, plan_subsection
if isinstance(plan_subsection["constructor"], str):
module, _, ctor = plan_subsection["constructor"].rpartition(".")
mod = importlib.import_module(module)
return getattr(mod, ctor)
else:
return plan_subsection["constructor"]