|
|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
任务管理控制器
|
|
|
|
|
适配新数据库结构,提供加噪、微调、热力图、数值评估等任务相关接口
|
|
|
|
|
@ -14,6 +13,7 @@ from app.database import (
|
|
|
|
|
Image
|
|
|
|
|
)
|
|
|
|
|
from app.services.task_service import TaskService
|
|
|
|
|
from app.services.image_service import ImageService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_bp = Blueprint('task', __name__)
|
|
|
|
|
@ -115,10 +115,17 @@ def list_perturbation_configs(current_user_id):
|
|
|
|
|
@task_bp.route('/perturbation', methods=['POST'])
|
|
|
|
|
@int_jwt_required
|
|
|
|
|
def create_perturbation_task(current_user_id):
|
|
|
|
|
data = request.get_json() or {}
|
|
|
|
|
data_type_id = data.get('data_type_id')
|
|
|
|
|
perturbation_configs_id = data.get('perturbation_configs_id')
|
|
|
|
|
intensity = data.get('perturbation_intensity')
|
|
|
|
|
# 兼容 form-data 和 json,推荐 form-data
|
|
|
|
|
if request.content_type and request.content_type.startswith('multipart/form-data'):
|
|
|
|
|
data = request.form
|
|
|
|
|
else:
|
|
|
|
|
data = request.get_json() or {}
|
|
|
|
|
|
|
|
|
|
# 参数解析,form-data 推荐全部用字符串
|
|
|
|
|
data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0))
|
|
|
|
|
perturbation_configs_id = data.get('perturbation_configs_id', type=int) if hasattr(data, 'get') else int(data.get('perturbation_configs_id', 0))
|
|
|
|
|
intensity = data.get('perturbation_intensity', type=float) if hasattr(data, 'get') else float(data.get('perturbation_intensity', 0))
|
|
|
|
|
description = data.get('description')
|
|
|
|
|
|
|
|
|
|
if not all([data_type_id, perturbation_configs_id, intensity]):
|
|
|
|
|
return TaskService.json_error('缺少必要的任务参数')
|
|
|
|
|
@ -169,9 +176,17 @@ def create_perturbation_task(current_user_id):
|
|
|
|
|
db.session.add(perturbation)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
# 自动上传图片
|
|
|
|
|
files = request.files.getlist('files') if hasattr(request, 'files') else []
|
|
|
|
|
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
|
|
|
|
|
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])
|
|
|
|
|
|
|
|
|
|
# 创建任务成功后自动启动任务
|
|
|
|
|
job_id = TaskService.start_perturbation_task(task.tasks_id)
|
|
|
|
|
return jsonify({
|
|
|
|
|
'message': '加噪任务已创建',
|
|
|
|
|
'task': TaskService.serialize_task(task)
|
|
|
|
|
'message': '加噪任务已创建并已启动',
|
|
|
|
|
'task': TaskService.serialize_task(task),
|
|
|
|
|
'job_id': job_id
|
|
|
|
|
}), 201
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
@ -300,7 +315,13 @@ def create_heatmap_task(current_user_id):
|
|
|
|
|
db.session.add(heatmap)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
return jsonify({'message': '热力图任务已创建', 'task': TaskService.serialize_task(task)}), 201
|
|
|
|
|
# 创建任务成功后自动启动任务
|
|
|
|
|
job_id = TaskService.start_heatmap_task(task.tasks_id)
|
|
|
|
|
return jsonify({
|
|
|
|
|
'message': '热力图任务已创建并已启动',
|
|
|
|
|
'task': TaskService.serialize_task(task),
|
|
|
|
|
'job_id': job_id
|
|
|
|
|
}), 201
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
return TaskService.json_error(f'创建热力图任务失败: {exc}', 500)
|
|
|
|
|
@ -402,7 +423,13 @@ def create_finetune_from_perturbation(current_user_id):
|
|
|
|
|
db.session.add(finetune)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
return jsonify({'message': '微调任务已创建', 'task': TaskService.serialize_task(task)}), 201
|
|
|
|
|
# 创建任务成功后自动启动任务
|
|
|
|
|
job_id = TaskService.start_finetune_task(task.tasks_id)
|
|
|
|
|
return jsonify({
|
|
|
|
|
'message': '微调任务已创建并已启动',
|
|
|
|
|
'task': TaskService.serialize_task(task),
|
|
|
|
|
'job_id': job_id
|
|
|
|
|
}), 201
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
return TaskService.json_error(f'创建微调任务失败: {exc}', 500)
|
|
|
|
|
@ -419,10 +446,20 @@ def create_finetune_from_upload(current_user_id):
|
|
|
|
|
if role_code not in ('vip', 'admin'):
|
|
|
|
|
return TaskService.json_error('仅限VIP或管理员使用上传微调功能', 403)
|
|
|
|
|
|
|
|
|
|
data = request.get_json() or {}
|
|
|
|
|
finetune_configs_id = data.get('finetune_configs_id')
|
|
|
|
|
# 兼容 form-data 和 json,推荐 form-data
|
|
|
|
|
if request.content_type and request.content_type.startswith('multipart/form-data'):
|
|
|
|
|
data = request.form
|
|
|
|
|
else:
|
|
|
|
|
data = request.get_json() or {}
|
|
|
|
|
|
|
|
|
|
finetune_configs_id = data.get('finetune_configs_id', type=int) if hasattr(data, 'get') else int(data.get('finetune_configs_id', 0))
|
|
|
|
|
data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0))
|
|
|
|
|
description = data.get('description')
|
|
|
|
|
|
|
|
|
|
if not finetune_configs_id:
|
|
|
|
|
return TaskService.json_error('缺少必要参数: finetune_configs_id')
|
|
|
|
|
if not data_type_id:
|
|
|
|
|
return TaskService.json_error('缺少必要参数: data_type_id')
|
|
|
|
|
|
|
|
|
|
finetune_config = FinetuneConfig.query.get(finetune_configs_id)
|
|
|
|
|
if not finetune_config:
|
|
|
|
|
@ -473,9 +510,18 @@ def create_finetune_from_upload(current_user_id):
|
|
|
|
|
db.session.add(finetune)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
# 自动上传图片(仅上传微调任务)
|
|
|
|
|
files = request.files.getlist('files') if hasattr(request, 'files') else []
|
|
|
|
|
target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id)
|
|
|
|
|
success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, [])
|
|
|
|
|
|
|
|
|
|
# 自动启动任务
|
|
|
|
|
job_id = TaskService.start_finetune_task(task.tasks_id)
|
|
|
|
|
return jsonify({
|
|
|
|
|
'message': '上传微调任务已创建',
|
|
|
|
|
'task': TaskService.serialize_task(task)
|
|
|
|
|
'message': '上传微调任务已创建并已启动',
|
|
|
|
|
'task': TaskService.serialize_task(task),
|
|
|
|
|
'images': [ImageService.serialize_image(img) for img in result],
|
|
|
|
|
'job_id': job_id
|
|
|
|
|
}), 201
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
@ -590,7 +636,13 @@ def create_evaluate_task(current_user_id):
|
|
|
|
|
db.session.add(evaluate)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
return jsonify({'message': '评估任务已创建', 'task': TaskService.serialize_task(task)}), 201
|
|
|
|
|
# 创建任务成功后自动启动任务
|
|
|
|
|
job_id = TaskService.start_evaluate_task(task.tasks_id)
|
|
|
|
|
return jsonify({
|
|
|
|
|
'message': '评估任务已创建并已启动',
|
|
|
|
|
'task': TaskService.serialize_task(task),
|
|
|
|
|
'job_id': job_id
|
|
|
|
|
}), 201
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
return TaskService.json_error(f'创建评估任务失败: {exc}', 500)
|
|
|
|
|
|