Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+2 -2
View File
@@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule):
pixel_params = self._pixel_decoder(z_flat).params
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
mask_params = pixel_params[Ellipsis, 0:1]
pixel_params = pixel_params[Ellipsis, 1:]
mask_params = pixel_params[..., 0:1]
pixel_params = pixel_params[..., 1:]
output = MixtureParameters(
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
+2 -2
View File
@@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule):
n_channels = params.get_shape().as_list()[-1]
assert n_channels % 2 == 0
assert n_channels // 2 == self.output_shape[-1]
loc = params[Ellipsis, :n_channels // 2]
scale = params[Ellipsis, n_channels // 2:]
loc = params[..., :n_channels // 2]
scale = params[..., n_channels // 2:]
# apply activation functions
if self._scale != "fixed":
+3 -3
View File
@@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule):
for m in self._mapping:
with tf.name_scope(m.name):
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
pred = all_preds[Ellipsis, idx:idx + m.size]
pred = all_preds[..., idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
idx += m.size
@@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule):
@staticmethod
def one_hot(f, nr_categories):
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories)
@staticmethod
def angle_to_vector(theta):
@@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
correct = tf.cast(tf.equal(labels, pred), tf.float32)
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
return tf.reduce_sum(correct * assignment[..., 0]) / num_vis
def r2(labels, pred, assignment, mean_var_tot, num_vis):
+10 -10
View File
@@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule):
[get_components(xd) for xd in iterations["x_dist"]])
# metrics
tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2])
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2])
pm = tf.reshape(pm, sg["B * T, H * W, K"])
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"])
ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"])
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
@@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule):
factor_info["assignment"].append(fass)
for k in fpred:
factor_info["predictions"][k].append(
tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
tf.reduce_sum(fpred[k] * fass[..., None], axis=2))
factor_info["metrics"][k].append(fscalars[k])
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
@@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule):
@staticmethod
def _get_mask_posterior(out_dist, img):
p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :])
p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :])
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
return tf.transpose(posterior, [0, 4, 2, 3, 1])
@@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule):
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
log_prob = sg.guard(
out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1")
out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1")
counterfactual_log_probs = []
for k in range(0, self.num_components):
@@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule):
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
axis=1)
out_dist_k = self.output_dist(pixel, mask)
log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis]
log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis]
counterfactual_log_probs.append(log_prob_k)
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
@@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule):
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:]
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:]
return coords
else:
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
@@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule):
# ########## Mask Monitoring #######
if "mask" in data:
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1])
true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1])
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
else:
true_mask = None
@@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule):
adjusted_rand_index(true_mask, pred_mask))
scalars["loss/ari_nobg"] = tf.reduce_mean(
adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask))
adjusted_rand_index(true_mask[..., 1:], pred_mask))
return scalars
+1 -1
View File
@@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule):
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:]
return tf.concat([output, coords], axis=-1)
else:
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
+4 -4
View File
@@ -83,7 +83,7 @@ def show_mask(m, ax):
@optional_clean_ax
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
return ax.matshow(
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
@optional_clean_ax
@@ -115,7 +115,7 @@ def example_plot(rinfo,
show_img(image, ax=axes[0], color="#000000")
show_img(recons, ax=axes[1], color="#000000")
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
for k in range(K):
mask = pred_mask[k] if mask_components else None
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
@@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
for t in range(T):
show_img(recons[t, 0], ax=axes[t, 0])
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
axes[t, 0].set_ylabel("iter {}".format(t))
for k in range(K):
mask = pred_mask[t, k] if mask_components else None
@@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
axes[0, 0].set_title("Reconstruction")
axes[0, 1].set_title("Mask")
show_img(image[0], ax=axes[T, 0])
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
vmin = np.min(pred_mask_logits[T - 1])
vmax = np.max(pred_mask_logits[T - 1])
+3 -3
View File
@@ -123,7 +123,7 @@ def construct_diagnostic_image(
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
if mask_components:
components *= masks[:nr_images, Ellipsis, tf.newaxis]
components *= masks[:nr_images, ..., tf.newaxis]
# Pad everything
no_pad, pad = (0, 0), (border_width, border_width)
@@ -415,7 +415,7 @@ def images_to_grid(
if max_grid_width is not None:
grid_width = min(max_grid_width, grid_width)
images = images[: grid_height * grid_width, Ellipsis]
images = images[: grid_height * grid_width, ...]
# Pad with extra blank frames if grid_height x grid_width is less than the
# number of frames provided.
@@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1):
def ensure_3d(tensor):
if tensor.shape.ndims == 2:
return tensor[Ellipsis, None]
return tensor[..., None]
assert tensor.shape.ndims == 3
return tensor