parent
							
								
									8669425c6e
								
							
						
					
					
						commit
						8b3332bea7
					
				| @ -0,0 +1,149 @@ | |||||||
|  | import tkinter as tk | ||||||
|  | from tkinter import filedialog | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | import os | ||||||
|  | import time | ||||||
|  | 
 | ||||||
|  | # 获取当前脚本文件的目录 | ||||||
|  | base_path = os.path.dirname(os.path.abspath(__file__)) | ||||||
|  | 
 | ||||||
|  | def select_file_and_run(): | ||||||
|  |     file_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4;*.avi")]) | ||||||
|  |     if file_path: | ||||||
|  |         run_detection_and_tracking(file_path) | ||||||
|  | 
 | ||||||
|  | def run_detection_and_tracking(file_path): | ||||||
|  |     # 加载YOLO模型 | ||||||
|  |     model_path1 = os.path.join(base_path, 'yolo/yolov3.weights') | ||||||
|  |     model_path2 = os.path.join(base_path, 'yolo/yolov3.cfg') | ||||||
|  |     model_path3 = os.path.join(base_path, 'yolo/coco.names') | ||||||
|  |     net = cv2.dnn.readNet(model_path1, model_path2) | ||||||
|  |     layer_names = net.getLayerNames() | ||||||
|  |     output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers()] | ||||||
|  |     classes = [] | ||||||
|  |     with open(model_path3, "r") as f: | ||||||
|  |         classes = [line.strip() for line in f.readlines()] | ||||||
|  | 
 | ||||||
|  |     # 初始化视频捕捉 | ||||||
|  |     cap = cv2.VideoCapture(file_path) | ||||||
|  |     fps = cap.get(cv2.CAP_PROP_FPS) | ||||||
|  |     speed_up_factor = 2  # 加速倍数 | ||||||
|  |     delay = int(1000 / fps / speed_up_factor)  # 计算加速后每帧之间的时间间隔,以毫秒为单位 | ||||||
|  | 
 | ||||||
|  |     # 创建多目标追踪器 | ||||||
|  |     trackers = [] | ||||||
|  | 
 | ||||||
|  |     # 创建用于存储已追踪对象信息的列表 | ||||||
|  |     tracked_objects = [] | ||||||
|  | 
 | ||||||
|  |     while True: | ||||||
|  |         ret, frame = cap.read() | ||||||
|  |         if not ret: | ||||||
|  |             break | ||||||
|  | 
 | ||||||
|  |         start_time = time.time() | ||||||
|  | 
 | ||||||
|  |         # 检测目标 | ||||||
|  |         height, width, channels = frame.shape | ||||||
|  |         blob = cv2.dnn.blobFromImage(frame, 0.00392, (416, 416), (0, 0, 0), True, crop=False) | ||||||
|  |         net.setInput(blob) | ||||||
|  |         outs = net.forward(output_layers) | ||||||
|  | 
 | ||||||
|  |         # 获取检测结果 | ||||||
|  |         new_objects = []  # 用于存储当前帧检测到的新对象 | ||||||
|  | 
 | ||||||
|  |         class_ids = [] | ||||||
|  |         confidences = [] | ||||||
|  |         boxes = [] | ||||||
|  |         for out in outs: | ||||||
|  |             for detection in out: | ||||||
|  |                 scores = detection[5:] | ||||||
|  |                 class_id = np.argmax(scores) | ||||||
|  |                 confidence = scores[class_id] | ||||||
|  |                 if confidence > 0.5: | ||||||
|  |                     # 目标检测 | ||||||
|  |                     center_x = int(detection[0] * width) | ||||||
|  |                     center_y = int(detection[1] * height) | ||||||
|  |                     w = int(detection[2] * width) | ||||||
|  |                     h = int(detection[3] * height) | ||||||
|  |                     x = int(center_x - w / 2) | ||||||
|  |                     y = int(center_y - h / 2) | ||||||
|  |                     boxes.append([x, y, w, h]) | ||||||
|  |                     confidences.append(float(confidence)) | ||||||
|  |                     class_ids.append(class_id) | ||||||
|  |                     new_objects.append((x, y, w, h, class_id))  # 存储新检测到的对象信息 | ||||||
|  | 
 | ||||||
|  |         # 非最大值抑制 | ||||||
|  |         indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4) | ||||||
|  | 
 | ||||||
|  |         # 更新追踪器或添加新的追踪器 | ||||||
|  |         for i in indices.flatten(): | ||||||
|  |             x, y, w, h = boxes[i] | ||||||
|  |             label = str(classes[class_ids[i]]) | ||||||
|  |             if label in ["dog", "cat", "bird"]:  # 只选择动物目标 | ||||||
|  |                 # 检查是否已经有相似的追踪器在追踪相同类型的对象 | ||||||
|  |                 found_similar = False | ||||||
|  |                 for tracked_object in tracked_objects: | ||||||
|  |                     if tracked_object[4] == class_ids[i]:  # 检查类别是否相同 | ||||||
|  |                         # 计算当前检测到的对象与已有追踪器的距离或重叠度 | ||||||
|  |                         existing_bbox = (tracked_object[0], tracked_object[1], tracked_object[0] + tracked_object[2], | ||||||
|  |                                          tracked_object[1] + tracked_object[3]) | ||||||
|  |                         new_bbox = (x, y, x + w, y + h) | ||||||
|  |                         overlap_area = calculate_overlap(existing_bbox, new_bbox) | ||||||
|  |                         if overlap_area > 0.5:  # 如果重叠度超过阈值,认为是同一个对象,不再重复追踪 | ||||||
|  |                             found_similar = True | ||||||
|  |                             break | ||||||
|  |                 if not found_similar: | ||||||
|  |                     tracker = cv2.TrackerKCF_create() | ||||||
|  |                     trackers.append(tracker) | ||||||
|  |                     trackers[-1].init(frame, (x, y, w, h)) | ||||||
|  |                     tracked_objects.append((x, y, w, h, class_ids[i]))  # 添加到已追踪对象列表 | ||||||
|  | 
 | ||||||
|  |         # 绘制追踪框 | ||||||
|  |         for tracker in trackers: | ||||||
|  |             success, bbox = tracker.update(frame) | ||||||
|  |             if success: | ||||||
|  |                 # 画出追踪框 | ||||||
|  |                 p1 = (int(bbox[0]), int(bbox[1])) | ||||||
|  |                 p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])) | ||||||
|  |                 cv2.rectangle(frame, p1, p2, (255, 0, 0), 2) | ||||||
|  |             else: | ||||||
|  |                 # 追踪失败 | ||||||
|  |                 cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), | ||||||
|  |                             2) | ||||||
|  | 
 | ||||||
|  |         # 显示结果 | ||||||
|  |         cv2.imshow('Tracking', frame) | ||||||
|  | 
 | ||||||
|  |         # 按下ESC键退出或关闭窗口退出 | ||||||
|  |         if cv2.waitKey(delay) & 0xFF == 27: | ||||||
|  |             break | ||||||
|  |         if cv2.getWindowProperty('Tracking', cv2.WND_PROP_VISIBLE) < 1: | ||||||
|  |             break | ||||||
|  | 
 | ||||||
|  |     cap.release() | ||||||
|  |     cv2.destroyAllWindows() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def calculate_overlap(bbox1, bbox2): | ||||||
|  |     # bbox1 和 bbox2 分别是 (x1, y1, x2, y2) 格式的边界框坐标 | ||||||
|  |     # 其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标 | ||||||
|  | 
 | ||||||
|  |     # 计算交集部分的坐标 | ||||||
|  |     inter_x1 = max(bbox1[0], bbox2[0]) | ||||||
|  |     inter_y1 = max(bbox1[1], bbox2[1]) | ||||||
|  |     inter_x2 = min(bbox1[2], bbox2[2]) | ||||||
|  |     inter_y2 = min(bbox1[3], bbox2[3]) | ||||||
|  | 
 | ||||||
|  |     # 计算交集区域的面积 | ||||||
|  |     inter_area = max(0, inter_x2 - inter_x1 + 1) * max(0, inter_y2 - inter_y1 + 1) | ||||||
|  | 
 | ||||||
|  |     # 计算各自的区域面积 | ||||||
|  |     area_bbox1 = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1) | ||||||
|  |     area_bbox2 = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1) | ||||||
|  | 
 | ||||||
|  |     # 计算并返回重叠区域的IoU | ||||||
|  |     iou = inter_area / float(area_bbox1 + area_bbox2 - inter_area) | ||||||
|  | 
 | ||||||
|  |     return iou | ||||||
					Loading…
					
					
				
		Reference in new issue