diff --git a/train.py b/train.py index db135a7..484190a 100644 --- a/train.py +++ b/train.py @@ -199,7 +199,7 @@ def train(hyp): tb_writer.add_histogram('classes', c, 0) # Check anchors - check_anchors(dataset, model=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz) + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average ema = torch_utils.ModelEMA(model)