You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
5.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 输入图片处理
import os
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# -----------------生成图片路径和标签的List------------------------------------
# 生成图片的存储位置 /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/
train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/'
one = []
label_one = []
two = []
label_two = []
three = []
label_three = []
four = []
label_four = []
five = []
label_five = []
six = []
label_six = []
seven = []
label_seven = []
eight = []
label_eight = []
nine = []
label_nine = []
# step1获取'/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train'下所有的图片路径名,存放到
# 对应的列表中同时贴上标签存放到label列表中。
# ratio是测试集的比例
def get_files(file_dir, ratio):
for file in os.listdir(file_dir + '/one'):
one.append(file_dir + '/one' + '/' + file)
label_one.append(0)
for file in os.listdir(file_dir + '/two'):
two.append(file_dir + '/two' + '/' + file)
label_two.append(1)
for file in os.listdir(file_dir + '/three'):
three.append(file_dir + '/three' + '/' + file)
label_three.append(2)
for file in os.listdir(file_dir + '/four'):
four.append(file_dir + '/four' + '/' + file)
label_four.append(3)
for file in os.listdir(file_dir + '/five'):
five.append(file_dir + '/five' + '/' + file)
label_five.append(5)
for file in os.listdir(file_dir + '/six'):
six.append(file_dir + '/six' + '/' + file)
label_six.append(6)
for file in os.listdir(file_dir + '/seven'):
seven.append(file_dir + '/seven' + '/' + file)
label_seven.append(7)
for file in os.listdir(file_dir + '/eight'):
eight.append(file_dir + '/eight' + '/' + file)
label_eight.append(8)
for file in os.listdir(file_dir + '/nine'):
nine.append(file_dir + '/nine' + '/' + file)
label_nine.append(9)
# step2对生成的图片路径和标签List做打乱处理把所有的数据合起来组成一个listimg和lab
# np.hstack水平按列按顺序堆叠数组。
# >>> a = np.array((1,2,3))
# >>> b = np.array((2,3,4))
# >>> np.hstack((a,b))
# array([1, 2, 3, 2, 3, 4])
image_list = np.hstack((one, two, three, four, five, six, seven, eight, nine))
label_list = np.hstack((label_one, label_two, label_three, label_four, label_five, label_six, label_seven, label_eight, label_nine))
# 利用shuffle打乱顺序
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
# 从打乱的temp中再取出listimg和lab
# image_list = list(temp[:, 0])
# label_list = list(temp[:, 1])
# label_list = [int(i) for i in label_list]
# return image_list, label_list
# 将所有的img和lab转换成list
all_image_list = list(temp[:, 0])
all_label_list = list(temp[:, 1])
# 将所得List分为两部分一部分用来训练tra一部分用来测试val
# ratio是测试集的比例
# n_sample全部样本数
n_sample = len(all_label_list)
n_val = int(math.ceil(n_sample * ratio)) # 测试样本数
n_train = n_sample - n_val # 训练样本数
# 训练的图片和标签
tra_images = all_image_list[0:n_train]
tra_labels = all_label_list[0:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
# 测试图片和标签
val_images = all_image_list[n_train:-1]
val_labels = all_label_list[n_train:-1]
val_labels = [int(float(i)) for i in val_labels]
return tra_images, tra_labels, val_images, val_labels
# --------------------生成Batch----------------------------------------------
# step1将上面生成的List传入get_batch() 转换类型产生一个输入队列queue因为img和lab
# 是分开的所以使用tf.train.slice_input_producer()然后用tf.read_file()从队列中读取图像
# image_W, image_H, :设置好固定的图像高度和宽度
# 设置batch_size每个batch要放多少张图片
# capacity一个队列最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):
# 转换类型
image = tf.cast(image, tf.string)
label = tf.cast(label, tf.int32)
# make an input queue
# tf.train.slice_input_producer是一个tensor生成器作用是按照设定
# 每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。
input_queue = tf.compat.v1.train.slice_input_producer([image, label])
label = input_queue[1]
image_contents = tf.io.read_file(input_queue[0]) # read img from a queue
# step2将图像解码不同类型的图像不能混在一起要么只用jpeg要么只用png等。
image = tf.image.decode_jpeg(image_contents, channels=3)
# step3数据预处理对图像进行旋转、缩放、裁剪、归一化等操作让计算出的模型更健壮。
image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
image = tf.image.per_image_standardization(image)
# step4生成batch
# image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32
# label_batch: 1D tensor [batch_size], dtype=tf.int32
image_batch, label_batch = tf.compat.v1.train.batch([image, label],
batch_size=batch_size,
num_threads=32,
capacity=capacity)
# 重新排列label行数为[batch_size]
label_batch = tf.reshape(label_batch, [batch_size])
image_batch = tf.cast(image_batch, tf.float32)
return image_batch, label_batch