将lianghao_branch合并到develop #10

Merged
hnu202326010204 merged 5 commits from lianghao_branch into develop 1 month ago

@ -1,520 +0,0 @@
"""Stable Diffusion 注意力热力图差异可视化工具 (可靠版 - 语义阶段聚合)。
本模块使用一种健壮的方法通过在 Stable Diffusion 扩散模型U-Net
**早期时间步 (语义阶段)** 捕获并累加交叉注意力权重这种方法能确保捕获到的
注意力图信号集中且可靠用于对比分析干净输入和扰动输入生成的图像对模型
注意力机制的影响差异
典型用法:
python eva_gen_heatmap.py \\
--model_path /path/to/sd_model \\
--image_path_a /path/to/clean_image.png \\
--image_path_b /path/to/noisy_image.png \\
--prompt_text "a photo of sks person" \\
--target_word "sks" \\
--output_dir output/heatmap_reports
"""
# 通用参数解析与文件路径管理
import argparse
import os
from pathlib import Path
from typing import Dict, Any, List, Tuple
# 数值计算与深度学习依赖
import torch
import torch.nn.functional as F
import numpy as np
import itertools
import warnings
# 可视化依赖
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Diffusers 与 Transformers 依赖
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import Attention
from transformers import CLIPTokenizer
# 图像处理与数据读取
from PIL import Image
from torchvision import transforms
# 抑制非必要的警告输出
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# ============== 核心模块:注意力捕获与聚合 ==============
class AttentionMapProcessor:
"""自定义注意力处理器,用于捕获 U-Net 交叉注意力层的权重。
通过替换原始的 `Attention` 模块处理器该类在模型前向传播过程中
将所有交叉注意力层的注意力权重`attention_probs`捕获并存储
Attributes:
attention_maps (Dict[str, List[torch.Tensor]]): 存储捕获到的注意力图
键为层名称值为该层在不同时间步捕获到的注意力图列表
pipeline (StableDiffusionPipeline): 正在处理的 Stable Diffusion 管线
original_processors (Dict[str, Any]): 存储原始的注意力处理器用于恢复
current_layer_name (Optional[str]): 当前正在处理的注意力层的名称
"""
def __init__(self, pipeline: StableDiffusionPipeline):
"""初始化注意力处理器。
Args:
pipeline: Stable Diffusion 模型管线实例
"""
self.attention_maps: Dict[str, List[torch.Tensor]] = {}
self.pipeline = pipeline
self.original_processors = {}
self.current_layer_name = None
self._set_processors()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""重载 __call__ 方法,执行注意力计算并捕获权重。
此方法替代了原始的 `Attention.processor`在计算交叉注意力时进行捕获
Args:
attn: 当前的 `Attention` 模块实例
hidden_states: U-Net 隐状态 (query)
encoder_hidden_states: 文本编码器输出 (key/value)即交叉注意力输入
attention_mask: 注意力掩码
Returns:
计算后的输出隐状态
"""
# 如果不是交叉注意力(即 encoder_hidden_states 为 None则调用原始处理器
if encoder_hidden_states is None:
return attn.processor(
attn, hidden_states, encoder_hidden_states, attention_mask
)
# 1. 计算 Q, K, V
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 2. 准备矩阵乘法
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
# 3. 计算 Attention Scores (Q @ K^T)
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0], query.shape[1], key.shape[1],
dtype=query.dtype, device=query.device
),
query,
key.transpose(1, 2),
beta=0,
alpha=attn.scale,
)
# 4. 计算 Attention Probabilities
attention_probs = attention_scores.softmax(dim=-1)
layer_name = self.current_layer_name
# 5. 存储捕获的注意力图
if layer_name not in self.attention_maps:
self.attention_maps[layer_name] = []
# 存储当前时间步的注意力权重
self.attention_maps[layer_name].append(attention_probs.detach().cpu())
# 6. 计算输出 (Attention @ V)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# 7. 输出层
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def _set_processors(self):
"""注册自定义处理器,捕获 U-Net 中所有交叉注意力层的权重。
遍历 U-Net 的所有子模块找到所有交叉注意力层`Attention` 且名称包含 `attn2`
并将其处理器替换为当前的实例
"""
for name, module in self.pipeline.unet.named_modules():
if isinstance(module, Attention) and 'attn2' in name:
# 存储原始处理器以便后续恢复
self.original_processors[name] = module.processor
# 定义一个新的闭包函数,用于在调用前设置当前层的名称
def set_layer_name(current_name):
def new_call(*args, **kwargs):
self.current_layer_name = current_name
return self.__call__(*args, **kwargs)
return new_call
module.processor = set_layer_name(name)
def remove(self):
"""恢复 U-Net 的原始注意力处理器,清理钩子。"""
for name, original_processor in self.original_processors.items():
module = self.pipeline.unet.get_submodule(name)
module.processor = original_processor
self.attention_maps = {}
def aggregate_word_attention(
attention_maps: Dict[str, List[torch.Tensor]],
tokenizer: CLIPTokenizer,
target_word: str,
input_ids: torch.Tensor
) -> np.ndarray:
"""聚合所有层和语义时间步中目标词汇的注意力图,并进行归一化。
聚合步骤
1. 识别目标词汇对应的 Token 索引
2. 对每个层将所有捕获时间步的注意力图求平均
3. 提取目标 Token 对应的注意力子图并对 Token 维度求和 Attention Heads 求平均
4. 将不同分辨率的注意力图上采样到统一尺寸64x64
5. 对所有层的结果进行累加求和
6. 最终归一化到 [0, 1]
Args:
attention_maps: 包含各层和时间步捕获的注意力图的字典
tokenizer: CLIP 分词器实例
target_word: 需要聚焦的关键词
input_ids: Prompt 对应的 Token ID 张量
Returns:
最终聚合并上采样到 64x64 尺寸的注意力热力图 (NumPy 数组)
Raises:
ValueError: 如果无法在 Prompt 中找到目标词汇
RuntimeError: 如果未捕获到任何注意力数据
"""
# 1. 识别目标词汇的 Token 索引
prompt_tokens = tokenizer.convert_ids_to_tokens(
input_ids.squeeze().cpu().tolist()
)
target_lower = target_word.lower()
target_indices = []
for i, token in enumerate(prompt_tokens):
cleaned_token = token.replace('Ġ', '').replace('_', '').lower()
# 查找目标词汇或以目标词汇开头的 token 索引,并排除特殊 token
if (input_ids.squeeze()[i] not in tokenizer.all_special_ids and
(target_lower in cleaned_token or
cleaned_token.startswith(target_lower))):
target_indices.append(i)
if not target_indices:
print(f"[WARN] 目标词汇 '{target_word}' 未识别。请检查 Prompt 或 Target Word。")
raise ValueError("无法识别目标词汇的 token 索引。")
# 2. 聚合逻辑
all_attention_data = []
# U-Net 输出的最大分辨率64x64总像素点数
TARGET_SPATIAL_SIZE = 4096
TARGET_MAP_SIZE = 64
for layer_name, step_maps in attention_maps.items():
if not step_maps:
continue
# 对该层捕获的所有时间步求平均,形状: (batch, heads, spatial_res, target_tokens_len)
avg_map_over_time = torch.stack(step_maps).mean(dim=0)
# 移除批次维度 (假设 batch size = 1),形状: (heads, spatial_res, target_tokens_len)
attention_map = avg_map_over_time.squeeze(0)
# 提取目标 token 的注意力图。形状: (heads, spatial_res, target_indices_len)
target_token_maps = attention_map[:, :, target_indices]
# 对目标 token 求和 (dim=-1),对注意力头求平均 (dim=0),形状: (spatial_res,)
aggregated_map_flat = target_token_maps.sum(dim=-1).mean(dim=0).float()
# 3. 跨分辨率上采样
if aggregated_map_flat.shape[0] != TARGET_SPATIAL_SIZE:
# 当前图的尺寸16x16 (256) 或 32x32 (1024)
map_size = int(np.sqrt(aggregated_map_flat.shape[0]))
map_2d = aggregated_map_flat.reshape(map_size, map_size)
map_to_interp = map_2d.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
# 使用双线性插值上采样到 64x64
resized_map_2d = F.interpolate(
map_to_interp,
size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE),
mode='bilinear',
align_corners=False
)
resized_map_flat = resized_map_2d.squeeze().flatten()
all_attention_data.append(resized_map_flat)
else:
# 如果已经是 64x64直接使用
all_attention_data.append(aggregated_map_flat)
if not all_attention_data:
raise RuntimeError("未捕获到注意力数据。可能模型或参数设置有误。")
# 4. 对所有层的结果进行累加 (求和)
final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy()
# 5. 最终归一化到 [0, 1]
final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6)
map_size = int(np.sqrt(final_map_flat.shape[0]))
final_map_np = final_map_flat.reshape(map_size, map_size) # 64x64
return final_map_np
def get_attention_map_from_image(
pipeline: StableDiffusionPipeline,
image_path: str,
prompt_text: str,
target_word: str
) -> Tuple[Image.Image, np.ndarray]:
"""执行多时间步前向传播,捕获指定图片和 Prompt 的注意力图。
通过只运行扩散过程中的语义阶段早期时间步来确保捕获到的注意力权重
具有高信号质量
Args:
pipeline: Stable Diffusion 模型管线实例
image_path: 待处理的输入图片路径
prompt_text: 用于生成图片的 Prompt 文本
target_word: 需要聚焦和可视化的关键词
Returns:
包含 (原始图片, 最终上采样后的注意力图) 的元组
"""
print(f"\n-> 正在处理图片: {Path(image_path).name}")
image = Image.open(image_path).convert("RGB").resize((512, 512))
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
image_tensor = (
image_transform(image)
.unsqueeze(0)
.to(pipeline.device)
.to(pipeline.unet.dtype)
)
# 1. 编码到 Latent 空间
with torch.no_grad():
latent = (
pipeline.vae.encode(image_tensor).latent_dist.sample() *
pipeline.vae.config.scaling_factor
)
# 2. 编码 Prompt
text_input = pipeline.tokenizer(
prompt_text,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
input_ids = text_input.input_ids
with torch.no_grad():
# 获取文本嵌入
prompt_embeds = pipeline.text_encoder(
input_ids.to(pipeline.device)
)[0]
# 3. 定义语义时间步
scheduler = pipeline.scheduler
# 设置扩散步数 (例如 50 步)
scheduler.set_timesteps(50, device=pipeline.device)
# 只选择语义最丰富的早期 10 步进行捕获
semantic_steps = scheduler.timesteps[:10]
print(f"-> 正在对语义阶段的 {len(semantic_steps)} 个时间步进行注意力捕获...")
processor = AttentionMapProcessor(pipeline)
try:
# 4. 运行多步 UNet Forward Pass
with torch.no_grad():
# 在选定的语义时间步上运行 U-Net 预测
for t in semantic_steps:
pipeline.unet(latent, t, prompt_embeds, return_dict=False)
# 5. 聚合捕获到的数据
raw_map_np = aggregate_word_attention(
processor.attention_maps,
pipeline.tokenizer,
target_word,
input_ids
)
except Exception as e:
print(f"[ERROR] 注意力聚合失败: {e}")
# 确保清理钩子
raw_map_np = np.zeros(image.size)
finally:
processor.remove()
# 6. 注意力图上采样到图片尺寸 (512x512)
# PIL 进行上采样
heat_map_pil = Image.fromarray((raw_map_np * 255).astype(np.uint8))
heat_map_np_resized = (
np.array(heat_map_pil.resize(
image.size,
resample=Image.Resampling.LANCZOS # 使用高质量的 Lanczos 滤波器
)) / 255.0
)
return image, heat_map_np_resized
def main():
"""主函数,负责解析参数,加载模型,计算差异并生成可视化报告。"""
parser = argparse.ArgumentParser(description="SD 图片注意力差异可视化报告生成")
parser.add_argument("--model_path", type=str, required=True,
help="Stable Diffusion 模型本地路径。")
parser.add_argument("--image_path_a", type=str, required=True,
help="干净输入图片 (X) 路径。")
parser.add_argument("--image_path_b", type=str, required=True,
help="扰动输入图片 (X') 路径。")
parser.add_argument("--prompt_text", type=str, default="a photo of sks person",
help="用于生成图片的 Prompt 文本。")
parser.add_argument("--target_word", type=str, default="sks",
help="需要在注意力图中聚焦和可视化的关键词。")
parser.add_argument("--output_dir", type=str, default="output",
help="报告 PNG 文件的输出目录。")
args = parser.parse_args()
print(f"--- 正在生成 Stable Diffusion 注意力差异报告 ---")
# ---------------- 准备模型 ----------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
try:
# 加载 Stable Diffusion 管线
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path,
torch_dtype=dtype,
local_files_only=True,
safety_checker=None,
# 从子文件夹加载调度器配置
scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler")
).to(device)
except Exception as e:
print(f"[ERROR] 模型加载失败,请检查路径和环境依赖: {e}")
return
# ---------------- 获取数据 ----------------
# 获取干净图片 A 的注意力图 M_A
img_A, map_A = get_attention_map_from_image(pipe, args.image_path_a, args.prompt_text, args.target_word)
# 获取扰动图片 B 的注意力图 M_B
img_B, map_B = get_attention_map_from_image(pipe, args.image_path_b, args.prompt_text, args.target_word)
if map_A.shape != map_B.shape:
print("错误:注意力图尺寸不匹配。中止处理。")
return
# 计算差异图: Delta = M_A - M_B
diff_map = map_A - map_B
# 计算 L2 范数(差异距离)
l2_diff = np.linalg.norm(diff_map)
print(f"\n计算完毕,注意力图的 L2 范数差异值: {l2_diff:.4f}")
# ---------------- 绘制专业报告 ----------------
# 设置 Matplotlib 字体样式
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
'mathtext.fontset': 'cm'
})
fig = plt.figure(figsize=(12, 16), dpi=120)
# 3行 x 4列 网格布局,用于图片和图例的精确控制
gs = gridspec.GridSpec(3, 4, figure=fig,
height_ratios=[1, 1, 1.3],
hspace=0.3, wspace=0.1)
# --- 第一行:原始图片 ---
ax_img_a = fig.add_subplot(gs[0, 0:2])
ax_img_b = fig.add_subplot(gs[0, 2:4])
# 干净图片
ax_img_a.imshow(img_A)
ax_img_a.set_title(f"Clean Image ($X$)\nFilename: {Path(args.image_path_a).name}", fontsize=14, pad=10)
ax_img_a.axis('off')
# 扰动图片
ax_img_b.imshow(img_B)
ax_img_b.set_title(f"Noisy Image ($X'$)\nFilename: {Path(args.image_path_b).name}", fontsize=14, pad=10)
ax_img_b.axis('off')
# --- 第二行:注意力热力图 (Jet配色) ---
ax_map_a = fig.add_subplot(gs[1, 0:2])
ax_map_b = fig.add_subplot(gs[1, 2:4])
# 注意力图 A
im_map_a = ax_map_a.imshow(map_A, cmap='jet', vmin=0, vmax=1)
ax_map_a.set_title(f"Attention Heatmap ($M_X$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
ax_map_a.axis('off')
# 注意力图 B
im_map_b = ax_map_b.imshow(map_B, cmap='jet', vmin=0, vmax=1)
ax_map_b.set_title(f"Attention Heatmap ($M_{{X'}}$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
ax_map_b.axis('off')
# 为注意力图 B 绘制颜色指示条
divider = make_axes_locatable(ax_map_b)
cax_map = divider.append_axes("right", size="5%", pad=0.05)
cbar1 = fig.colorbar(im_map_b, cax=cax_map)
cbar1.set_label('Attention Intensity', fontsize=10)
# --- 第三行:差异对比 (完美居中) ---
# 差异图在网格的中间两列
ax_diff = fig.add_subplot(gs[2, 1:3])
vmax_diff = np.max(np.abs(diff_map))
# 使用 TwoSlopeNorm 确保 0 值位于色条中央
norm_diff = TwoSlopeNorm(vmin=-vmax_diff, vcenter=0., vmax=vmax_diff)
# 使用 Coolwarm 配色,蓝色表示负差异 (M_X' > M_X),红色表示正差异 (M_X > M_X')
im_diff = ax_diff.imshow(diff_map, cmap='coolwarm', norm=norm_diff)
title_text = (
r"Difference Map: $\Delta = M_X - M_{X'}$" +
f"\n$L_2$ Norm Distance: $\mathbf{{{l2_diff:.4f}}}$"
)
ax_diff.set_title(title_text, fontsize=16, pad=12)
ax_diff.axis('off')
# 差异图颜色指示条 (居中对齐)
cbar2 = fig.colorbar(im_diff, ax=ax_diff, fraction=0.046, pad=0.04)
cbar2.set_label(r'Scale: Red ($+$) $\leftrightarrow$ Blue ($-$)', fontsize=12)
# ---------------- 整体修饰与保存 ----------------
fig.suptitle(f"Museguard: SD Attention Analysis Report", fontsize=20, fontweight='bold', y=0.95)
output_filename = "heatmap_dif.png"
output_path = Path(args.output_dir) / output_filename
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
print(f"\n专业分析报告已保存至:\n{output_path.resolve()}")
if __name__ == "__main__":
main()

@ -1,513 +0,0 @@
"""图像生成质量多维度评估工具 (专业重构版)。
本脚本用于对比评估两组图像Clean vs Perturbed的生成质量
支持生成包含指标对比表和深度差异分析的 PNG 报告
Style Guide: Google Python Style Guide
"""
import os
import time
import subprocess
import tempfile
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Dict, Optional, Tuple, Any
import torch
import clip
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
from torchvision import transforms
from facenet_pytorch import MTCNN, InceptionResnetV1
from piq import ssim, psnr
import torch_fidelity as fid
# 抑制非必要的警告输出
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# -----------------------------------------------------------------------------
# 全局配置与样式
# -----------------------------------------------------------------------------
# Matplotlib LaTeX 风格配置
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
'mathtext.fontset': 'cm',
'axes.unicode_minus': False
})
# 指标元数据配置:定义指标目标方向和分析阈值
METRIC_ANALYSIS_META = {
'FID': {'higher_is_better': False, 'th': [2.0, 10.0, 30.0]},
'SSIM': {'higher_is_better': True, 'th': [0.01, 0.05, 0.15]},
'PSNR': {'higher_is_better': True, 'th': [0.5, 2.0, 5.0]},
'FDS': {'higher_is_better': True, 'th': [0.02, 0.05, 0.1]},
'CLIP_IQS': {'higher_is_better': True, 'th': [0.01, 0.03, 0.08]},
'BRISQUE': {'higher_is_better': False, 'th': [2.0, 5.0, 10.0]},
}
# 用于综合分析的降级权重
ANALYSIS_WEIGHTS = {'Severe': 3, 'Significant': 2, 'Slight': 1, 'Negligible': 0}
# -----------------------------------------------------------------------------
# 模型加载 (惰性加载或全局预加载)
# -----------------------------------------------------------------------------
try:
CLIP_MODEL, CLIP_PREPROCESS = clip.load('ViT-B/32', 'cuda')
CLIP_MODEL.eval()
except Exception as e:
print(f"[Warning] CLIP 模型加载失败: {e}")
CLIP_MODEL, CLIP_PREPROCESS = None, None
def _get_clip_text_features(text: str) -> torch.Tensor:
"""辅助函数:获取文本的 CLIP 特征。"""
if CLIP_MODEL is None:
return None
tokens = clip.tokenize(text).to('cuda')
with torch.no_grad():
features = CLIP_MODEL.encode_text(tokens)
features /= features.norm(dim=-1, keepdim=True)
return features
# -----------------------------------------------------------------------------
# 核心计算逻辑
# -----------------------------------------------------------------------------
def calculate_metrics(
ref_dir: str,
gen_dir: str,
image_size: int = 512
) -> Dict[str, float]:
"""计算图像集之间的多项质量评估指标。
包括 FDS, SSIM, PSNR, CLIP_IQS, FID
Args:
ref_dir: 参考图片目录路径
gen_dir: 生成图片目录路径
image_size: 图像处理尺寸
Returns:
包含各项指标名称和数值的字典若目录无效返回空字典
"""
metrics = {}
# 1. 数据加载
def load_images(directory):
imgs = []
if os.path.exists(directory):
for f in os.listdir(directory):
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
path = os.path.join(directory, f)
imgs.append(Image.open(path).convert("RGB"))
except Exception:
pass
return imgs
ref_imgs = load_images(ref_dir)
gen_imgs = load_images(gen_dir)
if not ref_imgs or not gen_imgs:
print(f"[Error] 图片加载失败或目录为空: \nRef: {ref_dir}\nGen: {gen_dir}")
return {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
# --- FDS (Face Detection Similarity) ---
print(">>> 计算 FDS...")
mtcnn = MTCNN(image_size=image_size, margin=0, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
def get_face_embeds(img_list):
embeds = []
for img in img_list:
face = mtcnn(img)
if face is not None:
embeds.append(resnet(face.unsqueeze(0).to(device)))
return torch.stack(embeds) if embeds else None
ref_embeds = get_face_embeds(ref_imgs)
gen_embeds = get_face_embeds(gen_imgs)
if ref_embeds is not None and gen_embeds is not None:
# 计算生成集每张脸与参考集所有脸的余弦相似度均值
sims = []
for g_emb in gen_embeds:
sim = torch.cosine_similarity(g_emb, ref_embeds).mean()
sims.append(sim)
metrics['FDS'] = torch.tensor(sims).mean().item()
else:
metrics['FDS'] = 0.0
# 清理显存
del mtcnn, resnet
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- SSIM & PSNR ---
print(">>> 计算 SSIM & PSNR...")
tfm = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
# 将参考集堆叠为 [N, C, H, W]
ref_tensor = torch.stack([tfm(img) for img in ref_imgs]).to(device)
ssim_accum, psnr_accum = 0.0, 0.0
for img in gen_imgs:
gen_tensor = tfm(img).unsqueeze(0).to(device) # [1, C, H, W]
# 扩展维度以匹配参考集
gen_expanded = gen_tensor.expand_as(ref_tensor)
# 计算单张生成图相对于整个参考集的平均结构相似度
val_ssim = ssim(gen_expanded, ref_tensor, data_range=1.0)
val_psnr = psnr(gen_expanded, ref_tensor, data_range=1.0)
ssim_accum += val_ssim.item()
psnr_accum += val_psnr.item()
metrics['SSIM'] = ssim_accum / len(gen_imgs)
metrics['PSNR'] = psnr_accum / len(gen_imgs)
# --- CLIP IQS ---
print(">>> 计算 CLIP IQS...")
if CLIP_MODEL:
iqs_accum = 0.0
txt_feat = _get_clip_text_features("good image")
for img in gen_imgs:
img_tensor = CLIP_PREPROCESS(img).unsqueeze(0).to(device)
img_feat = CLIP_MODEL.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
iqs_accum += (img_feat @ txt_feat.T).item()
metrics['CLIP_IQS'] = iqs_accum / len(gen_imgs)
else:
metrics['CLIP_IQS'] = np.nan
# --- FID ---
print(">>> 计算 FID...")
try:
fid_res = fid.calculate_metrics(
input1=ref_dir,
input2=gen_dir,
cuda=True,
fid=True,
verbose=False
)
metrics['FID'] = fid_res['frechet_inception_distance']
except Exception as e:
print(f"[Error] FID 计算异常: {e}")
metrics['FID'] = np.nan
return metrics
def run_brisque_cleanly(img_dir: str) -> float:
"""使用 subprocess 和临时目录优雅地执行外部 BRISQUE 脚本。
Args:
img_dir: 图像目录路径
Returns:
BRISQUE 分数若失败返回 NaN
"""
print(f">>> 计算 BRISQUE (External)...")
script_path = Path(__file__).parent / 'libsvm' / 'python' / 'brisquequality.py'
if not script_path.exists():
print(f"[Error] 找不到 BRISQUE 脚本: {script_path}")
return np.nan
abs_img_dir = os.path.abspath(img_dir)
with tempfile.TemporaryDirectory() as temp_dir:
try:
cmd = [
"python", str(script_path),
abs_img_dir,
temp_dir
]
# 在脚本所在目录执行
subprocess.run(
cmd,
cwd=script_path.parent,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 读取临时生成的日志文件
log_file = Path(temp_dir) / 'log.txt'
if log_file.exists():
content = log_file.read_text(encoding='utf-8').strip()
try:
return float(content.split()[-1])
except ValueError:
return float(content)
else:
return np.nan
except Exception as e:
print(f"[Error] BRISQUE 执行出错: {e}")
return np.nan
# -----------------------------------------------------------------------------
# 报告可视化与分析逻辑
# -----------------------------------------------------------------------------
def analyze_metric_diff(
metric_name: str,
clean_val: float,
pert_val: float
) -> Tuple[str, str, str]:
"""生成科学的分级差异分析文本。
Args:
metric_name: 指标名称
clean_val: 干净图得分
pert_val: 扰动图得分
Returns:
(表头箭头符号, 差异描述文本, 状态等级)
"""
cfg = METRIC_ANALYSIS_META.get(metric_name)
if not cfg:
return "-", "Configuration not found.", "Negligible"
diff = pert_val - clean_val
abs_diff = abs(diff)
# 判定好坏:
is_better = (cfg['higher_is_better'] and diff > 0) or (not cfg['higher_is_better'] and diff < 0)
is_worse = not is_better
# 确定程度
th = cfg['th']
if abs_diff < th[0]:
degree = "Negligible"
elif abs_diff < th[1]:
degree = "Slight"
elif abs_diff < th[2]:
degree = "Significant"
else:
degree = "Severe"
# 组装文案
header_arrow = r"$\uparrow$" if cfg['higher_is_better'] else r"$\downarrow$"
if degree == "Negligible":
analysis_text = f"Negligible change (diff < {th[0]:.4f})."
elif is_worse:
analysis_text = f"{degree} degradation."
else:
analysis_text = f"Unexpected {degree} change."
return header_arrow, analysis_text, degree
def generate_visual_report(
ref_dir: str,
clean_dir: str,
pert_dir: str,
clean_metrics: Dict,
pert_metrics: Dict,
output_path: str
):
"""渲染并保存专业对比分析报告 (PNG)。"""
def get_sample(d):
if not os.path.exists(d): return None, "N/A"
files = [f for f in os.listdir(d) if f.lower().endswith(('.png','.jpg'))]
if not files: return None, "Empty"
return Image.open(os.path.join(d, files[0])).convert("RGB"), files[0]
img_ref, name_ref = get_sample(ref_dir)
img_clean, name_clean = get_sample(clean_dir)
img_pert, name_pert = get_sample(pert_dir)
# 布局设置
# 增加高度以容纳文本
fig = plt.figure(figsize=(12, 16.5), dpi=120)
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.5], hspace=0.25, wspace=0.1)
# 1. 图像展示区
ax_ref = fig.add_subplot(gs[0, :])
if img_ref:
ax_ref.imshow(img_ref)
ax_ref.set_title(f"Reference Image ($X$)\n{name_ref}", fontsize=12, fontweight='bold', pad=10)
ax_ref.axis('off')
ax_c = fig.add_subplot(gs[1, 0])
if img_clean:
ax_c.imshow(img_clean)
ax_c.set_title(f"Clean Output ($Y$)\n{name_clean}", fontsize=12, fontweight='bold', pad=10)
ax_c.axis('off')
ax_p = fig.add_subplot(gs[1, 1])
if img_pert:
ax_p.imshow(img_pert)
ax_p.set_title(f"Perturbed Output ($Y'$)\n{name_pert}", fontsize=12, fontweight='bold', pad=10)
ax_p.axis('off')
# 2. 数据表格与分析区
ax_data = fig.add_subplot(gs[2, :])
ax_data.axis('off')
metrics_list = ['FID', 'SSIM', 'PSNR', 'FDS', 'CLIP_IQS', 'BRISQUE']
table_data = []
analysis_lines = []
degradation_score = 0
# 遍历指标生成数据和分析
for m in metrics_list:
c_val = clean_metrics.get(m, np.nan)
p_val = pert_metrics.get(m, np.nan)
c_str = f"{c_val:.4f}" if not np.isnan(c_val) else "N/A"
p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "N/A"
diff_str = "-"
header_arrow = ""
if not np.isnan(c_val) and not np.isnan(p_val):
# 获取深度分析
header_arrow, text_desc, degree = analyze_metric_diff(m, c_val, p_val)
# 计算差异值
diff = p_val - c_val
# 差异值本身的符号 (Diff > 0 或 Diff < 0)
diff_arrow = r"$\nearrow$" if diff > 0 else r"$\searrow$"
if abs(diff) < 1e-4: diff_arrow = r"$\rightarrow$"
diff_str = f"{diff:+.4f} {diff_arrow}"
analysis_lines.append(f"{m}: Change {diff:+.4f}. Analysis: {text_desc}")
# 累计降级分数
cfg = METRIC_ANALYSIS_META.get(m)
is_worse = (cfg['higher_is_better'] and diff < 0) or (not cfg['higher_is_better'] and diff > 0)
if is_worse:
degradation_score += ANALYSIS_WEIGHTS.get(degree, 0)
# 表格第一列:名称 + 期望方向箭头
name_with_arrow = f"{m} ({header_arrow})" if header_arrow else m
table_data.append([name_with_arrow, c_str, p_str, diff_str])
# 绘制表格
table = ax_data.table(
cellText=table_data,
colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\Delta$)"],
loc='upper center',
cellLoc='center',
colWidths=[0.25, 0.25, 0.25, 0.25]
)
table.scale(1, 2.0)
table.set_fontsize(11)
# 美化表头
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_text_props(weight='bold', color='white')
cell.set_facecolor('#404040')
elif col == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f5f5f5')
# 3. 底部综合分析文本框
if not analysis_lines:
analysis_lines.append("• All metrics are missing or invalid.")
full_text = "Quantitative Difference Analysis:\n" + "\n".join(analysis_lines)
# 总体结论判断 (基于 holistic degradation score)
conclusion = "\n\n>>> EXECUTIVE SUMMARY (Holistic Judgment):\n"
if degradation_score >= 8:
conclusion += "CRITICAL DEGRADATION. Significant quality loss observed. Attack highly effective."
elif degradation_score >= 4:
conclusion += "MODERATE DEGRADATION. Observable quality drop in key metrics. Attack effective."
elif degradation_score > 0:
conclusion += "MINOR DEGRADATION. Slight quality loss detected. Attack partially effective."
else:
conclusion += "INEFFECTIVE ATTACK. No significant or unexpected statistical quality loss observed."
full_text += conclusion
ax_data.text(
0.05,
0.30,
full_text,
ha='left',
va='top',
fontsize=12, family='monospace', wrap=True,
transform=ax_data.transAxes
)
fig.suptitle("Museguard: Quality Assurance Report", fontsize=18, fontweight='bold', y=0.95)
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
print(f"\n[Success] 报告已生成: {output_path}")
# -----------------------------------------------------------------------------
# 主入口
# -----------------------------------------------------------------------------
def main():
parser = ArgumentParser()
parser.add_argument('--clean_output_dir', type=str, required=True)
parser.add_argument('--perturbed_output_dir', type=str, required=True)
parser.add_argument('--clean_ref_dir', type=str, required=True)
parser.add_argument('--png_output_path', type=str, required=True)
parser.add_argument('--size', type=int, default=512)
args = parser.parse_args()
Path(args.png_output_path).parent.mkdir(parents=True, exist_ok=True)
print("========================================")
print(" Image Quality Evaluation Toolkit")
print("========================================")
# 1. 计算 Clean 组
print(f"\n[1/2] Evaluating Clean Set: {os.path.basename(args.clean_output_dir)}")
c_metrics = calculate_metrics(args.clean_ref_dir, args.clean_output_dir, args.size)
if c_metrics:
c_metrics['BRISQUE'] = run_brisque_cleanly(args.clean_output_dir)
# 2. 计算 Perturbed 组
print(f"\n[2/2] Evaluating Perturbed Set: {os.path.basename(args.perturbed_output_dir)}")
p_metrics = calculate_metrics(args.clean_ref_dir, args.perturbed_output_dir, args.size)
if p_metrics:
p_metrics['BRISQUE'] = run_brisque_cleanly(args.perturbed_output_dir)
# 3. 生成报告
if c_metrics and p_metrics:
generate_visual_report(
args.clean_ref_dir,
args.clean_output_dir,
args.perturbed_output_dir,
c_metrics,
p_metrics,
args.png_output_path
)
else:
print("\n[Fatal] 评估数据不完整,中止报告生成。")
if __name__ == '__main__':
main()

@ -102,7 +102,7 @@ def login():
return jsonify({'error': '账户已被禁用'}), 401
# 创建访问令牌 - 确保用户ID为字符串类型
access_token = create_access_token(identity=str(user.id))
access_token = create_access_token(identity=str(user.user_id))
return jsonify({
'message': '登录成功',

@ -143,7 +143,7 @@ def create_perturbation_task(current_user_id):
return TaskService.json_error('非法的 flow_id 参数')
try:
pending_status = TaskService.ensure_status('pending')
waiting_status = TaskService.ensure_status('waiting')
perturb_type = TaskService.require_task_type('perturbation')
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
@ -153,7 +153,7 @@ def create_perturbation_task(current_user_id):
flow_id=flow_id,
tasks_type_id=perturb_type.task_type_id,
user_id=current_user_id,
tasks_status_id=pending_status.task_status_id,
tasks_status_id=waiting_status.task_status_id,
description=data.get('description')
)
db.session.add(task)
@ -270,9 +270,14 @@ def create_heatmap_task(current_user_id):
if not perturbed_image or perturbed_image.task_id != perturbation_task_id:
return TaskService.json_error('扰动图片不存在或不属于该任务')
if perturbed_image.image_type:
image_code = perturbed_image.image_type.image_code
if image_code != 'perturbed':
return TaskService.json_error(f'仅支持加噪图生成热力图,当前图片类型为: {perturbed_image.image_type.image_name}', 400)
try:
heatmap_type = TaskService.require_task_type('heatmap')
pending_status = TaskService.ensure_status('pending')
waiting_status = TaskService.ensure_status('waiting')
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
@ -281,7 +286,7 @@ def create_heatmap_task(current_user_id):
flow_id=perturbation_task.flow_id,
tasks_type_id=heatmap_type.task_type_id,
user_id=current_user_id,
tasks_status_id=pending_status.task_status_id,
tasks_status_id=waiting_status.task_status_id,
description=data.get('description')
)
db.session.add(task)
@ -311,7 +316,7 @@ def start_heatmap_task(task_id, current_user_id):
if not task.heatmap:
return TaskService.json_error('热力图任务未配置对应图片', 400)
job_id = TaskService.start_heatmap_task(task_id, task.heatmap.images_id)
job_id = TaskService.start_heatmap_task(task_id)
if not job_id:
return TaskService.json_error('任务启动失败', 500)
return jsonify({'message': '任务已启动', 'job_id': job_id}), 200
@ -372,7 +377,7 @@ def create_finetune_from_perturbation(current_user_id):
return TaskService.json_error('微调配置不存在')
try:
pending_status = TaskService.ensure_status('pending')
waiting_status = TaskService.ensure_status('waiting')
finetune_type = TaskService.require_task_type('finetune')
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
@ -382,7 +387,7 @@ def create_finetune_from_perturbation(current_user_id):
flow_id=perturbation_task.flow_id,
tasks_type_id=finetune_type.task_type_id,
user_id=current_user_id,
tasks_status_id=pending_status.task_status_id,
tasks_status_id=waiting_status.task_status_id,
description=data.get('description')
)
db.session.add(task)
@ -443,7 +448,7 @@ def create_finetune_from_upload(current_user_id):
return TaskService.json_error('非法的 flow_id 参数')
try:
pending_status = TaskService.ensure_status('pending')
waiting_status = TaskService.ensure_status('waiting')
finetune_type = TaskService.require_task_type('finetune')
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
@ -453,7 +458,7 @@ def create_finetune_from_upload(current_user_id):
flow_id=flow_id,
tasks_type_id=finetune_type.task_type_id,
user_id=current_user_id,
tasks_status_id=pending_status.task_status_id,
tasks_status_id=waiting_status.task_status_id,
description=data.get('description')
)
db.session.add(task)
@ -487,7 +492,24 @@ def start_finetune_task(task_id, current_user_id):
job_id = TaskService.start_finetune_task(task_id)
if not job_id:
return TaskService.json_error('任务启动失败', 500)
return jsonify({'message': '任务已启动', 'job_id': job_id}), 200
# 处理返回的 job_id可能是单个或多个用逗号分隔
if ',' in job_id:
# 基于加噪的微调:返回两个 job_id
job_ids = job_id.split(',')
return jsonify({
'message': '微调任务已启动(原图和加噪图)',
'job_id': job_id,
'job_ids': job_ids,
'type': 'perturbation-based'
}), 200
else:
# 上传图片的微调:返回单个 job_id
return jsonify({
'message': '微调任务已启动',
'job_id': job_id,
'type': 'uploaded'
}), 200
@task_bp.route('/finetune', methods=['GET'])
@ -545,7 +567,7 @@ def create_evaluate_task(current_user_id):
try:
evaluate_type = TaskService.require_task_type('evaluate')
pending_status = TaskService.ensure_status('pending')
waiting_status = TaskService.ensure_status('waiting')
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
@ -554,7 +576,7 @@ def create_evaluate_task(current_user_id):
flow_id=finetune_task.flow_id,
tasks_type_id=evaluate_type.task_type_id,
user_id=current_user_id,
tasks_status_id=pending_status.task_status_id,
tasks_status_id=waiting_status.task_status_id,
description=data.get('description')
)
db.session.add(task)

@ -130,7 +130,7 @@ class TaskStatus(db.Model):
"""任务状态表"""
__tablename__ = 'task_status'
task_status_id = db.Column(Integer, primary_key=True, autoincrement=True)
task_status_code = db.Column(String(50), nullable=False, comment='状态代码 (Pending, Processing, Done, Failed)')
task_status_code = db.Column(String(50), nullable=False, comment='状态代码 (Waiting, Processing, Completed, Failed)')
task_status_name = db.Column(String(100), nullable=False)
description = db.Column(Text)

@ -115,6 +115,13 @@ class TaskService:
str(user_id),
str(flow_id)
)
@staticmethod
def get_model_data_path(user_id, flow_id):
"""模型数据路径: MODEL_DATA_FOLDER/user_id/flow_id"""
return TaskService._build_path(
Config.MODEL_DATA_FOLDER
)
# ==================== 通用辅助功能 ====================
@ -354,11 +361,16 @@ class TaskService:
logger.warning(f"Could not cancel RQ job: {e}")
# 更新数据库状态
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
try:
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
except Exception as e:
db.session.rollback()
logger.error(f"Failed to update task status: {e}")
return False
return True
@ -406,7 +418,7 @@ class TaskService:
logger.error(f"Perturbation config not found")
return None
algorithm_code = pert_config.perturbation_algorithm_code
algorithm_code = pert_config.perturbation_code
# 加入RQ队列
from app.workers.perturbation_worker import run_perturbation_task
@ -421,7 +433,7 @@ class TaskService:
output_dir=output_dir,
class_dir=class_dir,
algorithm_code=algorithm_code,
epsilon=pert_config.epsilon,
epsilon=perturbation.perturbation_intensity,
job_id=job_id,
job_timeout='4h'
)
@ -479,17 +491,19 @@ class TaskService:
return None
# 检测微调类型查找相同flow_id的Perturbation任务
perturbation_tasks = Task.query.filter(
perturb_type = TaskService.require_task_type('perturbation')
sibling_perturbation = Task.query.filter(
Task.flow_id == task.flow_id,
Task.tasks_type_id == 1, # perturbation类型
Task.tasks_type_id == perturb_type.task_type_id,
Task.tasks_id != task_id
).all()
).first()
has_perturbation = len(perturbation_tasks) > 0
has_perturbation = sibling_perturbation is not None
# 路径配置
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
class_dir = TaskService.get_class_data_path(user_id, task.flow_id)
model_data_dir = TaskService.get_model_data_path(user_id, task.flow_id)
if has_perturbation:
# 类型1基于加噪结果的微调
@ -499,52 +513,66 @@ class TaskService:
original_input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, task_id)
original_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, task_id)
# 获取坐标保存路径3D可视化
coords_save_path = TaskService._build_path(
original_coords_save_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(user_id),
str(task.flow_id),
str(task_id),
'coords.json'
'original_coords.json'
)
# 获取加噪坐标保存路径3D可视化
perturbed_coords_save_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(user_id),
str(task.flow_id),
str(task_id),
'perturbed_coords.json'
)
# 加入RQ队列
from app.workers.finetune_worker import run_finetune_task
queue = TaskService._get_queue()
job_id = f"ft_{task_id}"
job_id_original = f"ft_{task_id}_original"
job_id_perturbed = f"ft_{task_id}_perturbed"
job_original = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=original_input_dir,
output_model_dir=original_output_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=original_input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=original_output_dir,
coords_save_path=original_coords_save_path,
validation_output_dir=original_output_dir,
is_perturbed=False,
custom_params=None,
job_id=job_id,
job_id=job_id_original,
job_timeout='8h'
)
job_perturbed = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=perturbed_input_dir,
output_model_dir=perturbed_output_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=perturbed_input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=perturbed_output_dir,
coords_save_path=perturbed_coords_save_path,
validation_output_dir=perturbed_output_dir,
is_perturbed=True,
custom_params=None,
job_id=job_id,
job_id=job_id_perturbed,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}")
return f"{job_id_original},{job_id_perturbed}"
else:
# 类型2用户上传图片的微调
logger.info(f"Finetune task {task_id}: type=uploaded")
@ -569,20 +597,20 @@ class TaskService:
job = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=input_dir,
output_model_dir=uploaded_output_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=uploaded_output_dir,
validation_output_dir=uploaded_output_dir,
is_perturbed=False,
custom_params=None,
job_id=job_id,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
return job_id
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting finetune task: {e}")
@ -591,13 +619,12 @@ class TaskService:
# ==================== Heatmap 任务 ====================
@staticmethod
def start_heatmap_task(task_id, perturbed_image_id):
def start_heatmap_task(task_id):
"""
启动热力图任务
Args:
task_id: 任务ID
perturbed_image_id: 扰动图片ID
Returns:
job_id
@ -615,13 +642,19 @@ class TaskService:
logger.error(f"Heatmap task {task_id} not found")
return None
# 从heatmap对象获取扰动图片ID
perturbed_image_id = heatmap.images_id
if not perturbed_image_id:
logger.error(f"Heatmap task {task_id} has no associated perturbed image")
return None
# 获取扰动图片信息
perturbed_image = Image.query.get(perturbed_image_id)
if not perturbed_image:
logger.error(f"Perturbed image {perturbed_image_id} not found")
return None
user_id = perturbed_image.user_id
user_id = task.user_id
# 获取原图通过father_id关系
if not perturbed_image.father_id:
@ -633,31 +666,20 @@ class TaskService:
logger.error(f"Original image not found")
return None
# 构建图片路径
original_image_path = TaskService._build_path(
Config.ORIGINAL_IMAGES_FOLDER,
str(user_id),
str(task.flow_id),
original_image.image_name
# 构建图片路径(使用 stored_filename
original_image_path = os.path.join(
TaskService.get_original_images_path(user_id, task.flow_id),
original_image.stored_filename
)
perturbed_image_path = TaskService._build_path(
Config.PERTURBED_IMAGES_FOLDER,
str(user_id),
str(task.flow_id),
perturbed_image.image_name
perturbed_image_path = os.path.join(
TaskService.get_perturbed_images_path(user_id, task.flow_id),
perturbed_image.stored_filename
)
# 输出目录
output_dir = TaskService.get_heatmap_path(user_id, task.flow_id, task_id)
# 获取模型路径
sd_version = AlgorithmConfig.STABLE_DIFFUSION_VERSION
model_path = AlgorithmConfig.SD_MODEL_PATHS.get(sd_version)
if not model_path:
logger.error(f"Model path not found for SD version {sd_version}")
return None
# 加入RQ队列
from app.workers.heatmap_worker import run_heatmap_task
@ -670,7 +692,6 @@ class TaskService:
original_image_path=original_image_path,
perturbed_image_path=perturbed_image_path,
output_dir=output_dir,
model_path=model_path,
perturbed_image_id=perturbed_image_id,
job_id=job_id,
job_timeout='2h'

@ -18,7 +18,7 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images_dir,
def run_finetune_task(task_id, finetune_method, train_images_dir,
output_model_dir, class_dir, coords_save_path, validation_output_dir,
is_perturbed=False, custom_params=None):
"""
@ -26,7 +26,6 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
Args:
task_id: 任务ID
finetune_config_id: 微调配置ID
finetune_method: 微调方法 (dreambooth, lora, textual_inversion)
train_images_dir: 训练图片目录
output_model_dir: 模型输出目录
@ -54,13 +53,9 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
# 获取微调任务详情
finetune = Finetune.query.filter_by(
tasks_id=task_id,
finetune_configs_id=finetune_config_id
tasks_id=task_id
).first()
if not finetune:
raise ValueError(f"Finetune task ({task_id}, {finetune_config_id}) not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
@ -69,7 +64,6 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting finetune task {task_id} (config: {finetune_config_id})")
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
# 从数据库获取数据集类型的提示词
@ -89,7 +83,7 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
# 清空输出目录(避免旧文件残留)
logger.info(f"Clearing output directories...")
for dir_path in [output_model_dir, validation_output_dir, coords_save_path]:
for dir_path in [output_model_dir, validation_output_dir]:
if os.path.exists(dir_path):
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
@ -97,6 +91,16 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 清理旧的 coords.json 文件
if os.path.exists(coords_save_path):
if os.path.isfile(coords_save_path):
os.unlink(coords_save_path)
logger.info(f"Removed old coords file: {coords_save_path}")
# 确保 coords.json 的父目录存在
coords_dir = os.path.dirname(coords_save_path)
os.makedirs(coords_dir, exist_ok=True)
# 运行真实微调算法
result = _run_real_finetune(
@ -208,10 +212,6 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
else:
raise ValueError(f"Unsupported finetune method: {finetune_method}")
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他默认参数
for key, value in params.items():
if isinstance(value, bool):
@ -268,29 +268,48 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
if process.returncode != 0:
raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir和output_model_dir删除所有文件释放空间
logger.info(f"Cleaning class_dir and output_model_dir to save disk space...")
# 清理class_dir参考sh脚本
# 清理class_dir
if finetune_method in ['dreambooth', 'lora']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
logger.info(f"Removing all files in class_dir: {class_dir}")
shutil.rmtree(class_dir)
os.makedirs(class_dir)
# 清理output_model_dir中的非图片文件
logger.info(f"Cleaning non-image files in output directory: {output_model_dir}")
# 清理output_model_dir删除所有模型文件只保留验证图片在validation_output_dir
if os.path.exists(output_model_dir):
import shutil
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for item in os.listdir(output_model_dir):
item_path = os.path.join(output_model_dir, item)
if os.path.isfile(item_path):
_, ext = os.path.splitext(item)
if ext.lower() not in image_extensions:
try:
os.remove(item_path)
logger.info(f"Removed non-image file: {item}")
except Exception as e:
logger.warning(f"Failed to remove {item}: {str(e)}")
logger.info(f"Removing all files in output_model_dir: {output_model_dir}")
shutil.rmtree(output_model_dir)
os.makedirs(output_model_dir)
logger.info(f"Cleanup completed. Only validation images and coords.json are kept.")
# # 清理class_dir参考sh脚本
# if finetune_method in ['dreambooth', 'lora']:
# logger.info(f"Cleaning class directory: {class_dir}")
# if os.path.exists(class_dir):
# shutil.rmtree(class_dir)
# os.makedirs(class_dir)
# # 清理output_model_dir中的非图片文件
# logger.info(f"Cleaning non-image files in output directory: {output_model_dir}")
# if os.path.exists(output_model_dir):
# image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
# for item in os.listdir(output_model_dir):
# item_path = os.path.join(output_model_dir, item)
# if os.path.isfile(item_path):
# _, ext = os.path.splitext(item)
# if ext.lower() not in image_extensions:
# try:
# os.remove(item_path)
# logger.info(f"Removed non-image file: {item}")
# except Exception as e:
# logger.warning(f"Failed to remove {item}: {str(e)}")
return {
'status': 'success',

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
output_dir, model_path, perturbed_image_id=None):
output_dir, perturbed_image_id=None):
"""
执行热力图生成任务仅使用真实算法
@ -27,7 +27,6 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
original_image_path: 原始图片路径
perturbed_image_path: 扰动图片路径
output_dir: 输出目录
model_path: Stable Diffusion模型路径
perturbed_image_id: 扰动图片ID用于建立father关系
Returns:
@ -97,7 +96,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
# 运行真实热力图算法
result = _run_real_heatmap(
task_id, original_image_path, perturbed_image_path,
prompt_text, target_word, output_dir, model_path
prompt_text, target_word, output_dir
)
# 保存热力图文件到数据库
@ -132,7 +131,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
def _run_real_heatmap(task_id, original_image_path, perturbed_image_path,
prompt_text, target_word, output_dir, model_path):
prompt_text, target_word, output_dir):
"""运行真实热力图算法"""
from config.algorithm_config import AlgorithmConfig
@ -145,6 +144,14 @@ def _run_real_heatmap(task_id, original_image_path, perturbed_image_path,
if not script_path:
raise ValueError("Heatmap script not configured")
default_params = evaluate_config.get('default_params', {})
model_path = default_params.get('model_path')
if not model_path:
raise ValueError(f"{evaluate_config} ?{default_params} ?{model_path} Model path not configured in AlgorithmConfig.EVALUATE_SCRIPTS['heatmap']")
logger.info(f"Using model path from config: {model_path}")
# 构建命令行参数
cmd_args = [

@ -179,21 +179,21 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={epsilon}",
f"--pgd_eps={int(epsilon)}",
])
elif algorithm_code == 'caat':
# CAAT参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={epsilon}",
f"--eps={int(epsilon)}",
])
elif algorithm_code == 'pid':
# PID参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={epsilon}",
f"--eps={int(epsilon)}",
])
else:
raise ValueError(f"Unsupported algorithm code: {algorithm_code}")

@ -60,15 +60,15 @@ class AlgorithmConfig:
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 200,
'num_class_images': 5,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 2,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'max_train_steps': 1,
'max_f_train_steps': 1,
'max_adv_train_steps': 1,
'checkpointing_iterations': 1,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
@ -84,15 +84,15 @@ class AlgorithmConfig:
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 200,
'num_class_images': 5,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 2,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'max_f_train_steps': 1,
'max_adv_train_steps': 1,
'checkpointing_iterations': 1,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
@ -109,7 +109,7 @@ class AlgorithmConfig:
'resolution': 512,
'learning_rate': 1e-5,
'lr_warmup_steps': 0,
'max_train_steps': 10,
'max_train_steps': 2,
'hflip': True,
'mixed_precision': 'bf16',
'alpha': 5e-3
@ -122,7 +122,7 @@ class AlgorithmConfig:
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'max_train_steps': 10,
'max_train_steps': 2,
'center_crop': True,
'attack_type': 'add-log'
}
@ -173,16 +173,16 @@ class AlgorithmConfig:
'learning_rate': 2e-6,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 200,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'num_class_images': 5,
'max_train_steps': 4,
'checkpointing_steps': 2,
'center_crop': True,
'mixed_precision': 'bf16',
'prior_generation_precision': 'bf16',
'sample_batch_size': 5,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 10,
'validation_steps': 500,
'num_validation_images': 2,
'validation_steps': 1,
'coords_log_interval': 10
}
},
@ -202,14 +202,14 @@ class AlgorithmConfig:
'learning_rate': 1e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 200,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'num_class_images': 1,
'max_train_steps': 4,
'checkpointing_steps': 2,
'seed': 0,
'mixed_precision': 'fp16',
'rank': 4,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 10,
'num_validation_images': 2,
'coords_log_interval': 10
}
},
@ -228,8 +228,8 @@ class AlgorithmConfig:
'learning_rate': 5e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'max_train_steps': 4,
'checkpointing_steps': 2,
'seed': 0,
'mixed_precision': 'fp16',
'validation_prompt': 'a photo of sks person',
@ -252,7 +252,7 @@ class AlgorithmConfig:
'virtual_script': None,
'conda_env': CONDA_ENVS['pid'], # 使用与微调相同的环境
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'model_path': MODELS_DIR['model2'],
}
},
'numbers': {

@ -68,6 +68,7 @@ class Config:
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
# 微调训练相关配置
CLASS_DATA_FOLDER = os.path.join(STATIC_ROOT, 'class') # 类别数据目录(用于 prior preservation
MODEL_DATA_FOLDER = os.path.join(STATIC_ROOT, 'model_data') # 模型数据目录(用于微调训练数据存储)
# 可视化与分析配置
EVA_RES_FOLDER = os.path.join(STATIC_ROOT, 'eva_res') # 评估结果根目录
COORDS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 3D坐标可视化数据用于训练轨迹

Loading…
Cancel
Save