diff --git a/src/backend/app.py b/src/backend/app.py index 0855bf5..0edf176 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -8,6 +8,7 @@ from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate from flask_jwt_extended import JWTManager from flask_cors import CORS +from flask_mail import Mail from config.settings import Config # 初始化扩展 diff --git a/src/backend/app/controllers/admin_controller.py b/src/backend/app/controllers/admin_controller.py index a07f245..e84e5dd 100644 --- a/src/backend/app/controllers/admin_controller.py +++ b/src/backend/app/controllers/admin_controller.py @@ -61,8 +61,9 @@ def get_user_detail(user_id): total_tasks = Task.query.filter_by(user_id=user_id).count() # 查找用户的所有图片 user_tasks = Task.query.filter_by(user_id=user_id).all() - task_ids = [task.task_id for task in user_tasks] - total_images = Image.query.filter_by(task_id in (task_ids)).count() + task_ids = [task.tasks_id for task in user_tasks] + total_images = Image.query.filter(Image.task_id.in_(task_ids)).count() if task_ids else 0 + user_dict = user.to_dict() user_dict['stats'] = { @@ -101,11 +102,15 @@ def create_user(): if email and User.query.filter_by(email=email).first(): return jsonify({'error': '邮箱已被使用'}), 400 + # 角色映射 + role_map = {'admin': 1, 'vip': 2, 'normal': 3, 'user': 3} + role_id = role_map.get(role_code, 3) + # 创建用户 user = User( username=username, email=email, - role_id=User.role_to_id(role), + role_id=role_id, ) user.set_password(password) @@ -200,17 +205,25 @@ def delete_user(user_id): def get_system_stats(): """获取系统统计信息""" try: - from app.database import EvaluationResult + from app.database import TaskStatus total_users = User.query.count() - active_users = User.query.filter_by(is_active = True).count() - admin_users = User.query.filter_by(role_id = 0).count() + active_users = User.query.filter_by(is_active=True).count() + admin_users = User.query.filter_by(role_id=1).count() total_tasks = Task.query.count() - completed_tasks = Task.query.filter_by(status='completed').count() - processing_tasks = Task.query.filter_by(status='processing').count() - failed_tasks = Task.query.filter_by(status='failed').count() - waiting_tasks = Task.query.filter_by(status='waiting').count() + + # 通过 TaskStatus 表查询各状态的任务数 + def count_tasks_by_status(status_code): + status = TaskStatus.query.filter_by(task_status_code=status_code).first() + if status: + return Task.query.filter_by(tasks_status_id=status.task_status_id).count() + return 0 + + completed_tasks = count_tasks_by_status('completed') + processing_tasks = count_tasks_by_status('processing') + failed_tasks = count_tasks_by_status('failed') + waiting_tasks = count_tasks_by_status('waiting') total_images = Image.query.count() diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py index dac390c..9236d2d 100644 --- a/src/backend/app/controllers/auth_controller.py +++ b/src/backend/app/controllers/auth_controller.py @@ -25,7 +25,7 @@ def int_jwt_required(f): auth_bp = Blueprint('auth', __name__) @auth_bp.route('/code', methods=['GET']) -def send_email_verification_code(email: str = "3310207578@qq.com", purpose: str = 'register'): +def send_email_verification_code(email: str = "3310207578@qq.com", purpose: str = 'register'): email = "3310207578@qq.com" send_verification_code(email, purpose=purpose) return jsonify({'message': '验证码已发送'}), 200 @@ -60,15 +60,15 @@ def register(): if not code or not verify_code(email, code, purpose='register'): return jsonify({'error': '验证码无效或已过期'}), 400 - # 创建用户 - user = User(username=username, email=email) + # 创建用户(默认为普通用户,role_id=3) + user = User(username=username, email=email, role_id=3) user.set_password(password) db.session.add(user) db.session.commit() # 创建用户默认配置 - user_config = UserConfig(user_id=user.id) + user_config = UserConfig(user_id=user.user_id) db.session.add(user_config) db.session.commit() @@ -102,7 +102,7 @@ def login(): return jsonify({'error': '账户已被禁用'}), 401 # 创建访问令牌 - 确保用户ID为字符串类型 - access_token = create_access_token(identity=str(user.id)) + access_token = create_access_token(identity=str(user.user_id)) return jsonify({ 'message': '登录成功', @@ -144,6 +144,57 @@ def change_password(current_user_id): db.session.rollback() return jsonify({'error': f'密码修改失败: {str(e)}'}), 500 +@auth_bp.route('/change-email', methods = ['POST']) +@int_jwt_required +def change_email(current_user_id) + """修改邮箱""" + try: + user = User.query.filter_by(current_user_id) + data = request.get_json() + new_email = data.get('new_email') + code = data.get('code') + + if not new_email: + return jsonify({'error': '新邮箱不能为空'}), 400 + + if not User.query.filter(new_email).first(): + return jsonify({'error':'该邮箱已被使用'}), 400 + + if not code or not verify_code(email, code, purpose='register'): + return jsonify({'error': '验证码无效或已过期'}), 400 + + user.email = new_email + db.session.commit() + return jsonify({'message': '邮箱修改成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'邮箱修改失败: {str(e)}'}), 500 + +@auth_bp.route('/change-username', methods = ['POST']) +@int_jwt_required +def change_username(current_user_id) + """修改用户名""" + try: + user = User.query.filter_by(current_user_id) + data = request.get_json() + new_username = data.get('new_username') + + if not new_username: + return jsonify({'error': '新名称不能为空'}), 400 + + if not User.query.filter(new_username).first(): + return jsonify({'error':'该用户名已被使用'}), 400 + + user.name = new_username + db.session.commit() + return jsonify({'message': '用户名修改成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'用户名修改失败: {str(e)}'}), 500 + + @auth_bp.route('/profile', methods=['GET']) @int_jwt_required def get_profile(current_user_id): @@ -163,4 +214,4 @@ def get_profile(current_user_id): @jwt_required() def logout(): """用户登出(客户端删除token即可)""" - return jsonify({'message': '登出成功'}), 200 \ No newline at end of file + return jsonify({'message': '登出成功'}), 200 diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index 08aef56..f469396 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -1,14 +1,15 @@ - """ 图像管理控制器 -负责图片上传、下载等操作 +负责图片上传、查询、获取等操作 """ +import os +import base64 from flask import Blueprint, request, jsonify, send_file from app.controllers.auth_controller import int_jwt_required from app.services.task_service import TaskService from app.services.image_service import ImageService - +from app.database import Image, ImageType image_bp = Blueprint('image', __name__) @@ -41,48 +42,123 @@ def upload_original_images(current_user_id): return jsonify({ 'message': '图片上传成功', - 'images': [ImageService.serialize_image(img) for img in result], + 'images': [image_to_base64(img) for img in result], 'flow_id': task.flow_id }), 201 -# ==================== 结果下载 ==================== +# ==================== 单张图片获取 ==================== + +@image_bp.route('/file/', methods=['GET']) +@int_jwt_required +def get_image_file(image_id, current_user_id): + """获取单张图片文件(直接返回图片二进制)""" + image = Image.query.get(image_id) + if not image: + return ImageService.json_error('图片不存在', 404) + + task = image.task + if not task or task.user_id != current_user_id: + return ImageService.json_error('无权限访问该图片', 403) + + if not os.path.exists(image.file_path): + return ImageService.json_error('图片文件不存在', 404) + + ext = os.path.splitext(image.file_path)[1].lower() + mime_types = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + } + mimetype = mime_types.get(ext, 'application/octet-stream') + + return send_file(image.file_path, mimetype=mimetype) + + +# ==================== 任务图片获取(返回 base64) ==================== + +@image_bp.route('/task/', methods=['GET']) +@int_jwt_required +def get_task_images(task_id, current_user_id): + """获取任务的所有图片(base64格式)""" + task = TaskService.load_task_for_user(task_id, current_user_id) + if not task: + return ImageService.json_error('任务不存在或无权限', 404) + + image_type_code = request.args.get('type') + + query = Image.query.filter_by(task_id=task_id) + if image_type_code: + image_type = ImageType.query.filter_by(image_code=image_type_code).first() + if image_type: + query = query.filter_by(image_types_id=image_type.image_types_id) + + images = query.all() + + return jsonify({ + 'task_id': task_id, + 'images': [image_to_base64(img) for img in images], + 'total': len(images) + }), 200 + -@image_bp.route('/perturbation//download', methods=['GET']) +@image_bp.route('/perturbation/', methods=['GET']) @int_jwt_required -def download_perturbation_result(task_id, current_user_id): +def get_perturbation_images(task_id, current_user_id): + """获取加噪任务的结果图片(base64格式)""" task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation') if not task: return ImageService.json_error('任务不存在或无权限', 404) - directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id) - zipped, has_files = ImageService.zip_directory(directory) - if not has_files: - return ImageService.json_error('结果文件不存在', 404) + perturbed_type = ImageType.query.filter_by(image_code='perturbed').first() + if not perturbed_type: + return ImageService.json_error('图片类型未配置', 500) - filename = f"perturbation_{task_id}.zip" - return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + images = Image.query.filter_by( + task_id=task_id, + image_types_id=perturbed_type.image_types_id + ).all() + + return jsonify({ + 'task_id': task_id, + 'task_type': 'perturbation', + 'images': [image_to_base64(img) for img in images], + 'total': len(images) + }), 200 -@image_bp.route('/heatmap//download', methods=['GET']) +@image_bp.route('/heatmap/', methods=['GET']) @int_jwt_required -def download_heatmap_result(task_id, current_user_id): +def get_heatmap_images(task_id, current_user_id): + """获取热力图任务的结果图片(base64格式)""" task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap') if not task: return ImageService.json_error('任务不存在或无权限', 404) - directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id) - zipped, has_files = ImageService.zip_directory(directory) - if not has_files: - return ImageService.json_error('热力图文件不存在', 404) + heatmap_type = ImageType.query.filter_by(image_code='heatmap').first() + if not heatmap_type: + return ImageService.json_error('图片类型未配置', 500) - filename = f"heatmap_{task_id}.zip" - return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + images = Image.query.filter_by( + task_id=task_id, + image_types_id=heatmap_type.image_types_id + ).all() + return jsonify({ + 'task_id': task_id, + 'task_type': 'heatmap', + 'images': [image_to_base64(img) for img in images], + 'total': len(images) + }), 200 -@image_bp.route('/finetune//download', methods=['GET']) + +@image_bp.route('/finetune/', methods=['GET']) @int_jwt_required -def download_finetune_result(task_id, current_user_id): +def get_finetune_images(task_id, current_user_id): + """获取微调任务的生成图片(base64格式)""" task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune') if not task: return ImageService.json_error('任务不存在或无权限', 404) @@ -94,35 +170,89 @@ def download_finetune_result(task_id, current_user_id): source = TaskService.determine_finetune_source(task) except ValueError as exc: return ImageService.json_error(str(exc), 500) + + result = {'task_id': task_id, 'task_type': 'finetune', 'source': source} + if source == 'perturbation': - directories = { - 'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id), - 'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id) - } + original_gen_type = ImageType.query.filter_by(image_code='original_generate').first() + perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first() + + original_images = [] + perturbed_images = [] + + if original_gen_type: + original_images = Image.query.filter_by( + task_id=task_id, + image_types_id=original_gen_type.image_types_id + ).all() + + if perturbed_gen_type: + perturbed_images = Image.query.filter_by( + task_id=task_id, + image_types_id=perturbed_gen_type.image_types_id + ).all() + + result['original_generate'] = [image_to_base64(img) for img in original_images] + result['perturbed_generate'] = [image_to_base64(img) for img in perturbed_images] + result['total'] = len(original_images) + len(perturbed_images) else: - directories = { - 'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id) - } + uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first() + uploaded_images = [] - zipped, has_files = ImageService.zip_multiple_directories(directories) - if not has_files: - return ImageService.json_error('微调结果文件不存在', 404) + if uploaded_gen_type: + uploaded_images = Image.query.filter_by( + task_id=task_id, + image_types_id=uploaded_gen_type.image_types_id + ).all() - filename = f"finetune_{task_id}.zip" - return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + result['uploaded_generate'] = [image_to_base64(img) for img in uploaded_images] + result['total'] = len(uploaded_images) + return jsonify(result), 200 -@image_bp.route('/evaluate//download', methods=['GET']) + +@image_bp.route('/evaluate/', methods=['GET']) @int_jwt_required -def download_evaluate_result(task_id, current_user_id): +def get_evaluate_images(task_id, current_user_id): + """获取评估任务的结果图片(base64格式)""" task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate') if not task: return ImageService.json_error('任务不存在或无权限', 404) - directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id) - zipped, has_files = ImageService.zip_directory(directory) - if not has_files: - return ImageService.json_error('评估结果文件不存在', 404) + report_type = ImageType.query.filter_by(image_code='report').first() + if not report_type: + return ImageService.json_error('图片类型未配置', 500) + + images = Image.query.filter_by( + task_id=task_id, + image_types_id=report_type.image_types_id + ).all() + + return jsonify({ + 'task_id': task_id, + 'task_type': 'evaluate', + 'images': [image_to_base64(img) for img in images], + 'total': len(images) + }), 200 + + +# ==================== 图片删除 ==================== + +@image_bp.route('/', methods=['DELETE']) +@int_jwt_required +def delete_image(image_id, current_user_id): + """删除单张图片""" + image = Image.query.get(image_id) + if not image: + return ImageService.json_error('图片不存在', 404) + + task = image.task + if not task or task.user_id != current_user_id: + return ImageService.json_error('无权限删除该图片', 403) + + result = ImageService.delete_image(image_id, current_user_id) + if not result.get('success'): + return ImageService.json_error(result.get('error', '删除失败'), 500) - filename = f"evaluate_{task_id}.zip" - return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip') + return jsonify({'message': '图片删除成功'}), 200 + \ No newline at end of file diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index 933ad78..ce88ab4 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -7,7 +7,6 @@ import io import os import uuid import zipfile -import fcntl import time from datetime import datetime from werkzeug.utils import secure_filename @@ -425,4 +424,32 @@ class ImageService: 'width': image.width, 'height': image.height, 'image_type': image.image_type.image_code if image.image_type else None + } + + @staticmethod + def image_to_base64(image): + """将图片转换为 base64 编码""" + if not image or not os.path.exists(image.file_path): + return None + + ext = os.path.splitext(image.file_path)[1].lower() + mime_types = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + } + mimetype = mime_types.get(ext, 'image/png') + + with open(image.file_path, 'rb') as f: + data = base64.b64encode(f.read()).decode('utf-8') + + return { + 'image_id': image.images_id, + 'filename': image.stored_filename, + 'data': f'data:{mimetype};base64,{data}', + 'width': image.width, + 'height': image.height } \ No newline at end of file