将lianghao_branch合并到develop #23

Merged
hnu202326010204 merged 9 commits from lianghao_branch into develop 1 month ago

@ -15,7 +15,7 @@
- `500 Internal Server Error`:服务器内部错误。
- **JWT 身份错误**:使用 `@jwt_required` 的接口在缺少或失效 Token 时会由 Flask-JWT-Extended 返回标准 401 响应;使用 `@int_jwt_required` 的接口若无法将身份标识转换为整数,则返回 `{"error": "无效的用户身份标识"}`401
- **任务类型代码**`perturbation`(加噪)、`finetune`(微调)、`heatmap`(热力图)、`evaluate`(评估)。
- **任务状态代码**:需与 `task_status` 表保持一致(如 `pending`、`processing`、`completed`、`failed` 等)。
- **任务状态代码**:需与 `task_status` 表保持一致(如 `waiting`、`processing`、`completed`、`failed` 等)。
---
@ -208,7 +208,7 @@
### GET `/api/task`
**功能**:以可选筛选条件返回当前用户的所有任务摘要。
**认证**:是
**查询参数**`task_type`=`perturbation|finetune|heatmap|evaluate|all``task_status`=`pending|processing|completed|failed|all`。
**查询参数**`task_type`=`perturbation|finetune|heatmap|evaluate|all``task_status`=`waiting|processing|completed|failed|all`。
**成功响应** `200 OK`
```json
{
@ -369,7 +369,7 @@
- `403 {"error": "普通用户仅可使用人脸数据集"}`
- `400 {"error": "加噪配置不存在"}`
- `400 {"error": "非法的 flow_id 参数"}`
- `500 {"error": "Task status 'pending' is not configured"}` / `{...}`
- `500 {"error": "Task status 'waiting' is not configured"}` / `{...}`
- `500 {"error": "创建任务失败: ..."}`
##### PATCH `/api/task/perturbation/<task_id>`
@ -390,7 +390,7 @@
- `500 {"error": "更新任务失败: ..."}`(数据库提交失败或参数类型转换异常)
##### POST `/api/task/perturbation/<task_id>/start`
**功能**:向异步队列提交该加噪任务。
**功能**:向异步队列提交该加噪任务,并将任务状态重置为 `waiting`
**成功响应** `200 OK`
```json
{"message": "任务已启动", "job_id": "pert_901"}
@ -435,11 +435,11 @@
- `404 {"error": "加噪任务不存在或无权限"}`
- `400 {"error": "仅支持已完成的加噪任务创建热力图"}`
- `400 {"error": "扰动图片不存在或不属于该任务"}`
- `500 {"error": "Task type 'heatmap' is not configured"}``{"error": "Task status 'pending' is not configured"}`
- `500 {"error": "Task type 'heatmap' is not configured"}``{"error": "Task status 'waiting' is not configured"}`
- `500 {"error": "创建热力图任务失败: ..."}`
##### POST `/api/task/heatmap/<task_id>/start`
**功能**:触发热力图任务执行。
**功能**:触发热力图任务执行,并将任务状态重置为 `waiting`
**成功响应** `200 OK``{"message": "任务已启动", "job_id": "hm_1201"}`
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
@ -478,16 +478,21 @@
{
"perturbation_task_id": 901,
"finetune_configs_id": 4,
"finetune_name": "LoRA-人脸"
"finetune_name": "LoRA-人脸",
"data_type_id": 3,
"custom_prompt": "a photo of sks person"
}
```
> `data_type_id` 为可选参数,若不填则自动继承自加噪任务。
> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。
**成功响应**`{"message": "微调任务已创建", "task": {...}}` 任务对象中 `finetune.source` 字段为 `perturbation`
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `400 {"error": "缺少必要参数: perturbation_task_id 或 finetune_configs_id"}`
- `404 {"error": "加噪任务不存在或无权限"}`
- `400 {"error": "微调配置不存在"}`
- `500 {"error": "Task status 'pending' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}`
- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{"error": "Task type 'finetune' is not configured"}`
- `500 {"error": "创建微调任务失败: ..."}`
##### POST `/api/task/finetune/from-upload`
@ -496,10 +501,13 @@
- `finetune_configs_id`(数字,必填)
- `data_type_id`(数字,必填)
- `finetune_name`(字符串,可选)
- `custom_prompt`(字符串,可选)
- `description`(字符串,可选)
- `flow_id`(数字,可选)
- `files`(一个或多个图片文件,可选)
> `custom_prompt` 为可选参数,用于自定义微调训练时的提示词。
**成功响应** `201 Created`
```json
{
@ -518,12 +526,12 @@
- `400 {"error": "缺少必要参数: data_type_id"}`
- `400 {"error": "数据集类型不存在"}`
- `400 {"error": "非法的 flow_id 参数"}``{...}`
- `500 {"error": "Task status 'pending' is not configured"}` 或 `{...}`
- `500 {"error": "Task status 'waiting' is not configured"}` 或 `{...}`
- `500 {"error": "创建微调任务失败: ..."}`
##### POST `/api/task/finetune/<task_id>/start`
成功响应 `{"message": "任务已启动", "job_id": "ft_982"}`
**功能**:启动指定微调任务的后台执行。
**功能**:启动指定微调任务的后台执行,并将任务状态重置为 `waiting`
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
@ -563,12 +571,12 @@
- `400 {"error": "该微调任务已存在评估,请勿重复创建"}`
- `400 {"error": "数值评估仅支持基于加噪任务的微调结果"}`
- `400 {"error": "微调任务未配置详情"}`
- `500 {"error": "Task type 'evaluate' is not configured"}``{"error": "Task status 'pending' is not configured"}`
- `500 {"error": "Task type 'evaluate' is not configured"}``{"error": "Task status 'waiting' is not configured"}`
- `500 {"error": "创建评估任务失败: ..."}`
##### POST `/api/task/evaluate/<task_id>/start`
成功响应 `{"message": "任务已启动", "job_id": "eval_1301"}`
**功能**:推送评估任务进入执行队列。
**功能**:推送评估任务进入执行队列,并将任务状态重置为 `waiting`
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
@ -818,12 +826,6 @@
- `403 {"error": "需要管理员权限"}`(预期;当前代码在鉴权阶段可能直接抛出 500
- `500 {"error": "获取系统统计失败: ..."}`(任务状态统计语句抛错)
add:
---
## Auth 模块补充
@ -1565,3 +1567,9 @@ Authorization: Bearer <token>
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
- `500 {"error": "读取日志失败: ..."}`
---
## 文档更新记录
- [POST /api/task/finetune/from-perturbation](#post-apitaskfinetunefrom-perturbation):新增 `custom_prompt` 参数。
- [POST /api/task/finetune/from-upload](#post-apitaskfinetunefrom-upload):新增 `custom_prompt` 参数。

@ -174,6 +174,19 @@ def create_perturbation_task(current_user_id):
if not PerturbationConfig.query.get(perturbation_configs_id):
return TaskService.json_error('加噪配置不存在')
# 验证上传的图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
if not files:
return TaskService.json_error('请上传至少一张图片')
allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for file in files:
if not file.filename:
continue
_, ext = os.path.splitext(file.filename)
if ext.lower() not in allowed_extensions:
return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。')
try:
flow_id = data.get('flow_id')
flow_id = int(flow_id) if flow_id is not None else TaskService.generate_flow_id()
@ -208,7 +221,6 @@ def create_perturbation_task(current_user_id):
db.session.commit()
# 自动上传图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])
@ -428,6 +440,14 @@ def create_finetune_from_perturbation(current_user_id):
if not FinetuneConfig.query.get(finetune_configs_id):
return TaskService.json_error('微调配置不存在')
# 确定 data_type_id优先使用用户输入否则继承自加噪任务
data_type_id = data.get('data_type_id')
if not data_type_id and perturbation_task.perturbation:
data_type_id = perturbation_task.perturbation.data_type_id
if data_type_id and not DataType.query.get(data_type_id):
return TaskService.json_error('数据集类型不存在')
try:
waiting_status = TaskService.ensure_status('waiting')
finetune_type = TaskService.require_task_type('finetune')
@ -448,8 +468,9 @@ def create_finetune_from_perturbation(current_user_id):
finetune = Finetune(
tasks_id=task.tasks_id,
finetune_configs_id=finetune_configs_id,
data_type_id=data.get('data_type_id'),
finetune_name=data.get('finetune_name')
data_type_id=data_type_id,
finetune_name=data.get('finetune_name'),
custom_prompt=data.get('custom_prompt')
)
db.session.add(finetune)
db.session.commit()
@ -503,6 +524,19 @@ def create_finetune_from_upload(current_user_id):
if not data_type:
return TaskService.json_error('数据集类型不存在')
# 验证上传的图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
if not files:
return TaskService.json_error('请上传至少一张图片')
allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for file in files:
if not file.filename:
continue
_, ext = os.path.splitext(file.filename)
if ext.lower() not in allowed_extensions:
return TaskService.json_error(f'不支持的文件格式: {file.filename}。仅支持图片格式。')
try:
flow_id = data.get('flow_id')
if flow_id is not None:
@ -536,13 +570,13 @@ def create_finetune_from_upload(current_user_id):
tasks_id=task.tasks_id,
finetune_configs_id=finetune_configs_id,
data_type_id=data_type_id,
finetune_name=data.get('finetune_name')
finetune_name=data.get('finetune_name'),
custom_prompt=data.get('custom_prompt')
)
db.session.add(finetune)
db.session.commit()
# 自动上传图片(仅上传微调任务)
files = request.files.getlist('files') if hasattr(request, 'files') else []
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])

@ -98,7 +98,10 @@ class DataType(db.Model):
__tablename__ = 'data_type'
data_type_id = db.Column(Integer, primary_key=True)
data_type_code = db.Column(String(50), nullable=False)
data_type_prompt = db.Column(Text, comment='数据集相关的Prompt')
instance_prompt = db.Column(Text, comment='数据集相关的Prompt (Instance Prompt Template, e.g. "a photo of sks person")')
class_prompt = db.Column(String(255), comment='类别Prompt (e.g. "a photo of person")')
placeholder_token = db.Column(String(50), comment='TI Placeholder (e.g. "<sks-concept>")')
initializer_token = db.Column(String(50), comment='TI Initializer (e.g. "person")')
description = db.Column(Text)
def __repr__(self):
@ -245,6 +248,7 @@ class Finetune(db.Model):
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='微调配置ID')
data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集')
finetune_name = db.Column(String(100), comment='微调任务名称')
custom_prompt = db.Column(String(255), comment='用户自定义提示词(如 a photo of sks cat)')
task = db.relationship('Task', back_populates='finetune')
finetune_config = db.relationship('FinetuneConfig')

@ -231,6 +231,7 @@ class TaskService:
'finetune_configs_id': task.finetune.finetune_configs_id,
'data_type_id': task.finetune.data_type_id,
'finetune_name': task.finetune.finetune_name,
'custom_prompt': task.finetune.custom_prompt,
'source': source
}
elif task_type == 'heatmap' and task.heatmap:
@ -403,6 +404,12 @@ class TaskService:
if not perturbation:
logger.error(f"Perturbation task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 获取用户ID
user_id = task.user_id
@ -480,6 +487,12 @@ class TaskService:
if not finetune:
logger.error(f"Finetune task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 获取用户ID
user_id = task.user_id
@ -641,6 +654,12 @@ class TaskService:
if not heatmap:
logger.error(f"Heatmap task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 从heatmap对象获取扰动图片ID
perturbed_image_id = heatmap.images_id
@ -729,6 +748,12 @@ class TaskService:
if not evaluate:
logger.error(f"Evaluate task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
finetune = Finetune.query.get(evaluate.finetune_task_id)
if not finetune:

@ -66,20 +66,55 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
# 从数据库获取数据集类型的提示词
# 从Finetune表的data_type_id获取
instance_prompt = "a photo of sks person" # 默认值
class_prompt = "a photo of person" # 默认值
validation_prompt = "a photo of sks person" # 默认值
# 获取 DataType 配置
data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None
if finetune.data_type_id:
data_type = DataType.query.get(finetune.data_type_id)
if data_type and data_type.data_type_prompt:
instance_prompt = data_type.data_type_prompt
validation_prompt = instance_prompt
# 从instance_prompt生成class_prompt移除"sks"
class_prompt = instance_prompt.replace('sks ', '')
logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'")
# 默认值 (Fallback)
instance_prompt = "a photo of sks person"
class_prompt = "a photo of person"
placeholder_token = "<sks-concept>"
initializer_token = "person"
if data_type:
if data_type.instance_prompt:
instance_prompt = data_type.instance_prompt
if data_type.class_prompt:
class_prompt = data_type.class_prompt
if data_type.placeholder_token:
placeholder_token = data_type.placeholder_token
if data_type.initializer_token:
initializer_token = data_type.initializer_token
logger.info(f"DataType Config - Template: '{instance_prompt}', Class: '{class_prompt}'")
# 根据微调方法调整 Instance Prompt
if finetune_method == 'textual_inversion':
# TI: 将 'sks' 替换为 placeholder_token
instance_prompt_prefix = instance_prompt.replace('sks', placeholder_token)
else:
# DreamBooth/LoRA: 直接使用模板
instance_prompt_prefix = instance_prompt
# 处理 Validation Prompt (拼接后缀)
prompt_suffix = finetune.custom_prompt.strip() if finetune.custom_prompt else ""
if prompt_suffix:
validation_prompt = f"{instance_prompt_prefix}, {prompt_suffix}"
else:
validation_prompt = instance_prompt_prefix
instance_prompt = instance_prompt_prefix
logger.info(f"Prompts Finalized - Instance: '{instance_prompt}', Class: '{class_prompt}', Validation: '{validation_prompt}'")
# 设置 TI 特有参数
if custom_params is None:
custom_params = {}
if finetune_method == 'textual_inversion':
custom_params['placeholder_token'] = placeholder_token
custom_params['initializer_token'] = initializer_token
logger.info(f"TI Tokens - Placeholder: '{placeholder_token}', Initializer: '{initializer_token}'")
# 彻底清空输出目录(避免旧文件残留,特别是 textual_inversion 的 token
logger.info(f"Completely clearing output directories...")
@ -251,7 +286,12 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 特殊处理包含空格或特殊字符的参数,防止 conda run 的 shell 脚本解析错误
# 特别是 < > 符号会被 shell 解释为重定向
if isinstance(value, str) and any(c in value for c in [' ', '<', '>', '&', '|', ';', '(', ')']):
cmd_args.append(f"--{key}='{value}'")
else:
cmd_args.append(f"--{key}={value}")
# 设置环境变量
env = os.environ.copy()

@ -169,10 +169,9 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
if not data_type:
raise ValueError(f"DataType {perturbation.data_type_id} not found")
# 从data_type_prompt中提取instance_prompt
instance_prompt = data_type.data_type_prompt or 'a photo of sks person'
# 从instance_prompt生成class_prompt移除"sks"
class_prompt = instance_prompt.replace('sks ', '')
# 从数据库获取数据集类型的提示词
instance_prompt = data_type.instance_prompt or 'a photo of sks person'
class_prompt = data_type.class_prompt or 'a photo of person'
logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'")
@ -205,14 +204,14 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={int(epsilon)}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'pid':
# PID参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={int(epsilon)}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'glaze':
# Glaze参数结构

@ -238,9 +238,9 @@ class AlgorithmConfig:
'conda_env': CONDA_ENVS['textual_inversion'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'placeholder_token': 'sks',
'placeholder_token': '<sks-concept>',
'initializer_token': 'person',
'instance_prompt': 'a photo of sks person',
'instance_prompt': 'a photo of <sks-concept> person',
'resolution': 512,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
@ -251,7 +251,7 @@ class AlgorithmConfig:
'checkpointing_steps': 2,
'seed': 0,
'mixed_precision': 'fp16',
'validation_prompt': 'a photo of sks person',
'validation_prompt': 'a photo of <sks-concept> person',
'num_validation_images': 4,
'validation_epochs': 50,
'coords_log_interval': 10

@ -85,8 +85,22 @@ def init_database():
# 初始化数据集类型数据
data_types = [
{'data_type_code': 'facial', 'data_type_prompt': 'a photo of sks person', 'description': '人脸类型的数据集'},
{'data_type_code': 'art', 'data_type_prompt': 'a painting in the style of sks', 'description': '艺术品类型的数据集'}
{
'data_type_code': 'face',
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'placeholder_token': '<sks-concept>',
'initializer_token': 'person',
'description': '人脸类型的数据集'
},
{
'data_type_code': 'art',
'instance_prompt': 'a painting in sks style',
'class_prompt': 'a painting',
'placeholder_token': '<sks-style>',
'initializer_token': 'painting',
'description': '艺术品类型的数据集'
}
]
for data_type in data_types:
existing = DataType.query.filter_by(data_type_code=data_type['data_type_code']).first()

Loading…
Cancel
Save