You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
hunjianghu/gzy/tesorflow/预测模型.py

58 lines
1.7 KiB

import tensorflow as tf
import numpy as np
from PIL import Image
import os
flower_dict = {0:'小熊猫',1:'滑稽',2:'萌妹子',3:'小坏坏',4:'小黄鸡'}
path = 'C:/Users/yuanshao/Desktop/test/'
w=100
h=100
c=3
def read_one_image(path):
image = Image.open(path).convert('RGB')
img = image.resize((w, h), Image.ANTIALIAS)
return np.asarray(img)
def preprocess():
count = 0
for file in os.listdir(path):
new_name = os.path.join(path, str(count))
os.rename(os.path.join(path, file),new_name)
count += 1
for file in os.listdir(path):
ori_name = path+file
os.rename(ori_name,ori_name+'.jpg')
with tf.Session() as sess:
PATH = cate=[path+x for x in os.listdir(path)]
data = []
pic = []
preprocess()
for i in range(len(PATH)):
picture = path+str(i)+'.jpg';
print(picture)
pic.append(picture)
data.append(read_one_image(picture))
saver = tf.train.import_meta_graph('D:/tensorflow/saver/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('D:/tensorflow/saver/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
logits = graph.get_tensor_by_name("logits_eval:0")
classification_result = sess.run(logits,feed_dict={x:data})
#打印出预测矩阵
print(classification_result)
#打印出预测矩阵每一行最大值的索引
print(tf.argmax(classification_result,1).eval())
output = []
output = tf.argmax(classification_result,1).eval()
for i in range(len(output)):
print("",i,"张图片预测:"+flower_dict[output[i]])