You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
714 B
30 lines
714 B
2 years ago
|
import torch
|
||
|
import os
|
||
|
import config
|
||
|
from utils import *
|
||
|
import tensorflow.keras as keras
|
||
|
from keras.models import load_model
|
||
|
from keras.preprocessing import image
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
train_data, train_ans, test_data, test_ans = load_data()
|
||
|
|
||
|
print('Using loaded model to predict...')
|
||
|
model = load_model('model.h5')
|
||
|
np.set_printoptions(precision=4)
|
||
|
|
||
|
# test
|
||
|
print("Test network")
|
||
|
pred = model.predict(test_data,batch_size=BatchSize)
|
||
|
pred = pred.reshape((200,))
|
||
|
print(np.abs(pred-test_ans).shape)
|
||
|
print(test_ans.shape)
|
||
|
acc = np.mean(np.abs(pred-test_ans) < 1000)
|
||
|
print(pred.shape)
|
||
|
print('Prediction Accuracy: %.2f%%' % (acc*100))
|
||
|
|
||
|
model.save('./weights/model.h5')
|
||
|
|
||
|
print("Model summary")
|
||
|
model.summary()
|