diff --git a/models/yolo.py b/models/yolo.py index ae50c85..69d2f15 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -45,13 +45,15 @@ class Detect(nn.Module): class Model(nn.Module): - def __init__(self, model_yaml='yolov5s.yaml'): # cfg, number of classes, depth-width gains + def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes super(Model, self).__init__() with open(model_yaml) as f: self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict + if nc: + self.md['nc'] = nc # override yaml value # Define model - self.model, self.save, ch = parse_model(self.md, ch=[3]) # model, savelist, ch_out + self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out # print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))]) # Build strides, anchors