You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Software_Architecture/distance-judgement/src/drone/utils/part_detector_final.py

1996 lines
88 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
import torch
import numpy as np
import cv2
from ultralytics import YOLO
from pathlib import Path
class ShipPartDetector:
"""
舰船部件检测模块,负责识别舰船的各个组成部分
支持检测的部件包括:舰桥、雷达、舰炮、导弹发射装置、直升机甲板等
"""
def __init__(self, model_path=None, device=None):
"""
初始化舰船部件检测器
Args:
model_path: 部件检测模型路径如果为None则使用预训练通用模型后处理
device: 运行设备,可以是'cuda''cpu'None则自动选择
"""
# 设置模型路径
if model_path is None:
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
model_path = os.path.join(parent_dir, 'yolov8s-seg.pt')
self.model_path = model_path
# 设置设备
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
# 舰船部件类型映射表 - 扩展版本
self.part_types = {
# 航空母舰特有部件
"舰岛": {"中文名": "舰岛", "英文名": "Island", "描述": "航母上的指挥塔和控制中心"},
"飞行甲板": {"中文名": "飞行甲板", "英文名": "Flight Deck", "描述": "舰载机起降的平台"},
"升降机": {"中文名": "升降机", "英文名": "Elevator", "描述": "将舰载机运送到飞行甲板的电梯"},
"弹射器": {"中文名": "弹射器", "英文名": "Catapult", "描述": "协助舰载机起飞的装置"},
"阻拦索": {"中文名": "阻拦索", "英文名": "Arresting Wire", "描述": "帮助舰载机着陆减速的钢缆"},
"舰载机": {"中文名": "舰载机", "英文名": "Aircraft", "描述": "部署在航母上的飞机"},
# 驱逐舰特有部件
"舰炮": {"中文名": "舰炮", "英文名": "Naval Gun", "描述": "用于对海/对陆/对空射击的火炮"},
"垂直发射系统": {"中文名": "垂直发射系统", "英文名": "VLS", "描述": "用于发射导弹的垂直发射装置"},
"直升机平台": {"中文名": "直升机平台", "英文名": "Helicopter Deck", "描述": "供直升机起降的平台"},
"鱼雷发射管": {"中文名": "鱼雷发射管", "英文名": "Torpedo Tubes", "描述": "用于发射鱼雷的装置"},
# 通用部件
"舰桥": {"中文名": "舰桥", "英文名": "Bridge", "描述": "舰船的指挥控制中心"},
"雷达": {"中文名": "雷达", "英文名": "Radar", "描述": "探测目标的电子设备"},
"通信天线": {"中文名": "通信天线", "英文名": "Communication Antenna", "描述": "用于通信的天线装置"},
"烟囱": {"中文名": "烟囱", "英文名": "Funnel", "描述": "排放发动机废气的烟囱"},
"近防武器系统": {"中文名": "近防武器系统", "英文名": "CIWS", "描述": "防御导弹和飞机的近程武器系统"},
"救生艇": {"中文名": "救生艇", "英文名": "Lifeboat", "描述": "紧急情况下用于撤离的小艇"},
"": {"中文名": "", "英文名": "Anchor", "描述": "固定舰船位置的装置"},
"甲板": {"中文名": "甲板", "英文名": "Deck", "描述": "舰船的水平表面"},
"舷窗": {"中文名": "舷窗", "英文名": "Porthole", "描述": "舰船侧面的窗户"},
"机库": {"中文名": "机库", "英文名": "Hangar", "描述": "存放飞机或直升机的区域"},
"装甲板": {"中文名": "装甲板", "英文名": "Armored Deck", "描述": "加强保护的甲板"},
"探照灯": {"中文名": "探照灯", "英文名": "Searchlight", "描述": "用于夜间照明的强光灯"},
"声呐": {"中文名": "声呐", "英文名": "Sonar", "描述": "水下探测设备"},
"导弹发射器": {"中文名": "导弹发射器", "英文名": "Missile Launcher", "描述": "发射导弹的装置"},
"防空导弹": {"中文名": "防空导弹", "英文名": "Anti-air Missile", "描述": "用于防空的导弹系统"},
"反舰导弹": {"中文名": "反舰导弹", "英文名": "Anti-ship Missile", "描述": "用于攻击舰船的导弹"},
"电子战设备": {"中文名": "电子战设备", "英文名": "Electronic Warfare Equipment", "描述": "用于电子干扰和反干扰的设备"}
}
# 加载模型
try:
from ultralytics import YOLO
self.model = YOLO(self.model_path)
self.model_loaded = True
print(f"成功加载部件检测模型: {self.model_path}")
except Exception as e:
print(f"加载部件检测模型出错: {e}")
self.model = None
self.model_loaded = False
# 舰船类型与特有部件映射
self.ship_parts_map = {
"航空母舰": ["舰岛", "飞行甲板", "升降机", "弹射器", "阻拦索", "舰载机", "舰桥", "雷达", "通信天线"],
"驱逐舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "烟囱", "近防武器系统"],
"护卫舰": ["舰炮", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统", "声呐"],
"巡洋舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统"],
"潜艇": ["舰塔", "鱼雷发射管", "垂直发射系统", "通信天线", "潜望镜"],
"两栖攻击舰": ["飞行甲板", "舰岛", "直升机平台", "船坞", "舰桥", "雷达", "通信天线"]
}
print(f"使用设备: {self.device}")
# 启用通用部件检测
self.enable_generic_parts = True
# 部件类别映射 (这里的ID应该对应训练好的模型中的类别ID)
self.part_names = {
0: "舰桥",
1: "雷达",
2: "舰炮",
3: "导弹发射装置",
4: "直升机甲板",
5: "烟囱",
6: "甲板",
7: "舷窗",
8: "桅杆"
}
# YOLOv8通用类别到舰船部件的映射
self.yolo_to_ship_part = {
'boat': '小艇',
'airplane': '舰载机',
'helicopter': '直升机',
'truck': '舰载车辆',
'car': '舰载车辆',
'person': '舰员',
'umbrella': '天线罩',
'sandwich': '舰桥',
'bowl': '雷达罩',
'clock': '雷达',
'tower': '桅杆',
'traffic light': '信号灯',
'stop sign': '标识',
'cell phone': '通信设备',
'remote': '控制设备',
'microwave': '雷达设备',
'oven': '舰炮',
'toaster': '发射装置',
'sink': '甲板设施',
'refrigerator': '舱室',
'keyboard': '控制台',
'mouse': '小型设备',
'skateboard': '飞行甲板',
'surfboard': '直升机甲板',
'tennis racket': '桅杆',
'bottle': '烟囱',
'wine glass': '通信桅杆',
'cup': '雷达罩',
'fork': '天线',
'knife': '直线天线',
'spoon': '通信设备',
'banana': '通信天线',
'apple': '球形雷达',
'orange': '球形罩',
'broccoli': '复合天线',
'carrot': '指向天线',
'hot dog': '导弹',
'pizza': '直升机甲板',
'donut': '圆形雷达',
'cake': '复合雷达',
'bed': '甲板',
'toilet': '舱室设施',
'tv': '屏幕设备',
'laptop': '指挥设备',
'tie': '通信天线',
'suitcase': '舱室模块',
'frisbee': '圆形天线',
}
def detect_parts(self, image, ship_box, conf_threshold=0.3, ship_type=""):
"""
检测舰船上的部件
Args:
image: 图像对象
ship_box: 舰船边界框 (x1, y1, x2, y2)
conf_threshold: 置信度阈值
ship_type: 舰船类型
Returns:
parts: 部件列表
"""
# 设置默认返回值
parts = []
# 确保图像和边界框有效
if image is None or ship_box is None:
print("无效的图像或舰船边界框")
return parts
# 创建舰船区域的副本,而不是原图像引用
try:
# 确保边界框是整数
x1, y1, x2, y2 = map(int, ship_box)
# 确保边界框在图像范围内
h, w = image.shape[:2]
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(w, x2)
y2 = min(h, y2)
# 提取舰船区域
ship_img = image[y1:y2, x1:x2].copy()
# 检查提取的图像是否有效
if ship_img.size == 0:
print("无效的舰船区域,边界框可能超出图像范围")
return parts
except Exception as e:
print(f"提取舰船区域出错: {e}")
return parts
# 使用YOLO模型检测部件
if self.model is not None:
try:
# 根据舰船类型过滤目标部件
target_parts = self._get_target_parts(ship_type)
# 使用YOLO模型进行检测
results = self.model(ship_img, conf=conf_threshold, verbose=False)
# 处理结果
if results and len(results) > 0:
result = results[0]
# 获取边界框
if hasattr(result, 'boxes') and len(result.boxes) > 0:
boxes = result.boxes
for i, box in enumerate(boxes):
# 检查类别
cls_id = int(box.cls.item()) if hasattr(box, 'cls') else -1
# 确定部件类型
part_name = self._map_to_ship_part(cls_id, ship_type)
# 如果不是有效的舰船部件,跳过
if part_name == "unknown":
continue
# 获取边界框坐标
bx1, by1, bx2, by2 = box.xyxy[0].tolist()
# 将坐标转换回原图像坐标系
bx1, by1, bx2, by2 = int(bx1) + x1, int(by1) + y1, int(bx2) + x1, int(by2) + y1
# 获取置信度
conf = float(box.conf.item()) if hasattr(box, 'conf') else 0.5
# 添加到部件列表
part = {
'name': part_name,
'bbox': (bx1, by1, bx2, by2),
'confidence': conf
}
parts.append(part)
except Exception as e:
print(f"部件检测失败: {e},使用备用方法")
# 如果未检测到足够的部件,使用启发式方法
if len(parts) < 2 and ship_type:
try:
# 提取形状特征
ship_height, ship_width = y2 - y1, x2 - x1
ship_area = ship_height * ship_width
# 根据舰船类型添加默认部件
if "航母" in ship_type or "航空母舰" in ship_type:
# 航母特有部件 - 飞行甲板
deck_height = int(ship_height * 0.15)
deck_width = ship_width
deck_x1 = x1
deck_y1 = y1 + int(ship_height * 0.05)
deck_x2 = x1 + deck_width
deck_y2 = deck_y1 + deck_height
parts.append({
'name': "飞行甲板",
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
'confidence': 0.85
})
# 舰岛
island_width = int(ship_width * 0.25)
island_height = int(ship_height * 0.3)
island_x1 = x1 + int(ship_width * 0.6)
island_y1 = y1 + int(ship_height * 0.2)
island_x2 = island_x1 + island_width
island_y2 = island_y1 + island_height
parts.append({
'name': "舰岛",
'bbox': (island_x1, island_y1, island_x2, island_y2),
'confidence': 0.8
})
elif "驱逐舰" in ship_type:
# 驱逐舰特有部件 - 舰桥
bridge_width = int(ship_width * 0.2)
bridge_height = int(ship_height * 0.4)
bridge_x1 = x1 + int(ship_width * 0.4)
bridge_y1 = y1 + int(ship_height * 0.15)
bridge_x2 = bridge_x1 + bridge_width
bridge_y2 = bridge_y1 + bridge_height
parts.append({
'name': "舰桥",
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
'confidence': 0.8
})
# 垂发系统
vls_width = int(ship_width * 0.15)
vls_height = int(ship_height * 0.2)
vls_x1 = x1 + int(ship_width * 0.2)
vls_y1 = y1 + int(ship_height * 0.25)
vls_x2 = vls_x1 + vls_width
vls_y2 = vls_y1 + vls_height
parts.append({
'name': "垂发系统",
'bbox': (vls_x1, vls_y1, vls_x2, vls_y2),
'confidence': 0.7
})
# 主炮
gun_width = int(ship_width * 0.1)
gun_height = int(ship_height * 0.15)
gun_x1 = x1 + int(ship_width * 0.05)
gun_y1 = y1 + int(ship_height * 0.3)
gun_x2 = gun_x1 + gun_width
gun_y2 = gun_y1 + gun_height
parts.append({
'name': "主炮",
'bbox': (gun_x1, gun_y1, gun_x2, gun_y2),
'confidence': 0.75
})
elif "护卫舰" in ship_type:
# 护卫舰特有部件
# 舰桥
bridge_width = int(ship_width * 0.18)
bridge_height = int(ship_height * 0.35)
bridge_x1 = x1 + int(ship_width * 0.35)
bridge_y1 = y1 + int(ship_height * 0.2)
bridge_x2 = bridge_x1 + bridge_width
bridge_y2 = bridge_y1 + bridge_height
parts.append({
'name': "舰桥",
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
'confidence': 0.8
})
# 直升机甲板
heli_width = int(ship_width * 0.25)
heli_height = int(ship_height * 0.25)
heli_x1 = x1 + int(ship_width * 0.7)
heli_y1 = y1 + int(ship_height * 0.25)
heli_x2 = heli_x1 + heli_width
heli_y2 = heli_y1 + heli_height
parts.append({
'name': "直升机甲板",
'bbox': (heli_x1, heli_y1, heli_x2, heli_y2),
'confidence': 0.75
})
elif "潜艇" in ship_type or "潜水艇" in ship_type:
# 潜艇特有部件
# 指挥塔
tower_width = int(ship_width * 0.15)
tower_height = int(ship_height * 0.4)
tower_x1 = x1 + int(ship_width * 0.4)
tower_y1 = y1 + int(ship_height * 0.1)
tower_x2 = tower_x1 + tower_width
tower_y2 = tower_y1 + tower_height
parts.append({
'name': "指挥塔",
'bbox': (tower_x1, tower_y1, tower_x2, tower_y2),
'confidence': 0.8
})
else:
# 通用舰船部件
# 舰桥
bridge_width = int(ship_width * 0.2)
bridge_height = int(ship_height * 0.35)
bridge_x1 = x1 + int(ship_width * 0.4)
bridge_y1 = y1 + int(ship_height * 0.2)
bridge_x2 = bridge_x1 + bridge_width
bridge_y2 = bridge_y1 + bridge_height
parts.append({
'name': "舰桥",
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
'confidence': 0.8
})
# 雷达
radar_width = int(ship_width * 0.1)
radar_height = int(ship_height * 0.15)
radar_x1 = x1 + int(ship_width * 0.45)
radar_y1 = y1 + int(ship_height * 0.05)
radar_x2 = radar_x1 + radar_width
radar_y2 = radar_y1 + radar_height
parts.append({
'name': "雷达",
'bbox': (radar_x1, radar_y1, radar_x2, radar_y2),
'confidence': 0.7
})
# 甲板
deck_width = int(ship_width * 0.8)
deck_height = int(ship_height * 0.25)
deck_x1 = x1 + int(ship_width * 0.1)
deck_y1 = y1 + int(ship_height * 0.6)
deck_x2 = deck_x1 + deck_width
deck_y2 = deck_y1 + deck_height
parts.append({
'name': "甲板",
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
'confidence': 0.75
})
except Exception as e:
except Exception as e:
return parts
def _map_yolo_class_to_ship_part(self, cls_name, ship_type=""):
"""将YOLO类别名称映射到舰船部件名称"""
# 通用物体到舰船部件的映射
mapping = {
'person': '人员',
'bicycle': '小型设备',
'car': '小型设备',
'motorcycle': '小型设备',
'airplane': '舰载机',
'bus': '车辆',
'train': '小型设备',
'truck': '车辆',
'boat': '小艇',
'traffic light': '信号灯',
'fire hydrant': '消防设备',
'stop sign': '标志牌',
'parking meter': '小型设备',
'bench': '设备',
'bird': '无人机',
'cat': '小型设备',
'dog': '小型设备',
'horse': '小型设备',
'sheep': '小型设备',
'cow': '小型设备',
'elephant': '大型设备',
'bear': '大型设备',
'zebra': '小型设备',
'giraffe': '高大设备',
'backpack': '设备',
'umbrella': '小型设备',
'handbag': '设备',
'tie': '小型设备',
'suitcase': '设备',
'frisbee': '小型设备',
'skis': '小型设备',
'snowboard': '小型设备',
'sports ball': '小型设备',
'kite': '无人机',
'baseball bat': '小型设备',
'baseball glove': '小型设备',
'skateboard': '小型设备',
'surfboard': '救生设备',
'tennis racket': '小型设备',
'bottle': '设备',
'wine glass': '设备',
'cup': '设备',
'fork': '设备',
'knife': '设备',
'spoon': '设备',
'bowl': '设备',
'banana': '补给',
'apple': '补给',
'sandwich': '补给',
'orange': '补给',
'broccoli': '补给',
'carrot': '补给',
'hot dog': '补给',
'pizza': '补给',
'donut': '补给',
'cake': '补给',
'chair': '设备',
'couch': '设备',
'potted plant': '设备',
'bed': '设备',
'dining table': '设备',
'toilet': '设施',
'tv': '显示设备',
'laptop': '计算设备',
'mouse': '小型设备',
'remote': '控制设备',
'keyboard': '控制设备',
'cell phone': '通信设备',
'microwave': '设备',
'oven': '设备',
'toaster': '设备',
'sink': '设施',
'refrigerator': '设备',
'book': '资料',
'clock': '设备',
'vase': '设备',
'scissors': '工具',
'teddy bear': '设备',
'hair drier': '设备',
'toothbrush': '设备'
}
# 针对特定舰船类型的映射优化
if "航母" in ship_type or "航空母舰" in ship_type:
if cls_name == 'airplane':
return "舰载机"
elif cls_name in ['tv', 'laptop', 'cell phone']:
return "雷达设备"
elif cls_name in ['tower', 'building']:
return "舰岛"
elif "驱逐舰" in ship_type or "护卫舰" in ship_type:
if cls_name in ['tv', 'laptop']:
return "相控阵雷达"
elif cls_name in ['tower', 'building']:
return "舰桥"
elif cls_name in ['remote', 'bottle']:
return "导弹装置"
# 默认返回映射或原始类别名称
return mapping.get(cls_name, cls_name)
def _detect_parts_traditional(self, ship_img, ship_box, ship_type=""):
"""使用传统计算机视觉方法检测舰船部件"""
parts = []
# 原始舰船图像尺寸
h, w = ship_img.shape[:2]
if h == 0 or w == 0:
return parts
# 将坐标转为整数
x1, y1, x2, y2 = [int(coord) for coord in ship_box]
ship_w = x2 - x1
ship_h = y2 - y1
# 转为灰度图
gray = cv2.cvtColor(ship_img, cv2.COLOR_BGR2GRAY) if len(ship_img.shape) == 3 else ship_img.copy()
# 提取边缘
edges = cv2.Canny(gray, 50, 150)
_, thresh = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY)
# 寻找轮廓
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 分析轮廓并识别部件
for contour in contours:
# 计算轮廓面积和边界框
area = cv2.contourArea(contour)
if area < (w * h * 0.005): # 忽略太小的轮廓
continue
# 计算边界框
x, y, box_width, box_height = cv2.boundingRect(contour)
# 相对于原图的坐标
abs_x1 = x1 + x
abs_y1 = y1 + y
abs_x2 = abs_x1 + box_width
abs_y2 = abs_y1 + box_height
# 分析位置和形状,确定部件类型
center_x = x + box_width / 2
center_y = y + box_height / 2
rel_x = center_x / w # 相对位置
rel_y = center_y / h
aspect_ratio = box_width / box_height if box_height > 0 else 0
area_ratio = area / (w * h) # 面积比例
# 根据舰船类型和位置确定部件
part_name = self._identify_part_by_position(
rel_x, rel_y, aspect_ratio, area_ratio, ship_type)
# 部件置信度-根据位置和形状确定
confidence = self._calculate_part_confidence(
rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type)
# 定义绘制中文的函数
def draw_cn_text(img, text, position, font_scale=0.5, color=(0, 255, 0), thickness=2):
"""使用PIL绘制中文文本并转回OpenCV格式"""
from PIL import Image, ImageDraw, ImageFont
import numpy as np
# 转换为PIL图像
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
# 获取系统默认字体或指定中文字体
try:
# 尝试使用微软雅黑字体 (Windows)
font = ImageFont.truetype("msyh.ttc", int(font_scale * 20))
except:
try:
# 尝试使用宋体 (Windows)
font = ImageFont.truetype("simsun.ttc", int(font_scale * 20))
except:
try:
# 尝试使用WenQuanYi (Linux)
font = ImageFont.truetype("wqy-microhei.ttc", int(font_scale * 20))
except:
# 使用系统默认字体
font = ImageFont.load_default()
# 绘制文本
draw.text(position, text, font=font, fill=color[::-1]) # RGB顺序
# 转回OpenCV格式
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# 如果提供了舰船边界框,则只处理该区域
if ship_box is not None:
x1, y1, x2, y2 = ship_box
# 确保边界在图像范围内
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2)
roi = img[y1:y2, x1:x2]
else:
roi = img
x1, y1 = 0, 0
# 检查ROI有效性
if roi is None or roi.size == 0 or roi.shape[0] <= 0 or roi.shape[1] <= 0:
return [], img.copy() # 返回空部件列表和原始图像
# 获取ROI尺寸
roi_h, roi_w = roi.shape[:2]
# 如果未提供舰船类型,尝试推测
if ship_type is None or ship_type == "":
ship_type = self._infer_ship_type(roi).lower()
print(f"推测舰船类型: {ship_type}")
else:
ship_type = ship_type.lower()
print(f"使用提供的舰船类型: {ship_type}")
# 初始化部件列表和结果图像
parts = []
result_img = img.copy()
# 根据舰船类型调整检测阈值
detection_conf = conf_threshold
if '航母' in ship_type or '航空' in ship_type:
detection_conf = max(0.05, conf_threshold * 0.5) # 航母需要更低的阈值
elif '驱逐' in ship_type:
detection_conf = max(0.1, conf_threshold * 0.7) # 驱逐舰稍微降低阈值
# 进行部件检测
results = self.model(roi, conf=detection_conf, device=self.device)
# 处理检测结果
if len(results) > 0:
# 获取边界框
boxes = results[0].boxes
# 如果有分割结果,也使用它来提高识别精度
if hasattr(results[0], 'masks') and results[0].masks is not None:
masks = results[0].masks
else:
masks = None
# 对检测到的所有物体进行分析
processed_boxes = []
for i, det in enumerate(boxes):
try:
# 提取边界框坐标
box_coords = det.xyxy[0].cpu().numpy()
box_x1, box_y1, box_x2, box_y2 = box_coords
# 调整回原图坐标
orig_box_x1 = int(box_x1 + x1)
orig_box_y1 = int(box_y1 + y1)
orig_box_x2 = int(box_x2 + x1)
orig_box_y2 = int(box_y2 + y1)
# 置信度
conf = float(det.conf[0].cpu().numpy())
# 类别ID
cls_id = int(det.cls[0].cpu().numpy())
orig_class_name = self.model.names[cls_id]
# 根据位置和尺寸确定部件类型 - 考虑舰船类型
part_name = self._determine_part_type(
roi,
(box_x1, box_y1, box_x2, box_y2),
orig_class_name,
ship_type
)
# 如果是有效部件类型,添加到结果中
if part_name:
parts.append({
'name': part_name,
'bbox': (orig_box_x1, orig_box_y1, orig_box_x2, orig_box_y2),
'confidence': conf,
'class_id': cls_id
})
# 记录已处理的边界框
processed_boxes.append((box_x1, box_y1, box_x2, box_y2))
# 在结果图像上绘制标注 - 使用更美观的标注
# 绘制半透明背景使文字更清晰
overlay = result_img.copy()
cv2.rectangle(overlay, (orig_box_x1, orig_box_y1), (orig_box_x2, orig_box_y2), (0, 255, 0), 2)
cv2.rectangle(overlay, (orig_box_x1, orig_box_y1-25), (orig_box_x1 + len(part_name)*12, orig_box_y1), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
# 绘制标签文字
label_text = f"{part_name}: {conf:.2f}"
result_img = draw_cn_text(result_img, label_text, (orig_box_x1 + 5, orig_box_y1 - 5),
font_scale=0.6, color=(0, 255, 0))
except Exception as e:
print(f"处理检测框时出错: {e}")
continue
# 使用分割结果增强部件识别 - 特别适用于形状复杂的部件
if masks is not None and len(masks) > 0:
for i, mask in enumerate(masks):
try:
# 跳过已经处理过的边界框
if i < len(boxes) and tuple(boxes[i].xyxy[0].cpu().numpy()) in processed_boxes:
continue
# 提取分割掩码
mask_array = mask.data[0].cpu().numpy()
# 计算掩码边界框
mask_positions = np.where(mask_array > 0.5)
if len(mask_positions[0]) == 0 or len(mask_positions[1]) == 0:
continue
min_y, max_y = np.min(mask_positions[0]), np.max(mask_positions[0])
min_x, max_x = np.min(mask_positions[1]), np.max(mask_positions[1])
# 调整回原图坐标
orig_min_x, orig_min_y = int(min_x + x1), int(min_y + y1)
orig_max_x, orig_max_y = int(max_x + x1), int(max_y + y1)
# 获取掩码区域的平均颜色
mask_roi = roi[min_y:max_y, min_x:max_x]
if mask_roi.size == 0:
continue
# 根据位置和形状分析识别部件类型
box_width = max_x - min_x
box_height = max_y - min_y
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2
# 根据舰船类型和位置判断部件类型
part_name = None
if '航母' in ship_type or '航空' in ship_type:
# 航母特有部件
if center_y < roi_h * 0.3 and center_x > roi_w * 0.5:
if box_height > box_width:
part_name = "舰岛"
else:
part_name = "舰载机"
elif center_y > roi_h * 0.4 and box_width > roi_w * 0.3:
part_name = "飞行甲板"
elif box_width < roi_w * 0.1 and box_height < roi_h * 0.1:
part_name = "舰载机"
elif '驱逐' in ship_type:
# 驱逐舰特有部件
if center_y < roi_h * 0.3 and box_width < roi_w * 0.2:
if box_height > box_width:
part_name = "桅杆"
else:
part_name = "相控阵雷达"
elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1:
part_name = "舰桥"
elif center_x < roi_w * 0.3:
part_name = "主炮"
elif center_x > roi_w * 0.6 and center_y < roi_h * 0.5:
part_name = "垂发系统"
else:
# 通用部件识别
if center_y < roi_h * 0.3 and box_width < roi_w * 0.2:
if box_height > box_width * 1.5:
part_name = "桅杆"
else:
part_name = "雷达"
elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1:
part_name = "舰桥"
elif center_x < roi_w * 0.3 and box_width < roi_w * 0.15:
part_name = "舰炮"
elif center_x > roi_w * 0.7 and box_width > roi_w * 0.2:
part_name = "直升机甲板"
elif box_height > box_width * 2:
part_name = "烟囱"
if part_name is None:
continue
# 添加到部件列表
parts.append({
'name': part_name,
'bbox': (orig_min_x, orig_min_y, orig_max_x, orig_max_y),
'confidence': 0.7, # 分割结果的默认置信度
'class_id': -1 # 使用-1表示分割结果
})
# 在结果图像上绘制标注
overlay = result_img.copy()
cv2.rectangle(overlay, (orig_min_x, orig_min_y), (orig_max_x, orig_max_y), (0, 255, 0), 2)
cv2.rectangle(overlay, (orig_min_x, orig_min_y-25), (orig_min_x + len(part_name)*12, orig_min_y), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
label_text = f"{part_name}: 0.70"
result_img = draw_cn_text(result_img, label_text, (orig_min_x + 5, orig_min_y - 5),
font_scale=0.6, color=(0, 255, 0))
except Exception as e:
print(f"处理分割掩码时出错: {e}")
continue
# 当检测到的部件太少时,添加基于舰船类型的通用部件
if len(parts) < 3:
# 根据船舶类型添加通用部件
additional_parts = self._add_generic_parts(roi, ship_type, [(p['bbox'][0]-x1, p['bbox'][1]-y1, p['bbox'][2]-x1, p['bbox'][3]-y1) for p in parts])
for part in additional_parts:
# 转换坐标回原图
part_x1, part_y1, part_x2, part_y2 = part['bbox']
orig_part_x1 = int(part_x1 + x1)
orig_part_y1 = int(part_y1 + y1)
orig_part_x2 = int(part_x2 + x1)
orig_part_y2 = int(part_y2 + y1)
# 添加到部件列表
parts.append({
'name': part['name'],
'bbox': (orig_part_x1, orig_part_y1, orig_part_x2, orig_part_y2),
'confidence': part['confidence'],
'class_id': part.get('class_id', -2) # 使用-2表示启发式生成的部件
})
# 在结果图像上绘制标注
overlay = result_img.copy()
cv2.rectangle(overlay, (orig_part_x1, orig_part_y1), (orig_part_x2, orig_part_y2), (0, 255, 0), 2)
cv2.rectangle(overlay, (orig_part_x1, orig_part_y1-25), (orig_part_x1 + len(part['name'])*12, orig_part_y1), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
label_text = f"{part['name']}: {part['confidence']:.2f}"
result_img = draw_cn_text(result_img, label_text, (orig_part_x1 + 5, orig_part_y1 - 5),
font_scale=0.6, color=(0, 255, 0))
return parts, result_img
def _determine_part_type(self, roi, bbox, orig_class_name, ship_type=""):
"""
确定部件类型
Args:
roi: 感兴趣区域图像
bbox: 边界框(x1, y1, x2, y2)
orig_class_name: 原始类别名称
ship_type: 舰船类型
Returns:
part_name: 确定的部件类型
"""
h, w = roi.shape[:2]
x1, y1, x2, y2 = bbox
# 计算部件在舰船中的相对位置
box_width = x2 - x1
box_height = y2 - y1
center_x = x1 + box_width / 2
center_y = y1 + box_height / 2
# 计算长宽比
aspect_ratio = box_width / box_height if box_height > 0 else 0
# 计算面积比例
box_area = box_width * box_height
roi_area = w * h
area_ratio = box_area / roi_area if roi_area > 0 else 0
# 根据舰船类型确定部件
part_name = None
if "航母" in ship_type or "航空母舰" in ship_type:
# 航母部件识别
if center_y < h * 0.3:
if center_x > w * 0.5:
if box_width > box_height and area_ratio > 0.05:
part_name = "舰岛"
elif area_ratio < 0.03:
part_name = "雷达"
elif box_height > box_width * 1.5 and area_ratio < 0.02:
part_name = "桅杆"
elif area_ratio < 0.03 and box_width > box_height:
part_name = "相控阵雷达"
elif center_y > h * 0.4 and center_y < h * 0.8 and center_x > w * 0.1 and center_x < w * 0.9:
if area_ratio > 0.2:
part_name = "飞行甲板"
elif box_width > box_height * 3 and area_ratio < 0.1:
part_name = "弹射器"
elif box_width > box_height and area_ratio < 0.05:
part_name = "舰载机"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "烟囱"
elif "驱逐舰" in ship_type or "巡洋舰" in ship_type:
# 驱逐舰/巡洋舰部件识别
if center_y < h * 0.35:
if box_width > w * 0.1 and area_ratio > 0.08:
part_name = "舰桥"
elif box_height > box_width * 1.5 and area_ratio < 0.02:
part_name = "桅杆"
elif area_ratio < 0.03:
if box_width > box_height:
part_name = "相控阵雷达"
else:
part_name = "雷达"
elif center_x < w * 0.3 and center_y < h * 0.5:
if box_width < w * 0.15 and box_height < h * 0.15:
part_name = "主炮"
elif center_x > w * 0.6 and center_y < h * 0.6 and box_width < w * 0.2:
if box_width > box_height and area_ratio < 0.05:
part_name = "垂发系统"
elif center_x > w * 0.7 and center_y > h * 0.6 and area_ratio > 0.05:
part_name = "直升机甲板"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "烟囱"
elif "护卫舰" in ship_type:
# 护卫舰部件识别
if center_y < h * 0.4:
if box_width > w * 0.1 and area_ratio > 0.08:
part_name = "舰桥"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "桅杆"
elif area_ratio < 0.03:
part_name = "雷达"
elif rel_x < 0.3 and box_width < w * 0.15:
part_name = "舰炮"
elif rel_x > 0.7 and area_ratio > 0.05:
part_name = "直升机甲板"
elif rel_x > 0.6 and area_ratio < 0.04:
part_name = "导弹发射器"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "烟囱"
elif "潜艇" in ship_type:
# 潜艇部件识别
if center_y < h * 0.3 and area_ratio > 0.05:
part_name = "塔台"
elif center_x < w * 0.3 and area_ratio < 0.04:
part_name = "鱼雷管"
elif center_y < h * 0.2 and area_ratio < 0.02:
part_name = "通信天线"
elif center_x > w * 0.6 and center_y < h * 0.4 and area_ratio < 0.03:
part_name = "潜望镜"
elif center_y > h * 0.6 and center_x < w * 0.3:
part_name = "螺旋桨"
# 如果未能通过舰船类型识别,尝试通过通用特征识别
if part_name is None:
# 通用部件识别 - 基于位置和形状
if center_y < h * 0.3:
if box_width > w * 0.2 and area_ratio > 0.08:
part_name = "舰桥"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "桅杆"
elif area_ratio < 0.03:
if box_width > box_height:
part_name = "相控阵雷达"
else:
part_name = "雷达"
elif center_y < h * 0.6:
if center_x < w * 0.25 and area_ratio < 0.05:
part_name = "舰炮"
elif box_height > box_width * 1.5 and area_ratio < 0.03:
part_name = "烟囱"
elif aspect_ratio > 1.5 and area_ratio < 0.05 and center_x > w * 0.6:
part_name = "导弹发射器"
elif center_y > h * 0.6:
if rel_x > 0.7 and area_ratio > 0.05:
part_name = "直升机甲板"
elif area_ratio > 0.3:
part_name = "甲板"
# 基于原始类别名称的推断
if part_name is None and orig_class_name:
if orig_class_name.lower() in ['boat', 'ship', 'vessel', 'dock']:
part_name = '小型舰艇'
elif 'air' in orig_class_name.lower() or 'plane' in orig_class_name.lower():
part_name = '舰载机'
elif 'tower' in orig_class_name.lower() or 'pole' in orig_class_name.lower():
part_name = '桅杆'
elif 'radar' in orig_class_name.lower() or 'antenna' in orig_class_name.lower():
part_name = '雷达'
elif 'gun' in orig_class_name.lower() or 'cannon' in orig_class_name.lower():
part_name = '舰炮'
# 如果面积太小,可能是噪声
if part_name is None and area_ratio < 0.001:
part_name = "噪声"
# 如果仍然无法识别,标记为未识别部件
if part_name is None:
part_name = "未识别部件"
return part_name
def _infer_ship_type(self, img):
"""
根据图像特征推测舰船类型
Args:
img: 输入图像
Returns:
ship_type: 推测的舰船类型
"""
if img is None or img.size == 0 or img.shape[0] == 0 or img.shape[1] == 0:
return "未知舰船"
# 获取图像尺寸
height, width = img.shape[:2]
aspect_ratio = width / height if height > 0 else 0
# 转为灰度图
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img
# 边缘检测
edges = cv2.Canny(gray, 50, 150)
edge_pixels = cv2.countNonZero(edges)
edge_density = edge_pixels / (width * height) if width * height > 0 else 0
# 水平线特征检测 - 对航母重要
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1))
horizontal_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, horizontal_kernel)
horizontal_pixels = cv2.countNonZero(horizontal_lines)
horizontal_ratio = horizontal_pixels / (width * height) if width * height > 0 else 0
# 垂直线特征检测 - 对驱逐舰重要
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 15))
vertical_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, vertical_kernel)
vertical_pixels = cv2.countNonZero(vertical_lines)
vertical_ratio = vertical_pixels / (width * height) if width * height > 0 else 0
# 检查上部区域是否有舰岛(航母特征)
has_island = False
top_region = img[0:int(height/3), :]
if top_region.size > 0:
top_gray = cv2.cvtColor(top_region, cv2.COLOR_BGR2GRAY) if len(top_region.shape) == 3 else top_region
_, top_thresh = cv2.threshold(top_gray, 100, 255, cv2.THRESH_BINARY)
top_contours, _ = cv2.findContours(top_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in top_contours:
area = cv2.contourArea(contour)
if area > (top_region.shape[0] * top_region.shape[1] * 0.01):
x, y, w, h = cv2.boundingRect(contour)
# 舰岛通常高于宽
if h > w and h > top_region.shape[0] * 0.3:
has_island = True
break
# 航空母舰特征:长宽比大,水平线丰富,有舰岛
if (aspect_ratio > 3.0 or horizontal_ratio > 0.05) and edge_density < 0.15:
if has_island or aspect_ratio > 3.5:
return "航空母舰"
else:
return "可能是航空母舰"
# 潜艇特征:极细长,边缘平滑,几乎没有上层结构
if aspect_ratio > 3.5 and edge_density < 0.1 and vertical_ratio < 0.02:
return "潜艇"
# 驱逐舰特征:中等长宽比,边缘复杂,有明显的上层结构
if 2.2 < aspect_ratio < 3.8 and edge_density > 0.1 and vertical_ratio > 0.02:
return "驱逐舰"
# 护卫舰特征:较小长宽比,结构适中
if 2.0 < aspect_ratio < 3.0 and 0.1 < edge_density < 0.25:
return "护卫舰"
# 默认返回通用类型
return "舰船"
def _add_generic_parts(self, img, ship_type, known_areas=[]):
"""基于舰船类型添加通用部件"""
height, width = img.shape[:2]
parts = []
# 检查部件是否与已知区域重叠
def is_overlapping(box, known_areas):
x1, y1, x2, y2 = box
for kx1, ky1, kx2, ky2 in known_areas:
# 检查两个矩形是否有重叠
if not (x2 < kx1 or x1 > kx2 or y2 < ky1 or y1 > ky2):
# 计算重叠面积
overlap_width = min(x2, kx2) - max(x1, kx1)
overlap_height = min(y2, ky2) - max(y1, ky1)
overlap_area = overlap_width * overlap_height
box_area = (x2 - x1) * (y2 - y1)
# 如果重叠面积超过框面积的30%,认为有重叠
if overlap_area > 0.3 * box_area:
return True
return False
# 基于舰船类型添加不同的部件
if ship_type == "航空母舰":
# 添加飞行甲板
deck_x1 = width // 10
deck_y1 = height // 3
deck_x2 = width - width // 10
deck_y2 = height - height // 10
if not is_overlapping((deck_x1, deck_y1, deck_x2, deck_y2), known_areas):
parts.append({
'name': "飞行甲板",
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
'confidence': 0.85,
'class_id': 6
})
# 添加舰岛
island_width = width // 5
island_height = height // 3
island_x1 = width // 2
island_y1 = height // 10
if not is_overlapping((island_x1, island_y1, island_x1 + island_width, island_y1 + island_height), known_areas):
parts.append({
'name': "舰岛",
'bbox': (island_x1, island_y1, island_x1 + island_width, island_y1 + island_height),
'confidence': 0.8,
'class_id': 0
})
# 添加弹射器
catapult_width = width // 3
catapult_height = height // 20
catapult_x1 = width // 10
catapult_y1 = height // 3
if not is_overlapping((catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height), known_areas):
parts.append({
'name': "弹射器",
'bbox': (catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height),
'confidence': 0.75,
'class_id': 3
})
elif ship_type == "驱逐舰" or ship_type == "护卫舰":
# 添加舰桥
bridge_width = width // 4
bridge_height = height // 3
bridge_x1 = width // 2 - bridge_width // 2
bridge_y1 = height // 10
if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas):
parts.append({
'name': "舰桥",
'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height),
'confidence': 0.85,
'class_id': 0
})
# 添加主炮
gun_width = width // 10
gun_height = height // 10
gun_x1 = width // 10
gun_y1 = height // 5
if not is_overlapping((gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height), known_areas):
parts.append({
'name': "主炮",
'bbox': (gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height),
'confidence': 0.8,
'class_id': 2
})
# 添加导弹垂发系统
vls_width = width // 8
vls_height = height // 8
vls_x1 = width // 3
vls_y1 = height // 3
if not is_overlapping((vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height), known_areas):
parts.append({
'name': "垂发系统",
'bbox': (vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height),
'confidence': 0.75,
'class_id': 3
})
# 添加直升机甲板
heli_width = width // 5
heli_height = height // 5
heli_x1 = width * 3 // 4
heli_y1 = height // 2
if not is_overlapping((heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height), known_areas):
parts.append({
'name': "直升机甲板",
'bbox': (heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height),
'confidence': 0.7,
'class_id': 4
})
# 添加相控阵雷达
radar_width = width // 12
radar_height = height // 8
radar_x1 = width // 2
radar_y1 = height // 12
if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas):
parts.append({
'name': "相控阵雷达",
'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height),
'confidence': 0.8,
'class_id': 1
})
elif ship_type == "潜艇":
# 添加舰桥塔
sail_width = width // 6
sail_height = height // 2
sail_x1 = width // 2 - sail_width // 2
sail_y1 = 0
if not is_overlapping((sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height), known_areas):
parts.append({
'name': "舰桥塔",
'bbox': (sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height),
'confidence': 0.85,
'class_id': 0
})
# 添加鱼雷管
torpedo_width = width // 10
torpedo_height = height // 10
torpedo_x1 = width // 10
torpedo_y1 = height // 2
if not is_overlapping((torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height), known_areas):
parts.append({
'name': "鱼雷管",
'bbox': (torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height),
'confidence': 0.7,
'class_id': 3
})
else: # 通用舰船
# 添加舰桥
bridge_width = width // 4
bridge_height = height // 3
bridge_x1 = width // 2 - bridge_width // 2
bridge_y1 = height // 8
if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas):
parts.append({
'name': "舰桥",
'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height),
'confidence': 0.75,
'class_id': 0
})
# 添加雷达
radar_width = width // 10
radar_height = width // 10 # 保持正方形
radar_x1 = width // 2 - radar_width // 2
radar_y1 = height // 20
if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas):
parts.append({
'name': "雷达",
'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height),
'confidence': 0.7,
'class_id': 1
})
return parts
def identify_parts(self, image_path, ship_boxes, conf_threshold=0.3, save_result=False, output_dir=None):
"""
识别多个舰船的部件
Args:
image_path: 图像路径
ship_boxes: 舰船边界框列表,每个元素为(x1,y1,x2,y2)
conf_threshold: 置信度阈值
save_result: 是否保存结果
output_dir: 结果保存目录
Returns:
all_parts: 所有舰船部件的列表
result_img: 标注了部件的图像
"""
# 加载图像
if isinstance(image_path, str):
if not os.path.exists(image_path):
raise FileNotFoundError(f"图像文件不存在: {image_path}")
img = cv2.imread(image_path)
else:
img = image_path.copy()
result_img = img.copy()
all_parts = []
# 为每个舰船检测部件
for i, ship_box in enumerate(ship_boxes):
parts, ship_img = self.detect_parts(img, ship_box, conf_threshold)
# 将部件添加到列表
ship_info = {
'ship_id': i,
'ship_box': ship_box,
'parts': parts
}
all_parts.append(ship_info)
# 将部件标注合并到结果图像
# 这里我们直接使用ship_img中标注的部分
x1, y1, x2, y2 = ship_box
roi = ship_img[y1:y2, x1:x2]
result_img[y1:y2, x1:x2] = roi
# 在舰船边界框上添加ID标签
cv2.rectangle(result_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.putText(result_img, f"Ship {i}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
# 保存结果
if save_result and output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
if isinstance(image_path, str):
result_filename = os.path.join(output_dir, f"parts_{os.path.basename(image_path)}")
else:
result_filename = os.path.join(output_dir, f"parts_{len(os.listdir(output_dir))}.jpg")
cv2.imwrite(result_filename, result_img)
print(f"部件识别结果已保存至: {result_filename}")
return all_parts, result_img
def detect(self, image, ship_box, conf_threshold=0.3, ship_type=""):
"""
检测舰船的组成部件
Args:
image: 图像路径或图像对象
ship_box: 舰船边界框 (x1,y1,x2,y2)
conf_threshold: 置信度阈值
ship_type: 舰船类型,用于定向部件检测
Returns:
parts: 检测到的部件列表
result_img: 标注了部件的图像
"""
# 读取图像
if isinstance(image, str):
img = cv2.imread(image)
else:
img = image.copy() if isinstance(image, np.ndarray) else np.array(image)
if img is None:
return [], np.zeros((100, 100, 3), dtype=np.uint8)
# 返回结果
result_img = img.copy()
# 提取舰船区域
x1, y1, x2, y2 = ship_box
# 确保索引是整数
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 确保边界在图像范围内
h, w = img.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
# 使用模型检测部件(实际实现取决于具体模型)
parts = []
# 根据舰船类型添加常见部件
if "航空母舰" in ship_type:
# 添加舰岛
island_w = int((x2 - x1) * 0.15)
island_h = int((y2 - y1) * 0.3)
island_x = x1 + int((x2 - x1) * 0.7)
island_y = y1 + int((y2 - y1) * 0.1)
parts.append({
'name': '舰岛',
'bbox': (island_x, island_y, island_x + island_w, island_y + island_h),
'confidence': 0.8,
'class_id': 1
})
# 标注舰岛
cv2.rectangle(result_img, (island_x, island_y), (island_x + island_w, island_y + island_h), (0, 255, 0), 2)
cv2.putText(result_img, "舰岛: 0.80", (island_x, island_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 添加甲板
deck_w = int((x2 - x1) * 0.9)
deck_h = int((y2 - y1) * 0.5)
deck_x = x1 + int((x2 - x1) * 0.05)
deck_y = y1 + int((y2 - y1) * 0.3)
parts.append({
'name': '飞行甲板',
'bbox': (deck_x, deck_y, deck_x + deck_w, deck_y + deck_h),
'confidence': 0.85,
'class_id': 2
})
# 标注甲板
cv2.rectangle(result_img, (deck_x, deck_y), (deck_x + deck_w, deck_y + deck_h), (255, 0, 0), 2)
cv2.putText(result_img, "飞行甲板: 0.85", (deck_x, deck_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
elif "驱逐舰" in ship_type:
# 添加舰桥
bridge_w = int((x2 - x1) * 0.2)
bridge_h = int((y2 - y1) * 0.4)
bridge_x = x1 + int((x2 - x1) * 0.4)
bridge_y = y1 + int((y2 - y1) * 0.1)
parts.append({
'name': '舰桥',
'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h),
'confidence': 0.8,
'class_id': 3
})
# 标注舰桥
cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2)
cv2.putText(result_img, "舰桥: 0.80", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
elif "护卫舰" in ship_type:
# 添加舰桥
bridge_w = int((x2 - x1) * 0.15)
bridge_h = int((y2 - y1) * 0.3)
bridge_x = x1 + int((x2 - x1) * 0.45)
bridge_y = y1 + int((y2 - y1) * 0.15)
parts.append({
'name': '舰桥',
'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h),
'confidence': 0.75,
'class_id': 3
})
# 标注舰桥
cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2)
cv2.putText(result_img, "舰桥: 0.75", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return parts, result_img
def _identify_part_by_position(self, rel_x, rel_y, aspect_ratio, area_ratio, ship_type):
"""根据相对位置和形状识别舰船部件"""
# 根据舰船类型和位置确定部件
if "航母" in ship_type or "航空母舰" in ship_type:
# 航母部件
if rel_y < 0.3:
if rel_x > 0.5:
if aspect_ratio < 1 and area_ratio > 0.05:
return "舰岛"
elif area_ratio < 0.03:
return "雷达"
elif area_ratio < 0.02 and aspect_ratio < 1:
return "桅杆"
elif area_ratio < 0.03:
return "相控阵雷达"
elif rel_y > 0.4 and rel_x > 0.1 and rel_x < 0.9:
return "飞行甲板"
elif rel_x < 0.5 and rel_y > 0.6:
return "弹射器"
elif aspect_ratio > 1 and area_ratio < 0.05 and rel_y > 0.3:
return "舰载机"
elif aspect_ratio < 1 and area_ratio < 0.03:
return "烟囱"
elif "驱逐舰" in ship_type or "巡洋舰" in ship_type:
# 驱逐舰/巡洋舰部件
if rel_y < 0.35:
if area_ratio > 0.1:
return "舰桥"
elif aspect_ratio < 1 and area_ratio < 0.02:
return "桅杆"
elif area_ratio < 0.03:
if aspect_ratio > 1.5:
return "相控阵雷达"
else:
return "球形雷达"
elif rel_x < 0.3 and rel_y < 0.5:
if area_ratio < 0.05:
return "主炮"
elif rel_x > 0.6 and area_ratio < 0.05:
return "垂发系统"
elif rel_y > 0.6 and rel_x > 0.7 and area_ratio > 0.05:
return "直升机甲板"
elif aspect_ratio < 1 and area_ratio < 0.02:
return "烟囱"
elif "护卫舰" in ship_type:
# 护卫舰部件
if rel_y < 0.4:
if area_ratio > 0.08:
return "舰桥"
elif area_ratio < 0.03 and aspect_ratio < 1:
return "桅杆"
elif area_ratio < 0.03:
return "雷达"
elif rel_x < 0.3:
return "舰炮"
elif rel_x > 0.7 and area_ratio > 0.05:
return "直升机甲板"
elif rel_x > 0.6 and area_ratio < 0.04:
return "导弹发射器"
elif "潜艇" in ship_type:
# 潜艇部件
if rel_y < 0.3 and area_ratio > 0.05:
return "塔台"
elif rel_x < 0.3 and area_ratio < 0.04:
return "鱼雷管"
elif rel_y < 0.2 and area_ratio < 0.02:
return "通信天线"
# 通用部件识别 - 基于位置和形状
if rel_y < 0.3:
if area_ratio > 0.08:
return "舰桥"
elif aspect_ratio < 1 and area_ratio < 0.03:
return "桅杆"
elif area_ratio < 0.03:
if aspect_ratio > 1:
return "相控阵雷达"
else:
return "雷达"
elif rel_y < 0.6:
if rel_x < 0.25 and area_ratio < 0.05:
return "舰炮"
elif aspect_ratio < 1 and area_ratio < 0.03:
return "烟囱"
elif aspect_ratio > 1.5 and area_ratio < 0.05 and rel_x > 0.6:
return "导弹发射器"
elif rel_y > 0.6:
if rel_x > 0.7 and area_ratio > 0.05:
return "直升机甲板"
elif area_ratio > 0.3:
return "甲板"
return "未知部件"
def _calculate_part_confidence(self, rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type):
"""计算部件识别的置信度"""
# 基础置信度
base_confidence = 0.5
# 根据部件类型和位置调整置信度
if part_name == "舰桥":
if rel_y < 0.4 and area_ratio > 0.05:
return min(0.9, base_confidence + 0.3)
elif part_name == "雷达":
if rel_y < 0.3:
return min(0.85, base_confidence + 0.25)
elif part_name == "舰炮":
if rel_x < 0.3:
return min(0.8, base_confidence + 0.2)
elif part_name == "导弹发射器":
if 0.4 < rel_x < 0.8:
return min(0.75, base_confidence + 0.15)
elif part_name == "飞行甲板":
if "航母" in ship_type and area_ratio > 0.3:
return min(0.95, base_confidence + 0.4)
elif part_name == "舰岛":
if "航母" in ship_type and rel_x > 0.6 and rel_y < 0.3:
return min(0.9, base_confidence + 0.3)
elif part_name == "直升机甲板":
if rel_x > 0.7 and rel_y > 0.6:
return min(0.8, base_confidence + 0.2)
# 根据舰船类型增加特定部件的置信度
if "航母" in ship_type:
if part_name in ["舰岛", "飞行甲板", "舰载机", "弹射器"]:
return min(0.85, base_confidence + 0.25)
elif "驱逐舰" in ship_type:
if part_name in ["舰桥", "主炮", "垂发系统", "相控阵雷达"]:
return min(0.85, base_confidence + 0.25)
elif "护卫舰" in ship_type:
if part_name in ["舰桥", "舰炮", "导弹发射器"]:
return min(0.8, base_confidence + 0.2)
elif "潜艇" in ship_type:
if part_name in ["塔台", "鱼雷管"]:
return min(0.85, base_confidence + 0.25)
return base_confidence
def _calculate_iou(self, box1, box2):
"""计算两个边界框的IoU"""
# 确保所有坐标为数字
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# 计算交集面积
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
# 如果没有交集返回0
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# 计算两个框的面积
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
# 计算IoU
return intersection_area / float(box1_area + box2_area - intersection_area)
def _add_specific_ship_parts(self, ship_img, ship_box, ship_type):
"""
为特定类型的舰船添加可能的部件
Args:
ship_img: 舰船图像
ship_box: 舰船边界框
ship_type: 舰船类型
Returns:
additional_parts: 额外检测到的部件列表
"""
additional_parts = []
h, w = ship_img.shape[:2]
# 获取目标舰船类型的特定部件
target_parts = self._get_target_parts(ship_type)
if "航空母舰" in ship_type:
# 检测舰岛
island_x = int(w * 0.75) # 舰岛通常在右侧四分之三处
island_y = int(h * 0.2) # 从顶部20%开始
island_w = int(w * 0.2) # 宽度约为舰船宽度的20%
island_h = int(h * 0.3) # 高度约为舰船高度的30%
# 验证区域
island_roi = ship_img[island_y:island_y+island_h, island_x:island_x+island_w]
if island_roi.size > 0:
# 分析区域特征
gray = cv2.cvtColor(island_roi, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
edge_pixels = cv2.countNonZero(edges)
edge_density = edge_pixels / (island_roi.shape[0] * island_roi.shape[1]) if island_roi.size > 0 else 0
# 如果区域有足够的边缘特征,识别为舰岛
if edge_density > 0.1 and "舰岛" in target_parts:
additional_parts.append({
"name": "舰岛",
"confidence": 0.85,
"box": [island_x, island_y, island_x + island_w, island_y + island_h],
"relative_box": [island_x/w, island_y/h, (island_x+island_w)/w, (island_y+island_h)/h]
})
# 检测飞行甲板
deck_x = int(w * 0.05) # 从左侧5%开始
deck_y = int(h * 0.1) # 从顶部10%开始
deck_w = int(w * 0.9) # 宽度约为舰船宽度的90%
deck_h = int(h * 0.2) # 高度约为舰船高度的20%
if "飞行甲板" in target_parts:
additional_parts.append({
"name": "飞行甲板",
"confidence": 0.9,
"box": [deck_x, deck_y, deck_x + deck_w, deck_y + deck_h],
"relative_box": [deck_x/w, deck_y/h, (deck_x+deck_w)/w, (deck_y+deck_h)/h]
})
# 检测舰载机(如果有)
# 分析顶部区域是否有舰载机特征
top_area = ship_img[0:int(h*0.3), :]
if top_area.size > 0:
gray = cv2.cvtColor(top_area, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (5, 5), 0)
_, thresh = cv2.threshold(blur, 120, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 筛选可能的舰载机轮廓
aircraft_contours = []
for cnt in contours:
area = cv2.contourArea(cnt)
if area > 100 and area < 5000: # 舰载机大小范围
x, y, w_box, h_box = cv2.boundingRect(cnt)
aspect_ratio = w_box / h_box if h_box > 0 else 0
if 1.5 < aspect_ratio < 3.0: # 舰载机长宽比范围
aircraft_contours.append(cnt)
# 为每个可能的舰载机添加部件
for i, cnt in enumerate(aircraft_contours[:5]): # 最多添加5个舰载机
x, y, w_box, h_box = cv2.boundingRect(cnt)
if "舰载机" in target_parts:
additional_parts.append({
"name": "舰载机",
"confidence": 0.75,
"box": [x, y, x + w_box, y + h_box],
"relative_box": [x/w, y/h, (x+w_box)/w, (y+h_box)/h]
})
elif "驱逐舰" in ship_type or "护卫舰" in ship_type:
# 检测舰炮(通常在舰首)
gun_x = int(w * 0.05) # 从左侧5%开始(假设舰首在左侧)
gun_y = int(h * 0.1) # 从顶部10%开始
gun_w = int(w * 0.15) # 宽度约为舰船宽度的15%
gun_h = int(h * 0.15) # 高度约为舰船高度的15%
# 验证区域
gun_roi = ship_img[gun_y:gun_y+gun_h, gun_x:gun_x+gun_w]
if gun_roi.size > 0 and "舰炮" in target_parts:
# 可以使用更复杂的检测算法来验证
additional_parts.append({
"name": "舰炮",
"confidence": 0.8,
"box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h],
"relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h]
})
# 检测舰桥(通常在中部偏后)
bridge_x = int(w * 0.4) # 从40%处开始
bridge_y = int(h * 0.05) # 从顶部5%开始
bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20%
bridge_h = int(h * 0.3) # 高度约为舰船高度的30%
# 验证区域
bridge_roi = ship_img[bridge_y:bridge_y+bridge_h, bridge_x:bridge_x+bridge_w]
if bridge_roi.size > 0 and "舰桥" in target_parts:
# 分析区域特征
gray = cv2.cvtColor(bridge_roi, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
edge_pixels = cv2.countNonZero(edges)
edge_density = edge_pixels / (bridge_roi.shape[0] * bridge_roi.shape[1]) if bridge_roi.size > 0 else 0
if edge_density > 0.1:
additional_parts.append({
"name": "舰桥",
"confidence": 0.85,
"box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h],
"relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h]
})
# 检测雷达(通常在舰桥顶部)
radar_x = int(w * 0.45) # 从45%处开始
radar_y = int(h * 0.05) # 从顶部5%开始
radar_w = int(w * 0.1) # 宽度约为舰船宽度的10%
radar_h = int(h * 0.1) # 高度约为舰船高度的10%
if "雷达" in target_parts:
additional_parts.append({
"name": "雷达",
"confidence": 0.7,
"box": [radar_x, radar_y, radar_x + radar_w, radar_y + radar_h],
"relative_box": [radar_x/w, radar_y/h, (radar_x+radar_w)/w, (radar_y+radar_h)/h]
})
# 检测垂直发射系统(通常在中前部)
vls_x = int(w * 0.25) # 从25%处开始
vls_y = int(h * 0.1) # 从顶部10%开始
vls_w = int(w * 0.15) # 宽度约为舰船宽度的15%
vls_h = int(h * 0.1) # 高度约为舰船高度的10%
if "垂直发射系统" in target_parts:
additional_parts.append({
"name": "垂直发射系统",
"confidence": 0.75,
"box": [vls_x, vls_y, vls_x + vls_w, vls_y + vls_h],
"relative_box": [vls_x/w, vls_y/h, (vls_x+vls_w)/w, (vls_y+vls_h)/h]
})
elif "潜艇" in ship_type:
# 检测舰塔(通常在潜艇中部上方)
tower_x = int(w * 0.4) # 从40%处开始
tower_y = int(h * 0.0) # 从顶部开始
tower_w = int(w * 0.2) # 宽度约为潜艇宽度的20%
tower_h = int(h * 0.5) # 高度约为潜艇高度的50%
if "舰塔" in target_parts:
additional_parts.append({
"name": "舰塔",
"confidence": 0.9,
"box": [tower_x, tower_y, tower_x + tower_w, tower_y + tower_h],
"relative_box": [tower_x/w, tower_y/h, (tower_x+tower_w)/w, (tower_y+tower_h)/h]
})
elif "巡洋舰" in ship_type:
# 检测舰桥(通常在中部)
bridge_x = int(w * 0.4) # 从40%处开始
bridge_y = int(h * 0.05) # 从顶部5%开始
bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20%
bridge_h = int(h * 0.3) # 高度约为舰船高度的30%
if "舰桥" in target_parts:
additional_parts.append({
"name": "舰桥",
"confidence": 0.85,
"box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h],
"relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h]
})
# 检测主炮(通常在前部)
gun_x = int(w * 0.1) # 从10%处开始
gun_y = int(h * 0.1) # 从顶部10%开始
gun_w = int(w * 0.15) # 宽度约为舰船宽度的15%
gun_h = int(h * 0.15) # 高度约为舰船高度的15%
if "舰炮" in target_parts:
additional_parts.append({
"name": "舰炮",
"confidence": 0.8,
"box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h],
"relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h]
})
# 检测垂直发射系统(通常分布在前后部)
vls1_x = int(w * 0.2) # 前部VLS
vls1_y = int(h * 0.1)
vls1_w = int(w * 0.1)
vls1_h = int(h * 0.1)
vls2_x = int(w * 0.6) # 后部VLS
vls2_y = int(h * 0.1)
vls2_w = int(w * 0.1)
vls2_h = int(h * 0.1)
if "垂直发射系统" in target_parts:
additional_parts.append({
"name": "垂直发射系统",
"confidence": 0.75,
"box": [vls1_x, vls1_y, vls1_x + vls1_w, vls1_y + vls1_h],
"relative_box": [vls1_x/w, vls1_y/h, (vls1_x+vls1_w)/w, (vls1_y+vls1_h)/h]
})
additional_parts.append({
"name": "垂直发射系统",
"confidence": 0.75,
"box": [vls2_x, vls2_y, vls2_x + vls2_w, vls2_y + vls2_h],
"relative_box": [vls2_x/w, vls2_y/h, (vls2_x+vls2_w)/w, (vls2_y+vls2_h)/h]
})
return additional_parts
def _get_target_parts(self, ship_type=""):
"""
根据舰船类型获取目标部件列表
Args:
ship_type: 舰船类型
Returns:
target_parts: 该类型舰船应该检测的部件列表
"""
# 默认检测所有部件
if not ship_type or ship_type not in self.ship_parts_map:
return list(self.part_types.keys())
# 返回特定舰船类型的部件列表
return self.ship_parts_map.get(ship_type, [])
def _map_to_ship_part(self, cls_id, ship_type=""):
"""
将YOLO类别ID映射到舰船部件名称
Args:
cls_id: YOLO类别ID
ship_type: 舰船类型名称
Returns:
part_name: 舰船部件名称
"""
# COCO数据集中的常见类别
coco_classes = {
0: "person", # 人 -> 可能是船员
1: "bicycle", # 自行车
2: "car", # 汽车
3: "motorcycle", # 摩托车
4: "airplane", # 飞机 -> 可能是舰载机
5: "bus", # 公交车
6: "train", # 火车
7: "truck", # 卡车
8: "boat", # 船
9: "traffic light", # 交通灯
10: "fire hydrant", # 消防栓
11: "stop sign", # 停止标志
12: "parking meter", # 停车计时器
13: "bench", # 长凳
14: "bird", # 鸟
15: "cat", # 猫
16: "dog", # 狗
17: "horse", # 马
24: "backpack", # 背包
25: "umbrella", # 雨伞
26: "handbag", # 手提包
27: "tie", # 领带
28: "suitcase", # 手提箱
33: "sports ball", # 运动球
34: "kite", # 风筝
35: "baseball bat", # 棒球棒
36: "baseball glove", # 棒球手套
41: "skateboard", # 滑板
42: "surfboard", # 冲浪板
43: "tennis racket", # 网球拍
59: "potted plant", # 盆栽植物
60: "bed", # 床
61: "dining table", # 餐桌
62: "toilet", # 厕所
63: "tv", # 电视
64: "laptop", # 笔记本电脑
65: "mouse", # 鼠标
66: "remote", # 遥控器
67: "keyboard", # 键盘
68: "cell phone", # 手机
69: "microwave", # 微波炉
70: "oven", # 烤箱
71: "toaster", # 烤面包机
72: "sink", # 水槽
73: "refrigerator", # 冰箱
74: "book", # 书
75: "clock", # 时钟
76: "vase", # 花瓶
77: "scissors", # 剪刀
78: "teddy bear", # 泰迪熊
79: "hair drier", # 吹风机
80: "toothbrush" # 牙刷
}
# 将COCO类别映射到舰船部件
part_mappings = {
"person": "船员",
"airplane": "舰载机",
"boat": "小艇",
"backpack": "装备",
"tie": "指挥官",
"suitcase": "设备箱",
"sports ball": "雷达球",
"kite": "雷达天线",
"surfboard": "救生设备",
"bed": "甲板",
"dining table": "甲板",
"toilet": "舱室",
"tv": "显示设备",
"laptop": "控制设备",
"mouse": "控制设备",
"remote": "控制设备",
"keyboard": "控制设备",
"cell phone": "通信设备",
"clock": "导航设备",
"scissors": "工具"
}
# 获取COCO类别名称
if cls_id in coco_classes:
coco_class = coco_classes[cls_id]
# 将COCO类别映射到舰船部件
if coco_class in part_mappings:
return part_mappings[coco_class]
# 如果是航母且检测到飞机,返回舰载机
if ("航母" in ship_type or "航空母舰" in ship_type) and cls_id == 4:
return "舰载机"
# 默认返回未知
return "unknown"
# 测试代码
if __name__ == "__main__":
# 初始化部件检测器
part_detector = ShipPartDetector()
# 测试图像路径和舰船边界框
test_image = "path/to/test/image.jpg"
test_ship_box = (100, 100, 500, 400) # 示例边界框
if os.path.exists(test_image):
# 执行部件检测
parts, result_img = part_detector.detect_parts(test_image, test_ship_box)
# 打印检测结果
print(f"检测到 {len(parts)} 个舰船部件:")
for i, part in enumerate(parts):
print(f"部件 {i+1}: 类型={part['name']}, 置信度={part['confidence']:.4f}, 位置={part['bbox']}")