From 6c40a2f8f52ca55a364d1bb2e2e18b3296290c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 00:46:12 +0800 Subject: [PATCH 01/14] =?UTF-8?q?feat:=20=20=E4=BF=AE=E6=94=B9=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/init_db.py | 91 +++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/src/backend/init_db.py b/src/backend/init_db.py index 611e415..9f15d70 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -15,9 +15,9 @@ def init_database(): # 初始化角色数据 roles = [ - {'role_id': 0, 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'}, - {'role_id': 1, 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'}, - {'role_id': 2, 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'} + {'role_id': 0, 'role_code': 'admin', 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'}, + {'role_id': 1, 'role_code': 'vip', 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'}, + {'role_id': 2, 'role_code': 'normal', 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'} ] for role_data in roles: existing = Role.query.filter_by(role_id=role_data['role_id']).first() @@ -26,29 +26,30 @@ def init_database(): db.session.add(new_role) # 初始化任务状态数据 - task_status = [ - {'status_code': 'waiting', 'status_name': '待处理', 'description': '任务已创建,等待处理'}, - {'status_code': 'processing', 'status_name': '进行中', 'description': '任务正在处理中'}, - {'status_code': 'completed', 'status_name': '已完成', 'description':'任务已成功完成'}, - {'status_code': 'failed', 'status_name': '失败', 'description': '任务处理失败'} + task_statuses = [ + {'task_status_code': 'waiting', 'task_status_name': '待处理', 'description': '任务已创建,等待处理'}, + {'task_status_code': 'processing', 'task_status_name': '进行中', 'description': '任务正在处理中'}, + {'task_status_code': 'completed', 'task_status_name': '已完成', 'description':'任务已成功完成'}, + {'task_status_code': 'failed', 'task_status_name': '失败', 'description': '任务处理失败'} ] - for status in task_status: - existing = TaskStatus.query.filter_by(status_code=status['status_code']).first() + for status in task_statuses: + existing = TaskStatus.query.filter_by(task_status_code=status['task_status_code']).first() if not existing: new_status = TaskStatus(**status) db.session.add(new_status) # 初始化图片类型数据 image_types = [ - {'image_code': 'original', 'image_name': '原始图片', 'description': '用户上传的原始图像文件'}, - {'image_code': 'perturbed', 'image_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, - {'image_code': 'original_generate', 'image_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, - {'image_code': 'perturbed_generate', 'image_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'}, - {'image_code': 'heatmap', 'image_name': '生成的热力图', 'description': '热力图'} + {'image_code': 'original', 'image_name': '原始图', 'description': '用户上传的原始图像'}, + {'image_code': 'perturbed', 'image_name': '加噪图', 'description': '经过扰动算法处理后的防护图像'}, + {'image_code': 'original_generate', 'image_name': '原始图像生成图', 'description': '使用原始图像训练后生成的图像'}, + {'image_code': 'perturbed_generate', 'image_name': '加噪图像生成图', 'description': '使用加噪图像训练后生成的图像'}, + {'image_code': 'heatmap', 'image_name': '热力图', 'description': '原始图与加噪图的差异热力图'}, + {'image_code': 'report', 'image_name': '报告图', 'description': '任务评估指标可视化图表'} ] for img_type in image_types: - existing = ImageType.query.filter_by(type_code=img_type['type_code']).first() + existing = ImageType.query.filter_by(image_code=img_type['image_code']).first() if not existing: new_type = ImageType(**img_type) db.session.add(new_type) @@ -62,7 +63,7 @@ def init_database(): ] for config in perturbation_configs: - existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first() + existing = PerturbationConfig.query.filter_by(perturbation_code=config['perturbation_code']).first() if not existing: new_config = PerturbationConfig(**config) db.session.add(new_config) @@ -75,47 +76,53 @@ def init_database(): ] for config in finetune_configs: - existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first() + existing = FinetuneConfig.query.filter_by(finetune_code=config['finetune_code']).first() if not existing: new_config = FinetuneConfig(**config) db.session.add(new_config) # 初始化数据集类型数据 - dataset_types = [ - {'data_type_id': 0, 'dataset_code': 'facial', 'dataset_name': '人脸数据集', 'description': '人脸类型的数据集'}, - {'data_type_id': 1, 'dataset_code': 'art', 'dataset_name': '艺术品数据集', 'description': '艺术品类型的数据集'} + data_types = [ + {'data_type_code': 'facial', 'data_type_prompt': 'a photo of sks person', 'description': '人脸类型的数据集'}, + {'data_type_code': 'art', 'data_type_prompt': 'a painting in the style of sks', 'description': '艺术品类型的数据集'} ] - for dataset in dataset_types: - existing = DataType.query.filter_by(data_type_id=dataset['data_type_id']).first() + for data_type in data_types: + existing = DataType.query.filter_by(data_type_code=data_type['data_type_code']).first() if not existing: - new_dataset = DataType(**dataset) - db.session.add(new_dataset) + new_data_type = DataType(**data_type) + db.session.add(new_data_type) - # 初始化任务类型数据 + # 初始化任务类型数据(按执行逻辑顺序排列) task_types = [ - {'task_type_id': 0, 'task_code': 'perturbation', 'task_name': '加噪任务', 'description': '对图像进行加噪处理的任务'}, - {'task_type_id': 1, 'task_code': 'finetune', 'task_name': '微调任务', 'description': '对模型进行微调训练的任务'}, - {'task_type_id': 2, 'task_code': 'generation', 'task_name': '生成任务', 'description': '利用微调后模型进行图像生成的任务'} - {'task_type_id': 3, 'task_code': 'heatmap', 'task_name': '热力图任务', 'description': '计算X和X’的热力图的任务'} + {'task_type_code': 'perturbation', 'task_type_name': '加噪任务', 'description': '对图像进行扰动处理,生成防护图像'}, + {'task_type_code': 'heatmap', 'task_type_name': '热力图任务', 'description': '可视化原始图与加噪图的差异热力图'}, + {'task_type_code': 'finetune', 'task_type_name': '微调任务', 'description': '使用图像数据集对模型进行微调训练'}, + {'task_type_code': 'evaluate', 'task_type_name': '评估任务', 'description': '评估微调后模型的生成效果和防护性能'} ] + for task_type in task_types: + existing = TaskType.query.filter_by(task_type_code=task_type['task_type_code']).first() + if not existing: + new_task_type = TaskType(**task_type) + db.session.add(new_task_type) - # 创建默认管理员用户 - admin_users = [ - {'username': 'admin1', 'email': 'admin1@museguard.com', 'role_id': 0}, - {'username': 'admin2', 'email': 'admin2@museguard.com', 'role_id': 0}, - {'username': 'admin3', 'email': 'admin3@museguard.com', 'role_id': 0} + # 创建默认测试用户(三种角色各一个) + test_users = [ + {'username': 'admin_test', 'email': 'admin@test.com', 'password': 'admin123', 'role_id': 0}, + {'username': 'vip_test', 'email': 'vip@test.com', 'password': 'vip123', 'role_id': 1}, + {'username': 'normal_test', 'email': 'normal@test.com', 'password': 'normal123', 'role_id': 2} ] - for admin_data in admin_users: - existing = User.query.filter_by(username=admin_data['username']).first() + for user_data in test_users: + existing = User.query.filter_by(username=user_data['username']).first() if not existing: - admin_user = User(**admin_data) - admin_user.set_password('admin123') # 默认密码 - db.session.add(admin_user) + password = user_data.pop('password') # 取出密码 + test_user = User(**user_data) + test_user.set_password(password) + db.session.add(test_user) - # 为管理员创建默认配置 + # 为测试用户创建默认配置 db.session.flush() # 确保user.id可用 - user_config = UserConfig(user_id=admin_user.id) + user_config = UserConfig(user_id=test_user.user_id) db.session.add(user_config) # 提交所有更改 -- 2.34.1 From 2b8807fcddbd4dfe8bd0d4254ea538e2ac9026f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 00:46:54 +0800 Subject: [PATCH 02/14] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0.gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 86aa95f..aa202e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,7 @@ +# Python 编译缓存 __pycache__/ -venv/ -python=3.11/ - +# 图片文件 *.png *.jpg *.jpeg @@ -10,11 +9,21 @@ python=3.11/ # 环境配置文件(包含敏感信息) *.env -# 日志文件 +# 日志及进程文件 logs/ *.log +*.pid # 上传文件临时目录 uploads/ -.github/ \ No newline at end of file +# 微调生成文件 +*.json +*.bin +*.pkl +*.safetensors +*.pt +*.txt + +# vscode 配置 +.vscode/ \ No newline at end of file -- 2.34.1 From 7f8f74e3a659043a737d894b13f1b195a1ef56d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 04:29:40 +0800 Subject: [PATCH 03/14] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=AE=97?= =?UTF-8?q?=E6=B3=95=E6=96=87=E4=BB=B6=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E5=8F=8D=E8=BD=AC=E7=9A=84=E9=85=8D=E7=BD=AE=E4=BF=A1?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../algorithms/evaluate/eva_gen_heatmap.py | 520 ++++++ .../app/algorithms/evaluate/eva_gen_nums.py | 513 ++++++ src/backend/app/algorithms/finetune/infer.py | 87 - ...reambooth_gen.py => train_db_gen_trace.py} | 79 +- .../finetune/train_dreambooth_alone.py | 1035 ------------ ...in_lora_gen.py => train_lora_gen_trace.py} | 92 +- .../algorithms/finetune/train_ti_gen_trace.py | 1404 +++++++++++++++++ .../algorithms/processor/coords_processor.py | 132 ++ .../algorithms/processor/image_processor.py | 149 ++ src/backend/app/services/auth_service.py | 34 - src/backend/app/services/task_service.py | 29 +- src/backend/app/workers/finetune_worker.py | 44 +- src/backend/config/algorithm_config.py | 54 +- src/backend/config/settings.py | 6 + 14 files changed, 2989 insertions(+), 1189 deletions(-) create mode 100644 src/backend/app/algorithms/evaluate/eva_gen_heatmap.py create mode 100644 src/backend/app/algorithms/evaluate/eva_gen_nums.py delete mode 100644 src/backend/app/algorithms/finetune/infer.py rename src/backend/app/algorithms/finetune/{train_dreambooth_gen.py => train_db_gen_trace.py} (93%) delete mode 100644 src/backend/app/algorithms/finetune/train_dreambooth_alone.py rename src/backend/app/algorithms/finetune/{train_lora_gen.py => train_lora_gen_trace.py} (92%) create mode 100644 src/backend/app/algorithms/finetune/train_ti_gen_trace.py create mode 100644 src/backend/app/algorithms/processor/coords_processor.py create mode 100644 src/backend/app/algorithms/processor/image_processor.py delete mode 100644 src/backend/app/services/auth_service.py diff --git a/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py b/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py new file mode 100644 index 0000000..fb27936 --- /dev/null +++ b/src/backend/app/algorithms/evaluate/eva_gen_heatmap.py @@ -0,0 +1,520 @@ +"""Stable Diffusion 注意力热力图差异可视化工具 (可靠版 - 语义阶段聚合)。 + +本模块使用一种健壮的方法,通过在 Stable Diffusion 扩散模型(U-Net)的 +**早期时间步 (语义阶段)** 捕获并累加交叉注意力权重。这种方法能确保捕获到的 +注意力图信号集中且可靠,用于对比分析干净输入和扰动输入生成的图像对模型 +注意力机制的影响差异。 + +典型用法: + python eva_gen_heatmap.py \\ + --model_path /path/to/sd_model \\ + --image_path_a /path/to/clean_image.png \\ + --image_path_b /path/to/noisy_image.png \\ + --prompt_text "a photo of sks person" \\ + --target_word "sks" \\ + --output_dir output/heatmap_reports +""" + +# 通用参数解析与文件路径管理 +import argparse +import os +from pathlib import Path +from typing import Dict, Any, List, Tuple + +# 数值计算与深度学习依赖 +import torch +import torch.nn.functional as F +import numpy as np +import itertools +import warnings + +# 可视化依赖 +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.colors import TwoSlopeNorm +from mpl_toolkits.axes_grid1 import make_axes_locatable + +# Diffusers 与 Transformers 依赖 +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler +from diffusers.models.attention_processor import Attention +from transformers import CLIPTokenizer + +# 图像处理与数据读取 +from PIL import Image +from torchvision import transforms + +# 抑制非必要的警告输出 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# ============== 核心模块:注意力捕获与聚合 ============== + +class AttentionMapProcessor: + """自定义注意力处理器,用于捕获 U-Net 交叉注意力层的权重。 + + 通过替换原始的 `Attention` 模块处理器,该类在模型前向传播过程中, + 将所有交叉注意力层的注意力权重(`attention_probs`)捕获并存储。 + + Attributes: + attention_maps (Dict[str, List[torch.Tensor]]): 存储捕获到的注意力图, + 键为层名称,值为该层在不同时间步捕获到的注意力图列表。 + pipeline (StableDiffusionPipeline): 正在处理的 Stable Diffusion 管线。 + original_processors (Dict[str, Any]): 存储原始的注意力处理器,用于恢复。 + current_layer_name (Optional[str]): 当前正在处理的注意力层的名称。 + """ + + def __init__(self, pipeline: StableDiffusionPipeline): + """初始化注意力处理器。 + + Args: + pipeline: Stable Diffusion 模型管线实例。 + """ + self.attention_maps: Dict[str, List[torch.Tensor]] = {} + self.pipeline = pipeline + self.original_processors = {} + self.current_layer_name = None + self._set_processors() + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None + ) -> torch.Tensor: + """重载 __call__ 方法,执行注意力计算并捕获权重。 + + 此方法替代了原始的 `Attention.processor`,在计算交叉注意力时进行捕获。 + + Args: + attn: 当前的 `Attention` 模块实例。 + hidden_states: U-Net 隐状态 (query)。 + encoder_hidden_states: 文本编码器输出 (key/value),即交叉注意力输入。 + attention_mask: 注意力掩码。 + + Returns: + 计算后的输出隐状态。 + """ + # 如果不是交叉注意力(即 encoder_hidden_states 为 None),则调用原始处理器 + if encoder_hidden_states is None: + return attn.processor( + attn, hidden_states, encoder_hidden_states, attention_mask + ) + + # 1. 计算 Q, K, V + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # 2. 准备矩阵乘法 + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + + # 3. 计算 Attention Scores (Q @ K^T) + attention_scores = torch.baddbmm( + torch.empty( + query.shape[0], query.shape[1], key.shape[1], + dtype=query.dtype, device=query.device + ), + query, + key.transpose(1, 2), + beta=0, + alpha=attn.scale, + ) + + # 4. 计算 Attention Probabilities + attention_probs = attention_scores.softmax(dim=-1) + layer_name = self.current_layer_name + + # 5. 存储捕获的注意力图 + if layer_name not in self.attention_maps: + self.attention_maps[layer_name] = [] + + # 存储当前时间步的注意力权重 + self.attention_maps[layer_name].append(attention_probs.detach().cpu()) + + # 6. 计算输出 (Attention @ V) + value = attn.head_to_batch_dim(value) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # 7. 输出层 + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + def _set_processors(self): + """注册自定义处理器,捕获 U-Net 中所有交叉注意力层的权重。 + + 遍历 U-Net 的所有子模块,找到所有交叉注意力层(`Attention` 且名称包含 `attn2`), + 并将其处理器替换为当前的实例。 + """ + for name, module in self.pipeline.unet.named_modules(): + if isinstance(module, Attention) and 'attn2' in name: + # 存储原始处理器以便后续恢复 + self.original_processors[name] = module.processor + + # 定义一个新的闭包函数,用于在调用前设置当前层的名称 + def set_layer_name(current_name): + def new_call(*args, **kwargs): + self.current_layer_name = current_name + return self.__call__(*args, **kwargs) + return new_call + + module.processor = set_layer_name(name) + + def remove(self): + """恢复 U-Net 的原始注意力处理器,清理钩子。""" + for name, original_processor in self.original_processors.items(): + module = self.pipeline.unet.get_submodule(name) + module.processor = original_processor + self.attention_maps = {} + + +def aggregate_word_attention( + attention_maps: Dict[str, List[torch.Tensor]], + tokenizer: CLIPTokenizer, + target_word: str, + input_ids: torch.Tensor +) -> np.ndarray: + """聚合所有层和语义时间步中目标词汇的注意力图,并进行归一化。 + + 聚合步骤: + 1. 识别目标词汇对应的 Token 索引。 + 2. 对每个层:将所有捕获时间步的注意力图求平均。 + 3. 提取目标 Token 对应的注意力子图,并对 Token 维度求和,对 Attention Heads 求平均。 + 4. 将不同分辨率的注意力图上采样到统一尺寸(64x64)。 + 5. 对所有层的结果进行累加(求和)。 + 6. 最终归一化到 [0, 1]。 + + Args: + attention_maps: 包含各层和时间步捕获的注意力图的字典。 + tokenizer: CLIP 分词器实例。 + target_word: 需要聚焦的关键词。 + input_ids: Prompt 对应的 Token ID 张量。 + + Returns: + 最终聚合并上采样到 64x64 尺寸的注意力热力图 (NumPy 数组)。 + + Raises: + ValueError: 如果无法在 Prompt 中找到目标词汇。 + RuntimeError: 如果未捕获到任何注意力数据。 + """ + + # 1. 识别目标词汇的 Token 索引 + prompt_tokens = tokenizer.convert_ids_to_tokens( + input_ids.squeeze().cpu().tolist() + ) + target_lower = target_word.lower() + target_indices = [] + + for i, token in enumerate(prompt_tokens): + cleaned_token = token.replace('Ġ', '').replace('_', '').lower() + # 查找目标词汇或以目标词汇开头的 token 索引,并排除特殊 token + if (input_ids.squeeze()[i] not in tokenizer.all_special_ids and + (target_lower in cleaned_token or + cleaned_token.startswith(target_lower))): + target_indices.append(i) + + if not target_indices: + print(f"[WARN] 目标词汇 '{target_word}' 未识别。请检查 Prompt 或 Target Word。") + raise ValueError("无法识别目标词汇的 token 索引。") + + # 2. 聚合逻辑 + all_attention_data = [] + # U-Net 输出的最大分辨率(64x64),总像素点数 + TARGET_SPATIAL_SIZE = 4096 + TARGET_MAP_SIZE = 64 + + for layer_name, step_maps in attention_maps.items(): + if not step_maps: + continue + + # 对该层捕获的所有时间步求平均,形状: (batch, heads, spatial_res, target_tokens_len) + avg_map_over_time = torch.stack(step_maps).mean(dim=0) + + # 移除批次维度 (假设 batch size = 1),形状: (heads, spatial_res, target_tokens_len) + attention_map = avg_map_over_time.squeeze(0) + + # 提取目标 token 的注意力图。形状: (heads, spatial_res, target_indices_len) + target_token_maps = attention_map[:, :, target_indices] + + # 对目标 token 求和 (dim=-1),对注意力头求平均 (dim=0),形状: (spatial_res,) + aggregated_map_flat = target_token_maps.sum(dim=-1).mean(dim=0).float() + + # 3. 跨分辨率上采样 + if aggregated_map_flat.shape[0] != TARGET_SPATIAL_SIZE: + # 当前图的尺寸:16x16 (256) 或 32x32 (1024) + map_size = int(np.sqrt(aggregated_map_flat.shape[0])) + map_2d = aggregated_map_flat.reshape(map_size, map_size) + map_to_interp = map_2d.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] + + # 使用双线性插值上采样到 64x64 + resized_map_2d = F.interpolate( + map_to_interp, + size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE), + mode='bilinear', + align_corners=False + ) + resized_map_flat = resized_map_2d.squeeze().flatten() + all_attention_data.append(resized_map_flat) + else: + # 如果已经是 64x64,直接使用 + all_attention_data.append(aggregated_map_flat) + + if not all_attention_data: + raise RuntimeError("未捕获到注意力数据。可能模型或参数设置有误。") + + # 4. 对所有层的结果进行累加 (求和) + final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy() + + # 5. 最终归一化到 [0, 1] + final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6) + + map_size = int(np.sqrt(final_map_flat.shape[0])) + final_map_np = final_map_flat.reshape(map_size, map_size) # 64x64 + + return final_map_np + + +def get_attention_map_from_image( + pipeline: StableDiffusionPipeline, + image_path: str, + prompt_text: str, + target_word: str +) -> Tuple[Image.Image, np.ndarray]: + """执行多时间步前向传播,捕获指定图片和 Prompt 的注意力图。 + + 通过只运行扩散过程中的语义阶段(早期时间步)来确保捕获到的注意力权重 + 具有高信号质量。 + + Args: + pipeline: Stable Diffusion 模型管线实例。 + image_path: 待处理的输入图片路径。 + prompt_text: 用于生成图片的 Prompt 文本。 + target_word: 需要聚焦和可视化的关键词。 + + Returns: + 包含 (原始图片, 最终上采样后的注意力图) 的元组。 + """ + print(f"\n-> 正在处理图片: {Path(image_path).name}") + image = Image.open(image_path).convert("RGB").resize((512, 512)) + image_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + image_tensor = ( + image_transform(image) + .unsqueeze(0) + .to(pipeline.device) + .to(pipeline.unet.dtype) + ) + + # 1. 编码到 Latent 空间 + with torch.no_grad(): + latent = ( + pipeline.vae.encode(image_tensor).latent_dist.sample() * + pipeline.vae.config.scaling_factor + ) + + # 2. 编码 Prompt + text_input = pipeline.tokenizer( + prompt_text, + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt" + ) + input_ids = text_input.input_ids + + with torch.no_grad(): + # 获取文本嵌入 + prompt_embeds = pipeline.text_encoder( + input_ids.to(pipeline.device) + )[0] + + # 3. 定义语义时间步 + scheduler = pipeline.scheduler + # 设置扩散步数 (例如 50 步) + scheduler.set_timesteps(50, device=pipeline.device) + + # 只选择语义最丰富的早期 10 步进行捕获 + semantic_steps = scheduler.timesteps[:10] + print(f"-> 正在对语义阶段的 {len(semantic_steps)} 个时间步进行注意力捕获...") + + + processor = AttentionMapProcessor(pipeline) + + try: + # 4. 运行多步 UNet Forward Pass + with torch.no_grad(): + # 在选定的语义时间步上运行 U-Net 预测 + for t in semantic_steps: + pipeline.unet(latent, t, prompt_embeds, return_dict=False) + + # 5. 聚合捕获到的数据 + raw_map_np = aggregate_word_attention( + processor.attention_maps, + pipeline.tokenizer, + target_word, + input_ids + ) + except Exception as e: + print(f"[ERROR] 注意力聚合失败: {e}") + # 确保清理钩子 + raw_map_np = np.zeros(image.size) + finally: + processor.remove() + + # 6. 注意力图上采样到图片尺寸 (512x512) + # PIL 进行上采样 + heat_map_pil = Image.fromarray((raw_map_np * 255).astype(np.uint8)) + heat_map_np_resized = ( + np.array(heat_map_pil.resize( + image.size, + resample=Image.Resampling.LANCZOS # 使用高质量的 Lanczos 滤波器 + )) / 255.0 + ) + + return image, heat_map_np_resized + + +def main(): + """主函数,负责解析参数,加载模型,计算差异并生成可视化报告。""" + parser = argparse.ArgumentParser(description="SD 图片注意力差异可视化报告生成") + parser.add_argument("--model_path", type=str, required=True, + help="Stable Diffusion 模型本地路径。") + parser.add_argument("--image_path_a", type=str, required=True, + help="干净输入图片 (X) 路径。") + parser.add_argument("--image_path_b", type=str, required=True, + help="扰动输入图片 (X') 路径。") + parser.add_argument("--prompt_text", type=str, default="a photo of sks person", + help="用于生成图片的 Prompt 文本。") + parser.add_argument("--target_word", type=str, default="sks", + help="需要在注意力图中聚焦和可视化的关键词。") + parser.add_argument("--output_dir", type=str, default="output", + help="报告 PNG 文件的输出目录。") + args = parser.parse_args() + + print(f"--- 正在生成 Stable Diffusion 注意力差异报告 ---") + + # ---------------- 准备模型 ---------------- + device = 'cuda' if torch.cuda.is_available() else 'cpu' + dtype = torch.float16 if device == 'cuda' else torch.float32 + + try: + # 加载 Stable Diffusion 管线 + pipe = StableDiffusionPipeline.from_pretrained( + args.model_path, + torch_dtype=dtype, + local_files_only=True, + safety_checker=None, + # 从子文件夹加载调度器配置 + scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler") + ).to(device) + except Exception as e: + print(f"[ERROR] 模型加载失败,请检查路径和环境依赖: {e}") + return + + # ---------------- 获取数据 ---------------- + # 获取干净图片 A 的注意力图 M_A + img_A, map_A = get_attention_map_from_image(pipe, args.image_path_a, args.prompt_text, args.target_word) + # 获取扰动图片 B 的注意力图 M_B + img_B, map_B = get_attention_map_from_image(pipe, args.image_path_b, args.prompt_text, args.target_word) + + if map_A.shape != map_B.shape: + print("错误:注意力图尺寸不匹配。中止处理。") + return + + # 计算差异图: Delta = M_A - M_B + diff_map = map_A - map_B + # 计算 L2 范数(差异距离) + l2_diff = np.linalg.norm(diff_map) + print(f"\n计算完毕,注意力图的 L2 范数差异值: {l2_diff:.4f}") + + # ---------------- 绘制专业报告 ---------------- + + # 设置 Matplotlib 字体样式 + plt.rcParams.update({ + 'font.family': 'serif', + 'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'], + 'mathtext.fontset': 'cm' + }) + + fig = plt.figure(figsize=(12, 16), dpi=120) + + # 3行 x 4列 网格布局,用于图片和图例的精确控制 + gs = gridspec.GridSpec(3, 4, figure=fig, + height_ratios=[1, 1, 1.3], + hspace=0.3, wspace=0.1) + + # --- 第一行:原始图片 --- + ax_img_a = fig.add_subplot(gs[0, 0:2]) + ax_img_b = fig.add_subplot(gs[0, 2:4]) + + # 干净图片 + ax_img_a.imshow(img_A) + ax_img_a.set_title(f"Clean Image ($X$)\nFilename: {Path(args.image_path_a).name}", fontsize=14, pad=10) + ax_img_a.axis('off') + + # 扰动图片 + ax_img_b.imshow(img_B) + ax_img_b.set_title(f"Noisy Image ($X'$)\nFilename: {Path(args.image_path_b).name}", fontsize=14, pad=10) + ax_img_b.axis('off') + + # --- 第二行:注意力热力图 (Jet配色) --- + ax_map_a = fig.add_subplot(gs[1, 0:2]) + ax_map_b = fig.add_subplot(gs[1, 2:4]) + + # 注意力图 A + im_map_a = ax_map_a.imshow(map_A, cmap='jet', vmin=0, vmax=1) + ax_map_a.set_title(f"Attention Heatmap ($M_X$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10) + ax_map_a.axis('off') + + # 注意力图 B + im_map_b = ax_map_b.imshow(map_B, cmap='jet', vmin=0, vmax=1) + ax_map_b.set_title(f"Attention Heatmap ($M_{{X'}}$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10) + ax_map_b.axis('off') + + # 为注意力图 B 绘制颜色指示条 + divider = make_axes_locatable(ax_map_b) + cax_map = divider.append_axes("right", size="5%", pad=0.05) + cbar1 = fig.colorbar(im_map_b, cax=cax_map) + cbar1.set_label('Attention Intensity', fontsize=10) + + # --- 第三行:差异对比 (完美居中) --- + # 差异图在网格的中间两列 + ax_diff = fig.add_subplot(gs[2, 1:3]) + + vmax_diff = np.max(np.abs(diff_map)) + # 使用 TwoSlopeNorm 确保 0 值位于色条中央 + norm_diff = TwoSlopeNorm(vmin=-vmax_diff, vcenter=0., vmax=vmax_diff) + + # 使用 Coolwarm 配色,蓝色表示负差异 (M_X' > M_X),红色表示正差异 (M_X > M_X') + im_diff = ax_diff.imshow(diff_map, cmap='coolwarm', norm=norm_diff) + + title_text = ( + r"Difference Map: $\Delta = M_X - M_{X'}$" + + f"\n$L_2$ Norm Distance: $\mathbf{{{l2_diff:.4f}}}$" + ) + ax_diff.set_title(title_text, fontsize=16, pad=12) + ax_diff.axis('off') + + # 差异图颜色指示条 (居中对齐) + cbar2 = fig.colorbar(im_diff, ax=ax_diff, fraction=0.046, pad=0.04) + cbar2.set_label(r'Scale: Red ($+$) $\leftrightarrow$ Blue ($-$)', fontsize=12) + + # ---------------- 整体修饰与保存 ---------------- + fig.suptitle(f"Museguard: SD Attention Analysis Report", fontsize=20, fontweight='bold', y=0.95) + + output_filename = "heatmap_dif.png" + output_path = Path(args.output_dir) / output_filename + output_path.parent.mkdir(parents=True, exist_ok=True) + + plt.savefig(output_path, bbox_inches='tight', facecolor='white') + print(f"\n专业分析报告已保存至:\n{output_path.resolve()}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/backend/app/algorithms/evaluate/eva_gen_nums.py b/src/backend/app/algorithms/evaluate/eva_gen_nums.py new file mode 100644 index 0000000..31041bd --- /dev/null +++ b/src/backend/app/algorithms/evaluate/eva_gen_nums.py @@ -0,0 +1,513 @@ +"""图像生成质量多维度评估工具 (专业重构版)。 + +本脚本用于对比评估两组图像(Clean vs Perturbed)的生成质量。 +支持生成包含指标对比表和深度差异分析的 PNG 报告。 + +Style Guide: Google Python Style Guide +""" + +import os +import time +import subprocess +import tempfile +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Dict, Optional, Tuple, Any + +import torch +import clip +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from PIL import Image +from torchvision import transforms +from facenet_pytorch import MTCNN, InceptionResnetV1 +from piq import ssim, psnr +import torch_fidelity as fid + +# 抑制非必要的警告输出 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +# ----------------------------------------------------------------------------- +# 全局配置与样式 +# ----------------------------------------------------------------------------- + +# Matplotlib LaTeX 风格配置 +plt.rcParams.update({ + 'font.family': 'serif', + 'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'], + 'mathtext.fontset': 'cm', + 'axes.unicode_minus': False +}) + +# 指标元数据配置:定义指标目标方向和分析阈值 +METRIC_ANALYSIS_META = { + 'FID': {'higher_is_better': False, 'th': [2.0, 10.0, 30.0]}, + 'SSIM': {'higher_is_better': True, 'th': [0.01, 0.05, 0.15]}, + 'PSNR': {'higher_is_better': True, 'th': [0.5, 2.0, 5.0]}, + 'FDS': {'higher_is_better': True, 'th': [0.02, 0.05, 0.1]}, + 'CLIP_IQS': {'higher_is_better': True, 'th': [0.01, 0.03, 0.08]}, + 'BRISQUE': {'higher_is_better': False, 'th': [2.0, 5.0, 10.0]}, +} +# 用于综合分析的降级权重 +ANALYSIS_WEIGHTS = {'Severe': 3, 'Significant': 2, 'Slight': 1, 'Negligible': 0} + + +# ----------------------------------------------------------------------------- +# 模型加载 (惰性加载或全局预加载) +# ----------------------------------------------------------------------------- + +try: + CLIP_MODEL, CLIP_PREPROCESS = clip.load('ViT-B/32', 'cuda') + CLIP_MODEL.eval() +except Exception as e: + print(f"[Warning] CLIP 模型加载失败: {e}") + CLIP_MODEL, CLIP_PREPROCESS = None, None + +def _get_clip_text_features(text: str) -> torch.Tensor: + """辅助函数:获取文本的 CLIP 特征。""" + if CLIP_MODEL is None: + return None + tokens = clip.tokenize(text).to('cuda') + with torch.no_grad(): + features = CLIP_MODEL.encode_text(tokens) + features /= features.norm(dim=-1, keepdim=True) + return features + +# ----------------------------------------------------------------------------- +# 核心计算逻辑 +# ----------------------------------------------------------------------------- + +def calculate_metrics( + ref_dir: str, + gen_dir: str, + image_size: int = 512 +) -> Dict[str, float]: + """计算图像集之间的多项质量评估指标。 + + 包括 FDS, SSIM, PSNR, CLIP_IQS, FID。 + + Args: + ref_dir: 参考图片目录路径。 + gen_dir: 生成图片目录路径。 + image_size: 图像处理尺寸。 + + Returns: + 包含各项指标名称和数值的字典。若目录无效返回空字典。 + """ + metrics = {} + + # 1. 数据加载 + def load_images(directory): + imgs = [] + if os.path.exists(directory): + for f in os.listdir(directory): + if f.lower().endswith(('.png', '.jpg', '.jpeg')): + try: + path = os.path.join(directory, f) + imgs.append(Image.open(path).convert("RGB")) + except Exception: + pass + return imgs + + ref_imgs = load_images(ref_dir) + gen_imgs = load_images(gen_dir) + + if not ref_imgs or not gen_imgs: + print(f"[Error] 图片加载失败或目录为空: \nRef: {ref_dir}\nGen: {gen_dir}") + return {} + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + with torch.no_grad(): + # --- FDS (Face Detection Similarity) --- + print(">>> 计算 FDS...") + mtcnn = MTCNN(image_size=image_size, margin=0, device=device) + resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device) + + def get_face_embeds(img_list): + embeds = [] + for img in img_list: + face = mtcnn(img) + if face is not None: + embeds.append(resnet(face.unsqueeze(0).to(device))) + return torch.stack(embeds) if embeds else None + + ref_embeds = get_face_embeds(ref_imgs) + gen_embeds = get_face_embeds(gen_imgs) + + if ref_embeds is not None and gen_embeds is not None: + # 计算生成集每张脸与参考集所有脸的余弦相似度均值 + sims = [] + for g_emb in gen_embeds: + sim = torch.cosine_similarity(g_emb, ref_embeds).mean() + sims.append(sim) + metrics['FDS'] = torch.tensor(sims).mean().item() + else: + metrics['FDS'] = 0.0 + + # 清理显存 + del mtcnn, resnet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # --- SSIM & PSNR --- + print(">>> 计算 SSIM & PSNR...") + tfm = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor() + ]) + + # 将参考集堆叠为 [N, C, H, W] + ref_tensor = torch.stack([tfm(img) for img in ref_imgs]).to(device) + + ssim_accum, psnr_accum = 0.0, 0.0 + for img in gen_imgs: + gen_tensor = tfm(img).unsqueeze(0).to(device) # [1, C, H, W] + + # 扩展维度以匹配参考集 + gen_expanded = gen_tensor.expand_as(ref_tensor) + + # 计算单张生成图相对于整个参考集的平均结构相似度 + val_ssim = ssim(gen_expanded, ref_tensor, data_range=1.0) + val_psnr = psnr(gen_expanded, ref_tensor, data_range=1.0) + + ssim_accum += val_ssim.item() + psnr_accum += val_psnr.item() + + metrics['SSIM'] = ssim_accum / len(gen_imgs) + metrics['PSNR'] = psnr_accum / len(gen_imgs) + + # --- CLIP IQS --- + print(">>> 计算 CLIP IQS...") + if CLIP_MODEL: + iqs_accum = 0.0 + txt_feat = _get_clip_text_features("good image") + for img in gen_imgs: + img_tensor = CLIP_PREPROCESS(img).unsqueeze(0).to(device) + img_feat = CLIP_MODEL.encode_image(img_tensor) + img_feat /= img_feat.norm(dim=-1, keepdim=True) + iqs_accum += (img_feat @ txt_feat.T).item() + metrics['CLIP_IQS'] = iqs_accum / len(gen_imgs) + else: + metrics['CLIP_IQS'] = np.nan + + # --- FID --- + print(">>> 计算 FID...") + try: + fid_res = fid.calculate_metrics( + input1=ref_dir, + input2=gen_dir, + cuda=True, + fid=True, + verbose=False + ) + metrics['FID'] = fid_res['frechet_inception_distance'] + except Exception as e: + print(f"[Error] FID 计算异常: {e}") + metrics['FID'] = np.nan + + return metrics + + +def run_brisque_cleanly(img_dir: str) -> float: + """使用 subprocess 和临时目录优雅地执行外部 BRISQUE 脚本。 + + Args: + img_dir: 图像目录路径。 + + Returns: + BRISQUE 分数,若失败返回 NaN。 + """ + print(f">>> 计算 BRISQUE (External)...") + + script_path = Path(__file__).parent / 'libsvm' / 'python' / 'brisquequality.py' + if not script_path.exists(): + print(f"[Error] 找不到 BRISQUE 脚本: {script_path}") + return np.nan + + abs_img_dir = os.path.abspath(img_dir) + + with tempfile.TemporaryDirectory() as temp_dir: + try: + cmd = [ + "python", str(script_path), + abs_img_dir, + temp_dir + ] + + # 在脚本所在目录执行 + subprocess.run( + cmd, + cwd=script_path.parent, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + # 读取临时生成的日志文件 + log_file = Path(temp_dir) / 'log.txt' + if log_file.exists(): + content = log_file.read_text(encoding='utf-8').strip() + try: + return float(content.split()[-1]) + except ValueError: + return float(content) + else: + return np.nan + + except Exception as e: + print(f"[Error] BRISQUE 执行出错: {e}") + return np.nan + + +# ----------------------------------------------------------------------------- +# 报告可视化与分析逻辑 +# ----------------------------------------------------------------------------- + +def analyze_metric_diff( + metric_name: str, + clean_val: float, + pert_val: float +) -> Tuple[str, str, str]: + """生成科学的分级差异分析文本。 + + Args: + metric_name: 指标名称。 + clean_val: 干净图得分。 + pert_val: 扰动图得分。 + + Returns: + (表头箭头符号, 差异描述文本, 状态等级) + """ + + cfg = METRIC_ANALYSIS_META.get(metric_name) + if not cfg: + return "-", "Configuration not found.", "Negligible" + + diff = pert_val - clean_val + abs_diff = abs(diff) + + # 判定好坏: + is_better = (cfg['higher_is_better'] and diff > 0) or (not cfg['higher_is_better'] and diff < 0) + is_worse = not is_better + + # 确定程度 + th = cfg['th'] + if abs_diff < th[0]: + degree = "Negligible" + elif abs_diff < th[1]: + degree = "Slight" + elif abs_diff < th[2]: + degree = "Significant" + else: + degree = "Severe" + + # 组装文案 + header_arrow = r"$\uparrow$" if cfg['higher_is_better'] else r"$\downarrow$" + + if degree == "Negligible": + analysis_text = f"Negligible change (diff < {th[0]:.4f})." + elif is_worse: + analysis_text = f"{degree} degradation." + else: + analysis_text = f"Unexpected {degree} change." + + return header_arrow, analysis_text, degree + + +def generate_visual_report( + ref_dir: str, + clean_dir: str, + pert_dir: str, + clean_metrics: Dict, + pert_metrics: Dict, + output_path: str +): + """渲染并保存专业对比分析报告 (PNG)。""" + + def get_sample(d): + if not os.path.exists(d): return None, "N/A" + files = [f for f in os.listdir(d) if f.lower().endswith(('.png','.jpg'))] + if not files: return None, "Empty" + return Image.open(os.path.join(d, files[0])).convert("RGB"), files[0] + + img_ref, name_ref = get_sample(ref_dir) + img_clean, name_clean = get_sample(clean_dir) + img_pert, name_pert = get_sample(pert_dir) + + # 布局设置 + # 增加高度以容纳文本 + fig = plt.figure(figsize=(12, 16.5), dpi=120) + gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.5], hspace=0.25, wspace=0.1) + + # 1. 图像展示区 + ax_ref = fig.add_subplot(gs[0, :]) + if img_ref: + ax_ref.imshow(img_ref) + ax_ref.set_title(f"Reference Image ($X$)\n{name_ref}", fontsize=12, fontweight='bold', pad=10) + ax_ref.axis('off') + + ax_c = fig.add_subplot(gs[1, 0]) + if img_clean: + ax_c.imshow(img_clean) + ax_c.set_title(f"Clean Output ($Y$)\n{name_clean}", fontsize=12, fontweight='bold', pad=10) + ax_c.axis('off') + + ax_p = fig.add_subplot(gs[1, 1]) + if img_pert: + ax_p.imshow(img_pert) + ax_p.set_title(f"Perturbed Output ($Y'$)\n{name_pert}", fontsize=12, fontweight='bold', pad=10) + ax_p.axis('off') + + # 2. 数据表格与分析区 + ax_data = fig.add_subplot(gs[2, :]) + ax_data.axis('off') + + metrics_list = ['FID', 'SSIM', 'PSNR', 'FDS', 'CLIP_IQS', 'BRISQUE'] + table_data = [] + analysis_lines = [] + + degradation_score = 0 + + # 遍历指标生成数据和分析 + for m in metrics_list: + c_val = clean_metrics.get(m, np.nan) + p_val = pert_metrics.get(m, np.nan) + + c_str = f"{c_val:.4f}" if not np.isnan(c_val) else "N/A" + p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "N/A" + diff_str = "-" + + header_arrow = "" + + if not np.isnan(c_val) and not np.isnan(p_val): + # 获取深度分析 + header_arrow, text_desc, degree = analyze_metric_diff(m, c_val, p_val) + + # 计算差异值 + diff = p_val - c_val + # 差异值本身的符号 (Diff > 0 或 Diff < 0) + diff_arrow = r"$\nearrow$" if diff > 0 else r"$\searrow$" + if abs(diff) < 1e-4: diff_arrow = r"$\rightarrow$" + + diff_str = f"{diff:+.4f} {diff_arrow}" + + analysis_lines.append(f"• {m}: Change {diff:+.4f}. Analysis: {text_desc}") + + # 累计降级分数 + cfg = METRIC_ANALYSIS_META.get(m) + is_worse = (cfg['higher_is_better'] and diff < 0) or (not cfg['higher_is_better'] and diff > 0) + if is_worse: + degradation_score += ANALYSIS_WEIGHTS.get(degree, 0) + + # 表格第一列:名称 + 期望方向箭头 + name_with_arrow = f"{m} ({header_arrow})" if header_arrow else m + table_data.append([name_with_arrow, c_str, p_str, diff_str]) + + # 绘制表格 + table = ax_data.table( + cellText=table_data, + colLabels=["Metric (Goal)", "Clean ($Y$)", "Perturbed ($Y'$)", "Diff ($\Delta$)"], + loc='upper center', + cellLoc='center', + colWidths=[0.25, 0.25, 0.25, 0.25] + ) + table.scale(1, 2.0) + table.set_fontsize(11) + + # 美化表头 + for (row, col), cell in table.get_celld().items(): + if row == 0: + cell.set_text_props(weight='bold', color='white') + cell.set_facecolor('#404040') + elif col == 0: + cell.set_text_props(weight='bold') + cell.set_facecolor('#f5f5f5') + + # 3. 底部综合分析文本框 + if not analysis_lines: + analysis_lines.append("• All metrics are missing or invalid.") + + full_text = "Quantitative Difference Analysis:\n" + "\n".join(analysis_lines) + + # 总体结论判断 (基于 holistic degradation score) + conclusion = "\n\n>>> EXECUTIVE SUMMARY (Holistic Judgment):\n" + + if degradation_score >= 8: + conclusion += "CRITICAL DEGRADATION. Significant quality loss observed. Attack highly effective." + elif degradation_score >= 4: + conclusion += "MODERATE DEGRADATION. Observable quality drop in key metrics. Attack effective." + elif degradation_score > 0: + conclusion += "MINOR DEGRADATION. Slight quality loss detected. Attack partially effective." + else: + conclusion += "INEFFECTIVE ATTACK. No significant or unexpected statistical quality loss observed." + + full_text += conclusion + + ax_data.text( + 0.05, + 0.30, + full_text, + ha='left', + va='top', + fontsize=12, family='monospace', wrap=True, + transform=ax_data.transAxes + ) + + fig.suptitle("Museguard: Quality Assurance Report", fontsize=18, fontweight='bold', y=0.95) + + plt.savefig(output_path, bbox_inches='tight', facecolor='white') + print(f"\n[Success] 报告已生成: {output_path}") + + +# ----------------------------------------------------------------------------- +# 主入口 +# ----------------------------------------------------------------------------- + +def main(): + parser = ArgumentParser() + parser.add_argument('--clean_output_dir', type=str, required=True) + parser.add_argument('--perturbed_output_dir', type=str, required=True) + parser.add_argument('--clean_ref_dir', type=str, required=True) + parser.add_argument('--png_output_path', type=str, required=True) + parser.add_argument('--size', type=int, default=512) + args = parser.parse_args() + + + Path(args.png_output_path).parent.mkdir(parents=True, exist_ok=True) + + print("========================================") + print(" Image Quality Evaluation Toolkit") + print("========================================") + + # 1. 计算 Clean 组 + print(f"\n[1/2] Evaluating Clean Set: {os.path.basename(args.clean_output_dir)}") + c_metrics = calculate_metrics(args.clean_ref_dir, args.clean_output_dir, args.size) + if c_metrics: + c_metrics['BRISQUE'] = run_brisque_cleanly(args.clean_output_dir) + + # 2. 计算 Perturbed 组 + print(f"\n[2/2] Evaluating Perturbed Set: {os.path.basename(args.perturbed_output_dir)}") + p_metrics = calculate_metrics(args.clean_ref_dir, args.perturbed_output_dir, args.size) + if p_metrics: + p_metrics['BRISQUE'] = run_brisque_cleanly(args.perturbed_output_dir) + + # 3. 生成报告 + if c_metrics and p_metrics: + generate_visual_report( + args.clean_ref_dir, + args.clean_output_dir, + args.perturbed_output_dir, + c_metrics, + p_metrics, + args.png_output_path + ) + else: + print("\n[Fatal] 评估数据不完整,中止报告生成。") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/infer.py b/src/backend/app/algorithms/finetune/infer.py deleted file mode 100644 index 624fcc6..0000000 --- a/src/backend/app/algorithms/finetune/infer.py +++ /dev/null @@ -1,87 +0,0 @@ -import argparse -import os - -import torch -from diffusers import StableDiffusionPipeline -from torchvision.utils import make_grid -from pytorch_lightning import seed_everything -from PIL import Image -import numpy as np -from einops import rearrange - -parser = argparse.ArgumentParser(description="Inference") -parser.add_argument( - "--model_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", -) -parser.add_argument( - "--output_dir", - type=str, - default="./test-infer/", - help="The output directory where predictions are saved", -) -parser.add_argument( - "--seed", - type=int, - default=42, - help="The output directory where predictions are saved", -) -parser.add_argument( - "--v", - type=str, - default="sks", - help="The output directory where predictions are saved", -) - -args = parser.parse_args() - -if __name__ == "__main__": - seed_everything(args.seed) - os.makedirs(args.output_dir, exist_ok=True) - - # define prompts - prompts = [ - f"a photo of {args.v} person", - f"a dslr portrait of {args.v} person", - f"a photo of {args.v} person looking at the mirror", - f"a photo of {args.v} person in front of eiffel tower", - ] - - - # create & load model - pipe = StableDiffusionPipeline.from_pretrained( - args.model_path, - torch_dtype=torch.float32, - safety_checker=None, - local_files_only=True, - ).to("cuda") - - for prompt in prompts: - print(">>>>>>", prompt) - norm_prompt = prompt.lower().replace(",", "").replace(" ", "_") - out_path = f"{args.output_dir}/{norm_prompt}" - os.makedirs(out_path, exist_ok=True) - all_samples = list() - for i in range(5): - images = pipe([prompt] * 6, num_inference_steps=100, guidance_scale=7.5,).images - for idx, image in enumerate(images): - image.save(f"{out_path}/{i}_{idx}.png") - image = np.array(image, dtype=np.float32) - image /= 255.0 - image = np.transpose(image, (2, 0, 1)) - image = torch.from_numpy(image) # numpy->tensor - all_samples.append(image) - grid = torch.stack(all_samples, 0) - # grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=8) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - img = Image.fromarray(grid.astype(np.uint8)) - img.save(f"{args.output_dir}/{prompt}.png") - torch.cuda.empty_cache() - - del pipe - torch.cuda.empty_cache() diff --git a/src/backend/app/algorithms/finetune/train_dreambooth_gen.py b/src/backend/app/algorithms/finetune/train_db_gen_trace.py similarity index 93% rename from src/backend/app/algorithms/finetune/train_dreambooth_gen.py rename to src/backend/app/algorithms/finetune/train_db_gen_trace.py index c34a908..76eaa6a 100644 --- a/src/backend/app/algorithms/finetune/train_dreambooth_gen.py +++ b/src/backend/app/algorithms/finetune/train_db_gen_trace.py @@ -24,6 +24,7 @@ import os import shutil import warnings from pathlib import Path +import pandas as pd import numpy as np import torch @@ -523,11 +524,6 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) - parser.add_argument( - "--is_perturbed", - action="store_true", - help="Whether training on perturbed images. Affects the generated image naming.", - ) parser.add_argument( "--offset_noise", @@ -594,6 +590,21 @@ def parse_args(input_args=None): help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", ) + # [START] 为可视化方案增加的参数 (通用指标) + parser.add_argument( + "--coords_save_path", + type=str, + default=None, + help="The path to save the intermediate coordinates (X, Y, Z metrics) for visualization.", + ) + parser.add_argument( + "--coords_log_interval", + type=int, + default=10, + help="Log and record intermediate coordinates every X steps.", + ) + # [END] 为可视化方案增加的参数 (通用指标) + if input_args is not None: args = parser.parse_args(input_args) else: @@ -1182,6 +1193,10 @@ def main(args): tracker_config.pop("validation_images") accelerator.init_trackers("dreambooth", config=tracker_config) + # [START] 为可视化方案增加的初始化 (通用指标) + coords_list = [] + # [END] 为可视化方案增加的初始化 (通用指标) + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1339,6 +1354,43 @@ def main(args): # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss + # [START] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) + if args.coords_save_path is not None: + # 修正 X轴 计算:将 torch.linalg.norm 替换为传统的 torch.norm + # 传统的 torch.norm 支持对多个维度求范数 (dim=[1, 2, 3]) + # X轴: UNet 预测特征 L2 范数 (衡量预测的“强度”) + # torch.norm(..., p=2, dim=...) 表示 L2 范数 + X_i_feature_norm = torch.norm( + model_pred.detach().float(), + p=2, + dim=[1, 2, 3] # 对 C, H, W 维度求 L2 范数 + ).mean().item() # 对 Batch 维度求平均 + # Y轴: UNet 预测特征方差 (衡量预测的“混乱度/稳定性”) + # var() 默认对所有维度求方差 + Y_i_feature_var = torch.var( + model_pred.detach().float() + ).item() + # Z轴: LDM 损失 (衡量预测的“准确度”) + Z_i = loss.detach().item() + + # 记录坐标 (仅在主进程进行) + if accelerator.is_main_process and global_step % args.coords_log_interval == 0: + coords_list.append([global_step, X_i_feature_norm, Y_i_feature_var, Z_i]) + if global_step % (args.coords_log_interval * 10) == 0: + df = pd.DataFrame( + coords_list, + columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + ) + save_file_path = Path(args.coords_save_path) + if not save_file_path.suffix: + save_file_path = save_file_path / "coords.csv" + save_file_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_file_path, index=False) + logger.info( + f"Step {global_step}: 已记录可视化坐标,周期保存批次坐标到 {save_file_path}" + ) + # [END] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) + accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( @@ -1449,10 +1501,25 @@ def main(args): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) + + # [START] 为可视化方案增加的最终保存 (通用指标) + if args.coords_save_path is not None and coords_list: + df = pd.DataFrame( + coords_list, + columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + ) + # 假设 args.coords_save_path 是目标文件路径 + save_file_path = Path(args.coords_save_path) + if not save_file_path.suffix: + save_file_path = save_file_path / "coords.csv" + save_file_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_file_path, index=False) + logger.info(f"训练结束:已将所有 {len(coords_list)} 坐标保存到 {save_file_path}") + # [END] 为可视化方案增加的最终保存 (通用指标) accelerator.end_training() if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/src/backend/app/algorithms/finetune/train_dreambooth_alone.py b/src/backend/app/algorithms/finetune/train_dreambooth_alone.py deleted file mode 100644 index 52a04a5..0000000 --- a/src/backend/app/algorithms/finetune/train_dreambooth_alone.py +++ /dev/null @@ -1,1035 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import argparse -import hashlib -import itertools -import logging -import math -import os -import warnings -from pathlib import Path -from typing import Optional - -import datasets -import diffusers -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel -from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version -from diffusers.utils.import_utils import is_xformers_available -from huggingface_hub import HfFolder, create_repo, whoami -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.13.0.dev0") - -logger = get_logger(__name__) - - -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - revision=revision, - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "RobertaSeriesModelWithTransformation": - from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation - - return RobertaSeriesModelWithTransformation - else: - raise ValueError(f"{model_class} is not supported.") - - -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--instance_data_dir", - type=str, - default=None, - required=True, - help="A folder containing the training data of instance images.", - ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - required=True, - help="The prompt with identifier specifying the instance", - ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided instance images.", - ) - parser.add_argument( - "--inference_prompts", - type=str, - default=None, - help="The prompt used to generate images at inference.", - ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", - ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="text-inversion-model", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", - ) - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-6, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam optimizer.", - ) - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer", - ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", - ) - parser.add_argument( - "--set_grads_to_none", - action="store_true", - help=( - "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" - " behaviors, so disable this argument if it causes any problems. More info:" - " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" - ), - ) - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") - if args.class_prompt is None: - raise ValueError("You must specify prompt for class images.") - else: - # logger is not available yet - if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") - if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") - - return args - - -class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ - - def __init__( - self, - instance_data_root, - instance_prompt, - tokenizer, - class_data_root=None, - class_prompt=None, - size=512, - center_crop=False, - ): - self.size = size - self.center_crop = center_crop - self.tokenizer = tokenizer - - self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") - - self.instance_images_path = list(Path(instance_data_root).iterdir()) - self.num_instance_images = len(self.instance_images_path) - self.instance_prompt = instance_prompt - self._length = self.num_instance_images - - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt - else: - self.class_data_root = None - - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def __getitem__(self, index): - example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids - - if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids - - return example - - -def collate_fn(examples, with_prior_preservation=False): - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = torch.cat(input_ids, dim=0) - - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - } - return batch - - -class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." - - def __init__(self, prompt, num_samples): - self.prompt = prompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example - - -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - -def infer(checkpoint_path, ckpt_pipeline, prompts=None, n_img=16, bs=8, n_steps=100, guidance_scale=7.5): - if ckpt_pipeline is None: - pipe = StableDiffusionPipeline.from_pretrained( - checkpoint_path, torch_dtype=torch.bfloat16, safety_checker=None - ).to("cuda") - else: - pipe = ckpt_pipeline.to("cuda") - pipe.enable_xformers_memory_efficient_attention() - pipe.disable_attention_slicing() - - for prompt in prompts: - print(prompt) - norm_prompt = prompt.lower().replace(",", "").replace(" ", "_") - out_path = f"{checkpoint_path}/dreambooth/{norm_prompt}" - os.makedirs(out_path, exist_ok=True) - for i in range(n_img // bs): - images = pipe( - [prompt] * bs, - num_inference_steps=n_steps, - guidance_scale=guidance_scale, - ).images - for idx, image in enumerate(images): - image.save(f"{out_path}/{i}_{idx}.png") - del pipe - - -class LatentsDataset(Dataset): - def __init__(self, latents_cache, text_encoder_cache): - self.latents_cache = latents_cache - self.text_encoder_cache = text_encoder_cache - - def __len__(self): - return len(self.latents_cache) - - def __getitem__(self, index): - return self.latents_cache[index], self.text_encoder_cache[index] - - -def main(args): - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - logging_dir=logging_dir, - ) - - # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate - # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Generate class images if prior preservation is enabled. - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) - cur_class_images = len(list(class_images_dir.iterdir())) - if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - if args.prior_generation_precision == "fp32": - torch_dtype = torch.float32 - elif args.prior_generation_precision == "fp16": - torch_dtype = torch.float16 - elif args.prior_generation_precision == "bf16": - torch_dtype = torch.bfloat16 - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None, - revision=args.revision, - ) - pipeline.set_progress_bar_config(disable=True) - pipeline.enable_xformers_memory_efficient_attention() - pipeline.disable_attention_slicing() - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) - - for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, - ): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - image.save(image_filename) - - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Load the tokenizer - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) - elif args.pretrained_model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, - ) - - # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - - # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - - vae.requires_grad_(False) - if not args.train_text_encoder: - text_encoder.requires_grad_(False) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() - - # Check that all trainable models are in full precision - low_precision_error_string = ( - "Please make sure to always have all model weights in full float32 precision when starting training - even if" - " doing mixed precision training. copy of the weights should still be float32." - ) - - if accelerator.unwrap_model(unet).dtype != torch.float32: - raise ValueError( - f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" - ) - - if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: - raise ValueError( - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." - f" {low_precision_error_string}" - ) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Optimizer creation - params_to_optimize = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() - ) - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - ) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=False, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - ) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - vae.to(device, dtype=weight_dtype) - - latents_cache = [] - text_encoder_cache = [] - - for batch in tqdm(train_dataloader, desc="Caching latents"): - with torch.no_grad(): - batch["pixel_values"] = batch["pixel_values"].to(device, dtype=weight_dtype) - batch["input_ids"] = batch["input_ids"].to(device) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) - if args.train_text_encoder: - text_encoder_cache.append(batch["input_ids"]) - else: - text_encoder_cache.append(text_encoder(batch["input_ids"])[0]) - train_dataset = LatentsDataset(latents_cache, text_encoder_cache) - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True) - scaling_factor = vae.config.scaling_factor - del vae - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - if args.train_text_encoder: - ( - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # Move vae and text_encoder to device and cast to weight_dtype - if not args.train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth", config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm( - range(global_step, args.max_train_steps), - disable=not accelerator.is_local_main_process, - ) - progress_bar.set_description("Steps") - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - if args.train_text_encoder: - text_encoder.train() - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - with accelerator.accumulate(unet): - # Convert images to latent space - latent_dist = batch[0][0] - latents = latent_dist.sample() - latents = latents * scaling_factor - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - device=latents.device, - ) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch[0][1])[0] - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=args.set_grads_to_none) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = args.output_dir - ckpt_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - revision=args.revision, - ) - if global_step < 1000: - prompts = args.inference_prompts.split(";") - infer(save_path, ckpt_pipeline, prompts, n_img=16, bs=4, n_steps=100) - else: - ckpt_pipeline.save_pretrained(save_path) - del ckpt_pipeline - prompts = args.inference_prompts.split(";") - ckpt_pipeline = None - # infer(save_path, ckpt_pipeline, prompts, n_img=16, bs=4, n_steps=100) - logger.info(f"Saved state to {save_path}") - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - # Create the pipeline using using the trained modules and save it. - accelerator.wait_for_everyone() - if accelerator.is_main_process: - print("Finish training") - - accelerator.end_training() - - -if __name__ == "__main__": - args = parse_args() - main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_lora_gen.py b/src/backend/app/algorithms/finetune/train_lora_gen_trace.py similarity index 92% rename from src/backend/app/algorithms/finetune/train_lora_gen.py rename to src/backend/app/algorithms/finetune/train_lora_gen_trace.py index 4a951b2..3dfc6fd 100644 --- a/src/backend/app/algorithms/finetune/train_lora_gen.py +++ b/src/backend/app/algorithms/finetune/train_lora_gen_trace.py @@ -528,11 +528,20 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + # [START] 为可视化方案增加的参数定义 parser.add_argument( - "--is_perturbed", - action="store_true", - help="Whether training on perturbed images. Affects the generated image naming.", + "--positions_save_path", + type=str, + default=None, + help="保存3D可视化坐标数据的路径 (X: LoRA权重L2范数, Y: 总梯度L2范数, Z: LDM损失)。", ) + parser.add_argument( + "--coords_log_interval", + type=int, + default=25, + help="保存坐标数据的步数间隔。", + ) + # [END] 为可视化方案增加的参数定义 if input_args is not None: args = parser.parse_args(input_args) @@ -1211,6 +1220,15 @@ def main(args): initial_global_step = 0 first_epoch = 0 + # [START] 为可视化方案增加的初始化和导入 + coords_list = [] + if args.positions_save_path is not None: + import pandas as pd + logger.info( + f"可视化指标采集已启用。数据将每 {args.coords_log_interval} 步保存一次到 {args.positions_save_path}" + ) + # [END] 为可视化方案增加的初始化和导入 + progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, @@ -1306,6 +1324,20 @@ def main(args): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) + + # [START] 为可视化方案增加的指标采集 - Y轴 (梯度范数) + Y_i = float("nan") + if args.positions_save_path is not None: + # Y轴: 总梯度L2范数 (在反向传播之后,优化器更新之前计算) + grad_norm_sq = 0.0 + for name, p in unet.named_parameters(): + # 只关注需要梯度更新的参数 + if p.grad is not None and p.requires_grad: + # 使用 float() 避免 torch.amp 带来的精度问题,确保准确计算L2范数 + grad_norm_sq += (p.grad.data.float() ** 2).sum().item() + Y_i = math.sqrt(grad_norm_sq) + # [END] 为可视化方案增加的指标采集 - Y轴 (梯度范数) + if accelerator.sync_gradients: accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) optimizer.step() @@ -1317,6 +1349,45 @@ def main(args): progress_bar.update(1) global_step += 1 + + # [START] 为可视化方案增加的指标采集 - X, Z轴和保存逻辑 + if args.positions_save_path is not None and ( + global_step % args.coords_log_interval == 0 + or global_step == 1 + or global_step == initial_global_step + 1 + ): + + # Z轴: LDM 损失 + Z_i = loss.detach().item() + + # X轴: 总LoRA权重L2范数 (在优化器更新之后计算) + lora_weight_norm_sq = 0.0 + for name, p in unet.named_parameters(): + # 只关注 LoRA 权重参数 ("lora" in name) + if "lora" in name and p.requires_grad: + lora_weight_norm_sq += (p.data.float() ** 2).sum().item() + X_i = math.sqrt(lora_weight_norm_sq) + + # 记录坐标数据 + coords_list.append([global_step, X_i, Y_i, Z_i]) + + # 实时保存到文件 (可选,但为了防止训练中断丢失数据,建议实时保存) + # 每次记录时都覆盖保存,确保文件始终是最新的 + df = pd.DataFrame( + coords_list, + columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss'] + ) + save_path = Path(args.positions_save_path) / "coords_live.csv" + save_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_path, index=False) + + if global_step % (args.coords_log_interval * 10) == 0: + logger.info( + f"Step {global_step}: 已记录并保存可视化坐标 (X={X_i:.4f}, Y={Y_i:.4f}, Z={Z_i:.4f}) 到 {save_path}" + ) + # [END] 为可视化方案增加的指标采集 - X, Z轴和保存逻辑 + + if accelerator.is_main_process: if (global_step + 1) % args.checkpointing_steps == 0: # 1. 保存模型参数:直接保存到 args.output_dir,覆盖上一轮 @@ -1435,9 +1506,22 @@ def main(args): ignore_patterns=["step_*", "epoch_*"], ) + # [START] 为可视化方案增加的最终保存 (防止最后一步数据没有被保存) + if args.positions_save_path is not None and coords_list: + df = pd.DataFrame( + coords_list, + columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss'] + ) + save_path = Path(args.positions_save_path) / "coords.csv" + + save_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_path, index=False) + logger.info(f"训练结束:已将所有 {len(coords_list)} 步可视化坐标数据保存到 {save_path}") + # [END] 为可视化方案增加的最终保存 + accelerator.end_training() if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/src/backend/app/algorithms/finetune/train_ti_gen_trace.py b/src/backend/app/algorithms/finetune/train_ti_gen_trace.py new file mode 100644 index 0000000..b97d155 --- /dev/null +++ b/src/backend/app/algorithms/finetune/train_ti_gen_trace.py @@ -0,0 +1,1404 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import logging +import math +import os +import shutil +import warnings +from pathlib import Path +import pandas as pd + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +# Removed LoRA import: from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + # Removed LoRA import: convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, + placeholder_token: str = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + # Model card updated for Textual Inversion + model_description = f""" +# Textual Inversion - {repo_id} + +These are Textual Inversion weights (an embedding) for {base_model}. The weights were trained on {prompt} using a placeholder token `{placeholder_token}`. You can find some example images in the following. \n +{img_str} +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + prompt=prompt, + model_description=model_description, + inference=True, + ) + tags = ["text-to-image", "diffusers", "textual-inversion", "diffusers-training"] + if isinstance(pipeline, StableDiffusionPipeline): + tags.extend(["stable-diffusion", "stable-diffusion-diffusers"]) + else: + tags.extend(["if", "if-diffusers"]) + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) # disable safety checker + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + if args.validation_images is None: + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance. **Must contain the placeholder token** (e.g., 'a photo of ').", + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="The placeholder token (e.g., ) that will be learned and used in instance_prompt.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + + parser.add_argument( + "--output_dir", + type=str, + default="textual-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + + parser.add_argument( + "--validation_image_output_dir", + type=str, + default=None, + help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", + ) + + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + # Textual Inversion only trains the embedding, not the full text encoder + # parser.add_argument( + # "--train_text_encoder", + # action="store_true", + # help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + # ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + + parser.add_argument( + "--initializer_token", + type=str, + default=None, + required=True, + help="A token to use as a proxy for the concept during training. Used to initialize the new placeholder embedding.", + ) + + # [START] 为可视化方案增加的通用指标参数定义 (保持不变) + parser.add_argument( + "--coords_save_path", + type=str, + default=None, + help="保存3D可视化坐标数据的路径 (X: UNet预测特征L2范数, Y: UNet预测特征方差, Z: LDM损失)。", + ) + parser.add_argument( + "--coords_log_interval", + type=int, + default=25, + help="保存坐标数据的步数间隔。", + ) + # [END] 为可视化方案增加的通用指标参数定义 + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + + # if "<" not in args.placeholder_token or ">" not in args.placeholder_token: + # logger.warning( + # f"The placeholder token `{args.placeholder_token}` does not seem to be enclosed by `<` and `>`. " + # f"Please make sure it's a unique token that is unlikely to exist in the vocabulary." + # ) + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + + NOTE: Renamed from DreamBoothDataset to TI_Dataset for clarity, but keeping the name if possible to maintain + compatibility with original imports/logic. Reverting to original name to maintain structural parity. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + size=512, + center_crop=False, + # Encoder hidden states pre-computation is not supported for TI as the embeddings are the target + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + # TI does not support pre-computed embeddings + if encoder_hidden_states is not None or class_prompt_encoder_hidden_states is not None: + raise ValueError("Textual Inversion cannot use pre-computed encoder hidden states.") + + self.encoder_hidden_states = encoder_hidden_states # Should be None + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states # Should be None + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images # Simplified length as no prior preservation + + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + # Tokenize instance prompt + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + # Class data logic removed for Textual Inversion + + return example + + +def collate_fn(examples): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + +# PromptDataset and class image generation are removed as TI typically doesn't use prior preservation + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Prior preservation image generation logic removed + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # Add the placeholder token to the tokenizer vocabulary and initialize the new token embedding + # Get token IDs for initializer and placeholder tokens + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + placeholder_token_ids = tokenizer.encode(args.placeholder_token, add_special_tokens=False) + + if len(initializer_token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + if placeholder_token_ids != tokenizer.unk_token_id: + # If the placeholder is already in the vocab, it's either an existing token or was already added. + # We need to make sure it's actually the placeholder and not an existing common word. + # However, for simplicity and matching standard TI, we assume it's a new token. + # The standard approach is to *add* the placeholder token, which results in a list of new tokens. + + # Add the placeholder token to the tokenizer and get the new token ID + tokenizer.add_tokens(args.placeholder_token) + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + else: + # This case handles when the placeholder token is already a single, known token, which is usually fine, + # but in TI we usually want to add a *new* token. We rely on the `add_tokens` method below. + tokenizer.add_tokens(args.placeholder_token) + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # Textual Inversion specific setup: Resize token embeddings and initialize new token + text_encoder.resize_token_embeddings(len(tokenizer)) + + token_embeds = text_encoder.get_input_embeddings().weight.data + initializer_token_id = tokenizer.convert_tokens_to_ids(args.initializer_token) + + # Initialize the new token embedding with the initializer token's embedding + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + # Freeze all models and then unfreeze the embedding layer + if vae is not None: + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # Only train the newly added embedding (Textual Inversion) + embedding_layer = text_encoder.get_input_embeddings() + embedding_layer.weight.requires_grad = True + + # Freeze all but the placeholder token's embedding. We create a mask/indices for the placeholder token ID. + # Note: Textual Inversion typically only trains the new token's embedding. + # We use a trick to register the embedding layer as trainable, but ensure only the new embedding is updated. + + # The simplest way is to ensure all embedding weights are trainable, and let the optimizer only update + # the ones that appear in the batch. However, a safer way is to specifically mark only the placeholder + # embedding as trainable. + + # Get the embedding tensor + trainable_token_embeds = embedding_layer.weight + + # Mask to freeze all except the placeholder token's embedding + mask = torch.ones(len(tokenizer), dtype=torch.bool) + mask[placeholder_token_id] = False # We want the placeholder to be unmasked (trainable) + + # Freeze the embeddings that are NOT the placeholder token's + trainable_token_embeds.data[mask] = trainable_token_embeds.data[mask].float() + trainable_token_embeds.data[mask].requires_grad = False + + # Make sure the placeholder token's embedding is set to require gradients + trainable_token_embeds.data[placeholder_token_id].requires_grad = True + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and unet) to half-precision + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # Note: Only trainable parameters (new embeddings) must remain in float32 for fp16 training. + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # Textual Inversion doesn't train the full text encoder, so we only need to checkpoint UNet + + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # We only save the trained token embedding + text_encoder_unwrapped = unwrap_model(text_encoder) + + # Find the trained embedding + trained_embeddings = text_encoder_unwrapped.get_input_embeddings().weight[placeholder_token_id:placeholder_token_id+1] + + # Create a state dict to save + learned_embeds_dict = { + args.placeholder_token: trained_embeddings.detach().cpu() + } + + # Save the embedding file (similar to Textual Inversion pipelines) + torch.save(learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + # Also save tokenizer for completeness + tokenizer.save_pretrained(output_dir) + + def load_model_hook(models, input_dir): + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + if isinstance(model, type(unwrap_model(text_encoder))): + text_encoder_ = model + # UNet is not passed to the load hook for training state, only text_encoder's embedding matters + + # Load the embedding file + embedding_path = os.path.join(input_dir, "learned_embeds.bin") + if not os.path.exists(embedding_path): + logger.warning(f"Could not find learned_embeds.bin at {embedding_path}. This may be normal if starting a new run.") + return + + state_dict = torch.load(embedding_path, map_location="cpu") + + # We expect a dictionary where the key is the placeholder token + if args.placeholder_token not in state_dict: + raise ValueError( + f"Trained embedding not found for placeholder token '{args.placeholder_token}' in loaded state dict." + ) + + learned_embeds = state_dict[args.placeholder_token] + + # Load embedding into the text encoder + token_embeds = text_encoder_.get_input_embeddings().weight.data + + # Ensure the current tokenizer and text encoder size is consistent with the checkpoint + current_tokenizer = AutoTokenizer.from_pretrained(input_dir) + current_placeholder_token_id = current_tokenizer.convert_tokens_to_ids(args.placeholder_token) + + if current_placeholder_token_id == current_tokenizer.unk_token_id: + raise ValueError( + f"Placeholder token '{args.placeholder_token}' not found in the tokenizer loaded from checkpoint at {input_dir}. " + "Ensure your checkpoint contains the tokenizer with the added placeholder token." + ) + + token_embeds[current_placeholder_token_id] = learned_embeds.to(token_embeds.dtype).to(token_embeds.device) + + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Only upcast trainable parameters (embedding) into fp32 if mixed precision is used + if accelerator.mixed_precision == "fp16": + # The embedding layer is the only part that needs to be checked + cast_training_params([text_encoder], dtype=torch.float32) + + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation: only includes the trainable embedding parameters + params_to_optimize = list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + + if not params_to_optimize: + raise ValueError("No trainable parameters found. Check if the embedding layer is set to requires_grad=True.") + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Pre-computation is not supported for Textual Inversion, so this block is simplified + pre_computed_encoder_hidden_states = None + pre_computed_class_prompt_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + # Only UNet, Text Encoder and Optimizer are prepared (VAE is not optimized/frozen) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers("textual-inversion", config=tracker_config) # Updated project name + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + resume_path = args.output_dir + + try: + accelerator.print(f"Resuming from checkpoint at {resume_path}") + accelerator.load_state(resume_path) + + # After loading state, `accelerator` updates its internal state including `step` and `epoch` + initial_global_step = accelerator.state.global_step + global_step = initial_global_step + + # Recalculate first_epoch based on the loaded global_step + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + first_epoch = global_step // num_update_steps_per_epoch + + accelerator.print(f"Resumed at global step {global_step} and epoch {first_epoch}") + + except Exception as e: + accelerator.print( + f"Could not load state from '{resume_path}'. Starting a new training run. Error: {e}" + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + first_epoch = 0 + else: + initial_global_step = 0 + first_epoch = 0 + + # [START] 为可视化方案增加的初始化和导入 (保持不变) + coords_list = [] + # 提前定义 X, Y 指标的临时存储变量,用于跨代码块传递数据 + X_i_feature_norm = float("nan") + Y_i_feature_var = float("nan") + + if args.coords_save_path is not None: + logger.info( + f"可视化指标采集已启用。数据将每 {args.coords_log_interval} 步保存一次到 {args.coords_save_path}" + ) + # [END] 为可视化方案增加的初始化和导入 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() # UNet is frozen, but keep in train mode for modules like Dropout (if any) + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + # Since pre_compute_text_embeddings is false, we encode the prompt here + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + encoder_hidden_states, + class_labels=class_labels, + return_dict=False, + )[0] + + # If model predicts variance, throw away the prediction. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # [START] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) (保持不变) + if args.coords_save_path is not None: + # 修正 X轴 计算:将 torch.linalg.norm 替换为传统的 torch.norm + # 传统的 torch.norm 支持对多个维度求范数 (dim=[1, 2, 3]) + # X轴: UNet 预测特征 L2 范数 (衡量预测的“强度”) + # torch.norm(..., p=2, dim=...) 表示 L2 范数 + X_i_feature_norm = torch.norm( + model_pred.detach().float(), + p=2, + dim=[1, 2, 3] # 对 C, H, W 维度求 L2 范数 + ).mean().item() # 对 Batch 维度求平均 + + # Y轴: UNet 预测特征方差 (衡量预测的“混乱度/稳定性”) + # var() 默认对所有维度求方差,我们对 C, H, W 求方差,然后对 Batch 求平均 + Y_i_feature_var = model_pred.detach().float().var(dim=[1, 2, 3]).mean().item() + # [END] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Prior preservation block removed for Textual Inversion. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + if accelerator.sync_gradients: + # Only clip gradient for trainable parameters + # For Textual Inversion, only the embedding requires grad + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Ensure only the placeholder token's embedding is updated and all others are clamped + # This is the "slicing" step typical of TI to ensure only the learned token moves + if accelerator.num_processes > 1: + # For DDP/Distributed training, we need to unwrap the model to apply the mask + unwrapped_text_encoder = unwrap_model(text_encoder) + trainable_embeds = unwrapped_text_encoder.get_input_embeddings().weight + else: + trainable_embeds = text_encoder.get_input_embeddings().weight + + # Clamp the non-placeholder embeddings (ensure they don't move) + trainable_embeds.data[mask] = trainable_embeds.data[mask].float().to(trainable_embeds.device) + trainable_embeds.data[placeholder_token_id] = trainable_embeds.data[placeholder_token_id].float() + + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + + # [START] 为可视化方案增加的 X, Y, Z轴 数据记录和保存 (通用指标) (保持不变) + if args.coords_save_path is not None and ( + global_step % args.coords_log_interval == 0 + or global_step == 1 + or global_step == initial_global_step + 1 + ): + + # Z轴: LDM 损失 (直接获取当前步的 loss) + Z_i = loss.detach().item() + + # 记录坐标数据 (X和Y已在前面计算) + coords_list.append([global_step, X_i_feature_norm, Y_i_feature_var, Z_i]) + + # 实时保存到文件 (覆盖保存,确保文件始终是最新的) + df = pd.DataFrame( + coords_list, + columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + ) + + # 假设 args.coords_save_path 是目标文件路径 (如 ./data/coords.csv) + save_file_path = Path(args.coords_save_path) + if not save_file_path.suffix: + save_file_path = save_file_path / "coords.csv" + save_file_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_file_path, index=False) + + if global_step % (args.coords_log_interval * 10) == 0: + logger.info( + f"Step {global_step}: 已记录并保存可视化坐标 (X={X_i_feature_norm:.4f}, Y={Y_i_feature_var:.4f}, Z={Z_i:.4f}) 到 {save_file_path}" + ) + # [END] 为可视化方案增加的 X, Y, Z轴 数据记录和保存 (通用指标) + + + if accelerator.is_main_process: + if (global_step + 1) % args.checkpointing_steps == 0: + # 1. 保存模型参数:直接保存到 args.output_dir,覆盖上一轮 + output_dir = args.output_dir + # accelerator.save_state handles saving the models using the registered hooks + accelerator.save_state(output_dir) + logger.info(f"Saving state to {output_dir} at step {global_step+1}") + + # 2. 推理调用模型:从 args.output_dir 加载最新的模型权重 + # Textual Inversion Pipeline loading + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=unwrap_model(text_encoder), # Use the unwrapped text encoder + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # Load the learned embedding into the pipeline's tokenizer/text_encoder + # (The load hook handles the actual embedding tensor update during accelerator.load_state) + # Here, we only need to load the tokenizer to ensure the pipeline has the placeholder token + pipeline.tokenizer = AutoTokenizer.from_pretrained(output_dir) + pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) + # 🌟 关键修复:手动加载 learned_embeds.bin 文件 + # 1. 加载 learned_embeds.bin + path = os.path.join(args.output_dir, "learned_embeds.bin") + if not os.path.exists(path): + # 如果文件名为 pytorch_model.bin (accelerate保存的完整模型),我们需要从模型中提取 + # 此处假设您只保存了 learned_embeds.bin + logger.warning("learned_embeds.bin not found. Skipping manual load.") + else: + # 加载权重字典 + loaded_embeds = torch.load(path, map_location="cpu") + + # 2. 提取唯一的 key (例如 'sks') 和 embedding tensor + token_name = list(loaded_embeds.keys())[0] + embedding = loaded_embeds[token_name] + + # 3. 获取新 token 的 ID + token_id = pipeline.tokenizer.convert_tokens_to_ids(token_name) + + # 4. 将权重插入到 Text Encoder 的 Embedding Layer 中 + text_encoder_embeddings = pipeline.text_encoder.get_input_embeddings() + text_encoder_embeddings.weight.data[token_id] = embedding.to(text_encoder_embeddings.weight.dtype).to(text_encoder_embeddings.weight.device) + + # 保持 pipeline 在 GPU 上 + pipeline.to(accelerator.device) + + # Set pipeline args + pipeline_args = {"prompt": args.validation_prompt} + + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + ) + + # 3. 推理生成结果保存:直接保存到指定目录/output_dir,不创建子文件夹 + base_save_path = Path(args.validation_image_output_dir or args.output_dir) + base_save_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving validation images to {base_save_path}") + + # 图片直接保存在 base_save_path,会覆盖上一轮的同名图片 + for i, image in enumerate(images): + image.save(base_save_path / f"image_{i}.png") + + # Clean up pipeline to save memory + del pipeline + gc.collect() + torch.cuda.empty_cache() + + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + + # Save the final embeddings and tokenizer + accelerator.wait_for_everyone() + if accelerator.is_main_process: + text_encoder = unwrap_model(text_encoder) + + # Final save of the learned_embeds.bin and tokenizer + trained_embeddings = text_encoder.get_input_embeddings().weight[placeholder_token_id:placeholder_token_id+1] + + learned_embeds_dict = { + args.placeholder_token: trained_embeddings.detach().cpu() + } + + torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + tokenizer.save_pretrained(args.output_dir) + + # Final inference + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype + ) + # Load the final embedding + pipeline.tokenizer = AutoTokenizer.from_pretrained(args.output_dir) + pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) + pipeline.load_textual_inversion(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + args.num_train_epochs, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=False, # TI is not full text encoder training + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + placeholder_token=args.placeholder_token, # Added for TI + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training (Textual Inversion)", + ignore_patterns=["step_*", "epoch_*"], + ) + + # [START] 为可视化方案增加的最终保存 (通用指标) (保持不变) + if args.coords_save_path is not None and coords_list: + df = pd.DataFrame( + coords_list, + columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + ) + # 假设 args.coords_save_path 是目标文件路径 + save_file_path = Path(args.coords_save_path) + if not save_file_path.suffix: + save_file_path = save_file_path / "coords.csv" + save_file_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(save_file_path, index=False) + logger.info(f"训练结束:已将所有 {len(coords_list)} 步可视化坐标数据保存到 {save_file_path}") + # [END] 为可视化方案增加的最终保存 + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/processor/coords_processor.py b/src/backend/app/algorithms/processor/coords_processor.py new file mode 100644 index 0000000..ae2a908 --- /dev/null +++ b/src/backend/app/algorithms/processor/coords_processor.py @@ -0,0 +1,132 @@ +import pandas as pd +from pathlib import Path +import sys +import logging +import numpy as np +import statsmodels.api as sm + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def apply_lowess_and_clipping_scaling(input_csv_path, output_csv_path, lowess_frac, target_range, clipping_percentile): + """ + 应用 Lowess 局部加权回归进行平滑(提取总体趋势),然后使用百分位数裁剪后的 Min-Max 边界来缩放。 + 目标:生成最平滑、最接近单调下降的客观趋势。 + """ + input_path = Path(input_csv_path) + output_path = Path(output_csv_path) + + if not input_path.exists(): + logging.error(f"错误:未找到输入文件 {input_csv_path}") + return + + logging.info(f"读取原始数据: {input_csv_path}") + df = pd.read_csv(input_path) + + df = df.loc[:,~df.columns.duplicated()].copy() + + # 定义原始数据列名 + raw_x_col = 'X_Feature_L2_Norm' + raw_y_col = 'Y_Feature_Variance' + raw_z_col = 'Z_LDM_Loss' + + # --------------------------- 1. Lowess 局部加权回归平滑 (提取总体趋势) --------------------------- + logging.info(f"应用 Lowess 局部加权回归,平滑因子 frac={lowess_frac}。") + + x_coords = df['step'].values + + for raw_col in [raw_x_col, raw_y_col, raw_z_col]: + y_coords = df[raw_col].values + + smoothed_data = sm.nonparametric.lowess( + endog=y_coords, + exog=x_coords, + frac=lowess_frac, + it=0 + ) + df[f'{raw_col}_LOWESS'] = smoothed_data[:, 1] + + + # --------------------------- 2. 百分位数边界缩放与方向统一 --------------------------- + p = clipping_percentile + logging.info(f"应用百分位数边界 (p={p}) 进行线性缩放,目标范围 [0, {target_range:.2f}]") + + scale_cols_map = { + 'X_Feature_L2_Norm': f'{raw_x_col}_LOWESS', + 'Y_Feature_Variance': f'{raw_y_col}_LOWESS', + 'Z_LDM_Loss': f'{raw_z_col}_LOWESS' + } + + for final_col, lowess_col in scale_cols_map.items(): + + data = df[lowess_col] + + # 裁剪:计算裁剪后的 min/max (定义缩放窗口) + lower_bound = data.quantile(p) + upper_bound = data.quantile(1.0 - p) + + min_val = lower_bound + max_val = upper_bound + data_range = max_val - min_val + + if data_range <= 0 or data_range == np.nan: + df[final_col] = 0.0 + logging.warning(f"列 {final_col} 裁剪后的范围为 {data_range:.4f},跳过缩放。") + continue + + # 归一化: (data - Min_window) / Range_window + normalized_data = (data - min_val) / data_range + + # **优化方向统一逻辑 (所有指标都应是越小越好):** + if final_col in ['X_Feature_L2_Norm', 'Y_Feature_Variance']: + # X/Y 反转:将 Max 映射到 0,Min 映射到 TargetRange + final_scaled_data = (1.0 - normalized_data) * target_range + else: # Z_LDM_Loss + # Z 标准缩放:Min 映射到 0,Max 映射到 TargetRange + final_scaled_data = normalized_data * target_range + + # 保留负值,以确保平滑过渡 + df[final_col] = final_scaled_data + + logging.info(f" - 列 {final_col}:裁剪边界: [{min_val:.4f}, {max_val:.4f}]。缩放后范围不再严格约束 [0, {target_range:.2f}],以保留趋势。") + + + # --------------------------- 3. 最终保存 --------------------------- + output_path.parent.mkdir(parents=True, exist_ok=True) + + final_cols = ['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + + + df[final_cols].to_csv( + output_path, + index=False, + float_format='%.3f' + ) + + logging.info(f"Lowess平滑和缩放后的数据已保存到: {output_csv_path}") + + +if __name__ == '__main__': + if len(sys.argv) != 6: + logging.error("使用方法: python smooth_coords.py <输入CSV路径> <输出CSV路径> <目标视觉范围 (例如 30)> <离散点裁剪百分比 (例如 0.15)>") + else: + input_csv = sys.argv[1] + output_csv = sys.argv[2] + try: + lowess_frac = float(sys.argv[3]) + target_range = float(sys.argv[4]) + clipping_p = float(sys.argv[5]) + + if not (0.0 < lowess_frac <= 1.0): + raise ValueError("Lowess 平滑因子 frac 必须在 (0.0, 1.0] 之间。") + if target_range <= 0: + raise ValueError("目标视觉范围必须大于 0。") + if not (0 <= clipping_p < 0.5): + raise ValueError("裁剪百分比必须在 [0, 0.5) 之间。") + + if not Path(output_csv).suffix: + output_csv = str(Path(output_csv) / "scaled_coords.csv") + + apply_lowess_and_clipping_scaling(input_csv, output_csv, lowess_frac, target_range, clipping_p) + except ValueError as e: + logging.error(f"参数错误: {e}") \ No newline at end of file diff --git a/src/backend/app/algorithms/processor/image_processor.py b/src/backend/app/algorithms/processor/image_processor.py new file mode 100644 index 0000000..287a197 --- /dev/null +++ b/src/backend/app/algorithms/processor/image_processor.py @@ -0,0 +1,149 @@ +""" +图片处理功能,用于把原始图片剪裁为中心正方形,指定分辨率,并保存为指定格式,还可以选择是否序列化改名。 +""" + +import argparse +import os +from pathlib import Path +from PIL import Image + +# --- 1. 参数解析 --- +def parse_args(input_args=None): + """ + 解析命令行参数。 + """ + parser = argparse.ArgumentParser(description="Image Processor for Centering, Resizing, and Format Conversion.") + + # 路径和分辨率参数 + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="A folder containing the original images to be processed and overwritten.", + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The target resolution (width and height) for the output images (e.g., 512 for 512x512).", + ) + # 格式参数 + parser.add_argument( + "--target_format", + type=str, + default="png", + choices=["jpeg", "png", "webp", "jpg"], + help="The target format for the saved images (e.g., 'png', 'jpg', 'webp'). The original file will be overwritten, potentially changing the file extension.", + ) + + # 序列化数字重命名参数 + parser.add_argument( + "--rename_sequential", + action="store_true", # 当这个参数存在时,其值为 True + help="If set, images will be sequentially renamed (e.g., 001.jpg, 002.jpg...) instead of preserving the original filename. WARNING: This WILL delete the originals.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + +# --- 2. 核心图像处理逻辑 --- +def process_image(image_path: Path, output_path: Path, resolution: int, target_format: str, delete_original: bool): + """ + 加载图像,居中取最大正方形,升降分辨率,并保存为目标格式。 + + Args: + image_path: 原始图片路径。 + output_path: 最终保存路径。 + resolution: 目标分辨率。 + target_format: 目标文件格式。 + delete_original: 是否删除原始文件。 + """ + try: + # 加载图像并统一转换为 RGB 模式 + img = Image.open(image_path).convert("RGB") + + # 居中取最大正方形 + width, height = img.size + min_dim = min(width, height) + + # 计算裁剪框 (以最短边为尺寸的中心正方形) + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + right = left + min_dim + bottom = top + min_dim + + # 裁剪中心正方形 + img = img.crop((left, top, right, bottom)) + + # 升降分辨率到指定 resolution + # 使用 LANCZOS 高质量重采样方法 + img = img.resize((resolution, resolution), resample=Image.Resampling.LANCZOS) + + # 准备输出格式 + save_format = target_format.upper().replace('JPEG', 'JPG') + + # 保存图片 + # 对于 JPEG/JPG,设置 quality 参数 + if save_format == 'JPG': + img.save(output_path, format='JPEG', quality=95) + else: + img.save(output_path, format=save_format) + + # 根据标记决定是否删除原始文件 + if delete_original and image_path.resolve() != output_path.resolve(): + os.remove(image_path) + + print(f"Processed: {image_path.name} -> {output_path.name} ({resolution}x{resolution} {save_format})") + + except Exception as e: + print(f"Error processing {image_path.name}: {e}") + +# --- 3. 主函数 --- +def main(args): + # 路径准备 + input_dir = Path(args.input_dir) + + if not input_dir.is_dir(): + print(f"Error: Input directory not found at {input_dir}") + return + + # 查找所有图片文件 (支持 jpg, jpeg, png, webp) + valid_suffixes = ['.jpg', '.jpeg', '.png', '.webp'] + image_paths = sorted([p for p in input_dir.iterdir() if p.suffix.lower() in valid_suffixes]) # 排序以确保重命名顺序一致 + + if not image_paths: + print(f"No image files found in {input_dir}") + return + + print(f"Found {len(image_paths)} images in {input_dir}. Starting processing...") + + # 准备目标格式的扩展名 + extension = args.target_format.lower().replace('jpeg', 'jpg') + + # 迭代处理图片 + for i, image_path in enumerate(image_paths): + + # 决定输出路径 + if args.rename_sequential: + # 顺序重命名逻辑:001, 002, 003... (至少三位数字) + new_name = f"{i + 1:03d}.{extension}" + output_path = input_dir / new_name + # 如果原始文件与新文件名称不同,则需要删除原始文件 + delete_original = True + else: + # 保持原始文件名,但修改后缀 + output_path = image_path.with_suffix(f'.{extension}') + # 只有当原始后缀与目标后缀不同时,才需要删除原始文件(防止遗留旧格式) + delete_original = (image_path.suffix.lower() != f'.{extension}') + + process_image(image_path, output_path, args.resolution, args.target_format, delete_original) + + print("Processing complete.") + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/services/auth_service.py b/src/backend/app/services/auth_service.py deleted file mode 100644 index 756d907..0000000 --- a/src/backend/app/services/auth_service.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -认证服务 -处理用户认证相关逻辑 -""" - -from app.database import User - -class AuthService: - """认证服务类""" - - @staticmethod - def authenticate_user(username, password): - """验证用户凭据""" - user = User.query.filter_by(username=username).first() - - if user and user.check_password(password) and user.is_active: - return user - - return None - - @staticmethod - def get_user_by_id(user_id): - """根据ID获取用户""" - return User.query.get(user_id) - - @staticmethod - def is_email_available(email): - """检查邮箱是否可用""" - return User.query.filter_by(email=email).first() is None - - @staticmethod - def is_username_available(username): - """检查用户名是否可用""" - return User.query.filter_by(username=username).first() is None \ No newline at end of file diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 078d1bd..b9822de 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -422,9 +422,30 @@ class TaskService: str(batch.user_id), str(batch.id) ) - # 类别图片目录(微调用) + # 类别图片目录(Prior Preservation) class_finetune_dir = os.path.join( - project_root, 'static', 'class_finetune', + project_root, + current_app.config['CLASS_DATA_FOLDER'], + str(batch.user_id), str(batch.id) + ) + + # 坐标可视化保存目录(训练轨迹) + coords_save_dir = os.path.join( + project_root, + current_app.config['COORDS_SAVE_FOLDER'], + str(batch.user_id), str(batch.id) + ) + + # 验证图片输出目录(分别对应 clean 和 perturbed) + validation_original_dir = os.path.join( + project_root, + current_app.config['MODEL_ORIGINAL_FOLDER'], + str(batch.user_id), str(batch.id) + ) + + validation_perturbed_dir = os.path.join( + project_root, + current_app.config['MODEL_PERTURBED_FOLDER'], str(batch.user_id), str(batch.id) ) @@ -440,6 +461,8 @@ class TaskService: train_images_dir=original_dir, output_model_dir=original_model_dir, class_dir=class_finetune_dir, + coords_save_path=coords_save_dir, + validation_output_dir=validation_original_dir, inference_prompts=inference_prompts, is_perturbed=False, custom_params=None, @@ -456,6 +479,8 @@ class TaskService: train_images_dir=perturbed_dir, output_model_dir=perturbed_model_dir, class_dir=class_finetune_dir, + coords_save_path=coords_save_dir, + validation_output_dir=validation_perturbed_dir, inference_prompts=inference_prompts, is_perturbed=True, custom_params=None, diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 20698ee..eab354b 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -97,17 +97,20 @@ def _check_and_update_finetune_status(finetune_task): def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images_dir, output_model_dir, - class_dir, inference_prompts, is_perturbed=False, custom_params=None): + class_dir, coords_save_path, validation_output_dir, inference_prompts, + is_perturbed=False, custom_params=None): """ 执行微调任务 Args: finetune_batch_id: 微调任务ID batch_id: 扰动任务批次ID - finetune_method: 微调方法 (dreambooth, lora) + finetune_method: 微调方法 (dreambooth, lora, textual_inversion) train_images_dir: 训练图片目录(原始或扰动) output_model_dir: 模型输出目录 - class_dir: 类别图片目录 + class_dir: 类别图片目录(用于 prior preservation) + coords_save_path: 坐标保存路径(用于训练轨迹可视化) + validation_output_dir: 验证图片输出目录 inference_prompts: 推理提示词 is_perturbed: 是否是扰动图片训练 custom_params: 自定义参数 @@ -142,6 +145,8 @@ def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images # 确保目录存在 os.makedirs(output_model_dir, exist_ok=True) os.makedirs(class_dir, exist_ok=True) + os.makedirs(coords_save_path, exist_ok=True) + os.makedirs(validation_output_dir, exist_ok=True) # 获取配置 use_real = AlgorithmConfig.USE_REAL_ALGORITHMS @@ -150,7 +155,8 @@ def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images # 使用真实微调算法 result = _run_real_finetune( finetune_method, batch_id, train_images_dir, output_model_dir, - class_dir, inference_prompts, is_perturbed, custom_params + class_dir, coords_save_path, validation_output_dir, + inference_prompts, is_perturbed, custom_params ) else: # 使用虚拟微调实现 @@ -180,7 +186,8 @@ def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, - class_dir, inference_prompts, is_perturbed, custom_params): + class_dir, coords_save_path, validation_output_dir, + inference_prompts, is_perturbed, custom_params): """运行真实微调算法""" from config.algorithm_config import AlgorithmConfig @@ -198,18 +205,39 @@ def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model # 合并参数 params = {**default_params, **(custom_params or {})} - # 构建命令行参数 cmd_args = [ f"--instance_data_dir={train_images_dir}", f"--output_dir={output_model_dir}", - f"--class_data_dir={class_dir}", + f"--validation_image_output_dir={validation_output_dir}", ] + if finetune_method == 'dreambooth': + # DreamBooth 特有参数 + cmd_args.extend([ + f"--class_data_dir={class_dir}", + f"--coords_save_path={coords_save_path}", + ]) + + elif finetune_method == 'lora': + # LoRA 特有参数 (positions_save_path 等同于 coords_save_path) + cmd_args.extend([ + f"--class_data_dir={class_dir}", + f"--positions_save_path={coords_save_path}", + ]) + + elif finetune_method == 'textual_inversion': + # Textual Inversion 特有参数 (不需要 class_data_dir) + cmd_args.extend([ + f"--coords_save_path={coords_save_path}", + ]) + else: + raise ValueError(f"Unsupported finetune method: {finetune_method}") + # 添加is_perturbed标志 if is_perturbed: cmd_args.append("--is_perturbed") - # 添加其他参数 + # 添加其他默认参数 for key, value in params.items(): if isinstance(value, bool): if value: diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index dd79776..571b562 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -40,6 +40,7 @@ class AlgorithmConfig: 'pid': os.getenv('CONDA_ENV_PID', 'pid'), 'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'), 'lora': os.getenv('CONDA_ENV_LORA', 'pid'), + 'textual_inversion': os.getenv('CONDA_ENV_TI', 'pid'), } # 模型路径配置 @@ -157,7 +158,7 @@ class AlgorithmConfig: # ========== 微调算法配置 ========== FINETUNE_SCRIPTS = { 'dreambooth': { - 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_gen.py'), + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_db_gen_trace.py'), 'virtual_script': None, # 使用虚拟实现在worker中 'conda_env': CONDA_ENVS['dreambooth'], 'default_params': { @@ -169,23 +170,24 @@ class AlgorithmConfig: 'resolution': 512, 'train_batch_size': 1, 'gradient_accumulation_steps': 1, - 'learning_rate': 1e-4, + 'learning_rate': 2e-6, 'lr_scheduler': 'constant', 'lr_warmup_steps': 0, - 'num_class_images': 1, - 'max_train_steps': 1, - 'checkpointing_steps': 1, + 'num_class_images': 200, + 'max_train_steps': 1000, + 'checkpointing_steps': 500, 'center_crop': True, 'mixed_precision': 'bf16', 'prior_generation_precision': 'bf16', - 'sample_batch_size': 1, + 'sample_batch_size': 5, 'validation_prompt': 'a photo of sks person', - 'num_validation_images': 1, - 'validation_steps': 1 + 'num_validation_images': 10, + 'validation_steps': 500, + 'coords_log_interval': 10 } }, 'lora': { - 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen.py'), + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen_trace.py'), 'virtual_script': None, 'conda_env': CONDA_ENVS['lora'], 'default_params': { @@ -200,14 +202,40 @@ class AlgorithmConfig: 'learning_rate': 1e-4, 'lr_scheduler': 'constant', 'lr_warmup_steps': 0, - 'num_class_images': 1, - 'max_train_steps': 1, - 'checkpointing_steps': 1, + 'num_class_images': 200, + 'max_train_steps': 1000, + 'checkpointing_steps': 500, 'seed': 0, 'mixed_precision': 'fp16', 'rank': 4, 'validation_prompt': 'a photo of sks person', - 'num_validation_images': 1 + 'num_validation_images': 10, + 'coords_log_interval': 10 + } + }, + 'textual_inversion': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_ti_gen_trace.py'), + 'virtual_script': None, + 'conda_env': CONDA_ENVS['textual_inversion'], + 'default_params': { + 'pretrained_model_name_or_path': MODELS_DIR['model2'], + 'placeholder_token': 'sks', + 'initializer_token': 'person', + 'instance_prompt': 'a photo of sks person', + 'resolution': 512, + 'train_batch_size': 1, + 'gradient_accumulation_steps': 1, + 'learning_rate': 5e-4, + 'lr_scheduler': 'constant', + 'lr_warmup_steps': 0, + 'max_train_steps': 1000, + 'checkpointing_steps': 500, + 'seed': 0, + 'mixed_precision': 'fp16', + 'validation_prompt': 'a photo of sks person', + 'num_validation_images': 4, + 'validation_epochs': 50, + 'coords_log_interval': 10 } } } diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 5f4b3a6..bf99fe1 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -55,6 +55,12 @@ class Config: MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 + # 微调训练相关配置 + CLASS_DATA_FOLDER = os.path.join(STATIC_ROOT, 'class') # 类别数据目录(用于 prior preservation) + # 可视化与分析配置 + EVA_RES_FOLDER = os.path.join(STATIC_ROOT, 'eva_res') # 评估结果根目录 + COORDS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 3D坐标可视化数据(用于训练轨迹) + POSITIONS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 位置数据(与coords相同,LoRA使用)未使用 # 预设演示图像配置 DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录 -- 2.34.1 From f01c3bac26d465a1aa7d2689a54d9c68050d4da9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 05:07:59 +0800 Subject: [PATCH 04/14] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=E7=AE=97=E6=B3=95=E5=90=8E=E7=AB=AF=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/services/task_service.py | 186 +++++++++++++++++++ src/backend/app/workers/evaluate_worker.py | 201 +++++++++++++++++++++ src/backend/app/workers/heatmap_worker.py | 198 ++++++++++++++++++++ src/backend/config/algorithm_config.py | 25 +++ src/backend/config/settings.py | 2 + 5 files changed, 612 insertions(+) create mode 100644 src/backend/app/workers/evaluate_worker.py create mode 100644 src/backend/app/workers/heatmap_worker.py diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index b9822de..dcea483 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -651,4 +651,190 @@ class TaskService: except Exception as e: print(f"生成最终评估时出错: {str(e)}") db.session.rollback() + + @staticmethod + def start_heatmap_task(heatmap_task, original_image_id, perturbed_image_id): + """ + 启动热力图生成任务 + + Args: + heatmap_task: Heatmap对象 + original_image_id: 原始图片ID(前端选择) + perturbed_image_id: 扰动图片ID(前端选择) + + Returns: + job_id + """ + try: + # 获取关联的主任务 + task = heatmap_task.task + if not task: + print(f"Heatmap task {heatmap_task.tasks_id} has no associated Task") + return None + + # 获取图片信息 + from app.database import Image + original_image = Image.query.get(original_image_id) + perturbed_image = Image.query.get(perturbed_image_id) + + if not original_image or not perturbed_image: + print("Selected images not found") + return None + + # 获取Prompt文本(从关联的Perturbation任务的数据集类型) + from app.database import Perturbation, DataType + perturbation = Perturbation.query.filter_by(tasks_id=task.tasks_id).first() + if perturbation and perturbation.data_type: + prompt_text = perturbation.data_type.data_type_prompt + # 从prompt中提取target_word(简单提取最后一个词) + target_word = prompt_text.split()[-1] if prompt_text else 'person' + else: + prompt_text = "a photo of sks person" + target_word = "person" + + project_root = os.path.dirname(current_app.root_path) + + # 输出目录 + output_dir = os.path.join( + project_root, + current_app.config['HEATDIF_SAVE_FOLDER'], + str(task.user_id), + str(task.tasks_id) + ) + + # 模型路径(从配置文件获取) + from config.algorithm_config import AlgorithmConfig + model_path = AlgorithmConfig.MODELS_DIR.get('model2') # 默认使用SD 1.5 + + # 获取队列 + queue = TaskService._get_queue() + + from app.workers.heatmap_worker import run_heatmap_task + + # 提交任务到队列 + job = queue.enqueue( + run_heatmap_task, + heatmap_id=heatmap_task.tasks_id, + task_id=task.tasks_id, + original_image_path=original_image.file_path, + perturbed_image_path=perturbed_image.file_path, + prompt_text=prompt_text, + target_word=target_word, + output_dir=output_dir, + model_path=model_path, + job_timeout=AlgorithmConfig.TASK_TIMEOUT, + job_id=f"heatmap_{heatmap_task.tasks_id}" + ) + + # 更新任务状态为queued + from app.database import TaskStatus + queued_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if queued_status: + task.tasks_status_id = queued_status.task_status_id + db.session.commit() + + return job.id + + except Exception as e: + print(f"启动热力图任务时出错: {str(e)}") + return None + + @staticmethod + def start_evaluate_task(evaluate_task): + """ + 启动数值评估任务 + + Args: + evaluate_task: Evaluate对象 + + Returns: + job_id + """ + try: + # 获取关联的主任务 + task = evaluate_task.task + if not task: + print(f"Evaluate task for Task {evaluate_task.tasks_id} has no associated Task") + return None + + # 获取关联的Finetune任务,以确定微调方法 + from app.database import Finetune + finetune = Finetune.query.filter_by( + tasks_id=evaluate_task.tasks_id, + finetune_configs_id=evaluate_task.finetune_configs_id + ).first() + + if not finetune: + print(f"No finetune task found for Evaluate task") + return None + + project_root = os.path.dirname(current_app.root_path) + + # 参考图片目录(原始上传的图片) + clean_ref_dir = os.path.join( + project_root, + current_app.config['ORIGINAL_IMAGES_FOLDER'], + str(task.user_id), + str(task.tasks_id) + ) + + # Clean输出目录(原始图训练后的生成结果) + clean_output_dir = os.path.join( + project_root, + current_app.config['MODEL_OUTPUTS_FOLDER'], + 'clean', + str(task.user_id), + str(task.tasks_id) + ) + + # Perturbed输出目录(扰动图训练后的生成结果) + perturbed_output_dir = os.path.join( + project_root, + current_app.config['MODEL_OUTPUTS_FOLDER'], + 'perturbed', + str(task.user_id), + str(task.tasks_id) + ) + + # 评估结果输出目录 + output_dir = os.path.join( + project_root, + current_app.config['NUMBERS_SAVE_FOLDER'], + str(task.user_id), + str(task.tasks_id) + ) + + # 获取队列 + queue = TaskService._get_queue() + + from app.workers.evaluate_worker import run_evaluate_task + from config.algorithm_config import AlgorithmConfig + + # 提交任务到队列 + job = queue.enqueue( + run_evaluate_task, + evaluate_id=evaluate_task.tasks_id, + task_id=task.tasks_id, + clean_ref_dir=clean_ref_dir, + clean_output_dir=clean_output_dir, + perturbed_output_dir=perturbed_output_dir, + output_dir=output_dir, + image_size=512, + job_timeout=AlgorithmConfig.TASK_TIMEOUT, + job_id=f"evaluate_{evaluate_task.tasks_id}_{evaluate_task.finetune_configs_id}" + ) + + # 更新任务状态为queued + from app.database import TaskStatus + queued_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if queued_status: + task.tasks_status_id = queued_status.task_status_id + db.session.commit() + + return job.id + + except Exception as e: + print(f"启动评估任务时出错: {str(e)}") + return None + diff --git a/src/backend/app/workers/evaluate_worker.py b/src/backend/app/workers/evaluate_worker.py new file mode 100644 index 0000000..aee7187 --- /dev/null +++ b/src/backend/app/workers/evaluate_worker.py @@ -0,0 +1,201 @@ +""" +RQ Worker 数值评估任务处理器 +生成原始图与扰动图微调后的模型生成效果对比报告 +""" + +import os +import subprocess +import logging +from datetime import datetime + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_evaluate_task(evaluate_id, task_id, clean_ref_dir, clean_output_dir, + perturbed_output_dir, output_dir, image_size=512): + """ + 执行数值评估任务 + + Args: + evaluate_id: 评估任务ID(复合主键之一) + task_id: 关联的主任务ID + clean_ref_dir: 干净参考图片目录(原始上传的图片) + clean_output_dir: 干净图片训练后的生成结果目录 + perturbed_output_dir: 扰动图片训练后的生成结果目录 + output_dir: 输出目录 + image_size: 图片处理尺寸 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Evaluate, Task, TaskStatus + + app = create_app() + + with app.app_context(): + try: + # Evaluate 使用复合主键,需要通过 tasks_id 和 finetune_configs_id 查询 + # 这里先通过 task_id 查询 + evaluate_task = Evaluate.query.filter_by(tasks_id=task_id).first() + if not evaluate_task: + raise ValueError(f"Evaluate task for Task {task_id} not found") + + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 更新任务状态为处理中 + processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() + if processing_status: + task.tasks_status_id = processing_status.task_status_id + db.session.commit() + + logger.info(f"Starting evaluate task for Task {task_id}") + + # 确保目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 获取配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + + if use_real: + # 使用真实评估算法 + result = _run_real_evaluate( + task_id, clean_ref_dir, clean_output_dir, + perturbed_output_dir, output_dir, image_size + ) + else: + # 使用虚拟实现(生成占位符报告) + result = _run_virtual_evaluate(output_dir) + + # 保存评估结果文件路径到数据库 + report_file = os.path.join(output_dir, 'nums_dif.png') + if os.path.exists(report_file): + # 可以在这里保存评估结果的详细信息到 EvaluationResult 表 + pass + + # 更新任务状态为完成 + completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() + if completed_status: + task.tasks_status_id = completed_status.task_status_id + db.session.commit() + + logger.info(f"Evaluate task completed for Task {task_id}") + return result + + except Exception as e: + logger.error(f"Evaluate task failed for Task {task_id}: {str(e)}", exc_info=True) + + # 更新任务状态为失败 + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + db.session.commit() + + raise + + +def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir, + perturbed_output_dir, output_dir, image_size): + """运行真实数值评估算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real evaluate generation") + + # 获取评估脚本配置 + evaluate_config = AlgorithmConfig.EVALUATE_SCRIPTS.get('numbers', {}) + script_path = evaluate_config.get('real_script') + conda_env = evaluate_config.get('conda_env') + + if not script_path: + raise ValueError("Evaluate script not configured") + + # 输出文件路径 + png_output_path = os.path.join(output_dir, 'nums_dif.png') + + # 构建命令行参数 + cmd_args = [ + f"--clean_ref_dir={clean_ref_dir}", + f"--clean_output_dir={clean_output_dir}", + f"--perturbed_output_dir={perturbed_output_dir}", + f"--png_output_path={png_output_path}", + f"--size={image_size}", + ] + + # 构建完整命令 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join( + log_dir, + f'evaluate_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Evaluate generation failed with code {process.returncode}. Check log: {log_file}") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'log_file': log_file + } + + +def _run_virtual_evaluate(output_dir): + """运行虚拟评估实现(生成占位符)""" + logger.info(f"Running virtual evaluate generation") + + # 创建占位符图片 + from PIL import Image, ImageDraw + + # 创建一个模拟的评估报告图 + img = Image.new('RGB', (1200, 1600), color=(255, 255, 255)) + draw = ImageDraw.Draw(img) + + # 添加标题文本 + draw.text((50, 50), "Virtual Evaluation Report Placeholder", fill=(0, 0, 0)) + draw.text((50, 100), "Real evaluation will be generated when USE_REAL_ALGORITHMS=true", fill=(128, 128, 128)) + draw.text((50, 150), "Metrics: FID, SSIM, PSNR, FDS, CLIP_IQS, BRISQUE", fill=(64, 64, 64)) + + # 保存 + output_file = os.path.join(output_dir, 'nums_dif.png') + img.save(output_file) + + logger.info(f"Virtual evaluation report saved to {output_file}") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'virtual': True + } diff --git a/src/backend/app/workers/heatmap_worker.py b/src/backend/app/workers/heatmap_worker.py new file mode 100644 index 0000000..d379eb5 --- /dev/null +++ b/src/backend/app/workers/heatmap_worker.py @@ -0,0 +1,198 @@ +""" +RQ Worker 热力图任务处理器 +生成原始图与扰动图的注意力差异热力图 +""" + +import os +import subprocess +import logging +from datetime import datetime + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_heatmap_task(heatmap_id, task_id, original_image_path, perturbed_image_path, + prompt_text, target_word, output_dir, model_path): + """ + 执行热力图生成任务 + + Args: + heatmap_id: 热力图任务ID + task_id: 关联的主任务ID + original_image_path: 原始图片路径 + perturbed_image_path: 扰动图片路径 + prompt_text: Prompt文本(如 "a photo of sks person") + target_word: 目标关键词(如 "person") + output_dir: 输出目录 + model_path: Stable Diffusion模型路径 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Heatmap, Task, TaskStatus + + app = create_app() + + with app.app_context(): + try: + heatmap_task = Heatmap.query.get(heatmap_id) + if not heatmap_task: + raise ValueError(f"Heatmap task {heatmap_id} not found") + + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 更新任务状态为处理中 + processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() + if processing_status: + task.tasks_status_id = processing_status.task_status_id + db.session.commit() + + logger.info(f"Starting heatmap task for Heatmap {heatmap_id}, Task {task_id}") + + # 确保目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 获取配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + + if use_real: + # 使用真实热力图算法 + result = _run_real_heatmap( + task_id, original_image_path, perturbed_image_path, + prompt_text, target_word, output_dir, model_path + ) + else: + # 使用虚拟实现(生成占位符图片) + result = _run_virtual_heatmap(output_dir) + + # 保存热力图文件路径到数据库 + heatmap_file = os.path.join(output_dir, 'heatmap_dif.png') + if os.path.exists(heatmap_file): + heatmap_task.heatmap_name = 'heatmap_dif.png' + db.session.commit() + + # 更新任务状态为完成 + completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() + if completed_status: + task.tasks_status_id = completed_status.task_status_id + db.session.commit() + + logger.info(f"Heatmap task completed for Heatmap {heatmap_id}") + return result + + except Exception as e: + logger.error(f"Heatmap task failed for Heatmap {heatmap_id}: {str(e)}", exc_info=True) + + # 更新任务状态为失败 + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + db.session.commit() + + raise + + +def _run_real_heatmap(task_id, original_image_path, perturbed_image_path, + prompt_text, target_word, output_dir, model_path): + """运行真实热力图算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real heatmap generation") + + # 获取热力图脚本配置 + evaluate_config = AlgorithmConfig.EVALUATE_SCRIPTS.get('heatmap', {}) + script_path = evaluate_config.get('real_script') + conda_env = evaluate_config.get('conda_env') + + if not script_path: + raise ValueError("Heatmap script not configured") + + # 构建命令行参数 + cmd_args = [ + f"--model_path={model_path}", + f"--image_path_a={original_image_path}", + f"--image_path_b={perturbed_image_path}", + f"--prompt_text={prompt_text}", + f"--target_word={target_word}", + f"--output_dir={output_dir}", + ] + + # 构建完整命令 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join( + log_dir, + f'heatmap_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Heatmap generation failed with code {process.returncode}. Check log: {log_file}") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'log_file': log_file + } + + +def _run_virtual_heatmap(output_dir): + """运行虚拟热力图实现(生成占位符)""" + logger.info(f"Running virtual heatmap generation") + + # 创建占位符图片 + from PIL import Image, ImageDraw, ImageFont + import numpy as np + + # 创建一个模拟的热力图 + img = Image.new('RGB', (1200, 1600), color=(255, 255, 255)) + draw = ImageDraw.Draw(img) + + # 添加标题文本 + draw.text((50, 50), "Virtual Heatmap Placeholder", fill=(0, 0, 0)) + draw.text((50, 100), "Real heatmap will be generated when USE_REAL_ALGORITHMS=true", fill=(128, 128, 128)) + + # 保存 + output_file = os.path.join(output_dir, 'heatmap_dif.png') + img.save(output_file) + + logger.info(f"Virtual heatmap saved to {output_file}") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'virtual': True + } diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 571b562..e95933a 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -244,3 +244,28 @@ class AlgorithmConfig: def get_finetune_config(cls, finetune_method): """获取微调算法配置""" return cls.FINETUNE_SCRIPTS.get(finetune_method, {}) + + # ========== 评估算法配置 ========== + EVALUATE_SCRIPTS = { + 'heatmap': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'evaluate', 'eva_gen_heatmap.py'), + 'virtual_script': None, + 'conda_env': CONDA_ENVS['pid'], # 使用与微调相同的环境 + 'default_params': { + 'pretrained_model_name_or_path': MODELS_DIR['model2'], + } + }, + 'numbers': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'evaluate', 'eva_gen_nums.py'), + 'virtual_script': None, + 'conda_env': CONDA_ENVS['pid'], + 'default_params': { + 'image_size': 512, + } + } + } + + @classmethod + def get_evaluate_config(cls, evaluate_method): + """获取评估算法配置""" + return cls.EVALUATE_SCRIPTS.get(evaluate_method, {}) diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index bf99fe1..02d57b7 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -61,6 +61,8 @@ class Config: EVA_RES_FOLDER = os.path.join(STATIC_ROOT, 'eva_res') # 评估结果根目录 COORDS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 3D坐标可视化数据(用于训练轨迹) POSITIONS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'position') # 位置数据(与coords相同,LoRA使用)未使用 + HEATDIF_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'heatdif') # 热力图差异数据 + NUMBERS_SAVE_FOLDER = os.path.join(EVA_RES_FOLDER, 'numbers') # 数值结果数据 # 预设演示图像配置 DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录 -- 2.34.1 From f97bf59d4dd85559cf9e6e84d22fbf5b81b7012c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 11:56:54 +0800 Subject: [PATCH 05/14] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84worker?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/workers/evaluate_worker.py | 66 +- src/backend/app/workers/finetune_worker.py | 1013 ++++++++--------- src/backend/app/workers/heatmap_worker.py | 70 +- .../app/workers/perturbation_worker.py | 862 +++++++------- 4 files changed, 1068 insertions(+), 943 deletions(-) diff --git a/src/backend/app/workers/evaluate_worker.py b/src/backend/app/workers/evaluate_worker.py index aee7187..0f3b4b7 100644 --- a/src/backend/app/workers/evaluate_worker.py +++ b/src/backend/app/workers/evaluate_worker.py @@ -77,8 +77,8 @@ def run_evaluate_task(evaluate_id, task_id, clean_ref_dir, clean_output_dir, # 保存评估结果文件路径到数据库 report_file = os.path.join(output_dir, 'nums_dif.png') if os.path.exists(report_file): - # 可以在这里保存评估结果的详细信息到 EvaluationResult 表 - pass + # 保存报告图到Image表 + _save_report_image(task.tasks_id, report_file) # 更新任务状态为完成 completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() @@ -199,3 +199,65 @@ def _run_virtual_evaluate(output_dir): 'output_dir': output_dir, 'virtual': True } + + +def _save_report_image(task_id, report_file_path): + """ + 保存评估报告图到数据库Image表 + + Args: + task_id: 任务ID + report_file_path: 报告图文件完整路径 + """ + from app import db + from app.database import Image, ImageType + from PIL import Image as PILImage + + try: + # 获取报告图片类型 + report_type = ImageType.query.filter_by(image_code='report').first() + if not report_type: + logger.error("Image type 'report' not found") + return + + # 获取文件名 + report_filename = os.path.basename(report_file_path) + + # 检查是否已经保存过 + existing = Image.query.filter_by( + task_id=task_id, + stored_filename=report_filename, + image_types_id=report_type.image_types_id + ).first() + + if existing: + logger.info(f"Report image {report_filename} already exists, skipping") + return + + # 读取图片尺寸 + try: + with PILImage.open(report_file_path) as img: + width, height = img.size + except: + width, height = None, None + + # 保存到数据库 (report不需要father_id) + report_image = Image( + task_id=task_id, + image_types_id=report_type.image_types_id, + father_id=None, + stored_filename=report_filename, + file_path=report_file_path, + file_size=os.path.getsize(report_file_path), + width=width, + height=height + ) + + db.session.add(report_image) + db.session.commit() + + logger.info(f"Saved report image: {report_filename}") + + except Exception as e: + logger.error(f"Error saving report image: {str(e)}") + db.session.rollback() diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index eab354b..3eaf6da 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -1,517 +1,496 @@ -""" -RQ Worker 微调任务处理器 -在后台执行模型微调任务 -""" - -import os -import subprocess -import logging -from datetime import datetime - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - - -def _check_and_update_finetune_status(finetune_task): - """ - 检查微调任务状态并更新 - 当原始和扰动图片的微调都完成时,更新任务状态为completed - - Args: - finetune_task: FinetuneBatch对象 - """ - from app import db - from rq.job import Job - from redis import Redis - from config.algorithm_config import AlgorithmConfig - - try: - # 刷新数据库对象,确保获取最新状态 - db.session.refresh(finetune_task) - - # 如果状态已经是completed或failed,不再检查 - if finetune_task.status in ['completed', 'failed']: - return - - redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL) - - original_job_done = False - perturbed_job_done = False - has_original_job = False - has_perturbed_job = False - - # 检查原始图片微调任务 - if finetune_task.original_job_id: - has_original_job = True - try: - original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn) - status = original_job.get_status() - logger.info(f"Original job {finetune_task.original_job_id} status: {status}") - if status == 'finished': - original_job_done = True - elif status == 'failed': - # 如果原始任务失败,整个微调任务标记为失败 - finetune_task.status = 'failed' - finetune_task.error_message = f"Original finetune job failed: {original_job.exc_info}" - finetune_task.completed_at = datetime.utcnow() - db.session.commit() - logger.error(f"FinetuneBatch {finetune_task.id} failed: original job failed") - return - except Exception as e: - logger.error(f"Error checking original job: {str(e)}") - - # 检查扰动图片微调任务 - if finetune_task.perturbed_job_id: - has_perturbed_job = True - try: - perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn) - status = perturbed_job.get_status() - logger.info(f"Perturbed job {finetune_task.perturbed_job_id} status: {status}") - if status == 'finished': - perturbed_job_done = True - elif status == 'failed': - # 如果扰动任务失败,整个微调任务标记为失败 - finetune_task.status = 'failed' - finetune_task.error_message = f"Perturbed finetune job failed: {perturbed_job.exc_info}" - finetune_task.completed_at = datetime.utcnow() - db.session.commit() - logger.error(f"FinetuneBatch {finetune_task.id} failed: perturbed job failed") - return - except Exception as e: - logger.error(f"Error checking perturbed job: {str(e)}") - - # 如果两个任务都完成,更新状态为completed - if has_original_job and has_perturbed_job and original_job_done and perturbed_job_done: - finetune_task.status = 'completed' - finetune_task.completed_at = datetime.utcnow() - db.session.commit() - logger.info(f"FinetuneBatch {finetune_task.id} completed - both jobs finished") - else: - logger.info(f"FinetuneBatch {finetune_task.id} not all jobs finished yet: original={original_job_done}, perturbed={perturbed_job_done}") - - except Exception as e: - logger.error(f"Error checking finetune status: {str(e)}", exc_info=True) - - -def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images_dir, output_model_dir, - class_dir, coords_save_path, validation_output_dir, inference_prompts, - is_perturbed=False, custom_params=None): - """ - 执行微调任务 - - Args: - finetune_batch_id: 微调任务ID - batch_id: 扰动任务批次ID - finetune_method: 微调方法 (dreambooth, lora, textual_inversion) - train_images_dir: 训练图片目录(原始或扰动) - output_model_dir: 模型输出目录 - class_dir: 类别图片目录(用于 prior preservation) - coords_save_path: 坐标保存路径(用于训练轨迹可视化) - validation_output_dir: 验证图片输出目录 - inference_prompts: 推理提示词 - is_perturbed: 是否是扰动图片训练 - custom_params: 自定义参数 - - Returns: - 任务执行结果 - """ - from config.algorithm_config import AlgorithmConfig - from app import create_app, db - from app.database import FinetuneBatch, Batch, Image, ImageType - - app = create_app() - - with app.app_context(): - try: - finetune_task = FinetuneBatch.query.get(finetune_batch_id) - if not finetune_task: - raise ValueError(f"FinetuneBatch {finetune_batch_id} not found") - - batch = Batch.query.get(batch_id) - if not batch: - raise ValueError(f"Batch {batch_id} not found") - - # 更新微调任务状态为处理中 - if finetune_task.status == 'queued': - finetune_task.status = 'processing' - db.session.commit() - - logger.info(f"Starting finetune task for FinetuneBatch {finetune_batch_id}, Batch {batch_id}") - logger.info(f"Method: {finetune_method}, Perturbed: {is_perturbed}") - - # 确保目录存在 - os.makedirs(output_model_dir, exist_ok=True) - os.makedirs(class_dir, exist_ok=True) - os.makedirs(coords_save_path, exist_ok=True) - os.makedirs(validation_output_dir, exist_ok=True) - - # 获取配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS - - if use_real: - # 使用真实微调算法 - result = _run_real_finetune( - finetune_method, batch_id, train_images_dir, output_model_dir, - class_dir, coords_save_path, validation_output_dir, - inference_prompts, is_perturbed, custom_params - ) - else: - # 使用虚拟微调实现 - result = _run_virtual_finetune( - finetune_method, batch_id, train_images_dir, output_model_dir, - is_perturbed - ) - - # 保存生成的图片到数据库 - _save_generated_images(batch_id, output_model_dir, is_perturbed) - - # 检查两个任务是否都已完成 - _check_and_update_finetune_status(finetune_task) - - logger.info(f"Finetune task completed for FinetuneBatch {finetune_batch_id}") - return result - - except Exception as e: - logger.error(f"Finetune task failed for FinetuneBatch {finetune_batch_id}: {str(e)}", exc_info=True) - # 更新微调任务状态为失败 - if finetune_task: - finetune_task.status = 'failed' - finetune_task.error_message = str(e) - finetune_task.completed_at = datetime.utcnow() - db.session.commit() - raise - - -def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, - class_dir, coords_save_path, validation_output_dir, - inference_prompts, is_perturbed, custom_params): - """运行真实微调算法""" - from config.algorithm_config import AlgorithmConfig - - logger.info(f"Running real finetune: {finetune_method}") - - # 获取微调脚本路径和环境 - finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) - script_path = finetune_config.get('real_script') - conda_env = finetune_config.get('conda_env') - default_params = finetune_config.get('default_params', {}) - - if not script_path: - raise ValueError(f"Finetune method {finetune_method} not configured") - - # 合并参数 - params = {**default_params, **(custom_params or {})} - - cmd_args = [ - f"--instance_data_dir={train_images_dir}", - f"--output_dir={output_model_dir}", - f"--validation_image_output_dir={validation_output_dir}", - ] - - if finetune_method == 'dreambooth': - # DreamBooth 特有参数 - cmd_args.extend([ - f"--class_data_dir={class_dir}", - f"--coords_save_path={coords_save_path}", - ]) - - elif finetune_method == 'lora': - # LoRA 特有参数 (positions_save_path 等同于 coords_save_path) - cmd_args.extend([ - f"--class_data_dir={class_dir}", - f"--positions_save_path={coords_save_path}", - ]) - - elif finetune_method == 'textual_inversion': - # Textual Inversion 特有参数 (不需要 class_data_dir) - cmd_args.extend([ - f"--coords_save_path={coords_save_path}", - ]) - else: - raise ValueError(f"Unsupported finetune method: {finetune_method}") - - # 添加is_perturbed标志 - if is_perturbed: - cmd_args.append("--is_perturbed") - - # 添加其他默认参数 - for key, value in params.items(): - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 构建完整命令 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'accelerate', 'launch', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - image_type = 'perturbed' if is_perturbed else 'original' - log_file = os.path.join( - log_dir, - f'finetune_{image_type}_{batch_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' - ) - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - if process.returncode != 0: - raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}") - - # 清理class_dir - logger.info(f"Cleaning class directory: {class_dir}") - if os.path.exists(class_dir): - import shutil - for item in os.listdir(class_dir): - item_path = os.path.join(class_dir, item) - if os.path.isfile(item_path): - os.remove(item_path) - elif os.path.isdir(item_path): - shutil.rmtree(item_path) - - # 清理output_model_dir中的非图片文件 - logger.info(f"Cleaning non-image files in output directory: {output_model_dir}") - if os.path.exists(output_model_dir): - import shutil - image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} - for item in os.listdir(output_model_dir): - item_path = os.path.join(output_model_dir, item) - # 如果是目录,直接删除 - if os.path.isdir(item_path): - logger.info(f"Removing directory: {item_path}") - shutil.rmtree(item_path) - # 如果是文件,检查是否为图片 - elif os.path.isfile(item_path): - _, ext = os.path.splitext(item.lower()) - if ext not in image_extensions: - logger.info(f"Removing non-image file: {item_path}") - os.remove(item_path) - - return { - 'status': 'success', - 'output_dir': output_model_dir, - 'log_file': log_file - } - - -def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, is_perturbed): - """运行虚拟微调实现""" - from config.algorithm_config import AlgorithmConfig - import glob - - logger.info(f"Running virtual finetune: {finetune_method}") - - # 获取微调配置 - finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) - if not finetune_config: - raise ValueError(f"Finetune method {finetune_method} not configured") - - conda_env = finetune_config.get('conda_env') - default_params = finetune_config.get('default_params', {}) - - # 获取虚拟微调脚本路径 - script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py' - script_path = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../algorithms/finetune_virtual', - script_name - )) - - if not os.path.exists(script_path): - raise FileNotFoundError(f"Virtual finetune script not found: {script_path}") - - logger.info(f"Virtual script path: {script_path}") - logger.info(f"Conda environment: {conda_env}") - - # 创建输出目录 - os.makedirs(output_model_dir, exist_ok=True) - validation_output_dir = os.path.join(output_model_dir, 'generated') - os.makedirs(validation_output_dir, exist_ok=True) - - # 构建命令行参数(与真实微调参数一致) - cmd_args = [ - f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", - f"--instance_data_dir={train_images_dir}", - f"--output_dir={output_model_dir}", - f"--validation_image_output_dir={validation_output_dir}", - f"--class_data_dir=/tmp/class_placeholder", - ] - - # 添加is_perturbed标志 - if is_perturbed: - cmd_args.append("--is_perturbed") - - # 添加其他默认参数 - for key, value in default_params.items(): - if key == 'pretrained_model_name_or_path': - continue # 已添加 - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 使用conda run执行虚拟脚本 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'python', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - image_type = 'perturbed' if is_perturbed else 'original' - log_file = os.path.join( - log_dir, - f'virtual_{finetune_method}_{image_type}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' - ) - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - if process.returncode != 0: - raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}") - - # 统计生成的图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - generated_files = [] - for ext in image_extensions: - generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext))) - generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext.upper()))) - - logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images") - - return { - 'status': 'success', - 'output_dir': output_model_dir, - 'generated_count': len(generated_files), - 'generated_files': generated_files, - 'log_file': log_file - } - - -def _save_generated_images(batch_id, output_model_dir, is_perturbed): - """保存生成的图片到数据库""" - from app import db - from app.database import Batch, Image, ImageType - import glob - - try: - batch = Batch.query.get(batch_id) - if not batch: - return - - # 确定图片类型 - if is_perturbed: - image_type = ImageType.query.filter_by(type_code='perturbed_generate').first() - else: - image_type = ImageType.query.filter_by(type_code='original_generate').first() - - if not image_type: - logger.error(f"Image type not found for is_perturbed={is_perturbed}") - return - - # 查找生成的图片 - generated_dir = os.path.join(output_model_dir, 'generated') - if not os.path.exists(generated_dir): - # 尝试直接从output_model_dir查找 - generated_dir = output_model_dir - - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - image_files = [] - for ext in image_extensions: - image_files.extend(glob.glob(os.path.join(generated_dir, ext))) - image_files.extend(glob.glob(os.path.join(generated_dir, ext.upper()))) - - logger.info(f"Found {len(image_files)} generated images to save") - - # 保存到数据库 - saved_count = 0 - for image_path in image_files: - try: - from PIL import Image as PILImage - filename = os.path.basename(image_path) - - # 检查是否已经保存过(使用filename作为stored_filename) - existing = Image.query.filter_by( - batch_id=batch_id, - stored_filename=filename - ).first() - - if existing: - logger.info(f"Image already exists: {filename}") - continue - - with PILImage.open(image_path) as img: - width, height = img.size - - # 生成图片不设置父图片关系(多对多关系,无法确定具体父图片) - # 创建图片记录(直接使用filename,算法已经生成了正确格式) - generated_image = Image( - user_id=batch.user_id, - batch_id=batch_id, - father_id=None, # 微调生成图片无特定父图片 - original_filename=filename, - stored_filename=filename, # 算法输出已经是正确格式 - file_path=image_path, - file_size=os.path.getsize(image_path), - image_type_id=image_type.id, - width=width, - height=height - ) - - db.session.add(generated_image) - saved_count += 1 - logger.info(f"Saved generated image: {filename}") - - except Exception as e: - logger.error(f"Failed to save {image_path}: {str(e)}") - - db.session.commit() - logger.info(f"Successfully saved {saved_count} generated images to database") - - except Exception as e: - logger.error(f"Error saving generated images: {str(e)}") - db.session.rollback() +""" +RQ Worker 微调任务处理器 - 适配新数据库结构 +支持两种微调模式: +1. 基于加噪任务的微调 (共享task_id) +2. 直接上传图片的微调 (独立task_id) +""" + +import os +import subprocess +import logging +import glob +from datetime import datetime +from PIL import Image as PILImage + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images_dir, + output_model_dir, class_dir, coords_save_path, validation_output_dir, + is_perturbed=False, has_perturbation_task=False, custom_params=None): + """ + 执行微调任务 + + Args: + task_id: 任务ID + finetune_config_id: 微调配置ID + finetune_method: 微调方法 (dreambooth, lora, textual_inversion) + train_images_dir: 训练图片目录 + output_model_dir: 模型输出目录 + class_dir: 类别图片目录 + coords_save_path: 坐标保存路径 + validation_output_dir: 验证图片输出目录 + is_perturbed: 是否使用扰动图片训练 + has_perturbation_task: 是否基于加噪任务(True表示共享task_id, False表示独立任务) + custom_params: 自定义参数 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Task, Finetune, DataType, Perturbation, TaskStatus + + app = create_app() + + with app.app_context(): + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 获取微调任务详情 + finetune = Finetune.query.filter_by( + tasks_id=task_id, + finetune_configs_id=finetune_config_id + ).first() + + if not finetune: + raise ValueError(f"Finetune task ({task_id}, {finetune_config_id}) not found") + + # 更新任务状态为处理中 + processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() + if processing_status: + task.tasks_status_id = processing_status.task_status_id + if not task.started_at: + task.started_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Starting finetune task {task_id} (config: {finetune_config_id})") + logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}") + + # 获取Prompt文本 + # 优先从Finetune的data_type获取,如果没有则尝试从关联的Perturbation获取 + inference_prompts = "a photo of sks person" # 默认值 + + if finetune.data_type_id: + # 从微调任务的数据集类型获取 + data_type = DataType.query.get(finetune.data_type_id) + if data_type and data_type.data_type_prompt: + inference_prompts = data_type.data_type_prompt + logger.info(f"Using prompt from Finetune.data_type: {inference_prompts}") + elif has_perturbation_task: + # 如果是基于加噪任务,尝试从加噪任务获取 + perturbation = Perturbation.query.filter_by(tasks_id=task_id).first() + if perturbation and perturbation.data_type_id: + data_type = DataType.query.get(perturbation.data_type_id) + if data_type and data_type.data_type_prompt: + inference_prompts = data_type.data_type_prompt + logger.info(f"Using prompt from Perturbation.data_type: {inference_prompts}") + + # 获取配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + + if use_real: + # 使用真实微调算法 + result = _run_real_finetune( + finetune_method, task_id, train_images_dir, output_model_dir, + class_dir, coords_save_path, validation_output_dir, + inference_prompts, is_perturbed, custom_params + ) + else: + # 使用虚拟实现 + result = _run_virtual_finetune( + finetune_method, task_id, train_images_dir, output_model_dir, is_perturbed + ) + + # 保存生成的验证图片到数据库 + _save_generated_images(task_id, validation_output_dir, is_perturbed, has_perturbation_task) + + # 更新任务状态为完成 + completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() + if completed_status: + task.tasks_status_id = completed_status.task_status_id + task.finished_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Finetune task {task_id} completed successfully") + return result + + except Exception as e: + logger.error(f"Finetune task {task_id} failed: {str(e)}", exc_info=True) + + # 更新任务状态为失败 + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + task.finished_at = datetime.utcnow() + task.error_message = str(e) + db.session.commit() + + raise + + +def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir, + class_dir, coords_save_path, validation_output_dir, + inference_prompts, is_perturbed, custom_params): + """运行真实微调算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real finetune: {finetune_method}") + + # 获取微调脚本路径和环境 + finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) + script_path = finetune_config.get('real_script') + conda_env = finetune_config.get('conda_env') + default_params = finetune_config.get('default_params', {}) + + if not script_path: + raise ValueError(f"Finetune method {finetune_method} not configured") + + # 合并参数 + params = {**default_params, **(custom_params or {})} + + cmd_args = [ + f"--instance_data_dir={train_images_dir}", + f"--output_dir={output_model_dir}", + f"--validation_image_output_dir={validation_output_dir}", + ] + + if finetune_method == 'dreambooth': + # DreamBooth 特有参数 + cmd_args.extend([ + f"--class_data_dir={class_dir}", + f"--coords_save_path={coords_save_path}", + ]) + + elif finetune_method == 'lora': + # LoRA 特有参数 (positions_save_path 等同于 coords_save_path) + cmd_args.extend([ + f"--class_data_dir={class_dir}", + f"--positions_save_path={coords_save_path}", + ]) + + elif finetune_method == 'textual_inversion': + # Textual Inversion 特有参数 (不需要 class_data_dir) + cmd_args.extend([ + f"--coords_save_path={coords_save_path}", + ]) + else: + raise ValueError(f"Unsupported finetune method: {finetune_method}") + + # 添加is_perturbed标志 + if is_perturbed: + cmd_args.append("--is_perturbed") + + # 添加其他默认参数 + for key, value in params.items(): + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 构建完整命令 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'accelerate', 'launch', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + image_type = 'perturbed' if is_perturbed else 'original' + log_file = os.path.join( + log_dir, + f'finetune_{image_type}_{task_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}") + + # 清理class_dir + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + import shutil + shutil.rmtree(class_dir) + os.makedirs(class_dir) + + # 清理output_model_dir中的非图片文件 + logger.info(f"Cleaning non-image files in output directory: {output_model_dir}") + if os.path.exists(output_model_dir): + import shutil + image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for item in os.listdir(output_model_dir): + item_path = os.path.join(output_model_dir, item) + if os.path.isfile(item_path): + _, ext = os.path.splitext(item) + if ext.lower() not in image_extensions: + try: + os.remove(item_path) + logger.info(f"Removed non-image file: {item}") + except Exception as e: + logger.warning(f"Failed to remove {item}: {str(e)}") + + return { + 'status': 'success', + 'output_dir': output_model_dir, + 'log_file': log_file + } + + +def _run_virtual_finetune(finetune_method, task_id, train_images_dir, output_model_dir, is_perturbed): + """运行虚拟微调实现""" + from config.algorithm_config import AlgorithmConfig + import shutil + + logger.info(f"Running virtual finetune: {finetune_method}") + + # 获取微调配置 + finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) + if not finetune_config: + raise ValueError(f"Finetune method {finetune_method} not configured") + + conda_env = finetune_config.get('conda_env') + default_params = finetune_config.get('default_params', {}) + + # 获取虚拟微调脚本路径 + script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py' + script_path = os.path.abspath(os.path.join( + os.path.dirname(__file__), + '../algorithms/finetune_virtual', + script_name + )) + + if not os.path.exists(script_path): + raise FileNotFoundError(f"Virtual finetune script not found: {script_path}") + + logger.info(f"Virtual script path: {script_path}") + logger.info(f"Conda environment: {conda_env}") + + # 创建输出目录 + os.makedirs(output_model_dir, exist_ok=True) + validation_output_dir = os.path.join(output_model_dir, 'generated') + os.makedirs(validation_output_dir, exist_ok=True) + + # 构建命令行参数 + cmd_args = [ + f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", + f"--instance_data_dir={train_images_dir}", + f"--output_dir={output_model_dir}", + f"--validation_image_output_dir={validation_output_dir}", + f"--class_data_dir=/tmp/class_placeholder", + ] + + # 添加is_perturbed标志 + if is_perturbed: + cmd_args.append("--is_perturbed") + + # 添加其他默认参数 + for key, value in default_params.items(): + if key == 'pretrained_model_name_or_path': + continue + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 使用conda run执行虚拟脚本 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + image_type = 'perturbed' if is_perturbed else 'original' + log_file = os.path.join( + log_dir, + f'virtual_{finetune_method}_{image_type}_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}") + + # 统计生成的图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + generated_files = [] + for ext in image_extensions: + generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext))) + + logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images") + + return { + 'status': 'success', + 'output_dir': output_model_dir, + 'generated_count': len(generated_files), + 'generated_files': generated_files, + 'log_file': log_file + } + + +def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_task): + """ + 保存微调生成的验证图片到数据库 + + Args: + task_id: 任务ID + output_dir: 生成图片输出目录 + is_perturbed: 是否为扰动图片训练生成 + has_perturbation_task: 是否基于加噪任务 + """ + from app import db + from app.database import Task, Image, ImageType + + try: + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 获取生成图片类型 + if is_perturbed: + generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first() + else: + generated_type = ImageType.query.filter_by(image_code='original_generate').first() + + if not generated_type: + raise ValueError(f"Image type '{'perturbed' if is_perturbed else 'original'}_generate' not found") + + # 如果基于加噪任务,获取对应的训练图片以建立father关系 + father_images_map = {} + if has_perturbation_task: + if is_perturbed: + # 扰动图片生成,父图片是扰动图 + perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() + father_images = Image.query.filter_by( + task_id=task_id, + image_types_id=perturbed_type.image_types_id + ).all() + else: + # 原始图片生成,父图片是原始图 + original_type = ImageType.query.filter_by(image_code='original').first() + father_images = Image.query.filter_by( + task_id=task_id, + image_types_id=original_type.image_types_id + ).all() + + # 创建映射: stored_filename -> Image对象 + father_images_map = {img.stored_filename: img for img in father_images} + + # 查找输出目录中的生成图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + generated_files = [] + for ext in image_extensions: + generated_files.extend(glob.glob(os.path.join(output_dir, ext))) + generated_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + + logger.info(f"Found {len(generated_files)} generated images to save") + + saved_count = 0 + for generated_path in generated_files: + try: + # 获取文件名 + generated_filename = os.path.basename(generated_path) + + # 检查是否已经保存过 + existing = Image.query.filter_by( + task_id=task_id, + stored_filename=generated_filename, + image_types_id=generated_type.image_types_id + ).first() + + if existing: + logger.info(f"Generated image {generated_filename} already exists, skipping") + continue + + # 尝试根据文件名找到对应的父图片 + # 假设生成图片命名包含原始训练图片的名称 + father_id = None + if has_perturbation_task and father_images_map: + # 简单的匹配策略:查找文件名是否包含某个训练图片名 + for train_filename, train_image in father_images_map.items(): + base_name = os.path.splitext(train_filename)[0] + if base_name in generated_filename: + father_id = train_image.images_id + break + + # 读取图片尺寸 + try: + with PILImage.open(generated_path) as img: + width, height = img.size + except: + width, height = None, None + + # 保存到数据库 + generated_image = Image( + task_id=task_id, + image_types_id=generated_type.image_types_id, + father_id=father_id, # 如果找到对应的训练图片则设置father_id + stored_filename=generated_filename, + file_path=generated_path, + file_size=os.path.getsize(generated_path), + width=width, + height=height + ) + + db.session.add(generated_image) + saved_count += 1 + logger.info(f"Saved generated image: {generated_filename} (father: {father_id})") + + except Exception as e: + logger.error(f"Error saving generated image {generated_filename}: {str(e)}") + continue + + db.session.commit() + logger.info(f"Successfully saved {saved_count} generated images to database") + + except Exception as e: + logger.error(f"Error saving generated images: {str(e)}") + db.session.rollback() diff --git a/src/backend/app/workers/heatmap_worker.py b/src/backend/app/workers/heatmap_worker.py index d379eb5..4e00e50 100644 --- a/src/backend/app/workers/heatmap_worker.py +++ b/src/backend/app/workers/heatmap_worker.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def run_heatmap_task(heatmap_id, task_id, original_image_path, perturbed_image_path, - prompt_text, target_word, output_dir, model_path): + prompt_text, target_word, output_dir, model_path, original_image_id=None): """ 执行热力图生成任务 @@ -29,6 +29,7 @@ def run_heatmap_task(heatmap_id, task_id, original_image_path, perturbed_image_p target_word: 目标关键词(如 "person") output_dir: 输出目录 model_path: Stable Diffusion模型路径 + original_image_id: 原始图片ID (用于建立father关系) Returns: 任务执行结果 @@ -73,10 +74,12 @@ def run_heatmap_task(heatmap_id, task_id, original_image_path, perturbed_image_p # 使用虚拟实现(生成占位符图片) result = _run_virtual_heatmap(output_dir) - # 保存热力图文件路径到数据库 + # 保存热力图文件到数据库 heatmap_file = os.path.join(output_dir, 'heatmap_dif.png') if os.path.exists(heatmap_file): heatmap_task.heatmap_name = 'heatmap_dif.png' + # 保存热力图到Image表 + _save_heatmap_image(task_id, heatmap_file, original_image_id) db.session.commit() # 更新任务状态为完成 @@ -196,3 +199,66 @@ def _run_virtual_heatmap(output_dir): 'output_dir': output_dir, 'virtual': True } + + +def _save_heatmap_image(task_id, heatmap_file_path, father_image_id=None): + """ + 保存热力图到数据库Image表 + + Args: + task_id: 任务ID + heatmap_file_path: 热力图文件完整路径 + father_image_id: 父图片ID(原始图片ID) + """ + from app import db + from app.database import Image, ImageType + from PIL import Image as PILImage + + try: + # 获取热力图图片类型 + heatmap_type = ImageType.query.filter_by(image_code='heatmap').first() + if not heatmap_type: + logger.error("Image type 'heatmap' not found") + return + + # 获取文件名 + heatmap_filename = os.path.basename(heatmap_file_path) + + # 检查是否已经保存过 + existing = Image.query.filter_by( + task_id=task_id, + stored_filename=heatmap_filename, + image_types_id=heatmap_type.image_types_id + ).first() + + if existing: + logger.info(f"Heatmap image {heatmap_filename} already exists, skipping") + return + + # 读取图片尺寸 + try: + with PILImage.open(heatmap_file_path) as img: + width, height = img.size + except: + width, height = None, None + + # 保存到数据库 + heatmap_image = Image( + task_id=task_id, + image_types_id=heatmap_type.image_types_id, + father_id=father_image_id, # 设置父图片关系 + stored_filename=heatmap_filename, + file_path=heatmap_file_path, + file_size=os.path.getsize(heatmap_file_path), + width=width, + height=height + ) + + db.session.add(heatmap_image) + db.session.commit() + + logger.info(f"Saved heatmap image: {heatmap_filename} (father: {father_image_id})") + + except Exception as e: + logger.error(f"Error saving heatmap image: {str(e)}") + db.session.rollback() diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 774f80c..4c29c3a 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -1,422 +1,440 @@ -""" -RQ Worker任务处理器 -在后台执行对抗性扰动算法 -""" - -import os -import sys -import subprocess -import logging -from datetime import datetime -from pathlib import Path - -# 设置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - - -def run_perturbation_task(batch_id, algorithm_code, epsilon, use_strong_protection, - input_dir, output_dir, class_dir, custom_params=None): - """ - 执行对抗性扰动任务 - - Args: - batch_id: 任务批次ID - algorithm_code: 算法代码 - epsilon: 扰动强度 - use_strong_protection: 是否使用防净化版本 - input_dir: 输入图片目录 - output_dir: 输出目录 - class_dir: 类别图片目录 - custom_params: 自定义参数 - - Returns: - 任务执行结果 - """ - from config.algorithm_config import AlgorithmConfig - from app import create_app, db - from app.database import Batch - - # 创建应用上下文 - app = create_app() - - with app.app_context(): - try: - # 更新任务状态 - batch = Batch.query.get(batch_id) - if not batch: - raise ValueError(f"Batch {batch_id} not found") - - batch.status = 'processing' - batch.started_at = datetime.utcnow() - db.session.commit() - - logger.info(f"Starting perturbation task for batch {batch_id}") - logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}") - - # 获取算法配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS - script_path = AlgorithmConfig.get_script_path(algorithm_code) - conda_env = AlgorithmConfig.get_conda_env(algorithm_code) - - # 确保目录存在 - os.makedirs(output_dir, exist_ok=True) - os.makedirs(class_dir, exist_ok=True) - - if use_real: - # 使用真实算法 - result = _run_real_algorithm( - script_path, conda_env, algorithm_code, batch_id, - epsilon, use_strong_protection, input_dir, output_dir, - class_dir, custom_params - ) - else: - # 使用虚拟实现 - result = _run_virtual_algorithm( - algorithm_code, batch_id, epsilon, use_strong_protection, - input_dir, output_dir - ) - - # 更新任务状态为完成 - batch.status = 'completed' - batch.completed_at = datetime.utcnow() - db.session.commit() - - # 保存扰动图片到数据库 - _save_perturbed_images(batch_id, output_dir) - - logger.info(f"Task completed successfully for batch {batch_id}") - return result - - except Exception as e: - logger.error(f"Task failed for batch {batch_id}: {str(e)}", exc_info=True) - - # 更新任务状态为失败 - if batch: - batch.status = 'failed' - batch.error_message = str(e) - batch.completed_at = datetime.utcnow() - db.session.commit() - - raise - - -def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id, - epsilon, use_strong_protection, input_dir, output_dir, - class_dir, custom_params): - """运行真实算法""" - from config.algorithm_config import AlgorithmConfig - - logger.info(f"Running real algorithm: {algorithm_code}") - logger.info(f"Conda environment: {conda_env}") - logger.info(f"Script path: {script_path}") - - # 获取默认参数 - default_params = AlgorithmConfig.get_default_params(algorithm_code) - - # 合并自定义参数 - params = {**default_params, **(custom_params or {})} - - cmd_args = [] - if algorithm_code == 'aspl': - cmd_args.extend([ - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir={class_dir}", - f"--pgd_eps={str(epsilon)}", - ]) - elif algorithm_code == 'simac': - cmd_args.extend([ - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir={class_dir}", - f"--pgd_eps={str(epsilon)}", - ]) - elif algorithm_code == 'caat': - cmd_args.extend([ - f"--instance_data_dir={input_dir}", - f"--output_dir={output_dir}", - f"--eps={str(epsilon)}", - ]) - elif algorithm_code == 'pid': - cmd_args.extend([ - f"--instance_data_dir={input_dir}", - f"--output_dir={output_dir}", - f"--eps={str(epsilon)}", - ]) - else: - raise ValueError(f"Unsupported algorithm code: {algorithm_code}") - - # 添加其他参数 - for key, value in params.items(): - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 构建完整命令 - # 使用conda run避免环境嵌套问题 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'accelerate', 'launch', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f'batch_{batch_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - # 实时输出日志 - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - logger.info(f"output_dir: {output_dir}") - logger.info(f"log_file: {log_file}") - - if process.returncode != 0: - raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}") - - # 清理class_dir - logger.info(f"Cleaning class directory: {class_dir}") - if os.path.exists(class_dir): - import shutil - for item in os.listdir(class_dir): - item_path = os.path.join(class_dir, item) - if os.path.isfile(item_path): - os.remove(item_path) - elif os.path.isdir(item_path): - shutil.rmtree(item_path) - - return { - 'status': 'success', - 'output_dir': output_dir, - 'log_file': log_file - } - - -def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protection, - input_dir, output_dir): - """运行虚拟算法实现""" - from config.algorithm_config import AlgorithmConfig - import glob - - logger.info(f"Running virtual algorithm: {algorithm_code}") - - # 获取算法配置 - algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code) - if not algo_config: - raise ValueError(f"Algorithm {algorithm_code} not configured") - - conda_env = algo_config.get('conda_env') - default_params = algo_config.get('default_params', {}) - - # 获取虚拟算法脚本路径 - script_path = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../algorithms/perturbation_virtual', - f'{algorithm_code}.py' - )) - - if not os.path.exists(script_path): - raise FileNotFoundError(f"Virtual script not found: {script_path}") - - logger.info(f"Virtual script path: {script_path}") - logger.info(f"Conda environment: {conda_env}") - - # 确保输出目录存在 - os.makedirs(output_dir, exist_ok=True) - - # 构建命令行参数(与真实算法参数一致) - cmd_args = [ - f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir=/tmp/class_placeholder", - f"--pgd_eps={epsilon}", - ] - - # 添加其他默认参数 - for key, value in default_params.items(): - if key == 'pretrained_model_name_or_path': - continue # 已添加 - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 使用conda run执行虚拟脚本 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'python', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join( - log_dir, - f'virtual_{algorithm_code}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' - ) - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - # 实时输出日志 - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - if process.returncode != 0: - raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}") - - # 统计处理的图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - processed_files = [] - for ext in image_extensions: - processed_files.extend(glob.glob(os.path.join(output_dir, ext))) - processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) - - logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images") - - return { - 'status': 'success', - 'output_dir': output_dir, - 'processed_count': len(processed_files), - 'processed_files': processed_files, - 'log_file': log_file - } - - -def _save_perturbed_images(batch_id, output_dir): - """保存扰动图片到数据库""" - from app import db - from app.database import Batch, Image, ImageType - import glob - from PIL import Image as PILImage - - try: - batch = Batch.query.get(batch_id) - if not batch: - logger.error(f"Batch {batch_id} not found") - return - - # 获取扰动图片类型 - perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() - if not perturbed_type: - logger.error("Perturbed image type not found") - return - - # 获取原始图片列表 - original_type = ImageType.query.filter_by(type_code='original').first() - original_images = Image.query.filter_by( - batch_id=batch_id, - image_type_id=original_type.id - ).all() - - # 创建原图映射字典: stored_filename -> Image对象 - original_map = {img.stored_filename: img for img in original_images} - - # 查找输出目录中的扰动图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - perturbed_files = [] - for ext in image_extensions: - perturbed_files.extend(glob.glob(os.path.join(output_dir, ext))) - perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) - - logger.info(f"Found {len(perturbed_files)} perturbed images to save") - - saved_count = 0 - for perturbed_path in perturbed_files: - try: - filename = os.path.basename(perturbed_path) - - # 扰动图片命名格式: perturbed_{原图名}.ext - # 提取原图名 - parent_image = None - if filename.startswith('perturbed_'): - # 去掉perturbed_前缀,得到原图名 - original_filename = filename[len('perturbed_'):] - # 尝试从映射中查找 - parent_image = original_map.get(original_filename) - if not parent_image: - logger.warning(f"Parent image not found for {filename}, original should be: {original_filename}") - - # 获取图片尺寸 - with PILImage.open(perturbed_path) as img: - width, height = img.size - - # 检查是否已经保存过(使用filename作为stored_filename) - existing = Image.query.filter_by( - batch_id=batch_id, - stored_filename=filename - ).first() - - if existing: - logger.info(f"Image already exists: {filename}") - continue - - # 创建扰动图片记录(直接使用filename,因为算法已经添加了perturbed_前缀) - perturbed_image = Image( - user_id=batch.user_id, - batch_id=batch_id, - father_id=parent_image.id if parent_image else None, - original_filename=filename, - stored_filename=filename, # 算法输出已经是perturbed_格式 - file_path=perturbed_path, - file_size=os.path.getsize(perturbed_path), - image_type_id=perturbed_type.id, - width=width, - height=height - ) - - db.session.add(perturbed_image) - saved_count += 1 - logger.info(f"Saved perturbed image: {filename} (parent: {parent_image.stored_filename if parent_image else 'None'})") - - except Exception as e: - logger.error(f"Failed to save {perturbed_path}: {str(e)}") - - db.session.commit() - logger.info(f"Successfully saved {saved_count} perturbed images to database") - - except Exception as e: - logger.error(f"Error saving perturbed images: {str(e)}") - db.session.rollback() +""" +RQ Worker任务处理器 - 加噪任务 +适配新数据库结构: Task + Perturbation + Images +""" + +import os +import subprocess +import logging +import glob +from datetime import datetime +from pathlib import Path +from PIL import Image as PILImage + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_perturbation_task(task_id, algorithm_code, epsilon, use_strong_protection, + input_dir, output_dir, class_dir, custom_params=None): + """ + 执行对抗性扰动任务 + + Args: + task_id: 任务ID(对应 tasks 表的 tasks_id) + algorithm_code: 算法代码 + epsilon: 扰动强度 + use_strong_protection: 是否使用防净化版本 + input_dir: 输入图片目录 + output_dir: 输出目录 + class_dir: 类别图片目录 + custom_params: 自定义参数 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Task, Perturbation, TaskStatus + + # 创建应用上下文 + app = create_app() + + with app.app_context(): + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 获取加噪任务详情 + perturbation = Perturbation.query.get(task_id) + if not perturbation: + raise ValueError(f"Perturbation task {task_id} not found") + + # 更新任务状态为处理中 + processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() + if processing_status: + task.tasks_status_id = processing_status.task_status_id + task.started_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Starting perturbation task {task_id}") + logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}") + + # 获取算法配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + script_path = AlgorithmConfig.get_script_path(algorithm_code) + conda_env = AlgorithmConfig.get_conda_env(algorithm_code) + + # 确保目录存在 + os.makedirs(output_dir, exist_ok=True) + os.makedirs(class_dir, exist_ok=True) + + if use_real: + # 使用真实算法 + result = _run_real_algorithm( + script_path, conda_env, algorithm_code, task_id, + epsilon, use_strong_protection, input_dir, output_dir, + class_dir, custom_params + ) + else: + # 使用虚拟实现 + result = _run_virtual_algorithm( + algorithm_code, task_id, epsilon, use_strong_protection, + input_dir, output_dir + ) + + # 保存扰动图片到数据库 + _save_perturbed_images(task_id, output_dir) + + # 更新任务状态为完成 + completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() + if completed_status: + task.tasks_status_id = completed_status.task_status_id + task.finished_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Perturbation task {task_id} completed successfully") + return result + + except Exception as e: + logger.error(f"Perturbation task {task_id} failed: {str(e)}", exc_info=True) + + # 更新任务状态为失败 + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + task.finished_at = datetime.utcnow() + task.error_message = str(e) + db.session.commit() + + raise + + +def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, + epsilon, use_strong_protection, input_dir, output_dir, + class_dir, custom_params): + """运行真实算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real algorithm: {algorithm_code}") + logger.info(f"Conda environment: {conda_env}") + logger.info(f"Script path: {script_path}") + + # 获取默认参数 + default_params = AlgorithmConfig.get_default_params(algorithm_code) + + # 合并自定义参数 + params = {**default_params, **(custom_params or {})} + + cmd_args = [] + if algorithm_code == 'aspl': + cmd_args.extend([ + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir={class_dir}", + f"--pgd_eps={str(epsilon)}", + ]) + elif algorithm_code == 'simac': + cmd_args.extend([ + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir={class_dir}", + f"--pgd_eps={str(epsilon)}", + ]) + elif algorithm_code == 'caat': + cmd_args.extend([ + f"--instance_data_dir={input_dir}", + f"--output_dir={output_dir}", + f"--eps={str(epsilon)}", + ]) + elif algorithm_code == 'pid': + cmd_args.extend([ + f"--instance_data_dir={input_dir}", + f"--output_dir={output_dir}", + f"--eps={str(epsilon)}", + ]) + else: + raise ValueError(f"Unsupported algorithm code: {algorithm_code}") + + # 添加其他参数 + for key, value in params.items(): + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 构建完整命令 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'accelerate', 'launch', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f'task_{task_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + # 实时输出日志 + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + logger.info(f"output_dir: {output_dir}") + logger.info(f"log_file: {log_file}") + + if process.returncode != 0: + raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}") + + # 清理class_dir + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + import shutil + shutil.rmtree(class_dir) + os.makedirs(class_dir) + + return { + 'status': 'success', + 'output_dir': output_dir, + 'log_file': log_file + } + + +def _run_virtual_algorithm(algorithm_code, task_id, epsilon, use_strong_protection, + input_dir, output_dir): + """运行虚拟算法实现""" + from config.algorithm_config import AlgorithmConfig + import shutil + + logger.info(f"Running virtual algorithm: {algorithm_code}") + + # 获取算法配置 + algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code) + if not algo_config: + raise ValueError(f"Algorithm {algorithm_code} not configured") + + conda_env = algo_config.get('conda_env') + default_params = algo_config.get('default_params', {}) + + # 获取虚拟算法脚本路径 + script_path = os.path.abspath(os.path.join( + os.path.dirname(__file__), + '../algorithms/perturbation_virtual', + f'{algorithm_code}.py' + )) + + if not os.path.exists(script_path): + raise FileNotFoundError(f"Virtual script not found: {script_path}") + + logger.info(f"Virtual script path: {script_path}") + logger.info(f"Conda environment: {conda_env}") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 构建命令行参数 + cmd_args = [ + f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir=/tmp/class_placeholder", + f"--pgd_eps={epsilon}", + ] + + # 添加其他默认参数 + for key, value in default_params.items(): + if key == 'pretrained_model_name_or_path': + continue + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 使用conda run执行虚拟脚本 + cmd = [ + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join( + log_dir, + f'virtual_{algorithm_code}_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + # 实时输出日志 + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}") + + # 统计处理的图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + processed_files = [] + for ext in image_extensions: + processed_files.extend(glob.glob(os.path.join(output_dir, ext))) + processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + + logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'processed_count': len(processed_files), + 'processed_files': processed_files, + 'log_file': log_file + } + + +def _save_perturbed_images(task_id, output_dir): + """ + 保存扰动图片到数据库(适配新数据库结构) + + Args: + task_id: 任务ID + output_dir: 扰动图片输出目录 + """ + from app import db + from app.database import Task, Image, ImageType + + try: + task = Task.query.get(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + # 获取扰动图片类型 + perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() + if not perturbed_type: + raise ValueError("Image type 'perturbed' not found") + + # 获取原始图片列表(同一个task_id下的原始图片) + original_type = ImageType.query.filter_by(image_code='original').first() + original_images = Image.query.filter_by( + task_id=task_id, + image_types_id=original_type.image_types_id + ).all() + + # 创建原图映射字典: stored_filename -> Image对象 + original_map = {img.stored_filename: img for img in original_images} + + # 查找输出目录中的扰动图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + perturbed_files = [] + for ext in image_extensions: + perturbed_files.extend(glob.glob(os.path.join(output_dir, ext))) + perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + + logger.info(f"Found {len(perturbed_files)} perturbed images to save") + + saved_count = 0 + for perturbed_path in perturbed_files: + try: + # 获取文件名(不含路径) + perturbed_filename = os.path.basename(perturbed_path) + + # 尝试找到对应的原始图片 + # 假设扰动图片命名为: perturbed_{original_name}.ext + original_filename = perturbed_filename + if perturbed_filename.startswith('perturbed_'): + original_filename = perturbed_filename[len('perturbed_'):] + + original_image = original_map.get(original_filename) + if not original_image: + # 尝试完全匹配 + matching_images = [img for img in original_images if img.stored_filename == perturbed_filename] + if matching_images: + original_image = matching_images[0] + else: + logger.warning(f"Could not find original image for {perturbed_filename}") + # 即使找不到父图片也保存,但father_id设为None + + # 检查是否已经保存过 + existing = Image.query.filter_by( + task_id=task_id, + stored_filename=perturbed_filename, + image_types_id=perturbed_type.image_types_id + ).first() + + if existing: + logger.info(f"Perturbed image {perturbed_filename} already exists, skipping") + continue + + # 读取图片尺寸 + try: + with PILImage.open(perturbed_path) as img: + width, height = img.size + except: + width, height = None, None + + # 保存到数据库(使用新结构) + perturbed_image = Image( + task_id=task_id, + image_types_id=perturbed_type.image_types_id, + father_id=original_image.images_id if original_image else None, # 设置父图片关系 + stored_filename=perturbed_filename, + file_path=perturbed_path, + file_size=os.path.getsize(perturbed_path), + width=width, + height=height + ) + + db.session.add(perturbed_image) + saved_count += 1 + logger.info(f"Saved perturbed image: {perturbed_filename} (father: {original_image.images_id if original_image else 'None'})") + + except Exception as e: + logger.error(f"Error saving perturbed image {perturbed_filename}: {str(e)}") + continue + + db.session.commit() + logger.info(f"Successfully saved {saved_count} perturbed images to database") + + except Exception as e: + logger.error(f"Error saving perturbed images: {str(e)}") + db.session.rollback() -- 2.34.1 From bb9edee6e644431e09323c3b4df4b6c92c4e0fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 20:03:15 +0800 Subject: [PATCH 06/14] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E4=B8=AD=E7=83=AD=E5=8A=9B=E5=9B=BE=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E4=B8=BB=E9=94=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/database/__init__.py | 50 +++++++++++++++------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 8349728..0526edb 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -205,15 +205,13 @@ class Task(db.Model): task_status = db.relationship('TaskStatus', backref='tasks') images = db.relationship('Image', backref='task', lazy='dynamic', cascade='all, delete-orphan') - # --- 变更部分 --- - # 与子表的一对一关系 (perturbation, heatmap) + # 与子表的一对一关系 (perturbation) perturbation = db.relationship('Perturbation', uselist=False, back_populates='task', cascade='all, delete-orphan') - heatmap = db.relationship('Heatmap', uselist=False, back_populates='task', cascade='all, delete-orphan') - # 与子表的一对多关系 (finetune, evaluate) + # 与子表的一对多关系 (heatmap, finetune, evaluate) + heatmaps = db.relationship('Heatmap', back_populates='task', cascade='all, delete-orphan') finetunes = db.relationship('Finetune', back_populates='task', cascade='all, delete-orphan') evaluations = db.relationship('Evaluate', back_populates='task', cascade='all, delete-orphan') - # --- 变更结束 --- def __repr__(self): return f'' @@ -239,21 +237,21 @@ class Perturbation(db.Model): return f'' # ---------------------------- -# 7. 任务子表:微调任务 (finetune) - [已更新为复合主键] +# 7. 任务子表:微调任务 (finetune) - [复合主键:tasks_id + finetune_configs_id] # ---------------------------- class Finetune(db.Model): - """微调任务详情表""" + """微调任务详情表 + 说明: + - tasks_id 与对应加噪任务相同(第一种模式)或独立(第二种上传模式) + - finetune_configs_id 表示使用的微调算法配置 + """ __tablename__ = 'finetune' - # --- 变更部分:复合主键 --- tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表关联') finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='微调配置ID') - # --- 变更结束 --- data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集') finetune_name = db.Column(String(100), comment='微调任务名称') - # --- 变更部分:更新 back_populates --- task = db.relationship('Task', back_populates='finetunes') - # --- 变更结束 --- finetune_config = db.relationship('FinetuneConfig') data_type = db.relationship('DataType') @@ -276,21 +274,21 @@ class EvaluationResult(db.Model): return f'' # ---------------------------- -# 9. 任务子表:评估任务 (evaluate) - [已更新为复合主键] +# 9. 任务子表:评估任务 (evaluate) - [复合主键:tasks_id + finetune_configs_id] # ---------------------------- class Evaluate(db.Model): - """指标计算任务表""" + """指标计算任务表 + 说明: + - tasks_id 与对应微调任务相同 + - finetune_configs_id 与对应微调任务的配置相同 + """ __tablename__ = 'evaluate' - # --- 变更部分:复合主键 --- tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True) - finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='关联的微调配置(如果是针对微调的评估)') - # --- 变更结束 --- + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='关联的微调配置') evaluate_name = db.Column(String(100)) evaluation_results_id = db.Column(BigInteger, ForeignKey('evaluation_results.evaluation_results_id'), unique=True, default=None, comment='关联的结果ID') - # --- 变更部分:更新 back_populates --- task = db.relationship('Task', back_populates='evaluations') - # --- 变更结束 --- finetune_config = db.relationship('FinetuneConfig') evaluation_result = db.relationship('EvaluationResult', backref='evaluate_task', uselist=False) @@ -298,19 +296,25 @@ class Evaluate(db.Model): return f'' # ---------------------------- -# 10. 任务子表:热力图计算任务 (heatmap) +# 10. 任务子表:热力图计算任务 (heatmap) - [复合主键:tasks_id + images_id] # ---------------------------- class Heatmap(db.Model): - """热力图计算任务表""" + """热力图计算任务表 + 说明: + - tasks_id 与对应加噪任务相同 + - images_id 是该加噪任务的某张加噪图的ID + """ __tablename__ = 'heatmap' - tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True) + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与加噪任务相同的任务ID') + images_id = db.Column(BigInteger, ForeignKey('images.images_id', ondelete='CASCADE'), primary_key=True, comment='加噪图的ID') heatmap_name = db.Column(String(100)) # 关系 - task = db.relationship('Task', back_populates='heatmap') + task = db.relationship('Task', back_populates='heatmaps') + perturbation_image = db.relationship('Image', foreign_keys=[images_id]) def __repr__(self): - return f'' + return f'' # ---------------------------- # 11. 图片表 (images) -- 2.34.1 From 71e667cdd50c5a3ea14d7a54e251658f9f690b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 30 Nov 2025 20:07:00 +0800 Subject: [PATCH 07/14] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E5=BE=AE?= =?UTF-8?q?=E8=B0=83=E7=94=9F=E6=88=90=E5=9B=BE=E7=89=87=E7=9A=84=E7=88=B6?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E4=BF=9D=E5=AD=98=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/workers/finetune_worker.py | 61 ++++++++++------------ 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 3eaf6da..cdb8eec 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -380,6 +380,8 @@ def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_t """ 保存微调生成的验证图片到数据库 + 核心逻辑:所有生成图的father_id设置为输入图片的第一张 + Args: task_id: 任务ID output_dir: 生成图片输出目录 @@ -403,26 +405,30 @@ def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_t if not generated_type: raise ValueError(f"Image type '{'perturbed' if is_perturbed else 'original'}_generate' not found") - # 如果基于加噪任务,获取对应的训练图片以建立father关系 - father_images_map = {} - if has_perturbation_task: - if is_perturbed: - # 扰动图片生成,父图片是扰动图 - perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - father_images = Image.query.filter_by( - task_id=task_id, - image_types_id=perturbed_type.image_types_id - ).all() - else: - # 原始图片生成,父图片是原始图 - original_type = ImageType.query.filter_by(image_code='original').first() - father_images = Image.query.filter_by( - task_id=task_id, - image_types_id=original_type.image_types_id - ).all() - - # 创建映射: stored_filename -> Image对象 - father_images_map = {img.stored_filename: img for img in father_images} + # 获取输入图片的第一张作为father_id + father_id = None + if is_perturbed: + # 扰动图片生成,父图片是第一张扰动图 + perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() + first_image = Image.query.filter_by( + task_id=task_id, + image_types_id=perturbed_type.image_types_id + ).order_by(Image.images_id.asc()).first() + if first_image: + father_id = first_image.images_id + else: + # 原始图片生成,父图片是第一张原始图 + original_type = ImageType.query.filter_by(image_code='original').first() + first_image = Image.query.filter_by( + task_id=task_id, + image_types_id=original_type.image_types_id + ).order_by(Image.images_id.asc()).first() + if first_image: + father_id = first_image.images_id + + logger.info(f"Will set father_id={father_id} for all generated images") + + logger.info(f"Will set father_id={father_id} for all generated images") # 查找输出目录中的生成图片 image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] @@ -450,17 +456,6 @@ def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_t logger.info(f"Generated image {generated_filename} already exists, skipping") continue - # 尝试根据文件名找到对应的父图片 - # 假设生成图片命名包含原始训练图片的名称 - father_id = None - if has_perturbation_task and father_images_map: - # 简单的匹配策略:查找文件名是否包含某个训练图片名 - for train_filename, train_image in father_images_map.items(): - base_name = os.path.splitext(train_filename)[0] - if base_name in generated_filename: - father_id = train_image.images_id - break - # 读取图片尺寸 try: with PILImage.open(generated_path) as img: @@ -468,11 +463,11 @@ def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_t except: width, height = None, None - # 保存到数据库 + # 保存到数据库,所有生成图的father_id统一设置为输入的第一张图片 generated_image = Image( task_id=task_id, image_types_id=generated_type.image_types_id, - father_id=father_id, # 如果找到对应的训练图片则设置father_id + father_id=father_id, # 统一设置为输入的第一张图片 stored_filename=generated_filename, file_path=generated_path, file_size=os.path.getsize(generated_path), -- 2.34.1 From de56e358b9bd181f60d869df24e85f583ee59192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 01:02:04 +0800 Subject: [PATCH 08/14] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=9B=B8=E5=85=B3=E6=95=B0=E6=8D=AE=E5=BA=93=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=8A=E4=BC=A0=E5=9B=BE=E7=89=87=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=9B=BE=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/database/__init__.py | 1 + src/backend/config/settings.py | 3 ++- src/backend/init_db.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 0526edb..80afb22 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -191,6 +191,7 @@ class Task(db.Model): """任务总表""" __tablename__ = 'tasks' tasks_id = db.Column(BigInteger, primary_key=True, autoincrement=True, comment='任务ID') + flow_id = db.Column(BigInteger, nullable=False, index=True, comment='工作流ID,标识关联的任务组') tasks_type_id = db.Column(Integer, ForeignKey('task_type.task_type_id'), nullable=False, comment='任务类型') user_id = db.Column(Integer, ForeignKey('users.user_id'), nullable=False, index=True, comment='归属用户') tasks_status_id = db.Column(Integer, ForeignKey('task_status.task_status_id'), nullable=False, comment='任务状态ID') diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 02d57b7..f59610e 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -51,7 +51,8 @@ class Config: # 图像处理配置 ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片 PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 - MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 + MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录、 + MODEL_UPLOADED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'uploaded') # 上传图的模型生成结果 MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 diff --git a/src/backend/init_db.py b/src/backend/init_db.py index 9f15d70..ca1c814 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -42,6 +42,7 @@ def init_database(): image_types = [ {'image_code': 'original', 'image_name': '原始图', 'description': '用户上传的原始图像'}, {'image_code': 'perturbed', 'image_name': '加噪图', 'description': '经过扰动算法处理后的防护图像'}, + {'image_code': 'uploaded_generate', 'image_name': '上传图片生成图', 'description': '使用上传图片训练后生成的图像'}, {'image_code': 'original_generate', 'image_name': '原始图像生成图', 'description': '使用原始图像训练后生成的图像'}, {'image_code': 'perturbed_generate', 'image_name': '加噪图像生成图', 'description': '使用加噪图像训练后生成的图像'}, {'image_code': 'heatmap', 'image_name': '热力图', 'description': '原始图与加噪图的差异热力图'}, -- 2.34.1 From a8e4659a545876e97c40a09f1ab563c5143db6c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 01:03:27 +0800 Subject: [PATCH 09/14] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=9B=B8=E5=85=B3=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/database/__init__.py | 58 +++++++++++----------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 80afb22..85af8f4 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -206,13 +206,11 @@ class Task(db.Model): task_status = db.relationship('TaskStatus', backref='tasks') images = db.relationship('Image', backref='task', lazy='dynamic', cascade='all, delete-orphan') - # 与子表的一对一关系 (perturbation) + # 与子表的一对一关系 perturbation = db.relationship('Perturbation', uselist=False, back_populates='task', cascade='all, delete-orphan') - - # 与子表的一对多关系 (heatmap, finetune, evaluate) - heatmaps = db.relationship('Heatmap', back_populates='task', cascade='all, delete-orphan') - finetunes = db.relationship('Finetune', back_populates='task', cascade='all, delete-orphan') - evaluations = db.relationship('Evaluate', back_populates='task', cascade='all, delete-orphan') + heatmap = db.relationship('Heatmap', uselist=False, back_populates='task', cascade='all, delete-orphan') + finetune = db.relationship('Finetune', uselist=False, back_populates='task', cascade='all, delete-orphan') + evaluation = db.relationship('Evaluate', uselist=False, back_populates='task', cascade='all, delete-orphan') def __repr__(self): return f'' @@ -238,26 +236,22 @@ class Perturbation(db.Model): return f'' # ---------------------------- -# 7. 任务子表:微调任务 (finetune) - [复合主键:tasks_id + finetune_configs_id] +# 7. 任务子表:微调任务 (finetune) # ---------------------------- class Finetune(db.Model): - """微调任务详情表 - 说明: - - tasks_id 与对应加噪任务相同(第一种模式)或独立(第二种上传模式) - - finetune_configs_id 表示使用的微调算法配置 - """ + """微调任务详情表""" __tablename__ = 'finetune' - tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表关联') - finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='微调配置ID') + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联') + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='微调配置ID') data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集') finetune_name = db.Column(String(100), comment='微调任务名称') - task = db.relationship('Task', back_populates='finetunes') + task = db.relationship('Task', back_populates='finetune') finetune_config = db.relationship('FinetuneConfig') data_type = db.relationship('DataType') def __repr__(self): - return f'' + return f'' # ---------------------------- # 8. 评估结果表 (evaluation_results) @@ -275,47 +269,39 @@ class EvaluationResult(db.Model): return f'' # ---------------------------- -# 9. 任务子表:评估任务 (evaluate) - [复合主键:tasks_id + finetune_configs_id] +# 9. 任务子表:评估任务 (evaluate) # ---------------------------- class Evaluate(db.Model): - """指标计算任务表 - 说明: - - tasks_id 与对应微调任务相同 - - finetune_configs_id 与对应微调任务的配置相同 - """ + """指标计算任务表""" __tablename__ = 'evaluate' - tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True) - finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='关联的微调配置') + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联') + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='关联的微调配置') evaluate_name = db.Column(String(100)) evaluation_results_id = db.Column(BigInteger, ForeignKey('evaluation_results.evaluation_results_id'), unique=True, default=None, comment='关联的结果ID') - task = db.relationship('Task', back_populates='evaluations') + task = db.relationship('Task', back_populates='evaluation') finetune_config = db.relationship('FinetuneConfig') evaluation_result = db.relationship('EvaluationResult', backref='evaluate_task', uselist=False) def __repr__(self): - return f'' + return f'' # ---------------------------- -# 10. 任务子表:热力图计算任务 (heatmap) - [复合主键:tasks_id + images_id] +# 10. 任务子表:热力图计算任务 (heatmap) # ---------------------------- class Heatmap(db.Model): - """热力图计算任务表 - 说明: - - tasks_id 与对应加噪任务相同 - - images_id 是该加噪任务的某张加噪图的ID - """ + """热力图计算任务表""" __tablename__ = 'heatmap' - tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与加噪任务相同的任务ID') - images_id = db.Column(BigInteger, ForeignKey('images.images_id', ondelete='CASCADE'), primary_key=True, comment='加噪图的ID') + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联') + images_id = db.Column(BigInteger, ForeignKey('images.images_id', ondelete='CASCADE'), nullable=False, comment='关联的加噪图ID') heatmap_name = db.Column(String(100)) # 关系 - task = db.relationship('Task', back_populates='heatmaps') + task = db.relationship('Task', back_populates='heatmap') perturbation_image = db.relationship('Image', foreign_keys=[images_id]) def __repr__(self): - return f'' + return f'' # ---------------------------- # 11. 图片表 (images) -- 2.34.1 From 2095ccc19571ec046b577304fb3994db0c832ae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 01:51:07 +0800 Subject: [PATCH 10/14] =?UTF-8?q?refactor:=20=E5=8E=BB=E9=99=A4worker?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E7=9A=84=E8=99=9A=E6=8B=9F=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/workers/evaluate_worker.py | 95 ++--- src/backend/app/workers/finetune_worker.py | 312 ++++++---------- src/backend/app/workers/heatmap_worker.py | 130 +++---- .../app/workers/perturbation_worker.py | 332 ++++++++---------- 4 files changed, 359 insertions(+), 510 deletions(-) diff --git a/src/backend/app/workers/evaluate_worker.py b/src/backend/app/workers/evaluate_worker.py index 0f3b4b7..0e5acae 100644 --- a/src/backend/app/workers/evaluate_worker.py +++ b/src/backend/app/workers/evaluate_worker.py @@ -1,11 +1,12 @@ """ -RQ Worker 数值评估任务处理器 +RQ Worker 数值评估任务处理器(仅使用真实算法) 生成原始图与扰动图微调后的模型生成效果对比报告 """ import os import subprocess import logging +import shutil from datetime import datetime logging.basicConfig( @@ -15,14 +16,13 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def run_evaluate_task(evaluate_id, task_id, clean_ref_dir, clean_output_dir, +def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, perturbed_output_dir, output_dir, image_size=512): """ - 执行数值评估任务 + 执行数值评估任务(仅使用真实算法) Args: - evaluate_id: 评估任务ID(复合主键之一) - task_id: 关联的主任务ID + task_id: 任务ID clean_ref_dir: 干净参考图片目录(原始上传的图片) clean_output_dir: 干净图片训练后的生成结果目录 perturbed_output_dir: 扰动图片训练后的生成结果目录 @@ -40,39 +40,40 @@ def run_evaluate_task(evaluate_id, task_id, clean_ref_dir, clean_output_dir, with app.app_context(): try: - # Evaluate 使用复合主键,需要通过 tasks_id 和 finetune_configs_id 查询 - # 这里先通过 task_id 查询 - evaluate_task = Evaluate.query.filter_by(tasks_id=task_id).first() - if not evaluate_task: - raise ValueError(f"Evaluate task for Task {task_id} not found") - + # 获取任务 task = Task.query.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") + # 获取评估任务详情 + evaluate = Evaluate.query.get(task_id) + if not evaluate: + raise ValueError(f"Evaluate task {task_id} not found") + # 更新任务状态为处理中 processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() if processing_status: task.tasks_status_id = processing_status.task_status_id + task.started_at = datetime.utcnow() db.session.commit() - logger.info(f"Starting evaluate task for Task {task_id}") + logger.info(f"Starting evaluate task {task_id}") - # 确保目录存在 + # 确保目录存在并清空 os.makedirs(output_dir, exist_ok=True) + logger.info(f"Clearing output directory: {output_dir}") + for item in os.listdir(output_dir): + item_path = os.path.join(output_dir, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) - # 获取配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS - - if use_real: - # 使用真实评估算法 - result = _run_real_evaluate( - task_id, clean_ref_dir, clean_output_dir, - perturbed_output_dir, output_dir, image_size - ) - else: - # 使用虚拟实现(生成占位符报告) - result = _run_virtual_evaluate(output_dir) + # 运行真实评估算法 + result = _run_real_evaluate( + task_id, clean_ref_dir, clean_output_dir, + perturbed_output_dir, output_dir, image_size + ) # 保存评估结果文件路径到数据库 report_file = os.path.join(output_dir, 'nums_dif.png') @@ -84,21 +85,23 @@ def run_evaluate_task(evaluate_id, task_id, clean_ref_dir, clean_output_dir, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id + task.completed_at = datetime.utcnow() db.session.commit() - logger.info(f"Evaluate task completed for Task {task_id}") + logger.info(f"Evaluate task {task_id} completed") return result except Exception as e: - logger.error(f"Evaluate task failed for Task {task_id}: {str(e)}", exc_info=True) + logger.error(f"Evaluate task {task_id} failed: {str(e)}", exc_info=True) # 更新任务状态为失败 failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id + task.completed_at = datetime.utcnow() db.session.commit() - raise + return {'success': False, 'error': str(e)} def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir, @@ -136,6 +139,10 @@ def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir, logger.info(f"Executing command: {' '.join(cmd)}") + # 设置环境变量(强制离线模式) + env = os.environ.copy() + env['HF_HUB_OFFLINE'] = '1' + # 设置日志文件 log_dir = AlgorithmConfig.LOGS_DIR os.makedirs(log_dir, exist_ok=True) @@ -152,7 +159,8 @@ def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir, stderr=subprocess.STDOUT, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, + env=env ) for line in process.stdout: @@ -172,35 +180,6 @@ def _run_real_evaluate(task_id, clean_ref_dir, clean_output_dir, } -def _run_virtual_evaluate(output_dir): - """运行虚拟评估实现(生成占位符)""" - logger.info(f"Running virtual evaluate generation") - - # 创建占位符图片 - from PIL import Image, ImageDraw - - # 创建一个模拟的评估报告图 - img = Image.new('RGB', (1200, 1600), color=(255, 255, 255)) - draw = ImageDraw.Draw(img) - - # 添加标题文本 - draw.text((50, 50), "Virtual Evaluation Report Placeholder", fill=(0, 0, 0)) - draw.text((50, 100), "Real evaluation will be generated when USE_REAL_ALGORITHMS=true", fill=(128, 128, 128)) - draw.text((50, 150), "Metrics: FID, SSIM, PSNR, FDS, CLIP_IQS, BRISQUE", fill=(64, 64, 64)) - - # 保存 - output_file = os.path.join(output_dir, 'nums_dif.png') - img.save(output_file) - - logger.info(f"Virtual evaluation report saved to {output_file}") - - return { - 'status': 'success', - 'output_dir': output_dir, - 'virtual': True - } - - def _save_report_image(task_id, report_file_path): """ 保存评估报告图到数据库Image表 diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index cdb8eec..9800085 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -1,14 +1,13 @@ """ RQ Worker 微调任务处理器 - 适配新数据库结构 -支持两种微调模式: -1. 基于加噪任务的微调 (共享task_id) -2. 直接上传图片的微调 (独立task_id) +仅支持真实算法,移除虚拟算法调用 """ import os import subprocess import logging import glob +import shutil from datetime import datetime from PIL import Image as PILImage @@ -21,9 +20,9 @@ logger = logging.getLogger(__name__) def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images_dir, output_model_dir, class_dir, coords_save_path, validation_output_dir, - is_perturbed=False, has_perturbation_task=False, custom_params=None): + is_perturbed=False, custom_params=None): """ - 执行微调任务 + 执行微调任务(仅使用真实算法) Args: task_id: 任务ID @@ -35,7 +34,6 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images coords_save_path: 坐标保存路径 validation_output_dir: 验证图片输出目录 is_perturbed: 是否使用扰动图片训练 - has_perturbation_task: 是否基于加噪任务(True表示共享task_id, False表示独立任务) custom_params: 自定义参数 Returns: @@ -43,7 +41,7 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images """ from config.algorithm_config import AlgorithmConfig from app import create_app, db - from app.database import Task, Finetune, DataType, Perturbation, TaskStatus + from app.database import Task, Finetune, DataType, TaskStatus app = create_app() @@ -74,43 +72,41 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images logger.info(f"Starting finetune task {task_id} (config: {finetune_config_id})") logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}") - # 获取Prompt文本 - # 优先从Finetune的data_type获取,如果没有则尝试从关联的Perturbation获取 - inference_prompts = "a photo of sks person" # 默认值 + # 从数据库获取数据集类型的提示词 + # 从Finetune表的data_type_id获取 + instance_prompt = "a photo of sks person" # 默认值 + class_prompt = "a photo of person" # 默认值 + validation_prompt = "a photo of sks person" # 默认值 if finetune.data_type_id: - # 从微调任务的数据集类型获取 data_type = DataType.query.get(finetune.data_type_id) if data_type and data_type.data_type_prompt: - inference_prompts = data_type.data_type_prompt - logger.info(f"Using prompt from Finetune.data_type: {inference_prompts}") - elif has_perturbation_task: - # 如果是基于加噪任务,尝试从加噪任务获取 - perturbation = Perturbation.query.filter_by(tasks_id=task_id).first() - if perturbation and perturbation.data_type_id: - data_type = DataType.query.get(perturbation.data_type_id) - if data_type and data_type.data_type_prompt: - inference_prompts = data_type.data_type_prompt - logger.info(f"Using prompt from Perturbation.data_type: {inference_prompts}") + instance_prompt = data_type.data_type_prompt + validation_prompt = instance_prompt + # 从instance_prompt生成class_prompt(移除"sks") + class_prompt = instance_prompt.replace('sks ', '') + logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") - # 获取配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + # 清空输出目录(避免旧文件残留) + logger.info(f"Clearing output directories...") + for dir_path in [output_model_dir, validation_output_dir, coords_save_path]: + if os.path.exists(dir_path): + for item in os.listdir(dir_path): + item_path = os.path.join(dir_path, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) - if use_real: - # 使用真实微调算法 - result = _run_real_finetune( - finetune_method, task_id, train_images_dir, output_model_dir, - class_dir, coords_save_path, validation_output_dir, - inference_prompts, is_perturbed, custom_params - ) - else: - # 使用虚拟实现 - result = _run_virtual_finetune( - finetune_method, task_id, train_images_dir, output_model_dir, is_perturbed - ) + # 运行真实微调算法 + result = _run_real_finetune( + finetune_method, task_id, train_images_dir, output_model_dir, + class_dir, coords_save_path, validation_output_dir, + instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params + ) # 保存生成的验证图片到数据库 - _save_generated_images(task_id, validation_output_dir, is_perturbed, has_perturbation_task) + _save_generated_images(task_id, validation_output_dir, is_perturbed) # 更新任务状态为完成 completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() @@ -126,23 +122,45 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images logger.error(f"Finetune task {task_id} failed: {str(e)}", exc_info=True) # 更新任务状态为失败 - failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() - if failed_status: - task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() - task.error_message = str(e) - db.session.commit() + try: + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + task.finished_at = datetime.utcnow() + task.error_message = str(e) + db.session.commit() + except: + pass raise def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir, class_dir, coords_save_path, validation_output_dir, - inference_prompts, is_perturbed, custom_params): - """运行真实微调算法""" + instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params): + """ + 运行真实微调算法(参考sh脚本配置) + + Args: + finetune_method: 微调方法 + task_id: 任务ID + train_images_dir: 训练图片目录 + output_model_dir: 模型输出目录 + class_dir: 类别数据目录 + coords_save_path: 坐标保存路径 + validation_output_dir: 验证图片输出目录 + instance_prompt: 实例提示词 + class_prompt: 类别提示词 + validation_prompt: 验证提示词 + is_perturbed: 是否使用扰动图片 + custom_params: 自定义参数 + """ from config.algorithm_config import AlgorithmConfig logger.info(f"Running real finetune: {finetune_method}") + logger.info(f"Instance prompt: '{instance_prompt}'") + logger.info(f"Class prompt: '{class_prompt}'") + logger.info(f"Validation prompt: '{validation_prompt}'") # 获取微调脚本路径和环境 finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) @@ -153,9 +171,15 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ if not script_path: raise ValueError(f"Finetune method {finetune_method} not configured") - # 合并参数 + # 覆盖提示词参数(从数据库读取) + default_params['instance_prompt'] = instance_prompt + default_params['class_prompt'] = class_prompt + default_params['validation_prompt'] = validation_prompt + + # 合并自定义参数 params = {**default_params, **(custom_params or {})} + # 根据微调方法构建命令参数(参考sh脚本) cmd_args = [ f"--instance_data_dir={train_images_dir}", f"--output_dir={output_model_dir}", @@ -196,6 +220,11 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ else: cmd_args.append(f"--{key}={value}") + # 设置环境变量 + env = os.environ.copy() + env['HF_HUB_OFFLINE'] = '1' # 强制离线模式 + env['CUDA_VISIBLE_DEVICES'] = '0' # 默认使用第一块GPU + # 构建完整命令 cmd = [ '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', @@ -221,9 +250,11 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ stderr=subprocess.STDOUT, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, + env=env ) + # 实时输出日志 for line in process.stdout: f.write(line) f.flush() @@ -231,15 +262,19 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ process.wait() + logger.info(f"Finetune execution completed with return code: {process.returncode}") + logger.info(f"Output directory: {output_model_dir}") + logger.info(f"Log file: {log_file}") + if process.returncode != 0: raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}") - # 清理class_dir - logger.info(f"Cleaning class directory: {class_dir}") - if os.path.exists(class_dir): - import shutil - shutil.rmtree(class_dir) - os.makedirs(class_dir) + # 清理class_dir(参考sh脚本) + if finetune_method in ['dreambooth', 'lora']: + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + shutil.rmtree(class_dir) + os.makedirs(class_dir) # 清理output_model_dir中的非图片文件 logger.info(f"Cleaning non-image files in output directory: {output_model_dir}") @@ -264,188 +299,64 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ } -def _run_virtual_finetune(finetune_method, task_id, train_images_dir, output_model_dir, is_perturbed): - """运行虚拟微调实现""" - from config.algorithm_config import AlgorithmConfig - import shutil - - logger.info(f"Running virtual finetune: {finetune_method}") - - # 获取微调配置 - finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) - if not finetune_config: - raise ValueError(f"Finetune method {finetune_method} not configured") - - conda_env = finetune_config.get('conda_env') - default_params = finetune_config.get('default_params', {}) - - # 获取虚拟微调脚本路径 - script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py' - script_path = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../algorithms/finetune_virtual', - script_name - )) - - if not os.path.exists(script_path): - raise FileNotFoundError(f"Virtual finetune script not found: {script_path}") - - logger.info(f"Virtual script path: {script_path}") - logger.info(f"Conda environment: {conda_env}") - - # 创建输出目录 - os.makedirs(output_model_dir, exist_ok=True) - validation_output_dir = os.path.join(output_model_dir, 'generated') - os.makedirs(validation_output_dir, exist_ok=True) - - # 构建命令行参数 - cmd_args = [ - f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", - f"--instance_data_dir={train_images_dir}", - f"--output_dir={output_model_dir}", - f"--validation_image_output_dir={validation_output_dir}", - f"--class_data_dir=/tmp/class_placeholder", - ] - - # 添加is_perturbed标志 - if is_perturbed: - cmd_args.append("--is_perturbed") - - # 添加其他默认参数 - for key, value in default_params.items(): - if key == 'pretrained_model_name_or_path': - continue - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 使用conda run执行虚拟脚本 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'python', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - image_type = 'perturbed' if is_perturbed else 'original' - log_file = os.path.join( - log_dir, - f'virtual_{finetune_method}_{image_type}_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' - ) - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - if process.returncode != 0: - raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}") - - # 统计生成的图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - generated_files = [] - for ext in image_extensions: - generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext))) - - logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images") - - return { - 'status': 'success', - 'output_dir': output_model_dir, - 'generated_count': len(generated_files), - 'generated_files': generated_files, - 'log_file': log_file - } - - -def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_task): +def _save_generated_images(task_id, output_dir, is_perturbed): """ - 保存微调生成的验证图片到数据库 + 保存微调生成的验证图片到数据库(适配新数据库结构) - 核心逻辑:所有生成图的father_id设置为输入图片的第一张 + 新数据库结构: + - Task表:tasks_id (主键), flow_id, tasks_type_id + - Image表:images_id (主键), task_id (外键), image_types_id, father_id + - 生成图的father_id设置为输入图片的第一张 Args: task_id: 任务ID output_dir: 生成图片输出目录 is_perturbed: 是否为扰动图片训练生成 - has_perturbation_task: 是否基于加噪任务 """ from app import db from app.database import Task, Image, ImageType try: + # 验证任务存在 task = Task.query.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") - # 获取生成图片类型 + # 获取图片类型 if is_perturbed: generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first() + input_type = ImageType.query.filter_by(image_code='perturbed').first() else: generated_type = ImageType.query.filter_by(image_code='original_generate').first() + input_type = ImageType.query.filter_by(image_code='original').first() - if not generated_type: - raise ValueError(f"Image type '{'perturbed' if is_perturbed else 'original'}_generate' not found") + if not generated_type or not input_type: + raise ValueError("Required image types not found in database") # 获取输入图片的第一张作为father_id - father_id = None - if is_perturbed: - # 扰动图片生成,父图片是第一张扰动图 - perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - first_image = Image.query.filter_by( - task_id=task_id, - image_types_id=perturbed_type.image_types_id - ).order_by(Image.images_id.asc()).first() - if first_image: - father_id = first_image.images_id - else: - # 原始图片生成,父图片是第一张原始图 - original_type = ImageType.query.filter_by(image_code='original').first() - first_image = Image.query.filter_by( - task_id=task_id, - image_types_id=original_type.image_types_id - ).order_by(Image.images_id.asc()).first() - if first_image: - father_id = first_image.images_id - - logger.info(f"Will set father_id={father_id} for all generated images") + first_input_image = Image.query.filter_by( + task_id=task_id, + image_types_id=input_type.image_types_id + ).order_by(Image.images_id.asc()).first() + father_id = first_input_image.images_id if first_input_image else None logger.info(f"Will set father_id={father_id} for all generated images") - # 查找输出目录中的生成图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + # 查找输出目录中的所有生成图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp', '*.tiff'] generated_files = [] for ext in image_extensions: generated_files.extend(glob.glob(os.path.join(output_dir, ext))) generated_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) - logger.info(f"Found {len(generated_files)} generated images to save") + logger.info(f"Found {len(generated_files)} generated images in output directory") saved_count = 0 for generated_path in generated_files: try: - # 获取文件名 generated_filename = os.path.basename(generated_path) - # 检查是否已经保存过 + # 检查是否已存在 existing = Image.query.filter_by( task_id=task_id, stored_filename=generated_filename, @@ -453,15 +364,16 @@ def _save_generated_images(task_id, output_dir, is_perturbed, has_perturbation_t ).first() if existing: - logger.info(f"Generated image {generated_filename} already exists, skipping") + logger.info(f"Image {generated_filename} already exists, skipping") continue # 读取图片尺寸 + width, height = None, None try: with PILImage.open(generated_path) as img: width, height = img.size - except: - width, height = None, None + except Exception as e: + logger.warning(f"Could not read image dimensions for {generated_filename}: {e}") # 保存到数据库,所有生成图的father_id统一设置为输入的第一张图片 generated_image = Image( diff --git a/src/backend/app/workers/heatmap_worker.py b/src/backend/app/workers/heatmap_worker.py index 4e00e50..88cb10d 100644 --- a/src/backend/app/workers/heatmap_worker.py +++ b/src/backend/app/workers/heatmap_worker.py @@ -1,11 +1,13 @@ """ -RQ Worker 热力图任务处理器 +RQ Worker 热力图任务处理器 - 适配新数据库结构 生成原始图与扰动图的注意力差异热力图 +仅支持真实算法,移除虚拟算法调用 """ import os import subprocess import logging +import shutil from datetime import datetime logging.basicConfig( @@ -15,92 +17,118 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def run_heatmap_task(heatmap_id, task_id, original_image_path, perturbed_image_path, - prompt_text, target_word, output_dir, model_path, original_image_id=None): +def run_heatmap_task(task_id, original_image_path, perturbed_image_path, + output_dir, model_path, perturbed_image_id=None): """ - 执行热力图生成任务 + 执行热力图生成任务(仅使用真实算法) Args: - heatmap_id: 热力图任务ID - task_id: 关联的主任务ID + task_id: 任务ID original_image_path: 原始图片路径 perturbed_image_path: 扰动图片路径 - prompt_text: Prompt文本(如 "a photo of sks person") - target_word: 目标关键词(如 "person") output_dir: 输出目录 model_path: Stable Diffusion模型路径 - original_image_id: 原始图片ID (用于建立father关系) + perturbed_image_id: 扰动图片ID(用于建立father关系) Returns: 任务执行结果 """ from config.algorithm_config import AlgorithmConfig from app import create_app, db - from app.database import Heatmap, Task, TaskStatus + from app.database import Heatmap, Task, TaskStatus, DataType, Perturbation app = create_app() with app.app_context(): try: - heatmap_task = Heatmap.query.get(heatmap_id) - if not heatmap_task: - raise ValueError(f"Heatmap task {heatmap_id} not found") - + # 获取任务 task = Task.query.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") + # 获取热力图任务详情 + heatmap = Heatmap.query.get(task_id) + if not heatmap: + raise ValueError(f"Heatmap task {task_id} not found") + # 更新任务状态为处理中 processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() if processing_status: task.tasks_status_id = processing_status.task_status_id + task.started_at = datetime.utcnow() db.session.commit() - logger.info(f"Starting heatmap task for Heatmap {heatmap_id}, Task {task_id}") + logger.info(f"Starting heatmap task {task_id}") - # 确保目录存在 - os.makedirs(output_dir, exist_ok=True) + # 从数据库获取提示词(从关联的Perturbation任务获取) + prompt_text = "a photo of sks person" # 默认值 + target_word = "person" # 默认值 + + # 通过flow_id查找关联的Perturbation任务 + perturbation_tasks = Task.query.filter_by( + flow_id=task.flow_id, + tasks_type_id=1 # perturbation类型 + ).all() - # 获取配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + if perturbation_tasks: + for pert_task in perturbation_tasks: + perturbation = Perturbation.query.get(pert_task.tasks_id) + if perturbation and perturbation.data_type_id: + data_type = DataType.query.get(perturbation.data_type_id) + if data_type and data_type.data_type_prompt: + prompt_text = data_type.data_type_prompt + # 提取target_word(去除"sks"后的第一个名词) + words = prompt_text.replace('sks ', '').split() + if words: + target_word = words[-1] # 取最后一个词作为target + logger.info(f"Using prompts from database - prompt: '{prompt_text}', target: '{target_word}'") + break - if use_real: - # 使用真实热力图算法 - result = _run_real_heatmap( - task_id, original_image_path, perturbed_image_path, - prompt_text, target_word, output_dir, model_path - ) - else: - # 使用虚拟实现(生成占位符图片) - result = _run_virtual_heatmap(output_dir) + # 确保目录存在并清空 + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Clearing output directory: {output_dir}") + for item in os.listdir(output_dir): + item_path = os.path.join(output_dir, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + + # 运行真实热力图算法 + result = _run_real_heatmap( + task_id, original_image_path, perturbed_image_path, + prompt_text, target_word, output_dir, model_path + ) # 保存热力图文件到数据库 heatmap_file = os.path.join(output_dir, 'heatmap_dif.png') if os.path.exists(heatmap_file): - heatmap_task.heatmap_name = 'heatmap_dif.png' + heatmap.heatmap_name = 'heatmap_dif.png' # 保存热力图到Image表 - _save_heatmap_image(task_id, heatmap_file, original_image_id) + _save_heatmap_image(task_id, heatmap_file, perturbed_image_id) db.session.commit() # 更新任务状态为完成 completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id + task.completed_at = datetime.utcnow() db.session.commit() - logger.info(f"Heatmap task completed for Heatmap {heatmap_id}") + logger.info(f"Heatmap task {task_id} completed") return result except Exception as e: - logger.error(f"Heatmap task failed for Heatmap {heatmap_id}: {str(e)}", exc_info=True) + logger.error(f"Heatmap task {task_id} failed: {str(e)}", exc_info=True) # 更新任务状态为失败 failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id + task.completed_at = datetime.utcnow() db.session.commit() - raise + return {'success': False, 'error': str(e)} def _run_real_heatmap(task_id, original_image_path, perturbed_image_path, @@ -136,6 +164,10 @@ def _run_real_heatmap(task_id, original_image_path, perturbed_image_path, logger.info(f"Executing command: {' '.join(cmd)}") + # 设置环境变量(强制离线模式) + env = os.environ.copy() + env['HF_HUB_OFFLINE'] = '1' + # 设置日志文件 log_dir = AlgorithmConfig.LOGS_DIR os.makedirs(log_dir, exist_ok=True) @@ -152,7 +184,8 @@ def _run_real_heatmap(task_id, original_image_path, perturbed_image_path, stderr=subprocess.STDOUT, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, + env=env ) for line in process.stdout: @@ -172,35 +205,6 @@ def _run_real_heatmap(task_id, original_image_path, perturbed_image_path, } -def _run_virtual_heatmap(output_dir): - """运行虚拟热力图实现(生成占位符)""" - logger.info(f"Running virtual heatmap generation") - - # 创建占位符图片 - from PIL import Image, ImageDraw, ImageFont - import numpy as np - - # 创建一个模拟的热力图 - img = Image.new('RGB', (1200, 1600), color=(255, 255, 255)) - draw = ImageDraw.Draw(img) - - # 添加标题文本 - draw.text((50, 50), "Virtual Heatmap Placeholder", fill=(0, 0, 0)) - draw.text((50, 100), "Real heatmap will be generated when USE_REAL_ALGORITHMS=true", fill=(128, 128, 128)) - - # 保存 - output_file = os.path.join(output_dir, 'heatmap_dif.png') - img.save(output_file) - - logger.info(f"Virtual heatmap saved to {output_file}") - - return { - 'status': 'success', - 'output_dir': output_dir, - 'virtual': True - } - - def _save_heatmap_image(task_id, heatmap_file_path, father_image_id=None): """ 保存热力图到数据库Image表 diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 4c29c3a..69a2a36 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -1,12 +1,14 @@ """ RQ Worker任务处理器 - 加噪任务 适配新数据库结构: Task + Perturbation + Images +仅支持真实算法,移除虚拟算法调用 """ import os import subprocess import logging import glob +import shutil from datetime import datetime from pathlib import Path from PIL import Image as PILImage @@ -19,20 +21,19 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def run_perturbation_task(task_id, algorithm_code, epsilon, use_strong_protection, - input_dir, output_dir, class_dir, custom_params=None): +def run_perturbation_task(task_id, algorithm_code, epsilon, input_dir, output_dir, + class_dir, custom_params=None): """ - 执行对抗性扰动任务 + 执行对抗性扰动任务(仅使用真实算法) Args: task_id: 任务ID(对应 tasks 表的 tasks_id) - algorithm_code: 算法代码 + algorithm_code: 算法代码 (aspl/simac/caat/pid) epsilon: 扰动强度 - use_strong_protection: 是否使用防净化版本 input_dir: 输入图片目录 output_dir: 输出目录 - class_dir: 类别图片目录 - custom_params: 自定义参数 + class_dir: 类别图片目录(aspl/simac需要) + custom_params: 自定义参数字典 Returns: 任务执行结果 @@ -67,7 +68,6 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, use_strong_protectio logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}") # 获取算法配置 - use_real = AlgorithmConfig.USE_REAL_ALGORITHMS script_path = AlgorithmConfig.get_script_path(algorithm_code) conda_env = AlgorithmConfig.get_conda_env(algorithm_code) @@ -75,19 +75,21 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, use_strong_protectio os.makedirs(output_dir, exist_ok=True) os.makedirs(class_dir, exist_ok=True) - if use_real: - # 使用真实算法 - result = _run_real_algorithm( - script_path, conda_env, algorithm_code, task_id, - epsilon, use_strong_protection, input_dir, output_dir, - class_dir, custom_params - ) - else: - # 使用虚拟实现 - result = _run_virtual_algorithm( - algorithm_code, task_id, epsilon, use_strong_protection, - input_dir, output_dir - ) + # 清空输出目录(避免旧文件残留) + logger.info(f"Clearing output directory: {output_dir}") + if os.path.exists(output_dir): + for item in os.listdir(output_dir): + item_path = os.path.join(output_dir, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + + # 运行真实算法 + result = _run_real_algorithm( + script_path, conda_env, algorithm_code, task_id, + epsilon, input_dir, output_dir, class_dir, custom_params + ) # 保存扰动图片到数据库 _save_perturbed_images(task_id, output_dir) @@ -117,54 +119,86 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, use_strong_protectio def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, - epsilon, use_strong_protection, input_dir, output_dir, - class_dir, custom_params): - """运行真实算法""" + epsilon, input_dir, output_dir, class_dir, custom_params): + """ + 运行真实算法(参考sh脚本配置) + + Args: + script_path: 算法脚本路径 + conda_env: Conda环境名称 + algorithm_code: 算法代码 + task_id: 任务ID + epsilon: 扰动强度 + input_dir: 输入目录 + output_dir: 输出目录 + class_dir: 类别数据目录 + custom_params: 自定义参数 + """ from config.algorithm_config import AlgorithmConfig + from app import db + from app.database import Perturbation, DataType logger.info(f"Running real algorithm: {algorithm_code}") logger.info(f"Conda environment: {conda_env}") logger.info(f"Script path: {script_path}") + # 从数据库获取数据集类型的提示词 + perturbation = Perturbation.query.get(task_id) + if not perturbation: + raise ValueError(f"Perturbation task {task_id} not found") + + data_type = DataType.query.get(perturbation.data_type_id) + if not data_type: + raise ValueError(f"DataType {perturbation.data_type_id} not found") + + # 从data_type_prompt中提取instance_prompt + instance_prompt = data_type.data_type_prompt or 'a photo of sks person' + # 从instance_prompt生成class_prompt(移除"sks") + class_prompt = instance_prompt.replace('sks ', '') + + logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") + # 获取默认参数 default_params = AlgorithmConfig.get_default_params(algorithm_code) + # 覆盖提示词参数(从数据库读取) + default_params['instance_prompt'] = instance_prompt + if 'class_prompt' in default_params: + default_params['class_prompt'] = class_prompt + # 合并自定义参数 params = {**default_params, **(custom_params or {})} + # 根据算法构建命令参数(参考sh脚本) cmd_args = [] - if algorithm_code == 'aspl': - cmd_args.extend([ - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir={class_dir}", - f"--pgd_eps={str(epsilon)}", - ]) - elif algorithm_code == 'simac': + + if algorithm_code in ['aspl', 'simac']: + # ASPL和SimAC使用相同的参数结构 cmd_args.extend([ f"--instance_data_dir_for_train={input_dir}", f"--instance_data_dir_for_adversarial={input_dir}", f"--output_dir={output_dir}", f"--class_data_dir={class_dir}", - f"--pgd_eps={str(epsilon)}", + f"--pgd_eps={epsilon}", ]) elif algorithm_code == 'caat': + # CAAT参数结构 cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={str(epsilon)}", + f"--eps={epsilon}", ]) elif algorithm_code == 'pid': + # PID参数结构 cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={str(epsilon)}", + f"--eps={epsilon}", ]) else: raise ValueError(f"Unsupported algorithm code: {algorithm_code}") - # 添加其他参数 + # 添加其他默认参数 for key, value in params.items(): if isinstance(value, bool): if value: @@ -172,6 +206,10 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, else: cmd_args.append(f"--{key}={value}") + # 设置环境变量 + env = os.environ.copy() + env['HF_HUB_OFFLINE'] = '1' # 强制离线模式 + # 构建完整命令 cmd = [ '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', @@ -180,116 +218,12 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, logger.info(f"Executing command: {' '.join(cmd)}") - # 设置日志文件 - log_dir = AlgorithmConfig.LOGS_DIR - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f'task_{task_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') - - # 执行命令 - with open(log_file, 'w') as f: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True - ) - - # 实时输出日志 - for line in process.stdout: - f.write(line) - f.flush() - logger.info(line.strip()) - - process.wait() - - logger.info(f"output_dir: {output_dir}") - logger.info(f"log_file: {log_file}") - - if process.returncode != 0: - raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}") - - # 清理class_dir - logger.info(f"Cleaning class directory: {class_dir}") - if os.path.exists(class_dir): - import shutil - shutil.rmtree(class_dir) - os.makedirs(class_dir) - - return { - 'status': 'success', - 'output_dir': output_dir, - 'log_file': log_file - } - - -def _run_virtual_algorithm(algorithm_code, task_id, epsilon, use_strong_protection, - input_dir, output_dir): - """运行虚拟算法实现""" - from config.algorithm_config import AlgorithmConfig - import shutil - - logger.info(f"Running virtual algorithm: {algorithm_code}") - - # 获取算法配置 - algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code) - if not algo_config: - raise ValueError(f"Algorithm {algorithm_code} not configured") - - conda_env = algo_config.get('conda_env') - default_params = algo_config.get('default_params', {}) - - # 获取虚拟算法脚本路径 - script_path = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../algorithms/perturbation_virtual', - f'{algorithm_code}.py' - )) - - if not os.path.exists(script_path): - raise FileNotFoundError(f"Virtual script not found: {script_path}") - - logger.info(f"Virtual script path: {script_path}") - logger.info(f"Conda environment: {conda_env}") - - # 确保输出目录存在 - os.makedirs(output_dir, exist_ok=True) - - # 构建命令行参数 - cmd_args = [ - f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir=/tmp/class_placeholder", - f"--pgd_eps={epsilon}", - ] - - # 添加其他默认参数 - for key, value in default_params.items(): - if key == 'pretrained_model_name_or_path': - continue - if isinstance(value, bool): - if value: - cmd_args.append(f"--{key}") - else: - cmd_args.append(f"--{key}={value}") - - # 使用conda run执行虚拟脚本 - cmd = [ - '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', - 'python', script_path - ] + cmd_args - - logger.info(f"Executing command: {' '.join(cmd)}") - # 设置日志文件 log_dir = AlgorithmConfig.LOGS_DIR os.makedirs(log_dir, exist_ok=True) log_file = os.path.join( - log_dir, - f'virtual_{algorithm_code}_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + log_dir, + f'task_{task_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' ) # 执行命令 @@ -300,7 +234,8 @@ def _run_virtual_algorithm(algorithm_code, task_id, epsilon, use_strong_protecti stderr=subprocess.STDOUT, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, + env=env ) # 实时输出日志 @@ -311,23 +246,25 @@ def _run_virtual_algorithm(algorithm_code, task_id, epsilon, use_strong_protecti process.wait() + logger.info(f"Algorithm execution completed with return code: {process.returncode}") + logger.info(f"Output directory: {output_dir}") + logger.info(f"Log file: {log_file}") + if process.returncode != 0: - raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}") - - # 统计处理的图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] - processed_files = [] - for ext in image_extensions: - processed_files.extend(glob.glob(os.path.join(output_dir, ext))) - processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + raise RuntimeError( + f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}" + ) - logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images") + # 清理class_dir + if algorithm_code in ['aspl', 'simac']: + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + shutil.rmtree(class_dir) + os.makedirs(class_dir) return { 'status': 'success', 'output_dir': output_dir, - 'processed_count': len(processed_files), - 'processed_files': processed_files, 'log_file': log_file } @@ -336,6 +273,10 @@ def _save_perturbed_images(task_id, output_dir): """ 保存扰动图片到数据库(适配新数据库结构) + 新数据库结构: + - Task表:tasks_id (主键), flow_id, tasks_type_id + - Image表:images_id (主键), task_id (外键), image_types_id, father_id + Args: task_id: 任务ID output_dir: 扰动图片输出目录 @@ -344,57 +285,63 @@ def _save_perturbed_images(task_id, output_dir): from app.database import Task, Image, ImageType try: + # 验证任务存在 task = Task.query.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") - # 获取扰动图片类型 + # 获取图片类型 + original_type = ImageType.query.filter_by(image_code='original').first() perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - if not perturbed_type: - raise ValueError("Image type 'perturbed' not found") - # 获取原始图片列表(同一个task_id下的原始图片) - original_type = ImageType.query.filter_by(image_code='original').first() + if not original_type or not perturbed_type: + raise ValueError("Required image types not found in database") + + # 获取该任务的所有原始图片(用于建立父子关系) original_images = Image.query.filter_by( task_id=task_id, image_types_id=original_type.image_types_id ).all() - # 创建原图映射字典: stored_filename -> Image对象 + # 创建原图文件名映射 original_map = {img.stored_filename: img for img in original_images} - # 查找输出目录中的扰动图片 - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + logger.info(f"Found {len(original_images)} original images for task {task_id}") + + # 查找输出目录中的所有图片文件 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp', '*.tiff'] perturbed_files = [] for ext in image_extensions: perturbed_files.extend(glob.glob(os.path.join(output_dir, ext))) perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) - logger.info(f"Found {len(perturbed_files)} perturbed images to save") + logger.info(f"Found {len(perturbed_files)} perturbed images in output directory") saved_count = 0 for perturbed_path in perturbed_files: try: - # 获取文件名(不含路径) perturbed_filename = os.path.basename(perturbed_path) - # 尝试找到对应的原始图片 - # 假设扰动图片命名为: perturbed_{original_name}.ext - original_filename = perturbed_filename - if perturbed_filename.startswith('perturbed_'): - original_filename = perturbed_filename[len('perturbed_'):] + # 尝试匹配原始图片(建立父子关系) + # 算法可能输出同名文件或带前缀的文件 + father_image = None - original_image = original_map.get(original_filename) - if not original_image: - # 尝试完全匹配 - matching_images = [img for img in original_images if img.stored_filename == perturbed_filename] - if matching_images: - original_image = matching_images[0] - else: - logger.warning(f"Could not find original image for {perturbed_filename}") - # 即使找不到父图片也保存,但father_id设为None + # 策略1: 完全匹配文件名 + if perturbed_filename in original_map: + father_image = original_map[perturbed_filename] + else: + # 策略2: 移除可能的前缀(如perturbed_) + for prefix in ['perturbed_', 'adv_', 'protected_']: + if perturbed_filename.startswith(prefix): + clean_name = perturbed_filename[len(prefix):] + if clean_name in original_map: + father_image = original_map[clean_name] + break - # 检查是否已经保存过 + if not father_image: + logger.warning(f"Could not find father image for {perturbed_filename}, saving without father_id") + + # 检查是否已存在 existing = Image.query.filter_by( task_id=task_id, stored_filename=perturbed_filename, @@ -402,21 +349,22 @@ def _save_perturbed_images(task_id, output_dir): ).first() if existing: - logger.info(f"Perturbed image {perturbed_filename} already exists, skipping") + logger.info(f"Image {perturbed_filename} already exists, skipping") continue # 读取图片尺寸 + width, height = None, None try: with PILImage.open(perturbed_path) as img: width, height = img.size - except: - width, height = None, None + except Exception as e: + logger.warning(f"Could not read image dimensions for {perturbed_filename}: {e}") - # 保存到数据库(使用新结构) + # 保存到数据库 perturbed_image = Image( task_id=task_id, image_types_id=perturbed_type.image_types_id, - father_id=original_image.images_id if original_image else None, # 设置父图片关系 + father_id=father_image.images_id if father_image else None, stored_filename=perturbed_filename, file_path=perturbed_path, file_size=os.path.getsize(perturbed_path), @@ -426,15 +374,21 @@ def _save_perturbed_images(task_id, output_dir): db.session.add(perturbed_image) saved_count += 1 - logger.info(f"Saved perturbed image: {perturbed_filename} (father: {original_image.images_id if original_image else 'None'})") + + logger.info( + f"Saved: {perturbed_filename} " + f"(father: {father_image.stored_filename if father_image else 'None'})" + ) except Exception as e: - logger.error(f"Error saving perturbed image {perturbed_filename}: {str(e)}") + logger.error(f"Error saving {perturbed_filename}: {str(e)}") continue + # 提交所有更改 db.session.commit() - logger.info(f"Successfully saved {saved_count} perturbed images to database") + logger.info(f"Successfully saved {saved_count}/{len(perturbed_files)} perturbed images") except Exception as e: - logger.error(f"Error saving perturbed images: {str(e)}") + logger.error(f"Error in _save_perturbed_images: {str(e)}", exc_info=True) db.session.rollback() + raise -- 2.34.1 From db94e35efde05fbbe054b9cc221cf5ee5da01f68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 03:45:38 +0800 Subject: [PATCH 11/14] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84services?= =?UTF-8?q?=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/services/task_service.py | 1489 ++++++++++------------ src/backend/config/settings.py | 1 - 2 files changed, 649 insertions(+), 841 deletions(-) diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index dcea483..ae7d67e 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -1,840 +1,649 @@ -""" -任务处理服务 -处理图像加噪、评估等核心业务逻辑 -使用Redis Queue进行异步任务处理 -""" - -import os -from datetime import datetime -from flask import current_app -from redis import Redis -from rq import Queue -from rq.job import Job -from app import db -from app.database import Batch, Image, EvaluationResult, ImageType, FinetuneBatch -from config.algorithm_config import AlgorithmConfig - -class TaskService: - """任务处理服务""" - - @staticmethod - def _get_redis_connection(): - """获取Redis连接""" - return Redis.from_url(AlgorithmConfig.REDIS_URL) - - @staticmethod - def _get_queue(): - """获取RQ队列""" - redis_conn = TaskService._get_redis_connection() - return Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn) - - @staticmethod - def start_processing(batch): - """ - 开始处理任务(异步) - - Args: - batch: Batch对象 - - Returns: - 任务ID (RQ job id) - """ - try: - # 检查是否有原始图片 - original_images = Image.query.filter_by( - batch_id=batch.id - ).join(ImageType).filter( - ImageType.type_code == 'original' - ).all() - - if not original_images: - batch.status = 'failed' - batch.error_message = '没有找到原始图片' - batch.completed_at = datetime.utcnow() - db.session.commit() - return None - - # 准备任务参数 - project_root = os.path.dirname(current_app.root_path) - - # 输入目录(原始图片) - input_dir = os.path.join( - project_root, - current_app.config['ORIGINAL_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id) - ) - - # 输出目录(扰动后图片) - output_dir = os.path.join( - project_root, - current_app.config['PERTURBED_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id) - ) - - # 类别图片目录(用于prior preservation) - class_dir = os.path.join( - project_root, - 'static', 'class', - str(batch.user_id), - str(batch.id) - ) - - # 获取队列 - queue = TaskService._get_queue() - - # 提交任务到队列 - from app.workers.perturbation_worker import run_perturbation_task - - job = queue.enqueue( - run_perturbation_task, - batch_id=batch.id, - algorithm_code=batch.perturbation_config.method_code, - epsilon=int(batch.preferred_epsilon), - use_strong_protection=batch.use_strong_protection, - input_dir=input_dir, - output_dir=output_dir, - class_dir=class_dir, - custom_params=None, - job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"batch_{batch.id}" - ) - - # 更新任务状态 - batch.status = 'queued' - db.session.commit() - - return job.id - - except Exception as e: - # 处理失败 - batch.status = 'failed' - batch.error_message = str(e) - batch.completed_at = datetime.utcnow() - db.session.commit() - return None - - @staticmethod - def get_task_status(batch_id): - """ - 获取任务状态 - - Args: - batch_id: 批次ID - - Returns: - 任务状态信息 - """ - try: - batch = Batch.query.get(batch_id) - if not batch: - return {'status': 'not_found'} - - # 如果任务已完成或失败,直接返回数据库状态 - if batch.status in ['completed', 'failed']: - return { - 'status': batch.status, - 'error': batch.error_message if batch.status == 'failed' else None, - 'started_at': batch.started_at, - 'completed_at': batch.completed_at - } - - # 尝试从RQ获取任务状态 - try: - redis_conn = TaskService._get_redis_connection() - job = Job.fetch(f"batch_{batch_id}", connection=redis_conn) - - rq_status = job.get_status() - - # 映射RQ状态到我们的状态 - status_map = { - 'queued': 'queued', - 'started': 'processing', - 'finished': 'completed', - 'failed': 'failed' - } - - return { - 'status': status_map.get(rq_status, batch.status), - 'rq_status': rq_status, - 'progress': job.meta.get('progress', 0) if hasattr(job, 'meta') else 0, - 'started_at': batch.started_at, - 'result': job.result if rq_status == 'finished' else None, - 'error': str(job.exc_info) if rq_status == 'failed' else None - } - except: - # 如果无法从RQ获取状态,返回数据库状态 - return { - 'status': batch.status, - 'started_at': batch.started_at - } - - except Exception as e: - return {'status': 'error', 'error': str(e)} - - @staticmethod - def cancel_task(batch_id): - """ - 取消任务 - - Args: - batch_id: 批次ID - - Returns: - 是否成功取消 - """ - try: - batch = Batch.query.get(batch_id) - if not batch: - return False - - # 尝试从队列中删除任务 - try: - redis_conn = TaskService._get_redis_connection() - job = Job.fetch(f"batch_{batch_id}", connection=redis_conn) - job.cancel() - except: - pass - - # 更新数据库状态 - batch.status = 'failed' - batch.error_message = 'Task cancelled by user' - batch.completed_at = datetime.utcnow() - db.session.commit() - - return True - - except Exception as e: - print(f"取消任务时出错: {str(e)}") - return False - - @staticmethod - def process_results_and_evaluations(batch_id): - """ - 处理任务结果并生成评估(在worker完成后调用) - - Args: - batch_id: 批次ID - """ - try: - batch = Batch.query.get(batch_id) - if not batch: - return - - # 获取输出目录中的图片 - project_root = os.path.dirname(current_app.root_path) - output_dir = os.path.join( - project_root, - current_app.config['PERTURBED_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id) - ) - - # 获取原始图片 - original_images = Image.query.filter_by( - batch_id=batch.id - ).join(ImageType).filter( - ImageType.type_code == 'original' - ).all() - - perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() - - processed_images = [] - - # 为每张原始图片找到对应的扰动图片 - for original_image in original_images: - # 构建扰动图片路径 - original_name = os.path.splitext(original_image.original_filename)[0] - original_ext = os.path.splitext(original_image.original_filename)[1] - - perturbed_filename = f"{original_name}_perturbed{original_ext}" - perturbed_path = os.path.join(output_dir, perturbed_filename) - - if os.path.exists(perturbed_path): - # 保存扰动图片到数据库 - perturbed_image = Image( - user_id=batch.user_id, - batch_id=batch.id, - father_id=original_image.id, - original_filename=f"perturbed_{original_image.original_filename}", - stored_filename=os.path.basename(perturbed_path), - file_path=perturbed_path, - file_size=os.path.getsize(perturbed_path), - image_type_id=perturbed_type.id, - width=original_image.width, - height=original_image.height - ) - - db.session.add(perturbed_image) - processed_images.append((original_image, perturbed_image)) - - db.session.commit() - - # 生成评估结果 - TaskService._generate_evaluations(batch, processed_images) - - except Exception as e: - print(f"处理结果时出错: {str(e)}") @staticmethod - def _generate_evaluations(batch, processed_images): - """生成评估结果(虚拟实现)""" - try: - for original_image, perturbed_image in processed_images: - # TODO: 实现真实的评估引擎 - # 目前使用虚拟数据 - import random - - # 图像质量对比评估(虚拟数据) - quality_evaluation = EvaluationResult( - reference_image_id=original_image.id, - target_image_id=perturbed_image.id, - evaluation_type='image_quality', - purification_applied=False, - fid_score=round(random.uniform(0.1, 0.5), 4), - lpips_score=round(random.uniform(0.01, 0.1), 4), - ssim_score=round(random.uniform(0.85, 0.99), 4), - psnr_score=round(random.uniform(30, 45), 2), - heatmap_path=None - ) - - db.session.add(quality_evaluation) - - # 模型生成对比评估(虚拟数据) - generation_evaluation = EvaluationResult( - reference_image_id=original_image.id, - target_image_id=perturbed_image.id, - evaluation_type='model_generation', - purification_applied=False, - fid_score=round(random.uniform(0.2, 0.8), 4), - lpips_score=round(random.uniform(0.05, 0.2), 4), - ssim_score=round(random.uniform(0.7, 0.9), 4), - psnr_score=round(random.uniform(25, 40), 2), - heatmap_path=None - ) - - db.session.add(generation_evaluation) - - db.session.commit() - - except Exception as e: - print(f"生成评估结果时出错: {str(e)}") - - @staticmethod - def get_processing_progress(batch_id): - """获取处理进度""" - try: - batch = Batch.query.get(batch_id) - if not batch: - return 0 - - if batch.status == 'pending': - return 0 - elif batch.status == 'completed': - return 100 - elif batch.status == 'failed': - return 0 - elif batch.status == 'processing': - # 简单的进度计算:根据已处理的图片数量 - total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( - ImageType.type_code == 'original' - ).count() - - processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( - ImageType.type_code == 'perturbed' - ).count() - - if total_images == 0: - return 0 - - progress = int((processed_images / total_images) * 80) # 80%用于图像处理,20%用于评估 - return min(progress, 95) # 最多95%,剩余5%用于最终完成 - - return 0 - - except Exception as e: - print(f"获取处理进度时出错: {str(e)}") - return 0 - - @staticmethod - def start_finetune_task(finetune_task): - """ - 启动微调任务(使用 FinetuneBatch) - - Args: - finetune_task: FinetuneBatch对象 - - Returns: - 包含两个job_id的字典 - """ - try: - # 获取关联的扰动任务 - batch = finetune_task.batch - if not batch: - print(f"FinetuneBatch {finetune_task.id} 没有关联的扰动任务") - return None - - # 检查是否有扰动图片 - perturbed_images = Image.query.filter_by( - batch_id=batch.id - ).join(ImageType).filter( - ImageType.type_code == 'perturbed' - ).all() - - if not perturbed_images: - print(f"Batch {batch.id} 没有扰动图片,无法启动微调任务") - finetune_task.status = 'failed' - finetune_task.error_message = '没有找到扰动图片' - db.session.commit() - return None - - project_root = os.path.dirname(current_app.root_path) - finetune_method = finetune_task.finetune_config.method_code - queue = TaskService._get_queue() - - from app.workers.finetune_worker import run_finetune_task - - # 原始图片目录 - original_dir = os.path.join( - project_root, - current_app.config['ORIGINAL_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id) - ) - - # 扰动图片目录 - perturbed_dir = os.path.join( - project_root, - current_app.config['PERTURBED_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id) - ) - - # 模型输出目录 - original_model_dir = os.path.join( - project_root, - current_app.config['MODEL_ORIGINAL_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - perturbed_model_dir = os.path.join( - project_root, - current_app.config['MODEL_PERTURBED_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - # 类别图片目录(Prior Preservation) - class_finetune_dir = os.path.join( - project_root, - current_app.config['CLASS_DATA_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - # 坐标可视化保存目录(训练轨迹) - coords_save_dir = os.path.join( - project_root, - current_app.config['COORDS_SAVE_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - # 验证图片输出目录(分别对应 clean 和 perturbed) - validation_original_dir = os.path.join( - project_root, - current_app.config['MODEL_ORIGINAL_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - validation_perturbed_dir = os.path.join( - project_root, - current_app.config['MODEL_PERTURBED_FOLDER'], - str(batch.user_id), str(batch.id) - ) - - # 推理提示词 - inference_prompts = "a photo of sks person" - - # 1. 用原始图片微调模型 - job_original = queue.enqueue( - run_finetune_task, - finetune_batch_id=finetune_task.id, - batch_id=batch.id, - finetune_method=finetune_method, - train_images_dir=original_dir, - output_model_dir=original_model_dir, - class_dir=class_finetune_dir, - coords_save_path=coords_save_dir, - validation_output_dir=validation_original_dir, - inference_prompts=inference_prompts, - is_perturbed=False, - custom_params=None, - job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"finetune_original_{finetune_task.id}" - ) - - # 2. 用扰动图片微调模型(依赖于原始图片微调完成) - job_perturbed = queue.enqueue( - run_finetune_task, - finetune_batch_id=finetune_task.id, - batch_id=batch.id, - finetune_method=finetune_method, - train_images_dir=perturbed_dir, - output_model_dir=perturbed_model_dir, - class_dir=class_finetune_dir, - coords_save_path=coords_save_dir, - validation_output_dir=validation_perturbed_dir, - inference_prompts=inference_prompts, - is_perturbed=True, - custom_params=None, - job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"finetune_perturbed_{finetune_task.id}", - depends_on=job_original - ) - - # 更新微调任务状态 - finetune_task.status = 'queued' - finetune_task.original_job_id = job_original.id - finetune_task.perturbed_job_id = job_perturbed.id - finetune_task.started_at = datetime.utcnow() - db.session.commit() - - return { - 'original_job_id': job_original.id, - 'perturbed_job_id': job_perturbed.id - } - - except Exception as e: - print(f"启动微调任务时出错: {str(e)}") - finetune_task.status = 'failed' - finetune_task.error_message = str(e) - finetune_task.completed_at = datetime.utcnow() - db.session.commit() - return None - - @staticmethod - def get_finetune_task_status(finetune_id): - """ - 获取微调任务状态(使用 FinetuneBatch) - 同时会检查并更新任务状态 - - Args: - finetune_id: 微调任务ID - - Returns: - 微调任务状态信息 - """ - try: - finetune_task = FinetuneBatch.query.get(finetune_id) - if not finetune_task: - return {'status': 'not_found'} - - # 如果任务不是最终状态,检查并更新状态 - if finetune_task.status not in ['completed', 'failed']: - from app.workers.finetune_worker import _check_and_update_finetune_status - _check_and_update_finetune_status(finetune_task) - # 刷新对象以获取最新状态 - db.session.refresh(finetune_task) - - # 如果任务已完成或失败,直接返回数据库状态 - if finetune_task.status in ['completed', 'failed']: - return { - 'status': finetune_task.status, - 'error': finetune_task.error_message if finetune_task.status == 'failed' else None, - 'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None, - 'completed_at': finetune_task.completed_at.isoformat() if finetune_task.completed_at else None - } - - # 从RQ获取任务状态 - redis_conn = TaskService._get_redis_connection() - - original_job_status = 'not_found' - perturbed_job_status = 'not_found' - - try: - if finetune_task.original_job_id: - original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn) - original_job_status = original_job.get_status() - except: - pass - - try: - if finetune_task.perturbed_job_id: - perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn) - perturbed_job_status = perturbed_job.get_status() - except: - pass - - # 映射状态 - status_map = { - 'queued': 'queued', - 'started': 'processing', - 'finished': 'completed', - 'failed': 'failed', - 'not_found': 'not_started' - } - - return { - 'status': finetune_task.status, - 'original_finetune': status_map.get(original_job_status, 'unknown'), - 'perturbed_finetune': status_map.get(perturbed_job_status, 'unknown'), - 'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None - } - - except Exception as e: - print(f"获取微调任务状态时出错: {str(e)}") - return {'status': 'error', 'error': str(e)} - - @staticmethod - def generate_final_evaluations(batch_id): - """ - 生成最终评估(对比原始和扰动图片微调后的模型生成效果) - - 此方法在两个微调任务都完成后调用 - - Args: - batch_id: 批次ID - """ - try: - batch = Batch.query.get(batch_id) - if not batch: - return - - # 获取原始图片生成的结果 - original_generated = Image.query.filter_by( - batch_id=batch_id - ).join(ImageType).filter( - ImageType.type_code == 'original_generate' - ).all() - - # 获取扰动图片生成的结果 - perturbed_generated = Image.query.filter_by( - batch_id=batch_id - ).join(ImageType).filter( - ImageType.type_code == 'perturbed_generate' - ).all() - - if not original_generated or not perturbed_generated: - print(f"Batch {batch_id} 缺少生成的图片,无法评估") - return - - # 配对评估 - for orig_gen in original_generated: - # 找到对应的扰动生成图片(基于相同的父图片) - matching_pert_gen = None - for pert_gen in perturbed_generated: - # 尝试匹配文件名或父图片关系 - if pert_gen.original_filename.replace('generated_', '') == orig_gen.original_filename.replace('generated_', ''): - matching_pert_gen = pert_gen - break - - if matching_pert_gen: - # TODO: 实现真实的评估引擎 - # 目前使用虚拟数据 - import random - - # 保存评估结果(虚拟数据) - generation_evaluation = EvaluationResult( - reference_image_id=orig_gen.id, - target_image_id=matching_pert_gen.id, - evaluation_type='model_generation', - purification_applied=False, - fid_score=round(random.uniform(0.3, 0.9), 4), - lpips_score=round(random.uniform(0.1, 0.3), 4), - ssim_score=round(random.uniform(0.6, 0.85), 4), - psnr_score=round(random.uniform(20, 35), 2), - heatmap_path=None - ) - - db.session.add(generation_evaluation) - - db.session.commit() - print(f"Batch {batch_id} 最终评估完成") - - except Exception as e: - print(f"生成最终评估时出错: {str(e)}") - db.session.rollback() - - @staticmethod - def start_heatmap_task(heatmap_task, original_image_id, perturbed_image_id): - """ - 启动热力图生成任务 - - Args: - heatmap_task: Heatmap对象 - original_image_id: 原始图片ID(前端选择) - perturbed_image_id: 扰动图片ID(前端选择) - - Returns: - job_id - """ - try: - # 获取关联的主任务 - task = heatmap_task.task - if not task: - print(f"Heatmap task {heatmap_task.tasks_id} has no associated Task") - return None - - # 获取图片信息 - from app.database import Image - original_image = Image.query.get(original_image_id) - perturbed_image = Image.query.get(perturbed_image_id) - - if not original_image or not perturbed_image: - print("Selected images not found") - return None - - # 获取Prompt文本(从关联的Perturbation任务的数据集类型) - from app.database import Perturbation, DataType - perturbation = Perturbation.query.filter_by(tasks_id=task.tasks_id).first() - if perturbation and perturbation.data_type: - prompt_text = perturbation.data_type.data_type_prompt - # 从prompt中提取target_word(简单提取最后一个词) - target_word = prompt_text.split()[-1] if prompt_text else 'person' - else: - prompt_text = "a photo of sks person" - target_word = "person" - - project_root = os.path.dirname(current_app.root_path) - - # 输出目录 - output_dir = os.path.join( - project_root, - current_app.config['HEATDIF_SAVE_FOLDER'], - str(task.user_id), - str(task.tasks_id) - ) - - # 模型路径(从配置文件获取) - from config.algorithm_config import AlgorithmConfig - model_path = AlgorithmConfig.MODELS_DIR.get('model2') # 默认使用SD 1.5 - - # 获取队列 - queue = TaskService._get_queue() - - from app.workers.heatmap_worker import run_heatmap_task - - # 提交任务到队列 - job = queue.enqueue( - run_heatmap_task, - heatmap_id=heatmap_task.tasks_id, - task_id=task.tasks_id, - original_image_path=original_image.file_path, - perturbed_image_path=perturbed_image.file_path, - prompt_text=prompt_text, - target_word=target_word, - output_dir=output_dir, - model_path=model_path, - job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"heatmap_{heatmap_task.tasks_id}" - ) - - # 更新任务状态为queued - from app.database import TaskStatus - queued_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if queued_status: - task.tasks_status_id = queued_status.task_status_id - db.session.commit() - - return job.id - - except Exception as e: - print(f"启动热力图任务时出错: {str(e)}") - return None - - @staticmethod - def start_evaluate_task(evaluate_task): - """ - 启动数值评估任务 - - Args: - evaluate_task: Evaluate对象 - - Returns: - job_id - """ - try: - # 获取关联的主任务 - task = evaluate_task.task - if not task: - print(f"Evaluate task for Task {evaluate_task.tasks_id} has no associated Task") - return None - - # 获取关联的Finetune任务,以确定微调方法 - from app.database import Finetune - finetune = Finetune.query.filter_by( - tasks_id=evaluate_task.tasks_id, - finetune_configs_id=evaluate_task.finetune_configs_id - ).first() - - if not finetune: - print(f"No finetune task found for Evaluate task") - return None - - project_root = os.path.dirname(current_app.root_path) - - # 参考图片目录(原始上传的图片) - clean_ref_dir = os.path.join( - project_root, - current_app.config['ORIGINAL_IMAGES_FOLDER'], - str(task.user_id), - str(task.tasks_id) - ) - - # Clean输出目录(原始图训练后的生成结果) - clean_output_dir = os.path.join( - project_root, - current_app.config['MODEL_OUTPUTS_FOLDER'], - 'clean', - str(task.user_id), - str(task.tasks_id) - ) - - # Perturbed输出目录(扰动图训练后的生成结果) - perturbed_output_dir = os.path.join( - project_root, - current_app.config['MODEL_OUTPUTS_FOLDER'], - 'perturbed', - str(task.user_id), - str(task.tasks_id) - ) - - # 评估结果输出目录 - output_dir = os.path.join( - project_root, - current_app.config['NUMBERS_SAVE_FOLDER'], - str(task.user_id), - str(task.tasks_id) - ) - - # 获取队列 - queue = TaskService._get_queue() - - from app.workers.evaluate_worker import run_evaluate_task - from config.algorithm_config import AlgorithmConfig - - # 提交任务到队列 - job = queue.enqueue( - run_evaluate_task, - evaluate_id=evaluate_task.tasks_id, - task_id=task.tasks_id, - clean_ref_dir=clean_ref_dir, - clean_output_dir=clean_output_dir, - perturbed_output_dir=perturbed_output_dir, - output_dir=output_dir, - image_size=512, - job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"evaluate_{evaluate_task.tasks_id}_{evaluate_task.finetune_configs_id}" - ) - - # 更新任务状态为queued - from app.database import TaskStatus - queued_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if queued_status: - task.tasks_status_id = queued_status.task_status_id - db.session.commit() - - return job.id - - except Exception as e: - print(f"启动评估任务时出错: {str(e)}") - return None - - +""" +任务处理服务(适配新数据库结构和路径配置) +处理加噪、微调、热力图、评估等核心业务逻辑 +使用Redis Queue进行异步任务处理 +""" + +import os +import logging +from datetime import datetime +from flask import current_app +from redis import Redis +from rq import Queue +from rq.job import Job +from app import db +from app.database import ( + Task, TaskStatus, TaskType, + Perturbation, Finetune, Heatmap, Evaluate, + Image, ImageType, DataType, + PerturbationConfig, FinetuneConfig +) +from config.algorithm_config import AlgorithmConfig +from config.settings import Config + +logger = logging.getLogger(__name__) + + +class TaskService: + """任务处理服务""" + + # ==================== 路径工具函数 ==================== + + @staticmethod + def _get_project_root(): + """获取项目根目录""" + return os.path.dirname(current_app.root_path) + + @staticmethod + def _build_path(*parts): + """构建路径""" + return os.path.join(TaskService._get_project_root(), *parts) + + @staticmethod + def get_original_images_path(user_id, flow_id): + """原图路径: ORIGINAL_IMAGES_FOLDER/user_id/flow_id""" + return TaskService._build_path( + Config.ORIGINAL_IMAGES_FOLDER, + str(user_id), + str(flow_id) + ) + + @staticmethod + def get_perturbed_images_path(user_id, flow_id): + """加噪图路径: PERTURBED_IMAGES_FOLDER/user_id/flow_id""" + return TaskService._build_path( + Config.PERTURBED_IMAGES_FOLDER, + str(user_id), + str(flow_id) + ) + + @staticmethod + def get_original_generated_path(user_id, flow_id, task_id): + """原图生成图路径: MODEL_ORIGINAL_FOLDER/user_id/flow_id/task_id""" + return TaskService._build_path( + Config.MODEL_ORIGINAL_FOLDER, + str(user_id), + str(flow_id), + str(task_id) + ) + + @staticmethod + def get_perturbed_generated_path(user_id, flow_id, task_id): + """加噪图生成图路径: MODEL_PERTURBED_FOLDER/user_id/flow_id/task_id""" + return TaskService._build_path( + Config.MODEL_PERTURBED_FOLDER, + str(user_id), + str(flow_id), + str(task_id) + ) + + @staticmethod + def get_uploaded_generated_path(user_id, flow_id, task_id): + """上传图生成图路径: MODEL_UPLOADED_FOLDER/user_id/flow_id/task_id""" + return TaskService._build_path( + Config.MODEL_UPLOADED_FOLDER, + str(user_id), + str(flow_id), + str(task_id) + ) + + @staticmethod + def get_heatmap_path(user_id, flow_id, task_id): + """热力图路径: HEATDIF_SAVE_FOLDER/user_id/flow_id/task_id""" + return TaskService._build_path( + Config.HEATDIF_SAVE_FOLDER, + str(user_id), + str(flow_id), + str(task_id) + ) + + @staticmethod + def get_evaluate_path(user_id, flow_id, task_id): + """数值结果路径: NUMBERS_SAVE_FOLDER/user_id/flow_id/task_id""" + return TaskService._build_path( + Config.NUMBERS_SAVE_FOLDER, + str(user_id), + str(flow_id), + str(task_id) + ) + + @staticmethod + def get_class_data_path(user_id, flow_id): + """类别数据路径: CLASS_DATA_FOLDER/user_id/flow_id""" + return TaskService._build_path( + Config.CLASS_DATA_FOLDER, + str(user_id), + str(flow_id) + ) + + # ==================== Redis/RQ 连接管理 ==================== + + @staticmethod + def _get_redis_connection(): + """获取Redis连接""" + return Redis.from_url(AlgorithmConfig.REDIS_URL) + + @staticmethod + def _get_queue(): + """获取RQ队列""" + redis_conn = TaskService._get_redis_connection() + return Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn) + + # ==================== 通用任务状态查询 ==================== + + @staticmethod + def get_task_status(task_id): + """ + 获取任务状态(通用查询,适用于所有类型任务) + + Args: + task_id: 任务ID + + Returns: + 任务状态信息 + """ + try: + task = Task.query.get(task_id) + if not task: + return {'status': 'not_found', 'error': 'Task not found'} + + # 获取任务状态名称 + status = TaskStatus.query.get(task.tasks_status_id) + status_code = status.task_status_code if status else 'unknown' + + # 获取任务类型 + task_type = TaskType.query.get(task.tasks_type_id) + type_code = task_type.task_type_code if task_type else 'unknown' + + result = { + 'task_id': task_id, + 'type': type_code, + 'status': status_code, + 'flow_id': task.flow_id, + 'created_at': task.created_at.isoformat() if task.created_at else None, + 'started_at': task.started_at.isoformat() if task.started_at else None, + 'completed_at': task.completed_at.isoformat() if task.completed_at else None + } + + # 如果任务失败,尝试从RQ获取错误信息 + if status_code == 'failed': + try: + redis_conn = TaskService._get_redis_connection() + job_id = TaskService._get_job_id_for_task(task_id, type_code) + if job_id: + job = Job.fetch(job_id, connection=redis_conn) + if job.is_failed: + result['error'] = str(job.exc_info) if job.exc_info else 'Unknown error' + except: + pass + + return result + + except Exception as e: + logger.error(f"Error getting task status: {e}") + return {'status': 'error', 'error': str(e)} + + @staticmethod + def _get_job_id_for_task(task_id, task_type): + """根据任务类型生成job_id""" + type_prefix = { + 'perturbation': 'pert', + 'finetune': 'ft', + 'heatmap': 'hm', + 'evaluate': 'eval' + } + prefix = type_prefix.get(task_type, 'task') + return f"{prefix}_{task_id}" + + @staticmethod + def cancel_task(task_id): + """ + 取消任务(通用取消,适用于所有类型任务) + + Args: + task_id: 任务ID + + Returns: + 是否成功取消 + """ + try: + task = Task.query.get(task_id) + if not task: + return False + + # 获取任务类型 + task_type = TaskType.query.get(task.tasks_type_id) + type_code = task_type.task_type_code if task_type else None + + # 尝试从队列中删除任务 + try: + redis_conn = TaskService._get_redis_connection() + job_id = TaskService._get_job_id_for_task(task_id, type_code) + job = Job.fetch(job_id, connection=redis_conn) + job.cancel() + except Exception as e: + logger.warning(f"Could not cancel RQ job: {e}") + + # 更新数据库状态 + failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() + if failed_status: + task.tasks_status_id = failed_status.task_status_id + task.completed_at = datetime.utcnow() + db.session.commit() + + return True + + except Exception as e: + logger.error(f"Error cancelling task: {e}") + return False + + # ==================== Perturbation 任务 ==================== + + @staticmethod + def start_perturbation_task(task_id): + """ + 启动加噪任务 + + Args: + task_id: 任务ID + + Returns: + job_id + """ + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + # 获取Perturbation任务详情 + perturbation = Perturbation.query.get(task_id) + if not perturbation: + logger.error(f"Perturbation task {task_id} not found") + return None + + # 获取用户ID + user_id = task.user_id + + # 路径配置 + input_dir = TaskService.get_original_images_path(user_id, task.flow_id) + output_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id) + class_dir = TaskService.get_class_data_path(user_id, task.flow_id) + + # 获取算法配置 + pert_config = PerturbationConfig.query.get(perturbation.perturbation_configs_id) + if not pert_config: + logger.error(f"Perturbation config not found") + return None + + algorithm_code = pert_config.perturbation_algorithm_code + + # 加入RQ队列 + from app.workers.perturbation_worker import run_perturbation_task + + queue = TaskService._get_queue() + job_id = f"pert_{task_id}" + + job = queue.enqueue( + run_perturbation_task, + task_id=task_id, + input_dir=input_dir, + output_dir=output_dir, + class_dir=class_dir, + algorithm_code=algorithm_code, + epsilon=pert_config.epsilon, + job_id=job_id, + job_timeout='4h' + ) + + logger.info(f"Perturbation task {task_id} enqueued with job_id {job_id}") + return job_id + + except Exception as e: + logger.error(f"Error starting perturbation task: {e}") + return None + + # ==================== Finetune 任务 ==================== + + @staticmethod + def start_finetune_task(task_id): + """ + 启动微调任务(支持两种类型) + + 类型1:基于加噪结果的微调 + - 有相同flow_id的Perturbation任务 + - 输入:原图 + 加噪图 + - 输出到:original_generated 和 perturbed_generated + + 类型2:用户上传图片的微调 + - 找不到相同flow_id的其他任务 + - 输入:仅原图 + - 输出到:uploaded_generated + + Args: + task_id: 任务ID + + Returns: + job_id + """ + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + # 获取Finetune任务详情 + finetune = Finetune.query.get(task_id) + if not finetune: + logger.error(f"Finetune task {task_id} not found") + return None + + # 获取用户ID + user_id = task.user_id + + # 获取微调配置 + ft_config = FinetuneConfig.query.get(finetune.finetune_configs_id) + if not ft_config: + logger.error(f"Finetune config not found") + return None + + # 检测微调类型:查找相同flow_id的Perturbation任务 + perturbation_tasks = Task.query.filter( + Task.flow_id == task.flow_id, + Task.tasks_type_id == 1, # perturbation类型 + Task.tasks_id != task_id + ).all() + + has_perturbation = len(perturbation_tasks) > 0 + + # 路径配置 + input_dir = TaskService.get_original_images_path(user_id, task.flow_id) + class_dir = TaskService.get_class_data_path(user_id, task.flow_id) + + if has_perturbation: + # 类型1:基于加噪结果的微调 + logger.info(f"Finetune task {task_id}: type=perturbation-based") + + perturbed_input_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id) + original_input_dir = TaskService.get_original_images_path(user_id, task.flow_id) + perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, task_id) + original_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, task_id) + + # 获取坐标保存路径(3D可视化) + coords_save_path = TaskService._build_path( + Config.COORDS_SAVE_FOLDER, + str(user_id), + str(task.flow_id), + str(task_id), + 'coords.json' + ) + + # 加入RQ队列 + from app.workers.finetune_worker import run_finetune_task + + queue = TaskService._get_queue() + job_id = f"ft_{task_id}" + + job_original = queue.enqueue( + run_finetune_task, + task_id=task_id, + finetune_method=ft_config.finetune_method, + tranin_images_dir=original_input_dir, + output_model_dir=original_output_dir, + class_dir=class_dir, + coords_save_path=coords_save_path, + validation_images_dir=original_output_dir, + is_perturbed=False, + custom_params=None, + job_id=job_id, + job_timeout='8h' + ) + + job_perturbed = queue.enqueue( + run_finetune_task, + task_id=task_id, + finetune_method=ft_config.finetune_method, + tranin_images_dir=perturbed_input_dir, + output_model_dir=perturbed_output_dir, + class_dir=class_dir, + coords_save_path=coords_save_path, + validation_images_dir=perturbed_output_dir, + is_perturbed=True, + custom_params=None, + job_id=job_id, + job_timeout='8h' + ) + + else: + # 类型2:用户上传图片的微调 + logger.info(f"Finetune task {task_id}: type=uploaded") + + uploaded_output_dir = TaskService.get_uploaded_generated_path(user_id, task.flow_id, task_id) + + # 获取坐标保存路径 + coords_save_path = TaskService._build_path( + Config.COORDS_SAVE_FOLDER, + str(user_id), + str(task.flow_id), + str(task_id), + 'coords.json' + ) + + # 加入RQ队列 + from app.workers.finetune_worker import run_finetune_task + + queue = TaskService._get_queue() + job_id = f"ft_{task_id}" + + job = queue.enqueue( + run_finetune_task, + task_id=task_id, + finetune_method=ft_config.finetune_method, + tranin_images_dir=input_dir, + output_model_dir=uploaded_output_dir, + class_dir=class_dir, + coords_save_path=coords_save_path, + validation_images_dir=uploaded_output_dir, + is_perturbed=False, + custom_params=None, + job_id=job_id, + job_timeout='8h' + ) + + logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}") + return job_id + + except Exception as e: + logger.error(f"Error starting finetune task: {e}") + return None + + # ==================== Heatmap 任务 ==================== + + @staticmethod + def start_heatmap_task(task_id, perturbed_image_id): + """ + 启动热力图任务 + + Args: + task_id: 任务ID + perturbed_image_id: 扰动图片ID + + Returns: + job_id + """ + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + # 获取Heatmap任务详情 + heatmap = Heatmap.query.get(task_id) + if not heatmap: + logger.error(f"Heatmap task {task_id} not found") + return None + + # 获取扰动图片信息 + perturbed_image = Image.query.get(perturbed_image_id) + if not perturbed_image: + logger.error(f"Perturbed image {perturbed_image_id} not found") + return None + + user_id = perturbed_image.user_id + + # 获取原图(通过father_id关系) + if not perturbed_image.father_id: + logger.error(f"Perturbed image {perturbed_image_id} has no father_id") + return None + + original_image = Image.query.get(perturbed_image.father_id) + if not original_image: + logger.error(f"Original image not found") + return None + + # 构建图片路径 + original_image_path = TaskService._build_path( + Config.ORIGINAL_IMAGES_FOLDER, + str(user_id), + str(task.flow_id), + original_image.image_name + ) + + perturbed_image_path = TaskService._build_path( + Config.PERTURBED_IMAGES_FOLDER, + str(user_id), + str(task.flow_id), + perturbed_image.image_name + ) + + # 输出目录 + output_dir = TaskService.get_heatmap_path(user_id, task.flow_id, task_id) + + # 获取模型路径 + sd_version = AlgorithmConfig.STABLE_DIFFUSION_VERSION + model_path = AlgorithmConfig.SD_MODEL_PATHS.get(sd_version) + if not model_path: + logger.error(f"Model path not found for SD version {sd_version}") + return None + + # 加入RQ队列 + from app.workers.heatmap_worker import run_heatmap_task + + queue = TaskService._get_queue() + job_id = f"hm_{task_id}" + + job = queue.enqueue( + run_heatmap_task, + task_id=task_id, + original_image_path=original_image_path, + perturbed_image_path=perturbed_image_path, + output_dir=output_dir, + model_path=model_path, + perturbed_image_id=perturbed_image_id, + job_id=job_id, + job_timeout='2h' + ) + + logger.info(f"Heatmap task {task_id} enqueued with job_id {job_id}") + return job_id + + except Exception as e: + logger.error(f"Error starting heatmap task: {e}") + return None + + # ==================== Evaluate 任务 ==================== + + @staticmethod + def start_evaluate_task(task_id): + """ + 启动评估任务 + + Args: + task_id: 任务ID + + Returns: + job_id + """ + try: + # 获取任务 + task = Task.query.get(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + # 获取Evaluate任务详情 + evaluate = Evaluate.query.get(task_id) + if not evaluate: + logger.error(f"Evaluate task {task_id} not found") + return None + + # 获取用户ID + sample_image = Image.query.filter_by(tasks_id=task_id).first() + if not sample_image: + logger.error(f"No images found for task {task_id}") + return None + user_id = sample_image.user_id + + # 查找相同flow_id的Finetune任务 + finetune_tasks = Task.query.filter( + Task.flow_id == task.flow_id, + Task.tasks_type_id == 2, # finetune类型 + Task.tasks_id != task_id + ).all() + + if not finetune_tasks: + logger.error(f"No finetune task found for flow {task.flow_id}") + return None + + # 从Evaluate任务获取需要的微调配置ID + if not evaluate.finetune_configs_id: + logger.error(f"Evaluate task {task_id} has no finetune_configs_id") + return None + + # 查找使用相同微调算法的任务 + target_finetune_task = None + for ft_task in finetune_tasks: + ft = Finetune.query.get(ft_task.tasks_id) + if ft and ft.finetune_configs_id == evaluate.finetune_configs_id: + target_finetune_task = ft_task + break + + if not target_finetune_task: + logger.error(f"No finetune task with config {evaluate.finetune_configs_id} found for flow {task.flow_id}") + return None + + finetune_task = target_finetune_task + + # 路径配置 + clean_ref_dir = TaskService.get_original_images_path(user_id, task.flow_id) + clean_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, finetune_task.tasks_id) + perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, finetune_task.tasks_id) + output_dir = TaskService.get_evaluate_path(user_id, task.flow_id, task_id) + + # 加入RQ队列 + from app.workers.evaluate_worker import run_evaluate_task + + queue = TaskService._get_queue() + job_id = f"eval_{task_id}" + + job = queue.enqueue( + run_evaluate_task, + task_id=task_id, + clean_ref_dir=clean_ref_dir, + clean_output_dir=clean_output_dir, + perturbed_output_dir=perturbed_output_dir, + output_dir=output_dir, + image_size=512, + job_id=job_id, + job_timeout='2h' + ) + + logger.info(f"Evaluate task {task_id} enqueued with job_id {job_id}") + return job_id + + except Exception as e: + logger.error(f"Error starting evaluate task: {e}") + return None diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index f59610e..6cc93c2 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -55,7 +55,6 @@ class Config: MODEL_UPLOADED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'uploaded') # 上传图的模型生成结果 MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 - HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 # 微调训练相关配置 CLASS_DATA_FOLDER = os.path.join(STATIC_ROOT, 'class') # 类别数据目录(用于 prior preservation) # 可视化与分析配置 -- 2.34.1 From 0ce6747b768b62b9faa6766dcbfb217b63ab922b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 08:53:52 +0800 Subject: [PATCH 12/14] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=AE=8C=E6=88=90=E5=AD=97=E6=AE=B5=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/services/task_service.py | 4 ++-- src/backend/app/workers/evaluate_worker.py | 4 ++-- src/backend/app/workers/heatmap_worker.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index ae7d67e..d60f1ec 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -162,7 +162,7 @@ class TaskService: 'flow_id': task.flow_id, 'created_at': task.created_at.isoformat() if task.created_at else None, 'started_at': task.started_at.isoformat() if task.started_at else None, - 'completed_at': task.completed_at.isoformat() if task.completed_at else None + 'finished_at': task.finished_at.isoformat() if task.finished_at else None } # 如果任务失败,尝试从RQ获取错误信息 @@ -228,7 +228,7 @@ class TaskService: failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.completed_at = datetime.utcnow() + task.finished_at = datetime.utcnow() db.session.commit() return True diff --git a/src/backend/app/workers/evaluate_worker.py b/src/backend/app/workers/evaluate_worker.py index 0e5acae..4293cc5 100644 --- a/src/backend/app/workers/evaluate_worker.py +++ b/src/backend/app/workers/evaluate_worker.py @@ -85,7 +85,7 @@ def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.completed_at = datetime.utcnow() + task.finished_at = datetime.utcnow() db.session.commit() logger.info(f"Evaluate task {task_id} completed") @@ -98,7 +98,7 @@ def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.completed_at = datetime.utcnow() + task.finished_at = datetime.utcnow() db.session.commit() return {'success': False, 'error': str(e)} diff --git a/src/backend/app/workers/heatmap_worker.py b/src/backend/app/workers/heatmap_worker.py index 88cb10d..0337e96 100644 --- a/src/backend/app/workers/heatmap_worker.py +++ b/src/backend/app/workers/heatmap_worker.py @@ -112,7 +112,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.completed_at = datetime.utcnow() + task.finished_at = datetime.utcnow() db.session.commit() logger.info(f"Heatmap task {task_id} completed") @@ -125,7 +125,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path, failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.completed_at = datetime.utcnow() + task.finished_at = datetime.utcnow() db.session.commit() return {'success': False, 'error': str(e)} -- 2.34.1 From ae261a6dc96df3bbf15880d0cb2dfd689936d478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 09:25:51 +0800 Subject: [PATCH 13/14] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84controllers?= =?UTF-8?q?=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/controllers/auth_controller.py | 1 - .../app/controllers/image_controller.py | 331 ++--- .../app/controllers/task_controller.py | 1141 ++++++++--------- .../app/controllers/user_controller.py | 248 ++-- src/backend/app/services/image_service.py | 177 ++- src/backend/app/services/task_service.py | 133 +- 6 files changed, 1087 insertions(+), 944 deletions(-) diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py index 751a2a5..bd93886 100644 --- a/src/backend/app/controllers/auth_controller.py +++ b/src/backend/app/controllers/auth_controller.py @@ -7,7 +7,6 @@ from flask import Blueprint, request, jsonify from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity from app import db from app.database import User, UserConfig -from app.services.auth_service import AuthService from functools import wraps import re diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index d52a4a9..08aef56 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -1,203 +1,128 @@ -""" -图像管理控制器 -处理图像下载、查看等功能 -""" - -from flask import Blueprint, send_file, jsonify, request, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from app.database import Image, EvaluationResult -from app.services.image_service import ImageService -import os - -image_bp = Blueprint('image', __name__) - -@image_bp.route('/file/', methods=['GET']) -@jwt_required() -def get_image_file(image_id): - """获取图片文件""" - try: - current_user_id = get_jwt_identity() - - # 查找图片记录 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 检查文件是否存在 - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file(image.file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取图片失败: {str(e)}'}), 500 - -@image_bp.route('/download/', methods=['GET']) -@jwt_required() -def download_image(image_id): - """下载图片文件""" - try: - current_user_id = get_jwt_identity() - - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file( - image.file_path, - as_attachment=True, - download_name=image.original_filename or f"image_{image_id}.jpg" - ) - - except Exception as e: - return jsonify({'error': f'下载图片失败: {str(e)}'}), 500 - -@image_bp.route('/batch//download', methods=['GET']) -@jwt_required() -def download_batch_images(batch_id): - """批量下载任务中的加噪后图片""" - try: - current_user_id = get_jwt_identity() - - # 获取任务中的加噪图片 - perturbed_images = Image.query.join(Image.image_type).filter( - Image.batch_id == batch_id, - Image.user_id == current_user_id, - Image.image_type.has(type_code='perturbed') - ).all() - - if not perturbed_images: - return jsonify({'error': '没有找到加噪后的图片'}), 404 - - # 创建ZIP文件 - import zipfile - import tempfile - - with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: - with zipfile.ZipFile(tmp_file.name, 'w') as zip_file: - for image in perturbed_images: - if os.path.exists(image.file_path): - arcname = image.original_filename or f"perturbed_{image.id}.jpg" - zip_file.write(image.file_path, arcname) - - return send_file( - tmp_file.name, - as_attachment=True, - download_name=f"batch_{batch_id}_perturbed_images.zip", - mimetype='application/zip' - ) - - except Exception as e: - return jsonify({'error': f'批量下载失败: {str(e)}'}), 500 - -@image_bp.route('//evaluations', methods=['GET']) -@jwt_required() -def get_image_evaluations(image_id): - """获取图片的评估结果""" - try: - current_user_id = get_jwt_identity() - - # 验证图片权限 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 获取以该图片为参考或目标的评估结果 - evaluations = EvaluationResult.query.filter( - (EvaluationResult.reference_image_id == image_id) | - (EvaluationResult.target_image_id == image_id) - ).all() - - return jsonify({ - 'image_id': image_id, - 'evaluations': [eval_result.to_dict() for eval_result in evaluations] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500 - -@image_bp.route('/compare', methods=['POST']) -@jwt_required() -def compare_images(): - """对比两张图片""" - try: - current_user_id = get_jwt_identity() - data = request.get_json() - - image1_id = data.get('image1_id') - image2_id = data.get('image2_id') - - if not image1_id or not image2_id: - return jsonify({'error': '请提供两张图片的ID'}), 400 - - # 验证图片权限 - image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first() - image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first() - - if not image1 or not image2: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 查找现有的评估结果 - evaluation = EvaluationResult.query.filter_by( - reference_image_id=image1_id, - target_image_id=image2_id - ).first() - - if not evaluation: - # 如果没有评估结果,返回基本对比信息 - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': None, - 'message': '暂无评估数据,请等待任务处理完成' - }), 200 - - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': evaluation.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'图片对比失败: {str(e)}'}), 500 - -@image_bp.route('/heatmap/', methods=['GET']) -@jwt_required() -def get_heatmap(heatmap_path): - """获取热力图文件""" - try: - # 安全检查,防止路径遍历攻击 - if '..' in heatmap_path or heatmap_path.startswith('/'): - return jsonify({'error': '无效的文件路径'}), 400 - - # 修正路径构建 - 获取项目根目录(backend目录) - project_root = os.path.dirname(current_app.root_path) - full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path)) - - if not os.path.exists(full_path): - return jsonify({'error': '热力图文件不存在'}), 404 - - return send_file(full_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500 - -@image_bp.route('/delete/', methods=['DELETE']) -@jwt_required() -def delete_image(image_id): - """删除图片""" - try: - current_user_id = get_jwt_identity() - - result = ImageService.delete_image(image_id, current_user_id) - - if result['success']: - return jsonify({'message': '图片删除成功'}), 200 - else: - return jsonify({'error': result['error']}), 400 - - except Exception as e: - return jsonify({'error': f'删除图片失败: {str(e)}'}), 500 \ No newline at end of file + +""" +图像管理控制器 +负责图片上传、下载等操作 +""" + +from flask import Blueprint, request, jsonify, send_file +from app.controllers.auth_controller import int_jwt_required +from app.services.task_service import TaskService +from app.services.image_service import ImageService + + +image_bp = Blueprint('image', __name__) + + +# ==================== 图片上传 ==================== + +@image_bp.route('/original', methods=['POST']) +@int_jwt_required +def upload_original_images(current_user_id): + task_id = request.form.get('task_id', type=int) + if not task_id: + return ImageService.json_error('缺少 task_id 参数') + + task = TaskService.load_task_for_user(task_id, current_user_id) + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + task_type = TaskService.get_task_type_code(task) + if task_type not in {'perturbation', 'finetune'}: + return ImageService.json_error('任务类型不支持图片上传', 400) + + files = request.files.getlist('files') + target_dir = TaskService.get_original_images_path(task.user_id, task.flow_id) + success, result = ImageService.save_original_images(task, files, target_dir) + if not success: + status_code = 400 + if isinstance(result, str) and (result.startswith('未配置图片类型') or '失败' in result): + status_code = 500 + return ImageService.json_error(result, status_code) + + return jsonify({ + 'message': '图片上传成功', + 'images': [ImageService.serialize_image(img) for img in result], + 'flow_id': task.flow_id + }), 201 + + +# ==================== 结果下载 ==================== + +@image_bp.route('/perturbation//download', methods=['GET']) +@int_jwt_required +def download_perturbation_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('结果文件不存在', 404) + + filename = f"perturbation_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/heatmap//download', methods=['GET']) +@int_jwt_required +def download_heatmap_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('热力图文件不存在', 404) + + filename = f"heatmap_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/finetune//download', methods=['GET']) +@int_jwt_required +def download_finetune_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + if not task.finetune: + return ImageService.json_error('微调任务配置不存在', 404) + + try: + source = TaskService.determine_finetune_source(task) + except ValueError as exc: + return ImageService.json_error(str(exc), 500) + if source == 'perturbation': + directories = { + 'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id), + 'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id) + } + else: + directories = { + 'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id) + } + + zipped, has_files = ImageService.zip_multiple_directories(directories) + if not has_files: + return ImageService.json_error('微调结果文件不存在', 404) + + filename = f"finetune_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/evaluate//download', methods=['GET']) +@int_jwt_required +def download_evaluate_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('评估结果文件不存在', 404) + + filename = f"evaluate_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index f49e668..2fbe51b 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -1,606 +1,535 @@ -""" -任务管理控制器 -处理创建任务、上传图片等功能 -""" - -from flask import Blueprint, request, jsonify, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from werkzeug.utils import secure_filename -from app import db -from app.database import User, Role, PerturbationConfig, FinetuneConfig, UserConfig, Image, ImageType, DataType, TaskType, TaskStatus, Task, Perturbation, Finetune, EvaluationResult, Evaluate, Heatmap -from app.services.task_service import TaskService -from app.services.image_service import ImageService -from app.utils.file_utils import allowed_file, save_uploaded_file -import os -import zipfile -import uuid - -task_bp = Blueprint('task', __name__) - -@task_bp.route('/create', methods=['POST']) -@jwt_required() -def create_task(): - """创建新任务(使用用户配置作为默认配置)""" - try: - current_user_id = get_jwt_identity() - user = User.query.get(current_user_id) - - if not user: - return jsonify({'error': '用户不存在'}), 404 - - - data = request.get_json() - batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') - - # 优先使用前端传来的参数,没有则用用户配置,没有再用默认 - perturbation_config_id = data.get('perturbation_config_id') - preferred_epsilon = data.get('epsilon') - use_strong_protection = data.get('use_strong_protection') - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - if user_config: - if perturbation_config_id is None: - perturbation_config_id = user_config.preferred_perturbation_config_id or 1 - if preferred_epsilon is None: - preferred_epsilon = user_config.preferred_epsilon or 8.0 - if use_strong_protection is None: - use_strong_protection = user_config.preferred_purification or False - else: - perturbation_config_id = perturbation_config_id or 1 - preferred_epsilon = preferred_epsilon or 8.0 - use_strong_protection = use_strong_protection if use_strong_protection is not None else False - - # 类型转换,防止前端传字符串 - try: - perturbation_config_id = int(perturbation_config_id) - except Exception: - perturbation_config_id = 1 - try: - preferred_epsilon = float(preferred_epsilon) - except Exception: - preferred_epsilon = 8.0 - use_strong_protection = bool(use_strong_protection) - - # 创建任务(只包含扰动相关配置,不包含微调配置) - batch = Batch( - user_id=current_user_id, - batch_name=batch_name, - perturbation_config_id=perturbation_config_id, - preferred_epsilon=preferred_epsilon, - use_strong_protection=use_strong_protection - ) - - db.session.add(batch) - db.session.commit() - - # 自动创建关联的微调任务(如果用户有默认微调配置则自动设置) - finetune_config_id = None - if user_config and user_config.preferred_finetune_config_id: - finetune_config_id = user_config.preferred_finetune_config_id - - finetune_batch = FinetuneBatch( - batch_id=batch.id, - user_id=current_user_id, - finetune_config_id=finetune_config_id, - status='pending' - ) - db.session.add(finetune_batch) - db.session.commit() - - return jsonify({ - 'message': '任务创建成功,请上传图片', - 'task': batch.to_dict(), - 'finetune_task_id': finetune_batch.id, - 'finetune_config_set': finetune_config_id is not None - }), 201 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'任务创建失败: {str(e)}'}), 500 - -@task_bp.route('/upload/', methods=['POST']) -@jwt_required() -def upload_images(batch_id): - """上传图片到指定任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400 - - if 'files' not in request.files: - return jsonify({'error': '没有选择文件'}), 400 - - files = request.files.getlist('files') - uploaded_files = [] - - # 获取原始图片类型ID - original_type = ImageType.query.filter_by(type_code='original').first() - if not original_type: - return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500 - - for file in files: - if file.filename == '': - continue - if file and allowed_file(file.filename): - # 处理单张图片 - if not file.filename.lower().endswith(('.zip', '.rar')): - # 统一走save_image,内部已实现上传到uploads和预处理 - result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) - if result['success']: - uploaded_files.append(result['image']) - else: - return jsonify({'error': result['error']}), 400 - else: - # 压缩包内图片也会走save_image - results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) - for result in results: - if result['success']: - uploaded_files.append(result['image']) - - if not uploaded_files: - return jsonify({'error': '没有有效的图片文件'}), 400 - - return jsonify({ - 'message': f'成功上传 {len(uploaded_files)} 张图片', - 'uploaded_files': [img.to_dict() for img in uploaded_files] - }), 200 - - except Exception as e: - return jsonify({'error': f'文件上传失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['GET']) -@jwt_required() -def get_task_config(batch_id): - """获取任务配置(显示用户上次的配置或默认配置)""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - # 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置 - if user_config: - suggested_config = { - 'perturbation_config_id': user_config.preferred_perturbation_config_id, - 'epsilon': float(user_config.preferred_epsilon), - 'use_strong_protection': user_config.preferred_purification - } - else: - suggested_config = { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'use_strong_protection': batch.use_strong_protection - } - - return jsonify({ - 'task': batch.to_dict(), - 'suggested_config': suggested_config, - 'current_config': { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'use_strong_protection': batch.use_strong_protection - } - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['PUT']) -@jwt_required() -def update_task_config(batch_id): - """更新任务配置(仅更新任务本身,不影响用户配置)""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法修改配置'}), 400 - - data = request.get_json() - - # 更新任务配置(仅扰动相关) - if 'perturbation_config_id' in data: - batch.perturbation_config_id = data['perturbation_config_id'] - - if 'epsilon' in data: - epsilon = float(data['epsilon']) - if 0 < epsilon <= 255: - batch.preferred_epsilon = epsilon - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'use_strong_protection' in data: - batch.use_strong_protection = bool(data['use_strong_protection']) - - db.session.commit() - - return jsonify({ - 'message': '任务配置更新成功', - 'task': batch.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500 - -@task_bp.route('/start/', methods=['POST']) -@jwt_required() -def start_task(batch_id): - """开始处理任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status not in ['pending', 'failed', 'canceled']: - return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 - # 如果是失败或取消,重置状态为pending - if batch.status in ['failed', 'canceled']: - batch.status = 'pending' - batch.error_message = None - db.session.commit() - - # 检查是否有上传的图片 - image_count = Image.query.filter_by(batch_id=batch_id).count() - if image_count == 0: - return jsonify({'error': '请先上传图片'}), 400 - - # 启动任务处理 - success = TaskService.start_processing(batch) - - if success: - return jsonify({ - 'message': '任务开始处理', - 'task': batch.to_dict() - }), 200 - else: - return jsonify({'error': '任务启动失败'}), 500 - - except Exception as e: - return jsonify({'error': f'任务启动失败: {str(e)}'}), 500 - -@task_bp.route('/list', methods=['GET']) -@jwt_required() -def list_tasks(): - """获取用户的任务列表""" - try: - current_user_id = get_jwt_identity() - - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - batches = Batch.query.filter_by(user_id=current_user_id)\ - .order_by(Batch.created_at.desc())\ - .paginate(page=page, per_page=per_page, error_out=False) - - return jsonify({ - 'tasks': [batch.to_dict() for batch in batches.items], - 'total': batches.total, - 'pages': batches.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500 - -@task_bp.route('/', methods=['GET']) -@jwt_required() -def get_task_detail(batch_id): - """获取任务详情""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取任务相关的图片 - images = Image.query.filter_by(batch_id=batch_id).all() - - return jsonify({ - 'task': batch.to_dict(), - 'images': [img.to_dict() for img in images], - 'image_count': len(images) - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500 - -@task_bp.route('//status', methods=['GET']) -@jwt_required() -def get_task_status(batch_id): - """获取任务处理状态""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - return jsonify({ - 'task_id': batch_id, - 'status': batch.status, - 'progress': TaskService.get_processing_progress(batch_id), - 'error_message': batch.error_message - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 - -# ==================== 微调任务管理接口 ==================== - -@task_bp.route('/finetune/configs', methods=['GET']) -@jwt_required() -def get_finetune_configs(): - """获取所有可用的微调配置""" - try: - configs = FinetuneConfig.query.all() - return jsonify({ - 'configs': [{ - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description - } for config in configs] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/list', methods=['GET']) -@jwt_required() -def list_finetune_tasks(): - """获取用户的所有微调任务列表""" - try: - current_user_id = get_jwt_identity() - - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - finetune_tasks = FinetuneBatch.query.filter_by(user_id=current_user_id)\ - .order_by(FinetuneBatch.created_at.desc())\ - .paginate(page=page, per_page=per_page, error_out=False) - - results = [] - for ft in finetune_tasks.items: - ft_dict = ft.to_dict() - # 添加关联的扰动任务信息 - ft_dict['batch_info'] = ft.batch.to_dict() if ft.batch else None - results.append(ft_dict) - - return jsonify({ - 'finetune_tasks': results, - 'total': finetune_tasks.total, - 'pages': finetune_tasks.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务列表失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/', methods=['GET']) -@jwt_required() -def get_finetune_task(finetune_id): - """获取微调任务详情""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - result = finetune_task.to_dict() - result['batch_info'] = finetune_task.batch.to_dict() if finetune_task.batch else None - - return jsonify({'finetune_task': result}), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务详情失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/by-batch/', methods=['GET']) -@jwt_required() -def get_finetune_by_batch(batch_id): - """根据扰动任务ID获取关联的微调任务""" - try: - current_user_id = get_jwt_identity() - - # 验证扰动任务权限 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '扰动任务不存在或无权限'}), 404 - - finetune_task = FinetuneBatch.query.filter_by(batch_id=batch_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '该扰动任务没有关联的微调任务'}), 404 - - result = finetune_task.to_dict() - result['batch_info'] = batch.to_dict() - - return jsonify({'finetune_task': result}), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//config', methods=['GET']) -@jwt_required() -def get_finetune_config(finetune_id): - """获取微调任务配置(显示用户默认配置或当前配置)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - # 如果用户有配置,显示用户默认配置;否则显示系统默认 - if user_config and user_config.preferred_finetune_config_id: - suggested_config = { - 'finetune_config_id': user_config.preferred_finetune_config_id, - 'finetune_config_name': user_config.preferred_finetune_config.method_name if user_config.preferred_finetune_config else None - } - else: - # 默认使用第一个微调配置 - default_config = FinetuneConfig.query.first() - suggested_config = { - 'finetune_config_id': default_config.id if default_config else 1, - 'finetune_config_name': default_config.method_name if default_config else None - } - - # 当前微调任务的配置 - current_config = { - 'finetune_config_id': finetune_task.finetune_config_id, - 'finetune_config_name': finetune_task.finetune_config.method_name if finetune_task.finetune_config else None - } - - return jsonify({ - 'finetune_task': finetune_task.to_dict(), - 'suggested_config': suggested_config, - 'current_config': current_config - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//config', methods=['PUT']) -@jwt_required() -def update_finetune_config(finetune_id): - """更新微调任务配置(仅限 pending 状态)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - if finetune_task.status != 'pending': - return jsonify({'error': '只能修改待处理状态的微调任务配置'}), 400 - - data = request.get_json() - finetune_config_id = data.get('finetune_config_id') - - if not finetune_config_id: - return jsonify({'error': '请提供微调方法ID'}), 400 - - # 验证微调配置是否存在 - finetune_config = FinetuneConfig.query.get(finetune_config_id) - if not finetune_config: - return jsonify({'error': '微调配置不存在'}), 404 - - finetune_task.finetune_config_id = finetune_config_id - db.session.commit() - - return jsonify({ - 'message': '微调配置更新成功', - 'finetune_task': finetune_task.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//start', methods=['POST']) -@jwt_required() -def start_finetune(finetune_id): - """启动微调任务""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 检查扰动任务是否已完成 - if finetune_task.batch.status != 'completed': - return jsonify({'error': '扰动任务尚未完成,无法开始微调'}), 400 - - # 检查是否已设置微调配置 - if not finetune_task.finetune_config_id: - return jsonify({'error': '请先设置微调方法'}), 400 - - # 检查状态 - if finetune_task.status not in ['pending', 'failed']: - return jsonify({'error': f'微调任务状态为 {finetune_task.status},无法启动'}), 400 - - # 启动微调任务 - job_ids = TaskService.start_finetune_task(finetune_task) - - if job_ids: - return jsonify({ - 'message': '微调任务已启动', - 'finetune_task_id': finetune_id, - 'job_ids': job_ids - }), 200 - else: - return jsonify({'error': '微调任务启动失败'}), 500 - - except Exception as e: - return jsonify({'error': f'启动微调任务失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//status', methods=['GET']) -@jwt_required() -def get_finetune_task_status(finetune_id): - """获取微调任务状态""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 获取详细状态 - status_info = TaskService.get_finetune_task_status(finetune_id) - - return jsonify({ - 'finetune_task_id': finetune_id, - 'status': finetune_task.status, - 'finetune_config': finetune_task.finetune_config.to_dict() if finetune_task.finetune_config else None, - 'details': status_info, - 'error_message': finetune_task.error_message - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务状态失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/', methods=['DELETE']) -@jwt_required() -def delete_finetune_task(finetune_id): - """删除微调任务(仅限 pending 或 failed 状态)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - if finetune_task.status not in ['pending', 'failed']: - return jsonify({'error': '只能删除待处理或失败状态的微调任务'}), 400 - - db.session.delete(finetune_task) - db.session.commit() - - return jsonify({'message': '微调任务删除成功'}), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'删除微调任务失败: {str(e)}'}), 500 + +""" +任务管理控制器 +适配新数据库结构,提供加噪、微调、热力图、数值评估等任务相关接口 +""" + +from flask import Blueprint, request, jsonify +from app import db +from app.controllers.auth_controller import int_jwt_required +from app.database import ( + Task, + Perturbation, Finetune, Heatmap, Evaluate, + PerturbationConfig, FinetuneConfig, DataType, + Image +) +from app.services.task_service import TaskService + + +task_bp = Blueprint('task', __name__) + + +# ==================== 通用任务接口 ==================== + +@task_bp.route('//status', methods=['GET']) +@int_jwt_required +def get_task_status(task_id, current_user_id): + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, current_user_id): + return TaskService.json_error('任务不存在或无权限', 404) + status = TaskService.get_task_status(task_id) + return jsonify(status), 200 + + +@task_bp.route('//cancel', methods=['POST']) +@int_jwt_required +def cancel_task(task_id, current_user_id): + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, current_user_id): + return TaskService.json_error('任务不存在或无权限', 404) + if TaskService.cancel_task(task_id): + return jsonify({'message': '任务已取消'}), 200 + return TaskService.json_error('取消任务失败', 500) + + +# ==================== 加噪任务 ==================== + +@task_bp.route('/perturbation/configs', methods=['GET']) +@int_jwt_required +def list_perturbation_configs(current_user_id): + configs = PerturbationConfig.query.order_by(PerturbationConfig.perturbation_configs_id).all() + return jsonify({'configs': [ + { + 'perturbation_configs_id': cfg.perturbation_configs_id, + 'perturbation_code': cfg.perturbation_code, + 'perturbation_name': cfg.perturbation_name, + 'description': cfg.description, + } + for cfg in configs + ]}), 200 + + +@task_bp.route('/perturbation', methods=['POST']) +@int_jwt_required +def create_perturbation_task(current_user_id): + data = request.get_json() or {} + data_type_id = data.get('data_type_id') + perturbation_configs_id = data.get('perturbation_configs_id') + intensity = data.get('perturbation_intensity') + + if not all([data_type_id, perturbation_configs_id, intensity]): + return TaskService.json_error('缺少必要的任务参数') + + if not DataType.query.get(data_type_id): + return TaskService.json_error('数据集类型不存在') + if not PerturbationConfig.query.get(perturbation_configs_id): + return TaskService.json_error('加噪配置不存在') + + try: + flow_id = data.get('flow_id') + flow_id = int(flow_id) if flow_id is not None else TaskService.generate_flow_id() + except Exception: + return TaskService.json_error('非法的 flow_id 参数') + + try: + pending_status = TaskService.ensure_status('pending') + perturb_type = TaskService.require_task_type('perturbation') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=flow_id, + tasks_type_id=perturb_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + perturbation = Perturbation( + tasks_id=task.tasks_id, + data_type_id=data_type_id, + perturbation_configs_id=perturbation_configs_id, + perturbation_intensity=float(intensity), + perturbation_name=data.get('perturbation_name') + ) + db.session.add(perturbation) + db.session.commit() + + return jsonify({ + 'message': '加噪任务已创建', + 'task': TaskService.serialize_task(task) + }), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建任务失败: {exc}', 500) + + +@task_bp.route('/perturbation/', methods=['PATCH']) +@int_jwt_required +def update_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + data = request.get_json() or {} + pert = task.perturbation + if not pert: + return TaskService.json_error('任务配置不存在', 404) + + if 'data_type_id' in data: + if not DataType.query.get(data['data_type_id']): + return TaskService.json_error('数据集类型不存在') + pert.data_type_id = data['data_type_id'] + if 'perturbation_configs_id' in data: + if not PerturbationConfig.query.get(data['perturbation_configs_id']): + return TaskService.json_error('加噪配置不存在') + pert.perturbation_configs_id = data['perturbation_configs_id'] + if 'perturbation_intensity' in data: + pert.perturbation_intensity = float(data['perturbation_intensity']) + if 'perturbation_name' in data: + pert.perturbation_name = data['perturbation_name'] + if 'description' in data: + task.description = data['description'] + + try: + db.session.commit() + return jsonify({'message': '任务已更新', 'task': TaskService.serialize_task(task)}), 200 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'更新任务失败: {exc}', 500) + + +@task_bp.route('/perturbation//start', methods=['POST']) +@int_jwt_required +def start_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_perturbation_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/perturbation', methods=['GET']) +@int_jwt_required +def list_perturbation_tasks(current_user_id): + try: + perturb_type = TaskService.require_task_type('perturbation') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=perturb_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/perturbation/', methods=['GET']) +@int_jwt_required +def get_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 热力图任务 ==================== + +@task_bp.route('/heatmap', methods=['POST']) +@int_jwt_required +def create_heatmap_task(current_user_id): + data = request.get_json() or {} + perturbation_task_id = data.get('perturbation_task_id') + perturbed_image_id = data.get('perturbed_image_id') + + if not perturbation_task_id or not perturbed_image_id: + return TaskService.json_error('缺少必要参数: perturbation_task_id 或 perturbed_image_id') + + perturbation_task = TaskService.load_task_for_user(perturbation_task_id, current_user_id, expected_type='perturbation') + if not perturbation_task: + return TaskService.json_error('加噪任务不存在或无权限', 404) + + status_code = perturbation_task.task_status.task_status_code if perturbation_task.task_status else None + if status_code != 'completed': + return TaskService.json_error('仅支持已完成的加噪任务创建热力图') + + perturbed_image = Image.query.get(perturbed_image_id) + if not perturbed_image or perturbed_image.task_id != perturbation_task_id: + return TaskService.json_error('扰动图片不存在或不属于该任务') + + try: + heatmap_type = TaskService.require_task_type('heatmap') + pending_status = TaskService.ensure_status('pending') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=perturbation_task.flow_id, + tasks_type_id=heatmap_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + heatmap = Heatmap( + tasks_id=task.tasks_id, + images_id=perturbed_image_id, + heatmap_name=data.get('heatmap_name') + ) + db.session.add(heatmap) + db.session.commit() + + return jsonify({'message': '热力图任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建热力图任务失败: {exc}', 500) + + +@task_bp.route('/heatmap//start', methods=['POST']) +@int_jwt_required +def start_heatmap_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + if not task.heatmap: + return TaskService.json_error('热力图任务未配置对应图片', 400) + + job_id = TaskService.start_heatmap_task(task_id, task.heatmap.images_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/heatmap', methods=['GET']) +@int_jwt_required +def list_heatmap_tasks(current_user_id): + try: + heatmap_type = TaskService.require_task_type('heatmap') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=heatmap_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/heatmap/', methods=['GET']) +@int_jwt_required +def get_heatmap_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 微调任务 ==================== + +@task_bp.route('/finetune/configs', methods=['GET']) +@int_jwt_required +def list_finetune_configs(current_user_id): + configs = FinetuneConfig.query.order_by(FinetuneConfig.finetune_configs_id).all() + return jsonify({'configs': [ + { + 'finetune_configs_id': cfg.finetune_configs_id, + 'finetune_code': cfg.finetune_code, + 'finetune_name': cfg.finetune_name, + 'description': cfg.description, + } + for cfg in configs + ]}), 200 + + +@task_bp.route('/finetune/from-perturbation', methods=['POST']) +@int_jwt_required +def create_finetune_from_perturbation(current_user_id): + data = request.get_json() or {} + perturbation_task_id = data.get('perturbation_task_id') + finetune_configs_id = data.get('finetune_configs_id') + + if not perturbation_task_id or not finetune_configs_id: + return TaskService.json_error('缺少必要参数: perturbation_task_id 或 finetune_configs_id') + + perturbation_task = TaskService.load_task_for_user(perturbation_task_id, current_user_id, expected_type='perturbation') + if not perturbation_task: + return TaskService.json_error('加噪任务不存在或无权限', 404) + + if not FinetuneConfig.query.get(finetune_configs_id): + return TaskService.json_error('微调配置不存在') + + try: + pending_status = TaskService.ensure_status('pending') + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=perturbation_task.flow_id, + tasks_type_id=finetune_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + finetune = Finetune( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_configs_id, + data_type_id=data.get('data_type_id'), + finetune_name=data.get('finetune_name') + ) + db.session.add(finetune) + db.session.commit() + + return jsonify({'message': '微调任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建微调任务失败: {exc}', 500) + + +@task_bp.route('/finetune/from-upload', methods=['POST']) +@int_jwt_required +def create_finetune_from_upload(current_user_id): + user = TaskService.get_user(current_user_id) + if not user: + return TaskService.json_error('用户不存在', 404) + + role_code = user.role.role_code if user.role else 'user' + if role_code not in ('vip', 'admin'): + return TaskService.json_error('仅限VIP或管理员使用上传微调功能', 403) + + data = request.get_json() or {} + finetune_configs_id = data.get('finetune_configs_id') + if not finetune_configs_id: + return TaskService.json_error('缺少必要参数: finetune_configs_id') + + if not FinetuneConfig.query.get(finetune_configs_id): + return TaskService.json_error('微调配置不存在') + + try: + flow_id = data.get('flow_id') + if flow_id is not None: + flow_id = int(flow_id) + existing = Task.query.filter_by(flow_id=flow_id).first() + if existing: + return TaskService.json_error('flow_id 已被占用,请勿复用已有任务流') + else: + flow_id = TaskService.generate_flow_id() + except Exception: + return TaskService.json_error('非法的 flow_id 参数') + + try: + pending_status = TaskService.ensure_status('pending') + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=flow_id, + tasks_type_id=finetune_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + finetune = Finetune( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_configs_id, + data_type_id=data.get('data_type_id'), + finetune_name=data.get('finetune_name') + ) + db.session.add(finetune) + db.session.commit() + + return jsonify({ + 'message': '上传微调任务已创建', + 'task': TaskService.serialize_task(task) + }), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建微调任务失败: {exc}', 500) + + +@task_bp.route('/finetune//start', methods=['POST']) +@int_jwt_required +def start_finetune_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_finetune_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/finetune', methods=['GET']) +@int_jwt_required +def list_finetune_tasks(current_user_id): + source = request.args.get('source') + try: + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + query = Task.query.filter_by(user_id=current_user_id, tasks_type_id=finetune_type.task_type_id) + + tasks = query.order_by(Task.created_at.desc()).all() + serialized = [] + for task in tasks: + task_dict = TaskService.serialize_task(task) + if source and task_dict.get('finetune', {}).get('source') != source: + continue + serialized.append(task_dict) + return jsonify({'tasks': serialized}), 200 + + +@task_bp.route('/finetune/', methods=['GET']) +@int_jwt_required +def get_finetune_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 数值评估任务 ==================== + +@task_bp.route('/evaluate', methods=['POST']) +@int_jwt_required +def create_evaluate_task(current_user_id): + data = request.get_json() or {} + finetune_task_id = data.get('finetune_task_id') + if not finetune_task_id: + return TaskService.json_error('缺少必要参数: finetune_task_id') + + finetune_task = TaskService.load_task_for_user(finetune_task_id, current_user_id, expected_type='finetune') + if not finetune_task: + return TaskService.json_error('微调任务不存在或无权限', 404) + + # 仅允许基于加噪微调创建评估 + if TaskService.determine_finetune_source(finetune_task) != 'perturbation': + return TaskService.json_error('数值评估仅支持基于加噪任务的微调结果') + + if not finetune_task.finetune: + return TaskService.json_error('微调任务未配置详情', 400) + + try: + evaluate_type = TaskService.require_task_type('evaluate') + pending_status = TaskService.ensure_status('pending') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=finetune_task.flow_id, + tasks_type_id=evaluate_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + evaluate = Evaluate( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_task.finetune.finetune_configs_id, + evaluate_name=data.get('evaluate_name') + ) + db.session.add(evaluate) + db.session.commit() + + return jsonify({'message': '评估任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建评估任务失败: {exc}', 500) + + +@task_bp.route('/evaluate//start', methods=['POST']) +@int_jwt_required +def start_evaluate_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_evaluate_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/evaluate', methods=['GET']) +@int_jwt_required +def list_evaluate_tasks(current_user_id): + try: + evaluate_type = TaskService.require_task_type('evaluate') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=evaluate_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/evaluate/', methods=['GET']) +@int_jwt_required +def get_evaluate_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 diff --git a/src/backend/app/controllers/user_controller.py b/src/backend/app/controllers/user_controller.py index 3d99fda..3325b9e 100644 --- a/src/backend/app/controllers/user_controller.py +++ b/src/backend/app/controllers/user_controller.py @@ -1,129 +1,119 @@ -""" -用户管理控制器 -处理用户配置等功能 -""" - -from flask import Blueprint, request, jsonify -from flask_jwt_extended import jwt_required -from app import db -from app.database import User, UserConfig, Perturbation, Finetune -from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 - -user_bp = Blueprint('user', __name__) - -@user_bp.route('/config', methods=['GET']) -@int_jwt_required -def get_user_config(current_user_id): - """获取用户配置""" - try: - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - # 如果没有配置,创建默认配置 - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - db.session.commit() - - return jsonify({ - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/config', methods=['PUT']) -@int_jwt_required -def update_user_config(current_user_id): - """更新用户配置""" - try: - data = request.get_json() - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - - # 更新配置字段 - if 'perturbation_configs_id' in data: - user_config.perturbation_configs_id = data['perturbation_configs_id'] - - if 'perturbation_intensity' in data: - intensity = float(data['perturbation_intensity']) - if 0 < epsilon <= 255: - user_config.perturbation_intensity = intensity - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'finetune_config_id' in data: - user_config.finetune_config_id = data['finetune_config_id'] - - db.session.commit() - - return jsonify({ - 'message': '用户配置更新成功', - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/algorithms', methods=['GET']) -@jwt_required() -def get_available_algorithms(): - """获取可用的算法列表""" - try: - perturbation_configs = Perturbation.query.all() - finetune_configs = Finetune.query.all() - - return jsonify({ - 'perturbation_algorithms': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description, - } for config in perturbation_configs - ], - 'finetune_methods': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description - } for config in finetune_configs - ] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500 - -@user_bp.route('/stats', methods=['GET']) -@int_jwt_required -def get_user_stats(current_user_id): - """获取用户统计信息""" - try: - from app.database import Task, Image - - # 统计用户的任务和图片数量 - total_tasks = Task.query.filter_by(user_id=current_user_id).count() - completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count() - processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count() - failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count() - - total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count() - - return jsonify({ - 'stats': { - 'total_tasks': total_tasks, - 'completed_tasks': completed_tasks, - 'processing_tasks': processing_tasks, - 'failed_tasks': failed_tasks, - 'total_images': total_images - } - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500 \ No newline at end of file + +""" +用户管理控制器 +负责用户配置、任务汇总等接口 +""" + +from flask import Blueprint, request, jsonify +from app import db +from app.controllers.auth_controller import int_jwt_required +from app.database import UserConfig, Task, TaskType, TaskStatus + + +user_bp = Blueprint('user', __name__) + + +def _json_error(message, status_code=400): + return jsonify({'error': message}), status_code + + +def _get_or_create_user_config(user_id): + config = UserConfig.query.filter_by(user_id=user_id).first() + if not config: + config = UserConfig(user_id=user_id) + db.session.add(config) + db.session.commit() + return config + + +def _serialize_config(config): + return { + 'user_configs_id': config.user_configs_id, + 'user_id': config.user_id, + 'data_type_id': config.data_type_id, + 'perturbation_configs_id': config.perturbation_configs_id, + 'perturbation_intensity': config.perturbation_intensity, + 'finetune_configs_id': config.finetune_configs_id, + 'created_at': config.created_at.isoformat() if config.created_at else None, + 'updated_at': config.updated_at.isoformat() if config.updated_at else None, + } + + +def _serialize_task(task): + status_code = task.task_status.task_status_code if task.task_status else None + task_type_code = task.task_type.task_type_code if task.task_type else None + return { + 'task_id': task.tasks_id, + 'flow_id': task.flow_id, + 'task_type': task_type_code, + 'status': status_code, + 'created_at': task.created_at.isoformat() if task.created_at else None, + 'started_at': task.started_at.isoformat() if task.started_at else None, + 'finished_at': task.finished_at.isoformat() if task.finished_at else None, + 'description': task.description, + 'error_message': task.error_message + } + + +@user_bp.route('/config', methods=['GET']) +@int_jwt_required +def get_user_config(current_user_id): + config = _get_or_create_user_config(current_user_id) + return jsonify({'config': _serialize_config(config)}), 200 + + +@user_bp.route('/config', methods=['PUT']) +@int_jwt_required +def update_user_config(current_user_id): + config = _get_or_create_user_config(current_user_id) + data = request.get_json() or {} + + allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'} + for key, value in data.items(): + if key in allowed_fields: + if key == 'perturbation_intensity' and value is not None: + try: + value = float(value) + except (TypeError, ValueError): + return _json_error('perturbation_intensity 参数格式不正确') + setattr(config, key, value) + + try: + db.session.commit() + return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200 + except Exception as exc: + db.session.rollback() + return _json_error(f'更新配置失败: {exc}', 500) + + +@user_bp.route('/tasks', methods=['GET']) +@int_jwt_required +def list_user_tasks(current_user_id): + task_type_code = request.args.get('type') + status_code = request.args.get('status') + + query = Task.query.filter_by(user_id=current_user_id) + + if task_type_code: + task_type = TaskType.query.filter_by(task_type_code=task_type_code).first() + if not task_type: + return _json_error('任务类型不存在', 404) + query = query.filter(Task.tasks_type_id == task_type.task_type_id) + + if status_code: + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if not status: + return _json_error('任务状态不存在', 404) + query = query.filter(Task.tasks_status_id == status.task_status_id) + + tasks = query.order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200 + + +@user_bp.route('/tasks/', methods=['GET']) +@int_jwt_required +def get_user_task(task_id, current_user_id): + task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first() + if not task: + return _json_error('任务不存在或无权限', 404) + return jsonify({'task': _serialize_task(task)}), 200 diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index 5c64d9e..933ad78 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -3,16 +3,18 @@ 处理图像上传、保存等功能 """ +import io import os import uuid import zipfile import fcntl import time +from datetime import datetime from werkzeug.utils import secure_filename -from flask import current_app +from flask import current_app, jsonify from PIL import Image as PILImage from app import db -from app.database import Image +from app.database import Image, ImageType from app.utils.file_utils import allowed_file class ImageService: @@ -254,4 +256,173 @@ class ImageService: except Exception as e: db.session.rollback() - return {'success': False, 'error': f'删除图片失败: {str(e)}'} \ No newline at end of file + return {'success': False, 'error': f'删除图片失败: {str(e)}'} + + # ==================== 控制器辅助功能 ==================== + + DEFAULT_TARGET_SIZE = 512 + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'} + + @staticmethod + def json_error(message, status_code=400): + """统一错误响应""" + return jsonify({'error': message}), status_code + + @staticmethod + def get_image_type_by_code(code): + """根据代码获取图片类型""" + return ImageType.query.filter_by(image_code=code).first() + + @staticmethod + def save_original_images(task, files, target_dir, image_type_code='original', target_size=None): + """保存原图上传""" + if not files: + return False, '未检测到文件上传' + + image_type = ImageService.get_image_type_by_code(image_type_code) + if not image_type: + return False, f'未配置图片类型: {image_type_code}' + + os.makedirs(target_dir, exist_ok=True) + + saved_records = [] + saved_paths = [] + size = target_size or ImageService.DEFAULT_TARGET_SIZE + + try: + for file in files: + if not file or not file.filename: + continue + if not allowed_file(file.filename): + continue + + extension = os.path.splitext(file.filename)[1].lower() + if extension not in ImageService.IMAGE_EXTENSIONS: + continue + + processed = ImageService._prepare_image(file, size) + filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir) + image = ImageService._create_image_record( + task, + image_type.image_types_id, + filename, + path, + width, + height, + file_size + ) + saved_records.append(image) + saved_paths.append(path) + + if not saved_records: + db.session.rollback() + return False, '未上传有效的图片文件' + + db.session.commit() + return True, saved_records + except Exception as exc: + db.session.rollback() + for path in saved_paths: + if os.path.exists(path): + try: + os.remove(path) + except OSError: + pass + return False, f'上传图片失败: {exc}' + + @staticmethod + def _prepare_image(file_storage, target_size): + """裁剪并缩放上传图片""" + file_storage.stream.seek(0) + image = PILImage.open(file_storage.stream).convert('RGB') + width, height = image.size + min_dim = min(width, height) + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + image = image.crop((left, top, left + min_dim, top + min_dim)) + return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS) + + @staticmethod + def _save_processed_image(image, target_dir): + """将处理后的图片保存为PNG""" + timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') + filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" + path = os.path.join(target_dir, filename) + image.save(path, format='PNG') + return filename, path, image.width, image.height, os.path.getsize(path) + + @staticmethod + def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None): + """创建图片数据库记录""" + image = Image( + task_id=task.tasks_id, + image_types_id=image_type_id, + father_id=father_id, + stored_filename=filename, + file_path=path, + file_size=file_size, + width=width, + height=height + ) + db.session.add(image) + return image + + @staticmethod + def zip_directory(directory): + """打包目录为zip""" + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + if os.path.isdir(directory): + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + arcname = os.path.relpath(file_path, directory) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + @staticmethod + def zip_multiple_directories(directories): + """打包多个目录""" + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + if isinstance(directories, dict): + iterable = directories.items() + else: + iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories) + + for label, directory in iterable: + if not os.path.isdir(directory): + continue + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, directory) + arcname = os.path.join(label or 'output', rel_path) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + @staticmethod + def serialize_image(image): + """图片序列化""" + if not image: + return None + return { + 'image_id': image.images_id, + 'task_id': image.task_id, + 'stored_filename': image.stored_filename, + 'file_path': image.file_path, + 'file_size': image.file_size, + 'width': image.width, + 'height': image.height, + 'image_type': image.image_type.image_code if image.image_type else None + } \ No newline at end of file diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index d60f1ec..80384ef 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -7,7 +7,7 @@ import os import logging from datetime import datetime -from flask import current_app +from flask import current_app, jsonify from redis import Redis from rq import Queue from rq.job import Job @@ -16,7 +16,7 @@ from app.database import ( Task, TaskStatus, TaskType, Perturbation, Finetune, Heatmap, Evaluate, Image, ImageType, DataType, - PerturbationConfig, FinetuneConfig + PerturbationConfig, FinetuneConfig, User ) from config.algorithm_config import AlgorithmConfig from config.settings import Config @@ -116,6 +116,135 @@ class TaskService: str(flow_id) ) + # ==================== 通用辅助功能 ==================== + + @staticmethod + def json_error(message, status_code=400): + """统一的错误响应""" + return jsonify({'error': message}), status_code + + @staticmethod + def get_task_type(code): + """根据任务类型代码获取TaskType""" + return TaskType.query.filter_by(task_type_code=code).first() + + @staticmethod + def require_task_type(code): + """确保任务类型存在""" + task_type = TaskService.get_task_type(code) + if not task_type: + raise ValueError(f"Task type '{code}' is not configured") + return task_type + + @staticmethod + def get_status_by_code(code): + """根据状态代码获取TaskStatus""" + return TaskStatus.query.filter_by(task_status_code=code).first() + + @staticmethod + def ensure_status(code): + """确保任务状态存在""" + status = TaskService.get_status_by_code(code) + if not status: + raise ValueError(f"Task status '{code}' is not configured") + return status + + @staticmethod + def generate_flow_id(): + """生成唯一的flow_id""" + base = int(datetime.utcnow().timestamp() * 1000) + while Task.query.filter_by(flow_id=base).first(): + base += 1 + return base + + @staticmethod + def ensure_task_owner(task, user_id): + """验证任务归属""" + return bool(task and task.user_id == user_id) + + @staticmethod + def get_task_type_code(task): + """获取任务类型代码""" + return task.task_type.task_type_code if task and task.task_type else None + + @staticmethod + def load_task_for_user(task_id, user_id, expected_type=None): + """根据任务ID加载用户的任务,可选检查类型""" + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, user_id): + return None + if expected_type: + task_type = TaskService.get_task_type_code(task) + if task_type != expected_type: + return None + return task + + @staticmethod + def determine_finetune_source(finetune_task): + """判断微调任务来源""" + perturb_type = TaskService.require_task_type('perturbation') + sibling_perturbation = Task.query.filter( + Task.flow_id == finetune_task.flow_id, + Task.tasks_type_id == perturb_type.task_type_id, + Task.tasks_id != finetune_task.tasks_id + ).first() + return 'perturbation' if sibling_perturbation else 'uploaded' + + @staticmethod + def serialize_task(task): + """任务序列化""" + task_type = TaskService.get_task_type_code(task) + status = task.task_status.task_status_code if task and task.task_status else None + base = { + 'task_id': task.tasks_id, + 'flow_id': task.flow_id, + 'task_type': task_type, + 'status': status, + 'user_id': task.user_id, + 'description': task.description, + 'created_at': task.created_at.isoformat() if task.created_at else None, + 'started_at': task.started_at.isoformat() if task.started_at else None, + 'finished_at': task.finished_at.isoformat() if task.finished_at else None, + 'error_message': task.error_message, + } + + if task_type == 'perturbation' and task.perturbation: + base['perturbation'] = { + 'data_type_id': task.perturbation.data_type_id, + 'perturbation_configs_id': task.perturbation.perturbation_configs_id, + 'perturbation_intensity': float(task.perturbation.perturbation_intensity), + 'perturbation_name': task.perturbation.perturbation_name, + } + elif task_type == 'finetune' and task.finetune: + try: + source = TaskService.determine_finetune_source(task) + except ValueError: + source = 'uploaded' + base['finetune'] = { + 'finetune_configs_id': task.finetune.finetune_configs_id, + 'data_type_id': task.finetune.data_type_id, + 'finetune_name': task.finetune.finetune_name, + 'source': source + } + elif task_type == 'heatmap' and task.heatmap: + base['heatmap'] = { + 'perturbed_image_id': task.heatmap.images_id, + 'heatmap_name': task.heatmap.heatmap_name + } + elif task_type == 'evaluate' and task.evaluation: + base['evaluate'] = { + 'finetune_configs_id': task.evaluation.finetune_configs_id, + 'evaluate_name': task.evaluation.evaluate_name, + 'evaluation_results_id': task.evaluation.evaluation_results_id + } + + return base + + @staticmethod + def get_user(user_id): + """获取用户""" + return User.query.get(user_id) + # ==================== Redis/RQ 连接管理 ==================== @staticmethod -- 2.34.1 From e587b29b2f1fcb7c14ffbbe762e516f44a56cecf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 10:52:36 +0800 Subject: [PATCH 14/14] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/controllers/admin_controller.py | 2 +- src/backend/init_db.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/backend/app/controllers/admin_controller.py b/src/backend/app/controllers/admin_controller.py index e9893f2..e0fd396 100644 --- a/src/backend/app/controllers/admin_controller.py +++ b/src/backend/app/controllers/admin_controller.py @@ -230,7 +230,7 @@ def get_system_stats(): 'total': total_tasks, 'completed': completed_tasks, 'processing': processing_tasks, - 'failed': failed_tasks + 'failed': failed_tasks, 'waiting': waiting_tasks }, 'images': { diff --git a/src/backend/init_db.py b/src/backend/init_db.py index ca1c814..0db5d42 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -15,9 +15,9 @@ def init_database(): # 初始化角色数据 roles = [ - {'role_id': 0, 'role_code': 'admin', 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'}, - {'role_id': 1, 'role_code': 'vip', 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'}, - {'role_id': 2, 'role_code': 'normal', 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'} + {'role_id': 1, 'role_code': 'admin', 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'}, + {'role_id': 2, 'role_code': 'vip', 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'}, + {'role_id': 3, 'role_code': 'normal', 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'} ] for role_data in roles: existing = Role.query.filter_by(role_id=role_data['role_id']).first() @@ -108,9 +108,9 @@ def init_database(): # 创建默认测试用户(三种角色各一个) test_users = [ - {'username': 'admin_test', 'email': 'admin@test.com', 'password': 'admin123', 'role_id': 0}, - {'username': 'vip_test', 'email': 'vip@test.com', 'password': 'vip123', 'role_id': 1}, - {'username': 'normal_test', 'email': 'normal@test.com', 'password': 'normal123', 'role_id': 2} + {'username': 'admin_test', 'email': 'admin@test.com', 'password': 'admin123', 'role_id': 1}, + {'username': 'vip_test', 'email': 'vip@test.com', 'password': 'vip123', 'role_id': 2}, + {'username': 'normal_test', 'email': 'normal@test.com', 'password': 'normal123', 'role_id': 3} ] for user_data in test_users: -- 2.34.1