diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 4457370..4dcc0f9 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -1,4 +1,3 @@ - """ 任务管理控制器 适配新数据库结构,提供加噪、微调、热力图、数值评估等任务相关接口 @@ -14,6 +13,7 @@ from app.database import ( Image ) from app.services.task_service import TaskService +from app.services.image_service import ImageService task_bp = Blueprint('task', __name__) @@ -115,10 +115,17 @@ def list_perturbation_configs(current_user_id): @task_bp.route('/perturbation', methods=['POST']) @int_jwt_required def create_perturbation_task(current_user_id): - data = request.get_json() or {} - data_type_id = data.get('data_type_id') - perturbation_configs_id = data.get('perturbation_configs_id') - intensity = data.get('perturbation_intensity') + # 兼容 form-data 和 json,推荐 form-data + if request.content_type and request.content_type.startswith('multipart/form-data'): + data = request.form + else: + data = request.get_json() or {} + + # 参数解析,form-data 推荐全部用字符串 + data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0)) + perturbation_configs_id = data.get('perturbation_configs_id', type=int) if hasattr(data, 'get') else int(data.get('perturbation_configs_id', 0)) + intensity = data.get('perturbation_intensity', type=float) if hasattr(data, 'get') else float(data.get('perturbation_intensity', 0)) + description = data.get('description') if not all([data_type_id, perturbation_configs_id, intensity]): return TaskService.json_error('缺少必要的任务参数') @@ -169,9 +176,17 @@ def create_perturbation_task(current_user_id): db.session.add(perturbation) db.session.commit() + # 自动上传图片 + files = request.files.getlist('files') if hasattr(request, 'files') else [] + target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) + success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) + + # 创建任务成功后自动启动任务 + job_id = TaskService.start_perturbation_task(task.tasks_id) return jsonify({ - 'message': '加噪任务已创建', - 'task': TaskService.serialize_task(task) + 'message': '加噪任务已创建并已启动', + 'task': TaskService.serialize_task(task), + 'job_id': job_id }), 201 except Exception as exc: db.session.rollback() @@ -300,7 +315,13 @@ def create_heatmap_task(current_user_id): db.session.add(heatmap) db.session.commit() - return jsonify({'message': '热力图任务已创建', 'task': TaskService.serialize_task(task)}), 201 + # 创建任务成功后自动启动任务 + job_id = TaskService.start_heatmap_task(task.tasks_id) + return jsonify({ + 'message': '热力图任务已创建并已启动', + 'task': TaskService.serialize_task(task), + 'job_id': job_id + }), 201 except Exception as exc: db.session.rollback() return TaskService.json_error(f'创建热力图任务失败: {exc}', 500) @@ -402,7 +423,13 @@ def create_finetune_from_perturbation(current_user_id): db.session.add(finetune) db.session.commit() - return jsonify({'message': '微调任务已创建', 'task': TaskService.serialize_task(task)}), 201 + # 创建任务成功后自动启动任务 + job_id = TaskService.start_finetune_task(task.tasks_id) + return jsonify({ + 'message': '微调任务已创建并已启动', + 'task': TaskService.serialize_task(task), + 'job_id': job_id + }), 201 except Exception as exc: db.session.rollback() return TaskService.json_error(f'创建微调任务失败: {exc}', 500) @@ -419,10 +446,20 @@ def create_finetune_from_upload(current_user_id): if role_code not in ('vip', 'admin'): return TaskService.json_error('仅限VIP或管理员使用上传微调功能', 403) - data = request.get_json() or {} - finetune_configs_id = data.get('finetune_configs_id') + # 兼容 form-data 和 json,推荐 form-data + if request.content_type and request.content_type.startswith('multipart/form-data'): + data = request.form + else: + data = request.get_json() or {} + + finetune_configs_id = data.get('finetune_configs_id', type=int) if hasattr(data, 'get') else int(data.get('finetune_configs_id', 0)) + data_type_id = data.get('data_type_id', type=int) if hasattr(data, 'get') else int(data.get('data_type_id', 0)) + description = data.get('description') + if not finetune_configs_id: return TaskService.json_error('缺少必要参数: finetune_configs_id') + if not data_type_id: + return TaskService.json_error('缺少必要参数: data_type_id') finetune_config = FinetuneConfig.query.get(finetune_configs_id) if not finetune_config: @@ -473,9 +510,18 @@ def create_finetune_from_upload(current_user_id): db.session.add(finetune) db.session.commit() + # 自动上传图片(仅上传微调任务) + files = request.files.getlist('files') if hasattr(request, 'files') else [] + target_dir = TaskService.get_original_images_path(current_user_id, task.flow_id) + success, result = ImageService.save_original_images(task, files, target_dir) if files else (True, []) + + # 自动启动任务 + job_id = TaskService.start_finetune_task(task.tasks_id) return jsonify({ - 'message': '上传微调任务已创建', - 'task': TaskService.serialize_task(task) + 'message': '上传微调任务已创建并已启动', + 'task': TaskService.serialize_task(task), + 'images': [ImageService.serialize_image(img) for img in result], + 'job_id': job_id }), 201 except Exception as exc: db.session.rollback() @@ -590,7 +636,13 @@ def create_evaluate_task(current_user_id): db.session.add(evaluate) db.session.commit() - return jsonify({'message': '评估任务已创建', 'task': TaskService.serialize_task(task)}), 201 + # 创建任务成功后自动启动任务 + job_id = TaskService.start_evaluate_task(task.tasks_id) + return jsonify({ + 'message': '评估任务已创建并已启动', + 'task': TaskService.serialize_task(task), + 'job_id': job_id + }), 201 except Exception as exc: db.session.rollback() return TaskService.json_error(f'创建评估任务失败: {exc}', 500)