@ -1,6 +1,3 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
import contextlib
import copy
@ -44,20 +41,14 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_
from diffusers . utils . import_utils import is_xformers_available
from diffusers . utils . torch_utils import is_compiled_module
# 可选启用 wandb 记录,未安装时不影响训练主流程
if is_wandb_available ( ) :
import wandb
logger = get_logger ( __name__ )
# -------------------------------------------------------------------------
# 功能模块:模型卡保存
# 1) 该模块用于生成/更新 README.md, 记录训练来源与关键配置
# 2) 支持将训练后验证生成的示例图片写入输出目录并写入引用
# 3) 便于后续将模型上传到 Hub 时展示效果与实验信息
# 4) 不参与训练与梯度计算,不影响参数更新与收敛行为
# 5) 既可服务于 Hub 发布,也可用于本地实验的结果归档
# -------------------------------------------------------------------------
# 将训练信息与样例图写入模型卡,便于本地归档与推送到 HuggingFace Hub
def save_model_card (
repo_id : str ,
images : list | None = None ,
@ -68,6 +59,7 @@ def save_model_card(
pipeline : DiffusionPipeline | None = None ,
) :
img_str = " "
# 将推理样例落盘到输出目录,并在 README.md 中插入相对路径引用
if images is not None :
for i , image in enumerate ( images ) :
image . save ( os . path . join ( repo_folder , f " image_ { i } .png " ) )
@ -99,14 +91,7 @@ def save_model_card(
model_card . save ( os . path . join ( repo_folder , " README.md " ) )
# -------------------------------------------------------------------------
# 功能模块: 训练后纯文本推理( validation)
# 1) 该模块仅在训练完全结束后执行,不参与训练过程与优化器状态
# 2) 该模块从 output_dir 重新加载微调后的 pipeline, 避免与训练对象耦合
# 3) 推理只接受文本提示词,不输入任何图像,不走 img2img 相关路径
# 4) 可设置推理步数与随机种子,方便提高细节并保证可复现
# 5) 输出 PIL 图片列表,可保存到目录并写入日志系统便于对比分析
# -------------------------------------------------------------------------
# 训练结束后的纯文本推理:从输出目录重新加载 pipeline, 保证推理与训练对象解耦
def run_validation_txt2img (
finetuned_model_dir : str ,
prompt : str ,
@ -123,6 +108,7 @@ def run_validation_txt2img(
f " 开始 validation 文生图:数量= { num_images } ,步数= { num_inference_steps } , guidance={ guidance_scale } ,提示词= { prompt } "
)
# 只加载 txt2img 所需组件,并禁用 safety_checker 以避免额外开销与拦截
pipe = StableDiffusionPipeline . from_pretrained (
finetuned_model_dir ,
torch_dtype = weight_dtype ,
@ -130,6 +116,7 @@ def run_validation_txt2img(
local_files_only = True ,
)
# 保证是 StableDiffusionPipeline, 避免加载到不兼容的管线导致参数不一致
if not isinstance ( pipe , StableDiffusionPipeline ) :
raise TypeError ( f " 加载的 pipeline 类型异常: { type ( pipe ) } ,需要 StableDiffusionPipeline 才能保证纯文本生图。 " )
@ -137,9 +124,11 @@ def run_validation_txt2img(
pipe . set_progress_bar_config ( disable = True )
pipe . safety_checker = lambda images , clip_input : ( images , [ False for _ in range ( len ( images ) ) ] )
# 使用 slicing 降低推理时显存占用,便于在训练机上额外运行验证
pipe . enable_attention_slicing ( )
pipe . enable_vae_slicing ( )
# 根据 accelerate 的混精配置选择 autocast, 上下文外不改变全局 dtype 行为
if accelerator . device . type == " cuda " :
if accelerator . mixed_precision == " bf16 " :
infer_ctx = torch . autocast ( device_type = " cuda " , dtype = torch . bfloat16 )
@ -150,6 +139,7 @@ def run_validation_txt2img(
else :
infer_ctx = contextlib . nullcontext ( )
# 为每张图单独设置种子偏移,保证同一次验证多图可复现且互不相同
images = [ ]
with infer_ctx :
for i in range ( num_images ) :
@ -166,6 +156,7 @@ def run_validation_txt2img(
)
images . append ( out . images [ 0 ] )
# 将验证图片写入 tracker, 便于对比不同 step 的训练效果
for tracker in accelerator . trackers :
if tracker . name == " tensorboard " :
np_images = np . stack ( [ np . asarray ( img ) for img in images ] )
@ -179,6 +170,7 @@ def run_validation_txt2img(
}
)
# 显式释放管线与缓存,避免与训练过程竞争显存
del pipe
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
@ -186,14 +178,7 @@ def run_validation_txt2img(
return images
# -------------------------------------------------------------------------
# 功能模块:从模型目录推断 TextEncoder 类型
# 1) 不同扩散模型对应不同文本编码器架构,需动态识别加载类
# 2) 通过读取 text_encoder/config 来获取 architectures 字段
# 3) 该模块返回类对象,用于后续 from_pretrained 加载权重
# 4) 便于同一训练脚本兼容多模型,而不写死具体实现
# 5) 若架构不支持会直接报错,避免训练过程走到一半才失败
# -------------------------------------------------------------------------
# 动态识别文本编码器架构,保证脚本可用于不同系列的扩散模型
def import_model_class_from_model_name_or_path ( pretrained_model_name_or_path : str , revision : str | None ) :
text_encoder_config = PretrainedConfig . from_pretrained (
pretrained_model_name_or_path ,
@ -204,68 +189,63 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
if model_class == " CLIPTextModel " :
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == " RobertaSeriesModelWithTransformation " :
from diffusers . pipelines . alt_diffusion . modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
elif model_class == " T5EncoderModel " :
from transformers import T5EncoderModel
return T5EncoderModel
else :
raise ValueError ( f " { model_class } 不受支持。 " )
# -------------------------------------------------------------------------
# 功能模块:命令行参数解析
# 1) 本模块定义 DreamBooth 训练参数与训练后 validation 参数
# 2) 训练负责微调权重与记录坐标, validation 只负责训练后文生图输出
# 3) 不提供训练中间验证参数,避免任何中途采样影响训练流程
# 4) 对关键参数组合做合法性检查,减少运行中途异常
# 5) 支持通过 shell 脚本传参实现批量实验、对比与复现
# -------------------------------------------------------------------------
# 参数解析:包含训练参数、先验保持参数、训练后验证参数,以及坐标记录参数
def parse_args ( input_args = None ) :
parser = argparse . ArgumentParser ( description = " DreamBooth 训练脚本(训练后纯文字生图 validation) " )
# 预训练模型与 tokenizer 相关配置
parser . add_argument ( " --pretrained_model_name_or_path " , type = str , default = None , required = True )
parser . add_argument ( " --revision " , type = str , default = None , required = False )
parser . add_argument ( " --variant " , type = str , default = None )
parser . add_argument ( " --tokenizer_name " , type = str , default = None )
# 数据路径与提示词配置
parser . add_argument ( " --instance_data_dir " , type = str , default = None , required = True )
parser . add_argument ( " --class_data_dir " , type = str , default = None , required = False )
parser . add_argument ( " --instance_prompt " , type = str , default = None , required = True )
parser . add_argument ( " --class_prompt " , type = str , default = None )
# 先验保持相关开关与权重
parser . add_argument ( " --with_prior_preservation " , default = False , action = " store_true " )
parser . add_argument ( " --prior_loss_weight " , type = float , default = 1.0 )
parser . add_argument ( " --num_class_images " , type = int , default = 100 )
# 输出与可复现配置
parser . add_argument ( " --output_dir " , type = str , default = " dreambooth-model " )
parser . add_argument ( " --seed " , type = int , default = None )
# 图像预处理配置
parser . add_argument ( " --resolution " , type = int , default = 512 )
parser . add_argument ( " --center_crop " , default = False , action = " store_true " )
# 是否同时训练 text encoder
parser . add_argument ( " --train_text_encoder " , action = " store_true " )
# 训练批次与 epoch/step 配置
parser . add_argument ( " --train_batch_size " , type = int , default = 4 )
parser . add_argument ( " --sample_batch_size " , type = int , default = 4 )
parser . add_argument ( " --num_train_epochs " , type = int , default = 1 )
parser . add_argument ( " --max_train_steps " , type = int , default = None )
parser . add_argument ( " --checkpointing_steps " , type = int , default = 500 )
# 梯度累积与显存优化相关开关
parser . add_argument ( " --gradient_accumulation_steps " , type = int , default = 1 )
parser . add_argument ( " --gradient_checkpointing " , action = " store_true " )
# 学习率与 scheduler 配置
parser . add_argument ( " --learning_rate " , type = float , default = 5e-6 )
parser . add_argument ( " --scale_lr " , action = " store_true " , default = False )
parser . add_argument (
" --lr_scheduler " ,
type = str ,
@ -276,37 +256,42 @@ def parse_args(input_args=None):
parser . add_argument ( " --lr_num_cycles " , type = int , default = 1 )
parser . add_argument ( " --lr_power " , type = float , default = 1.0 )
# 优化器相关配置
parser . add_argument ( " --use_8bit_adam " , action = " store_true " )
parser . add_argument ( " --dataloader_num_workers " , type = int , default = 0 )
parser . add_argument ( " --adam_beta1 " , type = float , default = 0.9 )
parser . add_argument ( " --adam_beta2 " , type = float , default = 0.999 )
parser . add_argument ( " --adam_weight_decay " , type = float , default = 1e-2 )
parser . add_argument ( " --adam_epsilon " , type = float , default = 1e-08 )
parser . add_argument ( " --max_grad_norm " , default = 1.0 , type = float )
# Hub 上传相关配置
parser . add_argument ( " --push_to_hub " , action = " store_true " )
parser . add_argument ( " --hub_token " , type = str , default = None )
parser . add_argument ( " --hub_model_id " , type = str , default = None )
# 日志与混精配置
parser . add_argument ( " --logging_dir " , type = str , default = " logs " )
parser . add_argument ( " --allow_tf32 " , action = " store_true " )
parser . add_argument ( " --report_to " , type = str , default = " tensorboard " )
parser . add_argument ( " --mixed_precision " , type = str , default = None , choices = [ " no " , " fp16 " , " bf16 " ] )
parser . add_argument ( " --prior_generation_precision " , type = str , default = None , choices = [ " no " , " fp32 " , " fp16 " , " bf16 " ] )
parser . add_argument ( " --local_rank " , type = int , default = - 1 )
# 注意力与梯度相关的显存优化开关
parser . add_argument ( " --enable_xformers_memory_efficient_attention " , action = " store_true " )
parser . add_argument ( " --set_grads_to_none " , action = " store_true " )
# 噪声与损失加权相关参数
parser . add_argument ( " --offset_noise " , action = " store_true " , default = False )
parser . add_argument ( " --snr_gamma " , type = float , default = None )
# tokenizer 与 text encoder 行为相关参数
parser . add_argument ( " --tokenizer_max_length " , type = int , default = None , required = False )
parser . add_argument ( " --text_encoder_use_attention_mask " , action = " store_true " , required = False )
parser . add_argument ( " --skip_save_text_encoder " , action = " store_true " , required = False )
# 训练后验证参数(本脚本不做中途验证,仅训练结束后跑一次)
parser . add_argument ( " --validation_prompt " , type = str , required = True )
parser . add_argument ( " --validation_negative_prompt " , type = str , default = " " )
parser . add_argument ( " --num_validation_images " , type = int , default = 10 )
@ -314,6 +299,7 @@ def parse_args(input_args=None):
parser . add_argument ( " --validation_guidance_scale " , type = float , default = 7.5 )
parser . add_argument ( " --validation_image_output_dir " , type = str , required = True )
# 训练过程坐标记录(用于可视化与轨迹分析)
parser . add_argument ( " --coords_save_path " , type = str , default = None )
parser . add_argument ( " --coords_log_interval " , type = int , default = 10 )
@ -322,10 +308,12 @@ def parse_args(input_args=None):
else :
args = parser . parse_args ( )
# 兼容 accelerate 启动时写入的 LOCAL_RANK 环境变量
env_local_rank = int ( os . environ . get ( " LOCAL_RANK " , - 1 ) )
if env_local_rank != - 1 and env_local_rank != args . local_rank :
args . local_rank = env_local_rank
# 先验保持开启时必须提供 class 数据与 class prompt
if args . with_prior_preservation :
if args . class_data_dir is None :
raise ValueError ( " 启用先验保持时必须提供 class_data_dir。 " )
@ -340,14 +328,7 @@ def parse_args(input_args=None):
return args
# -------------------------------------------------------------------------
# 功能模块: DreamBooth 训练数据集
# 1) 从 instance 与 class 目录读取图像,并统一做尺寸、裁剪与归一化
# 2) 同时提供实例提示词与类别提示词的 token id 作为文本输入
# 3) 先验保持模式下会返回两套图像与文本信息用于拼接训练
# 4) 数据集长度按 instance 与 class 的最大值取,便于循环采样
# 5) 数据集只负责准备输入,模型推理、损失计算与优化在主循环中完成
# -------------------------------------------------------------------------
# DreamBooth 数据集:负责读取图片、做裁剪归一化,并产出 prompt 的 input_ids 与 attention_mask
class DreamBoothDataset ( Dataset ) :
def __init__ (
self ,
@ -370,11 +351,13 @@ class DreamBoothDataset(Dataset):
if not self . instance_data_root . exists ( ) :
raise ValueError ( f " 实例图像目录不存在: { self . instance_data_root } " )
# instance 图片路径列表会循环采样,长度以 instance 数为基础
self . instance_images_path = [ p for p in Path ( instance_data_root ) . iterdir ( ) if p . is_file ( ) ]
self . num_instance_images = len ( self . instance_images_path )
self . instance_prompt = instance_prompt
self . _length = self . num_instance_images
# 在先验保持模式下同时读取 class 图片,并将长度设为两者最大值以便循环匹配
if class_data_root is not None :
self . class_data_root = Path ( class_data_root )
self . class_data_root . mkdir ( parents = True , exist_ok = True )
@ -388,6 +371,7 @@ class DreamBoothDataset(Dataset):
else :
self . class_data_root = None
# 训练用图像预处理:先 resize, 再 crop, 然后归一化到 [-1, 1]
self . image_transforms = transforms . Compose (
[
transforms . Resize ( size , interpolation = transforms . InterpolationMode . BILINEAR ) ,
@ -400,6 +384,7 @@ class DreamBoothDataset(Dataset):
def __len__ ( self ) :
return self . _length
# 将 prompt 统一分词为固定长度,避免动态长度导致批处理不稳定
def _tokenize ( self , prompt : str ) :
max_length = self . tokenizer_max_length if self . tokenizer_max_length is not None else self . tokenizer . model_max_length
return self . tokenizer (
@ -413,16 +398,19 @@ class DreamBoothDataset(Dataset):
def __getitem__ ( self , index ) :
example = { }
# 读取 instance 图片并处理 EXIF 方向,保证训练输入方向一致
instance_image = Image . open ( self . instance_images_path [ index % self . num_instance_images ] )
instance_image = exif_transpose ( instance_image )
if instance_image . mode != " RGB " :
instance_image = instance_image . convert ( " RGB " )
example [ " instance_images " ] = self . image_transforms ( instance_image )
# instance prompt 的 input_ids 与 attention_mask
text_inputs = self . _tokenize ( self . instance_prompt )
example [ " instance_prompt_ids " ] = text_inputs . input_ids
example [ " instance_attention_mask " ] = text_inputs . attention_mask
# 先验保持时额外返回 class 图片与 class prompt 的 token
if self . class_data_root :
class_image = Image . open ( self . class_images_path [ index % self . num_class_images ] )
class_image = exif_transpose ( class_image )
@ -437,14 +425,7 @@ class DreamBoothDataset(Dataset):
return example
# -------------------------------------------------------------------------
# 功能模块:批处理拼接与张量规整
# 1) 将单条样本组成的列表拼接为 batch 字典,供训练循环直接使用
# 2) 将图像张量 stack 成 (B,C,H,W) 并转换为 float, 提高后续 VAE 兼容性
# 3) 将 input_ids 与 attention_mask 在 batch 维度 cat, 便于文本编码器计算
# 4) 先验保持模式下将 instance 与 class 在 batch 维度拼接,减少前向次数
# 5) 该模块不做任何损失与梯度计算,只负责打包输入数据结构
# -------------------------------------------------------------------------
# 批处理拼接:将样本列表组装为 batch, 并在先验保持时拼接 instance 与 class 以减少前向次数
def collate_fn ( examples , with_prior_preservation = False ) :
input_ids = [ example [ " instance_prompt_ids " ] for example in examples ]
pixel_values = [ example [ " instance_images " ] for example in examples ]
@ -455,6 +436,7 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values + = [ example [ " class_images " ] for example in examples ]
attention_mask + = [ example [ " class_attention_mask " ] for example in examples ]
# 图像张量 stack 为 (B, C, H, W),并确保是连续内存与 float 类型
pixel_values = torch . stack ( pixel_values ) . to ( memory_format = torch . contiguous_format ) . float ( )
input_ids = torch . cat ( input_ids , dim = 0 )
attention_mask = torch . cat ( attention_mask , dim = 0 )
@ -462,14 +444,7 @@ def collate_fn(examples, with_prior_preservation=False):
return { " input_ids " : input_ids , " pixel_values " : pixel_values , " attention_mask " : attention_mask }
# -------------------------------------------------------------------------
# 功能模块:生成 class 图像的提示词数据集
# 1) 该数据集用于先验保持时批量生成类别图像,提供固定 prompt
# 2) 每条样本返回 prompt 与索引,索引用于生成稳定的文件名
# 3) 与训练数据集分离,避免采样逻辑影响训练数据读取与增强
# 4) 支持多进程环境下由 accelerate 分配采样 batch, 提高生成效率
# 5) 该模块只在 with_prior_preservation 启用且 class 数据不足时使用
# -------------------------------------------------------------------------
# class 图像生成专用数据集:仅提供 prompt 与 index, 用于加速生成与落盘命名
class PromptDataset ( Dataset ) :
def __init__ ( self , prompt , num_samples ) :
self . prompt = prompt
@ -482,14 +457,7 @@ class PromptDataset(Dataset):
return { " prompt " : self . prompt , " index " : index }
# -------------------------------------------------------------------------
# 功能模块:判断预训练模型是否包含 VAE
# 1) 通过检查 vae/config.json 是否存在来决定是否加载 VAE
# 2) 同时支持本地目录与 Hub 结构,便于离线缓存模式运行
# 3) 若不存在 VAE 子目录,将跳过加载并在训练中使用像素空间输入
# 4) 该判断只发生在初始化阶段,不影响训练过程与日志记录
# 5) 对 Stable Diffusion 类模型通常都会包含 VAE, 属于常规路径
# -------------------------------------------------------------------------
# 判断模型是否包含 VAE: 用于兼容可能没有 vae 子目录的模型结构
def model_has_vae ( args ) :
config_file_name = Path ( " vae " , AutoencoderKL . config_name ) . as_posix ( )
if os . path . isdir ( args . pretrained_model_name_or_path ) :
@ -500,14 +468,7 @@ def model_has_vae(args):
return any ( file . rfilename == config_file_name for file in files_in_repo )
# -------------------------------------------------------------------------
# 功能模块:文本编码器前向
# 1) 将 input_ids 与 attention_mask 输入 text encoder 得到条件嵌入
# 2) 可选择是否启用 attention_mask, 以适配不同文本编码器行为
# 3) 输出的 prompt_embeds 作为 UNet 条件输入,影响生成语义与身份绑定
# 4) 该函数在训练循环中频繁调用,需要保持设备与 dtype 的一致性
# 5) 返回张量为 (B, T, D),后续会与 timestep 一起输入 UNet
# -------------------------------------------------------------------------
# 文本编码:得到 prompt_embeds, 作为 UNet 的条件输入,控制语义与身份绑定
def encode_prompt ( text_encoder , input_ids , attention_mask , text_encoder_use_attention_mask : bool ) :
text_input_ids = input_ids . to ( text_encoder . device )
if text_encoder_use_attention_mask :
@ -517,18 +478,13 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return text_encoder ( text_input_ids , attention_mask = attention_mask , return_dict = False ) [ 0 ]
# -------------------------------------------------------------------------
# 功能模块:主训练流程
# 1) 负责构建 accelerate 环境、加载模型组件、准备数据与优化器
# 2) 支持先验保持:自动补足 class 图像并将 instance/class 合并训练
# 3) 训练循环中记录 loss、学习率与坐标指标, 输出 CSV 便于可视化分析
# 4) 训练结束后保存微调后的 pipeline 到 output_dir, 作为独立可用模型
# 5) 在保存完成后运行 validation, 仅用提示词进行文生图并将结果写入输出目录
# -------------------------------------------------------------------------
# 训练主流程:包含 class 数据补全、组件加载、训练循环、坐标记录、模型保存与训练后验证
def main ( args ) :
# 避免将 hub token 暴露到 wandb 等第三方日志系统中
if args . report_to == " wandb " and args . hub_token is not None :
raise ValueError ( " 不要同时使用 wandb 与 hub_token, 避免凭证泄露风险。 " )
# accelerate 项目配置:统一 output_dir 与 logging_dir, 便于多卡与断点保存
logging_dir = Path ( args . output_dir , args . logging_dir )
accelerator_project_config = ProjectConfiguration ( project_dir = args . output_dir , logging_dir = logging_dir )
@ -539,9 +495,11 @@ def main(args):
project_config = accelerator_project_config ,
)
# MPS 下关闭 AMP, 避免混精行为不一致导致训练异常
if torch . backends . mps . is_available ( ) :
accelerator . native_amp = False
# 初始化日志格式并打印 accelerate 状态,便于排查分布式配置问题
logging . basicConfig (
format = " %(asctime)s - %(levelname)s - %(name)s - %(message)s " ,
datefmt = " % m/ %d / % Y % H: % M: % S " ,
@ -549,21 +507,25 @@ def main(args):
)
logger . info ( accelerator . state , main_process_only = False )
# 主进程输出更多 warning, 非主进程尽量保持安静以减少干扰
if accelerator . is_local_main_process :
transformers . utils . logging . set_verbosity_warning ( )
warnings . filterwarnings ( " ignore " , category = UserWarning )
else :
transformers . utils . logging . set_verbosity_error ( )
# 设置随机种子,保证数据增强、噪声采样与验证结果可复现
if args . seed is not None :
set_seed ( args . seed )
# 先验保持:当 class 图片不足时,使用 base model 生成补齐并保存到 class_data_dir
if args . with_prior_preservation :
class_images_dir = Path ( args . class_data_dir )
class_images_dir . mkdir ( parents = True , exist_ok = True )
cur_class_images = len ( list ( class_images_dir . iterdir ( ) ) )
if cur_class_images < args . num_class_images :
# 生成 class 图片时可单独指定 dtype, 减少生成时的显存占用
torch_dtype = torch . float16 if accelerator . device . type == " cuda " else torch . float32
if args . prior_generation_precision == " fp32 " :
torch_dtype = torch . float32
@ -592,6 +554,7 @@ def main(args):
for example in tqdm ( sample_dataloader , desc = " 生成 class 图像 " , disable = not accelerator . is_local_main_process ) :
images = pipe ( example [ " prompt " ] ) . images
for i , image in enumerate ( images ) :
# 用图像内容 hash 防止同名冲突,并方便追溯生成来源
hash_image = insecure_hashlib . sha1 ( image . tobytes ( ) ) . hexdigest ( )
image_filename = class_images_dir / f " { example [ ' index ' ] [ i ] + cur_class_images } - { hash_image } .jpg "
image . save ( image_filename )
@ -600,6 +563,7 @@ def main(args):
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
# 准备输出目录与 Hub 仓库,仅主进程执行以避免竞争写入
if accelerator . is_main_process :
os . makedirs ( args . output_dir , exist_ok = True )
if args . push_to_hub :
@ -611,6 +575,7 @@ def main(args):
else :
repo_id = None
# tokenizer 加载:优先使用显式指定的 tokenizer_name, 否则从模型目录的 tokenizer 子目录读取
if args . tokenizer_name :
tokenizer = AutoTokenizer . from_pretrained ( args . tokenizer_name , revision = args . revision , use_fast = False )
else :
@ -621,9 +586,10 @@ def main(args):
use_fast = False ,
)
# 组件加载: scheduler、text_encoder、可选 VAE、以及 UNet
text_encoder_cls = import_model_class_from_model_name_or_path ( args . pretrained_model_name_or_path , args . revision )
noise_scheduler = DDPMScheduler . from_pretrained ( args . pretrained_model_name_or_path , subfolder = " scheduler " )
text_encoder = text_encoder_cls . from_pretrained (
args . pretrained_model_name_or_path , subfolder = " text_encoder " , revision = args . revision , variant = args . variant
)
@ -638,11 +604,13 @@ def main(args):
args . pretrained_model_name_or_path , subfolder = " unet " , revision = args . revision , variant = args . variant
)
# unwrap: 用于从 accelerator 包装对象中拿到可保存的原始模型
def unwrap_model ( model ) :
model = accelerator . unwrap_model ( model )
model = model . _orig_mod if is_compiled_module ( model ) else model
return model
# 自定义 hook: 让 accelerator.save_state 按 diffusers 的子目录结构保存 unet/text_encoder
def save_model_hook ( models , weights , output_dir ) :
if accelerator . is_main_process :
for model in models :
@ -650,6 +618,7 @@ def main(args):
model . save_pretrained ( os . path . join ( output_dir , sub_dir ) )
weights . pop ( )
# 自定义 hook: 断点恢复时从 output_dir 读取 unet/text_encoder 并覆盖当前实例参数
def load_model_hook ( models , input_dir ) :
while len ( models ) > 0 :
model = models . pop ( )
@ -665,12 +634,15 @@ def main(args):
accelerator . register_save_state_pre_hook ( save_model_hook )
accelerator . register_load_state_pre_hook ( load_model_hook )
# VAE 不参与训练,仅用于编码到 latent 空间
if vae is not None :
vae . requires_grad_ ( False )
# 默认只训练 UNet, 若开启 train_text_encoder 则同时训练文本编码器
if not args . train_text_encoder :
text_encoder . requires_grad_ ( False )
# xformers: 开启后可显著降低注意力显存占用, 但需要正确安装依赖
if args . enable_xformers_memory_efficient_attention :
if not is_xformers_available ( ) :
raise ValueError ( " xformers 不可用,请确认安装成功。 " )
@ -680,17 +652,21 @@ def main(args):
logger . warning ( " xformers 0.0.16 在部分 GPU 上训练不稳定,建议升级。 " )
unet . enable_xformers_memory_efficient_attention ( )
# gradient checkpointing: 以计算换显存, 适合大模型与大分辨率训练
if args . gradient_checkpointing :
unet . enable_gradient_checkpointing ( )
if args . train_text_encoder :
text_encoder . gradient_checkpointing_enable ( )
# TF32: 在 Ampere 上可加速 matmul, 通常对训练稳定性影响较小
if args . allow_tf32 :
torch . backends . cuda . matmul . allow_tf32 = True
# scale_lr: 按总 batch 规模放大学习率,便于多卡/大 batch 配置保持等效训练
if args . scale_lr :
args . learning_rate = args . learning_rate * args . gradient_accumulation_steps * args . train_batch_size * accelerator . num_processes
# 优化器:可选 8-bit Adam 降低显存占用
optimizer_class = torch . optim . AdamW
if args . use_8bit_adam :
try :
@ -710,6 +686,7 @@ def main(args):
eps = args . adam_epsilon ,
)
# 数据集与 dataloader: 根据 with_prior_preservation 决定是否加载 class 数据
train_dataset = DreamBoothDataset (
instance_data_root = args . instance_data_dir ,
instance_prompt = args . instance_prompt ,
@ -730,12 +707,14 @@ def main(args):
num_workers = args . dataloader_num_workers ,
)
# 训练步数:若未指定 max_train_steps, 则由 epoch 与 dataloader 推导
overrode_max_train_steps = False
num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
if args . max_train_steps is None :
args . max_train_steps = args . num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# scheduler: 基于总训练步数与 warmup 设置学习率曲线
lr_scheduler = get_scheduler (
args . lr_scheduler ,
optimizer = optimizer ,
@ -745,6 +724,7 @@ def main(args):
power = args . lr_power ,
)
# accelerate.prepare: 把模型、优化器、数据加载器与 scheduler 放入分布式与混精管理
if args . train_text_encoder :
unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
unet , text_encoder , optimizer , train_dataloader , lr_scheduler
@ -754,27 +734,33 @@ def main(args):
unet , optimizer , train_dataloader , lr_scheduler
)
# weight_dtype: 训练时模型权重与输入的 dtype, 用于混精与显存控制
weight_dtype = torch . float32
if accelerator . mixed_precision == " fp16 " :
weight_dtype = torch . float16
elif accelerator . mixed_precision == " bf16 " :
weight_dtype = torch . bfloat16
# VAE 始终与训练设备一致,并与 weight_dtype 对齐
if vae is not None :
vae . to ( accelerator . device , dtype = weight_dtype )
# 若不训练 text encoder, 则把它当作推理组件统一 cast 到混精 dtype
if not args . train_text_encoder :
text_encoder . to ( accelerator . device , dtype = weight_dtype )
# 重新计算 epoch: 在 prepare 之后 dataloader 规模可能变化,因此再推导一次更稳妥
num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
if overrode_max_train_steps :
args . max_train_steps = args . num_train_epochs * num_update_steps_per_epoch
args . num_train_epochs = math . ceil ( args . max_train_steps / num_update_steps_per_epoch )
# 初始化 tracker: 主进程写入配置, 便于实验复现与对比
if accelerator . is_main_process :
tracker_config = vars ( copy . deepcopy ( args ) )
accelerator . init_trackers ( " dreambooth " , config = tracker_config )
# coords_list: 用于记录训练过程的三维指标轨迹并写入 CSV
coords_list = [ ]
total_batch_size = args . train_batch_size * accelerator . num_processes * args . gradient_accumulation_steps
@ -795,6 +781,7 @@ def main(args):
disable = not accelerator . is_local_main_process ,
)
# 训练循环:每步完成 latent 构造、噪声添加、UNet 预测、loss 计算与反传更新
for epoch in range ( 0 , args . num_train_epochs ) :
unet . train ( )
if args . train_text_encoder :
@ -802,14 +789,17 @@ def main(args):
for step , batch in enumerate ( train_dataloader ) :
with accelerator . accumulate ( unet ) :
# 输入图像对齐 dtype, 避免混精下出现不必要的类型转换
pixel_values = batch [ " pixel_values " ] . to ( dtype = weight_dtype )
# 使用 VAE 将图像编码到 latent 空间,若无 VAE 则直接在像素空间训练
if vae is not None :
model_input = vae . encode ( batch [ " pixel_values " ] . to ( dtype = weight_dtype ) ) . latent_dist . sample ( )
model_input = model_input * vae . config . scaling_factor
else :
model_input = pixel_values
# 采样噪声,并可选叠加 offset noise 以改变噪声分布形态
if args . offset_noise :
noise = torch . randn_like ( model_input ) + 0.1 * torch . randn (
model_input . shape [ 0 ] , model_input . shape [ 1 ] , 1 , 1 , device = model_input . device
@ -817,13 +807,16 @@ def main(args):
else :
noise = torch . randn_like ( model_input )
# 为每个样本随机选择一个扩散时间步
bsz = model_input . shape [ 0 ]
timesteps = torch . randint (
0 , noise_scheduler . config . num_train_timesteps , ( bsz , ) , device = model_input . device
) . long ( )
# 前向扩散:给输入加噪声,形成 UNet 的输入
noisy_model_input = noise_scheduler . add_noise ( model_input , noise , timesteps )
# 文本编码:得到条件嵌入,用于指导 UNet 的去噪方向
encoder_hidden_states = encode_prompt (
text_encoder ,
batch [ " input_ids " ] ,
@ -831,11 +824,14 @@ def main(args):
text_encoder_use_attention_mask = args . text_encoder_use_attention_mask ,
)
# UNet 输出噪声预测(或速度预测),返回的第一个元素为预测张量
model_pred = unet ( noisy_model_input , timesteps , encoder_hidden_states , return_dict = False ) [ 0 ]
# 某些模型会同时预测方差,将通道拆分后仅保留噪声相关部分参与训练
if model_pred . shape [ 1 ] == 6 :
model_pred , _ = torch . chunk ( model_pred , 2 , dim = 1 )
# 根据 scheduler 的 prediction_type 构造监督目标
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
@ -843,11 +839,13 @@ def main(args):
else :
raise ValueError ( f " 未知 prediction_type: { noise_scheduler . config . prediction_type } " )
# 先验保持: batch 被拼接为 instance+class, 因此这里按 batch 维拆开分别计算 prior loss
if args . with_prior_preservation :
model_pred , model_pred_prior = torch . chunk ( model_pred , 2 , dim = 0 )
target , target_prior = torch . chunk ( target , 2 , dim = 0 )
prior_loss = F . mse_loss ( model_pred_prior . float ( ) , target_prior . float ( ) , reduction = " mean " )
# loss: 默认 MSE; 若提供 snr_gamma 则用 SNR 加权以平衡不同时间步的贡献
if args . snr_gamma is None :
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
else :
@ -863,6 +861,7 @@ def main(args):
if args . with_prior_preservation :
loss = loss + args . prior_loss_weight * prior_loss
# 训练轨迹记录:用模型输出统计量作为特征指标,配合 loss 形成三维轨迹
if args . coords_save_path is not None :
X_i_feature_norm = torch . norm ( model_pred . detach ( ) . float ( ) , p = 2 , dim = [ 1 , 2 , 3 ] ) . mean ( ) . item ( )
Y_i_feature_var = torch . var ( model_pred . detach ( ) . float ( ) ) . item ( )
@ -882,8 +881,10 @@ def main(args):
df . to_csv ( save_file_path , index = False )
logger . info ( f " 坐标已写入: { save_file_path } " )
# 反向传播: accelerate 负责混精与分布式同步
accelerator . backward ( loss )
# 梯度裁剪:仅在同步梯度时执行,避免对未同步的局部梯度产生偏差
if accelerator . sync_gradients :
params_to_clip = (
itertools . chain ( unet . parameters ( ) , text_encoder . parameters ( ) )
@ -892,18 +893,22 @@ def main(args):
)
accelerator . clip_grad_norm_ ( params_to_clip , args . max_grad_norm )
# 参数更新: optimizer 与 scheduler 逐步推进
optimizer . step ( )
lr_scheduler . step ( )
optimizer . zero_grad ( set_to_none = args . set_grads_to_none )
# 只有在完成一次“真实更新”后才推进 global_step 与进度条
if accelerator . sync_gradients :
progress_bar . update ( 1 )
global_step + = 1
# 按固定步数保存训练状态,用于断点恢复或中途回滚
if accelerator . is_main_process and global_step % args . checkpointing_steps == 0 :
accelerator . save_state ( args . output_dir )
logger . info ( f " 已保存训练状态到: { args . output_dir } " )
# 每步记录 loss 与 lr, 便于在 dashboard 中观察收敛曲线
logs = { " loss " : loss . detach ( ) . item ( ) , " lr " : lr_scheduler . get_last_lr ( ) [ 0 ] }
progress_bar . set_postfix ( * * logs )
accelerator . log ( logs , step = global_step )
@ -915,6 +920,7 @@ def main(args):
images = [ ]
if accelerator . is_main_process :
# 将训练好的组件写成独立 pipeline, 确保 output_dir 可直接用于推理
pipeline_args = { }
if not args . skip_save_text_encoder :
pipeline_args [ " text_encoder " ] = unwrap_model ( text_encoder )
@ -930,6 +936,7 @@ def main(args):
)
pipeline . save_pretrained ( args . output_dir )
# 释放训练对象,减少后续 validation 的显存压力
del unet
del optimizer
del lr_scheduler
@ -940,6 +947,7 @@ def main(args):
gc . collect ( )
torch . cuda . empty_cache ( )
# 训练结束后运行一次 txt2img 验证,并将结果保存到指定目录
images = run_validation_txt2img (
finetuned_model_dir = args . output_dir ,
prompt = args . validation_prompt ,
@ -959,6 +967,7 @@ def main(args):
image . save ( out_dir / f " validation_image_ { i } .png " )
logger . info ( f " validation 图像已保存到: { out_dir } " )
# 推送到 Hub: 写模型卡并上传 output_dir( 忽略 step/epoch 目录)
if args . push_to_hub :
save_model_card (
repo_id ,
@ -976,6 +985,7 @@ def main(args):
ignore_patterns = [ " step_* " , " epoch_* " ] ,
)
# 训练结束后再落一次坐标,保证最后一段数据不会因日志频率而遗漏
if args . coords_save_path is not None and coords_list :
df = pd . DataFrame ( coords_list , columns = [ " step " , " X_Feature_L2_Norm " , " Y_Feature_Variance " , " Z_LDM_Loss " ] )
save_file_path = Path ( args . coords_save_path )