import os import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from utils.data_standardizer import standardization from utils.data_loader import data_loader_new from options.test_options import TestOptions from utils.visualization import surf_plot def test(X, opt): X_test=torch.from_numpy(X).type(torch.float32).to(opt.device) model = torch.load(opt.pretrained_model_path) N=(opt.ms_ratio+1)**2 * 2 pred=torch.zeros(X_test.shape[0], N) for batch_idx, data_batch in enumerate(X_test): pred_ShapeFunction=model(data_batch) pred[batch_idx,:]=pred_ShapeFunction.reshape(N,8) @ data_batch[25:] return pred if __name__=='__main__': # Load parmetaers opt = TestOptions().parse() # Load datasets, mod2 as default m=opt.ms_ratio c_nelx=int(opt.nelx/m) c_nely=int(opt.nely/m) c_N=c_nelx*c_nely global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader_new(opt) # X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,2,2,2)[:,:,:,0].reshape(c_N,4), coarse_displace.reshape(c_N,2,2,2)[:,:,:,1].reshape(c_N,4))) X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,8))) Y = fine_displace.reshape(c_N,(m+1)**2*2) if opt.is_standard: X = standardization(X) Y = standardization(Y) # Predict pred = test(X, opt) pred.to('cpu') # Set loss function loss_function = nn.MSELoss() # Calculate loss pred_loss=[] Y_test = torch.from_numpy(Y).type(torch.float32).to('cpu') for i in range(pred.shape[0]): pred_loss.append(loss_function(pred[i,:],Y_test[i,:]).item()) print('Total loss: '+ str(loss_function(pred,Y_test).item())) # Plot plt.plot(range(pred.shape[0]),pred_loss) plt.ylabel('Loss') plt.xlabel('Coarse mesh id') plt.title("Linear graph") plt.savefig(os.path.join(opt.results_dir, 'test_loss.png')) plt.show() loss_metrix = np.asarray(pred_loss) loss_metrix = loss_metrix.reshape(int(opt.nely/opt.ms_ratio), int(opt.nelx/opt.ms_ratio)) plt.matshow(loss_metrix) plt.title("Show loss value in grid") plt.savefig(os.path.join(opt.results_dir, 'test_loss_in_grid.png')) plt.show() # Plot every mesh displacement for comparation # for i in range(pred.shape[0]): # surf_plot(pred[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes', 'test_pred_mesh'+str(i)+'.png'),'u') # surf_plot(Y_test[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes','test_GT_mesh'+str(i)+'.png'),'u')