You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.6 KiB
85 lines
2.6 KiB
import numpy as np
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
from utils.data_standardizer import standardization
|
|
from utils.data_loader import data_loader
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from models.ANN import ANN_Model
|
|
|
|
|
|
def train(X, Y, epochs=10000, mod='mod1', standard = False, device = 0):
|
|
if standard:
|
|
X = standardization(X)
|
|
Y = standardization(Y)
|
|
|
|
X_train,X_test,Y_train,Y_test=train_test_split(X,Y,test_size=0.2,random_state=0)
|
|
|
|
device = f'cuda:{device}' if torch.cuda.is_available() else 'cpu'
|
|
|
|
X_train=torch.from_numpy(X_train).type(torch.float32).to(device)
|
|
X_test=torch.from_numpy(X_test).type(torch.float32).to(device)
|
|
Y_train=torch.from_numpy(Y_train).type(torch.float32).to(device)
|
|
Y_test=torch.from_numpy(Y_test).type(torch.float32).to(device)
|
|
|
|
# Load net model
|
|
torch.manual_seed(20)
|
|
model = ANN_Model()
|
|
# model = CNN_Model()
|
|
# model = ENCODER_Model()
|
|
model.parameters
|
|
model=model.to(device)
|
|
print(model)
|
|
|
|
# Set loss function
|
|
loss_function = nn.MSELoss()
|
|
# MSE_loss=nn.MSELoss()
|
|
# BCE_loss=nn.BCELoss()
|
|
|
|
# Set adam optimizer
|
|
optimizer=torch.optim.Adam(model.parameters(),lr=0.001) # ANN 学习率最好0.001 左右(无归一化)
|
|
|
|
# Train
|
|
start_time=time.time()
|
|
losses=[]
|
|
for i in range(epochs):
|
|
pred = model.forward(X_train)
|
|
loss=loss_function(pred,Y_train)
|
|
# loss.requires_grad_(True)
|
|
losses.append(loss.cpu().detach().numpy())
|
|
if i%(epochs/10)==1:
|
|
print("Epoch number: {} and the loss : {}".format(i,loss.item()))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
print(time.time()-start_time)
|
|
|
|
torch.save(model, 'checkpoints/' + str(model).split('_')[0] + '_' + mod + '_' + 'opt.pt')
|
|
|
|
return losses
|
|
|
|
|
|
if __name__=='__main__':
|
|
# Load datasets
|
|
# train data select:
|
|
data_mod='mod1' # opt: mod1 mod2 mod3
|
|
|
|
dst_path='datasets/top88_'+ data_mod + '_xPhys_180_60.npy'
|
|
U_path='datasets/top88_'+ data_mod + '_u_180_60.npy'
|
|
global_density, global_displace, coarse_density, coarse_displace, fine_displace = data_loader(dst_path, U_path)
|
|
X = np.hstack((coarse_density[:,:] , coarse_displace[:,:,0] , coarse_displace[:,:,1]))
|
|
Y = fine_displace[:,:]
|
|
|
|
# Train
|
|
losses = train(X, Y, epochs=10000, mod=data_mod)
|
|
|
|
# plot loss
|
|
plt.plot(range(10000),losses)
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.show()
|
|
|