From ae261a6dc96df3bbf15880d0cb2dfd689936d478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Mon, 1 Dec 2025 09:25:51 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84controllers?= =?UTF-8?q?=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/controllers/auth_controller.py | 1 - .../app/controllers/image_controller.py | 331 ++--- .../app/controllers/task_controller.py | 1141 ++++++++--------- .../app/controllers/user_controller.py | 248 ++-- src/backend/app/services/image_service.py | 177 ++- src/backend/app/services/task_service.py | 133 +- 6 files changed, 1087 insertions(+), 944 deletions(-) diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py index 751a2a5..bd93886 100644 --- a/src/backend/app/controllers/auth_controller.py +++ b/src/backend/app/controllers/auth_controller.py @@ -7,7 +7,6 @@ from flask import Blueprint, request, jsonify from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity from app import db from app.database import User, UserConfig -from app.services.auth_service import AuthService from functools import wraps import re diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index d52a4a9..08aef56 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -1,203 +1,128 @@ -""" -图像管理控制器 -处理图像下载、查看等功能 -""" - -from flask import Blueprint, send_file, jsonify, request, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from app.database import Image, EvaluationResult -from app.services.image_service import ImageService -import os - -image_bp = Blueprint('image', __name__) - -@image_bp.route('/file/', methods=['GET']) -@jwt_required() -def get_image_file(image_id): - """获取图片文件""" - try: - current_user_id = get_jwt_identity() - - # 查找图片记录 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 检查文件是否存在 - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file(image.file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取图片失败: {str(e)}'}), 500 - -@image_bp.route('/download/', methods=['GET']) -@jwt_required() -def download_image(image_id): - """下载图片文件""" - try: - current_user_id = get_jwt_identity() - - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file( - image.file_path, - as_attachment=True, - download_name=image.original_filename or f"image_{image_id}.jpg" - ) - - except Exception as e: - return jsonify({'error': f'下载图片失败: {str(e)}'}), 500 - -@image_bp.route('/batch//download', methods=['GET']) -@jwt_required() -def download_batch_images(batch_id): - """批量下载任务中的加噪后图片""" - try: - current_user_id = get_jwt_identity() - - # 获取任务中的加噪图片 - perturbed_images = Image.query.join(Image.image_type).filter( - Image.batch_id == batch_id, - Image.user_id == current_user_id, - Image.image_type.has(type_code='perturbed') - ).all() - - if not perturbed_images: - return jsonify({'error': '没有找到加噪后的图片'}), 404 - - # 创建ZIP文件 - import zipfile - import tempfile - - with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: - with zipfile.ZipFile(tmp_file.name, 'w') as zip_file: - for image in perturbed_images: - if os.path.exists(image.file_path): - arcname = image.original_filename or f"perturbed_{image.id}.jpg" - zip_file.write(image.file_path, arcname) - - return send_file( - tmp_file.name, - as_attachment=True, - download_name=f"batch_{batch_id}_perturbed_images.zip", - mimetype='application/zip' - ) - - except Exception as e: - return jsonify({'error': f'批量下载失败: {str(e)}'}), 500 - -@image_bp.route('//evaluations', methods=['GET']) -@jwt_required() -def get_image_evaluations(image_id): - """获取图片的评估结果""" - try: - current_user_id = get_jwt_identity() - - # 验证图片权限 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 获取以该图片为参考或目标的评估结果 - evaluations = EvaluationResult.query.filter( - (EvaluationResult.reference_image_id == image_id) | - (EvaluationResult.target_image_id == image_id) - ).all() - - return jsonify({ - 'image_id': image_id, - 'evaluations': [eval_result.to_dict() for eval_result in evaluations] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500 - -@image_bp.route('/compare', methods=['POST']) -@jwt_required() -def compare_images(): - """对比两张图片""" - try: - current_user_id = get_jwt_identity() - data = request.get_json() - - image1_id = data.get('image1_id') - image2_id = data.get('image2_id') - - if not image1_id or not image2_id: - return jsonify({'error': '请提供两张图片的ID'}), 400 - - # 验证图片权限 - image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first() - image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first() - - if not image1 or not image2: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 查找现有的评估结果 - evaluation = EvaluationResult.query.filter_by( - reference_image_id=image1_id, - target_image_id=image2_id - ).first() - - if not evaluation: - # 如果没有评估结果,返回基本对比信息 - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': None, - 'message': '暂无评估数据,请等待任务处理完成' - }), 200 - - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': evaluation.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'图片对比失败: {str(e)}'}), 500 - -@image_bp.route('/heatmap/', methods=['GET']) -@jwt_required() -def get_heatmap(heatmap_path): - """获取热力图文件""" - try: - # 安全检查,防止路径遍历攻击 - if '..' in heatmap_path or heatmap_path.startswith('/'): - return jsonify({'error': '无效的文件路径'}), 400 - - # 修正路径构建 - 获取项目根目录(backend目录) - project_root = os.path.dirname(current_app.root_path) - full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path)) - - if not os.path.exists(full_path): - return jsonify({'error': '热力图文件不存在'}), 404 - - return send_file(full_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500 - -@image_bp.route('/delete/', methods=['DELETE']) -@jwt_required() -def delete_image(image_id): - """删除图片""" - try: - current_user_id = get_jwt_identity() - - result = ImageService.delete_image(image_id, current_user_id) - - if result['success']: - return jsonify({'message': '图片删除成功'}), 200 - else: - return jsonify({'error': result['error']}), 400 - - except Exception as e: - return jsonify({'error': f'删除图片失败: {str(e)}'}), 500 \ No newline at end of file + +""" +图像管理控制器 +负责图片上传、下载等操作 +""" + +from flask import Blueprint, request, jsonify, send_file +from app.controllers.auth_controller import int_jwt_required +from app.services.task_service import TaskService +from app.services.image_service import ImageService + + +image_bp = Blueprint('image', __name__) + + +# ==================== 图片上传 ==================== + +@image_bp.route('/original', methods=['POST']) +@int_jwt_required +def upload_original_images(current_user_id): + task_id = request.form.get('task_id', type=int) + if not task_id: + return ImageService.json_error('缺少 task_id 参数') + + task = TaskService.load_task_for_user(task_id, current_user_id) + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + task_type = TaskService.get_task_type_code(task) + if task_type not in {'perturbation', 'finetune'}: + return ImageService.json_error('任务类型不支持图片上传', 400) + + files = request.files.getlist('files') + target_dir = TaskService.get_original_images_path(task.user_id, task.flow_id) + success, result = ImageService.save_original_images(task, files, target_dir) + if not success: + status_code = 400 + if isinstance(result, str) and (result.startswith('未配置图片类型') or '失败' in result): + status_code = 500 + return ImageService.json_error(result, status_code) + + return jsonify({ + 'message': '图片上传成功', + 'images': [ImageService.serialize_image(img) for img in result], + 'flow_id': task.flow_id + }), 201 + + +# ==================== 结果下载 ==================== + +@image_bp.route('/perturbation//download', methods=['GET']) +@int_jwt_required +def download_perturbation_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('结果文件不存在', 404) + + filename = f"perturbation_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/heatmap//download', methods=['GET']) +@int_jwt_required +def download_heatmap_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('热力图文件不存在', 404) + + filename = f"heatmap_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/finetune//download', methods=['GET']) +@int_jwt_required +def download_finetune_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + if not task.finetune: + return ImageService.json_error('微调任务配置不存在', 404) + + try: + source = TaskService.determine_finetune_source(task) + except ValueError as exc: + return ImageService.json_error(str(exc), 500) + if source == 'perturbation': + directories = { + 'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id), + 'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id) + } + else: + directories = { + 'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id) + } + + zipped, has_files = ImageService.zip_multiple_directories(directories) + if not has_files: + return ImageService.json_error('微调结果文件不存在', 404) + + filename = f"finetune_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + + +@image_bp.route('/evaluate//download', methods=['GET']) +@int_jwt_required +def download_evaluate_result(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id) + zipped, has_files = ImageService.zip_directory(directory) + if not has_files: + return ImageService.json_error('评估结果文件不存在', 404) + + filename = f"evaluate_{task_id}.zip" + return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index f49e668..2fbe51b 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -1,606 +1,535 @@ -""" -任务管理控制器 -处理创建任务、上传图片等功能 -""" - -from flask import Blueprint, request, jsonify, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from werkzeug.utils import secure_filename -from app import db -from app.database import User, Role, PerturbationConfig, FinetuneConfig, UserConfig, Image, ImageType, DataType, TaskType, TaskStatus, Task, Perturbation, Finetune, EvaluationResult, Evaluate, Heatmap -from app.services.task_service import TaskService -from app.services.image_service import ImageService -from app.utils.file_utils import allowed_file, save_uploaded_file -import os -import zipfile -import uuid - -task_bp = Blueprint('task', __name__) - -@task_bp.route('/create', methods=['POST']) -@jwt_required() -def create_task(): - """创建新任务(使用用户配置作为默认配置)""" - try: - current_user_id = get_jwt_identity() - user = User.query.get(current_user_id) - - if not user: - return jsonify({'error': '用户不存在'}), 404 - - - data = request.get_json() - batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') - - # 优先使用前端传来的参数,没有则用用户配置,没有再用默认 - perturbation_config_id = data.get('perturbation_config_id') - preferred_epsilon = data.get('epsilon') - use_strong_protection = data.get('use_strong_protection') - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - if user_config: - if perturbation_config_id is None: - perturbation_config_id = user_config.preferred_perturbation_config_id or 1 - if preferred_epsilon is None: - preferred_epsilon = user_config.preferred_epsilon or 8.0 - if use_strong_protection is None: - use_strong_protection = user_config.preferred_purification or False - else: - perturbation_config_id = perturbation_config_id or 1 - preferred_epsilon = preferred_epsilon or 8.0 - use_strong_protection = use_strong_protection if use_strong_protection is not None else False - - # 类型转换,防止前端传字符串 - try: - perturbation_config_id = int(perturbation_config_id) - except Exception: - perturbation_config_id = 1 - try: - preferred_epsilon = float(preferred_epsilon) - except Exception: - preferred_epsilon = 8.0 - use_strong_protection = bool(use_strong_protection) - - # 创建任务(只包含扰动相关配置,不包含微调配置) - batch = Batch( - user_id=current_user_id, - batch_name=batch_name, - perturbation_config_id=perturbation_config_id, - preferred_epsilon=preferred_epsilon, - use_strong_protection=use_strong_protection - ) - - db.session.add(batch) - db.session.commit() - - # 自动创建关联的微调任务(如果用户有默认微调配置则自动设置) - finetune_config_id = None - if user_config and user_config.preferred_finetune_config_id: - finetune_config_id = user_config.preferred_finetune_config_id - - finetune_batch = FinetuneBatch( - batch_id=batch.id, - user_id=current_user_id, - finetune_config_id=finetune_config_id, - status='pending' - ) - db.session.add(finetune_batch) - db.session.commit() - - return jsonify({ - 'message': '任务创建成功,请上传图片', - 'task': batch.to_dict(), - 'finetune_task_id': finetune_batch.id, - 'finetune_config_set': finetune_config_id is not None - }), 201 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'任务创建失败: {str(e)}'}), 500 - -@task_bp.route('/upload/', methods=['POST']) -@jwt_required() -def upload_images(batch_id): - """上传图片到指定任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400 - - if 'files' not in request.files: - return jsonify({'error': '没有选择文件'}), 400 - - files = request.files.getlist('files') - uploaded_files = [] - - # 获取原始图片类型ID - original_type = ImageType.query.filter_by(type_code='original').first() - if not original_type: - return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500 - - for file in files: - if file.filename == '': - continue - if file and allowed_file(file.filename): - # 处理单张图片 - if not file.filename.lower().endswith(('.zip', '.rar')): - # 统一走save_image,内部已实现上传到uploads和预处理 - result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) - if result['success']: - uploaded_files.append(result['image']) - else: - return jsonify({'error': result['error']}), 400 - else: - # 压缩包内图片也会走save_image - results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) - for result in results: - if result['success']: - uploaded_files.append(result['image']) - - if not uploaded_files: - return jsonify({'error': '没有有效的图片文件'}), 400 - - return jsonify({ - 'message': f'成功上传 {len(uploaded_files)} 张图片', - 'uploaded_files': [img.to_dict() for img in uploaded_files] - }), 200 - - except Exception as e: - return jsonify({'error': f'文件上传失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['GET']) -@jwt_required() -def get_task_config(batch_id): - """获取任务配置(显示用户上次的配置或默认配置)""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - # 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置 - if user_config: - suggested_config = { - 'perturbation_config_id': user_config.preferred_perturbation_config_id, - 'epsilon': float(user_config.preferred_epsilon), - 'use_strong_protection': user_config.preferred_purification - } - else: - suggested_config = { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'use_strong_protection': batch.use_strong_protection - } - - return jsonify({ - 'task': batch.to_dict(), - 'suggested_config': suggested_config, - 'current_config': { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'use_strong_protection': batch.use_strong_protection - } - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['PUT']) -@jwt_required() -def update_task_config(batch_id): - """更新任务配置(仅更新任务本身,不影响用户配置)""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法修改配置'}), 400 - - data = request.get_json() - - # 更新任务配置(仅扰动相关) - if 'perturbation_config_id' in data: - batch.perturbation_config_id = data['perturbation_config_id'] - - if 'epsilon' in data: - epsilon = float(data['epsilon']) - if 0 < epsilon <= 255: - batch.preferred_epsilon = epsilon - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'use_strong_protection' in data: - batch.use_strong_protection = bool(data['use_strong_protection']) - - db.session.commit() - - return jsonify({ - 'message': '任务配置更新成功', - 'task': batch.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500 - -@task_bp.route('/start/', methods=['POST']) -@jwt_required() -def start_task(batch_id): - """开始处理任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status not in ['pending', 'failed', 'canceled']: - return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 - # 如果是失败或取消,重置状态为pending - if batch.status in ['failed', 'canceled']: - batch.status = 'pending' - batch.error_message = None - db.session.commit() - - # 检查是否有上传的图片 - image_count = Image.query.filter_by(batch_id=batch_id).count() - if image_count == 0: - return jsonify({'error': '请先上传图片'}), 400 - - # 启动任务处理 - success = TaskService.start_processing(batch) - - if success: - return jsonify({ - 'message': '任务开始处理', - 'task': batch.to_dict() - }), 200 - else: - return jsonify({'error': '任务启动失败'}), 500 - - except Exception as e: - return jsonify({'error': f'任务启动失败: {str(e)}'}), 500 - -@task_bp.route('/list', methods=['GET']) -@jwt_required() -def list_tasks(): - """获取用户的任务列表""" - try: - current_user_id = get_jwt_identity() - - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - batches = Batch.query.filter_by(user_id=current_user_id)\ - .order_by(Batch.created_at.desc())\ - .paginate(page=page, per_page=per_page, error_out=False) - - return jsonify({ - 'tasks': [batch.to_dict() for batch in batches.items], - 'total': batches.total, - 'pages': batches.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500 - -@task_bp.route('/', methods=['GET']) -@jwt_required() -def get_task_detail(batch_id): - """获取任务详情""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取任务相关的图片 - images = Image.query.filter_by(batch_id=batch_id).all() - - return jsonify({ - 'task': batch.to_dict(), - 'images': [img.to_dict() for img in images], - 'image_count': len(images) - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500 - -@task_bp.route('//status', methods=['GET']) -@jwt_required() -def get_task_status(batch_id): - """获取任务处理状态""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - return jsonify({ - 'task_id': batch_id, - 'status': batch.status, - 'progress': TaskService.get_processing_progress(batch_id), - 'error_message': batch.error_message - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 - -# ==================== 微调任务管理接口 ==================== - -@task_bp.route('/finetune/configs', methods=['GET']) -@jwt_required() -def get_finetune_configs(): - """获取所有可用的微调配置""" - try: - configs = FinetuneConfig.query.all() - return jsonify({ - 'configs': [{ - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description - } for config in configs] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/list', methods=['GET']) -@jwt_required() -def list_finetune_tasks(): - """获取用户的所有微调任务列表""" - try: - current_user_id = get_jwt_identity() - - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - finetune_tasks = FinetuneBatch.query.filter_by(user_id=current_user_id)\ - .order_by(FinetuneBatch.created_at.desc())\ - .paginate(page=page, per_page=per_page, error_out=False) - - results = [] - for ft in finetune_tasks.items: - ft_dict = ft.to_dict() - # 添加关联的扰动任务信息 - ft_dict['batch_info'] = ft.batch.to_dict() if ft.batch else None - results.append(ft_dict) - - return jsonify({ - 'finetune_tasks': results, - 'total': finetune_tasks.total, - 'pages': finetune_tasks.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务列表失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/', methods=['GET']) -@jwt_required() -def get_finetune_task(finetune_id): - """获取微调任务详情""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - result = finetune_task.to_dict() - result['batch_info'] = finetune_task.batch.to_dict() if finetune_task.batch else None - - return jsonify({'finetune_task': result}), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务详情失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/by-batch/', methods=['GET']) -@jwt_required() -def get_finetune_by_batch(batch_id): - """根据扰动任务ID获取关联的微调任务""" - try: - current_user_id = get_jwt_identity() - - # 验证扰动任务权限 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '扰动任务不存在或无权限'}), 404 - - finetune_task = FinetuneBatch.query.filter_by(batch_id=batch_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '该扰动任务没有关联的微调任务'}), 404 - - result = finetune_task.to_dict() - result['batch_info'] = batch.to_dict() - - return jsonify({'finetune_task': result}), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//config', methods=['GET']) -@jwt_required() -def get_finetune_config(finetune_id): - """获取微调任务配置(显示用户默认配置或当前配置)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - # 如果用户有配置,显示用户默认配置;否则显示系统默认 - if user_config and user_config.preferred_finetune_config_id: - suggested_config = { - 'finetune_config_id': user_config.preferred_finetune_config_id, - 'finetune_config_name': user_config.preferred_finetune_config.method_name if user_config.preferred_finetune_config else None - } - else: - # 默认使用第一个微调配置 - default_config = FinetuneConfig.query.first() - suggested_config = { - 'finetune_config_id': default_config.id if default_config else 1, - 'finetune_config_name': default_config.method_name if default_config else None - } - - # 当前微调任务的配置 - current_config = { - 'finetune_config_id': finetune_task.finetune_config_id, - 'finetune_config_name': finetune_task.finetune_config.method_name if finetune_task.finetune_config else None - } - - return jsonify({ - 'finetune_task': finetune_task.to_dict(), - 'suggested_config': suggested_config, - 'current_config': current_config - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//config', methods=['PUT']) -@jwt_required() -def update_finetune_config(finetune_id): - """更新微调任务配置(仅限 pending 状态)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - if finetune_task.status != 'pending': - return jsonify({'error': '只能修改待处理状态的微调任务配置'}), 400 - - data = request.get_json() - finetune_config_id = data.get('finetune_config_id') - - if not finetune_config_id: - return jsonify({'error': '请提供微调方法ID'}), 400 - - # 验证微调配置是否存在 - finetune_config = FinetuneConfig.query.get(finetune_config_id) - if not finetune_config: - return jsonify({'error': '微调配置不存在'}), 404 - - finetune_task.finetune_config_id = finetune_config_id - db.session.commit() - - return jsonify({ - 'message': '微调配置更新成功', - 'finetune_task': finetune_task.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新微调配置失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//start', methods=['POST']) -@jwt_required() -def start_finetune(finetune_id): - """启动微调任务""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 检查扰动任务是否已完成 - if finetune_task.batch.status != 'completed': - return jsonify({'error': '扰动任务尚未完成,无法开始微调'}), 400 - - # 检查是否已设置微调配置 - if not finetune_task.finetune_config_id: - return jsonify({'error': '请先设置微调方法'}), 400 - - # 检查状态 - if finetune_task.status not in ['pending', 'failed']: - return jsonify({'error': f'微调任务状态为 {finetune_task.status},无法启动'}), 400 - - # 启动微调任务 - job_ids = TaskService.start_finetune_task(finetune_task) - - if job_ids: - return jsonify({ - 'message': '微调任务已启动', - 'finetune_task_id': finetune_id, - 'job_ids': job_ids - }), 200 - else: - return jsonify({'error': '微调任务启动失败'}), 500 - - except Exception as e: - return jsonify({'error': f'启动微调任务失败: {str(e)}'}), 500 - -@task_bp.route('/finetune//status', methods=['GET']) -@jwt_required() -def get_finetune_task_status(finetune_id): - """获取微调任务状态""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - # 获取详细状态 - status_info = TaskService.get_finetune_task_status(finetune_id) - - return jsonify({ - 'finetune_task_id': finetune_id, - 'status': finetune_task.status, - 'finetune_config': finetune_task.finetune_config.to_dict() if finetune_task.finetune_config else None, - 'details': status_info, - 'error_message': finetune_task.error_message - }), 200 - - except Exception as e: - return jsonify({'error': f'获取微调任务状态失败: {str(e)}'}), 500 - -@task_bp.route('/finetune/', methods=['DELETE']) -@jwt_required() -def delete_finetune_task(finetune_id): - """删除微调任务(仅限 pending 或 failed 状态)""" - try: - current_user_id = get_jwt_identity() - - finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() - if not finetune_task: - return jsonify({'error': '微调任务不存在或无权限'}), 404 - - if finetune_task.status not in ['pending', 'failed']: - return jsonify({'error': '只能删除待处理或失败状态的微调任务'}), 400 - - db.session.delete(finetune_task) - db.session.commit() - - return jsonify({'message': '微调任务删除成功'}), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'删除微调任务失败: {str(e)}'}), 500 + +""" +任务管理控制器 +适配新数据库结构,提供加噪、微调、热力图、数值评估等任务相关接口 +""" + +from flask import Blueprint, request, jsonify +from app import db +from app.controllers.auth_controller import int_jwt_required +from app.database import ( + Task, + Perturbation, Finetune, Heatmap, Evaluate, + PerturbationConfig, FinetuneConfig, DataType, + Image +) +from app.services.task_service import TaskService + + +task_bp = Blueprint('task', __name__) + + +# ==================== 通用任务接口 ==================== + +@task_bp.route('//status', methods=['GET']) +@int_jwt_required +def get_task_status(task_id, current_user_id): + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, current_user_id): + return TaskService.json_error('任务不存在或无权限', 404) + status = TaskService.get_task_status(task_id) + return jsonify(status), 200 + + +@task_bp.route('//cancel', methods=['POST']) +@int_jwt_required +def cancel_task(task_id, current_user_id): + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, current_user_id): + return TaskService.json_error('任务不存在或无权限', 404) + if TaskService.cancel_task(task_id): + return jsonify({'message': '任务已取消'}), 200 + return TaskService.json_error('取消任务失败', 500) + + +# ==================== 加噪任务 ==================== + +@task_bp.route('/perturbation/configs', methods=['GET']) +@int_jwt_required +def list_perturbation_configs(current_user_id): + configs = PerturbationConfig.query.order_by(PerturbationConfig.perturbation_configs_id).all() + return jsonify({'configs': [ + { + 'perturbation_configs_id': cfg.perturbation_configs_id, + 'perturbation_code': cfg.perturbation_code, + 'perturbation_name': cfg.perturbation_name, + 'description': cfg.description, + } + for cfg in configs + ]}), 200 + + +@task_bp.route('/perturbation', methods=['POST']) +@int_jwt_required +def create_perturbation_task(current_user_id): + data = request.get_json() or {} + data_type_id = data.get('data_type_id') + perturbation_configs_id = data.get('perturbation_configs_id') + intensity = data.get('perturbation_intensity') + + if not all([data_type_id, perturbation_configs_id, intensity]): + return TaskService.json_error('缺少必要的任务参数') + + if not DataType.query.get(data_type_id): + return TaskService.json_error('数据集类型不存在') + if not PerturbationConfig.query.get(perturbation_configs_id): + return TaskService.json_error('加噪配置不存在') + + try: + flow_id = data.get('flow_id') + flow_id = int(flow_id) if flow_id is not None else TaskService.generate_flow_id() + except Exception: + return TaskService.json_error('非法的 flow_id 参数') + + try: + pending_status = TaskService.ensure_status('pending') + perturb_type = TaskService.require_task_type('perturbation') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=flow_id, + tasks_type_id=perturb_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + perturbation = Perturbation( + tasks_id=task.tasks_id, + data_type_id=data_type_id, + perturbation_configs_id=perturbation_configs_id, + perturbation_intensity=float(intensity), + perturbation_name=data.get('perturbation_name') + ) + db.session.add(perturbation) + db.session.commit() + + return jsonify({ + 'message': '加噪任务已创建', + 'task': TaskService.serialize_task(task) + }), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建任务失败: {exc}', 500) + + +@task_bp.route('/perturbation/', methods=['PATCH']) +@int_jwt_required +def update_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + data = request.get_json() or {} + pert = task.perturbation + if not pert: + return TaskService.json_error('任务配置不存在', 404) + + if 'data_type_id' in data: + if not DataType.query.get(data['data_type_id']): + return TaskService.json_error('数据集类型不存在') + pert.data_type_id = data['data_type_id'] + if 'perturbation_configs_id' in data: + if not PerturbationConfig.query.get(data['perturbation_configs_id']): + return TaskService.json_error('加噪配置不存在') + pert.perturbation_configs_id = data['perturbation_configs_id'] + if 'perturbation_intensity' in data: + pert.perturbation_intensity = float(data['perturbation_intensity']) + if 'perturbation_name' in data: + pert.perturbation_name = data['perturbation_name'] + if 'description' in data: + task.description = data['description'] + + try: + db.session.commit() + return jsonify({'message': '任务已更新', 'task': TaskService.serialize_task(task)}), 200 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'更新任务失败: {exc}', 500) + + +@task_bp.route('/perturbation//start', methods=['POST']) +@int_jwt_required +def start_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_perturbation_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/perturbation', methods=['GET']) +@int_jwt_required +def list_perturbation_tasks(current_user_id): + try: + perturb_type = TaskService.require_task_type('perturbation') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=perturb_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/perturbation/', methods=['GET']) +@int_jwt_required +def get_perturbation_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 热力图任务 ==================== + +@task_bp.route('/heatmap', methods=['POST']) +@int_jwt_required +def create_heatmap_task(current_user_id): + data = request.get_json() or {} + perturbation_task_id = data.get('perturbation_task_id') + perturbed_image_id = data.get('perturbed_image_id') + + if not perturbation_task_id or not perturbed_image_id: + return TaskService.json_error('缺少必要参数: perturbation_task_id 或 perturbed_image_id') + + perturbation_task = TaskService.load_task_for_user(perturbation_task_id, current_user_id, expected_type='perturbation') + if not perturbation_task: + return TaskService.json_error('加噪任务不存在或无权限', 404) + + status_code = perturbation_task.task_status.task_status_code if perturbation_task.task_status else None + if status_code != 'completed': + return TaskService.json_error('仅支持已完成的加噪任务创建热力图') + + perturbed_image = Image.query.get(perturbed_image_id) + if not perturbed_image or perturbed_image.task_id != perturbation_task_id: + return TaskService.json_error('扰动图片不存在或不属于该任务') + + try: + heatmap_type = TaskService.require_task_type('heatmap') + pending_status = TaskService.ensure_status('pending') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=perturbation_task.flow_id, + tasks_type_id=heatmap_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + heatmap = Heatmap( + tasks_id=task.tasks_id, + images_id=perturbed_image_id, + heatmap_name=data.get('heatmap_name') + ) + db.session.add(heatmap) + db.session.commit() + + return jsonify({'message': '热力图任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建热力图任务失败: {exc}', 500) + + +@task_bp.route('/heatmap//start', methods=['POST']) +@int_jwt_required +def start_heatmap_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + if not task.heatmap: + return TaskService.json_error('热力图任务未配置对应图片', 400) + + job_id = TaskService.start_heatmap_task(task_id, task.heatmap.images_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/heatmap', methods=['GET']) +@int_jwt_required +def list_heatmap_tasks(current_user_id): + try: + heatmap_type = TaskService.require_task_type('heatmap') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=heatmap_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/heatmap/', methods=['GET']) +@int_jwt_required +def get_heatmap_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 微调任务 ==================== + +@task_bp.route('/finetune/configs', methods=['GET']) +@int_jwt_required +def list_finetune_configs(current_user_id): + configs = FinetuneConfig.query.order_by(FinetuneConfig.finetune_configs_id).all() + return jsonify({'configs': [ + { + 'finetune_configs_id': cfg.finetune_configs_id, + 'finetune_code': cfg.finetune_code, + 'finetune_name': cfg.finetune_name, + 'description': cfg.description, + } + for cfg in configs + ]}), 200 + + +@task_bp.route('/finetune/from-perturbation', methods=['POST']) +@int_jwt_required +def create_finetune_from_perturbation(current_user_id): + data = request.get_json() or {} + perturbation_task_id = data.get('perturbation_task_id') + finetune_configs_id = data.get('finetune_configs_id') + + if not perturbation_task_id or not finetune_configs_id: + return TaskService.json_error('缺少必要参数: perturbation_task_id 或 finetune_configs_id') + + perturbation_task = TaskService.load_task_for_user(perturbation_task_id, current_user_id, expected_type='perturbation') + if not perturbation_task: + return TaskService.json_error('加噪任务不存在或无权限', 404) + + if not FinetuneConfig.query.get(finetune_configs_id): + return TaskService.json_error('微调配置不存在') + + try: + pending_status = TaskService.ensure_status('pending') + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=perturbation_task.flow_id, + tasks_type_id=finetune_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + finetune = Finetune( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_configs_id, + data_type_id=data.get('data_type_id'), + finetune_name=data.get('finetune_name') + ) + db.session.add(finetune) + db.session.commit() + + return jsonify({'message': '微调任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建微调任务失败: {exc}', 500) + + +@task_bp.route('/finetune/from-upload', methods=['POST']) +@int_jwt_required +def create_finetune_from_upload(current_user_id): + user = TaskService.get_user(current_user_id) + if not user: + return TaskService.json_error('用户不存在', 404) + + role_code = user.role.role_code if user.role else 'user' + if role_code not in ('vip', 'admin'): + return TaskService.json_error('仅限VIP或管理员使用上传微调功能', 403) + + data = request.get_json() or {} + finetune_configs_id = data.get('finetune_configs_id') + if not finetune_configs_id: + return TaskService.json_error('缺少必要参数: finetune_configs_id') + + if not FinetuneConfig.query.get(finetune_configs_id): + return TaskService.json_error('微调配置不存在') + + try: + flow_id = data.get('flow_id') + if flow_id is not None: + flow_id = int(flow_id) + existing = Task.query.filter_by(flow_id=flow_id).first() + if existing: + return TaskService.json_error('flow_id 已被占用,请勿复用已有任务流') + else: + flow_id = TaskService.generate_flow_id() + except Exception: + return TaskService.json_error('非法的 flow_id 参数') + + try: + pending_status = TaskService.ensure_status('pending') + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=flow_id, + tasks_type_id=finetune_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + finetune = Finetune( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_configs_id, + data_type_id=data.get('data_type_id'), + finetune_name=data.get('finetune_name') + ) + db.session.add(finetune) + db.session.commit() + + return jsonify({ + 'message': '上传微调任务已创建', + 'task': TaskService.serialize_task(task) + }), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建微调任务失败: {exc}', 500) + + +@task_bp.route('/finetune//start', methods=['POST']) +@int_jwt_required +def start_finetune_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_finetune_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/finetune', methods=['GET']) +@int_jwt_required +def list_finetune_tasks(current_user_id): + source = request.args.get('source') + try: + finetune_type = TaskService.require_task_type('finetune') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + query = Task.query.filter_by(user_id=current_user_id, tasks_type_id=finetune_type.task_type_id) + + tasks = query.order_by(Task.created_at.desc()).all() + serialized = [] + for task in tasks: + task_dict = TaskService.serialize_task(task) + if source and task_dict.get('finetune', {}).get('source') != source: + continue + serialized.append(task_dict) + return jsonify({'tasks': serialized}), 200 + + +@task_bp.route('/finetune/', methods=['GET']) +@int_jwt_required +def get_finetune_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 + + +# ==================== 数值评估任务 ==================== + +@task_bp.route('/evaluate', methods=['POST']) +@int_jwt_required +def create_evaluate_task(current_user_id): + data = request.get_json() or {} + finetune_task_id = data.get('finetune_task_id') + if not finetune_task_id: + return TaskService.json_error('缺少必要参数: finetune_task_id') + + finetune_task = TaskService.load_task_for_user(finetune_task_id, current_user_id, expected_type='finetune') + if not finetune_task: + return TaskService.json_error('微调任务不存在或无权限', 404) + + # 仅允许基于加噪微调创建评估 + if TaskService.determine_finetune_source(finetune_task) != 'perturbation': + return TaskService.json_error('数值评估仅支持基于加噪任务的微调结果') + + if not finetune_task.finetune: + return TaskService.json_error('微调任务未配置详情', 400) + + try: + evaluate_type = TaskService.require_task_type('evaluate') + pending_status = TaskService.ensure_status('pending') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + + try: + task = Task( + flow_id=finetune_task.flow_id, + tasks_type_id=evaluate_type.task_type_id, + user_id=current_user_id, + tasks_status_id=pending_status.task_status_id, + description=data.get('description') + ) + db.session.add(task) + db.session.flush() + + evaluate = Evaluate( + tasks_id=task.tasks_id, + finetune_configs_id=finetune_task.finetune.finetune_configs_id, + evaluate_name=data.get('evaluate_name') + ) + db.session.add(evaluate) + db.session.commit() + + return jsonify({'message': '评估任务已创建', 'task': TaskService.serialize_task(task)}), 201 + except Exception as exc: + db.session.rollback() + return TaskService.json_error(f'创建评估任务失败: {exc}', 500) + + +@task_bp.route('/evaluate//start', methods=['POST']) +@int_jwt_required +def start_evaluate_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + + job_id = TaskService.start_evaluate_task(task_id) + if not job_id: + return TaskService.json_error('任务启动失败', 500) + return jsonify({'message': '任务已启动', 'job_id': job_id}), 200 + + +@task_bp.route('/evaluate', methods=['GET']) +@int_jwt_required +def list_evaluate_tasks(current_user_id): + try: + evaluate_type = TaskService.require_task_type('evaluate') + except ValueError as exc: + return TaskService.json_error(str(exc), 500) + tasks = Task.query.filter_by(user_id=current_user_id, tasks_type_id=evaluate_type.task_type_id).order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200 + + +@task_bp.route('/evaluate/', methods=['GET']) +@int_jwt_required +def get_evaluate_task(task_id, current_user_id): + task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') + if not task: + return TaskService.json_error('任务不存在或无权限', 404) + return jsonify({'task': TaskService.serialize_task(task)}), 200 diff --git a/src/backend/app/controllers/user_controller.py b/src/backend/app/controllers/user_controller.py index 3d99fda..3325b9e 100644 --- a/src/backend/app/controllers/user_controller.py +++ b/src/backend/app/controllers/user_controller.py @@ -1,129 +1,119 @@ -""" -用户管理控制器 -处理用户配置等功能 -""" - -from flask import Blueprint, request, jsonify -from flask_jwt_extended import jwt_required -from app import db -from app.database import User, UserConfig, Perturbation, Finetune -from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 - -user_bp = Blueprint('user', __name__) - -@user_bp.route('/config', methods=['GET']) -@int_jwt_required -def get_user_config(current_user_id): - """获取用户配置""" - try: - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - # 如果没有配置,创建默认配置 - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - db.session.commit() - - return jsonify({ - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/config', methods=['PUT']) -@int_jwt_required -def update_user_config(current_user_id): - """更新用户配置""" - try: - data = request.get_json() - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - - # 更新配置字段 - if 'perturbation_configs_id' in data: - user_config.perturbation_configs_id = data['perturbation_configs_id'] - - if 'perturbation_intensity' in data: - intensity = float(data['perturbation_intensity']) - if 0 < epsilon <= 255: - user_config.perturbation_intensity = intensity - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'finetune_config_id' in data: - user_config.finetune_config_id = data['finetune_config_id'] - - db.session.commit() - - return jsonify({ - 'message': '用户配置更新成功', - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/algorithms', methods=['GET']) -@jwt_required() -def get_available_algorithms(): - """获取可用的算法列表""" - try: - perturbation_configs = Perturbation.query.all() - finetune_configs = Finetune.query.all() - - return jsonify({ - 'perturbation_algorithms': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description, - } for config in perturbation_configs - ], - 'finetune_methods': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description - } for config in finetune_configs - ] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500 - -@user_bp.route('/stats', methods=['GET']) -@int_jwt_required -def get_user_stats(current_user_id): - """获取用户统计信息""" - try: - from app.database import Task, Image - - # 统计用户的任务和图片数量 - total_tasks = Task.query.filter_by(user_id=current_user_id).count() - completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count() - processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count() - failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count() - - total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count() - - return jsonify({ - 'stats': { - 'total_tasks': total_tasks, - 'completed_tasks': completed_tasks, - 'processing_tasks': processing_tasks, - 'failed_tasks': failed_tasks, - 'total_images': total_images - } - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500 \ No newline at end of file + +""" +用户管理控制器 +负责用户配置、任务汇总等接口 +""" + +from flask import Blueprint, request, jsonify +from app import db +from app.controllers.auth_controller import int_jwt_required +from app.database import UserConfig, Task, TaskType, TaskStatus + + +user_bp = Blueprint('user', __name__) + + +def _json_error(message, status_code=400): + return jsonify({'error': message}), status_code + + +def _get_or_create_user_config(user_id): + config = UserConfig.query.filter_by(user_id=user_id).first() + if not config: + config = UserConfig(user_id=user_id) + db.session.add(config) + db.session.commit() + return config + + +def _serialize_config(config): + return { + 'user_configs_id': config.user_configs_id, + 'user_id': config.user_id, + 'data_type_id': config.data_type_id, + 'perturbation_configs_id': config.perturbation_configs_id, + 'perturbation_intensity': config.perturbation_intensity, + 'finetune_configs_id': config.finetune_configs_id, + 'created_at': config.created_at.isoformat() if config.created_at else None, + 'updated_at': config.updated_at.isoformat() if config.updated_at else None, + } + + +def _serialize_task(task): + status_code = task.task_status.task_status_code if task.task_status else None + task_type_code = task.task_type.task_type_code if task.task_type else None + return { + 'task_id': task.tasks_id, + 'flow_id': task.flow_id, + 'task_type': task_type_code, + 'status': status_code, + 'created_at': task.created_at.isoformat() if task.created_at else None, + 'started_at': task.started_at.isoformat() if task.started_at else None, + 'finished_at': task.finished_at.isoformat() if task.finished_at else None, + 'description': task.description, + 'error_message': task.error_message + } + + +@user_bp.route('/config', methods=['GET']) +@int_jwt_required +def get_user_config(current_user_id): + config = _get_or_create_user_config(current_user_id) + return jsonify({'config': _serialize_config(config)}), 200 + + +@user_bp.route('/config', methods=['PUT']) +@int_jwt_required +def update_user_config(current_user_id): + config = _get_or_create_user_config(current_user_id) + data = request.get_json() or {} + + allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'} + for key, value in data.items(): + if key in allowed_fields: + if key == 'perturbation_intensity' and value is not None: + try: + value = float(value) + except (TypeError, ValueError): + return _json_error('perturbation_intensity 参数格式不正确') + setattr(config, key, value) + + try: + db.session.commit() + return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200 + except Exception as exc: + db.session.rollback() + return _json_error(f'更新配置失败: {exc}', 500) + + +@user_bp.route('/tasks', methods=['GET']) +@int_jwt_required +def list_user_tasks(current_user_id): + task_type_code = request.args.get('type') + status_code = request.args.get('status') + + query = Task.query.filter_by(user_id=current_user_id) + + if task_type_code: + task_type = TaskType.query.filter_by(task_type_code=task_type_code).first() + if not task_type: + return _json_error('任务类型不存在', 404) + query = query.filter(Task.tasks_type_id == task_type.task_type_id) + + if status_code: + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if not status: + return _json_error('任务状态不存在', 404) + query = query.filter(Task.tasks_status_id == status.task_status_id) + + tasks = query.order_by(Task.created_at.desc()).all() + return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200 + + +@user_bp.route('/tasks/', methods=['GET']) +@int_jwt_required +def get_user_task(task_id, current_user_id): + task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first() + if not task: + return _json_error('任务不存在或无权限', 404) + return jsonify({'task': _serialize_task(task)}), 200 diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index 5c64d9e..933ad78 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -3,16 +3,18 @@ 处理图像上传、保存等功能 """ +import io import os import uuid import zipfile import fcntl import time +from datetime import datetime from werkzeug.utils import secure_filename -from flask import current_app +from flask import current_app, jsonify from PIL import Image as PILImage from app import db -from app.database import Image +from app.database import Image, ImageType from app.utils.file_utils import allowed_file class ImageService: @@ -254,4 +256,173 @@ class ImageService: except Exception as e: db.session.rollback() - return {'success': False, 'error': f'删除图片失败: {str(e)}'} \ No newline at end of file + return {'success': False, 'error': f'删除图片失败: {str(e)}'} + + # ==================== 控制器辅助功能 ==================== + + DEFAULT_TARGET_SIZE = 512 + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'} + + @staticmethod + def json_error(message, status_code=400): + """统一错误响应""" + return jsonify({'error': message}), status_code + + @staticmethod + def get_image_type_by_code(code): + """根据代码获取图片类型""" + return ImageType.query.filter_by(image_code=code).first() + + @staticmethod + def save_original_images(task, files, target_dir, image_type_code='original', target_size=None): + """保存原图上传""" + if not files: + return False, '未检测到文件上传' + + image_type = ImageService.get_image_type_by_code(image_type_code) + if not image_type: + return False, f'未配置图片类型: {image_type_code}' + + os.makedirs(target_dir, exist_ok=True) + + saved_records = [] + saved_paths = [] + size = target_size or ImageService.DEFAULT_TARGET_SIZE + + try: + for file in files: + if not file or not file.filename: + continue + if not allowed_file(file.filename): + continue + + extension = os.path.splitext(file.filename)[1].lower() + if extension not in ImageService.IMAGE_EXTENSIONS: + continue + + processed = ImageService._prepare_image(file, size) + filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir) + image = ImageService._create_image_record( + task, + image_type.image_types_id, + filename, + path, + width, + height, + file_size + ) + saved_records.append(image) + saved_paths.append(path) + + if not saved_records: + db.session.rollback() + return False, '未上传有效的图片文件' + + db.session.commit() + return True, saved_records + except Exception as exc: + db.session.rollback() + for path in saved_paths: + if os.path.exists(path): + try: + os.remove(path) + except OSError: + pass + return False, f'上传图片失败: {exc}' + + @staticmethod + def _prepare_image(file_storage, target_size): + """裁剪并缩放上传图片""" + file_storage.stream.seek(0) + image = PILImage.open(file_storage.stream).convert('RGB') + width, height = image.size + min_dim = min(width, height) + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + image = image.crop((left, top, left + min_dim, top + min_dim)) + return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS) + + @staticmethod + def _save_processed_image(image, target_dir): + """将处理后的图片保存为PNG""" + timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') + filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" + path = os.path.join(target_dir, filename) + image.save(path, format='PNG') + return filename, path, image.width, image.height, os.path.getsize(path) + + @staticmethod + def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None): + """创建图片数据库记录""" + image = Image( + task_id=task.tasks_id, + image_types_id=image_type_id, + father_id=father_id, + stored_filename=filename, + file_path=path, + file_size=file_size, + width=width, + height=height + ) + db.session.add(image) + return image + + @staticmethod + def zip_directory(directory): + """打包目录为zip""" + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + if os.path.isdir(directory): + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + arcname = os.path.relpath(file_path, directory) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + @staticmethod + def zip_multiple_directories(directories): + """打包多个目录""" + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + if isinstance(directories, dict): + iterable = directories.items() + else: + iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories) + + for label, directory in iterable: + if not os.path.isdir(directory): + continue + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, directory) + arcname = os.path.join(label or 'output', rel_path) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + @staticmethod + def serialize_image(image): + """图片序列化""" + if not image: + return None + return { + 'image_id': image.images_id, + 'task_id': image.task_id, + 'stored_filename': image.stored_filename, + 'file_path': image.file_path, + 'file_size': image.file_size, + 'width': image.width, + 'height': image.height, + 'image_type': image.image_type.image_code if image.image_type else None + } \ No newline at end of file diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index d60f1ec..80384ef 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -7,7 +7,7 @@ import os import logging from datetime import datetime -from flask import current_app +from flask import current_app, jsonify from redis import Redis from rq import Queue from rq.job import Job @@ -16,7 +16,7 @@ from app.database import ( Task, TaskStatus, TaskType, Perturbation, Finetune, Heatmap, Evaluate, Image, ImageType, DataType, - PerturbationConfig, FinetuneConfig + PerturbationConfig, FinetuneConfig, User ) from config.algorithm_config import AlgorithmConfig from config.settings import Config @@ -116,6 +116,135 @@ class TaskService: str(flow_id) ) + # ==================== 通用辅助功能 ==================== + + @staticmethod + def json_error(message, status_code=400): + """统一的错误响应""" + return jsonify({'error': message}), status_code + + @staticmethod + def get_task_type(code): + """根据任务类型代码获取TaskType""" + return TaskType.query.filter_by(task_type_code=code).first() + + @staticmethod + def require_task_type(code): + """确保任务类型存在""" + task_type = TaskService.get_task_type(code) + if not task_type: + raise ValueError(f"Task type '{code}' is not configured") + return task_type + + @staticmethod + def get_status_by_code(code): + """根据状态代码获取TaskStatus""" + return TaskStatus.query.filter_by(task_status_code=code).first() + + @staticmethod + def ensure_status(code): + """确保任务状态存在""" + status = TaskService.get_status_by_code(code) + if not status: + raise ValueError(f"Task status '{code}' is not configured") + return status + + @staticmethod + def generate_flow_id(): + """生成唯一的flow_id""" + base = int(datetime.utcnow().timestamp() * 1000) + while Task.query.filter_by(flow_id=base).first(): + base += 1 + return base + + @staticmethod + def ensure_task_owner(task, user_id): + """验证任务归属""" + return bool(task and task.user_id == user_id) + + @staticmethod + def get_task_type_code(task): + """获取任务类型代码""" + return task.task_type.task_type_code if task and task.task_type else None + + @staticmethod + def load_task_for_user(task_id, user_id, expected_type=None): + """根据任务ID加载用户的任务,可选检查类型""" + task = Task.query.get(task_id) + if not TaskService.ensure_task_owner(task, user_id): + return None + if expected_type: + task_type = TaskService.get_task_type_code(task) + if task_type != expected_type: + return None + return task + + @staticmethod + def determine_finetune_source(finetune_task): + """判断微调任务来源""" + perturb_type = TaskService.require_task_type('perturbation') + sibling_perturbation = Task.query.filter( + Task.flow_id == finetune_task.flow_id, + Task.tasks_type_id == perturb_type.task_type_id, + Task.tasks_id != finetune_task.tasks_id + ).first() + return 'perturbation' if sibling_perturbation else 'uploaded' + + @staticmethod + def serialize_task(task): + """任务序列化""" + task_type = TaskService.get_task_type_code(task) + status = task.task_status.task_status_code if task and task.task_status else None + base = { + 'task_id': task.tasks_id, + 'flow_id': task.flow_id, + 'task_type': task_type, + 'status': status, + 'user_id': task.user_id, + 'description': task.description, + 'created_at': task.created_at.isoformat() if task.created_at else None, + 'started_at': task.started_at.isoformat() if task.started_at else None, + 'finished_at': task.finished_at.isoformat() if task.finished_at else None, + 'error_message': task.error_message, + } + + if task_type == 'perturbation' and task.perturbation: + base['perturbation'] = { + 'data_type_id': task.perturbation.data_type_id, + 'perturbation_configs_id': task.perturbation.perturbation_configs_id, + 'perturbation_intensity': float(task.perturbation.perturbation_intensity), + 'perturbation_name': task.perturbation.perturbation_name, + } + elif task_type == 'finetune' and task.finetune: + try: + source = TaskService.determine_finetune_source(task) + except ValueError: + source = 'uploaded' + base['finetune'] = { + 'finetune_configs_id': task.finetune.finetune_configs_id, + 'data_type_id': task.finetune.data_type_id, + 'finetune_name': task.finetune.finetune_name, + 'source': source + } + elif task_type == 'heatmap' and task.heatmap: + base['heatmap'] = { + 'perturbed_image_id': task.heatmap.images_id, + 'heatmap_name': task.heatmap.heatmap_name + } + elif task_type == 'evaluate' and task.evaluation: + base['evaluate'] = { + 'finetune_configs_id': task.evaluation.finetune_configs_id, + 'evaluate_name': task.evaluation.evaluate_name, + 'evaluation_results_id': task.evaluation.evaluation_results_id + } + + return base + + @staticmethod + def get_user(user_id): + """获取用户""" + return User.query.get(user_id) + # ==================== Redis/RQ 连接管理 ==================== @staticmethod