|
|
|
|
@ -106,24 +106,54 @@ def cancel_task(task_id, current_user_id):
|
|
|
|
|
return jsonify({'message': '任务已取消'}), 200
|
|
|
|
|
return TaskService.json_error('取消任务失败', 500)
|
|
|
|
|
|
|
|
|
|
@task_bp.route('/<int:task_id>/restart', methods=['POST'])
|
|
|
|
|
@int_jwt_required
|
|
|
|
|
def restart_task(task_id, current_user_id):
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not TaskService.ensure_task_owner(task, current_user_id):
|
|
|
|
|
return TaskService.json_error('任务不存在或无权限', 404)
|
|
|
|
|
# 只允许cancelled/failed状态重启
|
|
|
|
|
status_code = task.task_status.task_status_code if task and task.task_status else None
|
|
|
|
|
if status_code not in ("cancelled", "failed"):
|
|
|
|
|
return TaskService.json_error('仅取消或失败的任务可重启', 400)
|
|
|
|
|
if not TaskService.restart_task(task_id):
|
|
|
|
|
return TaskService.json_error('重启任务失败', 500)
|
|
|
|
|
# 自动启动任务(按类型分发)
|
|
|
|
|
type_code = TaskService.get_task_type_code(task)
|
|
|
|
|
if type_code == 'perturbation':
|
|
|
|
|
job_id = TaskService.start_perturbation_task(task_id)
|
|
|
|
|
elif type_code == 'finetune':
|
|
|
|
|
job_id = TaskService.start_finetune_task(task_id)
|
|
|
|
|
elif type_code == 'heatmap':
|
|
|
|
|
job_id = TaskService.start_heatmap_task(task_id)
|
|
|
|
|
elif type_code == 'evaluate':
|
|
|
|
|
job_id = TaskService.start_evaluate_task(task_id)
|
|
|
|
|
else:
|
|
|
|
|
job_id = None
|
|
|
|
|
return jsonify({'message': '任务已重启', 'job_id': job_id}), 200
|
|
|
|
|
|
|
|
|
|
@task_bp.route('/<int:task_id>', methods=['DELETE'])
|
|
|
|
|
@int_jwt_required
|
|
|
|
|
def delete_task(task_id, current_user_id):
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not TaskService.ensure_task_owner(task, current_user_id):
|
|
|
|
|
return TaskService.json_error('任务不存在或无权限', 404)
|
|
|
|
|
status_code = task.task_status.task_status_code if task and task.task_status else None
|
|
|
|
|
if status_code not in ("cancelled", "failed"):
|
|
|
|
|
return TaskService.json_error('仅取消或失败的任务可删除', 400)
|
|
|
|
|
success, err = TaskService.delete_task(task_id, user_id=current_user_id)
|
|
|
|
|
if not success:
|
|
|
|
|
return TaskService.json_error(f'删除任务失败: {err}', 500)
|
|
|
|
|
return jsonify({'message': '任务已删除'}), 200
|
|
|
|
|
|
|
|
|
|
@task_bp.route('/quota', methods=['GET'])
|
|
|
|
|
@int_jwt_required
|
|
|
|
|
def get_task_quota(current_user_id):
|
|
|
|
|
user = TaskService.get_user(current_user_id)
|
|
|
|
|
if not user:
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota is None:
|
|
|
|
|
return TaskService.json_error('用户不存在', 404)
|
|
|
|
|
|
|
|
|
|
role = user.role
|
|
|
|
|
max_tasks = role.max_concurrent_tasks if role and role.max_concurrent_tasks is not None else 0
|
|
|
|
|
current_count = Task.query.filter_by(user_id=current_user_id).count()
|
|
|
|
|
remaining = max(max_tasks - current_count, 0)
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
'max_tasks': max_tasks,
|
|
|
|
|
'current_tasks': current_count,
|
|
|
|
|
'remaining_tasks': remaining
|
|
|
|
|
}), 200
|
|
|
|
|
return jsonify(quota), 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 加噪任务 ====================
|
|
|
|
|
@ -223,6 +253,11 @@ def create_perturbation_task(current_user_id):
|
|
|
|
|
except Exception:
|
|
|
|
|
return TaskService.json_error('非法的 flow_id 参数')
|
|
|
|
|
|
|
|
|
|
# 检查配额
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota and quota['remaining_tasks'] <= 0:
|
|
|
|
|
return TaskService.json_error('任务配额已满,请等待现有任务完成', 403)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
waiting_status = TaskService.ensure_status('waiting')
|
|
|
|
|
perturb_type = TaskService.require_task_type('perturbation')
|
|
|
|
|
@ -372,6 +407,11 @@ def create_heatmap_task(current_user_id):
|
|
|
|
|
if image_code != 'perturbed':
|
|
|
|
|
return TaskService.json_error(f'仅支持加噪图生成热力图,当前图片类型为: {perturbed_image.image_type.image_name}', 400)
|
|
|
|
|
|
|
|
|
|
# 检查配额
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota and quota['remaining_tasks'] <= 0:
|
|
|
|
|
return TaskService.json_error('任务配额已满,请等待现有任务完成', 403)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
heatmap_type = TaskService.require_task_type('heatmap')
|
|
|
|
|
waiting_status = TaskService.ensure_status('waiting')
|
|
|
|
|
@ -487,6 +527,11 @@ def create_finetune_from_perturbation(current_user_id):
|
|
|
|
|
if data_type_id and not DataType.query.get(data_type_id):
|
|
|
|
|
return TaskService.json_error('数据集类型不存在')
|
|
|
|
|
|
|
|
|
|
# 检查配额
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota and quota['remaining_tasks'] <= 0:
|
|
|
|
|
return TaskService.json_error('任务配额已满,请等待现有任务完成', 403)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
waiting_status = TaskService.ensure_status('waiting')
|
|
|
|
|
finetune_type = TaskService.require_task_type('finetune')
|
|
|
|
|
@ -588,6 +633,11 @@ def create_finetune_from_upload(current_user_id):
|
|
|
|
|
except Exception:
|
|
|
|
|
return TaskService.json_error('非法的 flow_id 参数')
|
|
|
|
|
|
|
|
|
|
# 检查配额
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota and quota['remaining_tasks'] <= 0:
|
|
|
|
|
return TaskService.json_error('任务配额已满,请等待现有任务完成', 403)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
waiting_status = TaskService.ensure_status('waiting')
|
|
|
|
|
finetune_type = TaskService.require_task_type('finetune')
|
|
|
|
|
@ -715,6 +765,11 @@ def create_evaluate_task(current_user_id):
|
|
|
|
|
if not finetune_task.finetune:
|
|
|
|
|
return TaskService.json_error('微调任务未配置详情', 400)
|
|
|
|
|
|
|
|
|
|
# 检查配额
|
|
|
|
|
quota = TaskService.get_user_task_quota(current_user_id)
|
|
|
|
|
if quota and quota['remaining_tasks'] <= 0:
|
|
|
|
|
return TaskService.json_error('任务配额已满,请等待现有任务完成', 403)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
evaluate_type = TaskService.require_task_type('evaluate')
|
|
|
|
|
waiting_status = TaskService.ensure_status('waiting')
|
|
|
|
|
|