You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.3 KiB
74 lines
2.3 KiB
import torch
|
|
import os
|
|
from config import *
|
|
from utils import *
|
|
import numpy as np
|
|
import scipy.io as scio
|
|
import tensorflow as tf
|
|
import tensorflow.keras as keras
|
|
from keras.models import Sequential
|
|
from keras.layers import Dense,Dropout,Activation,Flatten
|
|
from keras.layers import Conv2D,GlobalAveragePooling2D,MaxPooling2D,ZeroPadding2D,BatchNormalization
|
|
from tensorflow.keras.optimizers import Adam
|
|
import matplotlib.pyplot as plt
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
np.set_printoptions(threshold=np.inf)
|
|
|
|
train_data, train_ans, test_data, test_ans = load_data()
|
|
# train_data = train_data.reshape(800,10,32,1)
|
|
train_data = tf.expand_dims(train_data, axis=-1)
|
|
print(train_data.shape)
|
|
|
|
model=Sequential()
|
|
# model.add(Dense(20))
|
|
# model.add(Activation("relu"))
|
|
# model.add(Dense(1))
|
|
# model.add(Activation("relu"))
|
|
model.add(Conv2D(filters=1,kernel_size=(10,32),
|
|
strides=(1,1),padding='valid',
|
|
input_shape=(10,32,1),
|
|
activation='relu'))
|
|
|
|
# train the model using Adam
|
|
print("Training network")
|
|
model.compile(loss='mean_squared_logarithmic_error',
|
|
optimizer='adam',
|
|
metrics=['accuracy'])
|
|
H = model.fit(train_data,train_ans,validation_data=(test_data,test_ans),
|
|
epochs=Epochs,batch_size=BatchSize)
|
|
|
|
model.save_weights("model.h5")
|
|
print("Save model to disk")
|
|
|
|
# evaluate the network
|
|
print("Evaluating network")
|
|
loss, accuracy = model.evaluate(train_data,train_ans)
|
|
print("\nLoss: %.2f, Accuracy: %.2f%%" % (loss, accuracy*100))
|
|
|
|
# test
|
|
print("Test network")
|
|
pred = model.predict(test_data,batch_size=BatchSize)
|
|
pred = pred.reshape((200,))
|
|
print(np.abs(pred-test_ans).shape)
|
|
print(test_ans.shape)
|
|
acc = np.mean(np.abs(pred-test_ans) < 1000)
|
|
print(pred.shape)
|
|
print('Prediction Accuracy: %.2f%%' % (acc*100))
|
|
|
|
|
|
# # plot the training loss and accuracy
|
|
# plt.style.use("ggplot")
|
|
# plt.figure()
|
|
# plt.plot(np.arange(Epochs), H.history["loss"], label="train_loss")
|
|
# plt.plot(np.arange(Epochs), H.history["val_loss"], label="val_loss")
|
|
# plt.plot(np.arange(Epochs), H.history["accuracy"], label="train_acc")
|
|
# plt.plot(np.arange(Epochs), H.history["val_accuracy"], label="val_acc")
|
|
# plt.title("Training Loss and Accuracy")
|
|
# plt.xlabel("Epoch #")
|
|
# plt.ylabel("Loss/Accuracy")
|
|
# plt.legend()
|
|
# plt.savefig("LOSS.jpg")
|
|
print("Model summary")
|
|
model.summary()
|
|
|
|
|