You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
551 lines
20 KiB
551 lines
20 KiB
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
|
|
import glob
|
|
import os
|
|
import os.path as osp
|
|
import random
|
|
import json
|
|
import time
|
|
import hashlib
|
|
|
|
from multiprocessing.pool import Pool
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import ExifTags, Image, ImageOps
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
|
|
from .data_augment import (
|
|
augment_hsv,
|
|
letterbox,
|
|
mixup,
|
|
random_affine,
|
|
mosaic_augmentation,
|
|
)
|
|
from yolov6.utils.events import LOGGER
|
|
|
|
# Parameters
|
|
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
|
|
# Get orientation exif tag
|
|
for k, v in ExifTags.TAGS.items():
|
|
if v == "Orientation":
|
|
ORIENTATION = k
|
|
break
|
|
|
|
|
|
class TrainValDataset(Dataset):
|
|
# YOLOv6 train_loader/val_loader, loads images and labels for training and validation
|
|
def __init__(
|
|
self,
|
|
img_dir,
|
|
img_size=640,
|
|
batch_size=16,
|
|
augment=False,
|
|
hyp=None,
|
|
rect=False,
|
|
check_images=False,
|
|
check_labels=False,
|
|
stride=32,
|
|
pad=0.0,
|
|
rank=-1,
|
|
data_dict=None,
|
|
task="train",
|
|
):
|
|
assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}"
|
|
t1 = time.time()
|
|
self.__dict__.update(locals())
|
|
self.main_process = self.rank in (-1, 0)
|
|
self.task = self.task.capitalize()
|
|
self.class_names = data_dict["names"]
|
|
self.img_paths, self.labels = self.get_imgs_labels(self.img_dir)
|
|
if self.rect:
|
|
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
|
|
self.shapes = np.array(shapes, dtype=np.float64)
|
|
self.batch_indices = np.floor(
|
|
np.arange(len(shapes)) / self.batch_size
|
|
).astype(
|
|
np.int
|
|
) # batch indices of each image
|
|
self.sort_files_shapes()
|
|
t2 = time.time()
|
|
if self.main_process:
|
|
LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1))
|
|
|
|
def __len__(self):
|
|
"""Get the length of dataset"""
|
|
return len(self.img_paths)
|
|
|
|
def __getitem__(self, index):
|
|
"""Fetching a data sample for a given key.
|
|
This function applies mosaic and mixup augments during training.
|
|
During validation, letterbox augment is applied.
|
|
"""
|
|
# Mosaic Augmentation
|
|
if self.augment and random.random() < self.hyp["mosaic"]:
|
|
img, labels = self.get_mosaic(index)
|
|
shapes = None
|
|
|
|
# MixUp augmentation
|
|
if random.random() < self.hyp["mixup"]:
|
|
img_other, labels_other = self.get_mosaic(
|
|
random.randint(0, len(self.img_paths) - 1)
|
|
)
|
|
img, labels = mixup(img, labels, img_other, labels_other)
|
|
|
|
else:
|
|
# Load image
|
|
img, (h0, w0), (h, w) = self.load_image(index)
|
|
|
|
# Letterbox
|
|
shape = (
|
|
self.batch_shapes[self.batch_indices[index]]
|
|
if self.rect
|
|
else self.img_size
|
|
) # final letterboxed shape
|
|
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
|
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
|
|
|
labels = self.labels[index].copy()
|
|
if labels.size:
|
|
w *= ratio
|
|
h *= ratio
|
|
# new boxes
|
|
boxes = np.copy(labels[:, 1:])
|
|
boxes[:, 0] = (
|
|
w * (labels[:, 1] - labels[:, 3] / 2) + pad[0]
|
|
) # top left x
|
|
boxes[:, 1] = (
|
|
h * (labels[:, 2] - labels[:, 4] / 2) + pad[1]
|
|
) # top left y
|
|
boxes[:, 2] = (
|
|
w * (labels[:, 1] + labels[:, 3] / 2) + pad[0]
|
|
) # bottom right x
|
|
boxes[:, 3] = (
|
|
h * (labels[:, 2] + labels[:, 4] / 2) + pad[1]
|
|
) # bottom right y
|
|
labels[:, 1:] = boxes
|
|
|
|
if self.augment:
|
|
img, labels = random_affine(
|
|
img,
|
|
labels,
|
|
degrees=self.hyp["degrees"],
|
|
translate=self.hyp["translate"],
|
|
scale=self.hyp["scale"],
|
|
shear=self.hyp["shear"],
|
|
new_shape=(self.img_size, self.img_size),
|
|
)
|
|
|
|
if len(labels):
|
|
h, w = img.shape[:2]
|
|
|
|
labels[:, [1, 3]] = labels[:, [1, 3]].clip(0, w - 1e-3) # x1, x2
|
|
labels[:, [2, 4]] = labels[:, [2, 4]].clip(0, h - 1e-3) # y1, y2
|
|
|
|
boxes = np.copy(labels[:, 1:])
|
|
boxes[:, 0] = ((labels[:, 1] + labels[:, 3]) / 2) / w # x center
|
|
boxes[:, 1] = ((labels[:, 2] + labels[:, 4]) / 2) / h # y center
|
|
boxes[:, 2] = (labels[:, 3] - labels[:, 1]) / w # width
|
|
boxes[:, 3] = (labels[:, 4] - labels[:, 2]) / h # height
|
|
labels[:, 1:] = boxes
|
|
|
|
if self.augment:
|
|
img, labels = self.general_augment(img, labels)
|
|
|
|
labels_out = torch.zeros((len(labels), 6))
|
|
if len(labels):
|
|
labels_out[:, 1:] = torch.from_numpy(labels)
|
|
|
|
# Convert
|
|
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
img = np.ascontiguousarray(img)
|
|
|
|
return torch.from_numpy(img), labels_out, self.img_paths[index], shapes
|
|
|
|
def load_image(self, index):
|
|
"""Load image.
|
|
This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio.
|
|
|
|
Returns:
|
|
Image, original shape of image, resized image shape
|
|
"""
|
|
path = self.img_paths[index]
|
|
im = cv2.imread(path)
|
|
assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}"
|
|
|
|
h0, w0 = im.shape[:2] # origin shape
|
|
r = self.img_size / max(h0, w0)
|
|
if r != 1:
|
|
im = cv2.resize(
|
|
im,
|
|
(int(w0 * r), int(h0 * r)),
|
|
interpolation=cv2.INTER_AREA
|
|
if r < 1 and not self.augment
|
|
else cv2.INTER_LINEAR,
|
|
)
|
|
return im, (h0, w0), im.shape[:2]
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
"""Merges a list of samples to form a mini-batch of Tensor(s)"""
|
|
img, label, path, shapes = zip(*batch)
|
|
for i, l in enumerate(label):
|
|
l[:, 0] = i # add target image index for build_targets()
|
|
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
|
|
|
def get_imgs_labels(self, img_dir):
|
|
|
|
assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!"
|
|
valid_img_record = osp.join(
|
|
osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
|
|
)
|
|
NUM_THREADS = min(8, os.cpu_count())
|
|
|
|
img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
|
|
img_paths = sorted(
|
|
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
|
|
)
|
|
assert img_paths, f"No images found in {img_dir}."
|
|
|
|
img_hash = self.get_hash(img_paths)
|
|
if osp.exists(valid_img_record):
|
|
with open(valid_img_record, "r") as f:
|
|
cache_info = json.load(f)
|
|
if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
|
|
img_info = cache_info["information"]
|
|
else:
|
|
self.check_images = True
|
|
else:
|
|
self.check_images = True
|
|
|
|
# check images
|
|
if self.check_images and self.main_process:
|
|
img_info = {}
|
|
nc, msgs = 0, [] # number corrupt, messages
|
|
LOGGER.info(
|
|
f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
|
|
)
|
|
with Pool(NUM_THREADS) as pool:
|
|
pbar = tqdm(
|
|
pool.imap(TrainValDataset.check_image, img_paths),
|
|
total=len(img_paths),
|
|
)
|
|
for img_path, shape_per_img, nc_per_img, msg in pbar:
|
|
if nc_per_img == 0: # not corrupted
|
|
img_info[img_path] = {"shape": shape_per_img}
|
|
nc += nc_per_img
|
|
if msg:
|
|
msgs.append(msg)
|
|
pbar.desc = f"{nc} image(s) corrupted"
|
|
pbar.close()
|
|
if msgs:
|
|
LOGGER.info("\n".join(msgs))
|
|
|
|
cache_info = {"information": img_info, "image_hash": img_hash}
|
|
# save valid image paths.
|
|
with open(valid_img_record, "w") as f:
|
|
json.dump(cache_info, f)
|
|
|
|
# check and load anns
|
|
label_dir = osp.join(
|
|
osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
|
|
)
|
|
assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
|
|
|
|
img_paths = list(img_info.keys())
|
|
label_paths = sorted(
|
|
osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt")
|
|
for p in img_paths
|
|
)
|
|
label_hash = self.get_hash(label_paths)
|
|
if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
|
|
self.check_labels = True
|
|
|
|
if self.check_labels:
|
|
cache_info["label_hash"] = label_hash
|
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages
|
|
LOGGER.info(
|
|
f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
|
|
)
|
|
with Pool(NUM_THREADS) as pool:
|
|
pbar = pool.imap(
|
|
TrainValDataset.check_label_files, zip(img_paths, label_paths)
|
|
)
|
|
pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
|
|
for (
|
|
img_path,
|
|
labels_per_file,
|
|
nc_per_file,
|
|
nm_per_file,
|
|
nf_per_file,
|
|
ne_per_file,
|
|
msg,
|
|
) in pbar:
|
|
if nc_per_file == 0:
|
|
img_info[img_path]["labels"] = labels_per_file
|
|
else:
|
|
img_info.pop(img_path)
|
|
nc += nc_per_file
|
|
nm += nm_per_file
|
|
nf += nf_per_file
|
|
ne += ne_per_file
|
|
if msg:
|
|
msgs.append(msg)
|
|
if self.main_process:
|
|
pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
|
|
if self.main_process:
|
|
pbar.close()
|
|
with open(valid_img_record, "w") as f:
|
|
json.dump(cache_info, f)
|
|
if msgs:
|
|
LOGGER.info("\n".join(msgs))
|
|
if nf == 0:
|
|
LOGGER.warning(
|
|
f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
|
|
)
|
|
|
|
if self.task.lower() == "val":
|
|
if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
|
|
assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
|
|
else:
|
|
assert (
|
|
self.class_names
|
|
), "Class names is required when converting labels to coco format for evaluating."
|
|
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
|
|
if not osp.exists(save_dir):
|
|
os.mkdir(save_dir)
|
|
save_path = osp.join(
|
|
save_dir, "instances_" + osp.basename(img_dir) + ".json"
|
|
)
|
|
TrainValDataset.generate_coco_format_labels(
|
|
img_info, self.class_names, save_path
|
|
)
|
|
|
|
img_paths, labels = list(
|
|
zip(
|
|
*[
|
|
(
|
|
img_path,
|
|
np.array(info["labels"], dtype=np.float32)
|
|
if info["labels"]
|
|
else np.zeros((0, 5), dtype=np.float32),
|
|
)
|
|
for img_path, info in img_info.items()
|
|
]
|
|
)
|
|
)
|
|
self.img_info = img_info
|
|
LOGGER.info(
|
|
f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
|
|
)
|
|
return img_paths, labels
|
|
|
|
def get_mosaic(self, index):
|
|
"""Gets images and labels after mosaic augments"""
|
|
indices = [index] + random.choices(
|
|
range(0, len(self.img_paths)), k=3
|
|
) # 3 additional image indices
|
|
random.shuffle(indices)
|
|
imgs, hs, ws, labels = [], [], [], []
|
|
for index in indices:
|
|
img, _, (h, w) = self.load_image(index)
|
|
labels_per_img = self.labels[index]
|
|
imgs.append(img)
|
|
hs.append(h)
|
|
ws.append(w)
|
|
labels.append(labels_per_img)
|
|
img, labels = mosaic_augmentation(self.img_size, imgs, hs, ws, labels, self.hyp)
|
|
return img, labels
|
|
|
|
def general_augment(self, img, labels):
|
|
"""Gets images and labels after general augment
|
|
This function applies hsv, random ud-flip and random lr-flips augments.
|
|
"""
|
|
nl = len(labels)
|
|
|
|
# HSV color-space
|
|
augment_hsv(
|
|
img,
|
|
hgain=self.hyp["hsv_h"],
|
|
sgain=self.hyp["hsv_s"],
|
|
vgain=self.hyp["hsv_v"],
|
|
)
|
|
|
|
# Flip up-down
|
|
if random.random() < self.hyp["flipud"]:
|
|
img = np.flipud(img)
|
|
if nl:
|
|
labels[:, 2] = 1 - labels[:, 2]
|
|
|
|
# Flip left-right
|
|
if random.random() < self.hyp["fliplr"]:
|
|
img = np.fliplr(img)
|
|
if nl:
|
|
labels[:, 1] = 1 - labels[:, 1]
|
|
|
|
return img, labels
|
|
|
|
def sort_files_shapes(self):
|
|
# Sort by aspect ratio
|
|
batch_num = self.batch_indices[-1] + 1
|
|
s = self.shapes # wh
|
|
ar = s[:, 1] / s[:, 0] # aspect ratio
|
|
irect = ar.argsort()
|
|
self.img_paths = [self.img_paths[i] for i in irect]
|
|
self.labels = [self.labels[i] for i in irect]
|
|
self.shapes = s[irect] # wh
|
|
ar = ar[irect]
|
|
|
|
# Set training image shapes
|
|
shapes = [[1, 1]] * batch_num
|
|
for i in range(batch_num):
|
|
ari = ar[self.batch_indices == i]
|
|
mini, maxi = ari.min(), ari.max()
|
|
if maxi < 1:
|
|
shapes[i] = [maxi, 1]
|
|
elif mini > 1:
|
|
shapes[i] = [1, 1 / mini]
|
|
self.batch_shapes = (
|
|
np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(
|
|
np.int
|
|
)
|
|
* self.stride
|
|
)
|
|
|
|
@staticmethod
|
|
def check_image(im_file):
|
|
# verify an image.
|
|
nc, msg = 0, ""
|
|
try:
|
|
im = Image.open(im_file)
|
|
im.verify() # PIL verify
|
|
shape = im.size # (width, height)
|
|
im_exif = im._getexif()
|
|
if im_exif and ORIENTATION in im_exif:
|
|
rotation = im_exif[ORIENTATION]
|
|
if rotation in (6, 8):
|
|
shape = (shape[1], shape[0])
|
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
|
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
|
if im.format.lower() in ("jpg", "jpeg"):
|
|
with open(im_file, "rb") as f:
|
|
f.seek(-2, 2)
|
|
if f.read() != b"\xff\xd9": # corrupt JPEG
|
|
ImageOps.exif_transpose(Image.open(im_file)).save(
|
|
im_file, "JPEG", subsampling=0, quality=100
|
|
)
|
|
msg += f"WARNING: {im_file}: corrupt JPEG restored and saved"
|
|
return im_file, shape, nc, msg
|
|
except Exception as e:
|
|
nc = 1
|
|
msg = f"WARNING: {im_file}: ignoring corrupt image: {e}"
|
|
return im_file, None, nc, msg
|
|
|
|
@staticmethod
|
|
def check_label_files(args):
|
|
img_path, lb_path = args
|
|
nm, nf, ne, nc, msg = 0, 0, 0, 0, "" # number (missing, found, empty, message
|
|
try:
|
|
if osp.exists(lb_path):
|
|
nf = 1 # label found
|
|
with open(lb_path, "r") as f:
|
|
labels = [
|
|
x.split() for x in f.read().strip().splitlines() if len(x)
|
|
]
|
|
labels = np.array(labels, dtype=np.float32)
|
|
if len(labels):
|
|
assert all(
|
|
len(l) == 5 for l in labels
|
|
), f"{lb_path}: wrong label format."
|
|
assert (
|
|
labels >= 0
|
|
).all(), f"{lb_path}: Label values error: all values in label file must > 0"
|
|
assert (
|
|
labels[:, 1:] <= 1
|
|
).all(), f"{lb_path}: Label values error: all coordinates must be normalized"
|
|
|
|
_, indices = np.unique(labels, axis=0, return_index=True)
|
|
if len(indices) < len(labels): # duplicate row check
|
|
labels = labels[indices] # remove duplicates
|
|
msg += f"WARNING: {lb_path}: {len(labels) - len(indices)} duplicate labels removed"
|
|
labels = labels.tolist()
|
|
else:
|
|
ne = 1 # label empty
|
|
labels = []
|
|
else:
|
|
nm = 1 # label missing
|
|
labels = []
|
|
|
|
return img_path, labels, nc, nm, nf, ne, msg
|
|
except Exception as e:
|
|
nc = 1
|
|
msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}"
|
|
return img_path, None, nc, nm, nf, ne, msg
|
|
|
|
@staticmethod
|
|
def generate_coco_format_labels(img_info, class_names, save_path):
|
|
# for evaluation with pycocotools
|
|
dataset = {"categories": [], "annotations": [], "images": []}
|
|
for i, class_name in enumerate(class_names):
|
|
dataset["categories"].append(
|
|
{"id": i, "name": class_name, "supercategory": ""}
|
|
)
|
|
|
|
ann_id = 0
|
|
LOGGER.info(f"Convert to COCO format")
|
|
for i, (img_path, info) in enumerate(tqdm(img_info.items())):
|
|
labels = info["labels"] if info["labels"] else []
|
|
img_id = osp.splitext(osp.basename(img_path))[0]
|
|
img_id = int(img_id) if img_id.isnumeric() else img_id
|
|
img_w, img_h = info["shape"]
|
|
dataset["images"].append(
|
|
{
|
|
"file_name": os.path.basename(img_path),
|
|
"id": img_id,
|
|
"width": img_w,
|
|
"height": img_h,
|
|
}
|
|
)
|
|
if labels:
|
|
for label in labels:
|
|
c, x, y, w, h = label[:5]
|
|
# convert x,y,w,h to x1,y1,x2,y2
|
|
x1 = (x - w / 2) * img_w
|
|
y1 = (y - h / 2) * img_h
|
|
x2 = (x + w / 2) * img_w
|
|
y2 = (y + h / 2) * img_h
|
|
# cls_id starts from 0
|
|
cls_id = int(c)
|
|
w = max(0, x2 - x1)
|
|
h = max(0, y2 - y1)
|
|
dataset["annotations"].append(
|
|
{
|
|
"area": h * w,
|
|
"bbox": [x1, y1, w, h],
|
|
"category_id": cls_id,
|
|
"id": ann_id,
|
|
"image_id": img_id,
|
|
"iscrowd": 0,
|
|
# mask
|
|
"segmentation": [],
|
|
}
|
|
)
|
|
ann_id += 1
|
|
|
|
with open(save_path, "w") as f:
|
|
json.dump(dataset, f)
|
|
LOGGER.info(
|
|
f"Convert to COCO format finished. Resutls saved in {save_path}"
|
|
)
|
|
|
|
@staticmethod
|
|
def get_hash(paths):
|
|
"""Get the hash value of paths"""
|
|
assert isinstance(paths, list), "Only support list currently."
|
|
h = hashlib.md5("".join(paths).encode())
|
|
return h.hexdigest()
|