图片接口更改完成 #12

Merged
ppy4sjqvf merged 3 commits from ybw-branch into develop 1 month ago

@ -8,6 +8,7 @@ from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
from flask_mail import Mail
from config.settings import Config
# 初始化扩展

@ -61,8 +61,9 @@ def get_user_detail(user_id):
total_tasks = Task.query.filter_by(user_id=user_id).count()
# 查找用户的所有图片
user_tasks = Task.query.filter_by(user_id=user_id).all()
task_ids = [task.task_id for task in user_tasks]
total_images = Image.query.filter_by(task_id in (task_ids)).count()
task_ids = [task.tasks_id for task in user_tasks]
total_images = Image.query.filter(Image.task_id.in_(task_ids)).count() if task_ids else 0
user_dict = user.to_dict()
user_dict['stats'] = {
@ -101,11 +102,15 @@ def create_user():
if email and User.query.filter_by(email=email).first():
return jsonify({'error': '邮箱已被使用'}), 400
# 角色映射
role_map = {'admin': 1, 'vip': 2, 'normal': 3, 'user': 3}
role_id = role_map.get(role_code, 3)
# 创建用户
user = User(
username=username,
email=email,
role_id=User.role_to_id(role),
role_id=role_id,
)
user.set_password(password)
@ -200,17 +205,25 @@ def delete_user(user_id):
def get_system_stats():
"""获取系统统计信息"""
try:
from app.database import EvaluationResult
from app.database import TaskStatus
total_users = User.query.count()
active_users = User.query.filter_by(is_active = True).count()
admin_users = User.query.filter_by(role_id = 0).count()
active_users = User.query.filter_by(is_active=True).count()
admin_users = User.query.filter_by(role_id=1).count()
total_tasks = Task.query.count()
completed_tasks = Task.query.filter_by(status='completed').count()
processing_tasks = Task.query.filter_by(status='processing').count()
failed_tasks = Task.query.filter_by(status='failed').count()
waiting_tasks = Task.query.filter_by(status='waiting').count()
# 通过 TaskStatus 表查询各状态的任务数
def count_tasks_by_status(status_code):
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if status:
return Task.query.filter_by(tasks_status_id=status.task_status_id).count()
return 0
completed_tasks = count_tasks_by_status('completed')
processing_tasks = count_tasks_by_status('processing')
failed_tasks = count_tasks_by_status('failed')
waiting_tasks = count_tasks_by_status('waiting')
total_images = Image.query.count()

@ -25,7 +25,7 @@ def int_jwt_required(f):
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/code', methods=['GET'])
def send_email_verification_code(email: str = "3310207578@qq.com", purpose: str = 'register'):
def send_email_verification_code(email: str = "3310207578@qq.com", purpose: str = 'register'):
email = "3310207578@qq.com"
send_verification_code(email, purpose=purpose)
return jsonify({'message': '验证码已发送'}), 200
@ -60,15 +60,15 @@ def register():
if not code or not verify_code(email, code, purpose='register'):
return jsonify({'error': '验证码无效或已过期'}), 400
# 创建用户
user = User(username=username, email=email)
# 创建用户默认为普通用户role_id=3
user = User(username=username, email=email, role_id=3)
user.set_password(password)
db.session.add(user)
db.session.commit()
# 创建用户默认配置
user_config = UserConfig(user_id=user.id)
user_config = UserConfig(user_id=user.user_id)
db.session.add(user_config)
db.session.commit()
@ -102,7 +102,7 @@ def login():
return jsonify({'error': '账户已被禁用'}), 401
# 创建访问令牌 - 确保用户ID为字符串类型
access_token = create_access_token(identity=str(user.id))
access_token = create_access_token(identity=str(user.user_id))
return jsonify({
'message': '登录成功',
@ -144,6 +144,57 @@ def change_password(current_user_id):
db.session.rollback()
return jsonify({'error': f'密码修改失败: {str(e)}'}), 500
@auth_bp.route('/change-email', methods = ['POST'])
@int_jwt_required
def change_email(current_user_id)
"""修改邮箱"""
try:
user = User.query.filter_by(current_user_id)
data = request.get_json()
new_email = data.get('new_email')
code = data.get('code')
if not new_email:
return jsonify({'error': '新邮箱不能为空'}), 400
if not User.query.filter(new_email).first():
return jsonify({'error':'该邮箱已被使用'}) 400
if not code or not verify_code(email, code, purpose='register'):
return jsonify({'error': '验证码无效或已过期'}), 400
user.email = new_email
db.session.commit()
return jsonify({'message': '邮箱修改成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'邮箱修改失败: {str(e)}'}), 500
@auth_bp.route('/change-username', methods = ['POST'])
@int_jwt_required
def change_username(current_user_id)
"""修改用户名"""
try:
user = User.query.filter_by(current_user_id)
data = request.get_json()
new_username = data.get('new_username')
if not new_username:
return jsonify({'error': '新名称不能为空'}), 400
if not User.query.filter(new_username).first():
return jsonify({'error':'该用户名已被使用'}) 400
user.name = new_username
db.session.commit()
return jsonify({'message': '用户名修改成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'用户名修改失败: {str(e)}'}), 500
@auth_bp.route('/profile', methods=['GET'])
@int_jwt_required
def get_profile(current_user_id):
@ -163,4 +214,4 @@ def get_profile(current_user_id):
@jwt_required()
def logout():
"""用户登出客户端删除token即可"""
return jsonify({'message': '登出成功'}), 200
return jsonify({'message': '登出成功'}), 200

@ -1,14 +1,15 @@
"""
图像管理控制器
负责图片上传下载等操作
负责图片上传查询获取等操作
"""
import os
import base64
from flask import Blueprint, request, jsonify, send_file
from app.controllers.auth_controller import int_jwt_required
from app.services.task_service import TaskService
from app.services.image_service import ImageService
from app.database import Image, ImageType
image_bp = Blueprint('image', __name__)
@ -41,48 +42,123 @@ def upload_original_images(current_user_id):
return jsonify({
'message': '图片上传成功',
'images': [ImageService.serialize_image(img) for img in result],
'images': [image_to_base64(img) for img in result],
'flow_id': task.flow_id
}), 201
# ==================== 结果下载 ====================
# ==================== 单张图片获取 ====================
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@int_jwt_required
def get_image_file(image_id, current_user_id):
"""获取单张图片文件(直接返回图片二进制)"""
image = Image.query.get(image_id)
if not image:
return ImageService.json_error('图片不存在', 404)
task = image.task
if not task or task.user_id != current_user_id:
return ImageService.json_error('无权限访问该图片', 403)
if not os.path.exists(image.file_path):
return ImageService.json_error('图片文件不存在', 404)
ext = os.path.splitext(image.file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
mimetype = mime_types.get(ext, 'application/octet-stream')
return send_file(image.file_path, mimetype=mimetype)
# ==================== 任务图片获取(返回 base64 ====================
@image_bp.route('/task/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_task_images(task_id, current_user_id):
"""获取任务的所有图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
image_type_code = request.args.get('type')
query = Image.query.filter_by(task_id=task_id)
if image_type_code:
image_type = ImageType.query.filter_by(image_code=image_type_code).first()
if image_type:
query = query.filter_by(image_types_id=image_type.image_types_id)
images = query.all()
return jsonify({
'task_id': task_id,
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/perturbation/<int:task_id>/download', methods=['GET'])
@image_bp.route('/perturbation/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_perturbation_result(task_id, current_user_id):
def get_perturbation_images(task_id, current_user_id):
"""获取加噪任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='perturbation')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_perturbed_images_path(task.user_id, task.flow_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('结果文件不存在', 404)
perturbed_type = ImageType.query.filter_by(image_code='perturbed').first()
if not perturbed_type:
return ImageService.json_error('图片类型未配置', 500)
filename = f"perturbation_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'perturbation',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/heatmap/<int:task_id>/download', methods=['GET'])
@image_bp.route('/heatmap/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_heatmap_result(task_id, current_user_id):
def get_heatmap_images(task_id, current_user_id):
"""获取热力图任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='heatmap')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_heatmap_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('热力图文件不存在', 404)
heatmap_type = ImageType.query.filter_by(image_code='heatmap').first()
if not heatmap_type:
return ImageService.json_error('图片类型未配置', 500)
filename = f"heatmap_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
images = Image.query.filter_by(
task_id=task_id,
image_types_id=heatmap_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'heatmap',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
@image_bp.route('/finetune/<int:task_id>/download', methods=['GET'])
@image_bp.route('/finetune/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_finetune_result(task_id, current_user_id):
def get_finetune_images(task_id, current_user_id):
"""获取微调任务的生成图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='finetune')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
@ -94,35 +170,89 @@ def download_finetune_result(task_id, current_user_id):
source = TaskService.determine_finetune_source(task)
except ValueError as exc:
return ImageService.json_error(str(exc), 500)
result = {'task_id': task_id, 'task_type': 'finetune', 'source': source}
if source == 'perturbation':
directories = {
'original_generate': TaskService.get_original_generated_path(task.user_id, task.flow_id, task.tasks_id),
'perturbed_generate': TaskService.get_perturbed_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
original_gen_type = ImageType.query.filter_by(image_code='original_generate').first()
perturbed_gen_type = ImageType.query.filter_by(image_code='perturbed_generate').first()
original_images = []
perturbed_images = []
if original_gen_type:
original_images = Image.query.filter_by(
task_id=task_id,
image_types_id=original_gen_type.image_types_id
).all()
if perturbed_gen_type:
perturbed_images = Image.query.filter_by(
task_id=task_id,
image_types_id=perturbed_gen_type.image_types_id
).all()
result['original_generate'] = [image_to_base64(img) for img in original_images]
result['perturbed_generate'] = [image_to_base64(img) for img in perturbed_images]
result['total'] = len(original_images) + len(perturbed_images)
else:
directories = {
'uploaded_generate': TaskService.get_uploaded_generated_path(task.user_id, task.flow_id, task.tasks_id)
}
uploaded_gen_type = ImageType.query.filter_by(image_code='uploaded_generate').first()
uploaded_images = []
zipped, has_files = ImageService.zip_multiple_directories(directories)
if not has_files:
return ImageService.json_error('微调结果文件不存在', 404)
if uploaded_gen_type:
uploaded_images = Image.query.filter_by(
task_id=task_id,
image_types_id=uploaded_gen_type.image_types_id
).all()
filename = f"finetune_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
result['uploaded_generate'] = [image_to_base64(img) for img in uploaded_images]
result['total'] = len(uploaded_images)
return jsonify(result), 200
@image_bp.route('/evaluate/<int:task_id>/download', methods=['GET'])
@image_bp.route('/evaluate/<int:task_id>', methods=['GET'])
@int_jwt_required
def download_evaluate_result(task_id, current_user_id):
def get_evaluate_images(task_id, current_user_id):
"""获取评估任务的结果图片base64格式"""
task = TaskService.load_task_for_user(task_id, current_user_id, expected_type='evaluate')
if not task:
return ImageService.json_error('任务不存在或无权限', 404)
directory = TaskService.get_evaluate_path(task.user_id, task.flow_id, task.tasks_id)
zipped, has_files = ImageService.zip_directory(directory)
if not has_files:
return ImageService.json_error('评估结果文件不存在', 404)
report_type = ImageType.query.filter_by(image_code='report').first()
if not report_type:
return ImageService.json_error('图片类型未配置', 500)
images = Image.query.filter_by(
task_id=task_id,
image_types_id=report_type.image_types_id
).all()
return jsonify({
'task_id': task_id,
'task_type': 'evaluate',
'images': [image_to_base64(img) for img in images],
'total': len(images)
}), 200
# ==================== 图片删除 ====================
@image_bp.route('/<int:image_id>', methods=['DELETE'])
@int_jwt_required
def delete_image(image_id, current_user_id):
"""删除单张图片"""
image = Image.query.get(image_id)
if not image:
return ImageService.json_error('图片不存在', 404)
task = image.task
if not task or task.user_id != current_user_id:
return ImageService.json_error('无权限删除该图片', 403)
result = ImageService.delete_image(image_id, current_user_id)
if not result.get('success'):
return ImageService.json_error(result.get('error', '删除失败'), 500)
filename = f"evaluate_{task_id}.zip"
return send_file(zipped, download_name=filename, as_attachment=True, mimetype='application/zip')
return jsonify({'message': '图片删除成功'}), 200

@ -7,7 +7,6 @@ import io
import os
import uuid
import zipfile
import fcntl
import time
from datetime import datetime
from werkzeug.utils import secure_filename
@ -425,4 +424,32 @@ class ImageService:
'width': image.width,
'height': image.height,
'image_type': image.image_type.image_code if image.image_type else None
}
@staticmethod
def image_to_base64(image):
"""将图片转换为 base64 编码"""
if not image or not os.path.exists(image.file_path):
return None
ext = os.path.splitext(image.file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
}
mimetype = mime_types.get(ext, 'image/png')
with open(image.file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
return {
'image_id': image.images_id,
'filename': image.stored_filename,
'data': f'data:{mimetype};base64,{data}',
'width': image.width,
'height': image.height
}
Loading…
Cancel
Save