You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

29 lines
714 B

import torch
import os
import config
from utils import *
import tensorflow.keras as keras
from keras.models import load_model
from keras.preprocessing import image
import matplotlib.pyplot as plt
train_data, train_ans, test_data, test_ans = load_data()
print('Using loaded model to predict...')
model = load_model('model.h5')
np.set_printoptions(precision=4)
# test
print("Test network")
pred = model.predict(test_data,batch_size=BatchSize)
pred = pred.reshape((200,))
print(np.abs(pred-test_ans).shape)
print(test_ans.shape)
acc = np.mean(np.abs(pred-test_ans) < 1000)
print(pred.shape)
print('Prediction Accuracy: %.2f%%' % (acc*100))
model.save('./weights/model.h5')
print("Model summary")
model.summary()