|
|
|
|
@ -354,11 +354,16 @@ class TaskService:
|
|
|
|
|
logger.warning(f"Could not cancel RQ job: {e}")
|
|
|
|
|
|
|
|
|
|
# 更新数据库状态
|
|
|
|
|
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
|
|
|
|
|
if failed_status:
|
|
|
|
|
task.tasks_status_id = failed_status.task_status_id
|
|
|
|
|
task.finished_at = datetime.utcnow()
|
|
|
|
|
db.session.commit()
|
|
|
|
|
try:
|
|
|
|
|
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
|
|
|
|
|
if failed_status:
|
|
|
|
|
task.tasks_status_id = failed_status.task_status_id
|
|
|
|
|
task.finished_at = datetime.utcnow()
|
|
|
|
|
db.session.commit()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
logger.error(f"Failed to update task status: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@ -406,7 +411,7 @@ class TaskService:
|
|
|
|
|
logger.error(f"Perturbation config not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
algorithm_code = pert_config.perturbation_algorithm_code
|
|
|
|
|
algorithm_code = pert_config.perturbation_code
|
|
|
|
|
|
|
|
|
|
# 加入RQ队列
|
|
|
|
|
from app.workers.perturbation_worker import run_perturbation_task
|
|
|
|
|
@ -421,7 +426,7 @@ class TaskService:
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
algorithm_code=algorithm_code,
|
|
|
|
|
epsilon=pert_config.epsilon,
|
|
|
|
|
epsilon=pert_config.perturbation_intensity,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='4h'
|
|
|
|
|
)
|
|
|
|
|
@ -479,13 +484,14 @@ class TaskService:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 检测微调类型:查找相同flow_id的Perturbation任务
|
|
|
|
|
perturbation_tasks = Task.query.filter(
|
|
|
|
|
perturb_type = TaskService.require_task_type('perturbation')
|
|
|
|
|
sibling_perturbation = Task.query.filter(
|
|
|
|
|
Task.flow_id == task.flow_id,
|
|
|
|
|
Task.tasks_type_id == 1, # perturbation类型
|
|
|
|
|
Task.tasks_type_id == perturb_type.task_type_id,
|
|
|
|
|
Task.tasks_id != task_id
|
|
|
|
|
).all()
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
has_perturbation = len(perturbation_tasks) > 0
|
|
|
|
|
has_perturbation = sibling_perturbation is not None
|
|
|
|
|
|
|
|
|
|
# 路径配置
|
|
|
|
|
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
|
|
|
|
|
@ -513,38 +519,41 @@ class TaskService:
|
|
|
|
|
from app.workers.finetune_worker import run_finetune_task
|
|
|
|
|
|
|
|
|
|
queue = TaskService._get_queue()
|
|
|
|
|
job_id = f"ft_{task_id}"
|
|
|
|
|
job_id_original = f"ft_{task_id}_original"
|
|
|
|
|
job_id_perturbed = f"ft_{task_id}_perturbed"
|
|
|
|
|
|
|
|
|
|
job_original = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_method,
|
|
|
|
|
tranin_images_dir=original_input_dir,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=original_input_dir,
|
|
|
|
|
output_model_dir=original_output_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=coords_save_path,
|
|
|
|
|
validation_images_dir=original_output_dir,
|
|
|
|
|
validation_output_dir=original_output_dir,
|
|
|
|
|
is_perturbed=False,
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_id=job_id_original,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
job_perturbed = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_method,
|
|
|
|
|
tranin_images_dir=perturbed_input_dir,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=perturbed_input_dir,
|
|
|
|
|
output_model_dir=perturbed_output_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=coords_save_path,
|
|
|
|
|
validation_images_dir=perturbed_output_dir,
|
|
|
|
|
validation_output_dir=perturbed_output_dir,
|
|
|
|
|
is_perturbed=True,
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_id=job_id_perturbed,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# 类型2:用户上传图片的微调
|
|
|
|
|
logger.info(f"Finetune task {task_id}: type=uploaded")
|
|
|
|
|
@ -569,19 +578,19 @@ class TaskService:
|
|
|
|
|
job = queue.enqueue(
|
|
|
|
|
run_finetune_task,
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
finetune_method=ft_config.finetune_method,
|
|
|
|
|
tranin_images_dir=input_dir,
|
|
|
|
|
finetune_method=ft_config.finetune_code,
|
|
|
|
|
train_images_dir=input_dir,
|
|
|
|
|
output_model_dir=uploaded_output_dir,
|
|
|
|
|
class_dir=class_dir,
|
|
|
|
|
coords_save_path=coords_save_path,
|
|
|
|
|
validation_images_dir=uploaded_output_dir,
|
|
|
|
|
validation_output_dir=uploaded_output_dir,
|
|
|
|
|
is_perturbed=False,
|
|
|
|
|
custom_params=None,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
job_timeout='8h'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
|
|
|
|
|
return job_id
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
@ -591,13 +600,12 @@ class TaskService:
|
|
|
|
|
# ==================== Heatmap 任务 ====================
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def start_heatmap_task(task_id, perturbed_image_id):
|
|
|
|
|
def start_heatmap_task(task_id):
|
|
|
|
|
"""
|
|
|
|
|
启动热力图任务
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
perturbed_image_id: 扰动图片ID
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
job_id
|
|
|
|
|
@ -615,13 +623,19 @@ class TaskService:
|
|
|
|
|
logger.error(f"Heatmap task {task_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 从heatmap对象获取扰动图片ID
|
|
|
|
|
perturbed_image_id = heatmap.images_id
|
|
|
|
|
if not perturbed_image_id:
|
|
|
|
|
logger.error(f"Heatmap task {task_id} has no associated perturbed image")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 获取扰动图片信息
|
|
|
|
|
perturbed_image = Image.query.get(perturbed_image_id)
|
|
|
|
|
if not perturbed_image:
|
|
|
|
|
logger.error(f"Perturbed image {perturbed_image_id} not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
user_id = perturbed_image.user_id
|
|
|
|
|
user_id = task.user_id
|
|
|
|
|
|
|
|
|
|
# 获取原图(通过father_id关系)
|
|
|
|
|
if not perturbed_image.father_id:
|
|
|
|
|
@ -633,19 +647,19 @@ class TaskService:
|
|
|
|
|
logger.error(f"Original image not found")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 构建图片路径
|
|
|
|
|
# 构建图片路径(使用 stored_filename)
|
|
|
|
|
original_image_path = TaskService._build_path(
|
|
|
|
|
Config.ORIGINAL_IMAGES_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(task.flow_id),
|
|
|
|
|
original_image.image_name
|
|
|
|
|
original_image.stored_filename
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
perturbed_image_path = TaskService._build_path(
|
|
|
|
|
Config.PERTURBED_IMAGES_FOLDER,
|
|
|
|
|
str(user_id),
|
|
|
|
|
str(task.flow_id),
|
|
|
|
|
perturbed_image.image_name
|
|
|
|
|
perturbed_image.stored_filename
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 输出目录
|
|
|
|
|
|