You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
4.0 KiB

import os
import random
from PIL import Image
def image2matrix(image_filename):
image = Image.open(image_filename)
image = image.resize((20, 20))
matrix = [[image.getpixel((x, y)) for x in range(0, image.size[0])] for y in range(0, image.size[1])]
return matrix
class DataProcessor:
def __init__(self, aspect='area', separate_ratio=0.1):
self.train_dir = './dataset/train/{}/'.format(aspect)
self.validate_dir = './dataset/val/{}/'.format(aspect)
self.separate_ratio = separate_ratio
self.vectors = []
self.labels = []
self.train_set = []
self.train_batch_index = 0
self.train_epoch = 0
self.validate_set = []
self.validate_batch_index = 0
self.test_set = []
self.test_batch_index = 0
self.classes = 0
self.data_set_count = 0
self.load_train()
self.load_valid()
def load_train(self):
for root, dirs, files in os.walk(self.train_dir):
self.classes = max(self.classes, len(dirs))
if len(dirs) == 0:
label = int(root.split('/')[-1])
for name in files:
image_filename = os.path.join(root, name)
vector = image2matrix(image_filename)
self.vectors.append(vector)
self.labels.append(label)
if random.random() < self.separate_ratio:
self.test_set.append(self.data_set_count)
else:
self.train_set.append(self.data_set_count)
self.data_set_count += 1
def load_valid(self):
for root, dirs, files in os.walk(self.validate_dir):
self.classes = max(self.classes, len(dirs))
if len(dirs) == 0:
label = int(root.split('/')[-1])
for name in files:
image_filename = os.path.join(root, name)
vector = image2matrix(image_filename)
self.vectors.append(vector)
self.labels.append(label)
if random.random() < self.separate_ratio:
self.test_set.append(self.data_set_count)
else:
self.validate_set.append(self.data_set_count)
self.data_set_count += 1
def next_train_batch(self, batch=100):
input_x = []
input_y = []
for i in range(batch):
input_x.append(self.vectors[self.train_set[(self.train_batch_index + i) % len(self.train_set)]])
y = [0] * 34
y[self.labels[self.train_set[(self.train_batch_index + i) % len(self.train_set)]]] = 1
input_y.append(y)
self.train_batch_index += batch
if self.train_batch_index > len(self.train_set):
self.train_epoch += 1
self.train_batch_index %= len(self.train_set)
return input_x, input_y, self.train_epoch
def next_valid_batch(self, batch=100):
input_x = []
input_y = []
for i in range(batch):
index = random.randint(0, len(self.validate_set) - 1)
input_x.append(self.vectors[index])
y = [0] * 34
y[self.labels[index]] = 1
input_y.append(y)
self.validate_batch_index += batch
self.validate_batch_index %= len(self.validate_set)
return input_x, input_y, self.train_epoch
def next_test_batch(self, batch=100):
input_x = []
input_y = []
for i in range(batch):
input_x.append(self.vectors[self.test_set[(self.test_batch_index + i) % len(self.test_set)]])
y = [0] * 34
y[self.labels[self.test_set[(self.test_batch_index + i) % (len(self.test_set))]]] = 1
input_y.append(y)
self.test_batch_index += batch
if self.test_batch_index > len(self.test_set):
self.train_epoch += 1
self.test_batch_index %= len(self.test_set)
return input_x, input_y, self.train_epoch