3380 changed files with 247586 additions and 8 deletions
Binary file not shown.
Binary file not shown.
@ -0,0 +1,14 @@ |
|||
GenWeightPath = 'GenCurvedModel.h5' |
|||
DisWeightPath = 'DisCurvedModel.h5' |
|||
CurvedMatPath = 'CNN-quad2D/data/curved_mat/' |
|||
CurvedAnsPath = 'CNN-quad2D/data/curved_ans/' |
|||
|
|||
net_input_size = 32 #256 |
|||
row_size = 10 |
|||
col_size = 32 |
|||
|
|||
train_amount = 400 |
|||
test_amount = 100 |
|||
|
|||
Epochs = 10 #6000 |
|||
BatchSize = 200 #20 |
@ -0,0 +1,74 @@ |
|||
import torch |
|||
import os |
|||
from config import * |
|||
from utils import * |
|||
import numpy as np |
|||
import scipy.io as scio |
|||
import tensorflow as tf |
|||
import tensorflow.keras as keras |
|||
from keras.models import Sequential |
|||
from keras.layers import Dense,Dropout,Activation,Flatten |
|||
from keras.layers import Conv2D,GlobalAveragePooling2D,MaxPooling2D,ZeroPadding2D,BatchNormalization |
|||
from tensorflow.keras.optimizers import Adam |
|||
import matplotlib.pyplot as plt |
|||
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|||
np.set_printoptions(threshold=np.inf) |
|||
|
|||
train_data, train_ans, test_data, test_ans = load_data() |
|||
# train_data = train_data.reshape(800,10,32,1) |
|||
train_data = tf.expand_dims(train_data, axis=-1) |
|||
print(train_data.shape) |
|||
|
|||
model=Sequential() |
|||
# model.add(Dense(20)) |
|||
# model.add(Activation("relu")) |
|||
# model.add(Dense(1)) |
|||
# model.add(Activation("relu")) |
|||
model.add(Conv2D(filters=1,kernel_size=(10,32), |
|||
strides=(1,1),padding='valid', |
|||
input_shape=(10,32,1), |
|||
activation='relu')) |
|||
|
|||
# train the model using Adam |
|||
print("Training network") |
|||
model.compile(loss='mean_squared_logarithmic_error', |
|||
optimizer='adam', |
|||
metrics=['accuracy']) |
|||
H = model.fit(train_data,train_ans,validation_data=(test_data,test_ans), |
|||
epochs=Epochs,batch_size=BatchSize) |
|||
|
|||
model.save_weights("model.h5") |
|||
print("Save model to disk") |
|||
|
|||
# evaluate the network |
|||
print("Evaluating network") |
|||
loss, accuracy = model.evaluate(train_data,train_ans) |
|||
print("\nLoss: %.2f, Accuracy: %.2f%%" % (loss, accuracy*100)) |
|||
|
|||
# test |
|||
print("Test network") |
|||
pred = model.predict(test_data,batch_size=BatchSize) |
|||
pred = pred.reshape((200,)) |
|||
print(np.abs(pred-test_ans).shape) |
|||
print(test_ans.shape) |
|||
acc = np.mean(np.abs(pred-test_ans) < 1000) |
|||
print(pred.shape) |
|||
print('Prediction Accuracy: %.2f%%' % (acc*100)) |
|||
|
|||
|
|||
# # plot the training loss and accuracy |
|||
# plt.style.use("ggplot") |
|||
# plt.figure() |
|||
# plt.plot(np.arange(Epochs), H.history["loss"], label="train_loss") |
|||
# plt.plot(np.arange(Epochs), H.history["val_loss"], label="val_loss") |
|||
# plt.plot(np.arange(Epochs), H.history["accuracy"], label="train_acc") |
|||
# plt.plot(np.arange(Epochs), H.history["val_accuracy"], label="val_acc") |
|||
# plt.title("Training Loss and Accuracy") |
|||
# plt.xlabel("Epoch #") |
|||
# plt.ylabel("Loss/Accuracy") |
|||
# plt.legend() |
|||
# plt.savefig("LOSS.jpg") |
|||
print("Model summary") |
|||
model.summary() |
|||
|
@ -0,0 +1,29 @@ |
|||
import torch |
|||
import os |
|||
import config |
|||
from utils import * |
|||
import tensorflow.keras as keras |
|||
from keras.models import load_model |
|||
from keras.preprocessing import image |
|||
import matplotlib.pyplot as plt |
|||
|
|||
train_data, train_ans, test_data, test_ans = load_data() |
|||
|
|||
print('Using loaded model to predict...') |
|||
model = load_model('model.h5') |
|||
np.set_printoptions(precision=4) |
|||
|
|||
# test |
|||
print("Test network") |
|||
pred = model.predict(test_data,batch_size=BatchSize) |
|||
pred = pred.reshape((200,)) |
|||
print(np.abs(pred-test_ans).shape) |
|||
print(test_ans.shape) |
|||
acc = np.mean(np.abs(pred-test_ans) < 1000) |
|||
print(pred.shape) |
|||
print('Prediction Accuracy: %.2f%%' % (acc*100)) |
|||
|
|||
model.save('./weights/model.h5') |
|||
|
|||
print("Model summary") |
|||
model.summary() |
@ -0,0 +1,53 @@ |
|||
import torch |
|||
import numpy as np |
|||
import scipy.io as scio |
|||
from config import * |
|||
|
|||
def load_data(): |
|||
train_data = np.empty((train_amount*2,row_size,col_size),dtype="double") |
|||
train_ans = np.empty((train_amount*2,),dtype="double") |
|||
test_data = np.empty((test_amount*2,row_size,col_size),dtype="double") |
|||
test_ans = np.empty((test_amount*2,),dtype="double") |
|||
|
|||
print("Load training data") |
|||
for i in range(1,400): |
|||
dir = CurvedMatPath + str(i) + ".mat" |
|||
data = scio.loadmat(dir) |
|||
train_data[i] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) |
|||
|
|||
dir = CurvedAnsPath + str(i) + ".mat" |
|||
data = scio.loadmat(dir) |
|||
train_ans[i] = data['res'] |
|||
|
|||
# negtive area |
|||
dir = CurvedMatPath + str(i) + "_neg.mat" |
|||
data = scio.loadmat(dir) |
|||
train_data[i + train_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) |
|||
|
|||
dir = CurvedAnsPath + str(i) + "_neg.mat" |
|||
data = scio.loadmat(dir) |
|||
train_ans[i + train_amount] = data['res'] |
|||
|
|||
print("Load test data") |
|||
for i in range(401,500): |
|||
dir = CurvedMatPath + str(i) + ".mat" |
|||
data = scio.loadmat(dir) |
|||
test_data[i-400] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) |
|||
|
|||
dir = CurvedAnsPath + str(i) + ".mat" |
|||
data = scio.loadmat(dir) |
|||
test_ans[i-400] = data['res'] |
|||
|
|||
# negtive area |
|||
dir = CurvedMatPath + str(i) + "_neg.mat" |
|||
data = scio.loadmat(dir) |
|||
test_data[i-400+test_amount] = np.pad(data['mat'],((0,6),(0,20)),'constant',constant_values=(0,0)) |
|||
|
|||
dir = CurvedAnsPath + str(i) + "_neg.mat" |
|||
data = scio.loadmat(dir) |
|||
test_ans[i-400+test_amount] = data['res'] |
|||
if np.any(np.isnan(train_data)) or np.any(np.isnan(train_ans)) or np.any(np.isnan(test_data)) or np.any(np.isnan(test_ans)): |
|||
print("EXIST NAN") |
|||
|
|||
return train_data, train_ans, test_data, test_ans |
|||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue