You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
3.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import random
import json
import numpy as np
from PIL import Image
from 口罩检测.util import train_parameters
def get_data_list(target_path,train_list_path,eval_list_path):
#存放的类别信息
class_detail=[]
class_dirs = os.listdir(target_path)
all_class_images=0 #数据集中总的图像数量
class_label=0 #存放类别标签
class_dim=0 #存放类别数目
train_list=[] #训练集
eval_list=[] #测试集
for class_dir in class_dirs:
class_dim+=1
class_detail_list={}
eval_sum=0
trainer_sum=0
class_sum=0 #每个类别有多少张图片
path = target_path+"/"+class_dir
img_paths = os.listdir(path)
for img_path in img_paths:
name_path = path+"/"+img_path
if class_sum%10==0:
eval_sum+=1
eval_list.append(name_path+"\t%d"%class_label+"\n")
else:
trainer_sum+=1
train_list.append(name_path+"\t%d"%class_label+"\n")
class_sum+=1
all_class_images+=1
class_detail_list['class_name']=class_dir
class_detail_list['class_label']=class_label
class_detail_list['class_eval_images']=eval_sum
class_detail_list['class_train_images']=trainer_sum
class_detail.append(class_detail_list)
train_parameters['label_dict'][str(class_label)]= class_dir
class_label+=1
train_parameters['class_dim']=class_dim
random.shuffle(eval_list)
with open(eval_list_path,'a') as eval_file:
for eval_image in eval_list:
eval_file.write(eval_image)
random.shuffle(train_list)
with open(train_list_path,'a') as train_file:
for train_item in train_list:
train_file.write(train_item)
#说明json的文件信息
readJson={}
readJson['all_class_name']=target_path
readJson['all_class_images']= all_class_images
readJson['class_detail']=class_detail
jsons = json.dumps(readJson, sort_keys=True, indent=4, separators=(',', ': '))
with open(train_parameters['readme_path'],'w') as f:
f.write(jsons)
print ('生成数据列表完成!')
def custom_reader(file_list):
def reader():
with open(file_list,'r') as f:
lines=[line.strip() for line in f]
for line in lines:
img_path,label=line.strip().split('\t')
img =Image.open(img_path)
if img.mode!='RGB':
img = img.convert('RGB')
img = img.resize((244,244),Image.BILINEAR)
img = np.array(img).astype('float32')
img = img.transpose((2,0,1))
img = img/255.0
yield img,int(label)
return reader
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
#每次生成数据列表前首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f:
f.seek(0)
f.truncate()
with open(eval_list_path, 'w') as f:
f.seek(0)
f.truncate()
#生成数据列表
get_data_list(target_path,train_list_path,eval_list_path)