You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
75 lines
2.3 KiB
1 year ago
|
import torch
|
||
|
import os
|
||
|
from config import *
|
||
|
from utils import *
|
||
|
import numpy as np
|
||
|
import scipy.io as scio
|
||
|
import tensorflow as tf
|
||
|
import tensorflow.keras as keras
|
||
|
from keras.models import Sequential
|
||
|
from keras.layers import Dense,Dropout,Activation,Flatten
|
||
|
from keras.layers import Conv2D,GlobalAveragePooling2D,MaxPooling2D,ZeroPadding2D,BatchNormalization
|
||
|
from tensorflow.keras.optimizers import Adam
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||
|
np.set_printoptions(threshold=np.inf)
|
||
|
|
||
|
train_data, train_ans, test_data, test_ans = load_data()
|
||
|
# train_data = train_data.reshape(800,10,32,1)
|
||
|
train_data = tf.expand_dims(train_data, axis=-1)
|
||
|
print(train_data.shape)
|
||
|
|
||
|
model=Sequential()
|
||
|
# model.add(Dense(20))
|
||
|
# model.add(Activation("relu"))
|
||
|
# model.add(Dense(1))
|
||
|
# model.add(Activation("relu"))
|
||
|
model.add(Conv2D(filters=1,kernel_size=(10,32),
|
||
|
strides=(1,1),padding='valid',
|
||
|
input_shape=(10,32,1),
|
||
|
activation='relu'))
|
||
|
|
||
|
# train the model using Adam
|
||
|
print("Training network")
|
||
|
model.compile(loss='mean_squared_logarithmic_error',
|
||
|
optimizer='adam',
|
||
|
metrics=['accuracy'])
|
||
|
H = model.fit(train_data,train_ans,validation_data=(test_data,test_ans),
|
||
|
epochs=Epochs,batch_size=BatchSize)
|
||
|
|
||
|
model.save_weights("model.h5")
|
||
|
print("Save model to disk")
|
||
|
|
||
|
# evaluate the network
|
||
|
print("Evaluating network")
|
||
|
loss, accuracy = model.evaluate(train_data,train_ans)
|
||
|
print("\nLoss: %.2f, Accuracy: %.2f%%" % (loss, accuracy*100))
|
||
|
|
||
|
# test
|
||
|
print("Test network")
|
||
|
pred = model.predict(test_data,batch_size=BatchSize)
|
||
|
pred = pred.reshape((200,))
|
||
|
print(np.abs(pred-test_ans).shape)
|
||
|
print(test_ans.shape)
|
||
|
acc = np.mean(np.abs(pred-test_ans) < 1000)
|
||
|
print(pred.shape)
|
||
|
print('Prediction Accuracy: %.2f%%' % (acc*100))
|
||
|
|
||
|
|
||
|
# # plot the training loss and accuracy
|
||
|
# plt.style.use("ggplot")
|
||
|
# plt.figure()
|
||
|
# plt.plot(np.arange(Epochs), H.history["loss"], label="train_loss")
|
||
|
# plt.plot(np.arange(Epochs), H.history["val_loss"], label="val_loss")
|
||
|
# plt.plot(np.arange(Epochs), H.history["accuracy"], label="train_acc")
|
||
|
# plt.plot(np.arange(Epochs), H.history["val_accuracy"], label="val_acc")
|
||
|
# plt.title("Training Loss and Accuracy")
|
||
|
# plt.xlabel("Epoch #")
|
||
|
# plt.ylabel("Loss/Accuracy")
|
||
|
# plt.legend()
|
||
|
# plt.savefig("LOSS.jpg")
|
||
|
print("Model summary")
|
||
|
model.summary()
|
||
|
|