#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import sys import cv2 import argparse from pathlib import Path import numpy as np from PIL import Image, ImageDraw, ImageFont # 添加项目根目录到Python路径 script_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(script_dir) # 检查是否可以导入高级检测器 try: # 导入分析器和高级检测器 from scripts.ship_analyzer import ShipAnalyzer from utils.advanced_detector import AdvancedShipDetector ADVANCED_DETECTOR_AVAILABLE = True except ImportError as e: print(f"警告:无法导入高级检测器: {e}") print("将仅使用传统分析器") from scripts.ship_analyzer import ShipAnalyzer ADVANCED_DETECTOR_AVAILABLE = False def analyze_image(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3, use_advanced=True): """ 分析图像中的舰船和部件 Args: image_path: 图像路径 output_dir: 输出目录 conf_threshold: 检测置信度阈值 part_conf_threshold: 部件置信度阈值 use_advanced: 是否使用高级检测器 """ print(f"开始分析图像: {image_path}") # 检查图像是否存在 if not os.path.exists(image_path): print(f"错误: 图像文件不存在: {image_path}") return None # 创建输出目录 if output_dir is not None: os.makedirs(output_dir, exist_ok=True) # 根据参数选择使用高级检测器或传统分析器 if use_advanced and ADVANCED_DETECTOR_AVAILABLE: try: print("使用高级图像分析器...") result_img, results = analyze_with_advanced_detector(image_path, output_dir, conf_threshold, part_conf_threshold) except Exception as e: print(f"高级分析器出错: {str(e)}") print("回退到传统分析器...") # 如果高级分析失败,回退到传统分析器 analyzer = ShipAnalyzer() results, result_img = analyzer.analyze_image( image_path, conf_threshold=conf_threshold, part_conf_threshold=part_conf_threshold, save_result=True, output_dir=output_dir ) else: # 使用传统分析器 print("使用传统图像分析器...") analyzer = ShipAnalyzer() results, result_img = analyzer.analyze_image( image_path, conf_threshold=conf_threshold, part_conf_threshold=part_conf_threshold, save_result=True, output_dir=output_dir ) # 输出分析结果 if 'ships' in results: ships = results['ships'] print(f"\n分析完成,检测到 {len(ships)} 个舰船:") for i, ship in enumerate(ships): print(f"\n舰船 #{i+1}:") print(f" 类型: {ship['class_name']}") print(f" 置信度: {ship['class_confidence']:.2f}") parts = ship.get('parts', []) print(f" 检测到 {len(parts)} 个部件:") # 显示部件信息 for j, part in enumerate(parts): print(f" 部件 #{j+1}: {part['name']} (置信度: {part['confidence']:.2f})") else: # 兼容旧格式 print(f"\n分析完成,检测到 {len(results)} 个舰船:") for i, ship in enumerate(results): print(f"\n舰船 #{i+1}:") print(f" 类型: {ship['class_name']}") confidence = ship.get('class_confidence', ship.get('confidence', 0.0)) print(f" 置信度: {confidence:.2f}") parts = ship.get('parts', []) print(f" 检测到 {len(parts)} 个部件:") # 显示部件信息 for j, part in enumerate(parts): part_conf = part.get('confidence', 0.0) print(f" 部件 #{j+1}: {part['name']} (置信度: {part_conf:.2f})") # 保存结果图像 if output_dir is not None: result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}") cv2.imwrite(result_path, result_img) print(f"\n结果图像已保存至: {result_path}") return result_img def analyze_with_advanced_detector(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3): """ 使用高级检测器分析图像 Args: image_path: 图像路径 output_dir: 输出目录 conf_threshold: 检测置信度阈值 part_conf_threshold: 部件置信度阈值 Returns: result_img: 标注了检测结果的图像 results: 检测结果字典 """ try: print("正在加载高级图像分析模型...") # 初始化高级检测器 detector = AdvancedShipDetector() except Exception as e: print(f"高级模型加载失败: {e}") print("将使用传统计算机视觉方法进行舰船识别") # 创建一个基本的检测器实例,但不加载模型 detector = AdvancedShipDetector(load_models=False) # 读取图像 img = cv2.imread(image_path) if img is None: raise ValueError(f"无法读取图像: {image_path}") result_img = img.copy() h, w = img.shape[:2] # 使用高级检测器进行对象检测 ships = [] try: if hasattr(detector, 'detect_ships') and callable(detector.detect_ships): detected_ships = detector.detect_ships(img, conf_threshold) if detected_ships and len(detected_ships) > 0: ships = detected_ships # 使用检测器返回的图像 if len(detected_ships) > 1 and isinstance(detected_ships[1], np.ndarray): result_img = detected_ships[1] ships = detected_ships[0] else: print("高级检测器缺少detect_ships方法,使用基本识别") except Exception as e: print(f"高级舰船检测失败: {e}") # 如果没有检测到舰船,使用传统方法尝试识别单个舰船 if not ships: # 识别舰船类型 ship_type, confidence = detector.identify_ship_type(img) print(f"高级检测器识别结果: {ship_type}, 置信度: {confidence:.2f}") # 单个舰船的边界框 - 使用整个图像 padding = int(min(w, h) * 0.05) # 5%的边距 ship_box = (padding, padding, w-padding, h-padding) # 创建单个舰船对象 ship = { 'id': 1, 'bbox': ship_box, 'class_name': ship_type, 'class_confidence': confidence } ships = [ship] # 在图像上标注舰船信息 cv2.rectangle(result_img, (ship_box[0], ship_box[1]), (ship_box[2], ship_box[3]), (0, 0, 255), 2) cv2.putText(result_img, f"{ship_type}: {confidence:.2f}", (ship_box[0]+10, ship_box[1]+30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2) # 为每艘舰船检测部件 processed_ships = [] for i, ship in enumerate(ships): ship_id = i + 1 ship_box = ship.get('bbox', (0, 0, w, h)) ship_type = ship.get('class_name', '其他舰船') ship_confidence = ship.get('class_confidence', ship.get('confidence', 0.7)) # 格式化为标准结构 ship_with_parts = { 'id': ship_id, 'bbox': ship_box, 'class_name': ship_type, 'class_confidence': ship_confidence, 'parts': [] } # 检测舰船部件 try: parts = detector.detect_ship_parts(img, ship_box, ship_type, part_conf_threshold) print(f"舰船 #{ship_id} 检测到 {len(parts)} 个部件") # 为每个部件添加所属舰船ID for part in parts: part['ship_id'] = ship_id ship_with_parts['parts'].append(part) # 标注部件 part_box = part.get('bbox', (0, 0, 0, 0)) name = part.get('name', '未知部件') conf = part.get('confidence', 0.0) # 绘制部件边界框 cv2.rectangle(result_img, (int(part_box[0]), int(part_box[1])), (int(part_box[2]), int(part_box[3])), (0, 255, 0), 2) # 添加部件标签 label = f"{name}: {conf:.2f}" cv2.putText(result_img, label, (int(part_box[0]), int(part_box[1])-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) except Exception as e: print(f"部件检测失败: {e}") processed_ships.append(ship_with_parts) # 构建结果数据结构 results = { 'ships': processed_ships } # 保存结果图像 if output_dir is not None: os.makedirs(output_dir, exist_ok=True) result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}") cv2.imwrite(result_path, result_img) print(f"结果图像已保存至: {result_path}") return result_img, results def main(): parser = argparse.ArgumentParser(description="舰船图像分析工具") parser.add_argument("image_path", help="需要分析的舰船图像路径") parser.add_argument("--output", "-o", help="分析结果输出目录", default="results") parser.add_argument("--conf", "-c", type=float, default=0.25, help="检测置信度阈值") parser.add_argument("--part-conf", "-pc", type=float, default=0.3, help="部件检测置信度阈值") parser.add_argument("--show", action="store_true", help="显示分析结果图像") parser.add_argument("--traditional", action="store_true", help="使用传统分析器而非高级分析器") args = parser.parse_args() try: # 分析图像 result_img = analyze_image( args.image_path, output_dir=args.output, conf_threshold=args.conf, part_conf_threshold=args.part_conf, use_advanced=not args.traditional ) # 显示结果图像 if args.show and result_img is not None: cv2.imshow("分析结果", result_img) cv2.waitKey(0) cv2.destroyAllWindows() except Exception as e: print(f"分析过程中出错: {str(e)}") if __name__ == "__main__": main()