You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
yolov8/MTSP-main/app.py

397 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
# @Author : pan
# @Description : Flask后端
# @Date : 2023年7月27日10:46:25
import base64
import cv2
import os
import json
import string
import random
import jwt
import numpy as np
import supervision as sv
import time
import datetime
from typing import Any
from dotenv import load_dotenv
from PIL import Image
from flask import Flask, request, abort, send_from_directory, jsonify, session
from pprint import pprint
from apscheduler.schedulers.background import BackgroundScheduler
from ultralytics import YOLO
from utils.flask_utils import *
# --------------------------------配置加载
load_dotenv(override=True, dotenv_path='config/end-back.env')
# 服务器配置
HOST_NAME = os.environ['HOST_NAME']
PORT = int(os.environ['PORT'])
TOLERANT_TIME_ERROR = int(os.environ['TOLERANT_TIME_ERROR']) # 可以容忍的时间戳误差(s)
current_dir = os.getcwd() # 获取当前文件夹的路径
BEFORE_IMG_PATH = os.path.join(current_dir, 'static', os.environ['BEFORE_IMG_PATH']) # 拼接目标文件夹路径
AFTER_IMG_PATH = os.path.join(current_dir, 'static', os.environ['AFTER_IMG_PATH'])
# 数据库配置
MYSQL_HOST = os.environ['MYSQL_HOST'] # SQL主机
MYSQL_PORT = os.environ['MYSQL_PORT'] # 连接端口
MYSQL_user = os.environ['MYSQL_user'] # 用户名
MYSQL_password = os.environ['MYSQL_password'] # 密码
MYSQL_db = os.environ['MYSQL_db'] # 数据库名
MYSQL_charset = os.environ['MYSQL_charset'] # utf8
# 实例化数据库
db = SQLManager(host=MYSQL_HOST, port=eval(MYSQL_PORT), user=MYSQL_user,
passwd=MYSQL_password, db=MYSQL_db, charset=MYSQL_charset)
# result = db.get_one("SELECT * FROM user WHERE username=%s", ('dzp'))
# pprint(result)
# pprint(result['age'])
# Load a model yolo的全局变量
model = YOLO("./models/car.pt") # load a pretrained model (recommended for training)
box_annotator = sv.BoxAnnotator(
thickness=2,
text_thickness=1,
text_scale=0.5
)
app = Flask(__name__, static_folder='static')
# 在执行定时任务时可能会出现CPU耗尽的情况。
# 这个问题可能出现在任务本身需要大量CPU资源或任务设置了过长的时间间隔导致进程变得不稳定并且占用了整个CPU。
# 为了避免这个问题可以使用“APScheduler”库提供的schedulres定时器和executors执行器可以根据具体需求设置
# 链接https://www.python100.com/html/85441.html
# 关于Execution of job "scheduled_function (trigger: interval[0:00:10], next run at: 2023-08-01 12:24:15 CST)" skipped: maximum number of running instances reached (1)
# 1、调整定时任务的频率增加任务的执行间隔时间以确保一个任务实例能够在下一个实例开始之前完成。
# 例如将定时任务的执行间隔从10秒增加到20秒或更长时间。
# 2、增加最大运行实例数量根据你的需求和系统资源情况可能可以增加允许同时运行的任务实例的最大数量。
# 这通常需要查看你使用的任务调度器或框架的文档,并根据指导进行配置。
scheduler = BackgroundScheduler()
# 这里写要定时执行的代码
def scheduled_function():
print('定时任务启动!')
select_sql = "SELECT id, threshold, url, is_alarm, mode, location " \
"FROM monitor WHERE is_alarm='开启'"
monitor_list = db.get_list(select_sql)
# 循环执行
for item in monitor_list:
pid = int(item['id'])
threshold = int(item['threshold'])
mode = item['mode']
location = item['location']
source = item['url']
# 检测流是否存在
if not check_stream_availability(source):
print(f'该流拉取失败:{source}')
return False
# 根据模式选择不同的参数
if mode == "快速模式":
iter_model = iter(
model.track(source=source, show=False, stream=True, iou=0.3, conf=0.3))
elif mode == "准确模式":
iter_model = iter(
model.track(source=source, show=False, stream=True, iou=0.7, conf=0.7))
for i in range(2):
result = next(iter_model) # 这里是检测的核心
detections = sv.Detections.from_yolov8(result)
if result.boxes.id is None:
continue
if len(detections) > threshold:
# 获取当前时间
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 保存处理后的图片
res_url = save_res_img(result.orig_img, detections, f'alarm.jpg')
# 写入id
detections.tracker_id = result.boxes.id.cpu().numpy().astype(int)
# 构造警报信息
alarm_description = f'车流量:{len(detections)}'
# 构建插入语句
insert_sql = "INSERT INTO alarm (location, description, threshold, photo, pid, create_time, remark) " \
"VALUES (%s, %s, %s, %s, %s, %s, %s)"
db.modify(insert_sql, (location, alarm_description, threshold, res_url, pid, current_time, ''))
print('警报已记录!')
# 每5分钟执行一次
scheduler.add_job(scheduled_function, 'interval', seconds=20 * 1)
scheduler.start()
# 未登录——请求:
# 1、当用户没有登录时会话中session的username这一个key没有值为none
# 2、当为空的时候对于管理后端操作会被拦截器拦截并返回一个401
# 登录——请求
# 1、当用户进行登录后会话中的session为空然后我们为其设置一个session就好了session为他的username
# 2、当用户请求时带着他自己的username——服务器判断1session有没有 2session中的value 是否等于 username
# 注销——请求(逻辑)
# 1、当他注销时前端pinia就清空他的数据并且发生请求给后端让session清空
# 2、网站跳转到首页
# 拦截器
# 1、当他再次访问的不是白名单时 判断session中有没有username
# session设置
app.config['SECRET_KEY'] = 'my-secret-key' # 设置密钥
app.config['PERMANENT_SESSION_LIFETIME'] = 15 * 60 # session时间: 5分钟
# 拦截器白名单
whitelist = ['/', '/login', '/photo', '/recognize']
# 拦截器 (测试前就注释掉!)
@app.before_request
def interceptor():
if request.path.startswith('/static/'): # 如果请求路径以 /static/ 开头,则放行
return
if request.path in whitelist: # 白名单放行
return
if not session.get('username'): # 检查是否已登录
return wrap_unauthorized_return_value('Unauthorized') # 返回 401 未授权状态码
# 添加header解决跨域
@app.after_request
def after_request(response):
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'POST'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
return response
@app.route("/")
def start_server():
return "欢迎使用交通路况分析系统!后端启动成功!(*^▽^*)"
# JWT即 JSON Web Token ——这里准备使用TWJ的但是无奈不想用 Redis就算了 java就是sa-token
@app.route('/login', methods=["POST"])
def login():
try:
data = request.json # 获取 JSON 格式的数据
username = data.get('username').strip()
password = data.get('password').strip()
user_info = db.get_one("SELECT * FROM user WHERE username=%s", (username))
if user_info and user_info['password'] == password:
session['username'] = username # 存储session
return wrap_ok_return_value({'id':user_info['id'],
'avatar':user_info['avatar'],
'username':user_info['username']})
return wrap_error_return_value('错误的用户名或密码!') # 登陆失败
# 登陆失败
except:
return wrap_error_return_value('系统繁忙,请稍后再试!')
@app.route('/logOut', methods=["get"])
def log_out():
session.clear()
return wrap_ok_return_value('账号已退出!')
@app.route('/submitMonitorForm', methods=["POST"])
def submit_monitor_form():
try:
data = request.json # 获取 JSON 格式的数据
threshold = int(data.get('threshold'))
person = data.get('person')
video = data.get('video')
url = data.get('url')
if(data.get('is_alarm')):
is_alarm = '开启'
else:
is_alarm = '关闭'
mode = data.get('mode')
location = data.get('location')
create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
remark = data.get('remark')
# 插入
insert_sql = "INSERT INTO monitor " \
"(threshold, person, video, url, is_alarm, mode, location, create_time, create_by, remark) " \
"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
values = (threshold, person, video, url, is_alarm, mode, location, create_time, "", remark)
# pprint(values)
db.modify(insert_sql, values)
return wrap_ok_return_value('配置提交成功!')
# 处理异常情况
except Exception as e:
pprint(e)
return wrap_error_return_value('系统繁忙,请稍后再试!')
@app.route('/updateMonitorForm', methods=["POST"])
def update_monitor_form():
try:
data = request.json # 获取 JSON 格式的数据
id = data.get('id')
threshold = int(data.get('threshold'))
person = data.get('person')
video = data.get('video')
url = data.get('url')
if(data.get('is_alarm')):
is_alarm = '开启'
else:
is_alarm = '关闭'
mode = data.get('mode')
location = data.get('location')
remark = data.get('remark')
# 更新
update_sql = "UPDATE monitor SET " \
"threshold = %s, person = %s, video = %s, url = %s, " \
"is_alarm = %s, mode = %s, location = %s, remark = %s " \
"WHERE id = %s"
values = (threshold, person, video, url, is_alarm, mode, location, remark, id)
db.modify(update_sql, values)
return wrap_ok_return_value('配置更新成功!')
except Exception as e:
return wrap_error_return_value(str(e))
# 查询用户信息(分页查询)
@app.route('/usersList/<int:page>', methods=['GET'])
def get_user_list(page):
page_from = int((page - 1) * 10)
page_to = int(page)*10
select_sql = f"select id, username, avatar, email, grade from user limit {page_from}, {page_to}"
user_list = db.get_list(select_sql)
# pprint(user_list)
return wrap_ok_return_value(user_list)
# 查询监控信息(分页查询)
@app.route('/monitorList/<int:page>', methods=['GET'])
def get_monitor_list(page):
page_from = int((page - 1) * 10)
page_to = int(page)*10
select_sql = f"SELECT id, threshold, person, video, url, is_alarm, mode, " \
f"location, create_time, create_by, remark FROM monitor" \
f" limit {page_from}, {page_to}"
monitor_list = db.get_list(select_sql)
# 将datetime对象格式化为字符串
for item in monitor_list:
item['create_time'] = item['create_time'].strftime('%Y-%m-%d %H:%M:%S')
# pprint(monitor_list)
return wrap_ok_return_value(monitor_list)
# 查询警报信息(分页查询)
@app.route('/alarmList/<int:page>', methods=['GET'])
def get_alarm_list(page):
page_from = int((page - 1) * 10)
page_to = int(page)*10
select_sql = f"SELECT id, location, description, threshold, photo, pid, create_time, remark " \
f"FROM alarm LIMIT {page_from}, {page_to}"
alarm_list = db.get_list(select_sql)
# 将datetime对象格式化为字符串
for item in alarm_list:
item['create_time'] = item['create_time'].strftime('%Y-%m-%d %H:%M:%S')
return wrap_ok_return_value(alarm_list)
@app.route("/photo", methods=["POST"])
def recognize_base64():
photo_data = request.form.get('photo') # 获取前端传递的 base64 图片数据
photo_data = photo_data.replace('data:image/png;base64,', '') # 去掉 base64 编码中的前缀
# 解码 base64 数据为二进制数据
image_data = base64.b64decode(photo_data)
# 保存为文件
before_img_path = save_img_base64(image_data, path=BEFORE_IMG_PATH)
# 处理完成后,返回响应
name = f"{''.join(random.choice(string.ascii_lowercase) for i in range(5))}.png"
# 返回结果
return yolo_res(before_img_path=before_img_path, name=name)
@app.route("/recognize", methods=["POST"])
def recognize_photo():
photo = request.files['file']
name = photo.filename
# return "ok"
# img = Image.open(photo)
# img = Image.open(photo).convert("RGB")
img = cv2.imdecode(np.fromstring(photo.read(), np.uint8), cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# 保存未处理的图片
before_img_path = save_img(name,img,BEFORE_IMG_PATH)
# 返回结果
return yolo_res(before_img_path=before_img_path, name=name)
# yolo 处理图片
def yolo_res(before_img_path, name):
try:
# 获取数据源
img = Image.open(before_img_path)
iter_model = iter(
model.track(source=img, show=False))
result = next(iter_model) # 这里是检测的核心,每次循环都会检测一帧图像,可以自行打印result看看里面有哪些key可以用
# xyxy 是表示边界框坐标的一种常见格式。它代表了一个边界框的左上角和右下角的坐标值。
# xyxy 数组的第一个元素 [423.33, 453.05, 552.67, 588.98] 表示了一个边界框的四个坐标点:
# 左上角点的 x 和 y 坐标423.33, 453.05
# 右下角点的 x 和 y 坐标552.67, 588.98
detections = sv.Detections.from_yolov8(result)
if result.boxes.id is None:
return wrap_ok_return_value('照片中没有目标物体哟!')
# 写入id
detections.tracker_id = result.boxes.id.cpu().numpy().astype(int)
# 保存处理后的图片
res_img = result.orig_img
res_url = save_res_img(res_img, detections)
labels = [
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f} x:{x} y:{y}"
for x, y, confidence, class_id, tracker_id in detections
]
return wrap_ok_return_value({
'labels': labels,
'after_img_path': res_url
})
except Exception as e:
pprint(str(e))
return wrap_error_return_value('服务器繁忙,请稍后再试!')
def save_res_img(res_img, detections, name = 'default.jpg'):
labels = [
f"ID: {tracker_id}"
for x, y, confidence, class_id, tracker_id in detections
]
img_box = box_annotator.annotate(scene=res_img, detections=detections, labels=labels)
# 将 BGR 格式的 frame 转换为 RGB 格式
rgb_frame = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
# 把 rgb_frame 转换为 numpy格式 就行了
numpy_frame = np.array(rgb_frame)
after_img_path = save_img(name, numpy_frame, AFTER_IMG_PATH)
# 使用字符串替换方法将文件路径转换为指定的 URL 格式
return after_img_path.replace(current_dir, "http://127.0.0.1:5500/").replace('\\', '/')
# 在 Flask 中,警告消息"WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead."
# 表示您正在使用开发服务器,这不适合在生产环境中使用。如果您要切换到生产模式,您需要使用一个生产级的 WSGI 服务器。
# 通常使用类似于 Gunicorn、uWSGI 或 Nginx + uWSGI 的组合来部署和运行 Flask 应用。
if __name__ == "__main__":
app.run(host=HOST_NAME, port=PORT, debug=False)