diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index f469396..0d6bd03 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -42,7 +42,7 @@ def upload_original_images(current_user_id): return jsonify({ 'message': '图片上传成功', - 'images': [image_to_base64(img) for img in result], + 'images': [ImageService.image_to_base64(img) for img in result], 'flow_id': task.flow_id }), 201 @@ -100,7 +100,7 @@ def get_task_images(task_id, current_user_id): return jsonify({ 'task_id': task_id, - 'images': [image_to_base64(img) for img in images], + 'images': [ImageService.image_to_base64(img) for img in images], 'total': len(images) }), 200 @@ -125,7 +125,7 @@ def get_perturbation_images(task_id, current_user_id): return jsonify({ 'task_id': task_id, 'task_type': 'perturbation', - 'images': [image_to_base64(img) for img in images], + 'images': [ImageService.image_to_base64(img) for img in images], 'total': len(images) }), 200 @@ -150,7 +150,7 @@ def get_heatmap_images(task_id, current_user_id): return jsonify({ 'task_id': task_id, 'task_type': 'heatmap', - 'images': [image_to_base64(img) for img in images], + 'images': [ImageService.image_to_base64(img) for img in images], 'total': len(images) }), 200 @@ -192,8 +192,8 @@ def get_finetune_images(task_id, current_user_id): image_types_id=perturbed_gen_type.image_types_id ).all() - result['original_generate'] = [image_to_base64(img) for img in original_images] - result['perturbed_generate'] = [image_to_base64(img) for img in perturbed_images] + result['original_generate'] = [ImageService.image_to_base64(img) for img in original_images] + result['perturbed_generate'] = [ImageService.image_to_base64(img) for img in perturbed_images] result['total'] = len(original_images) + len(perturbed_images) else: uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first() @@ -205,7 +205,7 @@ def get_finetune_images(task_id, current_user_id): image_types_id=uploaded_gen_type.image_types_id ).all() - result['uploaded_generate'] = [image_to_base64(img) for img in uploaded_images] + result['uploaded_generate'] = [ImageService.image_to_base64(img) for img in uploaded_images] result['total'] = len(uploaded_images) return jsonify(result), 200 @@ -231,7 +231,7 @@ def get_evaluate_images(task_id, current_user_id): return jsonify({ 'task_id': task_id, 'task_type': 'evaluate', - 'images': [image_to_base64(img) for img in images], + 'images': [ImageService.image_to_base64(img) for img in images], 'total': len(images) }), 200 diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index ce88ab4..1bdddd0 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -3,6 +3,7 @@ 处理图像上传、保存等功能 """ +import base64 import io import os import uuid @@ -18,12 +19,12 @@ from app.utils.file_utils import allowed_file class ImageService: @staticmethod - def save_to_uploads(file, batch_id, user_id): + def save_to_uploads(file, task_id, user_id): """ 上传图片到uploads临时目录,返回临时文件路径和原始文件名。 """ project_root = os.path.dirname(current_app.root_path) - upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) + upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(task_id)) os.makedirs(upload_dir, exist_ok=True) orig_ext = os.path.splitext(file.filename)[1].lower() temp_name = f"{uuid.uuid4().hex}{orig_ext}" @@ -32,7 +33,7 @@ class ImageService: return temp_path, file.filename @staticmethod - def preprocess_image(temp_path, original_filename, batch_id, user_id, image_type_id, resolution=512, target_format='png'): + def preprocess_image(temp_path, original_filename, task_id, user_id, image_types_id, resolution=512, target_format='png'): """ 对图片进行中心裁剪、缩放、格式转换、重命名,保存到static/originals,返回数据库对象。 原图命名格式: 0000.png, 0001.png, ..., 9999.png @@ -53,23 +54,23 @@ class ImageService: img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS) project_root = os.path.dirname(current_app.root_path) - static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id)) + static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(task_id)) os.makedirs(static_dir, exist_ok=True) from app.database import ImageType - original_type = ImageType.query.filter_by(type_code='original').first() - target_image_type_id = original_type.id if original_type else image_type_id + original_type = ImageType.query.filter_by(image_code='original').first() + target_image_types_id = original_type.image_types_id if original_type else image_types_id # 首次查询最大序号 max_seq_result = db.session.execute( db.text(""" SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq FROM images - WHERE batch_id = :batch_id - AND image_type_id = :image_type_id + WHERE task_id = :task_id + AND image_types_id = :image_types_id AND stored_filename REGEXP '^[0-9]{4}\\.' """), - {'batch_id': batch_id, 'image_type_id': target_image_type_id} + {'task_id': task_id, 'image_types_id': target_image_types_id} ).fetchone() # 强制类型转换,确保安全 @@ -89,7 +90,7 @@ class ImageService: try: # 检查数据库中是否已存在此文件名 existing = Image.query.filter_by( - batch_id=batch_id, + task_id=task_id, stored_filename=new_name ).first() @@ -105,13 +106,11 @@ class ImageService: # 创建数据库记录 image = Image( - user_id=user_id, - batch_id=batch_id, - original_filename=original_filename, + task_id=task_id, + image_types_id=image_types_id, stored_filename=new_name, file_path=final_path, file_size=os.path.getsize(final_path), - image_type_id=image_type_id, width=img.width, height=img.height ) @@ -133,7 +132,7 @@ class ImageService: if final_path and os.path.exists(final_path): try: os.remove(final_path) - except: + except Exception: pass # 继续循环尝试下一个序号 time.sleep(0.005) @@ -151,26 +150,24 @@ class ImageService: if final_path and os.path.exists(final_path): try: os.remove(final_path) - except: + except Exception: pass return {'success': False, 'error': f'图片预处理失败: {str(e)}'} - - """图像处理服务""" - + @staticmethod - def save_image(file, batch_id, user_id, image_type_id, resolution=512, target_format='png'): + def save_image(file, task_id, user_id, image_types_id, resolution=512, target_format='png'): """保存单张图片,自动上传到uploads并预处理""" try: if not file or not allowed_file(file.filename): return {'success': False, 'error': '不支持的文件格式'} - temp_path, orig_name = ImageService.save_to_uploads(file, batch_id, user_id) - return ImageService.preprocess_image(temp_path, orig_name, batch_id, user_id, image_type_id, resolution, target_format) + temp_path, orig_name = ImageService.save_to_uploads(file, task_id, user_id) + return ImageService.preprocess_image(temp_path, orig_name, task_id, user_id, image_types_id, resolution, target_format) except Exception as e: db.session.rollback() return {'success': False, 'error': f'保存图片失败: {str(e)}'} @staticmethod - def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id): + def extract_and_save_zip(zip_file, task_id, user_id, image_types_id): """解压并保存压缩包中的图片""" results = [] temp_dir = None @@ -209,7 +206,7 @@ class ImageService: shutil.copy2(self.path, destination) virtual_file = FileWrapper(file_path, filename) - result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id) + result = ImageService.save_image(virtual_file, task_id, user_id, image_types_id) results.append(result) return results @@ -223,7 +220,7 @@ class ImageService: import shutil try: shutil.rmtree(temp_dir) - except: + except Exception: pass @staticmethod @@ -233,15 +230,19 @@ class ImageService: return None # 这里返回相对路径,前端可以拼接完整URL - return f"/api/image/file/{image.id}" + return f"/api/image/file/{image.images_id}" @staticmethod def delete_image(image_id, user_id): - """删除图片""" + """删除图片(通过关联的task验证权限)""" try: - image = Image.query.filter_by(id=image_id, user_id=user_id).first() + image = Image.query.filter_by(images_id=image_id).first() if not image: - return {'success': False, 'error': '图片不存在或无权限'} + return {'success': False, 'error': '图片不存在'} + + # 通过关联的task验证用户权限 + if not image.task or image.task.user_id != user_id: + return {'success': False, 'error': '无权限删除该图片'} # 删除文件 if os.path.exists(image.file_path):