diff --git a/generate_data.py b/generate_data.py new file mode 100644 index 0000000..97797d0 --- /dev/null +++ b/generate_data.py @@ -0,0 +1,44 @@ +import os +import paddle as paddle +from multiprocessing import cpu_count +import numpy as np +from PIL import Image + +data_path='./data_set' +train_data='./train_data.list' +test_data='./test_data.list' +characters_folders=os.listdir(data_path) +if(os.path.exists('./train_data.list')): + os.remove('./train_data.list') +if(os.path.exists('./test_data.list')): + os.remove('./test_data.list') + +for characters_folder in characters_folders: + with open(train_data,'a') as f_train: + with open(test_data,'a') as f_test: + character_imgs = os.listdir(os.path.join(data_path,characters_folder)) + count = 0 + for img in character_imgs: + filePath = data_path+"/"+characters_folder+"/"+img + if count%10==0: + f_test.write(filePath+"\t"+characters_folder+"\n") + else: + f_train.write(filePath+"\t"+characters_folder+"\n") + count+=1 +def data_mapper(sample): + img, label = sample + img = Image.open(img) + img = img.resize((100, 100), Image.ANTIALIAS) + img = np.array(img).astype('float32') + img = img.transpose((2, 0, 1)) + img = img/255.0 + return img, label + +def data_reader(data_list_path): + def reader(): + with open(data_list_path, 'r') as f: + lines = f.readlines() + for line in lines: + img, label = line.split('\t') + yield img, int(label) + return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 512) \ No newline at end of file