将lianghao_branch合并到develop #1
Merged
hnu202326010204
merged 3 commits from lianghao_branch into develop 2 months ago
@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
|
||||
venv/
|
||||
|
||||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
|
||||
.env
|
||||
@ -0,0 +1,46 @@
|
||||
"""
|
||||
MuseGuard 后端主应用入口
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
from config.settings import Config
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
|
||||
def create_app(config_class=Config):
|
||||
"""Flask应用工厂函数"""
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config_class)
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
CORS(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
|
||||
return app
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
app.run(debug=True, host='0.0.0.0', port=5000)
|
||||
@ -0,0 +1,83 @@
|
||||
"""
|
||||
MuseGuard 后端应用包
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
import os
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
cors = CORS()
|
||||
|
||||
def create_app(config_name=None):
|
||||
"""应用工厂函数"""
|
||||
# 配置静态文件和模板文件路径
|
||||
app = Flask(__name__,
|
||||
static_folder='../static',
|
||||
static_url_path='/static')
|
||||
|
||||
# 加载配置
|
||||
if config_name is None:
|
||||
config_name = os.environ.get('FLASK_ENV', 'development')
|
||||
|
||||
from config.settings import config
|
||||
app.config.from_object(config[config_name])
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
cors.init_app(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
from app.controllers.demo_controller import demo_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
app.register_blueprint(demo_bp, url_prefix='/api/demo')
|
||||
|
||||
# 注册错误处理器
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return {'error': 'Not found'}, 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
db.session.rollback()
|
||||
return {'error': 'Internal server error'}, 500
|
||||
|
||||
# 根路由
|
||||
@app.route('/')
|
||||
def index():
|
||||
return {
|
||||
'message': 'MuseGuard API Server',
|
||||
'version': '1.0.0',
|
||||
'status': 'running',
|
||||
'endpoints': {
|
||||
'health': '/health',
|
||||
'api_docs': '/api',
|
||||
'test_page': '/static/test.html'
|
||||
}
|
||||
}
|
||||
|
||||
# 健康检查端点
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
|
||||
|
||||
return app
|
||||
@ -0,0 +1,176 @@
|
||||
"""
|
||||
对抗性扰动算法引擎
|
||||
实现各种加噪算法的虚拟版本
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
import uuid
|
||||
from flask import current_app
|
||||
|
||||
class PerturbationEngine:
|
||||
"""对抗性扰动处理引擎"""
|
||||
|
||||
@staticmethod
|
||||
def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None):
|
||||
"""
|
||||
应用对抗性扰动
|
||||
|
||||
Args:
|
||||
image_path: 原始图片路径
|
||||
algorithm: 算法名称 (simac, caat, pid)
|
||||
epsilon: 扰动强度
|
||||
use_strong_protection: 是否使用防净化版本
|
||||
|
||||
Returns:
|
||||
处理后图片的路径
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
# 加载图片
|
||||
with Image.open(image_path) as img:
|
||||
# 转换为RGB模式
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 根据算法选择处理方法
|
||||
if algorithm == 'simac':
|
||||
perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection)
|
||||
elif algorithm == 'caat':
|
||||
perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection)
|
||||
elif algorithm == 'pid':
|
||||
perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection)
|
||||
else:
|
||||
raise ValueError(f"不支持的算法: {algorithm}")
|
||||
|
||||
# 使用输入的output_path参数
|
||||
if output_path is None:
|
||||
# 如果没有提供输出路径,使用默认路径
|
||||
from flask import current_app
|
||||
project_root = os.path.dirname(current_app.root_path)
|
||||
perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER'])
|
||||
os.makedirs(perturbed_dir, exist_ok=True)
|
||||
|
||||
file_extension = os.path.splitext(image_path)[1]
|
||||
output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}"
|
||||
output_path = os.path.join(perturbed_dir, output_filename)
|
||||
|
||||
# 保存处理后的图片
|
||||
perturbed_img.save(output_path, quality=95)
|
||||
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"应用扰动时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _apply_simac(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
SimAC算法的虚拟实现
|
||||
Simple Anti-Customization Method for Protecting Face Privacy
|
||||
"""
|
||||
# 将PIL图像转换为numpy数组
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
# 生成随机噪声(模拟对抗性扰动)
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:添加更复杂的扰动模式
|
||||
noise = np.random.normal(0, noise_scale * 0.8, img_array.shape)
|
||||
# 添加结构化噪声
|
||||
h, w = img_array.shape[:2]
|
||||
for i in range(0, h, 8):
|
||||
for j in range(0, w, 8):
|
||||
block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3))
|
||||
noise[i:i+8, j:j+8] += block_noise
|
||||
else:
|
||||
# 标准版本:简单高斯噪声
|
||||
noise = np.random.normal(0, noise_scale, img_array.shape)
|
||||
|
||||
# 应用噪声
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
# 转换回PIL图像
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 轻微的图像增强以模拟算法特性
|
||||
enhancer = ImageEnhance.Contrast(result_img)
|
||||
result_img = enhancer.enhance(1.02)
|
||||
|
||||
return result_img
|
||||
|
||||
@staticmethod
|
||||
def _apply_caat(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
CAAT算法的虚拟实现
|
||||
Perturbing Attention Gives You More Bang for the Buck
|
||||
"""
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:注意力区域重点扰动
|
||||
# 模拟注意力图(简单的边缘检测)
|
||||
gray_img = img.convert('L')
|
||||
edge_img = gray_img.filter(ImageFilter.FIND_EDGES)
|
||||
attention_map = np.array(edge_img, dtype=np.float32) / 255.0
|
||||
|
||||
# 在注意力区域添加更强的噪声
|
||||
noise = np.random.normal(0, noise_scale * 0.6, img_array.shape)
|
||||
for c in range(3):
|
||||
noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape)
|
||||
else:
|
||||
# 标准版本:均匀分布噪声
|
||||
noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape)
|
||||
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 轻微模糊以模拟注意力扰动效果
|
||||
result_img = result_img.filter(ImageFilter.BoxBlur(0.5))
|
||||
|
||||
return result_img
|
||||
|
||||
@staticmethod
|
||||
def _apply_pid(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
PID算法的虚拟实现
|
||||
Prompt-Independent Data Protection Against Latent Diffusion Models
|
||||
"""
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:频域扰动
|
||||
# 简单的频域变换模拟
|
||||
noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape)
|
||||
# 添加周期性扰动
|
||||
h, w = img_array.shape[:2]
|
||||
for i in range(h):
|
||||
for j in range(w):
|
||||
periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1)
|
||||
noise[i, j] += periodic_noise
|
||||
else:
|
||||
# 标准版本:拉普拉斯噪声
|
||||
noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape)
|
||||
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 色彩微调以模拟潜在空间扰动
|
||||
enhancer = ImageEnhance.Color(result_img)
|
||||
result_img = enhancer.enhance(0.98)
|
||||
|
||||
return result_img
|
||||
@ -0,0 +1,239 @@
|
||||
"""
|
||||
管理员控制器
|
||||
处理管理员功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from app import db
|
||||
from app.models import User, Batch, Image
|
||||
|
||||
admin_bp = Blueprint('admin', __name__)
|
||||
|
||||
def admin_required(f):
|
||||
"""管理员权限装饰器"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
current_user_id = get_jwt_identity()
|
||||
user = User.query.get(current_user_id)
|
||||
|
||||
if not user or user.role != 'admin':
|
||||
return jsonify({'error': '需要管理员权限'}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
@admin_bp.route('/users', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def list_users():
|
||||
"""获取用户列表"""
|
||||
try:
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 20, type=int)
|
||||
|
||||
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
return jsonify({
|
||||
'users': [user.to_dict() for user in users.items],
|
||||
'total': users.total,
|
||||
'pages': users.pages,
|
||||
'current_page': page
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_user_detail(user_id):
|
||||
"""获取用户详情"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 获取用户统计信息
|
||||
total_tasks = Batch.query.filter_by(user_id=user_id).count()
|
||||
total_images = Image.query.filter_by(user_id=user_id).count()
|
||||
|
||||
user_dict = user.to_dict()
|
||||
user_dict['stats'] = {
|
||||
'total_tasks': total_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
|
||||
return jsonify({'user': user_dict}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def create_user():
|
||||
"""创建用户"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
email = data.get('email')
|
||||
role = data.get('role', 'user')
|
||||
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
|
||||
|
||||
if not username or not password:
|
||||
return jsonify({'error': '用户名和密码不能为空'}), 400
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if User.query.filter_by(username=username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if email and User.query.filter_by(email=email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
role=role,
|
||||
max_concurrent_tasks=max_concurrent_tasks
|
||||
)
|
||||
user.set_password(password)
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户创建成功',
|
||||
'user': user.to_dict()
|
||||
}), 201
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def update_user(user_id):
|
||||
"""更新用户信息"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
# 更新字段
|
||||
if 'username' in data:
|
||||
new_username = data['username']
|
||||
if new_username != user.username:
|
||||
if User.query.filter_by(username=new_username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
user.username = new_username
|
||||
|
||||
if 'email' in data:
|
||||
new_email = data['email']
|
||||
if new_email != user.email:
|
||||
if User.query.filter_by(email=new_email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
user.email = new_email
|
||||
|
||||
if 'role' in data:
|
||||
user.role = data['role']
|
||||
|
||||
if 'max_concurrent_tasks' in data:
|
||||
user.max_concurrent_tasks = data['max_concurrent_tasks']
|
||||
|
||||
if 'is_active' in data:
|
||||
user.is_active = bool(data['is_active'])
|
||||
|
||||
if 'password' in data and data['password']:
|
||||
user.set_password(data['password'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户信息更新成功',
|
||||
'user': user.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def delete_user(user_id):
|
||||
"""删除用户"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 不能删除自己
|
||||
if user_id == current_user_id:
|
||||
return jsonify({'error': '不能删除自己的账户'}), 400
|
||||
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 删除用户(级联删除相关数据)
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': '用户删除成功'}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/stats', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_system_stats():
|
||||
"""获取系统统计信息"""
|
||||
try:
|
||||
from app.models import EvaluationResult
|
||||
|
||||
total_users = User.query.count()
|
||||
active_users = User.query.filter_by(is_active=True).count()
|
||||
admin_users = User.query.filter_by(role='admin').count()
|
||||
|
||||
total_tasks = Batch.query.count()
|
||||
completed_tasks = Batch.query.filter_by(status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(status='failed').count()
|
||||
|
||||
total_images = Image.query.count()
|
||||
total_evaluations = EvaluationResult.query.count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'users': {
|
||||
'total': total_users,
|
||||
'active': active_users,
|
||||
'admin': admin_users
|
||||
},
|
||||
'tasks': {
|
||||
'total': total_tasks,
|
||||
'completed': completed_tasks,
|
||||
'processing': processing_tasks,
|
||||
'failed': failed_tasks
|
||||
},
|
||||
'images': {
|
||||
'total': total_images
|
||||
},
|
||||
'evaluations': {
|
||||
'total': total_evaluations
|
||||
}
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500
|
||||
@ -0,0 +1,400 @@
|
||||
"""
|
||||
任务管理控制器
|
||||
处理创建任务、上传图片等功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from werkzeug.utils import secure_filename
|
||||
from app import db
|
||||
from app.models import User, Batch, Image, ImageType, UserConfig
|
||||
from app.services.task_service import TaskService
|
||||
from app.services.image_service import ImageService
|
||||
from app.utils.file_utils import allowed_file, save_uploaded_file
|
||||
import os
|
||||
import zipfile
|
||||
import uuid
|
||||
|
||||
task_bp = Blueprint('task', __name__)
|
||||
|
||||
@task_bp.route('/create', methods=['POST'])
|
||||
@jwt_required()
|
||||
def create_task():
|
||||
"""创建新任务(仅创建任务,使用默认配置)"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
user = User.query.get(current_user_id)
|
||||
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
data = request.get_json()
|
||||
batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}')
|
||||
|
||||
# 使用默认配置创建任务
|
||||
batch = Batch(
|
||||
user_id=current_user_id,
|
||||
batch_name=batch_name,
|
||||
perturbation_config_id=1, # 默认配置
|
||||
preferred_epsilon=8.0, # 默认epsilon
|
||||
finetune_config_id=1, # 默认微调配置
|
||||
use_strong_protection=False # 默认不启用强防护
|
||||
)
|
||||
|
||||
db.session.add(batch)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '任务创建成功,请上传图片',
|
||||
'task': batch.to_dict()
|
||||
}), 201
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'任务创建失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/upload/<int:batch_id>', methods=['POST'])
|
||||
@jwt_required()
|
||||
def upload_images(batch_id):
|
||||
"""上传图片到指定任务"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 检查任务是否存在且属于当前用户
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
if batch.status != 'pending':
|
||||
return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400
|
||||
|
||||
if 'files' not in request.files:
|
||||
return jsonify({'error': '没有选择文件'}), 400
|
||||
|
||||
files = request.files.getlist('files')
|
||||
uploaded_files = []
|
||||
|
||||
# 获取原始图片类型ID
|
||||
original_type = ImageType.query.filter_by(type_code='original').first()
|
||||
if not original_type:
|
||||
return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500
|
||||
|
||||
for file in files:
|
||||
if file.filename == '':
|
||||
continue
|
||||
|
||||
if file and allowed_file(file.filename):
|
||||
# 处理单张图片
|
||||
if not file.filename.lower().endswith(('.zip', '.rar')):
|
||||
result = ImageService.save_image(file, batch_id, current_user_id, original_type.id)
|
||||
if result['success']:
|
||||
uploaded_files.append(result['image'])
|
||||
else:
|
||||
return jsonify({'error': result['error']}), 400
|
||||
|
||||
# 处理压缩包
|
||||
else:
|
||||
results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id)
|
||||
for result in results:
|
||||
if result['success']:
|
||||
uploaded_files.append(result['image'])
|
||||
|
||||
if not uploaded_files:
|
||||
return jsonify({'error': '没有有效的图片文件'}), 400
|
||||
|
||||
return jsonify({
|
||||
'message': f'成功上传 {len(uploaded_files)} 张图片',
|
||||
'uploaded_files': [img.to_dict() for img in uploaded_files]
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'文件上传失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/<int:batch_id>/config', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_task_config(batch_id):
|
||||
"""获取任务配置(显示用户上次的配置或默认配置)"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 检查任务是否存在且属于当前用户
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
# 获取用户配置
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
# 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置
|
||||
if user_config:
|
||||
suggested_config = {
|
||||
'perturbation_config_id': user_config.preferred_perturbation_config_id,
|
||||
'epsilon': float(user_config.preferred_epsilon),
|
||||
'finetune_config_id': user_config.preferred_finetune_config_id,
|
||||
'use_strong_protection': user_config.preferred_purification
|
||||
}
|
||||
else:
|
||||
suggested_config = {
|
||||
'perturbation_config_id': batch.perturbation_config_id,
|
||||
'epsilon': float(batch.preferred_epsilon),
|
||||
'finetune_config_id': batch.finetune_config_id,
|
||||
'use_strong_protection': batch.use_strong_protection
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
'task': batch.to_dict(),
|
||||
'suggested_config': suggested_config,
|
||||
'current_config': {
|
||||
'perturbation_config_id': batch.perturbation_config_id,
|
||||
'epsilon': float(batch.preferred_epsilon),
|
||||
'finetune_config_id': batch.finetune_config_id,
|
||||
'use_strong_protection': batch.use_strong_protection
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/<int:batch_id>/config', methods=['PUT'])
|
||||
@jwt_required()
|
||||
def update_task_config(batch_id):
|
||||
"""更新任务配置"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 检查任务是否存在且属于当前用户
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
if batch.status != 'pending':
|
||||
return jsonify({'error': '任务已开始处理,无法修改配置'}), 400
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
# 更新任务配置
|
||||
if 'perturbation_config_id' in data:
|
||||
batch.perturbation_config_id = data['perturbation_config_id']
|
||||
|
||||
if 'epsilon' in data:
|
||||
epsilon = float(data['epsilon'])
|
||||
if 0 < epsilon <= 255:
|
||||
batch.preferred_epsilon = epsilon
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'finetune_config_id' in data:
|
||||
batch.finetune_config_id = data['finetune_config_id']
|
||||
|
||||
if 'use_strong_protection' in data:
|
||||
batch.use_strong_protection = bool(data['use_strong_protection'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# 更新用户配置(保存这次的选择作为下次的默认配置)
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
user_config.preferred_perturbation_config_id = batch.perturbation_config_id
|
||||
user_config.preferred_epsilon = batch.preferred_epsilon
|
||||
user_config.preferred_finetune_config_id = batch.finetune_config_id
|
||||
user_config.preferred_purification = batch.use_strong_protection
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '任务配置更新成功',
|
||||
'task': batch.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/start/<int:batch_id>', methods=['POST'])
|
||||
@jwt_required()
|
||||
def start_task(batch_id):
|
||||
"""开始处理任务"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 检查任务是否存在且属于当前用户
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
if batch.status != 'pending':
|
||||
return jsonify({'error': '任务状态不正确,无法开始处理'}), 400
|
||||
|
||||
# 检查是否有上传的图片
|
||||
image_count = Image.query.filter_by(batch_id=batch_id).count()
|
||||
if image_count == 0:
|
||||
return jsonify({'error': '请先上传图片'}), 400
|
||||
|
||||
# 启动任务处理
|
||||
success = TaskService.start_processing(batch)
|
||||
|
||||
if success:
|
||||
return jsonify({
|
||||
'message': '任务开始处理',
|
||||
'task': batch.to_dict()
|
||||
}), 200
|
||||
else:
|
||||
return jsonify({'error': '任务启动失败'}), 500
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'任务启动失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/list', methods=['GET'])
|
||||
@jwt_required()
|
||||
def list_tasks():
|
||||
"""获取用户的任务列表"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 10, type=int)
|
||||
|
||||
batches = Batch.query.filter_by(user_id=current_user_id)\
|
||||
.order_by(Batch.created_at.desc())\
|
||||
.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
return jsonify({
|
||||
'tasks': [batch.to_dict() for batch in batches.items],
|
||||
'total': batches.total,
|
||||
'pages': batches.pages,
|
||||
'current_page': page
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/<int:batch_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_task_detail(batch_id):
|
||||
"""获取任务详情"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
# 获取任务相关的图片
|
||||
images = Image.query.filter_by(batch_id=batch_id).all()
|
||||
|
||||
return jsonify({
|
||||
'task': batch.to_dict(),
|
||||
'images': [img.to_dict() for img in images],
|
||||
'image_count': len(images)
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/load-config', methods=['GET'])
|
||||
@jwt_required()
|
||||
def load_last_config():
|
||||
"""加载用户上次的配置"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 获取用户配置
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if user_config:
|
||||
config = {
|
||||
'perturbation_config_id': user_config.preferred_perturbation_config_id,
|
||||
'epsilon': float(user_config.preferred_epsilon),
|
||||
'finetune_config_id': user_config.preferred_finetune_config_id,
|
||||
'use_strong_protection': user_config.preferred_purification
|
||||
}
|
||||
return jsonify({
|
||||
'message': '成功加载上次配置',
|
||||
'config': config
|
||||
}), 200
|
||||
else:
|
||||
# 返回默认配置
|
||||
default_config = {
|
||||
'perturbation_config_id': 1,
|
||||
'epsilon': 8.0,
|
||||
'finetune_config_id': 1,
|
||||
'use_strong_protection': False
|
||||
}
|
||||
return jsonify({
|
||||
'message': '使用默认配置',
|
||||
'config': default_config
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'加载配置失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/save-config', methods=['POST'])
|
||||
@jwt_required()
|
||||
def save_current_config():
|
||||
"""保存当前配置作为用户偏好"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
data = request.get_json()
|
||||
|
||||
# 获取或创建用户配置
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
# 更新配置
|
||||
if 'perturbation_config_id' in data:
|
||||
user_config.preferred_perturbation_config_id = data['perturbation_config_id']
|
||||
|
||||
if 'epsilon' in data:
|
||||
epsilon = float(data['epsilon'])
|
||||
if 0 < epsilon <= 255:
|
||||
user_config.preferred_epsilon = epsilon
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'finetune_config_id' in data:
|
||||
user_config.preferred_finetune_config_id = data['finetune_config_id']
|
||||
|
||||
if 'use_strong_protection' in data:
|
||||
user_config.preferred_purification = bool(data['use_strong_protection'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '配置保存成功',
|
||||
'config': {
|
||||
'perturbation_config_id': user_config.preferred_perturbation_config_id,
|
||||
'epsilon': float(user_config.preferred_epsilon),
|
||||
'finetune_config_id': user_config.preferred_finetune_config_id,
|
||||
'use_strong_protection': user_config.preferred_purification
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'保存配置失败: {str(e)}'}), 500
|
||||
|
||||
@task_bp.route('/<int:batch_id>/status', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_task_status(batch_id):
|
||||
"""获取任务处理状态"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
|
||||
if not batch:
|
||||
return jsonify({'error': '任务不存在或无权限'}), 404
|
||||
|
||||
return jsonify({
|
||||
'task_id': batch_id,
|
||||
'status': batch.status,
|
||||
'progress': TaskService.get_processing_progress(batch_id),
|
||||
'error_message': batch.error_message
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500
|
||||
@ -0,0 +1,133 @@
|
||||
"""
|
||||
用户管理控制器
|
||||
处理用户配置等功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from app import db
|
||||
from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig
|
||||
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
|
||||
|
||||
user_bp = Blueprint('user', __name__)
|
||||
|
||||
@user_bp.route('/config', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_config(current_user_id):
|
||||
"""获取用户配置"""
|
||||
try:
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
# 如果没有配置,创建默认配置
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/config', methods=['PUT'])
|
||||
@int_jwt_required
|
||||
def update_user_config(current_user_id):
|
||||
"""更新用户配置"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
# 更新配置字段
|
||||
if 'preferred_perturbation_config_id' in data:
|
||||
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
|
||||
|
||||
if 'preferred_epsilon' in data:
|
||||
epsilon = float(data['preferred_epsilon'])
|
||||
if 0 < epsilon <= 255:
|
||||
user_config.preferred_epsilon = epsilon
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'preferred_finetune_config_id' in data:
|
||||
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
|
||||
|
||||
if 'preferred_purification' in data:
|
||||
user_config.preferred_purification = bool(data['preferred_purification'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户配置更新成功',
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/algorithms', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_available_algorithms():
|
||||
"""获取可用的算法列表"""
|
||||
try:
|
||||
perturbation_configs = PerturbationConfig.query.all()
|
||||
finetune_configs = FinetuneConfig.query.all()
|
||||
|
||||
return jsonify({
|
||||
'perturbation_algorithms': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description,
|
||||
'default_epsilon': float(config.default_epsilon)
|
||||
} for config in perturbation_configs
|
||||
],
|
||||
'finetune_methods': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description
|
||||
} for config in finetune_configs
|
||||
]
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/stats', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_stats(current_user_id):
|
||||
"""获取用户统计信息"""
|
||||
try:
|
||||
from app.models import Batch, Image
|
||||
|
||||
# 统计用户的任务和图片数量
|
||||
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
|
||||
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
|
||||
|
||||
total_images = Image.query.filter_by(user_id=current_user_id).count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'total_tasks': total_tasks,
|
||||
'completed_tasks': completed_tasks,
|
||||
'processing_tasks': processing_tasks,
|
||||
'failed_tasks': failed_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500
|
||||
@ -0,0 +1,233 @@
|
||||
"""
|
||||
数据库模型定义
|
||||
基于已有的schema.sql设计
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from app import db
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
class User(db.Model):
|
||||
"""用户表"""
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
username = db.Column(db.String(50), unique=True, nullable=False)
|
||||
password_hash = db.Column(db.String(255), nullable=False)
|
||||
email = db.Column(db.String(100))
|
||||
role = db.Column(db.Enum('user', 'admin'), default='user')
|
||||
max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0)
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = db.Column(db.Boolean, default=True)
|
||||
|
||||
# 关系
|
||||
batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan')
|
||||
images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan')
|
||||
user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan')
|
||||
|
||||
def set_password(self, password):
|
||||
"""设置密码"""
|
||||
self.password_hash = generate_password_hash(password)
|
||||
|
||||
def check_password(self, password):
|
||||
"""验证密码"""
|
||||
return check_password_hash(self.password_hash, password)
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'username': self.username,
|
||||
'email': self.email,
|
||||
'role': self.role,
|
||||
'max_concurrent_tasks': self.max_concurrent_tasks,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'is_active': self.is_active
|
||||
}
|
||||
|
||||
class ImageType(db.Model):
|
||||
"""图片类型表"""
|
||||
__tablename__ = 'image_types'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'),
|
||||
unique=True, nullable=False)
|
||||
type_name = db.Column(db.String(100), nullable=False)
|
||||
description = db.Column(db.Text)
|
||||
|
||||
# 关系
|
||||
images = db.relationship('Image', backref='image_type', lazy='dynamic')
|
||||
|
||||
class PerturbationConfig(db.Model):
|
||||
"""加噪算法表"""
|
||||
__tablename__ = 'perturbation_configs'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
method_code = db.Column(db.String(50), unique=True, nullable=False)
|
||||
method_name = db.Column(db.String(100), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
default_epsilon = db.Column(db.Numeric(5, 2), nullable=False)
|
||||
|
||||
# 关系
|
||||
batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic')
|
||||
user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic')
|
||||
|
||||
class FinetuneConfig(db.Model):
|
||||
"""微调方式表"""
|
||||
__tablename__ = 'finetune_configs'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
method_code = db.Column(db.String(50), unique=True, nullable=False)
|
||||
method_name = db.Column(db.String(100), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
|
||||
# 关系
|
||||
batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic')
|
||||
user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic')
|
||||
|
||||
class Batch(db.Model):
|
||||
"""加噪批次表"""
|
||||
__tablename__ = 'batch'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
|
||||
batch_name = db.Column(db.String(128))
|
||||
|
||||
# 加噪配置
|
||||
perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'),
|
||||
nullable=False, default=1)
|
||||
preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0)
|
||||
|
||||
# 评估配置
|
||||
finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'),
|
||||
nullable=False, default=1)
|
||||
|
||||
# 净化配置
|
||||
use_strong_protection = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
# 任务状态
|
||||
status = db.Column(db.Enum('pending', 'processing', 'completed', 'failed'), default='pending')
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
started_at = db.Column(db.DateTime)
|
||||
completed_at = db.Column(db.DateTime)
|
||||
error_message = db.Column(db.Text)
|
||||
result_path = db.Column(db.String(500))
|
||||
|
||||
# 关系
|
||||
images = db.relationship('Image', backref='batch', lazy='dynamic')
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'batch_name': self.batch_name,
|
||||
'status': self.status,
|
||||
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
|
||||
'use_strong_protection': self.use_strong_protection,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'error_message': self.error_message,
|
||||
'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None,
|
||||
'finetune_config': self.finetune_config.method_name if self.finetune_config else None
|
||||
}
|
||||
|
||||
class Image(db.Model):
|
||||
"""图片表"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
|
||||
batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'))
|
||||
father_id = db.Column(db.BigInteger, db.ForeignKey('images.id'))
|
||||
original_filename = db.Column(db.String(255))
|
||||
stored_filename = db.Column(db.String(255), unique=True, nullable=False)
|
||||
file_path = db.Column(db.String(500), nullable=False)
|
||||
file_size = db.Column(db.BigInteger)
|
||||
image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False)
|
||||
width = db.Column(db.Integer)
|
||||
height = db.Column(db.Integer)
|
||||
upload_time = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
|
||||
# 自引用关系
|
||||
children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic')
|
||||
|
||||
# 评估结果关系
|
||||
reference_evaluations = db.relationship('EvaluationResult',
|
||||
foreign_keys='EvaluationResult.reference_image_id',
|
||||
backref='reference_image', lazy='dynamic')
|
||||
target_evaluations = db.relationship('EvaluationResult',
|
||||
foreign_keys='EvaluationResult.target_image_id',
|
||||
backref='target_image', lazy='dynamic')
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'original_filename': self.original_filename,
|
||||
'stored_filename': self.stored_filename,
|
||||
'file_path': self.file_path,
|
||||
'file_size': self.file_size,
|
||||
'width': self.width,
|
||||
'height': self.height,
|
||||
'upload_time': self.upload_time.isoformat() if self.upload_time else None,
|
||||
'image_type': self.image_type.type_name if self.image_type else None,
|
||||
'batch_id': self.batch_id
|
||||
}
|
||||
|
||||
class EvaluationResult(db.Model):
|
||||
"""评估结果表"""
|
||||
__tablename__ = 'evaluation_results'
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True)
|
||||
reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
|
||||
target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
|
||||
evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False)
|
||||
purification_applied = db.Column(db.Boolean, default=False)
|
||||
fid_score = db.Column(db.Numeric(8, 4))
|
||||
lpips_score = db.Column(db.Numeric(8, 4))
|
||||
ssim_score = db.Column(db.Numeric(8, 4))
|
||||
psnr_score = db.Column(db.Numeric(8, 4))
|
||||
heatmap_path = db.Column(db.String(500))
|
||||
evaluated_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'evaluation_type': self.evaluation_type,
|
||||
'purification_applied': self.purification_applied,
|
||||
'fid_score': float(self.fid_score) if self.fid_score else None,
|
||||
'lpips_score': float(self.lpips_score) if self.lpips_score else None,
|
||||
'ssim_score': float(self.ssim_score) if self.ssim_score else None,
|
||||
'psnr_score': float(self.psnr_score) if self.psnr_score else None,
|
||||
'heatmap_path': self.heatmap_path,
|
||||
'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None
|
||||
}
|
||||
|
||||
class UserConfig(db.Model):
|
||||
"""用户配置表"""
|
||||
__tablename__ = 'user_configs'
|
||||
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True)
|
||||
preferred_perturbation_config_id = db.Column(db.BigInteger,
|
||||
db.ForeignKey('perturbation_configs.id'), default=1)
|
||||
preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0)
|
||||
preferred_finetune_config_id = db.Column(db.BigInteger,
|
||||
db.ForeignKey('finetune_configs.id'), default=1)
|
||||
preferred_purification = db.Column(db.Boolean, default=False)
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'user_id': self.user_id,
|
||||
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
|
||||
'preferred_purification': self.preferred_purification,
|
||||
'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None,
|
||||
'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
@ -0,0 +1,34 @@
|
||||
"""
|
||||
认证服务
|
||||
处理用户认证相关逻辑
|
||||
"""
|
||||
|
||||
from app.models import User
|
||||
|
||||
class AuthService:
|
||||
"""认证服务类"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(username, password):
|
||||
"""验证用户凭据"""
|
||||
user = User.query.filter_by(username=username).first()
|
||||
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
return User.query.get(user_id)
|
||||
|
||||
@staticmethod
|
||||
def is_email_available(email):
|
||||
"""检查邮箱是否可用"""
|
||||
return User.query.filter_by(email=email).first() is None
|
||||
|
||||
@staticmethod
|
||||
def is_username_available(username):
|
||||
"""检查用户名是否可用"""
|
||||
return User.query.filter_by(username=username).first() is None
|
||||
@ -0,0 +1,53 @@
|
||||
"""
|
||||
文件处理工具类
|
||||
"""
|
||||
|
||||
import os
|
||||
from werkzeug.utils import secure_filename
|
||||
from flask import current_app
|
||||
|
||||
def allowed_file(filename):
|
||||
"""检查文件扩展名是否被允许"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
|
||||
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
|
||||
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in allowed_extensions
|
||||
|
||||
def save_uploaded_file(file, upload_path):
|
||||
"""保存上传的文件"""
|
||||
try:
|
||||
if not file or not allowed_file(file.filename):
|
||||
return None
|
||||
|
||||
filename = secure_filename(file.filename)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
|
||||
|
||||
file.save(upload_path)
|
||||
return upload_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_file_size(file_path):
|
||||
"""获取文件大小"""
|
||||
try:
|
||||
return os.path.getsize(file_path)
|
||||
except:
|
||||
return 0
|
||||
|
||||
def delete_file(file_path):
|
||||
"""删除文件"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
@ -0,0 +1,16 @@
|
||||
"""
|
||||
JWT工具函数
|
||||
"""
|
||||
from functools import wraps
|
||||
from flask_jwt_extended import get_jwt_identity
|
||||
|
||||
def int_jwt_required(f):
|
||||
"""获取JWT身份并转换为整数的装饰器"""
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
jwt_identity = get_jwt_identity()
|
||||
if jwt_identity is not None:
|
||||
# 在函数调用前注入转换后的user_id
|
||||
kwargs['current_user_id'] = int(jwt_identity)
|
||||
return f(*args, **kwargs)
|
||||
return wrapped
|
||||
@ -0,0 +1,16 @@
|
||||
# MuseGuard 环境变量配置文件
|
||||
# 注意:此文件包含敏感信息,不应提交到版本控制系统
|
||||
|
||||
# 数据库配置
|
||||
DB_USER=root
|
||||
DB_PASSWORD=your_password_here
|
||||
DB_HOST=localhost
|
||||
DB_NAME=your_database_name_here
|
||||
|
||||
# Flask配置
|
||||
SECRET_KEY=museguard-secret-key-2024
|
||||
JWT_SECRET_KEY=jwt-secret-string
|
||||
|
||||
# 开发模式
|
||||
FLASK_ENV=development
|
||||
FLASK_DEBUG=True
|
||||
@ -0,0 +1,109 @@
|
||||
"""
|
||||
应用配置文件
|
||||
"""
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from dotenv import load_dotenv
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
# 加载环境变量 - 从 config 目录读取 .env 文件
|
||||
config_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(config_dir, '.env')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class Config:
|
||||
"""基础配置类"""
|
||||
|
||||
# 基础配置
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
|
||||
|
||||
# 数据库配置 - 支持密码中的特殊字符
|
||||
DB_USER = os.environ.get('DB_USER')
|
||||
DB_PASSWORD = os.environ.get('DB_PASSWORD')
|
||||
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
|
||||
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
|
||||
|
||||
# URL编码密码中的特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
_encoded_password = quote_plus(DB_PASSWORD)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
|
||||
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
SQLALCHEMY_ENGINE_OPTIONS = {
|
||||
'pool_pre_ping': True,
|
||||
'pool_recycle': 300,
|
||||
}
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
|
||||
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
|
||||
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
|
||||
|
||||
# 静态文件根目录
|
||||
STATIC_ROOT = 'static'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = 'uploads' # 临时上传目录
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
|
||||
|
||||
# 图像处理配置
|
||||
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片
|
||||
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
|
||||
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
|
||||
MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果
|
||||
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
|
||||
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
|
||||
|
||||
# 预设演示图像配置
|
||||
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
|
||||
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
|
||||
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
|
||||
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
|
||||
|
||||
# 邮件配置(用于注册验证)
|
||||
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
|
||||
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
|
||||
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
|
||||
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
|
||||
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
|
||||
|
||||
# 算法配置
|
||||
ALGORITHMS = {
|
||||
'simac': {
|
||||
'name': 'SimAC算法',
|
||||
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
|
||||
'default_epsilon': 8.0
|
||||
},
|
||||
'caat': {
|
||||
'name': 'CAAT算法',
|
||||
'description': 'Perturbing Attention Gives You More Bang for the Buck',
|
||||
'default_epsilon': 16.0
|
||||
},
|
||||
'pid': {
|
||||
'name': 'PID算法',
|
||||
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
|
||||
'default_epsilon': 4.0
|
||||
}
|
||||
}
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""开发环境配置"""
|
||||
DEBUG = True
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""生产环境配置"""
|
||||
DEBUG = False
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""测试环境配置"""
|
||||
TESTING = True
|
||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig
|
||||
}
|
||||
@ -0,0 +1,15 @@
|
||||
@echo off
|
||||
chcp 65001 > nul
|
||||
echo ==========================================
|
||||
echo MuseGuard 快速启动脚本
|
||||
echo ==========================================
|
||||
echo.
|
||||
|
||||
REM 切换到项目目录
|
||||
cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend"
|
||||
|
||||
REM 激活虚拟环境并启动服务器
|
||||
echo 正在激活虚拟环境并启动服务器...
|
||||
call venv_py311\Scripts\activate.bat && python run.py
|
||||
|
||||
pause
|
||||
@ -0,0 +1,33 @@
|
||||
# Core Flask Framework
|
||||
Flask==3.0.0
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
Flask-Migrate==4.0.5
|
||||
Flask-JWT-Extended==4.6.0
|
||||
Flask-CORS==5.0.0
|
||||
Werkzeug==3.0.1
|
||||
|
||||
# Database
|
||||
PyMySQL==1.1.1
|
||||
|
||||
# Image Processing
|
||||
Pillow==10.4.0
|
||||
numpy==1.26.4
|
||||
|
||||
# Security & Utils
|
||||
cryptography==42.0.8
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# Additional Dependencies (auto-installed)
|
||||
# alembic==1.17.0
|
||||
# blinker==1.9.0
|
||||
# cffi==2.0.0
|
||||
# click==8.3.0
|
||||
# greenlet==3.2.4
|
||||
# itsdangerous==2.2.0
|
||||
# Jinja2==3.1.6
|
||||
# Mako==1.3.10
|
||||
# MarkupSafe==3.0.3
|
||||
# pycparser==2.23
|
||||
# PyJWT==2.10.1
|
||||
# SQLAlchemy==2.0.44
|
||||
# typing_extensions==4.15.0
|
||||
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Flask应用启动脚本
|
||||
"""
|
||||
|
||||
import os
|
||||
from app import create_app
|
||||
|
||||
# 设置环境变量
|
||||
os.environ.setdefault('FLASK_ENV', 'development')
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 开发模式启动
|
||||
app.run(
|
||||
host='0.0.0.0',
|
||||
port=5000,
|
||||
debug=True,
|
||||
threaded=True
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue