import tkinter.filedialog as tkinter import numpy as np import argparse import cv2 from utils import FPS # 图像增强 def enhance_image(frame): # 直方图均衡化 frame_yuv = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV) frame_yuv[:, :, 0] = cv2.equalizeHist(frame_yuv[:, :, 0]) frame = cv2.cvtColor(frame_yuv, cv2.COLOR_YUV2BGR) # 锐化 sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) sharpened_frame = cv2.filter2D(frame, -1, sharpen_kernel) # 对比度增强 alpha = 1.5 # 控制对比度(1.0表示不变) enhanced_frame = cv2.convertScaleAbs(sharpened_frame, alpha=alpha, beta=0) # 亮度调整 beta = 30 # 控制亮度调整量 enhanced_frame = cv2.convertScaleAbs(enhanced_frame, alpha=1.0, beta=beta) return enhanced_frame # 参数 ap = argparse.ArgumentParser() ap.add_argument("-p", "--prototxt", default="mobilenet_ssd/MobileNetSSD_deploy.prototxt", help="path to Caffe 'deploy' prototxt file") ap.add_argument("-m", "--model", default="mobilenet_ssd/MobileNetSSD_deploy.caffemodel", help="path to Caffe pre-trained model") ap.add_argument("-v", "--video", default=None, help="path to input video file") ap.add_argument("-o", "--output", type=str, help="path to optional output video file") ap.add_argument("-c", "--confidence", type=float, default=0.3, help="minimum probability to filter weak detections") args = vars(ap.parse_args()) # SSD标签 CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] # 读取网络模型 print("[INFO] loading model...") net = cv2.dnn.readNetFromCaffe(args["prototxt"], args["model"]) # 初始化 if args["video"] is None: video_path = tkinter.askopenfilename(filetypes=[("视频文件", "*.mp4")]) print("[INFO] starting video stream...") vs = cv2.VideoCapture(video_path) else: print("[INFO] starting video stream...") vs = cv2.VideoCapture(args["video"]) writer = None # 初始化目标追踪器 trackers = [] labels = [] fps = FPS().start() while True: # 读取一帧 (grabbed, frame) = vs.read() # 是否是最后了 if frame is None: break # 图像增强 # frame = enhance_image(frame) # 预处理操作 (h, w) = frame.shape[:2] width = 600 r = width / float(w) dim = (width, int(h * r)) frame = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA) rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 如果要将结果保存的话 if args["output"] is not None and writer is None: fourcc = cv2.VideoWriter_fourcc(*"MJPG") writer = cv2.VideoWriter(args["output"], fourcc, 30, (frame.shape[1], frame.shape[0]), True) # 先检测 再追踪 if len(trackers) == 0: # 获取blob数据 (h, w) = frame.shape[:2] blob = cv2.dnn.blobFromImage(frame, 0.007843, (w, h), 127.5) # 得到检测结果 net.setInput(blob) detections = net.forward() # 遍历得到的检测结果 for i in np.arange(0, detections.shape[2]): # 能检测到多个结果,只保留概率高的 confidence = detections[0, 0, i, 2] # 过滤 if confidence > args["confidence"]: # 提取类别索引 idx = int(detections[0, 0, i, 1]) label = CLASSES[idx] # 只保留人的 if CLASSES[idx] != "person": continue # 得到BBOX box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") # 使用CSRT目标追踪器 tracker = cv2.TrackerCSRT_create() tracker.init(frame, (startX, startY, endX - startX, endY - startY)) # 保存结果 labels.append(label) trackers.append(tracker) # 绘图 cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2) cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2) # 如果已经有了框,就可以直接追踪了 else: # 每一个追踪器都要进行更新 for (tracker, label) in zip(trackers, labels): success, box = tracker.update(frame) if success: (startX, startY, w, h) = [int(v) for v in box] endX = startX + w endY = startY + h # 画出来 cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2) cv2.putText(frame, label, (startX, startY - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2) # 也可以把结果保存下来 if writer is not None: writer.write(frame) # 显示 cv2.imshow("Frame", frame) key = cv2.waitKey(1) & 0xFF # 退出 if key == 27: break # 计算FPS fps.update() fps.stop() print("[INFO] elapsed time: {:.2f}".format(fps.elapsed())) print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) if writer is not None: writer.release() cv2.destroyAllWindows() vs.release()