You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
1.7 KiB
58 lines
1.7 KiB
import tensorflow as tf
|
|
import numpy as np
|
|
from PIL import Image
|
|
import os
|
|
|
|
flower_dict = {0:'小熊猫',1:'滑稽',2:'萌妹子',3:'小坏坏',4:'小黄鸡'}
|
|
path = 'C:/Users/yuanshao/Desktop/test/'
|
|
|
|
w=100
|
|
h=100
|
|
c=3
|
|
|
|
def read_one_image(path):
|
|
image = Image.open(path).convert('RGB')
|
|
img = image.resize((w, h), Image.ANTIALIAS)
|
|
return np.asarray(img)
|
|
|
|
def preprocess():
|
|
count = 0
|
|
for file in os.listdir(path):
|
|
new_name = os.path.join(path, str(count))
|
|
os.rename(os.path.join(path, file),new_name)
|
|
count += 1
|
|
for file in os.listdir(path):
|
|
ori_name = path+file
|
|
os.rename(ori_name,ori_name+'.jpg')
|
|
|
|
with tf.Session() as sess:
|
|
PATH = cate=[path+x for x in os.listdir(path)]
|
|
data = []
|
|
pic = []
|
|
preprocess()
|
|
for i in range(len(PATH)):
|
|
picture = path+str(i)+'.jpg';
|
|
print(picture)
|
|
pic.append(picture)
|
|
data.append(read_one_image(picture))
|
|
|
|
saver = tf.train.import_meta_graph('D:/tensorflow/saver/model.ckpt.meta')
|
|
saver.restore(sess,tf.train.latest_checkpoint('D:/tensorflow/saver/'))
|
|
|
|
graph = tf.get_default_graph()
|
|
x = graph.get_tensor_by_name("x:0")
|
|
|
|
logits = graph.get_tensor_by_name("logits_eval:0")
|
|
|
|
classification_result = sess.run(logits,feed_dict={x:data})
|
|
|
|
#打印出预测矩阵
|
|
print(classification_result)
|
|
#打印出预测矩阵每一行最大值的索引
|
|
print(tf.argmax(classification_result,1).eval())
|
|
output = []
|
|
output = tf.argmax(classification_result,1).eval()
|
|
for i in range(len(output)):
|
|
print("第",i,"张图片预测:"+flower_dict[output[i]])
|
|
|