import torch import numpy as np import scipy.io as scio from config import * def load_data(): train_data = np.empty((train_amount*2,row_size,col_size),dtype="double") train_ans = np.empty((train_amount*2,),dtype="double") test_data = np.empty((test_amount*2,row_size,col_size),dtype="double") test_ans = np.empty((test_amount*2,),dtype="double") print("Load training data") for i in range(1,400): dir = CurvedMatPath + str(i) + ".mat" data = scio.loadmat(dir) train_data[i] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) dir = CurvedAnsPath + str(i) + ".mat" data = scio.loadmat(dir) train_ans[i] = data['res'] # negtive area dir = CurvedMatPath + str(i) + "_neg.mat" data = scio.loadmat(dir) train_data[i + train_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) dir = CurvedAnsPath + str(i) + "_neg.mat" data = scio.loadmat(dir) train_ans[i + train_amount] = data['res'] print("Load test data") for i in range(401,500): dir = CurvedMatPath + str(i) + ".mat" data = scio.loadmat(dir) test_data[i-400] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) dir = CurvedAnsPath + str(i) + ".mat" data = scio.loadmat(dir) test_ans[i-400] = data['res'] # negtive area dir = CurvedMatPath + str(i) + "_neg.mat" data = scio.loadmat(dir) test_data[i-400+test_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) dir = CurvedAnsPath + str(i) + "_neg.mat" data = scio.loadmat(dir) test_ans[i-400+test_amount] = data['res'] if np.any(np.isnan(train_data)) or np.any(np.isnan(train_ans)) or np.any(np.isnan(test_data)) or np.any(np.isnan(test_ans)): print("EXIST NAN") return train_data, train_ans, test_data, test_ans