From acc0f10ca451f1ecdb09b4ef529d8abb82395e05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 01:06:53 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=87=AA=E5=AE=9A=E4=B9=89=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/controllers/task_controller.py | 6 ++++-- src/backend/app/database/__init__.py | 1 + src/backend/app/services/task_service.py | 1 + src/backend/app/workers/finetune_worker.py | 16 +++++++++++----- src/backend/config/algorithm_config.py | 6 +++--- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 0a30e7e..9566a69 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -449,7 +449,8 @@ def create_finetune_from_perturbation(current_user_id): tasks_id=task.tasks_id, finetune_configs_id=finetune_configs_id, data_type_id=data.get('data_type_id'), - finetune_name=data.get('finetune_name') + finetune_name=data.get('finetune_name'), + custom_prompt=data.get('custom_prompt') ) db.session.add(finetune) db.session.commit() @@ -536,7 +537,8 @@ def create_finetune_from_upload(current_user_id): tasks_id=task.tasks_id, finetune_configs_id=finetune_configs_id, data_type_id=data_type_id, - finetune_name=data.get('finetune_name') + finetune_name=data.get('finetune_name'), + custom_prompt=data.get('custom_prompt') ) db.session.add(finetune) db.session.commit() diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index aedb9e8..8331e3e 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -245,6 +245,7 @@ class Finetune(db.Model): finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='微调配置ID') data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集') finetune_name = db.Column(String(100), comment='微调任务名称') + custom_prompt = db.Column(String(255), comment='用户自定义提示词(如 a photo of sks cat)') task = db.relationship('Task', back_populates='finetune') finetune_config = db.relationship('FinetuneConfig') diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 4fdd4fc..45aaa28 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -231,6 +231,7 @@ class TaskService: 'finetune_configs_id': task.finetune.finetune_configs_id, 'data_type_id': task.finetune.data_type_id, 'finetune_name': task.finetune.finetune_name, + 'custom_prompt': task.finetune.custom_prompt, 'source': source } elif task_type == 'heatmap' and task.heatmap: diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 57eef16..cf989fd 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -66,20 +66,26 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}") - # 从数据库获取数据集类型的提示词 - # 从Finetune表的data_type_id获取 + # 从数据库获取提示词 + # 优先级:用户自定义 custom_prompt > 数据集类型 data_type_id > 默认值 instance_prompt = "a photo of sks person" # 默认值 class_prompt = "a photo of person" # 默认值 validation_prompt = "a photo of sks person" # 默认值 - + initializer_token = "person" # 默认初始化词 + + if finetune.custom_prompt: + # 使用用户自定义提示词 (例如: "cat") -> "a photo of sks cat" + validation_prompt = f"a photo of sks {finetune.custom_prompt}" + if finetune.data_type_id: data_type = DataType.query.get(finetune.data_type_id) if data_type and data_type.data_type_prompt: instance_prompt = data_type.data_type_prompt validation_prompt = instance_prompt # 从instance_prompt生成class_prompt(移除"sks") - class_prompt = instance_prompt.replace('sks ', '') - logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") + class_prompt = instance_prompt + + logger.info(f"Using prompts - instance: '{instance_prompt}', class: '{class_prompt}', validation: '{validation_prompt}'") # 彻底清空输出目录(避免旧文件残留,特别是 textual_inversion 的 token) logger.info(f"Completely clearing output directories...") diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 17fb2f5..072f6b5 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -238,9 +238,9 @@ class AlgorithmConfig: 'conda_env': CONDA_ENVS['textual_inversion'], 'default_params': { 'pretrained_model_name_or_path': MODELS_DIR['model2'], - 'placeholder_token': 'sks', + 'placeholder_token': '', 'initializer_token': 'person', - 'instance_prompt': 'a photo of sks person', + 'instance_prompt': 'a photo of person', 'resolution': 512, 'train_batch_size': 1, 'gradient_accumulation_steps': 1, @@ -251,7 +251,7 @@ class AlgorithmConfig: 'checkpointing_steps': 2, 'seed': 0, 'mixed_precision': 'fp16', - 'validation_prompt': 'a photo of sks person', + 'validation_prompt': 'a photo of person', 'num_validation_images': 4, 'validation_epochs': 50, 'coords_log_interval': 10 -- 2.34.1 From f59d14bdf8031c6974e14f77b61acb4aa8ca1c8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 01:50:13 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0prompt=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/database/__init__.py | 5 +- src/backend/app/workers/finetune_worker.py | 67 ++++++++++++++++------ src/backend/init_db.py | 18 +++++- 3 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 8331e3e..300610c 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -98,7 +98,10 @@ class DataType(db.Model): __tablename__ = 'data_type' data_type_id = db.Column(Integer, primary_key=True) data_type_code = db.Column(String(50), nullable=False) - data_type_prompt = db.Column(Text, comment='数据集相关的Prompt') + instance_prompt = db.Column(Text, comment='数据集相关的Prompt (Instance Prompt Template, e.g. "a photo of sks person")') + class_prompt = db.Column(String(255), comment='类别Prompt (e.g. "a photo of person")') + placeholder_token = db.Column(String(50), comment='TI Placeholder (e.g. "")') + initializer_token = db.Column(String(50), comment='TI Initializer (e.g. "person")') description = db.Column(Text) def __repr__(self): diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index cf989fd..5738f1e 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -66,26 +66,55 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}") - # 从数据库获取提示词 - # 优先级:用户自定义 custom_prompt > 数据集类型 data_type_id > 默认值 - instance_prompt = "a photo of sks person" # 默认值 - class_prompt = "a photo of person" # 默认值 - validation_prompt = "a photo of sks person" # 默认值 - initializer_token = "person" # 默认初始化词 + # 获取 DataType 配置 + data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None + + # 默认值 (Fallback) + instance_prompt = "a photo of sks person" + class_prompt = "a photo of person" + placeholder_token = "" + initializer_token = "person" + + if data_type: + if data_type.instance_prompt: + instance_prompt = data_type.instance_prompt + if data_type.class_prompt: + class_prompt = data_type.class_prompt + if data_type.placeholder_token: + placeholder_token = data_type.placeholder_token + if data_type.initializer_token: + initializer_token = data_type.initializer_token + + logger.info(f"DataType Config - Template: '{instance_prompt}', Class: '{class_prompt}'") - if finetune.custom_prompt: - # 使用用户自定义提示词 (例如: "cat") -> "a photo of sks cat" - validation_prompt = f"a photo of sks {finetune.custom_prompt}" - - if finetune.data_type_id: - data_type = DataType.query.get(finetune.data_type_id) - if data_type and data_type.data_type_prompt: - instance_prompt = data_type.data_type_prompt - validation_prompt = instance_prompt - # 从instance_prompt生成class_prompt(移除"sks") - class_prompt = instance_prompt - - logger.info(f"Using prompts - instance: '{instance_prompt}', class: '{class_prompt}', validation: '{validation_prompt}'") + # 根据微调方法调整 Instance Prompt + if finetune_method == 'textual_inversion': + # TI: 将 'sks' 替换为 placeholder_token + instance_prompt_prefix = instance_prompt.replace('sks', placeholder_token) + else: + # DreamBooth/LoRA: 直接使用模板 + instance_prompt_prefix = instance_prompt + + # 处理 Validation Prompt (拼接后缀) + prompt_suffix = finetune.custom_prompt.strip() if finetune.custom_prompt else "" + + if prompt_suffix: + validation_prompt = f"{instance_prompt_prefix}, {prompt_suffix}" + else: + validation_prompt = instance_prompt_prefix + + instance_prompt = instance_prompt_prefix + + logger.info(f"Prompts Finalized - Instance: '{instance_prompt}', Class: '{class_prompt}', Validation: '{validation_prompt}'") + + # 设置 TI 特有参数 + if custom_params is None: + custom_params = {} + + if finetune_method == 'textual_inversion': + custom_params['placeholder_token'] = placeholder_token + custom_params['initializer_token'] = initializer_token + logger.info(f"TI Tokens - Placeholder: '{placeholder_token}', Initializer: '{initializer_token}'") # 彻底清空输出目录(避免旧文件残留,特别是 textual_inversion 的 token) logger.info(f"Completely clearing output directories...") diff --git a/src/backend/init_db.py b/src/backend/init_db.py index eecfdc5..7e70583 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -85,8 +85,22 @@ def init_database(): # 初始化数据集类型数据 data_types = [ - {'data_type_code': 'facial', 'data_type_prompt': 'a photo of sks person', 'description': '人脸类型的数据集'}, - {'data_type_code': 'art', 'data_type_prompt': 'a painting in the style of sks', 'description': '艺术品类型的数据集'} + { + 'data_type_code': 'face', + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'placeholder_token': '', + 'initializer_token': 'person', + 'description': '人脸类型的数据集' + }, + { + 'data_type_code': 'art', + 'instance_prompt': 'a painting in sks style', + 'class_prompt': 'a painting', + 'placeholder_token': '', + 'initializer_token': 'painting', + 'description': '艺术品类型的数据集' + } ] for data_type in data_types: existing = DataType.query.filter_by(data_type_code=data_type['data_type_code']).first() -- 2.34.1 From d9781182dd32fba8299c85792314d9b5055185f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 02:34:58 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AF=B9?= =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E5=9B=BE=E7=89=87=E7=9A=84=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=9A=84=E6=96=87=E4=BB=B6=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/controllers/task_controller.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 9566a69..2c5c94d 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -174,6 +174,19 @@ def create_perturbation_task(current_user_id): if not PerturbationConfig.query.get(perturbation_configs_id): return TaskService.json_error('加噪配置不存在') + # 验证上传的图片 + files = request.files.getlist('files') if hasattr(request, 'files') else [] + if not files: + return TaskService.json_error('请上传至少一张图片') + + allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for file in files: + if not file.filename: + continue + _, ext = os.path.splitext(file.filename) + if ext.lower() not in allowed_extensions: + return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。') + try: flow_id = data.get('flow_id') flow_id = int(flow_id) if flow_id is not None else TaskService.generate_flow_id() @@ -208,7 +221,6 @@ def create_perturbation_task(current_user_id): db.session.commit() # 自动上传图片 - files = request.files.getlist('files') if hasattr(request, 'files') else [] target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) @@ -504,6 +516,19 @@ def create_finetune_from_upload(current_user_id): if not data_type: return TaskService.json_error('数据集类型不存在') + # 验证上传的图片 + files = request.files.getlist('files') if hasattr(request, 'files') else [] + if not files: + return TaskService.json_error('请上传至少一张图片') + + allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for file in files: + if not file.filename: + continue + _, ext = os.path.splitext(file.filename) + if ext.lower() not in allowed_extensions: + return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。') + try: flow_id = data.get('flow_id') if flow_id is not None: @@ -544,7 +569,6 @@ def create_finetune_from_upload(current_user_id): db.session.commit() # 自动上传图片(仅上传微调任务) - files = request.files.getlist('files') if hasattr(request, 'files') else [] target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) -- 2.34.1 From c0554206692448ebc7c54f6820857008abcfdb5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 02:37:27 +0800 Subject: [PATCH 4/8] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=8A=A0=E5=99=AA?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E5=AD=97=E6=AE=B5=E9=94=99=E8=AF=AF=E5=92=8C?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E6=89=A7=E8=A1=8C=E8=84=9A=E6=9C=AC=E7=89=B9?= =?UTF-8?q?=E6=AE=8A=E7=AC=A6=E5=8F=B7=E9=94=99=E8=AF=AF=E8=A7=A3=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/workers/finetune_worker.py | 7 ++++++- src/backend/app/workers/perturbation_worker.py | 7 +++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 5738f1e..713089b 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -286,7 +286,12 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ if value: cmd_args.append(f"--{key}") else: - cmd_args.append(f"--{key}={value}") + # 特殊处理包含空格或特殊字符的参数,防止 conda run 的 shell 脚本解析错误 + # 特别是 < > 符号会被 shell 解释为重定向 + if isinstance(value, str) and any(c in value for c in [' ', '<', '>', '&', '|', ';', '(', ')']): + cmd_args.append(f"--{key}='{value}'") + else: + cmd_args.append(f"--{key}={value}") # 设置环境变量 env = os.environ.copy() diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 1850920..a88a60b 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -169,10 +169,9 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, if not data_type: raise ValueError(f"DataType {perturbation.data_type_id} not found") - # 从data_type_prompt中提取instance_prompt - instance_prompt = data_type.data_type_prompt or 'a photo of sks person' - # 从instance_prompt生成class_prompt(移除"sks") - class_prompt = instance_prompt.replace('sks ', '') + # 从数据库获取数据集类型的提示词 + instance_prompt = data_type.instance_prompt or 'a photo of sks person' + class_prompt = data_type.class_prompt or 'a photo of person' logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") -- 2.34.1 From b4a6ab47975bd004ea4b740e56079cf2170baf04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 09:36:47 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat:=20=E5=88=9B=E5=BB=BA=E5=BE=AE?= =?UTF-8?q?=E8=B0=83=E4=BB=BB=E5=8A=A1=E6=97=B6=EF=BC=8C=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E4=B8=BA=E5=AF=B9=E5=BA=94?= =?UTF-8?q?=E5=8A=A0=E5=99=AA=E4=BB=BB=E5=8A=A1=E9=80=89=E6=8B=A9=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/controllers/task_controller.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 2c5c94d..8da27a1 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -440,6 +440,14 @@ def create_finetune_from_perturbation(current_user_id): if not FinetuneConfig.query.get(finetune_configs_id): return TaskService.json_error('微调配置不存在') + # 确定 data_type_id:优先使用用户输入,否则继承自加噪任务 + data_type_id = data.get('data_type_id') + if not data_type_id and perturbation_task.perturbation: + data_type_id = perturbation_task.perturbation.data_type_id + + if data_type_id and not DataType.query.get(data_type_id): + return TaskService.json_error('数据集类型不存在') + try: waiting_status = TaskService.ensure_status('waiting') finetune_type = TaskService.require_task_type('finetune') @@ -460,7 +468,7 @@ def create_finetune_from_perturbation(current_user_id): finetune = Finetune( tasks_id=task.tasks_id, finetune_configs_id=finetune_configs_id, - data_type_id=data.get('data_type_id'), + data_type_id=data_type_id, finetune_name=data.get('finetune_name'), custom_prompt=data.get('custom_prompt') ) -- 2.34.1 From 53e62ed5cd4d35a480b84b7242413978f03a6baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 09:38:01 +0800 Subject: [PATCH 6/8] =?UTF-8?q?fix:=20=E5=BC=80=E5=A7=8B=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E6=97=B6=E5=B0=86=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E7=BD=AE?= =?UTF-8?q?=E4=B8=BA'waiting'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/services/task_service.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 45aaa28..9e6bb13 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -404,6 +404,12 @@ class TaskService: if not perturbation: logger.error(f"Perturbation task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 获取用户ID user_id = task.user_id @@ -481,6 +487,12 @@ class TaskService: if not finetune: logger.error(f"Finetune task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 获取用户ID user_id = task.user_id @@ -642,6 +654,12 @@ class TaskService: if not heatmap: logger.error(f"Heatmap task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 从heatmap对象获取扰动图片ID perturbed_image_id = heatmap.images_id @@ -730,6 +748,12 @@ class TaskService: if not evaluate: logger.error(f"Evaluate task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() finetune = Finetune.query.get(evaluate.finetune_task_id) if not finetune: -- 2.34.1 From 354eafade2bcaaf9c84b9201a2a715d305d55c50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 09:38:51 +0800 Subject: [PATCH 7/8] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E4=BC=A0=E5=85=A5?= =?UTF-8?q?=E5=8A=A0=E5=99=AA=E7=AE=97=E6=B3=95=E7=9A=84=E5=99=AA=E5=A3=B0?= =?UTF-8?q?=E5=BC=BA=E5=BA=A6=E7=9A=84=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/workers/perturbation_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index a88a60b..7578af4 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -204,14 +204,14 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={int(epsilon)}", + f"--eps={float(epsilon)}", ]) elif algorithm_code == 'pid': # PID参数结构 cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={int(epsilon)}", + f"--eps={float(epsilon)}", ]) elif algorithm_code == 'glaze': # Glaze参数结构 -- 2.34.1 From 8ded86686757b1411949d2b314106a8886dd7c49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 13 Dec 2025 10:15:29 +0800 Subject: [PATCH 8/8] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E5=90=8E?= =?UTF-8?q?=E7=AB=AFapi=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/project/02-设计文档/backend-api.md | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/doc/project/02-设计文档/backend-api.md b/doc/project/02-设计文档/backend-api.md index 423dbc2..fef2e25 100644 --- a/doc/project/02-设计文档/backend-api.md +++ b/doc/project/02-设计文档/backend-api.md @@ -15,7 +15,7 @@ - `500 Internal Server Error`:服务器内部错误。 - **JWT 身份错误**:使用 `@jwt_required` 的接口在缺少或失效 Token 时会由 Flask-JWT-Extended 返回标准 401 响应;使用 `@int_jwt_required` 的接口若无法将身份标识转换为整数,则返回 `{"error": "无效的用户身份标识"}`(401)。 - **任务类型代码**:`perturbation`(加噪)、`finetune`(微调)、`heatmap`(热力图)、`evaluate`(评估)。 -- **任务状态代码**:需与 `task_status` 表保持一致(如 `pending`、`processing`、`completed`、`failed` 等)。 +- **任务状态代码**:需与 `task_status` 表保持一致(如 `waiting`、`processing`、`completed`、`failed` 等)。 --- @@ -208,7 +208,7 @@ ### GET `/api/task` **功能**:以可选筛选条件返回当前用户的所有任务摘要。 **认证**:是 -**查询参数**:`task_type`=`perturbation|finetune|heatmap|evaluate|all`,`task_status`=`pending|processing|completed|failed|all`。 +**查询参数**:`task_type`=`perturbation|finetune|heatmap|evaluate|all`,`task_status`=`waiting|processing|completed|failed|all`。 **成功响应** `200 OK`: ```json { @@ -369,7 +369,7 @@ - `403 {"error": "普通用户仅可使用人脸数据集"}` - `400 {"error": "加噪配置不存在"}` - `400 {"error": "非法的 flow_id 参数"}` -- `500 {"error": "Task status 'pending' is not configured"}` / `{...}` +- `500 {"error": "Task status 'waiting' is not configured"}` / `{...}` - `500 {"error": "创建任务失败: ..."}` ##### PATCH `/api/task/perturbation/` @@ -390,7 +390,7 @@ - `500 {"error": "更新任务失败: ..."}`(数据库提交失败或参数类型转换异常) ##### POST `/api/task/perturbation//start` -**功能**:向异步队列提交该加噪任务。 +**功能**:向异步队列提交该加噪任务,并将任务状态重置为 `waiting`。 **成功响应** `200 OK`: ```json {"message": "任务已启动", "job_id": "pert_901"} @@ -435,11 +435,11 @@ - `404 {"error": "加噪任务不存在或无权限"}` - `400 {"error": "仅支持已完成的加噪任务创建热力图"}` - `400 {"error": "扰动图片不存在或不属于该任务"}` -- `500 {"error": "Task type 'heatmap' is not configured"}` 或 `{"error": "Task status 'pending' is not configured"}` +- `500 {"error": "Task type 'heatmap' is not configured"}` 或 `{"error": "Task status 'waiting' is not configured"}` - `500 {"error": "创建热力图任务失败: ..."}` ##### POST `/api/task/heatmap//start` -**功能**:触发热力图任务执行。 +**功能**:触发热力图任务执行,并将任务状态重置为 `waiting`。 **成功响应** `200 OK`:`{"message": "任务已启动", "job_id": "hm_1201"}` **错误响应**: - `401 {"error": "无效的用户身份标识"}` @@ -478,16 +478,21 @@ { "perturbation_task_id": 901, "finetune_configs_id": 4, - "finetune_name": "LoRA-人脸" + "finetune_name": "LoRA-人脸", + "data_type_id": 3, + "custom_prompt": "a photo of sks person" } ``` +> `data_type_id` 为可选参数,若不填则自动继承自加噪任务。 +> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。 + **成功响应**:`{"message": "微调任务已创建", "task": {...}}` 任务对象中 `finetune.source` 字段为 `perturbation`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `400 {"error": "缺少必要参数: perturbation_task_id 或 finetune_configs_id"}` - `404 {"error": "加噪任务不存在或无权限"}` - `400 {"error": "微调配置不存在"}` -- `500 {"error": "Task status 'pending' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}` +- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}` - `500 {"error": "创建微调任务失败: ..."}` ##### POST `/api/task/finetune/from-upload` @@ -496,10 +501,13 @@ - `finetune_configs_id`(数字,必填) - `data_type_id`(数字,必填) - `finetune_name`(字符串,可选) +- `custom_prompt`(字符串,可选) - `description`(字符串,可选) - `flow_id`(数字,可选) - `files`(一个或多个图片文件,可选) +> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。 + **成功响应** `201 Created`: ```json { @@ -518,12 +526,12 @@ - `400 {"error": "缺少必要参数: data_type_id"}` - `400 {"error": "数据集类型不存在"}` - `400 {"error": "非法的 flow_id 参数"}` 或 `{...}` -- `500 {"error": "Task status 'pending' is not configured"}` 或 `{...}` +- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{...}` - `500 {"error": "创建微调任务失败: ..."}` ##### POST `/api/task/finetune//start` 成功响应 `{"message": "任务已启动", "job_id": "ft_982"}`。 -**功能**:启动指定微调任务的后台执行。 +**功能**:启动指定微调任务的后台执行,并将任务状态重置为 `waiting`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` @@ -563,12 +571,12 @@ - `400 {"error": "该微调任务已存在评估,请勿重复创建"}` - `400 {"error": "数值评估仅支持基于加噪任务的微调结果"}` - `400 {"error": "微调任务未配置详情"}` -- `500 {"error": "Task type 'evaluate' is not configured"}` 或 `{"error": "Task status 'pending' is not configured"}` +- `500 {"error": "Task type 'evaluate' is not configured"}` 或 `{"error": "Task status 'waiting' is not configured"}` - `500 {"error": "创建评估任务失败: ..."}` ##### POST `/api/task/evaluate//start` 成功响应 `{"message": "任务已启动", "job_id": "eval_1301"}`。 -**功能**:推送评估任务进入执行队列。 +**功能**:推送评估任务进入执行队列,并将任务状态重置为 `waiting`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` @@ -818,12 +826,6 @@ - `403 {"error": "需要管理员权限"}`(预期;当前代码在鉴权阶段可能直接抛出 500)。 - `500 {"error": "获取系统统计失败: ..."}`(任务状态统计语句抛错) - - - - -add: - --- ## Auth 模块补充 @@ -1565,3 +1567,9 @@ Authorization: Bearer - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` - `500 {"error": "读取日志失败: ..."}` +--- + +## 文档更新记录 + +- [POST /api/task/finetune/from-perturbation](#post-apitaskfinetunefrom-perturbation):新增 `custom_prompt` 参数。 +- [POST /api/task/finetune/from-upload](#post-apitaskfinetunefrom-upload):新增 `custom_prompt` 参数。 -- 2.34.1