该项目是《Problem-independent machine learning (PIML)-based topology optimization—A universal approach》的python复现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

77 lines
2.6 KiB

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.data_standardizer import standardization
from utils.data_loader import data_loader_new
from options.test_options import TestOptions
from utils.visualization import surf_plot
def test(X, opt):
X_test=torch.from_numpy(X).type(torch.float32).to(opt.device)
model = torch.load(opt.pretrained_model_path)
N=(opt.ms_ratio+1)**2 * 2
pred=torch.zeros(X_test.shape[0], N)
for batch_idx, data_batch in enumerate(X_test):
pred_ShapeFunction=model(data_batch)
pred[batch_idx,:]=pred_ShapeFunction.reshape(N,8) @ data_batch[25:]
return pred
if __name__=='__main__':
# Load parmetaers
opt = TestOptions().parse()
# Load datasets, mod2 as default
m=opt.ms_ratio
c_nelx=int(opt.nelx/m)
c_nely=int(opt.nely/m)
c_N=c_nelx*c_nely
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader_new(opt)
# X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,2,2,2)[:,:,:,0].reshape(c_N,4), coarse_displace.reshape(c_N,2,2,2)[:,:,:,1].reshape(c_N,4)))
X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,8)))
Y = fine_displace.reshape(c_N,(m+1)**2*2)
if opt.is_standard:
X = standardization(X)
Y = standardization(Y)
# Predict
pred = test(X, opt)
pred.to('cpu')
# Set loss function
loss_function = nn.MSELoss()
# Calculate loss
pred_loss=[]
Y_test = torch.from_numpy(Y).type(torch.float32).to('cpu')
for i in range(pred.shape[0]):
pred_loss.append(loss_function(pred[i,:],Y_test[i,:]).item())
print('Total loss: '+ str(loss_function(pred,Y_test).item()))
# Plot
plt.plot(range(pred.shape[0]),pred_loss)
plt.ylabel('Loss')
plt.xlabel('Coarse mesh id')
plt.title("Linear graph")
plt.savefig(os.path.join(opt.results_dir, 'test_loss.png'))
plt.show()
loss_metrix = np.asarray(pred_loss)
loss_metrix = loss_metrix.reshape(int(opt.nely/opt.ms_ratio), int(opt.nelx/opt.ms_ratio))
plt.matshow(loss_metrix)
plt.title("Show loss value in grid")
plt.savefig(os.path.join(opt.results_dir, 'test_loss_in_grid.png'))
plt.show()
# Plot every mesh displacement for comparation
# for i in range(pred.shape[0]):
# surf_plot(pred[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes', 'test_pred_mesh'+str(i)+'.png'),'u')
# surf_plot(Y_test[i].detach().numpy(),6,6,os.path.join(opt.results_dir, 'meshes','test_GT_mesh'+str(i)+'.png'),'u')