You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.7 KiB

4 years ago
import os.path as osp
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS
@DATASETS.register_module
class BaseDataset(Dataset):
def __init__(self, img_path, data_list, list_path='list', cfg=None):
self.cfg = cfg
self.img_path = img_path
self.list_path = osp.join(img_path, list_path)
self.data_list = data_list
self.is_training = ('train' in data_list)
self.img_name_list = []
self.full_img_path_list = []
self.label_list = []
self.exist_list = []
self.transform = self.transform_train() if self.is_training else self.transform_val()
self.init()
def transform_train(self):
raise NotImplementedError()
def transform_val(self):
val_transform = torchvision.transforms.Compose([
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
self.cfg.img_norm['std'], (1, ))),
])
return val_transform
def view(self, img, coords, file_path=None):
for coord in coords:
for x, y in coord:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
if file_path is not None:
if not os.path.exists(osp.dirname(file_path)):
os.makedirs(osp.dirname(file_path))
cv2.imwrite(file_path, img)
def init(self):
raise NotImplementedError()
def __len__(self):
return len(self.full_img_path_list)
def __getitem__(self, idx):
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
img = img[self.cfg.cut_height:, :, :]
if self.is_training:
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
label = label.squeeze()
label = label[self.cfg.cut_height:, :]
exist = self.exist_list[idx]
if self.transform:
img, label = self.transform((img, label))
label = torch.from_numpy(label).contiguous().long()
else:
img, = self.transform((img,))
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
meta = {'full_img_path': self.full_img_path_list[idx],
'img_name': self.img_name_list[idx]}
data = {'img': img, 'meta': meta}
if self.is_training:
data.update({'label': label, 'exist': exist})
return data