将lianghao_branch合并到develop #27

Merged
hnu202326010204 merged 8 commits from lianghao_branch into develop 4 weeks ago

@ -1057,6 +1057,109 @@
- `404 {"error": "微调任务配置不存在"}`
- `500 {"error": "..."}`(判定来源失败)
#### GET `/api/task/finetune/<task_id>/coords`
**功能**获取指定微调任务的3D可视化坐标数据用于前端绘制训练轨迹图。根据微调任务类型返回不同数量的CSV数据基于加噪任务的微调返回2个坐标文件原图和加噪图上传图片的微调返回1个坐标文件。
**认证**:是
**路径参数**`task_id`(整数)- 微调任务ID。
**查询参数**:无。
**成功响应** `200 OK`(基于加噪任务的微调):
```json
{
"task_id": 982,
"flow_id": 60001,
"source": "perturbation",
"coords": [
{
"type": "original",
"filename": "original_coords.csv",
"path": "/root/autodl-tmp/MuseGuard/src/backend/static/eva_res/position/42/60001/982/original_coords.csv",
"data": [
{
"step": "0",
"x": "0.123",
"y": "0.456",
"z": "0.789"
},
{
"step": "10",
"x": "0.234",
"y": "0.567",
"z": "0.890"
}
]
},
{
"type": "perturbed",
"filename": "perturbed_coords.csv",
"path": "/root/autodl-tmp/MuseGuard/src/backend/static/eva_res/position/42/60001/982/perturbed_coords.csv",
"data": [
{
"step": "0",
"x": "0.321",
"y": "0.654",
"z": "0.987"
},
{
"step": "10",
"x": "0.432",
"y": "0.765",
"z": "0.098"
}
]
}
]
}
```
**成功响应** `200 OK`(上传图片的微调):
```json
{
"task_id": 983,
"flow_id": 60002,
"source": "uploaded",
"coords": [
{
"type": "uploaded",
"filename": "coords.csv",
"path": "/root/autodl-tmp/MuseGuard/src/backend/static/eva_res/position/42/60002/983/coords.csv",
"data": [
{
"step": "0",
"x": "0.111",
"y": "0.222",
"z": "0.333"
},
{
"step": "10",
"x": "0.444",
"y": "0.555",
"z": "0.666"
}
]
}
]
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
- `404 {"error": "微调任务详情不存在"}`
- `404 {"error": "原图坐标文件不存在"}` / `{"error": "加噪图坐标文件不存在"}` / `{"error": "坐标文件不存在"}`
- `500 {"error": "读取原图坐标文件失败: ..."}` / `{"error": "读取加噪图坐标文件失败: ..."}` / `{"error": "读取坐标文件失败: ..."}`
- `500 {"error": "..."}`(判定来源失败)
**说明**
- 基于加噪任务的微调(`source=perturbation`):返回 `original_coords.csv``perturbed_coords.csv` 两个文件
- 上传图片的微调(`source=uploaded`):返回 `coords.csv` 一个文件
- CSV 数据格式包含 `step`(训练步数)、`x`、`y`、`z`(三维坐标)字段
- `data` 数组中每个对象代表一个训练步骤的坐标点
### GET `/api/image/evaluate/<task_id>`
**功能**获取评估任务的结果图片base64格式
**认证**:是
@ -1555,21 +1658,49 @@ Authorization: Bearer <token>
## 五、任务日志相关
### GET `/api/task/<task_id>/logs`
**功能**:获取指定任务的运行日志。
**功能**:获取指定任务的运行日志。系统会自动查找该任务最新的日志文件(匹配模式:`*task_{task_id}_*.log`),如果存在多个日志文件,返回修改时间最新的一个。适用于所有任务类型(加噪、微调、热力图、评估)。
**认证**:是
**成功响应** `200 OK`
**路径参数**`task_id`(整数)- 任务ID。
**查询参数**:无。
**成功响应** `200 OK`(有日志内容):
```json
{
"logs": "2025-12-17 10:00:00 - INFO - Task started\n2025-12-17 10:00:01 - INFO - Loading model...\n2025-12-17 10:00:05 - INFO - Processing images...\n2025-12-17 10:05:30 - INFO - Task completed successfully"
}
```
**成功响应** `200 OK`(暂无日志):
```json
{
"logs": "2025-12-12 10:00:00 - INFO - Starting task...\n2025-12-12 10:00:01 - INFO - Processing..."
"logs": "暂无日志"
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
- `500 {"error": "读取日志失败: ..."}`
- `500 {"error": "读取日志失败: <错误详情>"}`
**说明**
- 日志文件命名规则:`<type>_task_<task_id>_<timestamp>.log`
- 日志内容为纯文本格式,包含任务执行过程中的信息、警告和错误
- 如果任务尚未启动或日志文件不存在,返回 "暂无日志"
- 日志文件存储在服务器配置的 `LOGS_DIR` 目录下
**使用场景**
- 实时监控任务执行进度
- 调试任务失败原因
- 分析算法执行过程
- 追踪系统运行状态
---
## 文档更新记录
- [POST /api/task/finetune/from-perturbation](#post-apitaskfinetunefrom-perturbation):新增 `custom_prompt` 参数。
- [POST /api/task/finetune/from-upload](#post-apitaskfinetunefrom-upload):新增 `custom_prompt` 参数。
- [GET /api/task/finetune/<task_id>/coords](#get-apitaskfinetunetask_idcoords)完善3D可视化坐标数据接口文档新增详细的请求响应格式说明和错误处理。
- [GET /api/task/<task_id>/logs](#get-apitasktask_idlogs):完善任务日志接口文档,新增详细的功能说明、响应格式、错误处理和使用场景。

@ -222,7 +222,9 @@ def main(args):
for i in range(0, len(dataset.instance_images_path)):
img = dataset[i]["pixel_values"]
img = to_image(img + attackmodel.delta[i])
img.save(os.path.join(args.output_dir, f"{i}.png"))
# 获取原始文件名(不含扩展名)
original_filename = Path(dataset.instance_images_path[i]).stem
img.save(os.path.join(args.output_dir, f"pid_{original_filename}.png"))
# Select target loss
clean_embedding = attackmodel(vae, batch["pixel_values"], batch["index"], False)
@ -270,6 +272,20 @@ def main(args):
# Logging steps
logs = {"loss": total_loss.item()}
progress_bar.set_postfix(**logs)
# 训练结束后保存最终的加噪图片
print("\nSaving final perturbed images...")
to_image = transforms.ToPILImage()
for i in range(0, len(dataset.instance_images_path)):
img = dataset[i]["pixel_values"]
img = to_image(img + attackmodel.delta[i])
# 获取原始文件名(不含扩展名)
original_filename = Path(dataset.instance_images_path[i]).stem
save_path = os.path.join(args.output_dir, f"pid_{original_filename}.png")
img.save(save_path)
print(f"Saved: {save_path}")
print(f"\nAll {len(dataset.instance_images_path)} perturbed images saved to {args.output_dir}")
if __name__ == "__main__":

@ -169,7 +169,7 @@ def create_perturbation_task(current_user_id):
if not data_type:
return TaskService.json_error('数据集类型不存在')
role_code = user.role.role_code if user.role else 'user'
if role_code in ('user', 'normal') and data_type.data_type_code != 'facial':
if role_code in ('user', 'normal') and data_type.data_type_code != 'face':
return TaskService.json_error('普通用户仅可使用人脸数据集', 403)
if not PerturbationConfig.query.get(perturbation_configs_id):
return TaskService.json_error('加噪配置不存在')

@ -562,7 +562,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=original_coords_save_path,
validation_output_dir=original_output_dir,
is_perturbed=False,
finetune_type="original",
custom_params=None,
job_id=job_id_original,
job_timeout='8h'
@ -577,7 +577,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=perturbed_coords_save_path,
validation_output_dir=perturbed_output_dir,
is_perturbed=True,
finetune_type="perturbed",
custom_params=None,
job_id=job_id_perturbed,
job_timeout='8h'
@ -616,7 +616,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_output_dir=uploaded_output_dir,
is_perturbed=False,
finetune_type="uploaded",
custom_params=None,
job_id=job_id,
job_timeout='8h'

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def run_finetune_task(task_id, finetune_method, train_images_dir,
output_model_dir, class_dir, coords_save_path, validation_output_dir,
is_perturbed=False, custom_params=None):
finetune_type, custom_params=None):
"""
执行微调任务仅使用真实算法
@ -32,7 +32,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
class_dir: 类别图片目录
coords_save_path: 坐标保存路径
validation_output_dir: 验证图片输出目录
is_perturbed: 是否使用扰动图片训练
finetune_type: 微调类型 (original, perturbed, uploaded)
custom_params: 自定义参数
Returns:
@ -64,7 +64,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
logger.info(f"Method: {finetune_method}, finetune_type: {finetune_type}")
# 获取 DataType 配置
data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None
@ -138,8 +138,8 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
if not is_perturbed:
# 原图微调:清除旧日志,创建新日志
if finetune_type == "original" or finetune_type == "uploaded":
# 原图/上传微调:清除旧日志,创建新日志
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
for old_log in old_logs:
try:
@ -152,7 +152,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
log_dir,
f'finetune_{finetune_method}_task_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
else:
elif finetune_type == "perturbed":
# 扰动图微调:尝试复用现有日志
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
if old_logs:
@ -170,11 +170,11 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
result = _run_real_finetune(
finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file
)
# 保存生成的验证图片到数据库
_save_generated_images(task_id, validation_output_dir, is_perturbed)
_save_generated_images(task_id, validation_output_dir, finetune_type)
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
@ -205,7 +205,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file):
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file):
"""
运行真实微调算法参考sh脚本配置
@ -220,7 +220,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
instance_prompt: 实例提示词
class_prompt: 类别提示词
validation_prompt: 验证提示词
is_perturbed: 是否使用扰动图片
finetune_type: 微调类型 (original, perturbed, uploaded)
custom_params: 自定义参数
log_file: 日志文件路径
"""
@ -309,7 +309,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
# 执行命令
# 使用追加模式 'a',以便在同一日志文件中记录原图和扰动图的微调过程
with open(log_file, 'a') as f:
if is_perturbed:
if finetune_type == "perturbed":
f.write(f"\n\n{'='*30}\nStarting Perturbed Finetune Task\n{'='*30}\n\n")
process = subprocess.Popen(
@ -386,7 +386,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
}
def _save_generated_images(task_id, output_dir, is_perturbed):
def _save_generated_images(task_id, output_dir, finetune_type):
"""
保存微调生成的验证图片到数据库适配新数据库结构
@ -398,7 +398,7 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
Args:
task_id: 任务ID
output_dir: 生成图片输出目录
is_perturbed: 是否为扰动图片训练生成
finetune_type: 微调类型 (original, perturbed, uploaded)
"""
from app import db
from app.database import Task, Image, ImageType
@ -410,12 +410,15 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
raise ValueError(f"Task {task_id} not found")
# 获取图片类型
if is_perturbed:
if finetune_type == "perturbed":
generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
input_type = ImageType.query.filter_by(image_code='perturbed').first()
else:
elif finetune_type == "original":
generated_type = ImageType.query.filter_by(image_code='original_generate').first()
input_type = ImageType.query.filter_by(image_code='original').first()
elif finetune_type == "uploaded":
generated_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
input_type = ImageType.query.filter_by(image_code='original').first()
if not generated_type or not input_type:
raise ValueError("Required image types not found in database")

@ -75,12 +75,9 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
perturbation = Perturbation.query.get(pert_task.tasks_id)
if perturbation and perturbation.data_type_id:
data_type = DataType.query.get(perturbation.data_type_id)
if data_type and data_type.data_type_prompt:
prompt_text = data_type.data_type_prompt
# 提取target_word去除"sks"后的第一个名词)
words = prompt_text.replace('sks ', '').split()
if words:
target_word = words[-1] # 取最后一个词作为target
if data_type and data_type.instance_prompt:
prompt_text = data_type.instance_prompt
target_word = data_type.initializer_token
logger.info(f"Using prompts from database - prompt: '{prompt_text}', target: '{target_word}'")
break
@ -119,12 +116,13 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path,
)
# 保存热力图文件到数据库
heatmap_file = os.path.join(output_dir, 'heatmap_dif.png')
heatmap_file = os.path.join(output_dir, 'dual_heatmap_report.png')
if os.path.exists(heatmap_file):
heatmap.heatmap_name = 'heatmap_dif.png'
heatmap.heatmap_name = 'dual_heatmap_report.png'
# 保存热力图到Image表
_save_heatmap_image(task_id, heatmap_file, perturbed_image_id)
db.session.commit()
else:
logger.error(f"Heatmap file not found: {heatmap_file}")
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()

@ -342,20 +342,19 @@ def _save_perturbed_images(task_id, output_dir):
perturbed_filename = os.path.basename(perturbed_path)
# 尝试匹配原始图片(建立父子关系)
# 算法可能输出同名文件或带前缀的文件
# 策略:检查加噪图的文件名中是否包含原图的文件名(去除扩展名)
father_image = None
# 策略1: 完全匹配文件名
if perturbed_filename in original_map:
father_image = original_map[perturbed_filename]
else:
# 策略2: 移除可能的前缀如perturbed_
for prefix in ['perturbed_', 'adv_', 'protected_']:
if perturbed_filename.startswith(prefix):
clean_name = perturbed_filename[len(prefix):]
if clean_name in original_map:
father_image = original_map[clean_name]
break
# 获取加噪图的文件名(不含扩展名)
perturbed_name_without_ext = os.path.splitext(perturbed_filename)[0]
# 遍历所有原图,检查原图文件名(不含扩展名)是否包含在加噪图文件名中
for original_filename, original_img in original_map.items():
original_name_without_ext = os.path.splitext(original_filename)[0]
if original_name_without_ext in perturbed_name_without_ext:
father_image = original_img
logger.info(f"Matched perturbed image '{perturbed_filename}' to original '{original_filename}'")
break
if not father_image:
logger.warning(f"Could not find father image for {perturbed_filename}, saving without father_id")

@ -67,9 +67,9 @@ class AlgorithmConfig:
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 1,
'max_f_train_steps': 1,
'max_adv_train_steps': 1,
'max_train_steps': 5,
'max_f_train_steps': 5,
'max_adv_train_steps': 5,
'checkpointing_iterations': 1,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
@ -201,8 +201,7 @@ class AlgorithmConfig:
'sample_batch_size': 5,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 2,
'validation_steps': 1,
'coords_log_interval': 10
'coords_log_interval': 1
}
},
'lora': {
@ -229,7 +228,7 @@ class AlgorithmConfig:
'rank': 4,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 2,
'coords_log_interval': 10
'coords_log_interval': 1
}
},
'textual_inversion': {
@ -254,7 +253,7 @@ class AlgorithmConfig:
'validation_prompt': 'a photo of <sks-concept> person',
'num_validation_images': 4,
'validation_epochs': 50,
'coords_log_interval': 10
'coords_log_interval': 1
}
}
}

Loading…
Cancel
Save