mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-27 02:05:41 +08:00
Release of IODINE
PiperOrigin-RevId: 299101887
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,264 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Data loading functionality for IODINE."""
|
||||
# pylint: disable=g-multiple-import, missing-docstring, unused-import
|
||||
import os.path
|
||||
|
||||
from iodine.modules.utils import flatten_all_but_last, ensure_3d
|
||||
from multi_object_datasets import (
|
||||
clevr_with_masks,
|
||||
multi_dsprites,
|
||||
tetrominoes,
|
||||
objects_room,
|
||||
)
|
||||
from shapeguard import ShapeGuard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class IODINEDataset(snt.AbstractModule):
|
||||
num_true_objects = 1
|
||||
num_channels = 3
|
||||
|
||||
factors = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
batch_size,
|
||||
image_dim,
|
||||
crop_region=None,
|
||||
shuffle_buffer=1000,
|
||||
max_num_objects=None,
|
||||
min_num_objects=None,
|
||||
grayscale=False,
|
||||
name="dataset",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self.path = os.path.abspath(os.path.expanduser(path))
|
||||
self.batch_size = batch_size
|
||||
self.crop_region = crop_region
|
||||
self.image_dim = image_dim
|
||||
self.shuffle_buffer = shuffle_buffer
|
||||
self.max_num_objects = max_num_objects
|
||||
self.min_num_objects = min_num_objects
|
||||
self.grayscale = grayscale
|
||||
self.dataset = None
|
||||
|
||||
def _build(self, subset="train"):
|
||||
dataset = self.dataset
|
||||
|
||||
# filter by number of objects
|
||||
if self.max_num_objects is not None or self.min_num_objects is not None:
|
||||
dataset = self.dataset.filter(self.filter_by_num_objects)
|
||||
|
||||
if subset == "train":
|
||||
# normal mode returns a shuffled dataset iterator
|
||||
if self.shuffle_buffer is not None:
|
||||
dataset = dataset.shuffle(self.shuffle_buffer)
|
||||
elif subset == "summary":
|
||||
# for generating summaries and overview images
|
||||
# returns a single fixed batch
|
||||
dataset = dataset.take(self.batch_size)
|
||||
|
||||
# repeat and batch
|
||||
dataset = dataset.repeat().batch(self.batch_size, drop_remainder=True)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
data = iterator.get_next()
|
||||
|
||||
# preprocess the data to ensure correct format, scale images etc.
|
||||
data = self.preprocess(data)
|
||||
return data
|
||||
|
||||
def filter_by_num_objects(self, d):
|
||||
if "visibility" not in d:
|
||||
return tf.constant(True)
|
||||
min_num_objects = self.max_num_objects or 0
|
||||
max_num_objects = self.max_num_objects or 6
|
||||
|
||||
min_predicate = tf.greater_equal(
|
||||
tf.reduce_sum(d["visibility"]),
|
||||
tf.constant(min_num_objects - 1e-5, dtype=tf.float32),
|
||||
)
|
||||
max_predicate = tf.less_equal(
|
||||
tf.reduce_sum(d["visibility"]),
|
||||
tf.constant(max_num_objects + 1e-5, dtype=tf.float32),
|
||||
)
|
||||
return tf.logical_and(min_predicate, max_predicate)
|
||||
|
||||
def preprocess(self, data):
|
||||
sg = ShapeGuard(dims={
|
||||
"B": self.batch_size,
|
||||
"H": self.image_dim[0],
|
||||
"W": self.image_dim[1]
|
||||
})
|
||||
image = sg.guard(data["image"], "B, h, w, C")
|
||||
mask = sg.guard(data["mask"], "B, L, h, w, 1")
|
||||
|
||||
# to float
|
||||
image = tf.cast(image, tf.float32) / 255.0
|
||||
mask = tf.cast(mask, tf.float32) / 255.0
|
||||
|
||||
# crop
|
||||
if self.crop_region is not None:
|
||||
height_slice = slice(self.crop_region[0][0], self.crop_region[0][1])
|
||||
width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
|
||||
image = image[:, height_slice, width_slice, :]
|
||||
|
||||
mask = mask[:, :, height_slice, width_slice, :]
|
||||
|
||||
flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)
|
||||
|
||||
# rescale
|
||||
size = tf.constant(
|
||||
self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True)
|
||||
image = tf.image.resize_images(
|
||||
image, size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
mask = tf.image.resize_images(
|
||||
flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
if self.grayscale:
|
||||
image = tf.reduce_mean(image, axis=-1, keepdims=True)
|
||||
|
||||
output = {
|
||||
"image": sg.guard(image[:, None], "B, T, H, W, C"),
|
||||
"mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
|
||||
"factors": self.preprocess_factors(data, sg),
|
||||
}
|
||||
|
||||
if "visibility" in data:
|
||||
output["visibility"] = sg.guard(data["visibility"], "B, L")
|
||||
else:
|
||||
output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)
|
||||
|
||||
return output
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
return {
|
||||
name: sg.guard(ensure_3d(data[name]), "B, L, *")
|
||||
for name in self.factors
|
||||
}
|
||||
|
||||
def get_placeholders(self, batch_size=None):
|
||||
batch_size = batch_size or self.batch_size
|
||||
sg = ShapeGuard(
|
||||
dims={
|
||||
"B": batch_size,
|
||||
"H": self.image_dim[0],
|
||||
"W": self.image_dim[1],
|
||||
"L": self.num_true_objects,
|
||||
"C": 3,
|
||||
"T": 1,
|
||||
})
|
||||
return {
|
||||
"image": tf.placeholder(dtype=tf.float32, shape=sg["B, T, H, W, C"]),
|
||||
"mask": tf.placeholder(dtype=tf.float32, shape=sg["B, T, L, H, W, 1"]),
|
||||
"visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
|
||||
"factors": {
|
||||
name:
|
||||
tf.placeholder(dtype=dtype, shape=sg["B, L, {}".format(size)])
|
||||
for name, (dtype, size) in self.factors
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CLEVR(IODINEDataset):
|
||||
num_true_objects = 11
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.uint8, 1),
|
||||
"shape": (tf.uint8, 1),
|
||||
"size": (tf.uint8, 1),
|
||||
"position": (tf.float32, 3),
|
||||
"rotation": (tf.float32, 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
crop_region=((29, 221), (64, 256)),
|
||||
image_dim=(128, 128),
|
||||
name="clevr",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
crop_region=crop_region,
|
||||
image_dim=image_dim,
|
||||
name=name,
|
||||
**kwargs)
|
||||
self.dataset = clevr_with_masks.dataset(self.path)
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
|
||||
return {
|
||||
"color": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||
"size": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||
"position": sg.guard(ensure_3d(data["pixel_coords"]), "B, L, 3"),
|
||||
"rotation": sg.guard(ensure_3d(data["rotation"]), "B, L, 1"),
|
||||
}
|
||||
|
||||
|
||||
class MultiDSprites(IODINEDataset):
|
||||
num_true_objects = 6
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.float32, 3),
|
||||
"shape": (tf.uint8, 1),
|
||||
"scale": (tf.float32, 1),
|
||||
"x": (tf.float32, 1),
|
||||
"y": (tf.float32, 1),
|
||||
"orientation": (tf.float32, 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
# variant from ['binarized', 'colored_on_grayscale', 'colored_on_colored']
|
||||
dataset_variant="colored_on_grayscale",
|
||||
image_dim=(64, 64),
|
||||
name="multi_dsprites",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||
self.dataset_variant = dataset_variant
|
||||
self.dataset = multi_dsprites.dataset(self.path, self.dataset_variant)
|
||||
|
||||
|
||||
class Tetrominoes(IODINEDataset):
|
||||
num_true_objects = 6
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.uint8, 3),
|
||||
"shape": (tf.uint8, 1),
|
||||
"position": (tf.float32, 2),
|
||||
}
|
||||
|
||||
def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwargs):
|
||||
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||
self.dataset = tetrominoes.dataset(self.path)
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
pos_x = ensure_3d(data["x"])
|
||||
pos_y = ensure_3d(data["y"])
|
||||
position = tf.concat([pos_x, pos_y], axis=2)
|
||||
|
||||
return {
|
||||
"color": sg.guard(ensure_3d(data["color"]), "B, L, 3"),
|
||||
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||
"position": sg.guard(ensure_3d(position), "B, L, 2"),
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Decoders for rendering images."""
|
||||
# pylint: disable=missing-docstring
|
||||
from iodine.modules.distributions import MixtureParameters
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
|
||||
|
||||
class ComponentDecoder(snt.AbstractModule):
|
||||
|
||||
def __init__(self, pixel_decoder, name="component_decoder"):
|
||||
super().__init__(name=name)
|
||||
self._pixel_decoder = pixel_decoder
|
||||
self._sg = shapeguard.ShapeGuard()
|
||||
|
||||
def set_output_shapes(self, pixel, mask):
|
||||
self._sg.guard(pixel, "K, H, W, Cp")
|
||||
self._sg.guard(mask, "K, H, W, 1")
|
||||
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
|
||||
|
||||
def _build(self, z):
|
||||
self._sg.guard(z, "B, K, Z")
|
||||
z_flat = self._sg.reshape(z, "B*K, Z")
|
||||
pixel_params = self._pixel_decoder(z_flat).params
|
||||
|
||||
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||
mask_params = pixel_params[Ellipsis, 0:1]
|
||||
pixel_params = pixel_params[Ellipsis, 1:]
|
||||
|
||||
output = MixtureParameters(
|
||||
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
|
||||
)
|
||||
|
||||
del self._sg.B
|
||||
return output
|
||||
@@ -0,0 +1,223 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Collection of sonnet modules that wrap useful distributions."""
|
||||
# pylint: disable=missing-docstring, g-doc-args, g-short-docstring-punctuation
|
||||
# pylint: disable=g-space-before-docstring-summary
|
||||
# pylint: disable=g-no-space-after-docstring-summary
|
||||
import collections
|
||||
from iodine.modules.utils import get_act_func
|
||||
from iodine.modules.utils import get_distribution
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
FlatParameters = collections.namedtuple("ParameterOut", ["params"])
|
||||
MixtureParameters = collections.namedtuple("MixtureOut", ["pixel", "mask"])
|
||||
|
||||
|
||||
class DistributionModule(snt.AbstractModule):
|
||||
"""Distribution Base class supporting shape inference & default priors."""
|
||||
|
||||
def __init__(self, name="distribution"):
|
||||
super().__init__(name=name)
|
||||
self._output_shape = None
|
||||
|
||||
def set_output_shape(self, shape):
|
||||
self._output_shape = shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self._output_shape
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_default_prior(self, batch_dim=(1,)):
|
||||
return self(
|
||||
tf.zeros(list(batch_dim) + self.input_shapes.params, dtype=tf.float32))
|
||||
|
||||
|
||||
class BernoulliOutput(DistributionModule):
|
||||
|
||||
def __init__(self, name="bernoulli_output"):
|
||||
super().__init__(name=name)
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
return FlatParameters(self.output_shape)
|
||||
|
||||
def _build(self, params):
|
||||
return tfd.Independent(
|
||||
tfd.Bernoulli(logits=params, dtype=tf.float32),
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
|
||||
class LocScaleDistribution(DistributionModule):
|
||||
"""Generic IID location / scale distribution.
|
||||
|
||||
Input parameters are concatenation of location and scale (2*Z,)
|
||||
|
||||
Args:
|
||||
dist: Distribution or str Kind of distribution used. Supports Normal,
|
||||
Logistic, Laplace, and StudentT distributions.
|
||||
dist_kwargs: dict custom keyword arguments for the distribution
|
||||
scale_act: function or str or None activation function to be applied to
|
||||
the scale input
|
||||
scale: str
|
||||
different modes for computing the scale:
|
||||
* stddev: scale is computed as scale_act(s)
|
||||
* var: scale is computed as sqrt(scale_act(s))
|
||||
* prec: scale is computed as 1./scale_act(s)
|
||||
* fixed: scale is a global variable (same for all pixels) if
|
||||
scale_val==-1. then it is a trainable variable initialized to 0.1
|
||||
else it is fixed to scale_val (input shape is only (Z,) in this
|
||||
case)
|
||||
scale_val: float determines the scale value (only used if scale=='fixed').
|
||||
loc_act: function or str or None activation function to be applied to the
|
||||
location input. Supports optional activation functions for scale and
|
||||
location.
|
||||
Supports different "modes" for scaling:
|
||||
* stddev:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist=tfd.Normal,
|
||||
dist_kwargs=None,
|
||||
scale_act=tf.exp,
|
||||
scale="stddev",
|
||||
scale_val=1.0,
|
||||
loc_act=None,
|
||||
name="loc_scale_dist",
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._scale_act = get_act_func(scale_act)
|
||||
self._loc_act = get_act_func(loc_act)
|
||||
# supports Normal, Logstic, Laplace, StudentT
|
||||
self._dist = get_distribution(dist)
|
||||
self._dist_kwargs = dist_kwargs or {}
|
||||
|
||||
assert scale in ["stddev", "var", "prec", "fixed"], scale
|
||||
self._scale = scale
|
||||
self._scale_val = scale_val
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
if self._scale == "fixed":
|
||||
param_shape = self.output_shape
|
||||
else:
|
||||
param_shape = self.output_shape[:-1] + [self.output_shape[-1] * 2]
|
||||
return FlatParameters(param_shape)
|
||||
|
||||
def _build(self, params):
|
||||
if self._scale == "fixed":
|
||||
loc = params
|
||||
scale = None # set later
|
||||
else:
|
||||
n_channels = params.get_shape().as_list()[-1]
|
||||
assert n_channels % 2 == 0
|
||||
assert n_channels // 2 == self.output_shape[-1]
|
||||
loc = params[Ellipsis, :n_channels // 2]
|
||||
scale = params[Ellipsis, n_channels // 2:]
|
||||
|
||||
# apply activation functions
|
||||
if self._scale != "fixed":
|
||||
scale = self._scale_act(scale)
|
||||
loc = self._loc_act(loc)
|
||||
|
||||
# apply the correct parametrization
|
||||
if self._scale == "var":
|
||||
scale = tf.sqrt(scale)
|
||||
elif self._scale == "prec":
|
||||
scale = tf.reciprocal(scale)
|
||||
elif self._scale == "fixed":
|
||||
if self._scale_val == -1.0:
|
||||
scale_val = tf.get_variable(
|
||||
"scale", initializer=tf.constant(0.1, dtype=tf.float32))
|
||||
else:
|
||||
scale_val = self._scale_val
|
||||
scale = tf.ones_like(loc) * scale_val
|
||||
# else 'stddev'
|
||||
|
||||
dist = self._dist(loc=loc, scale=scale, **self._dist_kwargs)
|
||||
|
||||
return tfd.Independent(dist, reinterpreted_batch_ndims=1)
|
||||
|
||||
|
||||
class MaskedMixture(DistributionModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_components,
|
||||
component_dist,
|
||||
mask_activation=None,
|
||||
name="masked_mixture",
|
||||
):
|
||||
"""
|
||||
Spatial Mixture Model composed of a categorical masking distribution and
|
||||
a custom pixel-wise component distribution (usually logistic or
|
||||
gaussian).
|
||||
|
||||
Args:
|
||||
num_components: int Number of mixture components >= 2
|
||||
component_dist: the distribution to use for the individual components
|
||||
mask_activation: str or function or None activation function that
|
||||
should be applied to the mask before the softmax.
|
||||
name: str
|
||||
"""
|
||||
|
||||
super().__init__(name=name)
|
||||
self._num_components = num_components
|
||||
self._dist = component_dist
|
||||
self._mask_activation = get_act_func(mask_activation)
|
||||
|
||||
def set_output_shape(self, shape):
|
||||
super().set_output_shape(shape)
|
||||
self._dist.set_output_shape(shape)
|
||||
|
||||
def _build(self, pixel, mask):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
# MASKING
|
||||
sg.guard(mask, "B, K, H, W, 1")
|
||||
mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
|
||||
mask = sg.reshape(mask, "B, H, W, K")
|
||||
mask = self._mask_activation(mask)
|
||||
mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||
mix_dist = tfd.Categorical(logits=mask)
|
||||
|
||||
# COMPONENTS
|
||||
sg.guard(pixel, "B, K, H, W, Cp")
|
||||
params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
|
||||
params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||
dist = self._dist(params)
|
||||
return tfd.MixtureSameFamily(
|
||||
mixture_distribution=mix_dist, components_distribution=dist)
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
pixel = [self._num_components] + self._dist.input_shapes.params
|
||||
mask = pixel[:-1] + [1]
|
||||
return MixtureParameters(pixel, mask)
|
||||
|
||||
def get_default_prior(self, batch_dim=(1,)):
|
||||
pixel = tf.zeros(
|
||||
list(batch_dim) + self.input_shapes.pixel, dtype=tf.float32)
|
||||
mask = tf.zeros(list(batch_dim) + self.input_shapes.mask, dtype=tf.float32)
|
||||
return self(pixel, mask)
|
||||
@@ -0,0 +1,206 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Factor Evaluation Module."""
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from iodine.modules import utils
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
Factor = collections.namedtuple("Factor", ["name", "size", "type"])
|
||||
|
||||
|
||||
class FactorRegressor(snt.AbstractModule):
|
||||
"""Assess representations by learning a linear mapping to latents."""
|
||||
|
||||
def __init__(self, mapping=None, name="repres_content"):
|
||||
super().__init__(name=name)
|
||||
if mapping is None:
|
||||
self._mapping = [
|
||||
Factor("color", 3, "scalar"),
|
||||
Factor("shape", 4, "categorical"),
|
||||
Factor("scale", 1, "scalar"),
|
||||
Factor("x", 1, "scalar"),
|
||||
Factor("y", 1, "scalar"),
|
||||
Factor("orientation", 2, "angle"),
|
||||
]
|
||||
else:
|
||||
self._mapping = [Factor(*m) for m in mapping]
|
||||
|
||||
def _build(self, z, latent, visibility, pred_mask, true_mask):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
z = sg.guard(z, "B, K, Z")
|
||||
pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
|
||||
true_mask = sg.guard(true_mask, "B, L, H, W, 1")
|
||||
|
||||
visibility = sg.guard(visibility, "B, L")
|
||||
num_visible_obj = tf.reduce_sum(visibility)
|
||||
|
||||
# Map z to predictions for all latents
|
||||
sg.M = sum([m.size for m in self._mapping])
|
||||
self.predictor = snt.Linear(sg.M, name="predict_latents")
|
||||
z_flat = sg.reshape(z, "B*K, Z")
|
||||
all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
|
||||
all_preds = sg.reshape(all_preds, "B, 1, K, M")
|
||||
all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])
|
||||
|
||||
# prepare latents
|
||||
latents = {}
|
||||
mean_var_tot = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
# preprocess, reshape, and tile
|
||||
lat_preprocess = self.get_preprocessing(m)
|
||||
lat = sg.guard(
|
||||
lat_preprocess(latent[m.name]), "B, L, {}".format(m.size))
|
||||
# compute mean over latent by training a variable using mse
|
||||
if m.type in {"scalar", "angle"}:
|
||||
mvt = utils.OnlineMeanVarEstimator(
|
||||
axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
|
||||
mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis])
|
||||
|
||||
lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
|
||||
lat = tf.tile(lat, sg["1, 1, K, 1"])
|
||||
latents[m.name] = lat
|
||||
|
||||
# prepare predictions
|
||||
idx = 0
|
||||
predictions = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||
pred = all_preds[Ellipsis, idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||
idx += m.size
|
||||
|
||||
# compute error
|
||||
total_pairwise_errors = None
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
error_fn = self.get_error_func(m)
|
||||
sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
|
||||
sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
|
||||
err = error_fn(latents[m.name], predictions[m.name])
|
||||
sg.guard(err, "B, L, K")
|
||||
if total_pairwise_errors is None:
|
||||
total_pairwise_errors = err
|
||||
else:
|
||||
total_pairwise_errors += err
|
||||
|
||||
# determine best assignment by comparing masks
|
||||
obj_mask = true_mask[:, :, tf.newaxis]
|
||||
pred_mask = pred_mask[:, tf.newaxis]
|
||||
pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
|
||||
best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
|
||||
assignment = tf.one_hot(best_match, sg.K)
|
||||
assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects
|
||||
|
||||
# total error
|
||||
total_error = (
|
||||
tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj)
|
||||
|
||||
# compute scalars
|
||||
monitored_scalars = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
metric = self.get_metric(m)
|
||||
scalar = metric(
|
||||
latents[m.name],
|
||||
predictions[m.name],
|
||||
assignment[:, :, :, tf.newaxis],
|
||||
mean_var_tot.get(m.name),
|
||||
num_visible_obj,
|
||||
)
|
||||
monitored_scalars[m.name] = scalar
|
||||
return total_error, monitored_scalars, mean_var_tot, predictions, assignment
|
||||
|
||||
@snt.reuse_variables
|
||||
def predict(self, z):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
z = sg.guard(z, "B, Z")
|
||||
all_preds = sg.guard(self.predictor(z), "B, M")
|
||||
|
||||
idx = 0
|
||||
predictions = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
pred = all_preds[:, idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
|
||||
idx += m.size
|
||||
return predictions
|
||||
|
||||
@staticmethod
|
||||
def get_error_func(factor):
|
||||
if factor.type in {"scalar", "angle"}:
|
||||
return sse
|
||||
elif factor.type == "categorical":
|
||||
return functools.partial(
|
||||
tf.losses.softmax_cross_entropy, reduction="none")
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
@staticmethod
|
||||
def get_metric(factor):
|
||||
if factor.type in {"scalar", "angle"}:
|
||||
return r2
|
||||
elif factor.type == "categorical":
|
||||
return accuracy
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
@staticmethod
|
||||
def one_hot(f, nr_categories):
|
||||
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
||||
|
||||
@staticmethod
|
||||
def angle_to_vector(theta):
|
||||
return tf.concat([tf.math.cos(theta), tf.math.sin(theta)], axis=-1)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessing(factor):
|
||||
if factor.type == "scalar":
|
||||
return tf.identity
|
||||
elif factor.type == "categorical":
|
||||
return functools.partial(
|
||||
FactorRegressor.one_hot, nr_categories=factor.size)
|
||||
elif factor.type == "angle":
|
||||
return FactorRegressor.angle_to_vector
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
|
||||
def sse(true, pred):
|
||||
# run our own sum squared error because we want to reduce sum over last dim
|
||||
return tf.reduce_sum(tf.square(true - pred), axis=-1)
|
||||
|
||||
|
||||
def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
||||
del mean_var_tot # unused
|
||||
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
||||
|
||||
|
||||
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||
del num_vis # unused
|
||||
mean, var, _ = mean_var_tot
|
||||
# labels, pred: (B, L, K, n)
|
||||
ss_res = tf.reduce_sum(tf.square(labels - pred) * assignment, axis=2)
|
||||
ss_tot = var[tf.newaxis, tf.newaxis, :] # (1, 1, n)
|
||||
return tf.reduce_mean(1.0 - ss_res / ss_tot)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,300 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Network modules."""
|
||||
# pylint: disable=g-multiple-import, g-doc-args, g-short-docstring-punctuation
|
||||
# pylint: disable=g-no-space-after-docstring-summary
|
||||
from iodine.modules.distributions import FlatParameters
|
||||
from iodine.modules.utils import flatten_all_but_last, get_act_func
|
||||
import numpy as np
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class CNN(snt.AbstractModule):
|
||||
"""ConvNet2D followed by an MLP.
|
||||
|
||||
This is a typical encoder architecture for VAEs, and has been found to work
|
||||
well. One small improvement is to append coordinate channels on the input,
|
||||
though for most datasets the improvement obtained is negligible.
|
||||
"""
|
||||
|
||||
def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
cnn_opt: Dictionary. Kwargs for the cnn. See vae_lib.ConvNet2D for
|
||||
details.
|
||||
mlp_opt: Dictionary. Kwargs for the mlp. See vae_lib.MLP for details.
|
||||
name: String. Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
|
||||
self._mode = mode
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
# assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(shape, "1, Y")
|
||||
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||
|
||||
def _build(self, image):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
|
||||
sg.guard(flat_image, "B, H, W, C")
|
||||
|
||||
cnn = snt.nets.ConvNet2D(
|
||||
activate_final=True,
|
||||
paddings=("SAME",),
|
||||
normalize_final=False,
|
||||
**self._cnn_opt)
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
|
||||
# run CNN
|
||||
net = cnn(flat_image)
|
||||
|
||||
if self._mode == "flatten":
|
||||
# flatten
|
||||
net_shape = net.get_shape().as_list()
|
||||
flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
|
||||
net = tf.reshape(net, flat_shape)
|
||||
elif self._mode == "avg_pool":
|
||||
net = tf.reduce_mean(net, axis=[1, 2])
|
||||
else:
|
||||
raise KeyError('Unknown mode "{}"'.format(self._mode))
|
||||
# run MLP
|
||||
output = sg.guard(mlp(net), "B, Y")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class MLP(snt.AbstractModule):
|
||||
"""MLP."""
|
||||
|
||||
def __init__(self, name="mlp", **mlp_opt):
|
||||
super().__init__(name=name)
|
||||
if "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
assert mlp_opt["output_sizes"][-1] is None, mlp_opt
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(shape, "1, Y")
|
||||
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||
|
||||
def _build(self, data):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_data, unflatten = flatten_all_but_last(data)
|
||||
sg.guard(flat_data, "B, N")
|
||||
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
# run MLP
|
||||
output = sg.guard(mlp(flat_data), "B, Y")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class DeConv(snt.AbstractModule):
|
||||
"""MLP followed by Deconv net.
|
||||
|
||||
This decoder is commonly used by vanilla VAE models. However, in practice
|
||||
BroadcastConv (see below) seems to disentangle slightly better.
|
||||
"""
|
||||
|
||||
def __init__(self, mlp_opt, cnn_opt, name="deconv"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
mlp_opt: Dictionary. Kwargs for vae_lib.MLP.
|
||||
cnn_opt: Dictionary. Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||
name: Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if mlp_opt and "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
self._target_out_shape = None
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
self._target_out_shape = shape
|
||||
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||
|
||||
def _build(self, z):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_z, unflatten = flatten_all_but_last(z)
|
||||
sg.guard(flat_z, "B, Z")
|
||||
sg.guard(self._target_out_shape, "H, W, C")
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
cnn = snt.nets.ConvNet2DTranspose(
|
||||
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||
net = mlp(flat_z)
|
||||
output = sg.guard(cnn(net), "B, H, W, C")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class BroadcastConv(snt.AbstractModule):
|
||||
"""MLP followed by a broadcast convolution.
|
||||
|
||||
This decoder takes a latent vector z, (optionally) applies an MLP to it,
|
||||
then tiles the resulting vector across space to have dimension [B, H, W, C]
|
||||
i.e. tiles across H and W. Then coordinate channels are appended and a
|
||||
convolutional layer is applied.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cnn_opt,
|
||||
mlp_opt=None,
|
||||
coord_type="linear",
|
||||
coord_freqs=3,
|
||||
name="broadcast_conv",
|
||||
):
|
||||
"""Args:
|
||||
cnn_opt: dict Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||
mlp_opt: None or dict If dictionary, then kwargs for snt.nets.MLP. If
|
||||
None, then the model will not process the latent vector by an mlp.
|
||||
coord_type: ["linear", "cos", None] type of coordinate channels to
|
||||
add.
|
||||
None: add no coordinate channels.
|
||||
linear: two channels with values linearly spaced from -1. to 1. in
|
||||
the H and W dimension respectively.
|
||||
cos: coord_freqs^2 many channels containing cosine basis functions.
|
||||
coord_freqs: int number of frequencies used to construct the cosine
|
||||
basis functions (only for coord_type=="cos")
|
||||
name: Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if mlp_opt and "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
|
||||
self._target_out_shape = None
|
||||
self._coord_type = coord_type
|
||||
self._coord_freqs = coord_freqs
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
self._target_out_shape = shape
|
||||
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||
|
||||
def _build(self, z):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._target_out_shape is not None, "Call set_output_shape"
|
||||
# reshape components into batch dimension before processing them
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_z, unflatten = flatten_all_but_last(z)
|
||||
sg.guard(flat_z, "B, Z")
|
||||
sg.guard(self._target_out_shape, "H, W, C")
|
||||
|
||||
if self._mlp_opt is None:
|
||||
mlp = tf.identity
|
||||
else:
|
||||
mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
|
||||
mlp_output = sg.guard(mlp(flat_z), "B, hidden")
|
||||
|
||||
# tile MLP output spatially and append coordinate channels
|
||||
broadcast_mlp_output = tf.tile(
|
||||
mlp_output[:, tf.newaxis, tf.newaxis],
|
||||
multiples=tf.constant(sg["1, H, W, 1"]),
|
||||
) # B, H, W, Z
|
||||
|
||||
dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)
|
||||
|
||||
cnn = snt.nets.ConvNet2D(
|
||||
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||
cnn_outputs = cnn(dec_cnn_inputs)
|
||||
sg.guard(cnn_outputs, "B, H, W, C")
|
||||
|
||||
return FlatParameters(unflatten(cnn_outputs))
|
||||
|
||||
def append_coordinate_channels(self, output):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(output, "B, H, W, C")
|
||||
if self._coord_type is None:
|
||||
return output
|
||||
if self._coord_type == "linear":
|
||||
w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
|
||||
h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
|
||||
w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
|
||||
h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
|
||||
return tf.concat([output, h_coords, w_coords], axis=-1)
|
||||
elif self._coord_type == "cos":
|
||||
freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
|
||||
valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
|
||||
valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
||||
return tf.concat([output, coords], axis=-1)
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||
|
||||
|
||||
class LSTM(snt.RNNCore):
|
||||
"""Wrapper around snt.LSTM that supports multi-layers and runs K components in
|
||||
parallel.
|
||||
|
||||
Expects input data of shape (B, K, H) and outputs data of shape (B, K, Y)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes, name="lstm"):
|
||||
super().__init__(name=name)
|
||||
self._hidden_sizes = hidden_sizes
|
||||
with self._enter_variable_scope():
|
||||
self._lstm_layers = [snt.LSTM(hidden_size=h) for h in self._hidden_sizes]
|
||||
|
||||
def initial_state(self, batch_size, **kwargs):
|
||||
return [
|
||||
lstm.initial_state(batch_size, **kwargs) for lstm in self._lstm_layers
|
||||
]
|
||||
|
||||
def _build(self, data, prev_states):
|
||||
assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
|
||||
assert len(prev_states) == len(self._hidden_sizes)
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(data, "B, K, H")
|
||||
data = sg.reshape(data, "B*K, H")
|
||||
|
||||
out = data
|
||||
new_states = []
|
||||
for lstm, pstate in zip(self._lstm_layers, prev_states):
|
||||
out, nstate = lstm(out, pstate)
|
||||
new_states.append(nstate)
|
||||
|
||||
sg.guard(out, "B*K, Y")
|
||||
out = sg.reshape(out, "B, K, Y")
|
||||
return out, new_states
|
||||
@@ -0,0 +1,226 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Plotting tools for IODINE."""
|
||||
# pylint: disable=unused-import, missing-docstring, unused-variable
|
||||
# pylint: disable=invalid-name, unexpected-keyword-arg
|
||||
import functools
|
||||
from iodine.modules.utils import get_mask_plot_colors
|
||||
from matplotlib.colors import hsv_to_rgb
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
|
||||
"inputs_plot")
|
||||
|
||||
|
||||
def clean_ax(ax, color=None, lw=4.0):
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if color is not None:
|
||||
for spine in ax.spines.values():
|
||||
spine.set_linewidth(lw)
|
||||
spine.set_color(color)
|
||||
|
||||
|
||||
def optional_ax(fn):
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
if kwargs.get("ax", None) is None:
|
||||
figsize = kwargs.pop("figsize", (4, 4))
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
kwargs["ax"] = ax
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def optional_clean_ax(fn):
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
if kwargs.get("ax", None) is None:
|
||||
figsize = kwargs.pop("figsize", (4, 4))
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
kwargs["ax"] = ax
|
||||
color = kwargs.pop("color", None)
|
||||
lw = kwargs.pop("lw", 4.0)
|
||||
res = fn(*args, **kwargs)
|
||||
clean_ax(kwargs["ax"], color, lw)
|
||||
return res
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_img(img, mask=None, ax=None, norm=False):
|
||||
if norm:
|
||||
vmin, vmax = np.min(img), np.max(img)
|
||||
img = (img - vmin) / (vmax - vmin)
|
||||
if mask is not None:
|
||||
img = img * mask + np.ones_like(img) * (1.0 - mask)
|
||||
|
||||
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_mask(m, ax):
|
||||
color_conv = get_mask_plot_colors(m.shape[0])
|
||||
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
|
||||
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||
return ax.matshow(
|
||||
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_coords(m, ax):
|
||||
vmin, vmax = np.min(m), np.max(m)
|
||||
m = (m - vmin) / (vmax - vmin)
|
||||
color_conv = get_mask_plot_colors(m.shape[-1])
|
||||
color_mask = np.dot(m, color_conv)
|
||||
return ax.imshow(color_mask, interpolation="nearest")
|
||||
|
||||
|
||||
def example_plot(rinfo,
|
||||
b=0,
|
||||
t=-1,
|
||||
mask_components=False,
|
||||
size=2,
|
||||
column_titles=True):
|
||||
image = rinfo["data"]["image"][b, 0]
|
||||
recons = rinfo["outputs"]["recons"][b, t, 0]
|
||||
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
|
||||
components = rinfo["outputs"]["components"][b, t]
|
||||
|
||||
K, H, W, C = components.shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
|
||||
nrows = 1
|
||||
ncols = 3 + K
|
||||
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
|
||||
show_img(image, ax=axes[0], color="#000000")
|
||||
show_img(recons, ax=axes[1], color="#000000")
|
||||
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
||||
for k in range(K):
|
||||
mask = pred_mask[k] if mask_components else None
|
||||
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||
|
||||
if column_titles:
|
||||
labels = ["Image", "Recons.", "Mask"
|
||||
] + ["Component {}".format(k + 1) for k in range(K)]
|
||||
for ax, title in zip(axes, labels):
|
||||
ax.set_title(title)
|
||||
plt.subplots_adjust(hspace=0.03, wspace=0.035)
|
||||
return fig
|
||||
|
||||
|
||||
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
image = rinfo["data"]["image"][b]
|
||||
true_mask = rinfo["data"]["true_mask"][b]
|
||||
recons = rinfo["outputs"]["recons"][b]
|
||||
pred_mask = rinfo["outputs"]["pred_mask"][b]
|
||||
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
|
||||
components = rinfo["outputs"]["components"][b]
|
||||
|
||||
T, K, H, W, C = components.shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
nrows = T + 1
|
||||
ncols = 2 + K
|
||||
fig, axes = plt.subplots(
|
||||
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
for t in range(T):
|
||||
show_img(recons[t, 0], ax=axes[t, 0])
|
||||
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
||||
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||
for k in range(K):
|
||||
mask = pred_mask[t, k] if mask_components else None
|
||||
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
|
||||
|
||||
axes[0, 0].set_title("Reconstruction")
|
||||
axes[0, 1].set_title("Mask")
|
||||
show_img(image[0], ax=axes[T, 0])
|
||||
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
||||
vmin = np.min(pred_mask_logits[T - 1])
|
||||
vmax = np.max(pred_mask_logits[T - 1])
|
||||
|
||||
for k in range(K):
|
||||
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||
show_mat(
|
||||
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
|
||||
axes[T, k + 2].set_xlabel(
|
||||
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
|
||||
axes[T, 0].set_xlabel("Input Image")
|
||||
axes[T, 1].set_xlabel("Ground Truth Mask")
|
||||
|
||||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||
return fig
|
||||
|
||||
|
||||
def inputs_plot(rinfo, b=0, t=0, size=2):
|
||||
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
inputs = rinfo["inputs"]["spatial"]
|
||||
rows = [
|
||||
("image", show_img, False),
|
||||
("components", show_img, False),
|
||||
("dcomponents", functools.partial(show_img, norm=True), False),
|
||||
("mask", show_mat, True),
|
||||
("pred_mask", show_mat, True),
|
||||
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
|
||||
("posterior", show_mat, True),
|
||||
("log_prob", show_mat, True),
|
||||
("counterfactual", show_mat, True),
|
||||
("coordinates", show_coords, False),
|
||||
]
|
||||
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
|
||||
nrows = len(rows)
|
||||
ncols = K + 1
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
figsize=(ncols * size - size * 0.9, nrows * size),
|
||||
gridspec_kw={"width_ratios": [1] * K + [0.1]},
|
||||
)
|
||||
for r, (name, plot_fn, make_cbar) in enumerate(rows):
|
||||
axes[r, 0].set_ylabel(name)
|
||||
if make_cbar:
|
||||
vmin = np.min(inputs[name][b, t])
|
||||
vmax = np.max(inputs[name][b, t])
|
||||
if np.abs(vmin - vmax) < 1e-6:
|
||||
vmin -= 0.1
|
||||
vmax += 0.1
|
||||
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
|
||||
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
|
||||
for k in range(K):
|
||||
if inputs[name].shape[2] == 1:
|
||||
m = inputs[name][b, t, 0]
|
||||
color = (0.0, 0.0, 0.0)
|
||||
else:
|
||||
m = inputs[name][b, t, k]
|
||||
color = colors[k]
|
||||
mappable = plot_fn(m, ax=axes[r, k], color=color)
|
||||
if make_cbar:
|
||||
fig.colorbar(mappable, cax=axes[r, K])
|
||||
else:
|
||||
axes[r, K].set_visible(False)
|
||||
for k in range(K):
|
||||
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||
|
||||
plt.subplots_adjust(hspace=0.05, wspace=0.05)
|
||||
return fig
|
||||
@@ -0,0 +1,163 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Iterative refinement modules."""
|
||||
# pylint: disable=g-doc-bad-indent, unused-variable
|
||||
from iodine.modules import utils
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class RefinementCore(snt.RNNCore):
|
||||
"""Recurrent Refinement Module.
|
||||
|
||||
Refinement modules take as inputs:
|
||||
* previous state (which could be an arbitrary nested structure)
|
||||
* current inputs which include
|
||||
* image-space inputs like pixel-based errors, or mask-posteriors
|
||||
* latent-space inputs like the previous z_dist, or dz
|
||||
|
||||
They use these inputs to produce:
|
||||
* output (usually a new z_dist)
|
||||
* new_state
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
encoder_net,
|
||||
recurrent_net,
|
||||
refinement_head,
|
||||
name="refinement"):
|
||||
super().__init__(name=name)
|
||||
self._encoder_net = encoder_net
|
||||
self._recurrent_net = recurrent_net
|
||||
self._refinement_head = refinement_head
|
||||
self._sg = shapeguard.ShapeGuard()
|
||||
|
||||
def initial_state(self, batch_size, **unused_kwargs):
|
||||
return self._recurrent_net.initial_state(batch_size)
|
||||
|
||||
def _build(self, inputs, prev_state):
|
||||
sg = self._sg
|
||||
assert "spatial" in inputs, inputs.keys()
|
||||
assert "flat" in inputs, inputs.keys()
|
||||
assert "zp" in inputs["flat"], inputs["flat"].keys()
|
||||
zp = sg.guard(inputs["flat"]["zp"], "B, K, Zp")
|
||||
|
||||
x = sg.guard(self.prepare_spatial_inputs(inputs["spatial"]), "B*K, H, W, C")
|
||||
h1 = sg.guard(self._encoder_net(x).params, "B*K, H1")
|
||||
h2 = sg.guard(self.prepare_flat_inputs(h1, inputs["flat"]), "B*K, H2")
|
||||
h2_unflattened = sg.reshape(h2, "B, K, H2")
|
||||
h3, next_state = self._recurrent_net(h2_unflattened, prev_state)
|
||||
sg.guard(h3, "B, K, H3")
|
||||
outputs = sg.guard(self._refinement_head(zp, h3), "B, K, Y")
|
||||
|
||||
del self._sg.B
|
||||
return outputs, next_state
|
||||
|
||||
def prepare_spatial_inputs(self, inputs):
|
||||
values = []
|
||||
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||
if val.shape.as_list()[1] == 1:
|
||||
self._sg.guard(val, "B, 1, H, W, _C")
|
||||
val = tf.tile(val, self._sg["1, K, 1, 1, 1"])
|
||||
else:
|
||||
self._sg.guard(val, "B, K, H, W, _C")
|
||||
values.append(val)
|
||||
concat_inputs = self._sg.guard(tf.concat(values, axis=-1), "B, K, H, W, C")
|
||||
return self._sg.reshape(concat_inputs, "B*K, H, W, C")
|
||||
|
||||
def prepare_flat_inputs(self, hidden, inputs):
|
||||
values = [self._sg.guard(hidden, "B*K, H1")]
|
||||
|
||||
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||
self._sg.guard(val, "B, K, _")
|
||||
val_flat = tf.reshape(val, self._sg["B*K"] + [-1])
|
||||
values.append(val_flat)
|
||||
return tf.concat(values, axis=-1)
|
||||
|
||||
|
||||
class ResHead(snt.AbstractModule):
|
||||
"""Updates Zp using a residual mechanism."""
|
||||
|
||||
def __init__(self, name="residual_head"):
|
||||
super().__init__(name=name)
|
||||
|
||||
def _build(self, zp_old, inputs):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(zp_old, "B, K, Zp")
|
||||
sg.guard(inputs, "B, K, H")
|
||||
update = snt.Linear(sg.Zp)
|
||||
|
||||
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||
|
||||
zp = flat_zp + update(flat_inputs)
|
||||
|
||||
return sg.reshape(zp, "B, K, Zp")
|
||||
|
||||
|
||||
class PredictorCorrectorHead(snt.AbstractModule):
|
||||
"""This refinement head is used for sequential data.
|
||||
|
||||
At every step it computes a prediction from the λ of the previous timestep
|
||||
and an update from the refinement network of the current timestep.
|
||||
|
||||
The next step λ' is computed as a gated combination of both:
|
||||
λ' = g * λ_corr + (1-g) * λ_pred
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_sizes=(64,),
|
||||
pred_gate_bias=0.0,
|
||||
corrector_gate_bias=0.0,
|
||||
activation=tf.nn.elu,
|
||||
name="predcorr_head",
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._hidden_sizes = hidden_sizes
|
||||
self._activation = utils.get_act_func(activation)
|
||||
self._pred_gate_bias = pred_gate_bias
|
||||
self._corrector_gate_bias = corrector_gate_bias
|
||||
|
||||
def _build(self, zp_old, inputs):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(zp_old, "B, K, Zp")
|
||||
sg.guard(inputs, "B, K, H")
|
||||
update = snt.Linear(sg.Zp)
|
||||
update_gate = snt.Linear(sg.Zp)
|
||||
predict = snt.nets.MLP(
|
||||
output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
|
||||
activation=self._activation,
|
||||
)
|
||||
|
||||
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||
|
||||
g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
|
||||
u = update(flat_inputs)
|
||||
|
||||
# a slightly more efficient way of computing the gated update
|
||||
# (1-g) * flat_zp + g * u
|
||||
zp_corrected = flat_zp + g * (u - flat_zp)
|
||||
|
||||
predicted = predict(flat_zp)
|
||||
pred_up = predicted[:, :sg.Zp]
|
||||
pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)
|
||||
|
||||
zp = zp_corrected + pred_gate * (pred_up - zp_corrected)
|
||||
|
||||
return sg.reshape(zp, "B, K, Zp")
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user