feature/zuyuan3
陆鑫宇 1 year ago
parent f8b87477a3
commit 708057caa9

@ -1,4 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# 从ultralytics.cfg模块中导入TASK2DATA和TASK2METRIC这两个变量可能是用于将任务类型与对应的数据集、评估指标等相关信息进行映射的字典
# 从ultralytics.utils模块中导入DEFAULT_CFG_DICT可能是默认配置字典、LOGGER用于记录日志、NUM_THREADS可能表示线程数量相关信息
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
@ -12,91 +14,134 @@ def run_ray_tune(model,
"""
Runs hyperparameter tuning using Ray Tune.
Args:
model (YOLO): Model to run the tuner on.
space (dict, optional): The hyperparameter search space. Defaults to None.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
函数功能
使用Ray Tune进行超参数调优针对给定的模型在指定的超参数搜索空间调优相关配置如优雅周期每个试验分配的GPU数量最大试验次数等
执行多次训练试验尝试找到最优的超参数组合并返回超参数搜索的结果
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
参数说明
model (YOLO)要进行超参数调优的YOLO模型对象包含了模型的结构训练逻辑等相关信息调优过程会基于这个模型进行多次训练尝试不同的超参数组合
space (dict, 可选)超参数搜索空间的字典用于定义要调整的超参数及其取值范围等信息如果为None则会使用默认的搜索空间默认值为None
grace_period (int, 可选)ASHA调度器的优雅周期以轮数为单位在这个周期内即使某些试验表现不佳也不会过早停止默认值为10
gpu_per_trial (int, 可选)每个试验分配的GPU数量如果为None则根据实际情况可能有默认的分配方式默认值为None
max_samples (int, 可选)要运行的最大试验次数即总共会尝试不同超参数组合进行训练的次数上限默认值为10
train_args (dict, 可选)传递给模型 `train()` 方法的其他额外参数用于在训练过程中配置训练相关的各种设置默认值为 {}
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
返回值
(dict)一个包含超参数搜索结果的字典里面包含了各个试验的相关信息如超参数取值训练指标等用于后续分析和确定最优超参数组合
抛出异常
ModuleNotFoundError如果没有安装Ray Tune库会抛出此异常提示需要安装Ray Tune才能进行超参数调优操作并给出安装命令
"""
# 如果传入的train_args为None则将其初始化为空字典确保后续操作不会出现空引用的情况
if train_args is None:
train_args = {}
try:
# 尝试从ray库中导入tune模块用于进行超参数调优相关的核心操作例如定义搜索空间、创建调优器等。
# 导入RunConfig类用于配置Ray Tune运行相关的设置如回调函数等。
# 导入WandbLoggerCallback类用于在超参数调优过程中集成Weights & Biaseswandb日志记录功能方便可视化和跟踪训练过程。
# 导入ASHAScheduler类用于定义超参数搜索的调度策略这里采用的是ASHAAsynchronous Successive Halving Algorithm调度器。
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import ASHAScheduler
except ImportError:
# 如果导入失败说明没有安装Ray Tune库抛出ModuleNotFoundError异常并提示需要安装Ray Tune才能进行超参数调优给出安装命令。
raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"')
try:
# 尝试导入wandb库用于可能的日志记录和可视化功能如果导入成功进一步检查是否有版本属性确保其正常可用
import wandb
assert hasattr(wandb, '__version__')
except (ImportError, AssertionError):
# 如果导入wandb库失败或者没有版本属性意味着可能不可用则将wandb设置为False表示不使用wandb进行日志记录等功能。
wandb = False
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
# 定义学习率初始值的搜索范围使用tune.uniform表示在给定的最小值1e-5和最大值1e-1之间均匀采样用于超参数调优时尝试不同的初始学习率。
'lr0': tune.uniform(1e-5, 1e-1),
# 定义最终OneCycleLR学习率的缩放因子lr0 * lrf得到最终学习率的搜索范围在0.01到1.0之间均匀采样,用于调整学习率的变化策略。
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
# 定义SGD动量或Adam的beta1参数的搜索范围在0.6到0.98之间均匀采样,用于调整优化器在更新参数时的动量相关参数。
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
# 定义优化器的权重衰减系数的搜索范围在0.0到0.001之间均匀采样,用于控制模型参数在训练过程中的正则化程度,防止过拟合。
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
# 定义热身轮数的搜索范围允许使用小数表示分数轮数在0.0到5.0之间均匀采样,用于控制训练开始阶段学习率逐渐上升的周期。
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
# 定义热身阶段初始动量的搜索范围在0.0到0.95之间均匀采样,用于调整热身阶段优化器的动量相关参数。
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
# 定义目标检测中框损失的增益系数的搜索范围在0.02到0.2之间均匀采样,用于调整框损失在整体损失中的权重。
'box': tune.uniform(0.02, 0.2), # box loss gain
# 定义目标检测中分类损失的增益系数的搜索范围在0.2到4.0之间均匀采样,并且提示与像素相关(可能根据图像像素等情况调整权重),用于调整分类损失在整体损失中的权重。
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
# 定义图像HSV颜色空间中Hue色调增强的比例范围在0.0到0.1之间均匀采样,用于图像增强操作,改变图像的色调。
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
# 定义图像HSV颜色空间中Saturation饱和度增强的比例范围在0.0到0.9之间均匀采样,用于图像增强操作,改变图像的饱和度。
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
# 定义图像HSV颜色空间中Value明度增强的比例范围在0.0到0.9之间均匀采样,用于图像增强操作,改变图像的明度。
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
# 定义图像旋转角度的搜索范围在0.0到45.0度之间均匀采样,用于图像增强操作,对图像进行随机旋转。
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
# 定义图像平移比例的搜索范围在0.0到0.9之间均匀采样,用于图像增强操作,对图像进行随机平移。
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
# 定义图像缩放比例的搜索范围在0.0到0.9之间均匀采样,用于图像增强操作,对图像进行随机缩放。
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
# 定义图像剪切角度的搜索范围在0.0到10.0度之间均匀采样,用于图像增强操作,对图像进行随机剪切变形。
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
# 定义图像透视变换比例的搜索范围在0.0到0.001之间均匀采样,范围相对较小,用于图像增强操作,对图像进行随机透视变换。
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
# 定义图像上下翻转的概率搜索范围在0.0到1.0之间均匀采样,用于图像增强操作,决定是否对图像进行上下翻转。
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
# 定义图像左右翻转的概率搜索范围在0.0到1.0之间均匀采样,用于图像增强操作,决定是否对图像进行左右翻转。
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
# 定义图像进行马赛克增强一种图像混合方式可能用于目标检测等任务的数据增强的概率搜索范围在0.0到1.0之间均匀采样。
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
# 定义图像进行mixup增强一种图像混合方式常用于数据增强的概率搜索范围在0.0到1.0之间均匀采样。
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
# 定义图像进行segment copy-paste增强可能用于分割任务的数据增强将部分区域复制粘贴的概率搜索范围在0.0到1.0之间均匀采样。
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
Args:
config (dict): A dictionary of hyperparameters to use for training.
函数功能
使用给定的超参数配置和额外的训练参数来训练YOLO模型在每次超参数调优的试验中会调用此函数进行一次完整的模型训练过程
参数说明
config (dict)一个包含超参数的字典这些超参数是在当前试验中要使用的具体取值用于配置模型训练过程中的各种设置如学习率损失增益等
Returns:
None.
返回值
None.
"""
# 重置模型的回调函数,可能是为了确保每次试验的训练过程不受之前设置的回调函数影响,以干净的状态开始新的训练。
model._reset_callbacks()
# 将传入的训练参数train_args更新到当前的超参数配置config使得配置包含了所有需要用于训练的参数信息。
config.update(train_args)
# 使用更新后的配置参数调用模型的 `train()` 方法进行模型训练,开始一次完整的训练过程。
model.train(**config)
# Get search space
# 获取超参数搜索空间
if not space:
space = default_space
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
# Get dataset
# 获取数据集相关信息
data = train_args.get('data', TASK2DATA[model.task])
space['data'] = data
if 'data' not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
# 定义带有资源分配的可训练函数将训练函数_tune与资源分配信息这里指定了CPU线程数量为NUM_THREADSGPU数量根据gpu_per_trial或默认设置关联起来
# 以便Ray Tune在调度试验时能够按照指定的资源分配来执行训练过程。
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
# 定义ASHA调度器用于超参数搜索设置时间属性time_attr为'epoch'表示以训练轮数作为时间参考,
# 根据模型任务类型从TASK2METRIC中获取对应的评估指标作为优化目标metric优化模式mode为'max'表示最大化该评估指标,
# 最大训练轮数max_t根据传入的训练参数train_args中的'epochs'或者默认配置DEFAULT_CFG_DICT['epochs']或者默认值100来确定
# 优雅周期grace_period使用传入的参数值缩减因子reduction_factor设置为3用于控制试验的逐步淘汰策略。
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=TASK2METRIC[model.task],
mode='max',
@ -104,17 +149,20 @@ def run_ray_tune(model,
grace_period=grace_period,
reduction_factor=3)
# Define the callbacks for the hyperparameter search
# 定义超参数搜索的回调函数列表如果wandb可用即wandb为True则添加WandbLoggerCallback用于记录日志到wandb平台项目名称为'YOLOv8-tune'
# 否则使用空列表,表示不添加额外的回调函数。
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
# Create the Ray Tune hyperparameter search tuner
# 创建Ray Tune超参数搜索调谐器tuner传入可训练函数trainable_with_resources、超参数搜索空间param_space、调优配置tune_config以及运行配置run_config
# 调优配置中指定了调度器asha_scheduler和要运行的最大样本数即最大试验次数num_samples运行配置中指定了回调函数列表callbacks和存储路径storage_path
# 用于配置整个超参数调优过程的相关设置后续通过调用fit方法来启动调优过程。
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
# Run the hyperparameter search
# 运行超参数搜索过程启动调谐器开始执行多次试验按照设定的搜索空间、调度策略等进行超参数调优在这个过程中会多次调用_tune函数进行模型训练。
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()
# 返回超参数搜索的结果通过调谐器的get_results方法获取包含各个试验详细信息如超参数取值、训练指标等的结果字典用于后续分析和确定最优超参数组合。
return tuner.get_results()
Loading…
Cancel
Save