将lianghao_branch合并到develop #13

Merged
hnu202326010204 merged 2 commits from lianghao_branch into develop 1 month ago

@ -1,4 +1,3 @@
"""
任务管理控制器
适配新数据库结构提供加噪微调热力图数值评估等任务相关接口
@ -14,6 +13,7 @@ from app.database import (
Image
)
from app.services.task_service import TaskService
from app.services.image_service import ImageService
task_bp = Blueprint('task', __name__)
@ -115,10 +115,17 @@ def list_perturbation_configs(current_user_id):
@task_bp.route('/perturbation', methods=['POST'])
@int_jwt_required
def create_perturbation_task(current_user_id):
data = request.get_json() or {}
data_type_id = data.get('data_type_id')
perturbation_configs_id = data.get('perturbation_configs_id')
intensity = data.get('perturbation_intensity')
# 兼容 form-data 和 json推荐 form-data
if request.content_type and request.content_type.startswith('multipart/form-data'):
data = request.form
else:
data = request.get_json() or {}
# 参数解析form-data 推荐全部用字符串
data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0))
perturbation_configs_id = data.get('perturbation_configs_id', type=int) if hasattr(data, 'get') else int(data.get('perturbation_configs_id', 0))
intensity = data.get('perturbation_intensity', type=float) if hasattr(data, 'get') else float(data.get('perturbation_intensity', 0))
description = data.get('description')
if not all([data_type_id, perturbation_configs_id, intensity]):
return TaskService.json_error('缺少必要的任务参数')
@ -169,9 +176,17 @@ def create_perturbation_task(current_user_id):
db.session.add(perturbation)
db.session.commit()
# 自动上传图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])
# 创建任务成功后自动启动任务
job_id = TaskService.start_perturbation_task(task.tasks_id)
return jsonify({
'message': '加噪任务已创建',
'task': TaskService.serialize_task(task)
'message': '加噪任务已创建并已启动',
'task': TaskService.serialize_task(task),
'job_id': job_id
}), 201
except Exception as exc:
db.session.rollback()
@ -300,7 +315,13 @@ def create_heatmap_task(current_user_id):
db.session.add(heatmap)
db.session.commit()
return jsonify({'message': '热力图任务已创建', 'task': TaskService.serialize_task(task)}), 201
# 创建任务成功后自动启动任务
job_id = TaskService.start_heatmap_task(task.tasks_id)
return jsonify({
'message': '热力图任务已创建并已启动',
'task': TaskService.serialize_task(task),
'job_id': job_id
}), 201
except Exception as exc:
db.session.rollback()
return TaskService.json_error(f'创建热力图任务失败: {exc}', 500)
@ -402,7 +423,13 @@ def create_finetune_from_perturbation(current_user_id):
db.session.add(finetune)
db.session.commit()
return jsonify({'message': '微调任务已创建', 'task': TaskService.serialize_task(task)}), 201
# 创建任务成功后自动启动任务
job_id = TaskService.start_finetune_task(task.tasks_id)
return jsonify({
'message': '微调任务已创建并已启动',
'task': TaskService.serialize_task(task),
'job_id': job_id
}), 201
except Exception as exc:
db.session.rollback()
return TaskService.json_error(f'创建微调任务失败: {exc}', 500)
@ -419,10 +446,20 @@ def create_finetune_from_upload(current_user_id):
if role_code not in ('vip', 'admin'):
return TaskService.json_error('仅限VIP或管理员使用上传微调功能', 403)
data = request.get_json() or {}
finetune_configs_id = data.get('finetune_configs_id')
# 兼容 form-data 和 json推荐 form-data
if request.content_type and request.content_type.startswith('multipart/form-data'):
data = request.form
else:
data = request.get_json() or {}
finetune_configs_id = data.get('finetune_configs_id', type=int) if hasattr(data, 'get') else int(data.get('finetune_configs_id', 0))
data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0))
description = data.get('description')
if not finetune_configs_id:
return TaskService.json_error('缺少必要参数: finetune_configs_id')
if not data_type_id:
return TaskService.json_error('缺少必要参数: data_type_id')
finetune_config = FinetuneConfig.query.get(finetune_configs_id)
if not finetune_config:
@ -473,9 +510,18 @@ def create_finetune_from_upload(current_user_id):
db.session.add(finetune)
db.session.commit()
# 自动上传图片(仅上传微调任务)
files = request.files.getlist('files') if hasattr(request, 'files') else []
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])
# 自动启动任务
job_id = TaskService.start_finetune_task(task.tasks_id)
return jsonify({
'message': '上传微调任务已创建',
'task': TaskService.serialize_task(task)
'message': '上传微调任务已创建并已启动',
'task': TaskService.serialize_task(task),
'images': [ImageService.serialize_image(img) for img in result],
'job_id': job_id
}), 201
except Exception as exc:
db.session.rollback()
@ -590,7 +636,13 @@ def create_evaluate_task(current_user_id):
db.session.add(evaluate)
db.session.commit()
return jsonify({'message': '评估任务已创建', 'task': TaskService.serialize_task(task)}), 201
# 创建任务成功后自动启动任务
job_id = TaskService.start_evaluate_task(task.tasks_id)
return jsonify({
'message': '评估任务已创建并已启动',
'task': TaskService.serialize_task(task),
'job_id': job_id
}), 201
except Exception as exc:
db.session.rollback()
return TaskService.json_error(f'创建评估任务失败: {exc}', 500)

Loading…
Cancel
Save