You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

732 lines
30 KiB

3 years ago
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.ticker as ticker
import numpy as np
from .misc import to_numpy
plt.switch_backend('agg')
pub = False
if pub:
ext = 'pdf'
dpi = 300
else:
ext = 'png'
dpi = None
def plot_prediction_det(save_dir, target, prediction, epoch, index,
plot_fn='contourf', cmap='jet', same_scale=False, row_labels=None, col_labels=None):
"""Plot prediction for one input (`index`-th at epoch `epoch`)
Args:
save_dir: directory to save predictions
target (np.ndarray): (3, 65, 65)
prediction (np.ndarray): (3, 65, 65)
epoch (int): which epoch
index (int): i-th prediction
plot_fn (str): choices=['contourf', 'imshow']
"""
target, prediction = to_numpy(target), to_numpy(prediction)
if row_labels is not None:
rows = row_labels
else:
rows = ['Simulation', 'Prediction', r'Simulation $-$ Prediction']
if col_labels is not None:
cols = col_labels
else:
cols = ['Pressure', 'Horizontal Flux', 'Vertical Flux']
# 3 x 65 x 65
n_fields = target.shape[0]
samples = np.concatenate((target, prediction, target - prediction), axis=0)
# print(samples.shape)
interp = None
vmin, vmax = [], []
for i in range(n_fields):
vmin.append(np.amin(samples[[i, i+n_fields]]))
vmax.append(np.amax(samples[[i, i+n_fields]]))
fig, axes = plt.subplots(3, n_fields, figsize=(3.75 * n_fields, 9))
for j, ax in enumerate(fig.axes):
ax.set_aspect('equal')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if j < 2 * n_fields:
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
else:
if same_scale:
vmin_error, vmax_error = vmin[j % n_fields], vmax[j % n_fields]
else:
vmin_error, vmax_error = None, None
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap)
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp, vmin=vmin_error, vmax=vmax_error)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
for ax, col in zip(axes[0], cols):
ax.set_title(col, size='large')
for ax, row in zip(axes[:, 0], rows):
ax.set_ylabel(row, rotation=90, size='large')
# plt.suptitle(f'Epoch {epoch}')
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
# plt.subplots_adjust(top=0.93)
plt.savefig(save_dir + '/pred_epoch{}_{}.{}'.format(epoch, index, ext),
dpi=dpi, bbox_inches='tight')
plt.close(fig)
def plot_prediction_det_animate2(save_dir, target, prediction, epoch, index, i_plot,
plot_fn='imshow', cmap='jet', same_scale=False,
vmax=None, vmin=None, vmax_err=None, vmin_err=None):
"""Plot prediction for one input (`index`-th at epoch `epoch`)
Args:
save_dir: directory to save predictions
target (np.ndarray): (3, 65, 65)
prediction (np.ndarray): (3, 65, 65)
epoch (int): which epoch
index (int): i-th prediction
plot_fn (str): choices=['contourf', 'imshow']
"""
target, prediction = to_numpy(target), to_numpy(prediction)
rows = ['Simulation', 'Prediction', 'Abs Error']
cols = ['Pressure', 'Horizontal Flux', 'Vertical Flux']
# 3 x 65 x 65
n_fields = target.shape[0]
samples = np.concatenate((target, prediction, abs(target - prediction)), axis=0)
# print(samples.shape)
interp = None
if vmax is None:
vmin, vmax = [], []
for i in range(n_fields):
vmin.append(np.amin(samples[[i, i+n_fields]]))
vmax.append(np.amax(samples[[i, i+n_fields]]))
fig, axes = plt.subplots(3, n_fields, figsize=(3.5 * n_fields, 9))
for j, ax in enumerate(fig.axes):
ax.set_aspect('equal')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if j < 2 * n_fields:
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
else:
if same_scale:
vmin_error, vmax_error = vmin[j % n_fields], vmax[j % n_fields]
else:
vmin_error, vmax_error = None, None
if vmax_err is not None:
vmin_error, vmax_error = vmin_err[j % n_fields], vmax_err[j % n_fields]
# if j == 8:
# vmin_error = None
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap)
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp, vmin=vmin_error, vmax=vmax_error)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
for ax, col in zip(axes[0], cols):
ax.set_title(col)
for ax, row in zip(axes[:, 0], rows):
ax.set_ylabel(row, rotation=90)
# plt.suptitle(f'Epoch {epoch}')
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
plt.subplots_adjust(top=0.93)
plt.savefig(save_dir + '/pred_{}_{}.{}'.format(index, i_plot, ext),
dpi=dpi, bbox_inches='tight')
# plt.savefig(save_dir + '/pred_epoch{}_{}.{}'.format(epoch, index, ext),
# dpi=dpi, bbox_inches='tight')
plt.close(fig)
def plot_prediction_bayes2(save_dir, target, pred_mean, pred_var, epoch, index,
plot_fn='imshow', cmap='jet', same_scale=False):
"""Plot prediction for one input (`index`-th at epoch `epoch`)
Args:
save_dir: directory to save predictions
target (np.ndarray): (3, 65, 65)
prediction (np.ndarray): (3, 65, 65)
epoch (int): which epoch
index (int): i-th prediction
plot_fn (str): choices=['contourf', 'imshow']
"""
target, pred_mean, pred_std = to_numpy(target), to_numpy(pred_mean), np.sqrt(to_numpy(pred_var))
rows = ['Simulation', 'Pred Mean', 'Pred Std', r'Sim $-$ Pred Mean']
cols = ['Pressure', 'Horizontal Flux', 'Vertical Flux']
# 3 x 65 x 65
n_fields = target.shape[0]
# 4, 3, 65, 65
samples = np.stack((target, pred_mean, pred_std, target - pred_mean), axis=0)
nrows = samples.shape[0]
# print(samples.shape)
interp = None
vmin, vmax = [], []
for j in range(n_fields):
vmin.append(np.amin(samples[[0, 1], j]))
vmax.append(np.amax(samples[[0, 1], j]))
# vmin.append(np.amin(samples[[i, i+n_fields]]))
# vmax.append(np.amax(samples[[i, i+n_fields]]))
fig, axes = plt.subplots(samples.shape[0], n_fields, figsize=(3.75 * n_fields, 3 * nrows))
for i in range(nrows):
for j in range(n_fields):
ax = axes[i, j]
# for j, ax in enumerate(fig.axes):
ax.set_aspect('equal')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if i < 2:
if plot_fn == 'contourf':
cax = ax.contourf(samples[i, j], 50, cmap=cmap,
vmin=vmin[j], vmax=vmax[j])
elif plot_fn =='imshow':
cax = ax.imshow(samples[i, j], cmap=cmap, origin='upper',
interpolation=interp,
vmin=vmin[j], vmax=vmax[j])
else:
if same_scale:
vmin_error, vmax_error = vmin[j], vmax[j]
else:
vmin_error, vmax_error = None, None
if plot_fn == 'contourf':
cax = ax.contourf(samples[i, j], 50, cmap=cmap)
elif plot_fn =='imshow':
cax = ax.imshow(samples[i, j], cmap=cmap, origin='upper',
interpolation=interp, vmin=vmin_error, vmax=vmax_error)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
for ax, col in zip(axes[0], cols):
ax.set_title(col, size='large')
for ax, row in zip(axes[:, 0], rows):
ax.set_ylabel(row, rotation=90, size='large')
# plt.suptitle(f'Epoch {epoch}')
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
# plt.subplots_adjust(top=0.93)
plt.savefig(save_dir + '/pred_epoch{}_{}.{}'.format(epoch, index, ext),
dpi=dpi, bbox_inches='tight')
plt.close(fig)
def save_stats(save_dir, logger, *metrics):
for metric in metrics:
metric_list = logger[metric]
np.savetxt(save_dir + f'/{metric}.txt', metric_list)
# plot stats
# metric_arr = np.loadtxt(save_dir + f'/{metric}.txt')
# if len(metric_arr.shape) == 1:
# metric_arr = metric_arr[:, None]
# lines = plt.plot(range(1, len(metric_arr)+1), metric_arr)
# labels = [f'{metric_arr[-5:, i].mean():.4f}' for i in range(metric_arr.shape[-1])]
# plt.legend(lines, labels)
# plt.savefig(save_dir + f'/{metric}.pdf')
# plt.close()
def plot_prediction_bayes(save_dir, target, pred_mean, pred_var, epoch, index,
plot_fn='contourf'):
"""Plot predictions at *one* test input
Args:
save_dir: directory to save predictions
target (np.ndarray or torch.Tensor): (3, 65, 65)
pred_mean (np.ndarray or torch.Tensor): (3, 65, 65)
pred_var (np.ndarray or torch.Tensor): (3, 65, 65)
epoch (int): which epoch
index (int): i-th prediction
plot_fn (str): choices=['contourf', 'imshow']
"""
target, pred_mean, pred_var = to_numpy(target), to_numpy(pred_mean), to_numpy(pred_var)
pred_error = target - pred_mean
two_sigma = np.sqrt(pred_var) * 2
# target: C x H x W
sfmt = ticker.ScalarFormatter(useMathText=True)
sfmt.set_powerlimits((-2, 2))
cmap = 'jet'
interpolation = None
fig = plt.figure(1, (11, 12))
axes_pad = 0.25
cbar_pad = 0.1
label_size = 6
subplots_position = ['23{}'.format(i) for i in range(1, 7)]
for i, subplot_i in enumerate(subplots_position):
if i < 3:
# share one colorbar
grid = ImageGrid(fig, subplot_i, # as in plt.subplot(111)
nrows_ncols=(2, 1),
axes_pad=axes_pad,
share_all=False,
cbar_location="right",
cbar_mode="single",
cbar_size="3%",
cbar_pad=cbar_pad,
)
data = (target[i], pred_mean[i])
channel = np.concatenate(data)
vmin, vmax = np.amin(channel), np.amax(channel)
# Add data to image grid
for j, ax in enumerate(grid):
if plot_fn == 'contourf':
im = ax.contourf(data[j], 50, vmin=vmin, vmax=vmax, cmap=cmap)
for c in im.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
elif plot_fn == 'imshow':
im = ax.imshow(data[j], vmin=vmin, vmax=vmax,
interpolation=interpolation, cmap=cmap)
ax.set_axis_off()
# ticks=np.linspace(vmin, vmax, 10)
#set_ticks, set_ticklabels
cbar = grid.cbar_axes[0].colorbar(im, format=sfmt)
# cbar.ax.set_yticks((vmin, vmax))
cbar.ax.yaxis.set_offset_position('left')
cbar.ax.tick_params(labelsize=label_size)
cbar.ax.toggle_label(True)
else:
grid = ImageGrid(fig, subplot_i, # as in plt.subplot(111)
nrows_ncols=(2, 1),
axes_pad=axes_pad,
share_all=False,
cbar_location="right",
cbar_mode="each",
cbar_size="6%",
cbar_pad=cbar_pad,
)
data = (pred_error[i-3], two_sigma[i-3])
# channel = np.concatenate(data)
# vmin, vmax = np.amin(channel), np.amax(channel)
# Add data to image grid
for j, ax in enumerate(grid):
if plot_fn == 'contourf':
im = ax.contourf(data[j], 50, cmap=cmap)
for c in im.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
elif plot_fn == 'imshow':
im = ax.imshow(data[j], interpolation=interpolation, cmap=cmap)
ax.set_axis_off()
cbar = grid.cbar_axes[j].colorbar(im, format=sfmt)
grid.cbar_axes[j].tick_params(labelsize=label_size)
grid.cbar_axes[j].toggle_label(True)
# cbar.formatter.set_powerlimits((0, 0))
cbar.ax.yaxis.set_offset_position('left')
# print(dir(cbar.ax.yaxis))
# cbar.update_ticks()
# plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
fig.subplots_adjust(wspace=0.075, hspace=0.075)
plt.savefig(save_dir + '/pred_at_x_epoch{}_{}.{}'.format(epoch, index, ext),
dpi=dpi, bbox_inches='tight')
plt.close(fig)
def plot_MC(save_dir, monte_carlo, pred_mean, pred_var, mean, n_train):
"""Plot Monte Carlo Output
Args:
monte_carlo (np.ndarray or torch.Tensor): simulation output
pred_mean (np.ndarray or torch.Tensor): from surrogate
pred_var (np.ndarray or torch.Tensor): predictive var using surrogate
mean (bool): Used in printing. True for plotting mean, False for var
"""
monte_carlo, pred_mean, pred_var = to_numpy(monte_carlo), \
to_numpy(pred_mean), to_numpy(pred_var)
two_sigma = 2 * np.sqrt(pred_var)
# target: C x H x W
sfmt = ticker.ScalarFormatter(useMathText=True)
sfmt.set_powerlimits((0, 0))
cmap = 'jet'
interpolation = 'bilinear'
pred_error = monte_carlo - pred_mean
fig = plt.figure(1, (10, 10))
axes_pad = 0.25
cbar_pad = 0.1
label_size = 6
subplots_position = ['23{}'.format(i) for i in range(1, 7)]
for i, subplot_i in enumerate(subplots_position):
if i < 3:
# share one colorbar
grid = ImageGrid(fig, subplot_i, # as in plt.subplot(111)
nrows_ncols=(2, 1),
axes_pad=axes_pad,
share_all=False,
cbar_location="right",
cbar_mode="single",
cbar_size="3%",
cbar_pad=cbar_pad,
)
data = (monte_carlo[i], pred_mean[i])
channel = np.concatenate(data)
vmin, vmax = np.amin(channel), np.amax(channel)
# Add data to image grid
for j, ax in enumerate(grid):
# im = ax.imshow(data[j], vmin=vmin, vmax=vmax,
# interpolation=interpolation, cmap=cmap)
im = ax.contourf(data[j], 50, vmin=vmin, vmax=vmax, cmap=cmap)
for c in im.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
ax.set_axis_off()
# ticks=np.linspace(vmin, vmax, 10)
#set_ticks, set_ticklabels
cbar = grid.cbar_axes[0].colorbar(im, format=sfmt)
# cbar.ax.set_yticks((vmin, vmax))
cbar.ax.tick_params(labelsize=label_size)
cbar.ax.yaxis.set_offset_position('left')
cbar.ax.toggle_label(True)
else:
grid = ImageGrid(fig, subplot_i, # as in plt.subplot(111)
nrows_ncols=(2, 1),
axes_pad=axes_pad,
share_all=False,
cbar_location="right",
cbar_mode="each",
cbar_size="6%",
cbar_pad=cbar_pad,
)
data = (pred_error[i-3], two_sigma[i-3])
# channel = np.concatenate(data)
# vmin, vmax = np.amin(channel), np.amax(channel)
# Add data to image grid
for j, ax in enumerate(grid):
# im = ax.imshow(data[j], interpolation=interpolation, cmap=cmap)
im = ax.contourf(data[j], 50, cmap=cmap)
for c in im.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
ax.set_axis_off()
cbar = grid.cbar_axes[j].colorbar(im, format=sfmt)
grid.cbar_axes[j].tick_params(labelsize=label_size)
grid.cbar_axes[j].toggle_label(True)
# cbar.formatter.set_powerlimits((0, 0))
cbar.ax.yaxis.set_offset_position('left')
# print(dir(cbar.ax.yaxis))
# cbar.update_ticks()
# plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
fig.subplots_adjust(wspace=0.075, hspace=0.075)
plt.savefig(save_dir + '/pred_{}_vs_MC.pdf'.format('mean' if mean else 'var'),
dpi=300, bbox_inches='tight')
plt.close(fig)
print("Done plotting Pred_{}_vs_MC, num of training: {}"
.format('mean' if mean else 'var', n_train))
def plot_MC2(save_dir, monte_carlo, pred_mean, pred_var, mean, ntrain,
plot_fn='imshow', cmap='jet', manual_scale=False, same_scale=False):
"""Plot Monte Carlo Output
Args:
monte_carlo (np.ndarray or torch.Tensor): simulation output
pred_mean (np.ndarray or torch.Tensor): from surrogate
pred_var (np.ndarray or torch.Tensor): predictive var using surrogate
mean (bool): Used in printing. True for plotting mean, False for var
"""
target, pred_mean, pred_std = to_numpy(monte_carlo), to_numpy(pred_mean), np.sqrt(to_numpy(pred_var))
if mean:
rows = ['Monte Carlo', 'Mean of Est. Mean', '2 Std of Est. Mean', 'Row1 - Row2']
else:
rows = ['Monte Carlo', 'Mean of Est. Variance', '2 Std of Est. Variance', 'Row1 - Row2']
cols = ['Pressure', 'Horizontal Flux', 'Vertical Flux']
# 3 x 65 x 65
n_fields = target.shape[0]
# 4, 3, 65, 65
samples = np.stack((target, pred_mean, pred_std * 2, target - pred_mean), axis=0)
nrows = samples.shape[0]
# print(samples.shape)
interp = None
vmin, vmax = [], []
for j in range(n_fields):
vmin.append(np.amin(samples[[0, 1], j]))
vmax.append(np.amax(samples[[0, 1], j]))
# manually set the vmin and vmax
if manual_scale and mean:
vmin[1], vmax[1] = 1.0, 1.1
# vmin[2], vmax[2] = -0.05, 0.05
fig, axes = plt.subplots(samples.shape[0], n_fields, figsize=(3.75 * n_fields, 3 * nrows))
for i in range(nrows):
for j in range(n_fields):
ax = axes[i, j]
# for j, ax in enumerate(fig.axes):
ax.set_aspect('equal')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if i < 2:
if plot_fn == 'contourf':
cax = ax.contourf(samples[i, j], 50, cmap=cmap,
vmin=vmin[j], vmax=vmax[j])
elif plot_fn =='imshow':
cax = ax.imshow(samples[i, j], cmap=cmap, origin='upper',
interpolation=interp,
vmin=vmin[j], vmax=vmax[j])
else:
if same_scale:
vmin_error, vmax_error = vmin[j], vmax[j]
else:
vmin_error, vmax_error = None, None
if plot_fn == 'contourf':
cax = ax.contourf(samples[i, j], 50, cmap=cmap)
elif plot_fn =='imshow':
cax = ax.imshow(samples[i, j], cmap=cmap, origin='upper',
interpolation=interp, vmin=vmin_error, vmax=vmax_error)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
for ax, col in zip(axes[0], cols):
ax.set_title(col, size='large')
for ax, row in zip(axes[:, 0], rows):
ax.set_ylabel(row, rotation=90, size='large')
# plt.suptitle(f'Epoch {epoch}')
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
# plt.subplots_adjust(top=0.93)
plt.savefig(save_dir + '/pred_{}_vs_MC.pdf'.format('mean' if mean else 'var'),
dpi=300, bbox_inches='tight')
plt.close(fig)
print("Done plotting Pred_{}_vs_MC, num of training: {}"
.format('mean' if mean else 'var', ntrain))
def plot_UP(save_dir, monte_carlo, surr_mean, is_mean,
plot_fn='imshow', cmap='jet', same_scale=False):
"""Plot uncertainty propagation, for deep ensembles. Only mean estimate,
no variance for each estimate.
Args:
save_dir: directory to save predictions
target (np.ndarray): (3, 65, 65)
prediction (np.ndarray): (3, 65, 65)
epoch (int): which epoch
index (int): i-th prediction
plot_fn (str): choices=['contourf', 'imshow']
"""
target, prediction = to_numpy(monte_carlo), to_numpy(surr_mean)
rows = ['Simulator', 'Surrogate', r'Row1 $-$ Row2']
cols = ['Pressure', 'Horizontal Flux', 'Vertical Flux']
# 3 x 65 x 65
n_fields = target.shape[0]
samples = np.concatenate((target, prediction, target - prediction), axis=0)
# print(samples.shape)
interp = None
vmin, vmax = [], []
for i in range(n_fields):
vmin.append(np.amin(samples[[i, i+n_fields]]))
vmax.append(np.amax(samples[[i, i+n_fields]]))
fig, axes = plt.subplots(3, n_fields, figsize=(3.75 * n_fields, 9))
for j, ax in enumerate(fig.axes):
ax.set_aspect('equal')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if j < 2 * n_fields:
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp,
vmin=vmin[j % n_fields], vmax=vmax[j % n_fields])
else:
if same_scale:
vmin_error, vmax_error = vmin[j % n_fields], vmax[j % n_fields]
else:
vmin_error, vmax_error = None, None
if plot_fn == 'contourf':
cax = ax.contourf(samples[j], 50, cmap=cmap)
elif plot_fn =='imshow':
cax = ax.imshow(samples[j], cmap=cmap, origin='upper',
interpolation=interp, vmin=vmin_error, vmax=vmax_error)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
for ax, col in zip(axes[0], cols):
ax.set_title(col, size='large')
for ax, row in zip(axes[:, 0], rows):
ax.set_ylabel(row, rotation=90, size='large')
# plt.suptitle(f'Epoch {epoch}')
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
# plt.subplots_adjust(top=0.93)
plt.savefig(save_dir + '/pred_{}_vs_MC.pdf'.format('mean' if is_mean else 'var'),
dpi=300, bbox_inches='tight')
plt.close(fig)
print("Done plotting Pred_{}_vs_MC".format('mean' if is_mean else 'var'))
def save_samples(save_dir, images, epoch, index, name, nrow=4, heatmap=True, cmap='jet', title=False):
"""Save samples in grid as images or plots
Args:
images (Tensor): B x C x H x W
"""
# if images.shape[0] < 10:
# nrow = 2
# ncol = images.shape[0] // nrow
# else:
# ncol = nrow
images = to_numpy(images)
ncol = images.shape[0] // nrow
if heatmap:
for c in range(images.shape[1]):
# (11, 12)
fig = plt.figure(1, (12, 12))
grid = ImageGrid(fig, 111,
nrows_ncols=(nrow, ncol),
axes_pad=0.1,
share_all=False,
cbar_location="top",
cbar_mode="single",
cbar_size="3%",
cbar_pad=0.1
)
for j, ax in enumerate(grid):
im = ax.imshow(images[j][c], cmap=cmap)
ax.set_axis_off()
ax.set_aspect('equal')
cbar = grid.cbar_axes[0].colorbar(im)
cbar.ax.tick_params(labelsize=10)
cbar.ax.toggle_label(True)
# change plot back to epoch
if title:
plt.suptitle(f'Epoch {epoch}')
plt.subplots_adjust(top=0.95)
plt.savefig(save_dir + '/epoch{}_{}_c{}_index{}.png'.format(epoch, name, c, index),
bbox_inches='tight')
plt.close(fig)
else:
torchvision.utils.save_image(images,
save_dir + '/fake_samples_epoch_{}.png'.format(epoch),
nrow=nrow,
normalize=True)
def plot_row(arrs, save_dir, filename, same_range=False, plot_fn='imshow',
cmap='viridis'):
"""
Args:
arrs (sequence of 2D Tensor or Numpy): seq of arrs to be plotted
save_dir (str):
filename (str):
same_range (bool): if True, subplots have the same range (colorbar)
plot_fn (str): choices=['imshow', 'contourf']
"""
interpolation = None
arrs = [to_numpy(arr) for arr in arrs]
if same_range:
vmax = max([np.amax(arr) for arr in arrs])
vmin = min([np.amin(arr) for arr in arrs])
else:
vmax, vmin = None, None
fig, _ = plt.subplots(1, len(arrs), figsize=(4.4 * len(arrs), 4))
for i, ax in enumerate(fig.axes):
if plot_fn == 'imshow':
cax = ax.imshow(arrs[i], cmap=cmap, interpolation=interpolation,
vmin=vmin, vmax=vmax)
elif plot_fn == 'contourf':
cax = ax.contourf(arrs[i], 50, cmap=cmap, vmin=vmin, vmax=vmax)
if plot_fn == 'contourf':
for c in cax.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
ax.set_axis_off()
cbar = plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((-2, 2))
cbar.ax.yaxis.set_offset_position('left')
# cbar.ax.tick_params(labelsize=5)
cbar.update_ticks()
plt.tight_layout(pad=0.05, w_pad=0.05, h_pad=0.05)
plt.savefig(save_dir + f'/{filename}.{ext}', dpi=dpi, bbox_inches='tight')
plt.close(fig)