图片服务修复 #17

Merged
ppy4sjqvf merged 1 commits from ybw-branch into develop 1 month ago

@ -42,7 +42,7 @@ def upload_original_images(current_user_id):
return jsonify({
'message': '图片上传成功',
'images': [image_to_base64(img) for img in result],
'images': [ImageService.image_to_base64(img) for img in result],
'flow_id': task.flow_id
}), 201
@ -100,7 +100,7 @@ def get_task_images(task_id, current_user_id):
return jsonify({
'task_id': task_id,
'images': [image_to_base64(img) for img in images],
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@ -125,7 +125,7 @@ def get_perturbation_images(task_id, current_user_id):
return jsonify({
'task_id': task_id,
'task_type': 'perturbation',
'images': [image_to_base64(img) for img in images],
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@ -150,7 +150,7 @@ def get_heatmap_images(task_id, current_user_id):
return jsonify({
'task_id': task_id,
'task_type': 'heatmap',
'images': [image_to_base64(img) for img in images],
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200
@ -192,8 +192,8 @@ def get_finetune_images(task_id, current_user_id):
image_types_id=perturbed_gen_type.image_types_id
).all()
result['original_generate'] = [image_to_base64(img) for img in original_images]
result['perturbed_generate'] = [image_to_base64(img) for img in perturbed_images]
result['original_generate'] = [ImageService.image_to_base64(img) for img in original_images]
result['perturbed_generate'] = [ImageService.image_to_base64(img) for img in perturbed_images]
result['total'] = len(original_images) + len(perturbed_images)
else:
uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
@ -205,7 +205,7 @@ def get_finetune_images(task_id, current_user_id):
image_types_id=uploaded_gen_type.image_types_id
).all()
result['uploaded_generate'] = [image_to_base64(img) for img in uploaded_images]
result['uploaded_generate'] = [ImageService.image_to_base64(img) for img in uploaded_images]
result['total'] = len(uploaded_images)
return jsonify(result), 200
@ -231,7 +231,7 @@ def get_evaluate_images(task_id, current_user_id):
return jsonify({
'task_id': task_id,
'task_type': 'evaluate',
'images': [image_to_base64(img) for img in images],
'images': [ImageService.image_to_base64(img) for img in images],
'total': len(images)
}), 200

@ -3,6 +3,7 @@
处理图像上传保存等功能
"""
import base64
import io
import os
import uuid
@ -18,12 +19,12 @@ from app.utils.file_utils import allowed_file
class ImageService:
@staticmethod
def save_to_uploads(file, batch_id, user_id):
def save_to_uploads(file, task_id, user_id):
"""
上传图片到uploads临时目录返回临时文件路径和原始文件名
"""
project_root = os.path.dirname(current_app.root_path)
upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id))
upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(task_id))
os.makedirs(upload_dir, exist_ok=True)
orig_ext = os.path.splitext(file.filename)[1].lower()
temp_name = f"{uuid.uuid4().hex}{orig_ext}"
@ -32,7 +33,7 @@ class ImageService:
return temp_path, file.filename
@staticmethod
def preprocess_image(temp_path, original_filename, batch_id, user_id, image_type_id, resolution=512, target_format='png'):
def preprocess_image(temp_path, original_filename, task_id, user_id, image_types_id, resolution=512, target_format='png'):
"""
对图片进行中心裁剪缩放格式转换重命名保存到static/originals返回数据库对象
原图命名格式: 0000.png, 0001.png, ..., 9999.png
@ -53,23 +54,23 @@ class ImageService:
img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS)
project_root = os.path.dirname(current_app.root_path)
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id))
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(task_id))
os.makedirs(static_dir, exist_ok=True)
from app.database import ImageType
original_type = ImageType.query.filter_by(type_code='original').first()
target_image_type_id = original_type.id if original_type else image_type_id
original_type = ImageType.query.filter_by(image_code='original').first()
target_image_types_id = original_type.image_types_id if original_type else image_types_id
# 首次查询最大序号
max_seq_result = db.session.execute(
db.text("""
SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq
FROM images
WHERE batch_id = :batch_id
AND image_type_id = :image_type_id
WHERE task_id = :task_id
AND image_types_id = :image_types_id
AND stored_filename REGEXP '^[0-9]{4}\\.'
"""),
{'batch_id': batch_id, 'image_type_id': target_image_type_id}
{'task_id': task_id, 'image_types_id': target_image_types_id}
).fetchone()
# 强制类型转换,确保安全
@ -89,7 +90,7 @@ class ImageService:
try:
# 检查数据库中是否已存在此文件名
existing = Image.query.filter_by(
batch_id=batch_id,
task_id=task_id,
stored_filename=new_name
).first()
@ -105,13 +106,11 @@ class ImageService:
# 创建数据库记录
image = Image(
user_id=user_id,
batch_id=batch_id,
original_filename=original_filename,
task_id=task_id,
image_types_id=image_types_id,
stored_filename=new_name,
file_path=final_path,
file_size=os.path.getsize(final_path),
image_type_id=image_type_id,
width=img.width,
height=img.height
)
@ -133,7 +132,7 @@ class ImageService:
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except:
except Exception:
pass
# 继续循环尝试下一个序号
time.sleep(0.005)
@ -151,26 +150,24 @@ class ImageService:
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except:
except Exception:
pass
return {'success': False, 'error': f'图片预处理失败: {str(e)}'}
"""图像处理服务"""
@staticmethod
def save_image(file, batch_id, user_id, image_type_id, resolution=512, target_format='png'):
def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'):
"""保存单张图片自动上传到uploads并预处理"""
try:
if not file or not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
temp_path, orig_name = ImageService.save_to_uploads(file, batch_id, user_id)
return ImageService.preprocess_image(temp_path, orig_name, batch_id, user_id, image_type_id, resolution, target_format)
temp_path, orig_name = ImageService.save_to_uploads(file, task_id, user_id)
return ImageService.preprocess_image(temp_path, orig_name, task_id, user_id, image_types_id, resolution, target_format)
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
@staticmethod
def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id):
def extract_and_save_zip(zip_file, task_id, user_id, image_types_id):
"""解压并保存压缩包中的图片"""
results = []
temp_dir = None
@ -209,7 +206,7 @@ class ImageService:
shutil.copy2(self.path, destination)
virtual_file = FileWrapper(file_path, filename)
result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id)
result = ImageService.save_image(virtual_file, task_id, user_id, image_types_id)
results.append(result)
return results
@ -223,7 +220,7 @@ class ImageService:
import shutil
try:
shutil.rmtree(temp_dir)
except:
except Exception:
pass
@staticmethod
@ -233,15 +230,19 @@ class ImageService:
return None
# 这里返回相对路径前端可以拼接完整URL
return f"/api/image/file/{image.id}"
return f"/api/image/file/{image.images_id}"
@staticmethod
def delete_image(image_id, user_id):
"""删除图片"""
"""删除图片通过关联的task验证权限"""
try:
image = Image.query.filter_by(id=image_id, user_id=user_id).first()
image = Image.query.filter_by(images_id=image_id).first()
if not image:
return {'success': False, 'error': '图片不存在或无权限'}
return {'success': False, 'error': '图片不存在'}
# 通过关联的task验证用户权限
if not image.task or image.task.user_id != user_id:
return {'success': False, 'error': '无权限删除该图片'}
# 删除文件
if os.path.exists(image.file_path):

Loading…
Cancel
Save