You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.8 KiB
60 lines
1.8 KiB
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from utils.data_standardizer import standardization
|
|
from utils.data_loader import data_loader
|
|
from options.test_options import TestOptions
|
|
|
|
|
|
def test(X, opt):
|
|
# OLD: model_load_path, X, standard = False, device = 0
|
|
if opt.is_standard:
|
|
X = standardization(X)
|
|
|
|
X_test=torch.from_numpy(X).type(torch.float32).to(opt.device)
|
|
|
|
model = torch.load(opt.pretrained_model_path)
|
|
return model(X_test)
|
|
|
|
|
|
if __name__=='__main__':
|
|
# Load parmetaers
|
|
opt = TestOptions().parse()
|
|
|
|
# Load datasets, mod2 as default
|
|
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader(opt)
|
|
X = np.hstack((coarse_density[:,:] , coarse_displace[:,:,0] , coarse_displace[:,:,1]))
|
|
Y = fine_displace[:,:]
|
|
|
|
# Predict
|
|
pred = test(X, opt)
|
|
|
|
# Set loss function
|
|
loss_function = nn.MSELoss()
|
|
# Calculate loss
|
|
pred_loss=[]
|
|
Y_test = torch.from_numpy(Y).type(torch.float32).to(opt.device)
|
|
for i in range(pred.shape[0]):
|
|
pred_loss.append(loss_function(pred[i,:],Y_test[i,:]).item())
|
|
|
|
print('Total loss: '+ str(loss_function(pred,Y_test).item()))
|
|
|
|
# Plot
|
|
plt.plot(range(pred.shape[0]),pred_loss)
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Coarse mesh id')
|
|
plt.title("Linear graph")
|
|
plt.savefig(os.path.join(opt.results_dir, 'test_loss.png'))
|
|
plt.show()
|
|
|
|
loss_metrix = np.asarray(pred_loss)
|
|
loss_metrix = loss_metrix.reshape(int(opt.nely/opt.ms_ratio), int(opt.nelx/opt.ms_ratio))
|
|
plt.matshow(loss_metrix)
|
|
plt.title("Show loss value in grid")
|
|
plt.savefig(os.path.join(opt.results_dir, 'test_loss_in_grid.png'))
|
|
plt.show()
|
|
|