You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.6 KiB
77 lines
2.6 KiB
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from utils.data_standardizer import standardization
|
|
from utils.data_loader import data_loader_new
|
|
from options.test_options import TestOptions
|
|
from utils.visualization import surf_plot
|
|
|
|
def test(X, opt):
|
|
X_test=torch.from_numpy(X).type(torch.float32).to(opt.device)
|
|
|
|
model = torch.load(opt.pretrained_model_path)
|
|
|
|
N=(opt.ms_ratio+1)**2 * 2
|
|
pred=torch.zeros(X_test.shape[0], N)
|
|
for batch_idx, data_batch in enumerate(X_test):
|
|
pred_ShapeFunction=model(data_batch)
|
|
pred[batch_idx,:]=pred_ShapeFunction.reshape(N,8) @ data_batch[25:]
|
|
|
|
return pred
|
|
|
|
|
|
if __name__=='__main__':
|
|
# Load parmetaers
|
|
opt = TestOptions().parse()
|
|
|
|
# Load datasets, mod2 as default
|
|
m=opt.ms_ratio
|
|
c_nelx=int(opt.nelx/m)
|
|
c_nely=int(opt.nely/m)
|
|
c_N=c_nelx*c_nely
|
|
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader_new(opt)
|
|
# X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,2,2,2)[:,:,:,0].reshape(c_N,4), coarse_displace.reshape(c_N,2,2,2)[:,:,:,1].reshape(c_N,4)))
|
|
X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,8)))
|
|
Y = fine_displace.reshape(c_N,(m+1)**2*2)
|
|
if opt.is_standard:
|
|
X = standardization(X)
|
|
Y = standardization(Y)
|
|
|
|
# Predict
|
|
pred = test(X, opt)
|
|
pred.to('cpu')
|
|
|
|
# Set loss function
|
|
loss_function = nn.MSELoss()
|
|
# Calculate loss
|
|
pred_loss=[]
|
|
Y_test = torch.from_numpy(Y).type(torch.float32).to('cpu')
|
|
for i in range(pred.shape[0]):
|
|
pred_loss.append(loss_function(pred[i,:],Y_test[i,:]).item())
|
|
|
|
print('Total loss: '+ str(loss_function(pred,Y_test).item()))
|
|
|
|
# Plot
|
|
plt.plot(range(pred.shape[0]),pred_loss)
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Coarse mesh id')
|
|
plt.title("Linear graph")
|
|
plt.savefig(os.path.join(opt.results_dir, 'test_loss.png'))
|
|
plt.show()
|
|
|
|
loss_metrix = np.asarray(pred_loss)
|
|
loss_metrix = loss_metrix.reshape(int(opt.nely/opt.ms_ratio), int(opt.nelx/opt.ms_ratio))
|
|
plt.matshow(loss_metrix)
|
|
plt.title("Show loss value in grid")
|
|
plt.savefig(os.path.join(opt.results_dir, 'test_loss_in_grid.png'))
|
|
plt.show()
|
|
|
|
# Plot every mesh displacement for comparation
|
|
# for i in range(pred.shape[0]):
|
|
# surf_plot(pred[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes', 'test_pred_mesh'+str(i)+'.png'),'u')
|
|
# surf_plot(Y_test[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes','test_GT_mesh'+str(i)+'.png'),'u')
|
|
|