diff --git a/doc/project/02-设计文档/backend-api.md b/doc/project/02-设计文档/backend-api.md index 423dbc2..fef2e25 100644 --- a/doc/project/02-设计文档/backend-api.md +++ b/doc/project/02-设计文档/backend-api.md @@ -15,7 +15,7 @@ - `500 Internal Server Error`:服务器内部错误。 - **JWT 身份错误**:使用 `@jwt_required` 的接口在缺少或失效 Token 时会由 Flask-JWT-Extended 返回标准 401 响应;使用 `@int_jwt_required` 的接口若无法将身份标识转换为整数,则返回 `{"error": "无效的用户身份标识"}`(401)。 - **任务类型代码**:`perturbation`(加噪)、`finetune`(微调)、`heatmap`(热力图)、`evaluate`(评估)。 -- **任务状态代码**:需与 `task_status` 表保持一致(如 `pending`、`processing`、`completed`、`failed` 等)。 +- **任务状态代码**:需与 `task_status` 表保持一致(如 `waiting`、`processing`、`completed`、`failed` 等)。 --- @@ -208,7 +208,7 @@ ### GET `/api/task` **功能**:以可选筛选条件返回当前用户的所有任务摘要。 **认证**:是 -**查询参数**:`task_type`=`perturbation|finetune|heatmap|evaluate|all`,`task_status`=`pending|processing|completed|failed|all`。 +**查询参数**:`task_type`=`perturbation|finetune|heatmap|evaluate|all`,`task_status`=`waiting|processing|completed|failed|all`。 **成功响应** `200 OK`: ```json { @@ -369,7 +369,7 @@ - `403 {"error": "普通用户仅可使用人脸数据集"}` - `400 {"error": "加噪配置不存在"}` - `400 {"error": "非法的 flow_id 参数"}` -- `500 {"error": "Task status 'pending' is not configured"}` / `{...}` +- `500 {"error": "Task status 'waiting' is not configured"}` / `{...}` - `500 {"error": "创建任务失败: ..."}` ##### PATCH `/api/task/perturbation/` @@ -390,7 +390,7 @@ - `500 {"error": "更新任务失败: ..."}`(数据库提交失败或参数类型转换异常) ##### POST `/api/task/perturbation//start` -**功能**:向异步队列提交该加噪任务。 +**功能**:向异步队列提交该加噪任务,并将任务状态重置为 `waiting`。 **成功响应** `200 OK`: ```json {"message": "任务已启动", "job_id": "pert_901"} @@ -435,11 +435,11 @@ - `404 {"error": "加噪任务不存在或无权限"}` - `400 {"error": "仅支持已完成的加噪任务创建热力图"}` - `400 {"error": "扰动图片不存在或不属于该任务"}` -- `500 {"error": "Task type 'heatmap' is not configured"}` 或 `{"error": "Task status 'pending' is not configured"}` +- `500 {"error": "Task type 'heatmap' is not configured"}` 或 `{"error": "Task status 'waiting' is not configured"}` - `500 {"error": "创建热力图任务失败: ..."}` ##### POST `/api/task/heatmap//start` -**功能**:触发热力图任务执行。 +**功能**:触发热力图任务执行,并将任务状态重置为 `waiting`。 **成功响应** `200 OK`:`{"message": "任务已启动", "job_id": "hm_1201"}` **错误响应**: - `401 {"error": "无效的用户身份标识"}` @@ -478,16 +478,21 @@ { "perturbation_task_id": 901, "finetune_configs_id": 4, - "finetune_name": "LoRA-人脸" + "finetune_name": "LoRA-人脸", + "data_type_id": 3, + "custom_prompt": "a photo of sks person" } ``` +> `data_type_id` 为可选参数,若不填则自动继承自加噪任务。 +> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。 + **成功响应**:`{"message": "微调任务已创建", "task": {...}}` 任务对象中 `finetune.source` 字段为 `perturbation`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `400 {"error": "缺少必要参数: perturbation_task_id 或 finetune_configs_id"}` - `404 {"error": "加噪任务不存在或无权限"}` - `400 {"error": "微调配置不存在"}` -- `500 {"error": "Task status 'pending' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}` +- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}` - `500 {"error": "创建微调任务失败: ..."}` ##### POST `/api/task/finetune/from-upload` @@ -496,10 +501,13 @@ - `finetune_configs_id`(数字,必填) - `data_type_id`(数字,必填) - `finetune_name`(字符串,可选) +- `custom_prompt`(字符串,可选) - `description`(字符串,可选) - `flow_id`(数字,可选) - `files`(一个或多个图片文件,可选) +> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。 + **成功响应** `201 Created`: ```json { @@ -518,12 +526,12 @@ - `400 {"error": "缺少必要参数: data_type_id"}` - `400 {"error": "数据集类型不存在"}` - `400 {"error": "非法的 flow_id 参数"}` 或 `{...}` -- `500 {"error": "Task status 'pending' is not configured"}` 或 `{...}` +- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{...}` - `500 {"error": "创建微调任务失败: ..."}` ##### POST `/api/task/finetune//start` 成功响应 `{"message": "任务已启动", "job_id": "ft_982"}`。 -**功能**:启动指定微调任务的后台执行。 +**功能**:启动指定微调任务的后台执行,并将任务状态重置为 `waiting`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` @@ -563,12 +571,12 @@ - `400 {"error": "该微调任务已存在评估,请勿重复创建"}` - `400 {"error": "数值评估仅支持基于加噪任务的微调结果"}` - `400 {"error": "微调任务未配置详情"}` -- `500 {"error": "Task type 'evaluate' is not configured"}` 或 `{"error": "Task status 'pending' is not configured"}` +- `500 {"error": "Task type 'evaluate' is not configured"}` 或 `{"error": "Task status 'waiting' is not configured"}` - `500 {"error": "创建评估任务失败: ..."}` ##### POST `/api/task/evaluate//start` 成功响应 `{"message": "任务已启动", "job_id": "eval_1301"}`。 -**功能**:推送评估任务进入执行队列。 +**功能**:推送评估任务进入执行队列,并将任务状态重置为 `waiting`。 **错误响应**: - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` @@ -818,12 +826,6 @@ - `403 {"error": "需要管理员权限"}`(预期;当前代码在鉴权阶段可能直接抛出 500)。 - `500 {"error": "获取系统统计失败: ..."}`(任务状态统计语句抛错) - - - - -add: - --- ## Auth 模块补充 @@ -1565,3 +1567,9 @@ Authorization: Bearer - `401 {"error": "无效的用户身份标识"}` - `404 {"error": "任务不存在或无权限"}` - `500 {"error": "读取日志失败: ..."}` +--- + +## 文档更新记录 + +- [POST /api/task/finetune/from-perturbation](#post-apitaskfinetunefrom-perturbation):新增 `custom_prompt` 参数。 +- [POST /api/task/finetune/from-upload](#post-apitaskfinetunefrom-upload):新增 `custom_prompt` 参数。 diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 0a30e7e..8da27a1 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -174,6 +174,19 @@ def create_perturbation_task(current_user_id): if not PerturbationConfig.query.get(perturbation_configs_id): return TaskService.json_error('加噪配置不存在') + # 验证上传的图片 + files = request.files.getlist('files') if hasattr(request, 'files') else [] + if not files: + return TaskService.json_error('请上传至少一张图片') + + allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for file in files: + if not file.filename: + continue + _, ext = os.path.splitext(file.filename) + if ext.lower() not in allowed_extensions: + return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。') + try: flow_id = data.get('flow_id') flow_id = int(flow_id) if flow_id is not None else TaskService.generate_flow_id() @@ -208,7 +221,6 @@ def create_perturbation_task(current_user_id): db.session.commit() # 自动上传图片 - files = request.files.getlist('files') if hasattr(request, 'files') else [] target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) @@ -428,6 +440,14 @@ def create_finetune_from_perturbation(current_user_id): if not FinetuneConfig.query.get(finetune_configs_id): return TaskService.json_error('微调配置不存在') + # 确定 data_type_id:优先使用用户输入,否则继承自加噪任务 + data_type_id = data.get('data_type_id') + if not data_type_id and perturbation_task.perturbation: + data_type_id = perturbation_task.perturbation.data_type_id + + if data_type_id and not DataType.query.get(data_type_id): + return TaskService.json_error('数据集类型不存在') + try: waiting_status = TaskService.ensure_status('waiting') finetune_type = TaskService.require_task_type('finetune') @@ -448,8 +468,9 @@ def create_finetune_from_perturbation(current_user_id): finetune = Finetune( tasks_id=task.tasks_id, finetune_configs_id=finetune_configs_id, - data_type_id=data.get('data_type_id'), - finetune_name=data.get('finetune_name') + data_type_id=data_type_id, + finetune_name=data.get('finetune_name'), + custom_prompt=data.get('custom_prompt') ) db.session.add(finetune) db.session.commit() @@ -503,6 +524,19 @@ def create_finetune_from_upload(current_user_id): if not data_type: return TaskService.json_error('数据集类型不存在') + # 验证上传的图片 + files = request.files.getlist('files') if hasattr(request, 'files') else [] + if not files: + return TaskService.json_error('请上传至少一张图片') + + allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for file in files: + if not file.filename: + continue + _, ext = os.path.splitext(file.filename) + if ext.lower() not in allowed_extensions: + return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。') + try: flow_id = data.get('flow_id') if flow_id is not None: @@ -536,13 +570,13 @@ def create_finetune_from_upload(current_user_id): tasks_id=task.tasks_id, finetune_configs_id=finetune_configs_id, data_type_id=data_type_id, - finetune_name=data.get('finetune_name') + finetune_name=data.get('finetune_name'), + custom_prompt=data.get('custom_prompt') ) db.session.add(finetune) db.session.commit() # 自动上传图片(仅上传微调任务) - files = request.files.getlist('files') if hasattr(request, 'files') else [] target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index aedb9e8..300610c 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -98,7 +98,10 @@ class DataType(db.Model): __tablename__ = 'data_type' data_type_id = db.Column(Integer, primary_key=True) data_type_code = db.Column(String(50), nullable=False) - data_type_prompt = db.Column(Text, comment='数据集相关的Prompt') + instance_prompt = db.Column(Text, comment='数据集相关的Prompt (Instance Prompt Template, e.g. "a photo of sks person")') + class_prompt = db.Column(String(255), comment='类别Prompt (e.g. "a photo of person")') + placeholder_token = db.Column(String(50), comment='TI Placeholder (e.g. "")') + initializer_token = db.Column(String(50), comment='TI Initializer (e.g. "person")') description = db.Column(Text) def __repr__(self): @@ -245,6 +248,7 @@ class Finetune(db.Model): finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='微调配置ID') data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集') finetune_name = db.Column(String(100), comment='微调任务名称') + custom_prompt = db.Column(String(255), comment='用户自定义提示词(如 a photo of sks cat)') task = db.relationship('Task', back_populates='finetune') finetune_config = db.relationship('FinetuneConfig') diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 4fdd4fc..9e6bb13 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -231,6 +231,7 @@ class TaskService: 'finetune_configs_id': task.finetune.finetune_configs_id, 'data_type_id': task.finetune.data_type_id, 'finetune_name': task.finetune.finetune_name, + 'custom_prompt': task.finetune.custom_prompt, 'source': source } elif task_type == 'heatmap' and task.heatmap: @@ -403,6 +404,12 @@ class TaskService: if not perturbation: logger.error(f"Perturbation task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 获取用户ID user_id = task.user_id @@ -480,6 +487,12 @@ class TaskService: if not finetune: logger.error(f"Finetune task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 获取用户ID user_id = task.user_id @@ -641,6 +654,12 @@ class TaskService: if not heatmap: logger.error(f"Heatmap task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() # 从heatmap对象获取扰动图片ID perturbed_image_id = heatmap.images_id @@ -729,6 +748,12 @@ class TaskService: if not evaluate: logger.error(f"Evaluate task {task_id} not found") return None + + # 更新任务状态为 waiting + waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() + if waiting_status: + task.tasks_status_id = waiting_status.task_status_id + db.session.commit() finetune = Finetune.query.get(evaluate.finetune_task_id) if not finetune: diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 57eef16..713089b 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -66,20 +66,55 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}") - # 从数据库获取数据集类型的提示词 - # 从Finetune表的data_type_id获取 - instance_prompt = "a photo of sks person" # 默认值 - class_prompt = "a photo of person" # 默认值 - validation_prompt = "a photo of sks person" # 默认值 + # 获取 DataType 配置 + data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None - if finetune.data_type_id: - data_type = DataType.query.get(finetune.data_type_id) - if data_type and data_type.data_type_prompt: - instance_prompt = data_type.data_type_prompt - validation_prompt = instance_prompt - # 从instance_prompt生成class_prompt(移除"sks") - class_prompt = instance_prompt.replace('sks ', '') - logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") + # 默认值 (Fallback) + instance_prompt = "a photo of sks person" + class_prompt = "a photo of person" + placeholder_token = "" + initializer_token = "person" + + if data_type: + if data_type.instance_prompt: + instance_prompt = data_type.instance_prompt + if data_type.class_prompt: + class_prompt = data_type.class_prompt + if data_type.placeholder_token: + placeholder_token = data_type.placeholder_token + if data_type.initializer_token: + initializer_token = data_type.initializer_token + + logger.info(f"DataType Config - Template: '{instance_prompt}', Class: '{class_prompt}'") + + # 根据微调方法调整 Instance Prompt + if finetune_method == 'textual_inversion': + # TI: 将 'sks' 替换为 placeholder_token + instance_prompt_prefix = instance_prompt.replace('sks', placeholder_token) + else: + # DreamBooth/LoRA: 直接使用模板 + instance_prompt_prefix = instance_prompt + + # 处理 Validation Prompt (拼接后缀) + prompt_suffix = finetune.custom_prompt.strip() if finetune.custom_prompt else "" + + if prompt_suffix: + validation_prompt = f"{instance_prompt_prefix}, {prompt_suffix}" + else: + validation_prompt = instance_prompt_prefix + + instance_prompt = instance_prompt_prefix + + logger.info(f"Prompts Finalized - Instance: '{instance_prompt}', Class: '{class_prompt}', Validation: '{validation_prompt}'") + + # 设置 TI 特有参数 + if custom_params is None: + custom_params = {} + + if finetune_method == 'textual_inversion': + custom_params['placeholder_token'] = placeholder_token + custom_params['initializer_token'] = initializer_token + logger.info(f"TI Tokens - Placeholder: '{placeholder_token}', Initializer: '{initializer_token}'") # 彻底清空输出目录(避免旧文件残留,特别是 textual_inversion 的 token) logger.info(f"Completely clearing output directories...") @@ -251,7 +286,12 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_ if value: cmd_args.append(f"--{key}") else: - cmd_args.append(f"--{key}={value}") + # 特殊处理包含空格或特殊字符的参数,防止 conda run 的 shell 脚本解析错误 + # 特别是 < > 符号会被 shell 解释为重定向 + if isinstance(value, str) and any(c in value for c in [' ', '<', '>', '&', '|', ';', '(', ')']): + cmd_args.append(f"--{key}='{value}'") + else: + cmd_args.append(f"--{key}={value}") # 设置环境变量 env = os.environ.copy() diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 1850920..7578af4 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -169,10 +169,9 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, if not data_type: raise ValueError(f"DataType {perturbation.data_type_id} not found") - # 从data_type_prompt中提取instance_prompt - instance_prompt = data_type.data_type_prompt or 'a photo of sks person' - # 从instance_prompt生成class_prompt(移除"sks") - class_prompt = instance_prompt.replace('sks ', '') + # 从数据库获取数据集类型的提示词 + instance_prompt = data_type.instance_prompt or 'a photo of sks person' + class_prompt = data_type.class_prompt or 'a photo of person' logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'") @@ -205,14 +204,14 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id, cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={int(epsilon)}", + f"--eps={float(epsilon)}", ]) elif algorithm_code == 'pid': # PID参数结构 cmd_args.extend([ f"--instance_data_dir={input_dir}", f"--output_dir={output_dir}", - f"--eps={int(epsilon)}", + f"--eps={float(epsilon)}", ]) elif algorithm_code == 'glaze': # Glaze参数结构 diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 17fb2f5..072f6b5 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -238,9 +238,9 @@ class AlgorithmConfig: 'conda_env': CONDA_ENVS['textual_inversion'], 'default_params': { 'pretrained_model_name_or_path': MODELS_DIR['model2'], - 'placeholder_token': 'sks', + 'placeholder_token': '', 'initializer_token': 'person', - 'instance_prompt': 'a photo of sks person', + 'instance_prompt': 'a photo of person', 'resolution': 512, 'train_batch_size': 1, 'gradient_accumulation_steps': 1, @@ -251,7 +251,7 @@ class AlgorithmConfig: 'checkpointing_steps': 2, 'seed': 0, 'mixed_precision': 'fp16', - 'validation_prompt': 'a photo of sks person', + 'validation_prompt': 'a photo of person', 'num_validation_images': 4, 'validation_epochs': 50, 'coords_log_interval': 10 diff --git a/src/backend/init_db.py b/src/backend/init_db.py index eecfdc5..7e70583 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -85,8 +85,22 @@ def init_database(): # 初始化数据集类型数据 data_types = [ - {'data_type_code': 'facial', 'data_type_prompt': 'a photo of sks person', 'description': '人脸类型的数据集'}, - {'data_type_code': 'art', 'data_type_prompt': 'a painting in the style of sks', 'description': '艺术品类型的数据集'} + { + 'data_type_code': 'face', + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'placeholder_token': '', + 'initializer_token': 'person', + 'description': '人脸类型的数据集' + }, + { + 'data_type_code': 'art', + 'instance_prompt': 'a painting in sks style', + 'class_prompt': 'a painting', + 'placeholder_token': '', + 'initializer_token': 'painting', + 'description': '艺术品类型的数据集' + } ] for data_type in data_types: existing = DataType.query.filter_by(data_type_code=data_type['data_type_code']).first()