You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#图片预处理
# 将原始图片转换成需要的大小,并将其保存
import os
import tensorflow as tf
from PIL import Image
# 原始图片的存储位置 /Users/leixinhong/PycharmProjects/classification/teethimg/train-data/
orig_picture = '/Users/leixinhong/PycharmProjects/classification/teethimg/train-data'
# 生成图片的存储位置 /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/
gen_picture = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/'
# 需要的识别类型
classes = ['one', 'two', 'three', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
# 样本总数
num_samples = 90
# 制作TFRecords数据
def create_record():
writer = tf.compat.v1.python_io.TFRecordWriter("dishes_train.tfrecords")
for index, name in list(enumerate(classes)):
# /Users/leixinhong/PycharmProjects/classification/teethimg/train-data/one/
class_path = orig_picture + "/" + name + "/"
# print(index)
# print(name)
# print(class_path)
# print(os.listdir(class_path))
# os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
for img_name in os.listdir(class_path):
img_path = class_path + img_name
#print(img_path)
img = Image.open(img_path)
#print(img)
img = img.resize((64, 64)) # 设置需要转换的图片大小
#print(img)
img_raw = img.tobytes() # 将图片转化为原生bytes
print(index, img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(filename):
# 创建文件队列,不限读取的数量
filename_queue = tf.compat.v1.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.compat.v1.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.io.parse_single_example(
serialized_example,
features={
'label': tf.io.FixedLenFeature([], tf.int64),
'img_raw': tf.io.FixedLenFeature([], tf.string)
})
label = features['label']
img = features['img_raw']
img = tf.io.decode_raw(img, tf.uint8)
img = tf.reshape(img, [64, 64, 3])
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(label, tf.int32)
return img, label
if __name__ == '__main__': # 程序主入口
create_record()
batch = read_and_decode('dishes_train.tfrecords')
init_op = tf.group(tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer())
with tf.compat.v1.Session() as sess: # 开始一个会话
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.compat.v1.train.start_queue_runners(coord=coord)
for i in range(num_samples):
example, lab = sess.run(batch) # 在会话中取出image和label
img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
img.save(gen_picture + '/' + str(i) + 'samples' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上/
print(example, lab)
coord.request_stop()
coord.join(threads)
sess.close()