mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-27 18:25:49 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
@@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule):
|
||||
pixel_params = self._pixel_decoder(z_flat).params
|
||||
|
||||
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||
mask_params = pixel_params[Ellipsis, 0:1]
|
||||
pixel_params = pixel_params[Ellipsis, 1:]
|
||||
mask_params = pixel_params[..., 0:1]
|
||||
pixel_params = pixel_params[..., 1:]
|
||||
|
||||
output = MixtureParameters(
|
||||
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||
|
||||
@@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule):
|
||||
n_channels = params.get_shape().as_list()[-1]
|
||||
assert n_channels % 2 == 0
|
||||
assert n_channels // 2 == self.output_shape[-1]
|
||||
loc = params[Ellipsis, :n_channels // 2]
|
||||
scale = params[Ellipsis, n_channels // 2:]
|
||||
loc = params[..., :n_channels // 2]
|
||||
scale = params[..., n_channels // 2:]
|
||||
|
||||
# apply activation functions
|
||||
if self._scale != "fixed":
|
||||
|
||||
@@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule):
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||
pred = all_preds[Ellipsis, idx:idx + m.size]
|
||||
pred = all_preds[..., idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||
idx += m.size
|
||||
|
||||
@@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule):
|
||||
|
||||
@staticmethod
|
||||
def one_hot(f, nr_categories):
|
||||
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
||||
return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories)
|
||||
|
||||
@staticmethod
|
||||
def angle_to_vector(theta):
|
||||
@@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
||||
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
||||
return tf.reduce_sum(correct * assignment[..., 0]) / num_vis
|
||||
|
||||
|
||||
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||
|
||||
+10
-10
@@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule):
|
||||
[get_components(xd) for xd in iterations["x_dist"]])
|
||||
|
||||
# metrics
|
||||
tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
||||
tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2])
|
||||
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
|
||||
pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
||||
pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2])
|
||||
pm = tf.reshape(pm, sg["B * T, H * W, K"])
|
||||
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
|
||||
ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"])
|
||||
ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"])
|
||||
|
||||
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
|
||||
|
||||
@@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule):
|
||||
factor_info["assignment"].append(fass)
|
||||
for k in fpred:
|
||||
factor_info["predictions"][k].append(
|
||||
tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
|
||||
tf.reduce_sum(fpred[k] * fass[..., None], axis=2))
|
||||
factor_info["metrics"][k].append(fscalars[k])
|
||||
|
||||
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
|
||||
@@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule):
|
||||
|
||||
@staticmethod
|
||||
def _get_mask_posterior(out_dist, img):
|
||||
p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :])
|
||||
p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :])
|
||||
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
|
||||
return tf.transpose(posterior, [0, 4, 2, 3, 1])
|
||||
|
||||
@@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule):
|
||||
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
|
||||
|
||||
log_prob = sg.guard(
|
||||
out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1")
|
||||
out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1")
|
||||
|
||||
counterfactual_log_probs = []
|
||||
for k in range(0, self.num_components):
|
||||
@@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule):
|
||||
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
|
||||
axis=1)
|
||||
out_dist_k = self.output_dist(pixel, mask)
|
||||
log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis]
|
||||
log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis]
|
||||
counterfactual_log_probs.append(log_prob_k)
|
||||
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
|
||||
|
||||
@@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule):
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:]
|
||||
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:]
|
||||
return coords
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
|
||||
@@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule):
|
||||
# ########## Mask Monitoring #######
|
||||
if "mask" in data:
|
||||
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
|
||||
true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1])
|
||||
true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1])
|
||||
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
|
||||
else:
|
||||
true_mask = None
|
||||
@@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule):
|
||||
adjusted_rand_index(true_mask, pred_mask))
|
||||
|
||||
scalars["loss/ari_nobg"] = tf.reduce_mean(
|
||||
adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask))
|
||||
adjusted_rand_index(true_mask[..., 1:], pred_mask))
|
||||
|
||||
return scalars
|
||||
|
||||
@@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule):
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:]
|
||||
return tf.concat([output, coords], axis=-1)
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||
|
||||
@@ -83,7 +83,7 @@ def show_mask(m, ax):
|
||||
@optional_clean_ax
|
||||
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||
return ax.matshow(
|
||||
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
@@ -115,7 +115,7 @@ def example_plot(rinfo,
|
||||
|
||||
show_img(image, ax=axes[0], color="#000000")
|
||||
show_img(recons, ax=axes[1], color="#000000")
|
||||
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
||||
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
|
||||
for k in range(K):
|
||||
mask = pred_mask[k] if mask_components else None
|
||||
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||
@@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
for t in range(T):
|
||||
show_img(recons[t, 0], ax=axes[t, 0])
|
||||
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
||||
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
|
||||
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||
for k in range(K):
|
||||
mask = pred_mask[t, k] if mask_components else None
|
||||
@@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
axes[0, 0].set_title("Reconstruction")
|
||||
axes[0, 1].set_title("Mask")
|
||||
show_img(image[0], ax=axes[T, 0])
|
||||
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
||||
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
|
||||
vmin = np.min(pred_mask_logits[T - 1])
|
||||
vmax = np.max(pred_mask_logits[T - 1])
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ def construct_diagnostic_image(
|
||||
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
|
||||
|
||||
if mask_components:
|
||||
components *= masks[:nr_images, Ellipsis, tf.newaxis]
|
||||
components *= masks[:nr_images, ..., tf.newaxis]
|
||||
|
||||
# Pad everything
|
||||
no_pad, pad = (0, 0), (border_width, border_width)
|
||||
@@ -415,7 +415,7 @@ def images_to_grid(
|
||||
if max_grid_width is not None:
|
||||
grid_width = min(max_grid_width, grid_width)
|
||||
|
||||
images = images[: grid_height * grid_width, Ellipsis]
|
||||
images = images[: grid_height * grid_width, ...]
|
||||
|
||||
# Pad with extra blank frames if grid_height x grid_width is less than the
|
||||
# number of frames provided.
|
||||
@@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1):
|
||||
|
||||
def ensure_3d(tensor):
|
||||
if tensor.shape.ndims == 2:
|
||||
return tensor[Ellipsis, None]
|
||||
return tensor[..., None]
|
||||
|
||||
assert tensor.shape.ndims == 3
|
||||
return tensor
|
||||
|
||||
Reference in New Issue
Block a user