图片接口更改完成

pull/12/head
杨博文 5 months ago
parent fd8e716886
commit 3cd2fd0809

@ -1,14 +1,15 @@
"""
图像管理控制器
负责图片上传下载等操作
负责图片上传查询获取等操作
"""
import os
import base64
from flask import Blueprint, request, jsonify, send_file
from app.controllers.auth_controller import int_jwt_required
from app.services.task_service import TaskService
from app.services.image_service import ImageService
from app.database import Image, ImageType
image_bp = Blueprint('image', __name__)
@ -41,48 +42,123 @@ def upload_original_images(current_user_id):
return jsonify({
'message': '图片上传成功',
'images': [ImageService.serialize_image(img) for img in result],
'images': [image_to_base64(img) for img in result],
'flow_id': task.flow_id
}), 201
# ==================== 结果下载 ====================
# ==================== 单张图片获取 ====================
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@int_jwt_required
def get_image_file(image_id, current_user_id):
"""获取单张图片文件(直接返回图片二进制)"""
image = Image.query.get(image_id)
if not image:
return ImageService.json_error('图片不存在', 404)
task = image.task
if not task or task.user_id != current_user_id:
return ImageService.json_error('无权限访问该图片', 403)
if not os.path.exists(image.file_path):
return ImageService.json_error('图片文件不存在', 404)
ext = os.path.splitext(image.file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
mimetype = mime_types.get(ext, 'application/octet-stream')
return send_file(image.file_path, mimetype=mimetype)
# ==================== 任务图片获取(返回 base64 ====================
@image_bp.route('/task/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_task_images(task_id, current_user_id):
"""获取任务的所有图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
image_type_code = request.args.get('type')
query = Image.query.filter_by(task_id=task_id)
if image_type_code:
image_type = ImageType.query.filter_by(image_code=image_type_code).first()
if image_type:
query = query.filter_by(image_types_id=image_type.image_types_id)
images = query.all()
return jsonify({
'task_id': task_id,
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/perturbation/<int:task_id>/download', methods=['GET'])
@image_bp.route('/perturbation/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_perturbation_result(task_id, current_user_id):
def get_perturbation_images(task_id, current_user_id):
"""获取加噪任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('结果文件不存在', 404)
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if not perturbed_type:
return ImageService.json_error('图片类型未配置', 500)
filename = f"perturbation_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'perturbation',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/heatmap/<int:task_id>/download', methods=['GET'])
@image_bp.route('/heatmap/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_heatmap_result(task_id, current_user_id):
def get_heatmap_images(task_id, current_user_id):
"""获取热力图任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('热力图文件不存在', 404)
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if not heatmap_type:
return ImageService.json_error('图片类型未配置', 500)
filename = f"heatmap_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
images = Image.query.filter_by(
task_id=task_id,
image_types_id=heatmap_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'heatmap',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/finetune/<int:task_id>/download', methods=['GET'])
@image_bp.route('/finetune/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_finetune_result(task_id, current_user_id):
def get_finetune_images(task_id, current_user_id):
"""获取微调任务的生成图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
@ -94,35 +170,89 @@ def download_finetune_result(task_id, current_user_id):
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return ImageService.json_error(str(exc), 500)
result = {'task_id': task_id, 'task_type': 'finetune', 'source': source}
if source == 'perturbation':
directories = {
'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id),
'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
original_gen_type = ImageType.query.filter_by(image_code='original_generate').first()
perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
original_images = []
perturbed_images = []
if original_gen_type:
original_images = Image.query.filter_by(
task_id=task_id,
image_types_id=original_gen_type.image_types_id
).all()
if perturbed_gen_type:
perturbed_images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_gen_type.image_types_id
).all()
result['original_generate'] = [image_to_base64(img) for img in original_images]
result['perturbed_generate'] = [image_to_base64(img) for img in perturbed_images]
result['total'] = len(original_images) + len(perturbed_images)
else:
directories = {
'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
uploaded_images = []
zipped, has_files = ImageService.zip_multiple_directories(directories)
if not has_files:
return ImageService.json_error('微调结果文件不存在', 404)
if uploaded_gen_type:
uploaded_images = Image.query.filter_by(
task_id=task_id,
image_types_id=uploaded_gen_type.image_types_id
).all()
filename = f"finetune_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
result['uploaded_generate'] = [image_to_base64(img) for img in uploaded_images]
result['total'] = len(uploaded_images)
return jsonify(result), 200
@image_bp.route('/evaluate/<int:task_id>/download', methods=['GET'])
@image_bp.route('/evaluate/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_evaluate_result(task_id, current_user_id):
def get_evaluate_images(task_id, current_user_id):
"""获取评估任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('评估结果文件不存在', 404)
report_type = ImageType.query.filter_by(image_code='report').first()
if not report_type:
return ImageService.json_error('图片类型未配置', 500)
images = Image.query.filter_by(
task_id=task_id,
image_types_id=report_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'evaluate',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
# ==================== 图片删除 ====================
@image_bp.route('/<int:image_id>', methods=['DELETE'])
@int_jwt_required
def delete_image(image_id, current_user_id):
"""删除单张图片"""
image = Image.query.get(image_id)
if not image:
return ImageService.json_error('图片不存在', 404)
task = image.task
if not task or task.user_id != current_user_id:
return ImageService.json_error('无权限删除该图片', 403)
result = ImageService.delete_image(image_id, current_user_id)
if not result.get('success'):
return ImageService.json_error(result.get('error', '删除失败'), 500)
filename = f"evaluate_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
return jsonify({'message': '图片删除成功'}), 200

@ -7,7 +7,6 @@ import io
import os
import uuid
import zipfile
import fcntl
import time
from datetime import datetime
from werkzeug.utils import secure_filename
@ -425,4 +424,32 @@ class ImageService:
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}
@staticmethod
def image_to_base64(image):
"""将图片转换为 base64 编码"""
if not image or not os.path.exists(image.file_path):
return None
ext = os.path.splitext(image.file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
mimetype = mime_types.get(ext, 'image/png')
with open(image.file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
return {
'image_id': image.images_id,
'filename': image.stored_filename,
'data': f'data:{mimetype};base64,{data}',
'width': image.width,
'height': image.height
}
Loading…
Cancel
Save