fix: 修复上传图生成图图片类型错误

pull/27/head
梁浩 4 months ago
parent 7aec5df48b
commit 737a95e09b

@ -562,7 +562,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=original_coords_save_path,
validation_output_dir=original_output_dir,
is_perturbed=False,
finetune_type="original",
custom_params=None,
job_id=job_id_original,
job_timeout='8h'
@ -577,7 +577,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=perturbed_coords_save_path,
validation_output_dir=perturbed_output_dir,
is_perturbed=True,
finetune_type="perturbed",
custom_params=None,
job_id=job_id_perturbed,
job_timeout='8h'
@ -616,7 +616,7 @@ class TaskService:
class_dir=class_dir,
coords_save_path=coords_save_path,
validation_output_dir=uploaded_output_dir,
is_perturbed=False,
finetune_type="uploaded",
custom_params=None,
job_id=job_id,
job_timeout='8h'

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def run_finetune_task(task_id, finetune_method, train_images_dir,
output_model_dir, class_dir, coords_save_path, validation_output_dir,
is_perturbed=False, custom_params=None):
finetune_type, custom_params=None):
"""
执行微调任务仅使用真实算法
@ -32,7 +32,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
class_dir: 类别图片目录
coords_save_path: 坐标保存路径
validation_output_dir: 验证图片输出目录
is_perturbed: 是否使用扰动图片训练
finetune_type: 微调类型 (original, perturbed, uploaded)
custom_params: 自定义参数
Returns:
@ -64,7 +64,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
task.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
logger.info(f"Method: {finetune_method}, finetune_type: {finetune_type}")
# 获取 DataType 配置
data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None
@ -138,8 +138,8 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
if not is_perturbed:
# 原图微调:清除旧日志,创建新日志
if finetune_type == "original" or finetune_type == "uploaded":
# 原图/上传微调:清除旧日志,创建新日志
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
for old_log in old_logs:
try:
@ -152,7 +152,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
log_dir,
f'finetune_{finetune_method}_task_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
else:
elif finetune_type == "perturbed":
# 扰动图微调:尝试复用现有日志
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
if old_logs:
@ -170,11 +170,11 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
result = _run_real_finetune(
finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file
)
# 保存生成的验证图片到数据库
_save_generated_images(task_id, validation_output_dir, is_perturbed)
_save_generated_images(task_id, validation_output_dir, finetune_type)
# 更新任务状态为完成
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
@ -205,7 +205,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir,
class_dir, coords_save_path, validation_output_dir,
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file):
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file):
"""
运行真实微调算法参考sh脚本配置
@ -220,7 +220,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
instance_prompt: 实例提示词
class_prompt: 类别提示词
validation_prompt: 验证提示词
is_perturbed: 是否使用扰动图片
finetune_type: 微调类型 (original, perturbed, uploaded)
custom_params: 自定义参数
log_file: 日志文件路径
"""
@ -309,7 +309,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
# 执行命令
# 使用追加模式 'a',以便在同一日志文件中记录原图和扰动图的微调过程
with open(log_file, 'a') as f:
if is_perturbed:
if finetune_type == "perturbed":
f.write(f"\n\n{'='*30}\nStarting Perturbed Finetune Task\n{'='*30}\n\n")
process = subprocess.Popen(
@ -386,7 +386,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
}
def _save_generated_images(task_id, output_dir, is_perturbed):
def _save_generated_images(task_id, output_dir, finetune_type):
"""
保存微调生成的验证图片到数据库适配新数据库结构
@ -398,7 +398,7 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
Args:
task_id: 任务ID
output_dir: 生成图片输出目录
is_perturbed: 是否为扰动图片训练生成
finetune_type: 微调类型 (original, perturbed, uploaded)
"""
from app import db
from app.database import Task, Image, ImageType
@ -410,12 +410,15 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
raise ValueError(f"Task {task_id} not found")
# 获取图片类型
if is_perturbed:
if finetune_type == "perturbed":
generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
input_type = ImageType.query.filter_by(image_code='perturbed').first()
else:
elif finetune_type == "original":
generated_type = ImageType.query.filter_by(image_code='original_generate').first()
input_type = ImageType.query.filter_by(image_code='original').first()
elif finetune_type == "uploaded":
generated_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
input_type = ImageType.query.filter_by(image_code='original').first()
if not generated_type or not input_type:
raise ValueError("Required image types not found in database")

Loading…
Cancel
Save