You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Software_Architecture/distance-judgement/src/drone/utils/analyze_image.py

283 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import cv2
import argparse
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw, ImageFont
# 添加项目根目录到Python路径
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_dir)
# 检查是否可以导入高级检测器
try:
# 导入分析器和高级检测器
from scripts.ship_analyzer import ShipAnalyzer
from utils.advanced_detector import AdvancedShipDetector
ADVANCED_DETECTOR_AVAILABLE = True
except ImportError as e:
print(f"警告:无法导入高级检测器: {e}")
print("将仅使用传统分析器")
from scripts.ship_analyzer import ShipAnalyzer
ADVANCED_DETECTOR_AVAILABLE = False
def analyze_image(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3, use_advanced=True):
"""
分析图像中的舰船和部件
Args:
image_path: 图像路径
output_dir: 输出目录
conf_threshold: 检测置信度阈值
part_conf_threshold: 部件置信度阈值
use_advanced: 是否使用高级检测器
"""
print(f"开始分析图像: {image_path}")
# 检查图像是否存在
if not os.path.exists(image_path):
print(f"错误: 图像文件不存在: {image_path}")
return None
# 创建输出目录
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
# 根据参数选择使用高级检测器或传统分析器
if use_advanced and ADVANCED_DETECTOR_AVAILABLE:
try:
print("使用高级图像分析器...")
result_img, results = analyze_with_advanced_detector(image_path, output_dir, conf_threshold, part_conf_threshold)
except Exception as e:
print(f"高级分析器出错: {str(e)}")
print("回退到传统分析器...")
# 如果高级分析失败,回退到传统分析器
analyzer = ShipAnalyzer()
results, result_img = analyzer.analyze_image(
image_path,
conf_threshold=conf_threshold,
part_conf_threshold=part_conf_threshold,
save_result=True,
output_dir=output_dir
)
else:
# 使用传统分析器
print("使用传统图像分析器...")
analyzer = ShipAnalyzer()
results, result_img = analyzer.analyze_image(
image_path,
conf_threshold=conf_threshold,
part_conf_threshold=part_conf_threshold,
save_result=True,
output_dir=output_dir
)
# 输出分析结果
if 'ships' in results:
ships = results['ships']
print(f"\n分析完成,检测到 {len(ships)} 个舰船:")
for i, ship in enumerate(ships):
print(f"\n舰船 #{i+1}:")
print(f" 类型: {ship['class_name']}")
print(f" 置信度: {ship['class_confidence']:.2f}")
parts = ship.get('parts', [])
print(f" 检测到 {len(parts)} 个部件:")
# 显示部件信息
for j, part in enumerate(parts):
print(f" 部件 #{j+1}: {part['name']} (置信度: {part['confidence']:.2f})")
else:
# 兼容旧格式
print(f"\n分析完成,检测到 {len(results)} 个舰船:")
for i, ship in enumerate(results):
print(f"\n舰船 #{i+1}:")
print(f" 类型: {ship['class_name']}")
confidence = ship.get('class_confidence', ship.get('confidence', 0.0))
print(f" 置信度: {confidence:.2f}")
parts = ship.get('parts', [])
print(f" 检测到 {len(parts)} 个部件:")
# 显示部件信息
for j, part in enumerate(parts):
part_conf = part.get('confidence', 0.0)
print(f" 部件 #{j+1}: {part['name']} (置信度: {part_conf:.2f})")
# 保存结果图像
if output_dir is not None:
result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}")
cv2.imwrite(result_path, result_img)
print(f"\n结果图像已保存至: {result_path}")
return result_img
def analyze_with_advanced_detector(image_path, output_dir=None, conf_threshold=0.25, part_conf_threshold=0.3):
"""
使用高级检测器分析图像
Args:
image_path: 图像路径
output_dir: 输出目录
conf_threshold: 检测置信度阈值
part_conf_threshold: 部件置信度阈值
Returns:
result_img: 标注了检测结果的图像
results: 检测结果字典
"""
try:
print("正在加载高级图像分析模型...")
# 初始化高级检测器
detector = AdvancedShipDetector()
except Exception as e:
print(f"高级模型加载失败: {e}")
print("将使用传统计算机视觉方法进行舰船识别")
# 创建一个基本的检测器实例,但不加载模型
detector = AdvancedShipDetector(load_models=False)
# 读取图像
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
result_img = img.copy()
h, w = img.shape[:2]
# 使用高级检测器进行对象检测
ships = []
try:
if hasattr(detector, 'detect_ships') and callable(detector.detect_ships):
detected_ships = detector.detect_ships(img, conf_threshold)
if detected_ships and len(detected_ships) > 0:
ships = detected_ships
# 使用检测器返回的图像
if len(detected_ships) > 1 and isinstance(detected_ships[1], np.ndarray):
result_img = detected_ships[1]
ships = detected_ships[0]
else:
print("高级检测器缺少detect_ships方法使用基本识别")
except Exception as e:
print(f"高级舰船检测失败: {e}")
# 如果没有检测到舰船,使用传统方法尝试识别单个舰船
if not ships:
# 识别舰船类型
ship_type, confidence = detector.identify_ship_type(img)
print(f"高级检测器识别结果: {ship_type}, 置信度: {confidence:.2f}")
# 单个舰船的边界框 - 使用整个图像
padding = int(min(w, h) * 0.05) # 5%的边距
ship_box = (padding, padding, w-padding, h-padding)
# 创建单个舰船对象
ship = {
'id': 1,
'bbox': ship_box,
'class_name': ship_type,
'class_confidence': confidence
}
ships = [ship]
# 在图像上标注舰船信息
cv2.rectangle(result_img, (ship_box[0], ship_box[1]), (ship_box[2], ship_box[3]), (0, 0, 255), 2)
cv2.putText(result_img, f"{ship_type}: {confidence:.2f}",
(ship_box[0]+10, ship_box[1]+30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
# 为每艘舰船检测部件
processed_ships = []
for i, ship in enumerate(ships):
ship_id = i + 1
ship_box = ship.get('bbox', (0, 0, w, h))
ship_type = ship.get('class_name', '其他舰船')
ship_confidence = ship.get('class_confidence', ship.get('confidence', 0.7))
# 格式化为标准结构
ship_with_parts = {
'id': ship_id,
'bbox': ship_box,
'class_name': ship_type,
'class_confidence': ship_confidence,
'parts': []
}
# 检测舰船部件
try:
parts = detector.detect_ship_parts(img, ship_box, ship_type, part_conf_threshold)
print(f"舰船 #{ship_id} 检测到 {len(parts)} 个部件")
# 为每个部件添加所属舰船ID
for part in parts:
part['ship_id'] = ship_id
ship_with_parts['parts'].append(part)
# 标注部件
part_box = part.get('bbox', (0, 0, 0, 0))
name = part.get('name', '未知部件')
conf = part.get('confidence', 0.0)
# 绘制部件边界框
cv2.rectangle(result_img,
(int(part_box[0]), int(part_box[1])),
(int(part_box[2]), int(part_box[3])),
(0, 255, 0), 2)
# 添加部件标签
label = f"{name}: {conf:.2f}"
cv2.putText(result_img, label,
(int(part_box[0]), int(part_box[1])-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
except Exception as e:
print(f"部件检测失败: {e}")
processed_ships.append(ship_with_parts)
# 构建结果数据结构
results = {
'ships': processed_ships
}
# 保存结果图像
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
result_path = os.path.join(output_dir, f"analysis_{os.path.basename(image_path)}")
cv2.imwrite(result_path, result_img)
print(f"结果图像已保存至: {result_path}")
return result_img, results
def main():
parser = argparse.ArgumentParser(description="舰船图像分析工具")
parser.add_argument("image_path", help="需要分析的舰船图像路径")
parser.add_argument("--output", "-o", help="分析结果输出目录", default="results")
parser.add_argument("--conf", "-c", type=float, default=0.25, help="检测置信度阈值")
parser.add_argument("--part-conf", "-pc", type=float, default=0.3, help="部件检测置信度阈值")
parser.add_argument("--show", action="store_true", help="显示分析结果图像")
parser.add_argument("--traditional", action="store_true", help="使用传统分析器而非高级分析器")
args = parser.parse_args()
try:
# 分析图像
result_img = analyze_image(
args.image_path,
output_dir=args.output,
conf_threshold=args.conf,
part_conf_threshold=args.part_conf,
use_advanced=not args.traditional
)
# 显示结果图像
if args.show and result_img is not None:
cv2.imshow("分析结果", result_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
except Exception as e:
print(f"分析过程中出错: {str(e)}")
if __name__ == "__main__":
main()