You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.2 KiB
101 lines
3.2 KiB
import time
|
|
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from utils.data_standardizer import standardization
|
|
from utils.data_loader import data_loader_new
|
|
from options.train_options import TrainOptions
|
|
from utils.b_spline.surface_inter import surface_inter
|
|
from models.ANN import ANN_Model
|
|
from models.ANN_b_spline import ANN_b_spline_Model
|
|
|
|
def train(X, Y, opt):
|
|
if opt.is_standard:
|
|
X = standardization(X)
|
|
Y = standardization(Y)
|
|
|
|
X_train=torch.from_numpy(X).type(torch.float32).to(opt.device)
|
|
Y_train=torch.from_numpy(Y).type(torch.float32).to(opt.device)
|
|
|
|
# Load net model
|
|
torch.manual_seed(20)
|
|
model_name=opt.model+'_Model'
|
|
model = eval(model_name)() # ANN_Model() as default
|
|
model.parameters
|
|
model=model.to(opt.device)
|
|
print(model)
|
|
|
|
# Set loss function
|
|
loss_function = nn.MSELoss()
|
|
|
|
# Set adam optimizer
|
|
optimizer=torch.optim.Adam(model.parameters(),lr=opt.lr)
|
|
|
|
# Train
|
|
start_time=time.time()
|
|
losses=[]
|
|
loss=0
|
|
N=(opt.ms_ratio+1)**2 * 2
|
|
for epoch in range(opt.epochs):
|
|
for batch_idx, data_batch in enumerate(X_train):
|
|
model.train() # 启用 batch normalization 和 dropout
|
|
|
|
# 线性插值
|
|
# pred_ShapeFunction = model(data_batch)
|
|
# pred_U=pred_ShapeFunction.reshape(N,8) @ data_batch[25:]
|
|
# loss=loss_function(pred_U,Y_train[batch_idx,:])
|
|
|
|
# B spline
|
|
control_points = model(data_batch)
|
|
pred_U = surface_inter(control_points)
|
|
loss=loss_function(pred_U,Y_train[batch_idx,:])
|
|
|
|
# print(loss.item())
|
|
loss.requires_grad_(True) # 梯度不更新
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
losses.append(loss.cpu().detach().numpy())
|
|
if epoch%(opt.epochs/20)==1:
|
|
print("Epoch number: {} and the loss : {}".format(epoch,loss.item()))
|
|
print(time.time()-start_time)
|
|
|
|
# save trained model, mkdir opreate has done in options/base_options.py
|
|
save_path=os.path.join(opt.expr_dir, opt.model+'_'+opt.mod+'_opt.pt')
|
|
torch.save(model, save_path)
|
|
|
|
return losses
|
|
|
|
|
|
if __name__=='__main__':
|
|
# Load parmetaers
|
|
opt = TrainOptions().parse()
|
|
save_path=os.path.join(opt.expr_dir, opt.model+'_'+opt.mod+'_opt.pt')
|
|
|
|
# Load datasets, mod1 as default
|
|
m=opt.ms_ratio
|
|
c_nelx=int(opt.nelx/m)
|
|
c_nely=int(opt.nely/m)
|
|
c_N=c_nelx*c_nely
|
|
|
|
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader_new(opt)
|
|
# X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,2,2,2)[:,:,:,0].reshape(c_N,4), coarse_displace.reshape(c_N,2,2,2)[:,:,:,1].reshape(c_N,4)))
|
|
X = np.hstack((coarse_density.reshape(c_N,m*m), coarse_displace.reshape(c_N,8)))
|
|
Y = fine_displace.reshape(c_N,(m+1)**2*2)
|
|
|
|
# Train
|
|
losses = train(X, Y, opt)
|
|
|
|
# plot loss
|
|
plt.plot(range(opt.epochs),losses)
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.savefig(os.path.join(opt.results_dir, 'train_losses.png'))
|
|
plt.show()
|
|
|