fix: 修复任务模块字段与参数匹配的问题

pull/10/head
梁浩 5 months ago
parent 197cbc6e8f
commit 13ca7c461e

@ -311,7 +311,7 @@ def start_heatmap_task(task_id, current_user_id):
if not task.heatmap:
return TaskService.json_error('热力图任务未配置对应图片', 400)
job_id = TaskService.start_heatmap_task(task_id, task.heatmap.images_id)
job_id = TaskService.start_heatmap_task(task_id)
if not job_id:
return TaskService.json_error('任务启动失败', 500)
return jsonify({'message': '任务已启动', 'job_id': job_id}), 200

@ -354,11 +354,16 @@ class TaskService:
logger.warning(f"Could not cancel RQ job: {e}")
# 更新数据库状态
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
try:
failed_status = TaskStatus.query.filter_by(task_status_code='failed').first()
if failed_status:
task.tasks_status_id = failed_status.task_status_id
task.finished_at = datetime.utcnow()
db.session.commit()
except Exception as e:
db.session.rollback()
logger.error(f"Failed to update task status: {e}")
return False
return True
@ -406,7 +411,7 @@ class TaskService:
logger.error(f"Perturbation config not found")
return None
algorithm_code = pert_config.perturbation_algorithm_code
algorithm_code = pert_config.perturbation_code
# 加入RQ队列
from app.workers.perturbation_worker import run_perturbation_task
@ -421,7 +426,7 @@ class TaskService:
output_dir=output_dir,
class_dir=class_dir,
algorithm_code=algorithm_code,
epsilon=pert_config.epsilon,
epsilon=pert_config.perturbation_intensity,
job_id=job_id,
job_timeout='4h'
)
@ -479,13 +484,14 @@ class TaskService:
return None
# 检测微调类型查找相同flow_id的Perturbation任务
perturbation_tasks = Task.query.filter(
perturb_type = TaskService.require_task_type('perturbation')
sibling_perturbation = Task.query.filter(
Task.flow_id == task.flow_id,
Task.tasks_type_id == 1, # perturbation类型
Task.tasks_type_id == perturb_type.task_type_id,
Task.tasks_id != task_id
).all()
).first()
has_perturbation = len(perturbation_tasks) > 0
has_perturbation = sibling_perturbation is not None
# 路径配置
input_dir = TaskService.get_original_images_path(user_id, task.flow_id)
@ -513,38 +519,41 @@ class TaskService:
from app.workers.finetune_worker import run_finetune_task
queue = TaskService._get_queue()
job_id = f"ft_{task_id}"
job_id_original = f"ft_{task_id}_original"
job_id_perturbed = f"ft_{task_id}_perturbed"
job_original = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=original_input_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=original_input_dir,
output_model_dir=original_output_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=original_output_dir,
validation_output_dir=original_output_dir,
is_perturbed=False,
custom_params=None,
job_id=job_id,
job_id=job_id_original,
job_timeout='8h'
)
job_perturbed = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=perturbed_input_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=perturbed_input_dir,
output_model_dir=perturbed_output_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=perturbed_output_dir,
validation_output_dir=perturbed_output_dir,
is_perturbed=True,
custom_params=None,
job_id=job_id,
job_id=job_id_perturbed,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_ids {job_id_original}, {job_id_perturbed}")
else:
# 类型2用户上传图片的微调
logger.info(f"Finetune task {task_id}: type=uploaded")
@ -569,19 +578,19 @@ class TaskService:
job = queue.enqueue(
run_finetune_task,
task_id=task_id,
finetune_method=ft_config.finetune_method,
tranin_images_dir=input_dir,
finetune_method=ft_config.finetune_code,
train_images_dir=input_dir,
output_model_dir=uploaded_output_dir,
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_images_dir=uploaded_output_dir,
validation_output_dir=uploaded_output_dir,
is_perturbed=False,
custom_params=None,
job_id=job_id,
job_timeout='8h'
)
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
logger.info(f"Finetune task {task_id} enqueued with job_id {job_id}")
return job_id
except Exception as e:
@ -591,13 +600,12 @@ class TaskService:
# ==================== Heatmap 任务 ====================
@staticmethod
def start_heatmap_task(task_id, perturbed_image_id):
def start_heatmap_task(task_id):
"""
启动热力图任务
Args:
task_id: 任务ID
perturbed_image_id: 扰动图片ID
Returns:
job_id
@ -615,13 +623,19 @@ class TaskService:
logger.error(f"Heatmap task {task_id} not found")
return None
# 从heatmap对象获取扰动图片ID
perturbed_image_id = heatmap.images_id
if not perturbed_image_id:
logger.error(f"Heatmap task {task_id} has no associated perturbed image")
return None
# 获取扰动图片信息
perturbed_image = Image.query.get(perturbed_image_id)
if not perturbed_image:
logger.error(f"Perturbed image {perturbed_image_id} not found")
return None
user_id = perturbed_image.user_id
user_id = task.user_id
# 获取原图通过father_id关系
if not perturbed_image.father_id:
@ -633,19 +647,19 @@ class TaskService:
logger.error(f"Original image not found")
return None
# 构建图片路径
# 构建图片路径(使用 stored_filename
original_image_path = TaskService._build_path(
Config.ORIGINAL_IMAGES_FOLDER,
str(user_id),
str(task.flow_id),
original_image.image_name
original_image.stored_filename
)
perturbed_image_path = TaskService._build_path(
Config.PERTURBED_IMAGES_FOLDER,
str(user_id),
str(task.flow_id),
perturbed_image.image_name
perturbed_image.stored_filename
)
# 输出目录

@ -18,7 +18,7 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images_dir,
def run_finetune_task(task_id, finetune_method, train_images_dir,
output_model_dir, class_dir, coords_save_path, validation_output_dir,
is_perturbed=False, custom_params=None):
"""
@ -26,7 +26,6 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
Args:
task_id: 任务ID
finetune_config_id: 微调配置ID
finetune_method: 微调方法 (dreambooth, lora, textual_inversion)
train_images_dir: 训练图片目录
output_model_dir: 模型输出目录
@ -54,13 +53,9 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
# 获取微调任务详情
finetune = Finetune.query.filter_by(
tasks_id=task_id,
finetune_configs_id=finetune_config_id
tasks_id=task_id
).first()
if not finetune:
raise ValueError(f"Finetune task ({task_id}, {finetune_config_id}) not found")
# 更新任务状态为处理中
processing_status = TaskStatus.query.filter_by(task_status_code='processing').first()
if processing_status:
@ -69,7 +64,6 @@ def run_finetune_task(task_id, finetune_config_id, finetune_method, train_images
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting finetune task {task_id} (config: {finetune_config_id})")
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
# 从数据库获取数据集类型的提示词

Loading…
Cancel
Save