将lianghao_branch合并到develop #14

Merged
hnu202326010204 merged 6 commits from lianghao_branch into develop 1 month ago

7
.gitignore vendored

@ -6,6 +6,9 @@ __pycache__/
*.jpg
*.jpeg
# 数据文件
*.csv
# 环境配置文件(包含敏感信息)
*.env
@ -28,8 +31,8 @@ uploads/
# vscode 配置
.vscode/
#github 工作流配置
# github 工作流配置
.github/
#pycharm 配置
# pycharm 配置
.idea/

@ -1,35 +0,0 @@
# Python 编译缓存
__pycache__/
# 图片文件
*.png
*.jpg
*.jpeg
# 环境配置文件(包含敏感信息)
*.env
# 日志及进程文件
logs/
*.log
*.pid
# 上传文件临时目录
uploads/
# 微调生成文件
*.json
*.bin
*.pkl
*.safetensors
*.pt
*.txt
# 模型文件
hf_models/
#数据库迁移文件
migrations/
#测试文件
test/

@ -1377,7 +1377,11 @@ def main(args):
coords_list,
columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss']
)
save_path = Path(args.positions_save_path) / "coords_live.csv"
# 假设 args.positions_save_path 是目标文件路径 (如 ./data/coords.csv)
save_path = Path(args.positions_save_path)
if not save_path.suffix:
save_path = save_path / "coords.csv"
save_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(save_path, index=False)
@ -1508,8 +1512,12 @@ def main(args):
coords_list,
columns=['step', 'X_LoRA_Weight_Norm', 'Y_Grad_Norm', 'Z_LDM_Loss']
)
save_path = Path(args.positions_save_path) / "coords.csv"
# 假设 args.positions_save_path 是目标文件路径 (如 ./data/coords.csv)
save_path = Path(args.positions_save_path)
if not save_path.suffix:
save_path = save_path / "coords.csv"
save_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(save_path, index=False)
logger.info(f"训练结束:已将所有 {len(coords_list)} 步可视化坐标数据保存到 {save_path}")

@ -679,3 +679,114 @@ def get_evaluate_task(task_id, current_user_id):
if not task:
return TaskService.json_error('任务不存在或无权限', 404)
return jsonify({'task': TaskService.serialize_task(task)}), 200
# ==================== 3D可视化坐标接口 ====================
@task_bp.route('/finetune/<int:task_id>/coords', methods=['GET'])
@int_jwt_required
def get_finetune_coords(task_id, current_user_id):
"""
获取微调任务的3D可视化坐标CSV文件
返回格式
- 基于加噪任务的微调返回 original_coords.csv perturbed_coords.csv
- 上传图片的微调返回 coords.csv
"""
import os
import csv
from config.settings import Config
# 验证任务存在且属于当前用户
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return TaskService.json_error('任务不存在或无权限', 404)
# 获取任务详情
finetune = Finetune.query.get(task_id)
if not finetune:
return TaskService.json_error('微调任务详情不存在', 404)
# 判断微调类型
try:
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return TaskService.json_error(str(exc), 500)
# 构建CSV文件路径
coords_base_path = TaskService._build_path(
Config.COORDS_SAVE_FOLDER,
str(current_user_id),
str(task.flow_id),
str(task_id)
)
result = {
'task_id': task_id,
'flow_id': task.flow_id,
'source': source,
'coords': []
}
if source == 'perturbation':
# 基于加噪任务的微调返回两个CSV文件
original_csv_path = os.path.join(coords_base_path, 'original_coords.csv')
perturbed_csv_path = os.path.join(coords_base_path, 'perturbed_coords.csv')
# 读取 original_coords.csv
if os.path.exists(original_csv_path):
try:
with open(original_csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
original_data = [row for row in reader]
result['coords'].append({
'type': 'original',
'filename': 'original_coords.csv',
'path': original_csv_path,
'data': original_data
})
except Exception as e:
return TaskService.json_error(f'读取原图坐标文件失败: {str(e)}', 500)
else:
return TaskService.json_error('原图坐标文件不存在', 404)
# 读取 perturbed_coords.csv
if os.path.exists(perturbed_csv_path):
try:
with open(perturbed_csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
perturbed_data = [row for row in reader]
result['coords'].append({
'type': 'perturbed',
'filename': 'perturbed_coords.csv',
'path': perturbed_csv_path,
'data': perturbed_data
})
except Exception as e:
return TaskService.json_error(f'读取加噪图坐标文件失败: {str(e)}', 500)
else:
return TaskService.json_error('加噪图坐标文件不存在', 404)
else: # source == 'uploaded'
# 上传图片的微调返回一个CSV文件
coords_csv_path = os.path.join(coords_base_path, 'coords.csv')
if os.path.exists(coords_csv_path):
try:
with open(coords_csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
coords_data = [row for row in reader]
result['coords'].append({
'type': 'uploaded',
'filename': 'coords.csv',
'path': coords_csv_path,
'data': coords_data
})
except Exception as e:
return TaskService.json_error(f'读取坐标文件失败: {str(e)}', 500)
else:
return TaskService.json_error('坐标文件不存在', 404)
return jsonify(result), 200

@ -521,7 +521,7 @@ class TaskService:
str(user_id),
str(task.flow_id),
str(task_id),
'original_coords.json'
'original_coords.csv'
)
# 获取加噪坐标保存路径3D可视化
@ -530,7 +530,7 @@ class TaskService:
str(user_id),
str(task.flow_id),
str(task_id),
'perturbed_coords.json'
'perturbed_coords.csv'
)
# 加入RQ队列
@ -585,7 +585,7 @@ class TaskService:
str(user_id),
str(task.flow_id),
str(task_id),
'coords.json'
'coords.csv'
)
# 加入RQ队列

@ -81,16 +81,13 @@ def run_finetune_task(task_id, finetune_method, train_images_dir,
class_prompt = instance_prompt.replace('sks ', '')
logger.info(f"Using prompts from database - instance: '{instance_prompt}', class: '{class_prompt}'")
# 清空输出目录(避免旧文件残留)
logger.info(f"Clearing output directories...")
# 彻底清空输出目录(避免旧文件残留,特别是 textual_inversion 的 token
logger.info(f"Completely clearing output directories...")
for dir_path in [output_model_dir, validation_output_dir]:
if os.path.exists(dir_path):
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
if os.path.isfile(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
logger.info(f"Removing and recreating directory: {dir_path}")
shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok=True)
# 清理旧的 coords.json 文件
if os.path.exists(coords_save_path):
@ -176,9 +173,12 @@ def _run_real_finetune(finetune_method, task_id, train_images_dir, output_model_
raise ValueError(f"Finetune method {finetune_method} not configured")
# 覆盖提示词参数(从数据库读取)
default_params['instance_prompt'] = instance_prompt
default_params['class_prompt'] = class_prompt
default_params['validation_prompt'] = validation_prompt
if 'instance_prompt' in default_params:
default_params['instance_prompt'] = instance_prompt
if 'class_prompt' in default_params:
default_params['class_prompt'] = class_prompt
if 'validation_prompt' in default_params:
default_params['validation_prompt'] = validation_prompt
# 合并自定义参数
params = {**default_params, **(custom_params or {})}

@ -162,7 +162,8 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
default_params = AlgorithmConfig.get_default_params(algorithm_code)
# 覆盖提示词参数(从数据库读取)
default_params['instance_prompt'] = instance_prompt
if 'instance_prompt' in default_params:
default_params['instance_prompt'] = instance_prompt
if 'class_prompt' in default_params:
default_params['class_prompt'] = class_prompt

Loading…
Cancel
Save