parent
1e25775c7d
commit
a814720403
@ -0,0 +1,90 @@
|
||||
"""File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/
|
||||
|
||||
Usage:
|
||||
import torch
|
||||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
|
||||
"""
|
||||
|
||||
dependencies = ['torch', 'pyyaml']
|
||||
import torch
|
||||
|
||||
from models.yolo import Model
|
||||
from utils import google_utils
|
||||
|
||||
|
||||
def create(name, pretrained, channels, classes):
|
||||
"""Creates a specified YOLOv5 model
|
||||
|
||||
Arguments:
|
||||
name (str): name of model, i.e. 'yolov5s'
|
||||
pretrained (bool): load pretrained weights into the model
|
||||
channels (int): number of input channels
|
||||
classes (int): number of model classes
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
"""
|
||||
model = Model('models/%s.yaml' % name, channels, classes)
|
||||
if pretrained:
|
||||
ckpt = '%s.pt' % name # checkpoint filename
|
||||
google_utils.attempt_download(ckpt) # download if not found locally
|
||||
state_dict = torch.load(ckpt)['model'].state_dict()
|
||||
state_dict = {k: v for k, v in state_dict if model.state_dict()[k].numel() == v.numel()} # filter
|
||||
model.load_state_dict(state_dict, strict=False) # load
|
||||
return model
|
||||
|
||||
|
||||
def yolov5s(pretrained=False, channels=3, classes=80):
|
||||
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
|
||||
|
||||
Arguments:
|
||||
pretrained (bool): load pretrained weights into the model, default=False
|
||||
channels (int): number of input channels, default=3
|
||||
classes (int): number of model classes, default=80
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
"""
|
||||
return create('yolov5s', pretrained, channels, classes)
|
||||
|
||||
|
||||
def yolov5m(pretrained=False, channels=3, classes=80):
|
||||
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
|
||||
|
||||
Arguments:
|
||||
pretrained (bool): load pretrained weights into the model, default=False
|
||||
channels (int): number of input channels, default=3
|
||||
classes (int): number of model classes, default=80
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
"""
|
||||
return create('yolov5m', pretrained, channels, classes)
|
||||
|
||||
|
||||
def yolov5l(pretrained=False, channels=3, classes=80):
|
||||
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
|
||||
|
||||
Arguments:
|
||||
pretrained (bool): load pretrained weights into the model, default=False
|
||||
channels (int): number of input channels, default=3
|
||||
classes (int): number of model classes, default=80
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
"""
|
||||
return create('yolov5l', pretrained, channels, classes)
|
||||
|
||||
|
||||
def yolov5x(pretrained=False, channels=3, classes=80):
|
||||
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
|
||||
|
||||
Arguments:
|
||||
pretrained (bool): load pretrained weights into the model, default=False
|
||||
channels (int): number of input channels, default=3
|
||||
classes (int): number of model classes, default=80
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
"""
|
||||
return create('yolov5x', pretrained, channels, classes)
|
@ -1,55 +0,0 @@
|
||||
# parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # expand model depth
|
||||
width_multiple: 1.0 # expand layer channels
|
||||
|
||||
# anchors
|
||||
anchors:
|
||||
- [10,13, 16,30, 33,23] # P3/8
|
||||
- [30,61, 62,45, 59,119] # P4/16
|
||||
- [116,90, 156,198, 373,326] # P5/32
|
||||
|
||||
# darknet53 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[[-1, 1, Conv, [32, 3, 1]], # 0
|
||||
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
||||
[-1, 1, BottleneckCSP, [64]],
|
||||
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
||||
[-1, 2, BottleneckCSP, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
||||
[-1, 8, BottleneckCSP, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
||||
[-1, 8, BottleneckCSP, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
||||
[-1, 4, BottleneckCSP, [1024]], # 10
|
||||
]
|
||||
|
||||
# yolov3-spp head
|
||||
# na = len(anchors[0])
|
||||
head:
|
||||
[[-1, 1, Bottleneck, [1024, False]], # 11
|
||||
[-1, 1, SPP, [512, [5, 9, 13]]],
|
||||
[-1, 1, Conv, [1024, 3, 1]],
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, Conv, [1024, 3, 1]],
|
||||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P5/32-large)
|
||||
|
||||
[-3, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, Conv, [512, 3, 1]],
|
||||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 24 (P4/16-medium)
|
||||
|
||||
[-3, 1, Conv, [128, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 1, Bottleneck, [256, False]],
|
||||
[-1, 2, Bottleneck, [256, False]],
|
||||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 30 (P3/8-small)
|
||||
|
||||
[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
Loading…
Reference in new issue