实现图片预览功能 #19

Merged
ppy4sjqvf merged 1 commits from ybw-branch into develop 1 month ago

@ -255,4 +255,190 @@ def delete_image(image_id, current_user_id):
return ImageService.json_error(result.get('error', '删除失败'), 500)
return jsonify({'message': '图片删除成功'}), 200
# ==================== 统一预览接口 ====================
@image_bp.route('/preview/flow/<int:flow_id>', methods=['GET'])
@int_jwt_required
def preview_flow_images(flow_id, current_user_id):
"""
获取工作流下所有图片的统一预览接口
返回数据结构:
{
"flow_id": 123,
"original": [...], # 原图
"perturbed": [...], # 加噪图
"original_generate": [...], # 原图微调生成
"perturbed_generate": [...], # 加噪图微调生成
"uploaded_generate": [...], # 上传图微调生成
"heatmap": [...], # 热力图
"report": [...] # 评估报告图
}
"""
from app.database import Task
# 验证用户对该flow的访问权限
tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all()
if not tasks:
return ImageService.json_error('工作流不存在或无权限', 404)
# 获取所有图片类型
image_types = {
'original': ImageType.query.filter_by(image_code='original').first(),
'perturbed': ImageType.query.filter_by(image_code='perturbed').first(),
'original_generate': ImageType.query.filter_by(image_code='original_generate').first(),
'perturbed_generate': ImageType.query.filter_by(image_code='perturbed_generate').first(),
'uploaded_generate': ImageType.query.filter_by(image_code='uploaded_generate').first(),
'heatmap': ImageType.query.filter_by(image_code='heatmap').first(),
'report': ImageType.query.filter_by(image_code='report').first(),
}
# 收集所有任务ID
task_ids = [t.tasks_id for t in tasks]
result = {
'flow_id': flow_id,
'original': [],
'perturbed': [],
'original_generate': [],
'perturbed_generate': [],
'uploaded_generate': [],
'heatmap': [],
'report': []
}
# 查询各类型图片
for type_code, image_type in image_types.items():
if image_type:
images = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == image_type.image_types_id
).all()
result[type_code] = [ImageService.image_to_base64(img) for img in images if img]
# 统计总数
result['total'] = sum(len(result[k]) for k in result if k not in ['flow_id', 'total'])
return jsonify(result), 200
@image_bp.route('/preview/task/<int:task_id>', methods=['GET'])
@int_jwt_required
def preview_task_images(task_id, current_user_id):
"""
获取单个任务的所有图片预览
根据任务类型返回相应的图片:
- perturbation: 原图 + 加噪图
- finetune: 原图 + 生成图(original_generate/perturbed_generate/uploaded_generate)
- heatmap: 原图 + 加噪图 + 热力图
- evaluate: 生成图 + 报告图
"""
from app.database import Task, TaskType
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
task_type_code = TaskService.get_task_type_code(task)
result = {
'task_id': task_id,
'flow_id': task.flow_id,
'task_type': task_type_code,
'images': {}
}
# 根据任务类型获取相关图片
if task_type_code == 'perturbation':
result['images'] = ImageService._get_perturbation_preview(task)
elif task_type_code == 'finetune':
result['images'] = ImageService._get_finetune_preview(task)
elif task_type_code == 'heatmap':
result['images'] = ImageService._get_heatmap_preview(task)
elif task_type_code == 'evaluate':
result['images'] = ImageService._get_evaluate_preview(task)
return jsonify(result), 200
@image_bp.route('/preview/compare/<int:flow_id>', methods=['GET'])
@int_jwt_required
def preview_compare_images(flow_id, current_user_id):
"""
获取对比预览数据用于展示原图vs加噪图原图生成vs加噪图生成的对比
返回配对的图片数据便于前端展示对比效果
"""
from app.database import Task
# 验证权限
tasks = Task.query.filter_by(flow_id=flow_id, user_id=current_user_id).all()
if not tasks:
return ImageService.json_error('工作流不存在或无权限', 404)
task_ids = [t.tasks_id for t in tasks]
# 获取图片类型
original_type = ImageType.query.filter_by(image_code='original').first()
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
original_gen_type = ImageType.query.filter_by(image_code='original_generate').first()
perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
result = {
'flow_id': flow_id,
'perturbation_pairs': [], # 原图 vs 加噪图
'generation_pairs': [] # 原图生成 vs 加噪图生成
}
# 构建原图vs加噪图对比
if original_type and perturbed_type:
originals = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_type.image_types_id
).all()
for orig in originals:
# 查找对应的加噪图通过father_id关联
perturbed = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == perturbed_type.image_types_id,
Image.father_id == orig.images_id
).first()
if perturbed:
result['perturbation_pairs'].append({
'original': ImageService.image_to_base64(orig),
'perturbed': ImageService.image_to_base64(perturbed)
})
# 构建生成图对比(按文件名匹配)
if original_gen_type and perturbed_gen_type:
original_gens = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_gen_type.image_types_id
).all()
perturbed_gens = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == perturbed_gen_type.image_types_id
).all()
# 按文件名建立映射
perturbed_map = {img.stored_filename: img for img in perturbed_gens}
for orig_gen in original_gens:
perturbed_gen = perturbed_map.get(orig_gen.stored_filename)
if perturbed_gen:
result['generation_pairs'].append({
'original_generate': ImageService.image_to_base64(orig_gen),
'perturbed_generate': ImageService.image_to_base64(perturbed_gen)
})
return jsonify(result), 200

@ -453,4 +453,111 @@ class ImageService:
'data': f'data:{mimetype};base64,{data}',
'width': image.width,
'height': image.height
}
}
## ==================== 获取预览图片服务 ====================
def _get_perturbation_preview(task):
"""获取加噪任务的预览图片"""
images = {'original': [], 'perturbed': []}
original_type = ImageType.query.filter_by(image_code='original').first()
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if original_type:
originals = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=original_type.image_types_id
).all()
images['original'] = [ImageService.image_to_base64(img) for img in originals]
if perturbed_type:
perturbeds = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=perturbed_type.image_types_id
).all()
images['perturbed'] = [ImageService.image_to_base64(img) for img in perturbeds]
return images
def _get_finetune_preview(task):
"""获取微调任务的预览图片"""
images = {
'original': [],
'original_generate': [],
'perturbed_generate': [],
'uploaded_generate': []
}
# 获取原图从同一flow_id的perturbation任务或当前任务
original_type = ImageType.query.filter_by(image_code='original').first()
if original_type:
# 查找同flow下的原图
from app.database import Task
flow_tasks = Task.query.filter_by(flow_id=task.flow_id, user_id=task.user_id).all()
task_ids = [t.tasks_id for t in flow_tasks]
originals = Image.query.filter(
Image.task_id.in_(task_ids),
Image.image_types_id == original_type.image_types_id
).all()
images['original'] = [ImageService.image_to_base64(img) for img in originals]
# 获取生成图
for type_code in ['original_generate', 'perturbed_generate', 'uploaded_generate']:
img_type = ImageType.query.filter_by(image_code=type_code).first()
if img_type:
generated = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=img_type.image_types_id
).all()
images[type_code] = [ImageService.image_to_base64(img) for img in generated]
return images
def _get_heatmap_preview(task):
"""获取热力图任务的预览图片(热力图本身已包含原图和加噪图的对比)"""
images = {'heatmap': []}
# 获取热力图(已是完整的对比报告图)
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if heatmap_type:
heatmaps = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=heatmap_type.image_types_id
).all()
images['heatmap'] = [ImageService.image_to_base64(img) for img in heatmaps]
return images
def _get_evaluate_preview(task):
"""获取评估任务的预览图片"""
images = {
'original_generate': [],
'perturbed_generate': [],
'report': []
}
# 获取报告图
report_type = ImageType.query.filter_by(image_code='report').first()
if report_type:
reports = Image.query.filter_by(
task_id=task.tasks_id,
image_types_id=report_type.image_types_id
).all()
images['report'] = [ImageService.image_to_base64(img) for img in reports]
# 获取关联的微调任务生成图
if task.evaluation and task.evaluation.finetune_task_id:
finetune_task_id = task.evaluation.finetune_task_id
for type_code in ['original_generate', 'perturbed_generate']:
img_type = ImageType.query.filter_by(image_code=type_code).first()
if img_type:
generated = Image.query.filter_by(
task_id=finetune_task_id,
image_types_id=img_type.image_types_id
).all()
images[type_code] = [ImageService.image_to_base64(img) for img in generated]
return images
Loading…
Cancel
Save