diff --git a/generate_data.py b/generate_data.py new file mode 100644 index 0000000..5dac6a1 --- /dev/null +++ b/generate_data.py @@ -0,0 +1,93 @@ +import os +import random +import json +import numpy as np +from PIL import Image +from 口罩检测.util import train_parameters +def get_data_list(target_path,train_list_path,eval_list_path): + #存放的类别信息 + class_detail=[] + class_dirs = os.listdir(target_path) + all_class_images=0 #数据集中总的图像数量 + class_label=0 #存放类别标签 + class_dim=0 #存放类别数目 + train_list=[] #训练集 + eval_list=[] #测试集 + for class_dir in class_dirs: + class_dim+=1 + class_detail_list={} + eval_sum=0 + trainer_sum=0 + class_sum=0 #每个类别有多少张图片 + path = target_path+"/"+class_dir + img_paths = os.listdir(path) + for img_path in img_paths: + name_path = path+"/"+img_path + if class_sum%10==0: + eval_sum+=1 + eval_list.append(name_path+"\t%d"%class_label+"\n") + + else: + trainer_sum+=1 + train_list.append(name_path+"\t%d"%class_label+"\n") + class_sum+=1 + all_class_images+=1 + class_detail_list['class_name']=class_dir + class_detail_list['class_label']=class_label + class_detail_list['class_eval_images']=eval_sum + class_detail_list['class_train_images']=trainer_sum + class_detail.append(class_detail_list) + train_parameters['label_dict'][str(class_label)]= class_dir + class_label+=1 + + train_parameters['class_dim']=class_dim + + random.shuffle(eval_list) + with open(eval_list_path,'a') as eval_file: + for eval_image in eval_list: + eval_file.write(eval_image) + + random.shuffle(train_list) + with open(train_list_path,'a') as train_file: + for train_item in train_list: + train_file.write(train_item) + + #说明json的文件信息 + readJson={} + readJson['all_class_name']=target_path + readJson['all_class_images']= all_class_images + readJson['class_detail']=class_detail + jsons = json.dumps(readJson, sort_keys=True, indent=4, separators=(',', ': ')) + with open(train_parameters['readme_path'],'w') as f: + f.write(jsons) + print ('生成数据列表完成!') + +def custom_reader(file_list): + def reader(): + with open(file_list,'r') as f: + lines=[line.strip() for line in f] + for line in lines: + img_path,label=line.strip().split('\t') + img =Image.open(img_path) + if img.mode!='RGB': + img = img.convert('RGB') + img = img.resize((244,244),Image.BILINEAR) + img = np.array(img).astype('float32') + img = img.transpose((2,0,1)) + img = img/255.0 + yield img,int(label) + return reader + +target_path=train_parameters['target_path'] +train_list_path=train_parameters['train_list_path'] +eval_list_path=train_parameters['eval_list_path'] +batch_size=train_parameters['train_batch_size'] +#每次生成数据列表前,首先清空train.txt和eval.txt +with open(train_list_path, 'w') as f: + f.seek(0) + f.truncate() +with open(eval_list_path, 'w') as f: + f.seek(0) + f.truncate() +#生成数据列表 +get_data_list(target_path,train_list_path,eval_list_path) \ No newline at end of file