From 9da56b62ddee550c0db0662c7ca00bcc6bfcf99a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 23 Jul 2020 15:34:23 -0700 Subject: [PATCH] v2.0 Release (#491) Signed-off-by: Glenn Jocher --- README.md | 16 +++++++++------- models/yolo.py | 20 ++++++++++++-------- models/yolov5l.yaml | 26 +++++++++++--------------- models/yolov5m.yaml | 26 +++++++++++--------------- models/yolov5s.yaml | 26 +++++++++++--------------- models/yolov5x.yaml | 26 +++++++++++--------------- train.py | 12 ++++++------ utils/torch_utils.py | 2 +- utils/utils.py | 9 +++++---- 9 files changed, 77 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index b4f9d2b..b79683d 100755 --- a/README.md +++ b/README.md @@ -8,26 +8,28 @@ This repository represents Ultralytics open-source research into future object d ** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 8, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. +- **July 23, 2020**: [v2.0 release](https://arxiv.org/abs/1803.01534): improved model definition, training and mAP [](). - **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, faster inference and improved mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972). - **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145). - **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP). -- **May 27, 2020**: Public release of repo. YOLOv5 models are SOTA among all known YOLO implementations. -- **April 1, 2020**: Start development of future [YOLOv3](https://github.com/ultralytics/yolov3)/[YOLOv4](https://github.com/AlexeyAB/darknet)-based PyTorch models in a range of compound-scaled sizes. +- **May 27, 2020**: Public release. YOLOv5 models are SOTA among all known YOLO implementations. +- **April 1, 2020**: Start development of future compound-scaled [YOLOv3](https://github.com/ultralytics/yolov3)/[YOLOv4](https://github.com/AlexeyAB/darknet)-based PyTorch models. ## Pretrained Checkpoints | Model | APval | APtest | AP50 | SpeedGPU | FPSGPU || params | FLOPS | |---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: | -| [YOLOv5s](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 36.6 | 36.6 | 55.8 | **2.1ms** | **476** || 7.5M | 13.2B -| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B -| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B -| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.4** | **48.4** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B +| [YOLOv5.1s](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 36.1 | 36.1 | 55.3 | **2.1ms** | **476** || 7.5M | 13.2B +| [YOLOv5.1m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.5 | 43.5 | 62.5 | 3.0ms | 333 || 21.8M | 39.4B +| [YOLOv5.1l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 47.0 | 47.1 | 65.6 | 3.9ms | 256 || 47.8M | 88.1B +| [YOLOv5.1x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **49.0** | **49.0** | **67.4** | 6.1ms | 164 || 89.0M | 166.4B +| | | | | | || | | [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B ** APtest denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy. -** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --data coco.yaml --img 736 --conf 0.001` +** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --data coco.yaml --img 672 --conf 0.001` ** SpeedGPU measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --data coco.yaml --img 640 --conf 0.1` ** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). diff --git a/models/yolo.py b/models/yolo.py index da96a31..16638ed 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -5,7 +5,7 @@ from models.experimental import * class Detect(nn.Module): - def __init__(self, nc=80, anchors=()): # detection layer + def __init__(self, nc=80, anchors=(), ch=()): # detection layer super(Detect, self).__init__() self.stride = None # strides computed during build self.nc = nc # number of classes @@ -16,6 +16,7 @@ class Detect(nn.Module): a = torch.tensor(anchors).float().view(self.nl, -1, 2) self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv self.export = False # onnx export def forward(self, x): @@ -23,6 +24,7 @@ class Detect(nn.Module): z = [] # inference output self.training |= self.export for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() @@ -124,8 +126,7 @@ class Model(nn.Module): def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. m = self.model[-1] # Detect() module - for f, s in zip(m.f, m.stride): #  from - mi = self.model[f % m.i] + for mi, s in zip(m.m, m.stride): #  from b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls @@ -133,9 +134,9 @@ class Model(nn.Module): def _print_biases(self): m = self.model[-1] # Detect() module - for f in sorted([x % m.i for x in m.f]): #  from - b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) - print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean())) + for mi in m.m: #  from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) # def _print_weights(self): # for m in self.model.modules(): @@ -159,7 +160,7 @@ class Model(nn.Module): def parse_model(d, ch): # model_dict, input_channels(3) print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] - na = (len(anchors[0]) // 2) # number of anchors + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors no = na * (nc + 5) # number of outputs = anchors * (classes + 5) layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out @@ -181,6 +182,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) # e = math.log(c2 / ch[1]) / math.log(2) # c2 = int(ch[1] * ex ** e) # if m != Focus: + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 # Experimental @@ -201,7 +203,9 @@ def parse_model(d, ch): # model_dict, input_channels(3) elif m is Concat: c2 = sum([ch[-1 if x == -1 else x + 1] for x in f]) elif m is Detect: - f = f or list(reversed([(-1 if j == i else j - 1) for j, x in enumerate(ch) if x == no])) + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) else: c2 = ch[f] diff --git a/models/yolov5l.yaml b/models/yolov5l.yaml index 959d4bd..a1c5547 100644 --- a/models/yolov5l.yaml +++ b/models/yolov5l.yaml @@ -5,9 +5,9 @@ width_multiple: 1.0 # layer channel multiple # anchors anchors: - - [116,90, 156,198, 373,326] # P5/32 - - [30,61, 62,45, 59,119] # P4/16 - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 backbone backbone: @@ -19,15 +19,14 @@ backbone: [-1, 9, BottleneckCSP, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 [-1, 9, BottleneckCSP, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 ] # YOLOv5 head head: - [[-1, 3, BottleneckCSP, [1024, False]], # 9 - - [-1, 1, Conv, [512, 1, 1]], + [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, BottleneckCSP, [512, False]], # 13 @@ -35,18 +34,15 @@ head: [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small) + [-1, 3, BottleneckCSP, [256, False]], # 17 - [-2, 1, Conv, [256, 3, 2]], + [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium) + [-1, 3, BottleneckCSP, [512, False]], # 20 - [-2, 1, Conv, [512, 3, 2]], + [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large) + [-1, 3, BottleneckCSP, [1024, False]], # 23 - [[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3) + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5m.yaml b/models/yolov5m.yaml index 60037c2..24a7193 100644 --- a/models/yolov5m.yaml +++ b/models/yolov5m.yaml @@ -5,9 +5,9 @@ width_multiple: 0.75 # layer channel multiple # anchors anchors: - - [116,90, 156,198, 373,326] # P5/32 - - [30,61, 62,45, 59,119] # P4/16 - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 backbone backbone: @@ -19,15 +19,14 @@ backbone: [-1, 9, BottleneckCSP, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 [-1, 9, BottleneckCSP, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 ] # YOLOv5 head head: - [[-1, 3, BottleneckCSP, [1024, False]], # 9 - - [-1, 1, Conv, [512, 1, 1]], + [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, BottleneckCSP, [512, False]], # 13 @@ -35,18 +34,15 @@ head: [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small) + [-1, 3, BottleneckCSP, [256, False]], # 17 - [-2, 1, Conv, [256, 3, 2]], + [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium) + [-1, 3, BottleneckCSP, [512, False]], # 20 - [-2, 1, Conv, [512, 3, 2]], + [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large) + [-1, 3, BottleneckCSP, [1024, False]], # 23 - [[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3) + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5s.yaml b/models/yolov5s.yaml index 1eaef97..ff07628 100644 --- a/models/yolov5s.yaml +++ b/models/yolov5s.yaml @@ -5,9 +5,9 @@ width_multiple: 0.50 # layer channel multiple # anchors anchors: - - [116,90, 156,198, 373,326] # P5/32 - - [30,61, 62,45, 59,119] # P4/16 - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 backbone backbone: @@ -19,15 +19,14 @@ backbone: [-1, 9, BottleneckCSP, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 [-1, 9, BottleneckCSP, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 ] # YOLOv5 head head: - [[-1, 3, BottleneckCSP, [1024, False]], # 9 - - [-1, 1, Conv, [512, 1, 1]], + [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, BottleneckCSP, [512, False]], # 13 @@ -35,18 +34,15 @@ head: [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small) + [-1, 3, BottleneckCSP, [256, False]], # 17 - [-2, 1, Conv, [256, 3, 2]], + [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium) + [-1, 3, BottleneckCSP, [512, False]], # 20 - [-2, 1, Conv, [512, 3, 2]], + [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large) + [-1, 3, BottleneckCSP, [1024, False]], # 23 - [[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3) + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5x.yaml b/models/yolov5x.yaml index dcd6fbc..8bd1837 100644 --- a/models/yolov5x.yaml +++ b/models/yolov5x.yaml @@ -5,9 +5,9 @@ width_multiple: 1.25 # layer channel multiple # anchors anchors: - - [116,90, 156,198, 373,326] # P5/32 - - [30,61, 62,45, 59,119] # P4/16 - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 backbone backbone: @@ -19,15 +19,14 @@ backbone: [-1, 9, BottleneckCSP, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 [-1, 9, BottleneckCSP, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 ] # YOLOv5 head head: - [[-1, 3, BottleneckCSP, [1024, False]], # 9 - - [-1, 1, Conv, [512, 1, 1]], + [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, BottleneckCSP, [512, False]], # 13 @@ -35,18 +34,15 @@ head: [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small) + [-1, 3, BottleneckCSP, [256, False]], # 17 - [-2, 1, Conv, [256, 3, 2]], + [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium) + [-1, 3, BottleneckCSP, [512, False]], # 20 - [-2, 1, Conv, [512, 3, 2]], + [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], - [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large) + [-1, 3, BottleneckCSP, [1024, False]], # 23 - [[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3) + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/train.py b/train.py index 86adee5..bfed25d 100644 --- a/train.py +++ b/train.py @@ -27,16 +27,16 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD 'momentum': 0.937, # SGD momentum/Adam beta1 'weight_decay': 5e-4, # optimizer weight decay 'giou': 0.05, # giou loss gain - 'cls': 0.58, # cls loss gain + 'cls': 0.5, # cls loss gain 'cls_pw': 1.0, # cls BCELoss positive_weight 'obj': 1.0, # obj loss gain (*=img_size/320 if img_size != 320) 'obj_pw': 1.0, # obj BCELoss positive_weight 'iou_t': 0.20, # iou training threshold 'anchor_t': 4.0, # anchor-multiple threshold 'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5) - 'hsv_h': 0.014, # image HSV-Hue augmentation (fraction) - 'hsv_s': 0.68, # image HSV-Saturation augmentation (fraction) - 'hsv_v': 0.36, # image HSV-Value augmentation (fraction) + 'hsv_h': 0.015, # image HSV-Hue augmentation (fraction) + 'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction) + 'hsv_v': 0.4, # image HSV-Value augmentation (fraction) 'degrees': 0.0, # image rotation (+/- deg) 'translate': 0.0, # image translation (+/- fraction) 'scale': 0.5, # image scale (+/- gain) @@ -159,7 +159,7 @@ def train(hyp, tb_writer, opt, device): model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) # Scheduler https://arxiv.org/pdf/1812.01187.pdf - lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine + lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) @@ -334,7 +334,7 @@ def train(hyp, tb_writer, opt, device): if rank in [-1, 0]: # mAP if ema is not None: - ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride']) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP results, maps, times = test.test(opt.data, diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 06d0447..ce584f8 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -65,7 +65,7 @@ def initialize_weights(model): if t is nn.Conv2d: pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif t is nn.BatchNorm2d: - m.eps = 1e-4 + m.eps = 1e-3 m.momentum = 0.03 elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]: m.inplace = True diff --git a/utils/utils.py b/utils/utils.py index c6f1352..34416f9 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -5,10 +5,10 @@ import random import shutil import subprocess import time +from contextlib import contextmanager from copy import copy from pathlib import Path from sys import platform -from contextlib import contextmanager import cv2 import matplotlib @@ -110,6 +110,7 @@ def check_anchor_order(m): da = a[-1] - a[0] # delta a ds = m.stride[-1] - m.stride[0] # delta s if da.sign() != ds.sign(): # same order + print('Reversing anchor order') m.anchors[:] = m.anchors.flip(0) m.anchor_grid[:] = m.anchor_grid.flip(0) @@ -459,7 +460,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # per output nt = 0 # number of targets np = len(p) # number of outputs - balance = [1.0, 1.0, 1.0] + balance = [4.0, 1.0, 0.4] if np == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj @@ -493,7 +494,7 @@ def compute_loss(p, targets, model): # predictions, targets, model s = 3 / np # output count scaling lbox *= h['giou'] * s - lobj *= h['obj'] * s + lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) lcls *= h['cls'] * s bs = tobj.shape[0] # batch size if red == 'sum': @@ -1119,7 +1120,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8, label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO')) - ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5], + ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.8, 39.6, 43.0, 47.5, 49.4, 50.7], 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') ax2.grid()