import os import sys import torch import numpy as np import cv2 from ultralytics import YOLO from pathlib import Path class ShipPartDetector: """ 舰船部件检测模块,负责识别舰船的各个组成部分 支持检测的部件包括:舰桥、雷达、舰炮、导弹发射装置、直升机甲板等 """ def __init__(self, model_path=None, device=None): """ 初始化舰船部件检测器 Args: model_path: 部件检测模型路径,如果为None则使用预训练通用模型后处理 device: 运行设备,可以是'cuda'或'cpu',None则自动选择 """ # 设置模型路径 if model_path is None: script_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(script_dir) model_path = os.path.join(parent_dir, 'yolov8s-seg.pt') self.model_path = model_path # 设置设备 if device is None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' else: self.device = device # 舰船部件类型映射表 - 扩展版本 self.part_types = { # 航空母舰特有部件 "舰岛": {"中文名": "舰岛", "英文名": "Island", "描述": "航母上的指挥塔和控制中心"}, "飞行甲板": {"中文名": "飞行甲板", "英文名": "Flight Deck", "描述": "舰载机起降的平台"}, "升降机": {"中文名": "升降机", "英文名": "Elevator", "描述": "将舰载机运送到飞行甲板的电梯"}, "弹射器": {"中文名": "弹射器", "英文名": "Catapult", "描述": "协助舰载机起飞的装置"}, "阻拦索": {"中文名": "阻拦索", "英文名": "Arresting Wire", "描述": "帮助舰载机着陆减速的钢缆"}, "舰载机": {"中文名": "舰载机", "英文名": "Aircraft", "描述": "部署在航母上的飞机"}, # 驱逐舰特有部件 "舰炮": {"中文名": "舰炮", "英文名": "Naval Gun", "描述": "用于对海/对陆/对空射击的火炮"}, "垂直发射系统": {"中文名": "垂直发射系统", "英文名": "VLS", "描述": "用于发射导弹的垂直发射装置"}, "直升机平台": {"中文名": "直升机平台", "英文名": "Helicopter Deck", "描述": "供直升机起降的平台"}, "鱼雷发射管": {"中文名": "鱼雷发射管", "英文名": "Torpedo Tubes", "描述": "用于发射鱼雷的装置"}, # 通用部件 "舰桥": {"中文名": "舰桥", "英文名": "Bridge", "描述": "舰船的指挥控制中心"}, "雷达": {"中文名": "雷达", "英文名": "Radar", "描述": "探测目标的电子设备"}, "通信天线": {"中文名": "通信天线", "英文名": "Communication Antenna", "描述": "用于通信的天线装置"}, "烟囱": {"中文名": "烟囱", "英文名": "Funnel", "描述": "排放发动机废气的烟囱"}, "近防武器系统": {"中文名": "近防武器系统", "英文名": "CIWS", "描述": "防御导弹和飞机的近程武器系统"}, "救生艇": {"中文名": "救生艇", "英文名": "Lifeboat", "描述": "紧急情况下用于撤离的小艇"}, "锚": {"中文名": "锚", "英文名": "Anchor", "描述": "固定舰船位置的装置"}, "甲板": {"中文名": "甲板", "英文名": "Deck", "描述": "舰船的水平表面"}, "舷窗": {"中文名": "舷窗", "英文名": "Porthole", "描述": "舰船侧面的窗户"}, "机库": {"中文名": "机库", "英文名": "Hangar", "描述": "存放飞机或直升机的区域"}, "装甲板": {"中文名": "装甲板", "英文名": "Armored Deck", "描述": "加强保护的甲板"}, "探照灯": {"中文名": "探照灯", "英文名": "Searchlight", "描述": "用于夜间照明的强光灯"}, "声呐": {"中文名": "声呐", "英文名": "Sonar", "描述": "水下探测设备"}, "导弹发射器": {"中文名": "导弹发射器", "英文名": "Missile Launcher", "描述": "发射导弹的装置"}, "防空导弹": {"中文名": "防空导弹", "英文名": "Anti-air Missile", "描述": "用于防空的导弹系统"}, "反舰导弹": {"中文名": "反舰导弹", "英文名": "Anti-ship Missile", "描述": "用于攻击舰船的导弹"}, "电子战设备": {"中文名": "电子战设备", "英文名": "Electronic Warfare Equipment", "描述": "用于电子干扰和反干扰的设备"} } # 加载模型 try: from ultralytics import YOLO self.model = YOLO(self.model_path) self.model_loaded = True print(f"成功加载部件检测模型: {self.model_path}") except Exception as e: print(f"加载部件检测模型出错: {e}") self.model = None self.model_loaded = False # 舰船类型与特有部件映射 self.ship_parts_map = { "航空母舰": ["舰岛", "飞行甲板", "升降机", "弹射器", "阻拦索", "舰载机", "舰桥", "雷达", "通信天线"], "驱逐舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "烟囱", "近防武器系统"], "护卫舰": ["舰炮", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统", "声呐"], "巡洋舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统"], "潜艇": ["舰塔", "鱼雷发射管", "垂直发射系统", "通信天线", "潜望镜"], "两栖攻击舰": ["飞行甲板", "舰岛", "直升机平台", "船坞", "舰桥", "雷达", "通信天线"] } print(f"使用设备: {self.device}") # 启用通用部件检测 self.enable_generic_parts = True # 部件类别映射 (这里的ID应该对应训练好的模型中的类别ID) self.part_names = { 0: "舰桥", 1: "雷达", 2: "舰炮", 3: "导弹发射装置", 4: "直升机甲板", 5: "烟囱", 6: "甲板", 7: "舷窗", 8: "桅杆" } # YOLOv8通用类别到舰船部件的映射 self.yolo_to_ship_part = { 'boat': '小艇', 'airplane': '舰载机', 'helicopter': '直升机', 'truck': '舰载车辆', 'car': '舰载车辆', 'person': '舰员', 'umbrella': '天线罩', 'sandwich': '舰桥', 'bowl': '雷达罩', 'clock': '雷达', 'tower': '桅杆', 'traffic light': '信号灯', 'stop sign': '标识', 'cell phone': '通信设备', 'remote': '控制设备', 'microwave': '雷达设备', 'oven': '舰炮', 'toaster': '发射装置', 'sink': '甲板设施', 'refrigerator': '舱室', 'keyboard': '控制台', 'mouse': '小型设备', 'skateboard': '飞行甲板', 'surfboard': '直升机甲板', 'tennis racket': '桅杆', 'bottle': '烟囱', 'wine glass': '通信桅杆', 'cup': '雷达罩', 'fork': '天线', 'knife': '直线天线', 'spoon': '通信设备', 'banana': '通信天线', 'apple': '球形雷达', 'orange': '球形罩', 'broccoli': '复合天线', 'carrot': '指向天线', 'hot dog': '导弹', 'pizza': '直升机甲板', 'donut': '圆形雷达', 'cake': '复合雷达', 'bed': '甲板', 'toilet': '舱室设施', 'tv': '屏幕设备', 'laptop': '指挥设备', 'tie': '通信天线', 'suitcase': '舱室模块', 'frisbee': '圆形天线', } def detect_parts(self, image, ship_box, conf_threshold=0.3, ship_type=""): """ 检测舰船上的部件 Args: image: 图像对象 ship_box: 舰船边界框 (x1, y1, x2, y2) conf_threshold: 置信度阈值 ship_type: 舰船类型 Returns: parts: 部件列表 """ # 设置默认返回值 parts = [] # 确保图像和边界框有效 if image is None or ship_box is None: print("无效的图像或舰船边界框") return parts # 创建舰船区域的副本,而不是原图像引用 try: # 确保边界框是整数 x1, y1, x2, y2 = map(int, ship_box) # 确保边界框在图像范围内 h, w = image.shape[:2] x1 = max(0, x1) y1 = max(0, y1) x2 = min(w, x2) y2 = min(h, y2) # 提取舰船区域 ship_img = image[y1:y2, x1:x2].copy() # 检查提取的图像是否有效 if ship_img.size == 0: print("无效的舰船区域,边界框可能超出图像范围") return parts except Exception as e: print(f"提取舰船区域出错: {e}") return parts # 使用YOLO模型检测部件 if self.model is not None: try: # 根据舰船类型过滤目标部件 target_parts = self._get_target_parts(ship_type) # 使用YOLO模型进行检测 results = self.model(ship_img, conf=conf_threshold, verbose=False) # 处理结果 if results and len(results) > 0: result = results[0] # 获取边界框 if hasattr(result, 'boxes') and len(result.boxes) > 0: boxes = result.boxes for i, box in enumerate(boxes): # 检查类别 cls_id = int(box.cls.item()) if hasattr(box, 'cls') else -1 # 确定部件类型 part_name = self._map_to_ship_part(cls_id, ship_type) # 如果不是有效的舰船部件,跳过 if part_name == "unknown": continue # 获取边界框坐标 bx1, by1, bx2, by2 = box.xyxy[0].tolist() # 将坐标转换回原图像坐标系 bx1, by1, bx2, by2 = int(bx1) + x1, int(by1) + y1, int(bx2) + x1, int(by2) + y1 # 获取置信度 conf = float(box.conf.item()) if hasattr(box, 'conf') else 0.5 # 添加到部件列表 part = { 'name': part_name, 'bbox': (bx1, by1, bx2, by2), 'confidence': conf } parts.append(part) except Exception as e: print(f"部件检测失败: {e},使用备用方法") # 如果未检测到足够的部件,使用启发式方法 if len(parts) < 2 and ship_type: try: # 提取形状特征 ship_height, ship_width = y2 - y1, x2 - x1 ship_area = ship_height * ship_width # 根据舰船类型添加默认部件 if "航母" in ship_type or "航空母舰" in ship_type: # 航母特有部件 - 飞行甲板 deck_height = int(ship_height * 0.15) deck_width = ship_width deck_x1 = x1 deck_y1 = y1 + int(ship_height * 0.05) deck_x2 = x1 + deck_width deck_y2 = deck_y1 + deck_height parts.append({ 'name': "飞行甲板", 'bbox': (deck_x1, deck_y1, deck_x2, deck_y2), 'confidence': 0.85 }) # 舰岛 island_width = int(ship_width * 0.25) island_height = int(ship_height * 0.3) island_x1 = x1 + int(ship_width * 0.6) island_y1 = y1 + int(ship_height * 0.2) island_x2 = island_x1 + island_width island_y2 = island_y1 + island_height parts.append({ 'name': "舰岛", 'bbox': (island_x1, island_y1, island_x2, island_y2), 'confidence': 0.8 }) elif "驱逐舰" in ship_type: # 驱逐舰特有部件 - 舰桥 bridge_width = int(ship_width * 0.2) bridge_height = int(ship_height * 0.4) bridge_x1 = x1 + int(ship_width * 0.4) bridge_y1 = y1 + int(ship_height * 0.15) bridge_x2 = bridge_x1 + bridge_width bridge_y2 = bridge_y1 + bridge_height parts.append({ 'name': "舰桥", 'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2), 'confidence': 0.8 }) # 垂发系统 vls_width = int(ship_width * 0.15) vls_height = int(ship_height * 0.2) vls_x1 = x1 + int(ship_width * 0.2) vls_y1 = y1 + int(ship_height * 0.25) vls_x2 = vls_x1 + vls_width vls_y2 = vls_y1 + vls_height parts.append({ 'name': "垂发系统", 'bbox': (vls_x1, vls_y1, vls_x2, vls_y2), 'confidence': 0.7 }) # 主炮 gun_width = int(ship_width * 0.1) gun_height = int(ship_height * 0.15) gun_x1 = x1 + int(ship_width * 0.05) gun_y1 = y1 + int(ship_height * 0.3) gun_x2 = gun_x1 + gun_width gun_y2 = gun_y1 + gun_height parts.append({ 'name': "主炮", 'bbox': (gun_x1, gun_y1, gun_x2, gun_y2), 'confidence': 0.75 }) elif "护卫舰" in ship_type: # 护卫舰特有部件 # 舰桥 bridge_width = int(ship_width * 0.18) bridge_height = int(ship_height * 0.35) bridge_x1 = x1 + int(ship_width * 0.35) bridge_y1 = y1 + int(ship_height * 0.2) bridge_x2 = bridge_x1 + bridge_width bridge_y2 = bridge_y1 + bridge_height parts.append({ 'name': "舰桥", 'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2), 'confidence': 0.8 }) # 直升机甲板 heli_width = int(ship_width * 0.25) heli_height = int(ship_height * 0.25) heli_x1 = x1 + int(ship_width * 0.7) heli_y1 = y1 + int(ship_height * 0.25) heli_x2 = heli_x1 + heli_width heli_y2 = heli_y1 + heli_height parts.append({ 'name': "直升机甲板", 'bbox': (heli_x1, heli_y1, heli_x2, heli_y2), 'confidence': 0.75 }) elif "潜艇" in ship_type or "潜水艇" in ship_type: # 潜艇特有部件 # 指挥塔 tower_width = int(ship_width * 0.15) tower_height = int(ship_height * 0.4) tower_x1 = x1 + int(ship_width * 0.4) tower_y1 = y1 + int(ship_height * 0.1) tower_x2 = tower_x1 + tower_width tower_y2 = tower_y1 + tower_height parts.append({ 'name': "指挥塔", 'bbox': (tower_x1, tower_y1, tower_x2, tower_y2), 'confidence': 0.8 }) else: # 通用舰船部件 # 舰桥 bridge_width = int(ship_width * 0.2) bridge_height = int(ship_height * 0.35) bridge_x1 = x1 + int(ship_width * 0.4) bridge_y1 = y1 + int(ship_height * 0.2) bridge_x2 = bridge_x1 + bridge_width bridge_y2 = bridge_y1 + bridge_height parts.append({ 'name': "舰桥", 'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2), 'confidence': 0.8 }) # 雷达 radar_width = int(ship_width * 0.1) radar_height = int(ship_height * 0.15) radar_x1 = x1 + int(ship_width * 0.45) radar_y1 = y1 + int(ship_height * 0.05) radar_x2 = radar_x1 + radar_width radar_y2 = radar_y1 + radar_height parts.append({ 'name': "雷达", 'bbox': (radar_x1, radar_y1, radar_x2, radar_y2), 'confidence': 0.7 }) # 甲板 deck_width = int(ship_width * 0.8) deck_height = int(ship_height * 0.25) deck_x1 = x1 + int(ship_width * 0.1) deck_y1 = y1 + int(ship_height * 0.6) deck_x2 = deck_x1 + deck_width deck_y2 = deck_y1 + deck_height parts.append({ 'name': "甲板", 'bbox': (deck_x1, deck_y1, deck_x2, deck_y2), 'confidence': 0.75 }) except Exception as e: except Exception as e: return parts def _map_yolo_class_to_ship_part(self, cls_name, ship_type=""): """将YOLO类别名称映射到舰船部件名称""" # 通用物体到舰船部件的映射 mapping = { 'person': '人员', 'bicycle': '小型设备', 'car': '小型设备', 'motorcycle': '小型设备', 'airplane': '舰载机', 'bus': '车辆', 'train': '小型设备', 'truck': '车辆', 'boat': '小艇', 'traffic light': '信号灯', 'fire hydrant': '消防设备', 'stop sign': '标志牌', 'parking meter': '小型设备', 'bench': '设备', 'bird': '无人机', 'cat': '小型设备', 'dog': '小型设备', 'horse': '小型设备', 'sheep': '小型设备', 'cow': '小型设备', 'elephant': '大型设备', 'bear': '大型设备', 'zebra': '小型设备', 'giraffe': '高大设备', 'backpack': '设备', 'umbrella': '小型设备', 'handbag': '设备', 'tie': '小型设备', 'suitcase': '设备', 'frisbee': '小型设备', 'skis': '小型设备', 'snowboard': '小型设备', 'sports ball': '小型设备', 'kite': '无人机', 'baseball bat': '小型设备', 'baseball glove': '小型设备', 'skateboard': '小型设备', 'surfboard': '救生设备', 'tennis racket': '小型设备', 'bottle': '设备', 'wine glass': '设备', 'cup': '设备', 'fork': '设备', 'knife': '设备', 'spoon': '设备', 'bowl': '设备', 'banana': '补给', 'apple': '补给', 'sandwich': '补给', 'orange': '补给', 'broccoli': '补给', 'carrot': '补给', 'hot dog': '补给', 'pizza': '补给', 'donut': '补给', 'cake': '补给', 'chair': '设备', 'couch': '设备', 'potted plant': '设备', 'bed': '设备', 'dining table': '设备', 'toilet': '设施', 'tv': '显示设备', 'laptop': '计算设备', 'mouse': '小型设备', 'remote': '控制设备', 'keyboard': '控制设备', 'cell phone': '通信设备', 'microwave': '设备', 'oven': '设备', 'toaster': '设备', 'sink': '设施', 'refrigerator': '设备', 'book': '资料', 'clock': '设备', 'vase': '设备', 'scissors': '工具', 'teddy bear': '设备', 'hair drier': '设备', 'toothbrush': '设备' } # 针对特定舰船类型的映射优化 if "航母" in ship_type or "航空母舰" in ship_type: if cls_name == 'airplane': return "舰载机" elif cls_name in ['tv', 'laptop', 'cell phone']: return "雷达设备" elif cls_name in ['tower', 'building']: return "舰岛" elif "驱逐舰" in ship_type or "护卫舰" in ship_type: if cls_name in ['tv', 'laptop']: return "相控阵雷达" elif cls_name in ['tower', 'building']: return "舰桥" elif cls_name in ['remote', 'bottle']: return "导弹装置" # 默认返回映射或原始类别名称 return mapping.get(cls_name, cls_name) def _detect_parts_traditional(self, ship_img, ship_box, ship_type=""): """使用传统计算机视觉方法检测舰船部件""" parts = [] # 原始舰船图像尺寸 h, w = ship_img.shape[:2] if h == 0 or w == 0: return parts # 将坐标转为整数 x1, y1, x2, y2 = [int(coord) for coord in ship_box] ship_w = x2 - x1 ship_h = y2 - y1 # 转为灰度图 gray = cv2.cvtColor(ship_img, cv2.COLOR_BGR2GRAY) if len(ship_img.shape) == 3 else ship_img.copy() # 提取边缘 edges = cv2.Canny(gray, 50, 150) _, thresh = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY) # 寻找轮廓 contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 分析轮廓并识别部件 for contour in contours: # 计算轮廓面积和边界框 area = cv2.contourArea(contour) if area < (w * h * 0.005): # 忽略太小的轮廓 continue # 计算边界框 x, y, box_width, box_height = cv2.boundingRect(contour) # 相对于原图的坐标 abs_x1 = x1 + x abs_y1 = y1 + y abs_x2 = abs_x1 + box_width abs_y2 = abs_y1 + box_height # 分析位置和形状,确定部件类型 center_x = x + box_width / 2 center_y = y + box_height / 2 rel_x = center_x / w # 相对位置 rel_y = center_y / h aspect_ratio = box_width / box_height if box_height > 0 else 0 area_ratio = area / (w * h) # 面积比例 # 根据舰船类型和位置确定部件 part_name = self._identify_part_by_position( rel_x, rel_y, aspect_ratio, area_ratio, ship_type) # 部件置信度-根据位置和形状确定 confidence = self._calculate_part_confidence( rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type) # 定义绘制中文的函数 def draw_cn_text(img, text, position, font_scale=0.5, color=(0, 255, 0), thickness=2): """使用PIL绘制中文文本并转回OpenCV格式""" from PIL import Image, ImageDraw, ImageFont import numpy as np # 转换为PIL图像 img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_pil) # 获取系统默认字体或指定中文字体 try: # 尝试使用微软雅黑字体 (Windows) font = ImageFont.truetype("msyh.ttc", int(font_scale * 20)) except: try: # 尝试使用宋体 (Windows) font = ImageFont.truetype("simsun.ttc", int(font_scale * 20)) except: try: # 尝试使用WenQuanYi (Linux) font = ImageFont.truetype("wqy-microhei.ttc", int(font_scale * 20)) except: # 使用系统默认字体 font = ImageFont.load_default() # 绘制文本 draw.text(position, text, font=font, fill=color[::-1]) # RGB顺序 # 转回OpenCV格式 return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # 如果提供了舰船边界框,则只处理该区域 if ship_box is not None: x1, y1, x2, y2 = ship_box # 确保边界在图像范围内 x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2) roi = img[y1:y2, x1:x2] else: roi = img x1, y1 = 0, 0 # 检查ROI有效性 if roi is None or roi.size == 0 or roi.shape[0] <= 0 or roi.shape[1] <= 0: return [], img.copy() # 返回空部件列表和原始图像 # 获取ROI尺寸 roi_h, roi_w = roi.shape[:2] # 如果未提供舰船类型,尝试推测 if ship_type is None or ship_type == "": ship_type = self._infer_ship_type(roi).lower() print(f"推测舰船类型: {ship_type}") else: ship_type = ship_type.lower() print(f"使用提供的舰船类型: {ship_type}") # 初始化部件列表和结果图像 parts = [] result_img = img.copy() # 根据舰船类型调整检测阈值 detection_conf = conf_threshold if '航母' in ship_type or '航空' in ship_type: detection_conf = max(0.05, conf_threshold * 0.5) # 航母需要更低的阈值 elif '驱逐' in ship_type: detection_conf = max(0.1, conf_threshold * 0.7) # 驱逐舰稍微降低阈值 # 进行部件检测 results = self.model(roi, conf=detection_conf, device=self.device) # 处理检测结果 if len(results) > 0: # 获取边界框 boxes = results[0].boxes # 如果有分割结果,也使用它来提高识别精度 if hasattr(results[0], 'masks') and results[0].masks is not None: masks = results[0].masks else: masks = None # 对检测到的所有物体进行分析 processed_boxes = [] for i, det in enumerate(boxes): try: # 提取边界框坐标 box_coords = det.xyxy[0].cpu().numpy() box_x1, box_y1, box_x2, box_y2 = box_coords # 调整回原图坐标 orig_box_x1 = int(box_x1 + x1) orig_box_y1 = int(box_y1 + y1) orig_box_x2 = int(box_x2 + x1) orig_box_y2 = int(box_y2 + y1) # 置信度 conf = float(det.conf[0].cpu().numpy()) # 类别ID cls_id = int(det.cls[0].cpu().numpy()) orig_class_name = self.model.names[cls_id] # 根据位置和尺寸确定部件类型 - 考虑舰船类型 part_name = self._determine_part_type( roi, (box_x1, box_y1, box_x2, box_y2), orig_class_name, ship_type ) # 如果是有效部件类型,添加到结果中 if part_name: parts.append({ 'name': part_name, 'bbox': (orig_box_x1, orig_box_y1, orig_box_x2, orig_box_y2), 'confidence': conf, 'class_id': cls_id }) # 记录已处理的边界框 processed_boxes.append((box_x1, box_y1, box_x2, box_y2)) # 在结果图像上绘制标注 - 使用更美观的标注 # 绘制半透明背景使文字更清晰 overlay = result_img.copy() cv2.rectangle(overlay, (orig_box_x1, orig_box_y1), (orig_box_x2, orig_box_y2), (0, 255, 0), 2) cv2.rectangle(overlay, (orig_box_x1, orig_box_y1-25), (orig_box_x1 + len(part_name)*12, orig_box_y1), (0, 0, 0), -1) cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img) # 绘制标签文字 label_text = f"{part_name}: {conf:.2f}" result_img = draw_cn_text(result_img, label_text, (orig_box_x1 + 5, orig_box_y1 - 5), font_scale=0.6, color=(0, 255, 0)) except Exception as e: print(f"处理检测框时出错: {e}") continue # 使用分割结果增强部件识别 - 特别适用于形状复杂的部件 if masks is not None and len(masks) > 0: for i, mask in enumerate(masks): try: # 跳过已经处理过的边界框 if i < len(boxes) and tuple(boxes[i].xyxy[0].cpu().numpy()) in processed_boxes: continue # 提取分割掩码 mask_array = mask.data[0].cpu().numpy() # 计算掩码边界框 mask_positions = np.where(mask_array > 0.5) if len(mask_positions[0]) == 0 or len(mask_positions[1]) == 0: continue min_y, max_y = np.min(mask_positions[0]), np.max(mask_positions[0]) min_x, max_x = np.min(mask_positions[1]), np.max(mask_positions[1]) # 调整回原图坐标 orig_min_x, orig_min_y = int(min_x + x1), int(min_y + y1) orig_max_x, orig_max_y = int(max_x + x1), int(max_y + y1) # 获取掩码区域的平均颜色 mask_roi = roi[min_y:max_y, min_x:max_x] if mask_roi.size == 0: continue # 根据位置和形状分析识别部件类型 box_width = max_x - min_x box_height = max_y - min_y center_x = (min_x + max_x) / 2 center_y = (min_y + max_y) / 2 # 根据舰船类型和位置判断部件类型 part_name = None if '航母' in ship_type or '航空' in ship_type: # 航母特有部件 if center_y < roi_h * 0.3 and center_x > roi_w * 0.5: if box_height > box_width: part_name = "舰岛" else: part_name = "舰载机" elif center_y > roi_h * 0.4 and box_width > roi_w * 0.3: part_name = "飞行甲板" elif box_width < roi_w * 0.1 and box_height < roi_h * 0.1: part_name = "舰载机" elif '驱逐' in ship_type: # 驱逐舰特有部件 if center_y < roi_h * 0.3 and box_width < roi_w * 0.2: if box_height > box_width: part_name = "桅杆" else: part_name = "相控阵雷达" elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1: part_name = "舰桥" elif center_x < roi_w * 0.3: part_name = "主炮" elif center_x > roi_w * 0.6 and center_y < roi_h * 0.5: part_name = "垂发系统" else: # 通用部件识别 if center_y < roi_h * 0.3 and box_width < roi_w * 0.2: if box_height > box_width * 1.5: part_name = "桅杆" else: part_name = "雷达" elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1: part_name = "舰桥" elif center_x < roi_w * 0.3 and box_width < roi_w * 0.15: part_name = "舰炮" elif center_x > roi_w * 0.7 and box_width > roi_w * 0.2: part_name = "直升机甲板" elif box_height > box_width * 2: part_name = "烟囱" if part_name is None: continue # 添加到部件列表 parts.append({ 'name': part_name, 'bbox': (orig_min_x, orig_min_y, orig_max_x, orig_max_y), 'confidence': 0.7, # 分割结果的默认置信度 'class_id': -1 # 使用-1表示分割结果 }) # 在结果图像上绘制标注 overlay = result_img.copy() cv2.rectangle(overlay, (orig_min_x, orig_min_y), (orig_max_x, orig_max_y), (0, 255, 0), 2) cv2.rectangle(overlay, (orig_min_x, orig_min_y-25), (orig_min_x + len(part_name)*12, orig_min_y), (0, 0, 0), -1) cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img) label_text = f"{part_name}: 0.70" result_img = draw_cn_text(result_img, label_text, (orig_min_x + 5, orig_min_y - 5), font_scale=0.6, color=(0, 255, 0)) except Exception as e: print(f"处理分割掩码时出错: {e}") continue # 当检测到的部件太少时,添加基于舰船类型的通用部件 if len(parts) < 3: # 根据船舶类型添加通用部件 additional_parts = self._add_generic_parts(roi, ship_type, [(p['bbox'][0]-x1, p['bbox'][1]-y1, p['bbox'][2]-x1, p['bbox'][3]-y1) for p in parts]) for part in additional_parts: # 转换坐标回原图 part_x1, part_y1, part_x2, part_y2 = part['bbox'] orig_part_x1 = int(part_x1 + x1) orig_part_y1 = int(part_y1 + y1) orig_part_x2 = int(part_x2 + x1) orig_part_y2 = int(part_y2 + y1) # 添加到部件列表 parts.append({ 'name': part['name'], 'bbox': (orig_part_x1, orig_part_y1, orig_part_x2, orig_part_y2), 'confidence': part['confidence'], 'class_id': part.get('class_id', -2) # 使用-2表示启发式生成的部件 }) # 在结果图像上绘制标注 overlay = result_img.copy() cv2.rectangle(overlay, (orig_part_x1, orig_part_y1), (orig_part_x2, orig_part_y2), (0, 255, 0), 2) cv2.rectangle(overlay, (orig_part_x1, orig_part_y1-25), (orig_part_x1 + len(part['name'])*12, orig_part_y1), (0, 0, 0), -1) cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img) label_text = f"{part['name']}: {part['confidence']:.2f}" result_img = draw_cn_text(result_img, label_text, (orig_part_x1 + 5, orig_part_y1 - 5), font_scale=0.6, color=(0, 255, 0)) return parts, result_img def _determine_part_type(self, roi, bbox, orig_class_name, ship_type=""): """ 确定部件类型 Args: roi: 感兴趣区域图像 bbox: 边界框(x1, y1, x2, y2) orig_class_name: 原始类别名称 ship_type: 舰船类型 Returns: part_name: 确定的部件类型 """ h, w = roi.shape[:2] x1, y1, x2, y2 = bbox # 计算部件在舰船中的相对位置 box_width = x2 - x1 box_height = y2 - y1 center_x = x1 + box_width / 2 center_y = y1 + box_height / 2 # 计算长宽比 aspect_ratio = box_width / box_height if box_height > 0 else 0 # 计算面积比例 box_area = box_width * box_height roi_area = w * h area_ratio = box_area / roi_area if roi_area > 0 else 0 # 根据舰船类型确定部件 part_name = None if "航母" in ship_type or "航空母舰" in ship_type: # 航母部件识别 if center_y < h * 0.3: if center_x > w * 0.5: if box_width > box_height and area_ratio > 0.05: part_name = "舰岛" elif area_ratio < 0.03: part_name = "雷达" elif box_height > box_width * 1.5 and area_ratio < 0.02: part_name = "桅杆" elif area_ratio < 0.03 and box_width > box_height: part_name = "相控阵雷达" elif center_y > h * 0.4 and center_y < h * 0.8 and center_x > w * 0.1 and center_x < w * 0.9: if area_ratio > 0.2: part_name = "飞行甲板" elif box_width > box_height * 3 and area_ratio < 0.1: part_name = "弹射器" elif box_width > box_height and area_ratio < 0.05: part_name = "舰载机" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "烟囱" elif "驱逐舰" in ship_type or "巡洋舰" in ship_type: # 驱逐舰/巡洋舰部件识别 if center_y < h * 0.35: if box_width > w * 0.1 and area_ratio > 0.08: part_name = "舰桥" elif box_height > box_width * 1.5 and area_ratio < 0.02: part_name = "桅杆" elif area_ratio < 0.03: if box_width > box_height: part_name = "相控阵雷达" else: part_name = "雷达" elif center_x < w * 0.3 and center_y < h * 0.5: if box_width < w * 0.15 and box_height < h * 0.15: part_name = "主炮" elif center_x > w * 0.6 and center_y < h * 0.6 and box_width < w * 0.2: if box_width > box_height and area_ratio < 0.05: part_name = "垂发系统" elif center_x > w * 0.7 and center_y > h * 0.6 and area_ratio > 0.05: part_name = "直升机甲板" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "烟囱" elif "护卫舰" in ship_type: # 护卫舰部件识别 if center_y < h * 0.4: if box_width > w * 0.1 and area_ratio > 0.08: part_name = "舰桥" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "桅杆" elif area_ratio < 0.03: part_name = "雷达" elif rel_x < 0.3 and box_width < w * 0.15: part_name = "舰炮" elif rel_x > 0.7 and area_ratio > 0.05: part_name = "直升机甲板" elif rel_x > 0.6 and area_ratio < 0.04: part_name = "导弹发射器" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "烟囱" elif "潜艇" in ship_type: # 潜艇部件识别 if center_y < h * 0.3 and area_ratio > 0.05: part_name = "塔台" elif center_x < w * 0.3 and area_ratio < 0.04: part_name = "鱼雷管" elif center_y < h * 0.2 and area_ratio < 0.02: part_name = "通信天线" elif center_x > w * 0.6 and center_y < h * 0.4 and area_ratio < 0.03: part_name = "潜望镜" elif center_y > h * 0.6 and center_x < w * 0.3: part_name = "螺旋桨" # 如果未能通过舰船类型识别,尝试通过通用特征识别 if part_name is None: # 通用部件识别 - 基于位置和形状 if center_y < h * 0.3: if box_width > w * 0.2 and area_ratio > 0.08: part_name = "舰桥" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "桅杆" elif area_ratio < 0.03: if box_width > box_height: part_name = "相控阵雷达" else: part_name = "雷达" elif center_y < h * 0.6: if center_x < w * 0.25 and area_ratio < 0.05: part_name = "舰炮" elif box_height > box_width * 1.5 and area_ratio < 0.03: part_name = "烟囱" elif aspect_ratio > 1.5 and area_ratio < 0.05 and center_x > w * 0.6: part_name = "导弹发射器" elif center_y > h * 0.6: if rel_x > 0.7 and area_ratio > 0.05: part_name = "直升机甲板" elif area_ratio > 0.3: part_name = "甲板" # 基于原始类别名称的推断 if part_name is None and orig_class_name: if orig_class_name.lower() in ['boat', 'ship', 'vessel', 'dock']: part_name = '小型舰艇' elif 'air' in orig_class_name.lower() or 'plane' in orig_class_name.lower(): part_name = '舰载机' elif 'tower' in orig_class_name.lower() or 'pole' in orig_class_name.lower(): part_name = '桅杆' elif 'radar' in orig_class_name.lower() or 'antenna' in orig_class_name.lower(): part_name = '雷达' elif 'gun' in orig_class_name.lower() or 'cannon' in orig_class_name.lower(): part_name = '舰炮' # 如果面积太小,可能是噪声 if part_name is None and area_ratio < 0.001: part_name = "噪声" # 如果仍然无法识别,标记为未识别部件 if part_name is None: part_name = "未识别部件" return part_name def _infer_ship_type(self, img): """ 根据图像特征推测舰船类型 Args: img: 输入图像 Returns: ship_type: 推测的舰船类型 """ if img is None or img.size == 0 or img.shape[0] == 0 or img.shape[1] == 0: return "未知舰船" # 获取图像尺寸 height, width = img.shape[:2] aspect_ratio = width / height if height > 0 else 0 # 转为灰度图 gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img # 边缘检测 edges = cv2.Canny(gray, 50, 150) edge_pixels = cv2.countNonZero(edges) edge_density = edge_pixels / (width * height) if width * height > 0 else 0 # 水平线特征检测 - 对航母重要 horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1)) horizontal_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, horizontal_kernel) horizontal_pixels = cv2.countNonZero(horizontal_lines) horizontal_ratio = horizontal_pixels / (width * height) if width * height > 0 else 0 # 垂直线特征检测 - 对驱逐舰重要 vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 15)) vertical_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, vertical_kernel) vertical_pixels = cv2.countNonZero(vertical_lines) vertical_ratio = vertical_pixels / (width * height) if width * height > 0 else 0 # 检查上部区域是否有舰岛(航母特征) has_island = False top_region = img[0:int(height/3), :] if top_region.size > 0: top_gray = cv2.cvtColor(top_region, cv2.COLOR_BGR2GRAY) if len(top_region.shape) == 3 else top_region _, top_thresh = cv2.threshold(top_gray, 100, 255, cv2.THRESH_BINARY) top_contours, _ = cv2.findContours(top_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in top_contours: area = cv2.contourArea(contour) if area > (top_region.shape[0] * top_region.shape[1] * 0.01): x, y, w, h = cv2.boundingRect(contour) # 舰岛通常高于宽 if h > w and h > top_region.shape[0] * 0.3: has_island = True break # 航空母舰特征:长宽比大,水平线丰富,有舰岛 if (aspect_ratio > 3.0 or horizontal_ratio > 0.05) and edge_density < 0.15: if has_island or aspect_ratio > 3.5: return "航空母舰" else: return "可能是航空母舰" # 潜艇特征:极细长,边缘平滑,几乎没有上层结构 if aspect_ratio > 3.5 and edge_density < 0.1 and vertical_ratio < 0.02: return "潜艇" # 驱逐舰特征:中等长宽比,边缘复杂,有明显的上层结构 if 2.2 < aspect_ratio < 3.8 and edge_density > 0.1 and vertical_ratio > 0.02: return "驱逐舰" # 护卫舰特征:较小长宽比,结构适中 if 2.0 < aspect_ratio < 3.0 and 0.1 < edge_density < 0.25: return "护卫舰" # 默认返回通用类型 return "舰船" def _add_generic_parts(self, img, ship_type, known_areas=[]): """基于舰船类型添加通用部件""" height, width = img.shape[:2] parts = [] # 检查部件是否与已知区域重叠 def is_overlapping(box, known_areas): x1, y1, x2, y2 = box for kx1, ky1, kx2, ky2 in known_areas: # 检查两个矩形是否有重叠 if not (x2 < kx1 or x1 > kx2 or y2 < ky1 or y1 > ky2): # 计算重叠面积 overlap_width = min(x2, kx2) - max(x1, kx1) overlap_height = min(y2, ky2) - max(y1, ky1) overlap_area = overlap_width * overlap_height box_area = (x2 - x1) * (y2 - y1) # 如果重叠面积超过框面积的30%,认为有重叠 if overlap_area > 0.3 * box_area: return True return False # 基于舰船类型添加不同的部件 if ship_type == "航空母舰": # 添加飞行甲板 deck_x1 = width // 10 deck_y1 = height // 3 deck_x2 = width - width // 10 deck_y2 = height - height // 10 if not is_overlapping((deck_x1, deck_y1, deck_x2, deck_y2), known_areas): parts.append({ 'name': "飞行甲板", 'bbox': (deck_x1, deck_y1, deck_x2, deck_y2), 'confidence': 0.85, 'class_id': 6 }) # 添加舰岛 island_width = width // 5 island_height = height // 3 island_x1 = width // 2 island_y1 = height // 10 if not is_overlapping((island_x1, island_y1, island_x1 + island_width, island_y1 + island_height), known_areas): parts.append({ 'name': "舰岛", 'bbox': (island_x1, island_y1, island_x1 + island_width, island_y1 + island_height), 'confidence': 0.8, 'class_id': 0 }) # 添加弹射器 catapult_width = width // 3 catapult_height = height // 20 catapult_x1 = width // 10 catapult_y1 = height // 3 if not is_overlapping((catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height), known_areas): parts.append({ 'name': "弹射器", 'bbox': (catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height), 'confidence': 0.75, 'class_id': 3 }) elif ship_type == "驱逐舰" or ship_type == "护卫舰": # 添加舰桥 bridge_width = width // 4 bridge_height = height // 3 bridge_x1 = width // 2 - bridge_width // 2 bridge_y1 = height // 10 if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas): parts.append({ 'name': "舰桥", 'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), 'confidence': 0.85, 'class_id': 0 }) # 添加主炮 gun_width = width // 10 gun_height = height // 10 gun_x1 = width // 10 gun_y1 = height // 5 if not is_overlapping((gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height), known_areas): parts.append({ 'name': "主炮", 'bbox': (gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height), 'confidence': 0.8, 'class_id': 2 }) # 添加导弹垂发系统 vls_width = width // 8 vls_height = height // 8 vls_x1 = width // 3 vls_y1 = height // 3 if not is_overlapping((vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height), known_areas): parts.append({ 'name': "垂发系统", 'bbox': (vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height), 'confidence': 0.75, 'class_id': 3 }) # 添加直升机甲板 heli_width = width // 5 heli_height = height // 5 heli_x1 = width * 3 // 4 heli_y1 = height // 2 if not is_overlapping((heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height), known_areas): parts.append({ 'name': "直升机甲板", 'bbox': (heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height), 'confidence': 0.7, 'class_id': 4 }) # 添加相控阵雷达 radar_width = width // 12 radar_height = height // 8 radar_x1 = width // 2 radar_y1 = height // 12 if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas): parts.append({ 'name': "相控阵雷达", 'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), 'confidence': 0.8, 'class_id': 1 }) elif ship_type == "潜艇": # 添加舰桥塔 sail_width = width // 6 sail_height = height // 2 sail_x1 = width // 2 - sail_width // 2 sail_y1 = 0 if not is_overlapping((sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height), known_areas): parts.append({ 'name': "舰桥塔", 'bbox': (sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height), 'confidence': 0.85, 'class_id': 0 }) # 添加鱼雷管 torpedo_width = width // 10 torpedo_height = height // 10 torpedo_x1 = width // 10 torpedo_y1 = height // 2 if not is_overlapping((torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height), known_areas): parts.append({ 'name': "鱼雷管", 'bbox': (torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height), 'confidence': 0.7, 'class_id': 3 }) else: # 通用舰船 # 添加舰桥 bridge_width = width // 4 bridge_height = height // 3 bridge_x1 = width // 2 - bridge_width // 2 bridge_y1 = height // 8 if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas): parts.append({ 'name': "舰桥", 'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), 'confidence': 0.75, 'class_id': 0 }) # 添加雷达 radar_width = width // 10 radar_height = width // 10 # 保持正方形 radar_x1 = width // 2 - radar_width // 2 radar_y1 = height // 20 if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas): parts.append({ 'name': "雷达", 'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), 'confidence': 0.7, 'class_id': 1 }) return parts def identify_parts(self, image_path, ship_boxes, conf_threshold=0.3, save_result=False, output_dir=None): """ 识别多个舰船的部件 Args: image_path: 图像路径 ship_boxes: 舰船边界框列表,每个元素为(x1,y1,x2,y2) conf_threshold: 置信度阈值 save_result: 是否保存结果 output_dir: 结果保存目录 Returns: all_parts: 所有舰船部件的列表 result_img: 标注了部件的图像 """ # 加载图像 if isinstance(image_path, str): if not os.path.exists(image_path): raise FileNotFoundError(f"图像文件不存在: {image_path}") img = cv2.imread(image_path) else: img = image_path.copy() result_img = img.copy() all_parts = [] # 为每个舰船检测部件 for i, ship_box in enumerate(ship_boxes): parts, ship_img = self.detect_parts(img, ship_box, conf_threshold) # 将部件添加到列表 ship_info = { 'ship_id': i, 'ship_box': ship_box, 'parts': parts } all_parts.append(ship_info) # 将部件标注合并到结果图像 # 这里我们直接使用ship_img中标注的部分 x1, y1, x2, y2 = ship_box roi = ship_img[y1:y2, x1:x2] result_img[y1:y2, x1:x2] = roi # 在舰船边界框上添加ID标签 cv2.rectangle(result_img, (x1, y1), (x2, y2), (255, 0, 0), 2) cv2.putText(result_img, f"Ship {i}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2) # 保存结果 if save_result and output_dir is not None: os.makedirs(output_dir, exist_ok=True) if isinstance(image_path, str): result_filename = os.path.join(output_dir, f"parts_{os.path.basename(image_path)}") else: result_filename = os.path.join(output_dir, f"parts_{len(os.listdir(output_dir))}.jpg") cv2.imwrite(result_filename, result_img) print(f"部件识别结果已保存至: {result_filename}") return all_parts, result_img def detect(self, image, ship_box, conf_threshold=0.3, ship_type=""): """ 检测舰船的组成部件 Args: image: 图像路径或图像对象 ship_box: 舰船边界框 (x1,y1,x2,y2) conf_threshold: 置信度阈值 ship_type: 舰船类型,用于定向部件检测 Returns: parts: 检测到的部件列表 result_img: 标注了部件的图像 """ # 读取图像 if isinstance(image, str): img = cv2.imread(image) else: img = image.copy() if isinstance(image, np.ndarray) else np.array(image) if img is None: return [], np.zeros((100, 100, 3), dtype=np.uint8) # 返回结果 result_img = img.copy() # 提取舰船区域 x1, y1, x2, y2 = ship_box # 确保索引是整数 x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) # 确保边界在图像范围内 h, w = img.shape[:2] x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) # 使用模型检测部件(实际实现取决于具体模型) parts = [] # 根据舰船类型添加常见部件 if "航空母舰" in ship_type: # 添加舰岛 island_w = int((x2 - x1) * 0.15) island_h = int((y2 - y1) * 0.3) island_x = x1 + int((x2 - x1) * 0.7) island_y = y1 + int((y2 - y1) * 0.1) parts.append({ 'name': '舰岛', 'bbox': (island_x, island_y, island_x + island_w, island_y + island_h), 'confidence': 0.8, 'class_id': 1 }) # 标注舰岛 cv2.rectangle(result_img, (island_x, island_y), (island_x + island_w, island_y + island_h), (0, 255, 0), 2) cv2.putText(result_img, "舰岛: 0.80", (island_x, island_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 添加甲板 deck_w = int((x2 - x1) * 0.9) deck_h = int((y2 - y1) * 0.5) deck_x = x1 + int((x2 - x1) * 0.05) deck_y = y1 + int((y2 - y1) * 0.3) parts.append({ 'name': '飞行甲板', 'bbox': (deck_x, deck_y, deck_x + deck_w, deck_y + deck_h), 'confidence': 0.85, 'class_id': 2 }) # 标注甲板 cv2.rectangle(result_img, (deck_x, deck_y), (deck_x + deck_w, deck_y + deck_h), (255, 0, 0), 2) cv2.putText(result_img, "飞行甲板: 0.85", (deck_x, deck_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) elif "驱逐舰" in ship_type: # 添加舰桥 bridge_w = int((x2 - x1) * 0.2) bridge_h = int((y2 - y1) * 0.4) bridge_x = x1 + int((x2 - x1) * 0.4) bridge_y = y1 + int((y2 - y1) * 0.1) parts.append({ 'name': '舰桥', 'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h), 'confidence': 0.8, 'class_id': 3 }) # 标注舰桥 cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2) cv2.putText(result_img, "舰桥: 0.80", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) elif "护卫舰" in ship_type: # 添加舰桥 bridge_w = int((x2 - x1) * 0.15) bridge_h = int((y2 - y1) * 0.3) bridge_x = x1 + int((x2 - x1) * 0.45) bridge_y = y1 + int((y2 - y1) * 0.15) parts.append({ 'name': '舰桥', 'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h), 'confidence': 0.75, 'class_id': 3 }) # 标注舰桥 cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2) cv2.putText(result_img, "舰桥: 0.75", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) return parts, result_img def _identify_part_by_position(self, rel_x, rel_y, aspect_ratio, area_ratio, ship_type): """根据相对位置和形状识别舰船部件""" # 根据舰船类型和位置确定部件 if "航母" in ship_type or "航空母舰" in ship_type: # 航母部件 if rel_y < 0.3: if rel_x > 0.5: if aspect_ratio < 1 and area_ratio > 0.05: return "舰岛" elif area_ratio < 0.03: return "雷达" elif area_ratio < 0.02 and aspect_ratio < 1: return "桅杆" elif area_ratio < 0.03: return "相控阵雷达" elif rel_y > 0.4 and rel_x > 0.1 and rel_x < 0.9: return "飞行甲板" elif rel_x < 0.5 and rel_y > 0.6: return "弹射器" elif aspect_ratio > 1 and area_ratio < 0.05 and rel_y > 0.3: return "舰载机" elif aspect_ratio < 1 and area_ratio < 0.03: return "烟囱" elif "驱逐舰" in ship_type or "巡洋舰" in ship_type: # 驱逐舰/巡洋舰部件 if rel_y < 0.35: if area_ratio > 0.1: return "舰桥" elif aspect_ratio < 1 and area_ratio < 0.02: return "桅杆" elif area_ratio < 0.03: if aspect_ratio > 1.5: return "相控阵雷达" else: return "球形雷达" elif rel_x < 0.3 and rel_y < 0.5: if area_ratio < 0.05: return "主炮" elif rel_x > 0.6 and area_ratio < 0.05: return "垂发系统" elif rel_y > 0.6 and rel_x > 0.7 and area_ratio > 0.05: return "直升机甲板" elif aspect_ratio < 1 and area_ratio < 0.02: return "烟囱" elif "护卫舰" in ship_type: # 护卫舰部件 if rel_y < 0.4: if area_ratio > 0.08: return "舰桥" elif area_ratio < 0.03 and aspect_ratio < 1: return "桅杆" elif area_ratio < 0.03: return "雷达" elif rel_x < 0.3: return "舰炮" elif rel_x > 0.7 and area_ratio > 0.05: return "直升机甲板" elif rel_x > 0.6 and area_ratio < 0.04: return "导弹发射器" elif "潜艇" in ship_type: # 潜艇部件 if rel_y < 0.3 and area_ratio > 0.05: return "塔台" elif rel_x < 0.3 and area_ratio < 0.04: return "鱼雷管" elif rel_y < 0.2 and area_ratio < 0.02: return "通信天线" # 通用部件识别 - 基于位置和形状 if rel_y < 0.3: if area_ratio > 0.08: return "舰桥" elif aspect_ratio < 1 and area_ratio < 0.03: return "桅杆" elif area_ratio < 0.03: if aspect_ratio > 1: return "相控阵雷达" else: return "雷达" elif rel_y < 0.6: if rel_x < 0.25 and area_ratio < 0.05: return "舰炮" elif aspect_ratio < 1 and area_ratio < 0.03: return "烟囱" elif aspect_ratio > 1.5 and area_ratio < 0.05 and rel_x > 0.6: return "导弹发射器" elif rel_y > 0.6: if rel_x > 0.7 and area_ratio > 0.05: return "直升机甲板" elif area_ratio > 0.3: return "甲板" return "未知部件" def _calculate_part_confidence(self, rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type): """计算部件识别的置信度""" # 基础置信度 base_confidence = 0.5 # 根据部件类型和位置调整置信度 if part_name == "舰桥": if rel_y < 0.4 and area_ratio > 0.05: return min(0.9, base_confidence + 0.3) elif part_name == "雷达": if rel_y < 0.3: return min(0.85, base_confidence + 0.25) elif part_name == "舰炮": if rel_x < 0.3: return min(0.8, base_confidence + 0.2) elif part_name == "导弹发射器": if 0.4 < rel_x < 0.8: return min(0.75, base_confidence + 0.15) elif part_name == "飞行甲板": if "航母" in ship_type and area_ratio > 0.3: return min(0.95, base_confidence + 0.4) elif part_name == "舰岛": if "航母" in ship_type and rel_x > 0.6 and rel_y < 0.3: return min(0.9, base_confidence + 0.3) elif part_name == "直升机甲板": if rel_x > 0.7 and rel_y > 0.6: return min(0.8, base_confidence + 0.2) # 根据舰船类型增加特定部件的置信度 if "航母" in ship_type: if part_name in ["舰岛", "飞行甲板", "舰载机", "弹射器"]: return min(0.85, base_confidence + 0.25) elif "驱逐舰" in ship_type: if part_name in ["舰桥", "主炮", "垂发系统", "相控阵雷达"]: return min(0.85, base_confidence + 0.25) elif "护卫舰" in ship_type: if part_name in ["舰桥", "舰炮", "导弹发射器"]: return min(0.8, base_confidence + 0.2) elif "潜艇" in ship_type: if part_name in ["塔台", "鱼雷管"]: return min(0.85, base_confidence + 0.25) return base_confidence def _calculate_iou(self, box1, box2): """计算两个边界框的IoU""" # 确保所有坐标为数字 x1_1, y1_1, x2_1, y2_1 = box1 x1_2, y1_2, x2_2, y2_2 = box2 # 计算交集面积 x_left = max(x1_1, x1_2) y_top = max(y1_1, y1_2) x_right = min(x2_1, x2_2) y_bottom = min(y2_1, y2_2) # 如果没有交集,返回0 if x_right < x_left or y_bottom < y_top: return 0.0 intersection_area = (x_right - x_left) * (y_bottom - y_top) # 计算两个框的面积 box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) # 计算IoU return intersection_area / float(box1_area + box2_area - intersection_area) def _add_specific_ship_parts(self, ship_img, ship_box, ship_type): """ 为特定类型的舰船添加可能的部件 Args: ship_img: 舰船图像 ship_box: 舰船边界框 ship_type: 舰船类型 Returns: additional_parts: 额外检测到的部件列表 """ additional_parts = [] h, w = ship_img.shape[:2] # 获取目标舰船类型的特定部件 target_parts = self._get_target_parts(ship_type) if "航空母舰" in ship_type: # 检测舰岛 island_x = int(w * 0.75) # 舰岛通常在右侧四分之三处 island_y = int(h * 0.2) # 从顶部20%开始 island_w = int(w * 0.2) # 宽度约为舰船宽度的20% island_h = int(h * 0.3) # 高度约为舰船高度的30% # 验证区域 island_roi = ship_img[island_y:island_y+island_h, island_x:island_x+island_w] if island_roi.size > 0: # 分析区域特征 gray = cv2.cvtColor(island_roi, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150) edge_pixels = cv2.countNonZero(edges) edge_density = edge_pixels / (island_roi.shape[0] * island_roi.shape[1]) if island_roi.size > 0 else 0 # 如果区域有足够的边缘特征,识别为舰岛 if edge_density > 0.1 and "舰岛" in target_parts: additional_parts.append({ "name": "舰岛", "confidence": 0.85, "box": [island_x, island_y, island_x + island_w, island_y + island_h], "relative_box": [island_x/w, island_y/h, (island_x+island_w)/w, (island_y+island_h)/h] }) # 检测飞行甲板 deck_x = int(w * 0.05) # 从左侧5%开始 deck_y = int(h * 0.1) # 从顶部10%开始 deck_w = int(w * 0.9) # 宽度约为舰船宽度的90% deck_h = int(h * 0.2) # 高度约为舰船高度的20% if "飞行甲板" in target_parts: additional_parts.append({ "name": "飞行甲板", "confidence": 0.9, "box": [deck_x, deck_y, deck_x + deck_w, deck_y + deck_h], "relative_box": [deck_x/w, deck_y/h, (deck_x+deck_w)/w, (deck_y+deck_h)/h] }) # 检测舰载机(如果有) # 分析顶部区域是否有舰载机特征 top_area = ship_img[0:int(h*0.3), :] if top_area.size > 0: gray = cv2.cvtColor(top_area, cv2.COLOR_BGR2GRAY) blur = cv2.GaussianBlur(gray, (5, 5), 0) _, thresh = cv2.threshold(blur, 120, 255, cv2.THRESH_BINARY_INV) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 筛选可能的舰载机轮廓 aircraft_contours = [] for cnt in contours: area = cv2.contourArea(cnt) if area > 100 and area < 5000: # 舰载机大小范围 x, y, w_box, h_box = cv2.boundingRect(cnt) aspect_ratio = w_box / h_box if h_box > 0 else 0 if 1.5 < aspect_ratio < 3.0: # 舰载机长宽比范围 aircraft_contours.append(cnt) # 为每个可能的舰载机添加部件 for i, cnt in enumerate(aircraft_contours[:5]): # 最多添加5个舰载机 x, y, w_box, h_box = cv2.boundingRect(cnt) if "舰载机" in target_parts: additional_parts.append({ "name": "舰载机", "confidence": 0.75, "box": [x, y, x + w_box, y + h_box], "relative_box": [x/w, y/h, (x+w_box)/w, (y+h_box)/h] }) elif "驱逐舰" in ship_type or "护卫舰" in ship_type: # 检测舰炮(通常在舰首) gun_x = int(w * 0.05) # 从左侧5%开始(假设舰首在左侧) gun_y = int(h * 0.1) # 从顶部10%开始 gun_w = int(w * 0.15) # 宽度约为舰船宽度的15% gun_h = int(h * 0.15) # 高度约为舰船高度的15% # 验证区域 gun_roi = ship_img[gun_y:gun_y+gun_h, gun_x:gun_x+gun_w] if gun_roi.size > 0 and "舰炮" in target_parts: # 可以使用更复杂的检测算法来验证 additional_parts.append({ "name": "舰炮", "confidence": 0.8, "box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h], "relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h] }) # 检测舰桥(通常在中部偏后) bridge_x = int(w * 0.4) # 从40%处开始 bridge_y = int(h * 0.05) # 从顶部5%开始 bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20% bridge_h = int(h * 0.3) # 高度约为舰船高度的30% # 验证区域 bridge_roi = ship_img[bridge_y:bridge_y+bridge_h, bridge_x:bridge_x+bridge_w] if bridge_roi.size > 0 and "舰桥" in target_parts: # 分析区域特征 gray = cv2.cvtColor(bridge_roi, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150) edge_pixels = cv2.countNonZero(edges) edge_density = edge_pixels / (bridge_roi.shape[0] * bridge_roi.shape[1]) if bridge_roi.size > 0 else 0 if edge_density > 0.1: additional_parts.append({ "name": "舰桥", "confidence": 0.85, "box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h], "relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h] }) # 检测雷达(通常在舰桥顶部) radar_x = int(w * 0.45) # 从45%处开始 radar_y = int(h * 0.05) # 从顶部5%开始 radar_w = int(w * 0.1) # 宽度约为舰船宽度的10% radar_h = int(h * 0.1) # 高度约为舰船高度的10% if "雷达" in target_parts: additional_parts.append({ "name": "雷达", "confidence": 0.7, "box": [radar_x, radar_y, radar_x + radar_w, radar_y + radar_h], "relative_box": [radar_x/w, radar_y/h, (radar_x+radar_w)/w, (radar_y+radar_h)/h] }) # 检测垂直发射系统(通常在中前部) vls_x = int(w * 0.25) # 从25%处开始 vls_y = int(h * 0.1) # 从顶部10%开始 vls_w = int(w * 0.15) # 宽度约为舰船宽度的15% vls_h = int(h * 0.1) # 高度约为舰船高度的10% if "垂直发射系统" in target_parts: additional_parts.append({ "name": "垂直发射系统", "confidence": 0.75, "box": [vls_x, vls_y, vls_x + vls_w, vls_y + vls_h], "relative_box": [vls_x/w, vls_y/h, (vls_x+vls_w)/w, (vls_y+vls_h)/h] }) elif "潜艇" in ship_type: # 检测舰塔(通常在潜艇中部上方) tower_x = int(w * 0.4) # 从40%处开始 tower_y = int(h * 0.0) # 从顶部开始 tower_w = int(w * 0.2) # 宽度约为潜艇宽度的20% tower_h = int(h * 0.5) # 高度约为潜艇高度的50% if "舰塔" in target_parts: additional_parts.append({ "name": "舰塔", "confidence": 0.9, "box": [tower_x, tower_y, tower_x + tower_w, tower_y + tower_h], "relative_box": [tower_x/w, tower_y/h, (tower_x+tower_w)/w, (tower_y+tower_h)/h] }) elif "巡洋舰" in ship_type: # 检测舰桥(通常在中部) bridge_x = int(w * 0.4) # 从40%处开始 bridge_y = int(h * 0.05) # 从顶部5%开始 bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20% bridge_h = int(h * 0.3) # 高度约为舰船高度的30% if "舰桥" in target_parts: additional_parts.append({ "name": "舰桥", "confidence": 0.85, "box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h], "relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h] }) # 检测主炮(通常在前部) gun_x = int(w * 0.1) # 从10%处开始 gun_y = int(h * 0.1) # 从顶部10%开始 gun_w = int(w * 0.15) # 宽度约为舰船宽度的15% gun_h = int(h * 0.15) # 高度约为舰船高度的15% if "舰炮" in target_parts: additional_parts.append({ "name": "舰炮", "confidence": 0.8, "box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h], "relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h] }) # 检测垂直发射系统(通常分布在前后部) vls1_x = int(w * 0.2) # 前部VLS vls1_y = int(h * 0.1) vls1_w = int(w * 0.1) vls1_h = int(h * 0.1) vls2_x = int(w * 0.6) # 后部VLS vls2_y = int(h * 0.1) vls2_w = int(w * 0.1) vls2_h = int(h * 0.1) if "垂直发射系统" in target_parts: additional_parts.append({ "name": "垂直发射系统", "confidence": 0.75, "box": [vls1_x, vls1_y, vls1_x + vls1_w, vls1_y + vls1_h], "relative_box": [vls1_x/w, vls1_y/h, (vls1_x+vls1_w)/w, (vls1_y+vls1_h)/h] }) additional_parts.append({ "name": "垂直发射系统", "confidence": 0.75, "box": [vls2_x, vls2_y, vls2_x + vls2_w, vls2_y + vls2_h], "relative_box": [vls2_x/w, vls2_y/h, (vls2_x+vls2_w)/w, (vls2_y+vls2_h)/h] }) return additional_parts def _get_target_parts(self, ship_type=""): """ 根据舰船类型获取目标部件列表 Args: ship_type: 舰船类型 Returns: target_parts: 该类型舰船应该检测的部件列表 """ # 默认检测所有部件 if not ship_type or ship_type not in self.ship_parts_map: return list(self.part_types.keys()) # 返回特定舰船类型的部件列表 return self.ship_parts_map.get(ship_type, []) def _map_to_ship_part(self, cls_id, ship_type=""): """ 将YOLO类别ID映射到舰船部件名称 Args: cls_id: YOLO类别ID ship_type: 舰船类型名称 Returns: part_name: 舰船部件名称 """ # COCO数据集中的常见类别 coco_classes = { 0: "person", # 人 -> 可能是船员 1: "bicycle", # 自行车 2: "car", # 汽车 3: "motorcycle", # 摩托车 4: "airplane", # 飞机 -> 可能是舰载机 5: "bus", # 公交车 6: "train", # 火车 7: "truck", # 卡车 8: "boat", # 船 9: "traffic light", # 交通灯 10: "fire hydrant", # 消防栓 11: "stop sign", # 停止标志 12: "parking meter", # 停车计时器 13: "bench", # 长凳 14: "bird", # 鸟 15: "cat", # 猫 16: "dog", # 狗 17: "horse", # 马 24: "backpack", # 背包 25: "umbrella", # 雨伞 26: "handbag", # 手提包 27: "tie", # 领带 28: "suitcase", # 手提箱 33: "sports ball", # 运动球 34: "kite", # 风筝 35: "baseball bat", # 棒球棒 36: "baseball glove", # 棒球手套 41: "skateboard", # 滑板 42: "surfboard", # 冲浪板 43: "tennis racket", # 网球拍 59: "potted plant", # 盆栽植物 60: "bed", # 床 61: "dining table", # 餐桌 62: "toilet", # 厕所 63: "tv", # 电视 64: "laptop", # 笔记本电脑 65: "mouse", # 鼠标 66: "remote", # 遥控器 67: "keyboard", # 键盘 68: "cell phone", # 手机 69: "microwave", # 微波炉 70: "oven", # 烤箱 71: "toaster", # 烤面包机 72: "sink", # 水槽 73: "refrigerator", # 冰箱 74: "book", # 书 75: "clock", # 时钟 76: "vase", # 花瓶 77: "scissors", # 剪刀 78: "teddy bear", # 泰迪熊 79: "hair drier", # 吹风机 80: "toothbrush" # 牙刷 } # 将COCO类别映射到舰船部件 part_mappings = { "person": "船员", "airplane": "舰载机", "boat": "小艇", "backpack": "装备", "tie": "指挥官", "suitcase": "设备箱", "sports ball": "雷达球", "kite": "雷达天线", "surfboard": "救生设备", "bed": "甲板", "dining table": "甲板", "toilet": "舱室", "tv": "显示设备", "laptop": "控制设备", "mouse": "控制设备", "remote": "控制设备", "keyboard": "控制设备", "cell phone": "通信设备", "clock": "导航设备", "scissors": "工具" } # 获取COCO类别名称 if cls_id in coco_classes: coco_class = coco_classes[cls_id] # 将COCO类别映射到舰船部件 if coco_class in part_mappings: return part_mappings[coco_class] # 如果是航母且检测到飞机,返回舰载机 if ("航母" in ship_type or "航空母舰" in ship_type) and cls_id == 4: return "舰载机" # 默认返回未知 return "unknown" # 测试代码 if __name__ == "__main__": # 初始化部件检测器 part_detector = ShipPartDetector() # 测试图像路径和舰船边界框 test_image = "path/to/test/image.jpg" test_ship_box = (100, 100, 500, 400) # 示例边界框 if os.path.exists(test_image): # 执行部件检测 parts, result_img = part_detector.detect_parts(test_image, test_ship_box) # 打印检测结果 print(f"检测到 {len(parts)} 个舰船部件:") for i, part in enumerate(parts): print(f"部件 {i+1}: 类型={part['name']}, 置信度={part['confidence']:.4f}, 位置={part['bbox']}")