You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
2.0 KiB
53 lines
2.0 KiB
1 year ago
|
import torch
|
||
|
import numpy as np
|
||
|
import scipy.io as scio
|
||
|
from config import *
|
||
|
|
||
|
def load_data():
|
||
|
train_data = np.empty((train_amount*2,row_size,col_size),dtype="double")
|
||
|
train_ans = np.empty((train_amount*2,),dtype="double")
|
||
|
test_data = np.empty((test_amount*2,row_size,col_size),dtype="double")
|
||
|
test_ans = np.empty((test_amount*2,),dtype="double")
|
||
|
|
||
|
print("Load training data")
|
||
|
for i in range(1,400):
|
||
|
dir = CurvedMatPath + str(i) + ".mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
train_data[i] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
|
||
|
|
||
|
dir = CurvedAnsPath + str(i) + ".mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
train_ans[i] = data['res']
|
||
|
|
||
|
# negtive area
|
||
|
dir = CurvedMatPath + str(i) + "_neg.mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
train_data[i + train_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
|
||
|
|
||
|
dir = CurvedAnsPath + str(i) + "_neg.mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
train_ans[i + train_amount] = data['res']
|
||
|
|
||
|
print("Load test data")
|
||
|
for i in range(401,500):
|
||
|
dir = CurvedMatPath + str(i) + ".mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
test_data[i-400] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
|
||
|
|
||
|
dir = CurvedAnsPath + str(i) + ".mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
test_ans[i-400] = data['res']
|
||
|
|
||
|
# negtive area
|
||
|
dir = CurvedMatPath + str(i) + "_neg.mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
test_data[i-400+test_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0))
|
||
|
|
||
|
dir = CurvedAnsPath + str(i) + "_neg.mat"
|
||
|
data = scio.loadmat(dir)
|
||
|
test_ans[i-400+test_amount] = data['res']
|
||
|
if np.any(np.isnan(train_data)) or np.any(np.isnan(train_ans)) or np.any(np.isnan(test_data)) or np.any(np.isnan(test_ans)):
|
||
|
print("EXIST NAN")
|
||
|
|
||
|
return train_data, train_ans, test_data, test_ans
|
||
|
|