You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.9 KiB
150 lines
5.9 KiB
5 months ago
|
import tkinter as tk
|
||
|
from tkinter import filedialog
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
import os
|
||
|
import time
|
||
|
|
||
|
# 获取当前脚本文件的目录
|
||
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
||
|
|
||
|
def select_file_and_run():
|
||
|
file_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4;*.avi")])
|
||
|
if file_path:
|
||
|
run_detection_and_tracking(file_path)
|
||
|
|
||
|
def run_detection_and_tracking(file_path):
|
||
|
# 加载YOLO模型
|
||
|
model_path1 = os.path.join(base_path, 'yolo/yolov3.weights')
|
||
|
model_path2 = os.path.join(base_path, 'yolo/yolov3.cfg')
|
||
|
model_path3 = os.path.join(base_path, 'yolo/coco.names')
|
||
|
net = cv2.dnn.readNet(model_path1, model_path2)
|
||
|
layer_names = net.getLayerNames()
|
||
|
output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers()]
|
||
|
classes = []
|
||
|
with open(model_path3, "r") as f:
|
||
|
classes = [line.strip() for line in f.readlines()]
|
||
|
|
||
|
# 初始化视频捕捉
|
||
|
cap = cv2.VideoCapture(file_path)
|
||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
|
speed_up_factor = 2 # 加速倍数
|
||
|
delay = int(1000 / fps / speed_up_factor) # 计算加速后每帧之间的时间间隔,以毫秒为单位
|
||
|
|
||
|
# 创建多目标追踪器
|
||
|
trackers = []
|
||
|
|
||
|
# 创建用于存储已追踪对象信息的列表
|
||
|
tracked_objects = []
|
||
|
|
||
|
while True:
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
|
||
|
start_time = time.time()
|
||
|
|
||
|
# 检测目标
|
||
|
height, width, channels = frame.shape
|
||
|
blob = cv2.dnn.blobFromImage(frame, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
|
||
|
net.setInput(blob)
|
||
|
outs = net.forward(output_layers)
|
||
|
|
||
|
# 获取检测结果
|
||
|
new_objects = [] # 用于存储当前帧检测到的新对象
|
||
|
|
||
|
class_ids = []
|
||
|
confidences = []
|
||
|
boxes = []
|
||
|
for out in outs:
|
||
|
for detection in out:
|
||
|
scores = detection[5:]
|
||
|
class_id = np.argmax(scores)
|
||
|
confidence = scores[class_id]
|
||
|
if confidence > 0.5:
|
||
|
# 目标检测
|
||
|
center_x = int(detection[0] * width)
|
||
|
center_y = int(detection[1] * height)
|
||
|
w = int(detection[2] * width)
|
||
|
h = int(detection[3] * height)
|
||
|
x = int(center_x - w / 2)
|
||
|
y = int(center_y - h / 2)
|
||
|
boxes.append([x, y, w, h])
|
||
|
confidences.append(float(confidence))
|
||
|
class_ids.append(class_id)
|
||
|
new_objects.append((x, y, w, h, class_id)) # 存储新检测到的对象信息
|
||
|
|
||
|
# 非最大值抑制
|
||
|
indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)
|
||
|
|
||
|
# 更新追踪器或添加新的追踪器
|
||
|
for i in indices.flatten():
|
||
|
x, y, w, h = boxes[i]
|
||
|
label = str(classes[class_ids[i]])
|
||
|
if label in ["dog", "cat", "bird"]: # 只选择动物目标
|
||
|
# 检查是否已经有相似的追踪器在追踪相同类型的对象
|
||
|
found_similar = False
|
||
|
for tracked_object in tracked_objects:
|
||
|
if tracked_object[4] == class_ids[i]: # 检查类别是否相同
|
||
|
# 计算当前检测到的对象与已有追踪器的距离或重叠度
|
||
|
existing_bbox = (tracked_object[0], tracked_object[1], tracked_object[0] + tracked_object[2],
|
||
|
tracked_object[1] + tracked_object[3])
|
||
|
new_bbox = (x, y, x + w, y + h)
|
||
|
overlap_area = calculate_overlap(existing_bbox, new_bbox)
|
||
|
if overlap_area > 0.5: # 如果重叠度超过阈值,认为是同一个对象,不再重复追踪
|
||
|
found_similar = True
|
||
|
break
|
||
|
if not found_similar:
|
||
|
tracker = cv2.TrackerKCF_create()
|
||
|
trackers.append(tracker)
|
||
|
trackers[-1].init(frame, (x, y, w, h))
|
||
|
tracked_objects.append((x, y, w, h, class_ids[i])) # 添加到已追踪对象列表
|
||
|
|
||
|
# 绘制追踪框
|
||
|
for tracker in trackers:
|
||
|
success, bbox = tracker.update(frame)
|
||
|
if success:
|
||
|
# 画出追踪框
|
||
|
p1 = (int(bbox[0]), int(bbox[1]))
|
||
|
p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
|
||
|
cv2.rectangle(frame, p1, p2, (255, 0, 0), 2)
|
||
|
else:
|
||
|
# 追踪失败
|
||
|
cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255),
|
||
|
2)
|
||
|
|
||
|
# 显示结果
|
||
|
cv2.imshow('Tracking', frame)
|
||
|
|
||
|
# 按下ESC键退出或关闭窗口退出
|
||
|
if cv2.waitKey(delay) & 0xFF == 27:
|
||
|
break
|
||
|
if cv2.getWindowProperty('Tracking', cv2.WND_PROP_VISIBLE) < 1:
|
||
|
break
|
||
|
|
||
|
cap.release()
|
||
|
cv2.destroyAllWindows()
|
||
|
|
||
|
|
||
|
def calculate_overlap(bbox1, bbox2):
|
||
|
# bbox1 和 bbox2 分别是 (x1, y1, x2, y2) 格式的边界框坐标
|
||
|
# 其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标
|
||
|
|
||
|
# 计算交集部分的坐标
|
||
|
inter_x1 = max(bbox1[0], bbox2[0])
|
||
|
inter_y1 = max(bbox1[1], bbox2[1])
|
||
|
inter_x2 = min(bbox1[2], bbox2[2])
|
||
|
inter_y2 = min(bbox1[3], bbox2[3])
|
||
|
|
||
|
# 计算交集区域的面积
|
||
|
inter_area = max(0, inter_x2 - inter_x1 + 1) * max(0, inter_y2 - inter_y1 + 1)
|
||
|
|
||
|
# 计算各自的区域面积
|
||
|
area_bbox1 = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
|
||
|
area_bbox2 = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
|
||
|
|
||
|
# 计算并返回重叠区域的IoU
|
||
|
iou = inter_area / float(area_bbox1 + area_bbox2 - inter_area)
|
||
|
|
||
|
return iou
|