You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
2.0 KiB

1 year ago
import torch
import numpy as np
import scipy.io as scio
from config import *
def load_data():
train_data = np.empty((train_amount*2,row_size,col_size),dtype="double")
train_ans = np.empty((train_amount*2,),dtype="double")
test_data = np.empty((test_amount*2,row_size,col_size),dtype="double")
test_ans = np.empty((test_amount*2,),dtype="double")
print("Load training data")
for i in range(1,400):
dir = CurvedMatPath + str(i) + ".mat"
data = scio.loadmat(dir)
train_data[i] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
dir = CurvedAnsPath + str(i) + ".mat"
data = scio.loadmat(dir)
train_ans[i] = data['res']
# negtive area
dir = CurvedMatPath + str(i) + "_neg.mat"
data = scio.loadmat(dir)
train_data[i + train_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
dir = CurvedAnsPath + str(i) + "_neg.mat"
data = scio.loadmat(dir)
train_ans[i + train_amount] = data['res']
print("Load test data")
for i in range(401,500):
dir = CurvedMatPath + str(i) + ".mat"
data = scio.loadmat(dir)
test_data[i-400] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
dir = CurvedAnsPath + str(i) + ".mat"
data = scio.loadmat(dir)
test_ans[i-400] = data['res']
# negtive area
dir = CurvedMatPath + str(i) + "_neg.mat"
data = scio.loadmat(dir)
test_data[i-400+test_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
dir = CurvedAnsPath + str(i) + "_neg.mat"
data = scio.loadmat(dir)
test_ans[i-400+test_amount] = data['res']
if np.any(np.isnan(train_data)) or np.any(np.isnan(train_ans)) or np.any(np.isnan(test_data)) or np.any(np.isnan(test_ans)):
print("EXIST NAN")
return train_data, train_ans, test_data, test_ans