|
|
|
|
@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
output_model_dir, class_dir, coords_save_path, validation_output_dir,
|
|
|
|
|
is_perturbed=False, custom_params=None):
|
|
|
|
|
finetune_type, custom_params=None):
|
|
|
|
|
"""
|
|
|
|
|
执行微调任务(仅使用真实算法)
|
|
|
|
|
|
|
|
|
|
@ -32,7 +32,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
class_dir: 类别图片目录
|
|
|
|
|
coords_save_path: 坐标保存路径
|
|
|
|
|
validation_output_dir: 验证图片输出目录
|
|
|
|
|
is_perturbed: 是否使用扰动图片训练
|
|
|
|
|
finetune_type: 微调类型 (original, perturbed, uploaded)
|
|
|
|
|
custom_params: 自定义参数
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
@ -64,7 +64,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
task.started_at = datetime.utcnow()
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Method: {finetune_method}, is_perturbed: {is_perturbed}")
|
|
|
|
|
logger.info(f"Method: {finetune_method}, finetune_type: {finetune_type}")
|
|
|
|
|
|
|
|
|
|
# 获取 DataType 配置
|
|
|
|
|
data_type = DataType.query.get(finetune.data_type_id) if finetune.data_type_id else None
|
|
|
|
|
@ -138,8 +138,8 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
log_dir = AlgorithmConfig.LOGS_DIR
|
|
|
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
if not is_perturbed:
|
|
|
|
|
# 原图微调:清除旧日志,创建新日志
|
|
|
|
|
if finetune_type == "original" or finetune_type == "uploaded":
|
|
|
|
|
# 原图/上传微调:清除旧日志,创建新日志
|
|
|
|
|
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
|
|
|
|
|
for old_log in old_logs:
|
|
|
|
|
try:
|
|
|
|
|
@ -152,7 +152,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
log_dir,
|
|
|
|
|
f'finetune_{finetune_method}_task_{task_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
elif finetune_type == "perturbed":
|
|
|
|
|
# 扰动图微调:尝试复用现有日志
|
|
|
|
|
old_logs = glob.glob(os.path.join(log_dir, f'finetune_{finetune_method}_task_{task_id}_*.log'))
|
|
|
|
|
if old_logs:
|
|
|
|
|
@ -170,11 +170,11 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
result = _run_real_finetune(
|
|
|
|
|
finetune_method, task_id, train_images_dir, output_model_dir,
|
|
|
|
|
class_dir, coords_save_path, validation_output_dir,
|
|
|
|
|
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file
|
|
|
|
|
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 保存生成的验证图片到数据库
|
|
|
|
|
_save_generated_images(task_id, validation_output_dir, is_perturbed)
|
|
|
|
|
_save_generated_images(task_id, validation_output_dir, finetune_type)
|
|
|
|
|
|
|
|
|
|
# 更新任务状态为完成
|
|
|
|
|
completed_status = TaskStatus.query.filter_by(task_status_code='completed').first()
|
|
|
|
|
@ -205,7 +205,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
|
|
|
|
|
|
|
|
|
|
def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_dir,
|
|
|
|
|
class_dir, coords_save_path, validation_output_dir,
|
|
|
|
|
instance_prompt, class_prompt, validation_prompt, is_perturbed, custom_params, log_file):
|
|
|
|
|
instance_prompt, class_prompt, validation_prompt, finetune_type, custom_params, log_file):
|
|
|
|
|
"""
|
|
|
|
|
运行真实微调算法(参考sh脚本配置)
|
|
|
|
|
|
|
|
|
|
@ -220,7 +220,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
|
|
|
|
|
instance_prompt: 实例提示词
|
|
|
|
|
class_prompt: 类别提示词
|
|
|
|
|
validation_prompt: 验证提示词
|
|
|
|
|
is_perturbed: 是否使用扰动图片
|
|
|
|
|
finetune_type: 微调类型 (original, perturbed, uploaded)
|
|
|
|
|
custom_params: 自定义参数
|
|
|
|
|
log_file: 日志文件路径
|
|
|
|
|
"""
|
|
|
|
|
@ -309,7 +309,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
|
|
|
|
|
# 执行命令
|
|
|
|
|
# 使用追加模式 'a',以便在同一日志文件中记录原图和扰动图的微调过程
|
|
|
|
|
with open(log_file, 'a') as f:
|
|
|
|
|
if is_perturbed:
|
|
|
|
|
if finetune_type == "perturbed":
|
|
|
|
|
f.write(f"\n\n{'='*30}\nStarting Perturbed Finetune Task\n{'='*30}\n\n")
|
|
|
|
|
|
|
|
|
|
process = subprocess.Popen(
|
|
|
|
|
@ -386,7 +386,7 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_generated_images(task_id, output_dir, is_perturbed):
|
|
|
|
|
def _save_generated_images(task_id, output_dir, finetune_type):
|
|
|
|
|
"""
|
|
|
|
|
保存微调生成的验证图片到数据库(适配新数据库结构)
|
|
|
|
|
|
|
|
|
|
@ -398,7 +398,7 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
|
|
|
|
|
Args:
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
output_dir: 生成图片输出目录
|
|
|
|
|
is_perturbed: 是否为扰动图片训练生成
|
|
|
|
|
finetune_type: 微调类型 (original, perturbed, uploaded)
|
|
|
|
|
"""
|
|
|
|
|
from app import db
|
|
|
|
|
from app.database import Task, Image, ImageType
|
|
|
|
|
@ -410,12 +410,15 @@ def _save_generated_images(task_id, output_dir, is_perturbed):
|
|
|
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
|
|
|
|
|
|
# 获取图片类型
|
|
|
|
|
if is_perturbed:
|
|
|
|
|
if finetune_type == "perturbed":
|
|
|
|
|
generated_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
|
|
|
|
|
input_type = ImageType.query.filter_by(image_code='perturbed').first()
|
|
|
|
|
else:
|
|
|
|
|
elif finetune_type == "original":
|
|
|
|
|
generated_type = ImageType.query.filter_by(image_code='original_generate').first()
|
|
|
|
|
input_type = ImageType.query.filter_by(image_code='original').first()
|
|
|
|
|
elif finetune_type == "uploaded":
|
|
|
|
|
generated_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
|
|
|
|
|
input_type = ImageType.query.filter_by(image_code='original').first()
|
|
|
|
|
|
|
|
|
|
if not generated_type or not input_type:
|
|
|
|
|
raise ValueError("Required image types not found in database")
|
|
|
|
|
|