You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
7.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tkinter as tk
from tkinter import filedialog
import numpy as np
import cv2
import os
import time
# 获取当前脚本文件的目录
base_path = os.path.dirname(os.path.abspath(__file__))
def select_file_and_run():
"""
弹出文件选择对话框,选择视频文件并运行检测和追踪程序。
"""
file_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4;*.avi")])
if file_path:
run_detection_and_tracking(file_path)
def run_detection_and_tracking(file_path):
"""
运行目标检测和多目标追踪程序。
参数:
file_path - 视频文件路径
"""
# 加载YOLO模型
model_path1 = os.path.join(base_path, 'yolo/yolov3.weights') # 权重文件路径
model_path2 = os.path.join(base_path, 'yolo/yolov3.cfg') # 配置文件路径
model_path3 = os.path.join(base_path, 'yolo/coco.names') # 类别名称文件路径
net = cv2.dnn.readNet(model_path1, model_path2) # 加载YOLO模型
layer_names = net.getLayerNames() # 获取模型层名称
output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers()] # 获取输出层名称
classes = []
with open(model_path3, "r") as f:
classes = [line.strip() for line in f.readlines()] # 读取类别名称
# 初始化视频捕捉
cap = cv2.VideoCapture(file_path) # 打开视频文件
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频帧率
speed_up_factor = 2 # 加速倍数
delay = int(1000 / fps / speed_up_factor) # 计算加速后每帧之间的时间间隔,以毫秒为单位
# 创建多目标追踪器
trackers = []
# 创建用于存储已追踪对象信息的列表
tracked_objects = []
while True:
ret, frame = cap.read() # 读取视频帧
if not ret:
break
start_time = time.time()
# 检测目标
height, width, channels = frame.shape
blob = cv2.dnn.blobFromImage(frame, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
net.setInput(blob) # 设置输入数据
outs = net.forward(output_layers) # 执行前向传播,获取检测结果
# 获取检测结果
new_objects = [] # 用于存储当前帧检测到的新对象
class_ids = [] # 存储检测到的目标的类别ID
confidences = [] # 存储检测到的目标的置信度
boxes = [] # 存储检测到的目标的边界框
for out in outs:
for detection in out:
scores = detection[5:] # 获取每个检测框的所有类别得分
class_id = np.argmax(scores) # 获取得分最高的类别ID即检测到的物体类别
confidence = scores[class_id] # 获取该类别的置信度
if confidence > 0.5: # 筛选置信度超过0.5的检测框
# 目标检测
center_x = int(detection[0] * width) # 中心点x坐标
center_y = int(detection[1] * height) # 中心点y坐标
w = int(detection[2] * width) # 边界框宽度
h = int(detection[3] * height) # 边界框高度
x = int(center_x - w / 2) # 左上角x坐标
y = int(center_y - h / 2) # 左上角y坐标
boxes.append([x, y, w, h]) # 存储边界框
confidences.append(float(confidence)) # 存储置信度
class_ids.append(class_id) # 存储类别ID
new_objects.append((x, y, w, h, class_id)) # 存储新检测到的对象信息
# 非最大值抑制
indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)
# 更新追踪器或添加新的追踪器
for i in indices.flatten(): # 遍历NMS后的检测框索引
x, y, w, h = boxes[i] # 获取检测框的左上角坐标和宽高
label = str(classes[class_ids[i]]) # 获取检测框的类别标签
if label in ["dog", "cat", "bird", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe"]: # 只选择动物目标
# 检查是否已经有相似的追踪器在追踪相同类型的对象
found_similar = False
for tracked_object in tracked_objects:
if tracked_object[4] == class_ids[i]: # 检查类别是否相同
# 计算当前检测到的对象与已有追踪器的距离或重叠度
existing_bbox = (tracked_object[0], tracked_object[1], tracked_object[0] + tracked_object[2],
tracked_object[1] + tracked_object[3])
new_bbox = (x, y, x + w, y + h)
overlap_area = calculate_overlap(existing_bbox, new_bbox)
if overlap_area > 0.5: # 如果重叠度超过阈值,认为是同一个对象,不再重复追踪
found_similar = True
break
if not found_similar:
tracker = cv2.TrackerKCF_create() # 创建KCF追踪器
trackers.append(tracker) # 将追踪器添加到追踪器列表
trackers[-1].init(frame, (x, y, w, h)) # 初始化追踪器
tracked_objects.append((x, y, w, h, class_ids[i])) # 添加到已追踪对象列表
# 绘制追踪框
for tracker in trackers:
success, bbox = tracker.update(frame) # 更新追踪器
if success:
# 画出追踪框
p1 = (int(bbox[0]), int(bbox[1]))
p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
cv2.rectangle(frame, p1, p2, (255, 0, 0), 2)
else:
# 追踪失败
cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX,
0.75, (0, 0, 255),2)
# 显示结果
cv2.imshow('Tracking', frame)
# 按下ESC键退出或关闭窗口退出
if cv2.waitKey(delay) & 0xFF == 27:
break
if cv2.getWindowProperty('Tracking', cv2.WND_PROP_VISIBLE) < 1:
break
cap.release()
cv2.destroyAllWindows()
def calculate_overlap(bbox1, bbox2):
"""
计算两个边界框的重叠面积比IoU
参数:
bbox1, bbox2 - 分别是 (x1, y1, x2, y2) 格式的边界框坐标,其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标
返回值:
iou - 重叠面积比
"""
# 计算交集部分的坐标
inter_x1 = max(bbox1[0], bbox2[0])
inter_y1 = max(bbox1[1], bbox2[1])
inter_x2 = min(bbox1[2], bbox2[2])
inter_y2 = min(bbox1[3], bbox2[3])
# 计算交集区域的面积
inter_area = max(0, inter_x2 - inter_x1 + 1) * max(0, inter_y2 - inter_y1 + 1)
# 计算各自的区域面积
area_bbox1 = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
area_bbox2 = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
# 计算并返回重叠区域的IoU
iou = inter_area / float(area_bbox1 + area_bbox2 - inter_area)
return iou