算法模块注释规范化 #37

Merged
hnu202326010204 merged 2 commits from hufan_branch into develop 2 weeks ago

@ -1,43 +1,38 @@
"""Stable Diffusion 双模态注意力热力图差异可视化工具。
"""
Stable Diffusion 双模态注意力热力图差异可视化工具
"""
# 通用参数解析与文件路径管理
import argparse
import os
from pathlib import Path
from typing import Dict, Any, List, Tuple
# 数值计算与深度学习依赖
import torch
import torch.nn.functional as F
import numpy as np
import warnings
# 可视化依赖
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Diffusers 与 Transformers 依赖
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import Attention
from transformers import CLIPTokenizer
# 图像处理与数据读取
from PIL import Image
from torchvision import transforms
# 抑制非必要的警告输出
# 关闭与本脚本输出无关的常见警告,减少控制台干扰
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# ============== 核心模块:双模态注意力捕获 ==============
# 双模态注意力捕获模块:在 U-Net 前向过程中同时收集交叉注意力与自注意力
class AttentionMapProcessor:
"""自定义注意力处理器,用于同时捕获 U-Net 的交叉注意力和自注意力权重。"""
# 自定义注意力处理器,用于拦截注意力计算并缓存注意力概率图
def __init__(self, pipeline: StableDiffusionPipeline):
self.cross_attention_maps: Dict[str, List[torch.Tensor]] = {}
self.self_attention_maps: Dict[str, List[torch.Tensor]] = {}
@ -53,21 +48,23 @@ class AttentionMapProcessor:
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""重载执行注意力计算并捕获权重 (支持 Self 和 Cross)。"""
# 同时支持 Cross-Attention 与 Self-Attention区别在于 Key/Value 的来源
is_cross = encoder_hidden_states is not None
sequence_input = encoder_hidden_states if is_cross else hidden_states
# 按 diffusers 的注意力实现方式构造 Q/K/V
query = attn.to_q(hidden_states)
key = attn.to_k(sequence_input)
value = attn.to_v(sequence_input)
# 将多头维度展开到 batch 维,便于矩阵乘
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
# 计算缩放点积注意力分数
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0], query.shape[1], key.shape[1],
query.shape[0], query.shape[1], key.shape[1],
dtype=query.dtype, device=query.device
),
query,
@ -76,44 +73,52 @@ class AttentionMapProcessor:
alpha=attn.scale,
)
# softmax 得到注意力概率,并缓存到 CPU 侧用于后续聚合
attention_probs = attention_scores.softmax(dim=-1)
layer_name = self.current_layer_name
map_to_store = attention_probs.detach().cpu()
# 按层名分别记录交叉注意力与自注意力,便于之后按层聚合
if is_cross:
if layer_name not in self.cross_attention_maps:
self.cross_attention_maps[layer_name] = []
self.cross_attention_maps[layer_name].append(map_to_store)
else:
# 内存保护:仅捕获中低分辨率层的自注意力 (防止 4096*4096 矩阵爆内存)
spatial_size = map_to_store.shape[-2]
if spatial_size <= 1024:
# 自注意力矩阵在高分辨率层会非常大,这里仅保留较小规模层以避免内存问题
spatial_size = map_to_store.shape[-2]
if spatial_size <= 1024:
if layer_name not in self.self_attention_maps:
self.self_attention_maps[layer_name] = []
self.self_attention_maps[layer_name].append(map_to_store)
# 按注意力权重加权求和并回到原始维度,继续 U-Net 的后续计算
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# 线性层与 dropout 等输出映射,与原 Attention 模块保持一致
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def _set_processors(self):
# 遍历 U-Net 中所有 Attention 模块,将其 processor 替换为可记录层名的包装调用
for name, module in self.pipeline.unet.named_modules():
if isinstance(module, Attention):
if 'attn1' in name or 'attn2' in name:
self.original_processors[name] = module.processor
def set_layer_name(current_name):
def new_call(*args, **kwargs):
self.current_layer_name = current_name
return self.__call__(*args, **kwargs)
return new_call
module.processor = set_layer_name(name)
def remove(self):
# 还原所有 Attention 模块的原始 processor并清空缓存数据
for name, original_processor in self.original_processors.items():
module = self.pipeline.unet.get_submodule(name)
module.processor = original_processor
@ -121,7 +126,7 @@ class AttentionMapProcessor:
self.self_attention_maps = {}
# ============== 聚合逻辑 ==============
# 注意力图聚合模块:将多层、多步的注意力数据统一聚合到固定大小的 2D 热力图
def aggregate_cross_attention(
attention_maps: Dict[str, List[torch.Tensor]],
@ -129,7 +134,7 @@ def aggregate_cross_attention(
target_word: str,
input_ids: torch.Tensor
) -> np.ndarray:
"""聚合交叉注意力:关注 Prompt 中的特定 Target Word。"""
# 将 Prompt 的 token 切分结果与目标词进行匹配,找到目标词对应的 token 索引
prompt_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().tolist())
target_lower = target_word.lower()
target_indices = []
@ -140,20 +145,25 @@ def aggregate_cross_attention(
(target_lower in cleaned_token or cleaned_token.startswith(target_lower))):
target_indices.append(i)
# 未命中目标词时返回全零图,避免后续流程崩溃
if not target_indices:
print(f"[WARN] Cross-Attn: 目标词汇 '{target_word}' 未识别。")
return np.zeros((64, 64))
all_attention_data = []
TARGET_SPATIAL_SIZE = 4096
TARGET_SPATIAL_SIZE = 4096
TARGET_MAP_SIZE = 64
# 逐层将注意力概率进行时间步平均,再对目标 token 通道求和得到空间关注强度
for layer_name, step_maps in attention_maps.items():
if not step_maps: continue
if not step_maps:
continue
avg_map = torch.stack(step_maps).mean(dim=0)
if avg_map.dim() == 4: avg_map = avg_map.squeeze(0)
if avg_map.dim() == 4:
avg_map = avg_map.squeeze(0)
target_map = avg_map[:, :, target_indices].sum(dim=-1).mean(dim=0).float()
# 不同层的空间分辨率不同,统一插值到固定尺寸以便跨层融合
if target_map.shape[0] != TARGET_SPATIAL_SIZE:
map_size = int(np.sqrt(target_map.shape[0]))
map_2d = target_map.reshape(map_size, map_size).unsqueeze(0).unsqueeze(0)
@ -162,8 +172,10 @@ def aggregate_cross_attention(
else:
all_attention_data.append(target_map)
if not all_attention_data: return np.zeros((64, 64))
if not all_attention_data:
return np.zeros((64, 64))
# 跨层求和并归一化到 0-1便于可视化对比
final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy()
final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6)
return final_map_flat.reshape(TARGET_MAP_SIZE, TARGET_MAP_SIZE)
@ -172,82 +184,56 @@ def aggregate_cross_attention(
def aggregate_self_attention(
attention_maps: Dict[str, List[torch.Tensor]]
) -> np.ndarray:
"""聚合自注意力:计算高频空间能量 (Laplacian High-Frequency Energy)。
原理
风格和纹理通常体现为注意力图中的高频变化
通过对每个 Query Attention Map 应用拉普拉斯算子Laplacian Kernel
我们可以提取出那些变化剧烈的区域边缘纹理接缝
最后聚合这些高频能量得到的图在空间结构上与原图对齐但亮度代表了纹理/风格复杂度
"""
# 将自注意力矩阵转为与空间对齐的强度图,这里使用拉普拉斯算子提取高频能量作为纹理强度代理
all_attention_data = []
TARGET_MAP_SIZE = 64
# 定义拉普拉斯卷积核用于提取高频信息
laplacian_kernel = torch.tensor([
[0, 1, 0],
[1, -4, 1],
[0, 1, 0],
[1, -4, 1],
[0, 1, 0]
], dtype=torch.float32).view(1, 1, 3, 3)
# 逐层对自注意力矩阵进行时间步与多头平均,再对每个 query 的注意力图做高频响应统计
for layer_name, step_maps in attention_maps.items():
if not step_maps: continue
# [Heads, H*W, H*W] -> [H*W, H*W] 取平均
if not step_maps:
continue
avg_matrix = torch.stack(step_maps).mean(dim=0).mean(dim=0).float()
# 获取当前层尺寸
current_pixels = avg_matrix.shape[0]
map_size = int(np.sqrt(current_pixels))
# 如果尺寸太小,高频信息没有意义,跳过极小层
# 极小尺度的注意力图通常缺少有效纹理结构信息,这里直接跳过
if map_size < 16:
continue
# 重塑为图像形式: [Batch(Pixels), Channels(1), H, W]
# 这里我们将 avg_matrix 视为:对于每一个 query pixel (行),它关注的 spatial map (列)
# 我们想知道每个 pixel 关注的区域是不是包含很多高频纹理
attn_maps = avg_matrix.reshape(current_pixels, 1, map_size, map_size) # [N, 1, H, W]
# 将 Kernel 移到同一设备
attn_maps = avg_matrix.reshape(current_pixels, 1, map_size, map_size)
kernel = laplacian_kernel.to(avg_matrix.device)
# 批量卷积计算高频响应 (High-Pass Filter)
# padding=1 保持尺寸不变
# 对每个 query 的空间注意力图做拉普拉斯卷积,得到高频响应
high_freq_response = F.conv2d(attn_maps, kernel, padding=1)
# 计算能量 (取绝对值或平方),这里取绝对值代表梯度的强度
# 用绝对值表示高频强度,并对每个 query 累计其响应作为空间分数
high_freq_energy = torch.abs(high_freq_response)
# 现在我们得到了 [N, 1, H, W] 的高频能量图。
# 我们需要将其聚合回一张 [H, W] 的图。
# 含义:对于图像上的位置 (i, j),其作为 Query 时,所关注的区域包含了多少高频信息?
# 或者:作为 Key 时,它贡献了多少高频信息?
# 这里采用 "Query-based Aggregation"
# 计算每个 Query pixel 对高频信息的总响应
# shape: [N, 1, H, W] -> sum(dim=(2,3)) -> [N]
# 这表示:位置 N 的像素,其注意力主要集中在高频纹理区域的程度。
spatial_score_flat = high_freq_energy.sum(dim=(2, 3)).squeeze() # [H*W]
# 归一化这一层的分数,防止数值爆炸
spatial_score_flat = high_freq_energy.sum(dim=(2, 3)).squeeze()
# 层内归一化避免不同层的数值尺度影响跨层融合
spatial_score_flat = spatial_score_flat / (spatial_score_flat.max() + 1e-6)
# 重塑为 2D 空间图
map_2d = spatial_score_flat.reshape(map_size, map_size).unsqueeze(0).unsqueeze(0)
# 插值统一到目标尺寸
resized = F.interpolate(map_2d, size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE), mode='bilinear', align_corners=False)
all_attention_data.append(resized.squeeze().flatten())
if not all_attention_data: return np.zeros((64, 64))
if not all_attention_data:
return np.zeros((64, 64))
# 聚合所有层
# 跨层求和并做 0-1 归一化,得到最终纹理强度热力图
final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy()
# 最终归一化,保持 0-1 范围,方便可视化
final_map_flat = (final_map_flat - final_map_flat.min()) / (final_map_flat.max() - final_map_flat.min() + 1e-6)
return final_map_flat.reshape(TARGET_MAP_SIZE, TARGET_MAP_SIZE)
@ -257,7 +243,7 @@ def get_dual_attention_maps(
prompt_text: str,
target_word: str
) -> Tuple[Image.Image, np.ndarray, np.ndarray]:
"""同时获取 Cross-Attention 和 Self-Attention 热力图。"""
# 对输入图像进行编码,并在少量时间步上运行 U-Net 来提取注意力分布
print(f"\n-> 正在处理图片: {Path(image_path).name}")
image = Image.open(image_path).convert("RGB").resize((512, 512))
image_tensor = transforms.Compose([
@ -267,35 +253,40 @@ def get_dual_attention_maps(
with torch.no_grad():
latent = (pipeline.vae.encode(image_tensor).latent_dist.sample() * pipeline.vae.config.scaling_factor)
text_input = pipeline.tokenizer(prompt_text, padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input = pipeline.tokenizer(
prompt_text,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
prompt_embeds = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
scheduler = pipeline.scheduler
scheduler.set_timesteps(50, device=pipeline.device)
semantic_steps = scheduler.timesteps[:10]
semantic_steps = scheduler.timesteps[:10]
processor = AttentionMapProcessor(pipeline)
try:
with torch.no_grad():
for t in semantic_steps:
pipeline.unet(latent, t, prompt_embeds, return_dict=False)
pipeline.unet(latent, t, prompt_embeds, return_dict=False)
cross_map_raw = aggregate_cross_attention(
processor.cross_attention_maps, pipeline.tokenizer, target_word, text_input.input_ids
)
self_map_raw = aggregate_self_attention(processor.self_attention_maps)
except Exception as e:
print(f"[ERROR] 注意力聚合失败: {e}")
# import traceback
# traceback.print_exc()
cross_map_raw = np.zeros((64, 64))
self_map_raw = np.zeros((64, 64))
finally:
processor.remove()
# 将 64x64 热力图上采样到与原图一致的空间大小,便于直接叠加或对比展示
def upsample(map_np):
pil_img = Image.fromarray((map_np * 255).astype(np.uint8))
return np.array(pil_img.resize(image.size, resample=Image.Resampling.LANCZOS)) / 255.0
@ -304,6 +295,7 @@ def get_dual_attention_maps(
def main():
# 解析运行参数并生成对比报告图像
parser = argparse.ArgumentParser(description="SD 双模态注意力差异分析报告")
parser.add_argument("--model_path", type=str, required=True, help="Stable Diffusion 模型路径")
parser.add_argument("--image_path_a", type=str, required=True, help="Clean Image")
@ -314,90 +306,100 @@ def main():
args = parser.parse_args()
print(f"--- 正在生成 Museguard 双模态分析报告 (High-Freq Energy Mode) ---")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
try:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path, torch_dtype=dtype, local_files_only=True, safety_checker=None,
args.model_path,
torch_dtype=dtype,
local_files_only=True,
safety_checker=None,
scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler")
).to(device)
except Exception as e:
print(f"[ERROR] 模型加载失败: {e}"); return
print(f"[ERROR] 模型加载失败: {e}")
return
img_A, cross_A, self_A = get_dual_attention_maps(pipe, args.image_path_a, args.prompt_text, args.target_word)
img_B, cross_B, self_B = get_dual_attention_maps(pipe, args.image_path_b, args.prompt_text, args.target_word)
diff_cross = cross_A - cross_B
l2_cross = np.linalg.norm(diff_cross)
diff_self = self_A - self_B
l2_self = np.linalg.norm(diff_self)
print(f"\nCross-Attn L2 Diff: {l2_cross:.4f}")
print(f"Self-Attn L2 Diff: {l2_self:.4f}")
# ---------------- 绘制增强版报告 ----------------
# 使用统一布局展示原图、两类注意力图及其差分图,并输出为单张报告图片
plt.rcParams.update({'font.family': 'serif', 'mathtext.fontset': 'cm'})
fig = plt.figure(figsize=(14, 22), dpi=100)
gs = gridspec.GridSpec(4, 4, figure=fig, height_ratios=[1, 1, 1, 1.2], hspace=0.3, wspace=0.1)
# Row 1: Images
ax_img_a = fig.add_subplot(gs[0, 0:2])
ax_img_b = fig.add_subplot(gs[0, 2:4])
ax_img_a.imshow(img_A); ax_img_a.set_title("Clean Image ($X$)", fontsize=14); ax_img_a.axis('off')
ax_img_b.imshow(img_B); ax_img_b.set_title("Noisy Image ($X'$)", fontsize=14); ax_img_b.axis('off')
ax_img_a.imshow(img_A)
ax_img_a.set_title("Clean Image ($X$)", fontsize=14)
ax_img_a.axis('off')
ax_img_b.imshow(img_B)
ax_img_b.set_title("Noisy Image ($X'$)", fontsize=14)
ax_img_b.axis('off')
# Row 2: Cross Attention
ax_cA = fig.add_subplot(gs[1, 0:2])
ax_cB = fig.add_subplot(gs[1, 2:4])
ax_cA.imshow(cross_A, cmap='jet', vmin=0, vmax=1)
ax_cA.set_title(f"Cross-Attn ($M^{{cross}}_X$)\nTarget: \"{args.target_word}\"", fontsize=14); ax_cA.axis('off')
ax_cA.set_title(f"Cross-Attn ($M^{{cross}}_X$)\nTarget: \"{args.target_word}\"", fontsize=14)
ax_cA.axis('off')
im_cB = ax_cB.imshow(cross_B, cmap='jet', vmin=0, vmax=1)
ax_cB.set_title(f"Cross-Attn ($M^{{cross}}_{{X'}}$)", fontsize=14); ax_cB.axis('off')
ax_cB.set_title(f"Cross-Attn ($M^{{cross}}_{{X'}}$)", fontsize=14)
ax_cB.axis('off')
divider = make_axes_locatable(ax_cB)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im_cB, cax=cax, label='Semantic Alignment')
# Row 3: Self Attention (High-Frequency Energy Mode)
ax_sA = fig.add_subplot(gs[2, 0:2])
ax_sB = fig.add_subplot(gs[2, 2:4])
# 恢复使用与 Cross Attention 一致的 'jet' colormap
ax_sA.imshow(self_A, cmap='jet', vmin=0, vmax=1)
ax_sA.set_title(f"Self-Attn ($M^{{self}}_X$)\nHigh-Freq Energy (Texture)", fontsize=14); ax_sA.axis('off')
ax_sA.set_title(f"Self-Attn ($M^{{self}}_X$)\nHigh-Freq Energy (Texture)", fontsize=14)
ax_sA.axis('off')
im_sB = ax_sB.imshow(self_B, cmap='jet', vmin=0, vmax=1)
ax_sB.set_title(f"Self-Attn ($M^{{self}}_{{X'}}$)", fontsize=14); ax_sB.axis('off')
ax_sB.set_title(f"Self-Attn ($M^{{self}}_{{X'}}$)", fontsize=14)
ax_sB.axis('off')
divider = make_axes_locatable(ax_sB)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im_sB, cax=cax, label='Texture Intensity')
# Row 4: Differences
ax_diff_c = fig.add_subplot(gs[3, 0:2])
ax_diff_s = fig.add_subplot(gs[3, 2:4])
vmax_c = max(np.max(np.abs(diff_cross)), 0.1)
norm_c = TwoSlopeNorm(vmin=-vmax_c, vcenter=0., vmax=vmax_c)
im_dc = ax_diff_c.imshow(diff_cross, cmap='coolwarm', norm=norm_c)
ax_diff_c.set_title(f"Cross Diff ($\Delta_{{cross}}$)\n$L_2$: {l2_cross:.4f}", fontsize=14); ax_diff_c.axis('off')
ax_diff_c.set_title(f"Cross Diff ($\\Delta_{{cross}}$)\n$L_2$: {l2_cross:.4f}", fontsize=14)
ax_diff_c.axis('off')
plt.colorbar(im_dc, ax=ax_diff_c, fraction=0.046, pad=0.04)
vmax_s = max(np.max(np.abs(diff_self)), 0.1)
norm_s = TwoSlopeNorm(vmin=-vmax_s, vcenter=0., vmax=vmax_s)
im_ds = ax_diff_s.imshow(diff_self, cmap='coolwarm', norm=norm_s)
ax_diff_s.set_title(f"Self Diff ($\Delta_{{self}}$)\n$L_2$: {l2_self:.4f}", fontsize=14); ax_diff_s.axis('off')
ax_diff_s.set_title(f"Self Diff ($\\Delta_{{self}}$)\n$L_2$: {l2_self:.4f}", fontsize=14)
ax_diff_s.axis('off')
plt.colorbar(im_ds, ax=ax_diff_s, fraction=0.046, pad=0.04)
fig.suptitle(f"Museguard: Dual-Mode Analysis (High-Freq Energy)", fontsize=20, fontweight='bold', y=0.92)
out_path = Path(args.output_dir) / "dual_heatmap_report.png"
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, bbox_inches='tight', facecolor='white')
print(f"\n报告已保存至: {out_path}")
if __name__ == "__main__":
main()

@ -1,9 +1,6 @@
"""图像生成质量多维度评估工具 (专业重构版)。
本脚本用于对比评估两组图像Clean vs Perturbed的生成质量
"""
用于对比评估两组图像Clean vs Perturbed的生成质量
支持生成包含指标对比表和深度差异分析的 PNG 报告
Style Guide: Google Python Style Guide
"""
import os
@ -27,15 +24,11 @@ from facenet_pytorch import MTCNN, InceptionResnetV1
from piq import ssim, psnr
import torch_fidelity as fid
# 抑制非必要的警告输出
# 关闭与评估过程无关的常见警告,避免影响关键信息阅读
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# -----------------------------------------------------------------------------
# 全局配置与样式
# -----------------------------------------------------------------------------
# Matplotlib LaTeX 风格配置
# 全局样式配置:统一图表字体、数学公式风格与负号显示效果
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
@ -43,7 +36,7 @@ plt.rcParams.update({
'axes.unicode_minus': False
})
# 指标元数据配置:定义指标目标方向和分析阈值
# 指标配置:给出每个指标的优劣方向以及用于分级判断的阈值
METRIC_ANALYSIS_META = {
'FID': {'higher_is_better': False, 'th': [2.0, 10.0, 30.0]},
'SSIM': {'higher_is_better': True, 'th': [0.01, 0.05, 0.15]},
@ -52,13 +45,12 @@ METRIC_ANALYSIS_META = {
'CLIP_IQS': {'higher_is_better': True, 'th': [0.01, 0.03, 0.08]},
'BRISQUE': {'higher_is_better': False, 'th': [2.0, 5.0, 10.0]},
}
# 用于综合分析的降级权重
# 综合结论中用于累加的权重,用于把分级差异映射成总体降级强度
ANALYSIS_WEIGHTS = {'Severe': 3, 'Significant': 2, 'Slight': 1, 'Negligible': 0}
# -----------------------------------------------------------------------------
# 模型加载 (惰性加载或全局预加载)
# -----------------------------------------------------------------------------
# 模型加载模块:在脚本启动时尝试预加载 CLIP失败时自动降级为不计算该项指标
try:
CLIP_MODEL, CLIP_PREPROCESS = clip.load('ViT-B/32', 'cuda')
@ -67,8 +59,9 @@ except Exception as e:
print(f"[Warning] CLIP 模型加载失败: {e}")
CLIP_MODEL, CLIP_PREPROCESS = None, None
def _get_clip_text_features(text: str) -> torch.Tensor:
"""辅助函数:获取文本的 CLIP 特征。"""
# 将文本编码为 CLIP 特征并归一化,用于后续与图像特征计算相似度
if CLIP_MODEL is None:
return None
tokens = clip.tokenize(text).to('cuda')
@ -77,31 +70,19 @@ def _get_clip_text_features(text: str) -> torch.Tensor:
features /= features.norm(dim=-1, keepdim=True)
return features
# -----------------------------------------------------------------------------
# 核心计算逻辑
# -----------------------------------------------------------------------------
# 指标计算模块:对两个图像集合计算多项指标,用于后续报告展示与差异分析
def calculate_metrics(
ref_dir: str,
gen_dir: str,
image_size: int = 512
) -> Dict[str, float]:
"""计算图像集之间的多项质量评估指标。
包括 FDS, SSIM, PSNR, CLIP_IQS, FID
Args:
ref_dir: 参考图片目录路径
gen_dir: 生成图片目录路径
image_size: 图像处理尺寸
Returns:
包含各项指标名称和数值的字典若目录无效返回空字典
"""
# 从目录读取图像并在同一设备上计算 FDS、SSIM、PSNR、CLIP_IQS 与 FID
metrics = {}
# 1. 数据加载
def load_images(directory):
# 读取目录下常见格式图像并转换为 RGB忽略无法打开的文件
imgs = []
if os.path.exists(directory):
for f in os.listdir(directory):
@ -116,6 +97,7 @@ def calculate_metrics(
ref_imgs = load_images(ref_dir)
gen_imgs = load_images(gen_dir)
# 若任一集合为空则直接返回,避免后续指标计算出错
if not ref_imgs or not gen_imgs:
print(f"[Error] 图片加载失败或目录为空: \nRef: {ref_dir}\nGen: {gen_dir}")
return {}
@ -123,12 +105,13 @@ def calculate_metrics(
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
# --- FDS (Face Detection Similarity) ---
# FDS使用人脸检测与人脸特征模型度量身份相似度
print(">>> 计算 FDS...")
mtcnn = MTCNN(image_size=image_size, margin=0, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
def get_face_embeds(img_list):
# 对每张图做检测与对齐,成功则提取人脸特征并收集为张量
embeds = []
for img in img_list:
face = mtcnn(img)
@ -140,7 +123,6 @@ def calculate_metrics(
gen_embeds = get_face_embeds(gen_imgs)
if ref_embeds is not None and gen_embeds is not None:
# 计算生成集每张脸与参考集所有脸的余弦相似度均值
sims = []
for g_emb in gen_embeds:
sim = torch.cosine_similarity(g_emb, ref_embeds).mean()
@ -149,39 +131,35 @@ def calculate_metrics(
else:
metrics['FDS'] = 0.0
# 清理显存
# 释放中间模型并回收显存,避免后续指标计算显存不足
del mtcnn, resnet
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- SSIM & PSNR ---
# SSIM 与 PSNR以参考集合为基准对每张生成图计算与参考集合的平均相似度
print(">>> 计算 SSIM & PSNR...")
tfm = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
# 将参考集堆叠为 [N, C, H, W]
ref_tensor = torch.stack([tfm(img) for img in ref_imgs]).to(device)
ssim_accum, psnr_accum = 0.0, 0.0
for img in gen_imgs:
gen_tensor = tfm(img).unsqueeze(0).to(device) # [1, C, H, W]
# 扩展维度以匹配参考集
gen_tensor = tfm(img).unsqueeze(0).to(device)
gen_expanded = gen_tensor.expand_as(ref_tensor)
# 计算单张生成图相对于整个参考集的平均结构相似度
val_ssim = ssim(gen_expanded, ref_tensor, data_range=1.0)
val_psnr = psnr(gen_expanded, ref_tensor, data_range=1.0)
ssim_accum += val_ssim.item()
psnr_accum += val_psnr.item()
metrics['SSIM'] = ssim_accum / len(gen_imgs)
metrics['PSNR'] = psnr_accum / len(gen_imgs)
# --- CLIP IQS ---
# CLIP_IQS用“good image”作为文本锚点计算生成图与该文本概念的相似度均值
print(">>> 计算 CLIP IQS...")
if CLIP_MODEL:
iqs_accum = 0.0
@ -195,7 +173,7 @@ def calculate_metrics(
else:
metrics['CLIP_IQS'] = np.nan
# --- FID ---
# FID使用 torch_fidelity 计算两个目录的分布距离
print(">>> 计算 FID...")
try:
fid_res = fid.calculate_metrics(
@ -214,23 +192,16 @@ def calculate_metrics(
def run_brisque_cleanly(img_dir: str) -> float:
"""使用 subprocess 和临时目录优雅地执行外部 BRISQUE 脚本。
Args:
img_dir: 图像目录路径
Returns:
BRISQUE 分数若失败返回 NaN
"""
# 通过子进程调用外部 BRISQUE 脚本,并用临时目录承载其输出文件
print(f">>> 计算 BRISQUE (External)...")
script_path = Path(__file__).parent / 'libsvm' / 'python' / 'brisquequality.py'
if not script_path.exists():
print(f"[Error] 找不到 BRISQUE 脚本: {script_path}")
return np.nan
abs_img_dir = os.path.abspath(img_dir)
with tempfile.TemporaryDirectory() as temp_dir:
try:
cmd = [
@ -238,17 +209,16 @@ def run_brisque_cleanly(img_dir: str) -> float:
abs_img_dir,
temp_dir
]
# 在脚本所在目录执行
subprocess.run(
cmd,
cwd=script_path.parent,
check=True,
stdout=subprocess.PIPE,
cmd,
cwd=script_path.parent,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 读取临时生成的日志文件
# 从临时目录读取脚本写出的 log.txt并解析其中的最终分数
log_file = Path(temp_dir) / 'log.txt'
if log_file.exists():
content = log_file.read_text(encoding='utf-8').strip()
@ -258,47 +228,33 @@ def run_brisque_cleanly(img_dir: str) -> float:
return float(content)
else:
return np.nan
except Exception as e:
print(f"[Error] BRISQUE 执行出错: {e}")
return np.nan
# -----------------------------------------------------------------------------
# 报告可视化与分析逻辑
# -----------------------------------------------------------------------------
# 报告生成模块:对指标差异进行分级解释,并渲染成包含样例图与表格的 PNG 报告
def analyze_metric_diff(
metric_name: str,
clean_val: float,
metric_name: str,
clean_val: float,
pert_val: float
) -> Tuple[str, str, str]:
"""生成科学的分级差异分析文本。
Args:
metric_name: 指标名称
clean_val: 干净图得分
pert_val: 扰动图得分
Returns:
(表头箭头符号, 差异描述文本, 状态等级)
"""
# 根据指标配置计算差异,并输出用于表格与文本解释的分析结果
cfg = METRIC_ANALYSIS_META.get(metric_name)
if not cfg:
return "-", "Configuration not found.", "Negligible"
diff = pert_val - clean_val
abs_diff = abs(diff)
# 判定好坏:
is_better = (cfg['higher_is_better'] and diff > 0) or (not cfg['higher_is_better'] and diff < 0)
is_worse = not is_better
# 确定程度
th = cfg['th']
if abs_diff < th[0]:
degree = "Negligible"
degree = "Negligible"
elif abs_diff < th[1]:
degree = "Slight"
elif abs_diff < th[2]:
@ -306,9 +262,8 @@ def analyze_metric_diff(
else:
degree = "Severe"
# 组装文案
header_arrow = r"$\uparrow$" if cfg['higher_is_better'] else r"$\downarrow$"
if degree == "Negligible":
analysis_text = f"Negligible change (diff < {th[0]:.4f})."
elif is_worse:
@ -320,31 +275,29 @@ def analyze_metric_diff(
def generate_visual_report(
ref_dir: str,
clean_dir: str,
pert_dir: str,
clean_metrics: Dict,
pert_metrics: Dict,
ref_dir: str,
clean_dir: str,
pert_dir: str,
clean_metrics: Dict,
pert_metrics: Dict,
output_path: str
):
"""渲染并保存专业对比分析报告 (PNG)。"""
# 从三个目录各取一张样例图,并将指标对比表与差异解释一起绘制到同一张图中
def get_sample(d):
if not os.path.exists(d): return None, "N/A"
files = [f for f in os.listdir(d) if f.lower().endswith(('.png','.jpg'))]
if not files: return None, "Empty"
if not os.path.exists(d):
return None, "N/A"
files = [f for f in os.listdir(d) if f.lower().endswith(('.png', '.jpg'))]
if not files:
return None, "Empty"
return Image.open(os.path.join(d, files[0])).convert("RGB"), files[0]
img_ref, name_ref = get_sample(ref_dir)
img_clean, name_clean = get_sample(clean_dir)
img_pert, name_pert = get_sample(pert_dir)
# 布局设置
# 增加高度以容纳文本
fig = plt.figure(figsize=(12, 16.5), dpi=120)
fig = plt.figure(figsize=(12, 16.5), dpi=120)
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.5], hspace=0.25, wspace=0.1)
# 1. 图像展示区
ax_ref = fig.add_subplot(gs[0, :])
if img_ref:
ax_ref.imshow(img_ref)
@ -363,80 +316,73 @@ def generate_visual_report(
ax_p.set_title(f"Perturbed Output ($Y'$)\n{name_pert}", fontsize=12, fontweight='bold', pad=10)
ax_p.axis('off')
# 2. 数据表格与分析区
ax_data = fig.add_subplot(gs[2, :])
ax_data.axis('off')
metrics_list = ['FID', 'SSIM', 'PSNR', 'FDS', 'CLIP_IQS', 'BRISQUE']
table_data = []
analysis_lines = []
degradation_score = 0
# 遍历指标生成数据和分析
# 为每个指标生成表格行,并收集对应的差异解释文本
for m in metrics_list:
c_val = clean_metrics.get(m, np.nan)
p_val = pert_metrics.get(m, np.nan)
c_str = f"{c_val:.4f}" if not np.isnan(c_val) else "N/A"
p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "N/A"
diff_str = "-"
header_arrow = ""
header_arrow = ""
if not np.isnan(c_val) and not np.isnan(p_val):
# 获取深度分析
header_arrow, text_desc, degree = analyze_metric_diff(m, c_val, p_val)
# 计算差异值
diff = p_val - c_val
# 差异值本身的符号 (Diff > 0 或 Diff < 0)
diff_arrow = r"$\nearrow$" if diff > 0 else r"$\searrow$"
if abs(diff) < 1e-4: diff_arrow = r"$\rightarrow$"
if abs(diff) < 1e-4:
diff_arrow = r"$\rightarrow$"
diff_str = f"{diff:+.4f} {diff_arrow}"
analysis_lines.append(f"{m}: Change {diff:+.4f}. Analysis: {text_desc}")
# 累计降级分数
cfg = METRIC_ANALYSIS_META.get(m)
is_worse = (cfg['higher_is_better'] and diff < 0) or (not cfg['higher_is_better'] and diff > 0)
if is_worse:
degradation_score += ANALYSIS_WEIGHTS.get(degree, 0)
# 表格第一列:名称 + 期望方向箭头
name_with_arrow = f"{m} ({header_arrow})" if header_arrow else m
table_data.append([name_with_arrow, c_str, p_str, diff_str])
# 绘制表格
table = ax_data.table(
cellText=table_data,
colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\Delta$)"],
colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\\Delta$)"],
loc='upper center',
cellLoc='center',
colWidths=[0.25, 0.25, 0.25, 0.25]
)
table.scale(1, 2.0)
table.set_fontsize(11)
# 美化表头
# 对表头与第一列做基础样式区分,提升可读性
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_text_props(weight='bold', color='white')
cell.set_facecolor('#404040')
cell.set_facecolor('#404040')
elif col == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f5f5f5')
# 3. 底部综合分析文本框
# 汇总差异分析文本并给出基于权重的总体结论
if not analysis_lines:
analysis_lines.append("• All metrics are missing or invalid.")
full_text = "Quantitative Difference Analysis:\n" + "\n".join(analysis_lines)
# 总体结论判断 (基于 holistic degradation score)
conclusion = "\n\n>>> EXECUTIVE SUMMARY (Holistic Judgment):\n"
if degradation_score >= 8:
conclusion += "CRITICAL DEGRADATION. Significant quality loss observed. Attack highly effective."
elif degradation_score >= 4:
@ -448,9 +394,7 @@ def generate_visual_report(
full_text += conclusion
# ---------------------------------------------------------------------
# 4. Metric definitions (ASCII-only / English-only to avoid font issues)
# ---------------------------------------------------------------------
# 在报告底部补充指标含义说明,便于非专业读者理解各项指标的侧重点
metric_definitions = [
"",
"",
@ -485,52 +429,45 @@ def generate_visual_report(
ax_data.text(
0.05,
0.30,
full_text,
ha='left',
va='top',
full_text,
ha='left',
va='top',
fontsize=12, family='monospace', wrap=True,
transform=ax_data.transAxes
)
fig.suptitle("Museguard: Quality Assurance Report", fontsize=18, fontweight='bold', y=0.95)
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
print(f"\n[Success] 报告已生成: {output_path}")
# -----------------------------------------------------------------------------
# 主入口
# -----------------------------------------------------------------------------
def main():
# 解析参数,分别评估 Clean 与 Perturbed 两组输出,并生成汇总报告
parser = ArgumentParser()
parser.add_argument('--clean_output_dir', type=str, required=True)
parser.add_argument('--perturbed_output_dir', type=str, required=True)
parser.add_argument('--clean_ref_dir', type=str, required=True)
parser.add_argument('--png_output_path', type=str, required=True)
parser.add_argument('--png_output_path', type=str, required=True)
parser.add_argument('--size', type=int, default=512)
args = parser.parse_args()
Path(args.png_output_path).parent.mkdir(parents=True, exist_ok=True)
print("========================================")
print(" Image Quality Evaluation Toolkit")
print("========================================")
# 1. 计算 Clean 组
print(f"\n[1/2] Evaluating Clean Set: {os.path.basename(args.clean_output_dir)}")
c_metrics = calculate_metrics(args.clean_ref_dir, args.clean_output_dir, args.size)
if c_metrics:
c_metrics['BRISQUE'] = run_brisque_cleanly(args.clean_output_dir)
# 2. 计算 Perturbed 组
print(f"\n[2/2] Evaluating Perturbed Set: {os.path.basename(args.perturbed_output_dir)}")
p_metrics = calculate_metrics(args.clean_ref_dir, args.perturbed_output_dir, args.size)
if p_metrics:
p_metrics['BRISQUE'] = run_brisque_cleanly(args.perturbed_output_dir)
# 3. 生成报告
if c_metrics and p_metrics:
generate_visual_report(
args.clean_ref_dir,
@ -543,5 +480,6 @@ def main():
else:
print("\n[Fatal] 评估数据不完整,中止报告生成。")
if __name__ == '__main__':
main()

@ -1,6 +1,3 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
import contextlib
import copy
@ -44,20 +41,14 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# 可选启用 wandb 记录,未安装时不影响训练主流程
if is_wandb_available():
import wandb
logger = get_logger(__name__)
# -------------------------------------------------------------------------
# 功能模块:模型卡保存
# 1) 该模块用于生成/更新 README.md记录训练来源与关键配置
# 2) 支持将训练后验证生成的示例图片写入输出目录并写入引用
# 3) 便于后续将模型上传到 Hub 时展示效果与实验信息
# 4) 不参与训练与梯度计算,不影响参数更新与收敛行为
# 5) 既可服务于 Hub 发布,也可用于本地实验的结果归档
# -------------------------------------------------------------------------
# 将训练信息与样例图写入模型卡,便于本地归档与推送到 HuggingFace Hub
def save_model_card(
repo_id: str,
images: list | None = None,
@ -68,6 +59,7 @@ def save_model_card(
pipeline: DiffusionPipeline | None = None,
):
img_str = ""
# 将推理样例落盘到输出目录,并在 README.md 中插入相对路径引用
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@ -99,14 +91,7 @@ def save_model_card(
model_card.save(os.path.join(repo_folder, "README.md"))
# -------------------------------------------------------------------------
# 功能模块训练后纯文本推理validation
# 1) 该模块仅在训练完全结束后执行,不参与训练过程与优化器状态
# 2) 该模块从 output_dir 重新加载微调后的 pipeline避免与训练对象耦合
# 3) 推理只接受文本提示词,不输入任何图像,不走 img2img 相关路径
# 4) 可设置推理步数与随机种子,方便提高细节并保证可复现
# 5) 输出 PIL 图片列表,可保存到目录并写入日志系统便于对比分析
# -------------------------------------------------------------------------
# 训练结束后的纯文本推理:从输出目录重新加载 pipeline保证推理与训练对象解耦
def run_validation_txt2img(
finetuned_model_dir: str,
prompt: str,
@ -123,6 +108,7 @@ def run_validation_txt2img(
f"开始 validation 文生图:数量={num_images},步数={num_inference_steps}guidance={guidance_scale},提示词={prompt}"
)
# 只加载 txt2img 所需组件,并禁用 safety_checker 以避免额外开销与拦截
pipe = StableDiffusionPipeline.from_pretrained(
finetuned_model_dir,
torch_dtype=weight_dtype,
@ -130,6 +116,7 @@ def run_validation_txt2img(
local_files_only=True,
)
# 保证是 StableDiffusionPipeline避免加载到不兼容的管线导致参数不一致
if not isinstance(pipe, StableDiffusionPipeline):
raise TypeError(f"加载的 pipeline 类型异常:{type(pipe)},需要 StableDiffusionPipeline 才能保证纯文本生图。")
@ -137,9 +124,11 @@ def run_validation_txt2img(
pipe.set_progress_bar_config(disable=True)
pipe.safety_checker = lambda images, clip_input: (images, [False for _ in range(len(images))])
# 使用 slicing 降低推理时显存占用,便于在训练机上额外运行验证
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
# 根据 accelerate 的混精配置选择 autocast上下文外不改变全局 dtype 行为
if accelerator.device.type == "cuda":
if accelerator.mixed_precision == "bf16":
infer_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@ -150,6 +139,7 @@ def run_validation_txt2img(
else:
infer_ctx = contextlib.nullcontext()
# 为每张图单独设置种子偏移,保证同一次验证多图可复现且互不相同
images = []
with infer_ctx:
for i in range(num_images):
@ -166,6 +156,7 @@ def run_validation_txt2img(
)
images.append(out.images[0])
# 将验证图片写入 tracker便于对比不同 step 的训练效果
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
@ -179,6 +170,7 @@ def run_validation_txt2img(
}
)
# 显式释放管线与缓存,避免与训练过程竞争显存
del pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -186,14 +178,7 @@ def run_validation_txt2img(
return images
# -------------------------------------------------------------------------
# 功能模块:从模型目录推断 TextEncoder 类型
# 1) 不同扩散模型对应不同文本编码器架构,需动态识别加载类
# 2) 通过读取 text_encoder/config 来获取 architectures 字段
# 3) 该模块返回类对象,用于后续 from_pretrained 加载权重
# 4) 便于同一训练脚本兼容多模型,而不写死具体实现
# 5) 若架构不支持会直接报错,避免训练过程走到一半才失败
# -------------------------------------------------------------------------
# 动态识别文本编码器架构,保证脚本可用于不同系列的扩散模型
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str | None):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
@ -204,68 +189,63 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} 不受支持。")
# -------------------------------------------------------------------------
# 功能模块:命令行参数解析
# 1) 本模块定义 DreamBooth 训练参数与训练后 validation 参数
# 2) 训练负责微调权重与记录坐标validation 只负责训练后文生图输出
# 3) 不提供训练中间验证参数,避免任何中途采样影响训练流程
# 4) 对关键参数组合做合法性检查,减少运行中途异常
# 5) 支持通过 shell 脚本传参实现批量实验、对比与复现
# -------------------------------------------------------------------------
# 参数解析:包含训练参数、先验保持参数、训练后验证参数,以及坐标记录参数
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="DreamBooth 训练脚本(训练后纯文字生图 validation")
# 预训练模型与 tokenizer 相关配置
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True)
parser.add_argument("--revision", type=str, default=None, required=False)
parser.add_argument("--variant", type=str, default=None)
parser.add_argument("--tokenizer_name", type=str, default=None)
# 数据路径与提示词配置
parser.add_argument("--instance_data_dir", type=str, default=None, required=True)
parser.add_argument("--class_data_dir", type=str, default=None, required=False)
parser.add_argument("--instance_prompt", type=str, default=None, required=True)
parser.add_argument("--class_prompt", type=str, default=None)
# 先验保持相关开关与权重
parser.add_argument("--with_prior_preservation", default=False, action="store_true")
parser.add_argument("--prior_loss_weight", type=float, default=1.0)
parser.add_argument("--num_class_images", type=int, default=100)
# 输出与可复现配置
parser.add_argument("--output_dir", type=str, default="dreambooth-model")
parser.add_argument("--seed", type=int, default=None)
# 图像预处理配置
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--center_crop", default=False, action="store_true")
# 是否同时训练 text encoder
parser.add_argument("--train_text_encoder", action="store_true")
# 训练批次与 epoch/step 配置
parser.add_argument("--train_batch_size", type=int, default=4)
parser.add_argument("--sample_batch_size", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--max_train_steps", type=int, default=None)
parser.add_argument("--checkpointing_steps", type=int, default=500)
# 梯度累积与显存优化相关开关
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--gradient_checkpointing", action="store_true")
# 学习率与 scheduler 配置
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--scale_lr", action="store_true", default=False)
parser.add_argument(
"--lr_scheduler",
type=str,
@ -276,37 +256,42 @@ def parse_args(input_args=None):
parser.add_argument("--lr_num_cycles", type=int, default=1)
parser.add_argument("--lr_power", type=float, default=1.0)
# 优化器相关配置
parser.add_argument("--use_8bit_adam", action="store_true")
parser.add_argument("--dataloader_num_workers", type=int, default=0)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
parser.add_argument("--max_grad_norm", default=1.0, type=float)
# Hub 上传相关配置
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
# 日志与混精配置
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument("--allow_tf32", action="store_true")
parser.add_argument("--report_to", type=str, default="tensorboard")
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"])
parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"])
parser.add_argument("--local_rank", type=int, default=-1)
# 注意力与梯度相关的显存优化开关
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
parser.add_argument("--set_grads_to_none", action="store_true")
# 噪声与损失加权相关参数
parser.add_argument("--offset_noise", action="store_true", default=False)
parser.add_argument("--snr_gamma", type=float, default=None)
# tokenizer 与 text encoder 行为相关参数
parser.add_argument("--tokenizer_max_length", type=int, default=None, required=False)
parser.add_argument("--text_encoder_use_attention_mask", action="store_true", required=False)
parser.add_argument("--skip_save_text_encoder", action="store_true", required=False)
# 训练后验证参数(本脚本不做中途验证,仅训练结束后跑一次)
parser.add_argument("--validation_prompt", type=str, required=True)
parser.add_argument("--validation_negative_prompt", type=str, default="")
parser.add_argument("--num_validation_images", type=int, default=10)
@ -314,6 +299,7 @@ def parse_args(input_args=None):
parser.add_argument("--validation_guidance_scale", type=float, default=7.5)
parser.add_argument("--validation_image_output_dir", type=str, required=True)
# 训练过程坐标记录(用于可视化与轨迹分析)
parser.add_argument("--coords_save_path", type=str, default=None)
parser.add_argument("--coords_log_interval", type=int, default=10)
@ -322,10 +308,12 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
# 兼容 accelerate 启动时写入的 LOCAL_RANK 环境变量
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# 先验保持开启时必须提供 class 数据与 class prompt
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("启用先验保持时必须提供 class_data_dir。")
@ -340,14 +328,7 @@ def parse_args(input_args=None):
return args
# -------------------------------------------------------------------------
# 功能模块DreamBooth 训练数据集
# 1) 从 instance 与 class 目录读取图像,并统一做尺寸、裁剪与归一化
# 2) 同时提供实例提示词与类别提示词的 token id 作为文本输入
# 3) 先验保持模式下会返回两套图像与文本信息用于拼接训练
# 4) 数据集长度按 instance 与 class 的最大值取,便于循环采样
# 5) 数据集只负责准备输入,模型推理、损失计算与优化在主循环中完成
# -------------------------------------------------------------------------
# DreamBooth 数据集:负责读取图片、做裁剪归一化,并产出 prompt 的 input_ids 与 attention_mask
class DreamBoothDataset(Dataset):
def __init__(
self,
@ -370,11 +351,13 @@ class DreamBoothDataset(Dataset):
if not self.instance_data_root.exists():
raise ValueError(f"实例图像目录不存在:{self.instance_data_root}")
# instance 图片路径列表会循环采样,长度以 instance 数为基础
self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.is_file()]
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
# 在先验保持模式下同时读取 class 图片,并将长度设为两者最大值以便循环匹配
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
@ -388,6 +371,7 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
# 训练用图像预处理:先 resize再 crop然后归一化到 [-1, 1]
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
@ -400,6 +384,7 @@ class DreamBoothDataset(Dataset):
def __len__(self):
return self._length
# 将 prompt 统一分词为固定长度,避免动态长度导致批处理不稳定
def _tokenize(self, prompt: str):
max_length = self.tokenizer_max_length if self.tokenizer_max_length is not None else self.tokenizer.model_max_length
return self.tokenizer(
@ -413,16 +398,19 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
# 读取 instance 图片并处理 EXIF 方向,保证训练输入方向一致
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
if instance_image.mode != "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
# instance prompt 的 input_ids 与 attention_mask
text_inputs = self._tokenize(self.instance_prompt)
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask
# 先验保持时额外返回 class 图片与 class prompt 的 token
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
@ -437,14 +425,7 @@ class DreamBoothDataset(Dataset):
return example
# -------------------------------------------------------------------------
# 功能模块:批处理拼接与张量规整
# 1) 将单条样本组成的列表拼接为 batch 字典,供训练循环直接使用
# 2) 将图像张量 stack 成 (B,C,H,W) 并转换为 float提高后续 VAE 兼容性
# 3) 将 input_ids 与 attention_mask 在 batch 维度 cat便于文本编码器计算
# 4) 先验保持模式下将 instance 与 class 在 batch 维度拼接,减少前向次数
# 5) 该模块不做任何损失与梯度计算,只负责打包输入数据结构
# -------------------------------------------------------------------------
# 批处理拼接:将样本列表组装为 batch并在先验保持时拼接 instance 与 class 以减少前向次数
def collate_fn(examples, with_prior_preservation=False):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
@ -455,6 +436,7 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values += [example["class_images"] for example in examples]
attention_mask += [example["class_attention_mask"] for example in examples]
# 图像张量 stack 为 (B, C, H, W),并确保是连续内存与 float 类型
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
attention_mask = torch.cat(attention_mask, dim=0)
@ -462,14 +444,7 @@ def collate_fn(examples, with_prior_preservation=False):
return {"input_ids": input_ids, "pixel_values": pixel_values, "attention_mask": attention_mask}
# -------------------------------------------------------------------------
# 功能模块:生成 class 图像的提示词数据集
# 1) 该数据集用于先验保持时批量生成类别图像,提供固定 prompt
# 2) 每条样本返回 prompt 与索引,索引用于生成稳定的文件名
# 3) 与训练数据集分离,避免采样逻辑影响训练数据读取与增强
# 4) 支持多进程环境下由 accelerate 分配采样 batch提高生成效率
# 5) 该模块只在 with_prior_preservation 启用且 class 数据不足时使用
# -------------------------------------------------------------------------
# class 图像生成专用数据集:仅提供 prompt 与 index用于加速生成与落盘命名
class PromptDataset(Dataset):
def __init__(self, prompt, num_samples):
self.prompt = prompt
@ -482,14 +457,7 @@ class PromptDataset(Dataset):
return {"prompt": self.prompt, "index": index}
# -------------------------------------------------------------------------
# 功能模块:判断预训练模型是否包含 VAE
# 1) 通过检查 vae/config.json 是否存在来决定是否加载 VAE
# 2) 同时支持本地目录与 Hub 结构,便于离线缓存模式运行
# 3) 若不存在 VAE 子目录,将跳过加载并在训练中使用像素空间输入
# 4) 该判断只发生在初始化阶段,不影响训练过程与日志记录
# 5) 对 Stable Diffusion 类模型通常都会包含 VAE属于常规路径
# -------------------------------------------------------------------------
# 判断模型是否包含 VAE用于兼容可能没有 vae 子目录的模型结构
def model_has_vae(args):
config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
if os.path.isdir(args.pretrained_model_name_or_path):
@ -500,14 +468,7 @@ def model_has_vae(args):
return any(file.rfilename == config_file_name for file in files_in_repo)
# -------------------------------------------------------------------------
# 功能模块:文本编码器前向
# 1) 将 input_ids 与 attention_mask 输入 text encoder 得到条件嵌入
# 2) 可选择是否启用 attention_mask以适配不同文本编码器行为
# 3) 输出的 prompt_embeds 作为 UNet 条件输入,影响生成语义与身份绑定
# 4) 该函数在训练循环中频繁调用,需要保持设备与 dtype 的一致性
# 5) 返回张量为 (B, T, D),后续会与 timestep 一起输入 UNet
# -------------------------------------------------------------------------
# 文本编码:得到 prompt_embeds作为 UNet 的条件输入,控制语义与身份绑定
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask: bool):
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
@ -517,18 +478,13 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False)[0]
# -------------------------------------------------------------------------
# 功能模块:主训练流程
# 1) 负责构建 accelerate 环境、加载模型组件、准备数据与优化器
# 2) 支持先验保持:自动补足 class 图像并将 instance/class 合并训练
# 3) 训练循环中记录 loss、学习率与坐标指标输出 CSV 便于可视化分析
# 4) 训练结束后保存微调后的 pipeline 到 output_dir作为独立可用模型
# 5) 在保存完成后运行 validation仅用提示词进行文生图并将结果写入输出目录
# -------------------------------------------------------------------------
# 训练主流程:包含 class 数据补全、组件加载、训练循环、坐标记录、模型保存与训练后验证
def main(args):
# 避免将 hub token 暴露到 wandb 等第三方日志系统中
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError("不要同时使用 wandb 与 hub_token避免凭证泄露风险。")
# accelerate 项目配置:统一 output_dir 与 logging_dir便于多卡与断点保存
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
@ -539,9 +495,11 @@ def main(args):
project_config=accelerator_project_config,
)
# MPS 下关闭 AMP避免混精行为不一致导致训练异常
if torch.backends.mps.is_available():
accelerator.native_amp = False
# 初始化日志格式并打印 accelerate 状态,便于排查分布式配置问题
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
@ -549,21 +507,25 @@ def main(args):
)
logger.info(accelerator.state, main_process_only=False)
# 主进程输出更多 warning非主进程尽量保持安静以减少干扰
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
warnings.filterwarnings("ignore", category=UserWarning)
else:
transformers.utils.logging.set_verbosity_error()
# 设置随机种子,保证数据增强、噪声采样与验证结果可复现
if args.seed is not None:
set_seed(args.seed)
# 先验保持:当 class 图片不足时,使用 base model 生成补齐并保存到 class_data_dir
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
class_images_dir.mkdir(parents=True, exist_ok=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
# 生成 class 图片时可单独指定 dtype减少生成时的显存占用
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
@ -592,6 +554,7 @@ def main(args):
for example in tqdm(sample_dataloader, desc="生成 class 图像", disable=not accelerator.is_local_main_process):
images = pipe(example["prompt"]).images
for i, image in enumerate(images):
# 用图像内容 hash 防止同名冲突,并方便追溯生成来源
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@ -600,6 +563,7 @@ def main(args):
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 准备输出目录与 Hub 仓库,仅主进程执行以避免竞争写入
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
@ -611,6 +575,7 @@ def main(args):
else:
repo_id = None
# tokenizer 加载:优先使用显式指定的 tokenizer_name否则从模型目录的 tokenizer 子目录读取
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
else:
@ -621,9 +586,10 @@ def main(args):
use_fast=False,
)
# 组件加载scheduler、text_encoder、可选 VAE、以及 UNet
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
@ -638,11 +604,13 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# unwrap用于从 accelerator 包装对象中拿到可保存的原始模型
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# 自定义 hook让 accelerator.save_state 按 diffusers 的子目录结构保存 unet/text_encoder
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
@ -650,6 +618,7 @@ def main(args):
model.save_pretrained(os.path.join(output_dir, sub_dir))
weights.pop()
# 自定义 hook断点恢复时从 output_dir 读取 unet/text_encoder 并覆盖当前实例参数
def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()
@ -665,12 +634,15 @@ def main(args):
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# VAE 不参与训练,仅用于编码到 latent 空间
if vae is not None:
vae.requires_grad_(False)
# 默认只训练 UNet若开启 train_text_encoder 则同时训练文本编码器
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
# xformers开启后可显著降低注意力显存占用但需要正确安装依赖
if args.enable_xformers_memory_efficient_attention:
if not is_xformers_available():
raise ValueError("xformers 不可用,请确认安装成功。")
@ -680,17 +652,21 @@ def main(args):
logger.warning("xformers 0.0.16 在部分 GPU 上训练不稳定,建议升级。")
unet.enable_xformers_memory_efficient_attention()
# gradient checkpointing以计算换显存适合大模型与大分辨率训练
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# TF32在 Ampere 上可加速 matmul通常对训练稳定性影响较小
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# scale_lr按总 batch 规模放大学习率,便于多卡/大 batch 配置保持等效训练
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
# 优化器:可选 8-bit Adam 降低显存占用
optimizer_class = torch.optim.AdamW
if args.use_8bit_adam:
try:
@ -710,6 +686,7 @@ def main(args):
eps=args.adam_epsilon,
)
# 数据集与 dataloader根据 with_prior_preservation 决定是否加载 class 数据
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
@ -730,12 +707,14 @@ def main(args):
num_workers=args.dataloader_num_workers,
)
# 训练步数:若未指定 max_train_steps则由 epoch 与 dataloader 推导
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# scheduler基于总训练步数与 warmup 设置学习率曲线
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
@ -745,6 +724,7 @@ def main(args):
power=args.lr_power,
)
# accelerate.prepare把模型、优化器、数据加载器与 scheduler 放入分布式与混精管理
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
@ -754,27 +734,33 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler
)
# weight_dtype训练时模型权重与输入的 dtype用于混精与显存控制
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# VAE 始终与训练设备一致,并与 weight_dtype 对齐
if vae is not None:
vae.to(accelerator.device, dtype=weight_dtype)
# 若不训练 text encoder则把它当作推理组件统一 cast 到混精 dtype
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# 重新计算 epoch在 prepare 之后 dataloader 规模可能变化,因此再推导一次更稳妥
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# 初始化 tracker主进程写入配置便于实验复现与对比
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers("dreambooth", config=tracker_config)
# coords_list用于记录训练过程的三维指标轨迹并写入 CSV
coords_list = []
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@ -795,6 +781,7 @@ def main(args):
disable=not accelerator.is_local_main_process,
)
# 训练循环:每步完成 latent 构造、噪声添加、UNet 预测、loss 计算与反传更新
for epoch in range(0, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
@ -802,14 +789,17 @@ def main(args):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 输入图像对齐 dtype避免混精下出现不必要的类型转换
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# 使用 VAE 将图像编码到 latent 空间,若无 VAE 则直接在像素空间训练
if vae is not None:
model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# 采样噪声,并可选叠加 offset noise 以改变噪声分布形态
if args.offset_noise:
noise = torch.randn_like(model_input) + 0.1 * torch.randn(
model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
@ -817,13 +807,16 @@ def main(args):
else:
noise = torch.randn_like(model_input)
# 为每个样本随机选择一个扩散时间步
bsz = model_input.shape[0]
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
).long()
# 前向扩散:给输入加噪声,形成 UNet 的输入
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# 文本编码:得到条件嵌入,用于指导 UNet 的去噪方向
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
@ -831,11 +824,14 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
# UNet 输出噪声预测(或速度预测),返回的第一个元素为预测张量
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, return_dict=False)[0]
# 某些模型会同时预测方差,将通道拆分后仅保留噪声相关部分参与训练
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# 根据 scheduler 的 prediction_type 构造监督目标
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -843,11 +839,13 @@ def main(args):
else:
raise ValueError(f"未知 prediction_type{noise_scheduler.config.prediction_type}")
# 先验保持batch 被拼接为 instance+class因此这里按 batch 维拆开分别计算 prior loss
if args.with_prior_preservation:
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# loss默认 MSE若提供 snr_gamma 则用 SNR 加权以平衡不同时间步的贡献
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
@ -863,6 +861,7 @@ def main(args):
if args.with_prior_preservation:
loss = loss + args.prior_loss_weight * prior_loss
# 训练轨迹记录:用模型输出统计量作为特征指标,配合 loss 形成三维轨迹
if args.coords_save_path is not None:
X_i_feature_norm = torch.norm(model_pred.detach().float(), p=2, dim=[1, 2, 3]).mean().item()
Y_i_feature_var = torch.var(model_pred.detach().float()).item()
@ -882,8 +881,10 @@ def main(args):
df.to_csv(save_file_path, index=False)
logger.info(f"坐标已写入:{save_file_path}")
# 反向传播accelerate 负责混精与分布式同步
accelerator.backward(loss)
# 梯度裁剪:仅在同步梯度时执行,避免对未同步的局部梯度产生偏差
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
@ -892,18 +893,22 @@ def main(args):
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# 参数更新optimizer 与 scheduler 逐步推进
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# 只有在完成一次“真实更新”后才推进 global_step 与进度条
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# 按固定步数保存训练状态,用于断点恢复或中途回滚
if accelerator.is_main_process and global_step % args.checkpointing_steps == 0:
accelerator.save_state(args.output_dir)
logger.info(f"已保存训练状态到:{args.output_dir}")
# 每步记录 loss 与 lr便于在 dashboard 中观察收敛曲线
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
@ -915,6 +920,7 @@ def main(args):
images = []
if accelerator.is_main_process:
# 将训练好的组件写成独立 pipeline确保 output_dir 可直接用于推理
pipeline_args = {}
if not args.skip_save_text_encoder:
pipeline_args["text_encoder"] = unwrap_model(text_encoder)
@ -930,6 +936,7 @@ def main(args):
)
pipeline.save_pretrained(args.output_dir)
# 释放训练对象,减少后续 validation 的显存压力
del unet
del optimizer
del lr_scheduler
@ -940,6 +947,7 @@ def main(args):
gc.collect()
torch.cuda.empty_cache()
# 训练结束后运行一次 txt2img 验证,并将结果保存到指定目录
images = run_validation_txt2img(
finetuned_model_dir=args.output_dir,
prompt=args.validation_prompt,
@ -959,6 +967,7 @@ def main(args):
image.save(out_dir / f"validation_image_{i}.png")
logger.info(f"validation 图像已保存到:{out_dir}")
# 推送到 Hub写模型卡并上传 output_dir忽略 step/epoch 目录)
if args.push_to_hub:
save_model_card(
repo_id,
@ -976,6 +985,7 @@ def main(args):
ignore_patterns=["step_*", "epoch_*"],
)
# 训练结束后再落一次坐标,保证最后一段数据不会因日志频率而遗漏
if args.coords_save_path is not None and coords_list:
df = pd.DataFrame(coords_list, columns=["step", "X_Feature_L2_Norm", "Y_Feature_Variance", "Z_LDM_Loss"])
save_file_path = Path(args.coords_save_path)

File diff suppressed because it is too large Load Diff

@ -1,18 +1,3 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import copy
import gc
@ -51,13 +36,12 @@ from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
)
# Removed LoRA import: from diffusers.loaders import LoraLoaderMixin
# 本脚本只训练 Textual Inversion 的 token embedding不涉及 LoRA 权重
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
# Removed LoRA import: convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@ -65,28 +49,23 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# wandb 为可选依赖,仅在环境可用时启用
if is_wandb_available():
import wandb
# 说明:
# 1) 本文件用于训练 Textual Inversion仅训练一个新 token 的向量)。
# 2) 训练过程冻结 UNet/VAE/TextEncoder 的主体权重,仅更新新 token 对应的 embedding 行。
# 3) 训练过程会按步保存 embedding并进行验证推理用于观察训练效果。
# 4) 文件还包含可视化坐标采集逻辑X=特征范数Y=特征方差Z=loss并写入 CSV。
# 5) 为了保证推理阶段的一致性,验证推理会从基础模型加载,并再加载 learned_embeds.bin 作为增量能力。
# 训练目标为 Textual Inversion只学习一个新 token 的 embedding 行
# 训练过程中冻结 UNet/VAE/TextEncoder 主体参数,只允许 placeholder token 对应的 embedding 更新
# 训练会周期性保存 learned_embeds.bin 与 tokenizer并在保存点执行验证推理以观察学习效果
# 可选记录训练轨迹坐标:(X=UNet 预测特征范数, Y=UNet 预测特征方差, Z=loss) 并写入 CSV
logger = get_logger(__name__)
def _load_textual_inversion_compat(pipeline: DiffusionPipeline, emb_dir: str, token: str):
"""
说明
1) 不同 diffusers 版本对 load_textual_inversion 的参数命名不一致
2) 有些版本支持 token=...有些版本支持 tokens=[...]还有些只支持路径
3) 本函数用于在不同版本之间提供兼容调用优先传入 token 名提高确定性
4) 若当前版本不接受这些参数会自动降级为仅传路径的调用方式
5) 该函数不会保存或覆盖基础模型文件只在运行时向 pipeline 注入增量 embedding
"""
# 兼容不同 diffusers 版本的 Textual Inversion 加载接口
# 优先显式指定 token 名,确保加载的 embedding 与 placeholder 对应
# 若接口参数不兼容则自动降级为只传路径的调用方式
# 该操作仅在运行时向 pipeline 注入 embedding不会修改基础模型目录
try:
pipeline.load_textual_inversion(emb_dir, token=token)
return
@ -112,12 +91,10 @@ def save_model_card(
pipeline: DiffusionPipeline = None,
placeholder_token: str = None,
):
# 说明:
# 1) 该函数用于生成并保存 README 模型卡片与示例图片,便于上传 Hub 或本地记录。
# 2) 对于 Textual Inversion模型文件主要是 learned_embeds.bin 与 tokenizer。
# 3) 该模型卡会说明训练所用的 placeholder token 与训练 prompt。
# 4) 生成的图片会被保存在 repo_folder 下,方便查看训练效果。
# 5) 本函数不会修改模型参数,只做文档与示例资产的持久化。
# 生成并保存模型卡 README同时保存示例图片到输出目录
# Textual Inversion 的核心产物是 learned_embeds.bin 与 tokenizer 增量词表
# 模型卡用于说明基础模型、训练 prompt 与 placeholder token便于复现与展示
# 本函数只写文档与图片文件,不改变任何训练参数或模型权重
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@ -156,12 +133,10 @@ def log_validation(
epoch,
is_final_validation=False,
):
# 说明:
# 1) 该函数用于在训练过程中做验证推理,观察当前 embedding 学到了什么。
# 2) 会将 scheduler 替换为更适合推理的 DPMSolverMultistepScheduler。
# 3) 会关闭安全检查器,避免被过滤导致无法看到结果。
# 4) 既支持纯文生图,也支持某些管线的传图推理(依赖 args.validation_images
# 5) 会把结果写入 trackertensorboard/wandb并释放 GPU 显存。
# 验证推理:在训练过程中生成样例图,用于观察 embedding 的学习方向
# 推理阶段使用 DPM-Solver 调度器提升速度,并禁用安全检查器避免结果被过滤
# 支持纯文本推理与带初始图像的推理形式(由 validation_images 控制)
# 推理结果会写入 trackertensorboard/wandb并在结束后释放显存
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
@ -219,12 +194,9 @@ def log_validation(
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
# 说明:
# 1) Stable Diffusion 不同变体可能使用不同的 text encoder 架构。
# 2) 该函数读取 text_encoder 的配置,判断其 architectures 字段来确定具体类。
# 3) 常见情况是 CLIPTextModel也可能是 Roberta 或 T5 系列。
# 4) 返回的类用于 from_pretrained 加载 text_encoder保证结构匹配。
# 5) 如果遇到未知架构会直接报错,避免后续 silent bug。
# 通过模型配置自动识别 text encoder 的具体架构,并返回对应的实现类
# 该识别逻辑用于兼容不同的 Stable Diffusion 系列与其他扩散管线变体
# 若架构不在支持列表中则直接报错,避免训练中途出现不匹配问题
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
@ -249,12 +221,11 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def parse_args(input_args=None):
# 说明:
# 1) 该函数定义所有可配置参数,支持命令行调用与被后端服务传参调用。
# 2) 训练相关参数包含学习率、步数、批大小、混合精度、保存间隔等。
# 3) Textual Inversion 需要 placeholder_token 与 initializer_token并且 prompt 必须包含 placeholder。
# 4) 验证推理参数用于在训练中生成图片,输出到指定目录用于可视化或服务返回。
# 5) coords_* 参数用于记录 3D 可视化坐标数据,不影响训练但会增加少量开销。
# 参数解析:定义训练、保存、验证推理、断点恢复与坐标记录所需的全部参数
# Textual Inversion 需要 placeholder_token 与 initializer_token并且 instance_prompt 必须包含 placeholder_token
# 训练步数由 max_train_steps 或 num_train_epochs 推导,保存间隔由 checkpointing_steps 控制
# 验证推理参数决定生成样例图的 prompt、数量与保存目录
# coords_* 参数用于训练轨迹输出 CSV不改变训练逻辑仅增加统计与写盘开销
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
@ -559,12 +530,11 @@ def parse_args(input_args=None):
class DreamBoothDataset(Dataset):
# 说明:
# 1) 该数据集负责读取实例图片,并把图片变换到训练所需的张量格式。
# 2) 同时会对 instance_prompt 做 tokenizer 编码,生成 input_ids 与 attention_mask。
# 3) Textual Inversion 不做 prior preservation因此长度等于实例图片数量。
# 4) 图像会先 resize 再 crop并归一化到 [-1,1]Normalize([0.5],[0.5]))。
# 5) 返回的字典字段会在 collate_fn 中被组装成 batch供 UNet 前向与损失计算使用。
# 数据集:负责读取实例图片并做预处理,同时对 instance_prompt 做分词编码
# 图像会先 resize 再裁剪,并归一化到 [-1, 1],以匹配 Stable Diffusion 的训练输入
# 每个样本输出 image 张量与 token 张量,字段名与训练循环一致
# Textual Inversion 不需要 class 数据或先验保持,因此长度等于实例图片数量
# 本类只准备数据,不参与任何梯度计算与模型更新
def __init__(
self,
instance_data_root,
@ -610,6 +580,7 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
@ -625,12 +596,11 @@ class DreamBoothDataset(Dataset):
def collate_fn(examples):
# 说明:
# 1) 该函数负责将 Dataset 返回的若干条样本组装成一个 batch。
# 2) 对图像张量做 stack得到 (B,C,H,W) 的 pixel_values。
# 3) 对 token 的 input_ids 做 cat得到 (B,seq_len) 的输入矩阵。
# 4) attention_mask 保持与 input_ids 对齐,用于 text encoder 的有效 token 标记。
# 5) 输出 batch 会被训练循环直接使用,字段命名与后续代码保持一致。
# batch 拼接:把多条样本打包成训练循环可直接使用的 batch 字典
# 图像张量 stack 为 (B, C, H, W),并转换为连续内存以提高算子效率
# input_ids 与 attention_mask 沿 batch 维拼接,保持与 text encoder 的输入格式一致
# 输出字段命名与训练主循环一致,避免额外适配
# 本函数不做任何增强或损失计算,只负责规整数据结构
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
@ -656,12 +626,11 @@ def collate_fn(examples):
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
# 说明:
# 1) 对文本 prompt 做 tokenizer 编码,生成 input_ids 与 attention_mask。
# 2) 使用固定长度 padding="max_length" 保证 batch 拼接简单一致。
# 3) truncation=True 防止超过最大长度导致报错。
# 4) tokenizer_max_length 允许外部指定最大长度;否则使用 tokenizer.model_max_length。
# 5) 返回 transformers 的 BatchEncoding后续直接取 input_ids 与 attention_mask 使用即可。
# 分词编码:把 prompt 转为固定长度 input_ids 与 attention_mask
# truncation=True 防止超长输入报错padding="max_length" 保持 batch 形状稳定
# tokenizer_max_length 若提供则覆盖默认长度,否则使用 tokenizer.model_max_length
# 返回 BatchEncoding训练代码从中读取 input_ids 与 attention_mask
# 此处不引入任何额外逻辑,保证 token 化行为可复现
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
@ -678,12 +647,11 @@ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
# 说明:
# 1) 将 token id 输入 Text Encoder得到用于 UNet 条件输入的 prompt_embeds。
# 2) 如果启用 attention_mask会把 mask 一并传入,以减少 padding token 的影响。
# 3) 输出的 prompt_embeds 通常形状为 (B, seq_len, hidden_dim)。
# 4) UNet 会把该 embedding 作为 cross-attention 的条件,实现文本引导。
# 5) 该函数不涉及梯度以外的副作用embedding 的更新由上层训练流程控制。
# 文本编码:将 token id 输入 text encoder 得到 prompt_embeds用作 UNet 条件输入
# 可选启用 attention_mask以减少 padding token 对编码结果的影响
# 输出为 (B, seq_len, hidden_dim),与 UNet cross-attention 的条件维度对齐
# 本函数不做保存与缓存,训练时是否更新 embedding 由上层控制
# 为避免设备不一致,输入会被移动到 text_encoder.device
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
@ -701,12 +669,11 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
def main(args):
# 说明:
# 1) 主函数负责训练用的全部初始化accelerate、模型加载、数据集/优化器/调度器。
# 2) Textual Inversion 的关键是新增一个 placeholder token并只训练该 token 的 embedding。
# 3) 训练过程中会定期保存 learned_embeds.bin 与 tokenizer并执行验证推理输出图片。
# 4) 验证推理从基础模型加载,再加载 learned_embeds.bin避免对基础模型权重产生写回影响。
# 5) 若开启 coords_save_path会按你原有逻辑采集并保存可视化坐标数据不改变其行为。
# 主流程:构建 accelerate 环境,加载基础模型组件,并创建可训练的 placeholder token
# placeholder token 会被加入 tokenizer并用 initializer token 的 embedding 进行初始化
# 训练时冻结 UNet/VAE/TextEncoder 的主体权重,仅更新 placeholder token 的 embedding 行
# 训练循环包含 latent 编码、加噪、条件编码、UNet 预测与 MSE 损失,并进行反向传播更新
# 训练期间按 checkpointing_steps 保存状态,并用注入 embedding 的 pipeline 做验证推理输出样例图
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
@ -810,6 +777,7 @@ def main(args):
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# 只对 embedding 层开放梯度,并配合 mask 将可训练范围限制为 placeholder_token 一行
embedding_layer = text_encoder.get_input_embeddings()
embedding_layer.weight.requires_grad = True
trainable_token_embeds = embedding_layer.weight
@ -849,23 +817,18 @@ def main(args):
unet.enable_gradient_checkpointing()
def unwrap_model(model):
# 说明:
# 1) accelerate 在分布式或混合精度下会包装模型,保存/取权重时需要先 unwrap。
# 2) 如果启用 torch.compile模型会被再次包装需取 _orig_mod 获取真实模块。
# 3) 该函数用于在保存 embedding、验证推理、访问模型权重时统一处理。
# 4) 返回的模型对象是“原始模型”,便于直接访问 embedding 权重与 config。
# 5) 该函数自身不做任何训练逻辑修改,只是一个安全的模型访问入口。
# 从 accelerate 包装中取出原始模型,便于保存与访问真实 embedding 权重
# 对 torch.compile 场景,进一步解包以获得真实模块
# 本函数不会改变模型状态,只提供统一的访问方式
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def save_model_hook(models, weights, output_dir):
# 说明:
# 1) 该 Hook 用于让 accelerate.save_state 保存为 Textual Inversion 需要的最小产物。
# 2) 主要保存 learned_embeds.bin仅包含 placeholder_token 对应的 embedding 行)。
# 3) 同时保存 tokenizer以便后续复现训练 token 的 id 映射与 tokenizer 配置。
# 4) 不保存 UNet/VAE/TextEncoder 的完整权重,避免体积巨大且不符合“增量”设计。
# 5) 保存行为只发生在主进程,避免分布式重复写盘导致文件冲突。
# accelerate 保存钩子:只保存 Textual Inversion 的最小产物
# learned_embeds.bin 只包含 placeholder_token 对应的 embedding 行,体积小且易于分发
# tokenizer 一并保存,用于恢复 token_id 映射与 placeholder_token 的存在性
# 不保存 UNet/VAE/TextEncoder 全量权重,保持增量训练的设计目标
if accelerator.is_main_process:
text_encoder_unwrapped = unwrap_model(text_encoder)
trained_embeddings = text_encoder_unwrapped.get_input_embeddings().weight[
@ -880,12 +843,10 @@ def main(args):
tokenizer.save_pretrained(output_dir)
def load_model_hook(models, input_dir):
# 说明:
# 1) 该 Hook 用于从 checkpoint 恢复训练时,将 learned_embeds.bin 写回到 text_encoder embedding。
# 2) 对于 Textual Inversion恢复的关键是 placeholder_token 对应 embedding 行,而非整个模型。
# 3) 同时通过 checkpoint 内的 tokenizer 获取 placeholder_token 的 token_id以保证写入位置一致。
# 4) 若 checkpoint 缺失 learned_embeds.bin会打印警告并跳过允许从头开始训练。
# 5) 该逻辑只改变当前训练进程内的权重状态,不会修改基础模型目录的文件。
# accelerate 加载钩子:从 learned_embeds.bin 恢复 placeholder_token 的 embedding 行
# 通过 checkpoint 内 tokenizer 获取 placeholder_token_id确保写回位置正确
# 若文件缺失则跳过恢复,允许从头训练或使用外部初始化
# 该操作只影响当前训练进程内存中的 embedding不会修改基础模型目录
text_encoder_ = None
while len(models) > 0:
@ -1060,6 +1021,7 @@ def main(args):
)
for epoch in range(first_epoch, args.num_train_epochs):
# 训练时 UNet/TextEncoder 保持 train(),但实际只有 embedding 行可更新
unet.train()
text_encoder.train()
@ -1067,12 +1029,14 @@ def main(args):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# 训练输入优先在 latent 空间,提升计算效率并匹配扩散模型训练范式
if vae is not None:
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# 为每个样本采样时间步与噪声,构造前向扩散后的 noisy 输入
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
@ -1082,6 +1046,7 @@ def main(args):
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# 文本条件编码:使用当前 text encoder 生成 prompt_embeds引导 UNet 去噪
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
@ -1097,6 +1062,7 @@ def main(args):
else:
class_labels = None
# UNet 预测噪声残差,训练目标是最小化预测与真实噪声(或速度)的均方误差
model_pred = unet(
noisy_model_input,
timesteps,
@ -1108,6 +1074,7 @@ def main(args):
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# 轨迹记录:用 UNet 输出统计量作为 X/Y配合 loss 形成训练动态曲线
if args.coords_save_path is not None:
X_i_feature_norm = torch.norm(model_pred.detach().float(), p=2, dim=[1, 2, 3]).mean().item()
Y_i_feature_var = model_pred.detach().float().var(dim=[1, 2, 3]).mean().item()
@ -1130,6 +1097,7 @@ def main(args):
lr_scheduler.step()
optimizer.zero_grad()
# 每次更新后强制把非 placeholder 的 embedding 恢复为固定值,保证只学习目标 token
if accelerator.num_processes > 1:
unwrapped_text_encoder = unwrap_model(text_encoder)
trainable_embeds = unwrapped_text_encoder.get_input_embeddings().weight
@ -1143,6 +1111,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
# 坐标保存:按固定步数把 (X,Y,Z) 追加到列表,并覆盖写入 CSV 以防训练中断丢失
if args.coords_save_path is not None and (
global_step % args.coords_log_interval == 0
or global_step == 1
@ -1167,6 +1136,7 @@ def main(args):
f"Step {global_step}: 已记录并保存可视化坐标 (X={X_i_feature_norm:.4f}, Y={Y_i_feature_var:.4f}, Z={Z_i:.4f}) 到 {save_file_path}"
)
# checkpoint保存训练状态并用基础模型 + 注入 embedding 的方式生成验证图像
if accelerator.is_main_process:
if (global_step + 1) % args.checkpointing_steps == 0:
output_dir = args.output_dir
@ -1213,7 +1183,10 @@ def main(args):
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# 训练结束保存最终产物learned_embeds.bin 与 tokenizer
# learned_embeds.bin 只包含 placeholder token 的 embedding 行,用于后续推理时注入到基础模型
text_encoder = unwrap_model(text_encoder)
trained_embeddings = text_encoder.get_input_embeddings().weight[
@ -1226,6 +1199,7 @@ def main(args):
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
tokenizer.save_pretrained(args.output_dir)
# 最终验证:重新加载基础模型并注入 embedding生成样例图用于输出与模型卡展示
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
@ -1265,6 +1239,7 @@ def main(args):
ignore_patterns=["step_*", "epoch_*"],
)
# 训练结束补写一次坐标文件,确保最后阶段数据不会遗漏
if args.coords_save_path is not None and coords_list:
df = pd.DataFrame(
coords_list,

@ -29,7 +29,7 @@ logger = get_logger(__name__)
class DreamBoothDatasetFromTensor(Dataset):
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
"""基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。"""
def __init__(
self,
@ -41,15 +41,18 @@ class DreamBoothDatasetFromTensor(Dataset):
size=512,
center_crop=False,
):
# 保存图像处理参数与 tokenizer
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
# 实例数据:直接来自传入的张量列表
self.instance_images_tensor = instance_images_tensor
self.num_instance_images = len(self.instance_images_tensor)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
# 可选类数据:用于先验保持,长度取实例与类数据的最大值
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
@ -60,6 +63,7 @@ class DreamBoothDatasetFromTensor(Dataset):
else:
self.class_data_root = None
# 统一的图像预处理
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
@ -73,6 +77,7 @@ class DreamBoothDatasetFromTensor(Dataset):
return self._length
def __getitem__(self, index):
# 取出实例图像张量与对应 prompt token
example = {}
instance_image = self.instance_images_tensor[index % self.num_instance_images]
example["instance_images"] = instance_image
@ -84,6 +89,7 @@ class DreamBoothDatasetFromTensor(Dataset):
return_tensors="pt",
).input_ids
# 若有类数据,则同时返回类图像与类 prompt token
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
@ -101,6 +107,7 @@ class DreamBoothDatasetFromTensor(Dataset):
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
# 根据 text_encoder 配置识别其架构,选择正确的模型类
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
@ -121,6 +128,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def parse_args(input_args=None):
# 解析命令行参数:模型路径、数据路径、对抗参数、先验保持、训练与日志配置
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
@ -341,7 +349,7 @@ def parse_args(input_args=None):
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"""用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。"""
def __init__(self, prompt, num_samples):
self.prompt = prompt
@ -358,6 +366,7 @@ class PromptDataset(Dataset):
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
# 读取目录下所有图片,按训练要求 resize/crop/normalize返回堆叠后的张量
image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
@ -381,8 +390,7 @@ def train_one_epoch(
data_tensor: torch.Tensor,
num_steps=20,
):
# Load the tokenizer
# 单轮训练:复制当前模型,使用给定数据迭代若干步,返回更新后的副本
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
@ -404,7 +412,6 @@ def train_one_epoch(
args.center_crop,
)
# weight_dtype = torch.bfloat16
weight_dtype = torch.bfloat16
device = torch.device("cuda")
@ -416,6 +423,7 @@ def train_one_epoch(
unet.train()
text_encoder.train()
# 构造当前步的样本instance + class并生成文本 token
step_data = train_dataset[step % len(train_dataset)]
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
device, dtype=weight_dtype
@ -425,24 +433,20 @@ def train_one_epoch(
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
# 随机采样时间步并加噪,模拟正向扩散
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
# 文本条件编码
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
# UNet 预测噪声
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
# 根据 scheduler 的预测类型选择目标
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -450,18 +454,13 @@ def train_one_epoch(
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# with prior preservation loss
# 可选先验保持:拆分 instance 与 class 部分分别计算 MSE
if args.with_prior_preservation:
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = instance_loss + args.prior_loss_weight * prior_loss
else:
@ -489,7 +488,7 @@ def pgd_attack(
target_tensor: torch.Tensor,
num_steps: int,
):
"""Return new perturbed data"""
"""PGD 对抗扰动:在噪声预算内迭代更新输入,返回新的扰动数据。"""
unet, text_encoder = models
weight_dtype = torch.bfloat16
@ -515,24 +514,18 @@ def pgd_attack(
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
# 采样时间步并加噪,准备 UNet 预测
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
#noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
# 文本条件与噪声预测
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
# 目标噪声或速度
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -544,7 +537,7 @@ def pgd_attack(
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
# 若有目标图像 latent加入目标对齐项保持原有逻辑损失为差值
if target_tensor is not None:
xtm1_pred = torch.cat(
[
@ -561,6 +554,7 @@ def pgd_attack(
loss.backward()
# PGD 更新并投影到 eps 球内,再裁剪到 [-1, 1]
alpha = args.pgd_alpha
eps = args.pgd_eps / 255
@ -598,7 +592,7 @@ def main(args):
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
# 先验保持:不足的 class 图像用基础模型生成补齐
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
@ -645,10 +639,9 @@ def main(args):
if torch.cuda.is_available():
torch.cuda.empty_cache()
# import correct text encoder class
# 加载 text encoder / UNet / tokenizer / scheduler / VAE
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
@ -712,9 +705,9 @@ def main(args):
)
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
# 交替流程:训练 surrogate -> PGD 扰动 -> 用扰动数据再训练主模型,周期性导出对抗样本
f = [unet, text_encoder]
for i in range(args.max_train_steps):
# 1. f' = f.clone()
f_sur = copy.deepcopy(f)
f_sur = train_one_epoch(
args,
@ -746,6 +739,7 @@ def main(args):
args.max_f_train_steps,
)
# 周期保存当前扰动图像,便于后续评估与复现
if (i + 1) % args.checkpointing_iterations == 0:
save_folder = args.output_dir
os.makedirs(save_folder, exist_ok=True)

@ -41,11 +41,13 @@ logger = get_logger(__name__)
def freeze_params(params):
"""冻结一组参数的梯度开关,使其在训练中保持不更新。"""
for param in params:
param.requires_grad = False
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
"""从预训练目录读取 text_encoder 配置,自动选择匹配的文本编码器实现类。"""
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
@ -65,7 +67,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"""用于批量生成 class 图像的 prompt 数据集,便于采样阶段在多卡上分发任务。"""
def __init__(self, prompt, num_samples):
self.prompt = prompt
@ -83,8 +85,10 @@ class PromptDataset(Dataset):
class CustomDiffusionDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
CAAT/Custom Diffusion 训练数据集
负责读取实例图像与可选的类图像并为每张图像生成对应的 prompt token
同时在实例图像上生成有效区域 mask供训练时对 loss 做空间加权
"""
def __init__(
@ -99,6 +103,7 @@ class CustomDiffusionDataset(Dataset):
hflip=False,
aug=True,
):
# 训练图像与 mask 的目标尺寸
self.size = size
self.mask_size = mask_size
self.center_crop = center_crop
@ -106,6 +111,7 @@ class CustomDiffusionDataset(Dataset):
self.interpolation = Image.BILINEAR
self.aug = aug
# 记录实例与类数据的路径及对应 prompt
self.instance_images_path = []
self.class_images_path = []
self.with_prior_preservation = with_prior_preservation
@ -115,6 +121,7 @@ class CustomDiffusionDataset(Dataset):
]
self.instance_images_path.extend(inst_img_path)
# 启用先验保持时,额外读取 class 图像与 class prompt
if with_prior_preservation:
class_data_root = Path(concept["class_data_dir"])
if os.path.isdir(class_data_root):
@ -129,12 +136,16 @@ class CustomDiffusionDataset(Dataset):
class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
self.class_images_path.extend(class_img_path[:num_class_images])
# 打乱实例顺序以增加训练随机性,并确定数据集长度
random.shuffle(self.instance_images_path)
self.num_instance_images = len(self.instance_images_path)
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
# 可选水平翻转增强
self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
# 类图像走标准 transforms实例图像会走自定义 preprocess 以生成 mask
self.image_transforms = transforms.Compose(
[
self.flip,
@ -149,6 +160,7 @@ class CustomDiffusionDataset(Dataset):
return self._length
def preprocess(self, image, scale, resample):
"""对实例图像做缩放与随机放置,并生成对应的有效区域 mask。"""
outer, inner = self.size, scale
factor = self.size // self.mask_size
if scale > self.size:
@ -171,13 +183,15 @@ class CustomDiffusionDataset(Dataset):
def __getitem__(self, index):
example = {}
# 读取实例图像与对应 prompt
instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]
instance_image = Image.open(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
instance_image = self.flip(instance_image)
# apply resize augmentation and create a valid image region mask
# 对实例图像做随机缩放增强,并生成有效区域 mask
random_scale = self.size
if self.aug:
random_scale = (
@ -187,11 +201,13 @@ class CustomDiffusionDataset(Dataset):
)
instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)
# 根据缩放幅度对 prompt 加入轻量描述,模拟尺度变化的语义提示
if random_scale < 0.6 * self.size:
instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt
elif random_scale > self.size:
instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt
# 实例图像与 mask 进入训练:图像已归一化到 [-1, 1]
example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
example["mask"] = torch.from_numpy(mask)
example["instance_prompt_ids"] = self.tokenizer(
@ -202,6 +218,7 @@ class CustomDiffusionDataset(Dataset):
return_tensors="pt",
).input_ids
# 先验保持:追加 class 图像、class mask 与 class prompt token
if self.with_prior_preservation:
class_image, class_prompt = self.class_images_path[index % self.num_class_images]
class_image = Image.open(class_image)
@ -222,6 +239,7 @@ class CustomDiffusionDataset(Dataset):
def parse_args(input_args=None):
"""解析 CAAT 训练参数:包含 PGD 超参、数据与模型路径、训练步数与优化器设置。"""
parser = argparse.ArgumentParser(description="CAAT training script.")
parser.add_argument(
"--alpha",
@ -494,6 +512,7 @@ def parse_args(input_args=None):
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# 先验保持模式需要 class 数据与 prompt多概念模式下由 concepts_list 提供
if args.with_prior_preservation:
if args.concepts_list is None:
if args.class_data_dir is None:
@ -501,7 +520,6 @@ def parse_args(input_args=None):
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
# logger is not available yet
if args.class_data_dir is not None:
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
@ -511,8 +529,8 @@ def parse_args(input_args=None):
def main(args):
# 初始化 accelerate 环境与日志目录
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
@ -534,11 +552,14 @@ def main(args):
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# 记录实验配置到 tracker便于后续复现实验
accelerator.init_trackers("CAAT", config=vars(args))
# If passed along, set the training seed now.
# 固定随机种子以提高可复现性
if args.seed is not None:
set_seed(args.seed)
# 将单概念参数统一封装为 concepts_list或从 json 中读取多概念配置
if args.concepts_list is None:
args.concepts_list = [
{
@ -552,7 +573,7 @@ def main(args):
with open(args.concepts_list, "r") as f:
args.concepts_list = json.load(f)
# Generate class images if prior preservation is enabled.
# 启用先验保持时,若 class 图像不足则使用基础模型补齐
if args.with_prior_preservation:
for i, concept in enumerate(args.concepts_list):
class_images_dir = Path(concept["class_data_dir"])
@ -604,10 +625,11 @@ def main(args):
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 创建输出目录
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
# 加载 tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name,
@ -622,10 +644,8 @@ def main(args):
use_fast=False,
)
# import correct text encoder class
# 加载 text encoder / scheduler / VAE / UNet
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
@ -636,22 +656,24 @@ def main(args):
)
# 冻结主干权重:该方法只训练 attention processor 中的增量参数
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
# 推理组件使用半精度可节省显存;训练增量层由 optimizer 管理
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
# 将模型移动到训练设备并统一 dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# 根据是否启用 xformers 选择 attention processor 的实现
attention_class = CustomDiffusionAttnProcessor
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@ -666,24 +688,12 @@ def main(args):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# now we will add new Custom Diffusion weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
# 训练策略:只训练 cross-attention 的 KV或全部 Q/K/V/out其余保持冻结
train_kv = True
train_q_out = False if args.freeze_model == "crossattn_kv" else True
custom_diffusion_attn_procs = {}
# 从 UNet state_dict 中取出原始权重,作为自定义 attention processor 的初始化
st = unet.state_dict()
for name, _ in unet.attn_processors.items():
@ -697,6 +707,8 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
layer_name = name.split(".processor")[0]
# KV 投影权重始终可训练;若启用 train_q_out 则额外训练 Q 与 out 投影
weights = {
"to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
"to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
@ -705,6 +717,8 @@ def main(args):
weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
# 仅对 cross-attention 层注入可训练 processorself-attention 走冻结版本
if cross_attention_dim is not None:
custom_diffusion_attn_procs[name] = attention_class(
train_kv=train_kv,
@ -723,24 +737,27 @@ def main(args):
del st
# 将新的 attention processor 注入 UNet并用 AttnProcsLayers 封装成可训练模块
unet.set_attn_processor(custom_diffusion_attn_procs)
custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
# 将增量层注册到 checkpoint保证训练状态可保存/恢复
accelerator.register_for_checkpointing(custom_diffusion_layers)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
# 允许 TF32 可提升部分 GPU 上的矩阵运算速度
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# 先验保持时,通常将学习率扩大以补偿额外损失项带来的梯度分摊
args.learning_rate = args.learning_rate
if args.with_prior_preservation:
args.learning_rate = args.learning_rate * 2.0
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
# 选择优化器实现:可选 8-bit AdamW 以降低显存占用
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
@ -753,7 +770,7 @@ def main(args):
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
# 仅优化 custom_diffusion_layers 的参数,其余主干保持冻结
optimizer = optimizer_class(
custom_diffusion_layers.parameters(),
lr=args.learning_rate,
@ -762,7 +779,7 @@ def main(args):
eps=args.adam_epsilon,
)
# Dataset creation:
# 构建训练数据集mask_size 通过 VAE latent 分辨率自动推导
train_dataset = CustomDiffusionDataset(
concepts_list=args.concepts_list,
tokenizer=tokenizer,
@ -780,15 +797,17 @@ def main(args):
)
# Prepare for PGD
# 为 PGD 准备可训练的图像张量:对实例图像做与训练一致的 transforms
pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images]
pertubed_images = torch.stack(pertubed_images).contiguous()
pertubed_images.requires_grad_()
# 保留原始图像张量,用于 PGD 的投影约束
original_images = pertubed_images.clone().detach()
original_images.requires_grad_(False)
# 文本 token对所有实例图像重复同一个 instance_prompt保持原脚本行为
input_ids = train_dataset.tokenizer(
args.instance_prompt,
truncation=True,
@ -798,6 +817,7 @@ def main(args):
).input_ids.repeat(len(original_images), 1)
def get_one_mask(image):
"""与训练同样的随机缩放逻辑,生成单张实例图像的有效区域 mask。"""
random_scale = train_dataset.size
if train_dataset.aug:
random_scale = (
@ -812,6 +832,7 @@ def main(args):
one_mask += class_mask
return one_mask
# 预先为每张图像生成 mask并堆叠为 batch 形式供训练损失使用
images_open_list = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
mask_list = []
for image in images_open_list:
@ -831,12 +852,13 @@ def main(args):
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# 将可训练模块、优化器、对抗图像张量与 mask 一并交给 accelerate 管理设备与并行
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare(
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask
)
# Train!
# 训练主循环:每步同时更新 attention 增量层与对抗图像PGD
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num pertubed_images = {len(pertubed_images)}")
@ -844,36 +866,31 @@ def main(args):
global_step = 0
first_epoch = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.max_train_steps):
unet.train()
for _ in range(1):
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
# Convert images to latent space
# 将图像编码到 latent 空间并加噪,形成 UNet 的训练输入
pertubed_images.requires_grad = True
latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
# 文本条件编码
encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0]
# Predict the noise residual
# UNet 预测噪声
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
# 选择监督目标epsilon 或 v
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -881,39 +898,32 @@ def main(args):
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# unet.zero_grad()
# text_encoder.zero_grad()
# loss 计算:可选先验保持;实例部分可结合 mask 做空间加权
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
mask = mask.to(accelerator.device)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64])
#loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean()
accelerator.backward(loss)
# 梯度裁剪:只裁剪可训练的 custom_diffusion_layers 参数
if accelerator.sync_gradients:
params_to_clip = (
custom_diffusion_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# PGD 更新:基于 pertubed_images 的梯度做投影更新,并保持在 eps 约束内
alpha = args.alpha
eps = args.eps
adv_images = pertubed_images + alpha * pertubed_images.grad.sign()
@ -925,7 +935,6 @@ def main(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
@ -938,6 +947,7 @@ def main(args):
if global_step >= args.max_train_steps:
break
# 训练结束后在主进程保存最终对抗图像,文件名包含原始图片名以便对齐
if accelerator.is_main_process:
logger.info("***** Final save of perturbed images *****")
save_folder = args.output_dir
@ -955,7 +965,6 @@ def main(args):
img_name = img_names[i]
save_path = os.path.join(save_folder, f"final_noise_{img_name}")
# 图像转换和保存
Image.fromarray(
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
).save(save_path)
@ -970,3 +979,4 @@ if __name__ == "__main__":
args = parse_args()
main(args)
print("<-------end-------->")

@ -1,8 +1,3 @@
"""
Glaze: 艺术风格保护算法
基于原始 Glaze 项目重构适配 4090D GPU 直接运行
"""
import argparse
import os
import gc
@ -24,6 +19,7 @@ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
def parse_args(input_args=None):
"""解析命令行参数,包含模型路径、输入输出目录、风格迁移配置与扰动优化超参。"""
parser = argparse.ArgumentParser(description="Glaze: 艺术风格保护算法")
parser.add_argument(
"--pretrained_model_name_or_path",
@ -126,7 +122,7 @@ def parse_args(input_args=None):
parser.add_argument(
'--style_transfer_iter',
type=int,
default=15,
default=15,
help='风格迁移的扩散步数'
)
parser.add_argument(
@ -159,6 +155,7 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
# 兼容 accelerate/分布式启动时的 LOCAL_RANK 注入
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@ -167,7 +164,7 @@ def parse_args(input_args=None):
def get_eps_from_intensity(intensity):
"""根据强度级别 (0-100) 计算 epsilon 值"""
"""将强度等级映射为 epsilon用于以更直观的方式控制扰动幅度。"""
if intensity <= 50:
actual_eps = 0.025 + 0.025 * intensity / 50
else:
@ -176,7 +173,7 @@ def get_eps_from_intensity(intensity):
def img2tensor(cur_img, device='cuda'):
"""将 PIL 图像转换为 [-1, 1] 范围的张量"""
"""将 PIL 图像转换为 [-1, 1] 范围的张量,并按 (1,C,H,W) 形式返回。"""
cur_img = np.array(cur_img)
img = (cur_img / 127.5 - 1).astype(np.float32)
img = rearrange(img, 'h w c -> c h w')
@ -185,7 +182,7 @@ def img2tensor(cur_img, device='cuda'):
def tensor2img(cur_img):
"""将 [-1, 1] 范围的张量转换为 PIL 图像"""
"""将 [-1, 1] 范围的张量转换为 PIL 图像,便于保存与可视化。"""
if len(cur_img.shape) == 3:
cur_img = cur_img.unsqueeze(0)
cur_img = torch.clamp((cur_img.detach() + 1) / 2, min=0, max=1)
@ -195,7 +192,7 @@ def tensor2img(cur_img):
def load_img(path):
"""加载图像并处理 EXIF 旋转信息"""
"""加载图像并修正 EXIF 方向,统一输出为 RGB失败则返回 None。"""
if not os.path.exists(path):
return None
try:
@ -210,14 +207,14 @@ def load_img(path):
class GlazeDataset(Dataset):
"""用于加载待处理图像的数据集"""
"""从目录读取待处理图像,并返回图像张量、路径与原始 PIL 图像对象。"""
def __init__(self, instance_data_root, size=512, center_crop=False):
self.size = size
self.center_crop = center_crop
self.instance_images_path = []
# 支持的图像格式
# 过滤常见图像后缀,并避免重复处理已输出的 *_glazed 文件
valid_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.webp', '.tiff'}
for p in Path(instance_data_root).iterdir():
@ -227,6 +224,7 @@ class GlazeDataset(Dataset):
self.instance_images_path = sorted(self.instance_images_path)
self.num_instance_images = len(self.instance_images_path)
# 这里不做 Normalize保持输入在 [0,1],后续在编码前再转换到 [-1,1]
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
@ -241,8 +239,8 @@ class GlazeDataset(Dataset):
img_path = self.instance_images_path[index % self.num_instance_images]
instance_image = load_img(str(img_path))
# 对异常样本返回占位图,避免训练/处理流程中断
if instance_image is None:
# 返回空白图像作为占位
instance_image = Image.new('RGB', (self.size, self.size), (0, 0, 0))
example['index'] = index % self.num_instance_images
@ -253,20 +251,20 @@ class GlazeDataset(Dataset):
class GlazeOptimizer:
"""Glaze 优化器核心类"""
"""Glaze 核心优化器:负责生成目标风格参考,并在特征空间内优化输入扰动。"""
def __init__(self, args, device):
self.args = args
self.device = device
self.half = args.half_precision and device == 'cuda'
# 计算 epsilon
# eps 控制扰动最大幅度,可由 intensity 自动换算或直接手动指定
if args.intensity is not None:
self.max_change = get_eps_from_intensity(args.intensity)
else:
self.max_change = args.eps
# 计算步长
# 步长默认取 eps 的一半,并在迭代中做衰减以降低后期振荡
if args.step_size is not None:
self.step_size = args.step_size
else:
@ -275,12 +273,12 @@ class GlazeOptimizer:
print(f"扰动预算 (epsilon): {self.max_change:.4f}")
print(f"步长: {self.step_size:.4f}")
# 模型占位符
# 模型在需要时惰性加载,减少启动开销与显存占用峰值
self.vae = None
self.sd_pipeline = None
def load_vae(self):
"""加载 VAE 编码器"""
"""加载 VAE 编码器,用于将图像映射到特征空间并参与梯度计算。"""
print("加载 VAE 模型...")
self.vae = AutoencoderKL.from_pretrained(
self.args.pretrained_model_name_or_path,
@ -294,24 +292,24 @@ class GlazeOptimizer:
# 注意:不设置 requires_grad_(False),因为我们需要通过它计算梯度
def load_sd_pipeline(self):
"""加载 Stable Diffusion img2img 管道用于风格迁移"""
"""加载 Stable Diffusion img2img 管道,用于生成目标风格参考图像。"""
print("加载 Stable Diffusion 管道...")
# 始终使用 FP32 加载以避免 CPU 卸载问题
# 始终使用 FP32 加载以避免 CPU offload 等路径带来的精度与兼容问题
self.sd_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
self.args.pretrained_model_name_or_path,
revision=self.args.revision,
torch_dtype=torch.float32,
safety_checker=None, # 禁用 NSFW 检查器
safety_checker=None,
requires_safety_checker=False
)
self.sd_pipeline.to(self.device)
self.sd_pipeline.enable_attention_slicing()
# 如果使用半精度且在 GPU 上,转换为 FP16
# 在 GPU 上启用半精度,以减少显存并加速推理
if self.half and self.device == 'cuda':
self.sd_pipeline.to(torch.float16)
# 尝试启用 xformers
# 可选启用 xformers 注意力实现,以进一步降低显存占用
if self.args.enable_xformers_memory_efficient_attention:
try:
self.sd_pipeline.enable_xformers_memory_efficient_attention()
@ -320,7 +318,7 @@ class GlazeOptimizer:
print(f"无法启用 xformers: {e}")
def unload_sd_pipeline(self):
"""卸载 SD 管道以释放显存"""
"""释放 SD 管道占用的显存,避免与后续优化阶段竞争资源。"""
if self.sd_pipeline is not None:
del self.sd_pipeline
self.sd_pipeline = None
@ -330,36 +328,35 @@ class GlazeOptimizer:
def vae_encode(self, x):
"""
使用 VAE 编码图像
注意这里不使用 no_grad以便支持梯度计算
使用 VAE 编码图像并返回 posterior 均值
这里保留梯度计算通路使输入扰动可以通过特征距离损失进行优化
"""
posterior = self.vae.encode(x).latent_dist
return posterior.mean
def vae_encode_no_grad(self, x):
"""使用 VAE 编码图像(不计算梯度版本,用于目标编码)"""
"""不计算梯度的 VAE 编码版本,用于提取目标图像特征以节省显存。"""
with torch.no_grad():
posterior = self.vae.encode(x).latent_dist
return posterior.mean.detach()
def style_transfer(self, img):
"""
使用 Stable Diffusion 进行风格迁移
生成目标风格图像
使用 SD img2img 将输入图像迁移到目标风格得到用于对齐的目标风格参考图像
"""
if self.sd_pipeline is None:
self.load_sd_pipeline()
# 调整图像大小
# 将原图缩放到不超过 512并以左上角对齐的方式填充到 512x512 画布
img_copy = img.copy()
img_copy.thumbnail((512, 512), Image.LANCZOS)
# 创建 512x512 画布
canvas = np.zeros((512, 512, 3), dtype=np.uint8)
canvas[:img_copy.size[1], :img_copy.size[0], :] = np.array(img_copy)
padded_img = Image.fromarray(canvas)
# 运行风格迁移
# 生成目标风格图像,仅用于提供参考,不需要梯度
with torch.no_grad():
result = self.sd_pipeline(
prompt=self.args.target_style,
@ -371,19 +368,18 @@ class GlazeOptimizer:
target_img = result.images[0]
# 裁剪回原始大小
# 将输出裁剪回缩放后的有效区域,再 resize 回原图尺寸以对齐后续分块
cropped_target = np.array(target_img)[:img_copy.size[1], : img_copy.size[0], :]
cropped_target = Image.fromarray(cropped_target)
# 调整到原图大小
full_target = cropped_target.resize(img.size, Image.LANCZOS)
return full_target
def segment_image(self, img):
"""
将图像分割成 512x512 的方块
返回: (方块列表, 最后一个方块的偏移, 方块大小)
将输入图像切分为若干正方形分块并将每个分块缩放到 512x512
返回值包含分块列表最后一个分块的对齐偏移以及原始正方形分块的边长
"""
img_array = np.array(img).astype(np.float32)
og_width, og_height = img.size
@ -391,9 +387,8 @@ class GlazeOptimizer:
squares_ls = []
last_index = 0
# 判断是宽图还是高图
# 以短边为正方形边长,沿长边方向切块
if og_height <= og_width:
# 宽图:按水平方向分割
square_size = og_height
cur_idx = 0
@ -409,7 +404,6 @@ class GlazeOptimizer:
squares_ls.append(cropped_img)
cur_idx += og_height
else:
# 高图:按垂直方向分割
square_size = og_width
cur_idx = 0
@ -428,20 +422,16 @@ class GlazeOptimizer:
return squares_ls, last_index, square_size
def put_back_cloak(self, og_img_array, cloak_list, last_index):
"""
将扰动贴回原图
"""
"""将每个分块的扰动增量贴回原图位置,并裁剪到合法像素范围。"""
og_height, og_width, _ = og_img_array.shape
if og_height <= og_width:
# 宽图
for idx, cur_cloak in enumerate(cloak_list):
if idx < len(cloak_list) - 1:
og_img_array[0:og_height, idx * og_height:(idx + 1) * og_height, : ] += cur_cloak
else:
og_img_array[0:og_height, idx * og_height:(idx + 1) * og_height, :] += cur_cloak[0:og_height, last_index:]
else:
# 高图
for idx, cur_cloak in enumerate(cloak_list):
if idx < len(cloak_list) - 1:
og_img_array[idx * og_width:(idx + 1) * og_width, 0:og_width, :] += cur_cloak
@ -452,9 +442,7 @@ class GlazeOptimizer:
return og_img_array
def get_cloak(self, og_segment_img, res_adv_tensor, square_size):
"""
计算单个方块的扰动 (cloak)
"""
"""将对抗结果与原分块对齐后取差值,得到该分块需要回贴到原图的扰动增量。"""
resize_back = og_segment_img.resize((square_size, square_size), Image.LANCZOS)
res_adv_img = tensor2img(res_adv_tensor).resize((square_size, square_size), Image.LANCZOS)
cloak = np.array(res_adv_img).astype(np.float32) - np.array(resize_back).astype(np.float32)
@ -462,13 +450,13 @@ class GlazeOptimizer:
def compute_adversarial(self, source_segments, target_segments, square_size, progress_callback=None):
"""
计算对抗扰动
核心优化算法
对每个分块执行 PGD 式优化使源分块在 VAE 特征空间上逼近目标风格分块
该模块是核心优化过程损失为 adv_emb target_emb 的距离并对扰动做 epsilon 约束投影
"""
results = []
for seg_idx, (source_seg, target_seg) in enumerate(zip(source_segments, target_segments)):
# 转换为张量
source_tensor = img2tensor(source_seg, self.device)
target_tensor = img2tensor(target_seg, self.device)
@ -476,18 +464,15 @@ class GlazeOptimizer:
source_tensor = source_tensor.half()
target_tensor = target_tensor.half()
# 获取目标编码(不需要梯度)
target_emb = self.vae_encode_no_grad(target_tensor)
# 初始化:源图像和扰动
X_batch = source_tensor.clone().detach()
modifiers = torch.zeros_like(X_batch, requires_grad=True)
# 调整大小的变换
# 通过先缩放回原分块尺寸再缩放到 512模拟回贴后的尺度影响
resizer_large = torchvision.transforms.Resize(square_size)
resizer_512 = torchvision.transforms.Resize((512, 512))
# PGD 优化循环
pbar = tqdm(range(self.args.max_train_steps),
desc=f"优化方块 {seg_idx + 1}/{len(source_segments)}",
leave=False)
@ -495,61 +480,45 @@ class GlazeOptimizer:
best_modifier = None
for step in pbar:
# 动态调整步长
# 使用随步数衰减的步长,提升收敛稳定性
decay = 1 - (step / self.args.max_train_steps)
actual_step_size = self.step_size * decay
# 确保 modifiers 需要梯度
if not modifiers.requires_grad:
modifiers = modifiers.detach().clone().requires_grad_(True)
# 应用扰动并裁剪
X_adv = torch.clamp(modifiers + X_batch, -1, 1)
# 调整大小(模拟实际处理)
X_adv_resized = resizer_large(X_adv)
X_adv_resized = resizer_512(X_adv_resized)
# 计算损失:最小化与目标编码的距离
adv_emb = self.vae_encode(X_adv_resized)
loss = (adv_emb - target_emb).norm()
# 反向传播
loss.backward()
# 获取梯度
grad = modifiers.grad.detach()
# PGD 更新:沿梯度符号方向移动
# 沿梯度符号方向更新,并投影到 [-eps, eps] 的约束范围内
with torch.no_grad():
update = grad.sign() * actual_step_size
modifiers_new = modifiers - update # 最小化损失,所以是减
# 投影到 epsilon 球
modifiers_new = modifiers - update
modifiers_new = torch.clamp(modifiers_new, -self.max_change, self.max_change)
# 保存最佳结果
best_modifier = modifiers_new.detach().clone()
# 重新初始化 modifiers 用于下一轮
modifiers = best_modifier.clone().requires_grad_(True)
# 更新进度条
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# 回调
if progress_callback:
progress_callback(seg_idx, step, loss.item())
# 最终对抗样本
with torch.no_grad():
best_adv = torch.clamp(best_modifier + X_batch, -1, 1)
# 计算 cloak
cloak = self.get_cloak(source_seg, best_adv, square_size)
results.append(cloak)
# 清理显存
del source_tensor, target_tensor, X_batch, modifiers, best_modifier
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -557,34 +526,27 @@ class GlazeOptimizer:
return results
def process_image(self, img, run_idx=0):
"""
处理单张图像
"""
"""处理单张图像:生成目标风格参考→分块→逐块优化→回贴合成。"""
print(f"\n=== 处理图像 (运行 {run_idx + 1}/{self.args.n_runs}) ===")
# 1.生成目标风格图像
print("生成目标风格图像...")
target_img = self.style_transfer(img)
# 释放 SD 管道显存
# 风格参考生成后立即释放 SD 管道,优先保证后续优化阶段显存充足
self.unload_sd_pipeline()
# 确保 VAE 已加载
if self.vae is None:
self.load_vae()
# 2.分割图像
print("分割图像...")
source_segments, last_index, square_size = self.segment_image(img)
target_segments, _, _ = self.segment_image(target_img)
print(f"图像被分割为 {len(source_segments)} 个方块,大小: {square_size}x{square_size}")
# 3.计算对抗扰动
print("计算对抗扰动...")
cloak_list = self.compute_adversarial(source_segments, target_segments, square_size)
# 4.将扰动贴回原图
print("合成最终图像...")
og_array = np.array(img).astype(np.float32)
cloaked_array = self.put_back_cloak(og_array, cloak_list, last_index)
@ -594,13 +556,13 @@ class GlazeOptimizer:
def main(args):
# 设置随机种子
# 设置随机种子,保证风格迁移与优化过程的随机分支可复现
if args.seed is not None:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# 检测设备
# 选择运行设备并打印基础信息
if torch.cuda.is_available():
device = 'cuda'
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
@ -609,10 +571,8 @@ def main(args):
device = 'cpu'
print("使用 CPU")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 加载数据集
dataset = GlazeDataset(
instance_data_root=args.instance_data_dir,
size=args.resolution,
@ -629,10 +589,9 @@ def main(args):
print(f"优化步数: {args.max_train_steps}")
print(f"运行次数: {args.n_runs}")
# 创建优化器
optimizer = GlazeOptimizer(args, device)
# 处理每张图像
# 逐张处理,并按参数拼接输出文件名,便于回溯实验条件
for img_idx in range(len(dataset)):
img_data = dataset[img_idx]
img_path = img_data['path']
@ -644,13 +603,10 @@ def main(args):
best_result = None
# 多次运行取最佳结果
# 多次运行可缓解随机性影响;保持原逻辑:以最后一次成功结果作为输出
for run_idx in range(args.n_runs):
try:
cloaked_img = optimizer.process_image(original_img, run_idx)
# 简单起见,这里取最后一次运行的结果
# 完整版本应该用 CLIP 评估选择最佳结果
best_result = cloaked_img
except Exception as e:
@ -660,16 +616,14 @@ def main(args):
continue
if best_result is not None:
# 保存结果
orig_name = Path(img_path).stem
orig_ext = Path(img_path).suffix
# 构建输出文件名
intensity_str = f"intensity{args.intensity}" if args.intensity else f"eps{int(args.eps*255)}"
output_name = f"{orig_name}_glazed_{intensity_str}_steps{args.max_train_steps}{orig_ext}"
output_path = os.path.join(args.output_dir, output_name)
# 保存图像
# 按扩展名选择保存格式,避免某些格式默认压缩带来额外失真
if output_path.lower().endswith('.png'):
best_result.save(output_path, 'PNG')
else:

@ -14,90 +14,88 @@ from diffusers import AutoencoderKL
def parse_args(input_args=None):
"""
配置解析函数定义模型路径数据集位置攻击参数等命令行输入
"""
parser = argparse.ArgumentParser(description="Simple example of a training script.")
# 基础模型与路径配置
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
help="HuggingFace 模型标识或本地预训练模型路径",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="指定模型的特定版本(如 branch, tag 或 commit id",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
help="包含训练实例图像的文件夹路径",
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
help="训练结果和检查点的保存目录",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
# 训练超参数配置
parser.add_argument("--seed", type=int, default=None, help="用于可复现训练的随机种子")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
help="输入图像的分辨率,所有图像将调整为此大小",
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
help="是否对图像进行中心裁剪,否则进行随机裁剪",
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of updating steps",
help="最大训练更新步数",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
help="数据加载的子进程数0 表示在主进程中加载",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--local_rank", type=int, default=-1, help="分布式训练的本地排名")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
"--enable_xformers_memory_efficient_attention", action="store_true", help="是否启用 xformers 以优化内存占用"
)
# 对抗扰动攻击专用参数
parser.add_argument(
'--eps',
type=float,
default=12.75,
help='pertubation budget'
help='扰动预算限制(基于 255 像素刻度)'
)
parser.add_argument(
'--step_size',
type=float,
default=1/255,
help='step size of each update'
help='每一迭代步的扰动更新步长'
)
parser.add_argument(
'--attack_type',
choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'],
help='what is the attack target'
help='对抗攻击的目标类型(如方差、均值或 KL 散度)'
)
if input_args is not None:
@ -105,6 +103,7 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
# 处理分布式环境下的 rank 变量
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@ -114,8 +113,7 @@ def parse_args(input_args=None):
class PIDDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
数据集类负责加载图像处理 EXIF 信息并应用预处理变换
"""
def __init__(
@ -128,6 +126,8 @@ class PIDDataset(Dataset):
self.center_crop = center_crop
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
# 图像预处理流水线:缩放 -> 裁剪 -> 转换为张量
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
@ -139,8 +139,11 @@ class PIDDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
# 自动修正图像的方向(基于 EXIF 元数据)
instance_image = exif_transpose(instance_image)
# 统一强制转换为 RGB 格式
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
@ -150,19 +153,23 @@ class PIDDataset(Dataset):
def main(args):
# Set random seed
"""
主训练流程初始化模型生成对抗扰动并进行 PGD 优化
"""
# 设定随机种子以保证实验的可重复性
if args.seed is not None:
torch.manual_seed(args.seed)
weight_dtype = torch.float32
device = torch.device('cuda')
# VAE encoder
# 初始化 VAE 编码器(保持冻结,仅用于提取特征或计算损失)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
# Dataset and DataLoaders creation:
# 创建数据集和数据加载器Batch Size 固定为 1 以适配扰动一一对应关系)
dataset = PIDDataset(
instance_data_root=args.instance_data_dir,
size=args.resolution,
@ -170,103 +177,117 @@ def main(args):
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1, # some parts of code don't support batching
batch_size=1,
shuffle=True,
num_workers=args.dataloader_num_workers,
)
# Wrapper of the perturbations generator
# 对抗攻击模型封装
class AttackModel(torch.nn.Module):
def __init__(self):
super().__init__()
to_tensor = transforms.ToTensor()
self.epsilon = args.eps/255
# 为数据集中每一张图初始化一个随机微小扰动Delta
self.delta = [torch.empty_like(to_tensor(Image.open(path))).uniform_(-self.epsilon, self.epsilon)
for path in dataset.instance_images_path]
self.size = dataset.size
def forward(self, vae, x, index, poison=False):
# Check whether we need to add perturbation
# 若处于攻击模式,则给输入图像加上扰动张量
if poison:
self.delta[index].requires_grad_(True)
x = x + self.delta[index].to(dtype=weight_dtype)
# Normalize to [-1, 1]
# 归一化图像到 [-1, 1] 区间,符合 VAE 输入要求
input_x = 2 * x - 1
return vae.encode(input_x.to(device))
attackmodel = AttackModel()
# Just to zero-out the gradient
# 定义优化器(注意:此处 LR 为 0实际更新通过手动 PGD 符号梯度完成)
optimizer = torch.optim.SGD(attackmodel.delta, lr=0)
# Progress bar
# 设置进度条
progress_bar = tqdm(range(0, args.max_train_steps), desc="Steps")
# Make sure the dir exists
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# Start optimizing the perturbation
# 核心优化循环
for step in progress_bar:
total_loss = 0.0
for batch in dataloader:
# Save images
if step%25 == 0:
# 定期保存添加扰动后的图像,以便观察视觉效果
if step % 25 == 0:
to_image = transforms.ToPILImage()
for i in range(0, len(dataset.instance_images_path)):
img = dataset[i]['pixel_values']
img = to_image(img + attackmodel.delta[i])
img.save(os.path.join(args.output_dir, f"{i}.png"))
# Select target loss
# 分别计算原始图像和中毒(添加扰动)图像在潜空间的分布
clean_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], False)
poison_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], True)
clean_latent = clean_embedding.latent_dist
poison_latent = poison_embedding.latent_dist
# 根据攻击类型计算相应的损失函数(旨在拉开或改变分布特征)
if args.attack_type == 'var':
loss = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
elif args.attack_type == 'mean':
loss = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
elif args.attack_type == 'KL':
# 计算两个正态分布之间的 KL 散度
sigma_2, mu_2 = poison_latent.std, poison_latent.mean
sigma_1, mu_1 = clean_latent.std, clean_latent.mean
KL_diver = torch.log(sigma_2 / sigma_1) - 0.5 + (sigma_1 ** 2 + (mu_1 - mu_2) ** 2) / (2 * sigma_2 ** 2)
loss = KL_diver.flatten().mean()
elif args.attack_type == 'latent_vector':
# 直接对采样后的潜向量计算 MSE
clean_vector = clean_latent.sample()
poison_vector = poison_latent.sample()
loss = F.mse_loss(clean_vector, poison_vector, reduction="mean")
elif args.attack_type == 'add':
# 同时攻击均值和标准差
loss_2 = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
loss_1 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
loss = loss_1 + loss_2
elif args.attack_type == 'add-log':
# 攻击对数方差和均值(数值更稳定的变体)
loss_1 = F.mse_loss(clean_latent.var.log(), poison_latent.var.log(), reduction="mean")
loss_2 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction='mean')
loss = loss_1 + loss_2
# 清除梯度并执行反向传播
optimizer.zero_grad()
loss.backward()
# Perform PGD update on the loss
# 执行 PGD (Projected Gradient Descent) 更新步骤
delta = attackmodel.delta[batch['index']]
delta.requires_grad_(False)
# 沿梯度上升方向更新(最大化损失),实现攻击效果
delta += delta.grad.sign() * args.step_size
# 约束 1将扰动范围裁剪在 epsilon 预算内
delta = torch.clamp(delta, -attackmodel.epsilon, attackmodel.epsilon)
# 约束 2确保最终生成的图像像素值在 [0, 1] 合法区间内
delta = torch.clamp(delta, -batch['pixel_values'].detach().cpu(), 1-batch['pixel_values'].detach().cpu())
# 写回更新后的扰动并移除 Batch 维度
attackmodel.delta[batch['index']] = delta.detach().squeeze(0)
total_loss += loss.detach().cpu()
# Logging steps
# 更新进度条状态栏
logs = {"loss": total_loss.item()}
progress_bar.set_postfix(**logs)
if __name__ == "__main__":
args = parse_args()
main(args)
main(args)

@ -31,18 +31,14 @@ logger = get_logger(__name__)
def _cuda_gc() -> None:
"""Try to release unreferenced CUDA memory and reduce fragmentation.
This is a best-effort helper. It does not change algorithmic behavior but can
make long runs less prone to OOM due to fragmentation/reserved-memory growth.
"""
"""尽力释放未引用的 CUDA 内存,降低碎片化风险,不改变算法行为。"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
class DreamBoothDatasetFromTensor(Dataset):
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path."""
"""基于内存张量的 DreamBooth 数据集:直接返回图像张量与 prompt token减少磁盘 IO。"""
def __init__(
self,
@ -114,6 +110,7 @@ class DreamBoothDatasetFromTensor(Dataset):
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
# 依据 text_encoder 配置识别架构,加载对应实现
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
@ -133,6 +130,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def parse_args(input_args=None):
# 解析全量参数:模型、数据、对抗超参、先验保持、训练与日志设置
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
@ -364,7 +362,7 @@ def parse_args(input_args=None):
class PromptDataset(Dataset):
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
"""多 GPU 生成 class 图像的提示词数据集"""
def __init__(self, prompt, num_samples):
self.prompt = prompt
@ -485,11 +483,11 @@ def train_one_epoch(
f"instance_loss: {instance_loss.detach().item()}"
)
# Best-effort: free per-step tensors earlier (no behavior change).
# 尽早释放当前步的中间张量,降低显存占用
del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states
del model_pred, target, loss, prior_loss, instance_loss
# Best-effort: release optimizer state + dataset refs sooner.
# 释放优化器与数据集引用,进一步回收显存
del optimizer, train_dataset, params_to_optimize
_cuda_gc()
@ -497,6 +495,7 @@ def train_one_epoch(
def set_unet_attr(unet):
# 覆写若干 up_block 的 resnet forward捕获中间特征以供特征对齐损失使用
def conv_forward(self):
def forward(input_tensor, temb):
self.in_layers_features = input_tensor
@ -554,6 +553,7 @@ def set_unet_attr(unet):
def save_feature_maps(up_blocks, down_blocks):
# 收集指定 up_block 的输出特征,用于对抗攻击中的特征对齐
out_layers_features_list_3 = []
res_3_list = [0, 1, 2]
@ -577,11 +577,7 @@ def pgd_attack(
num_steps: int,
time_list,
):
"""Return new perturbed data.
Note: This function keeps the external behavior identical, but tries to reduce
memory pressure by freeing tensors early and avoiding lingering references.
"""
"""PGD 对抗扰动:按预选时间步迭代更新图像,可附加特征对齐正则;尝试提前释放无用张量。"""
unet, text_encoder = models
weight_dtype = torch.bfloat16
device = torch.device("cuda")
@ -594,7 +590,6 @@ def pgd_attack(
perturbed_images = data_tensor.detach().clone()
perturbed_images.requires_grad_(True)
# Keep input_ids on CPU; move to GPU only when encoding.
input_ids = tokenizer(
args.instance_prompt,
truncation=True,
@ -611,6 +606,7 @@ def pgd_attack(
noise = torch.randn_like(latents)
# 为每个样本从其时间步列表中随机选择一个时间步
timesteps = []
for i in range(len(data_tensor)):
ts = time_list[i]
@ -635,6 +631,7 @@ def pgd_attack(
noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks)
# 计算干净样本的对应特征用于对齐(不反传)
with torch.no_grad():
clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample()
clean_latents = clean_latents * vae.config.scaling_factor
@ -652,8 +649,7 @@ def pgd_attack(
text_encoder.zero_grad(set_to_none=True)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Keep original behavior: feature loss does not backprop (added as Python float).
loss = loss + target_loss.detach().item()
loss = loss + target_loss.detach().item() # 特征对齐损失保持为常数项,不反传
loss.backward()
alpha = args.pgd_alpha
@ -666,7 +662,7 @@ def pgd_attack(
f"PGD loss - step {step}, loss: {loss.detach().item()}, target_loss : {target_loss.detach().item()}"
)
# Best-effort: free per-step tensors early.
# 尽早释放当前步的中间张量
del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target
del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3
del target_loss, loss, adv_images, eta
@ -685,10 +681,7 @@ def select_timestep(
original_images: torch.Tensor,
target_tensor: torch.Tensor,
):
"""Return timestep lists for each image.
External behavior unchanged; add best-effort per-loop cleanup to lower memory pressure.
"""
"""为每张图选择一个时间步列表:通过多次梯度采样筛掉部分时间步,减少攻击开销,同时保持外部行为不变。"""
unet, text_encoder = models
weight_dtype = torch.bfloat16
device = torch.device("cuda")
@ -721,6 +714,7 @@ def select_timestep(
select_mask = torch.where(input_mask == 1, True, False)
res_time_seq = torch.masked_select(time_seq, select_mask)
# 如果剩余时间步仍多,随机抽取部分时间步估计梯度,删除一段时间步区间
if len(res_time_seq) > 100:
min_score, max_score = 0.0, 0.0
for inner_try in range(0, 5):
@ -776,6 +770,7 @@ def select_timestep(
del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, score
# 删除一段时间步以缩小候选集合,并记录当前选中的时间步
print("del_t", del_t, "max_t", select_t)
if del_t < args.delta_t:
del_t = args.delta_t
@ -832,6 +827,7 @@ def select_timestep(
def setup_seeds():
# 设置统一随机种子并关闭 cudnn 非确定性,保证结果可复现
seed = 42
random.seed(seed)
np.random.seed(seed)
@ -869,7 +865,7 @@ def main(args):
set_seed(args.seed)
setup_seeds()
# Generate class images if prior preservation is enabled.
# 先验保持:若 class 图像不足,则通过基础 pipeline 生成补齐
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
class_images_dir.mkdir(parents=True, exist_ok=True)
@ -1024,7 +1020,7 @@ def main(args):
time_list,
)
# Free surrogate ASAP (best-effort, behavior unchanged).
# 及时释放 surrogate保持显存占用稳定
del f_sur
_cuda_gc()
@ -1061,7 +1057,7 @@ def main(args):
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
# Best-effort cleanup at the end of each outer iteration.
# 外层迭代结束后的清理
_cuda_gc()

@ -0,0 +1,132 @@
import pandas as pd
from pathlib import Path
import sys
import logging
import numpy as np
import statsmodels.api as sm
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def apply_lowess_and_clipping_scaling(input_csv_path, output_csv_path, lowess_frac, target_range, clipping_percentile):
"""
应用 Lowess 局部加权回归进行平滑提取总体趋势然后使用百分位数裁剪后的 Min-Max 边界来缩放
目标生成最平滑最接近单调下降的客观趋势
"""
input_path = Path(input_csv_path)
output_path = Path(output_csv_path)
if not input_path.exists():
logging.error(f"错误:未找到输入文件 {input_csv_path}")
return
logging.info(f"读取原始数据: {input_csv_path}")
df = pd.read_csv(input_path)
df = df.loc[:,~df.columns.duplicated()].copy()
# 定义原始数据列名
raw_x_col = 'X_Feature_L2_Norm'
raw_y_col = 'Y_Feature_Variance'
raw_z_col = 'Z_LDM_Loss'
# --------------------------- 1. Lowess 局部加权回归平滑 (提取总体趋势) ---------------------------
logging.info(f"应用 Lowess 局部加权回归,平滑因子 frac={lowess_frac}")
x_coords = df['step'].values
for raw_col in [raw_x_col, raw_y_col, raw_z_col]:
y_coords = df[raw_col].values
smoothed_data = sm.nonparametric.lowess(
endog=y_coords,
exog=x_coords,
frac=lowess_frac,
it=0
)
df[f'{raw_col}_LOWESS'] = smoothed_data[:, 1]
# --------------------------- 2. 百分位数边界缩放与方向统一 ---------------------------
p = clipping_percentile
logging.info(f"应用百分位数边界 (p={p}) 进行线性缩放,目标范围 [0, {target_range:.2f}]")
scale_cols_map = {
'X_Feature_L2_Norm': f'{raw_x_col}_LOWESS',
'Y_Feature_Variance': f'{raw_y_col}_LOWESS',
'Z_LDM_Loss': f'{raw_z_col}_LOWESS'
}
for final_col, lowess_col in scale_cols_map.items():
data = df[lowess_col]
# 裁剪:计算裁剪后的 min/max (定义缩放窗口)
lower_bound = data.quantile(p)
upper_bound = data.quantile(1.0 - p)
min_val = lower_bound
max_val = upper_bound
data_range = max_val - min_val
if data_range <= 0 or data_range == np.nan:
df[final_col] = 0.0
logging.warning(f"{final_col} 裁剪后的范围为 {data_range:.4f},跳过缩放。")
continue
# 归一化: (data - Min_window) / Range_window
normalized_data = (data - min_val) / data_range
# **优化方向统一逻辑 (所有指标都应是越小越好):**
if final_col in ['X_Feature_L2_Norm', 'Y_Feature_Variance']:
# X/Y 反转:将 Max 映射到 0Min 映射到 TargetRange
final_scaled_data = (1.0 - normalized_data) * target_range
else: # Z_LDM_Loss
# Z 标准缩放Min 映射到 0Max 映射到 TargetRange
final_scaled_data = normalized_data * target_range
# 保留负值,以确保平滑过渡
df[final_col] = final_scaled_data
logging.info(f" - 列 {final_col}:裁剪边界: [{min_val:.4f}, {max_val:.4f}]。缩放后范围不再严格约束 [0, {target_range:.2f}],以保留趋势。")
# --------------------------- 3. 最终保存 ---------------------------
output_path.parent.mkdir(parents=True, exist_ok=True)
final_cols = ['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss']
df[final_cols].to_csv(
output_path,
index=False,
float_format='%.3f'
)
logging.info(f"Lowess平滑和缩放后的数据已保存到: {output_csv_path}")
if __name__ == '__main__':
if len(sys.argv) != 6:
logging.error("使用方法: python smooth_coords.py <输入CSV路径> <输出CSV路径> <Lowess 平滑因子 frac (例如 0.4)> <目标视觉范围 (例如 30)> <离散点裁剪百分比 (例如 0.15)>")
else:
input_csv = sys.argv[1]
output_csv = sys.argv[2]
try:
lowess_frac = float(sys.argv[3])
target_range = float(sys.argv[4])
clipping_p = float(sys.argv[5])
if not (0.0 < lowess_frac <= 1.0):
raise ValueError("Lowess 平滑因子 frac 必须在 (0.0, 1.0] 之间。")
if target_range <= 0:
raise ValueError("目标视觉范围必须大于 0。")
if not (0 <= clipping_p < 0.5):
raise ValueError("裁剪百分比必须在 [0, 0.5) 之间。")
if not Path(output_csv).suffix:
output_csv = str(Path(output_csv) / "scaled_coords.csv")
apply_lowess_and_clipping_scaling(input_csv, output_csv, lowess_frac, target_range, clipping_p)
except ValueError as e:
logging.error(f"参数错误: {e}")

@ -0,0 +1,149 @@
"""
图片处理功能用于把原始图片剪裁为中心正方形指定分辨率并保存为指定格式还可以选择是否序列化改名
"""
import argparse
import os
from pathlib import Path
from PIL import Image
# --- 1. 参数解析 ---
def parse_args(input_args=None):
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description="Image Processor for Centering, Resizing, and Format Conversion.")
# 路径和分辨率参数
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="A folder containing the original images to be processed and overwritten.",
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help="The target resolution (width and height) for the output images (e.g., 512 for 512x512).",
)
# 格式参数
parser.add_argument(
"--target_format",
type=str,
default="png",
choices=["jpeg", "png", "webp", "jpg"],
help="The target format for the saved images (e.g., 'png', 'jpg', 'webp'). The original file will be overwritten, potentially changing the file extension.",
)
# 序列化数字重命名参数
parser.add_argument(
"--rename_sequential",
action="store_true", # 当这个参数存在时,其值为 True
help="If set, images will be sequentially renamed (e.g., 001.jpg, 002.jpg...) instead of preserving the original filename. WARNING: This WILL delete the originals.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
# --- 2. 核心图像处理逻辑 ---
def process_image(image_path: Path, output_path: Path, resolution: int, target_format: str, delete_original: bool):
"""
加载图像居中取最大正方形升降分辨率并保存为目标格式
Args:
image_path: 原始图片路径
output_path: 最终保存路径
resolution: 目标分辨率
target_format: 目标文件格式
delete_original: 是否删除原始文件
"""
try:
# 加载图像并统一转换为 RGB 模式
img = Image.open(image_path).convert("RGB")
# 居中取最大正方形
width, height = img.size
min_dim = min(width, height)
# 计算裁剪框 (以最短边为尺寸的中心正方形)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
# 裁剪中心正方形
img = img.crop((left, top, right, bottom))
# 升降分辨率到指定 resolution
# 使用 LANCZOS 高质量重采样方法
img = img.resize((resolution, resolution), resample=Image.Resampling.LANCZOS)
# 准备输出格式
save_format = target_format.upper().replace('JPEG', 'JPG')
# 保存图片
# 对于 JPEG/JPG设置 quality 参数
if save_format == 'JPG':
img.save(output_path, format='JPEG', quality=95)
else:
img.save(output_path, format=save_format)
# 根据标记决定是否删除原始文件
if delete_original and image_path.resolve() != output_path.resolve():
os.remove(image_path)
print(f"Processed: {image_path.name} -> {output_path.name} ({resolution}x{resolution} {save_format})")
except Exception as e:
print(f"Error processing {image_path.name}: {e}")
# --- 3. 主函数 ---
def main(args):
# 路径准备
input_dir = Path(args.input_dir)
if not input_dir.is_dir():
print(f"Error: Input directory not found at {input_dir}")
return
# 查找所有图片文件 (支持 jpg, jpeg, png, webp)
valid_suffixes = ['.jpg', '.jpeg', '.png', '.webp']
image_paths = sorted([p for p in input_dir.iterdir() if p.suffix.lower() in valid_suffixes]) # 排序以确保重命名顺序一致
if not image_paths:
print(f"No image files found in {input_dir}")
return
print(f"Found {len(image_paths)} images in {input_dir}. Starting processing...")
# 准备目标格式的扩展名
extension = args.target_format.lower().replace('jpeg', 'jpg')
# 迭代处理图片
for i, image_path in enumerate(image_paths):
# 决定输出路径
if args.rename_sequential:
# 顺序重命名逻辑001, 002, 003... (至少三位数字)
new_name = f"{i + 1:03d}.{extension}"
output_path = input_dir / new_name
# 如果原始文件与新文件名称不同,则需要删除原始文件
delete_original = True
else:
# 保持原始文件名,但修改后缀
output_path = image_path.with_suffix(f'.{extension}')
# 只有当原始后缀与目标后缀不同时,才需要删除原始文件(防止遗留旧格式)
delete_original = (image_path.suffix.lower() != f'.{extension}')
process_image(image_path, output_path, args.resolution, args.target_format, delete_original)
print("Processing complete.")
if __name__ == "__main__":
args = parse_args()
main(args)

@ -1,3 +1,4 @@
#需要环境conda activate caat
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
#export HF_HOME="/root/autodl-tmp/huggingface_cache"

@ -1,5 +1,4 @@
#!/bin/bash
#需要环境conda activate pid
#=============================================================================
# Glaze 风格保护攻击脚本
# 用于保护艺术作品免受 AI 模型的风格模仿

@ -1,5 +1,4 @@
#!/bin/bash
#需要环境conda activate pid
#=============================================================================
# Glaze 风格保护攻击脚本
# 用于保护艺术作品免受 AI 模型的风格模仿

@ -4,11 +4,17 @@
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
### SD v2.1
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="stabilityai/stable-diffusion-2-1"
### SD v1.5
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="runwayml/stable-diffusion-v1-5"
export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task003"
export TASKNAME="task001"
### Data to be protected
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
### Path to save the protected data
@ -26,8 +32,7 @@ echo "Clearing output directory: $OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
export PYTHONWARNINGS="ignore"
#忽略所有警告
### Generation command
# --max_train_steps: Optimizaiton steps
@ -41,6 +46,6 @@ CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \
--resolution=512 \
--max_train_steps=1000 \
--center_crop \
--eps 12 \
--step_size 0.002 \
--attack_type add-log
--eps 12.75 \
--attack_type add-log

@ -1,46 +0,0 @@
#需要环境conda activate pid
### Generate images protected by PID
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
### SD v1.5
export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task003"
### Data to be protected
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
### Path to save the protected data
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$OUTPUT_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
export PYTHONWARNINGS="ignore"
#忽略所有警告
### Generation command
# --max_train_steps: Optimizaiton steps
# --attack_type: target loss to update, choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'],
# Please refer to the file content for more usage
CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--max_train_steps=120 \
--eps 16 \
--step_size 0.01 \
--attack_type add-log \
--center_crop

@ -1,3 +1,4 @@
#需要环境conda activate pid
# ----------------- 1. 环境与模型配置 -----------------
# 强制 Hugging Face 库使用本地模型缓存 (离线模式)

@ -1,3 +1,4 @@
#需要环境conda activate pid
# ----------------- 1. 环境与路径配置 -----------------
export TASKNAME="task001"

Loading…
Cancel
Save