代码重构完成 #30

Merged
ppy4sjqvf merged 9 commits from class into develop 3 weeks ago

22
.gitignore vendored

@ -35,4 +35,24 @@ uploads/
.github/
# pycharm 配置
.idea/
.idea/
# pytest配置
*.ini
# 测试相关
.pytest_cache/
.coverage
.coverage.*
htmlcov/
.tox/
.nox/
coverage.xml
*.cover
*.py,cover
.hypothesis/
pytest_cache/
test-results/
test-reports/
tests/
run_tests.py

@ -9,7 +9,7 @@ from app import db
from app.database import User, UserConfig
from functools import wraps
import re
from app.services.email_service import send_verification_code, verify_code
from app.services.email import VerificationService
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@ -40,8 +40,11 @@ def send_email_verification_code():
if not re.match(email_pattern, email):
return jsonify({'error': '邮箱格式不正确'}), 400
send_verification_code(email, purpose=purpose)
return jsonify({'message': '验证码已发送'}), 200
verification_service = VerificationService()
if verification_service.send_code(email, purpose):
return jsonify({'message': '验证码已发送'}), 200
else:
return jsonify({'error': '验证码发送失败,请稍后重试'}), 500
except Exception as e:
return jsonify({'error': f'发送验证码失败: {str(e)}'}), 500
@ -72,8 +75,8 @@ def register():
if User.query.filter_by(email=email).first():
return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400
# 验证验证码
if not code or not verify_code(email, code, purpose='register'):
verification_service = VerificationService()
if not code or not verification_service.verify_code(email, code, purpose = 'register'):
return jsonify({'error': '验证码无效或已过期'}), 400
# 创建用户默认为普通用户role_id=3
@ -160,12 +163,15 @@ def change_password(current_user_id):
db.session.rollback()
return jsonify({'error': f'密码修改失败: {str(e)}'}), 500
@auth_bp.route('/change-email', methods = ['POST'])
@auth_bp.route('/change-email', methods=['POST'])
@int_jwt_required
def change_email(current_user_id):
"""修改邮箱"""
try:
user = User.query.filter_by(current_user_id)
user = User.query.filter_by(user_id=current_user_id).first()
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
new_email = data.get('new_email')
code = data.get('code')
@ -173,10 +179,15 @@ def change_email(current_user_id):
if not new_email:
return jsonify({'error': '新邮箱不能为空'}), 400
if not User.query.filter(new_email).first():
return jsonify({'error':'该邮箱已被使用'}), 400
# 验证邮箱格式
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, new_email):
return jsonify({'error': '邮箱格式不正确'}), 400
if User.query.filter_by(email=new_email).first():
return jsonify({'error': '该邮箱已被使用'}), 400
if not code or not verify_code(email, code, purpose='register'):
if not code or not verify_code(new_email, code, purpose='change_email'):
return jsonify({'error': '验证码无效或已过期'}), 400
user.email = new_email
@ -187,22 +198,25 @@ def change_email(current_user_id):
db.session.rollback()
return jsonify({'error': f'邮箱修改失败: {str(e)}'}), 500
@auth_bp.route('/change-username', methods = ['POST'])
@auth_bp.route('/change-username', methods=['POST'])
@int_jwt_required
def change_username(current_user_id):
"""修改用户名"""
try:
user = User.query.filter_by(current_user_id)
user = User.query.filter_by(user_id=current_user_id).first()
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
new_username = data.get('new_username')
if not new_username:
return jsonify({'error': '新名称不能为空'}), 400
if not User.query.filter(new_username).first():
return jsonify({'error':'该用户名已被使用'}), 400
if User.query.filter_by(username=new_username).first():
return jsonify({'error': '该用户名已被使用'}), 400
user.name = new_username
user.username = new_username
db.session.commit()
return jsonify({'message': '用户名修改成功'}), 200

@ -4,11 +4,12 @@
"""
import os
import base64
from flask import Blueprint, request, jsonify, send_file
import uuid
from flask import Blueprint, request, jsonify, send_file, Response
from app.controllers.auth_controller import int_jwt_required
from app.services.task_service import TaskService
from app.services.image_service import ImageService
from app.services.image.image_serializer import get_image_serializer
from app.database import Image, ImageType
image_bp = Blueprint('image', __name__)
@ -40,9 +41,10 @@ def upload_original_images(current_user_id):
status_code = 500
return ImageService.json_error(result, status_code)
serializer = get_image_serializer()
return jsonify({
'message': '图片上传成功',
'images': [ImageService.image_to_base64(img) for img in result],
'images': [serializer.to_dict(img) for img in result],
'flow_id': task.flow_id
}), 201
@ -78,162 +80,7 @@ def get_image_file(image_id, current_user_id):
return send_file(image.file_path, mimetype=mimetype)
# ==================== 任务图片获取(返回 base64 ====================
@image_bp.route('/task/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_task_images(task_id, current_user_id):
"""获取任务的所有图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
image_type_code = request.args.get('type')
query = Image.query.filter_by(task_id=task_id)
if image_type_code:
image_type = ImageType.query.filter_by(image_code=image_type_code).first()
if image_type:
query = query.filter_by(image_types_id=image_type.image_types_id)
images = query.all()
return jsonify({
'task_id': task_id,
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/perturbation/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_perturbation_images(task_id, current_user_id):
"""获取加噪任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if not perturbed_type:
return ImageService.json_error('图片类型未配置', 500)
images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'perturbation',
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/heatmap/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_heatmap_images(task_id, current_user_id):
"""获取热力图任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if not heatmap_type:
return ImageService.json_error('图片类型未配置', 500)
images = Image.query.filter_by(
task_id=task_id,
image_types_id=heatmap_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'heatmap',
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/finetune/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_finetune_images(task_id, current_user_id):
"""获取微调任务的生成图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
if not task.finetune:
return ImageService.json_error('微调任务配置不存在', 404)
try:
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return ImageService.json_error(str(exc), 500)
result = {'task_id': task_id, 'task_type': 'finetune', 'source': source}
if source == 'perturbation':
original_gen_type = ImageType.query.filter_by(image_code='original_generate').first()
perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
original_images = []
perturbed_images = []
if original_gen_type:
original_images = Image.query.filter_by(
task_id=task_id,
image_types_id=original_gen_type.image_types_id
).all()
if perturbed_gen_type:
perturbed_images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_gen_type.image_types_id
).all()
result['original_generate'] = [ImageService.image_to_base64(img) for img in original_images]
result['perturbed_generate'] = [ImageService.image_to_base64(img) for img in perturbed_images]
result['total'] = len(original_images) + len(perturbed_images)
else:
uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
uploaded_images = []
if uploaded_gen_type:
uploaded_images = Image.query.filter_by(
task_id=task_id,
image_types_id=uploaded_gen_type.image_types_id
).all()
result['uploaded_generate'] = [ImageService.image_to_base64(img) for img in uploaded_images]
result['total'] = len(uploaded_images)
return jsonify(result), 200
@image_bp.route('/evaluate/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_evaluate_images(task_id, current_user_id):
"""获取评估任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
report_type = ImageType.query.filter_by(image_code='report').first()
if not report_type:
return ImageService.json_error('图片类型未配置', 500)
images = Image.query.filter_by(
task_id=task_id,
image_types_id=report_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'evaluate',
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
# ==================== 图片删除 ====================
@ -257,188 +104,133 @@ def delete_image(image_id, current_user_id):
return jsonify({'message': '图片删除成功'}), 200
# ==================== 统一预览接口 ====================
@image_bp.route('/preview/flow/<int:flow_id>', methods=['GET'])
@int_jwt_required
def preview_flow_images(flow_id, current_user_id):
"""
获取工作流下所有图片的统一预览接口
返回数据结构:
{
"flow_id": 123,
"original": [...], # 原图
"perturbed": [...], # 加噪图
"original_generate": [...], # 原图微调生成
"perturbed_generate": [...], # 加噪图微调生成
"uploaded_generate": [...], # 上传图微调生成
"heatmap": [...], # 热力图
"report": [...] # 评估报告图
}
"""
from app.database import Task
# 验证用户对该flow的访问权限
tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all()
if not tasks:
return ImageService.json_error('工作流不存在或无权限', 404)
# 获取所有图片类型
image_types = {
'original': ImageType.query.filter_by(image_code='original').first(),
'perturbed': ImageType.query.filter_by(image_code='perturbed').first(),
'original_generate': ImageType.query.filter_by(image_code='original_generate').first(),
'perturbed_generate': ImageType.query.filter_by(image_code='perturbed_generate').first(),
'uploaded_generate': ImageType.query.filter_by(image_code='uploaded_generate').first(),
'heatmap': ImageType.query.filter_by(image_code='heatmap').first(),
'report': ImageType.query.filter_by(image_code='report').first(),
}
# 收集所有任务ID
task_ids = [t.tasks_id for t in tasks]
result = {
'flow_id': flow_id,
'original': [],
'perturbed': [],
'original_generate': [],
'perturbed_generate': [],
'uploaded_generate': [],
'heatmap': [],
'report': []
}
# 查询各类型图片
for type_code, image_type in image_types.items():
if image_type:
images = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == image_type.image_types_id
).all()
result[type_code] = [ImageService.image_to_base64(img) for img in images if img]
# 统计总数
result['total'] = sum(len(result[k]) for k in result if k not in ['flow_id', 'total'])
return jsonify(result), 200
@image_bp.route('/preview/task/<int:task_id>', methods=['GET'])
# ==================== 二进制流式传输接口 ====================
@image_bp.route('/binary/task/<int:task_id>', methods=['GET'])
@int_jwt_required
def preview_task_images(task_id, current_user_id):
def get_task_images_binary(task_id, current_user_id):
"""
获取单个任务的所有图片预览
根据任务类型返回相应的图片:
- perturbation: 原图 + 加噪图
- finetune: 原图 + 生成图(original_generate/perturbed_generate/uploaded_generate)
- heatmap: 原图 + 加噪图 + 热力图
- evaluate: 生成图 + 报告图
multipart/mixed 格式流式返回任务的所有图片二进制数据
Query参数:
type: 可选指定图片类型代码
响应格式: multipart/mixed
每个part包含:
- Content-Type: 图片MIME类型
- Content-Disposition: 文件名
- X-Image-Id: 图片ID
- X-Image-Type: 图片类型代码
- X-Image-Width: 宽度
- X-Image-Height: 高度
- 图片二进制数据
"""
from app.database import Task, TaskType
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
task_type_code = TaskService.get_task_type_code(task)
image_type_code = request.args.get('type')
result = {
'task_id': task_id,
'flow_id': task.flow_id,
'task_type': task_type_code,
'images': {}
}
query = Image.query.filter_by(task_id=task_id)
if image_type_code:
image_type = ImageType.query.filter_by(image_code=image_type_code).first()
if image_type:
query = query.filter_by(image_types_id=image_type.image_types_id)
# 根据任务类型获取相关图片
if task_type_code == 'perturbation':
result['images'] = ImageService._get_perturbation_preview(task)
elif task_type_code == 'finetune':
result['images'] = ImageService._get_finetune_preview(task)
elif task_type_code == 'heatmap':
result['images'] = ImageService._get_heatmap_preview(task)
elif task_type_code == 'evaluate':
result['images'] = ImageService._get_evaluate_preview(task)
images = query.all()
return jsonify(result), 200
if not images:
return ImageService.json_error('没有找到图片', 404)
# 按类型分组
images_dict = {}
for img in images:
type_code = img.image_type.image_code if img.image_type else 'unknown'
if type_code not in images_dict:
images_dict[type_code] = []
images_dict[type_code].append(img)
boundary = uuid.uuid4().hex
serializer = get_image_serializer()
return Response(
serializer.generate_multipart_stream(images_dict, boundary),
mimetype=f'multipart/mixed; boundary={boundary}',
headers={
'X-Total-Images': str(len(images)),
'X-Task-Id': str(task_id)
}
)
@image_bp.route('/preview/compare/<int:flow_id>', methods=['GET'])
@image_bp.route('/binary/flow/<int:flow_id>', methods=['GET'])
@int_jwt_required
def preview_compare_images(flow_id, current_user_id):
def get_flow_images_binary(flow_id, current_user_id):
"""
获取对比预览数据用于展示原图vs加噪图原图生成vs加噪图生成的对比
multipart/mixed 格式流式返回工作流的所有图片二进制数据
返回配对的图片数据便于前端展示对比效果
Query参数:
types: 可选逗号分隔的图片类型代码列表
响应格式: multipart/mixed
"""
from app.database import Task
# 验证权限
tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all()
if not tasks:
return ImageService.json_error('工作流不存在或无权限', 404)
task_ids = [t.tasks_id for t in tasks]
# 获取图片类型
original_type = ImageType.query.filter_by(image_code='original').first()
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
original_gen_type = ImageType.query.filter_by(image_code='original_generate').first()
perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
# 解析请求的图片类型
type_codes = request.args.get('types', '').split(',') if request.args.get('types') else None
result = {
'flow_id': flow_id,
'perturbation_pairs': [], # 原图 vs 加噪图
'generation_pairs': [] # 原图生成 vs 加噪图生成
}
# 构建查询
query = Image.query.filter(Image.task_id.in_(task_ids))
# 构建原图vs加噪图对比
if original_type and perturbed_type:
originals = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_type.image_types_id
).all()
for orig in originals:
# 查找对应的加噪图通过father_id关联
perturbed = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == perturbed_type.image_types_id,
Image.father_id == orig.images_id
).first()
if perturbed:
result['perturbation_pairs'].append({
'original': ImageService.image_to_base64(orig),
'perturbed': ImageService.image_to_base64(perturbed)
})
if type_codes:
type_ids = []
for code in type_codes:
code = code.strip()
if code:
img_type = ImageType.query.filter_by(image_code=code).first()
if img_type:
type_ids.append(img_type.image_types_id)
if type_ids:
query = query.filter(Image.image_types_id.in_(type_ids))
# 构建生成图对比(按文件名匹配)
if original_gen_type and perturbed_gen_type:
original_gens = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_gen_type.image_types_id
).all()
perturbed_gens = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == perturbed_gen_type.image_types_id
).all()
# 按文件名建立映射
perturbed_map = {img.stored_filename: img for img in perturbed_gens}
for orig_gen in original_gens:
perturbed_gen = perturbed_map.get(orig_gen.stored_filename)
if perturbed_gen:
result['generation_pairs'].append({
'original_generate': ImageService.image_to_base64(orig_gen),
'perturbed_generate': ImageService.image_to_base64(perturbed_gen)
})
images = query.all()
return jsonify(result), 200
if not images:
return ImageService.json_error('没有找到图片', 404)
# 按类型分组
images_dict = {}
for img in images:
type_code = img.image_type.image_code if img.image_type else 'unknown'
if type_code not in images_dict:
images_dict[type_code] = []
images_dict[type_code].append(img)
boundary = uuid.uuid4().hex
serializer = get_image_serializer()
return Response(
serializer.generate_multipart_stream(images_dict, boundary),
mimetype=f'multipart/mixed; boundary={boundary}',
headers={
'X-Total-Images': str(len(images)),
'X-Flow-Id': str(flow_id)
}
)
""" 前端解析预览图片方式
const response = await fetch(`/api/image/binary/task/${taskId}`);
const contentType = response.headers.get('content-type');
const boundary = contentType.match(/boundary=(.+)/)[1];
const buffer = await response.arrayBuffer();
// --boundary 分割解析每个 part
"""

@ -0,0 +1,63 @@
"""
Repository
提供数据访问抽象将数据库操作从 Service 层分离
使用方式:
from app.repositories import TaskRepository, ImageRepository
task_repo = TaskRepository()
task = task_repo.get_by_id(task_id)
image_repo = ImageRepository()
images = image_repo.get_by_task(task_id)
设计原则:
- 单一职责每个 Repository 只负责一个实体的数据访问
- 依赖倒置Service 层依赖 Repository 抽象
- 开闭原则通过继承 BaseRepository 扩展新实体
"""
from .base_repository import BaseRepository
from .task_repository import (
TaskRepository,
PerturbationRepository,
FinetuneRepository,
HeatmapRepository,
EvaluateRepository,
EvaluationResultRepository,
)
from .image_repository import ImageRepository
from .user_repository import UserRepository, UserConfigRepository, RoleRepository
from .config_repository import (
TaskTypeRepository,
TaskStatusRepository,
ImageTypeRepository,
PerturbationConfigRepository,
FinetuneConfigRepository,
DataTypeRepository,
)
__all__ = [
# Base
'BaseRepository',
# Task
'TaskRepository',
'PerturbationRepository',
'FinetuneRepository',
'HeatmapRepository',
'EvaluateRepository',
'EvaluationResultRepository',
# Image
'ImageRepository',
# User
'UserRepository',
'UserConfigRepository',
'RoleRepository',
# Config
'TaskTypeRepository',
'TaskStatusRepository',
'ImageTypeRepository',
'PerturbationConfigRepository',
'FinetuneConfigRepository',
'DataTypeRepository',
]

@ -0,0 +1,151 @@
"""
Repository 基类
提供通用的 CRUD 操作子类可以扩展特定实体的查询方法
"""
import logging
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Optional, List, Any, Type
from app import db
logger = logging.getLogger(__name__)
T = TypeVar('T')
class BaseRepository(ABC, Generic[T]):
"""
Repository 抽象基类
提供通用的数据访问方法:
- get_by_id: 根据 ID 获取单个实体
- get_all: 获取所有实体
- create: 创建新实体
- update: 更新实体
- delete: 删除实体
- save: 保存更改
子类需实现:
- _get_model_class(): 返回对应的 SQLAlchemy Model
"""
@abstractmethod
def _get_model_class(self) -> Type[T]:
"""返回对应的 Model 类"""
pass
@abstractmethod
def _get_primary_key_name(self) -> str:
"""返回主键字段名"""
pass
def get_by_id(self, entity_id: int) -> Optional[T]:
"""根据 ID 获取实体"""
return self._get_model_class().query.get(entity_id)
def get_all(self) -> List[T]:
"""获取所有实体"""
return self._get_model_class().query.all()
def find_by(self, **kwargs) -> List[T]:
"""根据条件查询"""
return self._get_model_class().query.filter_by(**kwargs).all()
def find_one_by(self, **kwargs) -> Optional[T]:
"""根据条件查询单个实体"""
return self._get_model_class().query.filter_by(**kwargs).first()
def exists(self, entity_id: int) -> bool:
"""检查实体是否存在"""
return self.get_by_id(entity_id) is not None
def count(self, **kwargs) -> int:
"""统计数量"""
if kwargs:
return self._get_model_class().query.filter_by(**kwargs).count()
return self._get_model_class().query.count()
def create(self, **kwargs) -> T:
"""
创建新实体不自动提交
Args:
**kwargs: 实体属性
Returns:
创建的实体对象
"""
model_class = self._get_model_class()
entity = model_class(**kwargs)
db.session.add(entity)
return entity
def add(self, entity: T) -> T:
"""
添加实体到 session不自动提交
Args:
entity: 实体对象
Returns:
添加的实体对象
"""
db.session.add(entity)
return entity
def delete(self, entity: T) -> bool:
"""
删除实体不自动提交
Args:
entity: 要删除的实体
Returns:
是否成功
"""
try:
db.session.delete(entity)
return True
except Exception as e:
logger.error(f"Delete failed: {e}")
return False
def delete_by_id(self, entity_id: int) -> bool:
"""
根据 ID 删除实体不自动提交
Args:
entity_id: 实体 ID
Returns:
是否成功
"""
entity = self.get_by_id(entity_id)
if entity:
return self.delete(entity)
return False
def save(self) -> bool:
"""
提交当前事务
Returns:
是否成功
"""
try:
db.session.commit()
return True
except Exception as e:
db.session.rollback()
logger.error(f"Save failed: {e}")
return False
def rollback(self):
"""回滚当前事务"""
db.session.rollback()
def refresh(self, entity: T) -> T:
"""刷新实体状态"""
db.session.refresh(entity)
return entity

@ -0,0 +1,154 @@
"""
配置字典表 Repository
负责各种配置表的数据访问
- TaskType: 任务类型
- TaskStatus: 任务状态
- ImageType: 图片类型
- PerturbationConfig: 加噪算法配置
- FinetuneConfig: 微调配置
- DataType: 数据集类型
"""
import logging
from typing import Optional, List, Type
from app.database import (
TaskType, TaskStatus, ImageType,
PerturbationConfig, FinetuneConfig, DataType
)
from .base_repository import BaseRepository
logger = logging.getLogger(__name__)
class TaskTypeRepository(BaseRepository[TaskType]):
"""任务类型数据访问"""
def _get_model_class(self) -> Type[TaskType]:
return TaskType
def _get_primary_key_name(self) -> str:
return 'task_type_id'
def get_by_code(self, code: str) -> Optional[TaskType]:
"""根据代码获取任务类型"""
return TaskType.query.filter_by(task_type_code=code).first()
def require(self, code: str) -> TaskType:
"""获取任务类型,不存在则抛出异常"""
task_type = self.get_by_code(code)
if not task_type:
raise ValueError(f"Task type '{code}' is not configured")
return task_type
class TaskStatusRepository(BaseRepository[TaskStatus]):
"""任务状态数据访问"""
def _get_model_class(self) -> Type[TaskStatus]:
return TaskStatus
def _get_primary_key_name(self) -> str:
return 'task_status_id'
def get_by_code(self, code: str) -> Optional[TaskStatus]:
"""根据代码获取任务状态"""
return TaskStatus.query.filter_by(task_status_code=code).first()
def require(self, code: str) -> TaskStatus:
"""获取任务状态,不存在则抛出异常"""
status = self.get_by_code(code)
if not status:
raise ValueError(f"Task status '{code}' is not configured")
return status
def get_waiting(self) -> Optional[TaskStatus]:
"""获取等待状态"""
return self.get_by_code('waiting')
def get_processing(self) -> Optional[TaskStatus]:
"""获取处理中状态"""
return self.get_by_code('processing')
def get_completed(self) -> Optional[TaskStatus]:
"""获取完成状态"""
return self.get_by_code('completed')
def get_failed(self) -> Optional[TaskStatus]:
"""获取失败状态"""
return self.get_by_code('failed')
class ImageTypeRepository(BaseRepository[ImageType]):
"""图片类型数据访问"""
def _get_model_class(self) -> Type[ImageType]:
return ImageType
def _get_primary_key_name(self) -> str:
return 'image_types_id'
def get_by_code(self, code: str) -> Optional[ImageType]:
"""根据代码获取图片类型"""
return ImageType.query.filter_by(image_code=code).first()
def require(self, code: str) -> ImageType:
"""获取图片类型,不存在则抛出异常"""
image_type = self.get_by_code(code)
if not image_type:
raise ValueError(f"Image type '{code}' is not configured")
return image_type
class PerturbationConfigRepository(BaseRepository[PerturbationConfig]):
"""加噪算法配置数据访问"""
def _get_model_class(self) -> Type[PerturbationConfig]:
return PerturbationConfig
def _get_primary_key_name(self) -> str:
return 'perturbation_configs_id'
def get_by_code(self, code: str) -> Optional[PerturbationConfig]:
"""根据代码获取加噪配置"""
return PerturbationConfig.query.filter_by(perturbation_code=code).first()
def get_all_active(self) -> List[PerturbationConfig]:
"""获取所有可用的加噪配置"""
return self.get_all()
class FinetuneConfigRepository(BaseRepository[FinetuneConfig]):
"""微调配置数据访问"""
def _get_model_class(self) -> Type[FinetuneConfig]:
return FinetuneConfig
def _get_primary_key_name(self) -> str:
return 'finetune_configs_id'
def get_by_code(self, code: str) -> Optional[FinetuneConfig]:
"""根据代码获取微调配置"""
return FinetuneConfig.query.filter_by(finetune_code=code).first()
def get_all_active(self) -> List[FinetuneConfig]:
"""获取所有可用的微调配置"""
return self.get_all()
class DataTypeRepository(BaseRepository[DataType]):
"""数据集类型数据访问"""
def _get_model_class(self) -> Type[DataType]:
return DataType
def _get_primary_key_name(self) -> str:
return 'data_type_id'
def get_by_code(self, code: str) -> Optional[DataType]:
"""根据代码获取数据集类型"""
return DataType.query.filter_by(data_type_code=code).first()
def get_all_active(self) -> List[DataType]:
"""获取所有可用的数据集类型"""
return self.get_all()

@ -0,0 +1,121 @@
"""
图片 Repository
负责 Image 实体的数据访问
"""
import logging
from typing import Optional, List, Type
from app.database import Image, ImageType
from .base_repository import BaseRepository
logger = logging.getLogger(__name__)
class ImageRepository(BaseRepository[Image]):
"""
图片数据访问
提供图片相关的查询和操作方法
"""
def _get_model_class(self) -> Type[Image]:
return Image
def _get_primary_key_name(self) -> str:
return 'images_id'
# ==================== 查询方法 ====================
def get_by_task(self, task_id: int) -> List[Image]:
"""获取任务的所有图片"""
return Image.query.filter_by(task_id=task_id).all()
def get_by_task_and_type(self, task_id: int, type_code: str) -> List[Image]:
"""获取任务指定类型的图片"""
image_type = ImageType.query.filter_by(image_code=type_code).first()
if not image_type:
return []
return Image.query.filter_by(
task_id=task_id,
image_types_id=image_type.image_types_id
).all()
def get_first_by_task_and_type(self, task_id: int, type_code: str) -> Optional[Image]:
"""获取任务指定类型的第一张图片"""
images = self.get_by_task_and_type(task_id, type_code)
return images[0] if images else None
def get_by_type(self, type_code: str) -> List[Image]:
"""获取指定类型的所有图片"""
image_type = ImageType.query.filter_by(image_code=type_code).first()
if not image_type:
return []
return Image.query.filter_by(image_types_id=image_type.image_types_id).all()
def get_children(self, parent_id: int) -> List[Image]:
"""获取子图片(派生图片)"""
return Image.query.filter_by(father_id=parent_id).all()
def get_parent(self, image_id: int) -> Optional[Image]:
"""获取父图片"""
image = self.get_by_id(image_id)
if image and image.father_id:
return self.get_by_id(image.father_id)
return None
def get_by_path(self, file_path: str) -> Optional[Image]:
"""根据文件路径获取图片"""
return Image.query.filter_by(file_path=file_path).first()
def get_by_filename(self, filename: str) -> Optional[Image]:
"""根据存储文件名获取图片"""
return Image.query.filter_by(stored_filename=filename).first()
# ==================== 统计方法 ====================
def count_by_task(self, task_id: int) -> int:
"""统计任务的图片数量"""
return Image.query.filter_by(task_id=task_id).count()
def count_by_task_and_type(self, task_id: int, type_code: str) -> int:
"""统计任务指定类型的图片数量"""
image_type = ImageType.query.filter_by(image_code=type_code).first()
if not image_type:
return 0
return Image.query.filter_by(
task_id=task_id,
image_types_id=image_type.image_types_id
).count()
# ==================== 权限验证 ====================
def is_owner(self, image: Image, user_id: int) -> bool:
"""验证图片归属(通过关联的任务)"""
if image and image.task:
return image.task.user_id == user_id
return False
def get_for_user(self, image_id: int, user_id: int) -> Optional[Image]:
"""获取用户的图片(带权限验证)"""
image = self.get_by_id(image_id)
if self.is_owner(image, user_id):
return image
return None
# ==================== 批量操作 ====================
def delete_by_task(self, task_id: int) -> int:
"""删除任务的所有图片记录"""
images = self.get_by_task(task_id)
count = 0
for image in images:
if self.delete(image):
count += 1
return count
def get_type_code(self, image: Image) -> Optional[str]:
"""获取图片类型代码"""
if image and image.image_type:
return image.image_type.image_code
return None

@ -0,0 +1,232 @@
"""
任务 Repository
负责 Task 及其子表Perturbation, Finetune, Heatmap, Evaluate的数据访问
"""
import logging
from typing import Optional, List, Type
from datetime import datetime
from app.database import (
Task, Perturbation, Finetune, Heatmap, Evaluate,
TaskType, TaskStatus, EvaluationResult
)
from .base_repository import BaseRepository
logger = logging.getLogger(__name__)
class TaskRepository(BaseRepository[Task]):
"""
任务数据访问
提供任务相关的查询和操作方法
"""
def _get_model_class(self) -> Type[Task]:
return Task
def _get_primary_key_name(self) -> str:
return 'tasks_id'
# ==================== 查询方法 ====================
def get_by_user(self, user_id: int) -> List[Task]:
"""获取用户的所有任务"""
return Task.query.filter_by(user_id=user_id).all()
def get_by_user_and_type(self, user_id: int, type_code: str) -> List[Task]:
"""获取用户指定类型的任务"""
task_type = TaskType.query.filter_by(task_type_code=type_code).first()
if not task_type:
return []
return Task.query.filter_by(
user_id=user_id,
tasks_type_id=task_type.task_type_id
).all()
def get_by_flow(self, flow_id: int) -> List[Task]:
"""获取同一工作流的所有任务"""
return Task.query.filter_by(flow_id=flow_id).all()
def get_by_flow_and_type(self, flow_id: int, type_code: str) -> Optional[Task]:
"""获取工作流中指定类型的任务"""
task_type = TaskType.query.filter_by(task_type_code=type_code).first()
if not task_type:
return None
return Task.query.filter_by(
flow_id=flow_id,
tasks_type_id=task_type.task_type_id
).first()
def get_by_status(self, status_code: str) -> List[Task]:
"""获取指定状态的任务"""
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
return []
return Task.query.filter_by(tasks_status_id=status.task_status_id).all()
def get_user_tasks_by_status(self, user_id: int, status_code: str) -> List[Task]:
"""获取用户指定状态的任务"""
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
return []
return Task.query.filter_by(
user_id=user_id,
tasks_status_id=status.task_status_id
).all()
def get_pending_tasks(self, user_id: int) -> List[Task]:
"""获取用户待处理的任务waiting + processing"""
waiting = TaskStatus.query.filter_by(task_status_code='waiting').first()
processing = TaskStatus.query.filter_by(task_status_code='processing').first()
status_ids = []
if waiting:
status_ids.append(waiting.task_status_id)
if processing:
status_ids.append(processing.task_status_id)
if not status_ids:
return []
return Task.query.filter(
Task.user_id == user_id,
Task.tasks_status_id.in_(status_ids)
).all()
def count_pending_tasks(self, user_id: int) -> int:
"""统计用户待处理任务数"""
return len(self.get_pending_tasks(user_id))
# ==================== 状态更新 ====================
def update_status(self, task: Task, status_code: str) -> bool:
"""
更新任务状态
Args:
task: 任务对象
status_code: 状态代码
Returns:
是否成功
"""
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
logger.error(f"Status '{status_code}' not found")
return False
task.tasks_status_id = status.task_status_id
# 自动更新时间戳
if status_code == 'processing':
task.started_at = datetime.utcnow()
elif status_code in ('completed', 'failed'):
task.finished_at = datetime.utcnow()
return True
def set_error(self, task: Task, error_message: str) -> bool:
"""设置任务错误信息并标记为失败"""
task.error_message = error_message
return self.update_status(task, 'failed')
# ==================== 权限验证 ====================
def is_owner(self, task: Task, user_id: int) -> bool:
"""验证任务归属"""
return task is not None and task.user_id == user_id
def get_for_user(self, task_id: int, user_id: int) -> Optional[Task]:
"""获取用户的任务(带权限验证)"""
task = self.get_by_id(task_id)
if self.is_owner(task, user_id):
return task
return None
# ==================== 任务类型判断 ====================
def get_type_code(self, task: Task) -> Optional[str]:
"""获取任务类型代码"""
if task and task.task_type:
return task.task_type.task_type_code
return None
def is_type(self, task: Task, type_code: str) -> bool:
"""判断任务是否为指定类型"""
return self.get_type_code(task) == type_code
class PerturbationRepository(BaseRepository[Perturbation]):
"""加噪任务详情数据访问"""
def _get_model_class(self) -> Type[Perturbation]:
return Perturbation
def _get_primary_key_name(self) -> str:
return 'tasks_id'
def get_by_task(self, task_id: int) -> Optional[Perturbation]:
"""根据任务 ID 获取加噪详情"""
return Perturbation.query.filter_by(tasks_id=task_id).first()
class FinetuneRepository(BaseRepository[Finetune]):
"""微调任务详情数据访问"""
def _get_model_class(self) -> Type[Finetune]:
return Finetune
def _get_primary_key_name(self) -> str:
return 'tasks_id'
def get_by_task(self, task_id: int) -> Optional[Finetune]:
"""根据任务 ID 获取微调详情"""
return Finetune.query.filter_by(tasks_id=task_id).first()
class HeatmapRepository(BaseRepository[Heatmap]):
"""热力图任务详情数据访问"""
def _get_model_class(self) -> Type[Heatmap]:
return Heatmap
def _get_primary_key_name(self) -> str:
return 'tasks_id'
def get_by_task(self, task_id: int) -> Optional[Heatmap]:
"""根据任务 ID 获取热力图详情"""
return Heatmap.query.filter_by(tasks_id=task_id).first()
def get_by_image(self, image_id: int) -> List[Heatmap]:
"""根据图片 ID 获取相关热力图任务"""
return Heatmap.query.filter_by(images_id=image_id).all()
class EvaluateRepository(BaseRepository[Evaluate]):
"""评估任务详情数据访问"""
def _get_model_class(self) -> Type[Evaluate]:
return Evaluate
def _get_primary_key_name(self) -> str:
return 'tasks_id'
def get_by_task(self, task_id: int) -> Optional[Evaluate]:
"""根据任务 ID 获取评估详情"""
return Evaluate.query.filter_by(tasks_id=task_id).first()
def get_by_finetune(self, finetune_task_id: int) -> Optional[Evaluate]:
"""根据微调任务 ID 获取评估任务"""
return Evaluate.query.filter_by(finetune_task_id=finetune_task_id).first()
class EvaluationResultRepository(BaseRepository[EvaluationResult]):
"""评估结果数据访问"""
def _get_model_class(self) -> Type[EvaluationResult]:
return EvaluationResult
def _get_primary_key_name(self) -> str:
return 'evaluation_results_id'

@ -0,0 +1,130 @@
"""
用户 Repository
负责 User UserConfig 实体的数据访问
"""
import logging
from typing import Optional, List, Type
from app.database import User, UserConfig, Role
from .base_repository import BaseRepository
logger = logging.getLogger(__name__)
class UserRepository(BaseRepository[User]):
"""
用户数据访问
提供用户相关的查询和操作方法
"""
def _get_model_class(self) -> Type[User]:
return User
def _get_primary_key_name(self) -> str:
return 'user_id'
# ==================== 查询方法 ====================
def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return User.query.filter_by(username=username).first()
def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return User.query.filter_by(email=email).first()
def get_active_users(self) -> List[User]:
"""获取所有激活的用户"""
return User.query.filter_by(is_active=True).all()
def get_by_role(self, role_code: str) -> List[User]:
"""获取指定角色的用户"""
role = Role.query.filter_by(role_code=role_code).first()
if not role:
return []
return User.query.filter_by(role_id=role.role_id).all()
# ==================== 验证方法 ====================
def username_exists(self, username: str) -> bool:
"""检查用户名是否存在"""
return self.get_by_username(username) is not None
def email_exists(self, email: str) -> bool:
"""检查邮箱是否存在"""
return self.get_by_email(email) is not None
def authenticate(self, username: str, password: str) -> Optional[User]:
"""
验证用户凭据
Args:
username: 用户名
password: 密码
Returns:
验证成功返回用户对象否则返回 None
"""
user = self.get_by_username(username)
if user and user.is_active and user.check_password(password):
return user
return None
# ==================== 角色相关 ====================
def get_role_code(self, user: User) -> Optional[str]:
"""获取用户角色代码"""
if user and user.role:
return user.role.role_code
return None
def is_admin(self, user: User) -> bool:
"""判断是否为管理员"""
return self.get_role_code(user) == 'admin'
def is_vip(self, user: User) -> bool:
"""判断是否为 VIP"""
return self.get_role_code(user) == 'vip'
def get_max_concurrent_tasks(self, user: User) -> int:
"""获取用户最大并发任务数"""
if user and user.role:
return user.role.max_concurrent_tasks or 1
return 1
class UserConfigRepository(BaseRepository[UserConfig]):
"""用户配置数据访问"""
def _get_model_class(self) -> Type[UserConfig]:
return UserConfig
def _get_primary_key_name(self) -> str:
return 'user_configs_id'
def get_by_user(self, user_id: int) -> Optional[UserConfig]:
"""根据用户 ID 获取配置"""
return UserConfig.query.filter_by(user_id=user_id).first()
def get_or_create(self, user_id: int) -> UserConfig:
"""获取或创建用户配置"""
config = self.get_by_user(user_id)
if not config:
config = self.create(user_id=user_id)
return config
class RoleRepository(BaseRepository[Role]):
"""角色数据访问"""
def _get_model_class(self) -> Type[Role]:
return Role
def _get_primary_key_name(self) -> str:
return 'role_id'
def get_by_code(self, role_code: str) -> Optional[Role]:
"""根据角色代码获取角色"""
return Role.query.filter_by(role_code=role_code).first()

@ -0,0 +1,4 @@
"""缓存服务模块"""
from .redis_client import RedisClient
__all__ = ['RedisClient']

@ -0,0 +1,66 @@
"""
Redis 客户端封装单例模式
职责单一只负责 Redis 连接管理和基础操作
"""
import logging
from typing import Optional
import redis
from flask import current_app
logger = logging.getLogger(__name__)
class RedisClient:
"""
Redis 客户端单例
使用方式:
client = RedisClient()
client.set('key', 'value', ex=300)
value = client.get('key')
"""
_instance: Optional['RedisClient'] = None
_pools: dict[str, redis.ConnectionPool] = {}
def __new__(cls) -> 'RedisClient':
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def _get_connection(self) -> redis.Redis:
"""获取 Redis 连接,复用连接池"""
redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0')
if redis_url not in self._pools:
self._pools[redis_url] = redis.ConnectionPool.from_url(
redis_url,
decode_responses=True
)
return redis.Redis(connection_pool=self._pools[redis_url])
def get(self, key: str) -> Optional[str]:
"""获取值"""
try:
return self._get_connection().get(key)
except Exception:
logger.exception(f"Redis GET 失败: {key}")
return None
def set(self, key: str, value: str, ex: Optional[int] = None) -> bool:
"""设置值,可选过期时间(秒)"""
try:
self._get_connection().set(key, value, ex=ex)
return True
except Exception:
logger.exception(f"Redis SET 失败: {key}")
return False
def delete(self, key: str) -> bool:
"""删除键"""
try:
return self._get_connection().delete(key) == 1
except Exception:
logger.exception(f"Redis DELETE 失败: {key}")
return False

Binary file not shown.

@ -0,0 +1,5 @@
"""邮件服务模块"""
from .email_sender import EmailSender
from .verification_service import VerificationService
__all__ = ['EmailSender', 'VerificationService']

@ -0,0 +1,67 @@
"""
邮件发送服务
职责单一只负责发送邮件
"""
import logging
from typing import Optional
from flask import current_app
from flask_mail import Message
logger = logging.getLogger(__name__)
class EmailSender:
"""
邮件发送器
使用方式:
sender = EmailSender()
sender.send('user@example.com', '标题', '内容')
"""
def send(
self,
to: str,
subject: str,
body: str,
html: Optional[str] = None
) -> bool:
"""
发送邮件
Args:
to: 收件人邮箱
subject: 邮件主题
body: 纯文本内容
html: HTML 内容可选
Returns:
是否发送成功
"""
try:
mail = current_app.extensions.get('mail')
if mail is None:
logger.error('Flask-Mail 未初始化,无法发送邮件')
return False
sender = (
current_app.config.get('MAIL_DEFAULT_SENDER') or
current_app.config.get('MAIL_USERNAME')
)
msg = Message(
subject=subject,
recipients=[to],
body=body,
html=html,
sender=sender
)
mail.send(msg)
logger.info(f'邮件发送成功: {to}')
return True
except Exception:
logger.exception(f'邮件发送失败: {to}')
return False

@ -0,0 +1,211 @@
"""
验证码服务
职责验证码的生成存储校验
通过组合 RedisClient EmailSender 实现依赖注入
"""
import random
import string
import logging
from typing import Optional
from flask import current_app
from app.services.cache import RedisClient
from app.services.email.email_sender import EmailSender
logger = logging.getLogger(__name__)
class VerificationService:
"""
验证码服务
使用方式:
# 方式1使用默认依赖
service = VerificationService()
# 方式2注入自定义依赖便于测试
service = VerificationService(
redis_client=mock_redis,
email_sender=mock_sender
)
# 发送验证码
service.send_code('user@example.com', purpose='register')
# 校验验证码
is_valid = service.verify_code('user@example.com', '123456', purpose='register')
"""
# 验证码 Redis key 前缀
KEY_PREFIX = 'verify'
def __init__(
self,
redis_client: Optional[RedisClient] = None,
email_sender: Optional[EmailSender] = None
):
"""
初始化验证码服务
Args:
redis_client: Redis 客户端默认使用单例
email_sender: 邮件发送器默认创建新实例
"""
self._redis = redis_client or RedisClient()
self._email = email_sender or EmailSender()
def _build_key(self, email: str, purpose: str) -> str:
"""构建 Redis 存储 key"""
return f"{self.KEY_PREFIX}:{purpose}:{email}"
@staticmethod
def _generate_code(length: int = 6) -> str:
"""生成数字验证码"""
return ''.join(random.choices(string.digits, k=length))
def _get_expire_seconds(self, custom_expire: Optional[int] = None) -> int:
"""获取过期时间(秒)"""
if custom_expire is not None:
return custom_expire
return current_app.config.get('VERIFICATION_CODE_EXPIRES', 300)
def _build_email_body(self, code: str, expire_seconds: int) -> str:
"""构建邮件内容"""
template = current_app.config.get(
'VERIFICATION_EMAIL_TEMPLATE',
'您的验证码为:{code},有效期 {expire_minutes} 分钟。请勿泄露给他人。'
)
return template.format(
code=code,
expire_seconds=expire_seconds,
expire_minutes=expire_seconds // 60
)
def send_code(
self,
email: str,
purpose: str = 'register',
length: int = 6,
expire_seconds: Optional[int] = None
) -> bool:
"""
生成并发送验证码
Args:
email: 目标邮箱
purpose: 用途register/reset_password/change_email
length: 验证码长度
expire_seconds: 过期时间默认从配置读取
Returns:
是否发送成功
"""
expire = self._get_expire_seconds(expire_seconds)
code = self._generate_code(length)
key = self._build_key(email, purpose)
# 存储到 Redis
if not self._redis.set(key, code, ex=expire):
logger.error(f'验证码存储失败: {email}')
return False
# 发送邮件
subject = current_app.config.get('VERIFICATION_EMAIL_SUBJECT', '您的验证码')
body = self._build_email_body(code, expire)
if not self._email.send(email, subject, body):
# 邮件发送失败,清理 Redis 中的验证码
self._redis.delete(key)
return False
logger.info(f'验证码已发送: {email} (purpose={purpose})')
return True
def verify_code(
self,
email: str,
code: str,
purpose: str = 'register',
delete_on_success: bool = True
) -> bool:
"""
校验验证码
Args:
email: 邮箱
code: 用户输入的验证码
purpose: 用途
delete_on_success: 校验成功后是否删除
Returns:
是否校验通过
"""
key = self._build_key(email, purpose)
stored = self._redis.get(key)
if stored is None:
logger.debug(f'验证码不存在或已过期: {email}')
return False
if str(stored) != str(code):
logger.debug(f'验证码不匹配: {email}')
return False
# 校验成功,删除验证码(防止重复使用)
if delete_on_success:
if not self._redis.delete(key):
logger.warning(f'验证码删除失败: {key}')
logger.info(f'验证码校验成功: {email} (purpose={purpose})')
return True
def clear_code(self, email: str, purpose: str = 'register') -> bool:
"""
清除验证码管理员操作或用户取消
Args:
email: 邮箱
purpose: 用途
Returns:
是否删除成功
"""
key = self._build_key(email, purpose)
return self._redis.delete(key)
# ============================================================
# 兼容层:保持原有函数接口,内部委托给 VerificationService
# 便于渐进式迁移,后续可逐步移除
# ============================================================
_default_service: Optional[VerificationService] = None
def _get_service() -> VerificationService:
"""获取默认服务实例(懒加载)"""
global _default_service
if _default_service is None:
_default_service = VerificationService()
return _default_service
def send_verification_code(
email: str,
purpose: str = 'register',
length: int = 6,
expire_seconds: Optional[int] = None
) -> bool:
"""【兼容接口】发送验证码"""
return _get_service().send_code(email, purpose, length, expire_seconds)
def verify_code(email: str, code: str, purpose: str = 'register') -> bool:
"""【兼容接口】校验验证码"""
return _get_service().verify_code(email, code, purpose)
def clear_verification_code(email: str, purpose: str = 'register') -> bool:
"""【兼容接口】清除验证码"""
return _get_service().clear_code(email, purpose)

@ -1,132 +0,0 @@
"""
验证码服务Redis 存储 + Flask-Mail 发送
提供
- `send_verification_code(email, purpose='register', length=6, expire_seconds=None)`
- `verify_code(email, code, purpose='register')`
依赖
- `redis`通过 `REDIS_URL` 配置默认为 redis://localhost:6379/0
- Flask-Mail 已在应用中初始化通过 `current_app.extensions['mail']` 获取
"""
import random
import string
import logging
from typing import Optional
import redis
from flask import current_app
from flask_mail import Message
logger = logging.getLogger(__name__)
pool = redis.ConnectionPool().from_url('redis://localhost:6379/0', decode_responses=True)
def _get_redis_client() -> redis.Redis:
"""根据 REDIS_URL 创建 redis 客户端"""
redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0')
# 不再使用全局的 pool而是每次根据 URL 获取连接redis-py 内部自己会管理连接池)
# 或者如果你想复用连接池,应该在 app 启动时初始化 pool而不是在全局写死 localhost
return redis.Redis.from_url(redis_url, decode_responses=True)
def _generate_code(length: int = 6) -> str:
"""生成指定长度的数字验证码(默认 6 位)。"""
# 只产生数字字符串更常用于验证码
return ''.join(random.choices(string.digits, k=length))
def send_verification_code(email: str,
purpose: str = 'register',
length: int = 6,
expire_seconds: Optional[int] = None) -> bool:
"""生成验证码,保存到 Redis并使用 Flask-Mail 发送给 `email`。
返回 True 表示发送成功False 表示失败或抛出异常时捕获后返回 False
"""
# 读取过期时间(秒),优先使用传入值,其次使用 app config
if expire_seconds is None:
expire_seconds = current_app.config.get('VERIFICATION_CODE_EXPIRES', 300)
code = _generate_code(length)
key = f"verify:{purpose}:{email}"
try:
r = _get_redis_client()
# 使用字符串保存验证码,并设置过期时间
r.set(key, code, ex=expire_seconds)
except Exception as e:
logger.exception("保存验证码到 Redis 失败")
return False
# 尝试发送邮件
try:
mail = current_app.extensions.get('mail')
subject = "您的验证码" # current_app.config.get('VERIFICATION_EMAIL_SUBJECT', '您的验证码')
sender = '1798231811@qq.com' # current_app.config.get('MAIL_DEFAULT_SENDER') or current_app.config.get('MAIL_USERNAME')
body = f'您的验证码为:{code},有效期 {expire_seconds} 秒。' # current_app.config.get('VERIFICATION_EMAIL_TEMPLATE',
# f'您的验证码为:{code},有效期 {expire_seconds} 秒。')
# 优先使用简单文本邮件,项目中可按需替换为 HTML 模板
msg = Message(subject=subject, recipients=[email], body=body, sender=sender)
if mail is None:
# 如果 Flask-Mail 未初始化,记录日志并返回 False
logger.error('Flask-Mail 未初始化,无法发送邮件')
return False
mail.send(msg)
logger.info('已发送验证码到 %s (purpose=%s)', email, purpose)
return True
except Exception:
logger.exception('发送验证码邮件失败')
return False
def verify_code(email: str, code: str, purpose: str = 'register') -> bool:
"""校验验证码是否正确。成功可配置是否从 Redis 删除该 key。
返回 True 表示校验通过False 表示失败或异常
"""
key = f"verify:{purpose}:{email}"
try:
r = _get_redis_client()
stored = r.get(key)
if stored is None:
return False
matched = (str(stored) == str(code))
if matched :
try:
r.delete(key)
except Exception:
logger.warning('校验成功,但删除 Redis key 失败: %s', key)
return matched
except Exception:
logger.exception('校验验证码时发生异常')
return False
def clear_verification_code(email: str, purpose: str = 'register') -> bool:
"""显式删除指定 email 的验证码(例如用于管理员撤销)。"""
key = f"verify:{purpose}:{email}"
try:
r = _get_redis_client()
return r.delete(key) == 1
except Exception:
logger.exception('删除验证码失败')
return False
if __name__ == '__main__':
# 简单测试发送和验证功能
test_email = "3310207578@qq.com"
if send_verification_code(test_email, expire_seconds=600):
print("验证码发送成功")
code = input("请输入收到的验证码: ")
if verify_code(test_email, code):
print("验证码验证成功")
else:
print("验证码验证失败")
else:
print("验证码发送失败")

@ -0,0 +1,23 @@
"""
图片服务模块
按职责拆分为:
- ImageProcessor: 图片预处理裁剪缩放格式转换
- ImageStorage: 图片存储管理保存删除
- ImageSerializer: 图片序列化JSONBase64
- ZipService: 打包服务
- ImagePreviewService: 预览图片服务
"""
from .image_processor import ImageProcessor
from .image_storage import ImageStorage
from .image_serializer import ImageSerializer
from .zip_service import ZipService
from .image_preview import ImagePreviewService
__all__ = [
'ImageProcessor',
'ImageStorage',
'ImageSerializer',
'ZipService',
'ImagePreviewService',
]

@ -0,0 +1,125 @@
"""
图片预览服务
职责单一获取各类任务的预览图片返回图片ID列表前端通过二进制接口获取
"""
import logging
from typing import Dict, List, Optional
from app.database import Task, Image
logger = logging.getLogger(__name__)
def _get_image_repo():
"""懒加载获取 ImageRepository"""
from app.repositories import ImageRepository
return ImageRepository()
def _get_task_repo():
"""懒加载获取 TaskRepository"""
from app.repositories import TaskRepository
return TaskRepository()
class ImagePreviewService:
"""
图片预览服务
负责获取各类任务的图片ID列表前端通过 /binary/task /binary/flow 接口获取二进制数据
"""
def _image_to_meta(self, image: Image) -> Optional[Dict]:
"""将图片转换为元数据字典"""
if not image:
return None
return {
'image_id': image.images_id,
'filename': image.stored_filename,
'width': image.width,
'height': image.height
}
def get_perturbation_preview(self, task: Task) -> Dict[str, List]:
"""获取加噪任务的预览图片元数据"""
image_repo = _get_image_repo()
originals = image_repo.get_by_task_and_type(task.tasks_id, 'original')
perturbeds = image_repo.get_by_task_and_type(task.tasks_id, 'perturbed')
return {
'original': [self._image_to_meta(img) for img in originals if img],
'perturbed': [self._image_to_meta(img) for img in perturbeds if img]
}
def get_finetune_preview(self, task: Task) -> Dict[str, List]:
"""获取微调任务的预览图片元数据"""
image_repo = _get_image_repo()
task_repo = _get_task_repo()
images = {
'original': [],
'original_generate': [],
'perturbed_generate': [],
'uploaded_generate': []
}
flow_tasks = task_repo.get_by_flow(task.flow_id)
for flow_task in flow_tasks:
if flow_task.user_id == task.user_id:
originals = image_repo.get_by_task_and_type(flow_task.tasks_id, 'original')
images['original'].extend([
self._image_to_meta(img) for img in originals if img
])
for type_code in ['original_generate', 'perturbed_generate', 'uploaded_generate']:
generated = image_repo.get_by_task_and_type(task.tasks_id, type_code)
images[type_code] = [self._image_to_meta(img) for img in generated if img]
return images
def get_heatmap_preview(self, task: Task) -> Dict[str, List]:
"""获取热力图任务的预览图片元数据"""
image_repo = _get_image_repo()
heatmaps = image_repo.get_by_task_and_type(task.tasks_id, 'heatmap')
return {
'heatmap': [self._image_to_meta(img) for img in heatmaps if img]
}
def get_evaluate_preview(self, task: Task) -> Dict[str, List]:
"""获取评估任务的预览图片元数据"""
image_repo = _get_image_repo()
reports = image_repo.get_by_task_and_type(task.tasks_id, 'report')
return {
'report': [self._image_to_meta(img) for img in reports if img]
}
def get_preview_by_task_type(self, task: Task, task_type: str) -> Dict[str, List]:
"""根据任务类型获取预览图片元数据"""
handlers = {
'perturbation': self.get_perturbation_preview,
'finetune': self.get_finetune_preview,
'heatmap': self.get_heatmap_preview,
'evaluate': self.get_evaluate_preview,
}
handler = handlers.get(task_type)
if handler:
return handler(task)
return {}
# 全局单例
_default_preview_service: Optional[ImagePreviewService] = None
def get_preview_service() -> ImagePreviewService:
"""获取默认的预览服务实例"""
global _default_preview_service
if _default_preview_service is None:
_default_preview_service = ImagePreviewService()
return _default_preview_service

@ -0,0 +1,125 @@
"""
图片预处理服务
职责单一图片的裁剪缩放格式转换
"""
import logging
from typing import Tuple
from PIL import Image as PILImage
logger = logging.getLogger(__name__)
class ImageProcessor:
"""
图片预处理器
负责图片的:
- 中心裁剪
- 缩放到指定尺寸
- 格式转换
使用方式:
processor = ImageProcessor(target_size=512)
pil_image = processor.process(file_storage)
processor.save(pil_image, '/path/to/output.png')
"""
DEFAULT_SIZE = 512
DEFAULT_FORMAT = 'PNG'
def __init__(self, target_size: int = None):
"""
初始化处理器
Args:
target_size: 目标尺寸正方形边长默认 512
"""
self._target_size = target_size or self.DEFAULT_SIZE
@property
def target_size(self) -> int:
return self._target_size
def process_from_file(self, file_storage) -> PILImage.Image:
"""
从上传文件处理图片
Args:
file_storage: Flask 文件存储对象
Returns:
处理后的 PIL Image 对象
"""
file_storage.stream.seek(0)
image = PILImage.open(file_storage.stream).convert('RGB')
return self._crop_and_resize(image)
def process_from_path(self, file_path: str) -> PILImage.Image:
"""
从文件路径处理图片
Args:
file_path: 图片文件路径
Returns:
处理后的 PIL Image 对象
"""
image = PILImage.open(file_path).convert('RGB')
return self._crop_and_resize(image)
def _crop_and_resize(self, image: PILImage.Image) -> PILImage.Image:
"""
中心裁剪并缩放
Args:
image: 原始 PIL Image
Returns:
处理后的 PIL Image
"""
width, height = image.size
min_dim = min(width, height)
# 中心裁剪为正方形
left = (width - min_dim) // 2
top = (height - min_dim) // 2
image = image.crop((left, top, left + min_dim, top + min_dim))
# 缩放到目标尺寸
return image.resize(
(self._target_size, self._target_size),
resample=PILImage.Resampling.LANCZOS
)
def save(
self,
image: PILImage.Image,
output_path: str,
format: str = None,
quality: int = 95
) -> Tuple[int, int, int]:
"""
保存处理后的图片
Args:
image: PIL Image 对象
output_path: 输出路径
format: 输出格式PNG/JPEG默认从路径推断
quality: JPEG 质量仅对 JPEG 有效
Returns:
(width, height, file_size) 元组
"""
import os
if format is None:
ext = os.path.splitext(output_path)[1].lower()
format = 'JPEG' if ext in ('.jpg', '.jpeg') else 'PNG'
if format.upper() == 'JPEG':
image.save(output_path, format='JPEG', quality=quality)
else:
image.save(output_path, format=format.upper())
return image.width, image.height, os.path.getsize(output_path)

@ -0,0 +1,147 @@
"""
图片序列化服务
职责单一图片的序列化JSON二进制流
"""
import os
import logging
from typing import Optional, Dict, Any, List, Generator
from app.database import Image
logger = logging.getLogger(__name__)
class ImageSerializer:
"""
图片序列化器
负责:
- 将图片对象转换为 JSON 字典
- 生成 multipart/mixed 二进制流
使用方式:
serializer = ImageSerializer()
data = serializer.to_dict(image)
"""
MIME_TYPES = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
def to_dict(self, image: Image) -> Optional[Dict[str, Any]]:
"""
将图片对象序列化为字典
Args:
image: Image 数据库对象
Returns:
字典或 None
"""
if not image:
return None
return {
'image_id': image.images_id,
'task_id': image.task_id,
'stored_filename': image.stored_filename,
'file_path': image.file_path,
'file_size': image.file_size,
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}
def get_url(self, image: Image) -> Optional[str]:
"""
获取图片访问 URL
Args:
image: Image 数据库对象
Returns:
URL 字符串或 None
"""
if not image or not image.file_path:
return None
return f"/api/image/file/{image.images_id}"
def serialize_list(self, images: list) -> list:
"""
批量序列化图片列表
Args:
images: Image 对象列表
Returns:
序列化后的列表
"""
return [self.to_dict(img) for img in images if img]
def generate_multipart_stream(
self,
images_dict: Dict[str, List[Image]],
boundary: str
) -> Generator[bytes, None, None]:
"""
生成 multipart/mixed 格式的二进制流
Args:
images_dict: 图片字典 {type_code: [Image, ...]}
boundary: multipart 边界字符串
Yields:
二进制数据块
"""
for type_code, images in images_dict.items():
for image in images:
if not image or not image.file_path:
continue
if not os.path.exists(image.file_path):
logger.warning(f"图片文件不存在: {image.file_path}")
continue
ext = os.path.splitext(image.file_path)[1].lower()
mimetype = self.MIME_TYPES.get(ext, 'application/octet-stream')
# 构建 part header
header = (
f"--{boundary}\r\n"
f"Content-Type: {mimetype}\r\n"
f"Content-Disposition: attachment; filename=\"{image.stored_filename}\"\r\n"
f"X-Image-Id: {image.images_id}\r\n"
f"X-Image-Type: {type_code}\r\n"
f"X-Image-Width: {image.width or 0}\r\n"
f"X-Image-Height: {image.height or 0}\r\n"
f"\r\n"
)
yield header.encode('utf-8')
# 流式读取文件内容
with open(image.file_path, 'rb') as f:
while chunk := f.read(8192):
yield chunk
yield b"\r\n"
# 结束边界
yield f"--{boundary}--\r\n".encode('utf-8')
# 全局单例
_default_serializer: Optional[ImageSerializer] = None
def get_image_serializer() -> ImageSerializer:
"""获取默认的序列化器实例"""
global _default_serializer
if _default_serializer is None:
_default_serializer = ImageSerializer()
return _default_serializer

@ -0,0 +1,280 @@
"""
图片存储服务
职责单一图片的保存删除文件管理
"""
import os
import uuid
import logging
from typing import Optional, List, Tuple
from datetime import datetime
from app import db
from app.database import Image, ImageType
from app.utils.file_utils import allowed_file
from app.services.image.image_processor import ImageProcessor
logger = logging.getLogger(__name__)
def _get_image_repo():
"""懒加载获取 ImageRepository"""
from app.repositories import ImageRepository
return ImageRepository()
def _get_image_type_repo():
"""懒加载获取 ImageTypeRepository"""
from app.repositories import ImageTypeRepository
return ImageTypeRepository()
class ImageStorage:
"""
图片存储管理器
负责:
- 保存上传的图片到指定目录
- 创建数据库记录
- 删除图片文件和记录
- 管理临时文件
使用方式:
storage = ImageStorage()
result = storage.save_uploaded_image(file, task, target_dir)
storage.delete_image(image_id, user_id)
"""
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.gif', '.bmp', '.tiff'}
def __init__(self, processor: Optional[ImageProcessor] = None, image_repo=None):
"""
初始化存储管理器
Args:
processor: 图片处理器默认创建新实例
image_repo: 图片 Repository默认懒加载
"""
self._processor = processor or ImageProcessor()
self._image_repo = image_repo
@property
def processor(self) -> ImageProcessor:
return self._processor
@property
def image_repo(self):
"""懒加载 ImageRepository"""
if self._image_repo is None:
self._image_repo = _get_image_repo()
return self._image_repo
def save_uploaded_image(
self,
file,
task,
target_dir: str,
image_type_code: str = 'original'
) -> dict:
"""
保存单张上传的图片
Args:
file: 上传的文件对象
task: 关联的任务对象
target_dir: 目标存储目录
image_type_code: 图片类型代码
Returns:
{'success': True, 'image': Image} {'success': False, 'error': str}
"""
if not file or not file.filename:
return {'success': False, 'error': '无效的文件'}
if not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
ext = os.path.splitext(file.filename)[1].lower()
if ext not in self.IMAGE_EXTENSIONS:
return {'success': False, 'error': f'不支持的图片格式: {ext}'}
image_type = _get_image_type_repo().get_by_code(image_type_code)
if not image_type:
return {'success': False, 'error': f'未配置图片类型: {image_type_code}'}
os.makedirs(target_dir, exist_ok=True)
try:
# 处理图片
processed = self._processor.process_from_file(file)
# 生成文件名并保存
filename, path, width, height, file_size = self._save_with_unique_name(
processed, target_dir
)
# 创建数据库记录
image = self._create_record(
task=task,
image_type_id=image_type.image_types_id,
filename=filename,
path=path,
width=width,
height=height,
file_size=file_size
)
db.session.commit()
return {'success': True, 'image': image}
except Exception as e:
db.session.rollback()
logger.error(f"保存图片失败: {e}")
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
def save_multiple_images(
self,
files: List,
task,
target_dir: str,
image_type_code: str = 'original'
) -> Tuple[bool, any]:
"""
批量保存上传的图片
Args:
files: 文件列表
task: 关联的任务对象
target_dir: 目标存储目录
image_type_code: 图片类型代码
Returns:
(success, result) - result Image 列表或错误信息
"""
if not files:
return False, '未检测到文件上传'
image_type = _get_image_type_repo().get_by_code(image_type_code)
if not image_type:
return False, f'未配置图片类型: {image_type_code}'
os.makedirs(target_dir, exist_ok=True)
saved_records = []
saved_paths = []
try:
for file in files:
if not file or not file.filename:
continue
if not allowed_file(file.filename):
continue
ext = os.path.splitext(file.filename)[1].lower()
if ext not in self.IMAGE_EXTENSIONS:
continue
processed = self._processor.process_from_file(file)
filename, path, width, height, file_size = self._save_with_unique_name(
processed, target_dir
)
image = self._create_record(
task=task,
image_type_id=image_type.image_types_id,
filename=filename,
path=path,
width=width,
height=height,
file_size=file_size
)
saved_records.append(image)
saved_paths.append(path)
if not saved_records:
db.session.rollback()
return False, '未上传有效的图片文件'
db.session.commit()
return True, saved_records
except Exception as e:
db.session.rollback()
# 清理已保存的文件
for path in saved_paths:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
return False, f'上传图片失败: {str(e)}'
def delete_image(self, image_id: int, user_id: int) -> dict:
"""
删除图片验证权限
Args:
image_id: 图片 ID
user_id: 用户 ID用于权限验证
Returns:
{'success': True} {'success': False, 'error': str}
"""
try:
# 使用 Repository 获取并验证权限
image = self.image_repo.get_for_user(image_id, user_id)
if not image:
return {'success': False, 'error': '图片不存在或无权限'}
# 删除文件
if image.file_path and os.path.exists(image.file_path):
os.remove(image.file_path)
# 使用 Repository 删除记录
if self.image_repo.delete(image) and self.image_repo.save():
return {'success': True}
return {'success': False, 'error': '删除记录失败'}
except Exception as e:
self.image_repo.rollback()
logger.error(f"删除图片失败: {e}")
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
def _save_with_unique_name(self, image, target_dir: str) -> Tuple[str, str, int, int, int]:
"""保存图片并生成唯一文件名"""
timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png"
path = os.path.join(target_dir, filename)
width, height, file_size = self._processor.save(image, path)
return filename, path, width, height, file_size
def _create_record(
self,
task,
image_type_id: int,
filename: str,
path: str,
width: int,
height: int,
file_size: int,
father_id: int = None
) -> Image:
"""创建数据库记录"""
image = Image(
task_id=task.tasks_id,
image_types_id=image_type_id,
father_id=father_id,
stored_filename=filename,
file_path=path,
file_size=file_size,
width=width,
height=height
)
db.session.add(image)
return image
@staticmethod
def get_image_type_by_code(code: str) -> Optional[ImageType]:
"""根据代码获取图片类型(委托给 Repository"""
return _get_image_type_repo().get_by_code(code)

@ -0,0 +1,139 @@
"""
打包服务
职责单一目录和文件的 ZIP 打包
"""
import os
import io
import zipfile
import logging
from typing import Union, Dict, List, Tuple
logger = logging.getLogger(__name__)
class ZipService:
"""
ZIP 打包服务
负责:
- 将单个目录打包为 ZIP
- 将多个目录打包为 ZIP
使用方式:
zip_service = ZipService()
buffer, has_files = zip_service.zip_directory('/path/to/dir')
buffer, has_files = zip_service.zip_multiple({'label1': '/path1', 'label2': '/path2'})
"""
def zip_directory(self, directory: str) -> Tuple[io.BytesIO, bool]:
"""
将单个目录打包为 ZIP
Args:
directory: 目录路径
Returns:
(BytesIO 缓冲区, 是否包含文件)
"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if os.path.isdir(directory):
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, directory)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
def zip_multiple(
self,
directories: Union[Dict[str, str], List[str]]
) -> Tuple[io.BytesIO, bool]:
"""
将多个目录打包为 ZIP
Args:
directories: 目录字典 {label: path} 或目录列表
Returns:
(BytesIO 缓冲区, 是否包含文件)
"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
# 统一处理字典和列表
if isinstance(directories, dict):
iterable = directories.items()
else:
iterable = (
(os.path.basename(d.rstrip(os.sep)) or 'output', d)
for d in directories
)
for label, directory in iterable:
if not os.path.isdir(directory):
continue
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, directory)
arcname = os.path.join(label or 'output', rel_path)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
def zip_files(
self,
files: List[str],
base_dir: str = None
) -> Tuple[io.BytesIO, bool]:
"""
将文件列表打包为 ZIP
Args:
files: 文件路径列表
base_dir: 基础目录用于计算相对路径
Returns:
(BytesIO 缓冲区, 是否包含文件)
"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
for file_path in files:
if not os.path.isfile(file_path):
continue
if base_dir:
arcname = os.path.relpath(file_path, base_dir)
else:
arcname = os.path.basename(file_path)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
# 全局单例
_default_zip_service: ZipService = None
def get_zip_service() -> ZipService:
"""获取默认的打包服务实例"""
global _default_zip_service
if _default_zip_service is None:
_default_zip_service = ZipService()
return _default_zip_service

@ -1,212 +1,227 @@
"""
图像处理服务
处理图像上传保存等功能
图像处理服务兼容入口
已重构为面向对象设计此文件保留原有接口以保持向后兼容
新代码请直接使用:
from app.services.image import (
ImageProcessor, # 图片预处理
ImageStorage, # 图片存储
ImageSerializer, # 图片序列化
ZipService, # 打包服务
ImagePreviewService # 预览服务
)
相关类:
- ImageProcessor: 裁剪缩放格式转换
- ImageStorage: 保存删除文件管理
- ImageSerializer: JSONBase64 序列化
- ZipService: ZIP 打包
- ImagePreviewService: 任务预览图片
"""
import base64
import io
import os
import uuid
import zipfile
import time
from datetime import datetime
from werkzeug.utils import secure_filename
from flask import current_app, jsonify
from PIL import Image as PILImage
import logging
from typing import Optional, List, Tuple, Dict, Any
from flask import jsonify
from app import db
from app.database import Image, ImageType
from app.utils.file_utils import allowed_file
from app.services.image.image_processor import ImageProcessor
from app.services.image.image_storage import ImageStorage
from app.services.image.image_serializer import ImageSerializer, get_image_serializer
from app.services.image.zip_service import ZipService, get_zip_service
from app.services.image.image_preview import ImagePreviewService, get_preview_service
logger = logging.getLogger(__name__)
# 全局实例
_storage: Optional[ImageStorage] = None
_serializer: Optional[ImageSerializer] = None
_zip_service: Optional[ZipService] = None
_preview_service: Optional[ImagePreviewService] = None
def _get_storage() -> ImageStorage:
global _storage
if _storage is None:
_storage = ImageStorage()
return _storage
def _get_serializer() -> ImageSerializer:
global _serializer
if _serializer is None:
_serializer = ImageSerializer()
return _serializer
def _get_zip_service() -> ZipService:
global _zip_service
if _zip_service is None:
_zip_service = ZipService()
return _zip_service
def _get_preview_service() -> ImagePreviewService:
global _preview_service
if _preview_service is None:
_preview_service = ImagePreviewService()
return _preview_service
class ImageService:
"""
图像处理服务兼容类
内部委托给新的服务类保持原有 API 不变
"""
DEFAULT_TARGET_SIZE = 512
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'}
# ==================== 存储相关(委托给 ImageStorage====================
@staticmethod
def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'):
"""保存单张图片"""
# 简化实现,委托给新服务
from app.database import Task
task = Task.query.get(task_id)
if not task:
return {'success': False, 'error': '任务不存在'}
from flask import current_app
project_root = os.path.dirname(current_app.root_path)
target_dir = os.path.join(
project_root,
current_app.config.get('ORIGINAL_IMAGES_FOLDER', 'static/originals'),
str(user_id),
str(task_id)
)
return _get_storage().save_uploaded_image(file, task, target_dir)
@staticmethod
def save_original_images(task, files, target_dir, image_type_code='original', target_size=None):
"""保存原图上传"""
return _get_storage().save_multiple_images(files, task, target_dir, image_type_code)
@staticmethod
def delete_image(image_id, user_id):
"""删除图片"""
return _get_storage().delete_image(image_id, user_id)
@staticmethod
def get_image_type_by_code(code):
"""根据代码获取图片类型"""
return ImageStorage.get_image_type_by_code(code)
# ==================== 序列化相关(委托给 ImageSerializer====================
@staticmethod
def serialize_image(image):
"""图片序列化"""
return _get_serializer().to_dict(image)
@staticmethod
def get_image_url(image):
"""获取图片访问URL"""
return _get_serializer().get_url(image)
# ==================== 打包相关(委托给 ZipService====================
@staticmethod
def zip_directory(directory):
"""打包目录为zip"""
return _get_zip_service().zip_directory(directory)
@staticmethod
def zip_multiple_directories(directories):
"""打包多个目录"""
return _get_zip_service().zip_multiple(directories)
# ==================== 工具方法 ====================
@staticmethod
def json_error(message, status_code=400):
"""统一错误响应"""
return jsonify({'error': message}), status_code
# ==================== 保留的原有方法(复杂逻辑暂不迁移)====================
@staticmethod
def save_to_uploads(file, task_id, user_id):
"""
上传图片到uploads临时目录返回临时文件路径和原始文件名
"""
"""上传图片到uploads临时目录"""
import uuid
from flask import current_app
project_root = os.path.dirname(current_app.root_path)
upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(task_id))
upload_dir = os.path.join(
project_root,
current_app.config['UPLOAD_FOLDER'],
str(user_id),
str(task_id)
)
os.makedirs(upload_dir, exist_ok=True)
orig_ext = os.path.splitext(file.filename)[1].lower()
temp_name = f"{uuid.uuid4().hex}{orig_ext}"
temp_path = os.path.join(upload_dir, temp_name)
file.save(temp_path)
return temp_path, file.filename
@staticmethod
def preprocess_image(temp_path, original_filename, task_id, user_id, image_types_id, resolution=512, target_format='png'):
"""
对图片进行中心裁剪缩放格式转换重命名保存到static/originals返回数据库对象
原图命名格式: 0000.png, 0001.png, ..., 9999.png
使用数据库事务和重试机制确保并发安全
"""
final_path = None
max_retries = 50
try:
img = PILImage.open(temp_path).convert("RGB")
width, height = img.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
img = img.crop((left, top, right, bottom))
img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS)
project_root = os.path.dirname(current_app.root_path)
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(task_id))
os.makedirs(static_dir, exist_ok=True)
from app.database import ImageType
original_type = ImageType.query.filter_by(image_code='original').first()
target_image_types_id = original_type.image_types_id if original_type else image_types_id
# 首次查询最大序号
max_seq_result = db.session.execute(
db.text("""
SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq
FROM images
WHERE task_id = :task_id
AND image_types_id = :image_types_id
AND stored_filename REGEXP '^[0-9]{4}\\.'
"""),
{'task_id': task_id, 'image_types_id': target_image_types_id}
).fetchone()
# 强制类型转换,确保安全
try:
base_sequence = int(max_seq_result[0]) if max_seq_result[0] is not None else -1
except Exception:
base_sequence = -1
base_sequence += 1
# 重试机制从base_sequence开始尝试连续的序号
for attempt in range(max_retries):
sequence_number = int(base_sequence) + int(attempt)
fmt_str = str(target_format).lower() if target_format else 'png'
new_name = f"{sequence_number:04d}.{fmt_str}"
final_path = os.path.join(static_dir, new_name)
try:
# 检查数据库中是否已存在此文件名
existing = Image.query.filter_by(
task_id=task_id,
stored_filename=new_name
).first()
if existing:
# 已存在,尝试下一个序号
continue
# 保存图片文件
if target_format.lower() in ['jpg', 'jpeg']:
img.save(final_path, format='JPEG', quality=95)
else:
img.save(final_path, format=target_format.upper())
# 创建数据库记录
image = Image(
task_id=task_id,
image_types_id=image_types_id,
stored_filename=new_name,
file_path=final_path,
file_size=os.path.getsize(final_path),
width=img.width,
height=img.height
)
db.session.add(image)
db.session.commit()
# 删除临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
return {'success': True, 'image': image}
except Exception as e:
db.session.rollback()
error_msg = str(e)
# 如果是唯一性冲突,清理文件并尝试下一个序号
if 'Duplicate entry' in error_msg or '1062' in error_msg:
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except Exception:
pass
# 继续循环尝试下一个序号
time.sleep(0.005)
continue
else:
# 其他错误直接抛出
raise
# 所有尝试都失败
raise Exception(f"无法生成唯一文件名,已尝试序号 {base_sequence}{base_sequence + max_retries - 1}")
except Exception as e:
db.session.rollback()
# 清理可能已保存的文件
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except Exception:
pass
return {'success': False, 'error': f'图片预处理失败: {str(e)}'}
@staticmethod
def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'):
"""保存单张图片自动上传到uploads并预处理"""
try:
if not file or not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
temp_path, orig_name = ImageService.save_to_uploads(file, task_id, user_id)
return ImageService.preprocess_image(temp_path, orig_name, task_id, user_id, image_types_id, resolution, target_format)
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
return temp_path, file.filename
@staticmethod
def extract_and_save_zip(zip_file, task_id, user_id, image_types_id):
"""解压并保存压缩包中的图片"""
import uuid
import zipfile
import shutil
from flask import current_app
from werkzeug.utils import secure_filename
results = []
temp_dir = None
try:
# 创建临时目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}")
temp_dir = os.path.join(
project_root,
current_app.config['UPLOAD_FOLDER'],
'temp',
f"{uuid.uuid4().hex}"
)
os.makedirs(temp_dir, exist_ok=True)
# 保存压缩包
zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename))
zip_file.save(zip_path)
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 遍历解压的文件
for root, dirs, files in os.walk(temp_dir):
for filename in files:
if filename.lower().endswith(('.zip', '.rar')):
continue # 跳过压缩包文件本身
continue
if allowed_file(filename):
file_path = os.path.join(root, filename)
# 创建虚拟文件对象
class FileWrapper:
def __init__(self, path, name):
self.path = path
self.filename = name
def save(self, destination):
import shutil
shutil.copy2(self.path, destination)
virtual_file = FileWrapper(file_path, filename)
result = ImageService.save_image(virtual_file, task_id, user_id, image_types_id)
result = ImageService.save_image(
virtual_file, task_id, user_id, image_types_id
)
results.append(result)
return results
@ -215,7 +230,6 @@ class ImageService:
return [{'success': False, 'error': f'解压失败: {str(e)}'}]
finally:
# 清理临时文件
if temp_dir and os.path.exists(temp_dir):
import shutil
try:
@ -223,134 +237,25 @@ class ImageService:
except Exception:
pass
@staticmethod
def get_image_url(image):
"""获取图片访问URL"""
if not image or not image.file_path:
return None
# 这里返回相对路径前端可以拼接完整URL
return f"/api/image/file/{image.images_id}"
@staticmethod
def delete_image(image_id, user_id):
"""删除图片通过关联的task验证权限"""
try:
image = Image.query.filter_by(images_id=image_id).first()
if not image:
return {'success': False, 'error': '图片不存在'}
# 通过关联的task验证用户权限
if not image.task or image.task.user_id != user_id:
return {'success': False, 'error': '无权限删除该图片'}
# 删除文件
if os.path.exists(image.file_path):
os.remove(image.file_path)
# 删除数据库记录
db.session.delete(image)
db.session.commit()
return {'success': True}
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
# ==================== 控制器辅助功能 ====================
DEFAULT_TARGET_SIZE = 512
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'}
@staticmethod
def json_error(message, status_code=400):
"""统一错误响应"""
return jsonify({'error': message}), status_code
@staticmethod
def get_image_type_by_code(code):
"""根据代码获取图片类型"""
return ImageType.query.filter_by(image_code=code).first()
@staticmethod
def save_original_images(task, files, target_dir, image_type_code='original', target_size=None):
"""保存原图上传"""
if not files:
return False, '未检测到文件上传'
image_type = ImageService.get_image_type_by_code(image_type_code)
if not image_type:
return False, f'未配置图片类型: {image_type_code}'
os.makedirs(target_dir, exist_ok=True)
saved_records = []
saved_paths = []
size = target_size or ImageService.DEFAULT_TARGET_SIZE
try:
for file in files:
if not file or not file.filename:
continue
if not allowed_file(file.filename):
continue
extension = os.path.splitext(file.filename)[1].lower()
if extension not in ImageService.IMAGE_EXTENSIONS:
continue
processed = ImageService._prepare_image(file, size)
filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir)
image = ImageService._create_image_record(
task,
image_type.image_types_id,
filename,
path,
width,
height,
file_size
)
saved_records.append(image)
saved_paths.append(path)
if not saved_records:
db.session.rollback()
return False, '未上传有效的图片文件'
db.session.commit()
return True, saved_records
except Exception as exc:
db.session.rollback()
for path in saved_paths:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
return False, f'上传图片失败: {exc}'
@staticmethod
def _prepare_image(file_storage, target_size):
"""裁剪并缩放上传图片"""
file_storage.stream.seek(0)
image = PILImage.open(file_storage.stream).convert('RGB')
width, height = image.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
image = image.crop((left, top, left + min_dim, top + min_dim))
return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS)
processor = ImageProcessor(target_size=target_size)
return processor.process_from_file(file_storage)
@staticmethod
def _save_processed_image(image, target_dir):
"""将处理后的图片保存为PNG"""
import uuid
from datetime import datetime
timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png"
path = os.path.join(target_dir, filename)
image.save(path, format='PNG')
return filename, path, image.width, image.height, os.path.getsize(path)
@staticmethod
def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None):
"""创建图片数据库记录"""
@ -366,184 +271,3 @@ class ImageService:
)
db.session.add(image)
return image
@staticmethod
def zip_directory(directory):
"""打包目录为zip"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if os.path.isdir(directory):
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, directory)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def zip_multiple_directories(directories):
"""打包多个目录"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if isinstance(directories, dict):
iterable = directories.items()
else:
iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories)
for label, directory in iterable:
if not os.path.isdir(directory):
continue
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, directory)
arcname = os.path.join(label or 'output', rel_path)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def serialize_image(image):
"""图片序列化"""
if not image:
return None
return {
'image_id': image.images_id,
'task_id': image.task_id,
'stored_filename': image.stored_filename,
'file_path': image.file_path,
'file_size': image.file_size,
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}
@staticmethod
def image_to_base64(image):
"""将图片转换为 base64 编码"""
if not image or not os.path.exists(image.file_path):
return None
ext = os.path.splitext(image.file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
mimetype = mime_types.get(ext, 'image/png')
with open(image.file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
return {
'image_id': image.images_id,
'filename': image.stored_filename,
'data': f'data:{mimetype};base64,{data}',
'width': image.width,
'height': image.height
}
## ==================== 获取预览图片服务 ====================
def _get_perturbation_preview(task):
"""获取加噪任务的预览图片"""
images = {'original': [], 'perturbed': []}
original_type = ImageType.query.filter_by(image_code='original').first()
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if original_type:
originals = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=original_type.image_types_id
).all()
images['original'] = [ImageService.image_to_base64(img) for img in originals]
if perturbed_type:
perturbeds = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=perturbed_type.image_types_id
).all()
images['perturbed'] = [ImageService.image_to_base64(img) for img in perturbeds]
return images
def _get_finetune_preview(task):
"""获取微调任务的预览图片"""
images = {
'original': [],
'original_generate': [],
'perturbed_generate': [],
'uploaded_generate': []
}
# 获取原图从同一flow_id的perturbation任务或当前任务
original_type = ImageType.query.filter_by(image_code='original').first()
if original_type:
# 查找同flow下的原图
from app.database import Task
flow_tasks = Task.query.filter_by(flow_id=task.flow_id, user_id=task.user_id).all()
task_ids = [t.tasks_id for t in flow_tasks]
originals = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_type.image_types_id
).all()
images['original'] = [ImageService.image_to_base64(img) for img in originals]
# 获取生成图
for type_code in ['original_generate', 'perturbed_generate', 'uploaded_generate']:
img_type = ImageType.query.filter_by(image_code=type_code).first()
if img_type:
generated = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=img_type.image_types_id
).all()
images[type_code] = [ImageService.image_to_base64(img) for img in generated]
return images
def _get_heatmap_preview(task):
"""获取热力图任务的预览图片(热力图本身已包含原图和加噪图的对比)"""
images = {'heatmap': []}
# 获取热力图(已是完整的对比报告图)
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if heatmap_type:
heatmaps = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=heatmap_type.image_types_id
).all()
images['heatmap'] = [ImageService.image_to_base64(img) for img in heatmaps]
return images
def _get_evaluate_preview(task):
"""获取评估任务的预览图片"""
images = {
'report': []
}
# 获取报告图
report_type = ImageType.query.filter_by(image_code='report').first()
if report_type:
reports = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=report_type.image_types_id
).all()
images['report'] = [ImageService.image_to_base64(img) for img in reports]
return images

@ -0,0 +1,4 @@
"""存储服务模块"""
from .path_manager import PathManager
__all__ = ['PathManager']

@ -0,0 +1,310 @@
"""
路径管理器
职责单一统一管理项目中所有路径的生成逻辑
遵循开闭原则新增路径类型只需添加方法无需修改现有代码
"""
import os
from typing import Union
from flask import current_app
from config.settings import Config
class PathManager:
"""
路径管理器
统一管理所有文件存储路径的生成包括
- 原图路径
- 加噪图路径
- 生成图路径原图/加噪/上传
- 热力图路径
- 评估结果路径
- 类别数据路径
- 模型数据路径
- 坐标文件路径
使用方式:
pm = PathManager()
path = pm.get_original_images_path(user_id=1, flow_id=123)
"""
def __init__(self, project_root: str = None):
"""
初始化路径管理器
Args:
project_root: 项目根目录默认从 Flask app 获取
"""
self._project_root = project_root
@property
def project_root(self) -> str:
"""获取项目根目录(懒加载)"""
if self._project_root is None:
self._project_root = os.path.dirname(current_app.root_path)
return self._project_root
def _build_path(self, *parts: Union[str, int]) -> str:
"""
构建完整路径
Args:
*parts: 路径组成部分会自动转换为字符串
Returns:
完整的绝对路径
"""
str_parts = [str(p) for p in parts]
return os.path.join(self.project_root, *str_parts)
# ==================== 图片存储路径 ====================
def get_original_images_path(self, user_id: int, flow_id: int) -> str:
"""
原图存储路径
格式: {ORIGINAL_IMAGES_FOLDER}/{user_id}/{flow_id}
"""
return self._build_path(
Config.ORIGINAL_IMAGES_FOLDER,
user_id,
flow_id
)
def get_perturbed_images_path(self, user_id: int, flow_id: int) -> str:
"""
加噪图存储路径
格式: {PERTURBED_IMAGES_FOLDER}/{user_id}/{flow_id}
"""
return self._build_path(
Config.PERTURBED_IMAGES_FOLDER,
user_id,
flow_id
)
# ==================== 生成图存储路径 ====================
def get_original_generated_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
原图生成图存储路径
格式: {MODEL_ORIGINAL_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.MODEL_ORIGINAL_FOLDER,
user_id,
flow_id,
task_id
)
def get_perturbed_generated_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
加噪图生成图存储路径
格式: {MODEL_PERTURBED_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.MODEL_PERTURBED_FOLDER,
user_id,
flow_id,
task_id
)
def get_uploaded_generated_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
上传图生成图存储路径
格式: {MODEL_UPLOADED_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.MODEL_UPLOADED_FOLDER,
user_id,
flow_id,
task_id
)
# ==================== 结果存储路径 ====================
def get_heatmap_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
热力图存储路径
格式: {HEATDIF_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.HEATDIF_SAVE_FOLDER,
user_id,
flow_id,
task_id
)
def get_evaluate_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
评估结果存储路径
格式: {NUMBERS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.NUMBERS_SAVE_FOLDER,
user_id,
flow_id,
task_id
)
# ==================== 数据路径 ====================
def get_class_data_path(self, user_id: int, flow_id: int) -> str:
"""
类别数据存储路径
格式: {CLASS_DATA_FOLDER}/{user_id}/{flow_id}
"""
return self._build_path(
Config.CLASS_DATA_FOLDER,
user_id,
flow_id
)
def get_model_data_path(self) -> str:
"""
模型数据存储路径全局共享
格式: {MODEL_DATA_FOLDER}
"""
return self._build_path(Config.MODEL_DATA_FOLDER)
# ==================== 坐标文件路径 ====================
def get_coords_base_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
坐标文件基础路径
格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}
"""
return self._build_path(
Config.COORDS_SAVE_FOLDER,
user_id,
flow_id,
task_id
)
def get_original_coords_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
原图坐标文件路径3D可视化用
格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/original_coords.csv
"""
return os.path.join(
self.get_coords_base_path(user_id, flow_id, task_id),
'original_coords.csv'
)
def get_perturbed_coords_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
加噪图坐标文件路径3D可视化用
格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/perturbed_coords.csv
"""
return os.path.join(
self.get_coords_base_path(user_id, flow_id, task_id),
'perturbed_coords.csv'
)
def get_uploaded_coords_path(
self,
user_id: int,
flow_id: int,
task_id: int
) -> str:
"""
上传图坐标文件路径3D可视化用
格式: {COORDS_SAVE_FOLDER}/{user_id}/{flow_id}/{task_id}/coords.csv
"""
return os.path.join(
self.get_coords_base_path(user_id, flow_id, task_id),
'coords.csv'
)
# ==================== 图片文件完整路径 ====================
def get_original_image_file_path(
self,
user_id: int,
flow_id: int,
filename: str
) -> str:
"""获取原图文件的完整路径"""
return os.path.join(
self.get_original_images_path(user_id, flow_id),
filename
)
def get_perturbed_image_file_path(
self,
user_id: int,
flow_id: int,
filename: str
) -> str:
"""获取加噪图文件的完整路径"""
return os.path.join(
self.get_perturbed_images_path(user_id, flow_id),
filename
)
# ============================================================
# 全局单例实例(便于简单场景使用)
# ============================================================
_default_manager: PathManager = None
def get_path_manager() -> PathManager:
"""获取默认的路径管理器实例"""
global _default_manager
if _default_manager is None:
_default_manager = PathManager()
return _default_manager

@ -0,0 +1,28 @@
"""
任务处理模块
提供面向对象的任务处理器使用模板方法模式统一任务生命周期
使用方式:
from app.services.task import TaskHandlerFactory
handler = TaskHandlerFactory.create('perturbation')
job_id = handler.start(task_id)
"""
from .base_handler import BaseTaskHandler
from .perturbation_handler import PerturbationTaskHandler
from .finetune_handler import FinetuneTaskHandler
from .heatmap_handler import HeatmapTaskHandler
from .evaluate_handler import EvaluateTaskHandler
from .task_factory import TaskHandlerFactory
from .task_queue import TaskQueue
__all__ = [
'BaseTaskHandler',
'PerturbationTaskHandler',
'FinetuneTaskHandler',
'HeatmapTaskHandler',
'EvaluateTaskHandler',
'TaskHandlerFactory',
'TaskQueue',
]

@ -0,0 +1,146 @@
"""
任务处理器基类
使用模板方法模式定义任务启动的统一流程子类实现具体细节
"""
import logging
from abc import ABC, abstractmethod
from typing import Optional, Any
from app.database import Task
from app.services.storage import PathManager
from app.services.task.task_queue import TaskQueue
logger = logging.getLogger(__name__)
def _get_task_repo():
"""懒加载获取 TaskRepository"""
from app.repositories import TaskRepository
return TaskRepository()
class BaseTaskHandler(ABC):
"""
任务处理器抽象基类
定义任务启动的模板方法子类需实现:
- _get_task_type_code(): 返回任务类型代码
- _load_task_detail(): 加载任务详情
- _validate(): 验证任务数据
- _build_worker_params(): 构建 worker 参数
- _get_worker_func(): 返回 worker 函数
"""
def __init__(
self,
path_manager: Optional[PathManager] = None,
task_queue: Optional[TaskQueue] = None,
task_repo=None
):
self._path_manager = path_manager or PathManager()
self._task_queue = task_queue or TaskQueue()
self._task_repo = task_repo
@property
def path_manager(self) -> PathManager:
return self._path_manager
@property
def task_queue(self) -> TaskQueue:
return self._task_queue
@property
def task_repo(self):
"""懒加载 TaskRepository"""
if self._task_repo is None:
self._task_repo = _get_task_repo()
return self._task_repo
def start(self, task_id: int) -> Optional[str]:
"""
启动任务模板方法
"""
try:
task = self._load_task(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
detail = self._load_task_detail(task_id)
if not detail:
logger.error(f"{self._get_task_type_code()} detail for task {task_id} not found")
return None
error = self._validate(task, detail)
if error:
logger.error(f"Task {task_id} validation failed: {error}")
return None
self._update_status(task, 'waiting')
params = self._build_worker_params(task, detail)
job_id = self._enqueue(task_id, params)
if job_id:
logger.info(f"{self._get_task_type_code()} task {task_id} started with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting {self._get_task_type_code()} task {task_id}: {e}")
return None
def _load_task(self, task_id: int) -> Optional[Task]:
"""使用 Repository 加载任务"""
return self.task_repo.get_by_id(task_id)
def _update_status(self, task: Task, status_code: str) -> bool:
"""使用 Repository 更新任务状态"""
if self.task_repo.update_status(task, status_code):
return self.task_repo.save()
return False
def _enqueue(self, task_id: int, params: dict) -> Optional[str]:
job_id = self._get_job_id(task_id)
worker_func = self._get_worker_func()
timeout = self._get_timeout()
return self._task_queue.enqueue(
worker_func,
job_id=job_id,
timeout=timeout,
**params
)
@abstractmethod
def _get_task_type_code(self) -> str:
pass
@abstractmethod
def _load_task_detail(self, task_id: int) -> Optional[Any]:
pass
@abstractmethod
def _validate(self, task: Task, detail: Any) -> Optional[str]:
pass
@abstractmethod
def _build_worker_params(self, task: Task, detail: Any) -> dict:
pass
@abstractmethod
def _get_worker_func(self):
pass
def _get_job_id(self, task_id: int) -> str:
prefix_map = {
'perturbation': 'pert',
'finetune': 'ft',
'heatmap': 'hm',
'evaluate': 'eval'
}
prefix = prefix_map.get(self._get_task_type_code(), 'task')
return f"{prefix}_{task_id}"
def _get_timeout(self) -> str:
return '4h'

@ -0,0 +1,87 @@
"""
评估任务处理器
处理模型评估Evaluate任务的启动逻辑
"""
import logging
from typing import Optional
from app.database import Evaluate, Task
from app.services.task.base_handler import BaseTaskHandler
logger = logging.getLogger(__name__)
def _get_evaluate_repo():
"""懒加载获取 EvaluateRepository"""
from app.repositories import EvaluateRepository
return EvaluateRepository()
def _get_finetune_repo():
"""懒加载获取 FinetuneRepository"""
from app.repositories import FinetuneRepository
return FinetuneRepository()
class EvaluateTaskHandler(BaseTaskHandler):
"""
评估任务处理器
处理流程:
1. 加载 Evaluate 详情
2. 验证关联的微调任务存在
3. 从微调任务获取路径信息
4. 构建评估参数
5. 入队执行 evaluate_worker
"""
def _get_task_type_code(self) -> str:
return 'evaluate'
def _load_task_detail(self, task_id: int) -> Optional[Evaluate]:
return _get_evaluate_repo().get_by_task(task_id)
def _validate(self, task: Task, detail: Evaluate) -> Optional[str]:
"""验证评估任务配置"""
# 检查关联的微调任务
if not detail.finetune_task_id:
return "Evaluate task has no associated finetune task"
finetune = _get_finetune_repo().get_by_task(detail.finetune_task_id)
if not finetune:
return f"Finetune task {detail.finetune_task_id} not found"
finetune_task = finetune.task
if not finetune_task:
return f"Finetune task {detail.finetune_task_id} missing Task relation"
return None
def _build_worker_params(self, task: Task, detail: Evaluate) -> dict:
"""构建评估任务参数"""
pm = self.path_manager
# 获取关联的微调任务信息
finetune = _get_finetune_repo().get_by_task(detail.finetune_task_id)
finetune_task = finetune.task
user_id = finetune_task.user_id
flow_id = finetune_task.flow_id
finetune_task_id = finetune_task.tasks_id
return {
'task_id': task.tasks_id,
'clean_ref_dir': pm.get_original_images_path(user_id, flow_id),
'clean_output_dir': pm.get_original_generated_path(user_id, flow_id, finetune_task_id),
'perturbed_output_dir': pm.get_perturbed_generated_path(user_id, flow_id, finetune_task_id),
'output_dir': pm.get_evaluate_path(user_id, flow_id, task.tasks_id),
'image_size': 512,
}
def _get_worker_func(self):
from app.workers.evaluate_worker import run_evaluate_task
return run_evaluate_task
def _get_timeout(self) -> str:
return '2h'

@ -0,0 +1,214 @@
"""
微调任务处理器
处理模型微调Finetune任务的启动逻辑
支持两种类型基于加噪结果的微调 用户上传图片的微调
"""
import logging
from typing import Optional, List
from app.database import Finetune, Task
from app.services.task.base_handler import BaseTaskHandler
logger = logging.getLogger(__name__)
def _get_finetune_repo():
"""懒加载获取 FinetuneRepository"""
from app.repositories import FinetuneRepository
return FinetuneRepository()
def _get_finetune_config_repo():
"""懒加载获取 FinetuneConfigRepository"""
from app.repositories import FinetuneConfigRepository
return FinetuneConfigRepository()
class FinetuneTaskHandler(BaseTaskHandler):
"""
微调任务处理器
支持两种微调类型:
1. 基于加噪结果的微调perturbation-based
- 同一 flow_id 下存在 Perturbation 任务
- 同时处理原图和加噪图生成两个 job
2. 用户上传图片的微调uploaded
- 独立的 flow_id无关联的 Perturbation 任务
- 仅处理上传的原图生成一个 job
"""
def _get_task_type_code(self) -> str:
return 'finetune'
def _load_task_detail(self, task_id: int) -> Optional[Finetune]:
return _get_finetune_repo().get_by_task(task_id)
def _validate(self, task: Task, detail: Finetune) -> Optional[str]:
"""验证微调任务配置"""
config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id)
if not config:
return f"Finetune config {detail.finetune_configs_id} not found"
return None
def _has_perturbation_sibling(self, task: Task) -> bool:
"""检查是否存在同 flow_id 的加噪任务"""
sibling = self.task_repo.get_by_flow_and_type(task.flow_id, 'perturbation')
return sibling is not None and sibling.tasks_id != task.tasks_id
def _build_worker_params(self, task: Task, detail: Finetune) -> dict:
"""构建微调任务参数(单个 job 的情况)"""
# 此方法用于上传图片微调的情况
user_id = task.user_id
flow_id = task.flow_id
task_id = task.tasks_id
pm = self.path_manager
config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id)
return {
'task_id': task_id,
'finetune_method': config.finetune_code,
'train_images_dir': pm.get_original_images_path(user_id, flow_id),
'output_model_dir': pm.get_model_data_path(),
'class_dir': pm.get_class_data_path(user_id, flow_id),
'coords_save_path': pm.get_uploaded_coords_path(user_id, flow_id, task_id),
'validation_output_dir': pm.get_uploaded_generated_path(user_id, flow_id, task_id),
'is_perturbed': False,
'custom_params': None,
}
def _build_perturbation_based_params(self, task: Task, detail: Finetune) -> List[dict]:
"""构建基于加噪的微调参数(两个 job"""
user_id = task.user_id
flow_id = task.flow_id
task_id = task.tasks_id
pm = self.path_manager
config = _get_finetune_config_repo().get_by_id(detail.finetune_configs_id)
# 原图微调参数
original_params = {
'task_id': task_id,
'finetune_method': config.finetune_code,
'train_images_dir': pm.get_original_images_path(user_id, flow_id),
'output_model_dir': pm.get_model_data_path(),
'class_dir': pm.get_class_data_path(user_id, flow_id),
'coords_save_path': pm.get_original_coords_path(user_id, flow_id, task_id),
'validation_output_dir': pm.get_original_generated_path(user_id, flow_id, task_id),
'is_perturbed': False,
'custom_params': None,
}
# 加噪图微调参数
perturbed_params = {
'task_id': task_id,
'finetune_method': config.finetune_code,
'train_images_dir': pm.get_perturbed_images_path(user_id, flow_id),
'output_model_dir': pm.get_model_data_path(),
'class_dir': pm.get_class_data_path(user_id, flow_id),
'coords_save_path': pm.get_perturbed_coords_path(user_id, flow_id, task_id),
'validation_output_dir': pm.get_perturbed_generated_path(user_id, flow_id, task_id),
'is_perturbed': True,
'custom_params': None,
}
return [original_params, perturbed_params]
def _get_worker_func(self):
from app.workers.finetune_worker import run_finetune_task
return run_finetune_task
def _get_timeout(self) -> str:
return '8h'
def start(self, task_id: int) -> Optional[str]:
"""
启动微调任务重写模板方法以支持双 job
Returns:
单个 job_id 或逗号分隔的多个 job_id
"""
try:
# 加载任务
task = self._load_task(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
detail = self._load_task_detail(task_id)
if not detail:
logger.error(f"Finetune detail for task {task_id} not found")
return None
# 验证
error = self._validate(task, detail)
if error:
logger.error(f"Task {task_id} validation failed: {error}")
return None
# 更新状态
self._update_status(task, 'waiting')
# 判断微调类型
if self._has_perturbation_sibling(task):
# 基于加噪的微调:创建两个 job
logger.info(f"Finetune task {task_id}: type=perturbation-based")
return self._start_perturbation_based(task_id, task, detail)
else:
# 上传图片的微调:创建一个 job
logger.info(f"Finetune task {task_id}: type=uploaded")
return self._start_uploaded(task_id, task, detail)
except Exception as e:
logger.error(f"Error starting finetune task {task_id}: {e}")
return None
def _start_perturbation_based(self, task_id: int, task: Task, detail: Finetune) -> Optional[str]:
"""启动基于加噪的微调(两个 job"""
params_list = self._build_perturbation_based_params(task, detail)
worker_func = self._get_worker_func()
timeout = self._get_timeout()
job_id_original = f"ft_{task_id}_original"
job_id_perturbed = f"ft_{task_id}_perturbed"
# 入队原图微调
result1 = self.task_queue.enqueue(
worker_func,
job_id=job_id_original,
timeout=timeout,
**params_list[0]
)
# 入队加噪图微调
result2 = self.task_queue.enqueue(
worker_func,
job_id=job_id_perturbed,
timeout=timeout,
**params_list[1]
)
if result1 and result2:
logger.info(f"Finetune task {task_id} enqueued: {job_id_original}, {job_id_perturbed}")
return f"{job_id_original},{job_id_perturbed}"
return None
def _start_uploaded(self, task_id: int, task: Task, detail: Finetune) -> Optional[str]:
"""启动上传图片的微调(单个 job"""
params = self._build_worker_params(task, detail)
job_id = f"ft_{task_id}"
result = self.task_queue.enqueue(
self._get_worker_func(),
job_id=job_id,
timeout=self._get_timeout(),
**params
)
if result:
logger.info(f"Finetune task {task_id} enqueued: {job_id}")
return result

@ -0,0 +1,105 @@
"""
热力图任务处理器
处理热力图Heatmap生成任务的启动逻辑
"""
import os
import logging
from typing import Optional
from app.database import Heatmap, Task
from app.services.task.base_handler import BaseTaskHandler
logger = logging.getLogger(__name__)
def _get_heatmap_repo():
"""懒加载获取 HeatmapRepository"""
from app.repositories import HeatmapRepository
return HeatmapRepository()
def _get_image_repo():
"""懒加载获取 ImageRepository"""
from app.repositories import ImageRepository
return ImageRepository()
class HeatmapTaskHandler(BaseTaskHandler):
"""
热力图任务处理器
处理流程:
1. 加载 Heatmap 详情
2. 验证关联的加噪图片存在
3. 通过 father_id 找到原图
4. 构建图片路径和输出路径
5. 入队执行 heatmap_worker
"""
def _get_task_type_code(self) -> str:
return 'heatmap'
def _load_task_detail(self, task_id: int) -> Optional[Heatmap]:
return _get_heatmap_repo().get_by_task(task_id)
def _validate(self, task: Task, detail: Heatmap) -> Optional[str]:
"""验证热力图任务配置"""
image_repo = _get_image_repo()
# 检查加噪图片 ID
if not detail.images_id:
return "Heatmap task has no associated perturbed image"
# 检查加噪图片存在
perturbed_image = image_repo.get_by_id(detail.images_id)
if not perturbed_image:
return f"Perturbed image {detail.images_id} not found"
# 检查原图存在(通过 father_id
if not perturbed_image.father_id:
return f"Perturbed image {detail.images_id} has no father_id"
original_image = image_repo.get_by_id(perturbed_image.father_id)
if not original_image:
return f"Original image (father_id={perturbed_image.father_id}) not found"
return None
def _build_worker_params(self, task: Task, detail: Heatmap) -> dict:
"""构建热力图任务参数"""
image_repo = _get_image_repo()
user_id = task.user_id
flow_id = task.flow_id
task_id = task.tasks_id
pm = self.path_manager
# 获取加噪图片
perturbed_image = image_repo.get_by_id(detail.images_id)
original_image = image_repo.get_by_id(perturbed_image.father_id)
# 构建图片完整路径
original_image_path = os.path.join(
pm.get_original_images_path(user_id, flow_id),
original_image.stored_filename
)
perturbed_image_path = os.path.join(
pm.get_perturbed_images_path(user_id, flow_id),
perturbed_image.stored_filename
)
return {
'task_id': task_id,
'original_image_path': original_image_path,
'perturbed_image_path': perturbed_image_path,
'output_dir': pm.get_heatmap_path(user_id, flow_id, task_id),
'perturbed_image_id': detail.images_id,
}
def _get_worker_func(self):
from app.workers.heatmap_worker import run_heatmap_task
return run_heatmap_task
def _get_timeout(self) -> str:
return '2h'

@ -0,0 +1,61 @@
"""
加噪任务处理器
"""
import logging
from typing import Optional
from app.database import Perturbation, Task
from app.services.task.base_handler import BaseTaskHandler
logger = logging.getLogger(__name__)
def _get_perturbation_repo():
"""懒加载获取 PerturbationRepository"""
from app.repositories import PerturbationRepository
return PerturbationRepository()
def _get_perturbation_config_repo():
"""懒加载获取 PerturbationConfigRepository"""
from app.repositories import PerturbationConfigRepository
return PerturbationConfigRepository()
class PerturbationTaskHandler(BaseTaskHandler):
"""加噪任务处理器"""
def _get_task_type_code(self) -> str:
return 'perturbation'
def _load_task_detail(self, task_id: int) -> Optional[Perturbation]:
return _get_perturbation_repo().get_by_task(task_id)
def _validate(self, task: Task, detail: Perturbation) -> Optional[str]:
config = _get_perturbation_config_repo().get_by_id(detail.perturbation_configs_id)
if not config:
return f"Perturbation config {detail.perturbation_configs_id} not found"
return None
def _build_worker_params(self, task: Task, detail: Perturbation) -> dict:
user_id = task.user_id
flow_id = task.flow_id
pm = self.path_manager
config = _get_perturbation_config_repo().get_by_id(detail.perturbation_configs_id)
return {
'task_id': task.tasks_id,
'input_dir': pm.get_original_images_path(user_id, flow_id),
'output_dir': pm.get_perturbed_images_path(user_id, flow_id),
'class_dir': pm.get_class_data_path(user_id, flow_id),
'algorithm_code': config.perturbation_code,
'epsilon': detail.perturbation_intensity,
}
def _get_worker_func(self):
from app.workers.perturbation_worker import run_perturbation_task
return run_perturbation_task
def _get_timeout(self) -> str:
return '4h'

@ -0,0 +1,126 @@
"""
任务处理器工厂
使用工厂模式根据任务类型创建对应的处理器
"""
import logging
from typing import Optional, Type, Dict
from app.services.task.base_handler import BaseTaskHandler
from app.services.task.perturbation_handler import PerturbationTaskHandler
from app.services.task.finetune_handler import FinetuneTaskHandler
from app.services.task.heatmap_handler import HeatmapTaskHandler
from app.services.task.evaluate_handler import EvaluateTaskHandler
from app.services.storage import PathManager
from app.services.task.task_queue import TaskQueue
logger = logging.getLogger(__name__)
class TaskHandlerFactory:
"""
任务处理器工厂
根据任务类型代码创建对应的处理器实例
使用方式:
# 方式1使用默认依赖
handler = TaskHandlerFactory.create('perturbation')
job_id = handler.start(task_id=123)
# 方式2注入自定义依赖便于测试
handler = TaskHandlerFactory.create(
'finetune',
path_manager=mock_pm,
task_queue=mock_queue
)
支持的任务类型:
- perturbation: 加噪任务
- finetune: 微调任务
- heatmap: 热力图任务
- evaluate: 评估任务
"""
# 任务类型到处理器类的映射
_handlers: Dict[str, Type[BaseTaskHandler]] = {
'perturbation': PerturbationTaskHandler,
'finetune': FinetuneTaskHandler,
'heatmap': HeatmapTaskHandler,
'evaluate': EvaluateTaskHandler,
}
@classmethod
def create(
cls,
task_type: str,
path_manager: Optional[PathManager] = None,
task_queue: Optional[TaskQueue] = None
) -> BaseTaskHandler:
"""
创建任务处理器
Args:
task_type: 任务类型代码
path_manager: 路径管理器可选
task_queue: 任务队列可选
Returns:
对应的任务处理器实例
Raises:
ValueError: 未知的任务类型
"""
handler_class = cls._handlers.get(task_type)
if handler_class is None:
available = ', '.join(cls._handlers.keys())
raise ValueError(
f"Unknown task type: '{task_type}'. "
f"Available types: {available}"
)
return handler_class(
path_manager=path_manager,
task_queue=task_queue
)
@classmethod
def register(cls, task_type: str, handler_class: Type[BaseTaskHandler]) -> None:
"""
注册新的任务处理器扩展点
Args:
task_type: 任务类型代码
handler_class: 处理器类
"""
cls._handlers[task_type] = handler_class
logger.info(f"Registered task handler: {task_type} -> {handler_class.__name__}")
@classmethod
def get_supported_types(cls) -> list:
"""获取所有支持的任务类型"""
return list(cls._handlers.keys())
# ============================================================
# 便捷函数:直接启动任务
# ============================================================
def start_task(task_type: str, task_id: int) -> Optional[str]:
"""
便捷函数根据类型启动任务
Args:
task_type: 任务类型代码
task_id: 任务 ID
Returns:
job_id None
"""
try:
handler = TaskHandlerFactory.create(task_type)
return handler.start(task_id)
except ValueError as e:
logger.error(str(e))
return None

@ -0,0 +1,116 @@
"""
任务队列管理
封装 Redis Queue (RQ) 的连接和队列操作
"""
import logging
from typing import Optional, Callable, Any
from redis import Redis
from rq import Queue
from rq.job import Job
from config.algorithm_config import AlgorithmConfig
logger = logging.getLogger(__name__)
class TaskQueue:
"""
任务队列管理器
封装 RQ 队列操作提供统一的任务入队接口
使用方式:
queue = TaskQueue()
job_id = queue.enqueue(
worker_func,
job_id='task_123',
timeout='4h',
task_id=123,
input_dir='/path/to/input'
)
"""
_instance: Optional['TaskQueue'] = None
def __new__(cls) -> 'TaskQueue':
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._redis_url = AlgorithmConfig.REDIS_URL
self._queue_name = AlgorithmConfig.RQ_QUEUE_NAME
self._connection: Optional[Redis] = None
self._queue: Optional[Queue] = None
self._initialized = True
@property
def connection(self) -> Redis:
"""获取 Redis 连接(懒加载)"""
if self._connection is None:
self._connection = Redis.from_url(self._redis_url)
return self._connection
@property
def queue(self) -> Queue:
"""获取 RQ 队列(懒加载)"""
if self._queue is None:
self._queue = Queue(self._queue_name, connection=self.connection)
return self._queue
def enqueue(
self,
func: Callable,
job_id: str,
timeout: str = '4h',
**kwargs: Any
) -> Optional[str]:
"""
将任务加入队列
Args:
func: 要执行的 worker 函数
job_id: 任务唯一标识
timeout: 超时时间 '4h', '30m'
**kwargs: 传递给 worker 函数的参数
Returns:
job_id None失败时
"""
try:
self.queue.enqueue(
func,
job_id=job_id,
job_timeout=timeout,
**kwargs
)
logger.info(f"Task enqueued: {job_id}")
return job_id
except Exception as e:
logger.error(f"Failed to enqueue task {job_id}: {e}")
return None
def fetch_job(self, job_id: str) -> Optional[Job]:
"""获取任务信息"""
try:
return Job.fetch(job_id, connection=self.connection)
except Exception as e:
logger.warning(f"Failed to fetch job {job_id}: {e}")
return None
def cancel_job(self, job_id: str) -> bool:
"""取消任务"""
try:
job = self.fetch_job(job_id)
if job:
job.cancel()
logger.info(f"Job cancelled: {job_id}")
return True
return False
except Exception as e:
logger.warning(f"Failed to cancel job {job_id}: {e}")
return False

@ -2,128 +2,152 @@
任务处理服务适配新数据库结构和路径配置
处理加噪微调热力图评估等核心业务逻辑
使用Redis Queue进行异步任务处理
已重构为面向对象设计推荐使用:
from app.services.task import TaskHandlerFactory
handler = TaskHandlerFactory.create('perturbation')
job_id = handler.start(task_id)
数据访问已迁移到 Repository :
from app.repositories import TaskRepository, TaskTypeRepository
task_repo = TaskRepository()
task = task_repo.get_by_id(task_id)
"""
import os
import logging
from datetime import datetime
from flask import current_app, jsonify
from typing import Optional
from flask import jsonify
from redis import Redis
from rq import Queue
from rq.job import Job
from app import db
from app.database import (
Task, TaskStatus, TaskType,
Perturbation, Finetune, Heatmap, Evaluate,
Image, ImageType, DataType,
PerturbationConfig, FinetuneConfig, User
)
from app.services.storage import PathManager
from config.algorithm_config import AlgorithmConfig
from config.settings import Config
logger = logging.getLogger(__name__)
# 全局单例实例
_path_manager: Optional[PathManager] = None
_task_repo = None
_task_type_repo = None
_task_status_repo = None
_user_repo = None
def _get_path_manager() -> PathManager:
"""获取路径管理器单例"""
global _path_manager
if _path_manager is None:
_path_manager = PathManager()
return _path_manager
def _get_task_repo():
"""获取任务 Repository 单例(懒加载)"""
global _task_repo
if _task_repo is None:
from app.repositories import TaskRepository
_task_repo = TaskRepository()
return _task_repo
def _get_task_type_repo():
"""获取任务类型 Repository 单例(懒加载)"""
global _task_type_repo
if _task_type_repo is None:
from app.repositories import TaskTypeRepository
_task_type_repo = TaskTypeRepository()
return _task_type_repo
def _get_task_status_repo():
"""获取任务状态 Repository 单例(懒加载)"""
global _task_status_repo
if _task_status_repo is None:
from app.repositories import TaskStatusRepository
_task_status_repo = TaskStatusRepository()
return _task_status_repo
def _get_user_repo():
"""获取用户 Repository 单例(懒加载)"""
global _user_repo
if _user_repo is None:
from app.repositories import UserRepository
_user_repo = UserRepository()
return _user_repo
def _get_task_handler(task_type: str):
"""获取任务处理器(懒加载导入避免循环依赖)"""
from app.services.task import TaskHandlerFactory
return TaskHandlerFactory.create(task_type)
class TaskService:
"""任务处理服务"""
# ==================== 路径工具函数 ====================
# ==================== 路径代理方法(委托给 PathManager====================
# 保持向后兼容,内部委托给 PathManager
@staticmethod
def _get_project_root():
"""获取项目根目录"""
return os.path.dirname(current_app.root_path)
return _get_path_manager().project_root
@staticmethod
def _build_path(*parts):
"""构建路径"""
return os.path.join(TaskService._get_project_root(), *parts)
return _get_path_manager()._build_path(*parts)
@staticmethod
def get_original_images_path(user_id, flow_id):
"""原图路径: ORIGINAL_IMAGES_FOLDER/user_id/flow_id"""
return TaskService._build_path(
Config.ORIGINAL_IMAGES_FOLDER,
str(user_id),
str(flow_id)
)
"""原图路径"""
return _get_path_manager().get_original_images_path(user_id, flow_id)
@staticmethod
def get_perturbed_images_path(user_id, flow_id):
"""加噪图路径: PERTURBED_IMAGES_FOLDER/user_id/flow_id"""
return TaskService._build_path(
Config.PERTURBED_IMAGES_FOLDER,
str(user_id),
str(flow_id)
)
"""加噪图路径"""
return _get_path_manager().get_perturbed_images_path(user_id, flow_id)
@staticmethod
def get_original_generated_path(user_id, flow_id, task_id):
"""原图生成图路径: MODEL_ORIGINAL_FOLDER/user_id/flow_id/task_id"""
return TaskService._build_path(
Config.MODEL_ORIGINAL_FOLDER,
str(user_id),
str(flow_id),
str(task_id)
)
"""原图生成图路径"""
return _get_path_manager().get_original_generated_path(user_id, flow_id, task_id)
@staticmethod
def get_perturbed_generated_path(user_id, flow_id, task_id):
"""加噪图生成图路径: MODEL_PERTURBED_FOLDER/user_id/flow_id/task_id"""
return TaskService._build_path(
Config.MODEL_PERTURBED_FOLDER,
str(user_id),
str(flow_id),
str(task_id)
)
"""加噪图生成图路径"""
return _get_path_manager().get_perturbed_generated_path(user_id, flow_id, task_id)
@staticmethod
def get_uploaded_generated_path(user_id, flow_id, task_id):
"""上传图生成图路径: MODEL_UPLOADED_FOLDER/user_id/flow_id/task_id"""
return TaskService._build_path(
Config.MODEL_UPLOADED_FOLDER,
str(user_id),
str(flow_id),
str(task_id)
)
"""上传图生成图路径"""
return _get_path_manager().get_uploaded_generated_path(user_id, flow_id, task_id)
@staticmethod
def get_heatmap_path(user_id, flow_id, task_id):
"""热力图路径: HEATDIF_SAVE_FOLDER/user_id/flow_id/task_id"""
return TaskService._build_path(
Config.HEATDIF_SAVE_FOLDER,
str(user_id),
str(flow_id),
str(task_id)
)
"""热力图路径"""
return _get_path_manager().get_heatmap_path(user_id, flow_id, task_id)
@staticmethod
def get_evaluate_path(user_id, flow_id, task_id):
"""数值结果路径: NUMBERS_SAVE_FOLDER/user_id/flow_id/task_id"""
return TaskService._build_path(
Config.NUMBERS_SAVE_FOLDER,
str(user_id),
str(flow_id),
str(task_id)
)
"""数值结果路径"""
return _get_path_manager().get_evaluate_path(user_id, flow_id, task_id)
@staticmethod
def get_class_data_path(user_id, flow_id):
"""类别数据路径: CLASS_DATA_FOLDER/user_id/flow_id"""
return TaskService._build_path(
Config.CLASS_DATA_FOLDER,
str(user_id),
str(flow_id)
)
@staticmethod
def get_model_data_path(user_id, flow_id):
"""模型数据路径: MODEL_DATA_FOLDER/user_id/flow_id"""
return TaskService._build_path(
Config.MODEL_DATA_FOLDER
)
"""类别数据路径"""
return _get_path_manager().get_class_data_path(user_id, flow_id)
@staticmethod
def get_model_data_path(user_id=None, flow_id=None):
"""模型数据路径"""
return _get_path_manager().get_model_data_path()
# ==================== 通用辅助功能 ====================
# 以下方法委托给 Repository 层,保持向后兼容
@staticmethod
def json_error(message, status_code=400):
@ -132,70 +156,63 @@ class TaskService:
@staticmethod
def get_task_type(code):
"""根据任务类型代码获取TaskType"""
return TaskType.query.filter_by(task_type_code=code).first()
"""根据任务类型代码获取TaskType(委托给 TaskTypeRepository"""
return _get_task_type_repo().get_by_code(code)
@staticmethod
def require_task_type(code):
"""确保任务类型存在"""
task_type = TaskService.get_task_type(code)
if not task_type:
raise ValueError(f"Task type '{code}' is not configured")
return task_type
"""确保任务类型存在(委托给 TaskTypeRepository"""
return _get_task_type_repo().require(code)
@staticmethod
def get_status_by_code(code):
"""根据状态代码获取TaskStatus"""
return TaskStatus.query.filter_by(task_status_code=code).first()
"""根据状态代码获取TaskStatus(委托给 TaskStatusRepository"""
return _get_task_status_repo().get_by_code(code)
@staticmethod
def ensure_status(code):
"""确保任务状态存在"""
status = TaskService.get_status_by_code(code)
if not status:
raise ValueError(f"Task status '{code}' is not configured")
return status
"""确保任务状态存在(委托给 TaskStatusRepository"""
return _get_task_status_repo().require(code)
@staticmethod
def generate_flow_id():
"""生成唯一的flow_id"""
base = int(datetime.utcnow().timestamp() * 1000)
while Task.query.filter_by(flow_id=base).first():
task_repo = _get_task_repo()
while task_repo.find_one_by(flow_id=base):
base += 1
return base
@staticmethod
def ensure_task_owner(task, user_id):
"""验证任务归属"""
return bool(task and task.user_id == user_id)
"""验证任务归属(委托给 TaskRepository"""
return _get_task_repo().is_owner(task, user_id)
@staticmethod
def get_task_type_code(task):
"""获取任务类型代码"""
return task.task_type.task_type_code if task and task.task_type else None
"""获取任务类型代码(委托给 TaskRepository"""
return _get_task_repo().get_type_code(task)
@staticmethod
def load_task_for_user(task_id, user_id, expected_type=None):
"""根据任务ID加载用户的任务可选检查类型"""
task = Task.query.get(task_id)
if not TaskService.ensure_task_owner(task, user_id):
task_repo = _get_task_repo()
task = task_repo.get_for_user(task_id, user_id)
if not task:
return None
if expected_type and not task_repo.is_type(task, expected_type):
return None
if expected_type:
task_type = TaskService.get_task_type_code(task)
if task_type != expected_type:
return None
return task
@staticmethod
def determine_finetune_source(finetune_task):
"""判断微调任务来源"""
perturb_type = TaskService.require_task_type('perturbation')
sibling_perturbation = Task.query.filter(
Task.flow_id == finetune_task.flow_id,
Task.tasks_type_id == perturb_type.task_type_id,
Task.tasks_id != finetune_task.tasks_id
).first()
return 'perturbation' if sibling_perturbation else 'uploaded'
task_repo = _get_task_repo()
sibling = task_repo.get_by_flow_and_type(finetune_task.flow_id, 'perturbation')
# 排除自身
if sibling and sibling.tasks_id != finetune_task.tasks_id:
return 'perturbation'
return 'uploaded'
@staticmethod
def serialize_task(task):
@ -251,8 +268,8 @@ class TaskService:
@staticmethod
def get_user(user_id):
"""获取用户"""
return User.query.get(user_id)
"""获取用户(委托给 UserRepository"""
return _get_user_repo().get_by_id(user_id)
# ==================== Redis/RQ 连接管理 ====================
@ -281,17 +298,14 @@ class TaskService:
任务状态信息
"""
try:
task = Task.query.get(task_id)
task_repo = _get_task_repo()
task = task_repo.get_by_id(task_id)
if not task:
return {'status': 'not_found', 'error': 'Task not found'}
# 获取任务状态名称
status = TaskStatus.query.get(task.tasks_status_id)
status_code = status.task_status_code if status else 'unknown'
# 获取任务类型
task_type = TaskType.query.get(task.tasks_type_id)
type_code = task_type.task_type_code if task_type else 'unknown'
# 使用 Repository 获取状态和类型代码
status_code = task.task_status.task_status_code if task.task_status else 'unknown'
type_code = task_repo.get_type_code(task) or 'unknown'
result = {
'task_id': task_id,
@ -345,13 +359,13 @@ class TaskService:
是否成功取消
"""
try:
task = Task.query.get(task_id)
task_repo = _get_task_repo()
task = task_repo.get_by_id(task_id)
if not task:
return False
# 获取任务类型
task_type = TaskType.query.get(task.tasks_type_id)
type_code = task_type.task_type_code if task_type else None
# 获取任务类型代码
type_code = task_repo.get_type_code(task)
# 尝试从队列中删除任务
try:
@ -362,19 +376,10 @@ class TaskService:
except Exception as e:
logger.warning(f"Could not cancel RQ job: {e}")
# 更新数据库状态
try:
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
except Exception as e:
db.session.rollback()
logger.error(f"Failed to update task status: {e}")
return False
return True
# 使用 Repository 更新状态
if task_repo.update_status(task, 'failed'):
return task_repo.save()
return False
except Exception as e:
logger.error(f"Error cancelling task: {e}")
@ -385,7 +390,7 @@ class TaskService:
@staticmethod
def start_perturbation_task(task_id):
"""
启动加噪任务
启动加噪任务委托给 PerturbationTaskHandler
Args:
task_id: 任务ID
@ -393,249 +398,33 @@ class TaskService:
Returns:
job_id
"""
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
# 获取Perturbation任务详情
perturbation = Perturbation.query.get(task_id)
if not perturbation:
logger.error(f"Perturbation task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 获取用户ID
user_id = task.user_id
# 路径配置
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
output_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id)
class_dir = TaskService.get_class_data_path(user_id, task.flow_id)
# 获取算法配置
pert_config = PerturbationConfig.query.get(perturbation.perturbation_configs_id)
if not pert_config:
logger.error(f"Perturbation config not found")
return None
algorithm_code = pert_config.perturbation_code
# 加入RQ队列
from app.workers.perturbation_worker import run_perturbation_task
queue = TaskService._get_queue()
job_id = f"pert_{task_id}"
job = queue.enqueue(
run_perturbation_task,
task_id=task_id,
input_dir=input_dir,
output_dir=output_dir,
class_dir=class_dir,
algorithm_code=algorithm_code,
epsilon=perturbation.perturbation_intensity,
job_id=job_id,
job_timeout='4h'
)
logger.info(f"Perturbation task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting perturbation task: {e}")
return None
return _get_task_handler('perturbation').start(task_id)
# ==================== Finetune 任务 ====================
@staticmethod
def start_finetune_task(task_id):
"""
启动微调任务支持两种类型
类型1基于加噪结果的微调
- 有相同flow_id的Perturbation任务
- 输入原图 + 加噪图
- 输出到original_generated perturbed_generated
启动微调任务委托给 FinetuneTaskHandler
类型2用户上传图片的微调
- 找不到相同flow_id的其他任务
- 输入仅原图
- 输出到uploaded_generated
支持两种类型:
- 基于加噪结果的微调
- 用户上传图片的微调
Args:
task_id: 任务ID
Returns:
job_id
job_id 或逗号分隔的多个 job_id
"""
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
# 获取Finetune任务详情
finetune = Finetune.query.get(task_id)
if not finetune:
logger.error(f"Finetune task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 获取用户ID
user_id = task.user_id
# 获取微调配置
ft_config = FinetuneConfig.query.get(finetune.finetune_configs_id)
if not ft_config:
logger.error(f"Finetune config not found")
return None
# 检测微调类型查找相同flow_id的Perturbation任务
perturb_type = TaskService.require_task_type('perturbation')
sibling_perturbation = Task.query.filter(
Task.flow_id == task.flow_id,
Task.tasks_type_id == perturb_type.task_type_id,
Task.tasks_id != task_id
).first()
has_perturbation = sibling_perturbation is not None
# 路径配置
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
class_dir = TaskService.get_class_data_path(user_id, task.flow_id)
model_data_dir = TaskService.get_model_data_path(user_id, task.flow_id)
if has_perturbation:
# 类型1基于加噪结果的微调
logger.info(f"Finetune task {task_id}: type=perturbation-based")
perturbed_input_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id)
original_input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, task_id)
original_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, task_id)
# 获取坐标保存路径3D可视化
original_coords_save_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(user_id),
str(task.flow_id),
str(task_id),
'original_coords.csv'
)
# 获取加噪坐标保存路径3D可视化
perturbed_coords_save_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(user_id),
str(task.flow_id),
str(task_id),
'perturbed_coords.csv'
)
# 加入RQ队列
from app.workers.finetune_worker import run_finetune_task
queue = TaskService._get_queue()
job_id_original = f"ft_{task_id}_original"
job_id_perturbed = f"ft_{task_id}_perturbed"
job_original = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_code,
train_images_dir=original_input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=original_coords_save_path,
validation_output_dir=original_output_dir,
finetune_type="original",
custom_params=None,
job_id=job_id_original,
job_timeout='8h'
)
job_perturbed = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_code,
train_images_dir=perturbed_input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=perturbed_coords_save_path,
validation_output_dir=perturbed_output_dir,
finetune_type="perturbed",
custom_params=None,
job_id=job_id_perturbed,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}")
return f"{job_id_original},{job_id_perturbed}"
else:
# 类型2用户上传图片的微调
logger.info(f"Finetune task {task_id}: type=uploaded")
uploaded_output_dir = TaskService.get_uploaded_generated_path(user_id, task.flow_id, task_id)
# 获取坐标保存路径
coords_save_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(user_id),
str(task.flow_id),
str(task_id),
'coords.csv'
)
# 加入RQ队列
from app.workers.finetune_worker import run_finetune_task
queue = TaskService._get_queue()
job_id = f"ft_{task_id}"
job = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_code,
train_images_dir=input_dir,
output_model_dir=model_data_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_output_dir=uploaded_output_dir,
finetune_type="uploaded",
custom_params=None,
job_id=job_id,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting finetune task: {e}")
return None
return _get_task_handler('finetune').start(task_id)
# ==================== Heatmap 任务 ====================
@staticmethod
def start_heatmap_task(task_id):
"""
启动热力图任务
启动热力图任务委托给 HeatmapTaskHandler
Args:
task_id: 任务ID
@ -643,93 +432,14 @@ class TaskService:
Returns:
job_id
"""
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
# 获取Heatmap任务详情
heatmap = Heatmap.query.get(task_id)
if not heatmap:
logger.error(f"Heatmap task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
# 从heatmap对象获取扰动图片ID
perturbed_image_id = heatmap.images_id
if not perturbed_image_id:
logger.error(f"Heatmap task {task_id} has no associated perturbed image")
return None
# 获取扰动图片信息
perturbed_image = Image.query.get(perturbed_image_id)
if not perturbed_image:
logger.error(f"Perturbed image {perturbed_image_id} not found")
return None
user_id = task.user_id
# 获取原图通过father_id关系
if not perturbed_image.father_id:
logger.error(f"Perturbed image {perturbed_image_id} has no father_id")
return None
original_image = Image.query.get(perturbed_image.father_id)
if not original_image:
logger.error(f"Original image not found")
return None
# 构建图片路径(使用 stored_filename
original_image_path = os.path.join(
TaskService.get_original_images_path(user_id, task.flow_id),
original_image.stored_filename
)
perturbed_image_path = os.path.join(
TaskService.get_perturbed_images_path(user_id, task.flow_id),
perturbed_image.stored_filename
)
# 输出目录
output_dir = TaskService.get_heatmap_path(user_id, task.flow_id, task_id)
# 加入RQ队列
from app.workers.heatmap_worker import run_heatmap_task
queue = TaskService._get_queue()
job_id = f"hm_{task_id}"
job = queue.enqueue(
run_heatmap_task,
task_id=task_id,
original_image_path=original_image_path,
perturbed_image_path=perturbed_image_path,
output_dir=output_dir,
perturbed_image_id=perturbed_image_id,
job_id=job_id,
job_timeout='2h'
)
logger.info(f"Heatmap task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting heatmap task: {e}")
return None
return _get_task_handler('heatmap').start(task_id)
# ==================== Evaluate 任务 ====================
@staticmethod
def start_evaluate_task(task_id):
"""
启动评估任务
启动评估任务委托给 EvaluateTaskHandler
Args:
task_id: 任务ID
@ -737,64 +447,4 @@ class TaskService:
Returns:
job_id
"""
try:
# 获取任务
task = Task.query.get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return None
# 获取Evaluate任务详情
evaluate = Evaluate.query.get(task_id)
if not evaluate:
logger.error(f"Evaluate task {task_id} not found")
return None
# 更新任务状态为 waiting
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
if waiting_status:
task.tasks_status_id = waiting_status.task_status_id
db.session.commit()
finetune = Finetune.query.get(evaluate.finetune_task_id)
if not finetune:
logger.error(f"Finetune task {evaluate.finetune_task_id} not found for evaluation {task_id}")
return None
finetune_task = finetune.task
if not finetune_task:
logger.error(f"Finetune task {evaluate.finetune_task_id} missing Task relation")
return None
user_id = finetune_task.user_id
# 路径配置
clean_ref_dir = TaskService.get_original_images_path(user_id, finetune_task.flow_id)
clean_output_dir = TaskService.get_original_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
output_dir = TaskService.get_evaluate_path(user_id, finetune_task.flow_id, task_id)
# 加入RQ队列
from app.workers.evaluate_worker import run_evaluate_task
queue = TaskService._get_queue()
job_id = f"eval_{task_id}"
job = queue.enqueue(
run_evaluate_task,
task_id=task_id,
clean_ref_dir=clean_ref_dir,
clean_output_dir=clean_output_dir,
perturbed_output_dir=perturbed_output_dir,
output_dir=output_dir,
image_size=512,
job_id=job_id,
job_timeout='2h'
)
logger.info(f"Evaluate task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
logger.error(f"Error starting evaluate task: {e}")
return None
return _get_task_handler('evaluate').start(task_id)

Loading…
Cancel
Save