You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
multiObjTracking/multi_object_tracking_slow.py

176 lines
5.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tkinter.filedialog as tkinter
import numpy as np
import argparse
import cv2
from utils import FPS
# 图像增强
def enhance_image(frame):
# 直方图均衡化
frame_yuv = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV)
frame_yuv[:, :, 0] = cv2.equalizeHist(frame_yuv[:, :, 0])
frame = cv2.cvtColor(frame_yuv, cv2.COLOR_YUV2BGR)
# 锐化
sharpen_kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
sharpened_frame = cv2.filter2D(frame, -1, sharpen_kernel)
# 对比度增强
alpha = 1.5 # 控制对比度1.0表示不变)
enhanced_frame = cv2.convertScaleAbs(sharpened_frame, alpha=alpha, beta=0)
# 亮度调整
beta = 30 # 控制亮度调整量
enhanced_frame = cv2.convertScaleAbs(enhanced_frame, alpha=1.0, beta=beta)
return enhanced_frame
# 参数
ap = argparse.ArgumentParser()
ap.add_argument("-p", "--prototxt", default="mobilenet_ssd/MobileNetSSD_deploy.prototxt",
help="path to Caffe 'deploy' prototxt file")
ap.add_argument("-m", "--model", default="mobilenet_ssd/MobileNetSSD_deploy.caffemodel",
help="path to Caffe pre-trained model")
ap.add_argument("-v", "--video", default=None,
help="path to input video file")
ap.add_argument("-o", "--output", type=str,
help="path to optional output video file")
ap.add_argument("-c", "--confidence", type=float, default=0.3,
help="minimum probability to filter weak detections")
args = vars(ap.parse_args())
# SSD标签
CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
"dog", "horse", "motorbike", "person", "pottedplant", "sheep",
"sofa", "train", "tvmonitor"]
# 读取网络模型
print("[INFO] loading model...")
net = cv2.dnn.readNetFromCaffe(args["prototxt"], args["model"])
# 初始化
if args["video"] is None:
video_path = tkinter.askopenfilename(filetypes=[("视频文件", "*.mp4")])
print("[INFO] starting video stream...")
vs = cv2.VideoCapture(video_path)
else:
print("[INFO] starting video stream...")
vs = cv2.VideoCapture(args["video"])
writer = None
# 初始化目标追踪器
trackers = []
labels = []
fps = FPS().start()
while True:
# 读取一帧
(grabbed, frame) = vs.read()
# 是否是最后了
if frame is None:
break
# 图像增强
# frame = enhance_image(frame)
# 预处理操作
(h, w) = frame.shape[:2]
width = 600
r = width / float(w)
dim = (width, int(h * r))
frame = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA)
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 如果要将结果保存的话
if args["output"] is not None and writer is None:
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
writer = cv2.VideoWriter(args["output"], fourcc, 30, (frame.shape[1], frame.shape[0]), True)
# 先检测 再追踪
if len(trackers) == 0:
# 获取blob数据
(h, w) = frame.shape[:2]
blob = cv2.dnn.blobFromImage(frame, 0.007843, (w, h), 127.5)
# 得到检测结果
net.setInput(blob)
detections = net.forward()
# 遍历得到的检测结果
for i in np.arange(0, detections.shape[2]):
# 能检测到多个结果,只保留概率高的
confidence = detections[0, 0, i, 2]
# 过滤
if confidence > args["confidence"]:
# 提取类别索引
idx = int(detections[0, 0, i, 1])
label = CLASSES[idx]
# 只保留人的
if CLASSES[idx] != "person":
continue
# 得到BBOX
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
(startX, startY, endX, endY) = box.astype("int")
# 使用CSRT目标追踪器
tracker = cv2.TrackerCSRT_create()
tracker.init(frame, (startX, startY, endX - startX, endY - startY))
# 保存结果
labels.append(label)
trackers.append(tracker)
# 绘图
cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2)
cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
# 如果已经有了框,就可以直接追踪了
else:
# 每一个追踪器都要进行更新
for (tracker, label) in zip(trackers, labels):
success, box = tracker.update(frame)
if success:
(startX, startY, w, h) = [int(v) for v in box]
endX = startX + w
endY = startY + h
# 画出来
cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2)
cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
# 也可以把结果保存下来
if writer is not None:
writer.write(frame)
# 显示
cv2.imshow("Frame", frame)
key = cv2.waitKey(1) & 0xFF
# 退出
if key == 27:
break
# 计算FPS
fps.update()
fps.stop()
print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
if writer is not None:
writer.release()
cv2.destroyAllWindows()
vs.release()