From 994829f66a51599ae10a1ebebf2cf767a8c01b29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=91=AB=E5=AE=87?= <1324004302@qq.com> Date: Wed, 18 Dec 2024 20:22:12 +0800 Subject: [PATCH] UYHHU --- MTSP-main/ultralytics/models/fastsam/val.py | 165 ++++++++++++++------ 1 file changed, 115 insertions(+), 50 deletions(-) diff --git a/MTSP-main/ultralytics/models/fastsam/val.py b/MTSP-main/ultralytics/models/fastsam/val.py index 9bbae57..61c8847 100644 --- a/MTSP-main/ultralytics/models/fastsam/val.py +++ b/MTSP-main/ultralytics/models/fastsam/val.py @@ -13,23 +13,56 @@ from ultralytics.utils.checks import check_requirements from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou from ultralytics.utils.plotting import output_to_target, plot_images - +# FastSAMValidator类继承自DetectionValidator,主要用于对FastSAM模型在分割任务中的验证相关操作,例如处理预测结果、计算评估指标、绘制验证图像等 class FastSAMValidator(DetectionValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): - """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" + """ + Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics. + + 函数功能: + 初始化FastSAMValidator类,调用父类(DetectionValidator)的初始化方法,同时设置任务类型为'segment', + 并初始化用于评估分割任务的指标对象(SegmentMetrics)。 + + 参数说明: + dataloader (可选):数据加载器,用于加载验证数据。 + save_dir (可选):保存验证结果的目录路径。 + pbar (可选):进度条对象,用于显示验证进度(可能在可视化进度相关的功能中使用)。 + args (可选):包含各种配置参数的对象,例如模型相关的参数、验证相关的设置等。 + _callbacks (可选):回调函数相关对象,用于在特定事件发生时执行自定义的操作(比如在验证过程的某些阶段触发额外的处理逻辑)。 + """ super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.args.task = 'segment' self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) def preprocess(self, batch): - """Preprocesses batch by converting masks to float and sending to device.""" + """ + Preprocesses batch by converting masks to float and sending to device. + + 函数功能: + 对输入的批次数据(batch)进行预处理,先调用父类的预处理方法,然后将批次数据中的掩码(masks)数据转换为浮点数类型,并发送到指定的设备(self.device)上。 + + 参数说明: + batch:包含了图像、标签、掩码等多种数据的批次数据,格式通常是按照数据加载器的定义组织的。 + + 返回值: + 处理后的批次数据,其中掩码数据已转换为浮点数并放置在相应设备上。 + """ batch = super().preprocess(batch) batch['masks'] = batch['masks'].to(self.device).float() return batch def init_metrics(self, model): - """Initialize metrics and select mask processing function based on save_json flag.""" + """ + Initialize metrics and select mask processing function based on save_json flag. + + 函数功能: + 初始化评估指标相关的设置,先调用父类的初始化指标方法,然后根据是否保存JSON格式结果(self.args.save_json)的配置, + 选择合适的掩码处理函数,用于后续对预测掩码的处理操作。 + + 参数说明: + model:正在验证的模型对象,可能在某些与模型相关的指标初始化操作中会用到(虽然此处代码中未体现具体使用情况)。 + """ super().init_metrics(model) self.plot_masks = [] if self.args.save_json: @@ -39,12 +72,32 @@ class FastSAMValidator(DetectionValidator): self.process = ops.process_mask # faster def get_desc(self): - """Return a formatted description of evaluation metrics.""" + """ + Return a formatted description of evaluation metrics. + + 函数功能: + 返回一个格式化的评估指标描述字符串,用于展示不同评估指标在输出结果中的格式和排列顺序,方便查看和理解验证结果中的各项指标含义。 + + 返回值: + 格式化后的字符串,包含了如类别(Class)、图像数量(Images)、实例数量(Instances)以及不同IoU阈值下的框相关指标(Box相关)和掩码相关指标(Mask相关)等信息的格式描述。 + """ return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', 'R', 'mAP50', 'mAP50-95)') def postprocess(self, preds): - """Postprocesses YOLO predictions and returns output detections with proto.""" + """ + Postprocesses YOLO predictions and returns output detections with proto. + + 函数功能: + 对YOLO模型的预测结果(preds)进行后处理操作,通过非极大值抑制(non_max_suppression)筛选出合适的预测框, + 并提取出与预测相关的其他输出(如proto,具体含义根据模型而定,可能与掩码生成等相关),最终返回处理后的预测结果和相关输出。 + + 参数说明: + preds:YOLO模型输出的原始预测结果,通常包含了多个维度的信息,例如预测的边界框、类别概率、掩码相关信息等(具体结构取决于模型设计)。 + + 返回值: + 包含处理后的预测检测框(p)和相关输出(proto)的元组,处理后的预测检测框经过了非极大值抑制等筛选操作,符合一定的置信度、IOU等条件。 + """ p = ops.non_max_suppression(preds[0], self.args.conf, self.args.iou, @@ -124,7 +177,13 @@ class FastSAMValidator(DetectionValidator): # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') def finalize_metrics(self, *args, **kwargs): - """Sets speed and confusion matrix for evaluation metrics.""" + """ + Sets speed and confusion matrix for evaluation metrics. + + 函数功能: + 将验证过程中的速度信息(self.speed)和混淆矩阵信息(self.confusion_matrix)设置到评估指标对象(self.metrics)中, + 用于最终的评估指标统计和展示等操作。 + """ self.metrics.speed = self.speed self.metrics.confusion_matrix = self.confusion_matrix @@ -143,7 +202,7 @@ class FastSAMValidator(DetectionValidator): index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) gt_masks = torch.where(gt_masks == index, 1.0, 0.0) - if gt_masks.shape[1:] != pred_masks.shape[1:]: + if gt_masks.shape[1:]!= pred_masks.shape[1:]: gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] gt_masks = gt_masks.gt_(0.5) iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) @@ -166,7 +225,16 @@ class FastSAMValidator(DetectionValidator): return torch.tensor(correct, dtype=torch.bool, device=detections.device) def plot_val_samples(self, batch, ni): - """Plots validation samples with bounding box labels.""" + """ + Plots validation samples with bounding box labels. + + 函数功能: + 使用给定的批次数据(batch)绘制带有边界框标签的验证样本图像,用于可视化验证数据的真实标注情况,方便查看数据和评估模型表现。 + + 参数说明: + batch:包含图像、标签、掩码等数据的批次数据,用于获取绘制图像所需的信息。 + ni:可能是用于标识批次序号或者图像序号等的索引信息,用于生成唯一的文件名等(从函数调用处推测)。 + """ plot_images(batch['img'], batch['batch_idx'], batch['cls'].squeeze(-1), @@ -178,7 +246,18 @@ class FastSAMValidator(DetectionValidator): on_plot=self.on_plot) def plot_predictions(self, batch, preds, ni): - """Plots batch predictions with masks and bounding boxes.""" + """ + Plots batch predictions with masks and bounding boxes. + + 函数功能: + 使用给定的批次数据(batch)和模型预测结果(preds)绘制带有掩码和边界框的预测图像,用于可视化模型的预测情况, + 方便对比真实标注和模型预测之间的差异,评估模型性能。 + + 参数说明: + batch:包含图像、标签、掩码等数据的批次数据,用于获取绘制图像所需的部分信息,如原始图像等。 + preds:模型对该批次数据的预测结果,包含了预测的边界框、掩码等信息,用于在图像上绘制相应的预测可视化内容。 + ni:类似前面函数中的索引信息,用于生成唯一的文件名等(从函数调用处推测)。 + """ plot_images( batch['img'], *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed @@ -190,12 +269,35 @@ class FastSAMValidator(DetectionValidator): self.plot_masks.clear() def pred_to_json(self, predn, filename, pred_masks): - """Save one JSON result.""" + """ + Save one JSON result. + + 函数功能: + 将模型的预测结果(包括边界框信息、类别信息、掩码信息等)转换为符合JSON格式的数据结构,并保存相关信息, + 以便后续可以使用相关工具(如pycocotools)进行评估指标的计算等操作。 + + 参数说明: + predn:处理后的预测结果数据,包含了边界框坐标、类别、置信度等信息,格式可能是经过特定处理后的张量形式。 + filename:原始图像的文件名,用于获取图像的标识信息(可能在JSON结果中作为图像的唯一标识等)。 + pred_masks:预测的掩码数据,用于将掩码信息编码为合适的格式后添加到JSON结果中。 + """ # Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} from pycocotools.mask import encode # noqa def single_encode(x): - """Encode predicted masks as RLE and append results to jdict.""" + """ + Encode predicted masks as RLE and append results to jdict. + + 函数功能: + 对单个预测掩码进行编码,将其转换为行程长度编码(Run-Length Encoding,RLE)格式,方便在JSON结果中存储掩码信息, + 并返回编码后的RLE数据结构。 + + 参数说明: + x:单个预测掩码数据,通常是二维数组形式表示的图像掩码。 + + 返回值: + 编码后的RLE数据结构,包含了掩码的编码信息(经过一定处理,如将字节串解码等操作),可用于JSON格式存储。 + """ rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] rle['counts'] = rle['counts'].decode('utf-8') return rle @@ -204,41 +306,4 @@ class FastSAMValidator(DetectionValidator): image_id = int(stem) if stem.isnumeric() else stem box = ops.xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner - pred_masks = np.transpose(pred_masks, (2, 0, 1)) - with ThreadPool(NUM_THREADS) as pool: - rles = pool.map(single_encode, pred_masks) - for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5), - 'segmentation': rles[i]}) - - def eval_json(self, stats): - """Return COCO-style object detection evaluation metrics.""" - if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') - try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') - from pycocotools.coco import COCO # noqa - from pycocotools.cocoeval import COCOeval # noqa - - for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' - anno = COCO(str(anno_json)) # init annotations api - pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): - if self.is_coco: - eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval - eval.evaluate() - eval.accumulate() - eval.summarize() - idx = i * 4 + 2 - stats[self.metrics.keys[idx + 1]], stats[ - self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 - except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') - return stats + pred_masks = np.transpose(pred_masks, (2, \ No newline at end of file