refactor: 重构controllers层

pull/7/head
梁浩 3 months ago
parent 0ce6747b76
commit ae261a6dc9

@ -7,7 +7,6 @@ from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from app import db
from app.database import User, UserConfig
from app.services.auth_service import AuthService
from functools import wraps
import re

@ -1,203 +1,128 @@
"""
图像管理控制器
处理图像下载查看等功能
"""
from flask import Blueprint, send_file, jsonify, request, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from app.database import Image, EvaluationResult
from app.services.image_service import ImageService
import os
image_bp = Blueprint('image', __name__)
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@jwt_required()
def get_image_file(image_id):
"""获取图片文件"""
try:
current_user_id = get_jwt_identity()
# 查找图片记录
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 检查文件是否存在
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(image.file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取图片失败: {str(e)}'}), 500
@image_bp.route('/download/<int:image_id>', methods=['GET'])
@jwt_required()
def download_image(image_id):
"""下载图片文件"""
try:
current_user_id = get_jwt_identity()
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(
image.file_path,
as_attachment=True,
download_name=image.original_filename or f"image_{image_id}.jpg"
)
except Exception as e:
return jsonify({'error': f'下载图片失败: {str(e)}'}), 500
@image_bp.route('/batch/<int:batch_id>/download', methods=['GET'])
@jwt_required()
def download_batch_images(batch_id):
"""批量下载任务中的加噪后图片"""
try:
current_user_id = get_jwt_identity()
# 获取任务中的加噪图片
perturbed_images = Image.query.join(Image.image_type).filter(
Image.batch_id == batch_id,
Image.user_id == current_user_id,
Image.image_type.has(type_code='perturbed')
).all()
if not perturbed_images:
return jsonify({'error': '没有找到加噪后的图片'}), 404
# 创建ZIP文件
import zipfile
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file:
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
for image in perturbed_images:
if os.path.exists(image.file_path):
arcname = image.original_filename or f"perturbed_{image.id}.jpg"
zip_file.write(image.file_path, arcname)
return send_file(
tmp_file.name,
as_attachment=True,
download_name=f"batch_{batch_id}_perturbed_images.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({'error': f'批量下载失败: {str(e)}'}), 500
@image_bp.route('/<int:image_id>/evaluations', methods=['GET'])
@jwt_required()
def get_image_evaluations(image_id):
"""获取图片的评估结果"""
try:
current_user_id = get_jwt_identity()
# 验证图片权限
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 获取以该图片为参考或目标的评估结果
evaluations = EvaluationResult.query.filter(
(EvaluationResult.reference_image_id == image_id) |
(EvaluationResult.target_image_id == image_id)
).all()
return jsonify({
'image_id': image_id,
'evaluations': [eval_result.to_dict() for eval_result in evaluations]
}), 200
except Exception as e:
return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500
@image_bp.route('/compare', methods=['POST'])
@jwt_required()
def compare_images():
"""对比两张图片"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
image1_id = data.get('image1_id')
image2_id = data.get('image2_id')
if not image1_id or not image2_id:
return jsonify({'error': '请提供两张图片的ID'}), 400
# 验证图片权限
image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first()
image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first()
if not image1 or not image2:
return jsonify({'error': '图片不存在或无权限'}), 404
# 查找现有的评估结果
evaluation = EvaluationResult.query.filter_by(
reference_image_id=image1_id,
target_image_id=image2_id
).first()
if not evaluation:
# 如果没有评估结果,返回基本对比信息
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': None,
'message': '暂无评估数据,请等待任务处理完成'
}), 200
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': evaluation.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'图片对比失败: {str(e)}'}), 500
@image_bp.route('/heatmap/<path:heatmap_path>', methods=['GET'])
@jwt_required()
def get_heatmap(heatmap_path):
"""获取热力图文件"""
try:
# 安全检查,防止路径遍历攻击
if '..' in heatmap_path or heatmap_path.startswith('/'):
return jsonify({'error': '无效的文件路径'}), 400
# 修正路径构建 - 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path))
if not os.path.exists(full_path):
return jsonify({'error': '热力图文件不存在'}), 404
return send_file(full_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500
@image_bp.route('/delete/<int:image_id>', methods=['DELETE'])
@jwt_required()
def delete_image(image_id):
"""删除图片"""
try:
current_user_id = get_jwt_identity()
result = ImageService.delete_image(image_id, current_user_id)
if result['success']:
return jsonify({'message': '图片删除成功'}), 200
else:
return jsonify({'error': result['error']}), 400
except Exception as e:
return jsonify({'error': f'删除图片失败: {str(e)}'}), 500
"""
图像管理控制器
负责图片上传下载等操作
"""
from flask import Blueprint, request, jsonify, send_file
from app.controllers.auth_controller import int_jwt_required
from app.services.task_service import TaskService
from app.services.image_service import ImageService
image_bp = Blueprint('image', __name__)
# ==================== 图片上传 ====================
@image_bp.route('/original', methods=['POST'])
@int_jwt_required
def upload_original_images(current_user_id):
task_id = request.form.get('task_id', type=int)
if not task_id:
return ImageService.json_error('缺少 task_id 参数')
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
task_type = TaskService.get_task_type_code(task)
if task_type not in {'perturbation', 'finetune'}:
return ImageService.json_error('任务类型不支持图片上传', 400)
files = request.files.getlist('files')
target_dir = TaskService.get_original_images_path(task.user_id, task.flow_id)
success, result = ImageService.save_original_images(task, files, target_dir)
if not success:
status_code = 400
if isinstance(result, str) and (result.startswith('未配置图片类型') or '失败' in result):
status_code = 500
return ImageService.json_error(result, status_code)
return jsonify({
'message': '图片上传成功',
'images': [ImageService.serialize_image(img) for img in result],
'flow_id': task.flow_id
}), 201
# ==================== 结果下载 ====================
@image_bp.route('/perturbation/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_perturbation_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('结果文件不存在', 404)
filename = f"perturbation_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/heatmap/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_heatmap_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('热力图文件不存在', 404)
filename = f"heatmap_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/finetune/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_finetune_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
if not task.finetune:
return ImageService.json_error('微调任务配置不存在', 404)
try:
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return ImageService.json_error(str(exc), 500)
if source == 'perturbation':
directories = {
'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id),
'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
else:
directories = {
'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
zipped, has_files = ImageService.zip_multiple_directories(directories)
if not has_files:
return ImageService.json_error('微调结果文件不存在', 404)
filename = f"finetune_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
@image_bp.route('/evaluate/<int:task_id>/download', methods=['GET'])
@int_jwt_required
def download_evaluate_result(task_id, current_user_id):
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('评估结果文件不存在', 404)
filename = f"evaluate_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')

File diff suppressed because it is too large Load Diff

@ -1,129 +1,119 @@
"""
用户管理控制器
处理用户配置等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.database import User, UserConfig, Perturbation, Finetune
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
"""获取用户配置"""
try:
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
# 如果没有配置,创建默认配置
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'config': user_config.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
"""更新用户配置"""
try:
data = request.get_json()
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置字段
if 'perturbation_configs_id' in data:
user_config.perturbation_configs_id = data['perturbation_configs_id']
if 'perturbation_intensity' in data:
intensity = float(data['perturbation_intensity'])
if 0 < epsilon <= 255:
user_config.perturbation_intensity = intensity
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'finetune_config_id' in data:
user_config.finetune_config_id = data['finetune_config_id']
db.session.commit()
return jsonify({
'message': '用户配置更新成功',
'config': user_config.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
@user_bp.route('/algorithms', methods=['GET'])
@jwt_required()
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = Perturbation.query.all()
finetune_configs = Finetune.query.all()
return jsonify({
'perturbation_algorithms': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
} for config in perturbation_configs
],
'finetune_methods': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description
} for config in finetune_configs
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
@user_bp.route('/stats', methods=['GET'])
@int_jwt_required
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.database import Task, Image
# 统计用户的任务和图片数量
total_tasks = Task.query.filter_by(user_id=current_user_id).count()
completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count()
return jsonify({
'stats': {
'total_tasks': total_tasks,
'completed_tasks': completed_tasks,
'processing_tasks': processing_tasks,
'failed_tasks': failed_tasks,
'total_images': total_images
}
}), 200
except Exception as e:
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500
"""
用户管理控制器
负责用户配置任务汇总等接口
"""
from flask import Blueprint, request, jsonify
from app import db
from app.controllers.auth_controller import int_jwt_required
from app.database import UserConfig, Task, TaskType, TaskStatus
user_bp = Blueprint('user', __name__)
def _json_error(message, status_code=400):
return jsonify({'error': message}), status_code
def _get_or_create_user_config(user_id):
config = UserConfig.query.filter_by(user_id=user_id).first()
if not config:
config = UserConfig(user_id=user_id)
db.session.add(config)
db.session.commit()
return config
def _serialize_config(config):
return {
'user_configs_id': config.user_configs_id,
'user_id': config.user_id,
'data_type_id': config.data_type_id,
'perturbation_configs_id': config.perturbation_configs_id,
'perturbation_intensity': config.perturbation_intensity,
'finetune_configs_id': config.finetune_configs_id,
'created_at': config.created_at.isoformat() if config.created_at else None,
'updated_at': config.updated_at.isoformat() if config.updated_at else None,
}
def _serialize_task(task):
status_code = task.task_status.task_status_code if task.task_status else None
task_type_code = task.task_type.task_type_code if task.task_type else None
return {
'task_id': task.tasks_id,
'flow_id': task.flow_id,
'task_type': task_type_code,
'status': status_code,
'created_at': task.created_at.isoformat() if task.created_at else None,
'started_at': task.started_at.isoformat() if task.started_at else None,
'finished_at': task.finished_at.isoformat() if task.finished_at else None,
'description': task.description,
'error_message': task.error_message
}
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
return jsonify({'config': _serialize_config(config)}), 200
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
data = request.get_json() or {}
allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'}
for key, value in data.items():
if key in allowed_fields:
if key == 'perturbation_intensity' and value is not None:
try:
value = float(value)
except (TypeError, ValueError):
return _json_error('perturbation_intensity 参数格式不正确')
setattr(config, key, value)
try:
db.session.commit()
return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200
except Exception as exc:
db.session.rollback()
return _json_error(f'更新配置失败: {exc}', 500)
@user_bp.route('/tasks', methods=['GET'])
@int_jwt_required
def list_user_tasks(current_user_id):
task_type_code = request.args.get('type')
status_code = request.args.get('status')
query = Task.query.filter_by(user_id=current_user_id)
if task_type_code:
task_type = TaskType.query.filter_by(task_type_code=task_type_code).first()
if not task_type:
return _json_error('任务类型不存在', 404)
query = query.filter(Task.tasks_type_id == task_type.task_type_id)
if status_code:
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
return _json_error('任务状态不存在', 404)
query = query.filter(Task.tasks_status_id == status.task_status_id)
tasks = query.order_by(Task.created_at.desc()).all()
return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200
@user_bp.route('/tasks/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_user_task(task_id, current_user_id):
task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first()
if not task:
return _json_error('任务不存在或无权限', 404)
return jsonify({'task': _serialize_task(task)}), 200

@ -3,16 +3,18 @@
处理图像上传保存等功能
"""
import io
import os
import uuid
import zipfile
import fcntl
import time
from datetime import datetime
from werkzeug.utils import secure_filename
from flask import current_app
from flask import current_app, jsonify
from PIL import Image as PILImage
from app import db
from app.database import Image
from app.database import Image, ImageType
from app.utils.file_utils import allowed_file
class ImageService:
@ -254,4 +256,173 @@ class ImageService:
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
return {'success': False, 'error': f'删除图片失败: {str(e)}'}
# ==================== 控制器辅助功能 ====================
DEFAULT_TARGET_SIZE = 512
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp'}
@staticmethod
def json_error(message, status_code=400):
"""统一错误响应"""
return jsonify({'error': message}), status_code
@staticmethod
def get_image_type_by_code(code):
"""根据代码获取图片类型"""
return ImageType.query.filter_by(image_code=code).first()
@staticmethod
def save_original_images(task, files, target_dir, image_type_code='original', target_size=None):
"""保存原图上传"""
if not files:
return False, '未检测到文件上传'
image_type = ImageService.get_image_type_by_code(image_type_code)
if not image_type:
return False, f'未配置图片类型: {image_type_code}'
os.makedirs(target_dir, exist_ok=True)
saved_records = []
saved_paths = []
size = target_size or ImageService.DEFAULT_TARGET_SIZE
try:
for file in files:
if not file or not file.filename:
continue
if not allowed_file(file.filename):
continue
extension = os.path.splitext(file.filename)[1].lower()
if extension not in ImageService.IMAGE_EXTENSIONS:
continue
processed = ImageService._prepare_image(file, size)
filename, path, width, height, file_size = ImageService._save_processed_image(processed, target_dir)
image = ImageService._create_image_record(
task,
image_type.image_types_id,
filename,
path,
width,
height,
file_size
)
saved_records.append(image)
saved_paths.append(path)
if not saved_records:
db.session.rollback()
return False, '未上传有效的图片文件'
db.session.commit()
return True, saved_records
except Exception as exc:
db.session.rollback()
for path in saved_paths:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
return False, f'上传图片失败: {exc}'
@staticmethod
def _prepare_image(file_storage, target_size):
"""裁剪并缩放上传图片"""
file_storage.stream.seek(0)
image = PILImage.open(file_storage.stream).convert('RGB')
width, height = image.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
image = image.crop((left, top, left + min_dim, top + min_dim))
return image.resize((target_size, target_size), resample=PILImage.Resampling.LANCZOS)
@staticmethod
def _save_processed_image(image, target_dir):
"""将处理后的图片保存为PNG"""
timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png"
path = os.path.join(target_dir, filename)
image.save(path, format='PNG')
return filename, path, image.width, image.height, os.path.getsize(path)
@staticmethod
def _create_image_record(task, image_type_id, filename, path, width, height, file_size, father_id=None):
"""创建图片数据库记录"""
image = Image(
task_id=task.tasks_id,
image_types_id=image_type_id,
father_id=father_id,
stored_filename=filename,
file_path=path,
file_size=file_size,
width=width,
height=height
)
db.session.add(image)
return image
@staticmethod
def zip_directory(directory):
"""打包目录为zip"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if os.path.isdir(directory):
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, directory)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def zip_multiple_directories(directories):
"""打包多个目录"""
buffer = io.BytesIO()
has_files = False
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if isinstance(directories, dict):
iterable = directories.items()
else:
iterable = ((os.path.basename(d.rstrip(os.sep)) or 'output', d) for d in directories)
for label, directory in iterable:
if not os.path.isdir(directory):
continue
for root, _, files in os.walk(directory):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, directory)
arcname = os.path.join(label or 'output', rel_path)
zipf.write(file_path, arcname)
has_files = True
buffer.seek(0)
return buffer, has_files
@staticmethod
def serialize_image(image):
"""图片序列化"""
if not image:
return None
return {
'image_id': image.images_id,
'task_id': image.task_id,
'stored_filename': image.stored_filename,
'file_path': image.file_path,
'file_size': image.file_size,
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}

@ -7,7 +7,7 @@
import os
import logging
from datetime import datetime
from flask import current_app
from flask import current_app, jsonify
from redis import Redis
from rq import Queue
from rq.job import Job
@ -16,7 +16,7 @@ from app.database import (
Task, TaskStatus, TaskType,
Perturbation, Finetune, Heatmap, Evaluate,
Image, ImageType, DataType,
PerturbationConfig, FinetuneConfig
PerturbationConfig, FinetuneConfig, User
)
from config.algorithm_config import AlgorithmConfig
from config.settings import Config
@ -116,6 +116,135 @@ class TaskService:
str(flow_id)
)
# ==================== 通用辅助功能 ====================
@staticmethod
def json_error(message, status_code=400):
"""统一的错误响应"""
return jsonify({'error': message}), status_code
@staticmethod
def get_task_type(code):
"""根据任务类型代码获取TaskType"""
return TaskType.query.filter_by(task_type_code=code).first()
@staticmethod
def require_task_type(code):
"""确保任务类型存在"""
task_type = TaskService.get_task_type(code)
if not task_type:
raise ValueError(f"Task type '{code}' is not configured")
return task_type
@staticmethod
def get_status_by_code(code):
"""根据状态代码获取TaskStatus"""
return TaskStatus.query.filter_by(task_status_code=code).first()
@staticmethod
def ensure_status(code):
"""确保任务状态存在"""
status = TaskService.get_status_by_code(code)
if not status:
raise ValueError(f"Task status '{code}' is not configured")
return status
@staticmethod
def generate_flow_id():
"""生成唯一的flow_id"""
base = int(datetime.utcnow().timestamp() * 1000)
while Task.query.filter_by(flow_id=base).first():
base += 1
return base
@staticmethod
def ensure_task_owner(task, user_id):
"""验证任务归属"""
return bool(task and task.user_id == user_id)
@staticmethod
def get_task_type_code(task):
"""获取任务类型代码"""
return task.task_type.task_type_code if task and task.task_type else None
@staticmethod
def load_task_for_user(task_id, user_id, expected_type=None):
"""根据任务ID加载用户的任务可选检查类型"""
task = Task.query.get(task_id)
if not TaskService.ensure_task_owner(task, user_id):
return None
if expected_type:
task_type = TaskService.get_task_type_code(task)
if task_type != expected_type:
return None
return task
@staticmethod
def determine_finetune_source(finetune_task):
"""判断微调任务来源"""
perturb_type = TaskService.require_task_type('perturbation')
sibling_perturbation = Task.query.filter(
Task.flow_id == finetune_task.flow_id,
Task.tasks_type_id == perturb_type.task_type_id,
Task.tasks_id != finetune_task.tasks_id
).first()
return 'perturbation' if sibling_perturbation else 'uploaded'
@staticmethod
def serialize_task(task):
"""任务序列化"""
task_type = TaskService.get_task_type_code(task)
status = task.task_status.task_status_code if task and task.task_status else None
base = {
'task_id': task.tasks_id,
'flow_id': task.flow_id,
'task_type': task_type,
'status': status,
'user_id': task.user_id,
'description': task.description,
'created_at': task.created_at.isoformat() if task.created_at else None,
'started_at': task.started_at.isoformat() if task.started_at else None,
'finished_at': task.finished_at.isoformat() if task.finished_at else None,
'error_message': task.error_message,
}
if task_type == 'perturbation' and task.perturbation:
base['perturbation'] = {
'data_type_id': task.perturbation.data_type_id,
'perturbation_configs_id': task.perturbation.perturbation_configs_id,
'perturbation_intensity': float(task.perturbation.perturbation_intensity),
'perturbation_name': task.perturbation.perturbation_name,
}
elif task_type == 'finetune' and task.finetune:
try:
source = TaskService.determine_finetune_source(task)
except ValueError:
source = 'uploaded'
base['finetune'] = {
'finetune_configs_id': task.finetune.finetune_configs_id,
'data_type_id': task.finetune.data_type_id,
'finetune_name': task.finetune.finetune_name,
'source': source
}
elif task_type == 'heatmap' and task.heatmap:
base['heatmap'] = {
'perturbed_image_id': task.heatmap.images_id,
'heatmap_name': task.heatmap.heatmap_name
}
elif task_type == 'evaluate' and task.evaluation:
base['evaluate'] = {
'finetune_configs_id': task.evaluation.finetune_configs_id,
'evaluate_name': task.evaluation.evaluate_name,
'evaluation_results_id': task.evaluation.evaluation_results_id
}
return base
@staticmethod
def get_user(user_id):
"""获取用户"""
return User.query.get(user_id)
# ==================== Redis/RQ 连接管理 ====================
@staticmethod

Loading…
Cancel
Save