|
|
import os
|
|
|
import random
|
|
|
import json
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from 口罩检测.util import train_parameters
|
|
|
def get_data_list(target_path,train_list_path,eval_list_path):
|
|
|
#存放的类别信息
|
|
|
class_detail=[]
|
|
|
class_dirs = os.listdir(target_path)
|
|
|
all_class_images=0 #数据集中总的图像数量
|
|
|
class_label=0 #存放类别标签
|
|
|
class_dim=0 #存放类别数目
|
|
|
train_list=[] #训练集
|
|
|
eval_list=[] #测试集
|
|
|
for class_dir in class_dirs:
|
|
|
class_dim+=1
|
|
|
class_detail_list={}
|
|
|
eval_sum=0
|
|
|
trainer_sum=0
|
|
|
class_sum=0 #每个类别有多少张图片
|
|
|
path = target_path+"/"+class_dir
|
|
|
img_paths = os.listdir(path)
|
|
|
for img_path in img_paths:
|
|
|
name_path = path+"/"+img_path
|
|
|
if class_sum%10==0:
|
|
|
eval_sum+=1
|
|
|
eval_list.append(name_path+"\t%d"%class_label+"\n")
|
|
|
|
|
|
else:
|
|
|
trainer_sum+=1
|
|
|
train_list.append(name_path+"\t%d"%class_label+"\n")
|
|
|
class_sum+=1
|
|
|
all_class_images+=1
|
|
|
class_detail_list['class_name']=class_dir
|
|
|
class_detail_list['class_label']=class_label
|
|
|
class_detail_list['class_eval_images']=eval_sum
|
|
|
class_detail_list['class_train_images']=trainer_sum
|
|
|
class_detail.append(class_detail_list)
|
|
|
train_parameters['label_dict'][str(class_label)]= class_dir
|
|
|
class_label+=1
|
|
|
|
|
|
train_parameters['class_dim']=class_dim
|
|
|
|
|
|
random.shuffle(eval_list)
|
|
|
with open(eval_list_path,'a') as eval_file:
|
|
|
for eval_image in eval_list:
|
|
|
eval_file.write(eval_image)
|
|
|
|
|
|
random.shuffle(train_list)
|
|
|
with open(train_list_path,'a') as train_file:
|
|
|
for train_item in train_list:
|
|
|
train_file.write(train_item)
|
|
|
|
|
|
#说明json的文件信息
|
|
|
readJson={}
|
|
|
readJson['all_class_name']=target_path
|
|
|
readJson['all_class_images']= all_class_images
|
|
|
readJson['class_detail']=class_detail
|
|
|
jsons = json.dumps(readJson, sort_keys=True, indent=4, separators=(',', ': '))
|
|
|
with open(train_parameters['readme_path'],'w') as f:
|
|
|
f.write(jsons)
|
|
|
print ('生成数据列表完成!')
|
|
|
|
|
|
def custom_reader(file_list):
|
|
|
def reader():
|
|
|
with open(file_list,'r') as f:
|
|
|
lines=[line.strip() for line in f]
|
|
|
for line in lines:
|
|
|
img_path,label=line.strip().split('\t')
|
|
|
img =Image.open(img_path)
|
|
|
if img.mode!='RGB':
|
|
|
img = img.convert('RGB')
|
|
|
img = img.resize((244,244),Image.BILINEAR)
|
|
|
img = np.array(img).astype('float32')
|
|
|
img = img.transpose((2,0,1))
|
|
|
img = img/255.0
|
|
|
yield img,int(label)
|
|
|
return reader
|
|
|
|
|
|
target_path=train_parameters['target_path']
|
|
|
train_list_path=train_parameters['train_list_path']
|
|
|
eval_list_path=train_parameters['eval_list_path']
|
|
|
batch_size=train_parameters['train_batch_size']
|
|
|
#每次生成数据列表前,首先清空train.txt和eval.txt
|
|
|
with open(train_list_path, 'w') as f:
|
|
|
f.seek(0)
|
|
|
f.truncate()
|
|
|
with open(eval_list_path, 'w') as f:
|
|
|
f.seek(0)
|
|
|
f.truncate()
|
|
|
#生成数据列表
|
|
|
get_data_list(target_path,train_list_path,eval_list_path) |