Compare commits

..

48 Commits

Author SHA1 Message Date
陆鑫宇 6568fcb518 ffff
1 year ago
陆鑫宇 25e25a96f9 ffff
1 year ago
陆鑫宇 1bac3afcc1 ffff
1 year ago
陆鑫宇 bc6da32041 ffff
1 year ago
陆鑫宇 9fe3894148 ffff
1 year ago
陆鑫宇 708057caa9 ffff
1 year ago
陆鑫宇 f8b87477a3 ffff
1 year ago
陆鑫宇 7791e429be ffff
1 year ago
陆鑫宇 25dc0197e6 ffff
1 year ago
陆鑫宇 19932cfa0c ffff
1 year ago
陆鑫宇 070cb63404 ffff
1 year ago
陆鑫宇 d2234de65e ffff
1 year ago
陆鑫宇 ea698ce184 ffff
1 year ago
陆鑫宇 a1c9c1c381 ffff
1 year ago
陆鑫宇 efcad48cc5 ffff
1 year ago
陆鑫宇 bbb11be855 ffff
1 year ago
陆鑫宇 1f1cb7317e ffff
1 year ago
陆鑫宇 f02d2c5bdf ffff
1 year ago
陆鑫宇 a6daed6c45 ffff
1 year ago
陆鑫宇 305c8d50a2 ffff
1 year ago
陆鑫宇 4fb4d7b1d5 UYHHU
1 year ago
陆鑫宇 994829f66a UYHHU
1 year ago
陆鑫宇 98baed14f0 UYHHU
1 year ago
陆鑫宇 f546ba5bca UYHHUfds
1 year ago
陆鑫宇 e765fe90bb UYHHUfds
1 year ago
陆鑫宇 9fb60312fa UYHHUfds
1 year ago
陆鑫宇 b3499cd5fd UYHHUfds
1 year ago
陆鑫宇 5ac0f469eb UYHHUfds
1 year ago
陆鑫宇 800a382420 UYHHUfds
1 year ago
陆鑫宇 5b10cb5db7 UYHHUfds
1 year ago
陆鑫宇 94782a07d7 UYHHUfds
1 year ago
陆鑫宇 2b3e77d5ec UYHHUfds
1 year ago
陆鑫宇 9da0fdeeba UYHHUfds
1 year ago
陆鑫宇 3af93b96e0 UYHHUfds
1 year ago
陆鑫宇 77000c29f2 Merge remote-tracking branch 'origin/feature/zuyuan3' into feature/zuyuan3
1 year ago
陆鑫宇 f2b32c4003 UYHHUfds
1 year ago
陆鑫宇 0d979b8a5e UYHHUfds
1 year ago
陆鑫宇 0b3f82b479 Merge remote-tracking branch 'origin/feature/zuyuan3' into feature/zuyuan3
1 year ago
陆鑫宇 688655e3ad UYHHU
1 year ago
陆鑫宇 d286d2e922 UYHHU
1 year ago
陆鑫宇 36fd69d945 UYHHU
1 year ago
陆鑫宇 830d6922af UYHHU
1 year ago
陆鑫宇 e8e899eab9 UYHHU
1 year ago
陆鑫宇 0844c8d295 UYHHU
1 year ago
陆鑫宇 0fe98c8397 Merge remote-tracking branch 'origin/feature/zuyuan3' into feature/zuyuan3
1 year ago
陆鑫宇 a0fa60bcd4 UYHHU
1 year ago
陆鑫宇 77f94dd69c add readme
1 year ago
p9tnomuae 3926f76968 accept pr
1 year ago

@ -0,0 +1,60 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="ddcf00de-e299-4bae-956f-15823ef810e7" name="Changes" comment="" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectColorInfo"><![CDATA[{
"associatedIndex": 8
}]]></component>
<component name="ProjectId" id="2qNz6vVownCYeQB9hugijDiZyr1" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"RunOnceActivity.ShowReadmeOnStart": "true",
"git-widget-placeholder": "master",
"last_opened_file_path": "C:/Users/86156/Downloads/yolov8-master/yolov8",
"node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true",
"node.js.selected.package.eslint": "(autodetect)",
"node.js.selected.package.tslint": "(autodetect)",
"nodejs_package_manager_path": "npm",
"vue.rearranger.settings.migration": "true"
}
}]]></component>
<component name="SharedIndexes">
<attachedChunks>
<set>
<option value="bundled-js-predefined-d6986cc7102b-e768b9ed790e-JavaScript-PY-243.21565.199" />
<option value="bundled-python-sdk-cab1f2013843-4ae2d6a61b08-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-243.21565.199" />
</set>
</attachedChunks>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="ddcf00de-e299-4bae-956f-15823ef810e7" name="Changes" comment="" />
<created>1734517648329</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1734517648329</updated>
<workItem from="1734517649406" duration="2881000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
</project>

@ -188,7 +188,7 @@ def login():
'avatar':user_info['avatar'],
'username':user_info['username']})
return wrap_error_return_value('错误的用户名或密码!') # 登陆失败
# 登陆失败
# 登陆失败
except:
return wrap_error_return_value('系统繁忙,请稍后再试!')

@ -31,7 +31,7 @@ class WorkerThread(QThread):
self.is_continue = False
self.is_close = False
self.is_exec = False
#DFTFDFGY
# 添加样式 赛博朋克
plt.style.use("cyberpunk")
# plt显示中文

@ -1,74 +1,111 @@
# -*- coding: utf-8 -*-
# @Author : pan
# @Description : 集成配置(单例模式)
# @Description : 集成配置(采用单例模式实现,用于管理整个程序相关的配置信息,例如目标检测的参数、文件保存路径等配置内容,并支持从配置文件读取和保存配置信息的功能。
# @Date : 2023年7月26日12:49:37
import json
import os
# dddddd (这里看起来像是一个无意义的占位或者临时标记,暂时可以忽略其具体作用)
# MainConfig类用于管理程序的配置信息采用单例模式设计确保整个程序中只有一个该类的实例存在方便统一管理配置。
class MainConfig:
# 类属性用于存储单例的实例对象初始化为None表示尚未创建实例。
_instance = None
# __new__方法是在创建类实例时调用的特殊方法用于控制实例的创建过程。
# 这里通过判断_instance是否为None来确保只有第一次调用时才真正创建实例后续调用直接返回已创建的实例从而实现单例模式。
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
# __init__方法是类的构造方法用于初始化实例的属性。
# 在这里会进行一些配置相关的初始化操作,同时通过判断是否已经初始化过(通过'initialized'属性判断)来避免重复初始化。
def __init__(self, config_path="main_config.json"):
if hasattr(self, 'initialized'):
return
self.initialized = True
self.config_path = config_path
# 默认配置
# 默认配置属性设置
# IOU交并比阈值用于目标检测等场景中判断检测框的重叠程度等情况默认值设为0.5。
self.iou = 0.5
# 置信度阈值用于过滤目标检测结果中置信度较低的检测框默认值设为0.5。
self.conf = 0.5
# 速率相关参数具体作用需根据程序上下文确定默认值设为1.0。
self.rate = 1.0
# 是否保存检测结果视频的标志默认设为False表示不保存。
self.save_res = False
# 是否保存检测结果文本标签信息的标志默认设为False表示不保存。
self.save_txt = False
# 保存检测结果视频的默认路径,设为"pre_result",后续可根据配置文件或实际情况修改。
self.save_res_path = "pre_result"
# 保存检测结果文本标签信息的默认路径,设为"pre_labels",后续可根据配置文件或实际情况修改。
self.save_txt_path = "pre_labels"
# 模型文件所在的默认路径,设为"models",用于存放或查找相关的模型文件。
self.models_path = "models"
# 是否显示目标轨迹的标志默认设为True表示显示。
self.show_trace = True
# 是否显示目标标签的标志默认设为True表示显示。
self.show_labels = True
self.open_fold = os.getcwd() # 设置为当前工作目录的路径
# 初始化为当前工作目录的路径,可作为程序后续操作的基础目录,比如打开文件等操作可能基于此目录进行。
self.open_fold = os.getcwd()
# RTSP实时流传输协议的IP地址及相关认证信息用于连接对应的视频流源默认设置了一个示例地址及简单的用户名和密码实际应用中可能需要修改
self.rtsp_ip = "rtsp://admin:admin@192.168.43.1:8554/live"
# 车辆相关的ID标识具体作用需根据程序对车辆的处理逻辑确定默认设为1。
self.car_id = 1
# 车辆相关的阈值可能用于判断车辆的某些条件是否满足等情况默认设为10。
self.car_threshold = 10
# 读取配置文件
# 调用load_config方法尝试从配置文件中读取配置信息来更新默认配置。
self.load_config()
# 从配置文件中加载配置信息的方法,如果配置文件不存在则先保存默认配置;如果存在则读取并更新相应的配置属性。
def load_config(self):
# 如果文件不存在(写入,然后使用默认值)
# 如果指定的配置文件不存在
if not os.path.exists(self.config_path):
# 先调用save_config方法保存默认配置信息到文件中后续可通过编辑该文件来修改配置
self.save_config()
return
# 如果文件存在(修改默认值)
# 如果配置文件存在,尝试打开文件并读取其中的配置数据
with open(self.config_path) as f:
try:
config_data = json.load(f)
# 读取配置
# 读取配置文件中的IOU值如果不存在则使用当前类属性的默认值self.iou
self.iou = config_data.get("iou", self.iou)
# 读取配置文件中的置信度值如果不存在则使用当前类属性的默认值self.conf
self.conf = config_data.get("conf", self.conf)
# 读取配置文件中的速率值如果不存在则使用当前类属性的默认值self.rate
self.rate = config_data.get("rate", self.rate)
# 读取配置文件中是否保存检测结果视频的标志如果不存在则使用当前类属性的默认值self.save_res
self.save_res = config_data.get("save_res", self.save_res)
# 读取配置文件中是否保存检测结果文本标签信息的标志如果不存在则使用当前类属性的默认值self.save_txt
self.save_txt = config_data.get("save_txt", self.save_txt)
# 读取配置文件中保存检测结果视频的路径如果不存在则使用当前类属性的默认值self.save_res_path
self.save_res_path = config_data.get("save_res_path", self.save_res_path)
# 读取配置文件中保存检测结果文本标签信息的路径如果不存在则使用当前类属性的默认值self.save_txt_path
self.save_txt_path = config_data.get("save_txt_path", self.save_txt_path)
# 读取配置文件中模型文件所在的路径如果不存在则使用当前类属性的默认值self.models_path
self.models_path = config_data.get("models_path", self.models_path)
# 读取配置文件中是否显示目标标签的标志如果不存在则使用当前类属性的默认值self.show_labels
self.show_labels = config_data.get("show_labels", self.show_labels)
# 读取配置文件中是否显示目标轨迹的标志如果不存在则使用当前类属性的默认值self.show_trace
self.show_trace = config_data.get("show_trace", self.show_trace)
# 读取配置文件中当前工作目录的路径如果不存在则使用当前类属性的默认值self.open_fold
self.open_fold = config_data.get("open_fold", self.open_fold)
# 读取配置文件中RTSP的IP地址如果不存在则使用当前类属性的默认值self.rtsp_ip
self.rtsp_ip = config_data.get("rtsp_ip", self.rtsp_ip)
# 读取配置文件中车辆相关的ID标识如果不存在则使用当前类属性的默认值self.car_id
self.car_id = config_data.get("car_id", self.car_id)
# 读取配置文件中车辆相关的阈值如果不存在则使用当前类属性的默认值self.car_threshold
self.car_threshold = config_data.get("car_threshold", self.car_threshold)
except (json.JSONDecodeError, KeyError, TypeError):
# 如果在读取配置文件或解析配置数据过程中出现错误比如JSON格式错误、缺少关键键值等则重新保存默认配置信息覆盖原配置文件
self.save_config()
# 保存配置
# 保存当前配置信息到配置文件的方法将类中的各个配置属性整理成字典格式转换为JSON字符串后写入到指定的配置文件中。
def save_config(self):
new_config = {"iou": self.iou,
"conf": self.conf,
@ -76,8 +113,7 @@ class MainConfig:
"save_res": self.save_res,
"save_txt": self.save_txt,
"save_res_path": self.save_res_path,
"save_txt_path": self.save_txt_path,
"models_path" : self.models_path,
"models_path": self.models_path,
"open_fold": self.open_fold,
"show_trace": self.show_trace,
"show_labels": self.show_labels,
@ -92,4 +128,5 @@ class MainConfig:
if __name__ == "__main__":
# 创建MainConfig类的实例由于采用单例模式这里获取到的是整个程序中唯一的配置实例对象后续可以通过这个对象来访问和操作配置信息。
config = MainConfig()

@ -3,6 +3,8 @@
# @Description : 废弃方案(雷达图)
# @Date : 2023年7月27日10:46:04
# 从PyQt5.QtWidgets模块中导入多个类这些类用于创建不同类型的窗口部件例如QApplication用于管理应用程序的生命周期和设置
# QWidget是所有用户界面对象的基类QLabel用于显示文本或图像标签QVBoxLayout用于创建垂直布局管理器QMainWindow是主窗口类提供了应用程序主窗口的框架。
from PyQt5.QtWidgets import (
QApplication,
QWidget,
@ -10,62 +12,119 @@ from PyQt5.QtWidgets import (
QVBoxLayout,
QMainWindow
)
# 从PyQt5.QtGui模块中导入多个类QPixmap用于处理图像像素数据QPaintEvent用于处理绘图事件相关操作QImage用于表示图像数据结构。
from PyQt5.QtGui import (
QPixmap,
QPaintEvent,
QImage
)
# 从PyQt5.QtCore模块中导入QThread类用于在多线程环境中创建线程实现多线程相关的功能不过在当前代码中未体现其具体使用
from PyQt5.QtCore import (
QThread
)
# 导入PyQt5.Qt模块下的所有内容通常不建议在大型项目中这样做因为可能会导致命名空间冲突但在简单示例中可以方便使用各种相关类和函数
from PyQt5.Qt import *
import sys
import numpy as np
# 从matplotlib.backends.backend_qtagg模块中导入FigureCanvas类它用于将matplotlib绘制的图形嵌入到PyQt应用程序的窗口部件中实现图形展示功能。
from matplotlib.backends.backend_qtagg import FigureCanvas
import matplotlib.pyplot as plt
#lxy
# 定义了一个名为drawCloudMain的类它继承自QMainWindow用于创建一个带有特定绘图功能绘制雷达图相关的主窗口。
class drawCloudMain(QMainWindow):
def __init__(self) -> None:
super().__init__()
self.r = 2 * np.random.rand(100) # 生成100个服从“0~1”均匀分布的随机样本值
self.angle = 2 * np.pi * np.random.rand(100) # 生成角度
# 使用numpy的random.rand函数生成100个服从“0~1”均匀分布的随机样本值作为雷达图中每个数据点到原点的距离半径
# 这里将生成的随机值乘以2使得半径的取值范围大致在0到2之间不过后续代码中设置了y轴范围限制实际展示效果会根据该限制调整
self.r = 2 * np.random.rand(100)
# 使用numpy的random.rand函数生成100个在“0到2 * pi”之间均匀分布的随机角度值用于确定雷达图中每个数据点的角度位置。
self.angle = 2 * np.pi * np.random.rand(100)
def paintEvent(self, painter: QPaintEvent) -> None:
plt.cla() # 清屏
# 获取绘图并绘制
"""
函数功能
重写QWidget类的paintEvent方法用于处理绘图事件当窗口需要重绘例如初次显示窗口大小改变被遮挡后重新显示等情况
该方法会被自动调用在这里实现绘制雷达图的具体逻辑
参数说明
painter (QPaintEvent)绘图事件对象包含了与绘图相关的一些信息不过在当前函数中未直接使用该参数进行操作
它主要是由系统传递过来触发绘图操作的一个标识
返回值
None
"""
# 清除当前matplotlib的绘图区域相当于清空之前可能存在的绘图内容为新的绘图做准备。
plt.cla()
# 获取matplotlib的Figure对象它是整个绘图的顶层容器类似于一张画布可以在上面添加各种坐标轴、图形等元素。
fig = plt.figure()
# 在Figure对象上添加一个极坐标投影的坐标轴参数[0, 0, 1, 1]表示坐标轴在Figure中的位置和大小这里是占满整个Figure
# 通过设置projection="polar"指定为极坐标系统,后续绘制的图形将基于极坐标进行展示,适用于绘制雷达图等需要极坐标表示的图形。
ax = fig.add_axes([0, 0, 1, 1], projection="polar")
# 设置极坐标中y轴半径方向的范围将其限制在0到10之间这样绘制的数据点的半径值会根据这个范围进行缩放展示。
ax.set_ylim(0, 10)
# 设置y轴半径方向的刻度这里指定刻度值为从0到10间隔为2用于在雷达图的半径方向上显示清晰的刻度标记方便查看数据的大小。
ax.set_yticks(np.arange(0, 10, 2))
# 在极坐标的坐标轴上绘制散点图使用之前生成的角度self.angle和半径self.r数据作为散点的坐标位置
# 以此展示出随机分布的散点构成的雷达图样式的图形。
ax.scatter(self.angle, self.r)
# 创建一个FigureCanvas对象它将matplotlib的Figure对象嵌入到PyQt的窗口部件体系中使得绘制的图形可以在PyQt窗口中显示出来。
cavans = FigureCanvas(fig)
# 将创建好的包含绘图内容的FigureCanvas设置为当前主窗口的中心部件这样绘图就会显示在主窗口的中央区域。
self.setCentralWidget(cavans)
# 定义了一个名为drawCloud的类它继承自QWidget同样用于创建一个可以绘制雷达图的窗口部件与drawCloudMain类功能类似但结构稍有不同。
class drawCloud(QWidget):
def __init__(self) -> None:
super().__init__()
self.r = 2 * np.random.rand(100) # 生成100个服从“0~1”均匀分布的随机样本值
self.angle = 2 * np.pi * np.random.rand(100) # 生成角度
# 与drawCloudMain类中的操作类似生成100个服从“0~1”均匀分布的随机样本值作为雷达图的半径数据乘以2后半径取值范围大致在0到2之间。
self.r = 2 * np.random.rand(100)
# 同样生成100个在“0到2 * pi”之间均匀分布的随机角度值用于确定雷达图中数据点的角度位置。
self.angle = 2 * np.pi * np.random.rand(100)
# 创建一个matplotlib的Figure对象作为绘图的顶层容器后续将在这个Figure上添加坐标轴和绘制图形。
self.figure = plt.figure()
# 创建一个FigureCanvas对象用于将matplotlib的Figure嵌入到PyQt窗口部件中以便在窗口中展示绘制的图形。
self.canvas = FigureCanvas(self.figure)
# 创建一个垂直布局管理器用于管理窗口中的部件布局将其与当前的QWidget对象self关联起来意味着该布局将应用于这个窗口部件内部。
layout = QVBoxLayout(self)
# 将包含绘图内容的FigureCanvas添加到垂直布局中这样它就会按照垂直布局的规则在窗口中显示通常会在窗口中从上到下排列部件。
layout.addWidget(self.canvas)
def paintEvent(self, painter: QPaintEvent) -> None:
"""
函数功能
重写QWidget类的paintEvent方法用于处理绘图事件在窗口需要重绘时被调用实现绘制雷达图的具体逻辑
与drawCloudMain类中的paintEvent方法类似但基于自身的成员变量和布局结构进行操作
参数说明
painter (QPaintEvent)绘图事件对象包含绘图相关信息在当前函数中未直接使用该参数进行操作主要用于触发绘图操作
返回值
None
"""
# 在之前创建的Figure对象上添加一个极坐标投影的坐标轴同样占满整个Figure用于后续在极坐标下绘制图形。
ax = self.figure.add_axes([0, 0, 1, 1], projection="polar")
# 设置极坐标中y轴半径方向的范围为0到10用于控制绘制的数据点在半径方向上的展示范围。
ax.set_ylim(0, 10)
# 设置y轴半径方向的刻度从0到10间隔为2方便在雷达图上显示清晰的半径刻度标记。
ax.set_yticks(np.arange(0, 10, 2))
# 在极坐标的坐标轴上绘制散点图使用已经生成的角度self.angle和半径self.r数据作为散点的坐标展示雷达图样式的图形。
ax.scatter(self.angle, self.r)
# 调用FigureCanvas的draw方法触发图形的重绘操作确保绘图内容能够及时更新并显示在窗口中例如在窗口大小改变等需要重绘的情况下生效。
self.canvas.draw()
if __name__ == '__main__':
# 创建一个QApplication对象它是整个PyQt应用程序的核心用于管理应用程序的生命周期处理各种事件循环等
# sys.argv参数用于传递命令行参数给应用程序在简单示例中可能不一定会用到这些参数但在实际复杂应用中可能会有相关配置通过命令行传入
app = QApplication(sys.argv)
# 创建一个drawCloud类的实例对象即创建一个可以绘制雷达图的窗口部件。
windows = drawCloud()
# 显示创建好的窗口部件,使其在屏幕上可见,调用这个方法后,窗口会根据其内部的布局和绘制逻辑展示相应的内容(这里就是雷达图)。
windows.show()
sys.exit(app.exec())
# 启动应用程序的事件循环,使得应用程序能够响应各种用户交互事件(如鼠标点击、键盘输入等)以及系统事件(如窗口重绘等),
# 程序会一直运行在这个循环中直到用户关闭窗口等操作触发退出事件然后返回应用程序的退出状态码传递给sys.exit用于正常退出程序。
sys.exit(app.exec())

@ -7,7 +7,7 @@ from collections import deque
import cv2
import numpy as np
#LXT
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
COLORS_10 =[(144,238,144),(178, 34, 34),(221,160,221),( 0,255, 0),( 0,128, 0),(210,105, 30),(220, 20, 60),
(192,192,192),(255,228,196),( 50,205, 50),(139, 0,139),(100,149,237),(138, 43,226),(238,130,238),
@ -37,7 +37,7 @@ def compute_color_for_labels(label):
else:
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
return tuple(color)
#LXY
#绘制轨迹
def draw_trail(img, bbox, names,object_id, identities=None, offset=(0, 0)):
try:

@ -104,7 +104,7 @@ if __name__ == "__main__":
show_data_db1 = db.get_list('select * from auto ')
pprint(show_data_db1)
#LXY
# 方法2
# 如果你的应用程序只偶尔需要进行数据库操作或者需要更加严谨的资源管理,可以选择使用 with 语句的方式。
with SQLManager() as sql_manager:

@ -1,7 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : pan
# @Description : 封装了一些函数
# @Date : 2023年8月8日09:45:23
import supervision as sv
from ultralytics import YOLO
@ -25,72 +21,139 @@ import cv2
from classes.paint_trail import draw_trail
from utils.main_utils import check_path
# 定义一个空列表用于存储时间轴相关的数据可能是后续绘制图表时的x轴数据比如对应不同时刻等情况
# 具体用途要结合整体代码功能来看,可能与记录检测过程在不同时间点的相关信息有关。
x_axis_time_graph = []
# 定义一个空列表用于存储数量相关的数据可能是检测到的目标数量等情况对应绘制图表时的y轴数据
# 同样要结合整体代码逻辑来明确其具体用途,大概率是用于展示随着时间变化检测目标数量等的变化趋势。
y_axis_count_graph = []
# 定义一个计数器变量用于记录视频相关的编号或者计数初始值设为0从变量名推测可能是针对多个视频处理时进行区分编号等操作。
video_id_count = 0
#xlylxy
# 定义YoloPredictor类它继承自BasePredictor和QObject意味着这个类既具备BasePredictor类的相关功能特性
# 又能利用QObject类在Qt框架下实现信号与槽等面向对象的事件机制相关功能可能用于目标检测以及与Qt界面交互相关的操作。
class YoloPredictor(BasePredictor, QObject):
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数
# 定义一个信号用于发送轨迹图像数据以numpy数组形式可以将检测到的目标轨迹相关图像信息传递给连接了该信号的槽函数
# 以便在Qt界面或者其他相关模块中进行展示等操作信号传递的数据类型为numpy.ndarray。
yolo2main_trail_img = Signal(np.ndarray)
# 定义一个信号用于发送绘制了标签与锚框的图像数据以numpy数组形式方便将带有检测结果标注的图像传递出去供显示或者进一步处理。
yolo2main_box_img = Signal(np.ndarray)
# 定义一个信号,用于发送检测相关的状态消息(以字符串形式),像“正在加载模型...”“检测中...”“检测终止”等状态提示,
# 可以让接收该信号的部分(比如界面显示模块)根据消息内容更新显示给用户相应的状态信息。
yolo2main_status_msg = Signal(str)
# 定义一个信号用于发送每秒帧率FPS相关的字符串信息可能用于在界面上展示检测过程中的实时帧率情况让用户了解检测速度。
yolo2main_fps = Signal(str)
# 定义一个信号,用于发送检测到的目标结果信息(以字典形式),字典中可能包含每个类别的数量等统计信息,
# 便于其他模块根据这些结果进行数据展示或者进一步分析等操作。
yolo2main_labels = Signal(dict)
# 定义一个信号,用于发送检测进度相关的信息(以整数形式),通常可以用于更新界面上的进度条展示,让用户直观看到检测的进度情况。
yolo2main_progress = Signal(int)
# 定义一个信号,用于发送当前帧中类别数量相关的信息(以整数形式),用于告知其他部分当前帧图像中检测到了多少种类别的目标。
yolo2main_class_num = Signal(int)
# 定义一个信号,用于发送当前帧中目标总数量相关的信息(以整数形式),方便展示每帧图像中具体检测到的目标个数情况。
yolo2main_target_num = Signal(int)
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
"""
类的构造函数用于初始化YoloPredictor类的实例对象进行各种属性和参数的设置
参数说明
cfg可选配置信息默认值为DEFAULT_CFG具体值需看代码中别处对其的定义可能包含模型配置检测相关的各种参数等内容
用于初始化模型和相关检测设置
overrides可选用于覆盖默认配置的参数类型可能是字典或者其他可用于修改配置的结构通过它可以对cfg中的部分配置进行自定义修改
具体操作
首先调用父类BasePredictor和QObject的构造函数来初始化继承自它们的部分确保继承的属性和方法等能正常使用
然后进行一系列与检测任务界面显示结果保存等相关的属性初始化操作
"""
super(YoloPredictor, self).__init__()
QObject.__init__(self)
try:
# 尝试获取配置信息通过调用get_cfg函数具体功能需看其定义传入cfg和overrides参数获取最终用于模型和检测的配置参数
# 如果获取过程出现异常比如函数内部执行出错等情况则通过pass跳过异常处理可能导致使用默认配置或者部分未正确初始化的配置情况。
self.args = get_cfg(cfg, overrides)
except:
pass
# 根据配置参数中的项目名称project或者默认的运行目录SETTINGS['runs_dir']与任务名称self.args.task来确定项目路径
# 如果配置中没有指定项目名称,则使用默认的运行目录与任务名称组合作为项目路径。
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
# 根据配置参数中的模式mode来确定一个名称这里直接使用配置中的模式值作为名称具体用途可能与保存结果、区分不同运行模式等相关。
name = f'{self.args.mode}'
# 使用increment_path函数功能可能是根据给定路径进行路径递增操作比如避免同名路径冲突等情况来生成保存目录路径
# 传入项目路径与名称的组合以及是否允许已存在根据配置中的exist_ok参数的标识确保保存结果等操作有合适的目录可用。
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
# 初始化一个标识表示是否已经完成热身操作具体热身操作可能与模型加载、初始化相关的前置准备工作有关从变量名推测初始值设为False。
self.done_warmup = False
if self.args.show:
# 如果配置参数中要求显示show为True则调用check_imshow函数可能用于检查是否可以显示图像比如检查显示设备是否可用等情况
# 并且根据warn参数决定是否给出警告提示来处理显示相关的设置将处理后的结果重新赋值给self.args.show。
self.args.show = check_imshow(warn=True)
# GUI args
self.used_model_name = None # 使用过的检测模型名称
self.new_model_name = None # 新更改的模型
self.source = '' # 输入源str
self.progress_value = 0 # 进度条的值
self.stop_dtc = False # 终止bool
self.continue_dtc = True # 暂停bool
# config
self.iou_thres = 0.45 # iou
self.conf_thres = 0.25 # conf
self.speed_thres = 0.01 # delay, ms (缓冲)
self.save_res = False # 保存MP4
self.save_txt = False # 保存txt
# GUI相关参数初始化部分以下这些变量用于在图形用户界面GUI交互或者展示相关场景下使用记录一些状态、名称等信息。
# 用于记录使用过的检测模型名称初始值设为None表示尚未记录使用过的模型名称后续在使用模型的过程中可能会对其赋值更新。
self.used_model_name = None
# 用于记录新更改的模型相关信息具体更改的内容可能是模型参数、模型文件等情况初始值设为None
# 当有模型更新操作时可以对其进行赋值,以便后续根据新模型进行相应处理。
self.new_model_name = None
# 用于记录输入源的字符串表示,比如可以是视频文件路径、摄像头设备编号等信息,初始值为空字符串,后续会在获取数据源时进行赋值。
self.source = ''
# 用于记录进度条的值代表检测任务的进度情况初始值设为0随着检测过程推进可以更新该值来实时反映进度通过信号发送给界面更新显示。
self.progress_value = 0
# 用于标记是否终止检测操作的布尔变量初始值设为False表示默认不终止检测当满足某些终止条件时会将其设为True来停止检测循环等操作。
self.stop_dtc = False
# 用于标记是否暂停检测操作的布尔变量初始值设为True表示默认可以继续检测当需要暂停时会将其设为False暂停检测循环
# 之后可以再通过改变该值恢复检测操作,实现暂停/继续功能。
self.continue_dtc = True
# 配置相关参数初始化部分,以下这些变量用于配置检测过程中的一些关键参数,如阈值、保存设置、显示设置等。
# 用于设置交并比IoU阈值该阈值在目标检测中常用于判断预测框与真实框的重叠程度以确定是否为有效的检测结果这里设为0.45。
self.iou_thres = 0.45
# 用于设置置信度Confidence阈值即模型预测某个目标存在的可信度只有置信度高于该阈值的检测结果才会被考虑这里设为0.25。
self.conf_thres = 0.25
# 用于设置速度相关的阈值从变量名推测可能与处理延迟、缓冲相关单位是毫秒具体用途要结合代码中使用该变量的地方来确定这里设为0.01。
self.speed_thres = 0.01
# 用于标记是否保存检测结果视频MP4格式的布尔变量初始值设为False表示默认不保存视频结果若需要保存可将其设为True
# 并配合后续相关的保存代码进行视频保存操作。
self.save_res = False
# 用于标记是否保存检测结果文本信息可能是标注信息等格式为txt的布尔变量初始值设为False代表默认不保存文本结果
# 若需要记录详细的文本标注等情况可以将其设为True并设置相应的保存路径等进行保存操作。
self.save_txt = False
# 用于指定保存检测结果视频的路径,初始值设为"pre_result",后续可能会根据实际情况进行调整或者拼接完整的文件名等操作来正确保存视频。
self.save_res_path = "pre_result"
# 用于指定保存检测结果文本信息的路径,初始值设为"pre_labels",同样可能会根据实际情况进一步处理来确保文本信息能准确保存到相应位置。
self.save_txt_path = "pre_labels"
self.show_labels = True # 显示图像标签bool
self.show_trace = True # 显示图像轨迹bool
# 用于标记是否显示图像标签的布尔变量初始值设为True表示默认显示检测到的目标对应的标签信息若不想显示可将其设为False。
self.show_labels = True
# 用于标记是否显示图像轨迹的布尔变量初始值设为True意味着默认会显示检测目标的轨迹信息若不需要展示轨迹可将其设为False。
self.show_trace = True
# 运行时候的参数初始化部分,以下这些变量用于在检测运行过程中记录各种实时状态、计数等信息,方便后续计算和展示相关数据。
# 运行时候的参数放这里
self.start_time = None # 拿来算FPS的计数变量
# 用于记录检测开始时间的变量初始值设为None在检测开始时会获取当前时间赋值给它以便后续用于计算每秒帧率FPS等操作
# 作为计时的起始时间点。
self.start_time = None
# 用于计数的变量初始值设为None从代码中后续使用情况看可能参与到FPS计算或者其他计数相关的逻辑中具体含义要结合具体使用代码确定。
self.count = None
# 用于累计计数的变量初始值设为None可能用于对检测到的目标数量或者其他相关量进行累加统计具体用途要结合使用场景分析。
self.sum_of_count = None
# 用于记录当前帧中类别数量的变量初始值设为None在检测每帧图像时会更新该值来反映当前帧包含的目标类别个数情况。
self.class_num = None
# 用于记录总帧数的变量初始值设为None在处理视频等情况时会获取视频的总帧数赋值给它方便进行进度计算、循环控制等操作。
self.total_frames = None
# 用于记录锁相关的标识具体锁的用途可能与多线程、资源访问控制等相关从变量名推测初始值设为None后续根据实际需求使用。
self.lock_id = None
# 设置线条样式 厚度 & 缩放大小
# 创建一个BoxAnnotator对象具体功能需看其类定义从参数看可能用于绘制标注框相关的操作设置标注框的厚度、文本厚度以及文本缩放比例等属性
# 用于后续在图像上绘制检测结果的标注信息,使标注更清晰美观,符合展示要求。
self.box_annotator = sv.BoxAnnotator(
thickness=2,
text_thickness=1,
@ -100,66 +163,65 @@ class YoloPredictor(BasePredictor, QObject):
# 点击开始检测按钮后的检测事件
@smart_inference_mode() # 一个修饰器用来开启检测模式如果torch>=1.9.0则执行torch.inference_mode()否则执行torch.no_grad()
def run(self):
"""
该方法是在点击开始检测按钮后执行的检测逻辑主体负责整个检测过程的流程控制包括加载模型获取数据源处理检测结果以及控制检测循环等操作
具体操作
首先发送状态消息信号表示正在加载模型进行一些计数和计时变量的初始化然后检查保存结果的路径是否存在如果需要保存结果的话
接着加载模型获取数据源初始化用于绘制图表的数据折线图相关的x轴和y轴数据列表之后根据数据源类型进行不同的处理比如视频的总帧数获取等
再创建用于保存结果视频的对象如果需要保存视频最后进入一个无限循环在循环中根据暂停终止等条件来处理每帧图像的检测结果以及相应的保存状态更新等操作
"""
# 发送状态消息信号,告知其他接收该信号的部分(如界面显示模块)当前正在加载模型,以便更新显示相应的状态提示给用户。
self.yolo2main_status_msg.emit('正在加载模型...')
# 将LoadStreams类具体功能需看其定义中的capture属性设为空字符串可能用于重置或者初始化数据源相关的状态
# 具体作用要结合LoadStreams类的使用场景来确定。
LoadStreams.capture = ''
self.count = 0 # 拿来参与算FPS的计数变量
self.start_time = time.time() # 拿来算FPS的计数变量
# 初始化计数变量用于参与每秒帧率FPS等相关的计算将其初始值设为0。
self.count = 0
# 获取当前时间并赋值给start_time变量用于记录检测开始的时间点方便后续计算FPS等操作通过time.time()函数获取当前时间的时间戳(秒数)。
self.start_time = time.time()
global video_id_count
# 检查保存路径
# 检查保存文本结果的路径是否存在如果需要保存检测结果文本self.save_txt为True则调用check_path函数功能可能是创建路径等确保路径可用进行检查。
if self.save_txt:
check_path(self.save_txt_path)
# 检查保存视频结果的路径是否存在同理如果要保存检测结果视频self.save_res为True调用check_path函数确保保存视频的路径可用。
if self.save_res:
check_path(self.save_res_path)
# 调用load_yolo_model函数具体功能为加载YOLO模型内部应该涉及模型文件读取、初始化等操作加载YOLO模型将加载好的模型赋值给model变量
# 以便后续使用该模型进行目标检测操作。
model = self.load_yolo_model()
# 获取数据源 (不同的类型获取不同的数据源)
# 获取数据源通过调用model的track方法具体track方法的功能可能是根据输入源进行目标跟踪检测等操作传入相关参数来配置检测行为
# 传入输入源self.source、是否显示show参数设为False可能表示不在获取数据源阶段显示相关图像等情况、是否以流的形式处理stream设为True
# 适合处理视频等连续数据的情况以及交并比阈值iou和置信度阈值conf等参数然后使用iter函数将返回的结果转换为迭代器方便后续逐帧处理。
iter_model = iter(
model.track(source=self.source, show=False, stream=True, iou=self.iou_thres, conf=self.conf_thres))
# 折线图数据初始化
# 折线图数据初始化将全局的x轴时间数据列表和y轴数量数据列表清空重新开始记录本次检测过程中相关的数据
# 用于后续可能的绘制图表展示检测目标数量随时间变化趋势等操作。
global x_axis_time_graph, y_axis_count_graph
x_axis_time_graph = []
y_axis_count_graph = []
# 发送状态消息信号,告知其他部分当前进入检测阶段,相应的界面等可以显示“检测中...”这样的提示给用户。
self.yolo2main_status_msg.emit('检测中...')
# 使用OpenCV读取视频——获取进度条
# 通过判断输入源self.source的文件格式是否包含常见的视频格式后缀如'mp4'、'avi'等),如果是视频文件,
# 则使用cv2.VideoCapture函数OpenCV中用于打开视频文件或者摄像头设备获取视频流的函数打开视频获取视频的总帧数通过cv2.CAP_PROP_FRAME_COUNT属性
# 然后释放视频资源调用release方法这里获取总帧数可能是用于后续计算检测进度、循环控制等操作虽然释放了资源但已经获取到了需要的总帧数信息。
if 'mp4' in self.source or 'avi' in self.source or 'mkv' in self.source or 'flv' in self.source or 'mov' in self.source:
cap = cv2.VideoCapture(self.source)
self.total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
# 如果保存,则创建写入对象
img_res, result, height, width = self.recognize_res(iter_model)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = None # 视频写出变量
if self.save_res:
out = cv2.VideoWriter(f'{self.save_res_path}/video_result_{video_id_count}.mp4', fourcc, 25,
(width, height), True) # 保存检测视频的路径
# 开始死循环检测
while True:
try:
# 暂停与开始
if self.continue_dtc:
img_res, result, height, width = self.recognize_res(iter_model)
self.res_address(img_res, result, height, width, model, out)
# 终止
if self.stop_dtc:
if self.save_res:
if out:
out.release()
video_id_count += 1
self.source = None
self.yolo2main_status_msg.emit('检测终止')
self.release_capture() # 这里是为了终止使用摄像头检测函数的线程改了yolo源码
break
# 检测截止(本地文件检测)
# 如果保存检测结果视频则创建用于写入视频的对象使用cv2.VideoWriter类配置视频编码格式这里使用'XVID'编码、帧率设为25帧/秒)、
# 检测截止(本地文件检测)
#lxylxylxy
except StopIteration:
if self.save_res:
out.release()
@ -331,18 +393,18 @@ class YoloPredictor(BasePredictor, QObject):
f"ID: {tracker_id} {model.model.names[class_id]}"
for _, _, confidence, class_id, tracker_id in detections
]
'''
如果Torch装的是cuda版本的话302行的代码需改成
'''
如果Torch装的是cuda版本的话302行的代码需改成
labels_draw = [
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
]
'''
# 存储labels里的信息
labels_write = [
labels_write = [
f"目标ID: {tracker_id} 目标类别: {class_id} 置信度: {confidence:0.2f}"
for _, _, confidence, class_id, tracker_id in detections
]
]
'''
如果Torch装的是cuda版本的话314行的代码需改成
labels_write = [
@ -352,19 +414,19 @@ class YoloPredictor(BasePredictor, QObject):
'''
# 如果显示标签 (要有才可以画呀!)---否则就是原图
if (self.show_labels == True) and (self.class_num != 0):
if (self.show_labels == True) and (self.class_num != 0):
img_box = self.box_annotator.annotate(scene=img_box, detections=detections, labels=labels_draw)
return labels_write, img_box
return labels_write, img_box
# 获取类别数
def get_class_number(self, detections):
def get_class_number(self, detections):
class_num_arr = []
for each in detections.class_id:
for each in detections.class_id:
if each not in class_num_arr:
class_num_arr.append(each)
return len(class_num_arr)
return len(class_num_arr)
# 释放摄像头
def release_capture(self):
def release_capture(self):
LoadStreams.capture = 'release' # 这里是为了终止使用摄像头检测函数的线程改了yolo源码

@ -2,7 +2,6 @@
# @Author : pan
# @Description : 废弃方案1
# @Date : 2023年7月27日10:28:50
import supervision as sv
from ultralytics import YOLO
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadStreams
@ -28,16 +27,23 @@ import os
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib
#include
#include
# 定义两个空列表用于存储时间轴相关数据可能后续用于绘制图表的x轴比如对应不同时刻等情况以及数量相关数据可能对应图表的y轴比如检测到的目标数量
# 它们的具体用途要结合整体代码功能来明确,大概率与展示检测过程中随时间变化的某些指标相关。
x_axis_time_graph = []
y_axis_count_graph = []
# 创建一个渐变色
# 创建一个渐变色的颜色映射colormap通过指定颜色列表和颜色数量N来定义渐变效果
# 这里从黑色 (0, 0, 0) 渐变到指定的颜色 (233, 156, 105)共包含256个渐变级别可用于后续图像绘制中颜色的渐变处理具体要看使用场景
gradient = LinearSegmentedColormap.from_list(
'gradient', [(0, 0, 0), (233, 156, 105)], N=256)
# 用于记录视频相关的编号或者计数初始值设为0从变量名推测可能是针对多个视频处理时进行区分编号等操作后续会根据视频处理情况进行递增等变化。
video_id_count = 0
# 定义一个颜色值元组,可能用于在图像绘制等操作中作为颜色的一种表示方式,这里的值看起来像是通过位运算等方式构造的特定颜色编码(具体要结合使用场景确定)。
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
# 定义一个包含多种颜色的列表每个元素是一个RGB颜色值的元组可作为颜色板用于给不同的目标、类别等分配不同的颜色方便在图像绘制中进行区分显示。
COLORS_10 = [(144, 238, 144), (178, 34, 34), (221, 160, 221), (0, 255, 0), (0, 128, 0), (210, 105, 30), (220, 20, 60),
(192, 192, 192), (255, 228, 196), (50, 205, 50), (139, 0, 139), (100, 149, 237), (138, 43, 226),
(238, 130, 238),
@ -57,12 +63,22 @@ COLORS_10 = [(144, 238, 144), (178, 34, 34), (221, 160, 221), (0, 255, 0), (0, 1
(102, 205, 170), (60, 179, 113), (46, 139, 87), (165, 42, 42), (178, 34, 34), (175, 238, 238),
(255, 248, 220),
(218, 165, 32), (255, 250, 240), (253, 245, 230), (244, 164, 96), (210, 105, 30)] # 颜色板
# 定义一个字典用于存储绘制轨迹相关的信息可能以目标的标识如ID作为键对应的值用于记录轨迹相关的数据具体要看后续对其的操作
dic_for_drawing_trails = {}
def compute_color_for_labels(label):
"""
设置不同类别的固定颜色
设置不同类别的固定颜色
函数功能
根据传入的标签值label为不同的类别分配固定的颜色通过一系列条件判断来确定具体类别对应的颜色值
如果不在特定的类别判断范围内则通过一定的计算规则生成对应的颜色返回的颜色以元组形式表示RGB格式用于后续图像绘制中区分不同类别
参数说明
label表示目标的类别标签通常是一个整数对应不同的目标类别通过这个值来决定返回哪种颜色用于表示该类别
返回值
返回一个表示颜色的元组RGB格式例如 (85, 45, 255)用于在图像绘制等操作中作为对应类别目标的颜色
"""
if label == 0: # person
color = (85, 45, 255)
@ -77,9 +93,28 @@ def compute_color_for_labels(label):
return tuple(color)
# 绘制轨迹
# 绘制轨迹的函数,用于在给定的图像上根据目标的边界框、类别、身份标识等信息绘制目标的运动轨迹。
def draw_trail(img, bbox, names, object_id, identities=None, offset=(0, 0)):
"""
函数功能
根据传入的图像目标边界框类别名称目标标识以及可选的身份标识等信息绘制目标在图像上的运动轨迹
同时处理一些轨迹相关的逻辑比如更新轨迹数据根据距离调整轨迹线条的粗细等最终返回绘制好轨迹的图像
参数说明
img要绘制轨迹的原始图像通常是一个OpenCV格式的图像数据numpy数组表示在这个图像上进行轨迹绘制操作
bbox目标的边界框信息通常是一个包含多个边界框坐标的列表每个边界框坐标可能以 [x1, y1, x2, y2] 的形式表示左上角和右下角坐标
用于确定目标在图像中的位置进而确定轨迹的起始和经过位置等
names目标的类别名称列表与边界框等信息对应用于通过类别获取相应的颜色等信息来绘制轨迹不过在当前函数中未直接体现其使用可能在其他相关函数中有配合使用情况
object_id目标的标识列表用于区分不同的目标与边界框等信息对应可能在处理轨迹数据确定颜色等方面起到索引或区分作用
identities可选目标的身份标识列表如果提供了该参数则可以根据身份标识来更精准地处理不同个体目标的轨迹若为None则采用默认的处理方式默认值为None
offset可选坐标偏移量以元组形式表示 (x, y)用于在绘制轨迹时对边界框坐标等进行偏移调整例如在图像拼接等场景下使用默认值为 (0, 0)
返回值
返回绘制好轨迹的图像即传入的原始图像经过轨迹绘制操作后的结果同样是OpenCV格式的图像数据numpy数组表示可用于后续显示保存等操作
"""
try:
# 遍历用于存储绘制轨迹相关信息的字典dic_for_drawing_trails的键检查每个键对应的身份标识是否不在当前传入的身份标识列表identities
# 如果不在,则从字典中移除该键值对,意味着删除那些不再出现在当前帧中的目标的轨迹相关数据,避免数据冗余和错误绘制。
for key in list(dic_for_drawing_trails):
if key not in identities:
dic_for_drawing_trails.pop(key)
@ -94,73 +129,69 @@ def draw_trail(img, bbox, names, object_id, identities=None, offset=(0, 0)):
y2 += offset[1]
# 获取锚框boundingbox中心点
# 计算边界框的中心点坐标通过取边界框左上角和右下角坐标的平均值在x和y方向分别计算得到用于确定轨迹绘制的经过点位置。
center = (int((x2 + x1) / 2), int((y2 + y2) / 2))
# 获取目标ID
# 根据是否传入了身份标识列表identities来获取目标的唯一标识ID如果传入了则取对应位置的身份标识作为ID
# 如果未传入则默认将ID设为0这个ID用于在轨迹相关数据存储如dic_for_drawing_trails字典中作为索引或区分不同目标的依据。
id = int(identities[i]) if identities is not None else 0
# 创建新的缓冲区
# 如果当前目标的ID不在轨迹相关信息字典dic_for_drawing_trails说明是新出现的目标
# 则为其创建一个新的双端队列deque作为缓冲区用于存储该目标的轨迹点设置队列最大长度为64即最多保留最近的64个轨迹点。
if id not in dic_for_drawing_trails:
dic_for_drawing_trails[id] = deque(maxlen=64)
try:
# 根据目标的类别标识object_id调用compute_color_for_labels函数获取对应的颜色用于绘制该目标的轨迹
# 如果在获取颜色过程中出现异常(比如类别标识不符合预期等情况),则跳过当前目标,继续处理下一个目标。
color = compute_color_for_labels(object_id[i])
except:
continue
# 将当前目标的中心点坐标添加到对应的轨迹缓冲区(双端队列)的左侧,这样新的轨迹点会排在前面,方便后续按照顺序绘制轨迹。
dic_for_drawing_trails[id].appendleft(center)
# 绘制轨迹
for i in range(1, len(dic_for_drawing_trails[id])):
if dic_for_drawing_trails[id][i - 1] is None or dic_for_drawing_trails[id][i] is None:
continue
# 轨迹动态粗细
# 根据当前轨迹点的索引位置i计算轨迹线条的粗细通过一定的数学公式sqrt函数结合当前索引等计算得到一个粗细值
# 并乘以1.5进行适当缩放,使得轨迹线条的粗细随着轨迹长度变化呈现动态效果,离当前点越近的轨迹部分越粗,越远越细。
thickness = int(np.sqrt(64 / float(i + i)) * 1.5)
# 使用OpenCV的cv2.line函数在图像上绘制从当前轨迹点的前一个点dic_for_drawing_trails[id][i - 1]到当前点dic_for_drawing_trails[id][i])的线段,
# 使用获取到的颜色color进行绘制并根据计算出的粗细thickness设置线条粗细从而在图像上绘制出目标的运动轨迹。
img = cv2.line(img, dic_for_drawing_trails[id][i - 1], dic_for_drawing_trails[id][i], color, thickness)
return img
class YoloPredictor(BasePredictor, QObject):
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数
# 定义一个信号用于发送轨迹图像数据以numpy数组形式可以将检测到的目标轨迹相关图像信息传递给连接了该信号的槽函数
# 以便在Qt界面或者其他相关模块中进行展示等操作信号传递的数据类型为numpy.ndarray。
yolo2main_trail_img = Signal(np.ndarray)
# 定义一个信号用于发送绘制了标签与锚框的图像数据以numpy数组形式方便将带有检测结果标注的图像传递出去供显示或者进一步处理。
yolo2main_box_img = Signal(np.ndarray)
# 定义一个信号,用于发送检测相关的状态消息(以字符串形式),像“正在加载模型...”“检测中...”“检测终止”等状态提示,
# 可以让接收该信号的部分(比如界面显示模块)根据消息内容更新显示给用户相应的状态信息。
yolo2main_status_msg = Signal(str)
# 定义一个信号用于发送每秒帧率FPS相关的字符串信息可能用于在界面上展示检测过程中的实时帧率情况让用户了解检测速度。
yolo2main_fps = Signal(str)
# 定义一个信号,用于发送检测到的目标结果信息(以字典形式),字典中可能包含每个类别的数量等统计信息,
# 便于其他模块根据这些结果进行数据展示或者进一步分析等操作。
yolo2main_labels = Signal(dict)
# 定义一个信号,用于发送检测进度相关的信息(以整数形式),通常可以用于更新界面上的进度条展示,让用户直观看到检测的进度情况。
yolo2main_progress = Signal(int)
# 定义一个信号,用于发送当前帧中类别数量相关的信息(以整数形式),用于告知其他部分当前帧图像中检测到了多少种类别的目标。
yolo2main_class_num = Signal(int)
# 定义一个信号,用于发送当前帧中目标总数量相关的信息(以整数形式),方便展示每帧图像中具体检测到的目标个数情况。
yolo2main_target_num = Signal(int)
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
super(YoloPredictor, self).__init__()
QObject.__init__(self)
try:
self.args = get_cfg(cfg, overrides)
except:
pass
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = f'{self.args.mode}'
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# GUI args
self.used_model_name = None # 使用过的检测模型名称
self.new_model_name = None # 新更改的模型
self.source = '' # 输入源str
self.stop_dtc = False # 终止bool
self.continue_dtc = True # 暂停bool
self.save_res = False # 保存MP4
self.save_txt = False # 保存txt
self.show_labels = True # 显示图像标签bool
self.iou_thres = 0.45 # iou
self.conf_thres = 0.25 # conf
self.speed_thres = 10 # delay, ms
self.progress_value = 0 # 进度条的值
self.lock_id = None
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)
"""
类的构造函数用于初始化YoloPredictor类的实例对象进行各种属性和参数的设置包括继承
# 单目标跟踪
def single_object_tracking(self, detections, img_box, org_2, store_xyxy_for_id):
@ -185,7 +216,7 @@ class YoloPredictor(BasePredictor, QObject):
pass
# 点击开始检测按钮后的检测事件
@smart_inference_mode() # 一个修饰器用来开启检测模式如果torch>=1.9.0则执行torch.inference_mode()否则执行torch.no_grad()
@smart_inference_mode() # 一个修饰器用来开启检测模式如果torch>=1.9.0则执行torch.inference_mode()否则执行torch.no_grad()
def run(self):
# try:
LoadStreams.capture = None
@ -208,24 +239,24 @@ class YoloPredictor(BasePredictor, QObject):
fourcc = cv2.VideoWriter_fourcc(*'XVID')
# 设置线条样式
# 设置线条样式
box_annotator = sv.BoxAnnotator(
thickness=2,
text_thickness=1,
text_scale=0.5
)
if self.continue_dtc: # 暂停与继续的切换
if self.continue_dtc: # 暂停与继续的切换
if self.used_model_name != self.new_model_name:
self.setup_model(self.new_model_name)
self.used_model_name = self.new_model_name
self.yolo2main_status_msg.emit('正在加载模型...')
# 加载模型
# 加载模型
model = YOLO(self.new_model_name)
# 获取数据源 (不同的类型获取不同的数据源)
# 获取数据源 (不同的类型获取不同的数据源)
iter_model = iter(
model.track(source=self.source, show=False, stream=True, iou=self.iou_thres, conf=self.conf_thres))
@ -233,11 +264,11 @@ class YoloPredictor(BasePredictor, QObject):
global x_axis_time_graph, y_axis_count_graph
x_axis_time_graph = []
y_axis_count_graph = []
flag_save_video = 1 # 拿来保存视频的flag免得在后面的循环里面重复执行cv2.VideoWriter()函数
flag_save_video = 1 # 拿来保存视频的flag免得在后面的循环里面重复执行cv2.VideoWriter()函数
self.yolo2main_status_msg.emit('检测中...')
# 使用OpenCV读取视频——获取进度条
# 使用OpenCV读取视频——获取进度条
if 'mp4' in self.source or 'avi' in self.source or 'mkv' in self.source or 'flv' in self.source or 'mov' in self.source:
cap = cv2.VideoCapture(self.source)
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
@ -272,19 +303,19 @@ class YoloPredictor(BasePredictor, QObject):
labels2 = "暂未识别到目标!"
else:
# 当没有识别目标时,这里会报错!
# 当没有识别目标时,这里会报错!
detections.tracker_id = result.boxes.id.cpu().numpy().astype(int)
for each in detections.class_id:
if each not in class_num_arr:
class_num_arr.append(each)
# id 、位置、目标总数
# id 、位置、目标总数
class_num = len(class_num_arr) # 类别数
id = detections.tracker_id # id
xyxy = detections.xyxy # 位置
sum_of_count = len(id) # 目标总数
# 轨迹绘制部分@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# 轨迹绘制部分@@@@@@@@@@@@@@@@@@@@@@@@@@@@
identities = id
height, width, _ = img_box.shape
img_trail = np.zeros((height, width, 3), dtype='uint8')
@ -298,7 +329,7 @@ class YoloPredictor(BasePredictor, QObject):
# 绘制轨迹!
draw_trail(img_trail, xyxy, model.model.names, id, identities)
# 要画出来的信息
# 要画出来的信息
labels = [
f"ID: {tracker_id} " \
f"CLASS: {model.model.names[class_id]} " \
@ -306,23 +337,23 @@ class YoloPredictor(BasePredictor, QObject):
for _, _, confidence, class_id, tracker_id in detections
]
'''
如果Torch装的是cuda版本的话302行的代码需改成
如果Torch装的是cuda版本的话302行的代码需改成
labels = [
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
]
'''
# 存储labels里的信息
labels2 = [
f"目标ID: {tracker_id} 目标类别: {model.model.names[class_id]} 置信度: {confidence:0.2f}"
f"目标ID: {tracker_id} 目标类别: {model.model.names[class_id]} 置信度: {confidence:0.2f}"
for _, _, confidence, class_id, tracker_id in detections
]
'''
如果Torch装的是cuda版本的话314行的代码需改成
如果Torch装的是cuda版本的话314行的代码需改成
labels2 = [
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
]
'''
@ -340,19 +371,19 @@ class YoloPredictor(BasePredictor, QObject):
elif self.show_labels == False:
img_box = org_2
# 画出多少出多少 进来多少 的横线
# 画出多少出多少 进来多少 的横线
# temp_sum = sum_of_count
# line_counter.trigger(detections=detections)
# line_annotator.annotate(frame=img_box, line_counter=line_counter)
# sum_of_count = line_counter.in_count + line_counter.out_count
# line_counter.trigger(detections=detections)
# line_annotator.annotate(frame=img_box, line_counter=line_counter)
# sum_of_count = line_counter.in_count + line_counter.out_count
# 如果保存,则创建写入对象
# 如果保存,则创建写入对象
if self.save_res and flag_save_video:
out = cv2.VideoWriter(f'pred_result/video_result_{video_id_count}.avi', fourcc, 25,
(width, height), True) # 保存检测视频的路径
flag_save_video = 0
# 如果停止
# 如果停止
if self.stop_dtc:
if self.save_res:
out.release()
@ -363,20 +394,20 @@ class YoloPredictor(BasePredictor, QObject):
break
try:
# 添加 折线图数据
# 添加 折线图数据
now = datetime.datetime.now()
# new_time = now.strftime("%M:%S")
# new_time = now.strftime("%M:%S")
new_time = now.strftime("%Y-%m-%d %H:%M:%S")
if new_time not in x_axis_time_graph: # 防止同一秒写入
x_axis_time_graph.append(new_time)
y_axis_count_graph.append(sum_of_count)
# 抠锚框里的图 (单目标追踪)
# 抠锚框里的图 (单目标追踪)
if self.lock_id is not None:
self.lock_id = int(self.lock_id)
try:
# 单目标追踪
# 单目标追踪
result_cropped = self.single_object_tracking(detections, img_box, org_2,
store_xyxy_for_id)
# print(result_cropped)
@ -394,20 +425,20 @@ class YoloPredictor(BasePredictor, QObject):
cv2.destroyAllWindows()
except:
pass
# 预测写入本地
# 预测写入本地
if self.save_res:
out.write(img_box)
# ----------------------------------------------信号发送区
time.sleep(0.0) # 缓冲
# 轨迹图像(左边)
# 轨迹图像(左边)
self.yolo2main_trail_img.emit(img_trail)
# 标签图(右边)
# 标签图(右边)
self.yolo2main_box_img.emit(img_box)
# 总类别数量 、 总目标数
# 总类别数量 、 总目标数
self.yolo2main_class_num.emit(class_num)
self.yolo2main_target_num.emit(sum_of_count)
# 进度条
# 进度条
self.progress_value = int(count / total_frames * 1000)
self.yolo2main_progress.emit(self.progress_value)
# 计算FPS

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# @Author : pan
# @Description : 废弃方案2
# @Date : 2023年7月27日10:28:50
# @Description : 废弃方案2
# @ Date : 2023年7月27日10:28:50
import supervision as sv
from ultralytics import YOLO
@ -50,11 +50,11 @@ COLORS_10 =[(144,238,144),(178, 34, 34),(221,160,221),( 0,255, 0),( 0,128, 0
(245,255,250),(240,230,140),(245,222,179),( 0,139,139),(143,188,143),(255, 0, 0),(240,128,128),
(102,205,170),( 60,179,113),( 46,139, 87),(165, 42, 42),(178, 34, 34),(175,238,238),(255,248,220),
(218,165, 32),(255,250,240),(253,245,230),(244,164, 96),(210,105, 30)]
#颜色板
#颜色板
dic_for_drawing_trails = {}
def compute_color_for_labels(label):
"""
设置不同类别的固定颜色
设置不同类别的固定颜色
"""
if label == 0: #person
color = (85,45,255)
@ -68,7 +68,7 @@ def compute_color_for_labels(label):
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
return tuple(color)
#绘制轨迹
# 绘制轨迹
def draw_trail(img, bbox, names,object_id, identities=None, offset=(0, 0)):
try:
for key in list(dic_for_drawing_trails):
@ -84,12 +84,12 @@ def draw_trail(img, bbox, names,object_id, identities=None, offset=(0, 0)):
y1 += offset[1]
y2 += offset[1]
#获取锚框boundingbox中心点
#获取锚框boundingbox中心点
center = (int((x2+x1)/ 2), int((y2+y2)/2))
#获取目标ID
#获取目标ID
id = int(identities[i]) if identities is not None else 0
#创建新的缓冲区
#创建新的缓冲区
if id not in dic_for_drawing_trails:
dic_for_drawing_trails[id] = deque(maxlen= 64)
try:
@ -98,27 +98,27 @@ def draw_trail(img, bbox, names,object_id, identities=None, offset=(0, 0)):
continue
dic_for_drawing_trails[id].appendleft(center)
#绘制轨迹
#绘制轨迹
for i in range(1, len(dic_for_drawing_trails[id])):
if dic_for_drawing_trails[id][i - 1] is None or dic_for_drawing_trails[id][i] is None:
continue
#轨迹动态粗细
#轨迹动态粗细
thickness = int(np.sqrt(64 / float(i + i)) * 1.5)
img = cv2.line(img, dic_for_drawing_trails[id][i - 1], dic_for_drawing_trails[id][i], color, thickness)
return img
class YoloPredictor(BasePredictor, QObject):
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数

@ -1,28 +1,9 @@
# -*- coding: utf-8 -*-
# @Author : pan
# @Description : 废弃方案3
# @Date : 2023年7月27日10:28:50
import supervision as sv
from ultralytics import YOLO
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadStreams
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.utils import DEFAULT_CFG, SETTINGS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils.checks import check_imshow
from PySide6.QtCore import Signal, QObject
from pathlib import Path
import datetime
import numpy as np
import time
import cv2
import os
from classes.paint_trail import draw_trail
# - *- coding: utf-8 -*-
# @ Author : pan
# @Description : 废弃方案3
# @Date : 2023年7月27日10:28:50
import
x_axis_time_graph = []
y_axis_count_graph = []
@ -30,15 +11,15 @@ video_id_count = 0
class YoloPredictor(BasePredictor, QObject):
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_trail_img = Signal(np.ndarray) # 轨迹图像信号
yolo2main_box_img = Signal(np.ndarray) # 绘制了标签与锚框的图像的信号
yolo2main_status_msg = Signal(str) # 检测/暂停/停止/测试完成等信号
yolo2main_fps = Signal(str) # fps
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数
yolo2main_labels = Signal(dict) # 检测到的目标结果(每个类别的数量)
yolo2main_progress = Signal(int) # 进度条
yolo2main_class_num = Signal(int) # 当前帧类别数
yolo2main_target_num = Signal(int) # 当前帧目标数
@ -58,28 +39,28 @@ class YoloPredictor(BasePredictor, QObject):
self.args.show = check_imshow(warn=True)
# GUI args
self.used_model_name = None # 使用过的检测模型名称
self.new_model_name = None # 新更改的模型
self.used_model_name = None # 使用过的检测模型名称
self.new_model_name = None # 新更改的模型
self.source = '' # 输入源str
self.progress_value = 0 # 进度条的值
self.source = '' # 输入源str
self.progress_value = 0 # 进度条的值
self.stop_dtc = False # 终止bool
self.continue_dtc = True # 暂停bool
self.stop_dtc = False # 终止bool
self.continue_dtc = True # 暂停bool
# config
self.iou_thres = 0.45 # iou
self.conf_thres = 0.25 # conf
self.speed_thres = 0.01 # delay, ms (缓冲)
self.iou_thres = 0.45 # iou
self.conf_thres = 0.25 # conf
self.speed_thres = 0.01 # delay, ms (缓冲)
self.save_res = False # 保存MP4
self.save_txt = False # 保存txt
self.show_labels = True # 显示图像标签bool
self.save_res = False # 保存MP4
self.save_txt = False # 保存txt
self.show_labels = True # 显示图像标签bool
# 运行时候的参数放这里
self.start_time = None # 拿来算FPS的计数变量
# 行时候的参数放这里
self.start_time = None # 来算FPS的计数变量
self.count = None
self.sum_of_count = None
self.class_num = None
@ -93,8 +74,8 @@ class YoloPredictor(BasePredictor, QObject):
text_scale=0.5
)
# 点击开始检测按钮后的检测事件
@smart_inference_mode() # 一个修饰器用来开启检测模式如果torch>=1.9.0则执行torch.inference_mode()否则执行torch.no_grad()
# 点击开始检测按钮后的检测事件
@smart_inference_mode() # 一个修饰器用来开启检测模式如果torch>=1.9.0则执行torch.inference_mode()否则执行torch.no_grad()
def run(self):
# try:
LoadStreams.capture = ''
@ -103,7 +84,7 @@ class YoloPredictor(BasePredictor, QObject):
self.yolo2main_status_msg.emit('正在加载模型...')
# 检查保存路径
# 检查保存路径
if self.save_txt:
if not os.path.exists('labels'):
os.mkdir('labels')
@ -111,10 +92,10 @@ class YoloPredictor(BasePredictor, QObject):
if not os.path.exists('pred_result'):
os.mkdir('pred_result')
self.count = 0 # 拿来参与算FPS的计数变量
self.count = 0 # 拿来参与算FPS的计数变量
self.start_time = time.time() # 拿来算FPS的计数变量
#LXY
if self.continue_dtc: # 暂停与继续的切换
if self.used_model_name != self.new_model_name:
@ -129,7 +110,7 @@ class YoloPredictor(BasePredictor, QObject):
iter_model = iter(
model.track(source=self.source, show=False, stream=True, iou=self.iou_thres, conf=self.conf_thres))
#LXY
# 折线图数据初始化
global x_axis_time_graph, y_axis_count_graph
x_axis_time_graph = []
@ -142,7 +123,7 @@ class YoloPredictor(BasePredictor, QObject):
cap = cv2.VideoCapture(self.source)
self.total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
# LXY
# 如果保存,则创建写入对象
img_res, result, height, width = self.recognize_res(iter_model)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
@ -150,7 +131,7 @@ class YoloPredictor(BasePredictor, QObject):
if self.save_res:
out = cv2.VideoWriter(f'pred_result/video_result_{video_id_count}.avi', fourcc, 25,
(width, height), True) # 保存检测视频的路径
# LXY
# 开始死循环检测
while True:
try:
@ -172,7 +153,7 @@ class YoloPredictor(BasePredictor, QObject):
LoadStreams.capture = 'release' # 这里是为了终止使用摄像头检测函数的线程改了yolo源码
break
# LXY
# 检测截止(本地文件检测)
except StopIteration:
if self.save_res:
@ -190,12 +171,13 @@ class YoloPredictor(BasePredictor, QObject):
except:
pass
# LXY
# 进行识别——并返回所有结果
def res_address(self, img_res, result, height, width, model, out):
# 复制一份
img_box = np.copy(img_res) # 右边的图
img_trail = np.zeros((height, width, 3), dtype='uint8') # 左边的轨迹
# LXY
# 如果没有识别的:
if result.boxes.id is None:
# 目标都是0
@ -217,7 +199,7 @@ class YoloPredictor(BasePredictor, QObject):
id = detections.tracker_id # id
xyxy = detections.xyxy # 位置
self.sum_of_count = len(id) # 目标总数
# LXY
# 轨迹绘制部分 @@@@@@@@@@@@@@@@@@@@@@@@@@@@
identities = id
grid_color = (255, 255, 255)
@ -254,7 +236,7 @@ class YoloPredictor(BasePredictor, QObject):
f"OBJECT-ID: {tracker_id} CLASS: {model.model.names[class_id]} CF: {confidence:0.2f}"
for _,confidence,class_id,tracker_id in detections
]
'''
'''#LXY
# 如果显示标签 (要有才可以画呀!)---否则就是原图
if (self.show_labels == True) and (self.class_num != 0):
img_box = self.box_annotator.annotate(scene=img_res, detections=detections, labels=labels_draw)
@ -271,23 +253,23 @@ class YoloPredictor(BasePredictor, QObject):
if self.save_res:
out.write(img_box)
# LXY
# 添加 折线图数据
now = datetime.datetime.now()
new_time = now.strftime("%Y-%m-%d %H:%M:%S")
if new_time not in x_axis_time_graph: # 防止同一秒写入
x_axis_time_graph.append(new_time)
y_axis_count_graph.append(self.sum_of_count)
# LXY
# 抠锚框里的图 (单目标追踪)
if self.lock_id is not None:
self.lock_id = int(self.lock_id)
self.open_target_tracking(detections=detections, img_res=img_res)
# LXY
# 传递信号给主窗口
self.emit_res(img_trail, img_box)
# 识别结果处理
@ -299,6 +281,7 @@ class YoloPredictor(BasePredictor, QObject):
return img_res, result, height, width
# LXY
# 单目标检测窗口开启
def open_target_tracking(self, detections, img_res):
try:
@ -315,6 +298,7 @@ class YoloPredictor(BasePredictor, QObject):
cv2.destroyAllWindows()
pass
# LXY
# 单目标跟踪
def single_object_tracking(self, detections, img_box):
store_xyxy_for_id = {}
@ -332,11 +316,12 @@ class YoloPredictor(BasePredictor, QObject):
result_cropped = result_mask[y1:y2, x1:x2]
result_cropped = cv2.resize(result_cropped, (256, 256))
return result_cropped
# LXY
except:
cv2.destroyAllWindows()
pass
# LXY
# 信号发送区
def emit_res(self, img_trail, img_box):

@ -35,6 +35,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# 主窗口向yolo实例发送执行信号
main2yolo_begin_sgl = Signal()
# 主函数
def __init__(self, parent=None):
super(MainWindow, self).__init__()

@ -1,224 +1,214 @@
from PySide6.QtCore import *
from PySide6.QtGui import *
from PySide6.QtWidgets import *
# After the interface hides the window, use this file to achieve borderless stretching of the window
# 自定义的窗口拉伸控制部件类,用于实现无边框窗口的可拉伸功能,通过在窗口的不同边缘创建可操作的区域(如角落和边框部分),
# 让用户能够通过鼠标拖动来改变窗口的大小,并且可以根据传入的参数设置是否隐藏这些操作区域的默认背景颜色。
class CustomGrip(QWidget):
def __init__(self, parent, position, disable_color = False):
"""
类的构造函数用于初始化CustomGrip类的实例对象设置与父窗口的关联以及根据指定位置创建相应的拉伸控制部件和对应的操作逻辑
参数说明
parent父窗口对象即要为其添加拉伸功能的窗口CustomGrip实例会作为子部件添加到这个父窗口上与之关联并控制其大小变化
position表示拉伸控制部件在窗口上的位置是Qt中的枚举类型如Qt.TopEdgeQt.BottomEdge等用于确定在窗口的哪个边缘创建拉伸相关的交互区域
disable_color可选布尔类型参数用于决定是否隐藏拉伸控制部件的默认背景颜色默认值为False表示不隐藏显示默认的背景样式
"""
# SETUP UI
# 调用QWidget的构造函数来初始化该部件的基本属性这是继承自QWidget类所必需的操作确保部件能正常工作。
QWidget.__init__(self)
# 保存传入的父窗口对象,以便后续在操作中能获取和修改父窗口的属性(如大小、几何形状等)。
self.parent = parent
# 将当前CustomGrip部件设置为父窗口的子部件建立父子部件的层级关系使得其能在父窗口的布局中正确显示和响应事件。
self.setParent(parent)
# 创建一个Widgets类的实例Widgets类可能用于创建和管理窗口拉伸相关的各种子部件如各个边缘的操作区域等后续会通过它来构建具体的UI元素。
self.wi = Widgets()
# SHOW TOP GRIP
# 判断传入的位置参数是否为窗口顶部边缘Qt.TopEdge如果是则执行以下代码创建顶部边缘的拉伸控制相关的UI元素和操作逻辑。
if position == Qt.TopEdge:
# 调用Widgets实例的top方法传入当前CustomGrip部件自身self用于创建顶部边缘的相关布局和子部件比如包含左上角、顶部中间、右上角等可操作区域的框架。
self.wi.top(self)
# 设置当前CustomGrip部件的几何形状位置和大小使其在父窗口的顶部宽度与父窗口一致高度为10像素确定其在父窗口中的显示位置和尺寸。
self.setGeometry(0, 0, self.parent.width(), 10)
# 设置当前CustomGrip部件的最大高度为10像素限制用户通过拉伸等操作改变其高度时的最大值保持其在顶部边缘的固定高度样式。
self.setMaximumHeight(10)
# GRIPS
# 创建一个位于左上角的尺寸调整手柄QSizeGrip它关联到Widgets实例中创建的左上角框架self.wi.top_left
# 用户可以通过拖动这个手柄来调整窗口大小QSizeGrip会自动处理相关的鼠标交互和窗口大小改变逻辑。
top_left = QSizeGrip(self.wi.top_left)
# 同理,创建一个位于右上角的尺寸调整手柄,关联到右上角框架,用于实现右上角的窗口大小调整功能。
top_right = QSizeGrip(self.wi.top_right)
# RESIZE TOP
# 定义一个内部函数resize_top用于处理顶部边缘的鼠标拖动调整窗口大小的操作逻辑它接收鼠标事件对象event作为参数。
def resize_top(event):
"""
函数功能
当用户在顶部边缘区域拖动鼠标时根据鼠标移动的位移量来调整父窗口的高度实现窗口顶部拉伸的效果同时确保窗口高度不小于其最小高度限制
参数说明
event鼠标事件对象包含了鼠标的当前位置移动状态等信息通过它可以获取鼠标在窗口中的相对移动距离等关键数据来计算窗口大小的变化量
具体操作
首先获取鼠标移动的相对位移量delta然后根据父窗口当前高度和鼠标在y方向的位移量计算新的高度
通过设置父窗口的几何形状geometry来更新其大小最后接受该鼠标事件表示已处理完成该事件避免事件传播到其他地方产生意外行为
"""
# 获取鼠标相对于上一次鼠标位置的位移量相对坐标变化这里的pos方法返回的是在当前部件坐标系下的坐标变化值。
delta = event.pos()
# 根据父窗口当前高度和鼠标在y方向的位移量计算新的高度使用max函数确保新高度不小于父窗口的最小高度通过self.parent.minimumHeight()获取),
# 计算方式是用父窗口当前高度减去鼠标在y方向上的位移量因为是顶部拉伸向上拖动鼠标时y坐标减小高度应相应减小
height = max(self.parent.minimumHeight(), self.parent.height() - delta.y())
# 获取父窗口当前的几何形状信息(包括位置、大小等),以便后续基于这个基础进行修改。
geo = self.parent.geometry()
# 设置父窗口的顶部位置,通过用底部位置减去新计算的高度来确定,实现根据鼠标拖动改变窗口高度且保持底部位置相对稳定的效果(视觉上是从顶部拉伸窗口)。
geo.setTop(geo.bottom() - height)
# 将更新后的几何形状设置回父窗口,从而改变父窗口的大小和位置,实现窗口拉伸效果。
self.parent.setGeometry(geo)
# 接受该鼠标事件表示该事件已被成功处理告知Qt系统不需要再将此事件传递给其他潜在的事件处理函数避免重复处理或产生意外行为。
event.accept()
# 将定义好的resize_top函数赋值给Widgets实例中创建的顶部框架self.wi.top的mouseMoveEvent属性
# 这样当鼠标在顶部框架区域移动时即用户拖动鼠标进行拉伸操作时就会触发resize_top函数来处理窗口大小调整逻辑。
self.wi.top.mouseMoveEvent = resize_top
# ENABLE COLOR
# 根据disable_color参数的值判断是否隐藏顶部边缘相关部件的背景颜色若为True则设置相应部件的样式表styleSheet为透明背景
# 实现隐藏默认背景颜色的效果,使这些部件在视觉上更融入窗口或者符合特定的界面设计需求。
if disable_color:
self.wi.top_left.setStyleSheet("background: transparent")
self.wi.top_right.setStyleSheet("background: transparent")
self.wi.top.setStyleSheet("background: transparent")
# SHOW BOTTOM GRIP
# 判断传入的位置参数是否为窗口底部边缘Qt.BottomEdge若是则执行以下代码创建底部边缘的拉伸控制相关UI元素和操作逻辑。
elif position == Qt.BottomEdge:
# 调用Widgets实例的bottom方法传入当前CustomGrip部件自身用于创建底部边缘的相关布局和子部件包含左下角、底部中间、右下角等可操作区域的框架。
self.wi.bottom(self)
# 设置当前CustomGrip部件的几何形状使其位于父窗口底部宽度与父窗口一致高度为10像素确定其在父窗口中的显示位置和尺寸。
self.setGeometry(0, self.parent.height() - 10, self.parent.width(), 10)
# 设置当前CustomGrip部件的最大高度为10像素限制其在拉伸操作时高度的最大值保持底部边缘的固定高度样式。
self.setMaximumHeight(10)
# GRIPS
# 创建一个位于左下角的尺寸调整手柄关联到Widgets实例中创建的左下角框架self.wi.bottom_left用于实现左下角的窗口大小调整功能。
self.bottom_left = QSizeGrip(self.wi.bottom_left)
# 创建一个位于右下角的尺寸调整手柄,关联到右下角框架,用于实现右下角的窗口大小调整功能。
self.bottom_right = QSizeGrip(self.wi.bottom_right)
# RESIZE BOTTOM
# 定义一个内部函数resize_bottom用于处理底部边缘的鼠标拖动调整窗口大小的操作逻辑接收鼠标事件对象作为参数。
def resize_bottom(event):
"""
函数功能
当用户在底部边缘区域拖动鼠标时根据鼠标移动的位移量来调整父窗口的高度实现窗口底部拉伸的效果同时确保窗口高度不小于其最小高度限制
参数说明
event鼠标事件对象包含鼠标位置移动状态等信息通过它获取鼠标在窗口中的相对移动距离来计算窗口大小变化量
具体操作
首先获取鼠标移动的位移量然后根据父窗口当前高度和鼠标在y方向的位移量计算新的高度使用max函数确保新高度不小于最小高度
最后通过调整父窗口的大小改变高度来实现窗口底部拉伸的效果并接受该鼠标事件表示已处理完成
"""
# 获取鼠标相对于上一次鼠标位置的位移量,用于后续计算窗口高度的变化量。
delta = event.pos()
# 根据父窗口当前高度和鼠标在y方向的位移量计算新的高度这里是在父窗口当前高度基础上加上鼠标在y方向的位移量因为是底部拉伸向下拖动鼠标时y坐标增加高度应相应增加
# 同时通过max函数确保新高度不小于父窗口的最小高度。
height = max(self.parent.minimumHeight(), self.parent.height() + delta.y())
# 直接调用父窗口的resize方法传入父窗口的宽度保持不变和新计算的高度来改变父窗口的大小实现窗口底部拉伸效果。
self.parent.resize(self.parent.width(), height)
# 接受该鼠标事件,表示该事件已被正确处理,避免事件的重复处理或异常传播。
event.accept()
# 将resize_bottom函数赋值给Widgets实例中创建的底部框架self.wi.bottom的mouseMoveEvent属性
# 使得鼠标在底部框架区域移动时能触发该函数来处理窗口大小调整逻辑。
self.wi.bottom.mouseMoveEvent = resize_bottom
# ENABLE COLOR
# 根据disable_color参数判断是否隐藏底部边缘相关部件的背景颜色若为True则设置相应部件的样式表为透明背景隐藏默认颜色。
if disable_color:
self.wi.bottom_left.setStyleSheet("background: transparent")
self.wi.bottom_right.setStyleSheet("background: transparent")
self.wi.bottom.setStyleSheet("background: transparent")
# SHOW LEFT GRIP
# 判断传入的位置参数是否为窗口左侧边缘Qt.LeftEdge若是则执行以下代码创建左侧边缘的拉伸控制相关UI元素和操作逻辑。
elif position == Qt.LeftEdge:
# 调用Widgets实例的left方法传入当前CustomGrip部件自身用于创建左侧边缘的相关布局和子部件即创建一个可用于左侧拉伸操作的框架。
self.wi.left(self)
# 设置当前CustomGrip部件的几何形状使其位于父窗口左侧x坐标为0y坐标为10像素可能为了避开顶部的一些操作区域等宽度为10像素高度与父窗口一致
# 确定其在父窗口中的显示位置和尺寸。
self.setGeometry(0, 10, 10, self.parent.height())
# 设置当前CustomGrip部件的最大宽度为10像素限制其在拉伸操作时宽度的最大值保持左侧边缘的固定宽度样式。
self.setMaximumWidth(10)
# RESIZE LEFT
# 定义一个内部函数resize_left用于处理左侧边缘的鼠标拖动调整窗口大小的操作逻辑接收鼠标事件对象作为参数。
def resize_left(event):
"""
函数功能
当用户在左侧边缘区域拖动鼠标时根据鼠标移动的位移量来调整父窗口的宽度实现窗口左侧拉伸的效果同时确保窗口宽度不小于其最小宽度限制
参数说明
event鼠标事件对象包含鼠标位置移动状态等信息通过它获取鼠标在窗口中的相对移动距离来计算窗口宽度变化量
具体操作
首先获取鼠标移动的位移量然后根据父窗口当前宽度和鼠标在x方向的位移量计算新的宽度使用max函数确保新宽度不小于最小宽度
接着通过设置父窗口的几何形状来更新其宽度实现窗口左侧拉伸效果并接受该鼠标事件表示已处理完成
"""
# 获取鼠标相对于上一次鼠标位置的位移量,用于后续计算窗口宽度的变化量。
delta = event.pos()
# 根据父窗口当前宽度和鼠标在x方向的位移量计算新的宽度这里是用父窗口当前宽度减去鼠标在x方向的位移量因为是左侧拉伸向左拖动鼠标时x坐标减小宽度应相应减小
# 同时通过max函数确保新宽度不小于父窗口的最小宽度。
width = max(self.parent.minimumWidth(), self.parent.width() - delta.x())
# 获取父窗口当前的几何形状信息,以便基于此进行修改。
geo = self.parent.geometry()
# 设置父窗口的左侧位置,通过用右侧位置减去新计算的宽度来确定,实现根据鼠标拖动改变窗口宽度且保持右侧位置相对稳定的效果(视觉上是从左侧拉伸窗口)。
geo.setLeft(geo.right() - width)
# 将更新后的几何形状设置回父窗口,从而改变父窗口的大小和位置,实现窗口左侧拉伸效果。
self.parent.setGeometry(geo)
# 接受该鼠标事件,表示该事件已被成功处理,避免事件传播到其他地方产生意外行为。
event.accept()
# 将resize_left函数赋值给Widgets实例中创建的左侧框架self.wi.leftgrip的mouseMoveEvent属性
# 使得鼠标在左侧框架区域移动时能触发该函数来处理窗口大小调整逻辑。
self.wi.leftgrip.mouseMoveEvent = resize_left
# ENABLE COLOR
# 根据disable_color参数判断是否隐藏左侧边缘部件的背景颜色若为True则设置其样式表为透明背景隐藏默认颜色。
if disable_color:
self.wi.leftgrip.setStyleSheet("background: transparent")
# RESIZE RIGHT
# 判断传入的位置参数是否为窗口右侧边缘Qt.RightEdge若是则执行以下代码创建右侧边缘的拉伸控制相关UI元素和操作逻辑。
elif position == Qt.RightEdge:
# 调用Widgets实例的right方法传入当前CustomGrip部件自身用于创建右侧边缘的相关布局和子部件即创建一个可用于右侧拉伸操作的框架。
self.wi.right(self)
# 设置当前CustomGrip部件的几何形状使其位于父窗口右侧x坐标为父窗口宽度减去10像素因为右侧边缘宽度为10像素y坐标为10像素宽度为10像素高度与父窗口一致
# 确定其在父窗口中的显示位置和尺寸。
self.setGeometry(self.parent.width() - 10, 10, 10, self.parent.height())
# 设置当前CustomGrip部件的最大宽度为10像素限制其在拉伸操作时宽度的最大值保持右侧边缘的固定宽度样式。
self.setMaximumWidth(10)
def resize_right(event):
"""
函数功能
当用户在右侧边缘区域拖动鼠标时根据鼠标移动的位移量来调整父窗口的宽度实现窗口右侧拉伸的效果同时确保窗口宽度不小于其最小宽度限制
参数说明
event鼠标事件对象包含鼠标位置移动状态等信息通过它获取鼠标在窗口中的相对移动距离来计算窗口宽度变化量
具体操作
首先获取鼠标移动的位移量然后根据父窗口当前宽度和鼠标在x方向的位移量计算新的宽度使用max函数确保新宽度不小于最小宽度
最后通过调整父窗口的大小改变宽度来实现窗口右侧拉伸效果并接受该鼠标事件表示已处理完成
"""
# 获取鼠标相对于上一次鼠标位置的位移量,用于后续计算窗口宽度的变化量。
delta = event.pos()
# 根据父窗口当前宽度和鼠标在x方向的位移量计算新的宽度这里是在父窗口当前宽度基础上加上鼠标在x方向的位移量因为是右侧拉伸向右拖动鼠标时x坐标增加宽度应相应增加
# 同时通过max函数确保新宽度不小于父窗口的最小宽度。
width = max(self.parent.minimumWidth(), self.parent.width() + delta.x())
# 直接调用父窗口的resize方法传入新计算的宽度和父窗口的高度保持不变来改变父窗口的大小实现窗口右侧拉伸效果。
self.parent.resize(width, self.parent.height())
# 接受该鼠标事件,表示该事件已被正确处理,避免事件的重复处理或异常传播。
event.accept()
self.wi.rightgrip.mouseMoveEvent = resize_right
# ENABLE COLOR
if disable_color:
self.wi.rightgrip.setStyleSheet("background: transparent")
def mouseReleaseEvent(self, event):
self.mousePos = None
def resizeEvent(self, event):
if hasattr(self.wi, 'container_top'):
self.wi.container_top.setGeometry(0, 0, self.width(), 10)
elif hasattr(self.wi, 'container_bottom'):
self.wi.container_bottom.setGeometry(0, 0, self.width(), 10)
elif hasattr(self.wi, 'leftgrip'):
self.wi.leftgrip.setGeometry(0, 0, 10, self.height() - 20)
elif hasattr(self.wi, 'rightgrip'):
self.wi.rightgrip.setGeometry(0, 0, 10, self.height() - 20)
class Widgets(object):
def top(self, Form):
if not Form.objectName():
Form.setObjectName(u"Form")
self.container_top = QFrame(Form)
self.container_top.setObjectName(u"container_top")
self.container_top.setGeometry(QRect(0, 0, 500, 10))
self.container_top.setMinimumSize(QSize(0, 10))
self.container_top.setMaximumSize(QSize(16777215, 10))
self.container_top.setFrameShape(QFrame.NoFrame)
self.container_top.setFrameShadow(QFrame.Raised)
self.top_layout = QHBoxLayout(self.container_top)
self.top_layout.setSpacing(0)
self.top_layout.setObjectName(u"top_layout")
self.top_layout.setContentsMargins(0, 0, 0, 0)
self.top_left = QFrame(self.container_top)
self.top_left.setObjectName(u"top_left")
self.top_left.setMinimumSize(QSize(10, 10))
self.top_left.setMaximumSize(QSize(10, 10))
self.top_left.setCursor(QCursor(Qt.SizeFDiagCursor))
self.top_left.setStyleSheet(u"background-color: rgb(33, 37, 43);")
self.top_left.setFrameShape(QFrame.NoFrame)
self.top_left.setFrameShadow(QFrame.Raised)
self.top_layout.addWidget(self.top_left)
self.top = QFrame(self.container_top)
self.top.setObjectName(u"top")
self.top.setCursor(QCursor(Qt.SizeVerCursor))
self.top.setStyleSheet(u"background-color: rgb(85, 255, 255);")
self.top.setFrameShape(QFrame.NoFrame)
self.top.setFrameShadow(QFrame.Raised)
self.top_layout.addWidget(self.top)
self.top_right = QFrame(self.container_top)
self.top_right.setObjectName(u"top_right")
self.top_right.setMinimumSize(QSize(10, 10))
self.top_right.setMaximumSize(QSize(10, 10))
self.top_right.setCursor(QCursor(Qt.SizeBDiagCursor))
self.top_right.setStyleSheet(u"background-color: rgb(33, 37, 43);")
self.top_right.setFrameShape(QFrame.NoFrame)
self.top_right.setFrameShadow(QFrame.Raised)
self.top_layout.addWidget(self.top_right)
def bottom(self, Form):
if not Form.objectName():
Form.setObjectName(u"Form")
self.container_bottom = QFrame(Form)
self.container_bottom.setObjectName(u"container_bottom")
self.container_bottom.setGeometry(QRect(0, 0, 500, 10))
self.container_bottom.setMinimumSize(QSize(0, 10))
self.container_bottom.setMaximumSize(QSize(16777215, 10))
self.container_bottom.setFrameShape(QFrame.NoFrame)
self.container_bottom.setFrameShadow(QFrame.Raised)
self.bottom_layout = QHBoxLayout(self.container_bottom)
self.bottom_layout.setSpacing(0)
self.bottom_layout.setObjectName(u"bottom_layout")
self.bottom_layout.setContentsMargins(0, 0, 0, 0)
self.bottom_left = QFrame(self.container_bottom)
self.bottom_left.setObjectName(u"bottom_left")
self.bottom_left.setMinimumSize(QSize(10, 10))
self.bottom_left.setMaximumSize(QSize(10, 10))
self.bottom_left.setCursor(QCursor(Qt.SizeBDiagCursor))
self.bottom_left.setStyleSheet(u"background-color: rgb(33, 37, 43);")
self.bottom_left.setFrameShape(QFrame.NoFrame)
self.bottom_left.setFrameShadow(QFrame.Raised)
self.bottom_layout.addWidget(self.bottom_left)
self.bottom = QFrame(self.container_bottom)
self.bottom.setObjectName(u"bottom")
self.bottom.setCursor(QCursor(Qt.SizeVerCursor))
self.bottom.setStyleSheet(u"background-color: rgb(85, 170, 0);")
self.bottom.setFrameShape(QFrame.NoFrame)
self.bottom.setFrameShadow(QFrame.Raised)
self.bottom_layout.addWidget(self.bottom)
self.bottom_right = QFrame(self.container_bottom)
self.bottom_right.setObjectName(u"bottom_right")
self.bottom_right.setMinimumSize(QSize(10, 10))
self.bottom_right.setMaximumSize(QSize(10, 10))
self.bottom_right.setCursor(QCursor(Qt.SizeFDiagCursor))
self.bottom_right.setStyleSheet(u"background-color: rgb(33, 37, 43);")
self.bottom_right.setFrameShape(QFrame.NoFrame)
self.bottom_right.setFrameShadow(QFrame.Raised)
self.bottom_layout.addWidget(self.bottom_right)
def left(self, Form):
if not Form.objectName():
Form.setObjectName(u"Form")
self.leftgrip = QFrame(Form)
self.leftgrip.setObjectName(u"left")
self.leftgrip.setGeometry(QRect(0, 10, 10, 480))
self.leftgrip.setMinimumSize(QSize(10, 0))
self.leftgrip.setCursor(QCursor(Qt.SizeHorCursor))
self.leftgrip.setStyleSheet(u"background-color: rgb(255, 121, 198);")
self.leftgrip.setFrameShape(QFrame.NoFrame)
self.leftgrip.setFrameShadow(QFrame.Raised)
def right(self, Form):
if not Form.objectName():
Form.setObjectName(u"Form")
Form.resize(500, 500)
self.rightgrip = QFrame(Form)
self.rightgrip.setObjectName(u"right")
self.rightgrip.setGeometry(QRect(0, 0, 10, 500))
self.rightgrip.setMinimumSize(QSize(10, 0))
self.rightgrip.setCursor(QCursor(Qt.SizeHorCursor))
self.rightgrip.setStyleSheet(u"background-color: rgb(255, 0, 127);")
self.rightgrip.setFrameShape(QFrame.NoFrame)
self.rightgrip.setFrameShadow(QFrame.Raised)
# 将resize_right函数赋值给Widgets实例中创建的右侧框架self.wi.rightgrip的mouseMoveEvent属性
# 使得鼠标在右侧框架区域移动时能触发该函数来处理窗口大小调整逻辑。

@ -17,7 +17,7 @@ from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
QPalette, QPixmap, QRadialGradient, QTransform)
from PySide6.QtWidgets import (QApplication, QHBoxLayout, QLabel, QLineEdit,
QPushButton, QSizePolicy, QWidget)
#LXY
class id_form(object):
def setupUi(self, Form):
if not Form.objectName():
@ -153,8 +153,9 @@ class id_form(object):
self.retranslateUi(Form)
QMetaObject.connectSlotsByName(Form)
# setupUi
# setupUi
# LXY
def retranslateUi(self, Form):
Form.setWindowTitle(QCoreApplication.translate("Form", u"单目标追踪", None))
self.label.setText(QCoreApplication.translate("Form", u"请输入车辆ID:", None))

@ -1,21 +1,33 @@
# -*- coding: utf-8 -*-
# @Author : pan
# 导入系统模块,用于处理命令行参数以及退出程序等相关操作
import sys
# 从PySide6.QtWidgets模块中导入QApplication和QWidget类QApplication用于管理整个应用程序的生命周期和相关设置QWidget是所有可视化窗口部件的基类
from PySide6.QtWidgets import QApplication, QWidget
# 从ui.dialog.id_dialog模块中导入id_form从名称推测可能是用于定义窗口界面相关的表单或者布局等内容具体需看其实际代码
from ui.dialog.id_dialog import id_form
# LXY (这里看起来像是一个随意的标记或者注释占位符,暂时不清楚其具体用途,可以先忽略)
# 定义id_Window类它继承自QWidget和id_form意味着这个类的实例将既是一个Qt的窗口部件又具备id_form所定义的相关界面特性或功能
class id_Window(QWidget, id_form):
def __init__(self):
super(id_Window, self).__init__()
# 调用从id_form继承来的setupUi方法通常用于初始化和设置窗口的用户界面比如添加各种控件、布局等并将当前窗口实例自身作为参数传入以完成界面的初始化工作
self.setupUi(self)
# 将按钮从名称self.pushButton推测它是窗口中的一个按钮部件的点击信号与closeWindow方法进行连接这样当按钮被点击时就会执行closeWindow方法来关闭窗口
self.pushButton.clicked.connect(self.closeWindow)
def closeWindow(self):
# 定义关闭窗口的方法调用窗口自身的close方法来关闭当前窗口实例实现点击按钮关闭窗口的功能
self.close()
if __name__ == '__main__':
# 创建QApplication类的实例它是整个Qt应用程序的基础需要传入命令行参数sys.argv用于接收外部传入的参数以及初始化应用程序的相关设置
app = QApplication(sys.argv)
# 创建id_Window类的实例也就是创建了一个具有特定界面和关闭功能的窗口对象
window = id_Window()
# 显示创建好的窗口,使其在屏幕上可见
window.show()
# 启动应用程序的事件循环开始处理用户交互、窗口绘制等各种事件直到应用程序退出最后通过sys.exit来确保应用程序能正确退出并返回应用程序的退出状态码
sys.exit(app.exec())

@ -12,8 +12,8 @@ from typing import Dict, List, Union
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, SETTINGS, SETTINGS_YAML,
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
yaml_print)
#ffdkds
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify', 'pose'
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
@ -27,7 +27,7 @@ TASK2METRIC = {
'segment': 'metrics/mAP50-95(M)',
'classify': 'metrics/accuracy_top1',
'pose': 'metrics/mAP50-95(P)'}
#ffdkds
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
@ -63,7 +63,7 @@ CLI_HELP_MSG = \
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
#ffdkds
# Define keys for arg type checks
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
@ -76,7 +76,7 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', '
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
#ffdkds
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
@ -86,18 +86,18 @@ def cfg2dict(cfg):
Returns:
cfg (dict): Configuration object in dictionary format.
"""
"""#ffdkds
if isinstance(cfg, (str, Path)):
cfg = yaml_load(cfg) # load dict
elif isinstance(cfg, SimpleNamespace):
cfg = vars(cfg) # convert to dict
return cfg
#ffdkds
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
#ffdkds
Args:
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
@ -106,13 +106,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
(SimpleNamespace): Training arguments namespace.
"""
cfg = cfg2dict(cfg)
# ffdkds
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
check_dict_alignment(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# ffdkds
# Special handling for numeric project/name
for k in 'project', 'name':
if k in cfg and isinstance(cfg[k], (int, float)):
@ -120,7 +120,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
if cfg.get('name') == 'model': # assign model to 'name' arg
cfg['name'] = cfg.get('model', '').split('.')[0]
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
# ffdkds
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
@ -140,11 +140,11 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
# ffdkds
# Return instance
return IterableSimpleNamespace(**cfg)
#ffdkds
def _handle_deprecation(custom):
"""Hardcoded function to handle deprecated config keys"""
@ -158,15 +158,15 @@ def _handle_deprecation(custom):
if key == 'line_thickness':
deprecation_warn(key, 'line_width')
custom['line_width'] = custom.pop('line_thickness')
# ffdkds
return custom
#ffdkds
def check_dict_alignment(base: Dict, custom: Dict, e=None):
"""
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
#ffdkds
Args:
custom (dict): a dictionary of custom configuration options
base (dict): a dictionary of base configuration options
@ -183,16 +183,16 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
#ffdkds
def merge_equals_args(args: List[str]) -> List[str]:
"""
Merges arguments around isolated '=' args in a list of strings.
The function considers cases where the first argument ends with '=' or the second starts with '=',
as well as when the middle one is an equals sign.
#ffdkds
Args:
args (List[str]): A list of strings where each element is an argument.
#ffdkds
Returns:
List[str]: A list of strings where the arguments around isolated '=' are merged.
"""
@ -209,25 +209,25 @@ def merge_equals_args(args: List[str]) -> List[str]:
else:
new_args.append(arg)
return new_args
#ffdkds
#ffdkds
def handle_yolo_hub(args: List[str]) -> None:
"""
Handle Ultralytics HUB command-line interface (CLI) commands.
This function processes Ultralytics HUB CLI commands such as login and logout.
It should be called when executing a script with arguments related to HUB authentication.
#ffdkds
Args:
args (List[str]): A list of command line arguments
#ffdkds
Example:
```python
python my_script.py hub login your_api_key
```
#ffdkds```
"""
from ultralytics import hub
# ffdkds
if args[0] == 'login':
key = args[1] if len(args) > 1 else ''
# Log in to Ultralytics HUB using the provided API key
@ -235,7 +235,7 @@ def handle_yolo_hub(args: List[str]) -> None:
elif args[0] == 'logout':
# Log out from Ultralytics HUB
hub.logout()
#ffdkds
def handle_yolo_settings(args: List[str]) -> None:
"""
@ -243,14 +243,14 @@ def handle_yolo_settings(args: List[str]) -> None:
This function processes YOLO settings CLI commands such as reset.
It should be called when executing a script with arguments related to YOLO settings management.
#ffdkds
Args:
args (List[str]): A list of command line arguments for YOLO settings management.
#ffdkds
Example:
```python
python my_script.py yolo settings reset
```
```#ffdkds
"""
if any(args):
if args[0] == 'reset':
@ -264,7 +264,7 @@ def handle_yolo_settings(args: List[str]) -> None:
yaml_print(SETTINGS_YAML) # print the current settings
#ffdkds
def parse_key_value_pair(pair):
"""Parse one 'key=value' pair and return key and value."""
re.sub(r' *= *', '=', pair) # remove spaces around equals sign
@ -272,7 +272,7 @@ def parse_key_value_pair(pair):
assert v, f"missing '{k}' value"
return k, smart_value(v)
#ffdkds
def smart_value(v):
"""Convert a string to an underlying type such as int, float, bool, etc."""
if v.lower() == 'none':
@ -285,13 +285,13 @@ def smart_value(v):
with contextlib.suppress(Exception):
return eval(v)
return v
#ffdkds
def entrypoint(debug=''):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
#ffdkds
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
@ -306,7 +306,7 @@ def entrypoint(debug=''):
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
# ffdkds
special = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,

@ -1,39 +1,78 @@
from pathlib import Path
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
函数功能概述
该函数实现了利用YOLO目标检测模型和SAM分割模型自动对图像进行标注的功能
它先通过YOLO模型检测出图像中的目标物体的边界框然后使用SAM模型基于这些边界框生成对应的分割掩码
最后将标注结果类别ID和分割掩码信息保存到文本文件中
Args:
data (str): Path to a folder containing images to be annotated.
输入参数说明
- data表示包含待标注图像的文件夹路径字符串类型指向存放需要进行标注处理的图像所在的文件夹位置
det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
- det_model可选参数用于指定预训练的YOLO检测模型的路径或名称默认值为'yolov8x.pt'字符串类型
sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
- sam_model可选参数用于指定预训练的SAM分割模型的路径或名称默认值为'sam_b.pt'字符串类型
device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
- device可选参数指定运行模型的设备默认是空字符串意味着会根据系统实际情况自动选择使用CPU或者可用的GPU来运行模型
output_dir (str | None | optional): Directory to save the annotated results.
Defaults to a 'labels' folder in the same directory as 'data'.
- output_dir可选参数用于指定保存标注结果的目录路径如果不传入该参数为None
则默认会在与'data'参数所指文件夹同一目录下创建名为'labels'的文件夹来保存结果
"""
# 实例化YOLO目标检测模型传入指定的预训练模型名称或路径创建出可用于进行目标检测的模型对象
det_model = YOLO(det_model)
# 实例化SAM分割模型传入指定的预训练模型名称或路径创建出可用于进行图像分割的模型对象
sam_model = SAM(sam_model)
# 如果没有指定输出目录output_dir为None
if not output_dir:
# 则将输出目录设置为'data'所指文件夹的父目录下的'labels'文件夹,
# 先将'data'转换为Path对象获取其父目录再拼接上'labels'文件夹名称得到输出目录路径
output_dir = Path(str(data)).parent / 'labels'
# 创建输出目录如果目录已经存在则不会报错exist_ok=True同时会创建父目录如果父目录不存在的话parents=True
Path(output_dir).mkdir(exist_ok=True, parents=True)
# 使用YOLO目标检测模型对输入的图像数据进行检测传入图像数据所在路径data
# 设置以流的形式处理图像stream=True适用于处理大量图像时节省内存并指定运行设备device
# 检测结果会以可迭代对象的形式返回,后续可以逐次获取每张图像的检测结果
det_results = det_model(data, stream=True, device=device)
# 遍历YOLO目标检测模型返回的每张图像的检测结果
for result in det_results:
boxes = result.boxes.xyxy # Boxes object for bbox outputs
class_ids = result.boxes.cls.int().tolist() # noqa
# 获取检测结果中的边界框坐标信息(左上角和右下角坐标),以张量形式返回,每个元素对应一个目标物体的边界框坐标
boxes = result.boxes.xyxy
# 获取检测结果中每个目标物体对应的类别ID将其转换为整数列表方便后续处理
class_ids = result.boxes.cls.int().tolist()
# 如果检测到了目标物体即类别ID列表长度大于0
if len(class_ids):
# 使用SAM分割模型对原始图像result.orig_img进行分割传入由YOLO模型检测到的边界框信息bboxes=boxes
# 设置不输出详细信息verbose=False不保存中间结果save=False并指定运行设备device
# 返回的结果是一个包含分割信息的对象列表,这里取第一个元素(因为只处理了一张图像的情况)
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
# 从SAM分割结果中获取分割掩码信息以归一化坐标形式表示每个元素对应一个目标物体的分割掩码坐标信息
segments = sam_results[0].masks.xyn
# 打开一个文本文件,文件名根据当前图像的文件名(去除扩展名)加上'.txt'后缀生成文件保存在输出目录output_dir
# 以写入模式('w')打开,如果文件不存在则创建,如果存在则覆盖原有内容
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
# 遍历每个目标物体的分割掩码信息
for i in range(len(segments)):
s = segments[i]
# 如果分割掩码为空长度为0则跳过当前循环不进行保存操作
if len(s) == 0:
continue
# 将分割掩码的坐标信息转换为字符串形式,方便写入文件
segment = map(str, segments[i].reshape(-1).tolist())
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
# 将类别ID和分割掩码坐标信息写入文本文件每个目标物体占一行格式为类别ID + 空格 + 分割掩码坐标(以空格分隔的字符串形式)
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')

@ -57,16 +57,30 @@ class Compose:
data = t(data)
return data
from pathlib import Path
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
函数功能概述
该函数实现了利用YOLO目标检测模型和SAM分割模型自动对图像进行标注的功能
它先通过YOLO模型检测出图像中的目标物体的边界框然后使用SAM模型基于这些边界框生成对应的分割掩码
最后将标注结果类别ID和分割掩码信息保存到文本文件中
Args:
def append(self, transform):
"""Appends a new transform to the existing list of transforms."""
"""Appends a new transform to the existing list of transforms."""
self.transforms.append(transform)
def tolist(self):
"""Converts list of transforms to a standard Python list."""
"""Converts list of transforms to a standard Python list."""
return self.transforms
def __repr__(self):
"""Return string representation of object."""
"""Return string representation of object."""
format_string = f'{self.__class__.__name__}('
for t in self.transforms:
format_string += '\n'

@ -18,12 +18,12 @@ from tqdm import tqdm
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS
#lxy
class BaseDataset(Dataset):
"""
Base dataset class for loading and processing image data.
#lxy
Args:
img_path (str): Path to the folder containing images.
imgsz (int, optional): Image size. Defaults to 640.
@ -38,7 +38,7 @@ class BaseDataset(Dataset):
single_cls (bool, optional): If True, single class training is used. Defaults to False.
classes (list): List of included classes. Default is None.
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
#lxy
Attributes:
im_files (list): List of image file paths.
labels (list): List of label data dictionaries.
@ -48,6 +48,9 @@ class BaseDataset(Dataset):
transforms (callable): Image transformation function.
"""
# lxy
# lxy
def __init__(self,
img_path,
imgsz=640,
@ -80,10 +83,12 @@ class BaseDataset(Dataset):
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
# lxy
# Buffer thread for mosaic images
self.buffer = [] # buffer size = batch size
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
# lxy
# Cache stuff
if cache == 'ram' and not self.check_cache_ram():
@ -92,10 +97,13 @@ class BaseDataset(Dataset):
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache:
self.cache_images(cache)
# lxy
# Transforms
self.transforms = self.build_transforms(hyp=hyp)
# lxy
def get_img_files(self, img_path):
"""Read image files."""
try:
@ -122,6 +130,8 @@ class BaseDataset(Dataset):
im_files = im_files[:round(len(im_files) * self.fraction)]
return im_files
# lxy
def update_labels(self, include_class: Optional[list]):
"""include_class, filter labels to include only these classes (optional)."""
include_class_array = np.array(include_class).reshape(1, -1)
@ -141,6 +151,8 @@ class BaseDataset(Dataset):
if self.single_cls:
self.labels[i]['cls'][:, 0] = 0
# lxy
def load_image(self, i):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
@ -170,6 +182,8 @@ class BaseDataset(Dataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
# lxy
def cache_images(self, cache):
"""Cache images to memory or disk."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
@ -186,12 +200,16 @@ class BaseDataset(Dataset):
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
pbar.close()
# lxy
def cache_images_to_disk(self, i):
"""Saves an image as an *.npy file for faster loading."""
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
# lxy
def check_cache_ram(self, safety_margin=0.5):
"""Check image caching requirements vs available memory."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
@ -210,10 +228,13 @@ class BaseDataset(Dataset):
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
return cache
# lxy
def set_rectangle(self):
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
# lxy
s = np.array([x.pop('shape') for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
@ -221,6 +242,7 @@ class BaseDataset(Dataset):
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
ar = ar[irect]
# lxy
# Set training image shapes
shapes = [[1, 1]] * nb
@ -235,10 +257,14 @@ class BaseDataset(Dataset):
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
# lxy
def __getitem__(self, index):
"""Returns transformed label information for given index."""
return self.transforms(self.get_image_and_label(index))
# lxy
def get_image_and_label(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
@ -250,14 +276,20 @@ class BaseDataset(Dataset):
label['rect_shape'] = self.batch_shapes[self.batch[index]]
return self.update_labels_info(label)
# lxy
def __len__(self):
"""Returns the length of the labels list for the dataset."""
return len(self.labels)
# lxy
def update_labels_info(self, label):
"""custom your label format here."""
return label
# lxy
def build_transforms(self, hyp=None):
"""Users can custom augmentations here
like:
@ -270,6 +302,8 @@ class BaseDataset(Dataset):
"""
raise NotImplementedError
# lxy
def get_labels(self):
"""Users can custom their own format here.
Make sure your output is a list with each element like below:

@ -20,57 +20,57 @@ from .utils import PIN_MEMORY
class InfiniteDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
def __init__(self, *args, **kwargs):
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
"""Returns the length of the batch sampler's sampler."""
"""Returns the length of the batch sampler's sampler."""
return len(self.batch_sampler.sampler)
def __iter__(self):
"""Creates a sampler that repeats indefinitely."""
"""Creates a sampler that repeats indefinitely."""
for _ in range(len(self)):
yield next(self.iterator)
def reset(self):
"""Reset iterator.
This is useful when we want to modify settings of dataset while training.
"""
"""Reset iterator.
This is useful when we want to modify settings of dataset while training.
"""
self.iterator = self._get_iterator()
class _RepeatSampler:
"""
Sampler that repeats forever.
"""
Sampler that repeats forever.
Args:
sampler (Dataset.sampler): The sampler to repeat.
"""
Args:
sampler (Dataset.sampler): The sampler to repeat.
"""
def __init__(self, sampler):
"""Initializes an object that repeats a given sampler indefinitely."""
"""Initializes an object that repeats a given sampler indefinitely."""
self.sampler = sampler
def __iter__(self):
"""Iterates over the 'sampler' and yields its contents."""
while True:
yield from iter(self.sampler)
def seed_worker(worker_id): # noqa
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
"""Build YOLO Dataset"""
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32
return YOLODataset(
img_path=img_path,
imgsz=cfg.imgsz,
@ -110,7 +110,7 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
def check_source(source):
"""Check source type and return corresponding flag values."""
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
@ -137,20 +137,20 @@ def check_source(source):
def load_inference_source(source=None, imgsz=640, vid_stride=1):
"""
Loads an inference source for object detection and applies necessary transformations.
Loads an inference source for object detection and applies necessary transformations.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
Returns:
dataset (Dataset): A dataset object for the specified input source.
dataset (Dataset): A dataset object for the specified input source.
"""
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
# Dataloader
# Dataloader
if tensor:
dataset = LoadTensor(source)
elif in_memory:
@ -164,7 +164,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1):
else:
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
# Attach source types to the dataset
# Attach source types to the dataset
setattr(dataset, 'source_type', source_type)
return dataset

@ -13,8 +13,8 @@ from ultralytics.utils.files import make_dirs
def coco91_to_coco80_class():
"""Converts 91-index COCO class IDs to 80-index COCO class IDs.
Returns:
(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
Returns:
(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
corresponding 91-index class ID.
"""

@ -19,14 +19,14 @@ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_lab
class YOLODataset(BaseDataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
data (dict, optional): A dataset YAML dictionary. Defaults to None.
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
Returns:
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
super().__init__(*args, **kwargs)
def cache_labels(self, path=Path('./labels.cache')):
"""Cache dataset labels, check images and read shapes.
"""Cache dataset labels, check images and read shapes.
Args:
path (Path): path where to save the cache file (default: Path('./labels.cache')).
Returns:

@ -29,10 +29,10 @@ class SourceTypes:
class LoadStreams:
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
capture = None
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
"""Initialize instance variables and check for consistent input stream shapes."""
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
self.imgsz = imgsz
@ -42,10 +42,10 @@ class LoadStreams:
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
s = get_best_youtube_url(s)
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0 and (is_colab() or is_kaggle()):
@ -70,14 +70,14 @@ class LoadStreams:
self.threads[i].start()
LOGGER.info('') # newline
# Check for common shapes
# Check for common shapes
self.bs = self.__len__()
def update(self, i, cap, stream):
"""Read stream `i` frames in daemon thread."""
"""Read stream `i` frames in daemon thread."""
n, f = 0, self.frames[i] # frame number, frame array
while cap.isOpened() and n < f:
# Only read a new frame if the buffer is empty
# Only read a new frame if the buffer is empty
if not self.imgs[i]:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
@ -94,22 +94,22 @@ class LoadStreams:
if self.capture == 'release':
break
def __iter__(self):
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
self.count = -1
return self
def __next__(self):
"""Returns source paths, transformed and original images for processing."""
"""Returns source paths, transformed and original images for processing."""
self.count += 1
# Wait until a frame is available in each buffer
# Wait until a frame is available in each buffer
while not all(self.imgs):
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows()
raise StopIteration
time.sleep(1 / min(self.fps))
# Get and remove the next frame from imgs buffer
# Get and remove the next frame from imgs buffer
return self.sources, [x.pop(0) for x in self.imgs], None, ''
def __len__(self):
@ -118,10 +118,10 @@ class LoadStreams:
class LoadScreenshots:
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
def __init__(self, source, imgsz=640):
"""source = [screen_number left top width height] (pixels)."""
"""source = [screen_number left top width height] (pixels)."""
check_requirements('mss')
import mss # noqa
@ -139,7 +139,7 @@ class LoadScreenshots:
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
self.top = monitor['top'] if top is None else (monitor['top'] + top)
self.left = monitor['left'] if left is None else (monitor['left'] + left)
@ -148,11 +148,11 @@ class LoadScreenshots:
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
def __iter__(self):
"""Returns an iterator of the object."""
"""Returns an iterator of the object."""
return self
def __next__(self):
"""mss screen capture: get raw pixels from the screen as np array."""
"""mss screen capture: get raw pixels from the screen as np array."""
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
@ -161,10 +161,10 @@ class LoadScreenshots:
class LoadImages:
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
def __init__(self, path, imgsz=640, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
parent = Path(path).parent
@ -230,11 +230,11 @@ class LoadImages:
success, im0 = self.cap.read()
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
else:
# Read image
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
if im0 is None:
@ -244,17 +244,17 @@ class LoadImages:
return [path], [im0], self.cap, s
def _new_video(self, path):
"""Create a new video capture object."""
"""Create a new video capture object."""
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
def _cv2_rotate(self, im):
"""Rotate a cv2 video manually."""
"""Rotate a cv2 video manually."""
if self.orientation == 0:
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
elif self.orientation == 180:
@ -271,7 +271,7 @@ class LoadImages:
class LoadPilAndNumpy:
def __init__(self, im0, imgsz=640):
"""Initialize PIL and Numpy Dataloader."""
"""Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list):
im0 = [im0]
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
@ -283,7 +283,7 @@ class LoadPilAndNumpy:
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
"""Validate and format an image to numpy array."""
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
if isinstance(im, Image.Image):
if im.mode != 'RGB':
@ -304,7 +304,7 @@ class LoadPilAndNumpy:
return self.paths, self.im0, None, ''
def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy."""
"""Enables iteration for class LoadPilAndNumpy."""
self.count = 0
return self
@ -319,7 +319,7 @@ class LoadTensor:
@staticmethod
def _single_check(im, stride=32):
"""Validate and format an image to torch.Tensor."""
"""Validate and format an image to torch.Tensor."""
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
if len(im.shape) != 4:

@ -31,20 +31,20 @@ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
# Get orientation exif tag
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
def img2label_paths(img_paths):
"""Define label paths as a function of image paths."""
"""Define label paths as a function of image paths."""
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
"""Returns a single hash value of a list of paths (files or dirs)."""
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes
h.update(''.join(paths).encode()) # hash paths
@ -52,7 +52,7 @@ def get_hash(paths):
def exif_size(img):
"""Returns exif-corrected PIL size."""
"""Returns exif-corrected PIL size."""
s = img.size # (width, height)
with contextlib.suppress(Exception):
rotation = dict(img._getexif().items())[orientation]
@ -62,12 +62,12 @@ def exif_size(img):
def verify_image_label(args):
"""Verify one image-label pair."""
"""Verify one image-label pair."""
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
# Number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
try:
# Verify images
# Verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
@ -81,7 +81,7 @@ def verify_image_label(args):
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
# Verify labels
# Verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file) as f:
@ -102,7 +102,7 @@ def verify_image_label(args):
assert (lb[:, 1:] <= 1).all(), \
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
# All labels
# All labels
max_cls = int(lb[:, 0].max()) # max label count
assert max_cls <= num_cls, \
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
@ -137,11 +137,11 @@ def verify_image_label(args):
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
"""
Args:
imgsz (tuple): The image size.
polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
color (int): color
downsample_ratio (int): downsample ratio
Args:
imgsz (tuple): The image size.
polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
color (int): color
downsample_ratio (int): downsample ratio
"""
mask = np.zeros(imgsz, dtype=np.uint8)
polygons = np.asarray(polygons)
@ -150,18 +150,18 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
polygons = polygons.reshape(shape[0], -1, 2)
cv2.fillPoly(mask, polygons, color=color)
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# NOTE: fillPoly firstly then resize is trying the keep the same way
# NOTE: fillPoly firstly then resize is trying the keep the same way
# of loss calculation when mask-ratio=1.
mask = cv2.resize(mask, (nw, nh))
return mask
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
"""
Args:
"""
A rgs:
imgsz (tuple): The image size.
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
color (int): color
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
color (int): color
downsample_ratio (int): downsample ratio
"""
masks = []
@ -192,21 +192,21 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
def check_det_dataset(dataset, autodownload=True):
"""Download, check and/or unzip dataset if not found locally."""
""" Download, check and/or unzip dataset if not found locally."""
data = check_file(dataset)
# Download (optional)
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
# Read yaml (optional)
if isinstance(data, (str, Path)):
data = yaml_load(data, append_filename=True) # dictionary
# Checks
# Checks
for k in 'train', 'val':
if k not in data:
raise SyntaxError(
@ -222,12 +222,12 @@ def check_det_dataset(dataset, autodownload=True):
data['names'] = check_class_names(data['names'])
# Resolve paths
# Resolve paths
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
data['path'] = path # download scripts
data['path'] = path # download scripts
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
if isinstance(data[k], str):
@ -238,7 +238,7 @@ def check_det_dataset(dataset, autodownload=True):
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse yaml
# Parse yaml
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
@ -271,15 +271,15 @@ def check_cls_dataset(dataset: str, split=''):
"""
Checks a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str): The name of the dataset.
dataset (str): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
(dict): A dictionary containing the following keys:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
@ -287,7 +287,7 @@ def check_cls_dataset(dataset: str, split=''):
- 'names' (dict): A dictionary of class names in the dataset.
Raises:
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
"""
dataset = Path(dataset)
@ -318,19 +318,19 @@ def check_cls_dataset(dataset: str, split=''):
class HUBDatasetStats():
"""
A class for generating HUB dataset JSON and `-hub` dataset directory.
A class for generating HUB dataset JSON and `-hub` dataset directory.
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Usage
from ultralytics.data.utils import HUBDatasetStats
from ultralytics.data.utils import HUBDatasetStats
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
stats.get_json(save=False)
stats.get_json(save=False)
stats.process_images()
"""

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
Ex port a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
Format | `format=argument` | Model
--- | --- | ---
@ -27,7 +27,7 @@ Python:
results = model.export(format='onnx')
CLI:
$ yolo mode=export model=yolov8n.pt format=onnx
$ yolo mode=export model=yolov8n.pt format=onnx
Inference:
$ yolo predict model=yolov8n.pt # PyTorch

@ -17,45 +17,7 @@ from ultralytics.utils.torch_utils import smart_inference_mode
class Model:
"""
A base model class to unify apis for all the models.
Args:
model (str, Path): Path to the model file to load or create.
task (Any, optional): Task type for the YOLO model. Defaults to None.
Attributes:
predictor (Any): The predictor object.
model (Any): The model object.
trainer (Any): The trainer object.
task (str): The type of model task.
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics (Any): The data for metrics.
Methods:
__call__(source=None, stream=False, **kwargs):
Alias for the predict method.
_new(cfg:str, verbose:bool=True) -> None:
Initializes a new model and infers the task type from the model definitions.
_load(weights:str, task:str='') -> None:
Initializes a new model and infers the task type from the model head.
_check_is_pytorch_model() -> None:
Raises TypeError if the model is not a PyTorch model.
reset() -> None:
Resets the model modules.
info(verbose:bool=False) -> None:
Logs the model info.
fuse() -> None:
Fuses the model for faster inference.
predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
Performs prediction using the YOLO model.
Returns:
list(ultralytics.engine.results.Results): The prediction results.
"""
""" """
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
"""

@ -15,7 +15,7 @@ Usage - sources:
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats:
$ yolo mode=predict model=yolov8n.pt # PyTorch
$ yolo mode=predict model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO

@ -25,7 +25,7 @@ class BaseTensor(SimpleClass):
def __init__(self, data, orig_shape) -> None:
"""Initialize BaseTensor with data and original shape.
Args:
Args:
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
orig_shape (tuple): Original shape of image.
"""

@ -38,7 +38,7 @@ class BaseTrainer:
"""
BaseTrainer
A base class for creating trainers.
A base class for creating trainers.
Attributes:
args (SimpleNamespace): Configuration for the trainer.

@ -6,7 +6,7 @@ Usage:
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
Usage - formats:
$ yolo mode=val model=yolov8n.pt # PyTorch
$ yolo mode=val model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO

@ -10,7 +10,7 @@ from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
def login(api_key=''):
"""
Log in to the Ultralytics HUB API using the provided API key.
Log in to the Ultralytics HUB API using the provided API key.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id

@ -13,7 +13,7 @@ class Auth:
def __init__(self, api_key='', verbose=False):
"""
Initialize the Auth class with an optional API key.
Initialize the Auth class with an optional API key.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id

@ -14,24 +14,7 @@ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__versio
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
Args:
url (str): Model identifier used to initialize the HUB training session.
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLOv5 model being trained.
model_url (str): URL for the model in Ultralytics HUB.
api_url (str): API URL for the model in Ultralytics HUB.
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
rate_limits (dict): Rate limits for different API calls (in seconds).
timers (dict): Timers for rate limiting.
metrics_queue (dict): Queue for the model's metrics.
model (dict): Model data fetched from Ultralytics HUB.
alive (bool): Indicates if the heartbeat loop is active.
"""
""" """
def __init__(self, url):
"""

@ -23,7 +23,7 @@ HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.co
def request_with_credentials(url: str) -> any:
"""
Make an AJAX request with cookies attached in a Google Colab environment.
Make an AJAX request with cookies attached in a Google Colab environment.
Args:
url (str): The URL to make the request to.

@ -10,12 +10,12 @@ from .val import FastSAMValidator
class FastSAM(Model):
"""
FastSAM model interface.
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""

@ -15,7 +15,7 @@ class FastSAMPredictor(DetectionPredictor):
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
@ -40,11 +40,13 @@ class FastSAMPredictor(DetectionPredictor):
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
#lxy
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
#lxy
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)

@ -12,13 +12,13 @@ from PIL import Image
class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None:
# self.img_path = img_path
# self.img_path = img_path
self.device = device
self.results = results
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
# Import and assign clip
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
@ -35,7 +35,7 @@ class FastSAMPrompt:
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
# t ransparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
@ -83,7 +83,7 @@ class FastSAMPrompt:
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# 将多个bbox合并成一个
# 将多个bbox合并成一个
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
@ -109,10 +109,10 @@ class FastSAMPrompt:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for macOS only
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())

@ -7,22 +7,33 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): (n, 4)
image_shape (tuple): (height, width)
threshold (int): pixel threshold
Returns:
adjusted_boxes (torch.Tensor): adjusted bounding boxes
函数功能
该函数用于调整边界框bounding boxes的坐标使其在距离图像边界小于指定阈值时能够紧贴图像的边界
参数说明
boxes (torch.Tensor): (n, 4)
- 输入的边界框张量形状为 (n, 4)其中 n 表示边界框的数量每个边界框由 4 个值表示
通常分别对应左上角的 x 坐标左上角的 y 坐标右下角的 x 坐标右下角的 y 坐标
image_shape (tuple): (height, width)
- 图像的形状以元组形式表示包含图像的高度height和宽度width两个维度的尺寸信息
threshold (int): pixel threshold
- 像素阈值用于判断边界框是否需要调整紧贴图像边界单位为像素默认值为 20
返回值
adjusted_boxes (torch.Tensor): adjusted bounding boxes
- 返回调整后的边界框张量其格式和输入的 'boxes' 张量一致只是坐标值根据规则进行了相应调整
"""
# Image dimensions
# 获取图像的高度和宽度,分别赋值给变量 h 和 w用于后续边界框坐标调整的判断依据
h, w = image_shape
# Adjust boxes
# 调整边界框的坐标,如果边界框的左上角 x 坐标boxes[:, 0])小于阈值,将其设置为 0使其紧贴图像左边界
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
# 如果边界框的左上角 y 坐标boxes[:, 1])小于阈值,将其设置为 0使其紧贴图像上边界
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
# 如果边界框的右下角 x 坐标boxes[:, 2])大于图像宽度减去阈值,将其设置为图像宽度 w使其紧贴图像右边界
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
# 如果边界框的右下角 y 坐标boxes[:, 3])大于图像高度减去阈值,将其设置为图像高度 h使其紧贴图像下边界
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes
@ -31,34 +42,55 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1 (torch.Tensor): (4, )
boxes (torch.Tensor): (n, 4)
Returns:
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
函数功能
该函数用于计算一个给定的边界框box1与一组其他边界框boxes之间的交并比Intersection-Over-Union简称 IoU
并可以根据设定的 IoU 阈值筛选出满足条件的边界框索引
参数说明
box1 (torch.Tensor): (4, )
- 单个边界框的坐标张量形状为 (4, )4 个值依次代表左上角的 x 坐标左上角的 y 坐标右下角的 x 坐标右下角的 y 坐标
boxes (torch.Tensor): (n, 4)
- 一组边界框的坐标张量形状为 (n, 4)n 表示边界框的数量每个边界框的坐标表示方式同 box1
iou_thres (int, optional): 交并比阈值默认值为 0.9用于筛选出与给定边界框 IoU 大于此阈值的其他边界框
image_shape (tuple, optional): (height, width)图像的形状默认值为 (640, 640)用于在计算 IoU 前调整边界框坐标调用了 adjust_bboxes_to_image_border 函数
raw_output (bool, optional): 是否返回原始的 IoU 默认值为 False如果为 True则直接返回计算得到的 IoU 0如果没有交集情况
如果为 False则返回 IoU 大于阈值的边界框索引
返回值
如果 raw_output True
返回计算得到的交并比IoU如果输入的 boxes 张量为空 iou.numel() == 0表示没有其他边界框与之计算 IoU则返回 0
如果 raw_output False
返回满足 IoU 大于阈值iou_thres条件的边界框在输入的 boxes 张量中的索引以扁平化后的张量形式返回
"""
# 先调用 adjust_bboxes_to_image_border 函数根据图像形状和阈值调整输入的边界框boxes坐标使其紧贴图像边界
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
# 以下是计算交并比IoU的步骤
# 对于每个边界框获取其与给定边界框box1在 x 方向上的交集的起始坐标(取两者左上角 x 坐标的最大值)
x1 = torch.max(box1[0], boxes[:, 0])
# 获取其与给定边界框在 y 方向上的交集的起始坐标(取两者左上角 y 坐标的最大值)
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# 获取其与给定边界框在 x 方向上的交集的结束坐标(取两者右下角 x 坐标的最小值)
x2 = torch.max(box1[2], boxes[:, 2])
# 获取其与给定边界框在 y 方向上的交集的结束坐标(取两者右下角 y 坐标的最小值)
y2 = torch.max(box1[3], boxes[:, 3])
# compute the area of intersection
# 计算交集区域的面积,通过计算交集区域在 x 和 y 方向上的边长(需要使用 clamp(0) 确保边长非负),然后相乘得到面积
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
# 计算给定边界框box1的面积通过右下角坐标与左上角坐标差值相乘得到
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
# 计算输入的一组边界框boxes中每个边界框的面积同样通过右下角坐标与左上角坐标差值相乘得到
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
# 计算并集区域的面积,通过两个边界框面积之和减去交集区域面积得到
union = box1_area + box2_area - intersection
# compute the IoU
# 计算交并比IoU即交集区域面积除以并集区域面积得到的结果形状应该为 (n, )n 为输入的边界框数量boxes 的行数)
iou = intersection / union # Should be shape (n, )
if raw_output:
return 0 if iou.numel() == 0 else iou
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()
# 根据设定的 IoU 阈值iou_thres筛选出 IoU 大于该阈值的边界框索引,通过 torch.nonzero 找到满足条件的位置索引,
# 并使用 flatten 方法将其扁平化,最终返回这些满足条件的边界框在输入的 boxes 张量中的索引
return torch.nonzero(iou > iou_thres).flatten()

@ -1,244 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from multiprocessing.pool import ThreadPool
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class FastSAMValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats

@ -1,13 +1,3 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
YOLO-NAS model interface.
Usage - Predict:
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from pathlib import Path
@ -28,30 +18,23 @@ class NAS(Model):
@smart_inference_mode()
def _load(self, weights: str, task: str):
# Load or create new NAS model
import super_gradients
suffix = Path(weights).suffix
if suffix == '.pt':
self.model = torch.load(weights)
elif suffix == '':
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
self.model.yaml = {}
self.model.task = 'detect'
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
@property

@ -7,16 +7,31 @@ from ultralytics.engine.results import Results
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
# NASPredictor类继承自BasePredictor用于对NAS模型预测结果进行后处理等相关操作以得到最终符合要求的检测结果
class NASPredictor(BasePredictor):
def postprocess(self, preds_in, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
"""
对模型的预测结果进行后处理包括坐标转换非极大值抑制以及结果的整理等操作最终返回处理后的检测结果列表
参数
preds_in模型的原始预测输入数据通常包含多个维度的信息这里是一个嵌套结构例如可能包含预测的边界框坐标类别概率等信息
img输入到模型中的图像数据可能是经过预处理后的形式用于在一些坐标转换操作中参考其尺寸信息比如尺寸缩放等操作会用到图像的高度和宽度维度信息
orig_imgs原始的图像数据可能是未经模型预处理的原始图像用于获取真实的图像尺寸以及与最终结果关联原始图像信息方便后续结果展示等使用其数据类型可能是列表或者张量等
返回值
results包含处理后检测结果的列表每个元素是一个Results类的实例包含了原始图像图像路径类别名称以及检测到的目标框等信息代表了对每一张原始图像的检测结果
"""
# Cat boxes and class scores
# 将预测的边界框坐标从左上角和右下角坐标形式xyxy转换为左上角坐标与宽高形式xywh便于后续处理和理解
boxes = xyxy2xywh(preds_in[0][0])
# 将转换后的边界框坐标信息与其他预测信息例如类别概率等这里是preds_in[0][1]所代表的内容)在最后一维进行拼接
# 然后进行维度变换,将最后一维调整到中间位置,以符合后续处理的格式要求
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
# 对拼接调整后的预测结果进行非极大值抑制NMS操作筛选出合适的预测框。
# 根据设定的置信度阈值self.args.conf、交并比阈值self.args.iou、是否进行类别无关的NMSself.args.agnostic_nms
# 最大检测目标数量self.args.max_det以及限定的类别self.args.classes等条件进行筛选。
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
@ -25,11 +40,18 @@ class NASPredictor(BasePredictor):
classes=self.args.classes)
results = []
# 遍历经过非极大值抑制后的每一组预测结果(针对每张图像的预测结果进行遍历)
for i, pred in enumerate(preds):
# 获取对应的原始图像如果orig_imgs是列表形式则取第i个元素作为原始图像如果不是列表可能是单张图像的张量形式等情况则直接使用orig_imgs作为原始图像
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
# 如果orig_imgs不是torch.Tensor类型可能是其他数据类型比如PIL图像等则需要根据输入图像img和原始图像orig_img的尺寸信息
# 对预测框的坐标进行缩放,使其坐标对应到原始图像的尺寸空间下,确保坐标的准确性。
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
# 获取图像的路径信息如果self.batch[0]是列表形式则取第i个元素作为图像路径如果不是列表可能是单一路径字符串等情况则直接使用self.batch[0]作为图像路径
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
# 创建一个Results类的实例将原始图像、图像路径、模型中的类别名称以及当前图像的预测框信息作为参数传入
# 以此来构建一个完整的检测结果对象并添加到results列表中最终得到所有图像的检测结果列表。
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
return results

@ -6,15 +6,52 @@ from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
# 定义模块的公开接口,使得当使用 `from module import *` 时,只会导入 '__all__' 列表中指定的类或函数等,这里指定了 'NASValidator' 类可被外部导入
__all__ = ['NASValidator']
# NASValidator类继承自DetectionValidator类主要用于对NAS相关模型预测结果进行后处理操作比如进行非极大值抑制等处理来得到最终可用的预测结果
class NASValidator(DetectionValidator):
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
"""
Apply Non-maximum suppression to prediction outputs.
函数功能
对模型的预测输出结果进行后处理具体操作包括先将边界框坐标格式进行转换然后拼接相关预测信息
最后通过非极大值抑制Non-maximum suppressionNMS算法筛选出合适的预测框返回处理后的预测结果
参数说明
preds_in模型输出的原始预测结果其结构通常比较复杂可能包含多层嵌套的张量等数据结构这里是一个包含多个元素的列表形式
其中 preds_in[0][0] 等索引对应着具体的预测信息如边界框坐标信息等具体结构取决于模型的输出格式定义
返回值
通过非极大值抑制处理后的预测结果其格式经过筛选和调整符合后续使用如评估可视化等的要求是经过处理后的符合一定条件的预测框等信息的集合
"""
# 将预测的边界框坐标从左上角和右下角坐标形式xyxy转换为左上角坐标与宽高形式xywh
# 这种坐标转换后的格式在后续一些计算和处理中可能更加方便,例如在计算面积、与其他坐标进行比较等操作时更符合常规思维和计算逻辑。
boxes = xyxy2xywh(preds_in[0][0])
# 将转换后的边界框坐标信息boxes与其他预测信息例如类别概率等这里是 preds_in[0][1] 所代表的内容)在最后一维进行拼接,
# 使得相关的预测信息整合在一起,方便后续作为一个整体进行处理。然后使用 permute 函数对维度进行重新排列,
# 将原本处于最后一维的拼接后信息调整到中间维度(调整为符合后续非极大值抑制等操作要求的维度顺序)。
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
# 调用非极大值抑制函数ops.non_max_suppression对拼接调整后的预测结果preds进行处理
# 根据设定的一系列参数来筛选出合适的预测框。以下是各参数的具体作用:
# self.args.conf置信度阈值只有置信度大于该阈值的预测框才有可能被保留下来用于过滤掉那些置信度较低、不太可靠的预测结果。
# self.args.iou交并比Intersection over UnionIoU阈值用于判断两个预测框之间的重叠程度
# 在非极大值抑制过程中如果两个预测框的IoU大于该阈值会根据一定规则比如置信度高低等决定保留其中一个去除另一个避免过多重叠的预测框出现。
# labels=self.lb传入标签信息self.lb可能在非极大值抑制过程中用于区分不同类别的预测框确保在筛选时按照不同类别分别进行处理符合实际的检测逻辑。
# multi_label=False表示是否允许多个标签情况这里设置为 False意味着每个预测框可能只对应一个类别标签具体取决于模型原本的设计和应用场景需求
# agnostic=self.args.single_cls如果 self.args.single_cls 为 True则表示进行类别无关的非极大值抑制即不考虑类别信息只要预测框之间的IoU等条件满足就进行筛选
# 如果为 False则会按照类别分别进行非极大值抑制操作这里根据传入的配置参数来决定具体的处理方式。
# max_det=self.args.max_det设置最大检测目标数量即经过非极大值抑制后每个图像最多保留的预测框数量用于控制最终结果的数量规模避免过多的预测框输出。
# max_time_img=0.5:限定处理每张图像预测结果的最大时间(单位可能根据具体实现而定,这里是 0.5,具体含义需看函数内部实现),
# 用于控制处理速度和资源消耗等情况,确保在合理时间内完成后处理操作。
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
@ -22,4 +59,4 @@ class NASValidator(DetectionValidator):
multi_label=False,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
max_time_img=0.5)
max_time_img=0.5)

@ -1,30 +1,52 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
RT-DETR model interface
此部分为文档字符串用于简要描述该模块的功能这里说明这个模块是关于RT-DETR模型的接口相关内容
"""
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from.predict import RTDETRPredictor
from.train import RTDETRTrainer
from.val import RTDETRValidator
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
# 定义RTDETR类它继承自Model类代表了RT-DETR模型相关的接口用于整合和管理该模型在预测、验证、训练等方面的具体实现类。
class RTDETR(Model):
"""
RTDETR model interface.
这是类的文档字符串再次强调该类是RT-DETR模型的接口用于表明类的主要作用
"""
def __init__(self, model='rtdetr-l.pt') -> None:
"""
初始化RTDETR类的实例
参数
model (str, 可选)指定模型的路径或文件名默认值为'rtdetr-l.pt'其格式应该是.pt.yaml或者.yml文件
因为后续会根据文件后缀判断是否支持创建模型其他格式暂不支持
抛出异常
如果传入的model参数对应的文件后缀不是.pt.yaml或.yml则抛出NotImplementedError异常
提示RT-DETR模型仅支持从这些类型的文件创建模型
"""
if model and not model.split('.')[-1] in ('pt', 'yaml', 'yml'):
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
super().__init__(model=model, task='detect')
@property
def task_map(self):
"""
定义一个属性property方法用于返回一个字典该字典建立了任务名称这里是'detect'代表检测任务
对应任务相关的具体类的映射关系方便根据任务名称获取相应的执行类例如预测任务对应的预测器类验证任务对应的验证器类
训练任务对应的训练器类以及模型本身对应的具体模型类
返回值
一个字典键为'task'这里固定为'detect'值为另一个字典其中包含了'predictor''validator''trainer''model'
这些键分别对应RTDETRPredictorRTDETRValidatorRTDETRTrainerRTDETRDetectionModel这些具体的类
用于明确不同任务所关联的具体实现类
"""
return {
'detect': {
'predictor': RTDETRPredictor,
'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
'model': RTDETRDetectionModel}}
'model': RTDETRDetectionModel}}

@ -22,22 +22,37 @@ from ultralytics.utils.ops import xywh2xyxy
def check_class_names(names):
"""Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
"""
Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
函数功能
对类别名称class names进行检查和处理如果输入的类别名称是列表形式则将其转换为字典形式
对于字典形式的类别名称进一步进行格式转换如将字符串键转换为整数键等并且会检查类别索引是否符合数据集要求
如果是 ImageNet 类别代码形式的名称还会将其映射为人类可读的名称
参数说明
names类别名称可以是列表或者字典形式列表形式下每个元素代表一个类别名称字典形式下键值对表示类别索引和对应的名称
返回值
经过处理后的类别名称字典键为整数类型的类别索引值为对应的类别名称字符串格式符合后续模型处理和使用的要求
"""
if isinstance(names, list): # names是列表形式
names = dict(enumerate(names)) # 将列表转换为字典,键为索引,值为对应名称
if isinstance(names, dict):
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
# 将字典中的 1) 字符串键转换为整数,例如 '0' 转换为 0以及将非字符串值转换为字符串例如 True 转换为 'True'
names = {int(k): str(v) for k, v in names.items()}
n = len(names)
if max(names.keys()) >= n:
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
if isinstance(names[0], str) and names[0].startswith('n0'): # 如果是 ImageNet 类别代码形式,例如 'n01440764'
map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # 加载用于映射为可读名称的映射表
names = {k: map[v] for k, v in names.items()}
return names
# AutoBackend类继承自nn.Module是一个多后端模型类用于在不同平台上使用Ultralytics YOLO进行Python推理支持多种模型格式的加载和推理操作。
class AutoBackend(nn.Module):
def __init__(self,
@ -51,161 +66,169 @@ class AutoBackend(nn.Module):
"""
MultiBackend class for python inference on various platforms using Ultralytics YOLO.
Args:
weights (str): The path to the weights file. Default: 'yolov8n.pt'
device (torch.device): The device to run the model on.
dnn (bool): Use OpenCV DNN module for inference if True, defaults to False.
data (str | Path | optional): Additional data.yaml file for class names.
fp16 (bool): If True, use half precision. Default: False
fuse (bool): Whether to fuse the model or not. Default: True
verbose (bool): Whether to run in verbose mode or not. Default: True
Supported formats and their naming conventions:
| Format | Suffix |
|-----------------------|------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx dnn=True |
| OpenVINO | *.xml |
| CoreML | *.mlmodel |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
| ncnn | *_ncnn_model |
函数功能
初始化AutoBackend类的实例根据传入的参数配置加载相应格式的模型并进行一系列的初始化设置
包括确定模型类型下载模型如果需要加载模型权重处理模型元数据检查类别名称等操作
为后续的模型推理做好准备
参数说明
weights (str)模型权重文件的路径默认值为 'yolov8n.pt'指定了要加载的模型权重文件位置
device (torch.device)运行模型的设备例如可以指定为 'cpu' 或者 'cuda'如果有GPU可用
dnn (bool)是否使用OpenCV的DNN模块进行推理默认为False
data (str | Path | optional)额外的数据.yaml文件路径可用于获取类别名称等信息如果需要的话
fp16 (bool)是否使用半精度half precision进行计算默认为False若为True则在支持的情况下以半精度运行模型可加快推理速度并节省内存但可能会有一定精度损失
fuse (bool)是否对模型进行融合操作例如融合一些层来优化模型结构和性能默认值为True
verbose (bool)是否运行在详细模式下即是否打印更多的加载处理等过程中的相关信息默认值为True
支持的模型格式及其命名约定
| Format | Suffix |
|-----------------------|------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx dnn=True |
| OpenVINO | *.xml |
| CoreML | *.mlmodel |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
| ncnn | *_ncnn_model |
"""
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
# 通过调用_model_type方法判断模型的类型分别对应不同的模型格式如PyTorch、TorchScript等后续根据这些类型进行相应的加载处理
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
# 根据模型类型和fp16参数确定是否使用半精度只有在是PyTorch、TorchScript、ONNX、OpenVINO、TensorRT或者是内存中的PyTorch模块、Triton格式时才可能使用半精度
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton
# 判断模型格式是否为CoreML、TensorFlow SavedModel、TensorFlow GraphDef、TensorFlow Lite或者TensorFlow Edge TPU这些采用BHWC格式与torch的BCWH格式相对的情况
nhwc = coreml or saved_model or pb or tflite or edgetpu
stride = 32 # 默认步长,后续可能会根据模型实际情况更新
model, metadata = None, None
# Set device
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats
# 判断是否可以使用CUDA即GPU是否可用且设备类型不是'cpu'如果满足条件且模型格式不是内存中的PyTorch模块、PyTorch、TorchScript、TensorRT这些GPU数据加载器相关格式则将设备设置为'cpu'并禁用CUDA
cuda = torch.cuda.is_available() and device.type!= 'cpu' # 使用CUDA
if cuda and not any([nn_module, pt, jit, engine]): # GPU数据加载器格式
device = torch.device('cpu')
cuda = False
# Download if not local
# 如果模型权重文件不是本地文件(即不是以.pt、.triton格式或者不是内存中的PyTorch模块形式存在尝试下载模型权重文件
if not (pt or triton or nn_module):
w = attempt_download_asset(w)
# Load model
if nn_module: # in-memory PyTorch model
# 根据不同的模型类型加载模型权重及相关配置信息
if nn_module: # 内存中的PyTorch模型
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, 'kpt_shape'):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
kpt_shape = model.kpt_shape # 仅在姿态相关任务时存在,获取关键点形状信息
stride = max(int(model.stride.max()), 32) # 获取模型的最大步长并确保不小于32
names = model.module.names if hasattr(model, 'module') else model.names # 获取类别名称
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
self.model = model # 显式地将模型赋值给self.model方便后续进行to()、cpu()、cuda()、half()等操作
pt = True
elif pt: # PyTorch
elif pt: # PyTorch格式模型
from ultralytics.nn.tasks import attempt_load_weights
model = attempt_load_weights(weights if isinstance(weights, list) else w,
device=device,
inplace=True,
fuse=fuse)
if hasattr(model, 'kpt_shape'):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
kpt_shape = model.kpt_shape # 姿态相关任务时获取关键点形状信息
stride = max(int(model.stride.max()), 32) # 获取模型最大步长确保不小于32
names = model.module.names if hasattr(model, 'module') else model.names # 获取类别名称
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript
self.model = model # 显式赋值给self.model方便后续操作
elif jit: # TorchScript格式模型
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
extra_files = {'config.txt': ''} # 用于存储模型元数据的额外文件字典
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files['config.txt']: # load metadata dict
if extra_files['config.txt']: # 如果配置文件存在,加载其中的元数据字典信息
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
elif dnn: # ONNX OpenCV DNN
elif dnn: # ONNX OpenCV DNN格式模型
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements('opencv-python>=4.5.4')
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
elif onnx: # ONNX Runtime格式模型
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
metadata = session.get_modelmeta().custom_metadata_map # 获取模型元数据中的自定义元数据映射
elif xml: # OpenVINO格式模型
LOGGER.info(f'Loading {w} for OpenVINO inference...')
check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch # noqa
check_requirements('openvino>=2023.0') # 需要安装openvino-dev指定了其在PyPI上的项目链接
from openvino.runtime import Core, Layout, get_batch # 导入相关模块用于后续操作
core = Core()
w = Path(w)
if not w.is_file(): # if not *.xml
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
if not w.is_file(): # 如果传入的不是.xml文件
w = next(w.glob('*.xml')) # 从 *_openvino_model 目录下获取.xml文件
ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin'))
if ov_model.get_parameters()[0].get_layout().empty:
ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
batch_dim = get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # 根据可用设备自动选择最佳设备进行编译
metadata = w.parent / 'metadata.yaml'
elif engine: # TensorRT
elif engine: # TensorRT格式模型
LOGGER.info(f'Loading {w} for TensorRT inference...')
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
import tensorrt as trt # 导入TensorRT模块若导入失败会根据系统情况尝试重新导入
except ImportError:
if LINUX:
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
import tensorrt as trt # noqa
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
import tensorrt as trt
check_version(trt.__version__, '7.0.0', hard=True) # 要求TensorRT版本大于等于7.0.0
if device.type == 'cpu':
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
# Read file
# 读取模型文件,先读取文件中的元数据长度,再读取元数据内容,最后读取模型引擎内容,并创建执行上下文等相关对象,用于后续推理
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
meta_len = int.from_bytes(f.read(4), byteorder='little') # 读取元数据长度
metadata = json.loads(f.read(meta_len).decode('utf-8')) # 读取元数据内容
model = runtime.deserialize_cuda_engine(f.read()) # 读取模型引擎
context = model.create_execution_context()
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
fp16 = False # 默认值,后续可能根据模型数据类型更新
dynamic = False
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
if -1 in tuple(model.get_binding_shape(i)): # 判断是否是动态输入形状
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else: # output
else: # 输出相关操作
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML
batch_size = bindings['images'].shape[0] # 如果是动态形状,这里获取的是最大批次大小
elif coreml: # CoreML格式模型
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
elif saved_model: # TF SavedModel
elif saved_model: # TensorFlow SavedModel格式模型
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
keras = False # assume TF1 saved_model
keras = False # 假设是TensorFlow 1的SavedModel格式非Keras模型
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
metadata = Path(w) / 'metadata.yaml'
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
elif pb: # TensorFlow GraphDef格式模型
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
@ -213,280 +236,30 @@ class AutoBackend(nn.Module):
def wrap_frozen_graph(gd, inputs, outputs):
"""Wrap frozen graphs for deployment."""
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # 对图进行包装
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
gd = tf.Graph().as_graph_def() # TF GraphDef
gd = tf.Graph().as_graph_def() # 获取TensorFlow的GraphDef对象
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
elif tflite or edgetpu: # TensorFlow Lite或者TensorFlow Edge TPU格式模型
try: # 尝试导入TensorFlow Lite运行时相关模块如果导入失败则导入TensorFlow相关模块来获取相应功能
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
if edgetpu: # TensorFlow Edge TPU格式
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
else: # TensorFlow Lite格式
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# Load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, 'r') as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
elif tfjs: # TF.js
raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi # noqa
w = Path(w)
if not w.is_file(): # if not *.pdmodel
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / 'metadata.yaml'
elif ncnn: # ncnn
LOGGER.info(f'Loading {w} for ncnn inference...')
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
import ncnn as pyncnn
net = pyncnn.Net()
net.opt.use_vulkan_compute = cuda
w = Path(w)
if not w.is_file(): # if not *.param
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
net.load_param(str(w))
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server
"""TODO
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
"""
raise NotImplementedError('Triton Inference Server is not currently supported.')
else:
from ultralytics.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. "
'See https://docs.ultralytics.com/modes/predict for help.'
f'\n\n{export_formats()}')
# Load external metadata YAML
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
metadata = yaml_load(metadata)
if metadata:
for k, v in metadata.items():
if k in ('stride', 'batch'):
metadata[k] = int(v)
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
metadata[k] = eval(v)
stride = metadata['stride']
task = metadata['task']
batch = metadata['batch']
imgsz = metadata['imgsz']
names = metadata['names']
kpt_shape = metadata.get('kpt_shape')
elif not (pt or triton or nn_module):
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
# Check names
if 'names' not in locals(): # names missing
names = self._apply_default_class_names(data)
names = check_class_names(names)
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False):
"""
Runs inference on the YOLOv8 MultiBackend model.
Args:
im (torch.Tensor): The image tensor to perform inference on.
augment (bool): whether to perform data augmentation during inference, defaults to False
visualize (bool): whether to visualize the output predictions, defaults to False
Returns:
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
"""
b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16:
im = im.half() # to FP16
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pt or self.nn_module: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
elif self.jit: # TorchScript
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
elif self.onnx: # ONNX Runtime
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
elif self.xml: # OpenVINO
im = im.cpu().numpy() # FP32
y = list(self.ov_compiled_model(im).values())
elif self.engine: # TensorRT
if self.dynamic and im.shape != self.bindings['images'].shape:
i = self.model.get_binding_index('images')
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings['images'].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im[0].cpu().numpy()
im_pil = Image.fromarray((im * 255).astype('uint8'))
# im = im.resize((192, 320), Image.BILINEAR)
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
if 'confidence' in y:
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif len(y) == 1: # classification model
y = list(y.values())
elif len(y) == 2: # segmentation model
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
elif self.paddle: # PaddlePaddle
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.ncnn: # ncnn
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
ex = self.net.create_extractor()
input_names, output_names = self.net.input_names(), self.net.output_names()
ex.input(input_names[0], mat_in)
y = []
for output_name in output_names:
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
if not isinstance(y, list):
y = [y]
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im))
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
self.names = {i: f'class{i}' for i in range(nc)}
else: # Lite or Edge TPU
details = self.input_details[0]
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
if integer:
scale, zero_point = details['quantization']
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
self.interpreter.set_tensor(details['index'], im)
self.interpreter.invoke()
y = []
for output in self.output_details:
x = self.interpreter.get_tensor(output['index'])
if integer:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim > 2: # if task is not classification
# Denormalize xywh with input image size
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
x[:, 0] *= w
x[:, 1] *= h
x[:, 2] *= w
x[:, 3] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# for x in y:
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
if isinstance(y, (list, tuple)):
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
else:
return self.from_numpy(y)
def from_numpy(self, x):
"""
Convert a numpy array to a tensor.
Args:
x (np.ndarray): The array to be converted.
Returns:
(torch.Tensor): The converted tensor
"""
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
def warmup(self, imgsz=(1, 3, 640, 640)):
"""
Warm up the model by running one forward pass with a dummy input.
Args:
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
Returns:
(None): This method runs the forward pass and don't return any value
"""
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup
@staticmethod
def _apply_default_class_names(data):
"""Applies default class names to an input YAML file or returns numerical class names."""
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
@staticmethod
def _model_type(p='path/to/model.pt'):
"""
This function takes a path to a model file and returns the model type
Args:
p: path to the model file. Defaults to path/to/model.pt
"""
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from ultralytics.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False) and not isinstance(p, str):
check_suffix(p, sf) # checks
url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf]
types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
return types + [triton]
interpreter = Interpreter(model_path=w) # 加载TensorFlow Lite模型
interpreter.allocate_tensors() # 为模型分配张量内存
input_details = interpreter.get_input_details() # 获取输入张量的详细信息
output_details

@ -25,18 +25,17 @@ except ImportError:
class BaseModel(nn.Module):
"""
The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
"""
#lxy
def forward(self, x, *args, **kwargs):
"""
Forward pass of the model on a single scale.
Wrapper for `_forward_once` method.
# lxy
Args:
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
# lxy
Returns:
(torch.Tensor): The output of the network.
"""
@ -47,7 +46,7 @@ class BaseModel(nn.Module):
def predict(self, x, profile=False, visualize=False, augment=False):
"""
Perform a forward pass through the network.
# lxy
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
@ -64,12 +63,12 @@ class BaseModel(nn.Module):
def _predict_once(self, x, profile=False, visualize=False):
"""
Perform a forward pass through the network.
# lxy
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
#lxy
Returns:
(torch.Tensor): The last output of the model.
"""
@ -85,25 +84,27 @@ class BaseModel(nn.Module):
feature_visualization(x, m.type, m.i, save_dir=visualize)
return x
# lxy
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
f'Reverting to single-scale inference instead.')
return self._predict_once(x)
# lxy
def _profile_one_layer(self, m, x, dt):
"""
Profile the computation time and FLOPs of a single layer of the model on a given input.
Appends the results to the provided list.
#lxy
Args:
m (nn.Module): The layer to be profiled.
x (torch.Tensor): The input data to the layer.
dt (list): A list to store the computation time of the layer.
#lxy
Returns:
None
"""
""
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_sync()
@ -115,7 +116,7 @@ class BaseModel(nn.Module):
LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
if c:
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
#lxy
def fuse(self, verbose=True):
"""
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
@ -149,7 +150,7 @@ class BaseModel(nn.Module):
Args:
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
#lxy
Returns:
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
@ -165,7 +166,7 @@ class BaseModel(nn.Module):
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
"""
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
def _apply(self, fn):
"""
`_apply()` is a function that applies a function to all the tensors in the model that are not

@ -5,67 +5,131 @@ from collections import OrderedDict
import numpy as np
# TrackState类是一个枚举类型虽然Python中没有严格的枚举语法但通过类来模拟用于定义目标跟踪过程中可能出现的不同状态。
class TrackState:
"""Enumeration of possible object tracking states."""
"""
Enumeration of possible object tracking states.
这个类的作用是清晰地列举出目标跟踪时对象可能处于的各种状态方便在跟踪算法的不同阶段对目标状态进行判断和处理
"""
# 表示新出现的目标,即刚刚进入跟踪视野,还未经过后续稳定跟踪等处理的初始状态。
New = 0
# 表示目标正在被稳定跟踪,意味着当前时刻能够获取到该目标较为准确的位置、特征等信息,跟踪算法正常作用于该目标。
Tracked = 1
# 表示目标暂时失去了跟踪线索,可能是由于遮挡、目标运动过快等原因导致当前帧难以准确获取其相关信息,但仍有可能后续再次恢复跟踪。
Lost = 2
# 表示目标已经从跟踪场景中移除,例如目标离开了监控范围或者确定不再需要对其进行跟踪了,通常是一种最终状态。
Removed = 3
# BaseTrack类是目标跟踪的基类它定义了一系列与目标跟踪相关的基本属性以及通用的操作方法
# 具体的跟踪算法类可以继承自这个基类,并根据自身需求重写相应的方法来实现特定的跟踪逻辑。
class BaseTrack:
"""Base class for object tracking, handling basic track attributes and operations."""
"""
Base class for object tracking, handling basic track attributes and operations.
此类作为基础框架旨在提供统一的结构便于不同的跟踪实现共享一些通用的属性和行为模式提高代码的复用性和可维护性
"""
# 用于记录全局的跟踪ID计数初始化为0每创建一个新的跟踪目标时会自增为每个跟踪目标分配唯一的ID。
_count = 0
# 每个跟踪目标的唯一标识符用于区分不同的目标初始值为0后续通过特定方法来分配具体的唯一ID。
track_id = 0
# 表示该跟踪目标是否处于激活状态激活状态意味着当前目标正在被有效跟踪相关信息正在实时更新等初始为False。
is_activated = False
# 目标当前所处的跟踪状态初始化为TrackState.New表示刚出现的新目标状态后续会根据跟踪情况进行相应改变。
state = TrackState.New
# 一个有序字典,用于存储目标在不同帧中的历史信息,例如历史位置、特征等,方便后续回顾和分析目标的运动轨迹等情况。
history = OrderedDict()
# 存储目标的特征信息列表,这些特征可以是用于识别目标、与其他目标区分开来的各种描述符,比如外观特征等。
features = []
# 当前帧提取到的目标特征,用于和历史特征结合或者单独进行一些与当前帧相关的处理,比如匹配、更新等操作。
curr_feature = None
# 目标的置信度得分用于衡量目标被正确跟踪或者识别的可能性大小得分越高表示越可靠初始值为0。
score = 0
# 目标首次出现的帧编号,用于确定目标进入跟踪视野的起始时间点,便于后续统计目标跟踪的时长等信息。
start_frame = 0
# 当前帧的编号,随着跟踪过程中每一帧的处理而更新,用于同步目标在不同时间点的状态和信息。
frame_id = 0
# 记录自上次成功更新目标信息以来经过的帧数,若该值过大,可能意味着目标即将或已经处于丢失状态,可用于判断是否需要标记为丢失等操作。
time_since_update = 0
# Multi-camera
# 用于多相机场景下,表示目标所在的位置信息,初始化为两个无穷大值组成的元组,具体的坐标含义和赋值会根据多相机系统的设定来确定。
# 例如可能表示目标在某个全局坐标系下的坐标位置等情况。
location = (np.inf, np.inf)
@property
def end_frame(self):
"""Return the last frame ID of the track."""
"""
Return the last frame ID of the track.
此属性方法用于获取该跟踪目标最后出现的帧编号也就是跟踪结束时所在的帧方便统计跟踪目标的生命周期查看跟踪过程覆盖的帧数范围等
"""
return self.frame_id
@staticmethod
def next_id():
"""Increment and return the global track ID counter."""
"""
Increment and return the global track ID counter.
这是一个静态方法用于递增全局的跟踪ID计数器并返回递增后的ID值确保每个新创建的跟踪目标都能获取到唯一的ID方便在整个跟踪系统中对不同目标进行标识和区分
"""
BaseTrack._count += 1
return BaseTrack._count
def activate(self, *args):
"""Activate the track with the provided arguments."""
"""
Activate the track with the provided arguments.
此方法用于激活跟踪目标根据传入的参数来初始化或更新目标的相关信息使其进入激活状态开始正常的跟踪流程
具体的激活操作内容需要由继承自该基类的子类根据实际跟踪算法需求来重写实现这里只是定义了接口抛出未实现异常
"""
raise NotImplementedError
def predict(self):
"""Predict the next state of the track."""
"""
Predict the next state of the track.
用于预测跟踪目标在下一帧的状态例如预测目标的位置速度等信息不同的跟踪算法可能有不同的预测策略和依据
同样此方法需要子类重写来实现具体的预测逻辑这里仅定义了抽象的方法结构抛出未实现异常
"""
raise NotImplementedError
def update(self, *args, **kwargs):
"""Update the track with new observations."""
"""
Update the track with new observations.
依据新获取到的观测信息比如新的位置检测结果新的特征信息等来更新跟踪目标的各项属性使跟踪信息保持最新和准确
该方法的具体实现依赖于具体的跟踪算法所以在基类中只是定义了接口由子类根据实际情况重写具体的更新操作逻辑此处抛出未实现异常
"""
raise NotImplementedError
def mark_lost(self):
"""Mark the track as lost."""
"""
Mark the track as lost.
将跟踪目标标记为丢失状态意味着当前目标暂时失去了有效的跟踪线索通过修改目标的状态属性为TrackState.Lost来实现这一标记操作
后续可以根据此状态采取相应的处理策略比如尝试重新搜索目标等待一段时间后删除该目标等
"""
self.state = TrackState.Lost
def mark_removed(self):
"""Mark the track as removed."""
"""
Mark the track as removed.
把跟踪目标标记为已移除状态表明该目标已经彻底从跟踪场景中去除不再参与后续的跟踪流程通过将状态属性设置为TrackState.Removed来进行标记
通常在确定目标不会再出现或者已经完成对其跟踪需求后进行此操作
"""
self.state = TrackState.Removed
@staticmethod
def reset_id():
"""Reset the global track ID counter."""
BaseTrack._count = 0
"""
Reset the global track ID counter.
静态方法用于重置全局的跟踪ID计数器将其重新初始化为0一般在需要重新开始跟踪任务或者清空之前的跟踪记录重新分配ID等场景下使用
"""
BaseTrack._count = 0

@ -4,63 +4,139 @@ from collections import deque
import numpy as np
from .basetrack import TrackState
from .byte_tracker import BYTETracker, STrack
from .utils import matching
from .utils.gmc import GMC
from .utils.kalman_filter import KalmanFilterXYWH
# 从相关模块中导入必要的类和函数TrackState用于表示跟踪状态BYTETracker、STrack是与目标跟踪相关的基础类
# matching模块可能包含一些用于匹配的函数GMC可能是与全局运动补偿相关的类KalmanFilterXYWH用于卡尔曼滤波操作。
from.basetrack import TrackState
from.byte_tracker import BYTETracker, STrack
from.utils import matching
from.utils.gmc import GMC
from.utils.kalman_filter import KalmanFilterXYWH
# BOTrack类继承自STrack类主要用于在目标跟踪过程中对YOLOv8检测到的目标进行更具体的跟踪相关处理
# 例如特征更新、预测、坐标转换等操作,融入了一些特定于该算法的逻辑和属性。
class BOTrack(STrack):
# 定义一个共享的卡尔曼滤波器实例用于多个目标跟踪时的状态预测等操作所有BOTrack实例可以共用这个滤波器节省资源并保证一致性。
shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
"""
Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features.
函数功能
初始化BOTrack类的实例用于创建一个可跟踪的目标对象除了调用父类的初始化方法外还初始化了一些与特征处理平滑相关的属性
比如特征历史队列特征平滑相关的参数等用于后续对目标特征的管理和更新操作
参数说明
tlwh目标的边界框坐标信息格式为 (top left x, top left y, width, height)表示目标在图像中的位置和大小范围
score目标的置信度得分用于衡量该目标被正确检测到的可能性大小取值范围通常在0到1之间得分越高表示越可靠
cls目标所属的类别用于标识目标的类别信息例如是行人车辆等不同类型的物体
feat可选目标的特征向量用于描述目标的外观纹理等特征信息可用于后续的目标匹配识别等操作如果没有提供则默认为None
feat_history可选特征历史队列的最大长度用于限制保存的历史特征数量避免内存占用过多默认值为50表示最多保存50个历史特征
"""
super().__init__(tlwh, score, cls)
# 用于保存平滑后的特征向量初始化为None后续会根据特征更新情况进行赋值和更新用于获取更稳定、抗噪性更好的目标特征表示。
self.smooth_feat = None
# 当前帧获取到的目标特征向量,用于和历史特征结合更新平滑特征,或者单独用于一些即时的处理,如与其他目标特征进行匹配等操作。
self.curr_feat = None
if feat is not None:
self.update_features(feat)
# 创建一个双端队列用于存储目标的历史特征向量设置了最大长度为feat_history新特征加入时如果队列已满会自动删除最早的特征。
self.features = deque([], maxlen=feat_history)
# 特征平滑的权重参数用于在更新平滑特征时确定当前特征和历史平滑特征的融合比例取值范围在0到1之间这里初始化为0.9。
self.alpha = 0.9
def update_features(self, feat):
"""Update features vector and smooth it using exponential moving average."""
"""
Update features vector and smooth it using exponential moving average.
函数功能
更新目标的特征向量并使用指数移动平均Exponential Moving AverageEMA的方法对特征进行平滑处理
使得特征在时间序列上更加稳定减少噪声和异常值的影响同时将更新后的特征添加到历史特征队列中
参数说明
feat新获取到的目标特征向量用于更新当前特征和平滑特征并添加到历史特征队列中
"""
# 对传入的特征向量进行归一化处理,使其具有单位长度,方便后续在特征空间中的比较、计算等操作,例如计算相似度等。
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
# 使用指数移动平均的方式更新平滑特征根据设定的权重参数alpha将当前特征和历史平滑特征进行加权融合
# 使得平滑特征能够逐渐适应目标特征的变化,同时又能保持一定的稳定性,减少突发噪声等因素的干扰。
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self):
"""Predicts the mean and covariance using Kalman filter."""
"""
Predicts the mean and covariance using Kalman filter.
函数功能
使用卡尔曼滤波器Kalman filter对目标的状态这里主要是位置速度等相关信息通过均值和协方差来表示进行预测
得到目标在下一时刻的预期状态以便后续根据实际观测进行更新和修正
具体操作
先复制当前的均值状态mean然后根据目标当前的跟踪状态判断是否处于Tracked状态如果不是Tracked状态比如可能是丢失或新出现等情况
则将与速度相关的维度这里假设是第6和第7维度具体取决于卡尔曼滤波器的状态向量定义设置为0意味着速度为0
最后通过共享的卡尔曼滤波器实例self.kalman_filter实际上就是类属性shared_kalman进行预测更新目标的均值和协方差状态
"""
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
if self.state!= TrackState.Tracked:
mean_state[6] = 0
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a track with updated features and optionally assigns a new ID."""
"""
Reactivates a track with updated features and optionally assigns a new ID.
函数功能
重新激活一个跟踪目标当目标曾经丢失后又重新出现等情况时调用此方法会根据新的跟踪信息这里是new_track所包含的信息更新目标的特征
并调用父类的re_activate方法完成其他与重新激活相关的通用操作同时可以选择是否为目标分配一个新的ID由new_id参数决定
参数说明
new_track包含了目标新的跟踪信息的对象例如新的位置特征等信息用于更新当前目标的相关属性
frame_id当前帧的编号用于记录目标重新激活所在的时间点方便后续统计跟踪的时间信息等
new_id可选布尔值用于决定是否为目标分配一个新的唯一标识符ID默认为False表示不分配新ID继续使用原来的ID
"""
if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat)
super().re_activate(new_track, frame_id, new_id)
def update(self, new_track, frame_id):
"""Update the YOLOv8 instance with new track and frame ID."""
"""
Update the YOLOv8 instance with new track and frame ID.
函数功能
根据新的跟踪信息new_track和当前帧编号frame_id来更新目标的相关属性特别是利用新的特征信息更新目标的特征相关属性
并调用父类的update方法完成其他通用的更新操作确保目标的跟踪信息始终保持最新状态
参数说明
new_track包含目标最新跟踪信息的对象例如新的位置特征等用于更新当前目标的对应属性
frame_id当前帧的编号用于记录目标更新信息所在的时间点便于后续分析目标的跟踪轨迹等情况
"""
if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat)
super().update(new_track, frame_id)
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
函数功能
获取目标当前位置的边界框坐标信息格式为 (top left x, top left y, width, height)
如果均值状态self.mean为空可能在某些初始化或者尚未更新状态的情况下则返回初始的边界框坐标self._tlwh的副本
否则根据均值状态计算并返回对应的边界框坐标信息计算方式是将中心坐标均值的前两个维度减去宽度和高度的一半得到左上角坐标
与宽度高度一起组成边界框坐标信息返回
返回值
目标当前位置的边界框坐标信息格式为 (top left x, top left y, width, height)
"""
if self.mean is None:
return self._tlwh.copy()
@ -70,13 +146,28 @@ class BOTrack(STrack):
@staticmethod
def multi_predict(stracks):
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
"""
Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
函数功能
对多个目标跟踪对象stracks的状态均值和协方差同时进行预测使用共享的卡尔曼滤波器BOTrack.shared_kalman来处理
方便在批量处理多个目标跟踪时统一进行状态预测操作提高效率并且保证所有目标在相同的滤波规则下进行预测
参数说明
stracks包含多个目标跟踪对象的列表每个对象都有自己的状态信息如均值协方差等需要进行下一时刻的状态预测操作
具体操作
首先判断传入的目标跟踪对象列表长度是否大于0如果为空则直接返回不进行预测操作然后获取所有目标跟踪对象的均值状态数组multi_mean
和协方差数组multi_covariance接着遍历每个目标跟踪对象根据其跟踪状态判断是否处于Tracked状态如果不是则将其均值状态中与速度相关的维度
同样假设是第6和第7维度具体取决于卡尔曼滤波器的状态向量定义设置为0表示速度为0最后使用共享的卡尔曼滤波器对所有目标的均值和协方差进行批量预测
并将更新后的均值和协方差分别赋值回对应的目标跟踪对象中完成多个目标的状态预测操作
"""
if len(stracks) <= 0:
return
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
if st.state!= TrackState.Tracked:
multi_mean[i][6] = 0
multi_mean[i][7] = 0
multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
@ -85,25 +176,60 @@ class BOTrack(STrack):
stracks[i].covariance = cov
def convert_coords(self, tlwh):
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
"""
Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format.
函数功能
将给定的边界框坐标格式从左上角宽高形式Top-Left-Width-Heighttlwh转换为中心坐标宽高形式X-Y-Width-Height
以便在不同的计算处理场景下使用合适的坐标表示形式例如某些算法可能更适合使用中心坐标形式进行操作
参数说明
tlwh边界框坐标信息格式为 (top left x, top left y, width, height)需要进行坐标格式转换
返回值
转换后的边界框坐标信息格式为 (center x, center y, width, height)即中心坐标宽高形式
"""
return self.tlwh_to_xywh(tlwh)
@staticmethod
def tlwh_to_xywh(tlwh):
"""Convert bounding box to format `(center x, center y, width,
"""
Convert bounding box to format `(center x, center y, width,
height)`.
函数功能
静态方法用于将边界框坐标从左上角宽高形式tlwh转换为中心坐标宽高形式xywh通过将左上角坐标加上宽高的一半得到中心坐标
与原来的宽高信息一起组成新的坐标表示形式返回方便在不同的目标处理计算场景下统一坐标格式
参数说明
tlwh边界框坐标信息格式为 (top left x, top left y, width, height)需要进行坐标格式转换
返回值
转换后的边界框坐标信息格式为 (center x, center y, width, height)即中心坐标宽高形式
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
# BOTSORT类继承自BYTETracker类主要用于实现基于特定参数和算法的目标跟踪功能集成了如ReID模块虽然目前部分功能尚未支持
# 全局运动补偿GMC算法等用于处理目标跟踪过程中的不同需求例如目标匹配、状态更新等操作。
class BOTSORT(BYTETracker):
def __init__(self, args, frame_rate=30):
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
"""
Initialize YOLOv8 object with ReID module and GMC algorithm.
函数功能
初始化BOTSORT类的实例除了调用父类的初始化方法外还根据传入的参数args初始化与目标跟踪相关的一些特定属性
比如距离阈值proximity_threshappearance_thresh以及创建全局运动补偿GMC算法的实例等用于后续目标跟踪过程中的相关处理操作
参数说明
args包含各种配置参数的对象这些参数会影响目标跟踪的具体行为和效果例如阈值设置是否启用某些功能等具体参数内容由外部传入决定
frame_rate可选帧率用于表示视频等数据的帧率情况默认值为30表示每秒30帧可用于一些与时间速度相关的计算或者算法调整等操作虽然此处代码中未明确体现具体使用情况
"""
super().__init__(args, frame_rate)
# ReID module
# ReID模块相关的距离阈值,用于判断目标之间在外观特征等方面的接近程度,在目标匹配等操作中会用到,具体含义和使用方式取决于具体算法逻辑。
self.proximity_thresh = args.proximity_thresh
self.appearance_thresh = args.appearance_thresh
@ -114,35 +240,37 @@ class BOTSORT(BYTETracker):
self.gmc = GMC(method=args.cmc_method)
def get_kalmanfilter(self):
"""Returns an instance of KalmanFilterXYWH for object tracking."""
"""
Returns an instance of KalmanFilterXYWH for object tracking.
函数功能
返回一个用于目标跟踪的卡尔曼滤波器KalmanFilterXYWH实例方便在目标跟踪过程中进行状态预测等操作
外部可以通过调用此方法获取滤波器实例来进行相关的跟踪计算
返回值
一个KalmanFilterXYWH类的实例用于目标跟踪中的状态预测等操作
"""
return KalmanFilterXYWH()
def init_track(self, dets, scores, cls, img=None):
"""Initialize track with detections, scores, and classes."""
if len(dets) == 0:
return []
if self.args.with_reid and self.encoder is not None:
features_keep = self.encoder.inference(img, dets)
return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
else:
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
"""
Initialize track with detections, scores, and classes.
def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
dists = matching.iou_distance(tracks, detections)
dists_mask = (dists > self.proximity_thresh)
函数功能
根据检测到的目标信息边界框坐标置信度得分类别初始化跟踪目标在启用ReID基于外观特征的目标重识别模块的情况下
还会利用编码器encoder提取目标的特征信息用于后续跟踪最终返回初始化后的跟踪目标对象列表
# TODO: mot20
# if not self.args.mot20:
dists = matching.fuse_score(dists, detections)
参数说明
dets检测到的目标边界框坐标信息列表每个元素表示一个目标的边界框坐标格式可能根据具体检测算法而定例如常见的xyxy或者tlwh等格式
scores对应检测目标的置信度得分列表与dets列表中的目标一一对应用于衡量每个检测结果的可靠性
cls检测目标所属的类别列表同样与dets列表中的目标一一对应用于标识每个目标的类别属性
img可选图像数据在启用ReID模块且编码器可用的情况下用于提取目标的外观特征信息如果不需要提取特征或者未启用相应功能则可以为None
返回值
包含初始化后的跟踪目标对象的列表每个对象都包含了目标的边界框坐标置信度得分类别以及如果有外观特征等跟踪相关信息用于后续的目标跟踪操作
"""
if len(dets) == 0:
return []
if self.args.with_reid and self.encoder is not None:
emb_dists = matching.embedding_distance(tracks, detections) / 2.0
emb_dists[emb_dists > self.appearance_thresh] = 1.0
emb_dists[dists_mask] = 1.0
dists = np.minimum(dists, emb_dists)
return dists
def multi_predict(self, tracks):
"""Predict and track multiple objects with YOLOv8 model."""
BOTrack.multi_predict(tracks)
features_keep = self.encoder.inference(img, dets)
return [BOTrack(xyxy, s, c, f) for

@ -4,12 +4,16 @@ from functools import partial
import torch
# 导入用于创建简单命名空间以及加载YAML文件的工具类和函数方便对配置参数等进行处理和管理
from ultralytics.utils import IterableSimpleNamespace, yaml_load
# 导入用于检查YAML文件相关情况的函数例如检查文件是否存在、格式是否正确等
from ultralytics.utils.checks import check_yaml
from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
# 从自定义的模块中导入目标跟踪相关的类BOTSORT和BYTETracker是不同的目标跟踪算法实现类用于后续根据配置选择具体的跟踪器
from.bot_sort import BOTSORT
from.byte_tracker import BYTETracker
# 创建一个字典,用于将跟踪器名称(字符串形式)映射到对应的跟踪器类,方便后续根据配置中的名称来实例化相应的跟踪器对象
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
@ -17,39 +21,80 @@ def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
函数功能
在预测开始阶段根据配置信息初始化目标跟踪器trackers如果跟踪器已经存在且persist为True则直接返回不进行重新初始化操作
否则会检查配置文件解析配置参数根据指定的跟踪器类型创建相应的跟踪器对象列表每个对象对应一批数据中的一个样本batch中的一个元素
用于后续在预测过程中对目标进行跟踪
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
参数说明
predictor (object)预测器对象它包含了与预测相关的各种属性和方法例如配置参数args数据集信息dataset
这里主要从它里面获取跟踪器相关的配置以及数据集的批次大小等信息来初始化跟踪器
persist (bool, 可选)布尔值用于决定如果跟踪器已经存在时是否保留它们默认为False表示不保留会重新初始化跟踪器
如果为True则在跟踪器已存在的情况下直接返回不进行重新初始化操作
抛出异常
AssertionError如果配置中指定的跟踪器类型tracker_type不是'bytetrack'或者'botsort'则抛出此异常因为当前代码只支持这两种跟踪器类型
提示用户当前仅支持这两种类型的跟踪器而实际配置中给出了不支持的类型
"""
# 判断预测器对象是否已经有'trackers'属性即跟踪器是否已经存在并且persist为True如果满足条件则直接返回不进行后续初始化操作
if hasattr(predictor, 'trackers') and persist:
return
# 通过check_yaml函数检查预测器配置中指定的跟踪器配置文件tracker确保文件存在且格式正确等返回处理后的文件路径或相关信息
tracker = check_yaml(predictor.args.tracker)
# 使用yaml_load函数加载跟踪器配置文件内容并将其转换为可迭代的简单命名空间对象IterableSimpleNamespace方便通过属性访问的方式获取配置参数
cfg = IterableSimpleNamespace(**yaml_load(tracker))
# 断言配置中的跟踪器类型tracker_type必须是'bytetrack'或者'botsort',如果不是则抛出异常并提示支持的跟踪器类型以及实际得到的不支持的类型
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
trackers = []
# 根据预测器数据集中的批次大小bs进行循环为每个批次中的样本创建对应的跟踪器对象
for _ in range(predictor.dataset.bs):
# 根据配置中指定的跟踪器类型cfg.tracker_type从TRACKER_MAP字典中获取对应的跟踪器类然后使用配置参数args=cfg和帧率frame_rate=30实例化跟踪器对象
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker)
# 将创建好的跟踪器对象列表赋值给预测器的'trackers'属性,以便后续在预测过程中使用这些跟踪器进行目标跟踪
predictor.trackers = trackers
def on_predict_postprocess_end(predictor):
"""Postprocess detected boxes and update with object tracking."""
"""
Postprocess detected boxes and update with object tracking.
函数功能
在预测后处理结束阶段对检测到的目标框boxes进行后处理操作并结合目标跟踪的结果更新相关信息
具体包括获取每个样本的检测框信息调用对应的跟踪器更新跟踪状态根据跟踪结果更新预测结果中的目标框等操作
使得最终的预测结果能够融合目标跟踪的信息提供更连续准确的目标状态描述
参数说明
predictor (object)预测器对象包含了预测过程中的各种中间结果以及相关属性例如数据集的批次大小bs当前批次的数据信息batch
预测结果results用于获取检测框信息图像数据以及更新预测结果等操作
"""
# 获取预测器数据集中的批次大小,即当前一次预测处理的样本数量,用于后续循环遍历每个样本进行相应操作
bs = predictor.dataset.bs
# 获取当前批次中的原始图像数据im0s这些图像数据可能是未经模型预处理的原始图像用于在目标跟踪过程中结合图像信息进行相关处理例如特征提取等具体取决于跟踪算法
im0s = predictor.batch[1]
# 遍历批次中的每个样本,对每个样本的检测框和跟踪信息进行处理
for i in range(bs):
# 获取当前样本索引为i的检测框信息将其从GPU张量转换为CPU上的numpy数组方便后续在Python环境下进行常规的数组操作和处理
# 这里的boxes属性通常包含了目标的位置、大小以及可能的其他相关信息例如置信度等具体取决于模型输出结构
det = predictor.results[i].boxes.cpu().numpy()
# 如果当前样本中没有检测到目标即检测框数量为0则跳过后续针对该样本的处理操作继续处理下一个样本
if len(det) == 0:
continue
# 调用当前样本对应的跟踪器predictor.trackers[i]的update方法传入检测框信息det和对应的原始图像im0s[i]
# 该方法会根据检测框和图像信息更新目标跟踪状态返回更新后的跟踪结果tracks跟踪结果可能包含目标的跟踪ID、更新后的位置等信息具体取决于跟踪器实现
tracks = predictor.trackers[i].update(det, im0s[i])
# 如果跟踪结果中没有有效的跟踪目标即跟踪目标数量为0则跳过后续针对该样本跟踪结果的处理操作继续处理下一个样本
if len(tracks) == 0:
continue
# 从跟踪结果tracks中提取目标的跟踪ID信息假设最后一维是跟踪ID并将其转换为整数类型的索引数组
# 用于后续根据跟踪ID筛选和更新预测结果中的目标框等信息确保预测结果与跟踪结果的一致性
idx = tracks[:, -1].astype(int)
# 根据跟踪ID索引数组idx筛选当前样本的预测结果predictor.results[i]),只保留跟踪到的目标对应的预测结果,
# 相当于将预测结果与跟踪结果进行关联和同步,去除那些没有被跟踪到的目标预测信息
predictor.results[i] = predictor.results[i][idx]
# 将更新后的跟踪结果中的目标框信息去除最后一维的跟踪ID转换为torch张量并更新到预测结果的boxes属性中
# 使得预测结果中的目标框信息能够反映最新的跟踪状态,完成预测结果与跟踪结果的融合更新操作
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
@ -57,10 +102,19 @@ def register_tracker(model, persist):
"""
Register tracking callbacks to the model for object tracking during prediction.
Args:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
函数功能
将目标跟踪相关的回调函数callbacks注册到模型对象上使得在预测过程中的特定阶段如预测开始预测后处理结束能够自动触发相应的跟踪操作
通过使用partial函数对回调函数进行部分参数绑定方便在模型的回调机制中进行注册和调用从而实现目标跟踪功能与预测流程的集成
参数说明
model (object)模型对象通常包含了模型结构配置参数以及一些用于控制模型行为如注册回调函数的方法
这里将目标跟踪的回调函数注册到该模型上使其在预测过程中能够执行跟踪相关操作
persist (bool)布尔值用于决定在跟踪器已经存在的情况下是否保留它们传递给on_predict_start回调函数影响跟踪器的初始化行为
"""
# 使用model对象的add_callback方法将'on_predict_start'阶段的回调函数注册到模型上通过partial函数将on_predict_start函数的persist参数绑定为传入的persist值
# 这样在预测开始阶段模型会自动调用on_predict_start函数进行跟踪器的初始化等相关操作具体操作取决于on_predict_start函数的实现逻辑
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
# 同样使用add_callback方法将'on_predict_postprocess_end'阶段的回调函数注册到模型上,
# 使得在预测后处理结束时模型会自动调用on_predict_postprocess_end函数进行检测框后处理和跟踪结果更新等操作具体操作见on_predict_postprocess_end函数的实现逻辑
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)

@ -8,26 +8,53 @@ import sys
import tempfile
from pathlib import Path
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
# 从当前包(模块)中导入用户配置目录相关的变量,该目录可能用于存放用户相关的配置文件等内容
from. import USER_CONFIG_DIR
# 从相关模块中导入与PyTorch版本相关的常量可能用于根据不同的PyTorch版本执行不同的逻辑比如分布式训练相关操作
from.torch_utils import TORCH_1_9
def find_free_network_port() -> int:
"""Finds a free port on localhost.
"""
Finds a free port on localhost.
函数功能
此函数的目的是在本地主机localhost 127.0.0.1上查找一个当前未被使用的网络端口
在单节点训练场景中当不想连接到真实的主节点但又需要设置 `MASTER_PORT` 环境变量时
这个函数就很有用因为它能提供一个可用的端口号供分布式训练相关配置使用
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
返回值
返回找到的空闲端口号类型为整数可用于后续的网络通信相关配置例如分布式训练中主节点监听的端口等设置
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# 将套接字绑定到本地主机的任意可用端口通过传入端口号为0来实现这样系统会自动分配一个当前未被使用的端口给该套接字
s.bind(('127.0.0.1', 0))
# 获取套接字绑定的地址信息包含IP地址和端口号并返回其中的端口号部分即找到了一个空闲的网络端口号
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
"""
Generates a DDP file and returns its file name.
函数功能
该函数用于生成一个与分布式数据并行Distributed Data ParallelDDP训练相关的Python文件
并返回这个文件的文件名文件内容主要是配置相关的代码逻辑用于后续启动分布式训练
里面会导入训练器trainer对应的类并基于一些默认配置和传入的覆盖配置overrides来实例化训练器并执行训练操作
参数说明
trainer训练器对象通常包含了训练相关的各种配置参数模型训练逻辑等信息从这个对象中可以获取到类相关信息用于在生成的文件中导入相应的类
以及其配置参数用于构建文件内的训练器实例化逻辑等
返回值
返回生成的DDP文件的文件名后续可以基于这个文件名去执行相应的分布式训练操作文件名是一个字符串类型的值
"""
# 获取训练器类的完整模块路径和类名然后通过rsplit方法从右侧按 '.' 分割一次得到模块名module和类名name
# 这样方便后续在生成的文件中准确导入对应的训练器类进行实例化操作
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
content = f'''overrides = {vars(trainer.args)}
if __name__ == "__main__":
from {module} import {name}
from ultralytics.utils import DEFAULT_CFG_DICT
@ -35,6 +62,8 @@ def generate_ddp_file(trainer):
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
trainer.train()'''
# 创建名为 'DDP' 的目录如果不存在的话用于存放生成的临时DDP文件这个目录位于用户配置目录USER_CONFIG_DIR
# 按照一定的目录结构来组织与分布式训练相关的临时文件等内容
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_',
suffix=f'{id(trainer)}.py',
@ -42,26 +71,60 @@ def generate_ddp_file(trainer):
encoding='utf-8',
dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file:
# 将上面构建好的内容写入到临时文件中该临时文件就是后续分布式训练要执行的Python脚本文件
# 里面定义了训练器实例化和训练启动的相关逻辑,基于传入的训练器对象的配置参数等信息来构建
file.write(content)
return file.name
def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
"""
Generates and returns command for distributed training.
函数功能
生成用于启动分布式训练的命令行命令并返回该命令以及对应的用于分布式训练的Python文件路径
根据不同的PyTorch版本选择不同的分布式训练启动工具 `torch.distributed.run` `torch.distributed.launch`
同时会查找一个空闲的网络端口结合其他相关参数如进程数 `world_size`训练文件路径等构建完整的命令行命令
参数说明
world_size表示分布式训练中总的进程数量用于指定启动多少个并行的训练进程来进行分布式训练通常根据计算资源训练规模等因素来确定
trainer训练器对象用于获取训练相关的文件路径如果不是常规的Python脚本文件启动方式则生成对应的临时文件
以及根据其是否是恢复训练resume属性来决定是否删除之前的保存目录等操作与分布式训练的一些前置准备工作相关
返回值
返回一个包含两个元素的元组第一个元素是构建好的用于启动分布式训练的命令行命令列表第二个元素是对应的用于分布式训练的Python文件路径
命令列表中的每个元素是命令行中的一个参数文件路径是一个字符串类型的值方便后续在系统中执行命令并基于相应文件进行分布式训练操作
"""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume:
# 如果不是恢复训练resume为False则删除之前训练保存的目录trainer.save_dir可能是为了重新开始训练避免旧数据干扰等情况
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = str(Path(sys.argv[0]).resolve())
# 定义一个正则表达式对象,用于匹配文件名的合法性,限制文件名只能由特定的字符(字母、数字、下划线、点、空格、斜杠、反斜杠、减号)组成,
# 并且长度最多为100个字符用于后续判断传入的训练文件是否符合要求
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
# 如果传入的文件不符合上述合法文件名的要求可能是通过命令行接口传入的其他形式的启动方式则调用generate_ddp_file函数生成一个临时的DDP训练文件
file = generate_ddp_file(trainer)
# 根据是否是PyTorch 1.9版本来选择不同的分布式训练启动工具TORCH_1_9是一个标识变量用于区分不同的PyTorch版本情况
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port()
# 构建用于启动分布式训练的命令行命令列表包括Python解释器路径、使用的分布式训练启动模块、每个节点的进程数量、主节点端口号以及训练文件路径等参数
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
return cmd, file
def ddp_cleanup(trainer, file):
"""Delete temp file if created."""
"""
Delete temp file if created.
函数功能
根据传入的训练器对象和文件路径判断文件是否是由该训练器生成的临时文件通过文件名中是否包含训练器的唯一标识符来判断
如果是则删除这个临时文件用于在分布式训练结束等情况下清理临时生成的文件避免占用过多磁盘空间等情况
参数说明
trainer训练器对象用于获取其唯一标识符id与文件路径中的后缀进行对比判断文件是否是与之相关的临时文件
file文件路径字符串代表要检查和可能删除的文件通过判断其文件名中是否包含训练器的标识符来确定是否是临时文件并进行删除操作
"""
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
os.remove(file)
os.remove(file)

@ -10,23 +10,56 @@ from datetime import datetime
from pathlib import Path
# 定义了一个名为WorkingDirectory的类它继承自contextlib.ContextDecorator
# 可以作为装饰器(@WorkingDirectory(dir)或者上下文管理器with WorkingDirectory(dir):)使用,
# 目的是方便地切换当前工作目录,并在操作结束后恢复到原来的工作目录。
class WorkingDirectory(contextlib.ContextDecorator):
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
"""
Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.
此类的主要功能是在进入特定代码块通过装饰器或上下文管理器的方式将当前工作目录切换到指定的新目录
当代码块执行完毕退出时再自动将工作目录恢复为原来的目录这样可以确保在不同代码块内对文件系统操作的相对路径不产生混乱
保持代码在不同环境下操作文件的稳定性和可预测性
"""
def __init__(self, new_dir):
"""Sets the working directory to 'new_dir' upon instantiation."""
self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir
"""
Sets the working directory to 'new_dir' upon instantiation.
参数说明
new_dir要切换到的新的工作目录路径可以是字符串或者Path对象类型代表希望进入的目标工作目录
在实例化时会记录下这个新的目标目录路径self.dir同时获取并记录当前所在的工作目录路径self.cwd
以便后续在退出代码块时能够恢复到原来的工作目录
"""
self.dir = new_dir # 新的工作目录路径
self.cwd = Path.cwd().resolve() # 当前工作目录路径通过Path.cwd().resolve()获取并确保是绝对路径
def __enter__(self):
"""Changes the current directory to the specified directory."""
"""
Changes the current directory to the specified directory.
当进入由该类作为上下文管理器或者装饰器包裹的代码块时此方法会被调用
它会使用os.chdir函数将当前工作目录切换到之前在实例化时指定的新目录self.dir
从而使得后续代码在该代码块内执行文件相关操作时相对路径的基准目录变为新目录
"""
os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Restore the current working directory on context exit."""
"""
Restore the current working directory on context exit.
当从由该类作为上下文管理器或者装饰器包裹的代码块退出时此方法会被调用
它会使用os.chdir函数将当前工作目录恢复到之前记录的原始工作目录self.cwd
保证不会因为在代码块内切换了工作目录而影响后续代码对文件系统操作的预期相对路径
同时可以处理代码块内可能出现的异常情况通过参数exc_typeexc_valexc_tb传递异常相关信息不过此处未对异常做特殊处理只是恢复目录
"""
os.chdir(self.cwd)
# 定义了一个名为spaces_in_path的上下文管理器函数用于处理路径中包含空格的情况
# 它会将包含空格的路径中的空格替换为下划线,复制对应的文件或目录到新路径下,
# 在执行完上下文代码块后,再将文件或目录复制回原来的位置,避免因路径空格导致的一些文件操作问题。
@contextmanager
def spaces_in_path(path):
"""
@ -34,122 +67,15 @@ def spaces_in_path(path):
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
executes the context code block, then copies the file/directory back to its original location.
Args:
path (str | Path): The original path.
Yields:
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
Example:
```python
with spaces_in_path('/path/with spaces') as new_path:
# your code here
```
"""
# If path has spaces, replace them with underscores
if ' ' in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
# Copy file/directory
if path.is_dir():
# tmp_path.mkdir(parents=True, exist_ok=True)
shutil.copytree(path, tmp_path)
elif path.is_file():
tmp_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, tmp_path)
try:
# Yield the temporary path
yield str(tmp_path) if string else tmp_path
finally:
# Copy file/directory back
if tmp_path.is_dir():
shutil.copytree(tmp_path, path, dirs_exist_ok=True)
elif tmp_path.is_file():
shutil.copy2(tmp_path, path) # Copy back the file
else:
# If there are no spaces, just yield the original path
yield path
def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
directory if it does not already exist.
Args:
path (str, pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
Returns:
(pathlib.Path): Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
if not os.path.exists(p): #
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def file_age(path=__file__):
"""Return days since last file update."""
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
def file_size(path):
"""Return file/dir size (MB)."""
if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
return 0.0
def get_latest_run(search_dir='.'):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
参数说明
path原始路径可以是字符串或者Path对象类型代表要处理的文件或目录的路径
如果该路径中包含空格则会按照上述逻辑进行处理若不包含空格则直接返回原路径供代码块使用
返回值
生成一个临时路径如果原始路径包含空格则是替换空格后的路径否则是原始路径本身供在with语句块内使用
在with语句块执行完毕后会自动进行相应的文件或目录复制回原位置等清理操作
def make_dirs(dir='new_dir/'):
"""Create directories."""
dir = Path(dir)
if dir.exists():
shutil.rmtree(dir) # delete dir
for p in dir, dir / 'labels', dir / 'images':
p.mkdir(parents=True, exist_ok=True) # make dir
return dir
示例用法
```python
with spaces_in_path('/path/with spaces') as new_path:
# your code here

@ -13,14 +13,14 @@ from .tal import bbox2dist
class VarifocalLoss(nn.Module):
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
# """Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
def __init__(self):
"""Initialize the VarifocalLoss class."""
# """Initialize the VarifocalLoss class."""
super().__init__()
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
# """Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
@ -30,13 +30,13 @@ class VarifocalLoss(nn.Module):
# Losses
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
# """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
super().__init__()
def forward(self, pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
# """Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@ -55,13 +55,13 @@ class FocalLoss(nn.Module):
class BboxLoss(nn.Module):
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
# """Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__()
self.reg_max = reg_max
self.use_dfl = use_dfl
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
# """IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
@ -78,7 +78,7 @@ class BboxLoss(nn.Module):
@staticmethod
def _df_loss(pred_dist, target):
"""Return sum of left and right DFL losses."""
# """Return sum of left and right DFL losses."""
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
tl = target.long() # target left
tr = tl + 1 # target right
@ -95,7 +95,7 @@ class KeypointLoss(nn.Module):
self.sigmas = sigmas
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
# """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
@ -114,8 +114,8 @@ class v8DetectionLoss:
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.reg_max = m.reg_max
self.device = device
@ -127,7 +127,7 @@ class v8DetectionLoss:
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
# """Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
else:
@ -144,7 +144,7 @@ class v8DetectionLoss:
return out
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
# """Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
@ -153,7 +153,7 @@ class v8DetectionLoss:
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
#"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
@ -208,7 +208,7 @@ class v8SegmentationLoss(v8DetectionLoss):
self.overlap = model.args.overlap_mask
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
# """Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Model validation metrics
#Model validation metrics
"""
import math
import warnings
@ -17,21 +17,21 @@ OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.0
# Boxes
def box_area(box):
"""Return box area, where box shape is xyxy(4,n)."""
# """Return box area, where box shape is xyxy(4,n)."""
return (box[2] - box[0]) * (box[3] - box[1])
def bbox_ioa(box1, box2, eps=1e-7):
"""
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
# """
# Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
Args:
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
# Args:
# box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
# box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
# eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
# Returns:
# (np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
"""
# Get the coordinates of bounding boxes
@ -51,18 +51,18 @@ def bbox_ioa(box1, box2, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
"""
Calculate intersection-over-union (IoU) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
# Calculate intersection-over-union (IoU) of boxes.
#Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
# Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
# Args:
# box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
# box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
# eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
"""
# Returns:
# (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
# """
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
@ -73,22 +73,22 @@ def box_iou(box1, box2, eps=1e-7):
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
# """
# Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
Args:
box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
# Args:
# box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
# box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
# xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
(x1, y1, x2, y2) format. Defaults to True.
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
# GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
# DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
# CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
# eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
"""
# Returns:
# (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
# """
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
@ -129,38 +129,37 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
def mask_iou(mask1, mask2, eps=1e-7):
"""
Calculate masks IoU.
Args:
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
product of image width and height.
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
product of image width and height.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
"""
## Calculate masks IoU.
#Args:
# mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
# product of image width and height.
# mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
# product of image width and height.
# eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
#Returns:
# (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
# """
intersection = torch.matmul(mask1, mask2.T).clamp_(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
return intersection / (union + eps)
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
"""
Calculate Object Keypoint Similarity (OKS).
Args:
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
sigma (list): A list containing 17 values representing keypoint scales.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
"""
# """
# Calculate Object Keypoint Similarity (OKS).
# Args:
# kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
# kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
# area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
# sigma (list): A list containing 17 values representing keypoint scales.
# eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
# Returns:
# (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
# """
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
@ -253,11 +252,11 @@ class ConfusionMatrix:
self.matrix[dc, self.nc] += 1 # predicted background
def matrix(self):
"""Returns the confusion matrix."""
# """Returns the confusion matrix."""
return self.matrix
def tp_fp(self):
"""Returns true positives and false positives."""
# """Returns true positives and false positives."""
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
@ -266,15 +265,15 @@ class ConfusionMatrix:
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
"""
Plot the confusion matrix using seaborn and save it to a file.
Args:
normalize (bool): Whether to normalize the confusion matrix.
save_dir (str): Directory where the plot will be saved.
names (tuple): Names of classes, used as labels on the plot.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
"""
# """
# Plot the confusion matrix using seaborn and save it to a file.
# Args:
# normalize (bool): Whether to normalize the confusion matrix.
# save_dir (str): Directory where the plot will be saved.
# names (tuple): Names of classes, used as labels on the plot.
# on_plot (func): An optional callback to pass plots path and data when they are rendered.
# """
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
@ -309,15 +308,15 @@ class ConfusionMatrix:
on_plot(plot_fname)
def print(self):
"""
Print the confusion matrix to the console.
"""
# """
# Print the confusion matrix to the console.
# """
for i in range(self.nc + 1):
LOGGER.info(' '.join(map(str, self.matrix[i])))
def smooth(y, f=0.05):
"""Box filter of fraction f."""
# """Box filter of fraction f."""
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
@ -326,7 +325,7 @@ def smooth(y, f=0.05):
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
"""Plots a precision-recall curve."""
# """Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
@ -351,7 +350,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
"""Plots a metric-confidence curve."""
#"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
@ -375,18 +374,18 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
def compute_ap(recall, precision):
"""
Compute the average precision (AP) given the recall and precision curves.
Arguments:
recall (list): The recall curve.
precision (list): The precision curve.
Returns:
(float): Average precision.
(np.ndarray): Precision envelope curve.
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
"""
#"""
# Compute the average precision (AP) given the recall and precision curves.
# Arguments:
# recall (list): The recall curve.
# precision (list): The precision curve.
# Returns:
# (float): Average precision.
# (np.ndarray): Precision envelope curve.
# (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
#"""
# Append sentinel values to beginning and end
mrec = np.concatenate(([0.0], recall, [1.0]))

@ -1,24 +1,65 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Monkey patches to update/extend functionality of existing functions
这段代码的主要作用是通过猴子补丁Monkey patches的方式来更新或扩展现有函数的功能
所谓猴子补丁就是在运行时动态地修改类或模块中的函数方法等这里是对OpenCV和PyTorch中的一些函数进行了这样的处理
使其能更好地适应特定需求比如处理多语言相关的问题增强序列化功能等
"""
from pathlib import Path
import cv2
import numpy as np
import torch
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
# 将原始的cv2.imshow函数赋值给_imshow变量这么做是为了避免后续重新定义imshow函数时出现递归调用错误。
# 因为新定义的imshow函数内部会调用显示图像的相关功能如果直接用cv2.imshow就容易陷入循环调用自身的情况
# 通过这种赋值的方式可以明确调用原始未修改的imshow逻辑保证代码正常运行。
_imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename, flags=cv2.IMREAD_COLOR):
"""
函数功能
该函数是对OpenCV原生imread函数功能的一种扩展旨在解决一些特殊情况下图像读取的兼容性问题特别是在多语言环境下
文件名可能包含非标准字符或者不同编码格式的情况它通过先从文件读取字节数据再利用cv2.imdecode进行解码的方式来获取图像数据
参数说明
- filename表示要读取的图像文件的路径可以是字符串或者Path对象类型它指定了图像文件在文件系统中的具体位置
支持各种常见的图像文件格式不过最终能否成功读取还取决于文件的实际编码情况以及OpenCV对该格式的支持程度
- flags可选参数用于指定图像读取的模式默认值是cv2.IMREAD_COLOR表示以彩色模式读取图像
还可以传入其他符合OpenCV要求的标志值例如cv2.IMREAD_GRAYSCALE表示以灰度模式读取图像等
不同的标志会影响读取后图像数据的格式和内容如通道数等
返回值
返回读取到的图像数据通常以numpy数组的形式呈现对于彩色图像一般是一个三维数组格式类似(height, width, channels)
其中height表示图像的高度width表示宽度channels表示颜色通道数例如常见的RGB图像channels为3
如果读取过程中出现文件不存在文件格式不被识别等问题可能会返回相应的错误或者异常情况取决于cv2.imdecode的具体行为
"""
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
def imwrite(filename, img):
"""
函数功能
这是对OpenCV原生imwrite函数的一种改进版本目的是在保存图像文件时更好地处理文件名包含特殊字符如多语言环境下的非ASCII字符的情况
确保图像能够正确地被编码并保存到指定文件中通过先对图像进行编码再将编码后的字节数据写入文件来实现图像保存功能
参数说明
- filename要保存图像的目标文件路径类型可以是字符串或者Path对象它决定了图像将被保存到文件系统中的哪个位置以及保存后的文件名
文件名的后缀例如.jpg.png等会指示OpenCV按照相应的图像格式对图像进行编码保存需要是OpenCV支持的格式
- img代表要保存的图像数据通常是一个多维的numpy数组其格式和维度要符合OpenCV对图像数据的要求例如彩色图像是(height, width, channels)形式
这里存储了实际的图像像素信息将被保存到指定的文件中
返回值
如果图像成功地被编码并保存到文件中函数返回True若在保存过程中出现诸如文件创建失败权限不足图像编码出错等异常情况
则返回False表示图像保存操作未成功完成
"""
try:
# 这里使用cv2.imencode函数对图像进行编码根据给定文件名的后缀通过Path(filename).suffix获取来确定编码的格式
# cv2.imencode会返回一个包含编码结果状态和编码后字节数据的元组我们取其中的第二个元素索引为1即字节数据
# 然后通过tofile方法将字节数据写入到指定的filename文件中以此实现图像的保存操作。
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
return True
except Exception:
@ -26,15 +67,57 @@ def imwrite(filename, img):
def imshow(path, im):
"""
函数功能
此函数是对OpenCV原生imshow函数的改进主要用于解决在多语言环境下图像显示窗口标题通常使用路径作为标题中字符编码的问题
确保含有特殊字符如非ASCII字符的路径能够正确显示在窗口标题上同时正常展示图像内容通过对路径进行特定的编码和解码处理后再调用原始的显示函数来实现
参数说明
- path用于指定图像显示窗口的标题路径其类型可以是字符串或者Path对象该路径一般是图像文件的实际路径或者相关标识信息
在多语言场景下可能包含需要特殊处理的字符以保证能在窗口标题上正确显示
- im代表要在窗口中显示的图像数据同样是一个符合OpenCV要求的多维numpy数组格式如(height, width, channels)的彩色图像数据等
存储了实际要展示的图像像素信息
具体操作
首先对path进行编码转换先使用encode方法将其转换为'unicode_escape'编码格式的字节串然后再用decode方法将字节串转换回普通字符串
这样可以处理路径中可能存在的特殊字符问题接着调用之前保存的原始imshow函数_imshow来显示图像使得图像能在窗口中正常展示
并且窗口标题能正确显示路径相关信息
"""
_imshow(path.encode('unicode_escape').decode(), im)
# PyTorch functions ----------------------------------------------------------------------------------------------------
# 把原始的torch.save函数赋给_torch_save变量同样是为了避免后续重新定义torch_save函数时出现递归调用的错误
# 使得在新函数内部可以调用原本的保存逻辑,防止陷入无限循环调用自身的情况,保证代码逻辑的正确性。
_torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
"""
函数功能
该函数对PyTorch原生的torch.save函数进行了功能扩展主要是针对序列化过程中处理lambda函数等特殊可调用对象的情况
因为Python标准的pickle模块在序列化时可能无法处理像lambda函数这样的对象而这里尝试导入dill模块如果存在的话
利用dill更强大的序列化能力它能处理更多类型的Python对象包括lambda函数来替换pickle模块进行序列化操作
从而让torch.save函数可以成功保存包含这类特殊对象的相关数据结构比如带有lambda函数的模型配置等
参数说明
- *args可变位置参数它的参数内容和顺序与原生torch.save函数所要求的一致主要用于传递要保存的对象以及保存的目标文件路径等关键信息
例如第一个参数通常是要序列化保存的Python对象后续参数可以是文件相关的路径等具体取决于torch.save的使用方式
- **kwargs可变关键字参数用于传递一些可选的额外参数给torch.save函数这里着重处理了'pickle_module'这个关键字参数
其他参数的含义和用法与原生torch.save函数中的相应参数保持一致
返回值
返回调用原始torch.save函数_torch_save的结果其返回值的含义与原生torch.save函数的返回行为相同
例如如果保存操作成功完成一般没有返回值在Python中相当于返回None若在保存过程中遇到诸如文件权限问题对象无法序列化等异常情况
则会抛出相应的异常具体的异常类型取决于出现问题的具体原因和torch.save函数内部的处理逻辑
具体操作
首先尝试导入dill模块并将其作为pickle模块使用通过as关键字起别名pickle如果导入过程中出现ImportError即dill模块不存在
则改为导入Python标准的pickle模块接着检查传入的关键字参数kwargs中是否已经包含了'pickle_module'这个参数
如果不存在就把刚才选择的序列化模块要么是dill要么是标准pickle模块添加到kwargs字典中对应键为'pickle_module'
最后调用原始的torch.save函数_torch_save并将处理后的*args和**kwargs参数传递进去返回其调用结果以此完成功能的扩展
"""
try:
import dill as pickle
except ImportError:
@ -42,4 +125,4 @@ def torch_save(*args, **kwargs):
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = pickle
return _torch_save(*args, **kwargs)
return _torch_save(*args, **kwargs)

@ -1,4 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# 从ultralytics.cfg模块中导入TASK2DATA和TASK2METRIC这两个变量可能是用于将任务类型与对应的数据集、评估指标等相关信息进行映射的字典
# 从ultralytics.utils模块中导入DEFAULT_CFG_DICT可能是默认配置字典、LOGGER用于记录日志、NUM_THREADS可能表示线程数量相关信息
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
@ -12,91 +14,134 @@ def run_ray_tune(model,
"""
Runs hyperparameter tuning using Ray Tune.
Args:
model (YOLO): Model to run the tuner on.
space (dict, optional): The hyperparameter search space. Defaults to None.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
函数功能
使用Ray Tune进行超参数调优针对给定的模型在指定的超参数搜索空间调优相关配置如优雅周期每个试验分配的GPU数量最大试验次数等
执行多次训练试验尝试找到最优的超参数组合并返回超参数搜索的结果
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
参数说明
model (YOLO)要进行超参数调优的YOLO模型对象包含了模型的结构训练逻辑等相关信息调优过程会基于这个模型进行多次训练尝试不同的超参数组合
space (dict, 可选)超参数搜索空间的字典用于定义要调整的超参数及其取值范围等信息如果为None则会使用默认的搜索空间默认值为None
grace_period (int, 可选)ASHA调度器的优雅周期以轮数为单位在这个周期内即使某些试验表现不佳也不会过早停止默认值为10
gpu_per_trial (int, 可选)每个试验分配的GPU数量如果为None则根据实际情况可能有默认的分配方式默认值为None
max_samples (int, 可选)要运行的最大试验次数即总共会尝试不同超参数组合进行训练的次数上限默认值为10
train_args (dict, 可选)传递给模型 `train()` 方法的其他额外参数用于在训练过程中配置训练相关的各种设置默认值为 {}
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
返回值
(dict)一个包含超参数搜索结果的字典里面包含了各个试验的相关信息如超参数取值训练指标等用于后续分析和确定最优超参数组合
抛出异常
ModuleNotFoundError如果没有安装Ray Tune库会抛出此异常提示需要安装Ray Tune才能进行超参数调优操作并给出安装命令
"""
# 如果传入的train_args为None则将其初始化为空字典确保后续操作不会出现空引用的情况
if train_args is None:
train_args = {}
try:
# 尝试从ray库中导入tune模块用于进行超参数调优相关的核心操作例如定义搜索空间、创建调优器等。
# 导入RunConfig类用于配置Ray Tune运行相关的设置如回调函数等。
# 导入WandbLoggerCallback类用于在超参数调优过程中集成Weights & Biaseswandb日志记录功能方便可视化和跟踪训练过程。
# 导入ASHAScheduler类用于定义超参数搜索的调度策略这里采用的是ASHAAsynchronous Successive Halving Algorithm调度器。
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import ASHAScheduler
except ImportError:
# 如果导入失败说明没有安装Ray Tune库抛出ModuleNotFoundError异常并提示需要安装Ray Tune才能进行超参数调优给出安装命令。
raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"')
try:
# 尝试导入wandb库用于可能的日志记录和可视化功能如果导入成功进一步检查是否有版本属性确保其正常可用
import wandb
assert hasattr(wandb, '__version__')
except (ImportError, AssertionError):
# 如果导入wandb库失败或者没有版本属性意味着可能不可用则将wandb设置为False表示不使用wandb进行日志记录等功能。
wandb = False
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
# 定义学习率初始值的搜索范围使用tune.uniform表示在给定的最小值1e-5和最大值1e-1之间均匀采样用于超参数调优时尝试不同的初始学习率。
'lr0': tune.uniform(1e-5, 1e-1),
# 定义最终OneCycleLR学习率的缩放因子lr0 * lrf得到最终学习率的搜索范围在0.01到1.0之间均匀采样,用于调整学习率的变化策略。
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
# 定义SGD动量或Adam的beta1参数的搜索范围在0.6到0.98之间均匀采样,用于调整优化器在更新参数时的动量相关参数。
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
# 定义优化器的权重衰减系数的搜索范围在0.0到0.001之间均匀采样,用于控制模型参数在训练过程中的正则化程度,防止过拟合。
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
# 定义热身轮数的搜索范围允许使用小数表示分数轮数在0.0到5.0之间均匀采样,用于控制训练开始阶段学习率逐渐上升的周期。
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
# 定义热身阶段初始动量的搜索范围在0.0到0.95之间均匀采样,用于调整热身阶段优化器的动量相关参数。
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
# 定义目标检测中框损失的增益系数的搜索范围在0.02到0.2之间均匀采样,用于调整框损失在整体损失中的权重。
'box': tune.uniform(0.02, 0.2), # box loss gain
# 定义目标检测中分类损失的增益系数的搜索范围在0.2到4.0之间均匀采样,并且提示与像素相关(可能根据图像像素等情况调整权重),用于调整分类损失在整体损失中的权重。
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
# 定义图像HSV颜色空间中Hue色调增强的比例范围在0.0到0.1之间均匀采样,用于图像增强操作,改变图像的色调。
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
# 定义图像HSV颜色空间中Saturation饱和度增强的比例范围在0.0到0.9之间均匀采样,用于图像增强操作,改变图像的饱和度。
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
# 定义图像HSV颜色空间中Value明度增强的比例范围在0.0到0.9之间均匀采样,用于图像增强操作,改变图像的明度。
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
# 定义图像旋转角度的搜索范围在0.0到45.0度之间均匀采样,用于图像增强操作,对图像进行随机旋转。
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
# 定义图像平移比例的搜索范围在0.0到0.9之间均匀采样,用于图像增强操作,对图像进行随机平移。
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
# 定义图像缩放比例的搜索范围在0.0到0.9之间均匀采样,用于图像增强操作,对图像进行随机缩放。
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
# 定义图像剪切角度的搜索范围在0.0到10.0度之间均匀采样,用于图像增强操作,对图像进行随机剪切变形。
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
# 定义图像透视变换比例的搜索范围在0.0到0.001之间均匀采样,范围相对较小,用于图像增强操作,对图像进行随机透视变换。
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
# 定义图像上下翻转的概率搜索范围在0.0到1.0之间均匀采样,用于图像增强操作,决定是否对图像进行上下翻转。
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
# 定义图像左右翻转的概率搜索范围在0.0到1.0之间均匀采样,用于图像增强操作,决定是否对图像进行左右翻转。
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
# 定义图像进行马赛克增强一种图像混合方式可能用于目标检测等任务的数据增强的概率搜索范围在0.0到1.0之间均匀采样。
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
# 定义图像进行mixup增强一种图像混合方式常用于数据增强的概率搜索范围在0.0到1.0之间均匀采样。
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
# 定义图像进行segment copy-paste增强可能用于分割任务的数据增强将部分区域复制粘贴的概率搜索范围在0.0到1.0之间均匀采样。
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
Args:
config (dict): A dictionary of hyperparameters to use for training.
函数功能
使用给定的超参数配置和额外的训练参数来训练YOLO模型在每次超参数调优的试验中会调用此函数进行一次完整的模型训练过程
参数说明
config (dict)一个包含超参数的字典这些超参数是在当前试验中要使用的具体取值用于配置模型训练过程中的各种设置如学习率损失增益等
Returns:
None.
返回值
None.
"""
# 重置模型的回调函数,可能是为了确保每次试验的训练过程不受之前设置的回调函数影响,以干净的状态开始新的训练。
model._reset_callbacks()
# 将传入的训练参数train_args更新到当前的超参数配置config使得配置包含了所有需要用于训练的参数信息。
config.update(train_args)
# 使用更新后的配置参数调用模型的 `train()` 方法进行模型训练,开始一次完整的训练过程。
model.train(**config)
# Get search space
# 获取超参数搜索空间
if not space:
space = default_space
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
# Get dataset
# 获取数据集相关信息
data = train_args.get('data', TASK2DATA[model.task])
space['data'] = data
if 'data' not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
# 定义带有资源分配的可训练函数将训练函数_tune与资源分配信息这里指定了CPU线程数量为NUM_THREADSGPU数量根据gpu_per_trial或默认设置关联起来
# 以便Ray Tune在调度试验时能够按照指定的资源分配来执行训练过程。
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
# 定义ASHA调度器用于超参数搜索设置时间属性time_attr为'epoch'表示以训练轮数作为时间参考,
# 根据模型任务类型从TASK2METRIC中获取对应的评估指标作为优化目标metric优化模式mode为'max'表示最大化该评估指标,
# 最大训练轮数max_t根据传入的训练参数train_args中的'epochs'或者默认配置DEFAULT_CFG_DICT['epochs']或者默认值100来确定
# 优雅周期grace_period使用传入的参数值缩减因子reduction_factor设置为3用于控制试验的逐步淘汰策略。
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=TASK2METRIC[model.task],
mode='max',
@ -104,17 +149,20 @@ def run_ray_tune(model,
grace_period=grace_period,
reduction_factor=3)
# Define the callbacks for the hyperparameter search
# 定义超参数搜索的回调函数列表如果wandb可用即wandb为True则添加WandbLoggerCallback用于记录日志到wandb平台项目名称为'YOLOv8-tune'
# 否则使用空列表,表示不添加额外的回调函数。
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
# Create the Ray Tune hyperparameter search tuner
# 创建Ray Tune超参数搜索调谐器tuner传入可训练函数trainable_with_resources、超参数搜索空间param_space、调优配置tune_config以及运行配置run_config
# 调优配置中指定了调度器asha_scheduler和要运行的最大样本数即最大试验次数num_samples运行配置中指定了回调函数列表callbacks和存储路径storage_path
# 用于配置整个超参数调优过程的相关设置后续通过调用fit方法来启动调优过程。
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
# Run the hyperparameter search
# 运行超参数搜索过程启动调谐器开始执行多次试验按照设定的搜索空间、调度策略等进行超参数调优在这个过程中会多次调用_tune函数进行模型训练。
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()
# 返回超参数搜索的结果通过调谐器的get_results方法获取包含各个试验详细信息如超参数取值、训练指标等的结果字典用于后续分析和确定最优超参数组合。
return tuner.get_results()
Loading…
Cancel
Save