mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
8457046b2c
PiperOrigin-RevId: 328023346
227 lines
7.2 KiB
Python
227 lines
7.2 KiB
Python
# Lint as: python3
|
|
# Copyright 2019 Deepmind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Plotting tools for IODINE."""
|
|
# pylint: disable=unused-import, missing-docstring, unused-variable
|
|
# pylint: disable=invalid-name, unexpected-keyword-arg
|
|
import functools
|
|
from iodine.modules.utils import get_mask_plot_colors
|
|
from matplotlib.colors import hsv_to_rgb
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
|
|
"inputs_plot")
|
|
|
|
|
|
def clean_ax(ax, color=None, lw=4.0):
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
if color is not None:
|
|
for spine in ax.spines.values():
|
|
spine.set_linewidth(lw)
|
|
spine.set_color(color)
|
|
|
|
|
|
def optional_ax(fn):
|
|
|
|
def _wrapped(*args, **kwargs):
|
|
if kwargs.get("ax", None) is None:
|
|
figsize = kwargs.pop("figsize", (4, 4))
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
kwargs["ax"] = ax
|
|
return fn(*args, **kwargs)
|
|
|
|
return _wrapped
|
|
|
|
|
|
def optional_clean_ax(fn):
|
|
|
|
def _wrapped(*args, **kwargs):
|
|
if kwargs.get("ax", None) is None:
|
|
figsize = kwargs.pop("figsize", (4, 4))
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
kwargs["ax"] = ax
|
|
color = kwargs.pop("color", None)
|
|
lw = kwargs.pop("lw", 4.0)
|
|
res = fn(*args, **kwargs)
|
|
clean_ax(kwargs["ax"], color, lw)
|
|
return res
|
|
|
|
return _wrapped
|
|
|
|
|
|
@optional_clean_ax
|
|
def show_img(img, mask=None, ax=None, norm=False):
|
|
if norm:
|
|
vmin, vmax = np.min(img), np.max(img)
|
|
img = (img - vmin) / (vmax - vmin)
|
|
if mask is not None:
|
|
img = img * mask + np.ones_like(img) * (1.0 - mask)
|
|
|
|
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
|
|
|
|
|
|
@optional_clean_ax
|
|
def show_mask(m, ax):
|
|
color_conv = get_mask_plot_colors(m.shape[0])
|
|
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
|
|
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
|
|
|
|
|
|
@optional_clean_ax
|
|
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
|
return ax.matshow(
|
|
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
|
|
|
|
|
@optional_clean_ax
|
|
def show_coords(m, ax):
|
|
vmin, vmax = np.min(m), np.max(m)
|
|
m = (m - vmin) / (vmax - vmin)
|
|
color_conv = get_mask_plot_colors(m.shape[-1])
|
|
color_mask = np.dot(m, color_conv)
|
|
return ax.imshow(color_mask, interpolation="nearest")
|
|
|
|
|
|
def example_plot(rinfo,
|
|
b=0,
|
|
t=-1,
|
|
mask_components=False,
|
|
size=2,
|
|
column_titles=True):
|
|
image = rinfo["data"]["image"][b, 0]
|
|
recons = rinfo["outputs"]["recons"][b, t, 0]
|
|
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
|
|
components = rinfo["outputs"]["components"][b, t]
|
|
|
|
K, H, W, C = components.shape
|
|
colors = get_mask_plot_colors(K)
|
|
|
|
nrows = 1
|
|
ncols = 3 + K
|
|
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
|
|
|
|
show_img(image, ax=axes[0], color="#000000")
|
|
show_img(recons, ax=axes[1], color="#000000")
|
|
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
|
|
for k in range(K):
|
|
mask = pred_mask[k] if mask_components else None
|
|
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
|
|
|
if column_titles:
|
|
labels = ["Image", "Recons.", "Mask"
|
|
] + ["Component {}".format(k + 1) for k in range(K)]
|
|
for ax, title in zip(axes, labels):
|
|
ax.set_title(title)
|
|
plt.subplots_adjust(hspace=0.03, wspace=0.035)
|
|
return fig
|
|
|
|
|
|
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
|
image = rinfo["data"]["image"][b]
|
|
true_mask = rinfo["data"]["true_mask"][b]
|
|
recons = rinfo["outputs"]["recons"][b]
|
|
pred_mask = rinfo["outputs"]["pred_mask"][b]
|
|
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
|
|
components = rinfo["outputs"]["components"][b]
|
|
|
|
T, K, H, W, C = components.shape
|
|
colors = get_mask_plot_colors(K)
|
|
nrows = T + 1
|
|
ncols = 2 + K
|
|
fig, axes = plt.subplots(
|
|
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
|
for t in range(T):
|
|
show_img(recons[t, 0], ax=axes[t, 0])
|
|
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
|
|
axes[t, 0].set_ylabel("iter {}".format(t))
|
|
for k in range(K):
|
|
mask = pred_mask[t, k] if mask_components else None
|
|
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
|
|
|
|
axes[0, 0].set_title("Reconstruction")
|
|
axes[0, 1].set_title("Mask")
|
|
show_img(image[0], ax=axes[T, 0])
|
|
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
|
|
vmin = np.min(pred_mask_logits[T - 1])
|
|
vmax = np.max(pred_mask_logits[T - 1])
|
|
|
|
for k in range(K):
|
|
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
|
show_mat(
|
|
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
|
|
axes[T, k + 2].set_xlabel(
|
|
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
|
|
axes[T, 0].set_xlabel("Input Image")
|
|
axes[T, 1].set_xlabel("Ground Truth Mask")
|
|
|
|
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
|
return fig
|
|
|
|
|
|
def inputs_plot(rinfo, b=0, t=0, size=2):
|
|
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
|
|
colors = get_mask_plot_colors(K)
|
|
inputs = rinfo["inputs"]["spatial"]
|
|
rows = [
|
|
("image", show_img, False),
|
|
("components", show_img, False),
|
|
("dcomponents", functools.partial(show_img, norm=True), False),
|
|
("mask", show_mat, True),
|
|
("pred_mask", show_mat, True),
|
|
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
|
|
("posterior", show_mat, True),
|
|
("log_prob", show_mat, True),
|
|
("counterfactual", show_mat, True),
|
|
("coordinates", show_coords, False),
|
|
]
|
|
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
|
|
nrows = len(rows)
|
|
ncols = K + 1
|
|
|
|
fig, axes = plt.subplots(
|
|
nrows=nrows,
|
|
ncols=ncols,
|
|
figsize=(ncols * size - size * 0.9, nrows * size),
|
|
gridspec_kw={"width_ratios": [1] * K + [0.1]},
|
|
)
|
|
for r, (name, plot_fn, make_cbar) in enumerate(rows):
|
|
axes[r, 0].set_ylabel(name)
|
|
if make_cbar:
|
|
vmin = np.min(inputs[name][b, t])
|
|
vmax = np.max(inputs[name][b, t])
|
|
if np.abs(vmin - vmax) < 1e-6:
|
|
vmin -= 0.1
|
|
vmax += 0.1
|
|
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
|
|
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
|
|
for k in range(K):
|
|
if inputs[name].shape[2] == 1:
|
|
m = inputs[name][b, t, 0]
|
|
color = (0.0, 0.0, 0.0)
|
|
else:
|
|
m = inputs[name][b, t, k]
|
|
color = colors[k]
|
|
mappable = plot_fn(m, ax=axes[r, k], color=color)
|
|
if make_cbar:
|
|
fig.colorbar(mappable, cax=axes[r, K])
|
|
else:
|
|
axes[r, K].set_visible(False)
|
|
for k in range(K):
|
|
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
|
|
|
plt.subplots_adjust(hspace=0.05, wspace=0.05)
|
|
return fig
|