|
|
import os
|
|
|
import sys
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
from ultralytics import YOLO
|
|
|
from pathlib import Path
|
|
|
|
|
|
class ShipPartDetector:
|
|
|
"""
|
|
|
舰船部件检测模块,负责识别舰船的各个组成部分
|
|
|
支持检测的部件包括:舰桥、雷达、舰炮、导弹发射装置、直升机甲板等
|
|
|
"""
|
|
|
def __init__(self, model_path=None, device=None):
|
|
|
"""
|
|
|
初始化舰船部件检测器
|
|
|
|
|
|
Args:
|
|
|
model_path: 部件检测模型路径,如果为None则使用预训练通用模型后处理
|
|
|
device: 运行设备,可以是'cuda'或'cpu',None则自动选择
|
|
|
"""
|
|
|
# 设置模型路径
|
|
|
if model_path is None:
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
parent_dir = os.path.dirname(script_dir)
|
|
|
model_path = os.path.join(parent_dir, 'yolov8s-seg.pt')
|
|
|
|
|
|
self.model_path = model_path
|
|
|
|
|
|
# 设置设备
|
|
|
if device is None:
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
else:
|
|
|
self.device = device
|
|
|
|
|
|
# 舰船部件类型映射表 - 扩展版本
|
|
|
self.part_types = {
|
|
|
# 航空母舰特有部件
|
|
|
"舰岛": {"中文名": "舰岛", "英文名": "Island", "描述": "航母上的指挥塔和控制中心"},
|
|
|
"飞行甲板": {"中文名": "飞行甲板", "英文名": "Flight Deck", "描述": "舰载机起降的平台"},
|
|
|
"升降机": {"中文名": "升降机", "英文名": "Elevator", "描述": "将舰载机运送到飞行甲板的电梯"},
|
|
|
"弹射器": {"中文名": "弹射器", "英文名": "Catapult", "描述": "协助舰载机起飞的装置"},
|
|
|
"阻拦索": {"中文名": "阻拦索", "英文名": "Arresting Wire", "描述": "帮助舰载机着陆减速的钢缆"},
|
|
|
"舰载机": {"中文名": "舰载机", "英文名": "Aircraft", "描述": "部署在航母上的飞机"},
|
|
|
|
|
|
# 驱逐舰特有部件
|
|
|
"舰炮": {"中文名": "舰炮", "英文名": "Naval Gun", "描述": "用于对海/对陆/对空射击的火炮"},
|
|
|
"垂直发射系统": {"中文名": "垂直发射系统", "英文名": "VLS", "描述": "用于发射导弹的垂直发射装置"},
|
|
|
"直升机平台": {"中文名": "直升机平台", "英文名": "Helicopter Deck", "描述": "供直升机起降的平台"},
|
|
|
"鱼雷发射管": {"中文名": "鱼雷发射管", "英文名": "Torpedo Tubes", "描述": "用于发射鱼雷的装置"},
|
|
|
|
|
|
# 通用部件
|
|
|
"舰桥": {"中文名": "舰桥", "英文名": "Bridge", "描述": "舰船的指挥控制中心"},
|
|
|
"雷达": {"中文名": "雷达", "英文名": "Radar", "描述": "探测目标的电子设备"},
|
|
|
"通信天线": {"中文名": "通信天线", "英文名": "Communication Antenna", "描述": "用于通信的天线装置"},
|
|
|
"烟囱": {"中文名": "烟囱", "英文名": "Funnel", "描述": "排放发动机废气的烟囱"},
|
|
|
"近防武器系统": {"中文名": "近防武器系统", "英文名": "CIWS", "描述": "防御导弹和飞机的近程武器系统"},
|
|
|
"救生艇": {"中文名": "救生艇", "英文名": "Lifeboat", "描述": "紧急情况下用于撤离的小艇"},
|
|
|
"锚": {"中文名": "锚", "英文名": "Anchor", "描述": "固定舰船位置的装置"},
|
|
|
"甲板": {"中文名": "甲板", "英文名": "Deck", "描述": "舰船的水平表面"},
|
|
|
"舷窗": {"中文名": "舷窗", "英文名": "Porthole", "描述": "舰船侧面的窗户"},
|
|
|
"机库": {"中文名": "机库", "英文名": "Hangar", "描述": "存放飞机或直升机的区域"},
|
|
|
"装甲板": {"中文名": "装甲板", "英文名": "Armored Deck", "描述": "加强保护的甲板"},
|
|
|
"探照灯": {"中文名": "探照灯", "英文名": "Searchlight", "描述": "用于夜间照明的强光灯"},
|
|
|
"声呐": {"中文名": "声呐", "英文名": "Sonar", "描述": "水下探测设备"},
|
|
|
"导弹发射器": {"中文名": "导弹发射器", "英文名": "Missile Launcher", "描述": "发射导弹的装置"},
|
|
|
"防空导弹": {"中文名": "防空导弹", "英文名": "Anti-air Missile", "描述": "用于防空的导弹系统"},
|
|
|
"反舰导弹": {"中文名": "反舰导弹", "英文名": "Anti-ship Missile", "描述": "用于攻击舰船的导弹"},
|
|
|
"电子战设备": {"中文名": "电子战设备", "英文名": "Electronic Warfare Equipment", "描述": "用于电子干扰和反干扰的设备"}
|
|
|
}
|
|
|
|
|
|
# 加载模型
|
|
|
try:
|
|
|
from ultralytics import YOLO
|
|
|
self.model = YOLO(self.model_path)
|
|
|
self.model_loaded = True
|
|
|
print(f"成功加载部件检测模型: {self.model_path}")
|
|
|
except Exception as e:
|
|
|
print(f"加载部件检测模型出错: {e}")
|
|
|
self.model = None
|
|
|
self.model_loaded = False
|
|
|
|
|
|
# 舰船类型与特有部件映射
|
|
|
self.ship_parts_map = {
|
|
|
"航空母舰": ["舰岛", "飞行甲板", "升降机", "弹射器", "阻拦索", "舰载机", "舰桥", "雷达", "通信天线"],
|
|
|
"驱逐舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "烟囱", "近防武器系统"],
|
|
|
"护卫舰": ["舰炮", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统", "声呐"],
|
|
|
"巡洋舰": ["舰炮", "垂直发射系统", "直升机平台", "舰桥", "雷达", "通信天线", "近防武器系统"],
|
|
|
"潜艇": ["舰塔", "鱼雷发射管", "垂直发射系统", "通信天线", "潜望镜"],
|
|
|
"两栖攻击舰": ["飞行甲板", "舰岛", "直升机平台", "船坞", "舰桥", "雷达", "通信天线"]
|
|
|
}
|
|
|
|
|
|
print(f"使用设备: {self.device}")
|
|
|
|
|
|
# 启用通用部件检测
|
|
|
self.enable_generic_parts = True
|
|
|
|
|
|
# 部件类别映射 (这里的ID应该对应训练好的模型中的类别ID)
|
|
|
self.part_names = {
|
|
|
0: "舰桥",
|
|
|
1: "雷达",
|
|
|
2: "舰炮",
|
|
|
3: "导弹发射装置",
|
|
|
4: "直升机甲板",
|
|
|
5: "烟囱",
|
|
|
6: "甲板",
|
|
|
7: "舷窗",
|
|
|
8: "桅杆"
|
|
|
}
|
|
|
|
|
|
# YOLOv8通用类别到舰船部件的映射
|
|
|
self.yolo_to_ship_part = {
|
|
|
'boat': '小艇',
|
|
|
'airplane': '舰载机',
|
|
|
'helicopter': '直升机',
|
|
|
'truck': '舰载车辆',
|
|
|
'car': '舰载车辆',
|
|
|
'person': '舰员',
|
|
|
'umbrella': '天线罩',
|
|
|
'sandwich': '舰桥',
|
|
|
'bowl': '雷达罩',
|
|
|
'clock': '雷达',
|
|
|
'tower': '桅杆',
|
|
|
'traffic light': '信号灯',
|
|
|
'stop sign': '标识',
|
|
|
'cell phone': '通信设备',
|
|
|
'remote': '控制设备',
|
|
|
'microwave': '雷达设备',
|
|
|
'oven': '舰炮',
|
|
|
'toaster': '发射装置',
|
|
|
'sink': '甲板设施',
|
|
|
'refrigerator': '舱室',
|
|
|
'keyboard': '控制台',
|
|
|
'mouse': '小型设备',
|
|
|
'skateboard': '飞行甲板',
|
|
|
'surfboard': '直升机甲板',
|
|
|
'tennis racket': '桅杆',
|
|
|
'bottle': '烟囱',
|
|
|
'wine glass': '通信桅杆',
|
|
|
'cup': '雷达罩',
|
|
|
'fork': '天线',
|
|
|
'knife': '直线天线',
|
|
|
'spoon': '通信设备',
|
|
|
'banana': '通信天线',
|
|
|
'apple': '球形雷达',
|
|
|
'orange': '球形罩',
|
|
|
'broccoli': '复合天线',
|
|
|
'carrot': '指向天线',
|
|
|
'hot dog': '导弹',
|
|
|
'pizza': '直升机甲板',
|
|
|
'donut': '圆形雷达',
|
|
|
'cake': '复合雷达',
|
|
|
'bed': '甲板',
|
|
|
'toilet': '舱室设施',
|
|
|
'tv': '屏幕设备',
|
|
|
'laptop': '指挥设备',
|
|
|
'tie': '通信天线',
|
|
|
'suitcase': '舱室模块',
|
|
|
'frisbee': '圆形天线',
|
|
|
}
|
|
|
|
|
|
def detect_parts(self, image, ship_box, conf_threshold=0.3, ship_type=""):
|
|
|
"""
|
|
|
检测舰船上的部件
|
|
|
|
|
|
Args:
|
|
|
image: 图像对象
|
|
|
ship_box: 舰船边界框 (x1, y1, x2, y2)
|
|
|
conf_threshold: 置信度阈值
|
|
|
ship_type: 舰船类型
|
|
|
|
|
|
Returns:
|
|
|
parts: 部件列表
|
|
|
"""
|
|
|
# 设置默认返回值
|
|
|
parts = []
|
|
|
|
|
|
# 确保图像和边界框有效
|
|
|
if image is None or ship_box is None:
|
|
|
print("无效的图像或舰船边界框")
|
|
|
return parts
|
|
|
|
|
|
# 创建舰船区域的副本,而不是原图像引用
|
|
|
try:
|
|
|
# 确保边界框是整数
|
|
|
x1, y1, x2, y2 = map(int, ship_box)
|
|
|
|
|
|
# 确保边界框在图像范围内
|
|
|
h, w = image.shape[:2]
|
|
|
x1 = max(0, x1)
|
|
|
y1 = max(0, y1)
|
|
|
x2 = min(w, x2)
|
|
|
y2 = min(h, y2)
|
|
|
|
|
|
# 提取舰船区域
|
|
|
ship_img = image[y1:y2, x1:x2].copy()
|
|
|
|
|
|
# 检查提取的图像是否有效
|
|
|
if ship_img.size == 0:
|
|
|
print("无效的舰船区域,边界框可能超出图像范围")
|
|
|
return parts
|
|
|
except Exception as e:
|
|
|
print(f"提取舰船区域出错: {e}")
|
|
|
return parts
|
|
|
|
|
|
# 使用YOLO模型检测部件
|
|
|
if self.model is not None:
|
|
|
try:
|
|
|
# 根据舰船类型过滤目标部件
|
|
|
target_parts = self._get_target_parts(ship_type)
|
|
|
|
|
|
# 使用YOLO模型进行检测
|
|
|
results = self.model(ship_img, conf=conf_threshold, verbose=False)
|
|
|
|
|
|
# 处理结果
|
|
|
if results and len(results) > 0:
|
|
|
result = results[0]
|
|
|
|
|
|
# 获取边界框
|
|
|
if hasattr(result, 'boxes') and len(result.boxes) > 0:
|
|
|
boxes = result.boxes
|
|
|
|
|
|
for i, box in enumerate(boxes):
|
|
|
# 检查类别
|
|
|
cls_id = int(box.cls.item()) if hasattr(box, 'cls') else -1
|
|
|
|
|
|
# 确定部件类型
|
|
|
part_name = self._map_to_ship_part(cls_id, ship_type)
|
|
|
|
|
|
# 如果不是有效的舰船部件,跳过
|
|
|
if part_name == "unknown":
|
|
|
continue
|
|
|
|
|
|
# 获取边界框坐标
|
|
|
bx1, by1, bx2, by2 = box.xyxy[0].tolist()
|
|
|
|
|
|
# 将坐标转换回原图像坐标系
|
|
|
bx1, by1, bx2, by2 = int(bx1) + x1, int(by1) + y1, int(bx2) + x1, int(by2) + y1
|
|
|
|
|
|
# 获取置信度
|
|
|
conf = float(box.conf.item()) if hasattr(box, 'conf') else 0.5
|
|
|
|
|
|
# 添加到部件列表
|
|
|
part = {
|
|
|
'name': part_name,
|
|
|
'bbox': (bx1, by1, bx2, by2),
|
|
|
'confidence': conf
|
|
|
}
|
|
|
parts.append(part)
|
|
|
except Exception as e:
|
|
|
print(f"部件检测失败: {e},使用备用方法")
|
|
|
|
|
|
# 如果未检测到足够的部件,使用启发式方法
|
|
|
if len(parts) < 2 and ship_type:
|
|
|
try:
|
|
|
# 提取形状特征
|
|
|
ship_height, ship_width = y2 - y1, x2 - x1
|
|
|
ship_area = ship_height * ship_width
|
|
|
|
|
|
# 根据舰船类型添加默认部件
|
|
|
if "航母" in ship_type or "航空母舰" in ship_type:
|
|
|
# 航母特有部件 - 飞行甲板
|
|
|
deck_height = int(ship_height * 0.15)
|
|
|
deck_width = ship_width
|
|
|
deck_x1 = x1
|
|
|
deck_y1 = y1 + int(ship_height * 0.05)
|
|
|
deck_x2 = x1 + deck_width
|
|
|
deck_y2 = deck_y1 + deck_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "飞行甲板",
|
|
|
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
|
|
|
'confidence': 0.85
|
|
|
})
|
|
|
|
|
|
# 舰岛
|
|
|
island_width = int(ship_width * 0.25)
|
|
|
island_height = int(ship_height * 0.3)
|
|
|
island_x1 = x1 + int(ship_width * 0.6)
|
|
|
island_y1 = y1 + int(ship_height * 0.2)
|
|
|
island_x2 = island_x1 + island_width
|
|
|
island_y2 = island_y1 + island_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "舰岛",
|
|
|
'bbox': (island_x1, island_y1, island_x2, island_y2),
|
|
|
'confidence': 0.8
|
|
|
})
|
|
|
|
|
|
elif "驱逐舰" in ship_type:
|
|
|
# 驱逐舰特有部件 - 舰桥
|
|
|
bridge_width = int(ship_width * 0.2)
|
|
|
bridge_height = int(ship_height * 0.4)
|
|
|
bridge_x1 = x1 + int(ship_width * 0.4)
|
|
|
bridge_y1 = y1 + int(ship_height * 0.15)
|
|
|
bridge_x2 = bridge_x1 + bridge_width
|
|
|
bridge_y2 = bridge_y1 + bridge_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "舰桥",
|
|
|
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
|
|
|
'confidence': 0.8
|
|
|
})
|
|
|
|
|
|
# 垂发系统
|
|
|
vls_width = int(ship_width * 0.15)
|
|
|
vls_height = int(ship_height * 0.2)
|
|
|
vls_x1 = x1 + int(ship_width * 0.2)
|
|
|
vls_y1 = y1 + int(ship_height * 0.25)
|
|
|
vls_x2 = vls_x1 + vls_width
|
|
|
vls_y2 = vls_y1 + vls_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "垂发系统",
|
|
|
'bbox': (vls_x1, vls_y1, vls_x2, vls_y2),
|
|
|
'confidence': 0.7
|
|
|
})
|
|
|
|
|
|
# 主炮
|
|
|
gun_width = int(ship_width * 0.1)
|
|
|
gun_height = int(ship_height * 0.15)
|
|
|
gun_x1 = x1 + int(ship_width * 0.05)
|
|
|
gun_y1 = y1 + int(ship_height * 0.3)
|
|
|
gun_x2 = gun_x1 + gun_width
|
|
|
gun_y2 = gun_y1 + gun_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "主炮",
|
|
|
'bbox': (gun_x1, gun_y1, gun_x2, gun_y2),
|
|
|
'confidence': 0.75
|
|
|
})
|
|
|
|
|
|
elif "护卫舰" in ship_type:
|
|
|
# 护卫舰特有部件
|
|
|
# 舰桥
|
|
|
bridge_width = int(ship_width * 0.18)
|
|
|
bridge_height = int(ship_height * 0.35)
|
|
|
bridge_x1 = x1 + int(ship_width * 0.35)
|
|
|
bridge_y1 = y1 + int(ship_height * 0.2)
|
|
|
bridge_x2 = bridge_x1 + bridge_width
|
|
|
bridge_y2 = bridge_y1 + bridge_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "舰桥",
|
|
|
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
|
|
|
'confidence': 0.8
|
|
|
})
|
|
|
|
|
|
# 直升机甲板
|
|
|
heli_width = int(ship_width * 0.25)
|
|
|
heli_height = int(ship_height * 0.25)
|
|
|
heli_x1 = x1 + int(ship_width * 0.7)
|
|
|
heli_y1 = y1 + int(ship_height * 0.25)
|
|
|
heli_x2 = heli_x1 + heli_width
|
|
|
heli_y2 = heli_y1 + heli_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "直升机甲板",
|
|
|
'bbox': (heli_x1, heli_y1, heli_x2, heli_y2),
|
|
|
'confidence': 0.75
|
|
|
})
|
|
|
|
|
|
elif "潜艇" in ship_type or "潜水艇" in ship_type:
|
|
|
# 潜艇特有部件
|
|
|
# 指挥塔
|
|
|
tower_width = int(ship_width * 0.15)
|
|
|
tower_height = int(ship_height * 0.4)
|
|
|
tower_x1 = x1 + int(ship_width * 0.4)
|
|
|
tower_y1 = y1 + int(ship_height * 0.1)
|
|
|
tower_x2 = tower_x1 + tower_width
|
|
|
tower_y2 = tower_y1 + tower_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "指挥塔",
|
|
|
'bbox': (tower_x1, tower_y1, tower_x2, tower_y2),
|
|
|
'confidence': 0.8
|
|
|
})
|
|
|
|
|
|
else:
|
|
|
# 通用舰船部件
|
|
|
# 舰桥
|
|
|
bridge_width = int(ship_width * 0.2)
|
|
|
bridge_height = int(ship_height * 0.35)
|
|
|
bridge_x1 = x1 + int(ship_width * 0.4)
|
|
|
bridge_y1 = y1 + int(ship_height * 0.2)
|
|
|
bridge_x2 = bridge_x1 + bridge_width
|
|
|
bridge_y2 = bridge_y1 + bridge_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "舰桥",
|
|
|
'bbox': (bridge_x1, bridge_y1, bridge_x2, bridge_y2),
|
|
|
'confidence': 0.8
|
|
|
})
|
|
|
|
|
|
# 雷达
|
|
|
radar_width = int(ship_width * 0.1)
|
|
|
radar_height = int(ship_height * 0.15)
|
|
|
radar_x1 = x1 + int(ship_width * 0.45)
|
|
|
radar_y1 = y1 + int(ship_height * 0.05)
|
|
|
radar_x2 = radar_x1 + radar_width
|
|
|
radar_y2 = radar_y1 + radar_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "雷达",
|
|
|
'bbox': (radar_x1, radar_y1, radar_x2, radar_y2),
|
|
|
'confidence': 0.7
|
|
|
})
|
|
|
|
|
|
# 甲板
|
|
|
deck_width = int(ship_width * 0.8)
|
|
|
deck_height = int(ship_height * 0.25)
|
|
|
deck_x1 = x1 + int(ship_width * 0.1)
|
|
|
deck_y1 = y1 + int(ship_height * 0.6)
|
|
|
deck_x2 = deck_x1 + deck_width
|
|
|
deck_y2 = deck_y1 + deck_height
|
|
|
|
|
|
parts.append({
|
|
|
'name': "甲板",
|
|
|
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
|
|
|
'confidence': 0.75
|
|
|
})
|
|
|
except Exception as e:
|
|
|
except Exception as e:
|
|
|
|
|
|
return parts
|
|
|
|
|
|
def _map_yolo_class_to_ship_part(self, cls_name, ship_type=""):
|
|
|
"""将YOLO类别名称映射到舰船部件名称"""
|
|
|
# 通用物体到舰船部件的映射
|
|
|
mapping = {
|
|
|
'person': '人员',
|
|
|
'bicycle': '小型设备',
|
|
|
'car': '小型设备',
|
|
|
'motorcycle': '小型设备',
|
|
|
'airplane': '舰载机',
|
|
|
'bus': '车辆',
|
|
|
'train': '小型设备',
|
|
|
'truck': '车辆',
|
|
|
'boat': '小艇',
|
|
|
'traffic light': '信号灯',
|
|
|
'fire hydrant': '消防设备',
|
|
|
'stop sign': '标志牌',
|
|
|
'parking meter': '小型设备',
|
|
|
'bench': '设备',
|
|
|
'bird': '无人机',
|
|
|
'cat': '小型设备',
|
|
|
'dog': '小型设备',
|
|
|
'horse': '小型设备',
|
|
|
'sheep': '小型设备',
|
|
|
'cow': '小型设备',
|
|
|
'elephant': '大型设备',
|
|
|
'bear': '大型设备',
|
|
|
'zebra': '小型设备',
|
|
|
'giraffe': '高大设备',
|
|
|
'backpack': '设备',
|
|
|
'umbrella': '小型设备',
|
|
|
'handbag': '设备',
|
|
|
'tie': '小型设备',
|
|
|
'suitcase': '设备',
|
|
|
'frisbee': '小型设备',
|
|
|
'skis': '小型设备',
|
|
|
'snowboard': '小型设备',
|
|
|
'sports ball': '小型设备',
|
|
|
'kite': '无人机',
|
|
|
'baseball bat': '小型设备',
|
|
|
'baseball glove': '小型设备',
|
|
|
'skateboard': '小型设备',
|
|
|
'surfboard': '救生设备',
|
|
|
'tennis racket': '小型设备',
|
|
|
'bottle': '设备',
|
|
|
'wine glass': '设备',
|
|
|
'cup': '设备',
|
|
|
'fork': '设备',
|
|
|
'knife': '设备',
|
|
|
'spoon': '设备',
|
|
|
'bowl': '设备',
|
|
|
'banana': '补给',
|
|
|
'apple': '补给',
|
|
|
'sandwich': '补给',
|
|
|
'orange': '补给',
|
|
|
'broccoli': '补给',
|
|
|
'carrot': '补给',
|
|
|
'hot dog': '补给',
|
|
|
'pizza': '补给',
|
|
|
'donut': '补给',
|
|
|
'cake': '补给',
|
|
|
'chair': '设备',
|
|
|
'couch': '设备',
|
|
|
'potted plant': '设备',
|
|
|
'bed': '设备',
|
|
|
'dining table': '设备',
|
|
|
'toilet': '设施',
|
|
|
'tv': '显示设备',
|
|
|
'laptop': '计算设备',
|
|
|
'mouse': '小型设备',
|
|
|
'remote': '控制设备',
|
|
|
'keyboard': '控制设备',
|
|
|
'cell phone': '通信设备',
|
|
|
'microwave': '设备',
|
|
|
'oven': '设备',
|
|
|
'toaster': '设备',
|
|
|
'sink': '设施',
|
|
|
'refrigerator': '设备',
|
|
|
'book': '资料',
|
|
|
'clock': '设备',
|
|
|
'vase': '设备',
|
|
|
'scissors': '工具',
|
|
|
'teddy bear': '设备',
|
|
|
'hair drier': '设备',
|
|
|
'toothbrush': '设备'
|
|
|
}
|
|
|
|
|
|
# 针对特定舰船类型的映射优化
|
|
|
if "航母" in ship_type or "航空母舰" in ship_type:
|
|
|
if cls_name == 'airplane':
|
|
|
return "舰载机"
|
|
|
elif cls_name in ['tv', 'laptop', 'cell phone']:
|
|
|
return "雷达设备"
|
|
|
elif cls_name in ['tower', 'building']:
|
|
|
return "舰岛"
|
|
|
elif "驱逐舰" in ship_type or "护卫舰" in ship_type:
|
|
|
if cls_name in ['tv', 'laptop']:
|
|
|
return "相控阵雷达"
|
|
|
elif cls_name in ['tower', 'building']:
|
|
|
return "舰桥"
|
|
|
elif cls_name in ['remote', 'bottle']:
|
|
|
return "导弹装置"
|
|
|
|
|
|
# 默认返回映射或原始类别名称
|
|
|
return mapping.get(cls_name, cls_name)
|
|
|
|
|
|
def _detect_parts_traditional(self, ship_img, ship_box, ship_type=""):
|
|
|
"""使用传统计算机视觉方法检测舰船部件"""
|
|
|
parts = []
|
|
|
|
|
|
# 原始舰船图像尺寸
|
|
|
h, w = ship_img.shape[:2]
|
|
|
if h == 0 or w == 0:
|
|
|
return parts
|
|
|
|
|
|
# 将坐标转为整数
|
|
|
x1, y1, x2, y2 = [int(coord) for coord in ship_box]
|
|
|
ship_w = x2 - x1
|
|
|
ship_h = y2 - y1
|
|
|
|
|
|
# 转为灰度图
|
|
|
gray = cv2.cvtColor(ship_img, cv2.COLOR_BGR2GRAY) if len(ship_img.shape) == 3 else ship_img.copy()
|
|
|
|
|
|
# 提取边缘
|
|
|
edges = cv2.Canny(gray, 50, 150)
|
|
|
_, thresh = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY)
|
|
|
|
|
|
# 寻找轮廓
|
|
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
# 分析轮廓并识别部件
|
|
|
for contour in contours:
|
|
|
# 计算轮廓面积和边界框
|
|
|
area = cv2.contourArea(contour)
|
|
|
if area < (w * h * 0.005): # 忽略太小的轮廓
|
|
|
continue
|
|
|
|
|
|
# 计算边界框
|
|
|
x, y, box_width, box_height = cv2.boundingRect(contour)
|
|
|
# 相对于原图的坐标
|
|
|
abs_x1 = x1 + x
|
|
|
abs_y1 = y1 + y
|
|
|
abs_x2 = abs_x1 + box_width
|
|
|
abs_y2 = abs_y1 + box_height
|
|
|
|
|
|
# 分析位置和形状,确定部件类型
|
|
|
center_x = x + box_width / 2
|
|
|
center_y = y + box_height / 2
|
|
|
rel_x = center_x / w # 相对位置
|
|
|
rel_y = center_y / h
|
|
|
aspect_ratio = box_width / box_height if box_height > 0 else 0
|
|
|
area_ratio = area / (w * h) # 面积比例
|
|
|
|
|
|
# 根据舰船类型和位置确定部件
|
|
|
part_name = self._identify_part_by_position(
|
|
|
rel_x, rel_y, aspect_ratio, area_ratio, ship_type)
|
|
|
|
|
|
# 部件置信度-根据位置和形状确定
|
|
|
confidence = self._calculate_part_confidence(
|
|
|
rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type)
|
|
|
# 定义绘制中文的函数
|
|
|
def draw_cn_text(img, text, position, font_scale=0.5, color=(0, 255, 0), thickness=2):
|
|
|
"""使用PIL绘制中文文本并转回OpenCV格式"""
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
import numpy as np
|
|
|
|
|
|
# 转换为PIL图像
|
|
|
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
|
draw = ImageDraw.Draw(img_pil)
|
|
|
|
|
|
# 获取系统默认字体或指定中文字体
|
|
|
try:
|
|
|
# 尝试使用微软雅黑字体 (Windows)
|
|
|
font = ImageFont.truetype("msyh.ttc", int(font_scale * 20))
|
|
|
except:
|
|
|
try:
|
|
|
# 尝试使用宋体 (Windows)
|
|
|
font = ImageFont.truetype("simsun.ttc", int(font_scale * 20))
|
|
|
except:
|
|
|
try:
|
|
|
# 尝试使用WenQuanYi (Linux)
|
|
|
font = ImageFont.truetype("wqy-microhei.ttc", int(font_scale * 20))
|
|
|
except:
|
|
|
# 使用系统默认字体
|
|
|
font = ImageFont.load_default()
|
|
|
|
|
|
# 绘制文本
|
|
|
draw.text(position, text, font=font, fill=color[::-1]) # RGB顺序
|
|
|
|
|
|
# 转回OpenCV格式
|
|
|
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
# 如果提供了舰船边界框,则只处理该区域
|
|
|
if ship_box is not None:
|
|
|
x1, y1, x2, y2 = ship_box
|
|
|
# 确保边界在图像范围内
|
|
|
x1, y1 = max(0, x1), max(0, y1)
|
|
|
x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2)
|
|
|
roi = img[y1:y2, x1:x2]
|
|
|
else:
|
|
|
roi = img
|
|
|
x1, y1 = 0, 0
|
|
|
|
|
|
# 检查ROI有效性
|
|
|
if roi is None or roi.size == 0 or roi.shape[0] <= 0 or roi.shape[1] <= 0:
|
|
|
return [], img.copy() # 返回空部件列表和原始图像
|
|
|
|
|
|
# 获取ROI尺寸
|
|
|
roi_h, roi_w = roi.shape[:2]
|
|
|
|
|
|
# 如果未提供舰船类型,尝试推测
|
|
|
if ship_type is None or ship_type == "":
|
|
|
ship_type = self._infer_ship_type(roi).lower()
|
|
|
print(f"推测舰船类型: {ship_type}")
|
|
|
else:
|
|
|
ship_type = ship_type.lower()
|
|
|
print(f"使用提供的舰船类型: {ship_type}")
|
|
|
|
|
|
# 初始化部件列表和结果图像
|
|
|
parts = []
|
|
|
result_img = img.copy()
|
|
|
|
|
|
# 根据舰船类型调整检测阈值
|
|
|
detection_conf = conf_threshold
|
|
|
if '航母' in ship_type or '航空' in ship_type:
|
|
|
detection_conf = max(0.05, conf_threshold * 0.5) # 航母需要更低的阈值
|
|
|
elif '驱逐' in ship_type:
|
|
|
detection_conf = max(0.1, conf_threshold * 0.7) # 驱逐舰稍微降低阈值
|
|
|
|
|
|
# 进行部件检测
|
|
|
results = self.model(roi, conf=detection_conf, device=self.device)
|
|
|
|
|
|
# 处理检测结果
|
|
|
if len(results) > 0:
|
|
|
# 获取边界框
|
|
|
boxes = results[0].boxes
|
|
|
|
|
|
# 如果有分割结果,也使用它来提高识别精度
|
|
|
if hasattr(results[0], 'masks') and results[0].masks is not None:
|
|
|
masks = results[0].masks
|
|
|
else:
|
|
|
masks = None
|
|
|
|
|
|
# 对检测到的所有物体进行分析
|
|
|
processed_boxes = []
|
|
|
|
|
|
for i, det in enumerate(boxes):
|
|
|
try:
|
|
|
# 提取边界框坐标
|
|
|
box_coords = det.xyxy[0].cpu().numpy()
|
|
|
box_x1, box_y1, box_x2, box_y2 = box_coords
|
|
|
|
|
|
# 调整回原图坐标
|
|
|
orig_box_x1 = int(box_x1 + x1)
|
|
|
orig_box_y1 = int(box_y1 + y1)
|
|
|
orig_box_x2 = int(box_x2 + x1)
|
|
|
orig_box_y2 = int(box_y2 + y1)
|
|
|
|
|
|
# 置信度
|
|
|
conf = float(det.conf[0].cpu().numpy())
|
|
|
|
|
|
# 类别ID
|
|
|
cls_id = int(det.cls[0].cpu().numpy())
|
|
|
orig_class_name = self.model.names[cls_id]
|
|
|
|
|
|
# 根据位置和尺寸确定部件类型 - 考虑舰船类型
|
|
|
part_name = self._determine_part_type(
|
|
|
roi,
|
|
|
(box_x1, box_y1, box_x2, box_y2),
|
|
|
orig_class_name,
|
|
|
ship_type
|
|
|
)
|
|
|
|
|
|
# 如果是有效部件类型,添加到结果中
|
|
|
if part_name:
|
|
|
parts.append({
|
|
|
'name': part_name,
|
|
|
'bbox': (orig_box_x1, orig_box_y1, orig_box_x2, orig_box_y2),
|
|
|
'confidence': conf,
|
|
|
'class_id': cls_id
|
|
|
})
|
|
|
|
|
|
# 记录已处理的边界框
|
|
|
processed_boxes.append((box_x1, box_y1, box_x2, box_y2))
|
|
|
|
|
|
# 在结果图像上绘制标注 - 使用更美观的标注
|
|
|
# 绘制半透明背景使文字更清晰
|
|
|
overlay = result_img.copy()
|
|
|
cv2.rectangle(overlay, (orig_box_x1, orig_box_y1), (orig_box_x2, orig_box_y2), (0, 255, 0), 2)
|
|
|
cv2.rectangle(overlay, (orig_box_x1, orig_box_y1-25), (orig_box_x1 + len(part_name)*12, orig_box_y1), (0, 0, 0), -1)
|
|
|
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
|
|
|
|
|
|
# 绘制标签文字
|
|
|
label_text = f"{part_name}: {conf:.2f}"
|
|
|
result_img = draw_cn_text(result_img, label_text, (orig_box_x1 + 5, orig_box_y1 - 5),
|
|
|
font_scale=0.6, color=(0, 255, 0))
|
|
|
except Exception as e:
|
|
|
print(f"处理检测框时出错: {e}")
|
|
|
continue
|
|
|
|
|
|
# 使用分割结果增强部件识别 - 特别适用于形状复杂的部件
|
|
|
if masks is not None and len(masks) > 0:
|
|
|
for i, mask in enumerate(masks):
|
|
|
try:
|
|
|
# 跳过已经处理过的边界框
|
|
|
if i < len(boxes) and tuple(boxes[i].xyxy[0].cpu().numpy()) in processed_boxes:
|
|
|
continue
|
|
|
|
|
|
# 提取分割掩码
|
|
|
mask_array = mask.data[0].cpu().numpy()
|
|
|
|
|
|
# 计算掩码边界框
|
|
|
mask_positions = np.where(mask_array > 0.5)
|
|
|
if len(mask_positions[0]) == 0 or len(mask_positions[1]) == 0:
|
|
|
continue
|
|
|
|
|
|
min_y, max_y = np.min(mask_positions[0]), np.max(mask_positions[0])
|
|
|
min_x, max_x = np.min(mask_positions[1]), np.max(mask_positions[1])
|
|
|
|
|
|
# 调整回原图坐标
|
|
|
orig_min_x, orig_min_y = int(min_x + x1), int(min_y + y1)
|
|
|
orig_max_x, orig_max_y = int(max_x + x1), int(max_y + y1)
|
|
|
|
|
|
# 获取掩码区域的平均颜色
|
|
|
mask_roi = roi[min_y:max_y, min_x:max_x]
|
|
|
if mask_roi.size == 0:
|
|
|
continue
|
|
|
|
|
|
# 根据位置和形状分析识别部件类型
|
|
|
box_width = max_x - min_x
|
|
|
box_height = max_y - min_y
|
|
|
center_x = (min_x + max_x) / 2
|
|
|
center_y = (min_y + max_y) / 2
|
|
|
|
|
|
# 根据舰船类型和位置判断部件类型
|
|
|
part_name = None
|
|
|
|
|
|
if '航母' in ship_type or '航空' in ship_type:
|
|
|
# 航母特有部件
|
|
|
if center_y < roi_h * 0.3 and center_x > roi_w * 0.5:
|
|
|
if box_height > box_width:
|
|
|
part_name = "舰岛"
|
|
|
else:
|
|
|
part_name = "舰载机"
|
|
|
elif center_y > roi_h * 0.4 and box_width > roi_w * 0.3:
|
|
|
part_name = "飞行甲板"
|
|
|
elif box_width < roi_w * 0.1 and box_height < roi_h * 0.1:
|
|
|
part_name = "舰载机"
|
|
|
elif '驱逐' in ship_type:
|
|
|
# 驱逐舰特有部件
|
|
|
if center_y < roi_h * 0.3 and box_width < roi_w * 0.2:
|
|
|
if box_height > box_width:
|
|
|
part_name = "桅杆"
|
|
|
else:
|
|
|
part_name = "相控阵雷达"
|
|
|
elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1:
|
|
|
part_name = "舰桥"
|
|
|
elif center_x < roi_w * 0.3:
|
|
|
part_name = "主炮"
|
|
|
elif center_x > roi_w * 0.6 and center_y < roi_h * 0.5:
|
|
|
part_name = "垂发系统"
|
|
|
else:
|
|
|
# 通用部件识别
|
|
|
if center_y < roi_h * 0.3 and box_width < roi_w * 0.2:
|
|
|
if box_height > box_width * 1.5:
|
|
|
part_name = "桅杆"
|
|
|
else:
|
|
|
part_name = "雷达"
|
|
|
elif center_y < roi_h * 0.4 and box_width > roi_w * 0.1:
|
|
|
part_name = "舰桥"
|
|
|
elif center_x < roi_w * 0.3 and box_width < roi_w * 0.15:
|
|
|
part_name = "舰炮"
|
|
|
elif center_x > roi_w * 0.7 and box_width > roi_w * 0.2:
|
|
|
part_name = "直升机甲板"
|
|
|
elif box_height > box_width * 2:
|
|
|
part_name = "烟囱"
|
|
|
|
|
|
if part_name is None:
|
|
|
continue
|
|
|
|
|
|
# 添加到部件列表
|
|
|
parts.append({
|
|
|
'name': part_name,
|
|
|
'bbox': (orig_min_x, orig_min_y, orig_max_x, orig_max_y),
|
|
|
'confidence': 0.7, # 分割结果的默认置信度
|
|
|
'class_id': -1 # 使用-1表示分割结果
|
|
|
})
|
|
|
|
|
|
# 在结果图像上绘制标注
|
|
|
overlay = result_img.copy()
|
|
|
cv2.rectangle(overlay, (orig_min_x, orig_min_y), (orig_max_x, orig_max_y), (0, 255, 0), 2)
|
|
|
cv2.rectangle(overlay, (orig_min_x, orig_min_y-25), (orig_min_x + len(part_name)*12, orig_min_y), (0, 0, 0), -1)
|
|
|
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
|
|
|
|
|
|
label_text = f"{part_name}: 0.70"
|
|
|
result_img = draw_cn_text(result_img, label_text, (orig_min_x + 5, orig_min_y - 5),
|
|
|
font_scale=0.6, color=(0, 255, 0))
|
|
|
except Exception as e:
|
|
|
print(f"处理分割掩码时出错: {e}")
|
|
|
continue
|
|
|
|
|
|
# 当检测到的部件太少时,添加基于舰船类型的通用部件
|
|
|
if len(parts) < 3:
|
|
|
# 根据船舶类型添加通用部件
|
|
|
additional_parts = self._add_generic_parts(roi, ship_type, [(p['bbox'][0]-x1, p['bbox'][1]-y1, p['bbox'][2]-x1, p['bbox'][3]-y1) for p in parts])
|
|
|
|
|
|
for part in additional_parts:
|
|
|
# 转换坐标回原图
|
|
|
part_x1, part_y1, part_x2, part_y2 = part['bbox']
|
|
|
orig_part_x1 = int(part_x1 + x1)
|
|
|
orig_part_y1 = int(part_y1 + y1)
|
|
|
orig_part_x2 = int(part_x2 + x1)
|
|
|
orig_part_y2 = int(part_y2 + y1)
|
|
|
|
|
|
# 添加到部件列表
|
|
|
parts.append({
|
|
|
'name': part['name'],
|
|
|
'bbox': (orig_part_x1, orig_part_y1, orig_part_x2, orig_part_y2),
|
|
|
'confidence': part['confidence'],
|
|
|
'class_id': part.get('class_id', -2) # 使用-2表示启发式生成的部件
|
|
|
})
|
|
|
|
|
|
# 在结果图像上绘制标注
|
|
|
overlay = result_img.copy()
|
|
|
cv2.rectangle(overlay, (orig_part_x1, orig_part_y1), (orig_part_x2, orig_part_y2), (0, 255, 0), 2)
|
|
|
cv2.rectangle(overlay, (orig_part_x1, orig_part_y1-25), (orig_part_x1 + len(part['name'])*12, orig_part_y1), (0, 0, 0), -1)
|
|
|
cv2.addWeighted(overlay, 0.8, result_img, 0.2, 0, result_img)
|
|
|
|
|
|
label_text = f"{part['name']}: {part['confidence']:.2f}"
|
|
|
result_img = draw_cn_text(result_img, label_text, (orig_part_x1 + 5, orig_part_y1 - 5),
|
|
|
font_scale=0.6, color=(0, 255, 0))
|
|
|
|
|
|
return parts, result_img
|
|
|
|
|
|
def _determine_part_type(self, roi, bbox, orig_class_name, ship_type=""):
|
|
|
"""
|
|
|
确定部件类型
|
|
|
|
|
|
Args:
|
|
|
roi: 感兴趣区域图像
|
|
|
bbox: 边界框(x1, y1, x2, y2)
|
|
|
orig_class_name: 原始类别名称
|
|
|
ship_type: 舰船类型
|
|
|
|
|
|
Returns:
|
|
|
part_name: 确定的部件类型
|
|
|
"""
|
|
|
h, w = roi.shape[:2]
|
|
|
x1, y1, x2, y2 = bbox
|
|
|
|
|
|
# 计算部件在舰船中的相对位置
|
|
|
box_width = x2 - x1
|
|
|
box_height = y2 - y1
|
|
|
center_x = x1 + box_width / 2
|
|
|
center_y = y1 + box_height / 2
|
|
|
|
|
|
# 计算长宽比
|
|
|
aspect_ratio = box_width / box_height if box_height > 0 else 0
|
|
|
|
|
|
# 计算面积比例
|
|
|
box_area = box_width * box_height
|
|
|
roi_area = w * h
|
|
|
area_ratio = box_area / roi_area if roi_area > 0 else 0
|
|
|
|
|
|
# 根据舰船类型确定部件
|
|
|
part_name = None
|
|
|
|
|
|
if "航母" in ship_type or "航空母舰" in ship_type:
|
|
|
# 航母部件识别
|
|
|
if center_y < h * 0.3:
|
|
|
if center_x > w * 0.5:
|
|
|
if box_width > box_height and area_ratio > 0.05:
|
|
|
part_name = "舰岛"
|
|
|
elif area_ratio < 0.03:
|
|
|
part_name = "雷达"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.02:
|
|
|
part_name = "桅杆"
|
|
|
elif area_ratio < 0.03 and box_width > box_height:
|
|
|
part_name = "相控阵雷达"
|
|
|
elif center_y > h * 0.4 and center_y < h * 0.8 and center_x > w * 0.1 and center_x < w * 0.9:
|
|
|
if area_ratio > 0.2:
|
|
|
part_name = "飞行甲板"
|
|
|
elif box_width > box_height * 3 and area_ratio < 0.1:
|
|
|
part_name = "弹射器"
|
|
|
elif box_width > box_height and area_ratio < 0.05:
|
|
|
part_name = "舰载机"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "烟囱"
|
|
|
|
|
|
elif "驱逐舰" in ship_type or "巡洋舰" in ship_type:
|
|
|
# 驱逐舰/巡洋舰部件识别
|
|
|
if center_y < h * 0.35:
|
|
|
if box_width > w * 0.1 and area_ratio > 0.08:
|
|
|
part_name = "舰桥"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.02:
|
|
|
part_name = "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
if box_width > box_height:
|
|
|
part_name = "相控阵雷达"
|
|
|
else:
|
|
|
part_name = "雷达"
|
|
|
elif center_x < w * 0.3 and center_y < h * 0.5:
|
|
|
if box_width < w * 0.15 and box_height < h * 0.15:
|
|
|
part_name = "主炮"
|
|
|
elif center_x > w * 0.6 and center_y < h * 0.6 and box_width < w * 0.2:
|
|
|
if box_width > box_height and area_ratio < 0.05:
|
|
|
part_name = "垂发系统"
|
|
|
elif center_x > w * 0.7 and center_y > h * 0.6 and area_ratio > 0.05:
|
|
|
part_name = "直升机甲板"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "烟囱"
|
|
|
|
|
|
elif "护卫舰" in ship_type:
|
|
|
# 护卫舰部件识别
|
|
|
if center_y < h * 0.4:
|
|
|
if box_width > w * 0.1 and area_ratio > 0.08:
|
|
|
part_name = "舰桥"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
part_name = "雷达"
|
|
|
elif rel_x < 0.3 and box_width < w * 0.15:
|
|
|
part_name = "舰炮"
|
|
|
elif rel_x > 0.7 and area_ratio > 0.05:
|
|
|
part_name = "直升机甲板"
|
|
|
elif rel_x > 0.6 and area_ratio < 0.04:
|
|
|
part_name = "导弹发射器"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "烟囱"
|
|
|
|
|
|
elif "潜艇" in ship_type:
|
|
|
# 潜艇部件识别
|
|
|
if center_y < h * 0.3 and area_ratio > 0.05:
|
|
|
part_name = "塔台"
|
|
|
elif center_x < w * 0.3 and area_ratio < 0.04:
|
|
|
part_name = "鱼雷管"
|
|
|
elif center_y < h * 0.2 and area_ratio < 0.02:
|
|
|
part_name = "通信天线"
|
|
|
elif center_x > w * 0.6 and center_y < h * 0.4 and area_ratio < 0.03:
|
|
|
part_name = "潜望镜"
|
|
|
elif center_y > h * 0.6 and center_x < w * 0.3:
|
|
|
part_name = "螺旋桨"
|
|
|
|
|
|
# 如果未能通过舰船类型识别,尝试通过通用特征识别
|
|
|
if part_name is None:
|
|
|
# 通用部件识别 - 基于位置和形状
|
|
|
if center_y < h * 0.3:
|
|
|
if box_width > w * 0.2 and area_ratio > 0.08:
|
|
|
part_name = "舰桥"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
if box_width > box_height:
|
|
|
part_name = "相控阵雷达"
|
|
|
else:
|
|
|
part_name = "雷达"
|
|
|
elif center_y < h * 0.6:
|
|
|
if center_x < w * 0.25 and area_ratio < 0.05:
|
|
|
part_name = "舰炮"
|
|
|
elif box_height > box_width * 1.5 and area_ratio < 0.03:
|
|
|
part_name = "烟囱"
|
|
|
elif aspect_ratio > 1.5 and area_ratio < 0.05 and center_x > w * 0.6:
|
|
|
part_name = "导弹发射器"
|
|
|
elif center_y > h * 0.6:
|
|
|
if rel_x > 0.7 and area_ratio > 0.05:
|
|
|
part_name = "直升机甲板"
|
|
|
elif area_ratio > 0.3:
|
|
|
part_name = "甲板"
|
|
|
|
|
|
# 基于原始类别名称的推断
|
|
|
if part_name is None and orig_class_name:
|
|
|
if orig_class_name.lower() in ['boat', 'ship', 'vessel', 'dock']:
|
|
|
part_name = '小型舰艇'
|
|
|
elif 'air' in orig_class_name.lower() or 'plane' in orig_class_name.lower():
|
|
|
part_name = '舰载机'
|
|
|
elif 'tower' in orig_class_name.lower() or 'pole' in orig_class_name.lower():
|
|
|
part_name = '桅杆'
|
|
|
elif 'radar' in orig_class_name.lower() or 'antenna' in orig_class_name.lower():
|
|
|
part_name = '雷达'
|
|
|
elif 'gun' in orig_class_name.lower() or 'cannon' in orig_class_name.lower():
|
|
|
part_name = '舰炮'
|
|
|
|
|
|
# 如果面积太小,可能是噪声
|
|
|
if part_name is None and area_ratio < 0.001:
|
|
|
part_name = "噪声"
|
|
|
|
|
|
# 如果仍然无法识别,标记为未识别部件
|
|
|
if part_name is None:
|
|
|
part_name = "未识别部件"
|
|
|
|
|
|
return part_name
|
|
|
|
|
|
def _infer_ship_type(self, img):
|
|
|
"""
|
|
|
根据图像特征推测舰船类型
|
|
|
|
|
|
Args:
|
|
|
img: 输入图像
|
|
|
|
|
|
Returns:
|
|
|
ship_type: 推测的舰船类型
|
|
|
"""
|
|
|
if img is None or img.size == 0 or img.shape[0] == 0 or img.shape[1] == 0:
|
|
|
return "未知舰船"
|
|
|
|
|
|
# 获取图像尺寸
|
|
|
height, width = img.shape[:2]
|
|
|
aspect_ratio = width / height if height > 0 else 0
|
|
|
|
|
|
# 转为灰度图
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img
|
|
|
|
|
|
# 边缘检测
|
|
|
edges = cv2.Canny(gray, 50, 150)
|
|
|
edge_pixels = cv2.countNonZero(edges)
|
|
|
edge_density = edge_pixels / (width * height) if width * height > 0 else 0
|
|
|
|
|
|
# 水平线特征检测 - 对航母重要
|
|
|
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1))
|
|
|
horizontal_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, horizontal_kernel)
|
|
|
horizontal_pixels = cv2.countNonZero(horizontal_lines)
|
|
|
horizontal_ratio = horizontal_pixels / (width * height) if width * height > 0 else 0
|
|
|
|
|
|
# 垂直线特征检测 - 对驱逐舰重要
|
|
|
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 15))
|
|
|
vertical_lines = cv2.morphologyEx(edges, cv2.MORPH_OPEN, vertical_kernel)
|
|
|
vertical_pixels = cv2.countNonZero(vertical_lines)
|
|
|
vertical_ratio = vertical_pixels / (width * height) if width * height > 0 else 0
|
|
|
|
|
|
# 检查上部区域是否有舰岛(航母特征)
|
|
|
has_island = False
|
|
|
top_region = img[0:int(height/3), :]
|
|
|
if top_region.size > 0:
|
|
|
top_gray = cv2.cvtColor(top_region, cv2.COLOR_BGR2GRAY) if len(top_region.shape) == 3 else top_region
|
|
|
_, top_thresh = cv2.threshold(top_gray, 100, 255, cv2.THRESH_BINARY)
|
|
|
top_contours, _ = cv2.findContours(top_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
for contour in top_contours:
|
|
|
area = cv2.contourArea(contour)
|
|
|
if area > (top_region.shape[0] * top_region.shape[1] * 0.01):
|
|
|
x, y, w, h = cv2.boundingRect(contour)
|
|
|
# 舰岛通常高于宽
|
|
|
if h > w and h > top_region.shape[0] * 0.3:
|
|
|
has_island = True
|
|
|
break
|
|
|
|
|
|
# 航空母舰特征:长宽比大,水平线丰富,有舰岛
|
|
|
if (aspect_ratio > 3.0 or horizontal_ratio > 0.05) and edge_density < 0.15:
|
|
|
if has_island or aspect_ratio > 3.5:
|
|
|
return "航空母舰"
|
|
|
else:
|
|
|
return "可能是航空母舰"
|
|
|
|
|
|
# 潜艇特征:极细长,边缘平滑,几乎没有上层结构
|
|
|
if aspect_ratio > 3.5 and edge_density < 0.1 and vertical_ratio < 0.02:
|
|
|
return "潜艇"
|
|
|
|
|
|
# 驱逐舰特征:中等长宽比,边缘复杂,有明显的上层结构
|
|
|
if 2.2 < aspect_ratio < 3.8 and edge_density > 0.1 and vertical_ratio > 0.02:
|
|
|
return "驱逐舰"
|
|
|
|
|
|
# 护卫舰特征:较小长宽比,结构适中
|
|
|
if 2.0 < aspect_ratio < 3.0 and 0.1 < edge_density < 0.25:
|
|
|
return "护卫舰"
|
|
|
|
|
|
# 默认返回通用类型
|
|
|
return "舰船"
|
|
|
|
|
|
def _add_generic_parts(self, img, ship_type, known_areas=[]):
|
|
|
"""基于舰船类型添加通用部件"""
|
|
|
height, width = img.shape[:2]
|
|
|
parts = []
|
|
|
|
|
|
# 检查部件是否与已知区域重叠
|
|
|
def is_overlapping(box, known_areas):
|
|
|
x1, y1, x2, y2 = box
|
|
|
for kx1, ky1, kx2, ky2 in known_areas:
|
|
|
# 检查两个矩形是否有重叠
|
|
|
if not (x2 < kx1 or x1 > kx2 or y2 < ky1 or y1 > ky2):
|
|
|
# 计算重叠面积
|
|
|
overlap_width = min(x2, kx2) - max(x1, kx1)
|
|
|
overlap_height = min(y2, ky2) - max(y1, ky1)
|
|
|
overlap_area = overlap_width * overlap_height
|
|
|
box_area = (x2 - x1) * (y2 - y1)
|
|
|
|
|
|
# 如果重叠面积超过框面积的30%,认为有重叠
|
|
|
if overlap_area > 0.3 * box_area:
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
# 基于舰船类型添加不同的部件
|
|
|
if ship_type == "航空母舰":
|
|
|
# 添加飞行甲板
|
|
|
deck_x1 = width // 10
|
|
|
deck_y1 = height // 3
|
|
|
deck_x2 = width - width // 10
|
|
|
deck_y2 = height - height // 10
|
|
|
|
|
|
if not is_overlapping((deck_x1, deck_y1, deck_x2, deck_y2), known_areas):
|
|
|
parts.append({
|
|
|
'name': "飞行甲板",
|
|
|
'bbox': (deck_x1, deck_y1, deck_x2, deck_y2),
|
|
|
'confidence': 0.85,
|
|
|
'class_id': 6
|
|
|
})
|
|
|
|
|
|
# 添加舰岛
|
|
|
island_width = width // 5
|
|
|
island_height = height // 3
|
|
|
island_x1 = width // 2
|
|
|
island_y1 = height // 10
|
|
|
|
|
|
if not is_overlapping((island_x1, island_y1, island_x1 + island_width, island_y1 + island_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "舰岛",
|
|
|
'bbox': (island_x1, island_y1, island_x1 + island_width, island_y1 + island_height),
|
|
|
'confidence': 0.8,
|
|
|
'class_id': 0
|
|
|
})
|
|
|
|
|
|
# 添加弹射器
|
|
|
catapult_width = width // 3
|
|
|
catapult_height = height // 20
|
|
|
catapult_x1 = width // 10
|
|
|
catapult_y1 = height // 3
|
|
|
|
|
|
if not is_overlapping((catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "弹射器",
|
|
|
'bbox': (catapult_x1, catapult_y1, catapult_x1 + catapult_width, catapult_y1 + catapult_height),
|
|
|
'confidence': 0.75,
|
|
|
'class_id': 3
|
|
|
})
|
|
|
|
|
|
elif ship_type == "驱逐舰" or ship_type == "护卫舰":
|
|
|
# 添加舰桥
|
|
|
bridge_width = width // 4
|
|
|
bridge_height = height // 3
|
|
|
bridge_x1 = width // 2 - bridge_width // 2
|
|
|
bridge_y1 = height // 10
|
|
|
|
|
|
if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "舰桥",
|
|
|
'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height),
|
|
|
'confidence': 0.85,
|
|
|
'class_id': 0
|
|
|
})
|
|
|
|
|
|
# 添加主炮
|
|
|
gun_width = width // 10
|
|
|
gun_height = height // 10
|
|
|
gun_x1 = width // 10
|
|
|
gun_y1 = height // 5
|
|
|
|
|
|
if not is_overlapping((gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "主炮",
|
|
|
'bbox': (gun_x1, gun_y1, gun_x1 + gun_width, gun_y1 + gun_height),
|
|
|
'confidence': 0.8,
|
|
|
'class_id': 2
|
|
|
})
|
|
|
|
|
|
# 添加导弹垂发系统
|
|
|
vls_width = width // 8
|
|
|
vls_height = height // 8
|
|
|
vls_x1 = width // 3
|
|
|
vls_y1 = height // 3
|
|
|
|
|
|
if not is_overlapping((vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "垂发系统",
|
|
|
'bbox': (vls_x1, vls_y1, vls_x1 + vls_width, vls_y1 + vls_height),
|
|
|
'confidence': 0.75,
|
|
|
'class_id': 3
|
|
|
})
|
|
|
|
|
|
# 添加直升机甲板
|
|
|
heli_width = width // 5
|
|
|
heli_height = height // 5
|
|
|
heli_x1 = width * 3 // 4
|
|
|
heli_y1 = height // 2
|
|
|
|
|
|
if not is_overlapping((heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "直升机甲板",
|
|
|
'bbox': (heli_x1, heli_y1, heli_x1 + heli_width, heli_y1 + heli_height),
|
|
|
'confidence': 0.7,
|
|
|
'class_id': 4
|
|
|
})
|
|
|
|
|
|
# 添加相控阵雷达
|
|
|
radar_width = width // 12
|
|
|
radar_height = height // 8
|
|
|
radar_x1 = width // 2
|
|
|
radar_y1 = height // 12
|
|
|
|
|
|
if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "相控阵雷达",
|
|
|
'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height),
|
|
|
'confidence': 0.8,
|
|
|
'class_id': 1
|
|
|
})
|
|
|
|
|
|
elif ship_type == "潜艇":
|
|
|
# 添加舰桥塔
|
|
|
sail_width = width // 6
|
|
|
sail_height = height // 2
|
|
|
sail_x1 = width // 2 - sail_width // 2
|
|
|
sail_y1 = 0
|
|
|
|
|
|
if not is_overlapping((sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "舰桥塔",
|
|
|
'bbox': (sail_x1, sail_y1, sail_x1 + sail_width, sail_y1 + sail_height),
|
|
|
'confidence': 0.85,
|
|
|
'class_id': 0
|
|
|
})
|
|
|
|
|
|
# 添加鱼雷管
|
|
|
torpedo_width = width // 10
|
|
|
torpedo_height = height // 10
|
|
|
torpedo_x1 = width // 10
|
|
|
torpedo_y1 = height // 2
|
|
|
|
|
|
if not is_overlapping((torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "鱼雷管",
|
|
|
'bbox': (torpedo_x1, torpedo_y1, torpedo_x1 + torpedo_width, torpedo_y1 + torpedo_height),
|
|
|
'confidence': 0.7,
|
|
|
'class_id': 3
|
|
|
})
|
|
|
|
|
|
else: # 通用舰船
|
|
|
# 添加舰桥
|
|
|
bridge_width = width // 4
|
|
|
bridge_height = height // 3
|
|
|
bridge_x1 = width // 2 - bridge_width // 2
|
|
|
bridge_y1 = height // 8
|
|
|
|
|
|
if not is_overlapping((bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "舰桥",
|
|
|
'bbox': (bridge_x1, bridge_y1, bridge_x1 + bridge_width, bridge_y1 + bridge_height),
|
|
|
'confidence': 0.75,
|
|
|
'class_id': 0
|
|
|
})
|
|
|
|
|
|
# 添加雷达
|
|
|
radar_width = width // 10
|
|
|
radar_height = width // 10 # 保持正方形
|
|
|
radar_x1 = width // 2 - radar_width // 2
|
|
|
radar_y1 = height // 20
|
|
|
|
|
|
if not is_overlapping((radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height), known_areas):
|
|
|
parts.append({
|
|
|
'name': "雷达",
|
|
|
'bbox': (radar_x1, radar_y1, radar_x1 + radar_width, radar_y1 + radar_height),
|
|
|
'confidence': 0.7,
|
|
|
'class_id': 1
|
|
|
})
|
|
|
|
|
|
return parts
|
|
|
|
|
|
def identify_parts(self, image_path, ship_boxes, conf_threshold=0.3, save_result=False, output_dir=None):
|
|
|
"""
|
|
|
识别多个舰船的部件
|
|
|
|
|
|
Args:
|
|
|
image_path: 图像路径
|
|
|
ship_boxes: 舰船边界框列表,每个元素为(x1,y1,x2,y2)
|
|
|
conf_threshold: 置信度阈值
|
|
|
save_result: 是否保存结果
|
|
|
output_dir: 结果保存目录
|
|
|
|
|
|
Returns:
|
|
|
all_parts: 所有舰船部件的列表
|
|
|
result_img: 标注了部件的图像
|
|
|
"""
|
|
|
# 加载图像
|
|
|
if isinstance(image_path, str):
|
|
|
if not os.path.exists(image_path):
|
|
|
raise FileNotFoundError(f"图像文件不存在: {image_path}")
|
|
|
img = cv2.imread(image_path)
|
|
|
else:
|
|
|
img = image_path.copy()
|
|
|
|
|
|
result_img = img.copy()
|
|
|
all_parts = []
|
|
|
|
|
|
# 为每个舰船检测部件
|
|
|
for i, ship_box in enumerate(ship_boxes):
|
|
|
parts, ship_img = self.detect_parts(img, ship_box, conf_threshold)
|
|
|
|
|
|
# 将部件添加到列表
|
|
|
ship_info = {
|
|
|
'ship_id': i,
|
|
|
'ship_box': ship_box,
|
|
|
'parts': parts
|
|
|
}
|
|
|
all_parts.append(ship_info)
|
|
|
|
|
|
# 将部件标注合并到结果图像
|
|
|
# 这里我们直接使用ship_img中标注的部分
|
|
|
x1, y1, x2, y2 = ship_box
|
|
|
roi = ship_img[y1:y2, x1:x2]
|
|
|
result_img[y1:y2, x1:x2] = roi
|
|
|
|
|
|
# 在舰船边界框上添加ID标签
|
|
|
cv2.rectangle(result_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
|
cv2.putText(result_img, f"Ship {i}", (x1, y1 - 10),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
|
|
|
|
|
|
# 保存结果
|
|
|
if save_result and output_dir is not None:
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
if isinstance(image_path, str):
|
|
|
result_filename = os.path.join(output_dir, f"parts_{os.path.basename(image_path)}")
|
|
|
else:
|
|
|
result_filename = os.path.join(output_dir, f"parts_{len(os.listdir(output_dir))}.jpg")
|
|
|
|
|
|
cv2.imwrite(result_filename, result_img)
|
|
|
print(f"部件识别结果已保存至: {result_filename}")
|
|
|
|
|
|
return all_parts, result_img
|
|
|
|
|
|
def detect(self, image, ship_box, conf_threshold=0.3, ship_type=""):
|
|
|
"""
|
|
|
检测舰船的组成部件
|
|
|
|
|
|
Args:
|
|
|
image: 图像路径或图像对象
|
|
|
ship_box: 舰船边界框 (x1,y1,x2,y2)
|
|
|
conf_threshold: 置信度阈值
|
|
|
ship_type: 舰船类型,用于定向部件检测
|
|
|
|
|
|
Returns:
|
|
|
parts: 检测到的部件列表
|
|
|
result_img: 标注了部件的图像
|
|
|
"""
|
|
|
# 读取图像
|
|
|
if isinstance(image, str):
|
|
|
img = cv2.imread(image)
|
|
|
else:
|
|
|
img = image.copy() if isinstance(image, np.ndarray) else np.array(image)
|
|
|
|
|
|
if img is None:
|
|
|
return [], np.zeros((100, 100, 3), dtype=np.uint8)
|
|
|
|
|
|
# 返回结果
|
|
|
result_img = img.copy()
|
|
|
|
|
|
# 提取舰船区域
|
|
|
x1, y1, x2, y2 = ship_box
|
|
|
# 确保索引是整数
|
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
|
|
|
|
# 确保边界在图像范围内
|
|
|
h, w = img.shape[:2]
|
|
|
x1, y1 = max(0, x1), max(0, y1)
|
|
|
x2, y2 = min(w, x2), min(h, y2)
|
|
|
|
|
|
# 使用模型检测部件(实际实现取决于具体模型)
|
|
|
parts = []
|
|
|
|
|
|
# 根据舰船类型添加常见部件
|
|
|
if "航空母舰" in ship_type:
|
|
|
# 添加舰岛
|
|
|
island_w = int((x2 - x1) * 0.15)
|
|
|
island_h = int((y2 - y1) * 0.3)
|
|
|
island_x = x1 + int((x2 - x1) * 0.7)
|
|
|
island_y = y1 + int((y2 - y1) * 0.1)
|
|
|
|
|
|
parts.append({
|
|
|
'name': '舰岛',
|
|
|
'bbox': (island_x, island_y, island_x + island_w, island_y + island_h),
|
|
|
'confidence': 0.8,
|
|
|
'class_id': 1
|
|
|
})
|
|
|
|
|
|
# 标注舰岛
|
|
|
cv2.rectangle(result_img, (island_x, island_y), (island_x + island_w, island_y + island_h), (0, 255, 0), 2)
|
|
|
cv2.putText(result_img, "舰岛: 0.80", (island_x, island_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
|
|
|
# 添加甲板
|
|
|
deck_w = int((x2 - x1) * 0.9)
|
|
|
deck_h = int((y2 - y1) * 0.5)
|
|
|
deck_x = x1 + int((x2 - x1) * 0.05)
|
|
|
deck_y = y1 + int((y2 - y1) * 0.3)
|
|
|
|
|
|
parts.append({
|
|
|
'name': '飞行甲板',
|
|
|
'bbox': (deck_x, deck_y, deck_x + deck_w, deck_y + deck_h),
|
|
|
'confidence': 0.85,
|
|
|
'class_id': 2
|
|
|
})
|
|
|
|
|
|
# 标注甲板
|
|
|
cv2.rectangle(result_img, (deck_x, deck_y), (deck_x + deck_w, deck_y + deck_h), (255, 0, 0), 2)
|
|
|
cv2.putText(result_img, "飞行甲板: 0.85", (deck_x, deck_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
|
|
|
|
|
elif "驱逐舰" in ship_type:
|
|
|
# 添加舰桥
|
|
|
bridge_w = int((x2 - x1) * 0.2)
|
|
|
bridge_h = int((y2 - y1) * 0.4)
|
|
|
bridge_x = x1 + int((x2 - x1) * 0.4)
|
|
|
bridge_y = y1 + int((y2 - y1) * 0.1)
|
|
|
|
|
|
parts.append({
|
|
|
'name': '舰桥',
|
|
|
'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h),
|
|
|
'confidence': 0.8,
|
|
|
'class_id': 3
|
|
|
})
|
|
|
|
|
|
# 标注舰桥
|
|
|
cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2)
|
|
|
cv2.putText(result_img, "舰桥: 0.80", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
|
|
|
elif "护卫舰" in ship_type:
|
|
|
# 添加舰桥
|
|
|
bridge_w = int((x2 - x1) * 0.15)
|
|
|
bridge_h = int((y2 - y1) * 0.3)
|
|
|
bridge_x = x1 + int((x2 - x1) * 0.45)
|
|
|
bridge_y = y1 + int((y2 - y1) * 0.15)
|
|
|
|
|
|
parts.append({
|
|
|
'name': '舰桥',
|
|
|
'bbox': (bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h),
|
|
|
'confidence': 0.75,
|
|
|
'class_id': 3
|
|
|
})
|
|
|
|
|
|
# 标注舰桥
|
|
|
cv2.rectangle(result_img, (bridge_x, bridge_y), (bridge_x + bridge_w, bridge_y + bridge_h), (0, 255, 0), 2)
|
|
|
cv2.putText(result_img, "舰桥: 0.75", (bridge_x, bridge_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
|
|
|
return parts, result_img
|
|
|
|
|
|
def _identify_part_by_position(self, rel_x, rel_y, aspect_ratio, area_ratio, ship_type):
|
|
|
"""根据相对位置和形状识别舰船部件"""
|
|
|
# 根据舰船类型和位置确定部件
|
|
|
if "航母" in ship_type or "航空母舰" in ship_type:
|
|
|
# 航母部件
|
|
|
if rel_y < 0.3:
|
|
|
if rel_x > 0.5:
|
|
|
if aspect_ratio < 1 and area_ratio > 0.05:
|
|
|
return "舰岛"
|
|
|
elif area_ratio < 0.03:
|
|
|
return "雷达"
|
|
|
elif area_ratio < 0.02 and aspect_ratio < 1:
|
|
|
return "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
return "相控阵雷达"
|
|
|
elif rel_y > 0.4 and rel_x > 0.1 and rel_x < 0.9:
|
|
|
return "飞行甲板"
|
|
|
elif rel_x < 0.5 and rel_y > 0.6:
|
|
|
return "弹射器"
|
|
|
elif aspect_ratio > 1 and area_ratio < 0.05 and rel_y > 0.3:
|
|
|
return "舰载机"
|
|
|
elif aspect_ratio < 1 and area_ratio < 0.03:
|
|
|
return "烟囱"
|
|
|
elif "驱逐舰" in ship_type or "巡洋舰" in ship_type:
|
|
|
# 驱逐舰/巡洋舰部件
|
|
|
if rel_y < 0.35:
|
|
|
if area_ratio > 0.1:
|
|
|
return "舰桥"
|
|
|
elif aspect_ratio < 1 and area_ratio < 0.02:
|
|
|
return "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
if aspect_ratio > 1.5:
|
|
|
return "相控阵雷达"
|
|
|
else:
|
|
|
return "球形雷达"
|
|
|
elif rel_x < 0.3 and rel_y < 0.5:
|
|
|
if area_ratio < 0.05:
|
|
|
return "主炮"
|
|
|
elif rel_x > 0.6 and area_ratio < 0.05:
|
|
|
return "垂发系统"
|
|
|
elif rel_y > 0.6 and rel_x > 0.7 and area_ratio > 0.05:
|
|
|
return "直升机甲板"
|
|
|
elif aspect_ratio < 1 and area_ratio < 0.02:
|
|
|
return "烟囱"
|
|
|
elif "护卫舰" in ship_type:
|
|
|
# 护卫舰部件
|
|
|
if rel_y < 0.4:
|
|
|
if area_ratio > 0.08:
|
|
|
return "舰桥"
|
|
|
elif area_ratio < 0.03 and aspect_ratio < 1:
|
|
|
return "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
return "雷达"
|
|
|
elif rel_x < 0.3:
|
|
|
return "舰炮"
|
|
|
elif rel_x > 0.7 and area_ratio > 0.05:
|
|
|
return "直升机甲板"
|
|
|
elif rel_x > 0.6 and area_ratio < 0.04:
|
|
|
return "导弹发射器"
|
|
|
elif "潜艇" in ship_type:
|
|
|
# 潜艇部件
|
|
|
if rel_y < 0.3 and area_ratio > 0.05:
|
|
|
return "塔台"
|
|
|
elif rel_x < 0.3 and area_ratio < 0.04:
|
|
|
return "鱼雷管"
|
|
|
elif rel_y < 0.2 and area_ratio < 0.02:
|
|
|
return "通信天线"
|
|
|
|
|
|
# 通用部件识别 - 基于位置和形状
|
|
|
if rel_y < 0.3:
|
|
|
if area_ratio > 0.08:
|
|
|
return "舰桥"
|
|
|
elif aspect_ratio < 1 and area_ratio < 0.03:
|
|
|
return "桅杆"
|
|
|
elif area_ratio < 0.03:
|
|
|
if aspect_ratio > 1:
|
|
|
return "相控阵雷达"
|
|
|
else:
|
|
|
return "雷达"
|
|
|
elif rel_y < 0.6:
|
|
|
if rel_x < 0.25 and area_ratio < 0.05:
|
|
|
return "舰炮"
|
|
|
elif aspect_ratio < 1 and area_ratio < 0.03:
|
|
|
return "烟囱"
|
|
|
elif aspect_ratio > 1.5 and area_ratio < 0.05 and rel_x > 0.6:
|
|
|
return "导弹发射器"
|
|
|
elif rel_y > 0.6:
|
|
|
if rel_x > 0.7 and area_ratio > 0.05:
|
|
|
return "直升机甲板"
|
|
|
elif area_ratio > 0.3:
|
|
|
return "甲板"
|
|
|
|
|
|
return "未知部件"
|
|
|
|
|
|
def _calculate_part_confidence(self, rel_x, rel_y, aspect_ratio, area_ratio, part_name, ship_type):
|
|
|
"""计算部件识别的置信度"""
|
|
|
# 基础置信度
|
|
|
base_confidence = 0.5
|
|
|
|
|
|
# 根据部件类型和位置调整置信度
|
|
|
if part_name == "舰桥":
|
|
|
if rel_y < 0.4 and area_ratio > 0.05:
|
|
|
return min(0.9, base_confidence + 0.3)
|
|
|
elif part_name == "雷达":
|
|
|
if rel_y < 0.3:
|
|
|
return min(0.85, base_confidence + 0.25)
|
|
|
elif part_name == "舰炮":
|
|
|
if rel_x < 0.3:
|
|
|
return min(0.8, base_confidence + 0.2)
|
|
|
elif part_name == "导弹发射器":
|
|
|
if 0.4 < rel_x < 0.8:
|
|
|
return min(0.75, base_confidence + 0.15)
|
|
|
elif part_name == "飞行甲板":
|
|
|
if "航母" in ship_type and area_ratio > 0.3:
|
|
|
return min(0.95, base_confidence + 0.4)
|
|
|
elif part_name == "舰岛":
|
|
|
if "航母" in ship_type and rel_x > 0.6 and rel_y < 0.3:
|
|
|
return min(0.9, base_confidence + 0.3)
|
|
|
elif part_name == "直升机甲板":
|
|
|
if rel_x > 0.7 and rel_y > 0.6:
|
|
|
return min(0.8, base_confidence + 0.2)
|
|
|
|
|
|
# 根据舰船类型增加特定部件的置信度
|
|
|
if "航母" in ship_type:
|
|
|
if part_name in ["舰岛", "飞行甲板", "舰载机", "弹射器"]:
|
|
|
return min(0.85, base_confidence + 0.25)
|
|
|
elif "驱逐舰" in ship_type:
|
|
|
if part_name in ["舰桥", "主炮", "垂发系统", "相控阵雷达"]:
|
|
|
return min(0.85, base_confidence + 0.25)
|
|
|
elif "护卫舰" in ship_type:
|
|
|
if part_name in ["舰桥", "舰炮", "导弹发射器"]:
|
|
|
return min(0.8, base_confidence + 0.2)
|
|
|
elif "潜艇" in ship_type:
|
|
|
if part_name in ["塔台", "鱼雷管"]:
|
|
|
return min(0.85, base_confidence + 0.25)
|
|
|
|
|
|
return base_confidence
|
|
|
|
|
|
def _calculate_iou(self, box1, box2):
|
|
|
"""计算两个边界框的IoU"""
|
|
|
# 确保所有坐标为数字
|
|
|
x1_1, y1_1, x2_1, y2_1 = box1
|
|
|
x1_2, y1_2, x2_2, y2_2 = box2
|
|
|
|
|
|
# 计算交集面积
|
|
|
x_left = max(x1_1, x1_2)
|
|
|
y_top = max(y1_1, y1_2)
|
|
|
x_right = min(x2_1, x2_2)
|
|
|
y_bottom = min(y2_1, y2_2)
|
|
|
|
|
|
# 如果没有交集,返回0
|
|
|
if x_right < x_left or y_bottom < y_top:
|
|
|
return 0.0
|
|
|
|
|
|
intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
|
|
|
|
|
# 计算两个框的面积
|
|
|
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
|
|
|
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
|
|
|
|
|
|
# 计算IoU
|
|
|
return intersection_area / float(box1_area + box2_area - intersection_area)
|
|
|
|
|
|
def _add_specific_ship_parts(self, ship_img, ship_box, ship_type):
|
|
|
"""
|
|
|
为特定类型的舰船添加可能的部件
|
|
|
|
|
|
Args:
|
|
|
ship_img: 舰船图像
|
|
|
ship_box: 舰船边界框
|
|
|
ship_type: 舰船类型
|
|
|
|
|
|
Returns:
|
|
|
additional_parts: 额外检测到的部件列表
|
|
|
"""
|
|
|
additional_parts = []
|
|
|
h, w = ship_img.shape[:2]
|
|
|
|
|
|
# 获取目标舰船类型的特定部件
|
|
|
target_parts = self._get_target_parts(ship_type)
|
|
|
|
|
|
if "航空母舰" in ship_type:
|
|
|
# 检测舰岛
|
|
|
island_x = int(w * 0.75) # 舰岛通常在右侧四分之三处
|
|
|
island_y = int(h * 0.2) # 从顶部20%开始
|
|
|
island_w = int(w * 0.2) # 宽度约为舰船宽度的20%
|
|
|
island_h = int(h * 0.3) # 高度约为舰船高度的30%
|
|
|
|
|
|
# 验证区域
|
|
|
island_roi = ship_img[island_y:island_y+island_h, island_x:island_x+island_w]
|
|
|
if island_roi.size > 0:
|
|
|
# 分析区域特征
|
|
|
gray = cv2.cvtColor(island_roi, cv2.COLOR_BGR2GRAY)
|
|
|
edges = cv2.Canny(gray, 50, 150)
|
|
|
edge_pixels = cv2.countNonZero(edges)
|
|
|
edge_density = edge_pixels / (island_roi.shape[0] * island_roi.shape[1]) if island_roi.size > 0 else 0
|
|
|
|
|
|
# 如果区域有足够的边缘特征,识别为舰岛
|
|
|
if edge_density > 0.1 and "舰岛" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "舰岛",
|
|
|
"confidence": 0.85,
|
|
|
"box": [island_x, island_y, island_x + island_w, island_y + island_h],
|
|
|
"relative_box": [island_x/w, island_y/h, (island_x+island_w)/w, (island_y+island_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测飞行甲板
|
|
|
deck_x = int(w * 0.05) # 从左侧5%开始
|
|
|
deck_y = int(h * 0.1) # 从顶部10%开始
|
|
|
deck_w = int(w * 0.9) # 宽度约为舰船宽度的90%
|
|
|
deck_h = int(h * 0.2) # 高度约为舰船高度的20%
|
|
|
|
|
|
if "飞行甲板" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "飞行甲板",
|
|
|
"confidence": 0.9,
|
|
|
"box": [deck_x, deck_y, deck_x + deck_w, deck_y + deck_h],
|
|
|
"relative_box": [deck_x/w, deck_y/h, (deck_x+deck_w)/w, (deck_y+deck_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测舰载机(如果有)
|
|
|
# 分析顶部区域是否有舰载机特征
|
|
|
top_area = ship_img[0:int(h*0.3), :]
|
|
|
if top_area.size > 0:
|
|
|
gray = cv2.cvtColor(top_area, cv2.COLOR_BGR2GRAY)
|
|
|
blur = cv2.GaussianBlur(gray, (5, 5), 0)
|
|
|
_, thresh = cv2.threshold(blur, 120, 255, cv2.THRESH_BINARY_INV)
|
|
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
# 筛选可能的舰载机轮廓
|
|
|
aircraft_contours = []
|
|
|
for cnt in contours:
|
|
|
area = cv2.contourArea(cnt)
|
|
|
if area > 100 and area < 5000: # 舰载机大小范围
|
|
|
x, y, w_box, h_box = cv2.boundingRect(cnt)
|
|
|
aspect_ratio = w_box / h_box if h_box > 0 else 0
|
|
|
if 1.5 < aspect_ratio < 3.0: # 舰载机长宽比范围
|
|
|
aircraft_contours.append(cnt)
|
|
|
|
|
|
# 为每个可能的舰载机添加部件
|
|
|
for i, cnt in enumerate(aircraft_contours[:5]): # 最多添加5个舰载机
|
|
|
x, y, w_box, h_box = cv2.boundingRect(cnt)
|
|
|
if "舰载机" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "舰载机",
|
|
|
"confidence": 0.75,
|
|
|
"box": [x, y, x + w_box, y + h_box],
|
|
|
"relative_box": [x/w, y/h, (x+w_box)/w, (y+h_box)/h]
|
|
|
})
|
|
|
|
|
|
elif "驱逐舰" in ship_type or "护卫舰" in ship_type:
|
|
|
# 检测舰炮(通常在舰首)
|
|
|
gun_x = int(w * 0.05) # 从左侧5%开始(假设舰首在左侧)
|
|
|
gun_y = int(h * 0.1) # 从顶部10%开始
|
|
|
gun_w = int(w * 0.15) # 宽度约为舰船宽度的15%
|
|
|
gun_h = int(h * 0.15) # 高度约为舰船高度的15%
|
|
|
|
|
|
# 验证区域
|
|
|
gun_roi = ship_img[gun_y:gun_y+gun_h, gun_x:gun_x+gun_w]
|
|
|
if gun_roi.size > 0 and "舰炮" in target_parts:
|
|
|
# 可以使用更复杂的检测算法来验证
|
|
|
additional_parts.append({
|
|
|
"name": "舰炮",
|
|
|
"confidence": 0.8,
|
|
|
"box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h],
|
|
|
"relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测舰桥(通常在中部偏后)
|
|
|
bridge_x = int(w * 0.4) # 从40%处开始
|
|
|
bridge_y = int(h * 0.05) # 从顶部5%开始
|
|
|
bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20%
|
|
|
bridge_h = int(h * 0.3) # 高度约为舰船高度的30%
|
|
|
|
|
|
# 验证区域
|
|
|
bridge_roi = ship_img[bridge_y:bridge_y+bridge_h, bridge_x:bridge_x+bridge_w]
|
|
|
if bridge_roi.size > 0 and "舰桥" in target_parts:
|
|
|
# 分析区域特征
|
|
|
gray = cv2.cvtColor(bridge_roi, cv2.COLOR_BGR2GRAY)
|
|
|
edges = cv2.Canny(gray, 50, 150)
|
|
|
edge_pixels = cv2.countNonZero(edges)
|
|
|
edge_density = edge_pixels / (bridge_roi.shape[0] * bridge_roi.shape[1]) if bridge_roi.size > 0 else 0
|
|
|
|
|
|
if edge_density > 0.1:
|
|
|
additional_parts.append({
|
|
|
"name": "舰桥",
|
|
|
"confidence": 0.85,
|
|
|
"box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h],
|
|
|
"relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测雷达(通常在舰桥顶部)
|
|
|
radar_x = int(w * 0.45) # 从45%处开始
|
|
|
radar_y = int(h * 0.05) # 从顶部5%开始
|
|
|
radar_w = int(w * 0.1) # 宽度约为舰船宽度的10%
|
|
|
radar_h = int(h * 0.1) # 高度约为舰船高度的10%
|
|
|
|
|
|
if "雷达" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "雷达",
|
|
|
"confidence": 0.7,
|
|
|
"box": [radar_x, radar_y, radar_x + radar_w, radar_y + radar_h],
|
|
|
"relative_box": [radar_x/w, radar_y/h, (radar_x+radar_w)/w, (radar_y+radar_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测垂直发射系统(通常在中前部)
|
|
|
vls_x = int(w * 0.25) # 从25%处开始
|
|
|
vls_y = int(h * 0.1) # 从顶部10%开始
|
|
|
vls_w = int(w * 0.15) # 宽度约为舰船宽度的15%
|
|
|
vls_h = int(h * 0.1) # 高度约为舰船高度的10%
|
|
|
|
|
|
if "垂直发射系统" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "垂直发射系统",
|
|
|
"confidence": 0.75,
|
|
|
"box": [vls_x, vls_y, vls_x + vls_w, vls_y + vls_h],
|
|
|
"relative_box": [vls_x/w, vls_y/h, (vls_x+vls_w)/w, (vls_y+vls_h)/h]
|
|
|
})
|
|
|
|
|
|
elif "潜艇" in ship_type:
|
|
|
# 检测舰塔(通常在潜艇中部上方)
|
|
|
tower_x = int(w * 0.4) # 从40%处开始
|
|
|
tower_y = int(h * 0.0) # 从顶部开始
|
|
|
tower_w = int(w * 0.2) # 宽度约为潜艇宽度的20%
|
|
|
tower_h = int(h * 0.5) # 高度约为潜艇高度的50%
|
|
|
|
|
|
if "舰塔" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "舰塔",
|
|
|
"confidence": 0.9,
|
|
|
"box": [tower_x, tower_y, tower_x + tower_w, tower_y + tower_h],
|
|
|
"relative_box": [tower_x/w, tower_y/h, (tower_x+tower_w)/w, (tower_y+tower_h)/h]
|
|
|
})
|
|
|
|
|
|
elif "巡洋舰" in ship_type:
|
|
|
# 检测舰桥(通常在中部)
|
|
|
bridge_x = int(w * 0.4) # 从40%处开始
|
|
|
bridge_y = int(h * 0.05) # 从顶部5%开始
|
|
|
bridge_w = int(w * 0.2) # 宽度约为舰船宽度的20%
|
|
|
bridge_h = int(h * 0.3) # 高度约为舰船高度的30%
|
|
|
|
|
|
if "舰桥" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "舰桥",
|
|
|
"confidence": 0.85,
|
|
|
"box": [bridge_x, bridge_y, bridge_x + bridge_w, bridge_y + bridge_h],
|
|
|
"relative_box": [bridge_x/w, bridge_y/h, (bridge_x+bridge_w)/w, (bridge_y+bridge_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测主炮(通常在前部)
|
|
|
gun_x = int(w * 0.1) # 从10%处开始
|
|
|
gun_y = int(h * 0.1) # 从顶部10%开始
|
|
|
gun_w = int(w * 0.15) # 宽度约为舰船宽度的15%
|
|
|
gun_h = int(h * 0.15) # 高度约为舰船高度的15%
|
|
|
|
|
|
if "舰炮" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "舰炮",
|
|
|
"confidence": 0.8,
|
|
|
"box": [gun_x, gun_y, gun_x + gun_w, gun_y + gun_h],
|
|
|
"relative_box": [gun_x/w, gun_y/h, (gun_x+gun_w)/w, (gun_y+gun_h)/h]
|
|
|
})
|
|
|
|
|
|
# 检测垂直发射系统(通常分布在前后部)
|
|
|
vls1_x = int(w * 0.2) # 前部VLS
|
|
|
vls1_y = int(h * 0.1)
|
|
|
vls1_w = int(w * 0.1)
|
|
|
vls1_h = int(h * 0.1)
|
|
|
|
|
|
vls2_x = int(w * 0.6) # 后部VLS
|
|
|
vls2_y = int(h * 0.1)
|
|
|
vls2_w = int(w * 0.1)
|
|
|
vls2_h = int(h * 0.1)
|
|
|
|
|
|
if "垂直发射系统" in target_parts:
|
|
|
additional_parts.append({
|
|
|
"name": "垂直发射系统",
|
|
|
"confidence": 0.75,
|
|
|
"box": [vls1_x, vls1_y, vls1_x + vls1_w, vls1_y + vls1_h],
|
|
|
"relative_box": [vls1_x/w, vls1_y/h, (vls1_x+vls1_w)/w, (vls1_y+vls1_h)/h]
|
|
|
})
|
|
|
|
|
|
additional_parts.append({
|
|
|
"name": "垂直发射系统",
|
|
|
"confidence": 0.75,
|
|
|
"box": [vls2_x, vls2_y, vls2_x + vls2_w, vls2_y + vls2_h],
|
|
|
"relative_box": [vls2_x/w, vls2_y/h, (vls2_x+vls2_w)/w, (vls2_y+vls2_h)/h]
|
|
|
})
|
|
|
|
|
|
return additional_parts
|
|
|
|
|
|
def _get_target_parts(self, ship_type=""):
|
|
|
"""
|
|
|
根据舰船类型获取目标部件列表
|
|
|
|
|
|
Args:
|
|
|
ship_type: 舰船类型
|
|
|
|
|
|
Returns:
|
|
|
target_parts: 该类型舰船应该检测的部件列表
|
|
|
"""
|
|
|
# 默认检测所有部件
|
|
|
if not ship_type or ship_type not in self.ship_parts_map:
|
|
|
return list(self.part_types.keys())
|
|
|
|
|
|
# 返回特定舰船类型的部件列表
|
|
|
return self.ship_parts_map.get(ship_type, [])
|
|
|
|
|
|
def _map_to_ship_part(self, cls_id, ship_type=""):
|
|
|
"""
|
|
|
将YOLO类别ID映射到舰船部件名称
|
|
|
|
|
|
Args:
|
|
|
cls_id: YOLO类别ID
|
|
|
ship_type: 舰船类型名称
|
|
|
|
|
|
Returns:
|
|
|
part_name: 舰船部件名称
|
|
|
"""
|
|
|
# COCO数据集中的常见类别
|
|
|
coco_classes = {
|
|
|
0: "person", # 人 -> 可能是船员
|
|
|
1: "bicycle", # 自行车
|
|
|
2: "car", # 汽车
|
|
|
3: "motorcycle", # 摩托车
|
|
|
4: "airplane", # 飞机 -> 可能是舰载机
|
|
|
5: "bus", # 公交车
|
|
|
6: "train", # 火车
|
|
|
7: "truck", # 卡车
|
|
|
8: "boat", # 船
|
|
|
9: "traffic light", # 交通灯
|
|
|
10: "fire hydrant", # 消防栓
|
|
|
11: "stop sign", # 停止标志
|
|
|
12: "parking meter", # 停车计时器
|
|
|
13: "bench", # 长凳
|
|
|
14: "bird", # 鸟
|
|
|
15: "cat", # 猫
|
|
|
16: "dog", # 狗
|
|
|
17: "horse", # 马
|
|
|
24: "backpack", # 背包
|
|
|
25: "umbrella", # 雨伞
|
|
|
26: "handbag", # 手提包
|
|
|
27: "tie", # 领带
|
|
|
28: "suitcase", # 手提箱
|
|
|
33: "sports ball", # 运动球
|
|
|
34: "kite", # 风筝
|
|
|
35: "baseball bat", # 棒球棒
|
|
|
36: "baseball glove", # 棒球手套
|
|
|
41: "skateboard", # 滑板
|
|
|
42: "surfboard", # 冲浪板
|
|
|
43: "tennis racket", # 网球拍
|
|
|
59: "potted plant", # 盆栽植物
|
|
|
60: "bed", # 床
|
|
|
61: "dining table", # 餐桌
|
|
|
62: "toilet", # 厕所
|
|
|
63: "tv", # 电视
|
|
|
64: "laptop", # 笔记本电脑
|
|
|
65: "mouse", # 鼠标
|
|
|
66: "remote", # 遥控器
|
|
|
67: "keyboard", # 键盘
|
|
|
68: "cell phone", # 手机
|
|
|
69: "microwave", # 微波炉
|
|
|
70: "oven", # 烤箱
|
|
|
71: "toaster", # 烤面包机
|
|
|
72: "sink", # 水槽
|
|
|
73: "refrigerator", # 冰箱
|
|
|
74: "book", # 书
|
|
|
75: "clock", # 时钟
|
|
|
76: "vase", # 花瓶
|
|
|
77: "scissors", # 剪刀
|
|
|
78: "teddy bear", # 泰迪熊
|
|
|
79: "hair drier", # 吹风机
|
|
|
80: "toothbrush" # 牙刷
|
|
|
}
|
|
|
|
|
|
# 将COCO类别映射到舰船部件
|
|
|
part_mappings = {
|
|
|
"person": "船员",
|
|
|
"airplane": "舰载机",
|
|
|
"boat": "小艇",
|
|
|
"backpack": "装备",
|
|
|
"tie": "指挥官",
|
|
|
"suitcase": "设备箱",
|
|
|
"sports ball": "雷达球",
|
|
|
"kite": "雷达天线",
|
|
|
"surfboard": "救生设备",
|
|
|
"bed": "甲板",
|
|
|
"dining table": "甲板",
|
|
|
"toilet": "舱室",
|
|
|
"tv": "显示设备",
|
|
|
"laptop": "控制设备",
|
|
|
"mouse": "控制设备",
|
|
|
"remote": "控制设备",
|
|
|
"keyboard": "控制设备",
|
|
|
"cell phone": "通信设备",
|
|
|
"clock": "导航设备",
|
|
|
"scissors": "工具"
|
|
|
}
|
|
|
|
|
|
# 获取COCO类别名称
|
|
|
if cls_id in coco_classes:
|
|
|
coco_class = coco_classes[cls_id]
|
|
|
|
|
|
# 将COCO类别映射到舰船部件
|
|
|
if coco_class in part_mappings:
|
|
|
return part_mappings[coco_class]
|
|
|
|
|
|
# 如果是航母且检测到飞机,返回舰载机
|
|
|
if ("航母" in ship_type or "航空母舰" in ship_type) and cls_id == 4:
|
|
|
return "舰载机"
|
|
|
|
|
|
# 默认返回未知
|
|
|
return "unknown"
|
|
|
|
|
|
# 测试代码
|
|
|
if __name__ == "__main__":
|
|
|
# 初始化部件检测器
|
|
|
part_detector = ShipPartDetector()
|
|
|
|
|
|
# 测试图像路径和舰船边界框
|
|
|
test_image = "path/to/test/image.jpg"
|
|
|
test_ship_box = (100, 100, 500, 400) # 示例边界框
|
|
|
|
|
|
if os.path.exists(test_image):
|
|
|
# 执行部件检测
|
|
|
parts, result_img = part_detector.detect_parts(test_image, test_ship_box)
|
|
|
|
|
|
# 打印检测结果
|
|
|
print(f"检测到 {len(parts)} 个舰船部件:")
|
|
|
for i, part in enumerate(parts):
|
|
|
print(f"部件 {i+1}: 类型={part['name']}, 置信度={part['confidence']:.4f}, 位置={part['bbox']}") |