diff --git a/train.py b/train.py index e9be7e4..7df99c0 100644 --- a/train.py +++ b/train.py @@ -195,8 +195,9 @@ def train(hyp): c = torch.tensor(labels[:, 0]) # classes # cf = torch.bincount(c.long(), minlength=nc) + 1. # model._initialize_biases(cf.to(device)) - plot_labels(labels) - tb_writer.add_histogram('classes', c, 0) + if tb_writer: + plot_labels(labels) + tb_writer.add_histogram('classes', c, 0) # Check anchors check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)