Release of IODINE

PiperOrigin-RevId: 299101887
This commit is contained in:
Diego de Las Casas
2020-03-05 15:52:20 +00:00
parent a5efafff3a
commit afcdc77239
23 changed files with 7600 additions and 0 deletions
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+264
View File
@@ -0,0 +1,264 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loading functionality for IODINE."""
# pylint: disable=g-multiple-import, missing-docstring, unused-import
import os.path
from iodine.modules.utils import flatten_all_but_last, ensure_3d
from multi_object_datasets import (
clevr_with_masks,
multi_dsprites,
tetrominoes,
objects_room,
)
from shapeguard import ShapeGuard
import sonnet as snt
import tensorflow.compat.v1 as tf
class IODINEDataset(snt.AbstractModule):
num_true_objects = 1
num_channels = 3
factors = {}
def __init__(
self,
path,
batch_size,
image_dim,
crop_region=None,
shuffle_buffer=1000,
max_num_objects=None,
min_num_objects=None,
grayscale=False,
name="dataset",
**kwargs,
):
super().__init__(name=name)
self.path = os.path.abspath(os.path.expanduser(path))
self.batch_size = batch_size
self.crop_region = crop_region
self.image_dim = image_dim
self.shuffle_buffer = shuffle_buffer
self.max_num_objects = max_num_objects
self.min_num_objects = min_num_objects
self.grayscale = grayscale
self.dataset = None
def _build(self, subset="train"):
dataset = self.dataset
# filter by number of objects
if self.max_num_objects is not None or self.min_num_objects is not None:
dataset = self.dataset.filter(self.filter_by_num_objects)
if subset == "train":
# normal mode returns a shuffled dataset iterator
if self.shuffle_buffer is not None:
dataset = dataset.shuffle(self.shuffle_buffer)
elif subset == "summary":
# for generating summaries and overview images
# returns a single fixed batch
dataset = dataset.take(self.batch_size)
# repeat and batch
dataset = dataset.repeat().batch(self.batch_size, drop_remainder=True)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
# preprocess the data to ensure correct format, scale images etc.
data = self.preprocess(data)
return data
def filter_by_num_objects(self, d):
if "visibility" not in d:
return tf.constant(True)
min_num_objects = self.max_num_objects or 0
max_num_objects = self.max_num_objects or 6
min_predicate = tf.greater_equal(
tf.reduce_sum(d["visibility"]),
tf.constant(min_num_objects - 1e-5, dtype=tf.float32),
)
max_predicate = tf.less_equal(
tf.reduce_sum(d["visibility"]),
tf.constant(max_num_objects + 1e-5, dtype=tf.float32),
)
return tf.logical_and(min_predicate, max_predicate)
def preprocess(self, data):
sg = ShapeGuard(dims={
"B": self.batch_size,
"H": self.image_dim[0],
"W": self.image_dim[1]
})
image = sg.guard(data["image"], "B, h, w, C")
mask = sg.guard(data["mask"], "B, L, h, w, 1")
# to float
image = tf.cast(image, tf.float32) / 255.0
mask = tf.cast(mask, tf.float32) / 255.0
# crop
if self.crop_region is not None:
height_slice = slice(self.crop_region[0][0], self.crop_region[0][1])
width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
image = image[:, height_slice, width_slice, :]
mask = mask[:, :, height_slice, width_slice, :]
flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)
# rescale
size = tf.constant(
self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True)
image = tf.image.resize_images(
image, size, method=tf.image.ResizeMethod.BILINEAR)
mask = tf.image.resize_images(
flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if self.grayscale:
image = tf.reduce_mean(image, axis=-1, keepdims=True)
output = {
"image": sg.guard(image[:, None], "B, T, H, W, C"),
"mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
"factors": self.preprocess_factors(data, sg),
}
if "visibility" in data:
output["visibility"] = sg.guard(data["visibility"], "B, L")
else:
output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)
return output
def preprocess_factors(self, data, sg):
return {
name: sg.guard(ensure_3d(data[name]), "B, L, *")
for name in self.factors
}
def get_placeholders(self, batch_size=None):
batch_size = batch_size or self.batch_size
sg = ShapeGuard(
dims={
"B": batch_size,
"H": self.image_dim[0],
"W": self.image_dim[1],
"L": self.num_true_objects,
"C": 3,
"T": 1,
})
return {
"image": tf.placeholder(dtype=tf.float32, shape=sg["B, T, H, W, C"]),
"mask": tf.placeholder(dtype=tf.float32, shape=sg["B, T, L, H, W, 1"]),
"visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
"factors": {
name:
tf.placeholder(dtype=dtype, shape=sg["B, L, {}".format(size)])
for name, (dtype, size) in self.factors
},
}
class CLEVR(IODINEDataset):
num_true_objects = 11
num_channels = 3
factors = {
"color": (tf.uint8, 1),
"shape": (tf.uint8, 1),
"size": (tf.uint8, 1),
"position": (tf.float32, 3),
"rotation": (tf.float32, 1),
}
def __init__(
self,
path,
crop_region=((29, 221), (64, 256)),
image_dim=(128, 128),
name="clevr",
**kwargs,
):
super().__init__(
path=path,
crop_region=crop_region,
image_dim=image_dim,
name=name,
**kwargs)
self.dataset = clevr_with_masks.dataset(self.path)
def preprocess_factors(self, data, sg):
return {
"color": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
"size": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
"position": sg.guard(ensure_3d(data["pixel_coords"]), "B, L, 3"),
"rotation": sg.guard(ensure_3d(data["rotation"]), "B, L, 1"),
}
class MultiDSprites(IODINEDataset):
num_true_objects = 6
num_channels = 3
factors = {
"color": (tf.float32, 3),
"shape": (tf.uint8, 1),
"scale": (tf.float32, 1),
"x": (tf.float32, 1),
"y": (tf.float32, 1),
"orientation": (tf.float32, 1),
}
def __init__(
self,
path,
# variant from ['binarized', 'colored_on_grayscale', 'colored_on_colored']
dataset_variant="colored_on_grayscale",
image_dim=(64, 64),
name="multi_dsprites",
**kwargs,
):
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
self.dataset_variant = dataset_variant
self.dataset = multi_dsprites.dataset(self.path, self.dataset_variant)
class Tetrominoes(IODINEDataset):
num_true_objects = 6
num_channels = 3
factors = {
"color": (tf.uint8, 3),
"shape": (tf.uint8, 1),
"position": (tf.float32, 2),
}
def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwargs):
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
self.dataset = tetrominoes.dataset(self.path)
def preprocess_factors(self, data, sg):
pos_x = ensure_3d(data["x"])
pos_y = ensure_3d(data["y"])
position = tf.concat([pos_x, pos_y], axis=2)
return {
"color": sg.guard(ensure_3d(data["color"]), "B, L, 3"),
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
"position": sg.guard(ensure_3d(position), "B, L, 2"),
}
+49
View File
@@ -0,0 +1,49 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoders for rendering images."""
# pylint: disable=missing-docstring
from iodine.modules.distributions import MixtureParameters
import shapeguard
import sonnet as snt
class ComponentDecoder(snt.AbstractModule):
def __init__(self, pixel_decoder, name="component_decoder"):
super().__init__(name=name)
self._pixel_decoder = pixel_decoder
self._sg = shapeguard.ShapeGuard()
def set_output_shapes(self, pixel, mask):
self._sg.guard(pixel, "K, H, W, Cp")
self._sg.guard(mask, "K, H, W, 1")
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
def _build(self, z):
self._sg.guard(z, "B, K, Z")
z_flat = self._sg.reshape(z, "B*K, Z")
pixel_params = self._pixel_decoder(z_flat).params
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
mask_params = pixel_params[Ellipsis, 0:1]
pixel_params = pixel_params[Ellipsis, 1:]
output = MixtureParameters(
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
)
del self._sg.B
return output
+223
View File
@@ -0,0 +1,223 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collection of sonnet modules that wrap useful distributions."""
# pylint: disable=missing-docstring, g-doc-args, g-short-docstring-punctuation
# pylint: disable=g-space-before-docstring-summary
# pylint: disable=g-no-space-after-docstring-summary
import collections
from iodine.modules.utils import get_act_func
from iodine.modules.utils import get_distribution
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
FlatParameters = collections.namedtuple("ParameterOut", ["params"])
MixtureParameters = collections.namedtuple("MixtureOut", ["pixel", "mask"])
class DistributionModule(snt.AbstractModule):
"""Distribution Base class supporting shape inference & default priors."""
def __init__(self, name="distribution"):
super().__init__(name=name)
self._output_shape = None
def set_output_shape(self, shape):
self._output_shape = shape
@property
def output_shape(self):
return self._output_shape
@property
def input_shapes(self):
raise NotImplementedError()
def get_default_prior(self, batch_dim=(1,)):
return self(
tf.zeros(list(batch_dim) + self.input_shapes.params, dtype=tf.float32))
class BernoulliOutput(DistributionModule):
def __init__(self, name="bernoulli_output"):
super().__init__(name=name)
@property
def input_shapes(self):
return FlatParameters(self.output_shape)
def _build(self, params):
return tfd.Independent(
tfd.Bernoulli(logits=params, dtype=tf.float32),
reinterpreted_batch_ndims=1)
class LocScaleDistribution(DistributionModule):
"""Generic IID location / scale distribution.
Input parameters are concatenation of location and scale (2*Z,)
Args:
dist: Distribution or str Kind of distribution used. Supports Normal,
Logistic, Laplace, and StudentT distributions.
dist_kwargs: dict custom keyword arguments for the distribution
scale_act: function or str or None activation function to be applied to
the scale input
scale: str
different modes for computing the scale:
* stddev: scale is computed as scale_act(s)
* var: scale is computed as sqrt(scale_act(s))
* prec: scale is computed as 1./scale_act(s)
* fixed: scale is a global variable (same for all pixels) if
scale_val==-1. then it is a trainable variable initialized to 0.1
else it is fixed to scale_val (input shape is only (Z,) in this
case)
scale_val: float determines the scale value (only used if scale=='fixed').
loc_act: function or str or None activation function to be applied to the
location input. Supports optional activation functions for scale and
location.
Supports different "modes" for scaling:
* stddev:
"""
def __init__(
self,
dist=tfd.Normal,
dist_kwargs=None,
scale_act=tf.exp,
scale="stddev",
scale_val=1.0,
loc_act=None,
name="loc_scale_dist",
):
super().__init__(name=name)
self._scale_act = get_act_func(scale_act)
self._loc_act = get_act_func(loc_act)
# supports Normal, Logstic, Laplace, StudentT
self._dist = get_distribution(dist)
self._dist_kwargs = dist_kwargs or {}
assert scale in ["stddev", "var", "prec", "fixed"], scale
self._scale = scale
self._scale_val = scale_val
@property
def input_shapes(self):
if self._scale == "fixed":
param_shape = self.output_shape
else:
param_shape = self.output_shape[:-1] + [self.output_shape[-1] * 2]
return FlatParameters(param_shape)
def _build(self, params):
if self._scale == "fixed":
loc = params
scale = None # set later
else:
n_channels = params.get_shape().as_list()[-1]
assert n_channels % 2 == 0
assert n_channels // 2 == self.output_shape[-1]
loc = params[Ellipsis, :n_channels // 2]
scale = params[Ellipsis, n_channels // 2:]
# apply activation functions
if self._scale != "fixed":
scale = self._scale_act(scale)
loc = self._loc_act(loc)
# apply the correct parametrization
if self._scale == "var":
scale = tf.sqrt(scale)
elif self._scale == "prec":
scale = tf.reciprocal(scale)
elif self._scale == "fixed":
if self._scale_val == -1.0:
scale_val = tf.get_variable(
"scale", initializer=tf.constant(0.1, dtype=tf.float32))
else:
scale_val = self._scale_val
scale = tf.ones_like(loc) * scale_val
# else 'stddev'
dist = self._dist(loc=loc, scale=scale, **self._dist_kwargs)
return tfd.Independent(dist, reinterpreted_batch_ndims=1)
class MaskedMixture(DistributionModule):
def __init__(
self,
num_components,
component_dist,
mask_activation=None,
name="masked_mixture",
):
"""
Spatial Mixture Model composed of a categorical masking distribution and
a custom pixel-wise component distribution (usually logistic or
gaussian).
Args:
num_components: int Number of mixture components >= 2
component_dist: the distribution to use for the individual components
mask_activation: str or function or None activation function that
should be applied to the mask before the softmax.
name: str
"""
super().__init__(name=name)
self._num_components = num_components
self._dist = component_dist
self._mask_activation = get_act_func(mask_activation)
def set_output_shape(self, shape):
super().set_output_shape(shape)
self._dist.set_output_shape(shape)
def _build(self, pixel, mask):
sg = shapeguard.ShapeGuard()
# MASKING
sg.guard(mask, "B, K, H, W, 1")
mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
mask = sg.reshape(mask, "B, H, W, K")
mask = self._mask_activation(mask)
mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture
mix_dist = tfd.Categorical(logits=mask)
# COMPONENTS
sg.guard(pixel, "B, K, H, W, Cp")
params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture
dist = self._dist(params)
return tfd.MixtureSameFamily(
mixture_distribution=mix_dist, components_distribution=dist)
@property
def input_shapes(self):
pixel = [self._num_components] + self._dist.input_shapes.params
mask = pixel[:-1] + [1]
return MixtureParameters(pixel, mask)
def get_default_prior(self, batch_dim=(1,)):
pixel = tf.zeros(
list(batch_dim) + self.input_shapes.pixel, dtype=tf.float32)
mask = tf.zeros(list(batch_dim) + self.input_shapes.mask, dtype=tf.float32)
return self(pixel, mask)
+206
View File
@@ -0,0 +1,206 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factor Evaluation Module."""
# pylint: disable=unused-variable
import collections
import functools
from iodine.modules import utils
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
Factor = collections.namedtuple("Factor", ["name", "size", "type"])
class FactorRegressor(snt.AbstractModule):
"""Assess representations by learning a linear mapping to latents."""
def __init__(self, mapping=None, name="repres_content"):
super().__init__(name=name)
if mapping is None:
self._mapping = [
Factor("color", 3, "scalar"),
Factor("shape", 4, "categorical"),
Factor("scale", 1, "scalar"),
Factor("x", 1, "scalar"),
Factor("y", 1, "scalar"),
Factor("orientation", 2, "angle"),
]
else:
self._mapping = [Factor(*m) for m in mapping]
def _build(self, z, latent, visibility, pred_mask, true_mask):
sg = shapeguard.ShapeGuard()
z = sg.guard(z, "B, K, Z")
pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
true_mask = sg.guard(true_mask, "B, L, H, W, 1")
visibility = sg.guard(visibility, "B, L")
num_visible_obj = tf.reduce_sum(visibility)
# Map z to predictions for all latents
sg.M = sum([m.size for m in self._mapping])
self.predictor = snt.Linear(sg.M, name="predict_latents")
z_flat = sg.reshape(z, "B*K, Z")
all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
all_preds = sg.reshape(all_preds, "B, 1, K, M")
all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])
# prepare latents
latents = {}
mean_var_tot = {}
for m in self._mapping:
with tf.name_scope(m.name):
# preprocess, reshape, and tile
lat_preprocess = self.get_preprocessing(m)
lat = sg.guard(
lat_preprocess(latent[m.name]), "B, L, {}".format(m.size))
# compute mean over latent by training a variable using mse
if m.type in {"scalar", "angle"}:
mvt = utils.OnlineMeanVarEstimator(
axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis])
lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
lat = tf.tile(lat, sg["1, 1, K, 1"])
latents[m.name] = lat
# prepare predictions
idx = 0
predictions = {}
for m in self._mapping:
with tf.name_scope(m.name):
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
pred = all_preds[Ellipsis, idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
idx += m.size
# compute error
total_pairwise_errors = None
for m in self._mapping:
with tf.name_scope(m.name):
error_fn = self.get_error_func(m)
sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
err = error_fn(latents[m.name], predictions[m.name])
sg.guard(err, "B, L, K")
if total_pairwise_errors is None:
total_pairwise_errors = err
else:
total_pairwise_errors += err
# determine best assignment by comparing masks
obj_mask = true_mask[:, :, tf.newaxis]
pred_mask = pred_mask[:, tf.newaxis]
pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
assignment = tf.one_hot(best_match, sg.K)
assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects
# total error
total_error = (
tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj)
# compute scalars
monitored_scalars = {}
for m in self._mapping:
with tf.name_scope(m.name):
metric = self.get_metric(m)
scalar = metric(
latents[m.name],
predictions[m.name],
assignment[:, :, :, tf.newaxis],
mean_var_tot.get(m.name),
num_visible_obj,
)
monitored_scalars[m.name] = scalar
return total_error, monitored_scalars, mean_var_tot, predictions, assignment
@snt.reuse_variables
def predict(self, z):
sg = shapeguard.ShapeGuard()
z = sg.guard(z, "B, Z")
all_preds = sg.guard(self.predictor(z), "B, M")
idx = 0
predictions = {}
for m in self._mapping:
with tf.name_scope(m.name):
pred = all_preds[:, idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
idx += m.size
return predictions
@staticmethod
def get_error_func(factor):
if factor.type in {"scalar", "angle"}:
return sse
elif factor.type == "categorical":
return functools.partial(
tf.losses.softmax_cross_entropy, reduction="none")
else:
raise KeyError(factor.type)
@staticmethod
def get_metric(factor):
if factor.type in {"scalar", "angle"}:
return r2
elif factor.type == "categorical":
return accuracy
else:
raise KeyError(factor.type)
@staticmethod
def one_hot(f, nr_categories):
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
@staticmethod
def angle_to_vector(theta):
return tf.concat([tf.math.cos(theta), tf.math.sin(theta)], axis=-1)
@staticmethod
def get_preprocessing(factor):
if factor.type == "scalar":
return tf.identity
elif factor.type == "categorical":
return functools.partial(
FactorRegressor.one_hot, nr_categories=factor.size)
elif factor.type == "angle":
return FactorRegressor.angle_to_vector
else:
raise KeyError(factor.type)
def sse(true, pred):
# run our own sum squared error because we want to reduce sum over last dim
return tf.reduce_sum(tf.square(true - pred), axis=-1)
def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
del mean_var_tot # unused
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
correct = tf.cast(tf.equal(labels, pred), tf.float32)
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
def r2(labels, pred, assignment, mean_var_tot, num_vis):
del num_vis # unused
mean, var, _ = mean_var_tot
# labels, pred: (B, L, K, n)
ss_res = tf.reduce_sum(tf.square(labels - pred) * assignment, axis=2)
ss_tot = var[tf.newaxis, tf.newaxis, :] # (1, 1, n)
return tf.reduce_mean(1.0 - ss_res / ss_tot)
File diff suppressed because it is too large Load Diff
+300
View File
@@ -0,0 +1,300 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network modules."""
# pylint: disable=g-multiple-import, g-doc-args, g-short-docstring-punctuation
# pylint: disable=g-no-space-after-docstring-summary
from iodine.modules.distributions import FlatParameters
from iodine.modules.utils import flatten_all_but_last, get_act_func
import numpy as np
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
class CNN(snt.AbstractModule):
"""ConvNet2D followed by an MLP.
This is a typical encoder architecture for VAEs, and has been found to work
well. One small improvement is to append coordinate channels on the input,
though for most datasets the improvement obtained is negligible.
"""
def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"):
"""Constructor.
Args:
cnn_opt: Dictionary. Kwargs for the cnn. See vae_lib.ConvNet2D for
details.
mlp_opt: Dictionary. Kwargs for the mlp. See vae_lib.MLP for details.
name: String. Optional name.
"""
super().__init__(name=name)
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._mode = mode
def set_output_shapes(self, shape):
# assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
sg = shapeguard.ShapeGuard()
sg.guard(shape, "1, Y")
self._mlp_opt["output_sizes"][-1] = sg.Y
def _build(self, image):
"""Connect model to TensorFlow graph."""
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
sg = shapeguard.ShapeGuard()
flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
sg.guard(flat_image, "B, H, W, C")
cnn = snt.nets.ConvNet2D(
activate_final=True,
paddings=("SAME",),
normalize_final=False,
**self._cnn_opt)
mlp = snt.nets.MLP(**self._mlp_opt)
# run CNN
net = cnn(flat_image)
if self._mode == "flatten":
# flatten
net_shape = net.get_shape().as_list()
flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
net = tf.reshape(net, flat_shape)
elif self._mode == "avg_pool":
net = tf.reduce_mean(net, axis=[1, 2])
else:
raise KeyError('Unknown mode "{}"'.format(self._mode))
# run MLP
output = sg.guard(mlp(net), "B, Y")
return FlatParameters(unflatten(output))
class MLP(snt.AbstractModule):
"""MLP."""
def __init__(self, name="mlp", **mlp_opt):
super().__init__(name=name)
if "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
assert mlp_opt["output_sizes"][-1] is None, mlp_opt
def set_output_shapes(self, shape):
sg = shapeguard.ShapeGuard()
sg.guard(shape, "1, Y")
self._mlp_opt["output_sizes"][-1] = sg.Y
def _build(self, data):
"""Connect model to TensorFlow graph."""
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
sg = shapeguard.ShapeGuard()
flat_data, unflatten = flatten_all_but_last(data)
sg.guard(flat_data, "B, N")
mlp = snt.nets.MLP(**self._mlp_opt)
# run MLP
output = sg.guard(mlp(flat_data), "B, Y")
return FlatParameters(unflatten(output))
class DeConv(snt.AbstractModule):
"""MLP followed by Deconv net.
This decoder is commonly used by vanilla VAE models. However, in practice
BroadcastConv (see below) seems to disentangle slightly better.
"""
def __init__(self, mlp_opt, cnn_opt, name="deconv"):
"""Constructor.
Args:
mlp_opt: Dictionary. Kwargs for vae_lib.MLP.
cnn_opt: Dictionary. Kwargs for vae_lib.ConvNet2D for the CNN.
name: Optional name.
"""
super().__init__(name=name)
assert cnn_opt["output_channels"][-1] is None, cnn_opt
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if mlp_opt and "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._target_out_shape = None
def set_output_shapes(self, shape):
self._target_out_shape = shape
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
def _build(self, z):
"""Connect model to TensorFlow graph."""
sg = shapeguard.ShapeGuard()
flat_z, unflatten = flatten_all_but_last(z)
sg.guard(flat_z, "B, Z")
sg.guard(self._target_out_shape, "H, W, C")
mlp = snt.nets.MLP(**self._mlp_opt)
cnn = snt.nets.ConvNet2DTranspose(
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
net = mlp(flat_z)
output = sg.guard(cnn(net), "B, H, W, C")
return FlatParameters(unflatten(output))
class BroadcastConv(snt.AbstractModule):
"""MLP followed by a broadcast convolution.
This decoder takes a latent vector z, (optionally) applies an MLP to it,
then tiles the resulting vector across space to have dimension [B, H, W, C]
i.e. tiles across H and W. Then coordinate channels are appended and a
convolutional layer is applied.
"""
def __init__(
self,
cnn_opt,
mlp_opt=None,
coord_type="linear",
coord_freqs=3,
name="broadcast_conv",
):
"""Args:
cnn_opt: dict Kwargs for vae_lib.ConvNet2D for the CNN.
mlp_opt: None or dict If dictionary, then kwargs for snt.nets.MLP. If
None, then the model will not process the latent vector by an mlp.
coord_type: ["linear", "cos", None] type of coordinate channels to
add.
None: add no coordinate channels.
linear: two channels with values linearly spaced from -1. to 1. in
the H and W dimension respectively.
cos: coord_freqs^2 many channels containing cosine basis functions.
coord_freqs: int number of frequencies used to construct the cosine
basis functions (only for coord_type=="cos")
name: Optional name.
"""
super().__init__(name=name)
assert cnn_opt["output_channels"][-1] is None, cnn_opt
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if mlp_opt and "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._target_out_shape = None
self._coord_type = coord_type
self._coord_freqs = coord_freqs
def set_output_shapes(self, shape):
self._target_out_shape = shape
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
def _build(self, z):
"""Connect model to TensorFlow graph."""
assert self._target_out_shape is not None, "Call set_output_shape"
# reshape components into batch dimension before processing them
sg = shapeguard.ShapeGuard()
flat_z, unflatten = flatten_all_but_last(z)
sg.guard(flat_z, "B, Z")
sg.guard(self._target_out_shape, "H, W, C")
if self._mlp_opt is None:
mlp = tf.identity
else:
mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
mlp_output = sg.guard(mlp(flat_z), "B, hidden")
# tile MLP output spatially and append coordinate channels
broadcast_mlp_output = tf.tile(
mlp_output[:, tf.newaxis, tf.newaxis],
multiples=tf.constant(sg["1, H, W, 1"]),
) # B, H, W, Z
dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)
cnn = snt.nets.ConvNet2D(
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
cnn_outputs = cnn(dec_cnn_inputs)
sg.guard(cnn_outputs, "B, H, W, C")
return FlatParameters(unflatten(cnn_outputs))
def append_coordinate_channels(self, output):
sg = shapeguard.ShapeGuard()
sg.guard(output, "B, H, W, C")
if self._coord_type is None:
return output
if self._coord_type == "linear":
w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
return tf.concat([output, h_coords, w_coords], axis=-1)
elif self._coord_type == "cos":
freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
return tf.concat([output, coords], axis=-1)
else:
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
class LSTM(snt.RNNCore):
"""Wrapper around snt.LSTM that supports multi-layers and runs K components in
parallel.
Expects input data of shape (B, K, H) and outputs data of shape (B, K, Y)
"""
def __init__(self, hidden_sizes, name="lstm"):
super().__init__(name=name)
self._hidden_sizes = hidden_sizes
with self._enter_variable_scope():
self._lstm_layers = [snt.LSTM(hidden_size=h) for h in self._hidden_sizes]
def initial_state(self, batch_size, **kwargs):
return [
lstm.initial_state(batch_size, **kwargs) for lstm in self._lstm_layers
]
def _build(self, data, prev_states):
assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
assert len(prev_states) == len(self._hidden_sizes)
sg = shapeguard.ShapeGuard()
sg.guard(data, "B, K, H")
data = sg.reshape(data, "B*K, H")
out = data
new_states = []
for lstm, pstate in zip(self._lstm_layers, prev_states):
out, nstate = lstm(out, pstate)
new_states.append(nstate)
sg.guard(out, "B*K, Y")
out = sg.reshape(out, "B, K, Y")
return out, new_states
+226
View File
@@ -0,0 +1,226 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plotting tools for IODINE."""
# pylint: disable=unused-import, missing-docstring, unused-variable
# pylint: disable=invalid-name, unexpected-keyword-arg
import functools
from iodine.modules.utils import get_mask_plot_colors
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import numpy as np
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
"inputs_plot")
def clean_ax(ax, color=None, lw=4.0):
ax.set_xticks([])
ax.set_yticks([])
if color is not None:
for spine in ax.spines.values():
spine.set_linewidth(lw)
spine.set_color(color)
def optional_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
return fn(*args, **kwargs)
return _wrapped
def optional_clean_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
color = kwargs.pop("color", None)
lw = kwargs.pop("lw", 4.0)
res = fn(*args, **kwargs)
clean_ax(kwargs["ax"], color, lw)
return res
return _wrapped
@optional_clean_ax
def show_img(img, mask=None, ax=None, norm=False):
if norm:
vmin, vmax = np.min(img), np.max(img)
img = (img - vmin) / (vmax - vmin)
if mask is not None:
img = img * mask + np.ones_like(img) * (1.0 - mask)
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mask(m, ax):
color_conv = get_mask_plot_colors(m.shape[0])
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
return ax.matshow(
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
@optional_clean_ax
def show_coords(m, ax):
vmin, vmax = np.min(m), np.max(m)
m = (m - vmin) / (vmax - vmin)
color_conv = get_mask_plot_colors(m.shape[-1])
color_mask = np.dot(m, color_conv)
return ax.imshow(color_mask, interpolation="nearest")
def example_plot(rinfo,
b=0,
t=-1,
mask_components=False,
size=2,
column_titles=True):
image = rinfo["data"]["image"][b, 0]
recons = rinfo["outputs"]["recons"][b, t, 0]
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
components = rinfo["outputs"]["components"][b, t]
K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = 1
ncols = 3 + K
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
show_img(image, ax=axes[0], color="#000000")
show_img(recons, ax=axes[1], color="#000000")
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
for k in range(K):
mask = pred_mask[k] if mask_components else None
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
if column_titles:
labels = ["Image", "Recons.", "Mask"
] + ["Component {}".format(k + 1) for k in range(K)]
for ax, title in zip(axes, labels):
ax.set_title(title)
plt.subplots_adjust(hspace=0.03, wspace=0.035)
return fig
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
image = rinfo["data"]["image"][b]
true_mask = rinfo["data"]["true_mask"][b]
recons = rinfo["outputs"]["recons"][b]
pred_mask = rinfo["outputs"]["pred_mask"][b]
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
components = rinfo["outputs"]["components"][b]
T, K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = T + 1
ncols = 2 + K
fig, axes = plt.subplots(
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
for t in range(T):
show_img(recons[t, 0], ax=axes[t, 0])
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
axes[t, 0].set_ylabel("iter {}".format(t))
for k in range(K):
mask = pred_mask[t, k] if mask_components else None
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
axes[0, 0].set_title("Reconstruction")
axes[0, 1].set_title("Mask")
show_img(image[0], ax=axes[T, 0])
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
vmin = np.min(pred_mask_logits[T - 1])
vmax = np.max(pred_mask_logits[T - 1])
for k in range(K):
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
show_mat(
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
axes[T, k + 2].set_xlabel(
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
axes[T, 0].set_xlabel("Input Image")
axes[T, 1].set_xlabel("Ground Truth Mask")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
return fig
def inputs_plot(rinfo, b=0, t=0, size=2):
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
colors = get_mask_plot_colors(K)
inputs = rinfo["inputs"]["spatial"]
rows = [
("image", show_img, False),
("components", show_img, False),
("dcomponents", functools.partial(show_img, norm=True), False),
("mask", show_mat, True),
("pred_mask", show_mat, True),
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
("posterior", show_mat, True),
("log_prob", show_mat, True),
("counterfactual", show_mat, True),
("coordinates", show_coords, False),
]
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
nrows = len(rows)
ncols = K + 1
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * size - size * 0.9, nrows * size),
gridspec_kw={"width_ratios": [1] * K + [0.1]},
)
for r, (name, plot_fn, make_cbar) in enumerate(rows):
axes[r, 0].set_ylabel(name)
if make_cbar:
vmin = np.min(inputs[name][b, t])
vmax = np.max(inputs[name][b, t])
if np.abs(vmin - vmax) < 1e-6:
vmin -= 0.1
vmax += 0.1
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
for k in range(K):
if inputs[name].shape[2] == 1:
m = inputs[name][b, t, 0]
color = (0.0, 0.0, 0.0)
else:
m = inputs[name][b, t, k]
color = colors[k]
mappable = plot_fn(m, ax=axes[r, k], color=color)
if make_cbar:
fig.colorbar(mappable, cax=axes[r, K])
else:
axes[r, K].set_visible(False)
for k in range(K):
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
plt.subplots_adjust(hspace=0.05, wspace=0.05)
return fig
+163
View File
@@ -0,0 +1,163 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Iterative refinement modules."""
# pylint: disable=g-doc-bad-indent, unused-variable
from iodine.modules import utils
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
class RefinementCore(snt.RNNCore):
"""Recurrent Refinement Module.
Refinement modules take as inputs:
* previous state (which could be an arbitrary nested structure)
* current inputs which include
* image-space inputs like pixel-based errors, or mask-posteriors
* latent-space inputs like the previous z_dist, or dz
They use these inputs to produce:
* output (usually a new z_dist)
* new_state
"""
def __init__(self,
encoder_net,
recurrent_net,
refinement_head,
name="refinement"):
super().__init__(name=name)
self._encoder_net = encoder_net
self._recurrent_net = recurrent_net
self._refinement_head = refinement_head
self._sg = shapeguard.ShapeGuard()
def initial_state(self, batch_size, **unused_kwargs):
return self._recurrent_net.initial_state(batch_size)
def _build(self, inputs, prev_state):
sg = self._sg
assert "spatial" in inputs, inputs.keys()
assert "flat" in inputs, inputs.keys()
assert "zp" in inputs["flat"], inputs["flat"].keys()
zp = sg.guard(inputs["flat"]["zp"], "B, K, Zp")
x = sg.guard(self.prepare_spatial_inputs(inputs["spatial"]), "B*K, H, W, C")
h1 = sg.guard(self._encoder_net(x).params, "B*K, H1")
h2 = sg.guard(self.prepare_flat_inputs(h1, inputs["flat"]), "B*K, H2")
h2_unflattened = sg.reshape(h2, "B, K, H2")
h3, next_state = self._recurrent_net(h2_unflattened, prev_state)
sg.guard(h3, "B, K, H3")
outputs = sg.guard(self._refinement_head(zp, h3), "B, K, Y")
del self._sg.B
return outputs, next_state
def prepare_spatial_inputs(self, inputs):
values = []
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
if val.shape.as_list()[1] == 1:
self._sg.guard(val, "B, 1, H, W, _C")
val = tf.tile(val, self._sg["1, K, 1, 1, 1"])
else:
self._sg.guard(val, "B, K, H, W, _C")
values.append(val)
concat_inputs = self._sg.guard(tf.concat(values, axis=-1), "B, K, H, W, C")
return self._sg.reshape(concat_inputs, "B*K, H, W, C")
def prepare_flat_inputs(self, hidden, inputs):
values = [self._sg.guard(hidden, "B*K, H1")]
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
self._sg.guard(val, "B, K, _")
val_flat = tf.reshape(val, self._sg["B*K"] + [-1])
values.append(val_flat)
return tf.concat(values, axis=-1)
class ResHead(snt.AbstractModule):
"""Updates Zp using a residual mechanism."""
def __init__(self, name="residual_head"):
super().__init__(name=name)
def _build(self, zp_old, inputs):
sg = shapeguard.ShapeGuard()
sg.guard(zp_old, "B, K, Zp")
sg.guard(inputs, "B, K, H")
update = snt.Linear(sg.Zp)
flat_zp = sg.reshape(zp_old, "B*K, Zp")
flat_inputs = sg.reshape(inputs, "B*K, H")
zp = flat_zp + update(flat_inputs)
return sg.reshape(zp, "B, K, Zp")
class PredictorCorrectorHead(snt.AbstractModule):
"""This refinement head is used for sequential data.
At every step it computes a prediction from the λ of the previous timestep
and an update from the refinement network of the current timestep.
The next step λ' is computed as a gated combination of both:
λ' = g * λ_corr + (1-g) * λ_pred
"""
def __init__(
self,
hidden_sizes=(64,),
pred_gate_bias=0.0,
corrector_gate_bias=0.0,
activation=tf.nn.elu,
name="predcorr_head",
):
super().__init__(name=name)
self._hidden_sizes = hidden_sizes
self._activation = utils.get_act_func(activation)
self._pred_gate_bias = pred_gate_bias
self._corrector_gate_bias = corrector_gate_bias
def _build(self, zp_old, inputs):
sg = shapeguard.ShapeGuard()
sg.guard(zp_old, "B, K, Zp")
sg.guard(inputs, "B, K, H")
update = snt.Linear(sg.Zp)
update_gate = snt.Linear(sg.Zp)
predict = snt.nets.MLP(
output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
activation=self._activation,
)
flat_zp = sg.reshape(zp_old, "B*K, Zp")
flat_inputs = sg.reshape(inputs, "B*K, H")
g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
u = update(flat_inputs)
# a slightly more efficient way of computing the gated update
# (1-g) * flat_zp + g * u
zp_corrected = flat_zp + g * (u - flat_zp)
predicted = predict(flat_zp)
pred_up = predicted[:, :sg.Zp]
pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)
zp = zp_corrected + pred_gate * (pred_up - zp_corrected)
return sg.reshape(zp, "B, K, Zp")
File diff suppressed because it is too large Load Diff