|
|
#!/usr/bin/env python3
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import cv2
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
import numpy as np
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
# 添加项目根目录到Python路径
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
sys.path.append(script_dir)
|
|
|
|
|
|
# 检查是否可以导入高级检测器
|
|
|
try:
|
|
|
# 导入分析器和高级检测器
|
|
|
from scripts.ship_analyzer import ShipAnalyzer
|
|
|
from utils.advanced_detector import AdvancedShipDetector
|
|
|
ADVANCED_DETECTOR_AVAILABLE = True
|
|
|
except ImportError as e:
|
|
|
print(f"警告:无法导入高级检测器: {e}")
|
|
|
print("将仅使用传统分析器")
|
|
|
from scripts.ship_analyzer import ShipAnalyzer
|
|
|
ADVANCED_DETECTOR_AVAILABLE = False
|
|
|
|
|
|
def analyze_image(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3, use_advanced=True):
|
|
|
"""
|
|
|
分析图像中的舰船和部件
|
|
|
|
|
|
Args:
|
|
|
image_path: 图像路径
|
|
|
output_dir: 输出目录
|
|
|
conf_threshold: 检测置信度阈值
|
|
|
part_conf_threshold: 部件置信度阈值
|
|
|
use_advanced: 是否使用高级检测器
|
|
|
"""
|
|
|
print(f"开始分析图像: {image_path}")
|
|
|
|
|
|
# 检查图像是否存在
|
|
|
if not os.path.exists(image_path):
|
|
|
print(f"错误: 图像文件不存在: {image_path}")
|
|
|
return None
|
|
|
|
|
|
# 创建输出目录
|
|
|
if output_dir is not None:
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
# 根据参数选择使用高级检测器或传统分析器
|
|
|
if use_advanced and ADVANCED_DETECTOR_AVAILABLE:
|
|
|
try:
|
|
|
print("使用高级图像分析器...")
|
|
|
result_img, results = analyze_with_advanced_detector(image_path, output_dir, conf_threshold, part_conf_threshold)
|
|
|
except Exception as e:
|
|
|
print(f"高级分析器出错: {str(e)}")
|
|
|
print("回退到传统分析器...")
|
|
|
# 如果高级分析失败,回退到传统分析器
|
|
|
analyzer = ShipAnalyzer()
|
|
|
results, result_img = analyzer.analyze_image(
|
|
|
image_path,
|
|
|
conf_threshold=conf_threshold,
|
|
|
part_conf_threshold=part_conf_threshold,
|
|
|
save_result=True,
|
|
|
output_dir=output_dir
|
|
|
)
|
|
|
else:
|
|
|
# 使用传统分析器
|
|
|
print("使用传统图像分析器...")
|
|
|
analyzer = ShipAnalyzer()
|
|
|
results, result_img = analyzer.analyze_image(
|
|
|
image_path,
|
|
|
conf_threshold=conf_threshold,
|
|
|
part_conf_threshold=part_conf_threshold,
|
|
|
save_result=True,
|
|
|
output_dir=output_dir
|
|
|
)
|
|
|
|
|
|
# 输出分析结果
|
|
|
if 'ships' in results:
|
|
|
ships = results['ships']
|
|
|
print(f"\n分析完成,检测到 {len(ships)} 个舰船:")
|
|
|
|
|
|
for i, ship in enumerate(ships):
|
|
|
print(f"\n舰船 #{i+1}:")
|
|
|
print(f" 类型: {ship['class_name']}")
|
|
|
print(f" 置信度: {ship['class_confidence']:.2f}")
|
|
|
parts = ship.get('parts', [])
|
|
|
print(f" 检测到 {len(parts)} 个部件:")
|
|
|
|
|
|
# 显示部件信息
|
|
|
for j, part in enumerate(parts):
|
|
|
print(f" 部件 #{j+1}: {part['name']} (置信度: {part['confidence']:.2f})")
|
|
|
else:
|
|
|
# 兼容旧格式
|
|
|
print(f"\n分析完成,检测到 {len(results)} 个舰船:")
|
|
|
for i, ship in enumerate(results):
|
|
|
print(f"\n舰船 #{i+1}:")
|
|
|
print(f" 类型: {ship['class_name']}")
|
|
|
confidence = ship.get('class_confidence', ship.get('confidence', 0.0))
|
|
|
print(f" 置信度: {confidence:.2f}")
|
|
|
parts = ship.get('parts', [])
|
|
|
print(f" 检测到 {len(parts)} 个部件:")
|
|
|
|
|
|
# 显示部件信息
|
|
|
for j, part in enumerate(parts):
|
|
|
part_conf = part.get('confidence', 0.0)
|
|
|
print(f" 部件 #{j+1}: {part['name']} (置信度: {part_conf:.2f})")
|
|
|
|
|
|
# 保存结果图像
|
|
|
if output_dir is not None:
|
|
|
result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}")
|
|
|
cv2.imwrite(result_path, result_img)
|
|
|
print(f"\n结果图像已保存至: {result_path}")
|
|
|
|
|
|
return result_img
|
|
|
|
|
|
def analyze_with_advanced_detector(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3):
|
|
|
"""
|
|
|
使用高级检测器分析图像
|
|
|
|
|
|
Args:
|
|
|
image_path: 图像路径
|
|
|
output_dir: 输出目录
|
|
|
conf_threshold: 检测置信度阈值
|
|
|
part_conf_threshold: 部件置信度阈值
|
|
|
|
|
|
Returns:
|
|
|
result_img: 标注了检测结果的图像
|
|
|
results: 检测结果字典
|
|
|
"""
|
|
|
try:
|
|
|
print("正在加载高级图像分析模型...")
|
|
|
# 初始化高级检测器
|
|
|
detector = AdvancedShipDetector()
|
|
|
except Exception as e:
|
|
|
print(f"高级模型加载失败: {e}")
|
|
|
print("将使用传统计算机视觉方法进行舰船识别")
|
|
|
# 创建一个基本的检测器实例,但不加载模型
|
|
|
detector = AdvancedShipDetector(load_models=False)
|
|
|
|
|
|
# 读取图像
|
|
|
img = cv2.imread(image_path)
|
|
|
if img is None:
|
|
|
raise ValueError(f"无法读取图像: {image_path}")
|
|
|
|
|
|
result_img = img.copy()
|
|
|
h, w = img.shape[:2]
|
|
|
|
|
|
# 使用高级检测器进行对象检测
|
|
|
ships = []
|
|
|
try:
|
|
|
if hasattr(detector, 'detect_ships') and callable(detector.detect_ships):
|
|
|
detected_ships = detector.detect_ships(img, conf_threshold)
|
|
|
if detected_ships and len(detected_ships) > 0:
|
|
|
ships = detected_ships
|
|
|
# 使用检测器返回的图像
|
|
|
if len(detected_ships) > 1 and isinstance(detected_ships[1], np.ndarray):
|
|
|
result_img = detected_ships[1]
|
|
|
ships = detected_ships[0]
|
|
|
else:
|
|
|
print("高级检测器缺少detect_ships方法,使用基本识别")
|
|
|
except Exception as e:
|
|
|
print(f"高级舰船检测失败: {e}")
|
|
|
|
|
|
# 如果没有检测到舰船,使用传统方法尝试识别单个舰船
|
|
|
if not ships:
|
|
|
# 识别舰船类型
|
|
|
ship_type, confidence = detector.identify_ship_type(img)
|
|
|
print(f"高级检测器识别结果: {ship_type}, 置信度: {confidence:.2f}")
|
|
|
|
|
|
# 单个舰船的边界框 - 使用整个图像
|
|
|
padding = int(min(w, h) * 0.05) # 5%的边距
|
|
|
ship_box = (padding, padding, w-padding, h-padding)
|
|
|
|
|
|
# 创建单个舰船对象
|
|
|
ship = {
|
|
|
'id': 1,
|
|
|
'bbox': ship_box,
|
|
|
'class_name': ship_type,
|
|
|
'class_confidence': confidence
|
|
|
}
|
|
|
ships = [ship]
|
|
|
|
|
|
# 在图像上标注舰船信息
|
|
|
cv2.rectangle(result_img, (ship_box[0], ship_box[1]), (ship_box[2], ship_box[3]), (0, 0, 255), 2)
|
|
|
cv2.putText(result_img, f"{ship_type}: {confidence:.2f}",
|
|
|
(ship_box[0]+10, ship_box[1]+30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
|
|
|
|
|
|
# 为每艘舰船检测部件
|
|
|
processed_ships = []
|
|
|
for i, ship in enumerate(ships):
|
|
|
ship_id = i + 1
|
|
|
ship_box = ship.get('bbox', (0, 0, w, h))
|
|
|
ship_type = ship.get('class_name', '其他舰船')
|
|
|
ship_confidence = ship.get('class_confidence', ship.get('confidence', 0.7))
|
|
|
|
|
|
# 格式化为标准结构
|
|
|
ship_with_parts = {
|
|
|
'id': ship_id,
|
|
|
'bbox': ship_box,
|
|
|
'class_name': ship_type,
|
|
|
'class_confidence': ship_confidence,
|
|
|
'parts': []
|
|
|
}
|
|
|
|
|
|
# 检测舰船部件
|
|
|
try:
|
|
|
parts = detector.detect_ship_parts(img, ship_box, ship_type, part_conf_threshold)
|
|
|
print(f"舰船 #{ship_id} 检测到 {len(parts)} 个部件")
|
|
|
|
|
|
# 为每个部件添加所属舰船ID
|
|
|
for part in parts:
|
|
|
part['ship_id'] = ship_id
|
|
|
ship_with_parts['parts'].append(part)
|
|
|
|
|
|
# 标注部件
|
|
|
part_box = part.get('bbox', (0, 0, 0, 0))
|
|
|
name = part.get('name', '未知部件')
|
|
|
conf = part.get('confidence', 0.0)
|
|
|
|
|
|
# 绘制部件边界框
|
|
|
cv2.rectangle(result_img,
|
|
|
(int(part_box[0]), int(part_box[1])),
|
|
|
(int(part_box[2]), int(part_box[3])),
|
|
|
(0, 255, 0), 2)
|
|
|
|
|
|
# 添加部件标签
|
|
|
label = f"{name}: {conf:.2f}"
|
|
|
cv2.putText(result_img, label,
|
|
|
(int(part_box[0]), int(part_box[1])-5),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
except Exception as e:
|
|
|
print(f"部件检测失败: {e}")
|
|
|
|
|
|
processed_ships.append(ship_with_parts)
|
|
|
|
|
|
# 构建结果数据结构
|
|
|
results = {
|
|
|
'ships': processed_ships
|
|
|
}
|
|
|
|
|
|
# 保存结果图像
|
|
|
if output_dir is not None:
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}")
|
|
|
cv2.imwrite(result_path, result_img)
|
|
|
print(f"结果图像已保存至: {result_path}")
|
|
|
|
|
|
return result_img, results
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="舰船图像分析工具")
|
|
|
parser.add_argument("image_path", help="需要分析的舰船图像路径")
|
|
|
parser.add_argument("--output", "-o", help="分析结果输出目录", default="results")
|
|
|
parser.add_argument("--conf", "-c", type=float, default=0.25, help="检测置信度阈值")
|
|
|
parser.add_argument("--part-conf", "-pc", type=float, default=0.3, help="部件检测置信度阈值")
|
|
|
parser.add_argument("--show", action="store_true", help="显示分析结果图像")
|
|
|
parser.add_argument("--traditional", action="store_true", help="使用传统分析器而非高级分析器")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
try:
|
|
|
# 分析图像
|
|
|
result_img = analyze_image(
|
|
|
args.image_path,
|
|
|
output_dir=args.output,
|
|
|
conf_threshold=args.conf,
|
|
|
part_conf_threshold=args.part_conf,
|
|
|
use_advanced=not args.traditional
|
|
|
)
|
|
|
|
|
|
# 显示结果图像
|
|
|
if args.show and result_img is not None:
|
|
|
cv2.imshow("分析结果", result_img)
|
|
|
cv2.waitKey(0)
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"分析过程中出错: {str(e)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |