You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
3.3 KiB

import os
import random
import json
import numpy as np
from PIL import Image
from 口罩检测.util import train_parameters
def get_data_list(target_path,train_list_path,eval_list_path):
#存放的类别信息
class_detail=[]
class_dirs = os.listdir(target_path)
all_class_images=0 #数据集中总的图像数量
class_label=0 #存放类别标签
class_dim=0 #存放类别数目
train_list=[] #训练集
eval_list=[] #测试集
for class_dir in class_dirs:
class_dim+=1
class_detail_list={}
eval_sum=0
trainer_sum=0
class_sum=0 #每个类别有多少张图片
path = target_path+"/"+class_dir
img_paths = os.listdir(path)
for img_path in img_paths:
name_path = path+"/"+img_path
if class_sum%10==0:
eval_sum+=1
eval_list.append(name_path+"\t%d"%class_label+"\n")
else:
trainer_sum+=1
train_list.append(name_path+"\t%d"%class_label+"\n")
class_sum+=1
all_class_images+=1
class_detail_list['class_name']=class_dir
class_detail_list['class_label']=class_label
class_detail_list['class_eval_images']=eval_sum
class_detail_list['class_train_images']=trainer_sum
class_detail.append(class_detail_list)
train_parameters['label_dict'][str(class_label)]= class_dir
class_label+=1
train_parameters['class_dim']=class_dim
random.shuffle(eval_list)
with open(eval_list_path,'a') as eval_file:
for eval_image in eval_list:
eval_file.write(eval_image)
random.shuffle(train_list)
with open(train_list_path,'a') as train_file:
for train_item in train_list:
train_file.write(train_item)
#说明json的文件信息
readJson={}
readJson['all_class_name']=target_path
readJson['all_class_images']= all_class_images
readJson['class_detail']=class_detail
jsons = json.dumps(readJson, sort_keys=True, indent=4, separators=(',', ': '))
with open(train_parameters['readme_path'],'w') as f:
f.write(jsons)
print ('生成数据列表完成!')
def custom_reader(file_list):
def reader():
with open(file_list,'r') as f:
lines=[line.strip() for line in f]
for line in lines:
img_path,label=line.strip().split('\t')
img =Image.open(img_path)
if img.mode!='RGB':
img = img.convert('RGB')
img = img.resize((244,244),Image.BILINEAR)
img = np.array(img).astype('float32')
img = img.transpose((2,0,1))
img = img/255.0
yield img,int(label)
return reader
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
#每次生成数据列表前首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f:
f.seek(0)
f.truncate()
with open(eval_list_path, 'w') as f:
f.seek(0)
f.truncate()
#生成数据列表
get_data_list(target_path,train_list_path,eval_list_path)