You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
1.9 KiB
64 lines
1.9 KiB
1 year ago
|
import numpy as np
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from utils.data_standardizer import standardization
|
||
|
from utils.data_loader import data_loader
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
def test(model_load_path, X, standard = False, device = 0):
|
||
|
model = torch.load(model_load_path)
|
||
|
|
||
|
if standard:
|
||
|
X = standardization(X)
|
||
|
device = f'cuda:{device}' if torch.cuda.is_available() else 'cpu'
|
||
|
X = torch.from_numpy(X).type(torch.float32).to(device)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
return model(X)
|
||
|
|
||
|
|
||
|
if __name__=='__main__':
|
||
|
# Load datasets
|
||
|
# test data select:
|
||
|
dataload_mod='mod1' # opt: mod1 mod2 mod3
|
||
|
# pretrained model select:
|
||
|
pretrained_mod='mod1' # opt: mod1 mod2 mod3
|
||
|
|
||
|
dst_path='datasets/top88_'+ dataload_mod + '_xPhys_180_60.npy'
|
||
|
U_path='datasets/top88_'+ dataload_mod + '_u_180_60.npy'
|
||
|
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader(dst_path, U_path)
|
||
|
X = np.hstack((coarse_density[:,:] , coarse_displace[:,:,0] , coarse_displace[:,:,1]))
|
||
|
Y = fine_displace[:,:]
|
||
|
|
||
|
# Set loss function
|
||
|
loss_function = nn.MSELoss()
|
||
|
|
||
|
# Predict
|
||
|
pred = test('checkpoints/ANN_' + pretrained_mod + '_opt.pt', X)
|
||
|
|
||
|
# Calculate loss
|
||
|
pred_loss=[]
|
||
|
device = f'cuda:{0}' if torch.cuda.is_available() else 'cpu'
|
||
|
Y = torch.from_numpy(Y).type(torch.float32).to(device)
|
||
|
for i in range(pred.shape[0]):
|
||
|
pred_loss.append(loss_function(pred[i,:],Y[i,:]).item())
|
||
|
|
||
|
print('Total loss: '+ str(loss_function(pred,Y).item()))
|
||
|
|
||
|
# Plot
|
||
|
plt.plot(range(pred.shape[0]),pred_loss)
|
||
|
plt.ylabel('Loss')
|
||
|
plt.xlabel('Coarse mesh id')
|
||
|
plt.title("Linear graph")
|
||
|
plt.show()
|
||
|
|
||
|
loss_metrix = np.asarray(pred_loss)
|
||
|
loss_metrix = loss_metrix.reshape(int(60/5), int(180/5))
|
||
|
plt.matshow(loss_metrix)
|
||
|
plt.title("Show loss value in grid")
|
||
|
plt.show()
|