Files
deepmind-research/iodine/modules/plotting.py
T
Florent Altché 8457046b2c Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
2020-08-26 16:54:56 +01:00

227 lines
7.2 KiB
Python

# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plotting tools for IODINE."""
# pylint: disable=unused-import, missing-docstring, unused-variable
# pylint: disable=invalid-name, unexpected-keyword-arg
import functools
from iodine.modules.utils import get_mask_plot_colors
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import numpy as np
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
"inputs_plot")
def clean_ax(ax, color=None, lw=4.0):
ax.set_xticks([])
ax.set_yticks([])
if color is not None:
for spine in ax.spines.values():
spine.set_linewidth(lw)
spine.set_color(color)
def optional_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
return fn(*args, **kwargs)
return _wrapped
def optional_clean_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
color = kwargs.pop("color", None)
lw = kwargs.pop("lw", 4.0)
res = fn(*args, **kwargs)
clean_ax(kwargs["ax"], color, lw)
return res
return _wrapped
@optional_clean_ax
def show_img(img, mask=None, ax=None, norm=False):
if norm:
vmin, vmax = np.min(img), np.max(img)
img = (img - vmin) / (vmax - vmin)
if mask is not None:
img = img * mask + np.ones_like(img) * (1.0 - mask)
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mask(m, ax):
color_conv = get_mask_plot_colors(m.shape[0])
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
return ax.matshow(
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
@optional_clean_ax
def show_coords(m, ax):
vmin, vmax = np.min(m), np.max(m)
m = (m - vmin) / (vmax - vmin)
color_conv = get_mask_plot_colors(m.shape[-1])
color_mask = np.dot(m, color_conv)
return ax.imshow(color_mask, interpolation="nearest")
def example_plot(rinfo,
b=0,
t=-1,
mask_components=False,
size=2,
column_titles=True):
image = rinfo["data"]["image"][b, 0]
recons = rinfo["outputs"]["recons"][b, t, 0]
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
components = rinfo["outputs"]["components"][b, t]
K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = 1
ncols = 3 + K
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
show_img(image, ax=axes[0], color="#000000")
show_img(recons, ax=axes[1], color="#000000")
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
for k in range(K):
mask = pred_mask[k] if mask_components else None
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
if column_titles:
labels = ["Image", "Recons.", "Mask"
] + ["Component {}".format(k + 1) for k in range(K)]
for ax, title in zip(axes, labels):
ax.set_title(title)
plt.subplots_adjust(hspace=0.03, wspace=0.035)
return fig
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
image = rinfo["data"]["image"][b]
true_mask = rinfo["data"]["true_mask"][b]
recons = rinfo["outputs"]["recons"][b]
pred_mask = rinfo["outputs"]["pred_mask"][b]
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
components = rinfo["outputs"]["components"][b]
T, K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = T + 1
ncols = 2 + K
fig, axes = plt.subplots(
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
for t in range(T):
show_img(recons[t, 0], ax=axes[t, 0])
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
axes[t, 0].set_ylabel("iter {}".format(t))
for k in range(K):
mask = pred_mask[t, k] if mask_components else None
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
axes[0, 0].set_title("Reconstruction")
axes[0, 1].set_title("Mask")
show_img(image[0], ax=axes[T, 0])
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
vmin = np.min(pred_mask_logits[T - 1])
vmax = np.max(pred_mask_logits[T - 1])
for k in range(K):
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
show_mat(
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
axes[T, k + 2].set_xlabel(
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
axes[T, 0].set_xlabel("Input Image")
axes[T, 1].set_xlabel("Ground Truth Mask")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
return fig
def inputs_plot(rinfo, b=0, t=0, size=2):
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
colors = get_mask_plot_colors(K)
inputs = rinfo["inputs"]["spatial"]
rows = [
("image", show_img, False),
("components", show_img, False),
("dcomponents", functools.partial(show_img, norm=True), False),
("mask", show_mat, True),
("pred_mask", show_mat, True),
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
("posterior", show_mat, True),
("log_prob", show_mat, True),
("counterfactual", show_mat, True),
("coordinates", show_coords, False),
]
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
nrows = len(rows)
ncols = K + 1
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * size - size * 0.9, nrows * size),
gridspec_kw={"width_ratios": [1] * K + [0.1]},
)
for r, (name, plot_fn, make_cbar) in enumerate(rows):
axes[r, 0].set_ylabel(name)
if make_cbar:
vmin = np.min(inputs[name][b, t])
vmax = np.max(inputs[name][b, t])
if np.abs(vmin - vmax) < 1e-6:
vmin -= 0.1
vmax += 0.1
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
for k in range(K):
if inputs[name].shape[2] == 1:
m = inputs[name][b, t, 0]
color = (0.0, 0.0, 0.0)
else:
m = inputs[name][b, t, k]
color = colors[k]
mappable = plot_fn(m, ax=axes[r, k], color=color)
if make_cbar:
fig.colorbar(mappable, cax=axes[r, K])
else:
axes[r, K].set_visible(False)
for k in range(K):
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
plt.subplots_adjust(hspace=0.05, wspace=0.05)
return fig