You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
279 lines
12 KiB
279 lines
12 KiB
3 years ago
|
import torch
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.ticker as ticker
|
||
|
import scipy.io
|
||
|
from scipy.stats import norm as scipy_norm
|
||
|
import seaborn as sns
|
||
|
from utils.misc import mkdir, to_numpy
|
||
|
from utils.plot import plot_prediction_bayes2, plot_MC2, save_samples
|
||
|
from utils.lhs import lhs
|
||
|
plt.switch_backend('agg')
|
||
|
|
||
|
|
||
|
class UQ_CondGlow(object):
|
||
|
"""Class for uncertainty quantification tasks, include:
|
||
|
|
||
|
- prediction at one input realization
|
||
|
- uncertainty propagation
|
||
|
- distribution estimate at certain location
|
||
|
- reliability diagram (assess uncertainty quality)
|
||
|
Args:
|
||
|
model: Pre-trained probabilistic surrogate
|
||
|
args: training arguments
|
||
|
mc_loader (utils.data.DataLoader): Dataloader for Monte Carlo data
|
||
|
"""
|
||
|
def __init__(self, model, args, mc_loader, test_loader, y_test_variation,
|
||
|
n_samples=20, temperature=1.0):
|
||
|
self.model = model
|
||
|
self.mc_loader = mc_loader
|
||
|
self.test_loader = test_loader
|
||
|
self.y_test_variation = y_test_variation
|
||
|
self.ntrain = args.ntrain
|
||
|
self.plot_fn = args.plot_fn
|
||
|
self.epochs = args.epochs
|
||
|
self.device = args.device
|
||
|
self.post_dir = args.post_dir
|
||
|
self.imsize = args.imsize
|
||
|
self.n_samples = n_samples
|
||
|
self.temperature = temperature
|
||
|
|
||
|
print(f'mc loader size: {len(self.mc_loader.dataset)}')
|
||
|
print(f'test loader size: {len(self.test_loader.dataset)}')
|
||
|
|
||
|
|
||
|
def plot_prediction_at_x(self, n_pred, plot_samples=False):
|
||
|
r"""Plot `n_pred` predictions for randomly selected input from test dataset.
|
||
|
- target
|
||
|
- predictive mean
|
||
|
- standard deviation of predictive output distribution
|
||
|
- error of the above two
|
||
|
|
||
|
Args:
|
||
|
n_pred: number of candidate predictions
|
||
|
plot_samples (bool): plot 15 output samples from p(y|x) for given x
|
||
|
"""
|
||
|
save_dir = self.post_dir + '/predict_at_x'
|
||
|
mkdir(save_dir)
|
||
|
print('Plotting predictions at x from test dataset..................')
|
||
|
np.random.seed(1)
|
||
|
idx = np.random.permutation(len(self.test_loader.dataset))[:n_pred]
|
||
|
for i in idx:
|
||
|
print('input index: {}'.format(i))
|
||
|
input, target = self.test_loader.dataset[i]
|
||
|
pred_mean, pred_var = self.model.predict(input.unsqueeze(0).to(self.device),
|
||
|
n_samples=self.n_samples, temperature=self.temperature)
|
||
|
|
||
|
plot_prediction_bayes2(save_dir, target, pred_mean.squeeze(0),
|
||
|
pred_var.squeeze(0), self.epochs, i, plot_fn=self.plot_fn)
|
||
|
if plot_samples:
|
||
|
samples_pred = self.model.sample(input.unsqueeze(0).to(self.device),
|
||
|
n_samples=15)[:, 0]
|
||
|
samples = torch.cat((target.unsqueeze(0), samples_pred.detach().cpu()), 0)
|
||
|
save_samples(save_dir, samples, self.epochs, i,
|
||
|
'samples', nrow=4, heatmap=True, cmap='jet')
|
||
|
|
||
|
|
||
|
def propagate_uncertainty(self, manual_scale=False, var_samples=10):
|
||
|
print("Propagate Uncertainty using pre-trained surrogate ...........")
|
||
|
# compute MC sample mean and variance in mini-batch
|
||
|
sample_mean_x = torch.zeros_like(self.mc_loader.dataset[0][0])
|
||
|
sample_var_x = torch.zeros_like(sample_mean_x)
|
||
|
sample_mean_y = torch.zeros_like(self.mc_loader.dataset[0][1])
|
||
|
sample_var_y = torch.zeros_like(sample_mean_y)
|
||
|
|
||
|
for _, (x_test_mc, y_test_mc) in enumerate(self.mc_loader):
|
||
|
x_test_mc, y_test_mc = x_test_mc, y_test_mc
|
||
|
sample_mean_x += x_test_mc.mean(0)
|
||
|
sample_mean_y += y_test_mc.mean(0)
|
||
|
sample_mean_x /= len(self.mc_loader)
|
||
|
sample_mean_y /= len(self.mc_loader)
|
||
|
|
||
|
for _, (x_test_mc, y_test_mc) in enumerate(self.mc_loader):
|
||
|
x_test_mc, y_test_mc = x_test_mc, y_test_mc
|
||
|
sample_var_x += ((x_test_mc - sample_mean_x) ** 2).mean(0)
|
||
|
sample_var_y += ((y_test_mc - sample_mean_y) ** 2).mean(0)
|
||
|
sample_var_x /= len(self.mc_loader)
|
||
|
sample_var_y /= len(self.mc_loader)
|
||
|
|
||
|
# plot input MC
|
||
|
stats_x = torch.stack((sample_mean_x, sample_var_x)).cpu().numpy()
|
||
|
fig, _ = plt.subplots(1, 2)
|
||
|
for i, ax in enumerate(fig.axes):
|
||
|
# ax.set_title(titles[i])
|
||
|
ax.set_aspect('equal')
|
||
|
ax.set_axis_off()
|
||
|
# im = ax.imshow(stats_x[i].squeeze(0),
|
||
|
# interpolation='bilinear', cmap=self.args.cmap)
|
||
|
im = ax.contourf(stats_x[i].squeeze(0), 50, cmap='jet')
|
||
|
for c in im.collections:
|
||
|
c.set_edgecolor("face")
|
||
|
c.set_linewidth(0.000000000001)
|
||
|
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04,
|
||
|
format=ticker.ScalarFormatter(useMathText=True))
|
||
|
cbar.formatter.set_powerlimits((0, 0))
|
||
|
cbar.ax.yaxis.set_offset_position('left')
|
||
|
cbar.update_ticks()
|
||
|
plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
|
||
|
out_stats_dir = self.post_dir + '/out_stats'
|
||
|
mkdir(out_stats_dir)
|
||
|
plt.savefig(out_stats_dir + '/input_MC.pdf', di=300, bbox_inches='tight')
|
||
|
plt.close(fig)
|
||
|
print("Done plotting input MC, num of training: {}".format(self.ntrain))
|
||
|
|
||
|
# MC surrogate predictions
|
||
|
y_pred_EE, y_pred_VE, y_pred_EV, y_pred_VV = self.model.propagate(
|
||
|
self.mc_loader, n_samples=self.n_samples,
|
||
|
temperature=self.temperature, var_samples=var_samples)
|
||
|
print('Done MC predictions')
|
||
|
|
||
|
# plot the 4 output stats
|
||
|
# plot the predictive mean
|
||
|
plot_MC2(out_stats_dir, sample_mean_y, y_pred_EE, y_pred_VE, True,
|
||
|
self.ntrain, manual_scale=manual_scale)
|
||
|
# plot the predictive var
|
||
|
plot_MC2(out_stats_dir, sample_var_y, y_pred_EV, y_pred_VV, False,
|
||
|
self.ntrain)
|
||
|
|
||
|
# save for MATLAB plotting
|
||
|
scipy.io.savemat(out_stats_dir + '/out_stats.mat',
|
||
|
{'sample_mean': sample_mean_y.cpu().numpy(),
|
||
|
'sample_var': sample_var_y.cpu().numpy(),
|
||
|
'y_pred_EE': y_pred_EE.cpu().numpy(),
|
||
|
'y_pred_VE': y_pred_VE.cpu().numpy(),
|
||
|
'y_pred_EV': y_pred_EV.cpu().numpy(),
|
||
|
'y_pred_VV': y_pred_VV.cpu().numpy()})
|
||
|
print('saved output stats to .mat file')
|
||
|
|
||
|
|
||
|
def plot_dist(self, num_loc):
|
||
|
"""Plot distribution estimate in `num_loc` locations in the domain,
|
||
|
which are chosen by Latin Hypercube Sampling.
|
||
|
Args:
|
||
|
num_loc (int): number of locations where distribution is estimated
|
||
|
"""
|
||
|
print('Plotting distribution estimate.................................')
|
||
|
|
||
|
assert num_loc > 0, 'num_loc must be greater than zero'
|
||
|
locations = lhs(2, num_loc, criterion='c')
|
||
|
print('Locations selected by LHS: \n{}'.format(locations))
|
||
|
# location (ndarray): [0, 1] x [0, 1]: N x 2
|
||
|
idx = (locations * self.imsize).astype(int)
|
||
|
|
||
|
print('Propagating...')
|
||
|
pred, target = [], []
|
||
|
for _, (x_mc, t_mc) in enumerate(self.mc_loader):
|
||
|
x_mc = x_mc.to(self.device)
|
||
|
# S x B x C x H x W
|
||
|
y_mc = self.model.sample(x_mc, n_samples=self.n_samples,
|
||
|
temperature=self.temperature)
|
||
|
# S x B x C x n_points
|
||
|
pred.append(y_mc[:, :, :, idx[:, 0], idx[:, 1]])
|
||
|
# B x C x n_points
|
||
|
target.append(t_mc[:, :, idx[:, 0], idx[:, 1]])
|
||
|
# S x M x C x n_points --> M x C x n_points
|
||
|
pred = torch.cat(pred, dim=1).mean(0).cpu().numpy()
|
||
|
|
||
|
print('pred size: {}'.format(pred.shape))
|
||
|
# M x C x n_points
|
||
|
target = torch.cat(target, dim=0).cpu().numpy()
|
||
|
print('target shape: {}'.format(target.shape))
|
||
|
dist_dir = self.post_dir + '/dist_estimate'
|
||
|
mkdir(dist_dir)
|
||
|
for loc in range(locations.shape[0]):
|
||
|
print(loc)
|
||
|
fig, _ = plt.subplots(1, 3, figsize=(12, 4))
|
||
|
for c, ax in enumerate(fig.axes):
|
||
|
sns.kdeplot(target[:, c, loc], color='b', ls='--', label='Monte Carlo', ax=ax)
|
||
|
sns.kdeplot(pred[:, c, loc], color='r', label='Surrogate', ax=ax)
|
||
|
ax.legend()
|
||
|
plt.savefig(dist_dir + '/loc_({:.5f}, {:.5f}).pdf'
|
||
|
.format(locations[loc][0], locations[loc][1]), dpi=300)
|
||
|
plt.close(fig)
|
||
|
|
||
|
|
||
|
def plot_reliability_diagram(self, label='Conditional Glow', save_time=True):
|
||
|
print("Plotting reliability diagram..................................")
|
||
|
# percentage: p
|
||
|
# p_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
|
||
|
p_list = np.linspace(0.01, 0.99, 10)
|
||
|
freq = []
|
||
|
n_channels = self.mc_loader.dataset[0][1].shape[0]
|
||
|
|
||
|
for p in p_list:
|
||
|
count = 0
|
||
|
numels = 0
|
||
|
for batch_idx, (input, target) in enumerate(self.mc_loader):
|
||
|
# only evaluate 2000 of the MC data to save time
|
||
|
if save_time and batch_idx > 4:
|
||
|
continue
|
||
|
pred_mean, pred_var = self.model.predict(input.to(self.device),
|
||
|
n_samples=self.n_samples, temperature=self.temperature)
|
||
|
|
||
|
interval = scipy_norm.interval(p, loc=pred_mean.cpu().numpy(),
|
||
|
scale=pred_var.sqrt().cpu().numpy())
|
||
|
|
||
|
count += ((target.numpy() >= interval[0])
|
||
|
& (target.numpy() <= interval[1])).sum(axis=(0, 2, 3))
|
||
|
numels += target.numel() / n_channels
|
||
|
print('p: {}, {} / {} = {}'.format(p, count, numels,
|
||
|
np.true_divide(count, numels)))
|
||
|
freq.append(np.true_divide(count, numels))
|
||
|
reliability_dir = self.post_dir + '/uncertainty_quality'
|
||
|
mkdir(reliability_dir)
|
||
|
|
||
|
freq = np.stack(freq, 0)
|
||
|
for i in range(freq.shape[-1]):
|
||
|
plt.figure()
|
||
|
plt.plot(p_list, freq[:, i], 'r', label=label)
|
||
|
plt.xlabel('Probability')
|
||
|
plt.ylabel('Frequency')
|
||
|
x = np.linspace(0, 1, 100)
|
||
|
plt.plot(x, x, 'k--', label='Ideal')
|
||
|
plt.legend(loc='upper left')
|
||
|
plt.savefig(reliability_dir + f"/reliability_diagram_{i}.pdf", dpi=300)
|
||
|
plt.close()
|
||
|
|
||
|
reliability = np.zeros((p_list.shape[0], 1+n_channels))
|
||
|
reliability[:, 0] = p_list
|
||
|
reliability[:, 1:] = freq
|
||
|
np.savetxt(reliability_dir + "/reliability_diagram.txt", reliability)
|
||
|
plt.close()
|
||
|
|
||
|
|
||
|
def test_metric(self, handle_nan=True):
|
||
|
relative_l2, err2 = [], []
|
||
|
num_nan_inf = 0
|
||
|
for batch_idx, (input, target) in enumerate(self.test_loader):
|
||
|
input, target = input.to(self.device), target.to(self.device)
|
||
|
pred_mean, pred_var = self.model.predict(input, n_samples=self.n_samples,
|
||
|
temperature=self.temperature)
|
||
|
# handling nan, inf
|
||
|
if handle_nan:
|
||
|
exception = torch.isnan(pred_mean) + torch.isinf(pred_mean)
|
||
|
exception = exception.sum((1, 2, 3)).gt(0)
|
||
|
normal = (1 - exception)
|
||
|
# print(normal)
|
||
|
normal_idx = torch.arange(len(normal)).to(self.device).masked_select(normal).to(torch.long)
|
||
|
# print(normal_idx)
|
||
|
pred_mean, target = pred_mean.index_select(0, normal_idx), target.index_select(0, normal_idx)
|
||
|
num_nan_inf += exception.sum()
|
||
|
# print(pred_mean.shape)
|
||
|
|
||
|
err2_sum = torch.sum((pred_mean - target) ** 2, [-1, -2])
|
||
|
relative_l2.append(torch.sqrt(err2_sum / (target ** 2).sum([-1, -2])))
|
||
|
err2.append(err2_sum)
|
||
|
|
||
|
relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
|
||
|
r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / self.y_test_variation
|
||
|
print(relative_l2)
|
||
|
print(r2_score)
|
||
|
np.savetxt(self.post_dir + '/nrmse_test.txt', relative_l2)
|
||
|
np.savetxt(self.post_dir + '/r2_test.txt', r2_score)
|
||
|
if handle_nan:
|
||
|
abnormal_rate = num_nan_inf / len(self.test_loader.dataset)
|
||
|
print(f'num_nan_inf: {num_nan_inf}')
|
||
|
print(f'abnormal rate: {abnormal_rate:.6f}')
|
||
|
np.savetxt(self.post_dir + '/log_stats.txt',
|
||
|
[num_nan_inf, len(self.test_loader.dataset), abnormal_rate])
|
||
|
|