You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
12 KiB

3 years ago
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import scipy.io
from scipy.stats import norm as scipy_norm
import seaborn as sns
from utils.misc import mkdir, to_numpy
from utils.plot import plot_prediction_bayes2, plot_MC2, save_samples
from utils.lhs import lhs
plt.switch_backend('agg')
class UQ_CondGlow(object):
"""Class for uncertainty quantification tasks, include:
- prediction at one input realization
- uncertainty propagation
- distribution estimate at certain location
- reliability diagram (assess uncertainty quality)
Args:
model: Pre-trained probabilistic surrogate
args: training arguments
mc_loader (utils.data.DataLoader): Dataloader for Monte Carlo data
"""
def __init__(self, model, args, mc_loader, test_loader, y_test_variation,
n_samples=20, temperature=1.0):
self.model = model
self.mc_loader = mc_loader
self.test_loader = test_loader
self.y_test_variation = y_test_variation
self.ntrain = args.ntrain
self.plot_fn = args.plot_fn
self.epochs = args.epochs
self.device = args.device
self.post_dir = args.post_dir
self.imsize = args.imsize
self.n_samples = n_samples
self.temperature = temperature
print(f'mc loader size: {len(self.mc_loader.dataset)}')
print(f'test loader size: {len(self.test_loader.dataset)}')
def plot_prediction_at_x(self, n_pred, plot_samples=False):
r"""Plot `n_pred` predictions for randomly selected input from test dataset.
- target
- predictive mean
- standard deviation of predictive output distribution
- error of the above two
Args:
n_pred: number of candidate predictions
plot_samples (bool): plot 15 output samples from p(y|x) for given x
"""
save_dir = self.post_dir + '/predict_at_x'
mkdir(save_dir)
print('Plotting predictions at x from test dataset..................')
np.random.seed(1)
idx = np.random.permutation(len(self.test_loader.dataset))[:n_pred]
for i in idx:
print('input index: {}'.format(i))
input, target = self.test_loader.dataset[i]
pred_mean, pred_var = self.model.predict(input.unsqueeze(0).to(self.device),
n_samples=self.n_samples, temperature=self.temperature)
plot_prediction_bayes2(save_dir, target, pred_mean.squeeze(0),
pred_var.squeeze(0), self.epochs, i, plot_fn=self.plot_fn)
if plot_samples:
samples_pred = self.model.sample(input.unsqueeze(0).to(self.device),
n_samples=15)[:, 0]
samples = torch.cat((target.unsqueeze(0), samples_pred.detach().cpu()), 0)
save_samples(save_dir, samples, self.epochs, i,
'samples', nrow=4, heatmap=True, cmap='jet')
def propagate_uncertainty(self, manual_scale=False, var_samples=10):
print("Propagate Uncertainty using pre-trained surrogate ...........")
# compute MC sample mean and variance in mini-batch
sample_mean_x = torch.zeros_like(self.mc_loader.dataset[0][0])
sample_var_x = torch.zeros_like(sample_mean_x)
sample_mean_y = torch.zeros_like(self.mc_loader.dataset[0][1])
sample_var_y = torch.zeros_like(sample_mean_y)
for _, (x_test_mc, y_test_mc) in enumerate(self.mc_loader):
x_test_mc, y_test_mc = x_test_mc, y_test_mc
sample_mean_x += x_test_mc.mean(0)
sample_mean_y += y_test_mc.mean(0)
sample_mean_x /= len(self.mc_loader)
sample_mean_y /= len(self.mc_loader)
for _, (x_test_mc, y_test_mc) in enumerate(self.mc_loader):
x_test_mc, y_test_mc = x_test_mc, y_test_mc
sample_var_x += ((x_test_mc - sample_mean_x) ** 2).mean(0)
sample_var_y += ((y_test_mc - sample_mean_y) ** 2).mean(0)
sample_var_x /= len(self.mc_loader)
sample_var_y /= len(self.mc_loader)
# plot input MC
stats_x = torch.stack((sample_mean_x, sample_var_x)).cpu().numpy()
fig, _ = plt.subplots(1, 2)
for i, ax in enumerate(fig.axes):
# ax.set_title(titles[i])
ax.set_aspect('equal')
ax.set_axis_off()
# im = ax.imshow(stats_x[i].squeeze(0),
# interpolation='bilinear', cmap=self.args.cmap)
im = ax.contourf(stats_x[i].squeeze(0), 50, cmap='jet')
for c in im.collections:
c.set_edgecolor("face")
c.set_linewidth(0.000000000001)
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04,
format=ticker.ScalarFormatter(useMathText=True))
cbar.formatter.set_powerlimits((0, 0))
cbar.ax.yaxis.set_offset_position('left')
cbar.update_ticks()
plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
out_stats_dir = self.post_dir + '/out_stats'
mkdir(out_stats_dir)
plt.savefig(out_stats_dir + '/input_MC.pdf', di=300, bbox_inches='tight')
plt.close(fig)
print("Done plotting input MC, num of training: {}".format(self.ntrain))
# MC surrogate predictions
y_pred_EE, y_pred_VE, y_pred_EV, y_pred_VV = self.model.propagate(
self.mc_loader, n_samples=self.n_samples,
temperature=self.temperature, var_samples=var_samples)
print('Done MC predictions')
# plot the 4 output stats
# plot the predictive mean
plot_MC2(out_stats_dir, sample_mean_y, y_pred_EE, y_pred_VE, True,
self.ntrain, manual_scale=manual_scale)
# plot the predictive var
plot_MC2(out_stats_dir, sample_var_y, y_pred_EV, y_pred_VV, False,
self.ntrain)
# save for MATLAB plotting
scipy.io.savemat(out_stats_dir + '/out_stats.mat',
{'sample_mean': sample_mean_y.cpu().numpy(),
'sample_var': sample_var_y.cpu().numpy(),
'y_pred_EE': y_pred_EE.cpu().numpy(),
'y_pred_VE': y_pred_VE.cpu().numpy(),
'y_pred_EV': y_pred_EV.cpu().numpy(),
'y_pred_VV': y_pred_VV.cpu().numpy()})
print('saved output stats to .mat file')
def plot_dist(self, num_loc):
"""Plot distribution estimate in `num_loc` locations in the domain,
which are chosen by Latin Hypercube Sampling.
Args:
num_loc (int): number of locations where distribution is estimated
"""
print('Plotting distribution estimate.................................')
assert num_loc > 0, 'num_loc must be greater than zero'
locations = lhs(2, num_loc, criterion='c')
print('Locations selected by LHS: \n{}'.format(locations))
# location (ndarray): [0, 1] x [0, 1]: N x 2
idx = (locations * self.imsize).astype(int)
print('Propagating...')
pred, target = [], []
for _, (x_mc, t_mc) in enumerate(self.mc_loader):
x_mc = x_mc.to(self.device)
# S x B x C x H x W
y_mc = self.model.sample(x_mc, n_samples=self.n_samples,
temperature=self.temperature)
# S x B x C x n_points
pred.append(y_mc[:, :, :, idx[:, 0], idx[:, 1]])
# B x C x n_points
target.append(t_mc[:, :, idx[:, 0], idx[:, 1]])
# S x M x C x n_points --> M x C x n_points
pred = torch.cat(pred, dim=1).mean(0).cpu().numpy()
print('pred size: {}'.format(pred.shape))
# M x C x n_points
target = torch.cat(target, dim=0).cpu().numpy()
print('target shape: {}'.format(target.shape))
dist_dir = self.post_dir + '/dist_estimate'
mkdir(dist_dir)
for loc in range(locations.shape[0]):
print(loc)
fig, _ = plt.subplots(1, 3, figsize=(12, 4))
for c, ax in enumerate(fig.axes):
sns.kdeplot(target[:, c, loc], color='b', ls='--', label='Monte Carlo', ax=ax)
sns.kdeplot(pred[:, c, loc], color='r', label='Surrogate', ax=ax)
ax.legend()
plt.savefig(dist_dir + '/loc_({:.5f}, {:.5f}).pdf'
.format(locations[loc][0], locations[loc][1]), dpi=300)
plt.close(fig)
def plot_reliability_diagram(self, label='Conditional Glow', save_time=True):
print("Plotting reliability diagram..................................")
# percentage: p
# p_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
p_list = np.linspace(0.01, 0.99, 10)
freq = []
n_channels = self.mc_loader.dataset[0][1].shape[0]
for p in p_list:
count = 0
numels = 0
for batch_idx, (input, target) in enumerate(self.mc_loader):
# only evaluate 2000 of the MC data to save time
if save_time and batch_idx > 4:
continue
pred_mean, pred_var = self.model.predict(input.to(self.device),
n_samples=self.n_samples, temperature=self.temperature)
interval = scipy_norm.interval(p, loc=pred_mean.cpu().numpy(),
scale=pred_var.sqrt().cpu().numpy())
count += ((target.numpy() >= interval[0])
& (target.numpy() <= interval[1])).sum(axis=(0, 2, 3))
numels += target.numel() / n_channels
print('p: {}, {} / {} = {}'.format(p, count, numels,
np.true_divide(count, numels)))
freq.append(np.true_divide(count, numels))
reliability_dir = self.post_dir + '/uncertainty_quality'
mkdir(reliability_dir)
freq = np.stack(freq, 0)
for i in range(freq.shape[-1]):
plt.figure()
plt.plot(p_list, freq[:, i], 'r', label=label)
plt.xlabel('Probability')
plt.ylabel('Frequency')
x = np.linspace(0, 1, 100)
plt.plot(x, x, 'k--', label='Ideal')
plt.legend(loc='upper left')
plt.savefig(reliability_dir + f"/reliability_diagram_{i}.pdf", dpi=300)
plt.close()
reliability = np.zeros((p_list.shape[0], 1+n_channels))
reliability[:, 0] = p_list
reliability[:, 1:] = freq
np.savetxt(reliability_dir + "/reliability_diagram.txt", reliability)
plt.close()
def test_metric(self, handle_nan=True):
relative_l2, err2 = [], []
num_nan_inf = 0
for batch_idx, (input, target) in enumerate(self.test_loader):
input, target = input.to(self.device), target.to(self.device)
pred_mean, pred_var = self.model.predict(input, n_samples=self.n_samples,
temperature=self.temperature)
# handling nan, inf
if handle_nan:
exception = torch.isnan(pred_mean) + torch.isinf(pred_mean)
exception = exception.sum((1, 2, 3)).gt(0)
normal = (1 - exception)
# print(normal)
normal_idx = torch.arange(len(normal)).to(self.device).masked_select(normal).to(torch.long)
# print(normal_idx)
pred_mean, target = pred_mean.index_select(0, normal_idx), target.index_select(0, normal_idx)
num_nan_inf += exception.sum()
# print(pred_mean.shape)
err2_sum = torch.sum((pred_mean - target) ** 2, [-1, -2])
relative_l2.append(torch.sqrt(err2_sum / (target ** 2).sum([-1, -2])))
err2.append(err2_sum)
relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / self.y_test_variation
print(relative_l2)
print(r2_score)
np.savetxt(self.post_dir + '/nrmse_test.txt', relative_l2)
np.savetxt(self.post_dir + '/r2_test.txt', r2_score)
if handle_nan:
abnormal_rate = num_nan_inf / len(self.test_loader.dataset)
print(f'num_nan_inf: {num_nan_inf}')
print(f'abnormal rate: {abnormal_rate:.6f}')
np.savetxt(self.post_dir + '/log_stats.txt',
[num_nan_inf, len(self.test_loader.dataset), abnormal_rate])