From 0ccb92c70aff1c455e74071f7fe24561ad8baf55 Mon Sep 17 00:00:00 2001 From: mmb2no96j Date: Wed, 3 Jul 2024 17:53:39 +0800 Subject: [PATCH] ADD file via upload --- multi_object_tracking_slow.py | 175 ++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 multi_object_tracking_slow.py diff --git a/multi_object_tracking_slow.py b/multi_object_tracking_slow.py new file mode 100644 index 0000000..30df000 --- /dev/null +++ b/multi_object_tracking_slow.py @@ -0,0 +1,175 @@ +import tkinter.filedialog as tkinter +import numpy as np +import argparse +import cv2 +from utils import FPS + + +# 图像增强 +def enhance_image(frame): + # 直方图均衡化 + frame_yuv = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV) + frame_yuv[:, :, 0] = cv2.equalizeHist(frame_yuv[:, :, 0]) + frame = cv2.cvtColor(frame_yuv, cv2.COLOR_YUV2BGR) + + # 锐化 + sharpen_kernel = np.array([[-1, -1, -1], + [-1, 9, -1], + [-1, -1, -1]]) + sharpened_frame = cv2.filter2D(frame, -1, sharpen_kernel) + + # 对比度增强 + alpha = 1.5 # 控制对比度(1.0表示不变) + enhanced_frame = cv2.convertScaleAbs(sharpened_frame, alpha=alpha, beta=0) + + # 亮度调整 + beta = 30 # 控制亮度调整量 + enhanced_frame = cv2.convertScaleAbs(enhanced_frame, alpha=1.0, beta=beta) + + return enhanced_frame + + + + +# 参数 +ap = argparse.ArgumentParser() +ap.add_argument("-p", "--prototxt", default="mobilenet_ssd/MobileNetSSD_deploy.prototxt", + help="path to Caffe 'deploy' prototxt file") +ap.add_argument("-m", "--model", default="mobilenet_ssd/MobileNetSSD_deploy.caffemodel", + help="path to Caffe pre-trained model") +ap.add_argument("-v", "--video", default=None, + help="path to input video file") +ap.add_argument("-o", "--output", type=str, + help="path to optional output video file") +ap.add_argument("-c", "--confidence", type=float, default=0.3, + help="minimum probability to filter weak detections") +args = vars(ap.parse_args()) + +# SSD标签 +CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat", + "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", + "dog", "horse", "motorbike", "person", "pottedplant", "sheep", + "sofa", "train", "tvmonitor"] + +# 读取网络模型 +print("[INFO] loading model...") +net = cv2.dnn.readNetFromCaffe(args["prototxt"], args["model"]) + +# 初始化 +if args["video"] is None: + video_path = tkinter.askopenfilename(filetypes=[("视频文件", "*.mp4")]) + print("[INFO] starting video stream...") + vs = cv2.VideoCapture(video_path) +else: + print("[INFO] starting video stream...") + vs = cv2.VideoCapture(args["video"]) +writer = None + +# 初始化目标追踪器 +trackers = [] +labels = [] +fps = FPS().start() + +while True: + # 读取一帧 + (grabbed, frame) = vs.read() + + # 是否是最后了 + if frame is None: + break + + # 图像增强 + # frame = enhance_image(frame) + + # 预处理操作 + (h, w) = frame.shape[:2] + width = 600 + r = width / float(w) + dim = (width, int(h * r)) + frame = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA) + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # 如果要将结果保存的话 + if args["output"] is not None and writer is None: + fourcc = cv2.VideoWriter_fourcc(*"MJPG") + writer = cv2.VideoWriter(args["output"], fourcc, 30, (frame.shape[1], frame.shape[0]), True) + + # 先检测 再追踪 + if len(trackers) == 0: + # 获取blob数据 + (h, w) = frame.shape[:2] + blob = cv2.dnn.blobFromImage(frame, 0.007843, (w, h), 127.5) + + # 得到检测结果 + net.setInput(blob) + detections = net.forward() + + # 遍历得到的检测结果 + for i in np.arange(0, detections.shape[2]): + # 能检测到多个结果,只保留概率高的 + confidence = detections[0, 0, i, 2] + + # 过滤 + if confidence > args["confidence"]: + # 提取类别索引 + idx = int(detections[0, 0, i, 1]) + label = CLASSES[idx] + + # 只保留人的 + if CLASSES[idx] != "person": + continue + + # 得到BBOX + box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) + (startX, startY, endX, endY) = box.astype("int") + + # 使用CSRT目标追踪器 + tracker = cv2.TrackerCSRT_create() + tracker.init(frame, (startX, startY, endX - startX, endY - startY)) + + # 保存结果 + labels.append(label) + trackers.append(tracker) + + # 绘图 + cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2) + cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2) + + # 如果已经有了框,就可以直接追踪了 + else: + # 每一个追踪器都要进行更新 + for (tracker, label) in zip(trackers, labels): + success, box = tracker.update(frame) + if success: + (startX, startY, w, h) = [int(v) for v in box] + endX = startX + w + endY = startY + h + + # 画出来 + cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2) + cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2) + + # 也可以把结果保存下来 + if writer is not None: + writer.write(frame) + + # 显示 + cv2.imshow("Frame", frame) + key = cv2.waitKey(1) & 0xFF + + # 退出 + if key == 27: + break + + # 计算FPS + fps.update() + +fps.stop() +print("[INFO] elapsed time: {:.2f}".format(fps.elapsed())) +print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) + +if writer is not None: + writer.release() + +cv2.destroyAllWindows() +vs.release()