将lianghao_branch合并到develop #7

Merged
hnu202326010204 merged 15 commits from lianghao_branch into develop 1 month ago

19
.gitignore vendored

@ -1,8 +1,7 @@
# Python 编译缓存
__pycache__/
venv/
python=3.11/
# 图片文件
*.png
*.jpg
*.jpeg
@ -10,11 +9,21 @@ python=3.11/
# 环境配置文件(包含敏感信息)
*.env
# 日志文件
# 日志及进程文件
logs/
*.log
*.pid
# 上传文件临时目录
uploads/
.github/
# 微调生成文件
*.json
*.bin
*.pkl
*.safetensors
*.pt
*.txt
# vscode 配置
.vscode/

@ -0,0 +1,520 @@
"""Stable Diffusion 注意力热力图差异可视化工具 (可靠版 - 语义阶段聚合)。
本模块使用一种健壮的方法通过在 Stable Diffusion 扩散模型U-Net
**早期时间步 (语义阶段)** 捕获并累加交叉注意力权重这种方法能确保捕获到的
注意力图信号集中且可靠用于对比分析干净输入和扰动输入生成的图像对模型
注意力机制的影响差异
典型用法:
python eva_gen_heatmap.py \\
--model_path /path/to/sd_model \\
--image_path_a /path/to/clean_image.png \\
--image_path_b /path/to/noisy_image.png \\
--prompt_text "a photo of sks person" \\
--target_word "sks" \\
--output_dir output/heatmap_reports
"""
# 通用参数解析与文件路径管理
import argparse
import os
from pathlib import Path
from typing import Dict, Any, List, Tuple
# 数值计算与深度学习依赖
import torch
import torch.nn.functional as F
import numpy as np
import itertools
import warnings
# 可视化依赖
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Diffusers 与 Transformers 依赖
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import Attention
from transformers import CLIPTokenizer
# 图像处理与数据读取
from PIL import Image
from torchvision import transforms
# 抑制非必要的警告输出
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# ============== 核心模块:注意力捕获与聚合 ==============
class AttentionMapProcessor:
"""自定义注意力处理器,用于捕获 U-Net 交叉注意力层的权重。
通过替换原始的 `Attention` 模块处理器该类在模型前向传播过程中
将所有交叉注意力层的注意力权重`attention_probs`捕获并存储
Attributes:
attention_maps (Dict[str, List[torch.Tensor]]): 存储捕获到的注意力图
键为层名称值为该层在不同时间步捕获到的注意力图列表
pipeline (StableDiffusionPipeline): 正在处理的 Stable Diffusion 管线
original_processors (Dict[str, Any]): 存储原始的注意力处理器用于恢复
current_layer_name (Optional[str]): 当前正在处理的注意力层的名称
"""
def __init__(self, pipeline: StableDiffusionPipeline):
"""初始化注意力处理器。
Args:
pipeline: Stable Diffusion 模型管线实例
"""
self.attention_maps: Dict[str, List[torch.Tensor]] = {}
self.pipeline = pipeline
self.original_processors = {}
self.current_layer_name = None
self._set_processors()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""重载 __call__ 方法,执行注意力计算并捕获权重。
此方法替代了原始的 `Attention.processor`在计算交叉注意力时进行捕获
Args:
attn: 当前的 `Attention` 模块实例
hidden_states: U-Net 隐状态 (query)
encoder_hidden_states: 文本编码器输出 (key/value)即交叉注意力输入
attention_mask: 注意力掩码
Returns:
计算后的输出隐状态
"""
# 如果不是交叉注意力(即 encoder_hidden_states 为 None则调用原始处理器
if encoder_hidden_states is None:
return attn.processor(
attn, hidden_states, encoder_hidden_states, attention_mask
)
# 1. 计算 Q, K, V
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 2. 准备矩阵乘法
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
# 3. 计算 Attention Scores (Q @ K^T)
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0], query.shape[1], key.shape[1],
dtype=query.dtype, device=query.device
),
query,
key.transpose(1, 2),
beta=0,
alpha=attn.scale,
)
# 4. 计算 Attention Probabilities
attention_probs = attention_scores.softmax(dim=-1)
layer_name = self.current_layer_name
# 5. 存储捕获的注意力图
if layer_name not in self.attention_maps:
self.attention_maps[layer_name] = []
# 存储当前时间步的注意力权重
self.attention_maps[layer_name].append(attention_probs.detach().cpu())
# 6. 计算输出 (Attention @ V)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# 7. 输出层
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def _set_processors(self):
"""注册自定义处理器,捕获 U-Net 中所有交叉注意力层的权重。
遍历 U-Net 的所有子模块找到所有交叉注意力层`Attention` 且名称包含 `attn2`
并将其处理器替换为当前的实例
"""
for name, module in self.pipeline.unet.named_modules():
if isinstance(module, Attention) and 'attn2' in name:
# 存储原始处理器以便后续恢复
self.original_processors[name] = module.processor
# 定义一个新的闭包函数,用于在调用前设置当前层的名称
def set_layer_name(current_name):
def new_call(*args, **kwargs):
self.current_layer_name = current_name
return self.__call__(*args, **kwargs)
return new_call
module.processor = set_layer_name(name)
def remove(self):
"""恢复 U-Net 的原始注意力处理器,清理钩子。"""
for name, original_processor in self.original_processors.items():
module = self.pipeline.unet.get_submodule(name)
module.processor = original_processor
self.attention_maps = {}
def aggregate_word_attention(
attention_maps: Dict[str, List[torch.Tensor]],
tokenizer: CLIPTokenizer,
target_word: str,
input_ids: torch.Tensor
) -> np.ndarray:
"""聚合所有层和语义时间步中目标词汇的注意力图,并进行归一化。
聚合步骤
1. 识别目标词汇对应的 Token 索引
2. 对每个层将所有捕获时间步的注意力图求平均
3. 提取目标 Token 对应的注意力子图并对 Token 维度求和 Attention Heads 求平均
4. 将不同分辨率的注意力图上采样到统一尺寸64x64
5. 对所有层的结果进行累加求和
6. 最终归一化到 [0, 1]
Args:
attention_maps: 包含各层和时间步捕获的注意力图的字典
tokenizer: CLIP 分词器实例
target_word: 需要聚焦的关键词
input_ids: Prompt 对应的 Token ID 张量
Returns:
最终聚合并上采样到 64x64 尺寸的注意力热力图 (NumPy 数组)
Raises:
ValueError: 如果无法在 Prompt 中找到目标词汇
RuntimeError: 如果未捕获到任何注意力数据
"""
# 1. 识别目标词汇的 Token 索引
prompt_tokens = tokenizer.convert_ids_to_tokens(
input_ids.squeeze().cpu().tolist()
)
target_lower = target_word.lower()
target_indices = []
for i, token in enumerate(prompt_tokens):
cleaned_token = token.replace('Ġ', '').replace('_', '').lower()
# 查找目标词汇或以目标词汇开头的 token 索引,并排除特殊 token
if (input_ids.squeeze()[i] not in tokenizer.all_special_ids and
(target_lower in cleaned_token or
cleaned_token.startswith(target_lower))):
target_indices.append(i)
if not target_indices:
print(f"[WARN] 目标词汇 '{target_word}' 未识别。请检查 Prompt 或 Target Word。")
raise ValueError("无法识别目标词汇的 token 索引。")
# 2. 聚合逻辑
all_attention_data = []
# U-Net 输出的最大分辨率64x64总像素点数
TARGET_SPATIAL_SIZE = 4096
TARGET_MAP_SIZE = 64
for layer_name, step_maps in attention_maps.items():
if not step_maps:
continue
# 对该层捕获的所有时间步求平均,形状: (batch, heads, spatial_res, target_tokens_len)
avg_map_over_time = torch.stack(step_maps).mean(dim=0)
# 移除批次维度 (假设 batch size = 1),形状: (heads, spatial_res, target_tokens_len)
attention_map = avg_map_over_time.squeeze(0)
# 提取目标 token 的注意力图。形状: (heads, spatial_res, target_indices_len)
target_token_maps = attention_map[:, :, target_indices]
# 对目标 token 求和 (dim=-1),对注意力头求平均 (dim=0),形状: (spatial_res,)
aggregated_map_flat = target_token_maps.sum(dim=-1).mean(dim=0).float()
# 3. 跨分辨率上采样
if aggregated_map_flat.shape[0] != TARGET_SPATIAL_SIZE:
# 当前图的尺寸16x16 (256) 或 32x32 (1024)
map_size = int(np.sqrt(aggregated_map_flat.shape[0]))
map_2d = aggregated_map_flat.reshape(map_size, map_size)
map_to_interp = map_2d.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
# 使用双线性插值上采样到 64x64
resized_map_2d = F.interpolate(
map_to_interp,
size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE),
mode='bilinear',
align_corners=False
)
resized_map_flat = resized_map_2d.squeeze().flatten()
all_attention_data.append(resized_map_flat)
else:
# 如果已经是 64x64直接使用
all_attention_data.append(aggregated_map_flat)
if not all_attention_data:
raise RuntimeError("未捕获到注意力数据。可能模型或参数设置有误。")
# 4. 对所有层的结果进行累加 (求和)
final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy()
# 5. 最终归一化到 [0, 1]
final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6)
map_size = int(np.sqrt(final_map_flat.shape[0]))
final_map_np = final_map_flat.reshape(map_size, map_size) # 64x64
return final_map_np
def get_attention_map_from_image(
pipeline: StableDiffusionPipeline,
image_path: str,
prompt_text: str,
target_word: str
) -> Tuple[Image.Image, np.ndarray]:
"""执行多时间步前向传播,捕获指定图片和 Prompt 的注意力图。
通过只运行扩散过程中的语义阶段早期时间步来确保捕获到的注意力权重
具有高信号质量
Args:
pipeline: Stable Diffusion 模型管线实例
image_path: 待处理的输入图片路径
prompt_text: 用于生成图片的 Prompt 文本
target_word: 需要聚焦和可视化的关键词
Returns:
包含 (原始图片, 最终上采样后的注意力图) 的元组
"""
print(f"\n-> 正在处理图片: {Path(image_path).name}")
image = Image.open(image_path).convert("RGB").resize((512, 512))
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
image_tensor = (
image_transform(image)
.unsqueeze(0)
.to(pipeline.device)
.to(pipeline.unet.dtype)
)
# 1. 编码到 Latent 空间
with torch.no_grad():
latent = (
pipeline.vae.encode(image_tensor).latent_dist.sample() *
pipeline.vae.config.scaling_factor
)
# 2. 编码 Prompt
text_input = pipeline.tokenizer(
prompt_text,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
input_ids = text_input.input_ids
with torch.no_grad():
# 获取文本嵌入
prompt_embeds = pipeline.text_encoder(
input_ids.to(pipeline.device)
)[0]
# 3. 定义语义时间步
scheduler = pipeline.scheduler
# 设置扩散步数 (例如 50 步)
scheduler.set_timesteps(50, device=pipeline.device)
# 只选择语义最丰富的早期 10 步进行捕获
semantic_steps = scheduler.timesteps[:10]
print(f"-> 正在对语义阶段的 {len(semantic_steps)} 个时间步进行注意力捕获...")
processor = AttentionMapProcessor(pipeline)
try:
# 4. 运行多步 UNet Forward Pass
with torch.no_grad():
# 在选定的语义时间步上运行 U-Net 预测
for t in semantic_steps:
pipeline.unet(latent, t, prompt_embeds, return_dict=False)
# 5. 聚合捕获到的数据
raw_map_np = aggregate_word_attention(
processor.attention_maps,
pipeline.tokenizer,
target_word,
input_ids
)
except Exception as e:
print(f"[ERROR] 注意力聚合失败: {e}")
# 确保清理钩子
raw_map_np = np.zeros(image.size)
finally:
processor.remove()
# 6. 注意力图上采样到图片尺寸 (512x512)
# PIL 进行上采样
heat_map_pil = Image.fromarray((raw_map_np * 255).astype(np.uint8))
heat_map_np_resized = (
np.array(heat_map_pil.resize(
image.size,
resample=Image.Resampling.LANCZOS # 使用高质量的 Lanczos 滤波器
)) / 255.0
)
return image, heat_map_np_resized
def main():
"""主函数,负责解析参数,加载模型,计算差异并生成可视化报告。"""
parser = argparse.ArgumentParser(description="SD 图片注意力差异可视化报告生成")
parser.add_argument("--model_path", type=str, required=True,
help="Stable Diffusion 模型本地路径。")
parser.add_argument("--image_path_a", type=str, required=True,
help="干净输入图片 (X) 路径。")
parser.add_argument("--image_path_b", type=str, required=True,
help="扰动输入图片 (X') 路径。")
parser.add_argument("--prompt_text", type=str, default="a photo of sks person",
help="用于生成图片的 Prompt 文本。")
parser.add_argument("--target_word", type=str, default="sks",
help="需要在注意力图中聚焦和可视化的关键词。")
parser.add_argument("--output_dir", type=str, default="output",
help="报告 PNG 文件的输出目录。")
args = parser.parse_args()
print(f"--- 正在生成 Stable Diffusion 注意力差异报告 ---")
# ---------------- 准备模型 ----------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
try:
# 加载 Stable Diffusion 管线
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path,
torch_dtype=dtype,
local_files_only=True,
safety_checker=None,
# 从子文件夹加载调度器配置
scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler")
).to(device)
except Exception as e:
print(f"[ERROR] 模型加载失败,请检查路径和环境依赖: {e}")
return
# ---------------- 获取数据 ----------------
# 获取干净图片 A 的注意力图 M_A
img_A, map_A = get_attention_map_from_image(pipe, args.image_path_a, args.prompt_text, args.target_word)
# 获取扰动图片 B 的注意力图 M_B
img_B, map_B = get_attention_map_from_image(pipe, args.image_path_b, args.prompt_text, args.target_word)
if map_A.shape != map_B.shape:
print("错误:注意力图尺寸不匹配。中止处理。")
return
# 计算差异图: Delta = M_A - M_B
diff_map = map_A - map_B
# 计算 L2 范数(差异距离)
l2_diff = np.linalg.norm(diff_map)
print(f"\n计算完毕,注意力图的 L2 范数差异值: {l2_diff:.4f}")
# ---------------- 绘制专业报告 ----------------
# 设置 Matplotlib 字体样式
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
'mathtext.fontset': 'cm'
})
fig = plt.figure(figsize=(12, 16), dpi=120)
# 3行 x 4列 网格布局,用于图片和图例的精确控制
gs = gridspec.GridSpec(3, 4, figure=fig,
height_ratios=[1, 1, 1.3],
hspace=0.3, wspace=0.1)
# --- 第一行:原始图片 ---
ax_img_a = fig.add_subplot(gs[0, 0:2])
ax_img_b = fig.add_subplot(gs[0, 2:4])
# 干净图片
ax_img_a.imshow(img_A)
ax_img_a.set_title(f"Clean Image ($X$)\nFilename: {Path(args.image_path_a).name}", fontsize=14, pad=10)
ax_img_a.axis('off')
# 扰动图片
ax_img_b.imshow(img_B)
ax_img_b.set_title(f"Noisy Image ($X'$)\nFilename: {Path(args.image_path_b).name}", fontsize=14, pad=10)
ax_img_b.axis('off')
# --- 第二行:注意力热力图 (Jet配色) ---
ax_map_a = fig.add_subplot(gs[1, 0:2])
ax_map_b = fig.add_subplot(gs[1, 2:4])
# 注意力图 A
im_map_a = ax_map_a.imshow(map_A, cmap='jet', vmin=0, vmax=1)
ax_map_a.set_title(f"Attention Heatmap ($M_X$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
ax_map_a.axis('off')
# 注意力图 B
im_map_b = ax_map_b.imshow(map_B, cmap='jet', vmin=0, vmax=1)
ax_map_b.set_title(f"Attention Heatmap ($M_{{X'}}$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
ax_map_b.axis('off')
# 为注意力图 B 绘制颜色指示条
divider = make_axes_locatable(ax_map_b)
cax_map = divider.append_axes("right", size="5%", pad=0.05)
cbar1 = fig.colorbar(im_map_b, cax=cax_map)
cbar1.set_label('Attention Intensity', fontsize=10)
# --- 第三行:差异对比 (完美居中) ---
# 差异图在网格的中间两列
ax_diff = fig.add_subplot(gs[2, 1:3])
vmax_diff = np.max(np.abs(diff_map))
# 使用 TwoSlopeNorm 确保 0 值位于色条中央
norm_diff = TwoSlopeNorm(vmin=-vmax_diff, vcenter=0., vmax=vmax_diff)
# 使用 Coolwarm 配色,蓝色表示负差异 (M_X' > M_X),红色表示正差异 (M_X > M_X')
im_diff = ax_diff.imshow(diff_map, cmap='coolwarm', norm=norm_diff)
title_text = (
r"Difference Map: $\Delta = M_X - M_{X'}$" +
f"\n$L_2$ Norm Distance: $\mathbf{{{l2_diff:.4f}}}$"
)
ax_diff.set_title(title_text, fontsize=16, pad=12)
ax_diff.axis('off')
# 差异图颜色指示条 (居中对齐)
cbar2 = fig.colorbar(im_diff, ax=ax_diff, fraction=0.046, pad=0.04)
cbar2.set_label(r'Scale: Red ($+$) $\leftrightarrow$ Blue ($-$)', fontsize=12)
# ---------------- 整体修饰与保存 ----------------
fig.suptitle(f"Museguard: SD Attention Analysis Report", fontsize=20, fontweight='bold', y=0.95)
output_filename = "heatmap_dif.png"
output_path = Path(args.output_dir) / output_filename
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
print(f"\n专业分析报告已保存至:\n{output_path.resolve()}")
if __name__ == "__main__":
main()

@ -0,0 +1,513 @@
"""图像生成质量多维度评估工具 (专业重构版)。
本脚本用于对比评估两组图像Clean vs Perturbed的生成质量
支持生成包含指标对比表和深度差异分析的 PNG 报告
Style Guide: Google Python Style Guide
"""
import os
import time
import subprocess
import tempfile
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Dict, Optional, Tuple, Any
import torch
import clip
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
from torchvision import transforms
from facenet_pytorch import MTCNN, InceptionResnetV1
from piq import ssim, psnr
import torch_fidelity as fid
# 抑制非必要的警告输出
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# -----------------------------------------------------------------------------
# 全局配置与样式
# -----------------------------------------------------------------------------
# Matplotlib LaTeX 风格配置
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
'mathtext.fontset': 'cm',
'axes.unicode_minus': False
})
# 指标元数据配置:定义指标目标方向和分析阈值
METRIC_ANALYSIS_META = {
'FID': {'higher_is_better': False, 'th': [2.0, 10.0, 30.0]},
'SSIM': {'higher_is_better': True, 'th': [0.01, 0.05, 0.15]},
'PSNR': {'higher_is_better': True, 'th': [0.5, 2.0, 5.0]},
'FDS': {'higher_is_better': True, 'th': [0.02, 0.05, 0.1]},
'CLIP_IQS': {'higher_is_better': True, 'th': [0.01, 0.03, 0.08]},
'BRISQUE': {'higher_is_better': False, 'th': [2.0, 5.0, 10.0]},
}
# 用于综合分析的降级权重
ANALYSIS_WEIGHTS = {'Severe': 3, 'Significant': 2, 'Slight': 1, 'Negligible': 0}
# -----------------------------------------------------------------------------
# 模型加载 (惰性加载或全局预加载)
# -----------------------------------------------------------------------------
try:
CLIP_MODEL, CLIP_PREPROCESS = clip.load('ViT-B/32', 'cuda')
CLIP_MODEL.eval()
except Exception as e:
print(f"[Warning] CLIP 模型加载失败: {e}")
CLIP_MODEL, CLIP_PREPROCESS = None, None
def _get_clip_text_features(text: str) -> torch.Tensor:
"""辅助函数:获取文本的 CLIP 特征。"""
if CLIP_MODEL is None:
return None
tokens = clip.tokenize(text).to('cuda')
with torch.no_grad():
features = CLIP_MODEL.encode_text(tokens)
features /= features.norm(dim=-1, keepdim=True)
return features
# -----------------------------------------------------------------------------
# 核心计算逻辑
# -----------------------------------------------------------------------------
def calculate_metrics(
ref_dir: str,
gen_dir: str,
image_size: int = 512
) -> Dict[str, float]:
"""计算图像集之间的多项质量评估指标。
包括 FDS, SSIM, PSNR, CLIP_IQS, FID
Args:
ref_dir: 参考图片目录路径
gen_dir: 生成图片目录路径
image_size: 图像处理尺寸
Returns:
包含各项指标名称和数值的字典若目录无效返回空字典
"""
metrics = {}
# 1. 数据加载
def load_images(directory):
imgs = []
if os.path.exists(directory):
for f in os.listdir(directory):
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
path = os.path.join(directory, f)
imgs.append(Image.open(path).convert("RGB"))
except Exception:
pass
return imgs
ref_imgs = load_images(ref_dir)
gen_imgs = load_images(gen_dir)
if not ref_imgs or not gen_imgs:
print(f"[Error] 图片加载失败或目录为空: \nRef: {ref_dir}\nGen: {gen_dir}")
return {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
# --- FDS (Face Detection Similarity) ---
print(">>> 计算 FDS...")
mtcnn = MTCNN(image_size=image_size, margin=0, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
def get_face_embeds(img_list):
embeds = []
for img in img_list:
face = mtcnn(img)
if face is not None:
embeds.append(resnet(face.unsqueeze(0).to(device)))
return torch.stack(embeds) if embeds else None
ref_embeds = get_face_embeds(ref_imgs)
gen_embeds = get_face_embeds(gen_imgs)
if ref_embeds is not None and gen_embeds is not None:
# 计算生成集每张脸与参考集所有脸的余弦相似度均值
sims = []
for g_emb in gen_embeds:
sim = torch.cosine_similarity(g_emb, ref_embeds).mean()
sims.append(sim)
metrics['FDS'] = torch.tensor(sims).mean().item()
else:
metrics['FDS'] = 0.0
# 清理显存
del mtcnn, resnet
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- SSIM & PSNR ---
print(">>> 计算 SSIM & PSNR...")
tfm = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
# 将参考集堆叠为 [N, C, H, W]
ref_tensor = torch.stack([tfm(img) for img in ref_imgs]).to(device)
ssim_accum, psnr_accum = 0.0, 0.0
for img in gen_imgs:
gen_tensor = tfm(img).unsqueeze(0).to(device) # [1, C, H, W]
# 扩展维度以匹配参考集
gen_expanded = gen_tensor.expand_as(ref_tensor)
# 计算单张生成图相对于整个参考集的平均结构相似度
val_ssim = ssim(gen_expanded, ref_tensor, data_range=1.0)
val_psnr = psnr(gen_expanded, ref_tensor, data_range=1.0)
ssim_accum += val_ssim.item()
psnr_accum += val_psnr.item()
metrics['SSIM'] = ssim_accum / len(gen_imgs)
metrics['PSNR'] = psnr_accum / len(gen_imgs)
# --- CLIP IQS ---
print(">>> 计算 CLIP IQS...")
if CLIP_MODEL:
iqs_accum = 0.0
txt_feat = _get_clip_text_features("good image")
for img in gen_imgs:
img_tensor = CLIP_PREPROCESS(img).unsqueeze(0).to(device)
img_feat = CLIP_MODEL.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
iqs_accum += (img_feat @ txt_feat.T).item()
metrics['CLIP_IQS'] = iqs_accum / len(gen_imgs)
else:
metrics['CLIP_IQS'] = np.nan
# --- FID ---
print(">>> 计算 FID...")
try:
fid_res = fid.calculate_metrics(
input1=ref_dir,
input2=gen_dir,
cuda=True,
fid=True,
verbose=False
)
metrics['FID'] = fid_res['frechet_inception_distance']
except Exception as e:
print(f"[Error] FID 计算异常: {e}")
metrics['FID'] = np.nan
return metrics
def run_brisque_cleanly(img_dir: str) -> float:
"""使用 subprocess 和临时目录优雅地执行外部 BRISQUE 脚本。
Args:
img_dir: 图像目录路径
Returns:
BRISQUE 分数若失败返回 NaN
"""
print(f">>> 计算 BRISQUE (External)...")
script_path = Path(__file__).parent / 'libsvm' / 'python' / 'brisquequality.py'
if not script_path.exists():
print(f"[Error] 找不到 BRISQUE 脚本: {script_path}")
return np.nan
abs_img_dir = os.path.abspath(img_dir)
with tempfile.TemporaryDirectory() as temp_dir:
try:
cmd = [
"python", str(script_path),
abs_img_dir,
temp_dir
]
# 在脚本所在目录执行
subprocess.run(
cmd,
cwd=script_path.parent,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 读取临时生成的日志文件
log_file = Path(temp_dir) / 'log.txt'
if log_file.exists():
content = log_file.read_text(encoding='utf-8').strip()
try:
return float(content.split()[-1])
except ValueError:
return float(content)
else:
return np.nan
except Exception as e:
print(f"[Error] BRISQUE 执行出错: {e}")
return np.nan
# -----------------------------------------------------------------------------
# 报告可视化与分析逻辑
# -----------------------------------------------------------------------------
def analyze_metric_diff(
metric_name: str,
clean_val: float,
pert_val: float
) -> Tuple[str, str, str]:
"""生成科学的分级差异分析文本。
Args:
metric_name: 指标名称
clean_val: 干净图得分
pert_val: 扰动图得分
Returns:
(表头箭头符号, 差异描述文本, 状态等级)
"""
cfg = METRIC_ANALYSIS_META.get(metric_name)
if not cfg:
return "-", "Configuration not found.", "Negligible"
diff = pert_val - clean_val
abs_diff = abs(diff)
# 判定好坏:
is_better = (cfg['higher_is_better'] and diff > 0) or (not cfg['higher_is_better'] and diff < 0)
is_worse = not is_better
# 确定程度
th = cfg['th']
if abs_diff < th[0]:
degree = "Negligible"
elif abs_diff < th[1]:
degree = "Slight"
elif abs_diff < th[2]:
degree = "Significant"
else:
degree = "Severe"
# 组装文案
header_arrow = r"$\uparrow$" if cfg['higher_is_better'] else r"$\downarrow$"
if degree == "Negligible":
analysis_text = f"Negligible change (diff < {th[0]:.4f})."
elif is_worse:
analysis_text = f"{degree} degradation."
else:
analysis_text = f"Unexpected {degree} change."
return header_arrow, analysis_text, degree
def generate_visual_report(
ref_dir: str,
clean_dir: str,
pert_dir: str,
clean_metrics: Dict,
pert_metrics: Dict,
output_path: str
):
"""渲染并保存专业对比分析报告 (PNG)。"""
def get_sample(d):
if not os.path.exists(d): return None, "N/A"
files = [f for f in os.listdir(d) if f.lower().endswith(('.png','.jpg'))]
if not files: return None, "Empty"
return Image.open(os.path.join(d, files[0])).convert("RGB"), files[0]
img_ref, name_ref = get_sample(ref_dir)
img_clean, name_clean = get_sample(clean_dir)
img_pert, name_pert = get_sample(pert_dir)
# 布局设置
# 增加高度以容纳文本
fig = plt.figure(figsize=(12, 16.5), dpi=120)
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.5], hspace=0.25, wspace=0.1)
# 1. 图像展示区
ax_ref = fig.add_subplot(gs[0, :])
if img_ref:
ax_ref.imshow(img_ref)
ax_ref.set_title(f"Reference Image ($X$)\n{name_ref}", fontsize=12, fontweight='bold', pad=10)
ax_ref.axis('off')
ax_c = fig.add_subplot(gs[1, 0])
if img_clean:
ax_c.imshow(img_clean)
ax_c.set_title(f"Clean Output ($Y$)\n{name_clean}", fontsize=12, fontweight='bold', pad=10)
ax_c.axis('off')
ax_p = fig.add_subplot(gs[1, 1])
if img_pert:
ax_p.imshow(img_pert)
ax_p.set_title(f"Perturbed Output ($Y'$)\n{name_pert}", fontsize=12, fontweight='bold', pad=10)
ax_p.axis('off')
# 2. 数据表格与分析区
ax_data = fig.add_subplot(gs[2, :])
ax_data.axis('off')
metrics_list = ['FID', 'SSIM', 'PSNR', 'FDS', 'CLIP_IQS', 'BRISQUE']
table_data = []
analysis_lines = []
degradation_score = 0
# 遍历指标生成数据和分析
for m in metrics_list:
c_val = clean_metrics.get(m, np.nan)
p_val = pert_metrics.get(m, np.nan)
c_str = f"{c_val:.4f}" if not np.isnan(c_val) else "N/A"
p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "N/A"
diff_str = "-"
header_arrow = ""
if not np.isnan(c_val) and not np.isnan(p_val):
# 获取深度分析
header_arrow, text_desc, degree = analyze_metric_diff(m, c_val, p_val)
# 计算差异值
diff = p_val - c_val
# 差异值本身的符号 (Diff > 0 或 Diff < 0)
diff_arrow = r"$\nearrow$" if diff > 0 else r"$\searrow$"
if abs(diff) < 1e-4: diff_arrow = r"$\rightarrow$"
diff_str = f"{diff:+.4f} {diff_arrow}"
analysis_lines.append(f"{m}: Change {diff:+.4f}. Analysis: {text_desc}")
# 累计降级分数
cfg = METRIC_ANALYSIS_META.get(m)
is_worse = (cfg['higher_is_better'] and diff < 0) or (not cfg['higher_is_better'] and diff > 0)
if is_worse:
degradation_score += ANALYSIS_WEIGHTS.get(degree, 0)
# 表格第一列:名称 + 期望方向箭头
name_with_arrow = f"{m} ({header_arrow})" if header_arrow else m
table_data.append([name_with_arrow, c_str, p_str, diff_str])
# 绘制表格
table = ax_data.table(
cellText=table_data,
colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\Delta$)"],
loc='upper center',
cellLoc='center',
colWidths=[0.25, 0.25, 0.25, 0.25]
)
table.scale(1, 2.0)
table.set_fontsize(11)
# 美化表头
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_text_props(weight='bold', color='white')
cell.set_facecolor('#404040')
elif col == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f5f5f5')
# 3. 底部综合分析文本框
if not analysis_lines:
analysis_lines.append("• All metrics are missing or invalid.")
full_text = "Quantitative Difference Analysis:\n" + "\n".join(analysis_lines)
# 总体结论判断 (基于 holistic degradation score)
conclusion = "\n\n>>> EXECUTIVE SUMMARY (Holistic Judgment):\n"
if degradation_score >= 8:
conclusion += "CRITICAL DEGRADATION. Significant quality loss observed. Attack highly effective."
elif degradation_score >= 4:
conclusion += "MODERATE DEGRADATION. Observable quality drop in key metrics. Attack effective."
elif degradation_score > 0:
conclusion += "MINOR DEGRADATION. Slight quality loss detected. Attack partially effective."
else:
conclusion += "INEFFECTIVE ATTACK. No significant or unexpected statistical quality loss observed."
full_text += conclusion
ax_data.text(
0.05,
0.30,
full_text,
ha='left',
va='top',
fontsize=12, family='monospace', wrap=True,
transform=ax_data.transAxes
)
fig.suptitle("Museguard: Quality Assurance Report", fontsize=18, fontweight='bold', y=0.95)
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
print(f"\n[Success] 报告已生成: {output_path}")
# -----------------------------------------------------------------------------
# 主入口
# -----------------------------------------------------------------------------
def main():
parser = ArgumentParser()
parser.add_argument('--clean_output_dir', type=str, required=True)
parser.add_argument('--perturbed_output_dir', type=str, required=True)
parser.add_argument('--clean_ref_dir', type=str, required=True)
parser.add_argument('--png_output_path', type=str, required=True)
parser.add_argument('--size', type=int, default=512)
args = parser.parse_args()
Path(args.png_output_path).parent.mkdir(parents=True, exist_ok=True)
print("========================================")
print(" Image Quality Evaluation Toolkit")
print("========================================")
# 1. 计算 Clean 组
print(f"\n[1/2] Evaluating Clean Set: {os.path.basename(args.clean_output_dir)}")
c_metrics = calculate_metrics(args.clean_ref_dir, args.clean_output_dir, args.size)
if c_metrics:
c_metrics['BRISQUE'] = run_brisque_cleanly(args.clean_output_dir)
# 2. 计算 Perturbed 组
print(f"\n[2/2] Evaluating Perturbed Set: {os.path.basename(args.perturbed_output_dir)}")
p_metrics = calculate_metrics(args.clean_ref_dir, args.perturbed_output_dir, args.size)
if p_metrics:
p_metrics['BRISQUE'] = run_brisque_cleanly(args.perturbed_output_dir)
# 3. 生成报告
if c_metrics and p_metrics:
generate_visual_report(
args.clean_ref_dir,
args.clean_output_dir,
args.perturbed_output_dir,
c_metrics,
p_metrics,
args.png_output_path
)
else:
print("\n[Fatal] 评估数据不完整,中止报告生成。")
if __name__ == '__main__':
main()

@ -1519,4 +1519,4 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
main(args)
main(args)

@ -1520,4 +1520,4 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
main(args)
main(args)

@ -0,0 +1,132 @@
import pandas as pd
from pathlib import Path
import sys
import logging
import numpy as np
import statsmodels.api as sm
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def apply_lowess_and_clipping_scaling(input_csv_path, output_csv_path, lowess_frac, target_range, clipping_percentile):
"""
应用 Lowess 局部加权回归进行平滑提取总体趋势然后使用百分位数裁剪后的 Min-Max 边界来缩放
目标生成最平滑最接近单调下降的客观趋势
"""
input_path = Path(input_csv_path)
output_path = Path(output_csv_path)
if not input_path.exists():
logging.error(f"错误:未找到输入文件 {input_csv_path}")
return
logging.info(f"读取原始数据: {input_csv_path}")
df = pd.read_csv(input_path)
df = df.loc[:,~df.columns.duplicated()].copy()
# 定义原始数据列名
raw_x_col = 'X_Feature_L2_Norm'
raw_y_col = 'Y_Feature_Variance'
raw_z_col = 'Z_LDM_Loss'
# --------------------------- 1. Lowess 局部加权回归平滑 (提取总体趋势) ---------------------------
logging.info(f"应用 Lowess 局部加权回归,平滑因子 frac={lowess_frac}")
x_coords = df['step'].values
for raw_col in [raw_x_col, raw_y_col, raw_z_col]:
y_coords = df[raw_col].values
smoothed_data = sm.nonparametric.lowess(
endog=y_coords,
exog=x_coords,
frac=lowess_frac,
it=0
)
df[f'{raw_col}_LOWESS'] = smoothed_data[:, 1]
# --------------------------- 2. 百分位数边界缩放与方向统一 ---------------------------
p = clipping_percentile
logging.info(f"应用百分位数边界 (p={p}) 进行线性缩放,目标范围 [0, {target_range:.2f}]")
scale_cols_map = {
'X_Feature_L2_Norm': f'{raw_x_col}_LOWESS',
'Y_Feature_Variance': f'{raw_y_col}_LOWESS',
'Z_LDM_Loss': f'{raw_z_col}_LOWESS'
}
for final_col, lowess_col in scale_cols_map.items():
data = df[lowess_col]
# 裁剪:计算裁剪后的 min/max (定义缩放窗口)
lower_bound = data.quantile(p)
upper_bound = data.quantile(1.0 - p)
min_val = lower_bound
max_val = upper_bound
data_range = max_val - min_val
if data_range <= 0 or data_range == np.nan:
df[final_col] = 0.0
logging.warning(f"{final_col} 裁剪后的范围为 {data_range:.4f},跳过缩放。")
continue
# 归一化: (data - Min_window) / Range_window
normalized_data = (data - min_val) / data_range
# **优化方向统一逻辑 (所有指标都应是越小越好):**
if final_col in ['X_Feature_L2_Norm', 'Y_Feature_Variance']:
# X/Y 反转:将 Max 映射到 0Min 映射到 TargetRange
final_scaled_data = (1.0 - normalized_data) * target_range
else: # Z_LDM_Loss
# Z 标准缩放Min 映射到 0Max 映射到 TargetRange
final_scaled_data = normalized_data * target_range
# 保留负值,以确保平滑过渡
df[final_col] = final_scaled_data
logging.info(f" - 列 {final_col}:裁剪边界: [{min_val:.4f}, {max_val:.4f}]。缩放后范围不再严格约束 [0, {target_range:.2f}],以保留趋势。")
# --------------------------- 3. 最终保存 ---------------------------
output_path.parent.mkdir(parents=True, exist_ok=True)
final_cols = ['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss']
df[final_cols].to_csv(
output_path,
index=False,
float_format='%.3f'
)
logging.info(f"Lowess平滑和缩放后的数据已保存到: {output_csv_path}")
if __name__ == '__main__':
if len(sys.argv) != 6:
logging.error("使用方法: python smooth_coords.py <输入CSV路径> <输出CSV路径> <Lowess 平滑因子 frac (例如 0.4)> <目标视觉范围 (例如 30)> <离散点裁剪百分比 (例如 0.15)>")
else:
input_csv = sys.argv[1]
output_csv = sys.argv[2]
try:
lowess_frac = float(sys.argv[3])
target_range = float(sys.argv[4])
clipping_p = float(sys.argv[5])
if not (0.0 < lowess_frac <= 1.0):
raise ValueError("Lowess 平滑因子 frac 必须在 (0.0, 1.0] 之间。")
if target_range <= 0:
raise ValueError("目标视觉范围必须大于 0。")
if not (0 <= clipping_p < 0.5):
raise ValueError("裁剪百分比必须在 [0, 0.5) 之间。")
if not Path(output_csv).suffix:
output_csv = str(Path(output_csv) / "scaled_coords.csv")
apply_lowess_and_clipping_scaling(input_csv, output_csv, lowess_frac, target_range, clipping_p)
except ValueError as e:
logging.error(f"参数错误: {e}")

@ -0,0 +1,149 @@
"""
图片处理功能用于把原始图片剪裁为中心正方形指定分辨率并保存为指定格式还可以选择是否序列化改名
"""
import argparse
import os
from pathlib import Path
from PIL import Image
# --- 1. 参数解析 ---
def parse_args(input_args=None):
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description="Image Processor for Centering, Resizing, and Format Conversion.")
# 路径和分辨率参数
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="A folder containing the original images to be processed and overwritten.",
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help="The target resolution (width and height) for the output images (e.g., 512 for 512x512).",
)
# 格式参数
parser.add_argument(
"--target_format",
type=str,
default="png",
choices=["jpeg", "png", "webp", "jpg"],
help="The target format for the saved images (e.g., 'png', 'jpg', 'webp'). The original file will be overwritten, potentially changing the file extension.",
)
# 序列化数字重命名参数
parser.add_argument(
"--rename_sequential",
action="store_true", # 当这个参数存在时,其值为 True
help="If set, images will be sequentially renamed (e.g., 001.jpg, 002.jpg...) instead of preserving the original filename. WARNING: This WILL delete the originals.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
# --- 2. 核心图像处理逻辑 ---
def process_image(image_path: Path, output_path: Path, resolution: int, target_format: str, delete_original: bool):
"""
加载图像居中取最大正方形升降分辨率并保存为目标格式
Args:
image_path: 原始图片路径
output_path: 最终保存路径
resolution: 目标分辨率
target_format: 目标文件格式
delete_original: 是否删除原始文件
"""
try:
# 加载图像并统一转换为 RGB 模式
img = Image.open(image_path).convert("RGB")
# 居中取最大正方形
width, height = img.size
min_dim = min(width, height)
# 计算裁剪框 (以最短边为尺寸的中心正方形)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
# 裁剪中心正方形
img = img.crop((left, top, right, bottom))
# 升降分辨率到指定 resolution
# 使用 LANCZOS 高质量重采样方法
img = img.resize((resolution, resolution), resample=Image.Resampling.LANCZOS)
# 准备输出格式
save_format = target_format.upper().replace('JPEG', 'JPG')
# 保存图片
# 对于 JPEG/JPG设置 quality 参数
if save_format == 'JPG':
img.save(output_path, format='JPEG', quality=95)
else:
img.save(output_path, format=save_format)
# 根据标记决定是否删除原始文件
if delete_original and image_path.resolve() != output_path.resolve():
os.remove(image_path)
print(f"Processed: {image_path.name} -> {output_path.name} ({resolution}x{resolution} {save_format})")
except Exception as e:
print(f"Error processing {image_path.name}: {e}")
# --- 3. 主函数 ---
def main(args):
# 路径准备
input_dir = Path(args.input_dir)
if not input_dir.is_dir():
print(f"Error: Input directory not found at {input_dir}")
return
# 查找所有图片文件 (支持 jpg, jpeg, png, webp)
valid_suffixes = ['.jpg', '.jpeg', '.png', '.webp']
image_paths = sorted([p for p in input_dir.iterdir() if p.suffix.lower() in valid_suffixes]) # 排序以确保重命名顺序一致
if not image_paths:
print(f"No image files found in {input_dir}")
return
print(f"Found {len(image_paths)} images in {input_dir}. Starting processing...")
# 准备目标格式的扩展名
extension = args.target_format.lower().replace('jpeg', 'jpg')
# 迭代处理图片
for i, image_path in enumerate(image_paths):
# 决定输出路径
if args.rename_sequential:
# 顺序重命名逻辑001, 002, 003... (至少三位数字)
new_name = f"{i + 1:03d}.{extension}"
output_path = input_dir / new_name
# 如果原始文件与新文件名称不同,则需要删除原始文件
delete_original = True
else:
# 保持原始文件名,但修改后缀
output_path = image_path.with_suffix(f'.{extension}')
# 只有当原始后缀与目标后缀不同时,才需要删除原始文件(防止遗留旧格式)
delete_original = (image_path.suffix.lower() != f'.{extension}')
process_image(image_path, output_path, args.resolution, args.target_format, delete_original)
print("Processing complete.")
if __name__ == "__main__":
args = parse_args()
main(args)

@ -230,7 +230,7 @@ def get_system_stats():
'total': total_tasks,
'completed': completed_tasks,
'processing': processing_tasks,
'failed': failed_tasks
'failed': failed_tasks,
'waiting': waiting_tasks
},
'images': {

@ -7,7 +7,6 @@ from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from app import db
from app.database import User, UserConfig
from app.services.auth_service import AuthService
from functools import wraps
import re

@ -1,203 +1,128 @@
"""
图像管理控制器
处理图像下载查看等功能
"""
from flask import Blueprint, send_file, jsonify, request, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from app.database import Image, EvaluationResult
from app.services.image_service import ImageService
import os
image_bp = Blueprint('image', __name__)
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@jwt_required()
def get_image_file(image_id):
"""获取图片文件"""
try:
current_user_id = get_jwt_identity()
# 查找图片记录
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 检查文件是否存在
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(image.file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取图片失败: {str(e)}'}), 500
@image_bp.route('/download/<int:image_id>', methods=['GET'])
@jwt_required()
def download_image(image_id):
"""下载图片文件"""
try:
current_user_id = get_jwt_identity()
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(
image.file_path,
as_attachment=True,
download_name=image.original_filename or f"image_{image_id}.jpg"
)
except Exception as e:
return jsonify({'error': f'下载图片失败: {str(e)}'}), 500
@image_bp.route('/batch/<int:batch_id>/download', methods=['GET'])
@jwt_required()
def download_batch_images(batch_id):
"""批量下载任务中的加噪后图片"""
try:
current_user_id = get_jwt_identity()
# 获取任务中的加噪图片
perturbed_images = Image.query.join(Image.image_type).filter(
Image.batch_id == batch_id,
Image.user_id == current_user_id,
Image.image_type.has(type_code='perturbed')
).all()
if not perturbed_images:
return jsonify({'error': '没有找到加噪后的图片'}), 404
# 创建ZIP文件
import zipfile
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file:
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
for image in perturbed_images:
if os.path.exists(image.file_path):
arcname = image.original_filename or f"perturbed_{image.id}.jpg"
zip_file.write(image.file_path, arcname)
return send_file(
tmp_file.name,
as_attachment=True,
download_name=f"batch_{batch_id}_perturbed_images.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({'error': f'批量下载失败: {str(e)}'}), 500
@image_bp.route('/<int:image_id>/evaluations', methods=['GET'])
@jwt_required()
def get_image_evaluations(image_id):
"""获取图片的评估结果"""
try:
current_user_id = get_jwt_identity()
# 验证图片权限
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 获取以该图片为参考或目标的评估结果
evaluations = EvaluationResult.query.filter(
(EvaluationResult.reference_image_id == image_id) |
(EvaluationResult.target_image_id == image_id)
).all()
return jsonify({
'image_id': image_id,
'evaluations': [eval_result.to_dict() for eval_result in evaluations]
}), 200
except Exception as e:
return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500
@image_bp.route('/compare', methods=['POST'])
@jwt_required()
def compare_images():
"""对比两张图片"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
image1_id = data.get('image1_id')
image2_id = data.get('image2_id')
if not image1_id or not image2_id:
return jsonify({'error': '请提供两张图片的ID'}), 400
# 验证图片权限
image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first()
image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first()
if not image1 or not image2:
return jsonify({'error': '图片不存在或无权限'}), 404
# 查找现有的评估结果
evaluation = EvaluationResult.query.filter_by(
reference_image_id=image1_id,
target_image_id=image2_id
).first()
if not evaluation:
# 如果没有评估结果,返回基本对比信息
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': None,
'message': '暂无评估数据,请等待任务处理完成'
}), 200
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': evaluation.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'图片对比失败: {str(e)}'}), 500
@image_bp.route('/heatmap/<path:heatmap_path>', methods=['GET'])
@jwt_required()
def get_heatmap(heatmap_path):
"""获取热力图文件"""
try:
# 安全检查,防止路径遍历攻击
if '..' in heatmap_path or heatmap_path.startswith('/'):
return jsonify({'error': '无效的文件路径'}), 400
# 修正路径构建 - 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path))
if not os.path.exists(full_path):
return jsonify({'error': '热力图文件不存在'}), 404
return send_file(full_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500
@image_bp.route('/delete/<int:image_id>', methods=['DELETE'])
@jwt_required()
def delete_image(image_id):
"""删除图片"""
try:
current_user_id = get_jwt_identity()
result = ImageService.delete_image(image_id, current_user_id)
if result['success']:
return jsonify({'message': '图片删除成功'}), 200
else:
return jsonify({'error': result['error']}), 400
except Exception as e:
return jsonify({'error': f'删除图片失败: {str(e)}'}), 500
"""
图像管理控制器
负责图片上传下载等操作
"""
from flask import Blueprint, request, jsonify, send_file
from app.controllers.auth_controller import int_jwt_required
from app.services.task_service import TaskService
from app.services.image_service import ImageService
image_bp = Blueprint('image', __name__)
# ==================== 图片上传 ====================
@image_bp.route('/original', methods=['POST'])
@int_jwt_required
def upload_original_images(current_user_id):
task_id = request.form.get('task_id', type=int)
if not task_id:
return ImageService.json_error('缺少 task_id 参数')
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
task_type = TaskService.get_task_type_code(task)
if task_type not in {'perturbation', 'finetune'}:
return ImageService.json_error('任务类型不支持图片上传', 400)
files = request.files.getlist('files')
target_dir = TaskService.get_original_images_path(task.user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir)
if not success:
status_code = 400
if isinstance(result, str) and (result.startswith('未配置图片类型') or '失败' in result):
status_code = 500
return ImageService.json_error(result, status_code)
return jsonify({
'message': '图片上传成功',
'images': [ImageService.serialize_image(img) for img in result],
'flow_id': task.flow_id
}), 201
# ==================== 结果下载 ====================
@image_bp.route('/perturbation/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_perturbation_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('结果文件不存在', 404)
filename = f"perturbation_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/heatmap/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_heatmap_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('热力图文件不存在', 404)
filename = f"heatmap_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/finetune/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_finetune_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
if not task.finetune:
return ImageService.json_error('微调任务配置不存在', 404)
try:
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return ImageService.json_error(str(exc), 500)
if source == 'perturbation':
directories = {
'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id),
'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
else:
directories = {
'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
zipped, has_files = ImageService.zip_multiple_directories(directories)
if not has_files:
return ImageService.json_error('微调结果文件不存在', 404)
filename = f"finetune_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/evaluate/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_evaluate_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('评估结果文件不存在', 404)
filename = f"evaluate_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')

File diff suppressed because it is too large Load Diff

@ -1,129 +1,119 @@
"""
用户管理控制器
处理用户配置等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.database import User, UserConfig, Perturbation, Finetune
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
"""获取用户配置"""
try:
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
# 如果没有配置,创建默认配置
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'config': user_config.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
"""更新用户配置"""
try:
data = request.get_json()
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置字段
if 'perturbation_configs_id' in data:
user_config.perturbation_configs_id = data['perturbation_configs_id']
if 'perturbation_intensity' in data:
intensity = float(data['perturbation_intensity'])
if 0 < epsilon <= 255:
user_config.perturbation_intensity = intensity
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'finetune_config_id' in data:
user_config.finetune_config_id = data['finetune_config_id']
db.session.commit()
return jsonify({
'message': '用户配置更新成功',
'config': user_config.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
@user_bp.route('/algorithms', methods=['GET'])
@jwt_required()
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = Perturbation.query.all()
finetune_configs = Finetune.query.all()
return jsonify({
'perturbation_algorithms': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
} for config in perturbation_configs
],
'finetune_methods': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description
} for config in finetune_configs
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
@user_bp.route('/stats', methods=['GET'])
@int_jwt_required
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.database import Task, Image
# 统计用户的任务和图片数量
total_tasks = Task.query.filter_by(user_id=current_user_id).count()
completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count()
return jsonify({
'stats': {
'total_tasks': total_tasks,
'completed_tasks': completed_tasks,
'processing_tasks': processing_tasks,
'failed_tasks': failed_tasks,
'total_images': total_images
}
}), 200
except Exception as e:
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500
"""
用户管理控制器
负责用户配置任务汇总等接口
"""
from flask import Blueprint, request, jsonify
from app import db
from app.controllers.auth_controller import int_jwt_required
from app.database import UserConfig, Task, TaskType, TaskStatus
user_bp = Blueprint('user', __name__)
def _json_error(message, status_code=400):
return jsonify({'error': message}), status_code
def _get_or_create_user_config(user_id):
config = UserConfig.query.filter_by(user_id=user_id).first()
if not config:
config = UserConfig(user_id=user_id)
db.session.add(config)
db.session.commit()
return config
def _serialize_config(config):
return {
'user_configs_id': config.user_configs_id,
'user_id': config.user_id,
'data_type_id': config.data_type_id,
'perturbation_configs_id': config.perturbation_configs_id,
'perturbation_intensity': config.perturbation_intensity,
'finetune_configs_id': config.finetune_configs_id,
'created_at': config.created_at.isoformat() if config.created_at else None,
'updated_at': config.updated_at.isoformat() if config.updated_at else None,
}
def _serialize_task(task):
status_code = task.task_status.task_status_code if task.task_status else None
task_type_code = task.task_type.task_type_code if task.task_type else None
return {
'task_id': task.tasks_id,
'flow_id': task.flow_id,
'task_type': task_type_code,
'status': status_code,
'created_at': task.created_at.isoformat() if task.created_at else None,
'started_at': task.started_at.isoformat() if task.started_at else None,
'finished_at': task.finished_at.isoformat() if task.finished_at else None,
'description': task.description,
'error_message': task.error_message
}
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
return jsonify({'config': _serialize_config(config)}), 200
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
data = request.get_json() or {}
allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'}
for key, value in data.items():
if key in allowed_fields:
if key == 'perturbation_intensity' and value is not None:
try:
value = float(value)
except (TypeError, ValueError):
return _json_error('perturbation_intensity 参数格式不正确')
setattr(config, key, value)
try:
db.session.commit()
return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200
except Exception as exc:
db.session.rollback()
return _json_error(f'更新配置失败: {exc}', 500)
@user_bp.route('/tasks', methods=['GET'])
@int_jwt_required
def list_user_tasks(current_user_id):
task_type_code = request.args.get('type')
status_code = request.args.get('status')
query = Task.query.filter_by(user_id=current_user_id)
if task_type_code:
task_type = TaskType.query.filter_by(task_type_code=task_type_code).first()
if not task_type:
return _json_error('任务类型不存在', 404)
query = query.filter(Task.tasks_type_id == task_type.task_type_id)
if status_code:
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
return _json_error('任务状态不存在', 404)
query = query.filter(Task.tasks_status_id == status.task_status_id)
tasks = query.order_by(Task.created_at.desc()).all()
return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200
@user_bp.route('/tasks/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_user_task(task_id, current_user_id):
task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first()
if not task:
return _json_error('任务不存在或无权限', 404)
return jsonify({'task': _serialize_task(task)}), 200

@ -191,6 +191,7 @@ class Task(db.Model):
"""任务总表"""
__tablename__ = 'tasks'
tasks_id = db.Column(BigInteger, primary_key=True, autoincrement=True, comment='任务ID')
flow_id = db.Column(BigInteger, nullable=False, index=True, comment='工作流ID标识关联的任务组')
tasks_type_id = db.Column(Integer, ForeignKey('task_type.task_type_id'), nullable=False, comment='任务类型')
user_id = db.Column(Integer, ForeignKey('users.user_id'), nullable=False, index=True, comment='归属用户')
tasks_status_id = db.Column(Integer, ForeignKey('task_status.task_status_id'), nullable=False, comment='任务状态ID')
@ -205,15 +206,11 @@ class Task(db.Model):
task_status = db.relationship('TaskStatus', backref='tasks')
images = db.relationship('Image', backref='task', lazy='dynamic', cascade='all, delete-orphan')
# --- 变更部分 ---
# 与子表的一对一关系 (perturbation, heatmap)
# 与子表的一对一关系
perturbation = db.relationship('Perturbation', uselist=False, back_populates='task', cascade='all, delete-orphan')
heatmap = db.relationship('Heatmap', uselist=False, back_populates='task', cascade='all, delete-orphan')
# 与子表的一对多关系 (finetune, evaluate)
finetunes = db.relationship('Finetune', back_populates='task', cascade='all, delete-orphan')
evaluations = db.relationship('Evaluate', back_populates='task', cascade='all, delete-orphan')
# --- 变更结束 ---
finetune = db.relationship('Finetune', uselist=False, back_populates='task', cascade='all, delete-orphan')
evaluation = db.relationship('Evaluate', uselist=False, back_populates='task', cascade='all, delete-orphan')
def __repr__(self):
return f'<Task {self.tasks_id}>'
@ -239,26 +236,22 @@ class Perturbation(db.Model):
return f'<Perturbation TaskID {self.tasks_id}>'
# ----------------------------
# 7. 任务子表:微调任务 (finetune) - [已更新为复合主键]
# 7. 任务子表:微调任务 (finetune)
# ----------------------------
class Finetune(db.Model):
"""微调任务详情表"""
__tablename__ = 'finetune'
# --- 变更部分:复合主键 ---
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表关联')
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='微调配置ID')
# --- 变更结束 ---
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联')
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='微调配置ID')
data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集')
finetune_name = db.Column(String(100), comment='微调任务名称')
# --- 变更部分:更新 back_populates ---
task = db.relationship('Task', back_populates='finetunes')
# --- 变更结束 ---
task = db.relationship('Task', back_populates='finetune')
finetune_config = db.relationship('FinetuneConfig')
data_type = db.relationship('DataType')
def __repr__(self):
return f'<Finetune TaskID={self.tasks_id}, ConfigID={self.finetune_configs_id}>'
return f'<Finetune TaskID={self.tasks_id}>'
# ----------------------------
# 8. 评估结果表 (evaluation_results)
@ -276,26 +269,22 @@ class EvaluationResult(db.Model):
return f'<EvaluationResult {self.evaluation_results_id}>'
# ----------------------------
# 9. 任务子表:评估任务 (evaluate) - [已更新为复合主键]
# 9. 任务子表:评估任务 (evaluate)
# ----------------------------
class Evaluate(db.Model):
"""指标计算任务表"""
__tablename__ = 'evaluate'
# --- 变更部分:复合主键 ---
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True)
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='关联的微调配置(如果是针对微调的评估)')
# --- 变更结束 ---
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联')
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='关联的微调配置')
evaluate_name = db.Column(String(100))
evaluation_results_id = db.Column(BigInteger, ForeignKey('evaluation_results.evaluation_results_id'), unique=True, default=None, comment='关联的结果ID')
# --- 变更部分:更新 back_populates ---
task = db.relationship('Task', back_populates='evaluations')
# --- 变更结束 ---
task = db.relationship('Task', back_populates='evaluation')
finetune_config = db.relationship('FinetuneConfig')
evaluation_result = db.relationship('EvaluationResult', backref='evaluate_task', uselist=False)
def __repr__(self):
return f'<Evaluate TaskID={self.tasks_id}, ConfigID={self.finetune_configs_id}>'
return f'<Evaluate TaskID={self.tasks_id}>'
# ----------------------------
# 10. 任务子表:热力图计算任务 (heatmap)
@ -303,14 +292,16 @@ class Evaluate(db.Model):
class Heatmap(db.Model):
"""热力图计算任务表"""
__tablename__ = 'heatmap'
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True)
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联')
images_id = db.Column(BigInteger, ForeignKey('images.images_id', ondelete='CASCADE'), nullable=False, comment='关联的加噪图ID')
heatmap_name = db.Column(String(100))
# 关系
task = db.relationship('Task', back_populates='heatmap')
perturbation_image = db.relationship('Image', foreign_keys=[images_id])
def __repr__(self):
return f'<Heatmap TaskID {self.tasks_id}>'
return f'<Heatmap TaskID={self.tasks_id}>'
# ----------------------------
# 11. 图片表 (images)

@ -1,34 +0,0 @@
"""
认证服务
处理用户认证相关逻辑
"""
from app.database import User
class AuthService:
"""认证服务类"""
@staticmethod
def authenticate_user(username, password):
"""验证用户凭据"""
user = User.query.filter_by(username=username).first()
if user and user.check_password(password) and user.is_active:
return user
return None
@staticmethod
def get_user_by_id(user_id):
"""根据ID获取用户"""
return User.query.get(user_id)
@staticmethod
def is_email_available(email):
"""检查邮箱是否可用"""
return User.query.filter_by(email=email).first() is None
@staticmethod
def is_username_available(username):
"""检查用户名是否可用"""
return User.query.filter_by(username=username).first() is None

@ -3,16 +3,18 @@
处理图像上传保存等功能
"""
import io
import os
import uuid
import zipfile
import fcntl
import time
from datetime import datetime
from werkzeug.utils import secure_filename
from flask import current_app
from flask import current_app, jsonify
from PIL import Image as PILImage
from app import db
from app.database import Image
from app.database import Image, ImageType
from app.utils.file_utils import allowed_file
class ImageService:
@ -254,4 +256,173 @@ class ImageService:
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
# ==================== 控制器辅助功能 ====================
DEFAULT_TARGET_SIZE = 512
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'}
@staticmethod
def json_error(message, status_code=400):
"""统一错误响应"""
return jsonify({'error': message}), status_code
@staticmethod
def get_image_type_by_code(code):
"""根据代码获取图片类型"""
return ImageType.query.filter_by(image_code=code).first()
@staticmethod
def save_original_images(task, files, target_dir, image_type_code='original', target_size=None):
"""保存原图上传"""
if not files:
return False, '未检测到文件上传'
image_type = ImageService.get_image_type_by_code(image_type_code)
if not image_type:
return False, f'未配置图片类型: {image_type_code}'
os.makedirs(target_dir, exist_ok=True)
saved_records = []
saved_paths = []
size = target_size or ImageService.DEFAULT_TARGET_SIZE
try:
for file in files:
if not file or not file.filename:
continue
if not allowed_file(file.filename):
continue
extension = os.path.splitext(file.filename)[1].lower()
if extension not in ImageService.IMAGE_EXTENSIONS:
continue
processed = ImageService._prepare_image(file, size)
filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir)
image = ImageService._create_image_record(
task,
image_type.image_types_id,
filename,
path,
width,
height,
file_size
)
saved_records.append(image)
saved_paths.append(path)
if not saved_records:
db.session.rollback()
return False, '未上传有效的图片文件'
db.session.commit()
return True, saved_records
except Exception as exc:
db.session.rollback()
for path in saved_paths:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
return False, f'上传图片失败: {exc}'
@staticmethod
def _prepare_image(file_storage, target_size):
"""裁剪并缩放上传图片"""
file_storage.stream.seek(0)
image = PILImage.open(file_storage.stream).convert('RGB')
width, height = image.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
image = image.crop((left, top, left + min_dim, top + min_dim))
return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS)
@staticmethod
def _save_processed_image(image, target_dir):
"""将处理后的图片保存为PNG"""
timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png"
path = os.path.join(target_dir, filename)
image.save(path, format='PNG')
return filename, path, image.width, image.height, os.path.getsize(path)
@staticmethod
def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None):
"""创建图片数据库记录"""
image = Image(
task_id=task.tasks_id,
image_types_id=image_type_id,
father_id=father_id,
stored_filename=filename,
file_path=path,
file_size=file_size,
width=width,
height=height
)
db.session.add(image)
return image
@staticmethod
def zip_directory(directory):
"""打包目录为zip"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if os.path.isdir(directory):
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, directory)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def zip_multiple_directories(directories):
"""打包多个目录"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if isinstance(directories, dict):
iterable = directories.items()
else:
iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories)
for label, directory in iterable:
if not os.path.isdir(directory):
continue
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, directory)
arcname = os.path.join(label or 'output', rel_path)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def serialize_image(image):
"""图片序列化"""
if not image:
return None
return {
'image_id': image.images_id,
'task_id': image.task_id,
'stored_filename': image.stored_filename,
'file_path': image.file_path,
'file_size': image.file_size,
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,242 @@
"""
RQ Worker 数值评估任务处理器仅使用真实算法
生成原始图与扰动图微调后的模型生成效果对比报告
"""
import os
import subprocess
import logging
import shutil
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir,
perturbed_output_dir, output_dir, image_size=512):
"""
执行数值评估任务仅使用真实算法
Args:
task_id: 任务ID
clean_ref_dir: 干净参考图片目录原始上传的图片
clean_output_dir: 干净图片训练后的生成结果目录
perturbed_output_dir: 扰动图片训练后的生成结果目录
output_dir: 输出目录
image_size: 图片处理尺寸
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Evaluate, Task, TaskStatus
app = create_app()
with app.app_context():
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取评估任务详情
evaluate = Evaluate.query.get(task_id)
if not evaluate:
raise ValueError(f"Evaluate task {task_id} not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
task.tasks_status_id = processing_status.task_status_id
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting evaluate task {task_id}")
# 确保目录存在并清空
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Clearing output directory: {output_dir}")
for item in os.listdir(output_dir):
item_path = os.path.join(output_dir, item)
if os.path.isfile(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 运行真实评估算法
result = _run_real_evaluate(
task_id, clean_ref_dir, clean_output_dir,
perturbed_output_dir, output_dir, image_size
)
# 保存评估结果文件路径到数据库
report_file = os.path.join(output_dir, 'nums_dif.png')
if os.path.exists(report_file):
# 保存报告图到Image表
_save_report_image(task.tasks_id, report_file)
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
if completed_status:
task.tasks_status_id = completed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
logger.info(f"Evaluate task {task_id} completed")
return result
except Exception as e:
logger.error(f"Evaluate task {task_id} failed: {str(e)}", exc_info=True)
# 更新任务状态为失败
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
return {'success': False, 'error': str(e)}
def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir,
perturbed_output_dir, output_dir, image_size):
"""运行真实数值评估算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real evaluate generation")
# 获取评估脚本配置
evaluate_config = AlgorithmConfig.EVALUATE_SCRIPTS.get('numbers', {})
script_path = evaluate_config.get('real_script')
conda_env = evaluate_config.get('conda_env')
if not script_path:
raise ValueError("Evaluate script not configured")
# 输出文件路径
png_output_path = os.path.join(output_dir, 'nums_dif.png')
# 构建命令行参数
cmd_args = [
f"--clean_ref_dir={clean_ref_dir}",
f"--clean_output_dir={clean_output_dir}",
f"--perturbed_output_dir={perturbed_output_dir}",
f"--png_output_path={png_output_path}",
f"--size={image_size}",
]
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置环境变量(强制离线模式)
env = os.environ.copy()
env['HF_HUB_OFFLINE'] = '1'
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(
log_dir,
f'evaluate_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
env=env
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Evaluate generation failed with code {process.returncode}. Check log: {log_file}")
return {
'status': 'success',
'output_dir': output_dir,
'log_file': log_file
}
def _save_report_image(task_id, report_file_path):
"""
保存评估报告图到数据库Image表
Args:
task_id: 任务ID
report_file_path: 报告图文件完整路径
"""
from app import db
from app.database import Image, ImageType
from PIL import Image as PILImage
try:
# 获取报告图片类型
report_type = ImageType.query.filter_by(image_code='report').first()
if not report_type:
logger.error("Image type 'report' not found")
return
# 获取文件名
report_filename = os.path.basename(report_file_path)
# 检查是否已经保存过
existing = Image.query.filter_by(
task_id=task_id,
stored_filename=report_filename,
image_types_id=report_type.image_types_id
).first()
if existing:
logger.info(f"Report image {report_filename} already exists, skipping")
return
# 读取图片尺寸
try:
with PILImage.open(report_file_path) as img:
width, height = img.size
except:
width, height = None, None
# 保存到数据库 (report不需要father_id)
report_image = Image(
task_id=task_id,
image_types_id=report_type.image_types_id,
father_id=None,
stored_filename=report_filename,
file_path=report_file_path,
file_size=os.path.getsize(report_file_path),
width=width,
height=height
)
db.session.add(report_image)
db.session.commit()
logger.info(f"Saved report image: {report_filename}")
except Exception as e:
logger.error(f"Error saving report image: {str(e)}")
db.session.rollback()

@ -1,489 +1,403 @@
"""
RQ Worker 微调任务处理器
在后台执行模型微调任务
"""
import os
import subprocess
import logging
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def _check_and_update_finetune_status(finetune_task):
"""
检查微调任务状态并更新
当原始和扰动图片的微调都完成时更新任务状态为completed
Args:
finetune_task: FinetuneBatch对象
"""
from app import db
from rq.job import Job
from redis import Redis
from config.algorithm_config import AlgorithmConfig
try:
# 刷新数据库对象,确保获取最新状态
db.session.refresh(finetune_task)
# 如果状态已经是completed或failed不再检查
if finetune_task.status in ['completed', 'failed']:
return
redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL)
original_job_done = False
perturbed_job_done = False
has_original_job = False
has_perturbed_job = False
# 检查原始图片微调任务
if finetune_task.original_job_id:
has_original_job = True
try:
original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn)
status = original_job.get_status()
logger.info(f"Original job {finetune_task.original_job_id} status: {status}")
if status == 'finished':
original_job_done = True
elif status == 'failed':
# 如果原始任务失败,整个微调任务标记为失败
finetune_task.status = 'failed'
finetune_task.error_message = f"Original finetune job failed: {original_job.exc_info}"
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.error(f"FinetuneBatch {finetune_task.id} failed: original job failed")
return
except Exception as e:
logger.error(f"Error checking original job: {str(e)}")
# 检查扰动图片微调任务
if finetune_task.perturbed_job_id:
has_perturbed_job = True
try:
perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn)
status = perturbed_job.get_status()
logger.info(f"Perturbed job {finetune_task.perturbed_job_id} status: {status}")
if status == 'finished':
perturbed_job_done = True
elif status == 'failed':
# 如果扰动任务失败,整个微调任务标记为失败
finetune_task.status = 'failed'
finetune_task.error_message = f"Perturbed finetune job failed: {perturbed_job.exc_info}"
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.error(f"FinetuneBatch {finetune_task.id} failed: perturbed job failed")
return
except Exception as e:
logger.error(f"Error checking perturbed job: {str(e)}")
# 如果两个任务都完成更新状态为completed
if has_original_job and has_perturbed_job and original_job_done and perturbed_job_done:
finetune_task.status = 'completed'
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.info(f"FinetuneBatch {finetune_task.id} completed - both jobs finished")
else:
logger.info(f"FinetuneBatch {finetune_task.id} not all jobs finished yet: original={original_job_done}, perturbed={perturbed_job_done}")
except Exception as e:
logger.error(f"Error checking finetune status: {str(e)}", exc_info=True)
def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed=False, custom_params=None):
"""
执行微调任务
Args:
finetune_batch_id: 微调任务ID
batch_id: 扰动任务批次ID
finetune_method: 微调方法 (dreambooth, lora)
train_images_dir: 训练图片目录原始或扰动
output_model_dir: 模型输出目录
class_dir: 类别图片目录
inference_prompts: 推理提示词
is_perturbed: 是否是扰动图片训练
custom_params: 自定义参数
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import FinetuneBatch, Batch, Image, ImageType
app = create_app()
with app.app_context():
try:
finetune_task = FinetuneBatch.query.get(finetune_batch_id)
if not finetune_task:
raise ValueError(f"FinetuneBatch {finetune_batch_id} not found")
batch = Batch.query.get(batch_id)
if not batch:
raise ValueError(f"Batch {batch_id} not found")
# 更新微调任务状态为处理中
if finetune_task.status == 'queued':
finetune_task.status = 'processing'
db.session.commit()
logger.info(f"Starting finetune task for FinetuneBatch {finetune_batch_id}, Batch {batch_id}")
logger.info(f"Method: {finetune_method}, Perturbed: {is_perturbed}")
# 确保目录存在
os.makedirs(output_model_dir, exist_ok=True)
os.makedirs(class_dir, exist_ok=True)
# 获取配置
use_real = AlgorithmConfig.USE_REAL_ALGORITHMS
if use_real:
# 使用真实微调算法
result = _run_real_finetune(
finetune_method, batch_id, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed, custom_params
)
else:
# 使用虚拟微调实现
result = _run_virtual_finetune(
finetune_method, batch_id, train_images_dir, output_model_dir,
is_perturbed
)
# 保存生成的图片到数据库
_save_generated_images(batch_id, output_model_dir, is_perturbed)
# 检查两个任务是否都已完成
_check_and_update_finetune_status(finetune_task)
logger.info(f"Finetune task completed for FinetuneBatch {finetune_batch_id}")
return result
except Exception as e:
logger.error(f"Finetune task failed for FinetuneBatch {finetune_batch_id}: {str(e)}", exc_info=True)
# 更新微调任务状态为失败
if finetune_task:
finetune_task.status = 'failed'
finetune_task.error_message = str(e)
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
raise
def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed, custom_params):
"""运行真实微调算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real finetune: {finetune_method}")
# 获取微调脚本路径和环境
finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {})
script_path = finetune_config.get('real_script')
conda_env = finetune_config.get('conda_env')
default_params = finetune_config.get('default_params', {})
if not script_path:
raise ValueError(f"Finetune method {finetune_method} not configured")
# 合并参数
params = {**default_params, **(custom_params or {})}
# 构建命令行参数
cmd_args = [
f"--instance_data_dir={train_images_dir}",
f"--output_dir={output_model_dir}",
f"--class_data_dir={class_dir}",
]
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
image_type = 'perturbed' if is_perturbed else 'original'
log_file = os.path.join(
log_dir,
f'finetune_{image_type}_{batch_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
import shutil
for item in os.listdir(class_dir):
item_path = os.path.join(class_dir, item)
if os.path.isfile(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 清理output_model_dir中的非图片文件
logger.info(f"Cleaning non-image files in output directory: {output_model_dir}")
if os.path.exists(output_model_dir):
import shutil
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for item in os.listdir(output_model_dir):
item_path = os.path.join(output_model_dir, item)
# 如果是目录,直接删除
if os.path.isdir(item_path):
logger.info(f"Removing directory: {item_path}")
shutil.rmtree(item_path)
# 如果是文件,检查是否为图片
elif os.path.isfile(item_path):
_, ext = os.path.splitext(item.lower())
if ext not in image_extensions:
logger.info(f"Removing non-image file: {item_path}")
os.remove(item_path)
return {
'status': 'success',
'output_dir': output_model_dir,
'log_file': log_file
}
def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, is_perturbed):
"""运行虚拟微调实现"""
from config.algorithm_config import AlgorithmConfig
import glob
logger.info(f"Running virtual finetune: {finetune_method}")
# 获取微调配置
finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {})
if not finetune_config:
raise ValueError(f"Finetune method {finetune_method} not configured")
conda_env = finetune_config.get('conda_env')
default_params = finetune_config.get('default_params', {})
# 获取虚拟微调脚本路径
script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py'
script_path = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../algorithms/finetune_virtual',
script_name
))
if not os.path.exists(script_path):
raise FileNotFoundError(f"Virtual finetune script not found: {script_path}")
logger.info(f"Virtual script path: {script_path}")
logger.info(f"Conda environment: {conda_env}")
# 创建输出目录
os.makedirs(output_model_dir, exist_ok=True)
validation_output_dir = os.path.join(output_model_dir, 'generated')
os.makedirs(validation_output_dir, exist_ok=True)
# 构建命令行参数(与真实微调参数一致)
cmd_args = [
f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}",
f"--instance_data_dir={train_images_dir}",
f"--output_dir={output_model_dir}",
f"--validation_image_output_dir={validation_output_dir}",
f"--class_data_dir=/tmp/class_placeholder",
]
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他默认参数
for key, value in default_params.items():
if key == 'pretrained_model_name_or_path':
continue # 已添加
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 使用conda run执行虚拟脚本
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
image_type = 'perturbed' if is_perturbed else 'original'
log_file = os.path.join(
log_dir,
f'virtual_{finetune_method}_{image_type}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}")
# 统计生成的图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
generated_files = []
for ext in image_extensions:
generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext)))
generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext.upper())))
logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images")
return {
'status': 'success',
'output_dir': output_model_dir,
'generated_count': len(generated_files),
'generated_files': generated_files,
'log_file': log_file
}
def _save_generated_images(batch_id, output_model_dir, is_perturbed):
"""保存生成的图片到数据库"""
from app import db
from app.database import Batch, Image, ImageType
import glob
try:
batch = Batch.query.get(batch_id)
if not batch:
return
# 确定图片类型
if is_perturbed:
image_type = ImageType.query.filter_by(type_code='perturbed_generate').first()
else:
image_type = ImageType.query.filter_by(type_code='original_generate').first()
if not image_type:
logger.error(f"Image type not found for is_perturbed={is_perturbed}")
return
# 查找生成的图片
generated_dir = os.path.join(output_model_dir, 'generated')
if not os.path.exists(generated_dir):
# 尝试直接从output_model_dir查找
generated_dir = output_model_dir
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(generated_dir, ext)))
image_files.extend(glob.glob(os.path.join(generated_dir, ext.upper())))
logger.info(f"Found {len(image_files)} generated images to save")
# 保存到数据库
saved_count = 0
for image_path in image_files:
try:
from PIL import Image as PILImage
filename = os.path.basename(image_path)
# 检查是否已经保存过使用filename作为stored_filename
existing = Image.query.filter_by(
batch_id=batch_id,
stored_filename=filename
).first()
if existing:
logger.info(f"Image already exists: {filename}")
continue
with PILImage.open(image_path) as img:
width, height = img.size
# 生成图片不设置父图片关系(多对多关系,无法确定具体父图片)
# 创建图片记录直接使用filename算法已经生成了正确格式
generated_image = Image(
user_id=batch.user_id,
batch_id=batch_id,
father_id=None, # 微调生成图片无特定父图片
original_filename=filename,
stored_filename=filename, # 算法输出已经是正确格式
file_path=image_path,
file_size=os.path.getsize(image_path),
image_type_id=image_type.id,
width=width,
height=height
)
db.session.add(generated_image)
saved_count += 1
logger.info(f"Saved generated image: {filename}")
except Exception as e:
logger.error(f"Failed to save {image_path}: {str(e)}")
db.session.commit()
logger.info(f"Successfully saved {saved_count} generated images to database")
except Exception as e:
logger.error(f"Error saving generated images: {str(e)}")
db.session.rollback()
"""
RQ Worker 微调任务处理器 - 适配新数据库结构
仅支持真实算法移除虚拟算法调用
"""
import os
import subprocess
import logging
import glob
import shutil
from datetime import datetime
from PIL import Image as PILImage
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images_dir,
output_model_dir, class_dir, coords_save_path, validation_output_dir,
is_perturbed=False, custom_params=None):
"""
执行微调任务仅使用真实算法
Args:
task_id: 任务ID
finetune_config_id: 微调配置ID
finetune_method: 微调方法 (dreambooth, lora, textual_inversion)
train_images_dir: 训练图片目录
output_model_dir: 模型输出目录
class_dir: 类别图片目录
coords_save_path: 坐标保存路径
validation_output_dir: 验证图片输出目录
is_perturbed: 是否使用扰动图片训练
custom_params: 自定义参数
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Task, Finetune, DataType, TaskStatus
app = create_app()
with app.app_context():
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取微调任务详情
finetune = Finetune.query.filter_by(
tasks_id=task_id,
finetune_configs_id=finetune_config_id
).first()
if not finetune:
raise ValueError(f"Finetune task ({task_id}, {finetune_config_id}) not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
task.tasks_status_id = processing_status.task_status_id
if not task.started_at:
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting finetune task {task_id} (config: {finetune_config_id})")
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
# 从数据库获取数据集类型的提示词
# 从Finetune表的data_type_id获取
instance_prompt = "a photo of sks person" # 默认值
class_prompt = "a photo of person" # 默认值
validation_prompt = "a photo of sks person" # 默认值
if finetune.data_type_id:
data_type = DataType.query.get(finetune.data_type_id)
if data_type and data_type.data_type_prompt:
instance_prompt = data_type.data_type_prompt
validation_prompt = instance_prompt
# 从instance_prompt生成class_prompt移除"sks"
class_prompt = instance_prompt.replace('sks ', '')
logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'")
# 清空输出目录(避免旧文件残留)
logger.info(f"Clearing output directories...")
for dir_path in [output_model_dir, validation_output_dir, coords_save_path]:
if os.path.exists(dir_path):
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
if os.path.isfile(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 运行真实微调算法
result = _run_real_finetune(
finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params
)
# 保存生成的验证图片到数据库
_save_generated_images(task_id, validation_output_dir, is_perturbed)
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
if completed_status:
task.tasks_status_id = completed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
logger.info(f"Finetune task {task_id} completed successfully")
return result
except Exception as e:
logger.error(f"Finetune task {task_id} failed: {str(e)}", exc_info=True)
# 更新任务状态为失败
try:
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
task.error_message = str(e)
db.session.commit()
except:
pass
raise
def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params):
"""
运行真实微调算法参考sh脚本配置
Args:
finetune_method: 微调方法
task_id: 任务ID
train_images_dir: 训练图片目录
output_model_dir: 模型输出目录
class_dir: 类别数据目录
coords_save_path: 坐标保存路径
validation_output_dir: 验证图片输出目录
instance_prompt: 实例提示词
class_prompt: 类别提示词
validation_prompt: 验证提示词
is_perturbed: 是否使用扰动图片
custom_params: 自定义参数
"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real finetune: {finetune_method}")
logger.info(f"Instance prompt: '{instance_prompt}'")
logger.info(f"Class prompt: '{class_prompt}'")
logger.info(f"Validation prompt: '{validation_prompt}'")
# 获取微调脚本路径和环境
finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {})
script_path = finetune_config.get('real_script')
conda_env = finetune_config.get('conda_env')
default_params = finetune_config.get('default_params', {})
if not script_path:
raise ValueError(f"Finetune method {finetune_method} not configured")
# 覆盖提示词参数(从数据库读取)
default_params['instance_prompt'] = instance_prompt
default_params['class_prompt'] = class_prompt
default_params['validation_prompt'] = validation_prompt
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
# 根据微调方法构建命令参数参考sh脚本
cmd_args = [
f"--instance_data_dir={train_images_dir}",
f"--output_dir={output_model_dir}",
f"--validation_image_output_dir={validation_output_dir}",
]
if finetune_method == 'dreambooth':
# DreamBooth 特有参数
cmd_args.extend([
f"--class_data_dir={class_dir}",
f"--coords_save_path={coords_save_path}",
])
elif finetune_method == 'lora':
# LoRA 特有参数 (positions_save_path 等同于 coords_save_path)
cmd_args.extend([
f"--class_data_dir={class_dir}",
f"--positions_save_path={coords_save_path}",
])
elif finetune_method == 'textual_inversion':
# Textual Inversion 特有参数 (不需要 class_data_dir)
cmd_args.extend([
f"--coords_save_path={coords_save_path}",
])
else:
raise ValueError(f"Unsupported finetune method: {finetune_method}")
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他默认参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 设置环境变量
env = os.environ.copy()
env['HF_HUB_OFFLINE'] = '1' # 强制离线模式
env['CUDA_VISIBLE_DEVICES'] = '0' # 默认使用第一块GPU
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
image_type = 'perturbed' if is_perturbed else 'original'
log_file = os.path.join(
log_dir,
f'finetune_{image_type}_{task_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
env=env
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
logger.info(f"Finetune execution completed with return code: {process.returncode}")
logger.info(f"Output directory: {output_model_dir}")
logger.info(f"Log file: {log_file}")
if process.returncode != 0:
raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir参考sh脚本
if finetune_method in ['dreambooth', 'lora']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
shutil.rmtree(class_dir)
os.makedirs(class_dir)
# 清理output_model_dir中的非图片文件
logger.info(f"Cleaning non-image files in output directory: {output_model_dir}")
if os.path.exists(output_model_dir):
import shutil
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for item in os.listdir(output_model_dir):
item_path = os.path.join(output_model_dir, item)
if os.path.isfile(item_path):
_, ext = os.path.splitext(item)
if ext.lower() not in image_extensions:
try:
os.remove(item_path)
logger.info(f"Removed non-image file: {item}")
except Exception as e:
logger.warning(f"Failed to remove {item}: {str(e)}")
return {
'status': 'success',
'output_dir': output_model_dir,
'log_file': log_file
}
def _save_generated_images(task_id, output_dir, is_perturbed):
"""
保存微调生成的验证图片到数据库适配新数据库结构
新数据库结构
- Task表tasks_id (主键), flow_id, tasks_type_id
- Image表images_id (主键), task_id (外键), image_types_id, father_id
- 生成图的father_id设置为输入图片的第一张
Args:
task_id: 任务ID
output_dir: 生成图片输出目录
is_perturbed: 是否为扰动图片训练生成
"""
from app import db
from app.database import Task, Image, ImageType
try:
# 验证任务存在
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取图片类型
if is_perturbed:
generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
input_type = ImageType.query.filter_by(image_code='perturbed').first()
else:
generated_type = ImageType.query.filter_by(image_code='original_generate').first()
input_type = ImageType.query.filter_by(image_code='original').first()
if not generated_type or not input_type:
raise ValueError("Required image types not found in database")
# 获取输入图片的第一张作为father_id
first_input_image = Image.query.filter_by(
task_id=task_id,
image_types_id=input_type.image_types_id
).order_by(Image.images_id.asc()).first()
father_id = first_input_image.images_id if first_input_image else None
logger.info(f"Will set father_id={father_id} for all generated images")
# 查找输出目录中的所有生成图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp', '*.tiff']
generated_files = []
for ext in image_extensions:
generated_files.extend(glob.glob(os.path.join(output_dir, ext)))
generated_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Found {len(generated_files)} generated images in output directory")
saved_count = 0
for generated_path in generated_files:
try:
generated_filename = os.path.basename(generated_path)
# 检查是否已存在
existing = Image.query.filter_by(
task_id=task_id,
stored_filename=generated_filename,
image_types_id=generated_type.image_types_id
).first()
if existing:
logger.info(f"Image {generated_filename} already exists, skipping")
continue
# 读取图片尺寸
width, height = None, None
try:
with PILImage.open(generated_path) as img:
width, height = img.size
except Exception as e:
logger.warning(f"Could not read image dimensions for {generated_filename}: {e}")
# 保存到数据库所有生成图的father_id统一设置为输入的第一张图片
generated_image = Image(
task_id=task_id,
image_types_id=generated_type.image_types_id,
father_id=father_id, # 统一设置为输入的第一张图片
stored_filename=generated_filename,
file_path=generated_path,
file_size=os.path.getsize(generated_path),
width=width,
height=height
)
db.session.add(generated_image)
saved_count += 1
logger.info(f"Saved generated image: {generated_filename} (father: {father_id})")
except Exception as e:
logger.error(f"Error saving generated image {generated_filename}: {str(e)}")
continue
db.session.commit()
logger.info(f"Successfully saved {saved_count} generated images to database")
except Exception as e:
logger.error(f"Error saving generated images: {str(e)}")
db.session.rollback()

@ -0,0 +1,268 @@
"""
RQ Worker 热力图任务处理器 - 适配新数据库结构
生成原始图与扰动图的注意力差异热力图
仅支持真实算法移除虚拟算法调用
"""
import os
import subprocess
import logging
import shutil
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
output_dir, model_path, perturbed_image_id=None):
"""
执行热力图生成任务仅使用真实算法
Args:
task_id: 任务ID
original_image_path: 原始图片路径
perturbed_image_path: 扰动图片路径
output_dir: 输出目录
model_path: Stable Diffusion模型路径
perturbed_image_id: 扰动图片ID用于建立father关系
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Heatmap, Task, TaskStatus, DataType, Perturbation
app = create_app()
with app.app_context():
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取热力图任务详情
heatmap = Heatmap.query.get(task_id)
if not heatmap:
raise ValueError(f"Heatmap task {task_id} not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
task.tasks_status_id = processing_status.task_status_id
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting heatmap task {task_id}")
# 从数据库获取提示词从关联的Perturbation任务获取
prompt_text = "a photo of sks person" # 默认值
target_word = "person" # 默认值
# 通过flow_id查找关联的Perturbation任务
perturbation_tasks = Task.query.filter_by(
flow_id=task.flow_id,
tasks_type_id=1 # perturbation类型
).all()
if perturbation_tasks:
for pert_task in perturbation_tasks:
perturbation = Perturbation.query.get(pert_task.tasks_id)
if perturbation and perturbation.data_type_id:
data_type = DataType.query.get(perturbation.data_type_id)
if data_type and data_type.data_type_prompt:
prompt_text = data_type.data_type_prompt
# 提取target_word去除"sks"后的第一个名词)
words = prompt_text.replace('sks ', '').split()
if words:
target_word = words[-1] # 取最后一个词作为target
logger.info(f"Using prompts from database - prompt: '{prompt_text}', target: '{target_word}'")
break
# 确保目录存在并清空
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Clearing output directory: {output_dir}")
for item in os.listdir(output_dir):
item_path = os.path.join(output_dir, item)
if os.path.isfile(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 运行真实热力图算法
result = _run_real_heatmap(
task_id, original_image_path, perturbed_image_path,
prompt_text, target_word, output_dir, model_path
)
# 保存热力图文件到数据库
heatmap_file = os.path.join(output_dir, 'heatmap_dif.png')
if os.path.exists(heatmap_file):
heatmap.heatmap_name = 'heatmap_dif.png'
# 保存热力图到Image表
_save_heatmap_image(task_id, heatmap_file, perturbed_image_id)
db.session.commit()
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
if completed_status:
task.tasks_status_id = completed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
logger.info(f"Heatmap task {task_id} completed")
return result
except Exception as e:
logger.error(f"Heatmap task {task_id} failed: {str(e)}", exc_info=True)
# 更新任务状态为失败
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
return {'success': False, 'error': str(e)}
def _run_real_heatmap(task_id, original_image_path, perturbed_image_path,
prompt_text, target_word, output_dir, model_path):
"""运行真实热力图算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real heatmap generation")
# 获取热力图脚本配置
evaluate_config = AlgorithmConfig.EVALUATE_SCRIPTS.get('heatmap', {})
script_path = evaluate_config.get('real_script')
conda_env = evaluate_config.get('conda_env')
if not script_path:
raise ValueError("Heatmap script not configured")
# 构建命令行参数
cmd_args = [
f"--model_path={model_path}",
f"--image_path_a={original_image_path}",
f"--image_path_b={perturbed_image_path}",
f"--prompt_text={prompt_text}",
f"--target_word={target_word}",
f"--output_dir={output_dir}",
]
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置环境变量(强制离线模式)
env = os.environ.copy()
env['HF_HUB_OFFLINE'] = '1'
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(
log_dir,
f'heatmap_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
env=env
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Heatmap generation failed with code {process.returncode}. Check log: {log_file}")
return {
'status': 'success',
'output_dir': output_dir,
'log_file': log_file
}
def _save_heatmap_image(task_id, heatmap_file_path, father_image_id=None):
"""
保存热力图到数据库Image表
Args:
task_id: 任务ID
heatmap_file_path: 热力图文件完整路径
father_image_id: 父图片ID(原始图片ID)
"""
from app import db
from app.database import Image, ImageType
from PIL import Image as PILImage
try:
# 获取热力图图片类型
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if not heatmap_type:
logger.error("Image type 'heatmap' not found")
return
# 获取文件名
heatmap_filename = os.path.basename(heatmap_file_path)
# 检查是否已经保存过
existing = Image.query.filter_by(
task_id=task_id,
stored_filename=heatmap_filename,
image_types_id=heatmap_type.image_types_id
).first()
if existing:
logger.info(f"Heatmap image {heatmap_filename} already exists, skipping")
return
# 读取图片尺寸
try:
with PILImage.open(heatmap_file_path) as img:
width, height = img.size
except:
width, height = None, None
# 保存到数据库
heatmap_image = Image(
task_id=task_id,
image_types_id=heatmap_type.image_types_id,
father_id=father_image_id, # 设置父图片关系
stored_filename=heatmap_filename,
file_path=heatmap_file_path,
file_size=os.path.getsize(heatmap_file_path),
width=width,
height=height
)
db.session.add(heatmap_image)
db.session.commit()
logger.info(f"Saved heatmap image: {heatmap_filename} (father: {father_image_id})")
except Exception as e:
logger.error(f"Error saving heatmap image: {str(e)}")
db.session.rollback()

@ -1,422 +1,394 @@
"""
RQ Worker任务处理器
在后台执行对抗性扰动算法
"""
import os
import sys
import subprocess
import logging
from datetime import datetime
from pathlib import Path
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_perturbation_task(batch_id, algorithm_code, epsilon, use_strong_protection,
input_dir, output_dir, class_dir, custom_params=None):
"""
执行对抗性扰动任务
Args:
batch_id: 任务批次ID
algorithm_code: 算法代码
epsilon: 扰动强度
use_strong_protection: 是否使用防净化版本
input_dir: 输入图片目录
output_dir: 输出目录
class_dir: 类别图片目录
custom_params: 自定义参数
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Batch
# 创建应用上下文
app = create_app()
with app.app_context():
try:
# 更新任务状态
batch = Batch.query.get(batch_id)
if not batch:
raise ValueError(f"Batch {batch_id} not found")
batch.status = 'processing'
batch.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting perturbation task for batch {batch_id}")
logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}")
# 获取算法配置
use_real = AlgorithmConfig.USE_REAL_ALGORITHMS
script_path = AlgorithmConfig.get_script_path(algorithm_code)
conda_env = AlgorithmConfig.get_conda_env(algorithm_code)
# 确保目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(class_dir, exist_ok=True)
if use_real:
# 使用真实算法
result = _run_real_algorithm(
script_path, conda_env, algorithm_code, batch_id,
epsilon, use_strong_protection, input_dir, output_dir,
class_dir, custom_params
)
else:
# 使用虚拟实现
result = _run_virtual_algorithm(
algorithm_code, batch_id, epsilon, use_strong_protection,
input_dir, output_dir
)
# 更新任务状态为完成
batch.status = 'completed'
batch.completed_at = datetime.utcnow()
db.session.commit()
# 保存扰动图片到数据库
_save_perturbed_images(batch_id, output_dir)
logger.info(f"Task completed successfully for batch {batch_id}")
return result
except Exception as e:
logger.error(f"Task failed for batch {batch_id}: {str(e)}", exc_info=True)
# 更新任务状态为失败
if batch:
batch.status = 'failed'
batch.error_message = str(e)
batch.completed_at = datetime.utcnow()
db.session.commit()
raise
def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id,
epsilon, use_strong_protection, input_dir, output_dir,
class_dir, custom_params):
"""运行真实算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real algorithm: {algorithm_code}")
logger.info(f"Conda environment: {conda_env}")
logger.info(f"Script path: {script_path}")
# 获取默认参数
default_params = AlgorithmConfig.get_default_params(algorithm_code)
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
cmd_args = []
if algorithm_code == 'aspl':
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={str(epsilon)}",
])
elif algorithm_code == 'simac':
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={str(epsilon)}",
])
elif algorithm_code == 'caat':
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={str(epsilon)}",
])
elif algorithm_code == 'pid':
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={str(epsilon)}",
])
else:
raise ValueError(f"Unsupported algorithm code: {algorithm_code}")
# 添加其他参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 构建完整命令
# 使用conda run避免环境嵌套问题
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'batch_{batch_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
logger.info(f"output_dir: {output_dir}")
logger.info(f"log_file: {log_file}")
if process.returncode != 0:
raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
import shutil
for item in os.listdir(class_dir):
item_path = os.path.join(class_dir, item)
if os.path.isfile(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
return {
'status': 'success',
'output_dir': output_dir,
'log_file': log_file
}
def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protection,
input_dir, output_dir):
"""运行虚拟算法实现"""
from config.algorithm_config import AlgorithmConfig
import glob
logger.info(f"Running virtual algorithm: {algorithm_code}")
# 获取算法配置
algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code)
if not algo_config:
raise ValueError(f"Algorithm {algorithm_code} not configured")
conda_env = algo_config.get('conda_env')
default_params = algo_config.get('default_params', {})
# 获取虚拟算法脚本路径
script_path = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../algorithms/perturbation_virtual',
f'{algorithm_code}.py'
))
if not os.path.exists(script_path):
raise FileNotFoundError(f"Virtual script not found: {script_path}")
logger.info(f"Virtual script path: {script_path}")
logger.info(f"Conda environment: {conda_env}")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 构建命令行参数(与真实算法参数一致)
cmd_args = [
f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}",
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir=/tmp/class_placeholder",
f"--pgd_eps={epsilon}",
]
# 添加其他默认参数
for key, value in default_params.items():
if key == 'pretrained_model_name_or_path':
continue # 已添加
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 使用conda run执行虚拟脚本
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(
log_dir,
f'virtual_{algorithm_code}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}")
# 统计处理的图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
processed_files = []
for ext in image_extensions:
processed_files.extend(glob.glob(os.path.join(output_dir, ext)))
processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images")
return {
'status': 'success',
'output_dir': output_dir,
'processed_count': len(processed_files),
'processed_files': processed_files,
'log_file': log_file
}
def _save_perturbed_images(batch_id, output_dir):
"""保存扰动图片到数据库"""
from app import db
from app.database import Batch, Image, ImageType
import glob
from PIL import Image as PILImage
try:
batch = Batch.query.get(batch_id)
if not batch:
logger.error(f"Batch {batch_id} not found")
return
# 获取扰动图片类型
perturbed_type = ImageType.query.filter_by(type_code='perturbed').first()
if not perturbed_type:
logger.error("Perturbed image type not found")
return
# 获取原始图片列表
original_type = ImageType.query.filter_by(type_code='original').first()
original_images = Image.query.filter_by(
batch_id=batch_id,
image_type_id=original_type.id
).all()
# 创建原图映射字典: stored_filename -> Image对象
original_map = {img.stored_filename: img for img in original_images}
# 查找输出目录中的扰动图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
perturbed_files = []
for ext in image_extensions:
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext)))
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Found {len(perturbed_files)} perturbed images to save")
saved_count = 0
for perturbed_path in perturbed_files:
try:
filename = os.path.basename(perturbed_path)
# 扰动图片命名格式: perturbed_{原图名}.ext
# 提取原图名
parent_image = None
if filename.startswith('perturbed_'):
# 去掉perturbed_前缀得到原图名
original_filename = filename[len('perturbed_'):]
# 尝试从映射中查找
parent_image = original_map.get(original_filename)
if not parent_image:
logger.warning(f"Parent image not found for {filename}, original should be: {original_filename}")
# 获取图片尺寸
with PILImage.open(perturbed_path) as img:
width, height = img.size
# 检查是否已经保存过使用filename作为stored_filename
existing = Image.query.filter_by(
batch_id=batch_id,
stored_filename=filename
).first()
if existing:
logger.info(f"Image already exists: {filename}")
continue
# 创建扰动图片记录直接使用filename因为算法已经添加了perturbed_前缀
perturbed_image = Image(
user_id=batch.user_id,
batch_id=batch_id,
father_id=parent_image.id if parent_image else None,
original_filename=filename,
stored_filename=filename, # 算法输出已经是perturbed_格式
file_path=perturbed_path,
file_size=os.path.getsize(perturbed_path),
image_type_id=perturbed_type.id,
width=width,
height=height
)
db.session.add(perturbed_image)
saved_count += 1
logger.info(f"Saved perturbed image: {filename} (parent: {parent_image.stored_filename if parent_image else 'None'})")
except Exception as e:
logger.error(f"Failed to save {perturbed_path}: {str(e)}")
db.session.commit()
logger.info(f"Successfully saved {saved_count} perturbed images to database")
except Exception as e:
logger.error(f"Error saving perturbed images: {str(e)}")
db.session.rollback()
"""
RQ Worker任务处理器 - 加噪任务
适配新数据库结构: Task + Perturbation + Images
仅支持真实算法移除虚拟算法调用
"""
import os
import subprocess
import logging
import glob
import shutil
from datetime import datetime
from pathlib import Path
from PIL import Image as PILImage
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_perturbation_task(task_id, algorithm_code, epsilon, input_dir, output_dir,
class_dir, custom_params=None):
"""
执行对抗性扰动任务仅使用真实算法
Args:
task_id: 任务ID对应 tasks 表的 tasks_id
algorithm_code: 算法代码 (aspl/simac/caat/pid)
epsilon: 扰动强度
input_dir: 输入图片目录
output_dir: 输出目录
class_dir: 类别图片目录aspl/simac需要
custom_params: 自定义参数字典
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Task, Perturbation, TaskStatus
# 创建应用上下文
app = create_app()
with app.app_context():
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取加噪任务详情
perturbation = Perturbation.query.get(task_id)
if not perturbation:
raise ValueError(f"Perturbation task {task_id} not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
task.tasks_status_id = processing_status.task_status_id
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting perturbation task {task_id}")
logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}")
# 获取算法配置
script_path = AlgorithmConfig.get_script_path(algorithm_code)
conda_env = AlgorithmConfig.get_conda_env(algorithm_code)
# 确保目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(class_dir, exist_ok=True)
# 清空输出目录(避免旧文件残留)
logger.info(f"Clearing output directory: {output_dir}")
if os.path.exists(output_dir):
for item in os.listdir(output_dir):
item_path = os.path.join(output_dir, item)
if os.path.isfile(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 运行真实算法
result = _run_real_algorithm(
script_path, conda_env, algorithm_code, task_id,
epsilon, input_dir, output_dir, class_dir, custom_params
)
# 保存扰动图片到数据库
_save_perturbed_images(task_id, output_dir)
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
if completed_status:
task.tasks_status_id = completed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
logger.info(f"Perturbation task {task_id} completed successfully")
return result
except Exception as e:
logger.error(f"Perturbation task {task_id} failed: {str(e)}", exc_info=True)
# 更新任务状态为失败
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
task.error_message = str(e)
db.session.commit()
raise
def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
epsilon, input_dir, output_dir, class_dir, custom_params):
"""
运行真实算法参考sh脚本配置
Args:
script_path: 算法脚本路径
conda_env: Conda环境名称
algorithm_code: 算法代码
task_id: 任务ID
epsilon: 扰动强度
input_dir: 输入目录
output_dir: 输出目录
class_dir: 类别数据目录
custom_params: 自定义参数
"""
from config.algorithm_config import AlgorithmConfig
from app import db
from app.database import Perturbation, DataType
logger.info(f"Running real algorithm: {algorithm_code}")
logger.info(f"Conda environment: {conda_env}")
logger.info(f"Script path: {script_path}")
# 从数据库获取数据集类型的提示词
perturbation = Perturbation.query.get(task_id)
if not perturbation:
raise ValueError(f"Perturbation task {task_id} not found")
data_type = DataType.query.get(perturbation.data_type_id)
if not data_type:
raise ValueError(f"DataType {perturbation.data_type_id} not found")
# 从data_type_prompt中提取instance_prompt
instance_prompt = data_type.data_type_prompt or 'a photo of sks person'
# 从instance_prompt生成class_prompt移除"sks"
class_prompt = instance_prompt.replace('sks ', '')
logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'")
# 获取默认参数
default_params = AlgorithmConfig.get_default_params(algorithm_code)
# 覆盖提示词参数(从数据库读取)
default_params['instance_prompt'] = instance_prompt
if 'class_prompt' in default_params:
default_params['class_prompt'] = class_prompt
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
# 根据算法构建命令参数参考sh脚本
cmd_args = []
if algorithm_code in ['aspl', 'simac']:
# ASPL和SimAC使用相同的参数结构
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={epsilon}",
])
elif algorithm_code == 'caat':
# CAAT参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={epsilon}",
])
elif algorithm_code == 'pid':
# PID参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={epsilon}",
])
else:
raise ValueError(f"Unsupported algorithm code: {algorithm_code}")
# 添加其他默认参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 设置环境变量
env = os.environ.copy()
env['HF_HUB_OFFLINE'] = '1' # 强制离线模式
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(
log_dir,
f'task_{task_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
env=env
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
logger.info(f"Algorithm execution completed with return code: {process.returncode}")
logger.info(f"Output directory: {output_dir}")
logger.info(f"Log file: {log_file}")
if process.returncode != 0:
raise RuntimeError(
f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}"
)
# 清理class_dir
if algorithm_code in ['aspl', 'simac']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
shutil.rmtree(class_dir)
os.makedirs(class_dir)
return {
'status': 'success',
'output_dir': output_dir,
'log_file': log_file
}
def _save_perturbed_images(task_id, output_dir):
"""
保存扰动图片到数据库适配新数据库结构
新数据库结构
- Task表tasks_id (主键), flow_id, tasks_type_id
- Image表images_id (主键), task_id (外键), image_types_id, father_id
Args:
task_id: 任务ID
output_dir: 扰动图片输出目录
"""
from app import db
from app.database import Task, Image, ImageType
try:
# 验证任务存在
task = Task.query.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# 获取图片类型
original_type = ImageType.query.filter_by(image_code='original').first()
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if not original_type or not perturbed_type:
raise ValueError("Required image types not found in database")
# 获取该任务的所有原始图片(用于建立父子关系)
original_images = Image.query.filter_by(
task_id=task_id,
image_types_id=original_type.image_types_id
).all()
# 创建原图文件名映射
original_map = {img.stored_filename: img for img in original_images}
logger.info(f"Found {len(original_images)} original images for task {task_id}")
# 查找输出目录中的所有图片文件
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp', '*.tiff']
perturbed_files = []
for ext in image_extensions:
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext)))
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Found {len(perturbed_files)} perturbed images in output directory")
saved_count = 0
for perturbed_path in perturbed_files:
try:
perturbed_filename = os.path.basename(perturbed_path)
# 尝试匹配原始图片(建立父子关系)
# 算法可能输出同名文件或带前缀的文件
father_image = None
# 策略1: 完全匹配文件名
if perturbed_filename in original_map:
father_image = original_map[perturbed_filename]
else:
# 策略2: 移除可能的前缀如perturbed_
for prefix in ['perturbed_', 'adv_', 'protected_']:
if perturbed_filename.startswith(prefix):
clean_name = perturbed_filename[len(prefix):]
if clean_name in original_map:
father_image = original_map[clean_name]
break
if not father_image:
logger.warning(f"Could not find father image for {perturbed_filename}, saving without father_id")
# 检查是否已存在
existing = Image.query.filter_by(
task_id=task_id,
stored_filename=perturbed_filename,
image_types_id=perturbed_type.image_types_id
).first()
if existing:
logger.info(f"Image {perturbed_filename} already exists, skipping")
continue
# 读取图片尺寸
width, height = None, None
try:
with PILImage.open(perturbed_path) as img:
width, height = img.size
except Exception as e:
logger.warning(f"Could not read image dimensions for {perturbed_filename}: {e}")
# 保存到数据库
perturbed_image = Image(
task_id=task_id,
image_types_id=perturbed_type.image_types_id,
father_id=father_image.images_id if father_image else None,
stored_filename=perturbed_filename,
file_path=perturbed_path,
file_size=os.path.getsize(perturbed_path),
width=width,
height=height
)
db.session.add(perturbed_image)
saved_count += 1
logger.info(
f"Saved: {perturbed_filename} "
f"(father: {father_image.stored_filename if father_image else 'None'})"
)
except Exception as e:
logger.error(f"Error saving {perturbed_filename}: {str(e)}")
continue
# 提交所有更改
db.session.commit()
logger.info(f"Successfully saved {saved_count}/{len(perturbed_files)} perturbed images")
except Exception as e:
logger.error(f"Error in _save_perturbed_images: {str(e)}", exc_info=True)
db.session.rollback()
raise

@ -40,6 +40,7 @@ class AlgorithmConfig:
'pid': os.getenv('CONDA_ENV_PID', 'pid'),
'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'),
'lora': os.getenv('CONDA_ENV_LORA', 'pid'),
'textual_inversion': os.getenv('CONDA_ENV_TI', 'pid'),
}
# 模型路径配置
@ -157,7 +158,7 @@ class AlgorithmConfig:
# ========== 微调算法配置 ==========
FINETUNE_SCRIPTS = {
'dreambooth': {
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_gen.py'),
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_db_gen_trace.py'),
'virtual_script': None, # 使用虚拟实现在worker中
'conda_env': CONDA_ENVS['dreambooth'],
'default_params': {
@ -169,23 +170,24 @@ class AlgorithmConfig:
'resolution': 512,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
'learning_rate': 1e-4,
'learning_rate': 2e-6,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 1,
'max_train_steps': 1,
'checkpointing_steps': 1,
'num_class_images': 200,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'center_crop': True,
'mixed_precision': 'bf16',
'prior_generation_precision': 'bf16',
'sample_batch_size': 1,
'sample_batch_size': 5,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 1,
'validation_steps': 1
'num_validation_images': 10,
'validation_steps': 500,
'coords_log_interval': 10
}
},
'lora': {
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen.py'),
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen_trace.py'),
'virtual_script': None,
'conda_env': CONDA_ENVS['lora'],
'default_params': {
@ -200,14 +202,40 @@ class AlgorithmConfig:
'learning_rate': 1e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 1,
'max_train_steps': 1,
'checkpointing_steps': 1,
'num_class_images': 200,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'seed': 0,
'mixed_precision': 'fp16',
'rank': 4,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 1
'num_validation_images': 10,
'coords_log_interval': 10
}
},
'textual_inversion': {
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_ti_gen_trace.py'),
'virtual_script': None,
'conda_env': CONDA_ENVS['textual_inversion'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'placeholder_token': 'sks',
'initializer_token': 'person',
'instance_prompt': 'a photo of sks person',
'resolution': 512,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
'learning_rate': 5e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'max_train_steps': 1000,
'checkpointing_steps': 500,
'seed': 0,
'mixed_precision': 'fp16',
'validation_prompt': 'a photo of sks person',
'num_validation_images': 4,
'validation_epochs': 50,
'coords_log_interval': 10
}
}
}
@ -216,3 +244,28 @@ class AlgorithmConfig:
def get_finetune_config(cls, finetune_method):
"""获取微调算法配置"""
return cls.FINETUNE_SCRIPTS.get(finetune_method, {})
# ========== 评估算法配置 ==========
EVALUATE_SCRIPTS = {
'heatmap': {
'real_script': os.path.join(ALGORITHMS_DIR, 'evaluate', 'eva_gen_heatmap.py'),
'virtual_script': None,
'conda_env': CONDA_ENVS['pid'], # 使用与微调相同的环境
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
}
},
'numbers': {
'real_script': os.path.join(ALGORITHMS_DIR, 'evaluate', 'eva_gen_nums.py'),
'virtual_script': None,
'conda_env': CONDA_ENVS['pid'],
'default_params': {
'image_size': 512,
}
}
}
@classmethod
def get_evaluate_config(cls, evaluate_method):
"""获取评估算法配置"""
return cls.EVALUATE_SCRIPTS.get(evaluate_method, {})

@ -51,10 +51,18 @@ class Config:
# 图像处理配置
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录、
MODEL_UPLOADED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'uploaded') # 上传图的模型生成结果
MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
# 微调训练相关配置
CLASS_DATA_FOLDER = os.path.join(STATIC_ROOT, 'class') # 类别数据目录(用于 prior preservation
# 可视化与分析配置
EVA_RES_FOLDER = os.path.join(STATIC_ROOT, 'eva_res') # 评估结果根目录
COORDS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 3D坐标可视化数据用于训练轨迹
POSITIONS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 位置数据与coords相同LoRA使用未使用
HEATDIF_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'heatdif') # 热力图差异数据
NUMBERS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'numbers') # 数值结果数据
# 预设演示图像配置
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录

@ -15,9 +15,9 @@ def init_database():
# 初始化角色数据
roles = [
{'role_id': 0, 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'},
{'role_id': 1, 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'},
{'role_id': 2, 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'}
{'role_id': 1, 'role_code': 'admin', 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'},
{'role_id': 2, 'role_code': 'vip', 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'},
{'role_id': 3, 'role_code': 'normal', 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'}
]
for role_data in roles:
existing = Role.query.filter_by(role_id=role_data['role_id']).first()
@ -26,29 +26,31 @@ def init_database():
db.session.add(new_role)
# 初始化任务状态数据
task_status = [
{'status_code': 'waiting', 'status_name': '待处理', 'description': '任务已创建,等待处理'},
{'status_code': 'processing', 'status_name': '进行中', 'description': '任务正在处理中'},
{'status_code': 'completed', 'status_name': '已完成', 'description':'任务已成功完成'},
{'status_code': 'failed', 'status_name': '失败', 'description': '任务处理失败'}
task_statuses = [
{'task_status_code': 'waiting', 'task_status_name': '待处理', 'description': '任务已创建,等待处理'},
{'task_status_code': 'processing', 'task_status_name': '进行中', 'description': '任务正在处理中'},
{'task_status_code': 'completed', 'task_status_name': '已完成', 'description':'任务已成功完成'},
{'task_status_code': 'failed', 'task_status_name': '失败', 'description': '任务处理失败'}
]
for status in task_status:
existing = TaskStatus.query.filter_by(status_code=status['status_code']).first()
for status in task_statuses:
existing = TaskStatus.query.filter_by(task_status_code=status['task_status_code']).first()
if not existing:
new_status = TaskStatus(**status)
db.session.add(new_status)
# 初始化图片类型数据
image_types = [
{'image_code': 'original', 'image_name': '原始图片', 'description': '用户上传的原始图像文件'},
{'image_code': 'perturbed', 'image_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'},
{'image_code': 'original_generate', 'image_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'},
{'image_code': 'perturbed_generate', 'image_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'},
{'image_code': 'heatmap', 'image_name': '生成的热力图', 'description': '热力图'}
{'image_code': 'original', 'image_name': '原始图', 'description': '用户上传的原始图像'},
{'image_code': 'perturbed', 'image_name': '加噪图', 'description': '经过扰动算法处理后的防护图像'},
{'image_code': 'uploaded_generate', 'image_name': '上传图片生成图', 'description': '使用上传图片训练后生成的图像'},
{'image_code': 'original_generate', 'image_name': '原始图像生成图', 'description': '使用原始图像训练后生成的图像'},
{'image_code': 'perturbed_generate', 'image_name': '加噪图像生成图', 'description': '使用加噪图像训练后生成的图像'},
{'image_code': 'heatmap', 'image_name': '热力图', 'description': '原始图与加噪图的差异热力图'},
{'image_code': 'report', 'image_name': '报告图', 'description': '任务评估指标可视化图表'}
]
for img_type in image_types:
existing = ImageType.query.filter_by(type_code=img_type['type_code']).first()
existing = ImageType.query.filter_by(image_code=img_type['image_code']).first()
if not existing:
new_type = ImageType(**img_type)
db.session.add(new_type)
@ -62,7 +64,7 @@ def init_database():
]
for config in perturbation_configs:
existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first()
existing = PerturbationConfig.query.filter_by(perturbation_code=config['perturbation_code']).first()
if not existing:
new_config = PerturbationConfig(**config)
db.session.add(new_config)
@ -75,47 +77,53 @@ def init_database():
]
for config in finetune_configs:
existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first()
existing = FinetuneConfig.query.filter_by(finetune_code=config['finetune_code']).first()
if not existing:
new_config = FinetuneConfig(**config)
db.session.add(new_config)
# 初始化数据集类型数据
dataset_types = [
{'data_type_id': 0, 'dataset_code': 'facial', 'dataset_name': '人脸数据集', 'description': '人脸类型的数据集'},
{'data_type_id': 1, 'dataset_code': 'art', 'dataset_name': '艺术品数据集', 'description': '艺术品类型的数据集'}
data_types = [
{'data_type_code': 'facial', 'data_type_prompt': 'a photo of sks person', 'description': '人脸类型的数据集'},
{'data_type_code': 'art', 'data_type_prompt': 'a painting in the style of sks', 'description': '艺术品类型的数据集'}
]
for dataset in dataset_types:
existing = DataType.query.filter_by(data_type_id=dataset['data_type_id']).first()
for data_type in data_types:
existing = DataType.query.filter_by(data_type_code=data_type['data_type_code']).first()
if not existing:
new_dataset = DataType(**dataset)
db.session.add(new_dataset)
new_data_type = DataType(**data_type)
db.session.add(new_data_type)
# 初始化任务类型数据
# 初始化任务类型数据(按执行逻辑顺序排列)
task_types = [
{'task_type_id': 0, 'task_code': 'perturbation', 'task_name': '加噪任务', 'description': '对图像进行加噪处理的任务'},
{'task_type_id': 1, 'task_code': 'finetune', 'task_name': '微调任务', 'description': '对模型进行微调训练的任务'},
{'task_type_id': 2, 'task_code': 'generation', 'task_name': '生成任务', 'description': '利用微调后模型进行图像生成的任务'}
{'task_type_id': 3, 'task_code': 'heatmap', 'task_name': '热力图任务', 'description': '计算X和X的热力图的任务'}
{'task_type_code': 'perturbation', 'task_type_name': '加噪任务', 'description': '对图像进行扰动处理,生成防护图像'},
{'task_type_code': 'heatmap', 'task_type_name': '热力图任务', 'description': '可视化原始图与加噪图的差异热力图'},
{'task_type_code': 'finetune', 'task_type_name': '微调任务', 'description': '使用图像数据集对模型进行微调训练'},
{'task_type_code': 'evaluate', 'task_type_name': '评估任务', 'description': '评估微调后模型的生成效果和防护性能'}
]
for task_type in task_types:
existing = TaskType.query.filter_by(task_type_code=task_type['task_type_code']).first()
if not existing:
new_task_type = TaskType(**task_type)
db.session.add(new_task_type)
# 创建默认管理员用户
admin_users = [
{'username': 'admin1', 'email': 'admin1@museguard.com', 'role_id': 0},
{'username': 'admin2', 'email': 'admin2@museguard.com', 'role_id': 0},
{'username': 'admin3', 'email': 'admin3@museguard.com', 'role_id': 0}
# 创建默认测试用户(三种角色各一个)
test_users = [
{'username': 'admin_test', 'email': 'admin@test.com', 'password': 'admin123', 'role_id': 1},
{'username': 'vip_test', 'email': 'vip@test.com', 'password': 'vip123', 'role_id': 2},
{'username': 'normal_test', 'email': 'normal@test.com', 'password': 'normal123', 'role_id': 3}
]
for admin_data in admin_users:
existing = User.query.filter_by(username=admin_data['username']).first()
for user_data in test_users:
existing = User.query.filter_by(username=user_data['username']).first()
if not existing:
admin_user = User(**admin_data)
admin_user.set_password('admin123') # 默认密码
db.session.add(admin_user)
password = user_data.pop('password') # 取出密码
test_user = User(**user_data)
test_user.set_password(password)
db.session.add(test_user)
# 为管理员创建默认配置
# 为测试用户创建默认配置
db.session.flush() # 确保user.id可用
user_config = UserConfig(user_id=admin_user.id)
user_config = UserConfig(user_id=test_user.user_id)
db.session.add(user_config)
# 提交所有更改

Loading…
Cancel
Save