import os import random import json import numpy as np from PIL import Image from 口罩检测.util import train_parameters def get_data_list(target_path,train_list_path,eval_list_path): #存放的类别信息 class_detail=[] class_dirs = os.listdir(target_path) all_class_images=0 #数据集中总的图像数量 class_label=0 #存放类别标签 class_dim=0 #存放类别数目 train_list=[] #训练集 eval_list=[] #测试集 for class_dir in class_dirs: class_dim+=1 class_detail_list={} eval_sum=0 trainer_sum=0 class_sum=0 #每个类别有多少张图片 path = target_path+"/"+class_dir img_paths = os.listdir(path) for img_path in img_paths: name_path = path+"/"+img_path if class_sum%10==0: eval_sum+=1 eval_list.append(name_path+"\t%d"%class_label+"\n") else: trainer_sum+=1 train_list.append(name_path+"\t%d"%class_label+"\n") class_sum+=1 all_class_images+=1 class_detail_list['class_name']=class_dir class_detail_list['class_label']=class_label class_detail_list['class_eval_images']=eval_sum class_detail_list['class_train_images']=trainer_sum class_detail.append(class_detail_list) train_parameters['label_dict'][str(class_label)]= class_dir class_label+=1 train_parameters['class_dim']=class_dim random.shuffle(eval_list) with open(eval_list_path,'a') as eval_file: for eval_image in eval_list: eval_file.write(eval_image) random.shuffle(train_list) with open(train_list_path,'a') as train_file: for train_item in train_list: train_file.write(train_item) #说明json的文件信息 readJson={} readJson['all_class_name']=target_path readJson['all_class_images']= all_class_images readJson['class_detail']=class_detail jsons = json.dumps(readJson, sort_keys=True, indent=4, separators=(',', ': ')) with open(train_parameters['readme_path'],'w') as f: f.write(jsons) print ('生成数据列表完成!') def custom_reader(file_list): def reader(): with open(file_list,'r') as f: lines=[line.strip() for line in f] for line in lines: img_path,label=line.strip().split('\t') img =Image.open(img_path) if img.mode!='RGB': img = img.convert('RGB') img = img.resize((244,244),Image.BILINEAR) img = np.array(img).astype('float32') img = img.transpose((2,0,1)) img = img/255.0 yield img,int(label) return reader target_path=train_parameters['target_path'] train_list_path=train_parameters['train_list_path'] eval_list_path=train_parameters['eval_list_path'] batch_size=train_parameters['train_batch_size'] #每次生成数据列表前,首先清空train.txt和eval.txt with open(train_list_path, 'w') as f: f.seek(0) f.truncate() with open(eval_list_path, 'w') as f: f.seek(0) f.truncate() #生成数据列表 get_data_list(target_path,train_list_path,eval_list_path)