|
|
|
|
@ -2,128 +2,152 @@
|
|
|
|
|
任务处理服务(适配新数据库结构和路径配置)
|
|
|
|
|
处理加噪、微调、热力图、评估等核心业务逻辑
|
|
|
|
|
使用Redis Queue进行异步任务处理
|
|
|
|
|
|
|
|
|
|
已重构为面向对象设计,推荐使用:
|
|
|
|
|
from app.services.task import TaskHandlerFactory
|
|
|
|
|
|
|
|
|
|
handler = TaskHandlerFactory.create('perturbation')
|
|
|
|
|
job_id = handler.start(task_id)
|
|
|
|
|
|
|
|
|
|
数据访问已迁移到 Repository 层:
|
|
|
|
|
from app.repositories import TaskRepository, TaskTypeRepository
|
|
|
|
|
|
|
|
|
|
task_repo = TaskRepository()
|
|
|
|
|
task = task_repo.get_by_id(task_id)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from flask import current_app, jsonify
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from flask import jsonify
|
|
|
|
|
from redis import Redis
|
|
|
|
|
from rq import Queue
|
|
|
|
|
from rq.job import Job
|
|
|
|
|
from app import db
|
|
|
|
|
from app.database import (
|
|
|
|
|
Task, TaskStatus, TaskType,
|
|
|
|
|
Perturbation, Finetune, Heatmap, Evaluate,
|
|
|
|
|
Image, ImageType, DataType,
|
|
|
|
|
PerturbationConfig, FinetuneConfig, User
|
|
|
|
|
)
|
|
|
|
|
from app.services.storage import PathManager
|
|
|
|
|
from config.algorithm_config import AlgorithmConfig
|
|
|
|
|
from config.settings import Config
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# 全局单例实例
|
|
|
|
|
_path_manager: Optional[PathManager] = None
|
|
|
|
|
_task_repo = None
|
|
|
|
|
_task_type_repo = None
|
|
|
|
|
_task_status_repo = None
|
|
|
|
|
_user_repo = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_path_manager() -> PathManager:
|
|
|
|
|
"""获取路径管理器单例"""
|
|
|
|
|
global _path_manager
|
|
|
|
|
if _path_manager is None:
|
|
|
|
|
_path_manager = PathManager()
|
|
|
|
|
return _path_manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_task_repo():
|
|
|
|
|
"""获取任务 Repository 单例(懒加载)"""
|
|
|
|
|
global _task_repo
|
|
|
|
|
if _task_repo is None:
|
|
|
|
|
from app.repositories import TaskRepository
|
|
|
|
|
_task_repo = TaskRepository()
|
|
|
|
|
return _task_repo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_task_type_repo():
|
|
|
|
|
"""获取任务类型 Repository 单例(懒加载)"""
|
|
|
|
|
global _task_type_repo
|
|
|
|
|
if _task_type_repo is None:
|
|
|
|
|
from app.repositories import TaskTypeRepository
|
|
|
|
|
_task_type_repo = TaskTypeRepository()
|
|
|
|
|
return _task_type_repo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_task_status_repo():
|
|
|
|
|
"""获取任务状态 Repository 单例(懒加载)"""
|
|
|
|
|
global _task_status_repo
|
|
|
|
|
if _task_status_repo is None:
|
|
|
|
|
from app.repositories import TaskStatusRepository
|
|
|
|
|
_task_status_repo = TaskStatusRepository()
|
|
|
|
|
return _task_status_repo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_user_repo():
|
|
|
|
|
"""获取用户 Repository 单例(懒加载)"""
|
|
|
|
|
global _user_repo
|
|
|
|
|
if _user_repo is None:
|
|
|
|
|
from app.repositories import UserRepository
|
|
|
|
|
_user_repo = UserRepository()
|
|
|
|
|
return _user_repo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_task_handler(task_type: str):
|
|
|
|
|
"""获取任务处理器(懒加载导入避免循环依赖)"""
|
|
|
|
|
from app.services.task import TaskHandlerFactory
|
|
|
|
|
return TaskHandlerFactory.create(task_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskService:
|
|
|
|
|
"""任务处理服务"""
|
|
|
|
|
|
|
|
|
|
# ==================== 路径工具函数 ====================
|
|
|
|
|
# ==================== 路径代理方法(委托给 PathManager)====================
|
|
|
|
|
# 保持向后兼容,内部委托给 PathManager
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_project_root():
|
|
|
|
|
"""获取项目根目录"""
|
|
|
|
|
return os.path.dirname(current_app.root_path)
|
|
|
|
|
return _get_path_manager().project_root
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _build_path(*parts):
|
|
|
|
|
"""构建路径"""
|
|
|
|
|
return os.path.join(TaskService._get_project_root(), *parts)
|
|
|
|
|
return _get_path_manager()._build_path(*parts)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_original_images_path(user_id, flow_id):
|
|
|
|
|
"""原图路径: ORIGINAL_IMAGES_FOLDER/user_id/flow_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.ORIGINAL_IMAGES_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id)
|
|
|
|
|
)
|
|
|
|
|
"""原图路径"""
|
|
|
|
|
return _get_path_manager().get_original_images_path(user_id, flow_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_perturbed_images_path(user_id, flow_id):
|
|
|
|
|
"""加噪图路径: PERTURBED_IMAGES_FOLDER/user_id/flow_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.PERTURBED_IMAGES_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id)
|
|
|
|
|
)
|
|
|
|
|
"""加噪图路径"""
|
|
|
|
|
return _get_path_manager().get_perturbed_images_path(user_id, flow_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_original_generated_path(user_id, flow_id, task_id):
|
|
|
|
|
"""原图生成图路径: MODEL_ORIGINAL_FOLDER/user_id/flow_id/task_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.MODEL_ORIGINAL_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id),
|
|
|
|
|
str(task_id)
|
|
|
|
|
)
|
|
|
|
|
"""原图生成图路径"""
|
|
|
|
|
return _get_path_manager().get_original_generated_path(user_id, flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_perturbed_generated_path(user_id, flow_id, task_id):
|
|
|
|
|
"""加噪图生成图路径: MODEL_PERTURBED_FOLDER/user_id/flow_id/task_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.MODEL_PERTURBED_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id),
|
|
|
|
|
str(task_id)
|
|
|
|
|
)
|
|
|
|
|
"""加噪图生成图路径"""
|
|
|
|
|
return _get_path_manager().get_perturbed_generated_path(user_id, flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_uploaded_generated_path(user_id, flow_id, task_id):
|
|
|
|
|
"""上传图生成图路径: MODEL_UPLOADED_FOLDER/user_id/flow_id/task_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.MODEL_UPLOADED_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id),
|
|
|
|
|
str(task_id)
|
|
|
|
|
)
|
|
|
|
|
"""上传图生成图路径"""
|
|
|
|
|
return _get_path_manager().get_uploaded_generated_path(user_id, flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_heatmap_path(user_id, flow_id, task_id):
|
|
|
|
|
"""热力图路径: HEATDIF_SAVE_FOLDER/user_id/flow_id/task_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.HEATDIF_SAVE_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id),
|
|
|
|
|
str(task_id)
|
|
|
|
|
)
|
|
|
|
|
"""热力图路径"""
|
|
|
|
|
return _get_path_manager().get_heatmap_path(user_id, flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_evaluate_path(user_id, flow_id, task_id):
|
|
|
|
|
"""数值结果路径: NUMBERS_SAVE_FOLDER/user_id/flow_id/task_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.NUMBERS_SAVE_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id),
|
|
|
|
|
str(task_id)
|
|
|
|
|
)
|
|
|
|
|
"""数值结果路径"""
|
|
|
|
|
return _get_path_manager().get_evaluate_path(user_id, flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_class_data_path(user_id, flow_id):
|
|
|
|
|
"""类别数据路径: CLASS_DATA_FOLDER/user_id/flow_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.CLASS_DATA_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(flow_id)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_model_data_path(user_id, flow_id):
|
|
|
|
|
"""模型数据路径: MODEL_DATA_FOLDER/user_id/flow_id"""
|
|
|
|
|
return TaskService._build_path(
|
|
|
|
|
Config.MODEL_DATA_FOLDER
|
|
|
|
|
)
|
|
|
|
|
"""类别数据路径"""
|
|
|
|
|
return _get_path_manager().get_class_data_path(user_id, flow_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_model_data_path(user_id=None, flow_id=None):
|
|
|
|
|
"""模型数据路径"""
|
|
|
|
|
return _get_path_manager().get_model_data_path()
|
|
|
|
|
|
|
|
|
|
# ==================== 通用辅助功能 ====================
|
|
|
|
|
# 以下方法委托给 Repository 层,保持向后兼容
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def json_error(message, status_code=400):
|
|
|
|
|
@ -132,70 +156,63 @@ class TaskService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_task_type(code):
|
|
|
|
|
"""根据任务类型代码获取TaskType"""
|
|
|
|
|
return TaskType.query.filter_by(task_type_code=code).first()
|
|
|
|
|
"""根据任务类型代码获取TaskType(委托给 TaskTypeRepository)"""
|
|
|
|
|
return _get_task_type_repo().get_by_code(code)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def require_task_type(code):
|
|
|
|
|
"""确保任务类型存在"""
|
|
|
|
|
task_type = TaskService.get_task_type(code)
|
|
|
|
|
if not task_type:
|
|
|
|
|
raise ValueError(f"Task type '{code}' is not configured")
|
|
|
|
|
return task_type
|
|
|
|
|
"""确保任务类型存在(委托给 TaskTypeRepository)"""
|
|
|
|
|
return _get_task_type_repo().require(code)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_status_by_code(code):
|
|
|
|
|
"""根据状态代码获取TaskStatus"""
|
|
|
|
|
return TaskStatus.query.filter_by(task_status_code=code).first()
|
|
|
|
|
"""根据状态代码获取TaskStatus(委托给 TaskStatusRepository)"""
|
|
|
|
|
return _get_task_status_repo().get_by_code(code)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def ensure_status(code):
|
|
|
|
|
"""确保任务状态存在"""
|
|
|
|
|
status = TaskService.get_status_by_code(code)
|
|
|
|
|
if not status:
|
|
|
|
|
raise ValueError(f"Task status '{code}' is not configured")
|
|
|
|
|
return status
|
|
|
|
|
"""确保任务状态存在(委托给 TaskStatusRepository)"""
|
|
|
|
|
return _get_task_status_repo().require(code)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def generate_flow_id():
|
|
|
|
|
"""生成唯一的flow_id"""
|
|
|
|
|
base = int(datetime.utcnow().timestamp() * 1000)
|
|
|
|
|
while Task.query.filter_by(flow_id=base).first():
|
|
|
|
|
task_repo = _get_task_repo()
|
|
|
|
|
while task_repo.find_one_by(flow_id=base):
|
|
|
|
|
base += 1
|
|
|
|
|
return base
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def ensure_task_owner(task, user_id):
|
|
|
|
|
"""验证任务归属"""
|
|
|
|
|
return bool(task and task.user_id == user_id)
|
|
|
|
|
"""验证任务归属(委托给 TaskRepository)"""
|
|
|
|
|
return _get_task_repo().is_owner(task, user_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_task_type_code(task):
|
|
|
|
|
"""获取任务类型代码"""
|
|
|
|
|
return task.task_type.task_type_code if task and task.task_type else None
|
|
|
|
|
"""获取任务类型代码(委托给 TaskRepository)"""
|
|
|
|
|
return _get_task_repo().get_type_code(task)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def load_task_for_user(task_id, user_id, expected_type=None):
|
|
|
|
|
"""根据任务ID加载用户的任务,可选检查类型"""
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not TaskService.ensure_task_owner(task, user_id):
|
|
|
|
|
task_repo = _get_task_repo()
|
|
|
|
|
task = task_repo.get_for_user(task_id, user_id)
|
|
|
|
|
if not task:
|
|
|
|
|
return None
|
|
|
|
|
if expected_type and not task_repo.is_type(task, expected_type):
|
|
|
|
|
return None
|
|
|
|
|
if expected_type:
|
|
|
|
|
task_type = TaskService.get_task_type_code(task)
|
|
|
|
|
if task_type != expected_type:
|
|
|
|
|
return None
|
|
|
|
|
return task
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def determine_finetune_source(finetune_task):
|
|
|
|
|
"""判断微调任务来源"""
|
|
|
|
|
perturb_type = TaskService.require_task_type('perturbation')
|
|
|
|
|
sibling_perturbation = Task.query.filter(
|
|
|
|
|
Task.flow_id == finetune_task.flow_id,
|
|
|
|
|
Task.tasks_type_id == perturb_type.task_type_id,
|
|
|
|
|
Task.tasks_id != finetune_task.tasks_id
|
|
|
|
|
).first()
|
|
|
|
|
return 'perturbation' if sibling_perturbation else 'uploaded'
|
|
|
|
|
task_repo = _get_task_repo()
|
|
|
|
|
sibling = task_repo.get_by_flow_and_type(finetune_task.flow_id, 'perturbation')
|
|
|
|
|
# 排除自身
|
|
|
|
|
if sibling and sibling.tasks_id != finetune_task.tasks_id:
|
|
|
|
|
return 'perturbation'
|
|
|
|
|
return 'uploaded'
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def serialize_task(task):
|
|
|
|
|
@ -251,8 +268,8 @@ class TaskService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_user(user_id):
|
|
|
|
|
"""获取用户"""
|
|
|
|
|
return User.query.get(user_id)
|
|
|
|
|
"""获取用户(委托给 UserRepository)"""
|
|
|
|
|
return _get_user_repo().get_by_id(user_id)
|
|
|
|
|
|
|
|
|
|
# ==================== Redis/RQ 连接管理 ====================
|
|
|
|
|
|
|
|
|
|
@ -281,17 +298,14 @@ class TaskService:
|
|
|
|
|
任务状态信息
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
task_repo = _get_task_repo()
|
|
|
|
|
task = task_repo.get_by_id(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
return {'status': 'not_found', 'error': 'Task not found'}
|
|
|
|
|
|
|
|
|
|
# 获取任务状态名称
|
|
|
|
|
status = TaskStatus.query.get(task.tasks_status_id)
|
|
|
|
|
status_code = status.task_status_code if status else 'unknown'
|
|
|
|
|
|
|
|
|
|
# 获取任务类型
|
|
|
|
|
task_type = TaskType.query.get(task.tasks_type_id)
|
|
|
|
|
type_code = task_type.task_type_code if task_type else 'unknown'
|
|
|
|
|
# 使用 Repository 获取状态和类型代码
|
|
|
|
|
status_code = task.task_status.task_status_code if task.task_status else 'unknown'
|
|
|
|
|
type_code = task_repo.get_type_code(task) or 'unknown'
|
|
|
|
|
|
|
|
|
|
result = {
|
|
|
|
|
'task_id': task_id,
|
|
|
|
|
@ -345,13 +359,13 @@ class TaskService:
|
|
|
|
|
是否成功取消
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
task_repo = _get_task_repo()
|
|
|
|
|
task = task_repo.get_by_id(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# 获取任务类型
|
|
|
|
|
task_type = TaskType.query.get(task.tasks_type_id)
|
|
|
|
|
type_code = task_type.task_type_code if task_type else None
|
|
|
|
|
# 获取任务类型代码
|
|
|
|
|
type_code = task_repo.get_type_code(task)
|
|
|
|
|
|
|
|
|
|
# 尝试从队列中删除任务
|
|
|
|
|
try:
|
|
|
|
|
@ -362,19 +376,10 @@ class TaskService:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Could not cancel RQ job: {e}")
|
|
|
|
|
|
|
|
|
|
# 更新数据库状态
|
|
|
|
|
try:
|
|
|
|
|
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
|
|
|
|
|
if failed_status:
|
|
|
|
|
task.tasks_status_id = failed_status.task_status_id
|
|
|
|
|
task.finished_at = datetime.utcnow()
|
|
|
|
|
db.session.commit()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
logger.error(f"Failed to update task status: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
# 使用 Repository 更新状态
|
|
|
|
|
if task_repo.update_status(task, 'failed'):
|
|
|
|
|
return task_repo.save()
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error cancelling task: {e}")
|
|
|
|
|
@ -385,7 +390,7 @@ class TaskService:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def start_perturbation_task(task_id):
|
|
|
|
|
"""
|
|
|
|
|
启动加噪任务
|
|
|
|
|
启动加噪任务(委托给 PerturbationTaskHandler)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
@ -393,249 +398,33 @@ class TaskService:
|
|
|
|
|
Returns:
|
|
|
|
|
job_id
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 获取任务
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
logger.error(f"Task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取Perturbation任务详情
|
|
|
|
|
perturbation = Perturbation.query.get(task_id)
|
|
|
|
|
if not perturbation:
|
|
|
|
|
logger.error(f"Perturbation task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 更新任务状态为 waiting
|
|
|
|
|
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
|
|
|
|
|
if waiting_status:
|
|
|
|
|
task.tasks_status_id = waiting_status.task_status_id
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
# 获取用户ID
|
|
|
|
|
user_id = task.user_id
|
|
|
|
|
|
|
|
|
|
# 路径配置
|
|
|
|
|
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
|
|
|
|
|
output_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id)
|
|
|
|
|
class_dir = TaskService.get_class_data_path(user_id, task.flow_id)
|
|
|
|
|
|
|
|
|
|
# 获取算法配置
|
|
|
|
|
pert_config = PerturbationConfig.query.get(perturbation.perturbation_configs_id)
|
|
|
|
|
if not pert_config:
|
|
|
|
|
logger.error(f"Perturbation config not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
algorithm_code = pert_config.perturbation_code
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.perturbation_worker import run_perturbation_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id = f"pert_{task_id}"
|
|
|
|
|
|
|
|
|
|
job = queue.enqueue(
|
|
|
|
|
run_perturbation_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
input_dir=input_dir,
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
algorithm_code=algorithm_code,
|
|
|
|
|
epsilon=perturbation.perturbation_intensity,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='4h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Perturbation task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
return job_id
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error starting perturbation task: {e}")
|
|
|
|
|
return None
|
|
|
|
|
return _get_task_handler('perturbation').start(task_id)
|
|
|
|
|
|
|
|
|
|
# ==================== Finetune 任务 ====================
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def start_finetune_task(task_id):
|
|
|
|
|
"""
|
|
|
|
|
启动微调任务(支持两种类型)
|
|
|
|
|
|
|
|
|
|
类型1:基于加噪结果的微调
|
|
|
|
|
- 有相同flow_id的Perturbation任务
|
|
|
|
|
- 输入:原图 + 加噪图
|
|
|
|
|
- 输出到:original_generated 和 perturbed_generated
|
|
|
|
|
启动微调任务(委托给 FinetuneTaskHandler)
|
|
|
|
|
|
|
|
|
|
类型2:用户上传图片的微调
|
|
|
|
|
- 找不到相同flow_id的其他任务
|
|
|
|
|
- 输入:仅原图
|
|
|
|
|
- 输出到:uploaded_generated
|
|
|
|
|
支持两种类型:
|
|
|
|
|
- 基于加噪结果的微调
|
|
|
|
|
- 用户上传图片的微调
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
job_id
|
|
|
|
|
job_id 或逗号分隔的多个 job_id
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 获取任务
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
logger.error(f"Task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取Finetune任务详情
|
|
|
|
|
finetune = Finetune.query.get(task_id)
|
|
|
|
|
if not finetune:
|
|
|
|
|
logger.error(f"Finetune task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 更新任务状态为 waiting
|
|
|
|
|
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
|
|
|
|
|
if waiting_status:
|
|
|
|
|
task.tasks_status_id = waiting_status.task_status_id
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
# 获取用户ID
|
|
|
|
|
user_id = task.user_id
|
|
|
|
|
|
|
|
|
|
# 获取微调配置
|
|
|
|
|
ft_config = FinetuneConfig.query.get(finetune.finetune_configs_id)
|
|
|
|
|
if not ft_config:
|
|
|
|
|
logger.error(f"Finetune config not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 检测微调类型:查找相同flow_id的Perturbation任务
|
|
|
|
|
perturb_type = TaskService.require_task_type('perturbation')
|
|
|
|
|
sibling_perturbation = Task.query.filter(
|
|
|
|
|
Task.flow_id == task.flow_id,
|
|
|
|
|
Task.tasks_type_id == perturb_type.task_type_id,
|
|
|
|
|
Task.tasks_id != task_id
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
has_perturbation = sibling_perturbation is not None
|
|
|
|
|
|
|
|
|
|
# 路径配置
|
|
|
|
|
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
|
|
|
|
|
class_dir = TaskService.get_class_data_path(user_id, task.flow_id)
|
|
|
|
|
model_data_dir = TaskService.get_model_data_path(user_id, task.flow_id)
|
|
|
|
|
|
|
|
|
|
if has_perturbation:
|
|
|
|
|
# 类型1:基于加噪结果的微调
|
|
|
|
|
logger.info(f"Finetune task {task_id}: type=perturbation-based")
|
|
|
|
|
|
|
|
|
|
perturbed_input_dir = TaskService.get_perturbed_images_path(user_id, task.flow_id)
|
|
|
|
|
original_input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
|
|
|
|
|
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, task_id)
|
|
|
|
|
original_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取坐标保存路径(3D可视化)
|
|
|
|
|
original_coords_save_path = TaskService._build_path(
|
|
|
|
|
Config.COORDS_SAVE_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(task.flow_id),
|
|
|
|
|
str(task_id),
|
|
|
|
|
'original_coords.csv'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 获取加噪坐标保存路径(3D可视化)
|
|
|
|
|
perturbed_coords_save_path = TaskService._build_path(
|
|
|
|
|
Config.COORDS_SAVE_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(task.flow_id),
|
|
|
|
|
str(task_id),
|
|
|
|
|
'perturbed_coords.csv'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.finetune_worker import run_finetune_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id_original = f"ft_{task_id}_original"
|
|
|
|
|
job_id_perturbed = f"ft_{task_id}_perturbed"
|
|
|
|
|
|
|
|
|
|
job_original = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=original_input_dir,
|
|
|
|
|
output_model_dir=model_data_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=original_coords_save_path,
|
|
|
|
|
validation_output_dir=original_output_dir,
|
|
|
|
|
finetune_type="original",
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id_original,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
job_perturbed = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=perturbed_input_dir,
|
|
|
|
|
output_model_dir=model_data_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=perturbed_coords_save_path,
|
|
|
|
|
validation_output_dir=perturbed_output_dir,
|
|
|
|
|
finetune_type="perturbed",
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id_perturbed,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}")
|
|
|
|
|
return f"{job_id_original},{job_id_perturbed}"
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# 类型2:用户上传图片的微调
|
|
|
|
|
logger.info(f"Finetune task {task_id}: type=uploaded")
|
|
|
|
|
|
|
|
|
|
uploaded_output_dir = TaskService.get_uploaded_generated_path(user_id, task.flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
# 获取坐标保存路径
|
|
|
|
|
coords_save_path = TaskService._build_path(
|
|
|
|
|
Config.COORDS_SAVE_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(task.flow_id),
|
|
|
|
|
str(task_id),
|
|
|
|
|
'coords.csv'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.finetune_worker import run_finetune_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id = f"ft_{task_id}"
|
|
|
|
|
|
|
|
|
|
job = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=input_dir,
|
|
|
|
|
output_model_dir=model_data_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=coords_save_path,
|
|
|
|
|
validation_output_dir=uploaded_output_dir,
|
|
|
|
|
finetune_type="uploaded",
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
return job_id
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error starting finetune task: {e}")
|
|
|
|
|
return None
|
|
|
|
|
return _get_task_handler('finetune').start(task_id)
|
|
|
|
|
|
|
|
|
|
# ==================== Heatmap 任务 ====================
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def start_heatmap_task(task_id):
|
|
|
|
|
"""
|
|
|
|
|
启动热力图任务
|
|
|
|
|
启动热力图任务(委托给 HeatmapTaskHandler)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
@ -643,93 +432,14 @@ class TaskService:
|
|
|
|
|
Returns:
|
|
|
|
|
job_id
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 获取任务
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
logger.error(f"Task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取Heatmap任务详情
|
|
|
|
|
heatmap = Heatmap.query.get(task_id)
|
|
|
|
|
if not heatmap:
|
|
|
|
|
logger.error(f"Heatmap task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 更新任务状态为 waiting
|
|
|
|
|
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
|
|
|
|
|
if waiting_status:
|
|
|
|
|
task.tasks_status_id = waiting_status.task_status_id
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
# 从heatmap对象获取扰动图片ID
|
|
|
|
|
perturbed_image_id = heatmap.images_id
|
|
|
|
|
if not perturbed_image_id:
|
|
|
|
|
logger.error(f"Heatmap task {task_id} has no associated perturbed image")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取扰动图片信息
|
|
|
|
|
perturbed_image = Image.query.get(perturbed_image_id)
|
|
|
|
|
if not perturbed_image:
|
|
|
|
|
logger.error(f"Perturbed image {perturbed_image_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
user_id = task.user_id
|
|
|
|
|
|
|
|
|
|
# 获取原图(通过father_id关系)
|
|
|
|
|
if not perturbed_image.father_id:
|
|
|
|
|
logger.error(f"Perturbed image {perturbed_image_id} has no father_id")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
original_image = Image.query.get(perturbed_image.father_id)
|
|
|
|
|
if not original_image:
|
|
|
|
|
logger.error(f"Original image not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 构建图片路径(使用 stored_filename)
|
|
|
|
|
original_image_path = os.path.join(
|
|
|
|
|
TaskService.get_original_images_path(user_id, task.flow_id),
|
|
|
|
|
original_image.stored_filename
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
perturbed_image_path = os.path.join(
|
|
|
|
|
TaskService.get_perturbed_images_path(user_id, task.flow_id),
|
|
|
|
|
perturbed_image.stored_filename
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 输出目录
|
|
|
|
|
output_dir = TaskService.get_heatmap_path(user_id, task.flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.heatmap_worker import run_heatmap_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id = f"hm_{task_id}"
|
|
|
|
|
|
|
|
|
|
job = queue.enqueue(
|
|
|
|
|
run_heatmap_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
original_image_path=original_image_path,
|
|
|
|
|
perturbed_image_path=perturbed_image_path,
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
perturbed_image_id=perturbed_image_id,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='2h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Heatmap task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
return job_id
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error starting heatmap task: {e}")
|
|
|
|
|
return None
|
|
|
|
|
return _get_task_handler('heatmap').start(task_id)
|
|
|
|
|
|
|
|
|
|
# ==================== Evaluate 任务 ====================
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def start_evaluate_task(task_id):
|
|
|
|
|
"""
|
|
|
|
|
启动评估任务
|
|
|
|
|
启动评估任务(委托给 EvaluateTaskHandler)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
@ -737,64 +447,4 @@ class TaskService:
|
|
|
|
|
Returns:
|
|
|
|
|
job_id
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 获取任务
|
|
|
|
|
task = Task.query.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
logger.error(f"Task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取Evaluate任务详情
|
|
|
|
|
evaluate = Evaluate.query.get(task_id)
|
|
|
|
|
if not evaluate:
|
|
|
|
|
logger.error(f"Evaluate task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 更新任务状态为 waiting
|
|
|
|
|
waiting_status = TaskStatus.query.filter_by(task_status_code='waiting').first()
|
|
|
|
|
if waiting_status:
|
|
|
|
|
task.tasks_status_id = waiting_status.task_status_id
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
finetune = Finetune.query.get(evaluate.finetune_task_id)
|
|
|
|
|
if not finetune:
|
|
|
|
|
logger.error(f"Finetune task {evaluate.finetune_task_id} not found for evaluation {task_id}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
finetune_task = finetune.task
|
|
|
|
|
if not finetune_task:
|
|
|
|
|
logger.error(f"Finetune task {evaluate.finetune_task_id} missing Task relation")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
user_id = finetune_task.user_id
|
|
|
|
|
|
|
|
|
|
# 路径配置
|
|
|
|
|
clean_ref_dir = TaskService.get_original_images_path(user_id, finetune_task.flow_id)
|
|
|
|
|
clean_output_dir = TaskService.get_original_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
|
|
|
|
|
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
|
|
|
|
|
output_dir = TaskService.get_evaluate_path(user_id, finetune_task.flow_id, task_id)
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.evaluate_worker import run_evaluate_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id = f"eval_{task_id}"
|
|
|
|
|
|
|
|
|
|
job = queue.enqueue(
|
|
|
|
|
run_evaluate_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
clean_ref_dir=clean_ref_dir,
|
|
|
|
|
clean_output_dir=clean_output_dir,
|
|
|
|
|
perturbed_output_dir=perturbed_output_dir,
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
image_size=512,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='2h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Evaluate task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
return job_id
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error starting evaluate task: {e}")
|
|
|
|
|
return None
|
|
|
|
|
return _get_task_handler('evaluate').start(task_id)
|
|
|
|
|
|