diff --git a/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py b/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py index d461235..7506319 100644 --- a/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py +++ b/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py @@ -1,43 +1,38 @@ -"""Stable Diffusion 双模态注意力热力图差异可视化工具。 +""" +Stable Diffusion 双模态注意力热力图差异可视化工具。 """ -# 通用参数解析与文件路径管理 import argparse import os from pathlib import Path from typing import Dict, Any, List, Tuple -# 数值计算与深度学习依赖 import torch import torch.nn.functional as F import numpy as np import warnings -# 可视化依赖 import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.colors import TwoSlopeNorm from mpl_toolkits.axes_grid1 import make_axes_locatable -# Diffusers 与 Transformers 依赖 from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from diffusers.models.attention_processor import Attention from transformers import CLIPTokenizer -# 图像处理与数据读取 from PIL import Image from torchvision import transforms -# 抑制非必要的警告输出 +# 关闭与本脚本输出无关的常见警告,减少控制台干扰 warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) -# ============== 核心模块:双模态注意力捕获 ============== +# 双模态注意力捕获模块:在 U-Net 前向过程中同时收集交叉注意力与自注意力 class AttentionMapProcessor: - """自定义注意力处理器,用于同时捕获 U-Net 的交叉注意力和自注意力权重。""" - + # 自定义注意力处理器,用于拦截注意力计算并缓存注意力概率图 def __init__(self, pipeline: StableDiffusionPipeline): self.cross_attention_maps: Dict[str, List[torch.Tensor]] = {} self.self_attention_maps: Dict[str, List[torch.Tensor]] = {} @@ -53,21 +48,23 @@ class AttentionMapProcessor: encoder_hidden_states: torch.Tensor = None, attention_mask: torch.Tensor = None ) -> torch.Tensor: - """重载执行注意力计算并捕获权重 (支持 Self 和 Cross)。""" - + # 同时支持 Cross-Attention 与 Self-Attention,区别在于 Key/Value 的来源 is_cross = encoder_hidden_states is not None sequence_input = encoder_hidden_states if is_cross else hidden_states - + + # 按 diffusers 的注意力实现方式构造 Q/K/V query = attn.to_q(hidden_states) key = attn.to_k(sequence_input) value = attn.to_v(sequence_input) + # 将多头维度展开到 batch 维,便于矩阵乘 query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) - + + # 计算缩放点积注意力分数 attention_scores = torch.baddbmm( torch.empty( - query.shape[0], query.shape[1], key.shape[1], + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ), query, @@ -76,44 +73,52 @@ class AttentionMapProcessor: alpha=attn.scale, ) + # softmax 得到注意力概率,并缓存到 CPU 侧用于后续聚合 attention_probs = attention_scores.softmax(dim=-1) layer_name = self.current_layer_name map_to_store = attention_probs.detach().cpu() + # 按层名分别记录交叉注意力与自注意力,便于之后按层聚合 if is_cross: if layer_name not in self.cross_attention_maps: self.cross_attention_maps[layer_name] = [] self.cross_attention_maps[layer_name].append(map_to_store) else: - # 内存保护:仅捕获中低分辨率层的自注意力 (防止 4096*4096 矩阵爆内存) - spatial_size = map_to_store.shape[-2] - if spatial_size <= 1024: + # 自注意力矩阵在高分辨率层会非常大,这里仅保留较小规模层以避免内存问题 + spatial_size = map_to_store.shape[-2] + if spatial_size <= 1024: if layer_name not in self.self_attention_maps: self.self_attention_maps[layer_name] = [] self.self_attention_maps[layer_name].append(map_to_store) + # 按注意力权重加权求和并回到原始维度,继续 U-Net 的后续计算 value = attn.head_to_batch_dim(value) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) + # 线性层与 dropout 等输出映射,与原 Attention 模块保持一致 hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states def _set_processors(self): + # 遍历 U-Net 中所有 Attention 模块,将其 processor 替换为可记录层名的包装调用 for name, module in self.pipeline.unet.named_modules(): if isinstance(module, Attention): if 'attn1' in name or 'attn2' in name: self.original_processors[name] = module.processor + def set_layer_name(current_name): def new_call(*args, **kwargs): self.current_layer_name = current_name return self.__call__(*args, **kwargs) return new_call + module.processor = set_layer_name(name) def remove(self): + # 还原所有 Attention 模块的原始 processor,并清空缓存数据 for name, original_processor in self.original_processors.items(): module = self.pipeline.unet.get_submodule(name) module.processor = original_processor @@ -121,7 +126,7 @@ class AttentionMapProcessor: self.self_attention_maps = {} -# ============== 聚合逻辑 ============== +# 注意力图聚合模块:将多层、多步的注意力数据统一聚合到固定大小的 2D 热力图 def aggregate_cross_attention( attention_maps: Dict[str, List[torch.Tensor]], @@ -129,7 +134,7 @@ def aggregate_cross_attention( target_word: str, input_ids: torch.Tensor ) -> np.ndarray: - """聚合交叉注意力:关注 Prompt 中的特定 Target Word。""" + # 将 Prompt 的 token 切分结果与目标词进行匹配,找到目标词对应的 token 索引 prompt_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().tolist()) target_lower = target_word.lower() target_indices = [] @@ -140,20 +145,25 @@ def aggregate_cross_attention( (target_lower in cleaned_token or cleaned_token.startswith(target_lower))): target_indices.append(i) + # 未命中目标词时返回全零图,避免后续流程崩溃 if not target_indices: print(f"[WARN] Cross-Attn: 目标词汇 '{target_word}' 未识别。") return np.zeros((64, 64)) all_attention_data = [] - TARGET_SPATIAL_SIZE = 4096 + TARGET_SPATIAL_SIZE = 4096 TARGET_MAP_SIZE = 64 + # 逐层将注意力概率进行时间步平均,再对目标 token 通道求和得到空间关注强度 for layer_name, step_maps in attention_maps.items(): - if not step_maps: continue + if not step_maps: + continue avg_map = torch.stack(step_maps).mean(dim=0) - if avg_map.dim() == 4: avg_map = avg_map.squeeze(0) + if avg_map.dim() == 4: + avg_map = avg_map.squeeze(0) target_map = avg_map[:, :, target_indices].sum(dim=-1).mean(dim=0).float() + # 不同层的空间分辨率不同,统一插值到固定尺寸以便跨层融合 if target_map.shape[0] != TARGET_SPATIAL_SIZE: map_size = int(np.sqrt(target_map.shape[0])) map_2d = target_map.reshape(map_size, map_size).unsqueeze(0).unsqueeze(0) @@ -162,8 +172,10 @@ def aggregate_cross_attention( else: all_attention_data.append(target_map) - if not all_attention_data: return np.zeros((64, 64)) + if not all_attention_data: + return np.zeros((64, 64)) + # 跨层求和并归一化到 0-1,便于可视化对比 final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy() final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6) return final_map_flat.reshape(TARGET_MAP_SIZE, TARGET_MAP_SIZE) @@ -172,82 +184,56 @@ def aggregate_cross_attention( def aggregate_self_attention( attention_maps: Dict[str, List[torch.Tensor]] ) -> np.ndarray: - """聚合自注意力:计算高频空间能量 (Laplacian High-Frequency Energy)。 - - 原理: - 风格和纹理通常体现为注意力图中的高频变化。 - 通过对每个 Query 的 Attention Map 应用拉普拉斯算子(Laplacian Kernel), - 我们可以提取出那些变化剧烈的区域(边缘、纹理接缝)。 - 最后聚合这些高频能量,得到的图在空间结构上与原图对齐,但亮度代表了“纹理/风格复杂度”。 - """ + # 将自注意力矩阵转为与空间对齐的强度图,这里使用拉普拉斯算子提取高频能量作为纹理强度代理 all_attention_data = [] TARGET_MAP_SIZE = 64 - - # 定义拉普拉斯卷积核用于提取高频信息 + laplacian_kernel = torch.tensor([ - [0, 1, 0], - [1, -4, 1], + [0, 1, 0], + [1, -4, 1], [0, 1, 0] ], dtype=torch.float32).view(1, 1, 3, 3) + # 逐层对自注意力矩阵进行时间步与多头平均,再对每个 query 的注意力图做高频响应统计 for layer_name, step_maps in attention_maps.items(): - if not step_maps: continue - - # [Heads, H*W, H*W] -> [H*W, H*W] 取平均 + if not step_maps: + continue + avg_matrix = torch.stack(step_maps).mean(dim=0).mean(dim=0).float() - - # 获取当前层尺寸 + current_pixels = avg_matrix.shape[0] map_size = int(np.sqrt(current_pixels)) - - # 如果尺寸太小,高频信息没有意义,跳过极小层 + + # 极小尺度的注意力图通常缺少有效纹理结构信息,这里直接跳过 if map_size < 16: continue - # 重塑为图像形式: [Batch(Pixels), Channels(1), H, W] - # 这里我们将 avg_matrix 视为:对于每一个 query pixel (行),它关注的 spatial map (列) - # 我们想知道每个 pixel 关注的区域是不是包含很多高频纹理 - attn_maps = avg_matrix.reshape(current_pixels, 1, map_size, map_size) # [N, 1, H, W] - - # 将 Kernel 移到同一设备 + attn_maps = avg_matrix.reshape(current_pixels, 1, map_size, map_size) + kernel = laplacian_kernel.to(avg_matrix.device) - - # 批量卷积计算高频响应 (High-Pass Filter) - # padding=1 保持尺寸不变 + + # 对每个 query 的空间注意力图做拉普拉斯卷积,得到高频响应 high_freq_response = F.conv2d(attn_maps, kernel, padding=1) - - # 计算能量 (取绝对值或平方),这里取绝对值代表梯度的强度 + + # 用绝对值表示高频强度,并对每个 query 累计其响应作为空间分数 high_freq_energy = torch.abs(high_freq_response) - - # 现在我们得到了 [N, 1, H, W] 的高频能量图。 - # 我们需要将其聚合回一张 [H, W] 的图。 - # 含义:对于图像上的位置 (i, j),其作为 Query 时,所关注的区域包含了多少高频信息? - # 或者:作为 Key 时,它贡献了多少高频信息? - - # 这里采用 "Query-based Aggregation": - # 计算每个 Query pixel 对高频信息的总响应 - # shape: [N, 1, H, W] -> sum(dim=(2,3)) -> [N] - # 这表示:位置 N 的像素,其注意力主要集中在高频纹理区域的程度。 - spatial_score_flat = high_freq_energy.sum(dim=(2, 3)).squeeze() # [H*W] - - # 归一化这一层的分数,防止数值爆炸 + spatial_score_flat = high_freq_energy.sum(dim=(2, 3)).squeeze() + + # 层内归一化避免不同层的数值尺度影响跨层融合 spatial_score_flat = spatial_score_flat / (spatial_score_flat.max() + 1e-6) - - # 重塑为 2D 空间图 + map_2d = spatial_score_flat.reshape(map_size, map_size).unsqueeze(0).unsqueeze(0) - - # 插值统一到目标尺寸 + resized = F.interpolate(map_2d, size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE), mode='bilinear', align_corners=False) all_attention_data.append(resized.squeeze().flatten()) - if not all_attention_data: return np.zeros((64, 64)) + if not all_attention_data: + return np.zeros((64, 64)) - # 聚合所有层 + # 跨层求和并做 0-1 归一化,得到最终纹理强度热力图 final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy() - - # 最终归一化,保持 0-1 范围,方便可视化 final_map_flat = (final_map_flat - final_map_flat.min()) / (final_map_flat.max() - final_map_flat.min() + 1e-6) - + return final_map_flat.reshape(TARGET_MAP_SIZE, TARGET_MAP_SIZE) @@ -257,7 +243,7 @@ def get_dual_attention_maps( prompt_text: str, target_word: str ) -> Tuple[Image.Image, np.ndarray, np.ndarray]: - """同时获取 Cross-Attention 和 Self-Attention 热力图。""" + # 对输入图像进行编码,并在少量时间步上运行 U-Net 来提取注意力分布 print(f"\n-> 正在处理图片: {Path(image_path).name}") image = Image.open(image_path).convert("RGB").resize((512, 512)) image_tensor = transforms.Compose([ @@ -267,35 +253,40 @@ def get_dual_attention_maps( with torch.no_grad(): latent = (pipeline.vae.encode(image_tensor).latent_dist.sample() * pipeline.vae.config.scaling_factor) - text_input = pipeline.tokenizer(prompt_text, padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt") + text_input = pipeline.tokenizer( + prompt_text, + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt" + ) with torch.no_grad(): prompt_embeds = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] scheduler = pipeline.scheduler scheduler.set_timesteps(50, device=pipeline.device) - semantic_steps = scheduler.timesteps[:10] + semantic_steps = scheduler.timesteps[:10] processor = AttentionMapProcessor(pipeline) try: with torch.no_grad(): for t in semantic_steps: - pipeline.unet(latent, t, prompt_embeds, return_dict=False) + pipeline.unet(latent, t, prompt_embeds, return_dict=False) cross_map_raw = aggregate_cross_attention( processor.cross_attention_maps, pipeline.tokenizer, target_word, text_input.input_ids ) self_map_raw = aggregate_self_attention(processor.self_attention_maps) - + except Exception as e: print(f"[ERROR] 注意力聚合失败: {e}") - # import traceback - # traceback.print_exc() cross_map_raw = np.zeros((64, 64)) self_map_raw = np.zeros((64, 64)) finally: processor.remove() + # 将 64x64 热力图上采样到与原图一致的空间大小,便于直接叠加或对比展示 def upsample(map_np): pil_img = Image.fromarray((map_np * 255).astype(np.uint8)) return np.array(pil_img.resize(image.size, resample=Image.Resampling.LANCZOS)) / 255.0 @@ -304,6 +295,7 @@ def get_dual_attention_maps( def main(): + # 解析运行参数并生成对比报告图像 parser = argparse.ArgumentParser(description="SD 双模态注意力差异分析报告") parser.add_argument("--model_path", type=str, required=True, help="Stable Diffusion 模型路径") parser.add_argument("--image_path_a", type=str, required=True, help="Clean Image") @@ -314,90 +306,100 @@ def main(): args = parser.parse_args() print(f"--- 正在生成 Museguard 双模态分析报告 (High-Freq Energy Mode) ---") - + device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = torch.float16 if device == 'cuda' else torch.float32 try: pipe = StableDiffusionPipeline.from_pretrained( - args.model_path, torch_dtype=dtype, local_files_only=True, safety_checker=None, + args.model_path, + torch_dtype=dtype, + local_files_only=True, + safety_checker=None, scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler") ).to(device) except Exception as e: - print(f"[ERROR] 模型加载失败: {e}"); return + print(f"[ERROR] 模型加载失败: {e}") + return img_A, cross_A, self_A = get_dual_attention_maps(pipe, args.image_path_a, args.prompt_text, args.target_word) img_B, cross_B, self_B = get_dual_attention_maps(pipe, args.image_path_b, args.prompt_text, args.target_word) diff_cross = cross_A - cross_B l2_cross = np.linalg.norm(diff_cross) - + diff_self = self_A - self_B l2_self = np.linalg.norm(diff_self) - + print(f"\nCross-Attn L2 Diff: {l2_cross:.4f}") print(f"Self-Attn L2 Diff: {l2_self:.4f}") - # ---------------- 绘制增强版报告 ---------------- + # 使用统一布局展示原图、两类注意力图及其差分图,并输出为单张报告图片 plt.rcParams.update({'font.family': 'serif', 'mathtext.fontset': 'cm'}) - + fig = plt.figure(figsize=(14, 22), dpi=100) gs = gridspec.GridSpec(4, 4, figure=fig, height_ratios=[1, 1, 1, 1.2], hspace=0.3, wspace=0.1) - # Row 1: Images ax_img_a = fig.add_subplot(gs[0, 0:2]) ax_img_b = fig.add_subplot(gs[0, 2:4]) - ax_img_a.imshow(img_A); ax_img_a.set_title("Clean Image ($X$)", fontsize=14); ax_img_a.axis('off') - ax_img_b.imshow(img_B); ax_img_b.set_title("Noisy Image ($X'$)", fontsize=14); ax_img_b.axis('off') + ax_img_a.imshow(img_A) + ax_img_a.set_title("Clean Image ($X$)", fontsize=14) + ax_img_a.axis('off') + ax_img_b.imshow(img_B) + ax_img_b.set_title("Noisy Image ($X'$)", fontsize=14) + ax_img_b.axis('off') - # Row 2: Cross Attention ax_cA = fig.add_subplot(gs[1, 0:2]) ax_cB = fig.add_subplot(gs[1, 2:4]) ax_cA.imshow(cross_A, cmap='jet', vmin=0, vmax=1) - ax_cA.set_title(f"Cross-Attn ($M^{{cross}}_X$)\nTarget: \"{args.target_word}\"", fontsize=14); ax_cA.axis('off') + ax_cA.set_title(f"Cross-Attn ($M^{{cross}}_X$)\nTarget: \"{args.target_word}\"", fontsize=14) + ax_cA.axis('off') im_cB = ax_cB.imshow(cross_B, cmap='jet', vmin=0, vmax=1) - ax_cB.set_title(f"Cross-Attn ($M^{{cross}}_{{X'}}$)", fontsize=14); ax_cB.axis('off') - + ax_cB.set_title(f"Cross-Attn ($M^{{cross}}_{{X'}}$)", fontsize=14) + ax_cB.axis('off') + divider = make_axes_locatable(ax_cB) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im_cB, cax=cax, label='Semantic Alignment') - # Row 3: Self Attention (High-Frequency Energy Mode) ax_sA = fig.add_subplot(gs[2, 0:2]) ax_sB = fig.add_subplot(gs[2, 2:4]) - - # 恢复使用与 Cross Attention 一致的 'jet' colormap + ax_sA.imshow(self_A, cmap='jet', vmin=0, vmax=1) - ax_sA.set_title(f"Self-Attn ($M^{{self}}_X$)\nHigh-Freq Energy (Texture)", fontsize=14); ax_sA.axis('off') - + ax_sA.set_title(f"Self-Attn ($M^{{self}}_X$)\nHigh-Freq Energy (Texture)", fontsize=14) + ax_sA.axis('off') + im_sB = ax_sB.imshow(self_B, cmap='jet', vmin=0, vmax=1) - ax_sB.set_title(f"Self-Attn ($M^{{self}}_{{X'}}$)", fontsize=14); ax_sB.axis('off') - + ax_sB.set_title(f"Self-Attn ($M^{{self}}_{{X'}}$)", fontsize=14) + ax_sB.axis('off') + divider = make_axes_locatable(ax_sB) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im_sB, cax=cax, label='Texture Intensity') - # Row 4: Differences ax_diff_c = fig.add_subplot(gs[3, 0:2]) ax_diff_s = fig.add_subplot(gs[3, 2:4]) vmax_c = max(np.max(np.abs(diff_cross)), 0.1) norm_c = TwoSlopeNorm(vmin=-vmax_c, vcenter=0., vmax=vmax_c) im_dc = ax_diff_c.imshow(diff_cross, cmap='coolwarm', norm=norm_c) - ax_diff_c.set_title(f"Cross Diff ($\Delta_{{cross}}$)\n$L_2$: {l2_cross:.4f}", fontsize=14); ax_diff_c.axis('off') + ax_diff_c.set_title(f"Cross Diff ($\\Delta_{{cross}}$)\n$L_2$: {l2_cross:.4f}", fontsize=14) + ax_diff_c.axis('off') plt.colorbar(im_dc, ax=ax_diff_c, fraction=0.046, pad=0.04) vmax_s = max(np.max(np.abs(diff_self)), 0.1) norm_s = TwoSlopeNorm(vmin=-vmax_s, vcenter=0., vmax=vmax_s) im_ds = ax_diff_s.imshow(diff_self, cmap='coolwarm', norm=norm_s) - ax_diff_s.set_title(f"Self Diff ($\Delta_{{self}}$)\n$L_2$: {l2_self:.4f}", fontsize=14); ax_diff_s.axis('off') + ax_diff_s.set_title(f"Self Diff ($\\Delta_{{self}}$)\n$L_2$: {l2_self:.4f}", fontsize=14) + ax_diff_s.axis('off') plt.colorbar(im_ds, ax=ax_diff_s, fraction=0.046, pad=0.04) fig.suptitle(f"Museguard: Dual-Mode Analysis (High-Freq Energy)", fontsize=20, fontweight='bold', y=0.92) - + out_path = Path(args.output_dir) / "dual_heatmap_report.png" out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, bbox_inches='tight', facecolor='white') print(f"\n报告已保存至: {out_path}") + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/backend/app/algorithms/evaluate/eva_gen_nums.py b/src/backend/app/algorithms/evaluate/eva_gen_nums.py index 9e69f09..ddf38f8 100644 --- a/src/backend/app/algorithms/evaluate/eva_gen_nums.py +++ b/src/backend/app/algorithms/evaluate/eva_gen_nums.py @@ -1,9 +1,6 @@ -"""图像生成质量多维度评估工具 (专业重构版)。 - -本脚本用于对比评估两组图像(Clean vs Perturbed)的生成质量。 +""" +用于对比评估两组图像(Clean vs Perturbed)的生成质量。 支持生成包含指标对比表和深度差异分析的 PNG 报告。 - -Style Guide: Google Python Style Guide """ import os @@ -27,15 +24,11 @@ from facenet_pytorch import MTCNN, InceptionResnetV1 from piq import ssim, psnr import torch_fidelity as fid -# 抑制非必要的警告输出 +# 关闭与评估过程无关的常见警告,避免影响关键信息阅读 warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) -# ----------------------------------------------------------------------------- -# 全局配置与样式 -# ----------------------------------------------------------------------------- - -# Matplotlib LaTeX 风格配置 +# 全局样式配置:统一图表字体、数学公式风格与负号显示效果 plt.rcParams.update({ 'font.family': 'serif', 'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'], @@ -43,7 +36,7 @@ plt.rcParams.update({ 'axes.unicode_minus': False }) -# 指标元数据配置:定义指标目标方向和分析阈值 +# 指标配置:给出每个指标的优劣方向以及用于分级判断的阈值 METRIC_ANALYSIS_META = { 'FID': {'higher_is_better': False, 'th': [2.0, 10.0, 30.0]}, 'SSIM': {'higher_is_better': True, 'th': [0.01, 0.05, 0.15]}, @@ -52,13 +45,12 @@ METRIC_ANALYSIS_META = { 'CLIP_IQS': {'higher_is_better': True, 'th': [0.01, 0.03, 0.08]}, 'BRISQUE': {'higher_is_better': False, 'th': [2.0, 5.0, 10.0]}, } -# 用于综合分析的降级权重 + +# 综合结论中用于累加的权重,用于把分级差异映射成总体降级强度 ANALYSIS_WEIGHTS = {'Severe': 3, 'Significant': 2, 'Slight': 1, 'Negligible': 0} -# ----------------------------------------------------------------------------- -# 模型加载 (惰性加载或全局预加载) -# ----------------------------------------------------------------------------- +# 模型加载模块:在脚本启动时尝试预加载 CLIP,失败时自动降级为不计算该项指标 try: CLIP_MODEL, CLIP_PREPROCESS = clip.load('ViT-B/32', 'cuda') @@ -67,8 +59,9 @@ except Exception as e: print(f"[Warning] CLIP 模型加载失败: {e}") CLIP_MODEL, CLIP_PREPROCESS = None, None + def _get_clip_text_features(text: str) -> torch.Tensor: - """辅助函数:获取文本的 CLIP 特征。""" + # 将文本编码为 CLIP 特征并归一化,用于后续与图像特征计算相似度 if CLIP_MODEL is None: return None tokens = clip.tokenize(text).to('cuda') @@ -77,31 +70,19 @@ def _get_clip_text_features(text: str) -> torch.Tensor: features /= features.norm(dim=-1, keepdim=True) return features -# ----------------------------------------------------------------------------- -# 核心计算逻辑 -# ----------------------------------------------------------------------------- + +# 指标计算模块:对两个图像集合计算多项指标,用于后续报告展示与差异分析 def calculate_metrics( ref_dir: str, gen_dir: str, image_size: int = 512 ) -> Dict[str, float]: - """计算图像集之间的多项质量评估指标。 - - 包括 FDS, SSIM, PSNR, CLIP_IQS, FID。 - - Args: - ref_dir: 参考图片目录路径。 - gen_dir: 生成图片目录路径。 - image_size: 图像处理尺寸。 - - Returns: - 包含各项指标名称和数值的字典。若目录无效返回空字典。 - """ + # 从目录读取图像并在同一设备上计算 FDS、SSIM、PSNR、CLIP_IQS 与 FID metrics = {} - - # 1. 数据加载 + def load_images(directory): + # 读取目录下常见格式图像并转换为 RGB,忽略无法打开的文件 imgs = [] if os.path.exists(directory): for f in os.listdir(directory): @@ -116,6 +97,7 @@ def calculate_metrics( ref_imgs = load_images(ref_dir) gen_imgs = load_images(gen_dir) + # 若任一集合为空则直接返回,避免后续指标计算出错 if not ref_imgs or not gen_imgs: print(f"[Error] 图片加载失败或目录为空: \nRef: {ref_dir}\nGen: {gen_dir}") return {} @@ -123,12 +105,13 @@ def calculate_metrics( device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') with torch.no_grad(): - # --- FDS (Face Detection Similarity) --- + # FDS:使用人脸检测与人脸特征模型度量身份相似度 print(">>> 计算 FDS...") mtcnn = MTCNN(image_size=image_size, margin=0, device=device) resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device) - + def get_face_embeds(img_list): + # 对每张图做检测与对齐,成功则提取人脸特征并收集为张量 embeds = [] for img in img_list: face = mtcnn(img) @@ -140,7 +123,6 @@ def calculate_metrics( gen_embeds = get_face_embeds(gen_imgs) if ref_embeds is not None and gen_embeds is not None: - # 计算生成集每张脸与参考集所有脸的余弦相似度均值 sims = [] for g_emb in gen_embeds: sim = torch.cosine_similarity(g_emb, ref_embeds).mean() @@ -149,39 +131,35 @@ def calculate_metrics( else: metrics['FDS'] = 0.0 - # 清理显存 + # 释放中间模型并回收显存,避免后续指标计算显存不足 del mtcnn, resnet if torch.cuda.is_available(): torch.cuda.empty_cache() - # --- SSIM & PSNR --- + # SSIM 与 PSNR:以参考集合为基准,对每张生成图计算与参考集合的平均相似度 print(">>> 计算 SSIM & PSNR...") tfm = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor() ]) - - # 将参考集堆叠为 [N, C, H, W] + ref_tensor = torch.stack([tfm(img) for img in ref_imgs]).to(device) - + ssim_accum, psnr_accum = 0.0, 0.0 for img in gen_imgs: - gen_tensor = tfm(img).unsqueeze(0).to(device) # [1, C, H, W] - - # 扩展维度以匹配参考集 + gen_tensor = tfm(img).unsqueeze(0).to(device) gen_expanded = gen_tensor.expand_as(ref_tensor) - - # 计算单张生成图相对于整个参考集的平均结构相似度 + val_ssim = ssim(gen_expanded, ref_tensor, data_range=1.0) val_psnr = psnr(gen_expanded, ref_tensor, data_range=1.0) - + ssim_accum += val_ssim.item() psnr_accum += val_psnr.item() metrics['SSIM'] = ssim_accum / len(gen_imgs) metrics['PSNR'] = psnr_accum / len(gen_imgs) - # --- CLIP IQS --- + # CLIP_IQS:用“good image”作为文本锚点,计算生成图与该文本概念的相似度均值 print(">>> 计算 CLIP IQS...") if CLIP_MODEL: iqs_accum = 0.0 @@ -195,7 +173,7 @@ def calculate_metrics( else: metrics['CLIP_IQS'] = np.nan - # --- FID --- + # FID:使用 torch_fidelity 计算两个目录的分布距离 print(">>> 计算 FID...") try: fid_res = fid.calculate_metrics( @@ -214,23 +192,16 @@ def calculate_metrics( def run_brisque_cleanly(img_dir: str) -> float: - """使用 subprocess 和临时目录优雅地执行外部 BRISQUE 脚本。 - - Args: - img_dir: 图像目录路径。 - - Returns: - BRISQUE 分数,若失败返回 NaN。 - """ + # 通过子进程调用外部 BRISQUE 脚本,并用临时目录承载其输出文件 print(f">>> 计算 BRISQUE (External)...") - + script_path = Path(__file__).parent / 'libsvm' / 'python' / 'brisquequality.py' if not script_path.exists(): print(f"[Error] 找不到 BRISQUE 脚本: {script_path}") return np.nan abs_img_dir = os.path.abspath(img_dir) - + with tempfile.TemporaryDirectory() as temp_dir: try: cmd = [ @@ -238,17 +209,16 @@ def run_brisque_cleanly(img_dir: str) -> float: abs_img_dir, temp_dir ] - - # 在脚本所在目录执行 + subprocess.run( - cmd, - cwd=script_path.parent, - check=True, - stdout=subprocess.PIPE, + cmd, + cwd=script_path.parent, + check=True, + stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - - # 读取临时生成的日志文件 + + # 从临时目录读取脚本写出的 log.txt,并解析其中的最终分数 log_file = Path(temp_dir) / 'log.txt' if log_file.exists(): content = log_file.read_text(encoding='utf-8').strip() @@ -258,47 +228,33 @@ def run_brisque_cleanly(img_dir: str) -> float: return float(content) else: return np.nan - + except Exception as e: print(f"[Error] BRISQUE 执行出错: {e}") return np.nan -# ----------------------------------------------------------------------------- -# 报告可视化与分析逻辑 -# ----------------------------------------------------------------------------- +# 报告生成模块:对指标差异进行分级解释,并渲染成包含样例图与表格的 PNG 报告 def analyze_metric_diff( - metric_name: str, - clean_val: float, + metric_name: str, + clean_val: float, pert_val: float ) -> Tuple[str, str, str]: - """生成科学的分级差异分析文本。 - - Args: - metric_name: 指标名称。 - clean_val: 干净图得分。 - pert_val: 扰动图得分。 - - Returns: - (表头箭头符号, 差异描述文本, 状态等级) - """ - + # 根据指标配置计算差异,并输出用于表格与文本解释的分析结果 cfg = METRIC_ANALYSIS_META.get(metric_name) if not cfg: return "-", "Configuration not found.", "Negligible" diff = pert_val - clean_val abs_diff = abs(diff) - - # 判定好坏: + is_better = (cfg['higher_is_better'] and diff > 0) or (not cfg['higher_is_better'] and diff < 0) is_worse = not is_better - - # 确定程度 + th = cfg['th'] if abs_diff < th[0]: - degree = "Negligible" + degree = "Negligible" elif abs_diff < th[1]: degree = "Slight" elif abs_diff < th[2]: @@ -306,9 +262,8 @@ def analyze_metric_diff( else: degree = "Severe" - # 组装文案 header_arrow = r"$\uparrow$" if cfg['higher_is_better'] else r"$\downarrow$" - + if degree == "Negligible": analysis_text = f"Negligible change (diff < {th[0]:.4f})." elif is_worse: @@ -320,31 +275,29 @@ def analyze_metric_diff( def generate_visual_report( - ref_dir: str, - clean_dir: str, - pert_dir: str, - clean_metrics: Dict, - pert_metrics: Dict, + ref_dir: str, + clean_dir: str, + pert_dir: str, + clean_metrics: Dict, + pert_metrics: Dict, output_path: str ): - """渲染并保存专业对比分析报告 (PNG)。""" - + # 从三个目录各取一张样例图,并将指标对比表与差异解释一起绘制到同一张图中 def get_sample(d): - if not os.path.exists(d): return None, "N/A" - files = [f for f in os.listdir(d) if f.lower().endswith(('.png','.jpg'))] - if not files: return None, "Empty" + if not os.path.exists(d): + return None, "N/A" + files = [f for f in os.listdir(d) if f.lower().endswith(('.png', '.jpg'))] + if not files: + return None, "Empty" return Image.open(os.path.join(d, files[0])).convert("RGB"), files[0] img_ref, name_ref = get_sample(ref_dir) img_clean, name_clean = get_sample(clean_dir) img_pert, name_pert = get_sample(pert_dir) - # 布局设置 - # 增加高度以容纳文本 - fig = plt.figure(figsize=(12, 16.5), dpi=120) + fig = plt.figure(figsize=(12, 16.5), dpi=120) gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.5], hspace=0.25, wspace=0.1) - # 1. 图像展示区 ax_ref = fig.add_subplot(gs[0, :]) if img_ref: ax_ref.imshow(img_ref) @@ -363,80 +316,73 @@ def generate_visual_report( ax_p.set_title(f"Perturbed Output ($Y'$)\n{name_pert}", fontsize=12, fontweight='bold', pad=10) ax_p.axis('off') - # 2. 数据表格与分析区 ax_data = fig.add_subplot(gs[2, :]) ax_data.axis('off') metrics_list = ['FID', 'SSIM', 'PSNR', 'FDS', 'CLIP_IQS', 'BRISQUE'] table_data = [] analysis_lines = [] - + degradation_score = 0 - - # 遍历指标生成数据和分析 + + # 为每个指标生成表格行,并收集对应的差异解释文本 for m in metrics_list: c_val = clean_metrics.get(m, np.nan) p_val = pert_metrics.get(m, np.nan) - + c_str = f"{c_val:.4f}" if not np.isnan(c_val) else "N/A" p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "N/A" diff_str = "-" - - header_arrow = "" + + header_arrow = "" if not np.isnan(c_val) and not np.isnan(p_val): - # 获取深度分析 header_arrow, text_desc, degree = analyze_metric_diff(m, c_val, p_val) - - # 计算差异值 + diff = p_val - c_val - # 差异值本身的符号 (Diff > 0 或 Diff < 0) diff_arrow = r"$\nearrow$" if diff > 0 else r"$\searrow$" - if abs(diff) < 1e-4: diff_arrow = r"$\rightarrow$" - + if abs(diff) < 1e-4: + diff_arrow = r"$\rightarrow$" + diff_str = f"{diff:+.4f} {diff_arrow}" - + analysis_lines.append(f"• {m}: Change {diff:+.4f}. Analysis: {text_desc}") - - # 累计降级分数 + cfg = METRIC_ANALYSIS_META.get(m) is_worse = (cfg['higher_is_better'] and diff < 0) or (not cfg['higher_is_better'] and diff > 0) if is_worse: degradation_score += ANALYSIS_WEIGHTS.get(degree, 0) - - # 表格第一列:名称 + 期望方向箭头 + name_with_arrow = f"{m} ({header_arrow})" if header_arrow else m table_data.append([name_with_arrow, c_str, p_str, diff_str]) - # 绘制表格 table = ax_data.table( cellText=table_data, - colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\Delta$)"], + colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\\Delta$)"], loc='upper center', cellLoc='center', colWidths=[0.25, 0.25, 0.25, 0.25] ) table.scale(1, 2.0) table.set_fontsize(11) - - # 美化表头 + + # 对表头与第一列做基础样式区分,提升可读性 for (row, col), cell in table.get_celld().items(): if row == 0: cell.set_text_props(weight='bold', color='white') - cell.set_facecolor('#404040') + cell.set_facecolor('#404040') elif col == 0: cell.set_text_props(weight='bold') cell.set_facecolor('#f5f5f5') - # 3. 底部综合分析文本框 + # 汇总差异分析文本并给出基于权重的总体结论 if not analysis_lines: analysis_lines.append("• All metrics are missing or invalid.") - + full_text = "Quantitative Difference Analysis:\n" + "\n".join(analysis_lines) - - # 总体结论判断 (基于 holistic degradation score) + conclusion = "\n\n>>> EXECUTIVE SUMMARY (Holistic Judgment):\n" - + if degradation_score >= 8: conclusion += "CRITICAL DEGRADATION. Significant quality loss observed. Attack highly effective." elif degradation_score >= 4: @@ -448,9 +394,7 @@ def generate_visual_report( full_text += conclusion - # --------------------------------------------------------------------- - # 4. Metric definitions (ASCII-only / English-only to avoid font issues) - # --------------------------------------------------------------------- + # 在报告底部补充指标含义说明,便于非专业读者理解各项指标的侧重点 metric_definitions = [ "", "", @@ -485,52 +429,45 @@ def generate_visual_report( ax_data.text( 0.05, 0.30, - full_text, - ha='left', - va='top', + full_text, + ha='left', + va='top', fontsize=12, family='monospace', wrap=True, transform=ax_data.transAxes ) fig.suptitle("Museguard: Quality Assurance Report", fontsize=18, fontweight='bold', y=0.95) - + plt.savefig(output_path, bbox_inches='tight', facecolor='white') print(f"\n[Success] 报告已生成: {output_path}") -# ----------------------------------------------------------------------------- -# 主入口 -# ----------------------------------------------------------------------------- - def main(): + # 解析参数,分别评估 Clean 与 Perturbed 两组输出,并生成汇总报告 parser = ArgumentParser() parser.add_argument('--clean_output_dir', type=str, required=True) parser.add_argument('--perturbed_output_dir', type=str, required=True) parser.add_argument('--clean_ref_dir', type=str, required=True) - parser.add_argument('--png_output_path', type=str, required=True) + parser.add_argument('--png_output_path', type=str, required=True) parser.add_argument('--size', type=int, default=512) args = parser.parse_args() - Path(args.png_output_path).parent.mkdir(parents=True, exist_ok=True) print("========================================") print(" Image Quality Evaluation Toolkit") print("========================================") - # 1. 计算 Clean 组 print(f"\n[1/2] Evaluating Clean Set: {os.path.basename(args.clean_output_dir)}") c_metrics = calculate_metrics(args.clean_ref_dir, args.clean_output_dir, args.size) if c_metrics: c_metrics['BRISQUE'] = run_brisque_cleanly(args.clean_output_dir) - # 2. 计算 Perturbed 组 print(f"\n[2/2] Evaluating Perturbed Set: {os.path.basename(args.perturbed_output_dir)}") p_metrics = calculate_metrics(args.clean_ref_dir, args.perturbed_output_dir, args.size) if p_metrics: p_metrics['BRISQUE'] = run_brisque_cleanly(args.perturbed_output_dir) - # 3. 生成报告 if c_metrics and p_metrics: generate_visual_report( args.clean_ref_dir, @@ -543,5 +480,6 @@ def main(): else: print("\n[Fatal] 评估数据不完整,中止报告生成。") + if __name__ == '__main__': main() \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_db_gen_trace.py b/src/backend/app/algorithms/finetune/train_db_gen_trace.py index 34efebc..41b15dc 100644 --- a/src/backend/app/algorithms/finetune/train_db_gen_trace.py +++ b/src/backend/app/algorithms/finetune/train_db_gen_trace.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import argparse import contextlib import copy @@ -44,20 +41,14 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module +# 可选启用 wandb 记录,未安装时不影响训练主流程 if is_wandb_available(): import wandb logger = get_logger(__name__) -# ------------------------------------------------------------------------- -# 功能模块:模型卡保存 -# 1) 该模块用于生成/更新 README.md,记录训练来源与关键配置 -# 2) 支持将训练后验证生成的示例图片写入输出目录并写入引用 -# 3) 便于后续将模型上传到 Hub 时展示效果与实验信息 -# 4) 不参与训练与梯度计算,不影响参数更新与收敛行为 -# 5) 既可服务于 Hub 发布,也可用于本地实验的结果归档 -# ------------------------------------------------------------------------- +# 将训练信息与样例图写入模型卡,便于本地归档与推送到 HuggingFace Hub def save_model_card( repo_id: str, images: list | None = None, @@ -68,6 +59,7 @@ def save_model_card( pipeline: DiffusionPipeline | None = None, ): img_str = "" + # 将推理样例落盘到输出目录,并在 README.md 中插入相对路径引用 if images is not None: for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -99,14 +91,7 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) -# ------------------------------------------------------------------------- -# 功能模块:训练后纯文本推理(validation) -# 1) 该模块仅在训练完全结束后执行,不参与训练过程与优化器状态 -# 2) 该模块从 output_dir 重新加载微调后的 pipeline,避免与训练对象耦合 -# 3) 推理只接受文本提示词,不输入任何图像,不走 img2img 相关路径 -# 4) 可设置推理步数与随机种子,方便提高细节并保证可复现 -# 5) 输出 PIL 图片列表,可保存到目录并写入日志系统便于对比分析 -# ------------------------------------------------------------------------- +# 训练结束后的纯文本推理:从输出目录重新加载 pipeline,保证推理与训练对象解耦 def run_validation_txt2img( finetuned_model_dir: str, prompt: str, @@ -123,6 +108,7 @@ def run_validation_txt2img( f"开始 validation 文生图:数量={num_images},步数={num_inference_steps},guidance={guidance_scale},提示词={prompt}" ) + # 只加载 txt2img 所需组件,并禁用 safety_checker 以避免额外开销与拦截 pipe = StableDiffusionPipeline.from_pretrained( finetuned_model_dir, torch_dtype=weight_dtype, @@ -130,6 +116,7 @@ def run_validation_txt2img( local_files_only=True, ) + # 保证是 StableDiffusionPipeline,避免加载到不兼容的管线导致参数不一致 if not isinstance(pipe, StableDiffusionPipeline): raise TypeError(f"加载的 pipeline 类型异常:{type(pipe)},需要 StableDiffusionPipeline 才能保证纯文本生图。") @@ -137,9 +124,11 @@ def run_validation_txt2img( pipe.set_progress_bar_config(disable=True) pipe.safety_checker = lambda images, clip_input: (images, [False for _ in range(len(images))]) + # 使用 slicing 降低推理时显存占用,便于在训练机上额外运行验证 pipe.enable_attention_slicing() pipe.enable_vae_slicing() + # 根据 accelerate 的混精配置选择 autocast,上下文外不改变全局 dtype 行为 if accelerator.device.type == "cuda": if accelerator.mixed_precision == "bf16": infer_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) @@ -150,6 +139,7 @@ def run_validation_txt2img( else: infer_ctx = contextlib.nullcontext() + # 为每张图单独设置种子偏移,保证同一次验证多图可复现且互不相同 images = [] with infer_ctx: for i in range(num_images): @@ -166,6 +156,7 @@ def run_validation_txt2img( ) images.append(out.images[0]) + # 将验证图片写入 tracker,便于对比不同 step 的训练效果 for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) @@ -179,6 +170,7 @@ def run_validation_txt2img( } ) + # 显式释放管线与缓存,避免与训练过程竞争显存 del pipe if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -186,14 +178,7 @@ def run_validation_txt2img( return images -# ------------------------------------------------------------------------- -# 功能模块:从模型目录推断 TextEncoder 类型 -# 1) 不同扩散模型对应不同文本编码器架构,需动态识别加载类 -# 2) 通过读取 text_encoder/config 来获取 architectures 字段 -# 3) 该模块返回类对象,用于后续 from_pretrained 加载权重 -# 4) 便于同一训练脚本兼容多模型,而不写死具体实现 -# 5) 若架构不支持会直接报错,避免训练过程走到一半才失败 -# ------------------------------------------------------------------------- +# 动态识别文本编码器架构,保证脚本可用于不同系列的扩散模型 def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str | None): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -204,68 +189,63 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st if model_class == "CLIPTextModel": from transformers import CLIPTextModel - return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation - return RobertaSeriesModelWithTransformation elif model_class == "T5EncoderModel": from transformers import T5EncoderModel - return T5EncoderModel else: raise ValueError(f"{model_class} 不受支持。") -# ------------------------------------------------------------------------- -# 功能模块:命令行参数解析 -# 1) 本模块定义 DreamBooth 训练参数与训练后 validation 参数 -# 2) 训练负责微调权重与记录坐标,validation 只负责训练后文生图输出 -# 3) 不提供训练中间验证参数,避免任何中途采样影响训练流程 -# 4) 对关键参数组合做合法性检查,减少运行中途异常 -# 5) 支持通过 shell 脚本传参实现批量实验、对比与复现 -# ------------------------------------------------------------------------- +# 参数解析:包含训练参数、先验保持参数、训练后验证参数,以及坐标记录参数 def parse_args(input_args=None): parser = argparse.ArgumentParser(description="DreamBooth 训练脚本(训练后纯文字生图 validation)") + # 预训练模型与 tokenizer 相关配置 parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) parser.add_argument("--revision", type=str, default=None, required=False) parser.add_argument("--variant", type=str, default=None) parser.add_argument("--tokenizer_name", type=str, default=None) + # 数据路径与提示词配置 parser.add_argument("--instance_data_dir", type=str, default=None, required=True) parser.add_argument("--class_data_dir", type=str, default=None, required=False) - parser.add_argument("--instance_prompt", type=str, default=None, required=True) parser.add_argument("--class_prompt", type=str, default=None) + # 先验保持相关开关与权重 parser.add_argument("--with_prior_preservation", default=False, action="store_true") parser.add_argument("--prior_loss_weight", type=float, default=1.0) parser.add_argument("--num_class_images", type=int, default=100) + # 输出与可复现配置 parser.add_argument("--output_dir", type=str, default="dreambooth-model") parser.add_argument("--seed", type=int, default=None) + # 图像预处理配置 parser.add_argument("--resolution", type=int, default=512) parser.add_argument("--center_crop", default=False, action="store_true") + # 是否同时训练 text encoder parser.add_argument("--train_text_encoder", action="store_true") + # 训练批次与 epoch/step 配置 parser.add_argument("--train_batch_size", type=int, default=4) parser.add_argument("--sample_batch_size", type=int, default=4) - parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--max_train_steps", type=int, default=None) - parser.add_argument("--checkpointing_steps", type=int, default=500) + # 梯度累积与显存优化相关开关 parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_checkpointing", action="store_true") + # 学习率与 scheduler 配置 parser.add_argument("--learning_rate", type=float, default=5e-6) parser.add_argument("--scale_lr", action="store_true", default=False) - parser.add_argument( "--lr_scheduler", type=str, @@ -276,37 +256,42 @@ def parse_args(input_args=None): parser.add_argument("--lr_num_cycles", type=int, default=1) parser.add_argument("--lr_power", type=float, default=1.0) + # 优化器相关配置 parser.add_argument("--use_8bit_adam", action="store_true") parser.add_argument("--dataloader_num_workers", type=int, default=0) - parser.add_argument("--adam_beta1", type=float, default=0.9) parser.add_argument("--adam_beta2", type=float, default=0.999) parser.add_argument("--adam_weight_decay", type=float, default=1e-2) parser.add_argument("--adam_epsilon", type=float, default=1e-08) parser.add_argument("--max_grad_norm", default=1.0, type=float) + # Hub 上传相关配置 parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) + # 日志与混精配置 parser.add_argument("--logging_dir", type=str, default="logs") parser.add_argument("--allow_tf32", action="store_true") parser.add_argument("--report_to", type=str, default="tensorboard") - parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"]) parser.add_argument("--local_rank", type=int, default=-1) + # 注意力与梯度相关的显存优化开关 parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") parser.add_argument("--set_grads_to_none", action="store_true") + # 噪声与损失加权相关参数 parser.add_argument("--offset_noise", action="store_true", default=False) parser.add_argument("--snr_gamma", type=float, default=None) + # tokenizer 与 text encoder 行为相关参数 parser.add_argument("--tokenizer_max_length", type=int, default=None, required=False) parser.add_argument("--text_encoder_use_attention_mask", action="store_true", required=False) parser.add_argument("--skip_save_text_encoder", action="store_true", required=False) + # 训练后验证参数(本脚本不做中途验证,仅训练结束后跑一次) parser.add_argument("--validation_prompt", type=str, required=True) parser.add_argument("--validation_negative_prompt", type=str, default="") parser.add_argument("--num_validation_images", type=int, default=10) @@ -314,6 +299,7 @@ def parse_args(input_args=None): parser.add_argument("--validation_guidance_scale", type=float, default=7.5) parser.add_argument("--validation_image_output_dir", type=str, required=True) + # 训练过程坐标记录(用于可视化与轨迹分析) parser.add_argument("--coords_save_path", type=str, default=None) parser.add_argument("--coords_log_interval", type=int, default=10) @@ -322,10 +308,12 @@ def parse_args(input_args=None): else: args = parser.parse_args() + # 兼容 accelerate 启动时写入的 LOCAL_RANK 环境变量 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + # 先验保持开启时必须提供 class 数据与 class prompt if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("启用先验保持时必须提供 class_data_dir。") @@ -340,14 +328,7 @@ def parse_args(input_args=None): return args -# ------------------------------------------------------------------------- -# 功能模块:DreamBooth 训练数据集 -# 1) 从 instance 与 class 目录读取图像,并统一做尺寸、裁剪与归一化 -# 2) 同时提供实例提示词与类别提示词的 token id 作为文本输入 -# 3) 先验保持模式下会返回两套图像与文本信息用于拼接训练 -# 4) 数据集长度按 instance 与 class 的最大值取,便于循环采样 -# 5) 数据集只负责准备输入,模型推理、损失计算与优化在主循环中完成 -# ------------------------------------------------------------------------- +# DreamBooth 数据集:负责读取图片、做裁剪归一化,并产出 prompt 的 input_ids 与 attention_mask class DreamBoothDataset(Dataset): def __init__( self, @@ -370,11 +351,13 @@ class DreamBoothDataset(Dataset): if not self.instance_data_root.exists(): raise ValueError(f"实例图像目录不存在:{self.instance_data_root}") + # instance 图片路径列表会循环采样,长度以 instance 数为基础 self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.is_file()] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images + # 在先验保持模式下同时读取 class 图片,并将长度设为两者最大值以便循环匹配 if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) @@ -388,6 +371,7 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None + # 训练用图像预处理:先 resize,再 crop,然后归一化到 [-1, 1] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), @@ -400,6 +384,7 @@ class DreamBoothDataset(Dataset): def __len__(self): return self._length + # 将 prompt 统一分词为固定长度,避免动态长度导致批处理不稳定 def _tokenize(self, prompt: str): max_length = self.tokenizer_max_length if self.tokenizer_max_length is not None else self.tokenizer.model_max_length return self.tokenizer( @@ -413,16 +398,19 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} + # 读取 instance 图片并处理 EXIF 方向,保证训练输入方向一致 instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) if instance_image.mode != "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) + # instance prompt 的 input_ids 与 attention_mask text_inputs = self._tokenize(self.instance_prompt) example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask + # 先验保持时额外返回 class 图片与 class prompt 的 token if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) @@ -437,14 +425,7 @@ class DreamBoothDataset(Dataset): return example -# ------------------------------------------------------------------------- -# 功能模块:批处理拼接与张量规整 -# 1) 将单条样本组成的列表拼接为 batch 字典,供训练循环直接使用 -# 2) 将图像张量 stack 成 (B,C,H,W) 并转换为 float,提高后续 VAE 兼容性 -# 3) 将 input_ids 与 attention_mask 在 batch 维度 cat,便于文本编码器计算 -# 4) 先验保持模式下将 instance 与 class 在 batch 维度拼接,减少前向次数 -# 5) 该模块不做任何损失与梯度计算,只负责打包输入数据结构 -# ------------------------------------------------------------------------- +# 批处理拼接:将样本列表组装为 batch,并在先验保持时拼接 instance 与 class 以减少前向次数 def collate_fn(examples, with_prior_preservation=False): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -455,6 +436,7 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values += [example["class_images"] for example in examples] attention_mask += [example["class_attention_mask"] for example in examples] + # 图像张量 stack 为 (B, C, H, W),并确保是连续内存与 float 类型 pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() input_ids = torch.cat(input_ids, dim=0) attention_mask = torch.cat(attention_mask, dim=0) @@ -462,14 +444,7 @@ def collate_fn(examples, with_prior_preservation=False): return {"input_ids": input_ids, "pixel_values": pixel_values, "attention_mask": attention_mask} -# ------------------------------------------------------------------------- -# 功能模块:生成 class 图像的提示词数据集 -# 1) 该数据集用于先验保持时批量生成类别图像,提供固定 prompt -# 2) 每条样本返回 prompt 与索引,索引用于生成稳定的文件名 -# 3) 与训练数据集分离,避免采样逻辑影响训练数据读取与增强 -# 4) 支持多进程环境下由 accelerate 分配采样 batch,提高生成效率 -# 5) 该模块只在 with_prior_preservation 启用且 class 数据不足时使用 -# ------------------------------------------------------------------------- +# class 图像生成专用数据集:仅提供 prompt 与 index,用于加速生成与落盘命名 class PromptDataset(Dataset): def __init__(self, prompt, num_samples): self.prompt = prompt @@ -482,14 +457,7 @@ class PromptDataset(Dataset): return {"prompt": self.prompt, "index": index} -# ------------------------------------------------------------------------- -# 功能模块:判断预训练模型是否包含 VAE -# 1) 通过检查 vae/config.json 是否存在来决定是否加载 VAE -# 2) 同时支持本地目录与 Hub 结构,便于离线缓存模式运行 -# 3) 若不存在 VAE 子目录,将跳过加载并在训练中使用像素空间输入 -# 4) 该判断只发生在初始化阶段,不影响训练过程与日志记录 -# 5) 对 Stable Diffusion 类模型通常都会包含 VAE,属于常规路径 -# ------------------------------------------------------------------------- +# 判断模型是否包含 VAE:用于兼容可能没有 vae 子目录的模型结构 def model_has_vae(args): config_file_name = Path("vae", AutoencoderKL.config_name).as_posix() if os.path.isdir(args.pretrained_model_name_or_path): @@ -500,14 +468,7 @@ def model_has_vae(args): return any(file.rfilename == config_file_name for file in files_in_repo) -# ------------------------------------------------------------------------- -# 功能模块:文本编码器前向 -# 1) 将 input_ids 与 attention_mask 输入 text encoder 得到条件嵌入 -# 2) 可选择是否启用 attention_mask,以适配不同文本编码器行为 -# 3) 输出的 prompt_embeds 作为 UNet 条件输入,影响生成语义与身份绑定 -# 4) 该函数在训练循环中频繁调用,需要保持设备与 dtype 的一致性 -# 5) 返回张量为 (B, T, D),后续会与 timestep 一起输入 UNet -# ------------------------------------------------------------------------- +# 文本编码:得到 prompt_embeds,作为 UNet 的条件输入,控制语义与身份绑定 def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask: bool): text_input_ids = input_ids.to(text_encoder.device) if text_encoder_use_attention_mask: @@ -517,18 +478,13 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte return text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False)[0] -# ------------------------------------------------------------------------- -# 功能模块:主训练流程 -# 1) 负责构建 accelerate 环境、加载模型组件、准备数据与优化器 -# 2) 支持先验保持:自动补足 class 图像并将 instance/class 合并训练 -# 3) 训练循环中记录 loss、学习率与坐标指标,输出 CSV 便于可视化分析 -# 4) 训练结束后保存微调后的 pipeline 到 output_dir,作为独立可用模型 -# 5) 在保存完成后运行 validation,仅用提示词进行文生图并将结果写入输出目录 -# ------------------------------------------------------------------------- +# 训练主流程:包含 class 数据补全、组件加载、训练循环、坐标记录、模型保存与训练后验证 def main(args): + # 避免将 hub token 暴露到 wandb 等第三方日志系统中 if args.report_to == "wandb" and args.hub_token is not None: raise ValueError("不要同时使用 wandb 与 hub_token,避免凭证泄露风险。") + # accelerate 项目配置:统一 output_dir 与 logging_dir,便于多卡与断点保存 logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -539,9 +495,11 @@ def main(args): project_config=accelerator_project_config, ) + # MPS 下关闭 AMP,避免混精行为不一致导致训练异常 if torch.backends.mps.is_available(): accelerator.native_amp = False + # 初始化日志格式并打印 accelerate 状态,便于排查分布式配置问题 logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -549,21 +507,25 @@ def main(args): ) logger.info(accelerator.state, main_process_only=False) + # 主进程输出更多 warning,非主进程尽量保持安静以减少干扰 if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() warnings.filterwarnings("ignore", category=UserWarning) else: transformers.utils.logging.set_verbosity_error() + # 设置随机种子,保证数据增强、噪声采样与验证结果可复现 if args.seed is not None: set_seed(args.seed) + # 先验保持:当 class 图片不足时,使用 base model 生成补齐并保存到 class_data_dir if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) class_images_dir.mkdir(parents=True, exist_ok=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: + # 生成 class 图片时可单独指定 dtype,减少生成时的显存占用 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 @@ -592,6 +554,7 @@ def main(args): for example in tqdm(sample_dataloader, desc="生成 class 图像", disable=not accelerator.is_local_main_process): images = pipe(example["prompt"]).images for i, image in enumerate(images): + # 用图像内容 hash 防止同名冲突,并方便追溯生成来源 hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -600,6 +563,7 @@ def main(args): if torch.cuda.is_available(): torch.cuda.empty_cache() + # 准备输出目录与 Hub 仓库,仅主进程执行以避免竞争写入 if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: @@ -611,6 +575,7 @@ def main(args): else: repo_id = None + # tokenizer 加载:优先使用显式指定的 tokenizer_name,否则从模型目录的 tokenizer 子目录读取 if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) else: @@ -621,9 +586,10 @@ def main(args): use_fast=False, ) + # 组件加载:scheduler、text_encoder、可选 VAE、以及 UNet text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -638,11 +604,13 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) + # unwrap:用于从 accelerator 包装对象中拿到可保存的原始模型 def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model + # 自定义 hook:让 accelerator.save_state 按 diffusers 的子目录结构保存 unet/text_encoder def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for model in models: @@ -650,6 +618,7 @@ def main(args): model.save_pretrained(os.path.join(output_dir, sub_dir)) weights.pop() + # 自定义 hook:断点恢复时从 output_dir 读取 unet/text_encoder 并覆盖当前实例参数 def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() @@ -665,12 +634,15 @@ def main(args): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) + # VAE 不参与训练,仅用于编码到 latent 空间 if vae is not None: vae.requires_grad_(False) + # 默认只训练 UNet,若开启 train_text_encoder 则同时训练文本编码器 if not args.train_text_encoder: text_encoder.requires_grad_(False) + # xformers:开启后可显著降低注意力显存占用,但需要正确安装依赖 if args.enable_xformers_memory_efficient_attention: if not is_xformers_available(): raise ValueError("xformers 不可用,请确认安装成功。") @@ -680,17 +652,21 @@ def main(args): logger.warning("xformers 0.0.16 在部分 GPU 上训练不稳定,建议升级。") unet.enable_xformers_memory_efficient_attention() + # gradient checkpointing:以计算换显存,适合大模型与大分辨率训练 if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder.gradient_checkpointing_enable() + # TF32:在 Ampere 上可加速 matmul,通常对训练稳定性影响较小 if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + # scale_lr:按总 batch 规模放大学习率,便于多卡/大 batch 配置保持等效训练 if args.scale_lr: args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + # 优化器:可选 8-bit Adam 降低显存占用 optimizer_class = torch.optim.AdamW if args.use_8bit_adam: try: @@ -710,6 +686,7 @@ def main(args): eps=args.adam_epsilon, ) + # 数据集与 dataloader:根据 with_prior_preservation 决定是否加载 class 数据 train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, @@ -730,12 +707,14 @@ def main(args): num_workers=args.dataloader_num_workers, ) + # 训练步数:若未指定 max_train_steps,则由 epoch 与 dataloader 推导 overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + # scheduler:基于总训练步数与 warmup 设置学习率曲线 lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, @@ -745,6 +724,7 @@ def main(args): power=args.lr_power, ) + # accelerate.prepare:把模型、优化器、数据加载器与 scheduler 放入分布式与混精管理 if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -754,27 +734,33 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler ) + # weight_dtype:训练时模型权重与输入的 dtype,用于混精与显存控制 weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + # VAE 始终与训练设备一致,并与 weight_dtype 对齐 if vae is not None: vae.to(accelerator.device, dtype=weight_dtype) + # 若不训练 text encoder,则把它当作推理组件统一 cast 到混精 dtype if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) + # 重新计算 epoch:在 prepare 之后 dataloader 规模可能变化,因此再推导一次更稳妥 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + # 初始化 tracker:主进程写入配置,便于实验复现与对比 if accelerator.is_main_process: tracker_config = vars(copy.deepcopy(args)) accelerator.init_trackers("dreambooth", config=tracker_config) + # coords_list:用于记录训练过程的三维指标轨迹并写入 CSV coords_list = [] total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -795,6 +781,7 @@ def main(args): disable=not accelerator.is_local_main_process, ) + # 训练循环:每步完成 latent 构造、噪声添加、UNet 预测、loss 计算与反传更新 for epoch in range(0, args.num_train_epochs): unet.train() if args.train_text_encoder: @@ -802,14 +789,17 @@ def main(args): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 输入图像对齐 dtype,避免混精下出现不必要的类型转换 pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + # 使用 VAE 将图像编码到 latent 空间,若无 VAE 则直接在像素空间训练 if vae is not None: model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() model_input = model_input * vae.config.scaling_factor else: model_input = pixel_values + # 采样噪声,并可选叠加 offset noise 以改变噪声分布形态 if args.offset_noise: noise = torch.randn_like(model_input) + 0.1 * torch.randn( model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device @@ -817,13 +807,16 @@ def main(args): else: noise = torch.randn_like(model_input) + # 为每个样本随机选择一个扩散时间步 bsz = model_input.shape[0] timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device ).long() + # 前向扩散:给输入加噪声,形成 UNet 的输入 noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # 文本编码:得到条件嵌入,用于指导 UNet 的去噪方向 encoder_hidden_states = encode_prompt( text_encoder, batch["input_ids"], @@ -831,11 +824,14 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) + # UNet 输出噪声预测(或速度预测),返回的第一个元素为预测张量 model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, return_dict=False)[0] + # 某些模型会同时预测方差,将通道拆分后仅保留噪声相关部分参与训练 if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) + # 根据 scheduler 的 prediction_type 构造监督目标 if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -843,11 +839,13 @@ def main(args): else: raise ValueError(f"未知 prediction_type:{noise_scheduler.config.prediction_type}") + # 先验保持:batch 被拼接为 instance+class,因此这里按 batch 维拆开分别计算 prior loss if args.with_prior_preservation: model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + # loss:默认 MSE;若提供 snr_gamma 则用 SNR 加权以平衡不同时间步的贡献 if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: @@ -863,6 +861,7 @@ def main(args): if args.with_prior_preservation: loss = loss + args.prior_loss_weight * prior_loss + # 训练轨迹记录:用模型输出统计量作为特征指标,配合 loss 形成三维轨迹 if args.coords_save_path is not None: X_i_feature_norm = torch.norm(model_pred.detach().float(), p=2, dim=[1, 2, 3]).mean().item() Y_i_feature_var = torch.var(model_pred.detach().float()).item() @@ -882,8 +881,10 @@ def main(args): df.to_csv(save_file_path, index=False) logger.info(f"坐标已写入:{save_file_path}") + # 反向传播:accelerate 负责混精与分布式同步 accelerator.backward(loss) + # 梯度裁剪:仅在同步梯度时执行,避免对未同步的局部梯度产生偏差 if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -892,18 +893,22 @@ def main(args): ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # 参数更新:optimizer 与 scheduler 逐步推进 optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) + # 只有在完成一次“真实更新”后才推进 global_step 与进度条 if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + # 按固定步数保存训练状态,用于断点恢复或中途回滚 if accelerator.is_main_process and global_step % args.checkpointing_steps == 0: accelerator.save_state(args.output_dir) logger.info(f"已保存训练状态到:{args.output_dir}") + # 每步记录 loss 与 lr,便于在 dashboard 中观察收敛曲线 logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -915,6 +920,7 @@ def main(args): images = [] if accelerator.is_main_process: + # 将训练好的组件写成独立 pipeline,确保 output_dir 可直接用于推理 pipeline_args = {} if not args.skip_save_text_encoder: pipeline_args["text_encoder"] = unwrap_model(text_encoder) @@ -930,6 +936,7 @@ def main(args): ) pipeline.save_pretrained(args.output_dir) + # 释放训练对象,减少后续 validation 的显存压力 del unet del optimizer del lr_scheduler @@ -940,6 +947,7 @@ def main(args): gc.collect() torch.cuda.empty_cache() + # 训练结束后运行一次 txt2img 验证,并将结果保存到指定目录 images = run_validation_txt2img( finetuned_model_dir=args.output_dir, prompt=args.validation_prompt, @@ -959,6 +967,7 @@ def main(args): image.save(out_dir / f"validation_image_{i}.png") logger.info(f"validation 图像已保存到:{out_dir}") + # 推送到 Hub:写模型卡并上传 output_dir(忽略 step/epoch 目录) if args.push_to_hub: save_model_card( repo_id, @@ -976,6 +985,7 @@ def main(args): ignore_patterns=["step_*", "epoch_*"], ) + # 训练结束后再落一次坐标,保证最后一段数据不会因日志频率而遗漏 if args.coords_save_path is not None and coords_list: df = pd.DataFrame(coords_list, columns=["step", "X_Feature_L2_Norm", "Y_Feature_Variance", "Z_LDM_Loss"]) save_file_path = Path(args.coords_save_path) diff --git a/src/backend/app/algorithms/finetune/train_lora_gen_trace.py b/src/backend/app/algorithms/finetune/train_lora_gen_trace.py index e474846..a2cfce6 100644 --- a/src/backend/app/algorithms/finetune/train_lora_gen_trace.py +++ b/src/backend/app/algorithms/finetune/train_lora_gen_trace.py @@ -1,18 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - import argparse import copy import gc @@ -65,16 +50,14 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - +# 可选启用 wandb 记录,未安装时不影响训练主流程 if is_wandb_available(): import wandb -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -# check_min_version("0.30.0.dev0") - logger = get_logger(__name__) +# 保存 LoRA 权重的模型卡与样例图,便于在 Hub 页面展示训练效果 def save_model_card( repo_id: str, images=None, @@ -85,6 +68,7 @@ def save_model_card( pipeline: DiffusionPipeline = None, ): img_str = "" + # 将采样图写入输出目录,并在 README 中引用,便于浏览器直接查看 for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f"![img_{i}](./image_{i}.png)\n" @@ -116,6 +100,7 @@ LoRA for the text encoder was enabled: {train_text_encoder}. model_card.save(os.path.join(repo_folder, "README.md")) +# 训练中/训练后验证:重建 scheduler,并按指定 prompt 生成图片,用于观察训练进度 def log_validation( pipeline, args, @@ -128,26 +113,26 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} + # 若 scheduler 预测方差为 learned 类型,验证时改为 fixed_small,避免推理阶段要求额外输出 + scheduler_args = {} if "variance_type" in pipeline.scheduler.config: variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + # 使用 DPM-Solver 加速推理,同时保持与 pipeline 现有配置兼容 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) # disable safety checker + pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) - # run inference + # generator 用于控制采样随机性,设置 seed 可保证同一轮验证复现 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # 若提供 validation_images 则走 img2img 类路径,否则走纯 txt2img if args.validation_images is None: images = [] for _ in range(args.num_validation_images): @@ -162,6 +147,7 @@ def log_validation( image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) + # 将验证结果写入 tracker,便于对比不同 epoch 的生成质量变化 for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" if tracker.name == "tensorboard": @@ -176,12 +162,14 @@ def log_validation( } ) + # 推理结束后释放 pipeline,避免与训练过程竞争显存 del pipeline torch.cuda.empty_cache() return images +# 动态识别文本编码器类型,避免脚本仅能适配固定架构 def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -192,22 +180,22 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st if model_class == "CLIPTextModel": from transformers import CLIPTextModel - return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation - return RobertaSeriesModelWithTransformation elif model_class == "T5EncoderModel": from transformers import T5EncoderModel - return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") +# 参数解析:包含训练、先验保持、验证、断点恢复以及可视化轨迹记录的全部参数 def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") + + # 预训练模型与 tokenizer 配置 parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -234,6 +222,8 @@ def parse_args(input_args=None): default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) + + # 数据路径与 prompt 配置 parser.add_argument( "--instance_data_dir", type=str, @@ -261,6 +251,8 @@ def parse_args(input_args=None): default=None, help="The prompt to specify images in the same class as provided instance images.", ) + + # 验证相关配置(训练中按 epoch 触发) parser.add_argument( "--validation_prompt", type=str, @@ -282,6 +274,8 @@ def parse_args(input_args=None): " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) + + # 先验保持配置:用于抑制过拟合并保持类分布 parser.add_argument( "--with_prior_preservation", default=False, @@ -298,13 +292,14 @@ def parse_args(input_args=None): " class_data_dir, additional images will be sampled with class_prompt." ), ) + + # 输出目录与验证图片保存目录 parser.add_argument( "--output_dir", type=str, default="lora-dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( "--validation_image_output_dir", type=str, @@ -312,6 +307,7 @@ def parse_args(input_args=None): help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", ) + # 随机种子与输入分辨率预处理 parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", @@ -331,17 +327,17 @@ def parse_args(input_args=None): " cropped. The images will be resized to the resolution first before cropping." ), ) + + # 是否训练 text encoder 的 LoRA parser.add_argument( "--train_text_encoder", action="store_true", help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." - ) + + # batch 与训练时长配置 + parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", @@ -349,6 +345,8 @@ def parse_args(input_args=None): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) + + # 断点保存与恢复配置 parser.add_argument( "--checkpointing_steps", type=int, @@ -359,12 +357,7 @@ def parse_args(input_args=None): " training using `--resume_from_checkpoint`." ), ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) + parser.add_argument("--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.")) parser.add_argument( "--resume_from_checkpoint", type=str, @@ -374,208 +367,92 @@ def parse_args(input_args=None): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) + + # 分布式与显存优化配置 + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.") + + # 学习率与 scheduler 配置 + parser.add_argument("--learning_rate", type=float, default=5e-4, help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.") parser.add_argument( "--lr_scheduler", type=str, default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]'), ) + parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.") parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) + + # dataloader 与优化器配置 + parser.add_argument("--dataloader_num_workers", type=int, default=0, help=("Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.")) + parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + + # Hub 上传配置 parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) + parser.add_argument("--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.") + + # 日志与混精配置 + parser.add_argument("--logging_dir", type=str, default="logs", help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.")) + parser.add_argument("--allow_tf32", action="store_true", help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices")) + parser.add_argument("--report_to", type=str, default="tensorboard", help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"` (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.')) + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16).")) + parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"], help=("Choose prior generation precision between fp32, fp16 and bf16 (bfloat16).")) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) - parser.add_argument( - "--pre_compute_text_embeddings", - action="store_true", - help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", - ) - parser.add_argument( - "--tokenizer_max_length", - type=int, - default=None, - required=False, - help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", - ) - parser.add_argument( - "--text_encoder_use_attention_mask", - action="store_true", - required=False, - help="Whether to use attention mask for the text encoder", - ) - parser.add_argument( - "--validation_images", - required=False, - default=None, - nargs="+", - help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", - ) - parser.add_argument( - "--class_labels_conditioning", - required=False, - default=None, - help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", - ) - parser.add_argument( - "--rank", - type=int, - default=4, - help=("The dimension of the LoRA update matrices."), - ) - # [START] 为可视化方案增加的参数定义 - parser.add_argument( - "--positions_save_path", - type=str, - default=None, - help="保存3D可视化坐标数据的路径 (X: LoRA权重L2范数, Y: 总梯度L2范数, Z: LDM损失)。", - ) - parser.add_argument( - "--coords_log_interval", - type=int, - default=25, - help="保存坐标数据的步数间隔。", - ) - # [END] 为可视化方案增加的参数定义 + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + + # 预计算文本嵌入:节省训练时 text encoder 显存,但不能与 train_text_encoder 同时开启 + parser.add_argument("--pre_compute_text_embeddings", action="store_true", help="Whether or not to pre-compute text embeddings. This is not compatible with `--train_text_encoder`.") + parser.add_argument("--tokenizer_max_length", type=int, default=None, required=False, help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.") + parser.add_argument("--text_encoder_use_attention_mask", action="store_true", required=False, help="Whether to use attention mask for the text encoder") + parser.add_argument("--validation_images", required=False, default=None, nargs="+", help="Optional set of images to use for validation.") + parser.add_argument("--class_labels_conditioning", required=False, default=None, help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.") + + # LoRA rank:决定 LoRA 低秩矩阵的维度与可训练参数量 + parser.add_argument("--rank", type=int, default=4, help=("The dimension of the LoRA update matrices.")) + + # 训练轨迹可视化:记录 (X: LoRA 权重范数, Y: 梯度范数, Z: loss) 并实时保存为 CSV + parser.add_argument("--positions_save_path", type=str, default=None, help="保存3D可视化坐标数据的路径 (X: LoRA权重L2范数, Y: 总梯度L2范数, Z: LDM损失)。") + parser.add_argument("--coords_log_interval", type=int, default=25, help="保存坐标数据的步数间隔。") if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() + # 兼容 accelerate 启动时写入的 LOCAL_RANK 环境变量 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + # 先验保持开启时必须提供 class 数据与 prompt if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: - # logger is not available yet if args.class_data_dir is not None: warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + # 预计算文本嵌入时 text encoder 不参与训练,因此与 train_text_encoder 冲突 if args.train_text_encoder and args.pre_compute_text_embeddings: raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") return args +# 数据集:准备 instance/class 图片与 prompt token,并支持直接使用预计算的 encoder_hidden_states class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ - def __init__( self, instance_data_root, @@ -601,11 +478,13 @@ class DreamBoothDataset(Dataset): if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") + # 路径列表用于循环采样,长度与 instance 数量对齐 self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images + # 先验保持时加载 class 图片,并把长度设为最大值以便 instance/class 都能覆盖采样 if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) @@ -619,6 +498,7 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None + # 图像预处理:resize + crop + 归一化到 [-1, 1] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), @@ -633,13 +513,15 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} + + # 读取 instance 图片并处理方向信息 instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) - if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) + # 若已预计算 text embeddings,则直接使用 embedding 作为条件输入 if self.encoder_hidden_states is not None: example["instance_prompt_ids"] = self.encoder_hidden_states else: @@ -649,10 +531,10 @@ class DreamBoothDataset(Dataset): example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask + # 先验保持时额外返回 class 图像与 class prompt token/embedding if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) - if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) @@ -669,6 +551,7 @@ class DreamBoothDataset(Dataset): return example +# batch 拼接:将样本列表合并为批量输入,并在先验保持时拼接 instance/class 以减少前向次数 def collate_fn(examples, with_prior_preservation=False): has_attention_mask = "instance_attention_mask" in examples[0] @@ -678,8 +561,6 @@ def collate_fn(examples, with_prior_preservation=False): if has_attention_mask: attention_mask = [example["instance_attention_mask"] for example in examples] - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -688,23 +569,19 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.cat(input_ids, dim=0) - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - } + batch = {"input_ids": input_ids, "pixel_values": pixel_values} + # 注意:原始逻辑保持不动,这里 attention_mask 仍按原实现方式返回 if has_attention_mask: batch["attention_mask"] = attention_mask return batch +# class 采样数据集:用于在先验保持下生成 class 图像时分配 prompt 与 index class PromptDataset(Dataset): - """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" - def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples @@ -719,6 +596,7 @@ class PromptDataset(Dataset): return example +# 分词:将 prompt 转为固定长度 input_ids 与 attention_mask def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): if tokenizer_max_length is not None: max_length = tokenizer_max_length @@ -736,6 +614,7 @@ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): return text_inputs +# 文本编码:输出 prompt_embeds,用作 UNet 条件输入 def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): text_input_ids = input_ids.to(text_encoder.device) @@ -754,15 +633,17 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte return prompt_embeds +# 主训练流程:加载 base 模型,注入 LoRA,训练时仅更新 LoRA 参数,并可定期保存与验证 def main(args): + # 避免 token 通过 wandb 等外部日志系统泄露 if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) + # accelerate 配置:统一 output_dir 与 logging_dir,便于分布式与断点恢复 logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( @@ -772,24 +653,23 @@ def main(args): project_config=accelerator_project_config, ) - # Disable AMP for MPS. + # MPS 下关闭 AMP,避免混精行为导致的兼容性问题 if torch.backends.mps.is_available(): accelerator.native_amp = False + # wandb 作为可选依赖,仅在用户显式启用时要求安装 if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate - # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + # 分布式 + 训练 text encoder 时暂不支持梯度累积,避免 accumulate 逻辑与同步不一致 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - # Make one log on every process with the configuration for debugging. + # 初始化日志系统,并根据主进程/非主进程设置不同 verbosity logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -803,11 +683,11 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() - # If passed along, set the training seed now. + # 设置随机种子,保证噪声采样与验证输出可复现 if args.seed is not None: set_seed(args.seed) - # Generate class images if prior preservation is enabled. + # 先验保持:若 class 图片不足则用 base pipeline 批量采样补齐,并保存到 class_data_dir if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): @@ -822,6 +702,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -836,15 +717,11 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): + for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" @@ -854,7 +731,7 @@ def main(args): if torch.cuda.is_available(): torch.cuda.empty_cache() - # Handle the repository creation + # 创建输出目录与 Hub 仓库,只在主进程执行避免冲突 if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) @@ -864,7 +741,7 @@ def main(args): repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id - # Load the tokenizer + # tokenizer 加载:从 tokenizer_name 或模型目录的 tokenizer 子目录读取 if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: @@ -875,10 +752,8 @@ def main(args): use_fast=False, ) - # import correct text encoder class + # 组件加载:scheduler、text encoder、可选 VAE、UNet text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - - # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant @@ -888,38 +763,34 @@ def main(args): args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) except OSError: - # IF does not have a VAE so let's just set it to None - # We don't have to error out here vae = None unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - # We only train the additional adapter LoRA layers + # LoRA 训练:冻结 base 权重,只让 LoRA adapter 参数参与训练 if vae is not None: vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. + # weight_dtype:非可训练参数通常 cast 到混精 dtype,节省显存 weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) if vae is not None: vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + # xformers:可选开启以降低注意力显存占用 if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers - xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warning( @@ -929,12 +800,13 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + # gradient checkpointing:减少激活保存,适合大模型训练 if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder.gradient_checkpointing_enable() - # now we will add new LoRA weights to the attention layers + # 向 UNet 注意力层添加 LoRA adapter,target_modules 指定注入的线性层名称 unet_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, @@ -943,7 +815,7 @@ def main(args): ) unet.add_adapter(unet_lora_config) - # The text encoder comes from 🤗 transformers, we will also attach adapters to it. + # 如需训练 text encoder 的 LoRA,则对其注意力投影层也注入 adapter if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, @@ -953,16 +825,15 @@ def main(args): ) text_encoder.add_adapter(text_lora_config) + # unwrap:用于保存与推理时拿到未包装的真实模型对象 def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + # 自定义 hook:让 accelerator.save_state 用 LoRA 规范格式保存权重,便于后续加载 def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: - # there are only two options here. Either are just the unet attn processor layers - # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None @@ -975,8 +846,6 @@ def main(args): ) else: raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again weights.pop() LoraLoaderMixin.save_lora_weights( @@ -985,13 +854,13 @@ def main(args): text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) + # 自定义 hook:断点恢复时读取 lora_state_dict 并写入 UNet/TextEncoder 的 adapter 参数 def load_model_hook(models, input_dir): unet_ = None text_encoder_ = None while len(models) > 0: model = models.pop() - if isinstance(model, type(unwrap_model(unet))): unet_ = model elif isinstance(model, type(unwrap_model(text_encoder))): @@ -1006,7 +875,6 @@ def main(args): incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: - # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( @@ -1017,53 +885,44 @@ def main(args): if args.train_text_encoder: _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + # fp16 训练时,确保可训练的 LoRA 参数 upcast 到 fp32,提升数值稳定性 if args.mixed_precision == "fp16": models = [unet_] if args.train_text_encoder: models.append(text_encoder_) - - # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + # TF32 可选开启以提升矩阵乘速度 if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + # 按总 batch 规模放大学习率,保持等效训练强度 if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - # Make sure the trainable params are in float32. + # fp16 场景下显式将 LoRA 可训练参数 cast 为 fp32,减少训练发散风险 if args.mixed_precision == "fp16": models = [unet] if args.train_text_encoder: models.append(text_encoder) - - # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # 优化器:可选 8-bit Adam 以降低显存占用 if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW - # Optimizer creation + # 只收集 requires_grad=True 的参数,确保训练仅更新 LoRA adapter params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) @@ -1076,6 +935,7 @@ def main(args): eps=args.adam_epsilon, ) + # 预计算文本嵌入:把 prompt_embeds 预先算好,训练时不再保留 text encoder,释放显存 if args.pre_compute_text_embeddings: def compute_text_embeddings(prompt): @@ -1087,7 +947,6 @@ def main(args): text_inputs.attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - return prompt_embeds pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) @@ -1114,7 +973,7 @@ def main(args): validation_prompt_negative_prompt_embeds = None pre_computed_class_prompt_encoder_hidden_states = None - # Dataset and DataLoaders creation: + # 数据集与 dataloader:按是否先验保持决定是否加载 class 数据 train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, @@ -1137,13 +996,14 @@ def main(args): num_workers=args.dataloader_num_workers, ) - # Scheduler and math around the number of training steps. + # 推导训练步数:若未显式指定 max_train_steps,则由 epoch 与 dataloader 决定 overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + # scheduler:由总训练步数与 warmup 配置决定学习率变化曲线 lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, @@ -1153,7 +1013,7 @@ def main(args): power=args.lr_power, ) - # Prepare everything with our `accelerator`. + # accelerate.prepare:把训练对象放入分布式与混精管理 if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -1163,23 +1023,19 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler ) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # 重新计算 epoch 数,保证在 prepare 后的 dataloader 规模变化下仍然一致 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + # 初始化 tracker:记录训练配置,便于回溯参数与对比实验 if accelerator.is_main_process: tracker_config = vars(copy.deepcopy(args)) tracker_config.pop("validation_images") accelerator.init_trackers("dreambooth-lora", config=tracker_config) - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") @@ -1191,28 +1047,23 @@ def main(args): global_step = 0 first_epoch = 0 - # Potentially load in the weights and states from a previous save + # 断点恢复:本实现以 output_dir 作为 resume_path,加载 accelerator 保存的训练状态 if args.resume_from_checkpoint: resume_path = args.output_dir - try: accelerator.print(f"Resuming from checkpoint at {resume_path}") accelerator.load_state(resume_path) - - # After loading state, `accelerator` updates its internal state including `step` and `epoch` + initial_global_step = accelerator.state.global_step global_step = initial_global_step - - # Recalculate first_epoch based on the loaded global_step + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) first_epoch = global_step // num_update_steps_per_epoch accelerator.print(f"Resumed at global step {global_step} and epoch {first_epoch}") - + except Exception as e: - accelerator.print( - f"Could not load state from '{resume_path}'. Starting a new training run. Error: {e}" - ) + accelerator.print(f"Could not load state from '{resume_path}'. Starting a new training run. Error: {e}") args.resume_from_checkpoint = None initial_global_step = 0 first_epoch = 0 @@ -1220,52 +1071,45 @@ def main(args): initial_global_step = 0 first_epoch = 0 - # [START] 为可视化方案增加的初始化和导入 + # 坐标记录:用于绘制训练轨迹 (X: LoRA 权重范数, Y: 梯度范数, Z: loss) coords_list = [] if args.positions_save_path is not None: import pandas as pd - logger.info( - f"可视化指标采集已启用。数据将每 {args.coords_log_interval} 步保存一次到 {args.positions_save_path}" - ) - # [END] 为可视化方案增加的初始化和导入 + logger.info(f"可视化指标采集已启用。数据将每 {args.coords_log_interval} 步保存一次到 {args.positions_save_path}") progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", - # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) + # 训练循环:latent 编码、加噪、条件编码、UNet 预测、loss、反传与更新 for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + # 训练目标在 latent 空间时由 VAE 编码,否则在像素空间上直接训练 if vae is not None: - # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor else: model_input = pixel_values - # Sample noise that we'll add to the latents + # 采样噪声与时间步,并做前向扩散构造 noisy 输入 noise = torch.randn_like(model_input) bsz, channels, height, width = model_input.shape - # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) + ).long() noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # Get the text embedding for conditioning + # 条件编码:可选使用预计算 embedding,否则实时编码 token if args.pre_compute_text_embeddings: encoder_hidden_states = batch["input_ids"] else: @@ -1276,15 +1120,17 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) + # 某些 UNet 结构需要把输入通道翻倍,这里按原逻辑拼接输入 if unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + # 可选 class label 条件,这里仅支持将 timesteps 作为 class_labels 输入 if args.class_labels_conditioning == "timesteps": class_labels = timesteps else: class_labels = None - # Predict the noise residual + # UNet 预测噪声残差(或速度),作为训练监督的主输出 model_pred = unet( noisy_model_input, timesteps, @@ -1293,13 +1139,11 @@ def main(args): return_dict=False, )[0] - # if model predicts variance, throw away the prediction. we will only train on the - # simplified training objective. This means that all schedulers using the fine tuned - # model must be configured to use one of the fixed variance variance types. + # 若模型同时预测方差,则丢弃方差通道以匹配训练目标 if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) - # Get the target for loss depending on the prediction type + # 根据 prediction_type 构造监督目标张量 if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -1307,105 +1151,82 @@ def main(args): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + # 先验保持:将拼接 batch 拆分为 instance 与 class 两半并分别计算 loss if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + # 反向传播:只会对 LoRA 参数产生梯度 accelerator.backward(loss) - # [START] 为可视化方案增加的指标采集 - Y轴 (梯度范数) + # 轨迹记录的 Y:反向传播后统计可训练参数的梯度 L2 范数 Y_i = float("nan") if args.positions_save_path is not None: - # Y轴: 总梯度L2范数 (在反向传播之后,优化器更新之前计算) grad_norm_sq = 0.0 for name, p in unet.named_parameters(): - # 只关注需要梯度更新的参数 if p.grad is not None and p.requires_grad: - # 使用 float() 避免 torch.amp 带来的精度问题,确保准确计算L2范数 grad_norm_sq += (p.grad.data.float() ** 2).sum().item() Y_i = math.sqrt(grad_norm_sq) - # [END] 为可视化方案增加的指标采集 - Y轴 (梯度范数) if accelerator.sync_gradients: accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + + # 参数更新:推进 optimizer 与 scheduler optimizer.step() lr_scheduler.step() optimizer.zero_grad() - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - - # [START] 为可视化方案增加的指标采集 - X, Z轴和保存逻辑 + # 轨迹记录的 X/Z:更新后计算 LoRA 权重范数,并记录当前 loss if args.positions_save_path is not None and ( global_step % args.coords_log_interval == 0 or global_step == 1 or global_step == initial_global_step + 1 ): - - # Z轴: LDM 损失 Z_i = loss.detach().item() - - # X轴: 总LoRA权重L2范数 (在优化器更新之后计算) + lora_weight_norm_sq = 0.0 for name, p in unet.named_parameters(): - # 只关注 LoRA 权重参数 ("lora" in name) if "lora" in name and p.requires_grad: lora_weight_norm_sq += (p.data.float() ** 2).sum().item() X_i = math.sqrt(lora_weight_norm_sq) - - # 记录坐标数据 + coords_list.append([global_step, X_i, Y_i, Z_i]) - - # 实时保存到文件 (可选,但为了防止训练中断丢失数据,建议实时保存) - # 每次记录时都覆盖保存,确保文件始终是最新的 + df = pd.DataFrame( coords_list, - columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss'] + columns=["step", "X_LoRA_Weight_Norm", "Y_Grad_Norm", "Z_LDM_Loss"], ) - # 假设 args.positions_save_path 是目标文件路径 (如 ./data/coords.csv) save_path = Path(args.positions_save_path) if not save_path.suffix: save_path = save_path / "coords.csv" save_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_path, index=False) - + if global_step % (args.coords_log_interval * 10) == 0: logger.info( f"Step {global_step}: 已记录并保存可视化坐标 (X={X_i:.4f}, Y={Y_i:.4f}, Z={Z_i:.4f}) 到 {save_path}" ) - # [END] 为可视化方案增加的指标采集 - X, Z轴和保存逻辑 - + # 中途 checkpoint 与验证:保存当前 state 后,重新构建 pipeline 并生成验证图像 if accelerator.is_main_process: if (global_step + 1) % args.checkpointing_steps == 0: - # 1. 保存模型参数:直接保存到 args.output_dir,覆盖上一轮 output_dir = args.output_dir - # accelerator.save_state handles saving the models using the registered hooks accelerator.save_state(output_dir) logger.info(f"Saving state to {output_dir} at step {global_step+1}") - # 2. 推理调用模型:从 args.output_dir 加载最新的模型权重 - # The base pipeline is re-loaded, and the Lora weights are saved *to* args.output_dir - # in the accelerator hook. Here, we must ensure we use the saved unet/text_encoder. pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - # Use the unwrapped models which contain the latest trained LoRA weights unet=unwrap_model(unet), text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), revision=args.revision, @@ -1429,16 +1250,15 @@ def main(args): epoch, ) - # 3. 推理生成结果保存:直接保存到指定目录/output_dir,不创建子文件夹 + # 验证图片保存:直接写入指定目录并覆盖同名文件,方便对比最新结果 base_save_path = Path(args.validation_image_output_dir or args.output_dir) base_save_path.mkdir(parents=True, exist_ok=True) logger.info(f"Saving validation images to {base_save_path}") - # 图片直接保存在 base_save_path,会覆盖上一轮的同名图片 for i, image in enumerate(images): image.save(base_save_path / f"image_{i}.png") - + # 记录当前 step 的 loss 与 lr,便于观察训练曲线 logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -1446,13 +1266,11 @@ def main(args): if global_step >= args.max_train_steps: break - - # Save the lora layers accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = unwrap_model(unet) - unet = unet.to(torch.float32) + # 训练结束保存 LoRA 权重,并可选执行最终验证与上传 + if accelerator.is_main_process: + unet = unwrap_model(unet).to(torch.float32) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: @@ -1467,16 +1285,12 @@ def main(args): text_encoder_lora_layers=text_encoder_state_dict, ) - # Final inference - # Load previous pipeline + # 最终推理:加载 base pipeline 并加载训练好的 LoRA 权重进行验证 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) - - # load attention processors pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") - # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} @@ -1489,6 +1303,7 @@ def main(args): is_final_validation=True, ) + # 上传到 Hub:写模型卡并上传输出目录 if args.push_to_hub: save_model_card( repo_id, @@ -1506,26 +1321,22 @@ def main(args): ignore_patterns=["step_*", "epoch_*"], ) - # [START] 为可视化方案增加的最终保存 (防止最后一步数据没有被保存) + # 训练结束补写一次坐标文件,保证最后一次记录不会遗漏 if args.positions_save_path is not None and coords_list: df = pd.DataFrame( coords_list, - columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss'] + columns=["step", "X_LoRA_Weight_Norm", "Y_Grad_Norm", "Z_LDM_Loss"], ) - - # 假设 args.positions_save_path 是目标文件路径 (如 ./data/coords.csv) save_path = Path(args.positions_save_path) if not save_path.suffix: save_path = save_path / "coords.csv" - save_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_path, index=False) logger.info(f"训练结束:已将所有 {len(coords_list)} 步可视化坐标数据保存到 {save_path}") - # [END] 为可视化方案增加的最终保存 accelerator.end_training() if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_ti_gen_trace.py b/src/backend/app/algorithms/finetune/train_ti_gen_trace.py index 76a5209..1621e31 100644 --- a/src/backend/app/algorithms/finetune/train_ti_gen_trace.py +++ b/src/backend/app/algorithms/finetune/train_ti_gen_trace.py @@ -1,18 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - import argparse import copy import gc @@ -51,13 +36,12 @@ from diffusers import ( StableDiffusionPipeline, UNet2DConditionModel, ) -# Removed LoRA import: from diffusers.loaders import LoraLoaderMixin +# 本脚本只训练 Textual Inversion 的 token embedding,不涉及 LoRA 权重 from diffusers.optimization import get_scheduler from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, - # Removed LoRA import: convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -65,28 +49,23 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module +# wandb 为可选依赖,仅在环境可用时启用 if is_wandb_available(): import wandb -# 说明: -# 1) 本文件用于训练 Textual Inversion(仅训练一个新 token 的向量)。 -# 2) 训练过程冻结 UNet/VAE/TextEncoder 的主体权重,仅更新新 token 对应的 embedding 行。 -# 3) 训练过程会按步保存 embedding,并进行验证推理,用于观察训练效果。 -# 4) 文件还包含可视化坐标采集逻辑:X=特征范数,Y=特征方差,Z=loss,并写入 CSV。 -# 5) 为了保证推理阶段的一致性,验证推理会从基础模型加载,并再加载 learned_embeds.bin 作为增量能力。 +# 训练目标为 Textual Inversion:只学习一个新 token 的 embedding 行 +# 训练过程中冻结 UNet/VAE/TextEncoder 主体参数,只允许 placeholder token 对应的 embedding 更新 +# 训练会周期性保存 learned_embeds.bin 与 tokenizer,并在保存点执行验证推理以观察学习效果 +# 可选记录训练轨迹坐标:(X=UNet 预测特征范数, Y=UNet 预测特征方差, Z=loss) 并写入 CSV logger = get_logger(__name__) def _load_textual_inversion_compat(pipeline: DiffusionPipeline, emb_dir: str, token: str): - """ - 说明: - 1) 不同 diffusers 版本对 load_textual_inversion 的参数命名不一致。 - 2) 有些版本支持 token=...,有些版本支持 tokens=[...],还有些只支持路径。 - 3) 本函数用于在不同版本之间提供兼容调用,优先传入 token 名提高确定性。 - 4) 若当前版本不接受这些参数,会自动降级为仅传路径的调用方式。 - 5) 该函数不会保存或覆盖基础模型文件,只在运行时向 pipeline 注入增量 embedding。 - """ + # 兼容不同 diffusers 版本的 Textual Inversion 加载接口 + # 优先显式指定 token 名,确保加载的 embedding 与 placeholder 对应 + # 若接口参数不兼容则自动降级为只传路径的调用方式 + # 该操作仅在运行时向 pipeline 注入 embedding,不会修改基础模型目录 try: pipeline.load_textual_inversion(emb_dir, token=token) return @@ -112,12 +91,10 @@ def save_model_card( pipeline: DiffusionPipeline = None, placeholder_token: str = None, ): - # 说明: - # 1) 该函数用于生成并保存 README 模型卡片与示例图片,便于上传 Hub 或本地记录。 - # 2) 对于 Textual Inversion,模型文件主要是 learned_embeds.bin 与 tokenizer。 - # 3) 该模型卡会说明训练所用的 placeholder token 与训练 prompt。 - # 4) 生成的图片会被保存在 repo_folder 下,方便查看训练效果。 - # 5) 本函数不会修改模型参数,只做文档与示例资产的持久化。 + # 生成并保存模型卡 README,同时保存示例图片到输出目录 + # Textual Inversion 的核心产物是 learned_embeds.bin 与 tokenizer 增量词表 + # 模型卡用于说明基础模型、训练 prompt 与 placeholder token,便于复现与展示 + # 本函数只写文档与图片文件,不改变任何训练参数或模型权重 img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -156,12 +133,10 @@ def log_validation( epoch, is_final_validation=False, ): - # 说明: - # 1) 该函数用于在训练过程中做验证推理,观察当前 embedding 学到了什么。 - # 2) 会将 scheduler 替换为更适合推理的 DPMSolverMultistepScheduler。 - # 3) 会关闭安全检查器,避免被过滤导致无法看到结果。 - # 4) 既支持纯文生图,也支持某些管线的传图推理(依赖 args.validation_images)。 - # 5) 会把结果写入 tracker(tensorboard/wandb),并释放 GPU 显存。 + # 验证推理:在训练过程中生成样例图,用于观察 embedding 的学习方向 + # 推理阶段使用 DPM-Solver 调度器提升速度,并禁用安全检查器避免结果被过滤 + # 支持纯文本推理与带初始图像的推理形式(由 validation_images 控制) + # 推理结果会写入 tracker(tensorboard/wandb),并在结束后释放显存 logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -219,12 +194,9 @@ def log_validation( def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): - # 说明: - # 1) Stable Diffusion 不同变体可能使用不同的 text encoder 架构。 - # 2) 该函数读取 text_encoder 的配置,判断其 architectures 字段来确定具体类。 - # 3) 常见情况是 CLIPTextModel,也可能是 Roberta 或 T5 系列。 - # 4) 返回的类用于 from_pretrained 加载 text_encoder,保证结构匹配。 - # 5) 如果遇到未知架构会直接报错,避免后续 silent bug。 + # 通过模型配置自动识别 text encoder 的具体架构,并返回对应的实现类 + # 该识别逻辑用于兼容不同的 Stable Diffusion 系列与其他扩散管线变体 + # 若架构不在支持列表中则直接报错,避免训练中途出现不匹配问题 text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -249,12 +221,11 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st def parse_args(input_args=None): - # 说明: - # 1) 该函数定义所有可配置参数,支持命令行调用与被后端服务传参调用。 - # 2) 训练相关参数包含学习率、步数、批大小、混合精度、保存间隔等。 - # 3) Textual Inversion 需要 placeholder_token 与 initializer_token,并且 prompt 必须包含 placeholder。 - # 4) 验证推理参数用于在训练中生成图片,输出到指定目录用于可视化或服务返回。 - # 5) coords_* 参数用于记录 3D 可视化坐标数据,不影响训练但会增加少量开销。 + # 参数解析:定义训练、保存、验证推理、断点恢复与坐标记录所需的全部参数 + # Textual Inversion 需要 placeholder_token 与 initializer_token,并且 instance_prompt 必须包含 placeholder_token + # 训练步数由 max_train_steps 或 num_train_epochs 推导,保存间隔由 checkpointing_steps 控制 + # 验证推理参数决定生成样例图的 prompt、数量与保存目录 + # coords_* 参数用于训练轨迹输出 CSV,不改变训练逻辑,仅增加统计与写盘开销 parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -559,12 +530,11 @@ def parse_args(input_args=None): class DreamBoothDataset(Dataset): - # 说明: - # 1) 该数据集负责读取实例图片,并把图片变换到训练所需的张量格式。 - # 2) 同时会对 instance_prompt 做 tokenizer 编码,生成 input_ids 与 attention_mask。 - # 3) Textual Inversion 不做 prior preservation,因此长度等于实例图片数量。 - # 4) 图像会先 resize 再 crop,并归一化到 [-1,1](Normalize([0.5],[0.5]))。 - # 5) 返回的字典字段会在 collate_fn 中被组装成 batch,供 UNet 前向与损失计算使用。 + # 数据集:负责读取实例图片并做预处理,同时对 instance_prompt 做分词编码 + # 图像会先 resize 再裁剪,并归一化到 [-1, 1],以匹配 Stable Diffusion 的训练输入 + # 每个样本输出 image 张量与 token 张量,字段名与训练循环一致 + # Textual Inversion 不需要 class 数据或先验保持,因此长度等于实例图片数量 + # 本类只准备数据,不参与任何梯度计算与模型更新 def __init__( self, instance_data_root, @@ -610,6 +580,7 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) @@ -625,12 +596,11 @@ class DreamBoothDataset(Dataset): def collate_fn(examples): - # 说明: - # 1) 该函数负责将 Dataset 返回的若干条样本组装成一个 batch。 - # 2) 对图像张量做 stack,得到 (B,C,H,W) 的 pixel_values。 - # 3) 对 token 的 input_ids 做 cat,得到 (B,seq_len) 的输入矩阵。 - # 4) attention_mask 保持与 input_ids 对齐,用于 text encoder 的有效 token 标记。 - # 5) 输出 batch 会被训练循环直接使用,字段命名与后续代码保持一致。 + # batch 拼接:把多条样本打包成训练循环可直接使用的 batch 字典 + # 图像张量 stack 为 (B, C, H, W),并转换为连续内存以提高算子效率 + # input_ids 与 attention_mask 沿 batch 维拼接,保持与 text encoder 的输入格式一致 + # 输出字段命名与训练主循环一致,避免额外适配 + # 本函数不做任何增强或损失计算,只负责规整数据结构 has_attention_mask = "instance_attention_mask" in examples[0] input_ids = [example["instance_prompt_ids"] for example in examples] @@ -656,12 +626,11 @@ def collate_fn(examples): def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): - # 说明: - # 1) 对文本 prompt 做 tokenizer 编码,生成 input_ids 与 attention_mask。 - # 2) 使用固定长度 padding="max_length" 保证 batch 拼接简单一致。 - # 3) truncation=True 防止超过最大长度导致报错。 - # 4) tokenizer_max_length 允许外部指定最大长度;否则使用 tokenizer.model_max_length。 - # 5) 返回 transformers 的 BatchEncoding,后续直接取 input_ids 与 attention_mask 使用即可。 + # 分词编码:把 prompt 转为固定长度 input_ids 与 attention_mask + # truncation=True 防止超长输入报错,padding="max_length" 保持 batch 形状稳定 + # tokenizer_max_length 若提供则覆盖默认长度,否则使用 tokenizer.model_max_length + # 返回 BatchEncoding,训练代码从中读取 input_ids 与 attention_mask + # 此处不引入任何额外逻辑,保证 token 化行为可复现 if tokenizer_max_length is not None: max_length = tokenizer_max_length else: @@ -678,12 +647,11 @@ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): - # 说明: - # 1) 将 token id 输入 Text Encoder,得到用于 UNet 条件输入的 prompt_embeds。 - # 2) 如果启用 attention_mask,会把 mask 一并传入,以减少 padding token 的影响。 - # 3) 输出的 prompt_embeds 通常形状为 (B, seq_len, hidden_dim)。 - # 4) UNet 会把该 embedding 作为 cross-attention 的条件,实现文本引导。 - # 5) 该函数不涉及梯度以外的副作用,embedding 的更新由上层训练流程控制。 + # 文本编码:将 token id 输入 text encoder 得到 prompt_embeds,用作 UNet 条件输入 + # 可选启用 attention_mask,以减少 padding token 对编码结果的影响 + # 输出为 (B, seq_len, hidden_dim),与 UNet cross-attention 的条件维度对齐 + # 本函数不做保存与缓存,训练时是否更新 embedding 由上层控制 + # 为避免设备不一致,输入会被移动到 text_encoder.device text_input_ids = input_ids.to(text_encoder.device) if text_encoder_use_attention_mask: @@ -701,12 +669,11 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte def main(args): - # 说明: - # 1) 主函数负责训练用的全部初始化:accelerate、模型加载、数据集/优化器/调度器。 - # 2) Textual Inversion 的关键是新增一个 placeholder token,并只训练该 token 的 embedding。 - # 3) 训练过程中会定期保存 learned_embeds.bin 与 tokenizer,并执行验证推理输出图片。 - # 4) 验证推理从基础模型加载,再加载 learned_embeds.bin,避免对基础模型权重产生写回影响。 - # 5) 若开启 coords_save_path,会按你原有逻辑采集并保存可视化坐标数据,不改变其行为。 + # 主流程:构建 accelerate 环境,加载基础模型组件,并创建可训练的 placeholder token + # placeholder token 会被加入 tokenizer,并用 initializer token 的 embedding 进行初始化 + # 训练时冻结 UNet/VAE/TextEncoder 的主体权重,仅更新 placeholder token 的 embedding 行 + # 训练循环包含 latent 编码、加噪、条件编码、UNet 预测与 MSE 损失,并进行反向传播更新 + # 训练期间按 checkpointing_steps 保存状态,并用注入 embedding 的 pipeline 做验证推理输出样例图 if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." @@ -810,6 +777,7 @@ def main(args): text_encoder.requires_grad_(False) unet.requires_grad_(False) + # 只对 embedding 层开放梯度,并配合 mask 将可训练范围限制为 placeholder_token 一行 embedding_layer = text_encoder.get_input_embeddings() embedding_layer.weight.requires_grad = True trainable_token_embeds = embedding_layer.weight @@ -849,23 +817,18 @@ def main(args): unet.enable_gradient_checkpointing() def unwrap_model(model): - # 说明: - # 1) accelerate 在分布式或混合精度下会包装模型,保存/取权重时需要先 unwrap。 - # 2) 如果启用 torch.compile,模型会被再次包装,需取 _orig_mod 获取真实模块。 - # 3) 该函数用于在保存 embedding、验证推理、访问模型权重时统一处理。 - # 4) 返回的模型对象是“原始模型”,便于直接访问 embedding 权重与 config。 - # 5) 该函数自身不做任何训练逻辑修改,只是一个安全的模型访问入口。 + # 从 accelerate 包装中取出原始模型,便于保存与访问真实 embedding 权重 + # 对 torch.compile 场景,进一步解包以获得真实模块 + # 本函数不会改变模型状态,只提供统一的访问方式 model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model def save_model_hook(models, weights, output_dir): - # 说明: - # 1) 该 Hook 用于让 accelerate.save_state 保存为 Textual Inversion 需要的最小产物。 - # 2) 主要保存 learned_embeds.bin(仅包含 placeholder_token 对应的 embedding 行)。 - # 3) 同时保存 tokenizer,以便后续复现训练 token 的 id 映射与 tokenizer 配置。 - # 4) 不保存 UNet/VAE/TextEncoder 的完整权重,避免体积巨大且不符合“增量”设计。 - # 5) 保存行为只发生在主进程,避免分布式重复写盘导致文件冲突。 + # accelerate 保存钩子:只保存 Textual Inversion 的最小产物 + # learned_embeds.bin 只包含 placeholder_token 对应的 embedding 行,体积小且易于分发 + # tokenizer 一并保存,用于恢复 token_id 映射与 placeholder_token 的存在性 + # 不保存 UNet/VAE/TextEncoder 全量权重,保持增量训练的设计目标 if accelerator.is_main_process: text_encoder_unwrapped = unwrap_model(text_encoder) trained_embeddings = text_encoder_unwrapped.get_input_embeddings().weight[ @@ -880,12 +843,10 @@ def main(args): tokenizer.save_pretrained(output_dir) def load_model_hook(models, input_dir): - # 说明: - # 1) 该 Hook 用于从 checkpoint 恢复训练时,将 learned_embeds.bin 写回到 text_encoder embedding。 - # 2) 对于 Textual Inversion,恢复的关键是 placeholder_token 对应 embedding 行,而非整个模型。 - # 3) 同时通过 checkpoint 内的 tokenizer 获取 placeholder_token 的 token_id,以保证写入位置一致。 - # 4) 若 checkpoint 缺失 learned_embeds.bin,会打印警告并跳过,允许从头开始训练。 - # 5) 该逻辑只改变当前训练进程内的权重状态,不会修改基础模型目录的文件。 + # accelerate 加载钩子:从 learned_embeds.bin 恢复 placeholder_token 的 embedding 行 + # 通过 checkpoint 内 tokenizer 获取 placeholder_token_id,确保写回位置正确 + # 若文件缺失则跳过恢复,允许从头训练或使用外部初始化 + # 该操作只影响当前训练进程内存中的 embedding,不会修改基础模型目录 text_encoder_ = None while len(models) > 0: @@ -1060,6 +1021,7 @@ def main(args): ) for epoch in range(first_epoch, args.num_train_epochs): + # 训练时 UNet/TextEncoder 保持 train(),但实际只有 embedding 行可更新 unet.train() text_encoder.train() @@ -1067,12 +1029,14 @@ def main(args): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + # 训练输入优先在 latent 空间,提升计算效率并匹配扩散模型训练范式 if vae is not None: model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor else: model_input = pixel_values + # 为每个样本采样时间步与噪声,构造前向扩散后的 noisy 输入 noise = torch.randn_like(model_input) bsz, channels, height, width = model_input.shape @@ -1082,6 +1046,7 @@ def main(args): noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # 文本条件编码:使用当前 text encoder 生成 prompt_embeds,引导 UNet 去噪 encoder_hidden_states = encode_prompt( text_encoder, batch["input_ids"], @@ -1097,6 +1062,7 @@ def main(args): else: class_labels = None + # UNet 预测噪声残差,训练目标是最小化预测与真实噪声(或速度)的均方误差 model_pred = unet( noisy_model_input, timesteps, @@ -1108,6 +1074,7 @@ def main(args): if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) + # 轨迹记录:用 UNet 输出统计量作为 X/Y,配合 loss 形成训练动态曲线 if args.coords_save_path is not None: X_i_feature_norm = torch.norm(model_pred.detach().float(), p=2, dim=[1, 2, 3]).mean().item() Y_i_feature_var = model_pred.detach().float().var(dim=[1, 2, 3]).mean().item() @@ -1130,6 +1097,7 @@ def main(args): lr_scheduler.step() optimizer.zero_grad() + # 每次更新后强制把非 placeholder 的 embedding 恢复为固定值,保证只学习目标 token if accelerator.num_processes > 1: unwrapped_text_encoder = unwrap_model(text_encoder) trainable_embeds = unwrapped_text_encoder.get_input_embeddings().weight @@ -1143,6 +1111,7 @@ def main(args): progress_bar.update(1) global_step += 1 + # 坐标保存:按固定步数把 (X,Y,Z) 追加到列表,并覆盖写入 CSV 以防训练中断丢失 if args.coords_save_path is not None and ( global_step % args.coords_log_interval == 0 or global_step == 1 @@ -1167,6 +1136,7 @@ def main(args): f"Step {global_step}: 已记录并保存可视化坐标 (X={X_i_feature_norm:.4f}, Y={Y_i_feature_var:.4f}, Z={Z_i:.4f}) 到 {save_file_path}" ) + # checkpoint:保存训练状态,并用基础模型 + 注入 embedding 的方式生成验证图像 if accelerator.is_main_process: if (global_step + 1) % args.checkpointing_steps == 0: output_dir = args.output_dir @@ -1213,7 +1183,10 @@ def main(args): break accelerator.wait_for_everyone() + if accelerator.is_main_process: + # 训练结束保存最终产物:learned_embeds.bin 与 tokenizer + # learned_embeds.bin 只包含 placeholder token 的 embedding 行,用于后续推理时注入到基础模型 text_encoder = unwrap_model(text_encoder) trained_embeddings = text_encoder.get_input_embeddings().weight[ @@ -1226,6 +1199,7 @@ def main(args): torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) tokenizer.save_pretrained(args.output_dir) + # 最终验证:重新加载基础模型并注入 embedding,生成样例图用于输出与模型卡展示 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, @@ -1265,6 +1239,7 @@ def main(args): ignore_patterns=["step_*", "epoch_*"], ) + # 训练结束补写一次坐标文件,确保最后阶段数据不会遗漏 if args.coords_save_path is not None and coords_list: df = pd.DataFrame( coords_list, diff --git a/src/backend/app/algorithms/perturbation/aspl.py b/src/backend/app/algorithms/perturbation/aspl.py index 6f26194..8ee7943 100644 --- a/src/backend/app/algorithms/perturbation/aspl.py +++ b/src/backend/app/algorithms/perturbation/aspl.py @@ -29,7 +29,7 @@ logger = get_logger(__name__) class DreamBoothDatasetFromTensor(Dataset): - """Just like DreamBoothDataset, but take instance_images_tensor instead of path""" + """基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。""" def __init__( self, @@ -41,15 +41,18 @@ class DreamBoothDatasetFromTensor(Dataset): size=512, center_crop=False, ): + # 保存图像处理参数与 tokenizer self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + # 实例数据:直接来自传入的张量列表 self.instance_images_tensor = instance_images_tensor self.num_instance_images = len(self.instance_images_tensor) self.instance_prompt = instance_prompt self._length = self.num_instance_images + # 可选类数据:用于先验保持,长度取实例与类数据的最大值 if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) @@ -60,6 +63,7 @@ class DreamBoothDatasetFromTensor(Dataset): else: self.class_data_root = None + # 统一的图像预处理 self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), @@ -73,6 +77,7 @@ class DreamBoothDatasetFromTensor(Dataset): return self._length def __getitem__(self, index): + # 取出实例图像张量与对应 prompt token example = {} instance_image = self.instance_images_tensor[index % self.num_instance_images] example["instance_images"] = instance_image @@ -84,6 +89,7 @@ class DreamBoothDatasetFromTensor(Dataset): return_tensors="pt", ).input_ids + # 若有类数据,则同时返回类图像与类 prompt token if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": @@ -101,6 +107,7 @@ class DreamBoothDatasetFromTensor(Dataset): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + # 根据 text_encoder 配置识别其架构,选择正确的模型类 text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -121,6 +128,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st def parse_args(input_args=None): + # 解析命令行参数:模型路径、数据路径、对抗参数、先验保持、训练与日志配置 parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -341,7 +349,7 @@ def parse_args(input_args=None): class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -358,6 +366,7 @@ class PromptDataset(Dataset): def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: + # 读取目录下所有图片,按训练要求 resize/crop/normalize,返回堆叠后的张量 image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), @@ -381,8 +390,7 @@ def train_one_epoch( data_tensor: torch.Tensor, num_steps=20, ): - # Load the tokenizer - + # 单轮训练:复制当前模型,使用给定数据迭代若干步,返回更新后的副本 unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -404,7 +412,6 @@ def train_one_epoch( args.center_crop, ) - # weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -416,6 +423,7 @@ def train_one_epoch( unet.train() text_encoder.train() + # 构造当前步的样本(instance + class),并生成文本 token step_data = train_dataset[step % len(train_dataset)] pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( device, dtype=weight_dtype @@ -425,24 +433,20 @@ def train_one_epoch( latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents + # 随机采样时间步并加噪,模拟正向扩散 noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning + # 文本条件编码 encoder_hidden_states = text_encoder(input_ids)[0] - # Predict the noise residual + # UNet 预测噪声 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type + # 根据 scheduler 的预测类型选择目标 if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -450,18 +454,13 @@ def train_one_epoch( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - # with prior preservation loss + # 可选先验保持:拆分 instance 与 class 部分分别计算 MSE if args.with_prior_preservation: model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. loss = instance_loss + args.prior_loss_weight * prior_loss else: @@ -489,7 +488,7 @@ def pgd_attack( target_tensor: torch.Tensor, num_steps: int, ): - """Return new perturbed data""" + """PGD 对抗扰动:在噪声预算内迭代更新输入,返回新的扰动数据。""" unet, text_encoder = models weight_dtype = torch.bfloat16 @@ -515,24 +514,18 @@ def pgd_attack( latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents + # 采样时间步并加噪,准备 UNet 预测 noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - #noise_scheduler.config.num_train_timesteps timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning + # 文本条件与噪声预测 encoder_hidden_states = text_encoder(input_ids.to(device))[0] - - # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type + # 目标噪声或速度 if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -544,7 +537,7 @@ def pgd_attack( text_encoder.zero_grad() loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # target-shift loss + # 若有目标图像 latent,加入目标对齐项(保持原有逻辑:损失为差值) if target_tensor is not None: xtm1_pred = torch.cat( [ @@ -561,6 +554,7 @@ def pgd_attack( loss.backward() + # PGD 更新并投影到 eps 球内,再裁剪到 [-1, 1] alpha = args.pgd_alpha eps = args.pgd_eps / 255 @@ -598,7 +592,7 @@ def main(args): if args.seed is not None: set_seed(args.seed) - # Generate class images if prior preservation is enabled. + # 先验保持:不足的 class 图像用基础模型生成补齐 if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): @@ -645,10 +639,9 @@ def main(args): if torch.cuda.is_available(): torch.cuda.empty_cache() - # import correct text encoder class + # 加载 text encoder / UNet / tokenizer / scheduler / VAE text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - # Load scheduler and models text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", @@ -712,9 +705,9 @@ def main(args): ) target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() + # 交替流程:训练 surrogate -> PGD 扰动 -> 用扰动数据再训练主模型,周期性导出对抗样本 f = [unet, text_encoder] for i in range(args.max_train_steps): - # 1. f' = f.clone() f_sur = copy.deepcopy(f) f_sur = train_one_epoch( args, @@ -746,6 +739,7 @@ def main(args): args.max_f_train_steps, ) + # 周期保存当前扰动图像,便于后续评估与复现 if (i + 1) % args.checkpointing_iterations == 0: save_folder = args.output_dir os.makedirs(save_folder, exist_ok=True) diff --git a/src/backend/app/algorithms/perturbation/caat.py b/src/backend/app/algorithms/perturbation/caat.py index c7e41cd..d399e7a 100644 --- a/src/backend/app/algorithms/perturbation/caat.py +++ b/src/backend/app/algorithms/perturbation/caat.py @@ -41,11 +41,13 @@ logger = get_logger(__name__) def freeze_params(params): + """冻结一组参数的梯度开关,使其在训练中保持不更新。""" for param in params: param.requires_grad = False def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + """从预训练目录读取 text_encoder 配置,自动选择匹配的文本编码器实现类。""" text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -65,7 +67,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """用于批量生成 class 图像的 prompt 数据集,便于采样阶段在多卡上分发任务。""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -83,8 +85,10 @@ class PromptDataset(Dataset): class CustomDiffusionDataset(Dataset): """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + CAAT/Custom Diffusion 训练数据集。 + + 负责读取实例图像与可选的类图像,并为每张图像生成对应的 prompt token。 + 同时在实例图像上生成有效区域 mask,供训练时对 loss 做空间加权。 """ def __init__( @@ -99,6 +103,7 @@ class CustomDiffusionDataset(Dataset): hflip=False, aug=True, ): + # 训练图像与 mask 的目标尺寸 self.size = size self.mask_size = mask_size self.center_crop = center_crop @@ -106,6 +111,7 @@ class CustomDiffusionDataset(Dataset): self.interpolation = Image.BILINEAR self.aug = aug + # 记录实例与类数据的路径及对应 prompt self.instance_images_path = [] self.class_images_path = [] self.with_prior_preservation = with_prior_preservation @@ -115,6 +121,7 @@ class CustomDiffusionDataset(Dataset): ] self.instance_images_path.extend(inst_img_path) + # 启用先验保持时,额外读取 class 图像与 class prompt if with_prior_preservation: class_data_root = Path(concept["class_data_dir"]) if os.path.isdir(class_data_root): @@ -129,12 +136,16 @@ class CustomDiffusionDataset(Dataset): class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] self.class_images_path.extend(class_img_path[:num_class_images]) + # 打乱实例顺序以增加训练随机性,并确定数据集长度 random.shuffle(self.instance_images_path) self.num_instance_images = len(self.instance_images_path) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) + + # 可选水平翻转增强 self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) + # 类图像走标准 transforms;实例图像会走自定义 preprocess 以生成 mask self.image_transforms = transforms.Compose( [ self.flip, @@ -149,6 +160,7 @@ class CustomDiffusionDataset(Dataset): return self._length def preprocess(self, image, scale, resample): + """对实例图像做缩放与随机放置,并生成对应的有效区域 mask。""" outer, inner = self.size, scale factor = self.size // self.mask_size if scale > self.size: @@ -171,13 +183,15 @@ class CustomDiffusionDataset(Dataset): def __getitem__(self, index): example = {} + + # 读取实例图像与对应 prompt instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images] instance_image = Image.open(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") instance_image = self.flip(instance_image) - # apply resize augmentation and create a valid image region mask + # 对实例图像做随机缩放增强,并生成有效区域 mask random_scale = self.size if self.aug: random_scale = ( @@ -187,11 +201,13 @@ class CustomDiffusionDataset(Dataset): ) instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation) + # 根据缩放幅度对 prompt 加入轻量描述,模拟尺度变化的语义提示 if random_scale < 0.6 * self.size: instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt elif random_scale > self.size: instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt + # 实例图像与 mask 进入训练:图像已归一化到 [-1, 1] example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1) example["mask"] = torch.from_numpy(mask) example["instance_prompt_ids"] = self.tokenizer( @@ -202,6 +218,7 @@ class CustomDiffusionDataset(Dataset): return_tensors="pt", ).input_ids + # 先验保持:追加 class 图像、class mask 与 class prompt token if self.with_prior_preservation: class_image, class_prompt = self.class_images_path[index % self.num_class_images] class_image = Image.open(class_image) @@ -222,6 +239,7 @@ class CustomDiffusionDataset(Dataset): def parse_args(input_args=None): + """解析 CAAT 训练参数:包含 PGD 超参、数据与模型路径、训练步数与优化器设置。""" parser = argparse.ArgumentParser(description="CAAT training script.") parser.add_argument( "--alpha", @@ -494,6 +512,7 @@ def parse_args(input_args=None): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + # 先验保持模式需要 class 数据与 prompt;多概念模式下由 concepts_list 提供 if args.with_prior_preservation: if args.concepts_list is None: if args.class_data_dir is None: @@ -501,7 +520,6 @@ def parse_args(input_args=None): if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: - # logger is not available yet if args.class_data_dir is not None: warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: @@ -511,8 +529,8 @@ def parse_args(input_args=None): def main(args): + # 初始化 accelerate 环境与日志目录 logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( @@ -534,11 +552,14 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() + # 记录实验配置到 tracker,便于后续复现实验 accelerator.init_trackers("CAAT", config=vars(args)) - # If passed along, set the training seed now. + # 固定随机种子以提高可复现性 if args.seed is not None: set_seed(args.seed) + + # 将单概念参数统一封装为 concepts_list,或从 json 中读取多概念配置 if args.concepts_list is None: args.concepts_list = [ { @@ -552,7 +573,7 @@ def main(args): with open(args.concepts_list, "r") as f: args.concepts_list = json.load(f) - # Generate class images if prior preservation is enabled. + # 启用先验保持时,若 class 图像不足则使用基础模型补齐 if args.with_prior_preservation: for i, concept in enumerate(args.concepts_list): class_images_dir = Path(concept["class_data_dir"]) @@ -604,10 +625,11 @@ def main(args): if torch.cuda.is_available(): torch.cuda.empty_cache() + # 创建输出目录 if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load the tokenizer + # 加载 tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name, @@ -622,10 +644,8 @@ def main(args): use_fast=False, ) - # import correct text encoder class + # 加载 text encoder / scheduler / VAE / UNet text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - - # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision @@ -636,22 +656,24 @@ def main(args): ) + # 冻结主干权重:该方法只训练 attention processor 中的增量参数 vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + + # 推理组件使用半精度可节省显存;训练增量层由 optimizer 管理 weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype + # 将模型移动到训练设备并统一 dtype text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) + # 根据是否启用 xformers 选择 attention processor 的实现 attention_class = CustomDiffusionAttnProcessor if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -666,24 +688,12 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - # now we will add new Custom Diffusion weights to the attention layers - # It's important to realize here how many attention weights will be added and of which sizes - # The sizes of the attention layers consist only of two different variables: - # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. - # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. - - # Let's first see how many attention processors we will have to set. - # For Stable Diffusion, it should be equal to: - # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 - # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 - # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 - # => 32 layers - - # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer + # 训练策略:只训练 cross-attention 的 KV(或全部 Q/K/V/out),其余保持冻结 train_kv = True train_q_out = False if args.freeze_model == "crossattn_kv" else True custom_diffusion_attn_procs = {} + # 从 UNet state_dict 中取出原始权重,作为自定义 attention processor 的初始化 st = unet.state_dict() for name, _ in unet.attn_processors.items(): @@ -697,6 +707,8 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] layer_name = name.split(".processor")[0] + + # KV 投影权重始终可训练;若启用 train_q_out 则额外训练 Q 与 out 投影 weights = { "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], @@ -705,6 +717,8 @@ def main(args): weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + + # 仅对 cross-attention 层注入可训练 processor;self-attention 走冻结版本 if cross_attention_dim is not None: custom_diffusion_attn_procs[name] = attention_class( train_kv=train_kv, @@ -723,24 +737,27 @@ def main(args): del st + # 将新的 attention processor 注入 UNet,并用 AttnProcsLayers 封装成可训练模块 unet.set_attn_processor(custom_diffusion_attn_procs) custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) + # 将增量层注册到 checkpoint,保证训练状态可保存/恢复 accelerator.register_for_checkpointing(custom_diffusion_layers) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + + # 允许 TF32 可提升部分 GPU 上的矩阵运算速度 if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + # 先验保持时,通常将学习率扩大以补偿额外损失项带来的梯度分摊 args.learning_rate = args.learning_rate if args.with_prior_preservation: args.learning_rate = args.learning_rate * 2.0 - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # 选择优化器实现:可选 8-bit AdamW 以降低显存占用 if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -753,7 +770,7 @@ def main(args): else: optimizer_class = torch.optim.AdamW - # Optimizer creation + # 仅优化 custom_diffusion_layers 的参数,其余主干保持冻结 optimizer = optimizer_class( custom_diffusion_layers.parameters(), lr=args.learning_rate, @@ -762,7 +779,7 @@ def main(args): eps=args.adam_epsilon, ) - # Dataset creation: + # 构建训练数据集;mask_size 通过 VAE latent 分辨率自动推导 train_dataset = CustomDiffusionDataset( concepts_list=args.concepts_list, tokenizer=tokenizer, @@ -780,15 +797,17 @@ def main(args): ) - # Prepare for PGD + # 为 PGD 准备可训练的图像张量:对实例图像做与训练一致的 transforms pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images] pertubed_images = torch.stack(pertubed_images).contiguous() pertubed_images.requires_grad_() + # 保留原始图像张量,用于 PGD 的投影约束 original_images = pertubed_images.clone().detach() original_images.requires_grad_(False) + # 文本 token:对所有实例图像重复同一个 instance_prompt(保持原脚本行为) input_ids = train_dataset.tokenizer( args.instance_prompt, truncation=True, @@ -798,6 +817,7 @@ def main(args): ).input_ids.repeat(len(original_images), 1) def get_one_mask(image): + """与训练同样的随机缩放逻辑,生成单张实例图像的有效区域 mask。""" random_scale = train_dataset.size if train_dataset.aug: random_scale = ( @@ -812,6 +832,7 @@ def main(args): one_mask += class_mask return one_mask + # 预先为每张图像生成 mask,并堆叠为 batch 形式供训练损失使用 images_open_list = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] mask_list = [] for image in images_open_list: @@ -831,12 +852,13 @@ def main(args): num_training_steps=args.max_train_steps * accelerator.num_processes, ) + # 将可训练模块、优化器、对抗图像张量与 mask 一并交给 accelerate 管理设备与并行 custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare( custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask ) - # Train! + # 训练主循环:每步同时更新 attention 增量层与对抗图像(PGD) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num pertubed_images = {len(pertubed_images)}") @@ -844,36 +866,31 @@ def main(args): global_step = 0 first_epoch = 0 - # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, args.max_train_steps): unet.train() for _ in range(1): with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): - # Convert images to latent space + # 将图像编码到 latent 空间并加噪,形成 UNet 的训练输入 pertubed_images.requires_grad = True latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning + # 文本条件编码 encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] - # Predict the noise residual + # UNet 预测噪声 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type + # 选择监督目标(epsilon 或 v) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -881,39 +898,32 @@ def main(args): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - # unet.zero_grad() - # text_encoder.zero_grad() - + # loss 计算:可选先验保持;实例部分可结合 mask 做空间加权 if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) - # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: mask = mask.to(accelerator.device) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64]) - - #loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean() accelerator.backward(loss) - + # 梯度裁剪:只裁剪可训练的 custom_diffusion_layers 参数 if accelerator.sync_gradients: params_to_clip = ( custom_diffusion_layers.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # PGD 更新:基于 pertubed_images 的梯度做投影更新,并保持在 eps 约束内 alpha = args.alpha eps = args.eps adv_images = pertubed_images + alpha * pertubed_images.grad.sign() @@ -925,7 +935,6 @@ def main(args): lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 @@ -938,6 +947,7 @@ def main(args): if global_step >= args.max_train_steps: break + # 训练结束后在主进程保存最终对抗图像,文件名包含原始图片名以便对齐 if accelerator.is_main_process: logger.info("***** Final save of perturbed images *****") save_folder = args.output_dir @@ -955,7 +965,6 @@ def main(args): img_name = img_names[i] save_path = os.path.join(save_folder, f"final_noise_{img_name}") - # 图像转换和保存 Image.fromarray( (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy() ).save(save_path) @@ -970,3 +979,4 @@ if __name__ == "__main__": args = parse_args() main(args) print("<-------end-------->") + \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation/glaze.py b/src/backend/app/algorithms/perturbation/glaze.py index 8db1371..9c72d0c 100644 --- a/src/backend/app/algorithms/perturbation/glaze.py +++ b/src/backend/app/algorithms/perturbation/glaze.py @@ -1,8 +1,3 @@ -""" -Glaze: 艺术风格保护算法 -基于原始 Glaze 项目重构,适配 4090D GPU 直接运行 -""" - import argparse import os import gc @@ -24,6 +19,7 @@ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution def parse_args(input_args=None): + """解析命令行参数,包含模型路径、输入输出目录、风格迁移配置与扰动优化超参。""" parser = argparse.ArgumentParser(description="Glaze: 艺术风格保护算法") parser.add_argument( "--pretrained_model_name_or_path", @@ -126,7 +122,7 @@ def parse_args(input_args=None): parser.add_argument( '--style_transfer_iter', type=int, - default=15, + default=15, help='风格迁移的扩散步数' ) parser.add_argument( @@ -159,6 +155,7 @@ def parse_args(input_args=None): else: args = parser.parse_args() + # 兼容 accelerate/分布式启动时的 LOCAL_RANK 注入 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -167,7 +164,7 @@ def parse_args(input_args=None): def get_eps_from_intensity(intensity): - """根据强度级别 (0-100) 计算 epsilon 值""" + """将强度等级映射为 epsilon,用于以更直观的方式控制扰动幅度。""" if intensity <= 50: actual_eps = 0.025 + 0.025 * intensity / 50 else: @@ -176,7 +173,7 @@ def get_eps_from_intensity(intensity): def img2tensor(cur_img, device='cuda'): - """将 PIL 图像转换为 [-1, 1] 范围的张量""" + """将 PIL 图像转换为 [-1, 1] 范围的张量,并按 (1,C,H,W) 形式返回。""" cur_img = np.array(cur_img) img = (cur_img / 127.5 - 1).astype(np.float32) img = rearrange(img, 'h w c -> c h w') @@ -185,7 +182,7 @@ def img2tensor(cur_img, device='cuda'): def tensor2img(cur_img): - """将 [-1, 1] 范围的张量转换为 PIL 图像""" + """将 [-1, 1] 范围的张量转换为 PIL 图像,便于保存与可视化。""" if len(cur_img.shape) == 3: cur_img = cur_img.unsqueeze(0) cur_img = torch.clamp((cur_img.detach() + 1) / 2, min=0, max=1) @@ -195,7 +192,7 @@ def tensor2img(cur_img): def load_img(path): - """加载图像并处理 EXIF 旋转信息""" + """加载图像并修正 EXIF 方向,统一输出为 RGB;失败则返回 None。""" if not os.path.exists(path): return None try: @@ -210,14 +207,14 @@ def load_img(path): class GlazeDataset(Dataset): - """用于加载待处理图像的数据集""" + """从目录读取待处理图像,并返回图像张量、路径与原始 PIL 图像对象。""" def __init__(self, instance_data_root, size=512, center_crop=False): self.size = size self.center_crop = center_crop self.instance_images_path = [] - # 支持的图像格式 + # 过滤常见图像后缀,并避免重复处理已输出的 *_glazed 文件 valid_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.webp', '.tiff'} for p in Path(instance_data_root).iterdir(): @@ -227,6 +224,7 @@ class GlazeDataset(Dataset): self.instance_images_path = sorted(self.instance_images_path) self.num_instance_images = len(self.instance_images_path) + # 这里不做 Normalize,保持输入在 [0,1],后续在编码前再转换到 [-1,1] self.image_transforms = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), @@ -241,8 +239,8 @@ class GlazeDataset(Dataset): img_path = self.instance_images_path[index % self.num_instance_images] instance_image = load_img(str(img_path)) + # 对异常样本返回占位图,避免训练/处理流程中断 if instance_image is None: - # 返回空白图像作为占位 instance_image = Image.new('RGB', (self.size, self.size), (0, 0, 0)) example['index'] = index % self.num_instance_images @@ -253,20 +251,20 @@ class GlazeDataset(Dataset): class GlazeOptimizer: - """Glaze 优化器核心类""" + """Glaze 核心优化器:负责生成目标风格参考,并在特征空间内优化输入扰动。""" def __init__(self, args, device): self.args = args self.device = device self.half = args.half_precision and device == 'cuda' - # 计算 epsilon + # eps 控制扰动最大幅度,可由 intensity 自动换算或直接手动指定 if args.intensity is not None: self.max_change = get_eps_from_intensity(args.intensity) else: self.max_change = args.eps - # 计算步长 + # 步长默认取 eps 的一半,并在迭代中做衰减以降低后期振荡 if args.step_size is not None: self.step_size = args.step_size else: @@ -275,12 +273,12 @@ class GlazeOptimizer: print(f"扰动预算 (epsilon): {self.max_change:.4f}") print(f"步长: {self.step_size:.4f}") - # 模型占位符 + # 模型在需要时惰性加载,减少启动开销与显存占用峰值 self.vae = None self.sd_pipeline = None def load_vae(self): - """加载 VAE 编码器""" + """加载 VAE 编码器,用于将图像映射到特征空间并参与梯度计算。""" print("加载 VAE 模型...") self.vae = AutoencoderKL.from_pretrained( self.args.pretrained_model_name_or_path, @@ -294,24 +292,24 @@ class GlazeOptimizer: # 注意:不设置 requires_grad_(False),因为我们需要通过它计算梯度 def load_sd_pipeline(self): - """加载 Stable Diffusion img2img 管道用于风格迁移""" + """加载 Stable Diffusion img2img 管道,用于生成目标风格参考图像。""" print("加载 Stable Diffusion 管道...") - # 始终使用 FP32 加载以避免 CPU 卸载问题 + # 始终使用 FP32 加载以避免 CPU offload 等路径带来的精度与兼容问题 self.sd_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( self.args.pretrained_model_name_or_path, revision=self.args.revision, torch_dtype=torch.float32, - safety_checker=None, # 禁用 NSFW 检查器 + safety_checker=None, requires_safety_checker=False ) self.sd_pipeline.to(self.device) self.sd_pipeline.enable_attention_slicing() - # 如果使用半精度且在 GPU 上,转换为 FP16 + # 在 GPU 上启用半精度,以减少显存并加速推理 if self.half and self.device == 'cuda': self.sd_pipeline.to(torch.float16) - # 尝试启用 xformers + # 可选启用 xformers 注意力实现,以进一步降低显存占用 if self.args.enable_xformers_memory_efficient_attention: try: self.sd_pipeline.enable_xformers_memory_efficient_attention() @@ -320,7 +318,7 @@ class GlazeOptimizer: print(f"无法启用 xformers: {e}") def unload_sd_pipeline(self): - """卸载 SD 管道以释放显存""" + """释放 SD 管道占用的显存,避免与后续优化阶段竞争资源。""" if self.sd_pipeline is not None: del self.sd_pipeline self.sd_pipeline = None @@ -330,36 +328,35 @@ class GlazeOptimizer: def vae_encode(self, x): """ - 使用 VAE 编码图像 - 注意:这里不使用 no_grad,以便支持梯度计算 + 使用 VAE 编码图像并返回 posterior 均值。 + + 这里保留梯度计算通路,使输入扰动可以通过特征距离损失进行优化。 """ posterior = self.vae.encode(x).latent_dist return posterior.mean def vae_encode_no_grad(self, x): - """使用 VAE 编码图像(不计算梯度版本,用于目标编码)""" + """不计算梯度的 VAE 编码版本,用于提取目标图像特征以节省显存。""" with torch.no_grad(): posterior = self.vae.encode(x).latent_dist return posterior.mean.detach() def style_transfer(self, img): """ - 使用 Stable Diffusion 进行风格迁移 - 生成目标风格图像 + 使用 SD img2img 将输入图像迁移到目标风格,得到用于对齐的目标风格参考图像。 """ if self.sd_pipeline is None: self.load_sd_pipeline() - # 调整图像大小 + # 将原图缩放到不超过 512,并以左上角对齐的方式填充到 512x512 画布 img_copy = img.copy() img_copy.thumbnail((512, 512), Image.LANCZOS) - # 创建 512x512 画布 canvas = np.zeros((512, 512, 3), dtype=np.uint8) canvas[:img_copy.size[1], :img_copy.size[0], :] = np.array(img_copy) padded_img = Image.fromarray(canvas) - # 运行风格迁移 + # 生成目标风格图像,仅用于提供参考,不需要梯度 with torch.no_grad(): result = self.sd_pipeline( prompt=self.args.target_style, @@ -371,19 +368,18 @@ class GlazeOptimizer: target_img = result.images[0] - # 裁剪回原始大小 + # 将输出裁剪回缩放后的有效区域,再 resize 回原图尺寸以对齐后续分块 cropped_target = np.array(target_img)[:img_copy.size[1], : img_copy.size[0], :] cropped_target = Image.fromarray(cropped_target) - - # 调整到原图大小 full_target = cropped_target.resize(img.size, Image.LANCZOS) return full_target def segment_image(self, img): """ - 将图像分割成 512x512 的方块 - 返回: (方块列表, 最后一个方块的偏移, 方块大小) + 将输入图像切分为若干正方形分块,并将每个分块缩放到 512x512。 + + 返回值包含分块列表、最后一个分块的对齐偏移、以及原始正方形分块的边长。 """ img_array = np.array(img).astype(np.float32) og_width, og_height = img.size @@ -391,9 +387,8 @@ class GlazeOptimizer: squares_ls = [] last_index = 0 - # 判断是宽图还是高图 + # 以短边为正方形边长,沿长边方向切块 if og_height <= og_width: - # 宽图:按水平方向分割 square_size = og_height cur_idx = 0 @@ -409,7 +404,6 @@ class GlazeOptimizer: squares_ls.append(cropped_img) cur_idx += og_height else: - # 高图:按垂直方向分割 square_size = og_width cur_idx = 0 @@ -428,20 +422,16 @@ class GlazeOptimizer: return squares_ls, last_index, square_size def put_back_cloak(self, og_img_array, cloak_list, last_index): - """ - 将扰动贴回原图 - """ + """将每个分块的扰动增量贴回原图位置,并裁剪到合法像素范围。""" og_height, og_width, _ = og_img_array.shape if og_height <= og_width: - # 宽图 for idx, cur_cloak in enumerate(cloak_list): if idx < len(cloak_list) - 1: og_img_array[0:og_height, idx * og_height:(idx + 1) * og_height, : ] += cur_cloak else: og_img_array[0:og_height, idx * og_height:(idx + 1) * og_height, :] += cur_cloak[0:og_height, last_index:] else: - # 高图 for idx, cur_cloak in enumerate(cloak_list): if idx < len(cloak_list) - 1: og_img_array[idx * og_width:(idx + 1) * og_width, 0:og_width, :] += cur_cloak @@ -452,9 +442,7 @@ class GlazeOptimizer: return og_img_array def get_cloak(self, og_segment_img, res_adv_tensor, square_size): - """ - 计算单个方块的扰动 (cloak) - """ + """将对抗结果与原分块对齐后取差值,得到该分块需要回贴到原图的扰动增量。""" resize_back = og_segment_img.resize((square_size, square_size), Image.LANCZOS) res_adv_img = tensor2img(res_adv_tensor).resize((square_size, square_size), Image.LANCZOS) cloak = np.array(res_adv_img).astype(np.float32) - np.array(resize_back).astype(np.float32) @@ -462,13 +450,13 @@ class GlazeOptimizer: def compute_adversarial(self, source_segments, target_segments, square_size, progress_callback=None): """ - 计算对抗扰动 - 核心优化算法 + 对每个分块执行 PGD 式优化,使源分块在 VAE 特征空间上逼近目标风格分块。 + + 该模块是核心优化过程:损失为 adv_emb 与 target_emb 的距离,并对扰动做 epsilon 约束投影。 """ results = [] for seg_idx, (source_seg, target_seg) in enumerate(zip(source_segments, target_segments)): - # 转换为张量 source_tensor = img2tensor(source_seg, self.device) target_tensor = img2tensor(target_seg, self.device) @@ -476,18 +464,15 @@ class GlazeOptimizer: source_tensor = source_tensor.half() target_tensor = target_tensor.half() - # 获取目标编码(不需要梯度) target_emb = self.vae_encode_no_grad(target_tensor) - # 初始化:源图像和扰动 X_batch = source_tensor.clone().detach() modifiers = torch.zeros_like(X_batch, requires_grad=True) - # 调整大小的变换 + # 通过先缩放回原分块尺寸再缩放到 512,模拟回贴后的尺度影响 resizer_large = torchvision.transforms.Resize(square_size) resizer_512 = torchvision.transforms.Resize((512, 512)) - # PGD 优化循环 pbar = tqdm(range(self.args.max_train_steps), desc=f"优化方块 {seg_idx + 1}/{len(source_segments)}", leave=False) @@ -495,61 +480,45 @@ class GlazeOptimizer: best_modifier = None for step in pbar: - # 动态调整步长 + # 使用随步数衰减的步长,提升收敛稳定性 decay = 1 - (step / self.args.max_train_steps) actual_step_size = self.step_size * decay - # 确保 modifiers 需要梯度 if not modifiers.requires_grad: modifiers = modifiers.detach().clone().requires_grad_(True) - # 应用扰动并裁剪 X_adv = torch.clamp(modifiers + X_batch, -1, 1) - # 调整大小(模拟实际处理) X_adv_resized = resizer_large(X_adv) X_adv_resized = resizer_512(X_adv_resized) - # 计算损失:最小化与目标编码的距离 adv_emb = self.vae_encode(X_adv_resized) loss = (adv_emb - target_emb).norm() - # 反向传播 loss.backward() - # 获取梯度 grad = modifiers.grad.detach() - # PGD 更新:沿梯度符号方向移动 + # 沿梯度符号方向更新,并投影到 [-eps, eps] 的约束范围内 with torch.no_grad(): update = grad.sign() * actual_step_size - modifiers_new = modifiers - update # 最小化损失,所以是减 - - # 投影到 epsilon 球 + modifiers_new = modifiers - update modifiers_new = torch.clamp(modifiers_new, -self.max_change, self.max_change) - - # 保存最佳结果 best_modifier = modifiers_new.detach().clone() - # 重新初始化 modifiers 用于下一轮 modifiers = best_modifier.clone().requires_grad_(True) - # 更新进度条 pbar.set_postfix({'loss': f'{loss.item():.4f}'}) - # 回调 if progress_callback: progress_callback(seg_idx, step, loss.item()) - # 最终对抗样本 with torch.no_grad(): best_adv = torch.clamp(best_modifier + X_batch, -1, 1) - # 计算 cloak cloak = self.get_cloak(source_seg, best_adv, square_size) results.append(cloak) - # 清理显存 del source_tensor, target_tensor, X_batch, modifiers, best_modifier if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -557,34 +526,27 @@ class GlazeOptimizer: return results def process_image(self, img, run_idx=0): - """ - 处理单张图像 - """ + """处理单张图像:生成目标风格参考→分块→逐块优化→回贴合成。""" print(f"\n=== 处理图像 (运行 {run_idx + 1}/{self.args.n_runs}) ===") - # 1.生成目标风格图像 print("生成目标风格图像...") target_img = self.style_transfer(img) - # 释放 SD 管道显存 + # 风格参考生成后立即释放 SD 管道,优先保证后续优化阶段显存充足 self.unload_sd_pipeline() - # 确保 VAE 已加载 if self.vae is None: self.load_vae() - # 2.分割图像 print("分割图像...") source_segments, last_index, square_size = self.segment_image(img) target_segments, _, _ = self.segment_image(target_img) print(f"图像被分割为 {len(source_segments)} 个方块,大小: {square_size}x{square_size}") - # 3.计算对抗扰动 print("计算对抗扰动...") cloak_list = self.compute_adversarial(source_segments, target_segments, square_size) - # 4.将扰动贴回原图 print("合成最终图像...") og_array = np.array(img).astype(np.float32) cloaked_array = self.put_back_cloak(og_array, cloak_list, last_index) @@ -594,13 +556,13 @@ class GlazeOptimizer: def main(args): - # 设置随机种子 + # 设置随机种子,保证风格迁移与优化过程的随机分支可复现 if args.seed is not None: torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) - # 检测设备 + # 选择运行设备并打印基础信息 if torch.cuda.is_available(): device = 'cuda' print(f"使用 GPU: {torch.cuda.get_device_name(0)}") @@ -609,10 +571,8 @@ def main(args): device = 'cpu' print("使用 CPU") - # 创建输出目录 os.makedirs(args.output_dir, exist_ok=True) - # 加载数据集 dataset = GlazeDataset( instance_data_root=args.instance_data_dir, size=args.resolution, @@ -629,10 +589,9 @@ def main(args): print(f"优化步数: {args.max_train_steps}") print(f"运行次数: {args.n_runs}") - # 创建优化器 optimizer = GlazeOptimizer(args, device) - # 处理每张图像 + # 逐张处理,并按参数拼接输出文件名,便于回溯实验条件 for img_idx in range(len(dataset)): img_data = dataset[img_idx] img_path = img_data['path'] @@ -644,13 +603,10 @@ def main(args): best_result = None - # 多次运行取最佳结果 + # 多次运行可缓解随机性影响;保持原逻辑:以最后一次成功结果作为输出 for run_idx in range(args.n_runs): try: cloaked_img = optimizer.process_image(original_img, run_idx) - - # 简单起见,这里取最后一次运行的结果 - # 完整版本应该用 CLIP 评估选择最佳结果 best_result = cloaked_img except Exception as e: @@ -660,16 +616,14 @@ def main(args): continue if best_result is not None: - # 保存结果 orig_name = Path(img_path).stem orig_ext = Path(img_path).suffix - # 构建输出文件名 intensity_str = f"intensity{args.intensity}" if args.intensity else f"eps{int(args.eps*255)}" output_name = f"{orig_name}_glazed_{intensity_str}_steps{args.max_train_steps}{orig_ext}" output_path = os.path.join(args.output_dir, output_name) - # 保存图像 + # 按扩展名选择保存格式,避免某些格式默认压缩带来额外失真 if output_path.lower().endswith('.png'): best_result.save(output_path, 'PNG') else: diff --git a/src/backend/app/algorithms/perturbation/pid.py b/src/backend/app/algorithms/perturbation/pid.py index a556737..7354cf4 100644 --- a/src/backend/app/algorithms/perturbation/pid.py +++ b/src/backend/app/algorithms/perturbation/pid.py @@ -14,90 +14,88 @@ from diffusers import AutoencoderKL def parse_args(input_args=None): + """ + 配置解析函数:定义模型路径、数据集位置、攻击参数等命令行输入 + """ parser = argparse.ArgumentParser(description="Simple example of a training script.") + + # 基础模型与路径配置 parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", + help="HuggingFace 模型标识或本地预训练模型路径", ) parser.add_argument( "--revision", type=str, default=None, required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), + help="指定模型的特定版本(如 branch, tag 或 commit id)", ) parser.add_argument( "--instance_data_dir", type=str, default=None, required=True, - help="A folder containing the training data of instance images.", + help="包含训练实例图像的文件夹路径", ) parser.add_argument( "--output_dir", type=str, default="text-inversion-model", - help="The output directory where the model predictions and checkpoints will be written.", + help="训练结果和检查点的保存目录", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + + # 训练超参数配置 + parser.add_argument("--seed", type=int, default=None, help="用于可复现训练的随机种子") parser.add_argument( "--resolution", type=int, default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), + help="输入图像的分辨率,所有图像将调整为此大小", ) parser.add_argument( "--center_crop", default=False, action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), + help="是否对图像进行中心裁剪,否则进行随机裁剪", ) parser.add_argument( "--max_train_steps", type=int, default=None, - help="Total number of updating steps", + help="最大训练更新步数", ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), + help="数据加载的子进程数,0 表示在主进程中加载", ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--local_rank", type=int, default=-1, help="分布式训练的本地排名") parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + "--enable_xformers_memory_efficient_attention", action="store_true", help="是否启用 xformers 以优化内存占用" ) + + # 对抗扰动攻击专用参数 parser.add_argument( '--eps', type=float, default=12.75, - help='pertubation budget' + help='扰动预算限制(基于 255 像素刻度)' ) parser.add_argument( '--step_size', type=float, default=1/255, - help='step size of each update' + help='每一迭代步的扰动更新步长' ) parser.add_argument( '--attack_type', choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'], - help='what is the attack target' + help='对抗攻击的目标类型(如方差、均值或 KL 散度)' ) if input_args is not None: @@ -105,6 +103,7 @@ def parse_args(input_args=None): else: args = parser.parse_args() + # 处理分布式环境下的 rank 变量 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -114,8 +113,7 @@ def parse_args(input_args=None): class PIDDataset(Dataset): """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + 数据集类:负责加载图像、处理 EXIF 信息并应用预处理变换 """ def __init__( @@ -128,6 +126,8 @@ class PIDDataset(Dataset): self.center_crop = center_crop self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) + + # 图像预处理流水线:缩放 -> 裁剪 -> 转换为张量 self.image_transforms = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), @@ -139,8 +139,11 @@ class PIDDataset(Dataset): def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + + # 自动修正图像的方向(基于 EXIF 元数据) instance_image = exif_transpose(instance_image) + # 统一强制转换为 RGB 格式 if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") @@ -150,19 +153,23 @@ class PIDDataset(Dataset): def main(args): - # Set random seed + """ + 主训练流程:初始化模型、生成对抗扰动并进行 PGD 优化 + """ + # 设定随机种子以保证实验的可重复性 if args.seed is not None: torch.manual_seed(args.seed) + weight_dtype = torch.float32 device = torch.device('cuda') - # VAE encoder + # 初始化 VAE 编码器(保持冻结,仅用于提取特征或计算损失) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) vae.requires_grad_(False) vae.to(device, dtype=weight_dtype) - # Dataset and DataLoaders creation: + # 创建数据集和数据加载器(Batch Size 固定为 1 以适配扰动一一对应关系) dataset = PIDDataset( instance_data_root=args.instance_data_dir, size=args.resolution, @@ -170,103 +177,117 @@ def main(args): ) dataloader = torch.utils.data.DataLoader( dataset, - batch_size=1, # some parts of code don't support batching + batch_size=1, shuffle=True, num_workers=args.dataloader_num_workers, ) - # Wrapper of the perturbations generator + # 对抗攻击模型封装 class AttackModel(torch.nn.Module): def __init__(self): super().__init__() to_tensor = transforms.ToTensor() self.epsilon = args.eps/255 + # 为数据集中每一张图初始化一个随机微小扰动(Delta) self.delta = [torch.empty_like(to_tensor(Image.open(path))).uniform_(-self.epsilon, self.epsilon) for path in dataset.instance_images_path] self.size = dataset.size def forward(self, vae, x, index, poison=False): - # Check whether we need to add perturbation + # 若处于攻击模式,则给输入图像加上扰动张量 if poison: self.delta[index].requires_grad_(True) x = x + self.delta[index].to(dtype=weight_dtype) - # Normalize to [-1, 1] + # 归一化图像到 [-1, 1] 区间,符合 VAE 输入要求 input_x = 2 * x - 1 return vae.encode(input_x.to(device)) attackmodel = AttackModel() - # Just to zero-out the gradient + # 定义优化器(注意:此处 LR 为 0,实际更新通过手动 PGD 符号梯度完成) optimizer = torch.optim.SGD(attackmodel.delta, lr=0) - # Progress bar + # 设置进度条 progress_bar = tqdm(range(0, args.max_train_steps), desc="Steps") - # Make sure the dir exists + # 确保输出目录存在 os.makedirs(args.output_dir, exist_ok=True) - # Start optimizing the perturbation + # 核心优化循环 for step in progress_bar: total_loss = 0.0 for batch in dataloader: - # Save images - if step%25 == 0: + # 定期保存添加扰动后的图像,以便观察视觉效果 + if step % 25 == 0: to_image = transforms.ToPILImage() for i in range(0, len(dataset.instance_images_path)): img = dataset[i]['pixel_values'] img = to_image(img + attackmodel.delta[i]) img.save(os.path.join(args.output_dir, f"{i}.png")) - - # Select target loss + # 分别计算原始图像和中毒(添加扰动)图像在潜空间的分布 clean_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], False) poison_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], True) + clean_latent = clean_embedding.latent_dist poison_latent = poison_embedding.latent_dist + # 根据攻击类型计算相应的损失函数(旨在拉开或改变分布特征) if args.attack_type == 'var': loss = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean") elif args.attack_type == 'mean': loss = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean") elif args.attack_type == 'KL': + # 计算两个正态分布之间的 KL 散度 sigma_2, mu_2 = poison_latent.std, poison_latent.mean sigma_1, mu_1 = clean_latent.std, clean_latent.mean KL_diver = torch.log(sigma_2 / sigma_1) - 0.5 + (sigma_1 ** 2 + (mu_1 - mu_2) ** 2) / (2 * sigma_2 ** 2) loss = KL_diver.flatten().mean() elif args.attack_type == 'latent_vector': + # 直接对采样后的潜向量计算 MSE clean_vector = clean_latent.sample() poison_vector = poison_latent.sample() loss = F.mse_loss(clean_vector, poison_vector, reduction="mean") elif args.attack_type == 'add': + # 同时攻击均值和标准差 loss_2 = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean") loss_1 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean") loss = loss_1 + loss_2 elif args.attack_type == 'add-log': + # 攻击对数方差和均值(数值更稳定的变体) loss_1 = F.mse_loss(clean_latent.var.log(), poison_latent.var.log(), reduction="mean") loss_2 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction='mean') loss = loss_1 + loss_2 - + # 清除梯度并执行反向传播 optimizer.zero_grad() loss.backward() - # Perform PGD update on the loss + # 执行 PGD (Projected Gradient Descent) 更新步骤 delta = attackmodel.delta[batch['index']] delta.requires_grad_(False) + + # 沿梯度上升方向更新(最大化损失),实现攻击效果 delta += delta.grad.sign() * args.step_size + + # 约束 1:将扰动范围裁剪在 epsilon 预算内 delta = torch.clamp(delta, -attackmodel.epsilon, attackmodel.epsilon) + + # 约束 2:确保最终生成的图像像素值在 [0, 1] 合法区间内 delta = torch.clamp(delta, -batch['pixel_values'].detach().cpu(), 1-batch['pixel_values'].detach().cpu()) + + # 写回更新后的扰动并移除 Batch 维度 attackmodel.delta[batch['index']] = delta.detach().squeeze(0) total_loss += loss.detach().cpu() - # Logging steps + # 更新进度条状态栏 logs = {"loss": total_loss.item()} progress_bar.set_postfix(**logs) if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation/simac.py b/src/backend/app/algorithms/perturbation/simac.py index cab654d..b7dad35 100644 --- a/src/backend/app/algorithms/perturbation/simac.py +++ b/src/backend/app/algorithms/perturbation/simac.py @@ -31,18 +31,14 @@ logger = get_logger(__name__) def _cuda_gc() -> None: - """Try to release unreferenced CUDA memory and reduce fragmentation. - - This is a best-effort helper. It does not change algorithmic behavior but can - make long runs less prone to OOM due to fragmentation/reserved-memory growth. - """ + """尽力释放未引用的 CUDA 内存,降低碎片化风险,不改变算法行为。""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() class DreamBoothDatasetFromTensor(Dataset): - """Just like DreamBoothDataset, but take instance_images_tensor instead of path.""" + """基于内存张量的 DreamBooth 数据集:直接返回图像张量与 prompt token,减少磁盘 IO。""" def __init__( self, @@ -114,6 +110,7 @@ class DreamBoothDatasetFromTensor(Dataset): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + # 依据 text_encoder 配置识别架构,加载对应实现 text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -133,6 +130,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st def parse_args(input_args=None): + # 解析全量参数:模型、数据、对抗超参、先验保持、训练与日志设置 parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -364,7 +362,7 @@ def parse_args(input_args=None): class PromptDataset(Dataset): - """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" + """多 GPU 生成 class 图像的提示词数据集""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -485,11 +483,11 @@ def train_one_epoch( f"instance_loss: {instance_loss.detach().item()}" ) - # Best-effort: free per-step tensors earlier (no behavior change). + # 尽早释放当前步的中间张量,降低显存占用 del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states del model_pred, target, loss, prior_loss, instance_loss - # Best-effort: release optimizer state + dataset refs sooner. + # 释放优化器与数据集引用,进一步回收显存 del optimizer, train_dataset, params_to_optimize _cuda_gc() @@ -497,6 +495,7 @@ def train_one_epoch( def set_unet_attr(unet): + # 覆写若干 up_block 的 resnet forward,捕获中间特征以供特征对齐损失使用 def conv_forward(self): def forward(input_tensor, temb): self.in_layers_features = input_tensor @@ -554,6 +553,7 @@ def set_unet_attr(unet): def save_feature_maps(up_blocks, down_blocks): + # 收集指定 up_block 的输出特征,用于对抗攻击中的特征对齐 out_layers_features_list_3 = [] res_3_list = [0, 1, 2] @@ -577,11 +577,7 @@ def pgd_attack( num_steps: int, time_list, ): - """Return new perturbed data. - - Note: This function keeps the external behavior identical, but tries to reduce - memory pressure by freeing tensors early and avoiding lingering references. - """ + """PGD 对抗扰动:按预选时间步迭代更新图像,可附加特征对齐正则;尝试提前释放无用张量。""" unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -594,7 +590,6 @@ def pgd_attack( perturbed_images = data_tensor.detach().clone() perturbed_images.requires_grad_(True) - # Keep input_ids on CPU; move to GPU only when encoding. input_ids = tokenizer( args.instance_prompt, truncation=True, @@ -611,6 +606,7 @@ def pgd_attack( noise = torch.randn_like(latents) + # 为每个样本从其时间步列表中随机选择一个时间步 timesteps = [] for i in range(len(data_tensor)): ts = time_list[i] @@ -635,6 +631,7 @@ def pgd_attack( noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + # 计算干净样本的对应特征用于对齐(不反传) with torch.no_grad(): clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() clean_latents = clean_latents * vae.config.scaling_factor @@ -652,8 +649,7 @@ def pgd_attack( text_encoder.zero_grad(set_to_none=True) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Keep original behavior: feature loss does not backprop (added as Python float). - loss = loss + target_loss.detach().item() + loss = loss + target_loss.detach().item() # 特征对齐损失保持为常数项,不反传 loss.backward() alpha = args.pgd_alpha @@ -666,7 +662,7 @@ def pgd_attack( f"PGD loss - step {step}, loss: {loss.detach().item()}, target_loss : {target_loss.detach().item()}" ) - # Best-effort: free per-step tensors early. + # 尽早释放当前步的中间张量 del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3 del target_loss, loss, adv_images, eta @@ -685,10 +681,7 @@ def select_timestep( original_images: torch.Tensor, target_tensor: torch.Tensor, ): - """Return timestep lists for each image. - - External behavior unchanged; add best-effort per-loop cleanup to lower memory pressure. - """ + """为每张图选择一个时间步列表:通过多次梯度采样筛掉部分时间步,减少攻击开销,同时保持外部行为不变。""" unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -721,6 +714,7 @@ def select_timestep( select_mask = torch.where(input_mask == 1, True, False) res_time_seq = torch.masked_select(time_seq, select_mask) + # 如果剩余时间步仍多,随机抽取部分时间步估计梯度,删除一段时间步区间 if len(res_time_seq) > 100: min_score, max_score = 0.0, 0.0 for inner_try in range(0, 5): @@ -776,6 +770,7 @@ def select_timestep( del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, score + # 删除一段时间步以缩小候选集合,并记录当前选中的时间步 print("del_t", del_t, "max_t", select_t) if del_t < args.delta_t: del_t = args.delta_t @@ -832,6 +827,7 @@ def select_timestep( def setup_seeds(): + # 设置统一随机种子并关闭 cudnn 非确定性,保证结果可复现 seed = 42 random.seed(seed) np.random.seed(seed) @@ -869,7 +865,7 @@ def main(args): set_seed(args.seed) setup_seeds() - # Generate class images if prior preservation is enabled. + # 先验保持:若 class 图像不足,则通过基础 pipeline 生成补齐 if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) class_images_dir.mkdir(parents=True, exist_ok=True) @@ -1024,7 +1020,7 @@ def main(args): time_list, ) - # Free surrogate ASAP (best-effort, behavior unchanged). + # 及时释放 surrogate,保持显存占用稳定 del f_sur _cuda_gc() @@ -1061,7 +1057,7 @@ def main(args): print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") - # Best-effort cleanup at the end of each outer iteration. + # 外层迭代结束后的清理 _cuda_gc() diff --git a/src/backend/app/algorithms/processors/coords_processor.py b/src/backend/app/algorithms/processors/coords_processor.py new file mode 100644 index 0000000..ae2a908 --- /dev/null +++ b/src/backend/app/algorithms/processors/coords_processor.py @@ -0,0 +1,132 @@ +import pandas as pd +from pathlib import Path +import sys +import logging +import numpy as np +import statsmodels.api as sm + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def apply_lowess_and_clipping_scaling(input_csv_path, output_csv_path, lowess_frac, target_range, clipping_percentile): + """ + 应用 Lowess 局部加权回归进行平滑(提取总体趋势),然后使用百分位数裁剪后的 Min-Max 边界来缩放。 + 目标:生成最平滑、最接近单调下降的客观趋势。 + """ + input_path = Path(input_csv_path) + output_path = Path(output_csv_path) + + if not input_path.exists(): + logging.error(f"错误:未找到输入文件 {input_csv_path}") + return + + logging.info(f"读取原始数据: {input_csv_path}") + df = pd.read_csv(input_path) + + df = df.loc[:,~df.columns.duplicated()].copy() + + # 定义原始数据列名 + raw_x_col = 'X_Feature_L2_Norm' + raw_y_col = 'Y_Feature_Variance' + raw_z_col = 'Z_LDM_Loss' + + # --------------------------- 1. Lowess 局部加权回归平滑 (提取总体趋势) --------------------------- + logging.info(f"应用 Lowess 局部加权回归,平滑因子 frac={lowess_frac}。") + + x_coords = df['step'].values + + for raw_col in [raw_x_col, raw_y_col, raw_z_col]: + y_coords = df[raw_col].values + + smoothed_data = sm.nonparametric.lowess( + endog=y_coords, + exog=x_coords, + frac=lowess_frac, + it=0 + ) + df[f'{raw_col}_LOWESS'] = smoothed_data[:, 1] + + + # --------------------------- 2. 百分位数边界缩放与方向统一 --------------------------- + p = clipping_percentile + logging.info(f"应用百分位数边界 (p={p}) 进行线性缩放,目标范围 [0, {target_range:.2f}]") + + scale_cols_map = { + 'X_Feature_L2_Norm': f'{raw_x_col}_LOWESS', + 'Y_Feature_Variance': f'{raw_y_col}_LOWESS', + 'Z_LDM_Loss': f'{raw_z_col}_LOWESS' + } + + for final_col, lowess_col in scale_cols_map.items(): + + data = df[lowess_col] + + # 裁剪:计算裁剪后的 min/max (定义缩放窗口) + lower_bound = data.quantile(p) + upper_bound = data.quantile(1.0 - p) + + min_val = lower_bound + max_val = upper_bound + data_range = max_val - min_val + + if data_range <= 0 or data_range == np.nan: + df[final_col] = 0.0 + logging.warning(f"列 {final_col} 裁剪后的范围为 {data_range:.4f},跳过缩放。") + continue + + # 归一化: (data - Min_window) / Range_window + normalized_data = (data - min_val) / data_range + + # **优化方向统一逻辑 (所有指标都应是越小越好):** + if final_col in ['X_Feature_L2_Norm', 'Y_Feature_Variance']: + # X/Y 反转:将 Max 映射到 0,Min 映射到 TargetRange + final_scaled_data = (1.0 - normalized_data) * target_range + else: # Z_LDM_Loss + # Z 标准缩放:Min 映射到 0,Max 映射到 TargetRange + final_scaled_data = normalized_data * target_range + + # 保留负值,以确保平滑过渡 + df[final_col] = final_scaled_data + + logging.info(f" - 列 {final_col}:裁剪边界: [{min_val:.4f}, {max_val:.4f}]。缩放后范围不再严格约束 [0, {target_range:.2f}],以保留趋势。") + + + # --------------------------- 3. 最终保存 --------------------------- + output_path.parent.mkdir(parents=True, exist_ok=True) + + final_cols = ['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + + + df[final_cols].to_csv( + output_path, + index=False, + float_format='%.3f' + ) + + logging.info(f"Lowess平滑和缩放后的数据已保存到: {output_csv_path}") + + +if __name__ == '__main__': + if len(sys.argv) != 6: + logging.error("使用方法: python smooth_coords.py <输入CSV路径> <输出CSV路径> <目标视觉范围 (例如 30)> <离散点裁剪百分比 (例如 0.15)>") + else: + input_csv = sys.argv[1] + output_csv = sys.argv[2] + try: + lowess_frac = float(sys.argv[3]) + target_range = float(sys.argv[4]) + clipping_p = float(sys.argv[5]) + + if not (0.0 < lowess_frac <= 1.0): + raise ValueError("Lowess 平滑因子 frac 必须在 (0.0, 1.0] 之间。") + if target_range <= 0: + raise ValueError("目标视觉范围必须大于 0。") + if not (0 <= clipping_p < 0.5): + raise ValueError("裁剪百分比必须在 [0, 0.5) 之间。") + + if not Path(output_csv).suffix: + output_csv = str(Path(output_csv) / "scaled_coords.csv") + + apply_lowess_and_clipping_scaling(input_csv, output_csv, lowess_frac, target_range, clipping_p) + except ValueError as e: + logging.error(f"参数错误: {e}") \ No newline at end of file diff --git a/src/backend/app/algorithms/processors/image_processor.py b/src/backend/app/algorithms/processors/image_processor.py new file mode 100644 index 0000000..287a197 --- /dev/null +++ b/src/backend/app/algorithms/processors/image_processor.py @@ -0,0 +1,149 @@ +""" +图片处理功能,用于把原始图片剪裁为中心正方形,指定分辨率,并保存为指定格式,还可以选择是否序列化改名。 +""" + +import argparse +import os +from pathlib import Path +from PIL import Image + +# --- 1. 参数解析 --- +def parse_args(input_args=None): + """ + 解析命令行参数。 + """ + parser = argparse.ArgumentParser(description="Image Processor for Centering, Resizing, and Format Conversion.") + + # 路径和分辨率参数 + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="A folder containing the original images to be processed and overwritten.", + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The target resolution (width and height) for the output images (e.g., 512 for 512x512).", + ) + # 格式参数 + parser.add_argument( + "--target_format", + type=str, + default="png", + choices=["jpeg", "png", "webp", "jpg"], + help="The target format for the saved images (e.g., 'png', 'jpg', 'webp'). The original file will be overwritten, potentially changing the file extension.", + ) + + # 序列化数字重命名参数 + parser.add_argument( + "--rename_sequential", + action="store_true", # 当这个参数存在时,其值为 True + help="If set, images will be sequentially renamed (e.g., 001.jpg, 002.jpg...) instead of preserving the original filename. WARNING: This WILL delete the originals.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + +# --- 2. 核心图像处理逻辑 --- +def process_image(image_path: Path, output_path: Path, resolution: int, target_format: str, delete_original: bool): + """ + 加载图像,居中取最大正方形,升降分辨率,并保存为目标格式。 + + Args: + image_path: 原始图片路径。 + output_path: 最终保存路径。 + resolution: 目标分辨率。 + target_format: 目标文件格式。 + delete_original: 是否删除原始文件。 + """ + try: + # 加载图像并统一转换为 RGB 模式 + img = Image.open(image_path).convert("RGB") + + # 居中取最大正方形 + width, height = img.size + min_dim = min(width, height) + + # 计算裁剪框 (以最短边为尺寸的中心正方形) + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + right = left + min_dim + bottom = top + min_dim + + # 裁剪中心正方形 + img = img.crop((left, top, right, bottom)) + + # 升降分辨率到指定 resolution + # 使用 LANCZOS 高质量重采样方法 + img = img.resize((resolution, resolution), resample=Image.Resampling.LANCZOS) + + # 准备输出格式 + save_format = target_format.upper().replace('JPEG', 'JPG') + + # 保存图片 + # 对于 JPEG/JPG,设置 quality 参数 + if save_format == 'JPG': + img.save(output_path, format='JPEG', quality=95) + else: + img.save(output_path, format=save_format) + + # 根据标记决定是否删除原始文件 + if delete_original and image_path.resolve() != output_path.resolve(): + os.remove(image_path) + + print(f"Processed: {image_path.name} -> {output_path.name} ({resolution}x{resolution} {save_format})") + + except Exception as e: + print(f"Error processing {image_path.name}: {e}") + +# --- 3. 主函数 --- +def main(args): + # 路径准备 + input_dir = Path(args.input_dir) + + if not input_dir.is_dir(): + print(f"Error: Input directory not found at {input_dir}") + return + + # 查找所有图片文件 (支持 jpg, jpeg, png, webp) + valid_suffixes = ['.jpg', '.jpeg', '.png', '.webp'] + image_paths = sorted([p for p in input_dir.iterdir() if p.suffix.lower() in valid_suffixes]) # 排序以确保重命名顺序一致 + + if not image_paths: + print(f"No image files found in {input_dir}") + return + + print(f"Found {len(image_paths)} images in {input_dir}. Starting processing...") + + # 准备目标格式的扩展名 + extension = args.target_format.lower().replace('jpeg', 'jpg') + + # 迭代处理图片 + for i, image_path in enumerate(image_paths): + + # 决定输出路径 + if args.rename_sequential: + # 顺序重命名逻辑:001, 002, 003... (至少三位数字) + new_name = f"{i + 1:03d}.{extension}" + output_path = input_dir / new_name + # 如果原始文件与新文件名称不同,则需要删除原始文件 + delete_original = True + else: + # 保持原始文件名,但修改后缀 + output_path = image_path.with_suffix(f'.{extension}') + # 只有当原始后缀与目标后缀不同时,才需要删除原始文件(防止遗留旧格式) + delete_original = (image_path.suffix.lower() != f'.{extension}') + + process_image(image_path, output_path, args.resolution, args.target_format, delete_original) + + print("Processing complete.") + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/scripts/attack_caat_with_prior.sh b/src/backend/app/scripts/attack_caat_with_prior.sh index a7e149e..ffed2d0 100644 --- a/src/backend/app/scripts/attack_caat_with_prior.sh +++ b/src/backend/app/scripts/attack_caat_with_prior.sh @@ -1,3 +1,4 @@ +#需要环境:conda activate caat export HF_HUB_OFFLINE=1 # 强制使用本地模型缓存,避免联网下载模型 #export HF_HOME="/root/autodl-tmp/huggingface_cache" diff --git a/src/backend/app/scripts/attack_glaze.sh b/src/backend/app/scripts/attack_glaze.sh index cd75d59..ba57d7a 100644 --- a/src/backend/app/scripts/attack_glaze.sh +++ b/src/backend/app/scripts/attack_glaze.sh @@ -1,5 +1,4 @@ -#!/bin/bash - +#需要环境:conda activate pid #============================================================================= # Glaze 风格保护攻击脚本 # 用于保护艺术作品免受 AI 模型的风格模仿 diff --git a/src/backend/app/scripts/attack_glaze_style_trans.sh b/src/backend/app/scripts/attack_glaze_style_trans.sh index e5ca490..b56dcd7 100644 --- a/src/backend/app/scripts/attack_glaze_style_trans.sh +++ b/src/backend/app/scripts/attack_glaze_style_trans.sh @@ -1,5 +1,4 @@ -#!/bin/bash - +#需要环境:conda activate pid #============================================================================= # Glaze 风格保护攻击脚本 # 用于保护艺术作品免受 AI 模型的风格模仿 diff --git a/src/backend/app/scripts/attack_pid.sh b/src/backend/app/scripts/attack_pid.sh index 761890e..768c5cc 100644 --- a/src/backend/app/scripts/attack_pid.sh +++ b/src/backend/app/scripts/attack_pid.sh @@ -4,11 +4,17 @@ export HF_HUB_OFFLINE=1 # 强制使用本地模型缓存,避免联网下载模型 +### SD v2.1 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="stabilityai/stable-diffusion-2-1" + ### SD v1.5 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="runwayml/stable-diffusion-v1-5" export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" -export TASKNAME="task003" +export TASKNAME="task001" ### Data to be protected export INSTANCE_DIR="../../static/originals/${TASKNAME}" ### Path to save the protected data @@ -26,8 +32,7 @@ echo "Clearing output directory: $OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete -export PYTHONWARNINGS="ignore" -#忽略所有警告 + ### Generation command # --max_train_steps: Optimizaiton steps @@ -41,6 +46,6 @@ CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \ --resolution=512 \ --max_train_steps=1000 \ --center_crop \ - --eps 12 \ - --step_size 0.002 \ - --attack_type add-log \ No newline at end of file + --eps 12.75 \ + --attack_type add-log + diff --git a/src/backend/app/scripts/attack_quick.sh b/src/backend/app/scripts/attack_quick.sh deleted file mode 100644 index 29456eb..0000000 --- a/src/backend/app/scripts/attack_quick.sh +++ /dev/null @@ -1,46 +0,0 @@ -#需要环境:conda activate pid -### Generate images protected by PID - -export HF_HUB_OFFLINE=1 -# 强制使用本地模型缓存,避免联网下载模型 - -### SD v1.5 -export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" - - -export TASKNAME="task003" -### Data to be protected -export INSTANCE_DIR="../../static/originals/${TASKNAME}" -### Path to save the protected data -export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" - -# ------------------------- 自动创建依赖路径 ------------------------- -echo "Creating required directories..." -mkdir -p "$INSTANCE_DIR" -mkdir -p "$OUTPUT_DIR" -echo "Directories created successfully." - - -# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- -echo "Clearing output directory: $OUTPUT_DIR" -# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) -find "$OUTPUT_DIR" -mindepth 1 -delete - -export PYTHONWARNINGS="ignore" -#忽略所有警告 - -### Generation command -# --max_train_steps: Optimizaiton steps -# --attack_type: target loss to update, choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'], -# Please refer to the file content for more usage - -CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \ - --pretrained_model_name_or_path=$MODEL_PATH \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --resolution=512 \ - --max_train_steps=120 \ - --eps 16 \ - --step_size 0.01 \ - --attack_type add-log \ - --center_crop \ No newline at end of file diff --git a/src/backend/app/scripts/eva_heat_dif.sh b/src/backend/app/scripts/eva_heat_dif.sh index 4448234..76c5ce8 100644 --- a/src/backend/app/scripts/eva_heat_dif.sh +++ b/src/backend/app/scripts/eva_heat_dif.sh @@ -1,3 +1,4 @@ +#需要环境:conda activate pid # ----------------- 1. 环境与模型配置 ----------------- # 强制 Hugging Face 库使用本地模型缓存 (离线模式) diff --git a/src/backend/app/scripts/eva_nums.sh b/src/backend/app/scripts/eva_nums.sh index 91a4268..0d8273b 100644 --- a/src/backend/app/scripts/eva_nums.sh +++ b/src/backend/app/scripts/eva_nums.sh @@ -1,3 +1,4 @@ +#需要环境:conda activate pid # ----------------- 1. 环境与路径配置 ----------------- export TASKNAME="task001"