You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.0 KiB
108 lines
4.0 KiB
import os
|
|
import random
|
|
from PIL import Image
|
|
|
|
|
|
def image2matrix(image_filename):
|
|
image = Image.open(image_filename)
|
|
image = image.resize((20, 20))
|
|
matrix = [[image.getpixel((x, y)) for x in range(0, image.size[0])] for y in range(0, image.size[1])]
|
|
return matrix
|
|
|
|
|
|
class DataProcessor:
|
|
def __init__(self, aspect='area', separate_ratio=0.1):
|
|
self.train_dir = './dataset/train/{}/'.format(aspect)
|
|
self.validate_dir = './dataset/val/{}/'.format(aspect)
|
|
self.separate_ratio = separate_ratio
|
|
|
|
self.vectors = []
|
|
self.labels = []
|
|
|
|
self.train_set = []
|
|
self.train_batch_index = 0
|
|
self.train_epoch = 0
|
|
self.validate_set = []
|
|
self.validate_batch_index = 0
|
|
self.test_set = []
|
|
self.test_batch_index = 0
|
|
|
|
self.classes = 0
|
|
self.data_set_count = 0
|
|
|
|
self.load_train()
|
|
self.load_valid()
|
|
|
|
def load_train(self):
|
|
for root, dirs, files in os.walk(self.train_dir):
|
|
self.classes = max(self.classes, len(dirs))
|
|
if len(dirs) == 0:
|
|
label = int(root.split('/')[-1])
|
|
for name in files:
|
|
image_filename = os.path.join(root, name)
|
|
vector = image2matrix(image_filename)
|
|
self.vectors.append(vector)
|
|
self.labels.append(label)
|
|
if random.random() < self.separate_ratio:
|
|
self.test_set.append(self.data_set_count)
|
|
else:
|
|
self.train_set.append(self.data_set_count)
|
|
self.data_set_count += 1
|
|
|
|
def load_valid(self):
|
|
for root, dirs, files in os.walk(self.validate_dir):
|
|
self.classes = max(self.classes, len(dirs))
|
|
if len(dirs) == 0:
|
|
label = int(root.split('/')[-1])
|
|
for name in files:
|
|
image_filename = os.path.join(root, name)
|
|
vector = image2matrix(image_filename)
|
|
self.vectors.append(vector)
|
|
self.labels.append(label)
|
|
if random.random() < self.separate_ratio:
|
|
self.test_set.append(self.data_set_count)
|
|
else:
|
|
self.validate_set.append(self.data_set_count)
|
|
self.data_set_count += 1
|
|
|
|
def next_train_batch(self, batch=100):
|
|
input_x = []
|
|
input_y = []
|
|
for i in range(batch):
|
|
input_x.append(self.vectors[self.train_set[(self.train_batch_index + i) % len(self.train_set)]])
|
|
y = [0] * 34
|
|
y[self.labels[self.train_set[(self.train_batch_index + i) % len(self.train_set)]]] = 1
|
|
input_y.append(y)
|
|
self.train_batch_index += batch
|
|
if self.train_batch_index > len(self.train_set):
|
|
self.train_epoch += 1
|
|
self.train_batch_index %= len(self.train_set)
|
|
return input_x, input_y, self.train_epoch
|
|
|
|
def next_valid_batch(self, batch=100):
|
|
input_x = []
|
|
input_y = []
|
|
for i in range(batch):
|
|
index = random.randint(0, len(self.validate_set) - 1)
|
|
input_x.append(self.vectors[index])
|
|
y = [0] * 34
|
|
y[self.labels[index]] = 1
|
|
input_y.append(y)
|
|
self.validate_batch_index += batch
|
|
self.validate_batch_index %= len(self.validate_set)
|
|
return input_x, input_y, self.train_epoch
|
|
|
|
def next_test_batch(self, batch=100):
|
|
input_x = []
|
|
input_y = []
|
|
for i in range(batch):
|
|
input_x.append(self.vectors[self.test_set[(self.test_batch_index + i) % len(self.test_set)]])
|
|
y = [0] * 34
|
|
y[self.labels[self.test_set[(self.test_batch_index + i) % (len(self.test_set))]]] = 1
|
|
input_y.append(y)
|
|
self.test_batch_index += batch
|
|
if self.test_batch_index > len(self.test_set):
|
|
self.train_epoch += 1
|
|
self.test_batch_index %= len(self.test_set)
|
|
return input_x, input_y, self.train_epoch
|