diff --git a/.gitignore b/.gitignore index 370ae49..ecfa54e 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,24 @@ uploads/ .github/ # pycharm 配置 -.idea/ \ No newline at end of file +.idea/ + +# pytest配置 +*.ini + +# 测试相关 +.pytest_cache/ +.coverage +.coverage.* +htmlcov/ +.tox/ +.nox/ +coverage.xml +*.cover +*.py,cover +.hypothesis/ +pytest_cache/ +test-results/ +test-reports/ +tests/ +run_tests.py \ No newline at end of file diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py index b84c7a7..37abb2b 100644 --- a/src/backend/app/controllers/auth_controller.py +++ b/src/backend/app/controllers/auth_controller.py @@ -9,7 +9,7 @@ from app import db from app.database import User, UserConfig from functools import wraps import re -from app.services.email_service import send_verification_code, verify_code +from app.services.email import VerificationService def int_jwt_required(f): """获取JWT身份并转换为整数的装饰器""" @@ -40,8 +40,11 @@ def send_email_verification_code(): if not re.match(email_pattern, email): return jsonify({'error': '邮箱格式不正确'}), 400 - send_verification_code(email, purpose=purpose) - return jsonify({'message': '验证码已发送'}), 200 + verification_service = VerificationService() + if verification_service.send_code(email, purpose): + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败,请稍后重试'}), 500 except Exception as e: return jsonify({'error': f'发送验证码失败: {str(e)}'}), 500 @@ -72,8 +75,8 @@ def register(): if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400 - # 验证验证码 - if not code or not verify_code(email, code, purpose='register'): + verification_service = VerificationService() + if not code or not verification_service.verify_code(email, code, purpose = 'register'): return jsonify({'error': '验证码无效或已过期'}), 400 # 创建用户(默认为普通用户,role_id=3) @@ -160,12 +163,15 @@ def change_password(current_user_id): db.session.rollback() return jsonify({'error': f'密码修改失败: {str(e)}'}), 500 -@auth_bp.route('/change-email', methods = ['POST']) +@auth_bp.route('/change-email', methods=['POST']) @int_jwt_required def change_email(current_user_id): """修改邮箱""" try: - user = User.query.filter_by(current_user_id) + user = User.query.filter_by(user_id=current_user_id).first() + if not user: + return jsonify({'error': '用户不存在'}), 404 + data = request.get_json() new_email = data.get('new_email') code = data.get('code') @@ -173,10 +179,15 @@ def change_email(current_user_id): if not new_email: return jsonify({'error': '新邮箱不能为空'}), 400 - if not User.query.filter(new_email).first(): - return jsonify({'error':'该邮箱已被使用'}), 400 + # 验证邮箱格式 + email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(email_pattern, new_email): + return jsonify({'error': '邮箱格式不正确'}), 400 + + if User.query.filter_by(email=new_email).first(): + return jsonify({'error': '该邮箱已被使用'}), 400 - if not code or not verify_code(email, code, purpose='register'): + if not code or not verify_code(new_email, code, purpose='change_email'): return jsonify({'error': '验证码无效或已过期'}), 400 user.email = new_email @@ -187,22 +198,25 @@ def change_email(current_user_id): db.session.rollback() return jsonify({'error': f'邮箱修改失败: {str(e)}'}), 500 -@auth_bp.route('/change-username', methods = ['POST']) +@auth_bp.route('/change-username', methods=['POST']) @int_jwt_required def change_username(current_user_id): """修改用户名""" try: - user = User.query.filter_by(current_user_id) + user = User.query.filter_by(user_id=current_user_id).first() + if not user: + return jsonify({'error': '用户不存在'}), 404 + data = request.get_json() new_username = data.get('new_username') if not new_username: return jsonify({'error': '新名称不能为空'}), 400 - if not User.query.filter(new_username).first(): - return jsonify({'error':'该用户名已被使用'}), 400 + if User.query.filter_by(username=new_username).first(): + return jsonify({'error': '该用户名已被使用'}), 400 - user.name = new_username + user.username = new_username db.session.commit() return jsonify({'message': '用户名修改成功'}), 200 diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index 6f0c800..eff5edb 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -4,11 +4,12 @@ """ import os -import base64 -from flask import Blueprint, request, jsonify, send_file +import uuid +from flask import Blueprint, request, jsonify, send_file, Response from app.controllers.auth_controller import int_jwt_required from app.services.task_service import TaskService from app.services.image_service import ImageService +from app.services.image.image_serializer import get_image_serializer from app.database import Image, ImageType image_bp = Blueprint('image', __name__) @@ -40,9 +41,10 @@ def upload_original_images(current_user_id): status_code = 500 return ImageService.json_error(result, status_code) + serializer = get_image_serializer() return jsonify({ 'message': '图片上传成功', - 'images': [ImageService.image_to_base64(img) for img in result], + 'images': [serializer.to_dict(img) for img in result], 'flow_id': task.flow_id }), 201 @@ -78,162 +80,7 @@ def get_image_file(image_id, current_user_id): return send_file(image.file_path, mimetype=mimetype) -# ==================== 任务图片获取(返回 base64) ==================== -@image_bp.route('/task/', methods=['GET']) -@int_jwt_required -def get_task_images(task_id, current_user_id): - """获取任务的所有图片(base64格式)""" - task = TaskService.load_task_for_user(task_id, current_user_id) - if not task: - return ImageService.json_error('任务不存在或无权限', 404) - - image_type_code = request.args.get('type') - - query = Image.query.filter_by(task_id=task_id) - if image_type_code: - image_type = ImageType.query.filter_by(image_code=image_type_code).first() - if image_type: - query = query.filter_by(image_types_id=image_type.image_types_id) - - images = query.all() - - return jsonify({ - 'task_id': task_id, - 'images': [ImageService.image_to_base64(img) for img in images], - 'total': len(images) - }), 200 - - -@image_bp.route('/perturbation/', methods=['GET']) -@int_jwt_required -def get_perturbation_images(task_id, current_user_id): - """获取加噪任务的结果图片(base64格式)""" - task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') - if not task: - return ImageService.json_error('任务不存在或无权限', 404) - - perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - if not perturbed_type: - return ImageService.json_error('图片类型未配置', 500) - - images = Image.query.filter_by( - task_id=task_id, - image_types_id=perturbed_type.image_types_id - ).all() - - return jsonify({ - 'task_id': task_id, - 'task_type': 'perturbation', - 'images': [ImageService.image_to_base64(img) for img in images], - 'total': len(images) - }), 200 - - -@image_bp.route('/heatmap/', methods=['GET']) -@int_jwt_required -def get_heatmap_images(task_id, current_user_id): - """获取热力图任务的结果图片(base64格式)""" - task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') - if not task: - return ImageService.json_error('任务不存在或无权限', 404) - - heatmap_type = ImageType.query.filter_by(image_code='heatmap').first() - if not heatmap_type: - return ImageService.json_error('图片类型未配置', 500) - - images = Image.query.filter_by( - task_id=task_id, - image_types_id=heatmap_type.image_types_id - ).all() - - return jsonify({ - 'task_id': task_id, - 'task_type': 'heatmap', - 'images': [ImageService.image_to_base64(img) for img in images], - 'total': len(images) - }), 200 - - -@image_bp.route('/finetune/', methods=['GET']) -@int_jwt_required -def get_finetune_images(task_id, current_user_id): - """获取微调任务的生成图片(base64格式)""" - task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') - if not task: - return ImageService.json_error('任务不存在或无权限', 404) - - if not task.finetune: - return ImageService.json_error('微调任务配置不存在', 404) - - try: - source = TaskService.determine_finetune_source(task) - except ValueError as exc: - return ImageService.json_error(str(exc), 500) - - result = {'task_id': task_id, 'task_type': 'finetune', 'source': source} - - if source == 'perturbation': - original_gen_type = ImageType.query.filter_by(image_code='original_generate').first() - perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first() - - original_images = [] - perturbed_images = [] - - if original_gen_type: - original_images = Image.query.filter_by( - task_id=task_id, - image_types_id=original_gen_type.image_types_id - ).all() - - if perturbed_gen_type: - perturbed_images = Image.query.filter_by( - task_id=task_id, - image_types_id=perturbed_gen_type.image_types_id - ).all() - - result['original_generate'] = [ImageService.image_to_base64(img) for img in original_images] - result['perturbed_generate'] = [ImageService.image_to_base64(img) for img in perturbed_images] - result['total'] = len(original_images) + len(perturbed_images) - else: - uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first() - uploaded_images = [] - - if uploaded_gen_type: - uploaded_images = Image.query.filter_by( - task_id=task_id, - image_types_id=uploaded_gen_type.image_types_id - ).all() - - result['uploaded_generate'] = [ImageService.image_to_base64(img) for img in uploaded_images] - result['total'] = len(uploaded_images) - - return jsonify(result), 200 - - -@image_bp.route('/evaluate/', methods=['GET']) -@int_jwt_required -def get_evaluate_images(task_id, current_user_id): - """获取评估任务的结果图片(base64格式)""" - task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') - if not task: - return ImageService.json_error('任务不存在或无权限', 404) - - report_type = ImageType.query.filter_by(image_code='report').first() - if not report_type: - return ImageService.json_error('图片类型未配置', 500) - - images = Image.query.filter_by( - task_id=task_id, - image_types_id=report_type.image_types_id - ).all() - - return jsonify({ - 'task_id': task_id, - 'task_type': 'evaluate', - 'images': [ImageService.image_to_base64(img) for img in images], - 'total': len(images) - }), 200 # ==================== 图片删除 ==================== @@ -257,188 +104,133 @@ def delete_image(image_id, current_user_id): return jsonify({'message': '图片删除成功'}), 200 -# ==================== 统一预览接口 ==================== -@image_bp.route('/preview/flow/', methods=['GET']) -@int_jwt_required -def preview_flow_images(flow_id, current_user_id): - """ - 获取工作流下所有图片的统一预览接口 - - 返回数据结构: - { - "flow_id": 123, - "original": [...], # 原图 - "perturbed": [...], # 加噪图 - "original_generate": [...], # 原图微调生成 - "perturbed_generate": [...], # 加噪图微调生成 - "uploaded_generate": [...], # 上传图微调生成 - "heatmap": [...], # 热力图 - "report": [...] # 评估报告图 - } - """ - from app.database import Task - - # 验证用户对该flow的访问权限 - tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all() - if not tasks: - return ImageService.json_error('工作流不存在或无权限', 404) - - # 获取所有图片类型 - image_types = { - 'original': ImageType.query.filter_by(image_code='original').first(), - 'perturbed': ImageType.query.filter_by(image_code='perturbed').first(), - 'original_generate': ImageType.query.filter_by(image_code='original_generate').first(), - 'perturbed_generate': ImageType.query.filter_by(image_code='perturbed_generate').first(), - 'uploaded_generate': ImageType.query.filter_by(image_code='uploaded_generate').first(), - 'heatmap': ImageType.query.filter_by(image_code='heatmap').first(), - 'report': ImageType.query.filter_by(image_code='report').first(), - } - - # 收集所有任务ID - task_ids = [t.tasks_id for t in tasks] - - result = { - 'flow_id': flow_id, - 'original': [], - 'perturbed': [], - 'original_generate': [], - 'perturbed_generate': [], - 'uploaded_generate': [], - 'heatmap': [], - 'report': [] - } - - # 查询各类型图片 - for type_code, image_type in image_types.items(): - if image_type: - images = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == image_type.image_types_id - ).all() - result[type_code] = [ImageService.image_to_base64(img) for img in images if img] - - # 统计总数 - result['total'] = sum(len(result[k]) for k in result if k not in ['flow_id', 'total']) - - return jsonify(result), 200 -@image_bp.route('/preview/task/', methods=['GET']) +# ==================== 二进制流式传输接口 ==================== + +@image_bp.route('/binary/task/', methods=['GET']) @int_jwt_required -def preview_task_images(task_id, current_user_id): +def get_task_images_binary(task_id, current_user_id): """ - 获取单个任务的所有图片预览 - - 根据任务类型返回相应的图片: - - perturbation: 原图 + 加噪图 - - finetune: 原图 + 生成图(original_generate/perturbed_generate/uploaded_generate) - - heatmap: 原图 + 加噪图 + 热力图 - - evaluate: 生成图 + 报告图 + 以 multipart/mixed 格式流式返回任务的所有图片二进制数据 + + Query参数: + type: 可选,指定图片类型代码 + + 响应格式: multipart/mixed + 每个part包含: + - Content-Type: 图片MIME类型 + - Content-Disposition: 文件名 + - X-Image-Id: 图片ID + - X-Image-Type: 图片类型代码 + - X-Image-Width: 宽度 + - X-Image-Height: 高度 + - 图片二进制数据 """ - from app.database import Task, TaskType - task = TaskService.load_task_for_user(task_id, current_user_id) if not task: return ImageService.json_error('任务不存在或无权限', 404) - task_type_code = TaskService.get_task_type_code(task) + image_type_code = request.args.get('type') - result = { - 'task_id': task_id, - 'flow_id': task.flow_id, - 'task_type': task_type_code, - 'images': {} - } + query = Image.query.filter_by(task_id=task_id) + if image_type_code: + image_type = ImageType.query.filter_by(image_code=image_type_code).first() + if image_type: + query = query.filter_by(image_types_id=image_type.image_types_id) - # 根据任务类型获取相关图片 - if task_type_code == 'perturbation': - result['images'] = ImageService._get_perturbation_preview(task) - elif task_type_code == 'finetune': - result['images'] = ImageService._get_finetune_preview(task) - elif task_type_code == 'heatmap': - result['images'] = ImageService._get_heatmap_preview(task) - elif task_type_code == 'evaluate': - result['images'] = ImageService._get_evaluate_preview(task) + images = query.all() - return jsonify(result), 200 - - - + if not images: + return ImageService.json_error('没有找到图片', 404) + + # 按类型分组 + images_dict = {} + for img in images: + type_code = img.image_type.image_code if img.image_type else 'unknown' + if type_code not in images_dict: + images_dict[type_code] = [] + images_dict[type_code].append(img) + + boundary = uuid.uuid4().hex + serializer = get_image_serializer() + + return Response( + serializer.generate_multipart_stream(images_dict, boundary), + mimetype=f'multipart/mixed; boundary={boundary}', + headers={ + 'X-Total-Images': str(len(images)), + 'X-Task-Id': str(task_id) + } + ) -@image_bp.route('/preview/compare/', methods=['GET']) +@image_bp.route('/binary/flow/', methods=['GET']) @int_jwt_required -def preview_compare_images(flow_id, current_user_id): +def get_flow_images_binary(flow_id, current_user_id): """ - 获取对比预览数据,用于展示原图vs加噪图、原图生成vs加噪图生成的对比 + 以 multipart/mixed 格式流式返回工作流的所有图片二进制数据 - 返回配对的图片数据,便于前端展示对比效果 + Query参数: + types: 可选,逗号分隔的图片类型代码列表 + + 响应格式: multipart/mixed """ from app.database import Task - # 验证权限 tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all() if not tasks: return ImageService.json_error('工作流不存在或无权限', 404) task_ids = [t.tasks_id for t in tasks] - # 获取图片类型 - original_type = ImageType.query.filter_by(image_code='original').first() - perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - original_gen_type = ImageType.query.filter_by(image_code='original_generate').first() - perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first() + # 解析请求的图片类型 + type_codes = request.args.get('types', '').split(',') if request.args.get('types') else None - result = { - 'flow_id': flow_id, - 'perturbation_pairs': [], # 原图 vs 加噪图 - 'generation_pairs': [] # 原图生成 vs 加噪图生成 - } + # 构建查询 + query = Image.query.filter(Image.task_id.in_(task_ids)) - # 构建原图vs加噪图对比 - if original_type and perturbed_type: - originals = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == original_type.image_types_id - ).all() - - for orig in originals: - # 查找对应的加噪图(通过father_id关联) - perturbed = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == perturbed_type.image_types_id, - Image.father_id == orig.images_id - ).first() - - if perturbed: - result['perturbation_pairs'].append({ - 'original': ImageService.image_to_base64(orig), - 'perturbed': ImageService.image_to_base64(perturbed) - }) + if type_codes: + type_ids = [] + for code in type_codes: + code = code.strip() + if code: + img_type = ImageType.query.filter_by(image_code=code).first() + if img_type: + type_ids.append(img_type.image_types_id) + if type_ids: + query = query.filter(Image.image_types_id.in_(type_ids)) - # 构建生成图对比(按文件名匹配) - if original_gen_type and perturbed_gen_type: - original_gens = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == original_gen_type.image_types_id - ).all() - - perturbed_gens = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == perturbed_gen_type.image_types_id - ).all() - - # 按文件名建立映射 - perturbed_map = {img.stored_filename: img for img in perturbed_gens} - - for orig_gen in original_gens: - perturbed_gen = perturbed_map.get(orig_gen.stored_filename) - if perturbed_gen: - result['generation_pairs'].append({ - 'original_generate': ImageService.image_to_base64(orig_gen), - 'perturbed_generate': ImageService.image_to_base64(perturbed_gen) - }) + images = query.all() - return jsonify(result), 200 - \ No newline at end of file + if not images: + return ImageService.json_error('没有找到图片', 404) + + # 按类型分组 + images_dict = {} + for img in images: + type_code = img.image_type.image_code if img.image_type else 'unknown' + if type_code not in images_dict: + images_dict[type_code] = [] + images_dict[type_code].append(img) + + boundary = uuid.uuid4().hex + serializer = get_image_serializer() + + return Response( + serializer.generate_multipart_stream(images_dict, boundary), + mimetype=f'multipart/mixed; boundary={boundary}', + headers={ + 'X-Total-Images': str(len(images)), + 'X-Flow-Id': str(flow_id) + } + ) + +""" 前端解析预览图片方式 +const response = await fetch(`/api/image/binary/task/${taskId}`); +const contentType = response.headers.get('content-type'); +const boundary = contentType.match(/boundary=(.+)/)[1]; +const buffer = await response.arrayBuffer(); +// 按 --boundary 分割解析每个 part +""" \ No newline at end of file diff --git a/src/backend/app/repositories/__init__.py b/src/backend/app/repositories/__init__.py new file mode 100644 index 0000000..db0b71a --- /dev/null +++ b/src/backend/app/repositories/__init__.py @@ -0,0 +1,63 @@ +""" +Repository 层 + +提供数据访问抽象,将数据库操作从 Service 层分离。 + +使用方式: + from app.repositories import TaskRepository, ImageRepository + + task_repo = TaskRepository() + task = task_repo.get_by_id(task_id) + + image_repo = ImageRepository() + images = image_repo.get_by_task(task_id) + +设计原则: + - 单一职责:每个 Repository 只负责一个实体的数据访问 + - 依赖倒置:Service 层依赖 Repository 抽象 + - 开闭原则:通过继承 BaseRepository 扩展新实体 +""" +from .base_repository import BaseRepository +from .task_repository import ( + TaskRepository, + PerturbationRepository, + FinetuneRepository, + HeatmapRepository, + EvaluateRepository, + EvaluationResultRepository, +) +from .image_repository import ImageRepository +from .user_repository import UserRepository, UserConfigRepository, RoleRepository +from .config_repository import ( + TaskTypeRepository, + TaskStatusRepository, + ImageTypeRepository, + PerturbationConfigRepository, + FinetuneConfigRepository, + DataTypeRepository, +) + +__all__ = [ + # Base + 'BaseRepository', + # Task + 'TaskRepository', + 'PerturbationRepository', + 'FinetuneRepository', + 'HeatmapRepository', + 'EvaluateRepository', + 'EvaluationResultRepository', + # Image + 'ImageRepository', + # User + 'UserRepository', + 'UserConfigRepository', + 'RoleRepository', + # Config + 'TaskTypeRepository', + 'TaskStatusRepository', + 'ImageTypeRepository', + 'PerturbationConfigRepository', + 'FinetuneConfigRepository', + 'DataTypeRepository', +] diff --git a/src/backend/app/repositories/base_repository.py b/src/backend/app/repositories/base_repository.py new file mode 100644 index 0000000..eabe268 --- /dev/null +++ b/src/backend/app/repositories/base_repository.py @@ -0,0 +1,151 @@ +""" +Repository 基类 + +提供通用的 CRUD 操作,子类可以扩展特定实体的查询方法。 +""" +import logging +from abc import ABC, abstractmethod +from typing import TypeVar, Generic, Optional, List, Any, Type + +from app import db + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +class BaseRepository(ABC, Generic[T]): + """ + Repository 抽象基类 + + 提供通用的数据访问方法: + - get_by_id: 根据 ID 获取单个实体 + - get_all: 获取所有实体 + - create: 创建新实体 + - update: 更新实体 + - delete: 删除实体 + - save: 保存更改 + + 子类需实现: + - _get_model_class(): 返回对应的 SQLAlchemy Model 类 + """ + + @abstractmethod + def _get_model_class(self) -> Type[T]: + """返回对应的 Model 类""" + pass + + @abstractmethod + def _get_primary_key_name(self) -> str: + """返回主键字段名""" + pass + + def get_by_id(self, entity_id: int) -> Optional[T]: + """根据 ID 获取实体""" + return self._get_model_class().query.get(entity_id) + + def get_all(self) -> List[T]: + """获取所有实体""" + return self._get_model_class().query.all() + + def find_by(self, **kwargs) -> List[T]: + """根据条件查询""" + return self._get_model_class().query.filter_by(**kwargs).all() + + def find_one_by(self, **kwargs) -> Optional[T]: + """根据条件查询单个实体""" + return self._get_model_class().query.filter_by(**kwargs).first() + + def exists(self, entity_id: int) -> bool: + """检查实体是否存在""" + return self.get_by_id(entity_id) is not None + + def count(self, **kwargs) -> int: + """统计数量""" + if kwargs: + return self._get_model_class().query.filter_by(**kwargs).count() + return self._get_model_class().query.count() + + def create(self, **kwargs) -> T: + """ + 创建新实体(不自动提交) + + Args: + **kwargs: 实体属性 + + Returns: + 创建的实体对象 + """ + model_class = self._get_model_class() + entity = model_class(**kwargs) + db.session.add(entity) + return entity + + def add(self, entity: T) -> T: + """ + 添加实体到 session(不自动提交) + + Args: + entity: 实体对象 + + Returns: + 添加的实体对象 + """ + db.session.add(entity) + return entity + + def delete(self, entity: T) -> bool: + """ + 删除实体(不自动提交) + + Args: + entity: 要删除的实体 + + Returns: + 是否成功 + """ + try: + db.session.delete(entity) + return True + except Exception as e: + logger.error(f"Delete failed: {e}") + return False + + def delete_by_id(self, entity_id: int) -> bool: + """ + 根据 ID 删除实体(不自动提交) + + Args: + entity_id: 实体 ID + + Returns: + 是否成功 + """ + entity = self.get_by_id(entity_id) + if entity: + return self.delete(entity) + return False + + def save(self) -> bool: + """ + 提交当前事务 + + Returns: + 是否成功 + """ + try: + db.session.commit() + return True + except Exception as e: + db.session.rollback() + logger.error(f"Save failed: {e}") + return False + + def rollback(self): + """回滚当前事务""" + db.session.rollback() + + def refresh(self, entity: T) -> T: + """刷新实体状态""" + db.session.refresh(entity) + return entity diff --git a/src/backend/app/repositories/config_repository.py b/src/backend/app/repositories/config_repository.py new file mode 100644 index 0000000..ba5ad07 --- /dev/null +++ b/src/backend/app/repositories/config_repository.py @@ -0,0 +1,154 @@ +""" +配置字典表 Repository + +负责各种配置表的数据访问: +- TaskType: 任务类型 +- TaskStatus: 任务状态 +- ImageType: 图片类型 +- PerturbationConfig: 加噪算法配置 +- FinetuneConfig: 微调配置 +- DataType: 数据集类型 +""" +import logging +from typing import Optional, List, Type + +from app.database import ( + TaskType, TaskStatus, ImageType, + PerturbationConfig, FinetuneConfig, DataType +) +from .base_repository import BaseRepository + +logger = logging.getLogger(__name__) + + +class TaskTypeRepository(BaseRepository[TaskType]): + """任务类型数据访问""" + + def _get_model_class(self) -> Type[TaskType]: + return TaskType + + def _get_primary_key_name(self) -> str: + return 'task_type_id' + + def get_by_code(self, code: str) -> Optional[TaskType]: + """根据代码获取任务类型""" + return TaskType.query.filter_by(task_type_code=code).first() + + def require(self, code: str) -> TaskType: + """获取任务类型,不存在则抛出异常""" + task_type = self.get_by_code(code) + if not task_type: + raise ValueError(f"Task type '{code}' is not configured") + return task_type + + +class TaskStatusRepository(BaseRepository[TaskStatus]): + """任务状态数据访问""" + + def _get_model_class(self) -> Type[TaskStatus]: + return TaskStatus + + def _get_primary_key_name(self) -> str: + return 'task_status_id' + + def get_by_code(self, code: str) -> Optional[TaskStatus]: + """根据代码获取任务状态""" + return TaskStatus.query.filter_by(task_status_code=code).first() + + def require(self, code: str) -> TaskStatus: + """获取任务状态,不存在则抛出异常""" + status = self.get_by_code(code) + if not status: + raise ValueError(f"Task status '{code}' is not configured") + return status + + def get_waiting(self) -> Optional[TaskStatus]: + """获取等待状态""" + return self.get_by_code('waiting') + + def get_processing(self) -> Optional[TaskStatus]: + """获取处理中状态""" + return self.get_by_code('processing') + + def get_completed(self) -> Optional[TaskStatus]: + """获取完成状态""" + return self.get_by_code('completed') + + def get_failed(self) -> Optional[TaskStatus]: + """获取失败状态""" + return self.get_by_code('failed') + + +class ImageTypeRepository(BaseRepository[ImageType]): + """图片类型数据访问""" + + def _get_model_class(self) -> Type[ImageType]: + return ImageType + + def _get_primary_key_name(self) -> str: + return 'image_types_id' + + def get_by_code(self, code: str) -> Optional[ImageType]: + """根据代码获取图片类型""" + return ImageType.query.filter_by(image_code=code).first() + + def require(self, code: str) -> ImageType: + """获取图片类型,不存在则抛出异常""" + image_type = self.get_by_code(code) + if not image_type: + raise ValueError(f"Image type '{code}' is not configured") + return image_type + + +class PerturbationConfigRepository(BaseRepository[PerturbationConfig]): + """加噪算法配置数据访问""" + + def _get_model_class(self) -> Type[PerturbationConfig]: + return PerturbationConfig + + def _get_primary_key_name(self) -> str: + return 'perturbation_configs_id' + + def get_by_code(self, code: str) -> Optional[PerturbationConfig]: + """根据代码获取加噪配置""" + return PerturbationConfig.query.filter_by(perturbation_code=code).first() + + def get_all_active(self) -> List[PerturbationConfig]: + """获取所有可用的加噪配置""" + return self.get_all() + + +class FinetuneConfigRepository(BaseRepository[FinetuneConfig]): + """微调配置数据访问""" + + def _get_model_class(self) -> Type[FinetuneConfig]: + return FinetuneConfig + + def _get_primary_key_name(self) -> str: + return 'finetune_configs_id' + + def get_by_code(self, code: str) -> Optional[FinetuneConfig]: + """根据代码获取微调配置""" + return FinetuneConfig.query.filter_by(finetune_code=code).first() + + def get_all_active(self) -> List[FinetuneConfig]: + """获取所有可用的微调配置""" + return self.get_all() + + +class DataTypeRepository(BaseRepository[DataType]): + """数据集类型数据访问""" + + def _get_model_class(self) -> Type[DataType]: + return DataType + + def _get_primary_key_name(self) -> str: + return 'data_type_id' + + def get_by_code(self, code: str) -> Optional[DataType]: + """根据代码获取数据集类型""" + return DataType.query.filter_by(data_type_code=code).first() + + def get_all_active(self) -> List[DataType]: + """获取所有可用的数据集类型""" + return self.get_all() diff --git a/src/backend/app/repositories/image_repository.py b/src/backend/app/repositories/image_repository.py new file mode 100644 index 0000000..39cc32f --- /dev/null +++ b/src/backend/app/repositories/image_repository.py @@ -0,0 +1,121 @@ +""" +图片 Repository + +负责 Image 实体的数据访问。 +""" +import logging +from typing import Optional, List, Type + +from app.database import Image, ImageType +from .base_repository import BaseRepository + +logger = logging.getLogger(__name__) + + +class ImageRepository(BaseRepository[Image]): + """ + 图片数据访问 + + 提供图片相关的查询和操作方法。 + """ + + def _get_model_class(self) -> Type[Image]: + return Image + + def _get_primary_key_name(self) -> str: + return 'images_id' + + # ==================== 查询方法 ==================== + + def get_by_task(self, task_id: int) -> List[Image]: + """获取任务的所有图片""" + return Image.query.filter_by(task_id=task_id).all() + + def get_by_task_and_type(self, task_id: int, type_code: str) -> List[Image]: + """获取任务指定类型的图片""" + image_type = ImageType.query.filter_by(image_code=type_code).first() + if not image_type: + return [] + return Image.query.filter_by( + task_id=task_id, + image_types_id=image_type.image_types_id + ).all() + + def get_first_by_task_and_type(self, task_id: int, type_code: str) -> Optional[Image]: + """获取任务指定类型的第一张图片""" + images = self.get_by_task_and_type(task_id, type_code) + return images[0] if images else None + + def get_by_type(self, type_code: str) -> List[Image]: + """获取指定类型的所有图片""" + image_type = ImageType.query.filter_by(image_code=type_code).first() + if not image_type: + return [] + return Image.query.filter_by(image_types_id=image_type.image_types_id).all() + + def get_children(self, parent_id: int) -> List[Image]: + """获取子图片(派生图片)""" + return Image.query.filter_by(father_id=parent_id).all() + + def get_parent(self, image_id: int) -> Optional[Image]: + """获取父图片""" + image = self.get_by_id(image_id) + if image and image.father_id: + return self.get_by_id(image.father_id) + return None + + def get_by_path(self, file_path: str) -> Optional[Image]: + """根据文件路径获取图片""" + return Image.query.filter_by(file_path=file_path).first() + + def get_by_filename(self, filename: str) -> Optional[Image]: + """根据存储文件名获取图片""" + return Image.query.filter_by(stored_filename=filename).first() + + # ==================== 统计方法 ==================== + + def count_by_task(self, task_id: int) -> int: + """统计任务的图片数量""" + return Image.query.filter_by(task_id=task_id).count() + + def count_by_task_and_type(self, task_id: int, type_code: str) -> int: + """统计任务指定类型的图片数量""" + image_type = ImageType.query.filter_by(image_code=type_code).first() + if not image_type: + return 0 + return Image.query.filter_by( + task_id=task_id, + image_types_id=image_type.image_types_id + ).count() + + # ==================== 权限验证 ==================== + + def is_owner(self, image: Image, user_id: int) -> bool: + """验证图片归属(通过关联的任务)""" + if image and image.task: + return image.task.user_id == user_id + return False + + def get_for_user(self, image_id: int, user_id: int) -> Optional[Image]: + """获取用户的图片(带权限验证)""" + image = self.get_by_id(image_id) + if self.is_owner(image, user_id): + return image + return None + + # ==================== 批量操作 ==================== + + def delete_by_task(self, task_id: int) -> int: + """删除任务的所有图片记录""" + images = self.get_by_task(task_id) + count = 0 + for image in images: + if self.delete(image): + count += 1 + return count + + def get_type_code(self, image: Image) -> Optional[str]: + """获取图片类型代码""" + if image and image.image_type: + return image.image_type.image_code + return None diff --git a/src/backend/app/repositories/task_repository.py b/src/backend/app/repositories/task_repository.py new file mode 100644 index 0000000..446b958 --- /dev/null +++ b/src/backend/app/repositories/task_repository.py @@ -0,0 +1,232 @@ +""" +任务 Repository + +负责 Task 及其子表(Perturbation, Finetune, Heatmap, Evaluate)的数据访问。 +""" +import logging +from typing import Optional, List, Type +from datetime import datetime + +from app.database import ( + Task, Perturbation, Finetune, Heatmap, Evaluate, + TaskType, TaskStatus, EvaluationResult +) +from .base_repository import BaseRepository + +logger = logging.getLogger(__name__) + + +class TaskRepository(BaseRepository[Task]): + """ + 任务数据访问 + + 提供任务相关的查询和操作方法。 + """ + + def _get_model_class(self) -> Type[Task]: + return Task + + def _get_primary_key_name(self) -> str: + return 'tasks_id' + + # ==================== 查询方法 ==================== + + def get_by_user(self, user_id: int) -> List[Task]: + """获取用户的所有任务""" + return Task.query.filter_by(user_id=user_id).all() + + def get_by_user_and_type(self, user_id: int, type_code: str) -> List[Task]: + """获取用户指定类型的任务""" + task_type = TaskType.query.filter_by(task_type_code=type_code).first() + if not task_type: + return [] + return Task.query.filter_by( + user_id=user_id, + tasks_type_id=task_type.task_type_id + ).all() + + def get_by_flow(self, flow_id: int) -> List[Task]: + """获取同一工作流的所有任务""" + return Task.query.filter_by(flow_id=flow_id).all() + + def get_by_flow_and_type(self, flow_id: int, type_code: str) -> Optional[Task]: + """获取工作流中指定类型的任务""" + task_type = TaskType.query.filter_by(task_type_code=type_code).first() + if not task_type: + return None + return Task.query.filter_by( + flow_id=flow_id, + tasks_type_id=task_type.task_type_id + ).first() + + def get_by_status(self, status_code: str) -> List[Task]: + """获取指定状态的任务""" + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if not status: + return [] + return Task.query.filter_by(tasks_status_id=status.task_status_id).all() + + def get_user_tasks_by_status(self, user_id: int, status_code: str) -> List[Task]: + """获取用户指定状态的任务""" + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if not status: + return [] + return Task.query.filter_by( + user_id=user_id, + tasks_status_id=status.task_status_id + ).all() + + def get_pending_tasks(self, user_id: int) -> List[Task]: + """获取用户待处理的任务(waiting + processing)""" + waiting = TaskStatus.query.filter_by(task_status_code='waiting').first() + processing = TaskStatus.query.filter_by(task_status_code='processing').first() + + status_ids = [] + if waiting: + status_ids.append(waiting.task_status_id) + if processing: + status_ids.append(processing.task_status_id) + + if not status_ids: + return [] + + return Task.query.filter( + Task.user_id == user_id, + Task.tasks_status_id.in_(status_ids) + ).all() + + def count_pending_tasks(self, user_id: int) -> int: + """统计用户待处理任务数""" + return len(self.get_pending_tasks(user_id)) + + # ==================== 状态更新 ==================== + + def update_status(self, task: Task, status_code: str) -> bool: + """ + 更新任务状态 + + Args: + task: 任务对象 + status_code: 状态代码 + + Returns: + 是否成功 + """ + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if not status: + logger.error(f"Status '{status_code}' not found") + return False + + task.tasks_status_id = status.task_status_id + + # 自动更新时间戳 + if status_code == 'processing': + task.started_at = datetime.utcnow() + elif status_code in ('completed', 'failed'): + task.finished_at = datetime.utcnow() + + return True + + def set_error(self, task: Task, error_message: str) -> bool: + """设置任务错误信息并标记为失败""" + task.error_message = error_message + return self.update_status(task, 'failed') + + # ==================== 权限验证 ==================== + + def is_owner(self, task: Task, user_id: int) -> bool: + """验证任务归属""" + return task is not None and task.user_id == user_id + + def get_for_user(self, task_id: int, user_id: int) -> Optional[Task]: + """获取用户的任务(带权限验证)""" + task = self.get_by_id(task_id) + if self.is_owner(task, user_id): + return task + return None + + # ==================== 任务类型判断 ==================== + + def get_type_code(self, task: Task) -> Optional[str]: + """获取任务类型代码""" + if task and task.task_type: + return task.task_type.task_type_code + return None + + def is_type(self, task: Task, type_code: str) -> bool: + """判断任务是否为指定类型""" + return self.get_type_code(task) == type_code + + +class PerturbationRepository(BaseRepository[Perturbation]): + """加噪任务详情数据访问""" + + def _get_model_class(self) -> Type[Perturbation]: + return Perturbation + + def _get_primary_key_name(self) -> str: + return 'tasks_id' + + def get_by_task(self, task_id: int) -> Optional[Perturbation]: + """根据任务 ID 获取加噪详情""" + return Perturbation.query.filter_by(tasks_id=task_id).first() + + +class FinetuneRepository(BaseRepository[Finetune]): + """微调任务详情数据访问""" + + def _get_model_class(self) -> Type[Finetune]: + return Finetune + + def _get_primary_key_name(self) -> str: + return 'tasks_id' + + def get_by_task(self, task_id: int) -> Optional[Finetune]: + """根据任务 ID 获取微调详情""" + return Finetune.query.filter_by(tasks_id=task_id).first() + + +class HeatmapRepository(BaseRepository[Heatmap]): + """热力图任务详情数据访问""" + + def _get_model_class(self) -> Type[Heatmap]: + return Heatmap + + def _get_primary_key_name(self) -> str: + return 'tasks_id' + + def get_by_task(self, task_id: int) -> Optional[Heatmap]: + """根据任务 ID 获取热力图详情""" + return Heatmap.query.filter_by(tasks_id=task_id).first() + + def get_by_image(self, image_id: int) -> List[Heatmap]: + """根据图片 ID 获取相关热力图任务""" + return Heatmap.query.filter_by(images_id=image_id).all() + + +class EvaluateRepository(BaseRepository[Evaluate]): + """评估任务详情数据访问""" + + def _get_model_class(self) -> Type[Evaluate]: + return Evaluate + + def _get_primary_key_name(self) -> str: + return 'tasks_id' + + def get_by_task(self, task_id: int) -> Optional[Evaluate]: + """根据任务 ID 获取评估详情""" + return Evaluate.query.filter_by(tasks_id=task_id).first() + + def get_by_finetune(self, finetune_task_id: int) -> Optional[Evaluate]: + """根据微调任务 ID 获取评估任务""" + return Evaluate.query.filter_by(finetune_task_id=finetune_task_id).first() + + +class EvaluationResultRepository(BaseRepository[EvaluationResult]): + """评估结果数据访问""" + + def _get_model_class(self) -> Type[EvaluationResult]: + return EvaluationResult + + def _get_primary_key_name(self) -> str: + return 'evaluation_results_id' diff --git a/src/backend/app/repositories/user_repository.py b/src/backend/app/repositories/user_repository.py new file mode 100644 index 0000000..f05b7c8 --- /dev/null +++ b/src/backend/app/repositories/user_repository.py @@ -0,0 +1,130 @@ +""" +用户 Repository + +负责 User 和 UserConfig 实体的数据访问。 +""" +import logging +from typing import Optional, List, Type + +from app.database import User, UserConfig, Role +from .base_repository import BaseRepository + +logger = logging.getLogger(__name__) + + +class UserRepository(BaseRepository[User]): + """ + 用户数据访问 + + 提供用户相关的查询和操作方法。 + """ + + def _get_model_class(self) -> Type[User]: + return User + + def _get_primary_key_name(self) -> str: + return 'user_id' + + # ==================== 查询方法 ==================== + + def get_by_username(self, username: str) -> Optional[User]: + """根据用户名获取用户""" + return User.query.filter_by(username=username).first() + + def get_by_email(self, email: str) -> Optional[User]: + """根据邮箱获取用户""" + return User.query.filter_by(email=email).first() + + def get_active_users(self) -> List[User]: + """获取所有激活的用户""" + return User.query.filter_by(is_active=True).all() + + def get_by_role(self, role_code: str) -> List[User]: + """获取指定角色的用户""" + role = Role.query.filter_by(role_code=role_code).first() + if not role: + return [] + return User.query.filter_by(role_id=role.role_id).all() + + # ==================== 验证方法 ==================== + + def username_exists(self, username: str) -> bool: + """检查用户名是否存在""" + return self.get_by_username(username) is not None + + def email_exists(self, email: str) -> bool: + """检查邮箱是否存在""" + return self.get_by_email(email) is not None + + def authenticate(self, username: str, password: str) -> Optional[User]: + """ + 验证用户凭据 + + Args: + username: 用户名 + password: 密码 + + Returns: + 验证成功返回用户对象,否则返回 None + """ + user = self.get_by_username(username) + if user and user.is_active and user.check_password(password): + return user + return None + + # ==================== 角色相关 ==================== + + def get_role_code(self, user: User) -> Optional[str]: + """获取用户角色代码""" + if user and user.role: + return user.role.role_code + return None + + def is_admin(self, user: User) -> bool: + """判断是否为管理员""" + return self.get_role_code(user) == 'admin' + + def is_vip(self, user: User) -> bool: + """判断是否为 VIP""" + return self.get_role_code(user) == 'vip' + + def get_max_concurrent_tasks(self, user: User) -> int: + """获取用户最大并发任务数""" + if user and user.role: + return user.role.max_concurrent_tasks or 1 + return 1 + + +class UserConfigRepository(BaseRepository[UserConfig]): + """用户配置数据访问""" + + def _get_model_class(self) -> Type[UserConfig]: + return UserConfig + + def _get_primary_key_name(self) -> str: + return 'user_configs_id' + + def get_by_user(self, user_id: int) -> Optional[UserConfig]: + """根据用户 ID 获取配置""" + return UserConfig.query.filter_by(user_id=user_id).first() + + def get_or_create(self, user_id: int) -> UserConfig: + """获取或创建用户配置""" + config = self.get_by_user(user_id) + if not config: + config = self.create(user_id=user_id) + return config + + +class RoleRepository(BaseRepository[Role]): + """角色数据访问""" + + def _get_model_class(self) -> Type[Role]: + return Role + + def _get_primary_key_name(self) -> str: + return 'role_id' + + def get_by_code(self, role_code: str) -> Optional[Role]: + """根据角色代码获取角色""" + return Role.query.filter_by(role_code=role_code).first() diff --git a/src/backend/app/services/cache/__init__.py b/src/backend/app/services/cache/__init__.py new file mode 100644 index 0000000..bd06492 --- /dev/null +++ b/src/backend/app/services/cache/__init__.py @@ -0,0 +1,4 @@ +"""缓存服务模块""" +from .redis_client import RedisClient + +__all__ = ['RedisClient'] diff --git a/src/backend/app/services/cache/redis_client.py b/src/backend/app/services/cache/redis_client.py new file mode 100644 index 0000000..f705466 --- /dev/null +++ b/src/backend/app/services/cache/redis_client.py @@ -0,0 +1,66 @@ +""" +Redis 客户端封装(单例模式) + +职责单一:只负责 Redis 连接管理和基础操作 +""" +import logging +from typing import Optional +import redis +from flask import current_app + +logger = logging.getLogger(__name__) + + +class RedisClient: + """ + Redis 客户端单例 + + 使用方式: + client = RedisClient() + client.set('key', 'value', ex=300) + value = client.get('key') + """ + _instance: Optional['RedisClient'] = None + _pools: dict[str, redis.ConnectionPool] = {} + + def __new__(cls) -> 'RedisClient': + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def _get_connection(self) -> redis.Redis: + """获取 Redis 连接,复用连接池""" + redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0') + + if redis_url not in self._pools: + self._pools[redis_url] = redis.ConnectionPool.from_url( + redis_url, + decode_responses=True + ) + + return redis.Redis(connection_pool=self._pools[redis_url]) + + def get(self, key: str) -> Optional[str]: + """获取值""" + try: + return self._get_connection().get(key) + except Exception: + logger.exception(f"Redis GET 失败: {key}") + return None + + def set(self, key: str, value: str, ex: Optional[int] = None) -> bool: + """设置值,可选过期时间(秒)""" + try: + self._get_connection().set(key, value, ex=ex) + return True + except Exception: + logger.exception(f"Redis SET 失败: {key}") + return False + + def delete(self, key: str) -> bool: + """删除键""" + try: + return self._get_connection().delete(key) == 1 + except Exception: + logger.exception(f"Redis DELETE 失败: {key}") + return False diff --git a/src/backend/app/services/dump.rdb b/src/backend/app/services/dump.rdb deleted file mode 100644 index 18dbae9..0000000 Binary files a/src/backend/app/services/dump.rdb and /dev/null differ diff --git a/src/backend/app/services/email/__init__.py b/src/backend/app/services/email/__init__.py new file mode 100644 index 0000000..f288ef7 --- /dev/null +++ b/src/backend/app/services/email/__init__.py @@ -0,0 +1,5 @@ +"""邮件服务模块""" +from .email_sender import EmailSender +from .verification_service import VerificationService + +__all__ = ['EmailSender', 'VerificationService'] diff --git a/src/backend/app/services/email/email_sender.py b/src/backend/app/services/email/email_sender.py new file mode 100644 index 0000000..685381c --- /dev/null +++ b/src/backend/app/services/email/email_sender.py @@ -0,0 +1,67 @@ +""" +邮件发送服务 + +职责单一:只负责发送邮件 +""" +import logging +from typing import Optional +from flask import current_app +from flask_mail import Message + +logger = logging.getLogger(__name__) + + +class EmailSender: + """ + 邮件发送器 + + 使用方式: + sender = EmailSender() + sender.send('user@example.com', '标题', '内容') + """ + + def send( + self, + to: str, + subject: str, + body: str, + html: Optional[str] = None + ) -> bool: + """ + 发送邮件 + + Args: + to: 收件人邮箱 + subject: 邮件主题 + body: 纯文本内容 + html: HTML 内容(可选) + + Returns: + 是否发送成功 + """ + try: + mail = current_app.extensions.get('mail') + if mail is None: + logger.error('Flask-Mail 未初始化,无法发送邮件') + return False + + sender = ( + current_app.config.get('MAIL_DEFAULT_SENDER') or + current_app.config.get('MAIL_USERNAME') + ) + + msg = Message( + subject=subject, + recipients=[to], + body=body, + html=html, + sender=sender + ) + + mail.send(msg) + logger.info(f'邮件发送成功: {to}') + return True + + except Exception: + logger.exception(f'邮件发送失败: {to}') + return False diff --git a/src/backend/app/services/email/verification_service.py b/src/backend/app/services/email/verification_service.py new file mode 100644 index 0000000..9267544 --- /dev/null +++ b/src/backend/app/services/email/verification_service.py @@ -0,0 +1,211 @@ +""" +验证码服务 + +职责:验证码的生成、存储、校验 +通过组合 RedisClient 和 EmailSender 实现(依赖注入) +""" +import random +import string +import logging +from typing import Optional +from flask import current_app + +from app.services.cache import RedisClient +from app.services.email.email_sender import EmailSender + +logger = logging.getLogger(__name__) + + +class VerificationService: + """ + 验证码服务 + + 使用方式: + # 方式1:使用默认依赖 + service = VerificationService() + + # 方式2:注入自定义依赖(便于测试) + service = VerificationService( + redis_client=mock_redis, + email_sender=mock_sender + ) + + # 发送验证码 + service.send_code('user@example.com', purpose='register') + + # 校验验证码 + is_valid = service.verify_code('user@example.com', '123456', purpose='register') + """ + + # 验证码 Redis key 前缀 + KEY_PREFIX = 'verify' + + def __init__( + self, + redis_client: Optional[RedisClient] = None, + email_sender: Optional[EmailSender] = None + ): + """ + 初始化验证码服务 + + Args: + redis_client: Redis 客户端,默认使用单例 + email_sender: 邮件发送器,默认创建新实例 + """ + self._redis = redis_client or RedisClient() + self._email = email_sender or EmailSender() + + def _build_key(self, email: str, purpose: str) -> str: + """构建 Redis 存储 key""" + return f"{self.KEY_PREFIX}:{purpose}:{email}" + + @staticmethod + def _generate_code(length: int = 6) -> str: + """生成数字验证码""" + return ''.join(random.choices(string.digits, k=length)) + + def _get_expire_seconds(self, custom_expire: Optional[int] = None) -> int: + """获取过期时间(秒)""" + if custom_expire is not None: + return custom_expire + return current_app.config.get('VERIFICATION_CODE_EXPIRES', 300) + + def _build_email_body(self, code: str, expire_seconds: int) -> str: + """构建邮件内容""" + template = current_app.config.get( + 'VERIFICATION_EMAIL_TEMPLATE', + '您的验证码为:{code},有效期 {expire_minutes} 分钟。请勿泄露给他人。' + ) + return template.format( + code=code, + expire_seconds=expire_seconds, + expire_minutes=expire_seconds // 60 + ) + + def send_code( + self, + email: str, + purpose: str = 'register', + length: int = 6, + expire_seconds: Optional[int] = None + ) -> bool: + """ + 生成并发送验证码 + + Args: + email: 目标邮箱 + purpose: 用途(register/reset_password/change_email 等) + length: 验证码长度 + expire_seconds: 过期时间(秒),默认从配置读取 + + Returns: + 是否发送成功 + """ + expire = self._get_expire_seconds(expire_seconds) + code = self._generate_code(length) + key = self._build_key(email, purpose) + + # 存储到 Redis + if not self._redis.set(key, code, ex=expire): + logger.error(f'验证码存储失败: {email}') + return False + + # 发送邮件 + subject = current_app.config.get('VERIFICATION_EMAIL_SUBJECT', '您的验证码') + body = self._build_email_body(code, expire) + + if not self._email.send(email, subject, body): + # 邮件发送失败,清理 Redis 中的验证码 + self._redis.delete(key) + return False + + logger.info(f'验证码已发送: {email} (purpose={purpose})') + return True + + def verify_code( + self, + email: str, + code: str, + purpose: str = 'register', + delete_on_success: bool = True + ) -> bool: + """ + 校验验证码 + + Args: + email: 邮箱 + code: 用户输入的验证码 + purpose: 用途 + delete_on_success: 校验成功后是否删除 + + Returns: + 是否校验通过 + """ + key = self._build_key(email, purpose) + stored = self._redis.get(key) + + if stored is None: + logger.debug(f'验证码不存在或已过期: {email}') + return False + + if str(stored) != str(code): + logger.debug(f'验证码不匹配: {email}') + return False + + # 校验成功,删除验证码(防止重复使用) + if delete_on_success: + if not self._redis.delete(key): + logger.warning(f'验证码删除失败: {key}') + + logger.info(f'验证码校验成功: {email} (purpose={purpose})') + return True + + def clear_code(self, email: str, purpose: str = 'register') -> bool: + """ + 清除验证码(管理员操作或用户取消) + + Args: + email: 邮箱 + purpose: 用途 + + Returns: + 是否删除成功 + """ + key = self._build_key(email, purpose) + return self._redis.delete(key) + + +# ============================================================ +# 兼容层:保持原有函数接口,内部委托给 VerificationService +# 便于渐进式迁移,后续可逐步移除 +# ============================================================ + +_default_service: Optional[VerificationService] = None + + +def _get_service() -> VerificationService: + """获取默认服务实例(懒加载)""" + global _default_service + if _default_service is None: + _default_service = VerificationService() + return _default_service + + +def send_verification_code( + email: str, + purpose: str = 'register', + length: int = 6, + expire_seconds: Optional[int] = None +) -> bool: + """【兼容接口】发送验证码""" + return _get_service().send_code(email, purpose, length, expire_seconds) + + +def verify_code(email: str, code: str, purpose: str = 'register') -> bool: + """【兼容接口】校验验证码""" + return _get_service().verify_code(email, code, purpose) + + +def clear_verification_code(email: str, purpose: str = 'register') -> bool: + """【兼容接口】清除验证码""" + return _get_service().clear_code(email, purpose) diff --git a/src/backend/app/services/email_service.py b/src/backend/app/services/email_service.py deleted file mode 100644 index 53ea96f..0000000 --- a/src/backend/app/services/email_service.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -验证码服务(Redis 存储 + Flask-Mail 发送) - -提供: -- `send_verification_code(email, purpose='register', length=6, expire_seconds=None)` -- `verify_code(email, code, purpose='register')` - -依赖: -- `redis`(通过 `REDIS_URL` 配置,默认为 redis://localhost:6379/0) -- Flask-Mail 已在应用中初始化(通过 `current_app.extensions['mail']` 获取) -""" -import random -import string -import logging -from typing import Optional - -import redis -from flask import current_app -from flask_mail import Message - -logger = logging.getLogger(__name__) - -pool = redis.ConnectionPool().from_url('redis://localhost:6379/0', decode_responses=True) - -def _get_redis_client() -> redis.Redis: - """根据 REDIS_URL 创建 redis 客户端""" - redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0') - - # 不再使用全局的 pool,而是每次根据 URL 获取连接(redis-py 内部自己会管理连接池) - # 或者如果你想复用连接池,应该在 app 启动时初始化 pool,而不是在全局写死 localhost - return redis.Redis.from_url(redis_url, decode_responses=True) - - -def _generate_code(length: int = 6) -> str: - """生成指定长度的数字验证码(默认 6 位)。""" - # 只产生数字字符串更常用于验证码 - return ''.join(random.choices(string.digits, k=length)) - - -def send_verification_code(email: str, - purpose: str = 'register', - length: int = 6, - expire_seconds: Optional[int] = None) -> bool: - """生成验证码,保存到 Redis,并使用 Flask-Mail 发送给 `email`。 - - 返回 True 表示发送成功,False 表示失败(或抛出异常时捕获后返回 False)。 - """ - # 读取过期时间(秒),优先使用传入值,其次使用 app config - if expire_seconds is None: - expire_seconds = current_app.config.get('VERIFICATION_CODE_EXPIRES', 300) - - code = _generate_code(length) - key = f"verify:{purpose}:{email}" - - try: - r = _get_redis_client() - # 使用字符串保存验证码,并设置过期时间 - r.set(key, code, ex=expire_seconds) - except Exception as e: - logger.exception("保存验证码到 Redis 失败") - return False - - # 尝试发送邮件 - try: - mail = current_app.extensions.get('mail') - subject = "您的验证码" # current_app.config.get('VERIFICATION_EMAIL_SUBJECT', '您的验证码') - sender = '1798231811@qq.com' # current_app.config.get('MAIL_DEFAULT_SENDER') or current_app.config.get('MAIL_USERNAME') - body = f'您的验证码为:{code},有效期 {expire_seconds} 秒。' # current_app.config.get('VERIFICATION_EMAIL_TEMPLATE', - # f'您的验证码为:{code},有效期 {expire_seconds} 秒。') - - # 优先使用简单文本邮件,项目中可按需替换为 HTML 模板 - msg = Message(subject=subject, recipients=[email], body=body, sender=sender) - - if mail is None: - # 如果 Flask-Mail 未初始化,记录日志并返回 False - logger.error('Flask-Mail 未初始化,无法发送邮件') - return False - - mail.send(msg) - logger.info('已发送验证码到 %s (purpose=%s)', email, purpose) - return True - except Exception: - logger.exception('发送验证码邮件失败') - return False - - -def verify_code(email: str, code: str, purpose: str = 'register') -> bool: - """校验验证码是否正确。成功可配置是否从 Redis 删除该 key。 - - 返回 True 表示校验通过;False 表示失败或异常。 - """ - key = f"verify:{purpose}:{email}" - try: - r = _get_redis_client() - stored = r.get(key) - if stored is None: - return False - - matched = (str(stored) == str(code)) - if matched : - try: - r.delete(key) - except Exception: - logger.warning('校验成功,但删除 Redis key 失败: %s', key) - return matched - except Exception: - logger.exception('校验验证码时发生异常') - return False - - -def clear_verification_code(email: str, purpose: str = 'register') -> bool: - """显式删除指定 email 的验证码(例如用于管理员撤销)。""" - key = f"verify:{purpose}:{email}" - try: - r = _get_redis_client() - return r.delete(key) == 1 - except Exception: - logger.exception('删除验证码失败') - return False - -if __name__ == '__main__': - # 简单测试发送和验证功能 - test_email = "3310207578@qq.com" - if send_verification_code(test_email, expire_seconds=600): - print("验证码发送成功") - code = input("请输入收到的验证码: ") - if verify_code(test_email, code): - print("验证码验证成功") - else: - print("验证码验证失败") - else: - print("验证码发送失败") \ No newline at end of file diff --git a/src/backend/app/services/image/__init__.py b/src/backend/app/services/image/__init__.py new file mode 100644 index 0000000..0676150 --- /dev/null +++ b/src/backend/app/services/image/__init__.py @@ -0,0 +1,23 @@ +""" +图片服务模块 + +按职责拆分为: +- ImageProcessor: 图片预处理(裁剪、缩放、格式转换) +- ImageStorage: 图片存储管理(保存、删除) +- ImageSerializer: 图片序列化(JSON、Base64) +- ZipService: 打包服务 +- ImagePreviewService: 预览图片服务 +""" +from .image_processor import ImageProcessor +from .image_storage import ImageStorage +from .image_serializer import ImageSerializer +from .zip_service import ZipService +from .image_preview import ImagePreviewService + +__all__ = [ + 'ImageProcessor', + 'ImageStorage', + 'ImageSerializer', + 'ZipService', + 'ImagePreviewService', +] diff --git a/src/backend/app/services/image/image_preview.py b/src/backend/app/services/image/image_preview.py new file mode 100644 index 0000000..a21ed8f --- /dev/null +++ b/src/backend/app/services/image/image_preview.py @@ -0,0 +1,125 @@ +""" +图片预览服务 + +职责单一:获取各类任务的预览图片(返回图片ID列表,前端通过二进制接口获取) +""" +import logging +from typing import Dict, List, Optional + +from app.database import Task, Image + +logger = logging.getLogger(__name__) + + +def _get_image_repo(): + """懒加载获取 ImageRepository""" + from app.repositories import ImageRepository + return ImageRepository() + + +def _get_task_repo(): + """懒加载获取 TaskRepository""" + from app.repositories import TaskRepository + return TaskRepository() + + +class ImagePreviewService: + """ + 图片预览服务 + + 负责获取各类任务的图片ID列表,前端通过 /binary/task 或 /binary/flow 接口获取二进制数据 + """ + + def _image_to_meta(self, image: Image) -> Optional[Dict]: + """将图片转换为元数据字典""" + if not image: + return None + return { + 'image_id': image.images_id, + 'filename': image.stored_filename, + 'width': image.width, + 'height': image.height + } + + def get_perturbation_preview(self, task: Task) -> Dict[str, List]: + """获取加噪任务的预览图片元数据""" + image_repo = _get_image_repo() + + originals = image_repo.get_by_task_and_type(task.tasks_id, 'original') + perturbeds = image_repo.get_by_task_and_type(task.tasks_id, 'perturbed') + + return { + 'original': [self._image_to_meta(img) for img in originals if img], + 'perturbed': [self._image_to_meta(img) for img in perturbeds if img] + } + + def get_finetune_preview(self, task: Task) -> Dict[str, List]: + """获取微调任务的预览图片元数据""" + image_repo = _get_image_repo() + task_repo = _get_task_repo() + + images = { + 'original': [], + 'original_generate': [], + 'perturbed_generate': [], + 'uploaded_generate': [] + } + + flow_tasks = task_repo.get_by_flow(task.flow_id) + for flow_task in flow_tasks: + if flow_task.user_id == task.user_id: + originals = image_repo.get_by_task_and_type(flow_task.tasks_id, 'original') + images['original'].extend([ + self._image_to_meta(img) for img in originals if img + ]) + + for type_code in ['original_generate', 'perturbed_generate', 'uploaded_generate']: + generated = image_repo.get_by_task_and_type(task.tasks_id, type_code) + images[type_code] = [self._image_to_meta(img) for img in generated if img] + + return images + + def get_heatmap_preview(self, task: Task) -> Dict[str, List]: + """获取热力图任务的预览图片元数据""" + image_repo = _get_image_repo() + heatmaps = image_repo.get_by_task_and_type(task.tasks_id, 'heatmap') + + return { + 'heatmap': [self._image_to_meta(img) for img in heatmaps if img] + } + + def get_evaluate_preview(self, task: Task) -> Dict[str, List]: + """获取评估任务的预览图片元数据""" + image_repo = _get_image_repo() + reports = image_repo.get_by_task_and_type(task.tasks_id, 'report') + + return { + 'report': [self._image_to_meta(img) for img in reports if img] + } + + def get_preview_by_task_type(self, task: Task, task_type: str) -> Dict[str, List]: + """根据任务类型获取预览图片元数据""" + handlers = { + 'perturbation': self.get_perturbation_preview, + 'finetune': self.get_finetune_preview, + 'heatmap': self.get_heatmap_preview, + 'evaluate': self.get_evaluate_preview, + } + + handler = handlers.get(task_type) + if handler: + return handler(task) + + return {} + + +# 全局单例 +_default_preview_service: Optional[ImagePreviewService] = None + + +def get_preview_service() -> ImagePreviewService: + """获取默认的预览服务实例""" + global _default_preview_service + if _default_preview_service is None: + _default_preview_service = ImagePreviewService() + return _default_preview_service diff --git a/src/backend/app/services/image/image_processor.py b/src/backend/app/services/image/image_processor.py new file mode 100644 index 0000000..d59d23f --- /dev/null +++ b/src/backend/app/services/image/image_processor.py @@ -0,0 +1,125 @@ +""" +图片预处理服务 + +职责单一:图片的裁剪、缩放、格式转换 +""" +import logging +from typing import Tuple +from PIL import Image as PILImage + +logger = logging.getLogger(__name__) + + +class ImageProcessor: + """ + 图片预处理器 + + 负责图片的: + - 中心裁剪 + - 缩放到指定尺寸 + - 格式转换 + + 使用方式: + processor = ImageProcessor(target_size=512) + pil_image = processor.process(file_storage) + processor.save(pil_image, '/path/to/output.png') + """ + + DEFAULT_SIZE = 512 + DEFAULT_FORMAT = 'PNG' + + def __init__(self, target_size: int = None): + """ + 初始化处理器 + + Args: + target_size: 目标尺寸(正方形边长),默认 512 + """ + self._target_size = target_size or self.DEFAULT_SIZE + + @property + def target_size(self) -> int: + return self._target_size + + def process_from_file(self, file_storage) -> PILImage.Image: + """ + 从上传文件处理图片 + + Args: + file_storage: Flask 文件存储对象 + + Returns: + 处理后的 PIL Image 对象 + """ + file_storage.stream.seek(0) + image = PILImage.open(file_storage.stream).convert('RGB') + return self._crop_and_resize(image) + + def process_from_path(self, file_path: str) -> PILImage.Image: + """ + 从文件路径处理图片 + + Args: + file_path: 图片文件路径 + + Returns: + 处理后的 PIL Image 对象 + """ + image = PILImage.open(file_path).convert('RGB') + return self._crop_and_resize(image) + + def _crop_and_resize(self, image: PILImage.Image) -> PILImage.Image: + """ + 中心裁剪并缩放 + + Args: + image: 原始 PIL Image + + Returns: + 处理后的 PIL Image + """ + width, height = image.size + min_dim = min(width, height) + + # 中心裁剪为正方形 + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + image = image.crop((left, top, left + min_dim, top + min_dim)) + + # 缩放到目标尺寸 + return image.resize( + (self._target_size, self._target_size), + resample=PILImage.Resampling.LANCZOS + ) + + def save( + self, + image: PILImage.Image, + output_path: str, + format: str = None, + quality: int = 95 + ) -> Tuple[int, int, int]: + """ + 保存处理后的图片 + + Args: + image: PIL Image 对象 + output_path: 输出路径 + format: 输出格式(PNG/JPEG),默认从路径推断 + quality: JPEG 质量(仅对 JPEG 有效) + + Returns: + (width, height, file_size) 元组 + """ + import os + + if format is None: + ext = os.path.splitext(output_path)[1].lower() + format = 'JPEG' if ext in ('.jpg', '.jpeg') else 'PNG' + + if format.upper() == 'JPEG': + image.save(output_path, format='JPEG', quality=quality) + else: + image.save(output_path, format=format.upper()) + + return image.width, image.height, os.path.getsize(output_path) diff --git a/src/backend/app/services/image/image_serializer.py b/src/backend/app/services/image/image_serializer.py new file mode 100644 index 0000000..a528328 --- /dev/null +++ b/src/backend/app/services/image/image_serializer.py @@ -0,0 +1,147 @@ +""" +图片序列化服务 + +职责单一:图片的序列化(JSON、二进制流) +""" +import os +import logging +from typing import Optional, Dict, Any, List, Generator + +from app.database import Image + +logger = logging.getLogger(__name__) + + +class ImageSerializer: + """ + 图片序列化器 + + 负责: + - 将图片对象转换为 JSON 字典 + - 生成 multipart/mixed 二进制流 + + 使用方式: + serializer = ImageSerializer() + data = serializer.to_dict(image) + """ + + MIME_TYPES = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + } + + def to_dict(self, image: Image) -> Optional[Dict[str, Any]]: + """ + 将图片对象序列化为字典 + + Args: + image: Image 数据库对象 + + Returns: + 字典或 None + """ + if not image: + return None + + return { + 'image_id': image.images_id, + 'task_id': image.task_id, + 'stored_filename': image.stored_filename, + 'file_path': image.file_path, + 'file_size': image.file_size, + 'width': image.width, + 'height': image.height, + 'image_type': image.image_type.image_code if image.image_type else None + } + + def get_url(self, image: Image) -> Optional[str]: + """ + 获取图片访问 URL + + Args: + image: Image 数据库对象 + + Returns: + URL 字符串或 None + """ + if not image or not image.file_path: + return None + + return f"/api/image/file/{image.images_id}" + + def serialize_list(self, images: list) -> list: + """ + 批量序列化图片列表 + + Args: + images: Image 对象列表 + + Returns: + 序列化后的列表 + """ + return [self.to_dict(img) for img in images if img] + + def generate_multipart_stream( + self, + images_dict: Dict[str, List[Image]], + boundary: str + ) -> Generator[bytes, None, None]: + """ + 生成 multipart/mixed 格式的二进制流 + + Args: + images_dict: 图片字典 {type_code: [Image, ...]} + boundary: multipart 边界字符串 + + Yields: + 二进制数据块 + """ + for type_code, images in images_dict.items(): + for image in images: + if not image or not image.file_path: + continue + if not os.path.exists(image.file_path): + logger.warning(f"图片文件不存在: {image.file_path}") + continue + + ext = os.path.splitext(image.file_path)[1].lower() + mimetype = self.MIME_TYPES.get(ext, 'application/octet-stream') + + # 构建 part header + header = ( + f"--{boundary}\r\n" + f"Content-Type: {mimetype}\r\n" + f"Content-Disposition: attachment; filename=\"{image.stored_filename}\"\r\n" + f"X-Image-Id: {image.images_id}\r\n" + f"X-Image-Type: {type_code}\r\n" + f"X-Image-Width: {image.width or 0}\r\n" + f"X-Image-Height: {image.height or 0}\r\n" + f"\r\n" + ) + yield header.encode('utf-8') + + # 流式读取文件内容 + with open(image.file_path, 'rb') as f: + while chunk := f.read(8192): + yield chunk + + yield b"\r\n" + + # 结束边界 + yield f"--{boundary}--\r\n".encode('utf-8') + + +# 全局单例 +_default_serializer: Optional[ImageSerializer] = None + + +def get_image_serializer() -> ImageSerializer: + """获取默认的序列化器实例""" + global _default_serializer + if _default_serializer is None: + _default_serializer = ImageSerializer() + return _default_serializer diff --git a/src/backend/app/services/image/image_storage.py b/src/backend/app/services/image/image_storage.py new file mode 100644 index 0000000..ec2beb5 --- /dev/null +++ b/src/backend/app/services/image/image_storage.py @@ -0,0 +1,280 @@ +""" +图片存储服务 + +职责单一:图片的保存、删除、文件管理 +""" +import os +import uuid +import logging +from typing import Optional, List, Tuple +from datetime import datetime + +from app import db +from app.database import Image, ImageType +from app.utils.file_utils import allowed_file +from app.services.image.image_processor import ImageProcessor + +logger = logging.getLogger(__name__) + + +def _get_image_repo(): + """懒加载获取 ImageRepository""" + from app.repositories import ImageRepository + return ImageRepository() + + +def _get_image_type_repo(): + """懒加载获取 ImageTypeRepository""" + from app.repositories import ImageTypeRepository + return ImageTypeRepository() + + +class ImageStorage: + """ + 图片存储管理器 + + 负责: + - 保存上传的图片到指定目录 + - 创建数据库记录 + - 删除图片文件和记录 + - 管理临时文件 + + 使用方式: + storage = ImageStorage() + result = storage.save_uploaded_image(file, task, target_dir) + storage.delete_image(image_id, user_id) + """ + + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.gif', '.bmp', '.tiff'} + + def __init__(self, processor: Optional[ImageProcessor] = None, image_repo=None): + """ + 初始化存储管理器 + + Args: + processor: 图片处理器,默认创建新实例 + image_repo: 图片 Repository,默认懒加载 + """ + self._processor = processor or ImageProcessor() + self._image_repo = image_repo + + @property + def processor(self) -> ImageProcessor: + return self._processor + + @property + def image_repo(self): + """懒加载 ImageRepository""" + if self._image_repo is None: + self._image_repo = _get_image_repo() + return self._image_repo + + def save_uploaded_image( + self, + file, + task, + target_dir: str, + image_type_code: str = 'original' + ) -> dict: + """ + 保存单张上传的图片 + + Args: + file: 上传的文件对象 + task: 关联的任务对象 + target_dir: 目标存储目录 + image_type_code: 图片类型代码 + + Returns: + {'success': True, 'image': Image} 或 {'success': False, 'error': str} + """ + if not file or not file.filename: + return {'success': False, 'error': '无效的文件'} + + if not allowed_file(file.filename): + return {'success': False, 'error': '不支持的文件格式'} + + ext = os.path.splitext(file.filename)[1].lower() + if ext not in self.IMAGE_EXTENSIONS: + return {'success': False, 'error': f'不支持的图片格式: {ext}'} + + image_type = _get_image_type_repo().get_by_code(image_type_code) + if not image_type: + return {'success': False, 'error': f'未配置图片类型: {image_type_code}'} + + os.makedirs(target_dir, exist_ok=True) + + try: + # 处理图片 + processed = self._processor.process_from_file(file) + + # 生成文件名并保存 + filename, path, width, height, file_size = self._save_with_unique_name( + processed, target_dir + ) + + # 创建数据库记录 + image = self._create_record( + task=task, + image_type_id=image_type.image_types_id, + filename=filename, + path=path, + width=width, + height=height, + file_size=file_size + ) + + db.session.commit() + return {'success': True, 'image': image} + + except Exception as e: + db.session.rollback() + logger.error(f"保存图片失败: {e}") + return {'success': False, 'error': f'保存图片失败: {str(e)}'} + + def save_multiple_images( + self, + files: List, + task, + target_dir: str, + image_type_code: str = 'original' + ) -> Tuple[bool, any]: + """ + 批量保存上传的图片 + + Args: + files: 文件列表 + task: 关联的任务对象 + target_dir: 目标存储目录 + image_type_code: 图片类型代码 + + Returns: + (success, result) - result 是 Image 列表或错误信息 + """ + if not files: + return False, '未检测到文件上传' + + image_type = _get_image_type_repo().get_by_code(image_type_code) + if not image_type: + return False, f'未配置图片类型: {image_type_code}' + + os.makedirs(target_dir, exist_ok=True) + + saved_records = [] + saved_paths = [] + + try: + for file in files: + if not file or not file.filename: + continue + if not allowed_file(file.filename): + continue + + ext = os.path.splitext(file.filename)[1].lower() + if ext not in self.IMAGE_EXTENSIONS: + continue + + processed = self._processor.process_from_file(file) + filename, path, width, height, file_size = self._save_with_unique_name( + processed, target_dir + ) + + image = self._create_record( + task=task, + image_type_id=image_type.image_types_id, + filename=filename, + path=path, + width=width, + height=height, + file_size=file_size + ) + saved_records.append(image) + saved_paths.append(path) + + if not saved_records: + db.session.rollback() + return False, '未上传有效的图片文件' + + db.session.commit() + return True, saved_records + + except Exception as e: + db.session.rollback() + # 清理已保存的文件 + for path in saved_paths: + if os.path.exists(path): + try: + os.remove(path) + except OSError: + pass + return False, f'上传图片失败: {str(e)}' + + def delete_image(self, image_id: int, user_id: int) -> dict: + """ + 删除图片(验证权限) + + Args: + image_id: 图片 ID + user_id: 用户 ID(用于权限验证) + + Returns: + {'success': True} 或 {'success': False, 'error': str} + """ + try: + # 使用 Repository 获取并验证权限 + image = self.image_repo.get_for_user(image_id, user_id) + if not image: + return {'success': False, 'error': '图片不存在或无权限'} + + # 删除文件 + if image.file_path and os.path.exists(image.file_path): + os.remove(image.file_path) + + # 使用 Repository 删除记录 + if self.image_repo.delete(image) and self.image_repo.save(): + return {'success': True} + return {'success': False, 'error': '删除记录失败'} + + except Exception as e: + self.image_repo.rollback() + logger.error(f"删除图片失败: {e}") + return {'success': False, 'error': f'删除图片失败: {str(e)}'} + + def _save_with_unique_name(self, image, target_dir: str) -> Tuple[str, str, int, int, int]: + """保存图片并生成唯一文件名""" + timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') + filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" + path = os.path.join(target_dir, filename) + + width, height, file_size = self._processor.save(image, path) + return filename, path, width, height, file_size + + def _create_record( + self, + task, + image_type_id: int, + filename: str, + path: str, + width: int, + height: int, + file_size: int, + father_id: int = None + ) -> Image: + """创建数据库记录""" + image = Image( + task_id=task.tasks_id, + image_types_id=image_type_id, + father_id=father_id, + stored_filename=filename, + file_path=path, + file_size=file_size, + width=width, + height=height + ) + db.session.add(image) + return image + + @staticmethod + def get_image_type_by_code(code: str) -> Optional[ImageType]: + """根据代码获取图片类型(委托给 Repository)""" + return _get_image_type_repo().get_by_code(code) diff --git a/src/backend/app/services/image/zip_service.py b/src/backend/app/services/image/zip_service.py new file mode 100644 index 0000000..3dcf28b --- /dev/null +++ b/src/backend/app/services/image/zip_service.py @@ -0,0 +1,139 @@ +""" +打包服务 + +职责单一:目录和文件的 ZIP 打包 +""" +import os +import io +import zipfile +import logging +from typing import Union, Dict, List, Tuple + +logger = logging.getLogger(__name__) + + +class ZipService: + """ + ZIP 打包服务 + + 负责: + - 将单个目录打包为 ZIP + - 将多个目录打包为 ZIP + + 使用方式: + zip_service = ZipService() + buffer, has_files = zip_service.zip_directory('/path/to/dir') + buffer, has_files = zip_service.zip_multiple({'label1': '/path1', 'label2': '/path2'}) + """ + + def zip_directory(self, directory: str) -> Tuple[io.BytesIO, bool]: + """ + 将单个目录打包为 ZIP + + Args: + directory: 目录路径 + + Returns: + (BytesIO 缓冲区, 是否包含文件) + """ + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + if os.path.isdir(directory): + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + arcname = os.path.relpath(file_path, directory) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + def zip_multiple( + self, + directories: Union[Dict[str, str], List[str]] + ) -> Tuple[io.BytesIO, bool]: + """ + 将多个目录打包为 ZIP + + Args: + directories: 目录字典 {label: path} 或目录列表 + + Returns: + (BytesIO 缓冲区, 是否包含文件) + """ + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + # 统一处理字典和列表 + if isinstance(directories, dict): + iterable = directories.items() + else: + iterable = ( + (os.path.basename(d.rstrip(os.sep)) or 'output', d) + for d in directories + ) + + for label, directory in iterable: + if not os.path.isdir(directory): + continue + + for root, _, files in os.walk(directory): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, directory) + arcname = os.path.join(label or 'output', rel_path) + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + def zip_files( + self, + files: List[str], + base_dir: str = None + ) -> Tuple[io.BytesIO, bool]: + """ + 将文件列表打包为 ZIP + + Args: + files: 文件路径列表 + base_dir: 基础目录(用于计算相对路径) + + Returns: + (BytesIO 缓冲区, 是否包含文件) + """ + buffer = io.BytesIO() + has_files = False + + with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: + for file_path in files: + if not os.path.isfile(file_path): + continue + + if base_dir: + arcname = os.path.relpath(file_path, base_dir) + else: + arcname = os.path.basename(file_path) + + zipf.write(file_path, arcname) + has_files = True + + buffer.seek(0) + return buffer, has_files + + +# 全局单例 +_default_zip_service: ZipService = None + + +def get_zip_service() -> ZipService: + """获取默认的打包服务实例""" + global _default_zip_service + if _default_zip_service is None: + _default_zip_service = ZipService() + return _default_zip_service diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index 4471b79..f4021ed 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -1,212 +1,227 @@ """ -图像处理服务 -处理图像上传、保存等功能 +图像处理服务(兼容入口) + +已重构为面向对象设计,此文件保留原有接口以保持向后兼容。 + +新代码请直接使用: + from app.services.image import ( + ImageProcessor, # 图片预处理 + ImageStorage, # 图片存储 + ImageSerializer, # 图片序列化 + ZipService, # 打包服务 + ImagePreviewService # 预览服务 + ) + +相关类: + - ImageProcessor: 裁剪、缩放、格式转换 + - ImageStorage: 保存、删除、文件管理 + - ImageSerializer: JSON、Base64 序列化 + - ZipService: ZIP 打包 + - ImagePreviewService: 任务预览图片 """ - -import base64 -import io import os -import uuid -import zipfile -import time -from datetime import datetime -from werkzeug.utils import secure_filename -from flask import current_app, jsonify -from PIL import Image as PILImage +import logging +from typing import Optional, List, Tuple, Dict, Any +from flask import jsonify + from app import db from app.database import Image, ImageType from app.utils.file_utils import allowed_file +from app.services.image.image_processor import ImageProcessor +from app.services.image.image_storage import ImageStorage +from app.services.image.image_serializer import ImageSerializer, get_image_serializer +from app.services.image.zip_service import ZipService, get_zip_service +from app.services.image.image_preview import ImagePreviewService, get_preview_service + +logger = logging.getLogger(__name__) + +# 全局实例 +_storage: Optional[ImageStorage] = None +_serializer: Optional[ImageSerializer] = None +_zip_service: Optional[ZipService] = None +_preview_service: Optional[ImagePreviewService] = None + + +def _get_storage() -> ImageStorage: + global _storage + if _storage is None: + _storage = ImageStorage() + return _storage + + +def _get_serializer() -> ImageSerializer: + global _serializer + if _serializer is None: + _serializer = ImageSerializer() + return _serializer + + +def _get_zip_service() -> ZipService: + global _zip_service + if _zip_service is None: + _zip_service = ZipService() + return _zip_service + + +def _get_preview_service() -> ImagePreviewService: + global _preview_service + if _preview_service is None: + _preview_service = ImagePreviewService() + return _preview_service + class ImageService: + """ + 图像处理服务(兼容类) + + 内部委托给新的服务类,保持原有 API 不变 + """ + + DEFAULT_TARGET_SIZE = 512 + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'} + + # ==================== 存储相关(委托给 ImageStorage)==================== + + @staticmethod + def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'): + """保存单张图片""" + # 简化实现,委托给新服务 + from app.database import Task + task = Task.query.get(task_id) + if not task: + return {'success': False, 'error': '任务不存在'} + + from flask import current_app + project_root = os.path.dirname(current_app.root_path) + target_dir = os.path.join( + project_root, + current_app.config.get('ORIGINAL_IMAGES_FOLDER', 'static/originals'), + str(user_id), + str(task_id) + ) + + return _get_storage().save_uploaded_image(file, task, target_dir) + + @staticmethod + def save_original_images(task, files, target_dir, image_type_code='original', target_size=None): + """保存原图上传""" + return _get_storage().save_multiple_images(files, task, target_dir, image_type_code) + + @staticmethod + def delete_image(image_id, user_id): + """删除图片""" + return _get_storage().delete_image(image_id, user_id) + + @staticmethod + def get_image_type_by_code(code): + """根据代码获取图片类型""" + return ImageStorage.get_image_type_by_code(code) + + # ==================== 序列化相关(委托给 ImageSerializer)==================== + + @staticmethod + def serialize_image(image): + """图片序列化""" + return _get_serializer().to_dict(image) + + @staticmethod + def get_image_url(image): + """获取图片访问URL""" + return _get_serializer().get_url(image) + + # ==================== 打包相关(委托给 ZipService)==================== + + @staticmethod + def zip_directory(directory): + """打包目录为zip""" + return _get_zip_service().zip_directory(directory) + + @staticmethod + def zip_multiple_directories(directories): + """打包多个目录""" + return _get_zip_service().zip_multiple(directories) + + + + # ==================== 工具方法 ==================== + + @staticmethod + def json_error(message, status_code=400): + """统一错误响应""" + return jsonify({'error': message}), status_code + + # ==================== 保留的原有方法(复杂逻辑暂不迁移)==================== + @staticmethod def save_to_uploads(file, task_id, user_id): - """ - 上传图片到uploads临时目录,返回临时文件路径和原始文件名。 - """ + """上传图片到uploads临时目录""" + import uuid + from flask import current_app + project_root = os.path.dirname(current_app.root_path) - upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(task_id)) + upload_dir = os.path.join( + project_root, + current_app.config['UPLOAD_FOLDER'], + str(user_id), + str(task_id) + ) os.makedirs(upload_dir, exist_ok=True) + orig_ext = os.path.splitext(file.filename)[1].lower() temp_name = f"{uuid.uuid4().hex}{orig_ext}" temp_path = os.path.join(upload_dir, temp_name) file.save(temp_path) - return temp_path, file.filename - - @staticmethod - def preprocess_image(temp_path, original_filename, task_id, user_id, image_types_id, resolution=512, target_format='png'): - """ - 对图片进行中心裁剪、缩放、格式转换、重命名,保存到static/originals,返回数据库对象。 - 原图命名格式: 0000.png, 0001.png, ..., 9999.png - 使用数据库事务和重试机制确保并发安全 - """ - final_path = None - max_retries = 50 - try: - img = PILImage.open(temp_path).convert("RGB") - width, height = img.size - min_dim = min(width, height) - left = (width - min_dim) // 2 - top = (height - min_dim) // 2 - right = left + min_dim - bottom = top + min_dim - img = img.crop((left, top, right, bottom)) - img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS) - - project_root = os.path.dirname(current_app.root_path) - static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(task_id)) - os.makedirs(static_dir, exist_ok=True) - - from app.database import ImageType - original_type = ImageType.query.filter_by(image_code='original').first() - target_image_types_id = original_type.image_types_id if original_type else image_types_id - - # 首次查询最大序号 - max_seq_result = db.session.execute( - db.text(""" - SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq - FROM images - WHERE task_id = :task_id - AND image_types_id = :image_types_id - AND stored_filename REGEXP '^[0-9]{4}\\.' - """), - {'task_id': task_id, 'image_types_id': target_image_types_id} - ).fetchone() - - # 强制类型转换,确保安全 - try: - base_sequence = int(max_seq_result[0]) if max_seq_result[0] is not None else -1 - except Exception: - base_sequence = -1 - base_sequence += 1 - - # 重试机制:从base_sequence开始尝试连续的序号 - for attempt in range(max_retries): - sequence_number = int(base_sequence) + int(attempt) - fmt_str = str(target_format).lower() if target_format else 'png' - new_name = f"{sequence_number:04d}.{fmt_str}" - final_path = os.path.join(static_dir, new_name) - - try: - # 检查数据库中是否已存在此文件名 - existing = Image.query.filter_by( - task_id=task_id, - stored_filename=new_name - ).first() - - if existing: - # 已存在,尝试下一个序号 - continue - - # 保存图片文件 - if target_format.lower() in ['jpg', 'jpeg']: - img.save(final_path, format='JPEG', quality=95) - else: - img.save(final_path, format=target_format.upper()) - - # 创建数据库记录 - image = Image( - task_id=task_id, - image_types_id=image_types_id, - stored_filename=new_name, - file_path=final_path, - file_size=os.path.getsize(final_path), - width=img.width, - height=img.height - ) - db.session.add(image) - db.session.commit() - - # 删除临时文件 - if os.path.exists(temp_path): - os.remove(temp_path) - - return {'success': True, 'image': image} - - except Exception as e: - db.session.rollback() - error_msg = str(e) - - # 如果是唯一性冲突,清理文件并尝试下一个序号 - if 'Duplicate entry' in error_msg or '1062' in error_msg: - if final_path and os.path.exists(final_path): - try: - os.remove(final_path) - except Exception: - pass - # 继续循环尝试下一个序号 - time.sleep(0.005) - continue - else: - # 其他错误直接抛出 - raise - - # 所有尝试都失败 - raise Exception(f"无法生成唯一文件名,已尝试序号 {base_sequence} 到 {base_sequence + max_retries - 1}") - - except Exception as e: - db.session.rollback() - # 清理可能已保存的文件 - if final_path and os.path.exists(final_path): - try: - os.remove(final_path) - except Exception: - pass - return {'success': False, 'error': f'图片预处理失败: {str(e)}'} - - @staticmethod - def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'): - """保存单张图片,自动上传到uploads并预处理""" - try: - if not file or not allowed_file(file.filename): - return {'success': False, 'error': '不支持的文件格式'} - temp_path, orig_name = ImageService.save_to_uploads(file, task_id, user_id) - return ImageService.preprocess_image(temp_path, orig_name, task_id, user_id, image_types_id, resolution, target_format) - except Exception as e: - db.session.rollback() - return {'success': False, 'error': f'保存图片失败: {str(e)}'} + return temp_path, file.filename @staticmethod def extract_and_save_zip(zip_file, task_id, user_id, image_types_id): """解压并保存压缩包中的图片""" + import uuid + import zipfile + import shutil + from flask import current_app + from werkzeug.utils import secure_filename + results = [] temp_dir = None try: - # 创建临时目录 project_root = os.path.dirname(current_app.root_path) - temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}") + temp_dir = os.path.join( + project_root, + current_app.config['UPLOAD_FOLDER'], + 'temp', + f"{uuid.uuid4().hex}" + ) os.makedirs(temp_dir, exist_ok=True) - # 保存压缩包 zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename)) zip_file.save(zip_path) - # 解压文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(temp_dir) - # 遍历解压的文件 for root, dirs, files in os.walk(temp_dir): for filename in files: if filename.lower().endswith(('.zip', '.rar')): - continue # 跳过压缩包文件本身 + continue if allowed_file(filename): file_path = os.path.join(root, filename) - # 创建虚拟文件对象 class FileWrapper: def __init__(self, path, name): self.path = path self.filename = name def save(self, destination): - import shutil shutil.copy2(self.path, destination) virtual_file = FileWrapper(file_path, filename) - result = ImageService.save_image(virtual_file, task_id, user_id, image_types_id) + result = ImageService.save_image( + virtual_file, task_id, user_id, image_types_id + ) results.append(result) return results @@ -215,7 +230,6 @@ class ImageService: return [{'success': False, 'error': f'解压失败: {str(e)}'}] finally: - # 清理临时文件 if temp_dir and os.path.exists(temp_dir): import shutil try: @@ -223,134 +237,25 @@ class ImageService: except Exception: pass - @staticmethod - def get_image_url(image): - """获取图片访问URL""" - if not image or not image.file_path: - return None - - # 这里返回相对路径,前端可以拼接完整URL - return f"/api/image/file/{image.images_id}" - - @staticmethod - def delete_image(image_id, user_id): - """删除图片(通过关联的task验证权限)""" - try: - image = Image.query.filter_by(images_id=image_id).first() - if not image: - return {'success': False, 'error': '图片不存在'} - - # 通过关联的task验证用户权限 - if not image.task or image.task.user_id != user_id: - return {'success': False, 'error': '无权限删除该图片'} - - # 删除文件 - if os.path.exists(image.file_path): - os.remove(image.file_path) - - # 删除数据库记录 - db.session.delete(image) - db.session.commit() - - return {'success': True} - - except Exception as e: - db.session.rollback() - return {'success': False, 'error': f'删除图片失败: {str(e)}'} - - # ==================== 控制器辅助功能 ==================== - - DEFAULT_TARGET_SIZE = 512 - IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'} - - @staticmethod - def json_error(message, status_code=400): - """统一错误响应""" - return jsonify({'error': message}), status_code - - @staticmethod - def get_image_type_by_code(code): - """根据代码获取图片类型""" - return ImageType.query.filter_by(image_code=code).first() - - @staticmethod - def save_original_images(task, files, target_dir, image_type_code='original', target_size=None): - """保存原图上传""" - if not files: - return False, '未检测到文件上传' - - image_type = ImageService.get_image_type_by_code(image_type_code) - if not image_type: - return False, f'未配置图片类型: {image_type_code}' - - os.makedirs(target_dir, exist_ok=True) - - saved_records = [] - saved_paths = [] - size = target_size or ImageService.DEFAULT_TARGET_SIZE - - try: - for file in files: - if not file or not file.filename: - continue - if not allowed_file(file.filename): - continue - - extension = os.path.splitext(file.filename)[1].lower() - if extension not in ImageService.IMAGE_EXTENSIONS: - continue - - processed = ImageService._prepare_image(file, size) - filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir) - image = ImageService._create_image_record( - task, - image_type.image_types_id, - filename, - path, - width, - height, - file_size - ) - saved_records.append(image) - saved_paths.append(path) - - if not saved_records: - db.session.rollback() - return False, '未上传有效的图片文件' - - db.session.commit() - return True, saved_records - except Exception as exc: - db.session.rollback() - for path in saved_paths: - if os.path.exists(path): - try: - os.remove(path) - except OSError: - pass - return False, f'上传图片失败: {exc}' - @staticmethod def _prepare_image(file_storage, target_size): """裁剪并缩放上传图片""" - file_storage.stream.seek(0) - image = PILImage.open(file_storage.stream).convert('RGB') - width, height = image.size - min_dim = min(width, height) - left = (width - min_dim) // 2 - top = (height - min_dim) // 2 - image = image.crop((left, top, left + min_dim, top + min_dim)) - return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS) - + processor = ImageProcessor(target_size=target_size) + return processor.process_from_file(file_storage) + @staticmethod def _save_processed_image(image, target_dir): """将处理后的图片保存为PNG""" + import uuid + from datetime import datetime + timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" path = os.path.join(target_dir, filename) image.save(path, format='PNG') + return filename, path, image.width, image.height, os.path.getsize(path) - + @staticmethod def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None): """创建图片数据库记录""" @@ -366,184 +271,3 @@ class ImageService: ) db.session.add(image) return image - - @staticmethod - def zip_directory(directory): - """打包目录为zip""" - buffer = io.BytesIO() - has_files = False - - with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: - if os.path.isdir(directory): - for root, _, files in os.walk(directory): - for filename in files: - file_path = os.path.join(root, filename) - arcname = os.path.relpath(file_path, directory) - zipf.write(file_path, arcname) - has_files = True - - buffer.seek(0) - return buffer, has_files - - @staticmethod - def zip_multiple_directories(directories): - """打包多个目录""" - buffer = io.BytesIO() - has_files = False - - with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: - if isinstance(directories, dict): - iterable = directories.items() - else: - iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories) - - for label, directory in iterable: - if not os.path.isdir(directory): - continue - for root, _, files in os.walk(directory): - for filename in files: - file_path = os.path.join(root, filename) - rel_path = os.path.relpath(file_path, directory) - arcname = os.path.join(label or 'output', rel_path) - zipf.write(file_path, arcname) - has_files = True - - buffer.seek(0) - return buffer, has_files - - @staticmethod - def serialize_image(image): - """图片序列化""" - if not image: - return None - return { - 'image_id': image.images_id, - 'task_id': image.task_id, - 'stored_filename': image.stored_filename, - 'file_path': image.file_path, - 'file_size': image.file_size, - 'width': image.width, - 'height': image.height, - 'image_type': image.image_type.image_code if image.image_type else None - } - - @staticmethod - def image_to_base64(image): - """将图片转换为 base64 编码""" - if not image or not os.path.exists(image.file_path): - return None - - ext = os.path.splitext(image.file_path)[1].lower() - mime_types = { - '.png': 'image/png', - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.gif': 'image/gif', - '.bmp': 'image/bmp', - '.webp': 'image/webp', - } - mimetype = mime_types.get(ext, 'image/png') - - with open(image.file_path, 'rb') as f: - data = base64.b64encode(f.read()).decode('utf-8') - - return { - 'image_id': image.images_id, - 'filename': image.stored_filename, - 'data': f'data:{mimetype};base64,{data}', - 'width': image.width, - 'height': image.height - } - - ## ==================== 获取预览图片服务 ==================== - def _get_perturbation_preview(task): - """获取加噪任务的预览图片""" - images = {'original': [], 'perturbed': []} - - original_type = ImageType.query.filter_by(image_code='original').first() - perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() - - if original_type: - originals = Image.query.filter_by( - task_id=task.tasks_id, - image_types_id=original_type.image_types_id - ).all() - images['original'] = [ImageService.image_to_base64(img) for img in originals] - - if perturbed_type: - perturbeds = Image.query.filter_by( - task_id=task.tasks_id, - image_types_id=perturbed_type.image_types_id - ).all() - images['perturbed'] = [ImageService.image_to_base64(img) for img in perturbeds] - - return images - - - def _get_finetune_preview(task): - """获取微调任务的预览图片""" - images = { - 'original': [], - 'original_generate': [], - 'perturbed_generate': [], - 'uploaded_generate': [] - } - - # 获取原图(从同一flow_id的perturbation任务或当前任务) - original_type = ImageType.query.filter_by(image_code='original').first() - if original_type: - # 查找同flow下的原图 - from app.database import Task - flow_tasks = Task.query.filter_by(flow_id=task.flow_id, user_id=task.user_id).all() - task_ids = [t.tasks_id for t in flow_tasks] - originals = Image.query.filter( - Image.task_id.in_(task_ids), - Image.image_types_id == original_type.image_types_id - ).all() - images['original'] = [ImageService.image_to_base64(img) for img in originals] - - # 获取生成图 - for type_code in ['original_generate', 'perturbed_generate', 'uploaded_generate']: - img_type = ImageType.query.filter_by(image_code=type_code).first() - if img_type: - generated = Image.query.filter_by( - task_id=task.tasks_id, - image_types_id=img_type.image_types_id - ).all() - images[type_code] = [ImageService.image_to_base64(img) for img in generated] - - return images - - - def _get_heatmap_preview(task): - """获取热力图任务的预览图片(热力图本身已包含原图和加噪图的对比)""" - images = {'heatmap': []} - - # 获取热力图(已是完整的对比报告图) - heatmap_type = ImageType.query.filter_by(image_code='heatmap').first() - if heatmap_type: - heatmaps = Image.query.filter_by( - task_id=task.tasks_id, - image_types_id=heatmap_type.image_types_id - ).all() - images['heatmap'] = [ImageService.image_to_base64(img) for img in heatmaps] - - return images - - - def _get_evaluate_preview(task): - """获取评估任务的预览图片""" - images = { - 'report': [] - } - - # 获取报告图 - report_type = ImageType.query.filter_by(image_code='report').first() - if report_type: - reports = Image.query.filter_by( - task_id=task.tasks_id, - image_types_id=report_type.image_types_id - ).all() - images['report'] = [ImageService.image_to_base64(img) for img in reports] - - return images \ No newline at end of file diff --git a/src/backend/app/services/storage/__init__.py b/src/backend/app/services/storage/__init__.py new file mode 100644 index 0000000..dca7c38 --- /dev/null +++ b/src/backend/app/services/storage/__init__.py @@ -0,0 +1,4 @@ +"""存储服务模块""" +from .path_manager import PathManager + +__all__ = ['PathManager'] diff --git a/src/backend/app/services/storage/path_manager.py b/src/backend/app/services/storage/path_manager.py new file mode 100644 index 0000000..93078fc --- /dev/null +++ b/src/backend/app/services/storage/path_manager.py @@ -0,0 +1,310 @@ +""" +路径管理器 + +职责单一:统一管理项目中所有路径的生成逻辑 +遵循开闭原则:新增路径类型只需添加方法,无需修改现有代码 +""" +import os +from typing import Union +from flask import current_app +from config.settings import Config + + +class PathManager: + """ + 路径管理器 + + 统一管理所有文件存储路径的生成,包括: + - 原图路径 + - 加噪图路径 + - 生成图路径(原图/加噪/上传) + - 热力图路径 + - 评估结果路径 + - 类别数据路径 + - 模型数据路径 + - 坐标文件路径 + + 使用方式: + pm = PathManager() + path = pm.get_original_images_path(user_id=1, flow_id=123) + """ + + def __init__(self, project_root: str = None): + """ + 初始化路径管理器 + + Args: + project_root: 项目根目录,默认从 Flask app 获取 + """ + self._project_root = project_root + + @property + def project_root(self) -> str: + """获取项目根目录(懒加载)""" + if self._project_root is None: + self._project_root = os.path.dirname(current_app.root_path) + return self._project_root + + def _build_path(self, *parts: Union[str, int]) -> str: + """ + 构建完整路径 + + Args: + *parts: 路径组成部分,会自动转换为字符串 + + Returns: + 完整的绝对路径 + """ + str_parts = [str(p) for p in parts] + return os.path.join(self.project_root, *str_parts) + + # ==================== 图片存储路径 ==================== + + def get_original_images_path(self, user_id: int, flow_id: int) -> str: + """ + 原图存储路径 + + 格式: {ORIGINAL_IMAGES_FOLDER}/{user_id}/{flow_id} + """ + return self._build_path( + Config.ORIGINAL_IMAGES_FOLDER, + user_id, + flow_id + ) + + def get_perturbed_images_path(self, user_id: int, flow_id: int) -> str: + """ + 加噪图存储路径 + + 格式: {PERTURBED_IMAGES_FOLDER}/{user_id}/{flow_id} + """ + return self._build_path( + Config.PERTURBED_IMAGES_FOLDER, + user_id, + flow_id + ) + + # ==================== 生成图存储路径 ==================== + + def get_original_generated_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 原图生成图存储路径 + + 格式: {MODEL_ORIGINAL_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.MODEL_ORIGINAL_FOLDER, + user_id, + flow_id, + task_id + ) + + def get_perturbed_generated_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 加噪图生成图存储路径 + + 格式: {MODEL_PERTURBED_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.MODEL_PERTURBED_FOLDER, + user_id, + flow_id, + task_id + ) + + def get_uploaded_generated_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 上传图生成图存储路径 + + 格式: {MODEL_UPLOADED_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.MODEL_UPLOADED_FOLDER, + user_id, + flow_id, + task_id + ) + + # ==================== 结果存储路径 ==================== + + def get_heatmap_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 热力图存储路径 + + 格式: {HEATDIF_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.HEATDIF_SAVE_FOLDER, + user_id, + flow_id, + task_id + ) + + def get_evaluate_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 评估结果存储路径 + + 格式: {NUMBERS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.NUMBERS_SAVE_FOLDER, + user_id, + flow_id, + task_id + ) + + # ==================== 数据路径 ==================== + + def get_class_data_path(self, user_id: int, flow_id: int) -> str: + """ + 类别数据存储路径 + + 格式: {CLASS_DATA_FOLDER}/{user_id}/{flow_id} + """ + return self._build_path( + Config.CLASS_DATA_FOLDER, + user_id, + flow_id + ) + + def get_model_data_path(self) -> str: + """ + 模型数据存储路径(全局共享) + + 格式: {MODEL_DATA_FOLDER} + """ + return self._build_path(Config.MODEL_DATA_FOLDER) + + # ==================== 坐标文件路径 ==================== + + def get_coords_base_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 坐标文件基础路径 + + 格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id} + """ + return self._build_path( + Config.COORDS_SAVE_FOLDER, + user_id, + flow_id, + task_id + ) + + def get_original_coords_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 原图坐标文件路径(3D可视化用) + + 格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/original_coords.csv + """ + return os.path.join( + self.get_coords_base_path(user_id, flow_id, task_id), + 'original_coords.csv' + ) + + def get_perturbed_coords_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 加噪图坐标文件路径(3D可视化用) + + 格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/perturbed_coords.csv + """ + return os.path.join( + self.get_coords_base_path(user_id, flow_id, task_id), + 'perturbed_coords.csv' + ) + + def get_uploaded_coords_path( + self, + user_id: int, + flow_id: int, + task_id: int + ) -> str: + """ + 上传图坐标文件路径(3D可视化用) + + 格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/coords.csv + """ + return os.path.join( + self.get_coords_base_path(user_id, flow_id, task_id), + 'coords.csv' + ) + + # ==================== 图片文件完整路径 ==================== + + def get_original_image_file_path( + self, + user_id: int, + flow_id: int, + filename: str + ) -> str: + """获取原图文件的完整路径""" + return os.path.join( + self.get_original_images_path(user_id, flow_id), + filename + ) + + def get_perturbed_image_file_path( + self, + user_id: int, + flow_id: int, + filename: str + ) -> str: + """获取加噪图文件的完整路径""" + return os.path.join( + self.get_perturbed_images_path(user_id, flow_id), + filename + ) + + +# ============================================================ +# 全局单例实例(便于简单场景使用) +# ============================================================ + +_default_manager: PathManager = None + + +def get_path_manager() -> PathManager: + """获取默认的路径管理器实例""" + global _default_manager + if _default_manager is None: + _default_manager = PathManager() + return _default_manager diff --git a/src/backend/app/services/task/__init__.py b/src/backend/app/services/task/__init__.py new file mode 100644 index 0000000..c0fdafc --- /dev/null +++ b/src/backend/app/services/task/__init__.py @@ -0,0 +1,28 @@ +""" +任务处理模块 + +提供面向对象的任务处理器,使用模板方法模式统一任务生命周期。 + +使用方式: + from app.services.task import TaskHandlerFactory + + handler = TaskHandlerFactory.create('perturbation') + job_id = handler.start(task_id) +""" +from .base_handler import BaseTaskHandler +from .perturbation_handler import PerturbationTaskHandler +from .finetune_handler import FinetuneTaskHandler +from .heatmap_handler import HeatmapTaskHandler +from .evaluate_handler import EvaluateTaskHandler +from .task_factory import TaskHandlerFactory +from .task_queue import TaskQueue + +__all__ = [ + 'BaseTaskHandler', + 'PerturbationTaskHandler', + 'FinetuneTaskHandler', + 'HeatmapTaskHandler', + 'EvaluateTaskHandler', + 'TaskHandlerFactory', + 'TaskQueue', +] diff --git a/src/backend/app/services/task/base_handler.py b/src/backend/app/services/task/base_handler.py new file mode 100644 index 0000000..6b0993c --- /dev/null +++ b/src/backend/app/services/task/base_handler.py @@ -0,0 +1,146 @@ +""" +任务处理器基类 + +使用模板方法模式定义任务启动的统一流程,子类实现具体细节。 +""" +import logging +from abc import ABC, abstractmethod +from typing import Optional, Any + +from app.database import Task +from app.services.storage import PathManager +from app.services.task.task_queue import TaskQueue + +logger = logging.getLogger(__name__) + + +def _get_task_repo(): + """懒加载获取 TaskRepository""" + from app.repositories import TaskRepository + return TaskRepository() + + +class BaseTaskHandler(ABC): + """ + 任务处理器抽象基类 + + 定义任务启动的模板方法,子类需实现: + - _get_task_type_code(): 返回任务类型代码 + - _load_task_detail(): 加载任务详情 + - _validate(): 验证任务数据 + - _build_worker_params(): 构建 worker 参数 + - _get_worker_func(): 返回 worker 函数 + """ + + def __init__( + self, + path_manager: Optional[PathManager] = None, + task_queue: Optional[TaskQueue] = None, + task_repo=None + ): + self._path_manager = path_manager or PathManager() + self._task_queue = task_queue or TaskQueue() + self._task_repo = task_repo + + @property + def path_manager(self) -> PathManager: + return self._path_manager + + @property + def task_queue(self) -> TaskQueue: + return self._task_queue + + @property + def task_repo(self): + """懒加载 TaskRepository""" + if self._task_repo is None: + self._task_repo = _get_task_repo() + return self._task_repo + + def start(self, task_id: int) -> Optional[str]: + """ + 启动任务(模板方法) + """ + try: + task = self._load_task(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + detail = self._load_task_detail(task_id) + if not detail: + logger.error(f"{self._get_task_type_code()} detail for task {task_id} not found") + return None + + error = self._validate(task, detail) + if error: + logger.error(f"Task {task_id} validation failed: {error}") + return None + + self._update_status(task, 'waiting') + params = self._build_worker_params(task, detail) + job_id = self._enqueue(task_id, params) + + if job_id: + logger.info(f"{self._get_task_type_code()} task {task_id} started with job_id {job_id}") + + return job_id + + except Exception as e: + logger.error(f"Error starting {self._get_task_type_code()} task {task_id}: {e}") + return None + + def _load_task(self, task_id: int) -> Optional[Task]: + """使用 Repository 加载任务""" + return self.task_repo.get_by_id(task_id) + + def _update_status(self, task: Task, status_code: str) -> bool: + """使用 Repository 更新任务状态""" + if self.task_repo.update_status(task, status_code): + return self.task_repo.save() + return False + + def _enqueue(self, task_id: int, params: dict) -> Optional[str]: + job_id = self._get_job_id(task_id) + worker_func = self._get_worker_func() + timeout = self._get_timeout() + + return self._task_queue.enqueue( + worker_func, + job_id=job_id, + timeout=timeout, + **params + ) + + @abstractmethod + def _get_task_type_code(self) -> str: + pass + + @abstractmethod + def _load_task_detail(self, task_id: int) -> Optional[Any]: + pass + + @abstractmethod + def _validate(self, task: Task, detail: Any) -> Optional[str]: + pass + + @abstractmethod + def _build_worker_params(self, task: Task, detail: Any) -> dict: + pass + + @abstractmethod + def _get_worker_func(self): + pass + + def _get_job_id(self, task_id: int) -> str: + prefix_map = { + 'perturbation': 'pert', + 'finetune': 'ft', + 'heatmap': 'hm', + 'evaluate': 'eval' + } + prefix = prefix_map.get(self._get_task_type_code(), 'task') + return f"{prefix}_{task_id}" + + def _get_timeout(self) -> str: + return '4h' diff --git a/src/backend/app/services/task/evaluate_handler.py b/src/backend/app/services/task/evaluate_handler.py new file mode 100644 index 0000000..1351452 --- /dev/null +++ b/src/backend/app/services/task/evaluate_handler.py @@ -0,0 +1,87 @@ +""" +评估任务处理器 + +处理模型评估(Evaluate)任务的启动逻辑 +""" +import logging +from typing import Optional + +from app.database import Evaluate, Task +from app.services.task.base_handler import BaseTaskHandler + +logger = logging.getLogger(__name__) + + +def _get_evaluate_repo(): + """懒加载获取 EvaluateRepository""" + from app.repositories import EvaluateRepository + return EvaluateRepository() + + +def _get_finetune_repo(): + """懒加载获取 FinetuneRepository""" + from app.repositories import FinetuneRepository + return FinetuneRepository() + + +class EvaluateTaskHandler(BaseTaskHandler): + """ + 评估任务处理器 + + 处理流程: + 1. 加载 Evaluate 详情 + 2. 验证关联的微调任务存在 + 3. 从微调任务获取路径信息 + 4. 构建评估参数 + 5. 入队执行 evaluate_worker + """ + + def _get_task_type_code(self) -> str: + return 'evaluate' + + def _load_task_detail(self, task_id: int) -> Optional[Evaluate]: + return _get_evaluate_repo().get_by_task(task_id) + + def _validate(self, task: Task, detail: Evaluate) -> Optional[str]: + """验证评估任务配置""" + # 检查关联的微调任务 + if not detail.finetune_task_id: + return "Evaluate task has no associated finetune task" + + finetune = _get_finetune_repo().get_by_task(detail.finetune_task_id) + if not finetune: + return f"Finetune task {detail.finetune_task_id} not found" + + finetune_task = finetune.task + if not finetune_task: + return f"Finetune task {detail.finetune_task_id} missing Task relation" + + return None + + def _build_worker_params(self, task: Task, detail: Evaluate) -> dict: + """构建评估任务参数""" + pm = self.path_manager + + # 获取关联的微调任务信息 + finetune = _get_finetune_repo().get_by_task(detail.finetune_task_id) + finetune_task = finetune.task + + user_id = finetune_task.user_id + flow_id = finetune_task.flow_id + finetune_task_id = finetune_task.tasks_id + + return { + 'task_id': task.tasks_id, + 'clean_ref_dir': pm.get_original_images_path(user_id, flow_id), + 'clean_output_dir': pm.get_original_generated_path(user_id, flow_id, finetune_task_id), + 'perturbed_output_dir': pm.get_perturbed_generated_path(user_id, flow_id, finetune_task_id), + 'output_dir': pm.get_evaluate_path(user_id, flow_id, task.tasks_id), + 'image_size': 512, + } + + def _get_worker_func(self): + from app.workers.evaluate_worker import run_evaluate_task + return run_evaluate_task + + def _get_timeout(self) -> str: + return '2h' diff --git a/src/backend/app/services/task/finetune_handler.py b/src/backend/app/services/task/finetune_handler.py new file mode 100644 index 0000000..2c3d26b --- /dev/null +++ b/src/backend/app/services/task/finetune_handler.py @@ -0,0 +1,214 @@ +""" +微调任务处理器 + +处理模型微调(Finetune)任务的启动逻辑 +支持两种类型:基于加噪结果的微调 和 用户上传图片的微调 +""" +import logging +from typing import Optional, List + +from app.database import Finetune, Task +from app.services.task.base_handler import BaseTaskHandler + +logger = logging.getLogger(__name__) + + +def _get_finetune_repo(): + """懒加载获取 FinetuneRepository""" + from app.repositories import FinetuneRepository + return FinetuneRepository() + + +def _get_finetune_config_repo(): + """懒加载获取 FinetuneConfigRepository""" + from app.repositories import FinetuneConfigRepository + return FinetuneConfigRepository() + + +class FinetuneTaskHandler(BaseTaskHandler): + """ + 微调任务处理器 + + 支持两种微调类型: + 1. 基于加噪结果的微调(perturbation-based) + - 同一 flow_id 下存在 Perturbation 任务 + - 同时处理原图和加噪图,生成两个 job + + 2. 用户上传图片的微调(uploaded) + - 独立的 flow_id,无关联的 Perturbation 任务 + - 仅处理上传的原图,生成一个 job + """ + + def _get_task_type_code(self) -> str: + return 'finetune' + + def _load_task_detail(self, task_id: int) -> Optional[Finetune]: + return _get_finetune_repo().get_by_task(task_id) + + def _validate(self, task: Task, detail: Finetune) -> Optional[str]: + """验证微调任务配置""" + config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id) + if not config: + return f"Finetune config {detail.finetune_configs_id} not found" + return None + + def _has_perturbation_sibling(self, task: Task) -> bool: + """检查是否存在同 flow_id 的加噪任务""" + sibling = self.task_repo.get_by_flow_and_type(task.flow_id, 'perturbation') + return sibling is not None and sibling.tasks_id != task.tasks_id + + def _build_worker_params(self, task: Task, detail: Finetune) -> dict: + """构建微调任务参数(单个 job 的情况)""" + # 此方法用于上传图片微调的情况 + user_id = task.user_id + flow_id = task.flow_id + task_id = task.tasks_id + pm = self.path_manager + + config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id) + + return { + 'task_id': task_id, + 'finetune_method': config.finetune_code, + 'train_images_dir': pm.get_original_images_path(user_id, flow_id), + 'output_model_dir': pm.get_model_data_path(), + 'class_dir': pm.get_class_data_path(user_id, flow_id), + 'coords_save_path': pm.get_uploaded_coords_path(user_id, flow_id, task_id), + 'validation_output_dir': pm.get_uploaded_generated_path(user_id, flow_id, task_id), + 'is_perturbed': False, + 'custom_params': None, + } + + def _build_perturbation_based_params(self, task: Task, detail: Finetune) -> List[dict]: + """构建基于加噪的微调参数(两个 job)""" + user_id = task.user_id + flow_id = task.flow_id + task_id = task.tasks_id + pm = self.path_manager + + config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id) + + # 原图微调参数 + original_params = { + 'task_id': task_id, + 'finetune_method': config.finetune_code, + 'train_images_dir': pm.get_original_images_path(user_id, flow_id), + 'output_model_dir': pm.get_model_data_path(), + 'class_dir': pm.get_class_data_path(user_id, flow_id), + 'coords_save_path': pm.get_original_coords_path(user_id, flow_id, task_id), + 'validation_output_dir': pm.get_original_generated_path(user_id, flow_id, task_id), + 'is_perturbed': False, + 'custom_params': None, + } + + # 加噪图微调参数 + perturbed_params = { + 'task_id': task_id, + 'finetune_method': config.finetune_code, + 'train_images_dir': pm.get_perturbed_images_path(user_id, flow_id), + 'output_model_dir': pm.get_model_data_path(), + 'class_dir': pm.get_class_data_path(user_id, flow_id), + 'coords_save_path': pm.get_perturbed_coords_path(user_id, flow_id, task_id), + 'validation_output_dir': pm.get_perturbed_generated_path(user_id, flow_id, task_id), + 'is_perturbed': True, + 'custom_params': None, + } + + return [original_params, perturbed_params] + + def _get_worker_func(self): + from app.workers.finetune_worker import run_finetune_task + return run_finetune_task + + def _get_timeout(self) -> str: + return '8h' + + def start(self, task_id: int) -> Optional[str]: + """ + 启动微调任务(重写模板方法以支持双 job) + + Returns: + 单个 job_id 或逗号分隔的多个 job_id + """ + try: + # 加载任务 + task = self._load_task(task_id) + if not task: + logger.error(f"Task {task_id} not found") + return None + + detail = self._load_task_detail(task_id) + if not detail: + logger.error(f"Finetune detail for task {task_id} not found") + return None + + # 验证 + error = self._validate(task, detail) + if error: + logger.error(f"Task {task_id} validation failed: {error}") + return None + + # 更新状态 + self._update_status(task, 'waiting') + + # 判断微调类型 + if self._has_perturbation_sibling(task): + # 基于加噪的微调:创建两个 job + logger.info(f"Finetune task {task_id}: type=perturbation-based") + return self._start_perturbation_based(task_id, task, detail) + else: + # 上传图片的微调:创建一个 job + logger.info(f"Finetune task {task_id}: type=uploaded") + return self._start_uploaded(task_id, task, detail) + + except Exception as e: + logger.error(f"Error starting finetune task {task_id}: {e}") + return None + + def _start_perturbation_based(self, task_id: int, task: Task, detail: Finetune) -> Optional[str]: + """启动基于加噪的微调(两个 job)""" + params_list = self._build_perturbation_based_params(task, detail) + worker_func = self._get_worker_func() + timeout = self._get_timeout() + + job_id_original = f"ft_{task_id}_original" + job_id_perturbed = f"ft_{task_id}_perturbed" + + # 入队原图微调 + result1 = self.task_queue.enqueue( + worker_func, + job_id=job_id_original, + timeout=timeout, + **params_list[0] + ) + + # 入队加噪图微调 + result2 = self.task_queue.enqueue( + worker_func, + job_id=job_id_perturbed, + timeout=timeout, + **params_list[1] + ) + + if result1 and result2: + logger.info(f"Finetune task {task_id} enqueued: {job_id_original}, {job_id_perturbed}") + return f"{job_id_original},{job_id_perturbed}" + + return None + + def _start_uploaded(self, task_id: int, task: Task, detail: Finetune) -> Optional[str]: + """启动上传图片的微调(单个 job)""" + params = self._build_worker_params(task, detail) + job_id = f"ft_{task_id}" + + result = self.task_queue.enqueue( + self._get_worker_func(), + job_id=job_id, + timeout=self._get_timeout(), + **params + ) + + if result: + logger.info(f"Finetune task {task_id} enqueued: {job_id}") + + return result diff --git a/src/backend/app/services/task/heatmap_handler.py b/src/backend/app/services/task/heatmap_handler.py new file mode 100644 index 0000000..b9992ae --- /dev/null +++ b/src/backend/app/services/task/heatmap_handler.py @@ -0,0 +1,105 @@ +""" +热力图任务处理器 + +处理热力图(Heatmap)生成任务的启动逻辑 +""" +import os +import logging +from typing import Optional + +from app.database import Heatmap, Task +from app.services.task.base_handler import BaseTaskHandler + +logger = logging.getLogger(__name__) + + +def _get_heatmap_repo(): + """懒加载获取 HeatmapRepository""" + from app.repositories import HeatmapRepository + return HeatmapRepository() + + +def _get_image_repo(): + """懒加载获取 ImageRepository""" + from app.repositories import ImageRepository + return ImageRepository() + + +class HeatmapTaskHandler(BaseTaskHandler): + """ + 热力图任务处理器 + + 处理流程: + 1. 加载 Heatmap 详情 + 2. 验证关联的加噪图片存在 + 3. 通过 father_id 找到原图 + 4. 构建图片路径和输出路径 + 5. 入队执行 heatmap_worker + """ + + def _get_task_type_code(self) -> str: + return 'heatmap' + + def _load_task_detail(self, task_id: int) -> Optional[Heatmap]: + return _get_heatmap_repo().get_by_task(task_id) + + def _validate(self, task: Task, detail: Heatmap) -> Optional[str]: + """验证热力图任务配置""" + image_repo = _get_image_repo() + + # 检查加噪图片 ID + if not detail.images_id: + return "Heatmap task has no associated perturbed image" + + # 检查加噪图片存在 + perturbed_image = image_repo.get_by_id(detail.images_id) + if not perturbed_image: + return f"Perturbed image {detail.images_id} not found" + + # 检查原图存在(通过 father_id) + if not perturbed_image.father_id: + return f"Perturbed image {detail.images_id} has no father_id" + + original_image = image_repo.get_by_id(perturbed_image.father_id) + if not original_image: + return f"Original image (father_id={perturbed_image.father_id}) not found" + + return None + + def _build_worker_params(self, task: Task, detail: Heatmap) -> dict: + """构建热力图任务参数""" + image_repo = _get_image_repo() + user_id = task.user_id + flow_id = task.flow_id + task_id = task.tasks_id + pm = self.path_manager + + # 获取加噪图片 + perturbed_image = image_repo.get_by_id(detail.images_id) + original_image = image_repo.get_by_id(perturbed_image.father_id) + + # 构建图片完整路径 + original_image_path = os.path.join( + pm.get_original_images_path(user_id, flow_id), + original_image.stored_filename + ) + + perturbed_image_path = os.path.join( + pm.get_perturbed_images_path(user_id, flow_id), + perturbed_image.stored_filename + ) + + return { + 'task_id': task_id, + 'original_image_path': original_image_path, + 'perturbed_image_path': perturbed_image_path, + 'output_dir': pm.get_heatmap_path(user_id, flow_id, task_id), + 'perturbed_image_id': detail.images_id, + } + + def _get_worker_func(self): + from app.workers.heatmap_worker import run_heatmap_task + return run_heatmap_task + + def _get_timeout(self) -> str: + return '2h' diff --git a/src/backend/app/services/task/perturbation_handler.py b/src/backend/app/services/task/perturbation_handler.py new file mode 100644 index 0000000..a75bc09 --- /dev/null +++ b/src/backend/app/services/task/perturbation_handler.py @@ -0,0 +1,61 @@ +""" +加噪任务处理器 +""" +import logging +from typing import Optional + +from app.database import Perturbation, Task +from app.services.task.base_handler import BaseTaskHandler + +logger = logging.getLogger(__name__) + + +def _get_perturbation_repo(): + """懒加载获取 PerturbationRepository""" + from app.repositories import PerturbationRepository + return PerturbationRepository() + + +def _get_perturbation_config_repo(): + """懒加载获取 PerturbationConfigRepository""" + from app.repositories import PerturbationConfigRepository + return PerturbationConfigRepository() + + +class PerturbationTaskHandler(BaseTaskHandler): + """加噪任务处理器""" + + def _get_task_type_code(self) -> str: + return 'perturbation' + + def _load_task_detail(self, task_id: int) -> Optional[Perturbation]: + return _get_perturbation_repo().get_by_task(task_id) + + def _validate(self, task: Task, detail: Perturbation) -> Optional[str]: + config = _get_perturbation_config_repo().get_by_id(detail.perturbation_configs_id) + if not config: + return f"Perturbation config {detail.perturbation_configs_id} not found" + return None + + def _build_worker_params(self, task: Task, detail: Perturbation) -> dict: + user_id = task.user_id + flow_id = task.flow_id + pm = self.path_manager + + config = _get_perturbation_config_repo().get_by_id(detail.perturbation_configs_id) + + return { + 'task_id': task.tasks_id, + 'input_dir': pm.get_original_images_path(user_id, flow_id), + 'output_dir': pm.get_perturbed_images_path(user_id, flow_id), + 'class_dir': pm.get_class_data_path(user_id, flow_id), + 'algorithm_code': config.perturbation_code, + 'epsilon': detail.perturbation_intensity, + } + + def _get_worker_func(self): + from app.workers.perturbation_worker import run_perturbation_task + return run_perturbation_task + + def _get_timeout(self) -> str: + return '4h' diff --git a/src/backend/app/services/task/task_factory.py b/src/backend/app/services/task/task_factory.py new file mode 100644 index 0000000..145e114 --- /dev/null +++ b/src/backend/app/services/task/task_factory.py @@ -0,0 +1,126 @@ +""" +任务处理器工厂 + +使用工厂模式根据任务类型创建对应的处理器 +""" +import logging +from typing import Optional, Type, Dict + +from app.services.task.base_handler import BaseTaskHandler +from app.services.task.perturbation_handler import PerturbationTaskHandler +from app.services.task.finetune_handler import FinetuneTaskHandler +from app.services.task.heatmap_handler import HeatmapTaskHandler +from app.services.task.evaluate_handler import EvaluateTaskHandler +from app.services.storage import PathManager +from app.services.task.task_queue import TaskQueue + +logger = logging.getLogger(__name__) + + +class TaskHandlerFactory: + """ + 任务处理器工厂 + + 根据任务类型代码创建对应的处理器实例 + + 使用方式: + # 方式1:使用默认依赖 + handler = TaskHandlerFactory.create('perturbation') + job_id = handler.start(task_id=123) + + # 方式2:注入自定义依赖(便于测试) + handler = TaskHandlerFactory.create( + 'finetune', + path_manager=mock_pm, + task_queue=mock_queue + ) + + 支持的任务类型: + - perturbation: 加噪任务 + - finetune: 微调任务 + - heatmap: 热力图任务 + - evaluate: 评估任务 + """ + + # 任务类型到处理器类的映射 + _handlers: Dict[str, Type[BaseTaskHandler]] = { + 'perturbation': PerturbationTaskHandler, + 'finetune': FinetuneTaskHandler, + 'heatmap': HeatmapTaskHandler, + 'evaluate': EvaluateTaskHandler, + } + + @classmethod + def create( + cls, + task_type: str, + path_manager: Optional[PathManager] = None, + task_queue: Optional[TaskQueue] = None + ) -> BaseTaskHandler: + """ + 创建任务处理器 + + Args: + task_type: 任务类型代码 + path_manager: 路径管理器(可选) + task_queue: 任务队列(可选) + + Returns: + 对应的任务处理器实例 + + Raises: + ValueError: 未知的任务类型 + """ + handler_class = cls._handlers.get(task_type) + + if handler_class is None: + available = ', '.join(cls._handlers.keys()) + raise ValueError( + f"Unknown task type: '{task_type}'. " + f"Available types: {available}" + ) + + return handler_class( + path_manager=path_manager, + task_queue=task_queue + ) + + @classmethod + def register(cls, task_type: str, handler_class: Type[BaseTaskHandler]) -> None: + """ + 注册新的任务处理器(扩展点) + + Args: + task_type: 任务类型代码 + handler_class: 处理器类 + """ + cls._handlers[task_type] = handler_class + logger.info(f"Registered task handler: {task_type} -> {handler_class.__name__}") + + @classmethod + def get_supported_types(cls) -> list: + """获取所有支持的任务类型""" + return list(cls._handlers.keys()) + + +# ============================================================ +# 便捷函数:直接启动任务 +# ============================================================ + +def start_task(task_type: str, task_id: int) -> Optional[str]: + """ + 便捷函数:根据类型启动任务 + + Args: + task_type: 任务类型代码 + task_id: 任务 ID + + Returns: + job_id 或 None + """ + try: + handler = TaskHandlerFactory.create(task_type) + return handler.start(task_id) + except ValueError as e: + logger.error(str(e)) + return None diff --git a/src/backend/app/services/task/task_queue.py b/src/backend/app/services/task/task_queue.py new file mode 100644 index 0000000..de1a1df --- /dev/null +++ b/src/backend/app/services/task/task_queue.py @@ -0,0 +1,116 @@ +""" +任务队列管理 + +封装 Redis Queue (RQ) 的连接和队列操作 +""" +import logging +from typing import Optional, Callable, Any +from redis import Redis +from rq import Queue +from rq.job import Job +from config.algorithm_config import AlgorithmConfig + +logger = logging.getLogger(__name__) + + +class TaskQueue: + """ + 任务队列管理器 + + 封装 RQ 队列操作,提供统一的任务入队接口 + + 使用方式: + queue = TaskQueue() + job_id = queue.enqueue( + worker_func, + job_id='task_123', + timeout='4h', + task_id=123, + input_dir='/path/to/input' + ) + """ + + _instance: Optional['TaskQueue'] = None + + def __new__(cls) -> 'TaskQueue': + """单例模式""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._redis_url = AlgorithmConfig.REDIS_URL + self._queue_name = AlgorithmConfig.RQ_QUEUE_NAME + self._connection: Optional[Redis] = None + self._queue: Optional[Queue] = None + self._initialized = True + + @property + def connection(self) -> Redis: + """获取 Redis 连接(懒加载)""" + if self._connection is None: + self._connection = Redis.from_url(self._redis_url) + return self._connection + + @property + def queue(self) -> Queue: + """获取 RQ 队列(懒加载)""" + if self._queue is None: + self._queue = Queue(self._queue_name, connection=self.connection) + return self._queue + + def enqueue( + self, + func: Callable, + job_id: str, + timeout: str = '4h', + **kwargs: Any + ) -> Optional[str]: + """ + 将任务加入队列 + + Args: + func: 要执行的 worker 函数 + job_id: 任务唯一标识 + timeout: 超时时间(如 '4h', '30m') + **kwargs: 传递给 worker 函数的参数 + + Returns: + job_id 或 None(失败时) + """ + try: + self.queue.enqueue( + func, + job_id=job_id, + job_timeout=timeout, + **kwargs + ) + logger.info(f"Task enqueued: {job_id}") + return job_id + except Exception as e: + logger.error(f"Failed to enqueue task {job_id}: {e}") + return None + + def fetch_job(self, job_id: str) -> Optional[Job]: + """获取任务信息""" + try: + return Job.fetch(job_id, connection=self.connection) + except Exception as e: + logger.warning(f"Failed to fetch job {job_id}: {e}") + return None + + def cancel_job(self, job_id: str) -> bool: + """取消任务""" + try: + job = self.fetch_job(job_id) + if job: + job.cancel() + logger.info(f"Job cancelled: {job_id}") + return True + return False + except Exception as e: + logger.warning(f"Failed to cancel job {job_id}: {e}") + return False diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index e71fdee..5eb7fbf 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -2,128 +2,152 @@ 任务处理服务(适配新数据库结构和路径配置) 处理加噪、微调、热力图、评估等核心业务逻辑 使用Redis Queue进行异步任务处理 + +已重构为面向对象设计,推荐使用: + from app.services.task import TaskHandlerFactory + + handler = TaskHandlerFactory.create('perturbation') + job_id = handler.start(task_id) + +数据访问已迁移到 Repository 层: + from app.repositories import TaskRepository, TaskTypeRepository + + task_repo = TaskRepository() + task = task_repo.get_by_id(task_id) """ -import os import logging from datetime import datetime -from flask import current_app, jsonify +from typing import Optional +from flask import jsonify from redis import Redis -from rq import Queue from rq.job import Job -from app import db -from app.database import ( - Task, TaskStatus, TaskType, - Perturbation, Finetune, Heatmap, Evaluate, - Image, ImageType, DataType, - PerturbationConfig, FinetuneConfig, User -) +from app.services.storage import PathManager from config.algorithm_config import AlgorithmConfig -from config.settings import Config logger = logging.getLogger(__name__) +# 全局单例实例 +_path_manager: Optional[PathManager] = None +_task_repo = None +_task_type_repo = None +_task_status_repo = None +_user_repo = None + + +def _get_path_manager() -> PathManager: + """获取路径管理器单例""" + global _path_manager + if _path_manager is None: + _path_manager = PathManager() + return _path_manager + + +def _get_task_repo(): + """获取任务 Repository 单例(懒加载)""" + global _task_repo + if _task_repo is None: + from app.repositories import TaskRepository + _task_repo = TaskRepository() + return _task_repo + + +def _get_task_type_repo(): + """获取任务类型 Repository 单例(懒加载)""" + global _task_type_repo + if _task_type_repo is None: + from app.repositories import TaskTypeRepository + _task_type_repo = TaskTypeRepository() + return _task_type_repo + + +def _get_task_status_repo(): + """获取任务状态 Repository 单例(懒加载)""" + global _task_status_repo + if _task_status_repo is None: + from app.repositories import TaskStatusRepository + _task_status_repo = TaskStatusRepository() + return _task_status_repo + + +def _get_user_repo(): + """获取用户 Repository 单例(懒加载)""" + global _user_repo + if _user_repo is None: + from app.repositories import UserRepository + _user_repo = UserRepository() + return _user_repo + + +def _get_task_handler(task_type: str): + """获取任务处理器(懒加载导入避免循环依赖)""" + from app.services.task import TaskHandlerFactory + return TaskHandlerFactory.create(task_type) + class TaskService: """任务处理服务""" - # ==================== 路径工具函数 ==================== + # ==================== 路径代理方法(委托给 PathManager)==================== + # 保持向后兼容,内部委托给 PathManager @staticmethod def _get_project_root(): """获取项目根目录""" - return os.path.dirname(current_app.root_path) + return _get_path_manager().project_root @staticmethod def _build_path(*parts): """构建路径""" - return os.path.join(TaskService._get_project_root(), *parts) + return _get_path_manager()._build_path(*parts) @staticmethod def get_original_images_path(user_id, flow_id): - """原图路径: ORIGINAL_IMAGES_FOLDER/user_id/flow_id""" - return TaskService._build_path( - Config.ORIGINAL_IMAGES_FOLDER, - str(user_id), - str(flow_id) - ) + """原图路径""" + return _get_path_manager().get_original_images_path(user_id, flow_id) @staticmethod def get_perturbed_images_path(user_id, flow_id): - """加噪图路径: PERTURBED_IMAGES_FOLDER/user_id/flow_id""" - return TaskService._build_path( - Config.PERTURBED_IMAGES_FOLDER, - str(user_id), - str(flow_id) - ) + """加噪图路径""" + return _get_path_manager().get_perturbed_images_path(user_id, flow_id) @staticmethod def get_original_generated_path(user_id, flow_id, task_id): - """原图生成图路径: MODEL_ORIGINAL_FOLDER/user_id/flow_id/task_id""" - return TaskService._build_path( - Config.MODEL_ORIGINAL_FOLDER, - str(user_id), - str(flow_id), - str(task_id) - ) + """原图生成图路径""" + return _get_path_manager().get_original_generated_path(user_id, flow_id, task_id) @staticmethod def get_perturbed_generated_path(user_id, flow_id, task_id): - """加噪图生成图路径: MODEL_PERTURBED_FOLDER/user_id/flow_id/task_id""" - return TaskService._build_path( - Config.MODEL_PERTURBED_FOLDER, - str(user_id), - str(flow_id), - str(task_id) - ) + """加噪图生成图路径""" + return _get_path_manager().get_perturbed_generated_path(user_id, flow_id, task_id) @staticmethod def get_uploaded_generated_path(user_id, flow_id, task_id): - """上传图生成图路径: MODEL_UPLOADED_FOLDER/user_id/flow_id/task_id""" - return TaskService._build_path( - Config.MODEL_UPLOADED_FOLDER, - str(user_id), - str(flow_id), - str(task_id) - ) + """上传图生成图路径""" + return _get_path_manager().get_uploaded_generated_path(user_id, flow_id, task_id) @staticmethod def get_heatmap_path(user_id, flow_id, task_id): - """热力图路径: HEATDIF_SAVE_FOLDER/user_id/flow_id/task_id""" - return TaskService._build_path( - Config.HEATDIF_SAVE_FOLDER, - str(user_id), - str(flow_id), - str(task_id) - ) + """热力图路径""" + return _get_path_manager().get_heatmap_path(user_id, flow_id, task_id) @staticmethod def get_evaluate_path(user_id, flow_id, task_id): - """数值结果路径: NUMBERS_SAVE_FOLDER/user_id/flow_id/task_id""" - return TaskService._build_path( - Config.NUMBERS_SAVE_FOLDER, - str(user_id), - str(flow_id), - str(task_id) - ) + """数值结果路径""" + return _get_path_manager().get_evaluate_path(user_id, flow_id, task_id) @staticmethod def get_class_data_path(user_id, flow_id): - """类别数据路径: CLASS_DATA_FOLDER/user_id/flow_id""" - return TaskService._build_path( - Config.CLASS_DATA_FOLDER, - str(user_id), - str(flow_id) - ) - - @staticmethod - def get_model_data_path(user_id, flow_id): - """模型数据路径: MODEL_DATA_FOLDER/user_id/flow_id""" - return TaskService._build_path( - Config.MODEL_DATA_FOLDER - ) + """类别数据路径""" + return _get_path_manager().get_class_data_path(user_id, flow_id) + + @staticmethod + def get_model_data_path(user_id=None, flow_id=None): + """模型数据路径""" + return _get_path_manager().get_model_data_path() # ==================== 通用辅助功能 ==================== + # 以下方法委托给 Repository 层,保持向后兼容 @staticmethod def json_error(message, status_code=400): @@ -132,70 +156,63 @@ class TaskService: @staticmethod def get_task_type(code): - """根据任务类型代码获取TaskType""" - return TaskType.query.filter_by(task_type_code=code).first() + """根据任务类型代码获取TaskType(委托给 TaskTypeRepository)""" + return _get_task_type_repo().get_by_code(code) @staticmethod def require_task_type(code): - """确保任务类型存在""" - task_type = TaskService.get_task_type(code) - if not task_type: - raise ValueError(f"Task type '{code}' is not configured") - return task_type + """确保任务类型存在(委托给 TaskTypeRepository)""" + return _get_task_type_repo().require(code) @staticmethod def get_status_by_code(code): - """根据状态代码获取TaskStatus""" - return TaskStatus.query.filter_by(task_status_code=code).first() + """根据状态代码获取TaskStatus(委托给 TaskStatusRepository)""" + return _get_task_status_repo().get_by_code(code) @staticmethod def ensure_status(code): - """确保任务状态存在""" - status = TaskService.get_status_by_code(code) - if not status: - raise ValueError(f"Task status '{code}' is not configured") - return status + """确保任务状态存在(委托给 TaskStatusRepository)""" + return _get_task_status_repo().require(code) @staticmethod def generate_flow_id(): """生成唯一的flow_id""" base = int(datetime.utcnow().timestamp() * 1000) - while Task.query.filter_by(flow_id=base).first(): + task_repo = _get_task_repo() + while task_repo.find_one_by(flow_id=base): base += 1 return base @staticmethod def ensure_task_owner(task, user_id): - """验证任务归属""" - return bool(task and task.user_id == user_id) + """验证任务归属(委托给 TaskRepository)""" + return _get_task_repo().is_owner(task, user_id) @staticmethod def get_task_type_code(task): - """获取任务类型代码""" - return task.task_type.task_type_code if task and task.task_type else None + """获取任务类型代码(委托给 TaskRepository)""" + return _get_task_repo().get_type_code(task) @staticmethod def load_task_for_user(task_id, user_id, expected_type=None): """根据任务ID加载用户的任务,可选检查类型""" - task = Task.query.get(task_id) - if not TaskService.ensure_task_owner(task, user_id): + task_repo = _get_task_repo() + task = task_repo.get_for_user(task_id, user_id) + if not task: + return None + if expected_type and not task_repo.is_type(task, expected_type): return None - if expected_type: - task_type = TaskService.get_task_type_code(task) - if task_type != expected_type: - return None return task @staticmethod def determine_finetune_source(finetune_task): """判断微调任务来源""" - perturb_type = TaskService.require_task_type('perturbation') - sibling_perturbation = Task.query.filter( - Task.flow_id == finetune_task.flow_id, - Task.tasks_type_id == perturb_type.task_type_id, - Task.tasks_id != finetune_task.tasks_id - ).first() - return 'perturbation' if sibling_perturbation else 'uploaded' + task_repo = _get_task_repo() + sibling = task_repo.get_by_flow_and_type(finetune_task.flow_id, 'perturbation') + # 排除自身 + if sibling and sibling.tasks_id != finetune_task.tasks_id: + return 'perturbation' + return 'uploaded' @staticmethod def serialize_task(task): @@ -251,8 +268,8 @@ class TaskService: @staticmethod def get_user(user_id): - """获取用户""" - return User.query.get(user_id) + """获取用户(委托给 UserRepository)""" + return _get_user_repo().get_by_id(user_id) # ==================== Redis/RQ 连接管理 ==================== @@ -281,17 +298,14 @@ class TaskService: 任务状态信息 """ try: - task = Task.query.get(task_id) + task_repo = _get_task_repo() + task = task_repo.get_by_id(task_id) if not task: return {'status': 'not_found', 'error': 'Task not found'} - # 获取任务状态名称 - status = TaskStatus.query.get(task.tasks_status_id) - status_code = status.task_status_code if status else 'unknown' - - # 获取任务类型 - task_type = TaskType.query.get(task.tasks_type_id) - type_code = task_type.task_type_code if task_type else 'unknown' + # 使用 Repository 获取状态和类型代码 + status_code = task.task_status.task_status_code if task.task_status else 'unknown' + type_code = task_repo.get_type_code(task) or 'unknown' result = { 'task_id': task_id, @@ -345,13 +359,13 @@ class TaskService: 是否成功取消 """ try: - task = Task.query.get(task_id) + task_repo = _get_task_repo() + task = task_repo.get_by_id(task_id) if not task: return False - # 获取任务类型 - task_type = TaskType.query.get(task.tasks_type_id) - type_code = task_type.task_type_code if task_type else None + # 获取任务类型代码 + type_code = task_repo.get_type_code(task) # 尝试从队列中删除任务 try: @@ -362,19 +376,10 @@ class TaskService: except Exception as e: logger.warning(f"Could not cancel RQ job: {e}") - # 更新数据库状态 - try: - failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() - if failed_status: - task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() - db.session.commit() - except Exception as e: - db.session.rollback() - logger.error(f"Failed to update task status: {e}") - return False - - return True + # 使用 Repository 更新状态 + if task_repo.update_status(task, 'failed'): + return task_repo.save() + return False except Exception as e: logger.error(f"Error cancelling task: {e}") @@ -385,7 +390,7 @@ class TaskService: @staticmethod def start_perturbation_task(task_id): """ - 启动加噪任务 + 启动加噪任务(委托给 PerturbationTaskHandler) Args: task_id: 任务ID @@ -393,249 +398,33 @@ class TaskService: Returns: job_id """ - try: - # 获取任务 - task = Task.query.get(task_id) - if not task: - logger.error(f"Task {task_id} not found") - return None - - # 获取Perturbation任务详情 - perturbation = Perturbation.query.get(task_id) - if not perturbation: - logger.error(f"Perturbation task {task_id} not found") - return None - - # 更新任务状态为 waiting - waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if waiting_status: - task.tasks_status_id = waiting_status.task_status_id - db.session.commit() - - # 获取用户ID - user_id = task.user_id - - # 路径配置 - input_dir = TaskService.get_original_images_path(user_id, task.flow_id) - output_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id) - class_dir = TaskService.get_class_data_path(user_id, task.flow_id) - - # 获取算法配置 - pert_config = PerturbationConfig.query.get(perturbation.perturbation_configs_id) - if not pert_config: - logger.error(f"Perturbation config not found") - return None - - algorithm_code = pert_config.perturbation_code - - # 加入RQ队列 - from app.workers.perturbation_worker import run_perturbation_task - - queue = TaskService._get_queue() - job_id = f"pert_{task_id}" - - job = queue.enqueue( - run_perturbation_task, - task_id=task_id, - input_dir=input_dir, - output_dir=output_dir, - class_dir=class_dir, - algorithm_code=algorithm_code, - epsilon=perturbation.perturbation_intensity, - job_id=job_id, - job_timeout='4h' - ) - - logger.info(f"Perturbation task {task_id} enqueued with job_id {job_id}") - return job_id - - except Exception as e: - logger.error(f"Error starting perturbation task: {e}") - return None + return _get_task_handler('perturbation').start(task_id) # ==================== Finetune 任务 ==================== @staticmethod def start_finetune_task(task_id): """ - 启动微调任务(支持两种类型) - - 类型1:基于加噪结果的微调 - - 有相同flow_id的Perturbation任务 - - 输入:原图 + 加噪图 - - 输出到:original_generated 和 perturbed_generated + 启动微调任务(委托给 FinetuneTaskHandler) - 类型2:用户上传图片的微调 - - 找不到相同flow_id的其他任务 - - 输入:仅原图 - - 输出到:uploaded_generated + 支持两种类型: + - 基于加噪结果的微调 + - 用户上传图片的微调 Args: task_id: 任务ID Returns: - job_id + job_id 或逗号分隔的多个 job_id """ - try: - # 获取任务 - task = Task.query.get(task_id) - if not task: - logger.error(f"Task {task_id} not found") - return None - - # 获取Finetune任务详情 - finetune = Finetune.query.get(task_id) - if not finetune: - logger.error(f"Finetune task {task_id} not found") - return None - - # 更新任务状态为 waiting - waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if waiting_status: - task.tasks_status_id = waiting_status.task_status_id - db.session.commit() - - # 获取用户ID - user_id = task.user_id - - # 获取微调配置 - ft_config = FinetuneConfig.query.get(finetune.finetune_configs_id) - if not ft_config: - logger.error(f"Finetune config not found") - return None - - # 检测微调类型:查找相同flow_id的Perturbation任务 - perturb_type = TaskService.require_task_type('perturbation') - sibling_perturbation = Task.query.filter( - Task.flow_id == task.flow_id, - Task.tasks_type_id == perturb_type.task_type_id, - Task.tasks_id != task_id - ).first() - - has_perturbation = sibling_perturbation is not None - - # 路径配置 - input_dir = TaskService.get_original_images_path(user_id, task.flow_id) - class_dir = TaskService.get_class_data_path(user_id, task.flow_id) - model_data_dir = TaskService.get_model_data_path(user_id, task.flow_id) - - if has_perturbation: - # 类型1:基于加噪结果的微调 - logger.info(f"Finetune task {task_id}: type=perturbation-based") - - perturbed_input_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id) - original_input_dir = TaskService.get_original_images_path(user_id, task.flow_id) - perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, task_id) - original_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, task_id) - - - # 获取坐标保存路径(3D可视化) - original_coords_save_path = TaskService._build_path( - Config.COORDS_SAVE_FOLDER, - str(user_id), - str(task.flow_id), - str(task_id), - 'original_coords.csv' - ) - - # 获取加噪坐标保存路径(3D可视化) - perturbed_coords_save_path = TaskService._build_path( - Config.COORDS_SAVE_FOLDER, - str(user_id), - str(task.flow_id), - str(task_id), - 'perturbed_coords.csv' - ) - - # 加入RQ队列 - from app.workers.finetune_worker import run_finetune_task - - queue = TaskService._get_queue() - job_id_original = f"ft_{task_id}_original" - job_id_perturbed = f"ft_{task_id}_perturbed" - - job_original = queue.enqueue( - run_finetune_task, - task_id=task_id, - finetune_method=ft_config.finetune_code, - train_images_dir=original_input_dir, - output_model_dir=model_data_dir, - class_dir=class_dir, - coords_save_path=original_coords_save_path, - validation_output_dir=original_output_dir, - finetune_type="original", - custom_params=None, - job_id=job_id_original, - job_timeout='8h' - ) - - job_perturbed = queue.enqueue( - run_finetune_task, - task_id=task_id, - finetune_method=ft_config.finetune_code, - train_images_dir=perturbed_input_dir, - output_model_dir=model_data_dir, - class_dir=class_dir, - coords_save_path=perturbed_coords_save_path, - validation_output_dir=perturbed_output_dir, - finetune_type="perturbed", - custom_params=None, - job_id=job_id_perturbed, - job_timeout='8h' - ) - - logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}") - return f"{job_id_original},{job_id_perturbed}" - - else: - # 类型2:用户上传图片的微调 - logger.info(f"Finetune task {task_id}: type=uploaded") - - uploaded_output_dir = TaskService.get_uploaded_generated_path(user_id, task.flow_id, task_id) - - # 获取坐标保存路径 - coords_save_path = TaskService._build_path( - Config.COORDS_SAVE_FOLDER, - str(user_id), - str(task.flow_id), - str(task_id), - 'coords.csv' - ) - - # 加入RQ队列 - from app.workers.finetune_worker import run_finetune_task - - queue = TaskService._get_queue() - job_id = f"ft_{task_id}" - - job = queue.enqueue( - run_finetune_task, - task_id=task_id, - finetune_method=ft_config.finetune_code, - train_images_dir=input_dir, - output_model_dir=model_data_dir, - class_dir=class_dir, - coords_save_path=coords_save_path, - validation_output_dir=uploaded_output_dir, - finetune_type="uploaded", - custom_params=None, - job_id=job_id, - job_timeout='8h' - ) - - logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}") - return job_id - - except Exception as e: - logger.error(f"Error starting finetune task: {e}") - return None + return _get_task_handler('finetune').start(task_id) # ==================== Heatmap 任务 ==================== @staticmethod def start_heatmap_task(task_id): """ - 启动热力图任务 + 启动热力图任务(委托给 HeatmapTaskHandler) Args: task_id: 任务ID @@ -643,93 +432,14 @@ class TaskService: Returns: job_id """ - try: - # 获取任务 - task = Task.query.get(task_id) - if not task: - logger.error(f"Task {task_id} not found") - return None - - # 获取Heatmap任务详情 - heatmap = Heatmap.query.get(task_id) - if not heatmap: - logger.error(f"Heatmap task {task_id} not found") - return None - - # 更新任务状态为 waiting - waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if waiting_status: - task.tasks_status_id = waiting_status.task_status_id - db.session.commit() - - # 从heatmap对象获取扰动图片ID - perturbed_image_id = heatmap.images_id - if not perturbed_image_id: - logger.error(f"Heatmap task {task_id} has no associated perturbed image") - return None - - # 获取扰动图片信息 - perturbed_image = Image.query.get(perturbed_image_id) - if not perturbed_image: - logger.error(f"Perturbed image {perturbed_image_id} not found") - return None - - user_id = task.user_id - - # 获取原图(通过father_id关系) - if not perturbed_image.father_id: - logger.error(f"Perturbed image {perturbed_image_id} has no father_id") - return None - - original_image = Image.query.get(perturbed_image.father_id) - if not original_image: - logger.error(f"Original image not found") - return None - - # 构建图片路径(使用 stored_filename) - original_image_path = os.path.join( - TaskService.get_original_images_path(user_id, task.flow_id), - original_image.stored_filename - ) - - perturbed_image_path = os.path.join( - TaskService.get_perturbed_images_path(user_id, task.flow_id), - perturbed_image.stored_filename - ) - - # 输出目录 - output_dir = TaskService.get_heatmap_path(user_id, task.flow_id, task_id) - - # 加入RQ队列 - from app.workers.heatmap_worker import run_heatmap_task - - queue = TaskService._get_queue() - job_id = f"hm_{task_id}" - - job = queue.enqueue( - run_heatmap_task, - task_id=task_id, - original_image_path=original_image_path, - perturbed_image_path=perturbed_image_path, - output_dir=output_dir, - perturbed_image_id=perturbed_image_id, - job_id=job_id, - job_timeout='2h' - ) - - logger.info(f"Heatmap task {task_id} enqueued with job_id {job_id}") - return job_id - - except Exception as e: - logger.error(f"Error starting heatmap task: {e}") - return None + return _get_task_handler('heatmap').start(task_id) # ==================== Evaluate 任务 ==================== @staticmethod def start_evaluate_task(task_id): """ - 启动评估任务 + 启动评估任务(委托给 EvaluateTaskHandler) Args: task_id: 任务ID @@ -737,64 +447,4 @@ class TaskService: Returns: job_id """ - try: - # 获取任务 - task = Task.query.get(task_id) - if not task: - logger.error(f"Task {task_id} not found") - return None - - # 获取Evaluate任务详情 - evaluate = Evaluate.query.get(task_id) - if not evaluate: - logger.error(f"Evaluate task {task_id} not found") - return None - - # 更新任务状态为 waiting - waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first() - if waiting_status: - task.tasks_status_id = waiting_status.task_status_id - db.session.commit() - - finetune = Finetune.query.get(evaluate.finetune_task_id) - if not finetune: - logger.error(f"Finetune task {evaluate.finetune_task_id} not found for evaluation {task_id}") - return None - - finetune_task = finetune.task - if not finetune_task: - logger.error(f"Finetune task {evaluate.finetune_task_id} missing Task relation") - return None - - user_id = finetune_task.user_id - - # 路径配置 - clean_ref_dir = TaskService.get_original_images_path(user_id, finetune_task.flow_id) - clean_output_dir = TaskService.get_original_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id) - perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id) - output_dir = TaskService.get_evaluate_path(user_id, finetune_task.flow_id, task_id) - - # 加入RQ队列 - from app.workers.evaluate_worker import run_evaluate_task - - queue = TaskService._get_queue() - job_id = f"eval_{task_id}" - - job = queue.enqueue( - run_evaluate_task, - task_id=task_id, - clean_ref_dir=clean_ref_dir, - clean_output_dir=clean_output_dir, - perturbed_output_dir=perturbed_output_dir, - output_dir=output_dir, - image_size=512, - job_id=job_id, - job_timeout='2h' - ) - - logger.info(f"Evaluate task {task_id} enqueued with job_id {job_id}") - return job_id - - except Exception as e: - logger.error(f"Error starting evaluate task: {e}") - return None + return _get_task_handler('evaluate').start(task_id)