You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.3 KiB

1 year ago
import torch
import os
from config import *
from utils import *
import numpy as np
import scipy.io as scio
import tensorflow as tf
import tensorflow.keras as keras
from keras.models import Sequential
from keras.layers import Dense,Dropout,Activation,Flatten
from keras.layers import Conv2D,GlobalAveragePooling2D,MaxPooling2D,ZeroPadding2D,BatchNormalization
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
np.set_printoptions(threshold=np.inf)
train_data, train_ans, test_data, test_ans = load_data()
# train_data = train_data.reshape(800,10,32,1)
train_data = tf.expand_dims(train_data, axis=-1)
print(train_data.shape)
model=Sequential()
# model.add(Dense(20))
# model.add(Activation("relu"))
# model.add(Dense(1))
# model.add(Activation("relu"))
model.add(Conv2D(filters=1,kernel_size=(10,32),
strides=(1,1),padding='valid',
input_shape=(10,32,1),
activation='relu'))
# train the model using Adam
print("Training network")
model.compile(loss='mean_squared_logarithmic_error',
optimizer='adam',
metrics=['accuracy'])
H = model.fit(train_data,train_ans,validation_data=(test_data,test_ans),
epochs=Epochs,batch_size=BatchSize)
model.save_weights("model.h5")
print("Save model to disk")
# evaluate the network
print("Evaluating network")
loss, accuracy = model.evaluate(train_data,train_ans)
print("\nLoss: %.2f, Accuracy: %.2f%%" % (loss, accuracy*100))
# test
print("Test network")
pred = model.predict(test_data,batch_size=BatchSize)
pred = pred.reshape((200,))
print(np.abs(pred-test_ans).shape)
print(test_ans.shape)
acc = np.mean(np.abs(pred-test_ans) < 1000)
print(pred.shape)
print('Prediction Accuracy: %.2f%%' % (acc*100))
# # plot the training loss and accuracy
# plt.style.use("ggplot")
# plt.figure()
# plt.plot(np.arange(Epochs), H.history["loss"], label="train_loss")
# plt.plot(np.arange(Epochs), H.history["val_loss"], label="val_loss")
# plt.plot(np.arange(Epochs), H.history["accuracy"], label="train_acc")
# plt.plot(np.arange(Epochs), H.history["val_accuracy"], label="val_acc")
# plt.title("Training Loss and Accuracy")
# plt.xlabel("Epoch #")
# plt.ylabel("Loss/Accuracy")
# plt.legend()
# plt.savefig("LOSS.jpg")
print("Model summary")
model.summary()