diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c9cc180 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +__pycache__/ + +venv/ + +*.png +*.jpg +*.jpeg + +.env \ No newline at end of file diff --git a/src/backend/README.md b/src/backend/README.md new file mode 100644 index 0000000..6728238 --- /dev/null +++ b/src/backend/README.md @@ -0,0 +1,243 @@ +# MuseGuard 后端框架 + +基于对抗性扰动的多风格图像生成防护系统 - 后端API服务 + +## 项目结构 + +``` +backend/ +├── app/ # 主应用目录 +│ ├── algorithms/ # 算法实现 +│ │ ├── perturbation_engine.py # 对抗性扰动引擎 +│ │ └── evaluation_engine.py # 评估引擎 +│ ├── controllers/ # 控制器(路由处理) +│ │ ├── auth_controller.py # 认证控制器 +│ │ ├── user_controller.py # 用户配置控制器 +│ │ ├── task_controller.py # 任务控制器 +| | ├── demo_controller.py # 首页示例控制器 +│ │ ├── image_controller.py # 图像控制器 +│ │ └── admin_controller.py # 管理员控制器 +│ ├── models/ # 数据模型 +│ │ └── __init__.py # SQLAlchemy模型定义 +│ ├── services/ # 业务逻辑服务 +│ │ ├── auth_service.py # 认证服务 +│ │ ├── task_service.py # 任务处理服务 +│ │ └── image_service.py # 图像处理服务 +│ └── utils/ # 工具类 +│ └── file_utils.py # 文件处理工具 +├── config/ # 配置文件 +│ └── settings.py # 应用配置 +├── uploads/ # 文件上传目录 +├── static/ # 静态文件 +│ ├── originals/ # 重命名后的原始图片 +│ ├── perturbed/ # 加噪后的图片 +│ ├── model_outputs/ # 模型生成的图片 +│ │ ├── clean/ # 原图的模型生成结果 +│ │ └── perturbed/ # 加噪图的模型生成结果 +│ ├── heatmaps/ # 热力图 +│ └── demo/ # 演示图片 +│ ├── original/ # 演示原始图片 +│ ├── perturbed/ # 演示加噪图片 +│ └── comparisons/ # 演示对比图 +├── app.py # Flask应用工厂 +├── run.py # 启动脚本 +├── init_db.py # 数据库初始化脚本 +└── requirements.txt # Python依赖 +``` + +## 功能特性 + +### 用户功能 +- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次) +- ✅ 用户登录/登出 +- ✅ 密码修改 +- ✅ 任务创建和管理 +- ✅ 图片上传(单张/压缩包批量) +- ✅ 加噪处理(4种算法:SimAC、CAAT、PID、ASPL) +- ✅ 扰动强度自定义 +- ✅ 防净化版本选择 +- ✅ 智能配置记忆:自动保存用户上次选择的配置 +- ✅ 处理结果下载 +- ✅ 图片质量对比查看(FID、LPIPS、SSIM、PSNR、热力图) +- ✅ 模型生成对比分析 +- ✅ 预设演示图片浏览 + +### 管理员功能 +- ✅ 用户管理(增删改查) +- ✅ 系统统计信息查看 + +### 算法实现 +- ✅ ASPL算法虚拟实现(原始版本 + 防净化版本) +- ✅ SimAC算法虚拟实现(原始版本 + 防净化版本) +- ✅ CAAT算法虚拟实现(原始版本 + 防净化版本) +- ✅ PID算法虚拟实现(原始版本 + 防净化版本) +- ✅ 图像质量评估指标计算 +- ✅ 模型生成效果对比 +- ✅ 热力图生成 + +## 安装和运行 + +### 1. 环境准备 + +#### 使用虚拟环境(推荐) + +**为什么需要虚拟环境?** +- ✅ **避免依赖冲突**:不同项目使用不同版本的包 +- ✅ **环境隔离**:不污染系统Python环境 +- ✅ **版本一致性**:确保团队环境统一 +- ✅ **易于管理**:可以随时删除重建 + +```bash +# 创建虚拟环境 +python -m venv venv + +# 激活虚拟环境 +# Windows: +venv\\Scripts\\activate +# Linux/Mac: +source venv/bin/activate + +# 更新pip(推荐) +python -m pip install --upgrade pip + +# 安装依赖 +pip install -r requirements.txt +``` + +### 2. 数据库配置 + +确保已安装MySQL数据库并创建数据库。 + +修改 `config/.env` 中的数据库连接配置: + +### 3. 初始化数据库 + +```bash +# 运行数据库初始化脚本 +python init_db.py +``` + +### 4. 启动应用 + +```bash +# 开发模式启动 +python run.py + +# 或者使用Flask命令 +flask run +``` + +应用将在 `http://localhost:5000` 启动 + +### 5. 系统测试 + +访问 `http://localhost:5000/static/test.html` 进入功能测试页面: + +## API接口文档 + +### 认证接口 (`/api/auth`) + +- `POST /register` - 用户注册 +- `POST /login` - 用户登录 +- `POST /change-password` - 修改密码 +- `GET /profile` - 获取用户信息 +- `POST /logout` - 用户登出 + +### 任务管理 (`/api/task`) + +- `POST /create` - 创建任务(使用默认配置) +- `POST /upload/` - 上传图片到指定任务 +- `GET //config` - 获取任务配置(显示用户上次选择) +- `PUT //config` - 更新任务配置(自动保存为用户偏好) +- `GET /load-config` - 加载用户上次配置 +- `POST /save-config` - 保存用户配置偏好 +- `POST /start/` - 开始处理任务 +- `GET /list` - 获取任务列表 +- `GET /` - 获取任务详情 +- `GET //status` - 获取处理状态 + +### 图片管理 (`/api/image`) + +- `GET /file/` - 查看图片 +- `GET /download/` - 下载图片 +- `GET /batch//download` - 批量下载 +- `GET //evaluations` - 获取评估结果 +- `POST /compare` - 对比图片 +- `GET /heatmap/` - 获取热力图 +- `DELETE /delete/` - 删除图片 + +### 用户设置 (`/api/user`) + +- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中) +- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存) +- `GET /algorithms` - 获取可用算法(动态从数据库加载) +- `GET /stats` - 获取用户统计 + +### 管理员功能 (`/api/admin`) + +- `GET /users` - 用户列表 +- `GET /users/` - 用户详情 +- `POST /users` - 创建用户 +- `PUT /users/` - 更新用户 +- `DELETE /users/` - 删除用户 +- `GET /stats` - 系统统计 + +### 演示功能 (`/api/demo`) + +- `GET /images` - 获取演示图片列表 +- `GET /image/original/` - 获取演示原始图片 +- `GET /image/perturbed/` - 获取演示加噪图片 +- `GET /image/comparison/` - 获取演示对比图片 +- `GET /algorithms` - 获取算法演示信息 +- `GET /stats` - 获取演示统计数据 + +## 默认账户 + +系统初始化后会创建3个管理员账户: + +- 用户名:`admin1`, `admin2`, `admin3` +- 默认密码:`admin123` +- 邮箱:`admin1@museguard.com` 等 + +## 技术栈 + +- **Web框架**: Flask 2.3.3 +- **数据库ORM**: SQLAlchemy 3.0.5 +- **数据库**: MySQL(通过PyMySQL连接) +- **认证**: JWT (Flask-JWT-Extended) +- **跨域**: Flask-CORS +- **图像处理**: Pillow + NumPy +- **数学计算**: NumPy + +## 开发说明 + +### 虚拟实现说明 + +当前所有算法都是**虚拟实现**,用于框架搭建和测试: + +1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果 +2. **评估指标**: 基于像素差异的简化计算 +3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果 + +### 扩展指南 + +要集成真实算法: + +1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现 +2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算 +3. 根据需要调整配置参数 + +### 目录权限 + +确保以下目录有写入权限: + +- `uploads/` - 用户上传文件 +- `static/originals/` - 重命名后的原始图片 +- `static/perturbed/` - 加噪后的图片 +- `static/model_outputs/` - 模型生成的图片 +- `static/heatmaps/` - 热力图文件 +- `static/demo/` - 演示图片(需要手动添加演示文件) + +## 许可证 + +本项目仅用于学习和研究目的。 \ No newline at end of file diff --git a/src/backend/app.py b/src/backend/app.py new file mode 100644 index 0000000..3d8a748 --- /dev/null +++ b/src/backend/app.py @@ -0,0 +1,46 @@ +""" +MuseGuard 后端主应用入口 +基于对抗性扰动的多风格图像生成防护系统 +""" + +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate +from flask_jwt_extended import JWTManager +from flask_cors import CORS +from config.settings import Config + +# 初始化扩展 +db = SQLAlchemy() +migrate = Migrate() +jwt = JWTManager() + +def create_app(config_class=Config): + """Flask应用工厂函数""" + app = Flask(__name__) + app.config.from_object(config_class) + + # 初始化扩展 + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + CORS(app) + + # 注册蓝图 + from app.controllers.auth_controller import auth_bp + from app.controllers.user_controller import user_bp + from app.controllers.task_controller import task_bp + from app.controllers.image_controller import image_bp + from app.controllers.admin_controller import admin_bp + + app.register_blueprint(auth_bp, url_prefix='/api/auth') + app.register_blueprint(user_bp, url_prefix='/api/user') + app.register_blueprint(task_bp, url_prefix='/api/task') + app.register_blueprint(image_bp, url_prefix='/api/image') + app.register_blueprint(admin_bp, url_prefix='/api/admin') + + return app + +if __name__ == '__main__': + app = create_app() + app.run(debug=True, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/src/backend/app/__init__.py b/src/backend/app/__init__.py new file mode 100644 index 0000000..4ed0625 --- /dev/null +++ b/src/backend/app/__init__.py @@ -0,0 +1,83 @@ +""" +MuseGuard 后端应用包 +基于对抗性扰动的多风格图像生成防护系统 +""" + +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate +from flask_jwt_extended import JWTManager +from flask_cors import CORS +import os + +# 初始化扩展 +db = SQLAlchemy() +migrate = Migrate() +jwt = JWTManager() +cors = CORS() + +def create_app(config_name=None): + """应用工厂函数""" + # 配置静态文件和模板文件路径 + app = Flask(__name__, + static_folder='../static', + static_url_path='/static') + + # 加载配置 + if config_name is None: + config_name = os.environ.get('FLASK_ENV', 'development') + + from config.settings import config + app.config.from_object(config[config_name]) + + # 初始化扩展 + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + cors.init_app(app) + + # 注册蓝图 + from app.controllers.auth_controller import auth_bp + from app.controllers.user_controller import user_bp + from app.controllers.task_controller import task_bp + from app.controllers.image_controller import image_bp + from app.controllers.admin_controller import admin_bp + from app.controllers.demo_controller import demo_bp + + app.register_blueprint(auth_bp, url_prefix='/api/auth') + app.register_blueprint(user_bp, url_prefix='/api/user') + app.register_blueprint(task_bp, url_prefix='/api/task') + app.register_blueprint(image_bp, url_prefix='/api/image') + app.register_blueprint(admin_bp, url_prefix='/api/admin') + app.register_blueprint(demo_bp, url_prefix='/api/demo') + + # 注册错误处理器 + @app.errorhandler(404) + def not_found_error(error): + return {'error': 'Not found'}, 404 + + @app.errorhandler(500) + def internal_error(error): + db.session.rollback() + return {'error': 'Internal server error'}, 500 + + # 根路由 + @app.route('/') + def index(): + return { + 'message': 'MuseGuard API Server', + 'version': '1.0.0', + 'status': 'running', + 'endpoints': { + 'health': '/health', + 'api_docs': '/api', + 'test_page': '/static/test.html' + } + } + + # 健康检查端点 + @app.route('/health') + def health_check(): + return {'status': 'healthy', 'message': 'MuseGuard backend is running'} + + return app \ No newline at end of file diff --git a/src/backend/app/algorithms/evaluation_engine.py b/src/backend/app/algorithms/evaluation_engine.py new file mode 100644 index 0000000..b6ab294 --- /dev/null +++ b/src/backend/app/algorithms/evaluation_engine.py @@ -0,0 +1,300 @@ +""" +评估引擎 +实现图像质量评估和模型生成对比的虚拟版本 +""" + +import os +import numpy as np +from PIL import Image, ImageDraw +import uuid +import random +from flask import current_app + +class EvaluationEngine: + """评估处理引擎""" + + @staticmethod + def evaluate_image_quality(original_path, perturbed_path): + """ + 评估图像质量指标 + + Args: + original_path: 原始图片路径 + perturbed_path: 扰动后图片路径 + + Returns: + 包含各种评估指标的字典 + """ + try: + # 加载图片 + with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img: + # 确保图片尺寸一致 + if orig_img.size != pert_img.size: + pert_img = pert_img.resize(orig_img.size) + + # 转换为RGB + if orig_img.mode != 'RGB': + orig_img = orig_img.convert('RGB') + if pert_img.mode != 'RGB': + pert_img = pert_img.convert('RGB') + + # 计算虚拟评估指标 + fid_score = EvaluationEngine._calculate_mock_fid(orig_img, pert_img) + lpips_score = EvaluationEngine._calculate_mock_lpips(orig_img, pert_img) + ssim_score = EvaluationEngine._calculate_mock_ssim(orig_img, pert_img) + psnr_score = EvaluationEngine._calculate_mock_psnr(orig_img, pert_img) + + # 生成热力图 + heatmap_path = EvaluationEngine._generate_difference_heatmap( + orig_img, pert_img, "quality" + ) + + return { + 'fid': fid_score, + 'lpips': lpips_score, + 'ssim': ssim_score, + 'psnr': psnr_score, + 'heatmap_path': heatmap_path + } + + except Exception as e: + print(f"图像质量评估失败: {str(e)}") + return { + 'fid': None, + 'lpips': None, + 'ssim': None, + 'psnr': None, + 'heatmap_path': None + } + + @staticmethod + def evaluate_model_generation(original_path, perturbed_path, finetune_method): + """ + 评估模型生成对比 + + Args: + original_path: 原始图片路径 + perturbed_path: 扰动后图片路径 + finetune_method: 微调方法 (dreambooth, lora) + + Returns: + 包含生成对比评估指标的字典 + """ + try: + # 模拟训练过程并生成对比图片 + with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img: + # 模拟原始图片训练后的生成结果 + orig_generated = EvaluationEngine._simulate_model_generation(orig_img, finetune_method, False) + + # 模拟扰动图片训练后的生成结果 + pert_generated = EvaluationEngine._simulate_model_generation(pert_img, finetune_method, True) + + # 计算生成质量对比指标 + fid_score = EvaluationEngine._calculate_generation_fid(orig_generated, pert_generated) + lpips_score = EvaluationEngine._calculate_generation_lpips(orig_generated, pert_generated) + ssim_score = EvaluationEngine._calculate_generation_ssim(orig_generated, pert_generated) + psnr_score = EvaluationEngine._calculate_generation_psnr(orig_generated, pert_generated) + + # 生成对比热力图 + heatmap_path = EvaluationEngine._generate_difference_heatmap( + orig_generated, pert_generated, "generation" + ) + + return { + 'fid': fid_score, + 'lpips': lpips_score, + 'ssim': ssim_score, + 'psnr': psnr_score, + 'heatmap_path': heatmap_path + } + + except Exception as e: + print(f"模型生成评估失败: {str(e)}") + return { + 'fid': None, + 'lpips': None, + 'ssim': None, + 'psnr': None, + 'heatmap_path': None + } + + @staticmethod + def _calculate_mock_fid(img1, img2): + """模拟FID计算""" + # 简单的像素差异统计作为FID近似 + arr1 = np.array(img1, dtype=np.float32) + arr2 = np.array(img2, dtype=np.float32) + + mse = np.mean((arr1 - arr2) ** 2) + # FID通常在0-300之间,扰动图片应该有较小的差异 + mock_fid = min(mse / 10.0 + random.uniform(0.5, 2.0), 50.0) + return round(mock_fid, 4) + + @staticmethod + def _calculate_mock_lpips(img1, img2): + """模拟LPIPS计算""" + arr1 = np.array(img1, dtype=np.float32) / 255.0 + arr2 = np.array(img2, dtype=np.float32) / 255.0 + + # 简单的感知差异模拟 + diff = np.abs(arr1 - arr2) + # 模拟感知权重(边缘区域权重更高) + gray1 = np.mean(arr1, axis=2) + edges = np.abs(np.gradient(gray1)[0]) + np.abs(np.gradient(gray1)[1]) + + weighted_diff = np.mean(diff * (1 + edges.reshape(edges.shape + (1,)))) + mock_lpips = min(weighted_diff + random.uniform(0.001, 0.01), 1.0) + return round(mock_lpips, 4) + + @staticmethod + def _calculate_mock_ssim(img1, img2): + """模拟SSIM计算""" + arr1 = np.array(img1, dtype=np.float32) + arr2 = np.array(img2, dtype=np.float32) + + # 简单的结构相似性模拟 + mu1 = np.mean(arr1) + mu2 = np.mean(arr2) + + sigma1_sq = np.var(arr1) + sigma2_sq = np.var(arr2) + sigma12 = np.mean((arr1 - mu1) * (arr2 - mu2)) + + c1 = (0.01 * 255) ** 2 + c2 = (0.03 * 255) ** 2 + + ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \ + ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2)) + + # SSIM在-1到1之间,但通常在0.8-1.0之间为良好 + mock_ssim = max(0.85 + random.uniform(-0.05, 0.1), 0.0) + return round(mock_ssim, 4) + + @staticmethod + def _calculate_mock_psnr(img1, img2): + """模拟PSNR计算""" + arr1 = np.array(img1, dtype=np.float32) + arr2 = np.array(img2, dtype=np.float32) + + mse = np.mean((arr1 - arr2) ** 2) + if mse == 0: + return 100.0 + + max_pixel = 255.0 + psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) + + # 对于轻微扰动,PSNR应该比较高(30-50dB) + mock_psnr = max(min(psnr + random.uniform(-2, 2), 50.0), 20.0) + return round(mock_psnr, 4) + + @staticmethod + def _simulate_model_generation(img, finetune_method, is_perturbed): + """模拟模型生成过程""" + # 对输入图像进行变换以模拟生成结果 + arr = np.array(img, dtype=np.float32) + + if finetune_method == "dreambooth": + # DreamBooth风格变换 + if is_perturbed: + # 扰动图片训练的模型生成质量较差 + noise_level = 15.0 + style_change = 0.3 + else: + # 原始图片训练的模型生成质量较好 + noise_level = 5.0 + style_change = 0.1 + + elif finetune_method == "lora": + # LoRA风格变换 + if is_perturbed: + noise_level = 12.0 + style_change = 0.25 + else: + noise_level = 3.0 + style_change = 0.08 + else: + noise_level = 8.0 + style_change = 0.15 + + # 添加噪声模拟生成质量 + noise = np.random.normal(0, noise_level, arr.shape) + generated_arr = arr + noise + + # 模拟风格变化 + if style_change > 0: + # 简单的颜色偏移 + generated_arr[:,:,0] *= (1 + style_change * random.uniform(-0.5, 0.5)) + generated_arr[:,:,1] *= (1 + style_change * random.uniform(-0.5, 0.5)) + generated_arr[:,:,2] *= (1 + style_change * random.uniform(-0.5, 0.5)) + + generated_arr = np.clip(generated_arr, 0, 255).astype(np.uint8) + return Image.fromarray(generated_arr) + + @staticmethod + def _calculate_generation_fid(img1, img2): + """计算生成图片的FID(质量越差FID越高越好)""" + fid = EvaluationEngine._calculate_mock_fid(img1, img2) + # 对于防护效果,FID越高越好 + return max(fid, 10.0) + + @staticmethod + def _calculate_generation_lpips(img1, img2): + """计算生成图片的LPIPS(差异越大越好)""" + lpips = EvaluationEngine._calculate_mock_lpips(img1, img2) + return max(lpips, 0.1) + + @staticmethod + def _calculate_generation_ssim(img1, img2): + """计算生成图片的SSIM(相似度越低越好)""" + ssim = EvaluationEngine._calculate_mock_ssim(img1, img2) + # 对于防护效果,SSIM越低越好 + return min(ssim, 0.9) + + @staticmethod + def _calculate_generation_psnr(img1, img2): + """计算生成图片的PSNR(质量差异越大越好)""" + psnr = EvaluationEngine._calculate_mock_psnr(img1, img2) + return min(psnr, 35.0) + + @staticmethod + def _generate_difference_heatmap(img1, img2, eval_type): + """生成差异热力图""" + try: + arr1 = np.array(img1, dtype=np.float32) + arr2 = np.array(img2, dtype=np.float32) + + # 计算像素差异 + diff = np.abs(arr1 - arr2) + diff_gray = np.mean(diff, axis=2) + + # 归一化到0-255 + diff_normalized = (diff_gray / diff_gray.max() * 255).astype(np.uint8) + + # 创建热力图 + heatmap = Image.fromarray(diff_normalized, mode='L') + heatmap = heatmap.convert('RGB') + + # 应用颜色映射(蓝色-绿色-红色) + heatmap_array = np.array(heatmap) + colored_heatmap = np.zeros_like(arr1, dtype=np.uint8) + + # 简单的颜色映射 + intensity = heatmap_array[:,:,0] / 255.0 + colored_heatmap[:,:,0] = (intensity * 255).astype(np.uint8) # 红色通道 + colored_heatmap[:,:,1] = ((1 - intensity) * intensity * 4 * 255).astype(np.uint8) # 绿色通道 + colored_heatmap[:,:,2] = ((1 - intensity) * 255).astype(np.uint8) # 蓝色通道 + + # 保存热力图 + heatmap_dir = os.path.join('static', 'heatmaps') + os.makedirs(heatmap_dir, exist_ok=True) + + heatmap_filename = f"{eval_type}_heatmap_{uuid.uuid4().hex[:8]}.png" + heatmap_path = os.path.join(heatmap_dir, heatmap_filename) + + Image.fromarray(colored_heatmap).save(heatmap_path) + + return heatmap_path + + except Exception as e: + print(f"生成热力图失败: {str(e)}") + return None \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation_engine.py b/src/backend/app/algorithms/perturbation_engine.py new file mode 100644 index 0000000..417b1f4 --- /dev/null +++ b/src/backend/app/algorithms/perturbation_engine.py @@ -0,0 +1,176 @@ +""" +对抗性扰动算法引擎 +实现各种加噪算法的虚拟版本 +""" + +import os +import numpy as np +from PIL import Image, ImageEnhance, ImageFilter +import uuid +from flask import current_app + +class PerturbationEngine: + """对抗性扰动处理引擎""" + + @staticmethod + def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None): + """ + 应用对抗性扰动 + + Args: + image_path: 原始图片路径 + algorithm: 算法名称 (simac, caat, pid) + epsilon: 扰动强度 + use_strong_protection: 是否使用防净化版本 + + Returns: + 处理后图片的路径 + """ + try: + if not os.path.exists(image_path): + raise FileNotFoundError(f"图片文件不存在: {image_path}") + + # 加载图片 + with Image.open(image_path) as img: + # 转换为RGB模式 + if img.mode != 'RGB': + img = img.convert('RGB') + + # 根据算法选择处理方法 + if algorithm == 'simac': + perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection) + elif algorithm == 'caat': + perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection) + elif algorithm == 'pid': + perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection) + else: + raise ValueError(f"不支持的算法: {algorithm}") + + # 使用输入的output_path参数 + if output_path is None: + # 如果没有提供输出路径,使用默认路径 + from flask import current_app + project_root = os.path.dirname(current_app.root_path) + perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER']) + os.makedirs(perturbed_dir, exist_ok=True) + + file_extension = os.path.splitext(image_path)[1] + output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}" + output_path = os.path.join(perturbed_dir, output_filename) + + # 保存处理后的图片 + perturbed_img.save(output_path, quality=95) + + return output_path + + except Exception as e: + print(f"应用扰动时出错: {str(e)}") + return None + + @staticmethod + def _apply_simac(img, epsilon, use_strong_protection): + """ + SimAC算法的虚拟实现 + Simple Anti-Customization Method for Protecting Face Privacy + """ + # 将PIL图像转换为numpy数组 + img_array = np.array(img, dtype=np.float32) + + # 生成随机噪声(模拟对抗性扰动) + noise_scale = epsilon / 255.0 + + if use_strong_protection: + # 防净化版本:添加更复杂的扰动模式 + noise = np.random.normal(0, noise_scale * 0.8, img_array.shape) + # 添加结构化噪声 + h, w = img_array.shape[:2] + for i in range(0, h, 8): + for j in range(0, w, 8): + block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3)) + noise[i:i+8, j:j+8] += block_noise + else: + # 标准版本:简单高斯噪声 + noise = np.random.normal(0, noise_scale, img_array.shape) + + # 应用噪声 + perturbed_array = img_array + noise * 255.0 + perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) + + # 转换回PIL图像 + result_img = Image.fromarray(perturbed_array) + + # 轻微的图像增强以模拟算法特性 + enhancer = ImageEnhance.Contrast(result_img) + result_img = enhancer.enhance(1.02) + + return result_img + + @staticmethod + def _apply_caat(img, epsilon, use_strong_protection): + """ + CAAT算法的虚拟实现 + Perturbing Attention Gives You More Bang for the Buck + """ + img_array = np.array(img, dtype=np.float32) + + noise_scale = epsilon / 255.0 + + if use_strong_protection: + # 防净化版本:注意力区域重点扰动 + # 模拟注意力图(简单的边缘检测) + gray_img = img.convert('L') + edge_img = gray_img.filter(ImageFilter.FIND_EDGES) + attention_map = np.array(edge_img, dtype=np.float32) / 255.0 + + # 在注意力区域添加更强的噪声 + noise = np.random.normal(0, noise_scale * 0.6, img_array.shape) + for c in range(3): + noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape) + else: + # 标准版本:均匀分布噪声 + noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape) + + perturbed_array = img_array + noise * 255.0 + perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) + + result_img = Image.fromarray(perturbed_array) + + # 轻微模糊以模拟注意力扰动效果 + result_img = result_img.filter(ImageFilter.BoxBlur(0.5)) + + return result_img + + @staticmethod + def _apply_pid(img, epsilon, use_strong_protection): + """ + PID算法的虚拟实现 + Prompt-Independent Data Protection Against Latent Diffusion Models + """ + img_array = np.array(img, dtype=np.float32) + + noise_scale = epsilon / 255.0 + + if use_strong_protection: + # 防净化版本:频域扰动 + # 简单的频域变换模拟 + noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape) + # 添加周期性扰动 + h, w = img_array.shape[:2] + for i in range(h): + for j in range(w): + periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1) + noise[i, j] += periodic_noise + else: + # 标准版本:拉普拉斯噪声 + noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape) + + perturbed_array = img_array + noise * 255.0 + perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) + + result_img = Image.fromarray(perturbed_array) + + # 色彩微调以模拟潜在空间扰动 + enhancer = ImageEnhance.Color(result_img) + result_img = enhancer.enhance(0.98) + + return result_img \ No newline at end of file diff --git a/src/backend/app/controllers/admin_controller.py b/src/backend/app/controllers/admin_controller.py new file mode 100644 index 0000000..581fa21 --- /dev/null +++ b/src/backend/app/controllers/admin_controller.py @@ -0,0 +1,239 @@ +""" +管理员控制器 +处理管理员功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required, get_jwt_identity +from app import db +from app.models import User, Batch, Image + +admin_bp = Blueprint('admin', __name__) + +def admin_required(f): + """管理员权限装饰器""" + from functools import wraps + + @wraps(f) + def decorated_function(*args, **kwargs): + current_user_id = get_jwt_identity() + user = User.query.get(current_user_id) + + if not user or user.role != 'admin': + return jsonify({'error': '需要管理员权限'}), 403 + + return f(*args, **kwargs) + + return decorated_function + +@admin_bp.route('/users', methods=['GET']) +@jwt_required() +@admin_required +def list_users(): + """获取用户列表""" + try: + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 20, type=int) + + users = User.query.paginate(page=page, per_page=per_page, error_out=False) + + return jsonify({ + 'users': [user.to_dict() for user in users.items], + 'total': users.total, + 'pages': users.pages, + 'current_page': page + }), 200 + + except Exception as e: + return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['GET']) +@jwt_required() +@admin_required +def get_user_detail(user_id): + """获取用户详情""" + try: + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + # 获取用户统计信息 + total_tasks = Batch.query.filter_by(user_id=user_id).count() + total_images = Image.query.filter_by(user_id=user_id).count() + + user_dict = user.to_dict() + user_dict['stats'] = { + 'total_tasks': total_tasks, + 'total_images': total_images + } + + return jsonify({'user': user_dict}), 200 + + except Exception as e: + return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500 + +@admin_bp.route('/users', methods=['POST']) +@jwt_required() +@admin_required +def create_user(): + """创建用户""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + email = data.get('email') + role = data.get('role', 'user') + max_concurrent_tasks = data.get('max_concurrent_tasks', 0) + + if not username or not password: + return jsonify({'error': '用户名和密码不能为空'}), 400 + + # 检查用户名是否已存在 + if User.query.filter_by(username=username).first(): + return jsonify({'error': '用户名已存在'}), 400 + + # 检查邮箱是否已存在 + if email and User.query.filter_by(email=email).first(): + return jsonify({'error': '邮箱已被使用'}), 400 + + # 创建用户 + user = User( + username=username, + email=email, + role=role, + max_concurrent_tasks=max_concurrent_tasks + ) + user.set_password(password) + + db.session.add(user) + db.session.commit() + + return jsonify({ + 'message': '用户创建成功', + 'user': user.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'创建用户失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['PUT']) +@jwt_required() +@admin_required +def update_user(user_id): + """更新用户信息""" + try: + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + + # 更新字段 + if 'username' in data: + new_username = data['username'] + if new_username != user.username: + if User.query.filter_by(username=new_username).first(): + return jsonify({'error': '用户名已存在'}), 400 + user.username = new_username + + if 'email' in data: + new_email = data['email'] + if new_email != user.email: + if User.query.filter_by(email=new_email).first(): + return jsonify({'error': '邮箱已被使用'}), 400 + user.email = new_email + + if 'role' in data: + user.role = data['role'] + + if 'max_concurrent_tasks' in data: + user.max_concurrent_tasks = data['max_concurrent_tasks'] + + if 'is_active' in data: + user.is_active = bool(data['is_active']) + + if 'password' in data and data['password']: + user.set_password(data['password']) + + db.session.commit() + + return jsonify({ + 'message': '用户信息更新成功', + 'user': user.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新用户失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['DELETE']) +@jwt_required() +@admin_required +def delete_user(user_id): + """删除用户""" + try: + current_user_id = get_jwt_identity() + + # 不能删除自己 + if user_id == current_user_id: + return jsonify({'error': '不能删除自己的账户'}), 400 + + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + # 删除用户(级联删除相关数据) + db.session.delete(user) + db.session.commit() + + return jsonify({'message': '用户删除成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'删除用户失败: {str(e)}'}), 500 + +@admin_bp.route('/stats', methods=['GET']) +@jwt_required() +@admin_required +def get_system_stats(): + """获取系统统计信息""" + try: + from app.models import EvaluationResult + + total_users = User.query.count() + active_users = User.query.filter_by(is_active=True).count() + admin_users = User.query.filter_by(role='admin').count() + + total_tasks = Batch.query.count() + completed_tasks = Batch.query.filter_by(status='completed').count() + processing_tasks = Batch.query.filter_by(status='processing').count() + failed_tasks = Batch.query.filter_by(status='failed').count() + + total_images = Image.query.count() + total_evaluations = EvaluationResult.query.count() + + return jsonify({ + 'stats': { + 'users': { + 'total': total_users, + 'active': active_users, + 'admin': admin_users + }, + 'tasks': { + 'total': total_tasks, + 'completed': completed_tasks, + 'processing': processing_tasks, + 'failed': failed_tasks + }, + 'images': { + 'total': total_images + }, + 'evaluations': { + 'total': total_evaluations + } + } + }), 200 + + except Exception as e: + return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py new file mode 100644 index 0000000..648780c --- /dev/null +++ b/src/backend/app/controllers/auth_controller.py @@ -0,0 +1,156 @@ +""" +用户认证控制器 +处理注册、登录、密码修改等功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity +from app import db +from app.models import User, UserConfig +from app.services.auth_service import AuthService +from functools import wraps +import re + +def int_jwt_required(f): + """获取JWT身份并转换为整数的装饰器""" + @wraps(f) + def wrapped(*args, **kwargs): + try: + current_user_id = int(get_jwt_identity()) + return f(*args, current_user_id=current_user_id, **kwargs) + except (TypeError, ValueError): + return jsonify({'error': '无效的用户身份标识'}), 401 + return jwt_required()(wrapped) + +auth_bp = Blueprint('auth', __name__) + +@auth_bp.route('/register', methods=['POST']) +def register(): + """用户注册""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + email = data.get('email') + + # 验证输入 + if not username or not password or not email: + return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400 + + # 验证邮箱格式 + email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(email_pattern, email): + return jsonify({'error': '邮箱格式不正确'}), 400 + + # 检查用户名是否已存在 + if User.query.filter_by(username=username).first(): + return jsonify({'error': '用户名已存在'}), 400 + + # 检查邮箱是否已注册 + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400 + + # 创建用户 + user = User(username=username, email=email) + user.set_password(password) + + db.session.add(user) + db.session.commit() + + # 创建用户默认配置 + user_config = UserConfig(user_id=user.id) + db.session.add(user_config) + db.session.commit() + + return jsonify({ + 'message': '注册成功', + 'user': user.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'注册失败: {str(e)}'}), 500 + +@auth_bp.route('/login', methods=['POST']) +def login(): + """用户登录""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + + if not username or not password: + return jsonify({'error': '用户名和密码不能为空'}), 400 + + # 查找用户 + user = User.query.filter_by(username=username).first() + + if not user or not user.check_password(password): + return jsonify({'error': '用户名或密码错误'}), 401 + + if not user.is_active: + return jsonify({'error': '账户已被禁用'}), 401 + + # 创建访问令牌 - 确保用户ID为字符串类型 + access_token = create_access_token(identity=str(user.id)) + + return jsonify({ + 'message': '登录成功', + 'access_token': access_token, + 'user': user.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'登录失败: {str(e)}'}), 500 + +@auth_bp.route('/change-password', methods=['POST']) +@int_jwt_required +def change_password(current_user_id): + """修改密码""" + try: + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + old_password = data.get('old_password') + new_password = data.get('new_password') + + if not old_password or not new_password: + return jsonify({'error': '旧密码和新密码不能为空'}), 400 + + # 验证旧密码 + if not user.check_password(old_password): + return jsonify({'error': '旧密码错误'}), 401 + + # 设置新密码 + user.set_password(new_password) + db.session.commit() + + return jsonify({'message': '密码修改成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'密码修改失败: {str(e)}'}), 500 + +@auth_bp.route('/profile', methods=['GET']) +@int_jwt_required +def get_profile(current_user_id): + """获取用户信息""" + try: + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + return jsonify({'user': user.to_dict()}), 200 + + except Exception as e: + return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500 + +@auth_bp.route('/logout', methods=['POST']) +@jwt_required() +def logout(): + """用户登出(客户端删除token即可)""" + return jsonify({'message': '登出成功'}), 200 \ No newline at end of file diff --git a/src/backend/app/controllers/demo_controller.py b/src/backend/app/controllers/demo_controller.py new file mode 100644 index 0000000..482958e --- /dev/null +++ b/src/backend/app/controllers/demo_controller.py @@ -0,0 +1,177 @@ +""" +演示图片控制器 +处理预设图像对比图的展示功能 +""" + +from flask import Blueprint, send_file, jsonify, current_app +from flask_jwt_extended import jwt_required +from app.models import PerturbationConfig, FinetuneConfig +import os +import glob + +demo_bp = Blueprint('demo', __name__) + +@demo_bp.route('/images', methods=['GET']) +def list_demo_images(): + """获取所有演示图片列表""" + try: + demo_images = [] + + # 获取演示原始图片 - 修正路径构建 + # 获取项目根目录(backend目录) + project_root = os.path.dirname(current_app.root_path) + original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) + + if os.path.exists(original_folder): + original_files = glob.glob(os.path.join(original_folder, '*')) + for file_path in original_files: + if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): + filename = os.path.basename(file_path) + name_without_ext = os.path.splitext(filename)[0] + + # 查找对应的加噪图片 + perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) + perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*")) + + # 查找对应的对比图 + comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) + comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*")) + + demo_image = { + 'id': name_without_ext, + 'name': name_without_ext, + 'original': f"/api/demo/image/original/{filename}", + 'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files], + 'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files] + } + demo_images.append(demo_image) + + return jsonify({ + 'demo_images': demo_images, + 'total': len(demo_images) + }), 200 + + except Exception as e: + return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500 + +@demo_bp.route('/image/original/', methods=['GET']) +def get_demo_original_image(filename): + """获取演示原始图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500 + +@demo_bp.route('/image/perturbed/', methods=['GET']) +def get_demo_perturbed_image(filename): + """获取演示加噪图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500 + +@demo_bp.route('/image/comparison/', methods=['GET']) +def get_demo_comparison_image(filename): + """获取演示对比图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500 + +@demo_bp.route('/algorithms', methods=['GET']) +def get_demo_algorithms(): + """获取演示算法信息""" + try: + # 从数据库获取扰动算法 + perturbation_algorithms = [] + perturbation_configs = PerturbationConfig.query.all() + for config in perturbation_configs: + perturbation_algorithms.append({ + 'id': config.id, + 'code': config.method_code, + 'name': config.method_name, + 'type': 'perturbation', + 'description': config.description, + 'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None + }) + + # 从数据库获取微调算法 + finetune_algorithms = [] + finetune_configs = FinetuneConfig.query.all() + for config in finetune_configs: + finetune_algorithms.append({ + 'id': config.id, + 'code': config.method_code, + 'name': config.method_name, + 'type': 'finetune', + 'description': config.description + }) + + return jsonify({ + 'perturbation_algorithms': perturbation_algorithms, + 'finetune_algorithms': finetune_algorithms, + 'evaluation_metrics': [ + {'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'}, + {'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'}, + {'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'}, + {'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'} + ] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500 + +@demo_bp.route('/stats', methods=['GET']) +def get_demo_stats(): + """获取演示统计信息""" + try: + # 统计演示图片数量 + project_root = os.path.dirname(current_app.root_path) + original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) + perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) + comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) + + original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0 + perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0 + comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0 + + # 统计数据库中的算法数量 + perturbation_count = PerturbationConfig.query.count() + finetune_count = FinetuneConfig.query.count() + total_algorithms = perturbation_count + finetune_count + + return jsonify({ + 'demo_stats': { + 'original_images': original_count, + 'perturbed_images': perturbed_count, + 'comparison_images': comparison_count, + 'supported_algorithms': total_algorithms, + 'perturbation_algorithms': perturbation_count, + 'finetune_algorithms': finetune_count, + 'evaluation_metrics': 4 + } + }), 200 + + except Exception as e: + return jsonify({'error': f'获取统计信息失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py new file mode 100644 index 0000000..f47fe98 --- /dev/null +++ b/src/backend/app/controllers/image_controller.py @@ -0,0 +1,203 @@ +""" +图像管理控制器 +处理图像下载、查看等功能 +""" + +from flask import Blueprint, send_file, jsonify, request, current_app +from flask_jwt_extended import jwt_required, get_jwt_identity +from app.models import Image, EvaluationResult +from app.services.image_service import ImageService +import os + +image_bp = Blueprint('image', __name__) + +@image_bp.route('/file/', methods=['GET']) +@jwt_required() +def get_image_file(image_id): + """获取图片文件""" + try: + current_user_id = get_jwt_identity() + + # 查找图片记录 + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 检查文件是否存在 + if not os.path.exists(image.file_path): + return jsonify({'error': '图片文件不存在'}), 404 + + return send_file(image.file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取图片失败: {str(e)}'}), 500 + +@image_bp.route('/download/', methods=['GET']) +@jwt_required() +def download_image(image_id): + """下载图片文件""" + try: + current_user_id = get_jwt_identity() + + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + if not os.path.exists(image.file_path): + return jsonify({'error': '图片文件不存在'}), 404 + + return send_file( + image.file_path, + as_attachment=True, + download_name=image.original_filename or f"image_{image_id}.jpg" + ) + + except Exception as e: + return jsonify({'error': f'下载图片失败: {str(e)}'}), 500 + +@image_bp.route('/batch//download', methods=['GET']) +@jwt_required() +def download_batch_images(batch_id): + """批量下载任务中的加噪后图片""" + try: + current_user_id = get_jwt_identity() + + # 获取任务中的加噪图片 + perturbed_images = Image.query.join(Image.image_type).filter( + Image.batch_id == batch_id, + Image.user_id == current_user_id, + Image.image_type.has(type_code='perturbed') + ).all() + + if not perturbed_images: + return jsonify({'error': '没有找到加噪后的图片'}), 404 + + # 创建ZIP文件 + import zipfile + import tempfile + + with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: + with zipfile.ZipFile(tmp_file.name, 'w') as zip_file: + for image in perturbed_images: + if os.path.exists(image.file_path): + arcname = image.original_filename or f"perturbed_{image.id}.jpg" + zip_file.write(image.file_path, arcname) + + return send_file( + tmp_file.name, + as_attachment=True, + download_name=f"batch_{batch_id}_perturbed_images.zip", + mimetype='application/zip' + ) + + except Exception as e: + return jsonify({'error': f'批量下载失败: {str(e)}'}), 500 + +@image_bp.route('//evaluations', methods=['GET']) +@jwt_required() +def get_image_evaluations(image_id): + """获取图片的评估结果""" + try: + current_user_id = get_jwt_identity() + + # 验证图片权限 + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 获取以该图片为参考或目标的评估结果 + evaluations = EvaluationResult.query.filter( + (EvaluationResult.reference_image_id == image_id) | + (EvaluationResult.target_image_id == image_id) + ).all() + + return jsonify({ + 'image_id': image_id, + 'evaluations': [eval_result.to_dict() for eval_result in evaluations] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500 + +@image_bp.route('/compare', methods=['POST']) +@jwt_required() +def compare_images(): + """对比两张图片""" + try: + current_user_id = get_jwt_identity() + data = request.get_json() + + image1_id = data.get('image1_id') + image2_id = data.get('image2_id') + + if not image1_id or not image2_id: + return jsonify({'error': '请提供两张图片的ID'}), 400 + + # 验证图片权限 + image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first() + image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first() + + if not image1 or not image2: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 查找现有的评估结果 + evaluation = EvaluationResult.query.filter_by( + reference_image_id=image1_id, + target_image_id=image2_id + ).first() + + if not evaluation: + # 如果没有评估结果,返回基本对比信息 + return jsonify({ + 'image1': image1.to_dict(), + 'image2': image2.to_dict(), + 'evaluation': None, + 'message': '暂无评估数据,请等待任务处理完成' + }), 200 + + return jsonify({ + 'image1': image1.to_dict(), + 'image2': image2.to_dict(), + 'evaluation': evaluation.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'图片对比失败: {str(e)}'}), 500 + +@image_bp.route('/heatmap/', methods=['GET']) +@jwt_required() +def get_heatmap(heatmap_path): + """获取热力图文件""" + try: + # 安全检查,防止路径遍历攻击 + if '..' in heatmap_path or heatmap_path.startswith('/'): + return jsonify({'error': '无效的文件路径'}), 400 + + # 修正路径构建 - 获取项目根目录(backend目录) + project_root = os.path.dirname(current_app.root_path) + full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path)) + + if not os.path.exists(full_path): + return jsonify({'error': '热力图文件不存在'}), 404 + + return send_file(full_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500 + +@image_bp.route('/delete/', methods=['DELETE']) +@jwt_required() +def delete_image(image_id): + """删除图片""" + try: + current_user_id = get_jwt_identity() + + result = ImageService.delete_image(image_id, current_user_id) + + if result['success']: + return jsonify({'message': '图片删除成功'}), 200 + else: + return jsonify({'error': result['error']}), 400 + + except Exception as e: + return jsonify({'error': f'删除图片失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py new file mode 100644 index 0000000..3cb8f73 --- /dev/null +++ b/src/backend/app/controllers/task_controller.py @@ -0,0 +1,400 @@ +""" +任务管理控制器 +处理创建任务、上传图片等功能 +""" + +from flask import Blueprint, request, jsonify, current_app +from flask_jwt_extended import jwt_required, get_jwt_identity +from werkzeug.utils import secure_filename +from app import db +from app.models import User, Batch, Image, ImageType, UserConfig +from app.services.task_service import TaskService +from app.services.image_service import ImageService +from app.utils.file_utils import allowed_file, save_uploaded_file +import os +import zipfile +import uuid + +task_bp = Blueprint('task', __name__) + +@task_bp.route('/create', methods=['POST']) +@jwt_required() +def create_task(): + """创建新任务(仅创建任务,使用默认配置)""" + try: + current_user_id = get_jwt_identity() + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') + + # 使用默认配置创建任务 + batch = Batch( + user_id=current_user_id, + batch_name=batch_name, + perturbation_config_id=1, # 默认配置 + preferred_epsilon=8.0, # 默认epsilon + finetune_config_id=1, # 默认微调配置 + use_strong_protection=False # 默认不启用强防护 + ) + + db.session.add(batch) + db.session.commit() + + return jsonify({ + 'message': '任务创建成功,请上传图片', + 'task': batch.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'任务创建失败: {str(e)}'}), 500 + +@task_bp.route('/upload/', methods=['POST']) +@jwt_required() +def upload_images(batch_id): + """上传图片到指定任务""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400 + + if 'files' not in request.files: + return jsonify({'error': '没有选择文件'}), 400 + + files = request.files.getlist('files') + uploaded_files = [] + + # 获取原始图片类型ID + original_type = ImageType.query.filter_by(type_code='original').first() + if not original_type: + return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500 + + for file in files: + if file.filename == '': + continue + + if file and allowed_file(file.filename): + # 处理单张图片 + if not file.filename.lower().endswith(('.zip', '.rar')): + result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) + if result['success']: + uploaded_files.append(result['image']) + else: + return jsonify({'error': result['error']}), 400 + + # 处理压缩包 + else: + results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) + for result in results: + if result['success']: + uploaded_files.append(result['image']) + + if not uploaded_files: + return jsonify({'error': '没有有效的图片文件'}), 400 + + return jsonify({ + 'message': f'成功上传 {len(uploaded_files)} 张图片', + 'uploaded_files': [img.to_dict() for img in uploaded_files] + }), 200 + + except Exception as e: + return jsonify({'error': f'文件上传失败: {str(e)}'}), 500 + +@task_bp.route('//config', methods=['GET']) +@jwt_required() +def get_task_config(batch_id): + """获取任务配置(显示用户上次的配置或默认配置)""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + # 获取用户配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + # 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置 + if user_config: + suggested_config = { + 'perturbation_config_id': user_config.preferred_perturbation_config_id, + 'epsilon': float(user_config.preferred_epsilon), + 'finetune_config_id': user_config.preferred_finetune_config_id, + 'use_strong_protection': user_config.preferred_purification + } + else: + suggested_config = { + 'perturbation_config_id': batch.perturbation_config_id, + 'epsilon': float(batch.preferred_epsilon), + 'finetune_config_id': batch.finetune_config_id, + 'use_strong_protection': batch.use_strong_protection + } + + return jsonify({ + 'task': batch.to_dict(), + 'suggested_config': suggested_config, + 'current_config': { + 'perturbation_config_id': batch.perturbation_config_id, + 'epsilon': float(batch.preferred_epsilon), + 'finetune_config_id': batch.finetune_config_id, + 'use_strong_protection': batch.use_strong_protection + } + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500 + +@task_bp.route('//config', methods=['PUT']) +@jwt_required() +def update_task_config(batch_id): + """更新任务配置""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务已开始处理,无法修改配置'}), 400 + + data = request.get_json() + + # 更新任务配置 + if 'perturbation_config_id' in data: + batch.perturbation_config_id = data['perturbation_config_id'] + + if 'epsilon' in data: + epsilon = float(data['epsilon']) + if 0 < epsilon <= 255: + batch.preferred_epsilon = epsilon + else: + return jsonify({'error': '扰动强度必须在0-255之间'}), 400 + + if 'finetune_config_id' in data: + batch.finetune_config_id = data['finetune_config_id'] + + if 'use_strong_protection' in data: + batch.use_strong_protection = bool(data['use_strong_protection']) + + db.session.commit() + + # 更新用户配置(保存这次的选择作为下次的默认配置) + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + if not user_config: + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + + user_config.preferred_perturbation_config_id = batch.perturbation_config_id + user_config.preferred_epsilon = batch.preferred_epsilon + user_config.preferred_finetune_config_id = batch.finetune_config_id + user_config.preferred_purification = batch.use_strong_protection + + db.session.commit() + + return jsonify({ + 'message': '任务配置更新成功', + 'task': batch.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500 + +@task_bp.route('/start/', methods=['POST']) +@jwt_required() +def start_task(batch_id): + """开始处理任务""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 + + # 检查是否有上传的图片 + image_count = Image.query.filter_by(batch_id=batch_id).count() + if image_count == 0: + return jsonify({'error': '请先上传图片'}), 400 + + # 启动任务处理 + success = TaskService.start_processing(batch) + + if success: + return jsonify({ + 'message': '任务开始处理', + 'task': batch.to_dict() + }), 200 + else: + return jsonify({'error': '任务启动失败'}), 500 + + except Exception as e: + return jsonify({'error': f'任务启动失败: {str(e)}'}), 500 + +@task_bp.route('/list', methods=['GET']) +@jwt_required() +def list_tasks(): + """获取用户的任务列表""" + try: + current_user_id = get_jwt_identity() + + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 10, type=int) + + batches = Batch.query.filter_by(user_id=current_user_id)\ + .order_by(Batch.created_at.desc())\ + .paginate(page=page, per_page=per_page, error_out=False) + + return jsonify({ + 'tasks': [batch.to_dict() for batch in batches.items], + 'total': batches.total, + 'pages': batches.pages, + 'current_page': page + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500 + +@task_bp.route('/', methods=['GET']) +@jwt_required() +def get_task_detail(batch_id): + """获取任务详情""" + try: + current_user_id = get_jwt_identity() + + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + # 获取任务相关的图片 + images = Image.query.filter_by(batch_id=batch_id).all() + + return jsonify({ + 'task': batch.to_dict(), + 'images': [img.to_dict() for img in images], + 'image_count': len(images) + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500 + +@task_bp.route('/load-config', methods=['GET']) +@jwt_required() +def load_last_config(): + """加载用户上次的配置""" + try: + current_user_id = get_jwt_identity() + + # 获取用户配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if user_config: + config = { + 'perturbation_config_id': user_config.preferred_perturbation_config_id, + 'epsilon': float(user_config.preferred_epsilon), + 'finetune_config_id': user_config.preferred_finetune_config_id, + 'use_strong_protection': user_config.preferred_purification + } + return jsonify({ + 'message': '成功加载上次配置', + 'config': config + }), 200 + else: + # 返回默认配置 + default_config = { + 'perturbation_config_id': 1, + 'epsilon': 8.0, + 'finetune_config_id': 1, + 'use_strong_protection': False + } + return jsonify({ + 'message': '使用默认配置', + 'config': default_config + }), 200 + + except Exception as e: + return jsonify({'error': f'加载配置失败: {str(e)}'}), 500 + +@task_bp.route('/save-config', methods=['POST']) +@jwt_required() +def save_current_config(): + """保存当前配置作为用户偏好""" + try: + current_user_id = get_jwt_identity() + data = request.get_json() + + # 获取或创建用户配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + if not user_config: + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + + # 更新配置 + if 'perturbation_config_id' in data: + user_config.preferred_perturbation_config_id = data['perturbation_config_id'] + + if 'epsilon' in data: + epsilon = float(data['epsilon']) + if 0 < epsilon <= 255: + user_config.preferred_epsilon = epsilon + else: + return jsonify({'error': '扰动强度必须在0-255之间'}), 400 + + if 'finetune_config_id' in data: + user_config.preferred_finetune_config_id = data['finetune_config_id'] + + if 'use_strong_protection' in data: + user_config.preferred_purification = bool(data['use_strong_protection']) + + db.session.commit() + + return jsonify({ + 'message': '配置保存成功', + 'config': { + 'perturbation_config_id': user_config.preferred_perturbation_config_id, + 'epsilon': float(user_config.preferred_epsilon), + 'finetune_config_id': user_config.preferred_finetune_config_id, + 'use_strong_protection': user_config.preferred_purification + } + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'保存配置失败: {str(e)}'}), 500 + +@task_bp.route('//status', methods=['GET']) +@jwt_required() +def get_task_status(batch_id): + """获取任务处理状态""" + try: + current_user_id = get_jwt_identity() + + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + return jsonify({ + 'task_id': batch_id, + 'status': batch.status, + 'progress': TaskService.get_processing_progress(batch_id), + 'error_message': batch.error_message + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/user_controller.py b/src/backend/app/controllers/user_controller.py new file mode 100644 index 0000000..b66b501 --- /dev/null +++ b/src/backend/app/controllers/user_controller.py @@ -0,0 +1,133 @@ +""" +用户管理控制器 +处理用户配置等功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required +from app import db +from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig +from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 + +user_bp = Blueprint('user', __name__) + +@user_bp.route('/config', methods=['GET']) +@int_jwt_required +def get_user_config(current_user_id): + """获取用户配置""" + try: + + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if not user_config: + # 如果没有配置,创建默认配置 + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + db.session.commit() + + return jsonify({ + 'config': user_config.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500 + +@user_bp.route('/config', methods=['PUT']) +@int_jwt_required +def update_user_config(current_user_id): + """更新用户配置""" + try: + data = request.get_json() + + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if not user_config: + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + + # 更新配置字段 + if 'preferred_perturbation_config_id' in data: + user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id'] + + if 'preferred_epsilon' in data: + epsilon = float(data['preferred_epsilon']) + if 0 < epsilon <= 255: + user_config.preferred_epsilon = epsilon + else: + return jsonify({'error': '扰动强度必须在0-255之间'}), 400 + + if 'preferred_finetune_config_id' in data: + user_config.preferred_finetune_config_id = data['preferred_finetune_config_id'] + + if 'preferred_purification' in data: + user_config.preferred_purification = bool(data['preferred_purification']) + + db.session.commit() + + return jsonify({ + 'message': '用户配置更新成功', + 'config': user_config.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500 + +@user_bp.route('/algorithms', methods=['GET']) +@jwt_required() +def get_available_algorithms(): + """获取可用的算法列表""" + try: + perturbation_configs = PerturbationConfig.query.all() + finetune_configs = FinetuneConfig.query.all() + + return jsonify({ + 'perturbation_algorithms': [ + { + 'id': config.id, + 'method_code': config.method_code, + 'method_name': config.method_name, + 'description': config.description, + 'default_epsilon': float(config.default_epsilon) + } for config in perturbation_configs + ], + 'finetune_methods': [ + { + 'id': config.id, + 'method_code': config.method_code, + 'method_name': config.method_name, + 'description': config.description + } for config in finetune_configs + ] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500 + +@user_bp.route('/stats', methods=['GET']) +@int_jwt_required +def get_user_stats(current_user_id): + """获取用户统计信息""" + try: + from app.models import Batch, Image + + # 统计用户的任务和图片数量 + total_tasks = Batch.query.filter_by(user_id=current_user_id).count() + completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count() + processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count() + failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count() + + total_images = Image.query.filter_by(user_id=current_user_id).count() + + return jsonify({ + 'stats': { + 'total_tasks': total_tasks, + 'completed_tasks': completed_tasks, + 'processing_tasks': processing_tasks, + 'failed_tasks': failed_tasks, + 'total_images': total_images + } + }), 200 + + except Exception as e: + return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/models/__init__.py b/src/backend/app/models/__init__.py new file mode 100644 index 0000000..d92e3e8 --- /dev/null +++ b/src/backend/app/models/__init__.py @@ -0,0 +1,233 @@ +""" +数据库模型定义 +基于已有的schema.sql设计 +""" + +from datetime import datetime +from app import db +from werkzeug.security import generate_password_hash, check_password_hash +from enum import Enum as PyEnum + +class User(db.Model): + """用户表""" + __tablename__ = 'users' + + id = db.Column(db.BigInteger, primary_key=True) + username = db.Column(db.String(50), unique=True, nullable=False) + password_hash = db.Column(db.String(255), nullable=False) + email = db.Column(db.String(100)) + role = db.Column(db.Enum('user', 'admin'), default='user') + max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0) + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + is_active = db.Column(db.Boolean, default=True) + + # 关系 + batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan') + images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan') + user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan') + + def set_password(self, password): + """设置密码""" + self.password_hash = generate_password_hash(password) + + def check_password(self, password): + """验证密码""" + return check_password_hash(self.password_hash, password) + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'username': self.username, + 'email': self.email, + 'role': self.role, + 'max_concurrent_tasks': self.max_concurrent_tasks, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'is_active': self.is_active + } + +class ImageType(db.Model): + """图片类型表""" + __tablename__ = 'image_types' + + id = db.Column(db.BigInteger, primary_key=True) + type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'), + unique=True, nullable=False) + type_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text) + + # 关系 + images = db.relationship('Image', backref='image_type', lazy='dynamic') + +class PerturbationConfig(db.Model): + """加噪算法表""" + __tablename__ = 'perturbation_configs' + + id = db.Column(db.BigInteger, primary_key=True) + method_code = db.Column(db.String(50), unique=True, nullable=False) + method_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=False) + default_epsilon = db.Column(db.Numeric(5, 2), nullable=False) + + # 关系 + batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic') + user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic') + +class FinetuneConfig(db.Model): + """微调方式表""" + __tablename__ = 'finetune_configs' + + id = db.Column(db.BigInteger, primary_key=True) + method_code = db.Column(db.String(50), unique=True, nullable=False) + method_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=False) + + # 关系 + batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic') + user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic') + +class Batch(db.Model): + """加噪批次表""" + __tablename__ = 'batch' + + id = db.Column(db.BigInteger, primary_key=True) + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) + batch_name = db.Column(db.String(128)) + + # 加噪配置 + perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'), + nullable=False, default=1) + preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0) + + # 评估配置 + finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'), + nullable=False, default=1) + + # 净化配置 + use_strong_protection = db.Column(db.Boolean, nullable=False, default=False) + + # 任务状态 + status = db.Column(db.Enum('pending', 'processing', 'completed', 'failed'), default='pending') + created_at = db.Column(db.DateTime, default=datetime.utcnow) + started_at = db.Column(db.DateTime) + completed_at = db.Column(db.DateTime) + error_message = db.Column(db.Text) + result_path = db.Column(db.String(500)) + + # 关系 + images = db.relationship('Image', backref='batch', lazy='dynamic') + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'batch_name': self.batch_name, + 'status': self.status, + 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, + 'use_strong_protection': self.use_strong_protection, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'started_at': self.started_at.isoformat() if self.started_at else None, + 'completed_at': self.completed_at.isoformat() if self.completed_at else None, + 'error_message': self.error_message, + 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None, + 'finetune_config': self.finetune_config.method_name if self.finetune_config else None + } + +class Image(db.Model): + """图片表""" + __tablename__ = 'images' + + id = db.Column(db.BigInteger, primary_key=True) + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) + batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id')) + father_id = db.Column(db.BigInteger, db.ForeignKey('images.id')) + original_filename = db.Column(db.String(255)) + stored_filename = db.Column(db.String(255), unique=True, nullable=False) + file_path = db.Column(db.String(500), nullable=False) + file_size = db.Column(db.BigInteger) + image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False) + width = db.Column(db.Integer) + height = db.Column(db.Integer) + upload_time = db.Column(db.DateTime, default=datetime.utcnow) + + # 自引用关系 + children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') + + # 评估结果关系 + reference_evaluations = db.relationship('EvaluationResult', + foreign_keys='EvaluationResult.reference_image_id', + backref='reference_image', lazy='dynamic') + target_evaluations = db.relationship('EvaluationResult', + foreign_keys='EvaluationResult.target_image_id', + backref='target_image', lazy='dynamic') + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'original_filename': self.original_filename, + 'stored_filename': self.stored_filename, + 'file_path': self.file_path, + 'file_size': self.file_size, + 'width': self.width, + 'height': self.height, + 'upload_time': self.upload_time.isoformat() if self.upload_time else None, + 'image_type': self.image_type.type_name if self.image_type else None, + 'batch_id': self.batch_id + } + +class EvaluationResult(db.Model): + """评估结果表""" + __tablename__ = 'evaluation_results' + + id = db.Column(db.BigInteger, primary_key=True) + reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) + target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) + evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False) + purification_applied = db.Column(db.Boolean, default=False) + fid_score = db.Column(db.Numeric(8, 4)) + lpips_score = db.Column(db.Numeric(8, 4)) + ssim_score = db.Column(db.Numeric(8, 4)) + psnr_score = db.Column(db.Numeric(8, 4)) + heatmap_path = db.Column(db.String(500)) + evaluated_at = db.Column(db.DateTime, default=datetime.utcnow) + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'evaluation_type': self.evaluation_type, + 'purification_applied': self.purification_applied, + 'fid_score': float(self.fid_score) if self.fid_score else None, + 'lpips_score': float(self.lpips_score) if self.lpips_score else None, + 'ssim_score': float(self.ssim_score) if self.ssim_score else None, + 'psnr_score': float(self.psnr_score) if self.psnr_score else None, + 'heatmap_path': self.heatmap_path, + 'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None + } + +class UserConfig(db.Model): + """用户配置表""" + __tablename__ = 'user_configs' + + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True) + preferred_perturbation_config_id = db.Column(db.BigInteger, + db.ForeignKey('perturbation_configs.id'), default=1) + preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0) + preferred_finetune_config_id = db.Column(db.BigInteger, + db.ForeignKey('finetune_configs.id'), default=1) + preferred_purification = db.Column(db.Boolean, default=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + """转换为字典""" + return { + 'user_id': self.user_id, + 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, + 'preferred_purification': self.preferred_purification, + 'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None, + 'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None + } \ No newline at end of file diff --git a/src/backend/app/services/auth_service.py b/src/backend/app/services/auth_service.py new file mode 100644 index 0000000..63a95c1 --- /dev/null +++ b/src/backend/app/services/auth_service.py @@ -0,0 +1,34 @@ +""" +认证服务 +处理用户认证相关逻辑 +""" + +from app.models import User + +class AuthService: + """认证服务类""" + + @staticmethod + def authenticate_user(username, password): + """验证用户凭据""" + user = User.query.filter_by(username=username).first() + + if user and user.check_password(password) and user.is_active: + return user + + return None + + @staticmethod + def get_user_by_id(user_id): + """根据ID获取用户""" + return User.query.get(user_id) + + @staticmethod + def is_email_available(email): + """检查邮箱是否可用""" + return User.query.filter_by(email=email).first() is None + + @staticmethod + def is_username_available(username): + """检查用户名是否可用""" + return User.query.filter_by(username=username).first() is None \ No newline at end of file diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py new file mode 100644 index 0000000..914533d --- /dev/null +++ b/src/backend/app/services/image_service.py @@ -0,0 +1,161 @@ +""" +图像处理服务 +处理图像上传、保存等功能 +""" + +import os +import uuid +import zipfile +from werkzeug.utils import secure_filename +from flask import current_app +from PIL import Image as PILImage +from app import db +from app.models import Image +from app.utils.file_utils import allowed_file + +class ImageService: + """图像处理服务""" + + @staticmethod + def save_image(file, batch_id, user_id, image_type_id): + """保存单张图片""" + try: + if not file or not allowed_file(file.filename): + return {'success': False, 'error': '不支持的文件格式'} + + # 生成唯一文件名 + file_extension = os.path.splitext(file.filename)[1].lower() + stored_filename = f"{uuid.uuid4().hex}{file_extension}" + + # 临时保存到上传目录 + project_root = os.path.dirname(current_app.root_path) + temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, stored_filename) + file.save(temp_path) + + # 移动到对应的静态目录 + static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id)) + os.makedirs(static_dir, exist_ok=True) + file_path = os.path.join(static_dir, stored_filename) + + # 移动文件到最终位置 + import shutil + shutil.move(temp_path, file_path) + + # 获取图片尺寸 + try: + with PILImage.open(file_path) as img: + width, height = img.size + except: + width, height = None, None + + # 创建数据库记录 + image = Image( + user_id=user_id, + batch_id=batch_id, + original_filename=file.filename, + stored_filename=stored_filename, + file_path=file_path, + file_size=os.path.getsize(file_path), + image_type_id=image_type_id, + width=width, + height=height + ) + + db.session.add(image) + db.session.commit() + + return {'success': True, 'image': image} + + except Exception as e: + db.session.rollback() + return {'success': False, 'error': f'保存图片失败: {str(e)}'} + + @staticmethod + def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id): + """解压并保存压缩包中的图片""" + results = [] + temp_dir = None + + try: + # 创建临时目录 + project_root = os.path.dirname(current_app.root_path) + temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}") + os.makedirs(temp_dir, exist_ok=True) + + # 保存压缩包 + zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename)) + zip_file.save(zip_path) + + # 解压文件 + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + # 遍历解压的文件 + for root, dirs, files in os.walk(temp_dir): + for filename in files: + if filename.lower().endswith(('.zip', '.rar')): + continue # 跳过压缩包文件本身 + + if allowed_file(filename): + file_path = os.path.join(root, filename) + + # 创建虚拟文件对象 + class FileWrapper: + def __init__(self, path, name): + self.path = path + self.filename = name + + def save(self, destination): + import shutil + shutil.copy2(self.path, destination) + + virtual_file = FileWrapper(file_path, filename) + result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id) + results.append(result) + + return results + + except Exception as e: + return [{'success': False, 'error': f'解压失败: {str(e)}'}] + + finally: + # 清理临时文件 + if temp_dir and os.path.exists(temp_dir): + import shutil + try: + shutil.rmtree(temp_dir) + except: + pass + + @staticmethod + def get_image_url(image): + """获取图片访问URL""" + if not image or not image.file_path: + return None + + # 这里返回相对路径,前端可以拼接完整URL + return f"/api/image/file/{image.id}" + + @staticmethod + def delete_image(image_id, user_id): + """删除图片""" + try: + image = Image.query.filter_by(id=image_id, user_id=user_id).first() + if not image: + return {'success': False, 'error': '图片不存在或无权限'} + + # 删除文件 + if os.path.exists(image.file_path): + os.remove(image.file_path) + + # 删除数据库记录 + db.session.delete(image) + db.session.commit() + + return {'success': True} + + except Exception as e: + db.session.rollback() + return {'success': False, 'error': f'删除图片失败: {str(e)}'} \ No newline at end of file diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py new file mode 100644 index 0000000..b9b00e0 --- /dev/null +++ b/src/backend/app/services/task_service.py @@ -0,0 +1,191 @@ +""" +任务处理服务 +处理图像加噪、评估等核心业务逻辑 +""" + +import os +import time +from datetime import datetime +from flask import current_app +from app import db +from app.models import Batch, Image, EvaluationResult, ImageType +from app.algorithms.perturbation_engine import PerturbationEngine +from app.algorithms.evaluation_engine import EvaluationEngine + +class TaskService: + """任务处理服务""" + + @staticmethod + def start_processing(batch): + """开始处理任务""" + try: + # 更新任务状态 + batch.status = 'processing' + batch.started_at = datetime.utcnow() + db.session.commit() + + # 获取任务相关的原始图片 + original_images = Image.query.filter_by( + batch_id=batch.id + ).join(ImageType).filter( + ImageType.type_code == 'original' + ).all() + + if not original_images: + batch.status = 'failed' + batch.error_message = '没有找到原始图片' + batch.completed_at = datetime.utcnow() + db.session.commit() + return False + + # 处理每张图片 + perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() + + processed_images = [] + for original_image in original_images: + try: + # 确定加噪图片的保存路径 + project_root = os.path.dirname(current_app.root_path) + perturbed_dir = os.path.join(project_root, + current_app.config['PERTURBED_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id)) + os.makedirs(perturbed_dir, exist_ok=True) + + # 调用加噪算法 + perturbed_image_path = PerturbationEngine.apply_perturbation( + original_image.file_path, + batch.perturbation_config.method_code, + float(batch.preferred_epsilon), + batch.use_strong_protection + ) + + if perturbed_image_path: + # 保存加噪后的图片记录 + perturbed_image = Image( + user_id=batch.user_id, + batch_id=batch.id, + father_id=original_image.id, + original_filename=f"perturbed_{original_image.original_filename}", + stored_filename=os.path.basename(perturbed_image_path), + file_path=perturbed_image_path, + file_size=os.path.getsize(perturbed_image_path) if os.path.exists(perturbed_image_path) else 0, + image_type_id=perturbed_type.id, + width=original_image.width, + height=original_image.height + ) + + db.session.add(perturbed_image) + processed_images.append((original_image, perturbed_image)) + + except Exception as e: + print(f"处理图片 {original_image.id} 时出错: {str(e)}") + continue + + # 提交加噪后的图片 + db.session.commit() + + # 生成评估结果 + TaskService._generate_evaluations(batch, processed_images) + + # 更新任务状态为完成 + batch.status = 'completed' + batch.completed_at = datetime.utcnow() + db.session.commit() + + return True + + except Exception as e: + # 处理失败 + batch.status = 'failed' + batch.error_message = str(e) + batch.completed_at = datetime.utcnow() + db.session.commit() + return False + + @staticmethod + def _generate_evaluations(batch, processed_images): + """生成评估结果""" + try: + for original_image, perturbed_image in processed_images: + # 图像质量对比评估 + quality_metrics = EvaluationEngine.evaluate_image_quality( + original_image.file_path, + perturbed_image.file_path + ) + + quality_evaluation = EvaluationResult( + reference_image_id=original_image.id, + target_image_id=perturbed_image.id, + evaluation_type='image_quality', + purification_applied=False, + fid_score=quality_metrics.get('fid'), + lpips_score=quality_metrics.get('lpips'), + ssim_score=quality_metrics.get('ssim'), + psnr_score=quality_metrics.get('psnr'), + heatmap_path=quality_metrics.get('heatmap_path') + ) + + db.session.add(quality_evaluation) + + # 模型生成对比评估 + generation_metrics = EvaluationEngine.evaluate_model_generation( + original_image.file_path, + perturbed_image.file_path, + batch.finetune_config.method_code + ) + + generation_evaluation = EvaluationResult( + reference_image_id=original_image.id, + target_image_id=perturbed_image.id, + evaluation_type='model_generation', + purification_applied=False, + fid_score=generation_metrics.get('fid'), + lpips_score=generation_metrics.get('lpips'), + ssim_score=generation_metrics.get('ssim'), + psnr_score=generation_metrics.get('psnr'), + heatmap_path=generation_metrics.get('heatmap_path') + ) + + db.session.add(generation_evaluation) + + db.session.commit() + + except Exception as e: + print(f"生成评估结果时出错: {str(e)}") + + @staticmethod + def get_processing_progress(batch_id): + """获取处理进度""" + try: + batch = Batch.query.get(batch_id) + if not batch: + return 0 + + if batch.status == 'pending': + return 0 + elif batch.status == 'completed': + return 100 + elif batch.status == 'failed': + return 0 + elif batch.status == 'processing': + # 简单的进度计算:根据已处理的图片数量 + total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( + ImageType.type_code == 'original' + ).count() + + processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( + ImageType.type_code == 'perturbed' + ).count() + + if total_images == 0: + return 0 + + progress = int((processed_images / total_images) * 80) # 80%用于图像处理,20%用于评估 + return min(progress, 95) # 最多95%,剩余5%用于最终完成 + + return 0 + + except Exception as e: + print(f"获取处理进度时出错: {str(e)}") + return 0 \ No newline at end of file diff --git a/src/backend/app/utils/file_utils.py b/src/backend/app/utils/file_utils.py new file mode 100644 index 0000000..f23d489 --- /dev/null +++ b/src/backend/app/utils/file_utils.py @@ -0,0 +1,53 @@ +""" +文件处理工具类 +""" + +import os +from werkzeug.utils import secure_filename +from flask import current_app + +def allowed_file(filename): + """检查文件扩展名是否被允许""" + if not filename: + return False + + allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS', + {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}) + + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in allowed_extensions + +def save_uploaded_file(file, upload_path): + """保存上传的文件""" + try: + if not file or not allowed_file(file.filename): + return None + + filename = secure_filename(file.filename) + + # 确保目录存在 + os.makedirs(os.path.dirname(upload_path), exist_ok=True) + + file.save(upload_path) + return upload_path + + except Exception as e: + print(f"保存文件失败: {str(e)}") + return None + +def get_file_size(file_path): + """获取文件大小""" + try: + return os.path.getsize(file_path) + except: + return 0 + +def delete_file(file_path): + """删除文件""" + try: + if os.path.exists(file_path): + os.remove(file_path) + return True + except: + pass + return False \ No newline at end of file diff --git a/src/backend/app/utils/jwt_utils.py b/src/backend/app/utils/jwt_utils.py new file mode 100644 index 0000000..8e0aee6 --- /dev/null +++ b/src/backend/app/utils/jwt_utils.py @@ -0,0 +1,16 @@ +""" +JWT工具函数 +""" +from functools import wraps +from flask_jwt_extended import get_jwt_identity + +def int_jwt_required(f): + """获取JWT身份并转换为整数的装饰器""" + @wraps(f) + def wrapped(*args, **kwargs): + jwt_identity = get_jwt_identity() + if jwt_identity is not None: + # 在函数调用前注入转换后的user_id + kwargs['current_user_id'] = int(jwt_identity) + return f(*args, **kwargs) + return wrapped \ No newline at end of file diff --git a/src/backend/config/.env b/src/backend/config/.env new file mode 100644 index 0000000..be6a0db --- /dev/null +++ b/src/backend/config/.env @@ -0,0 +1,16 @@ +# MuseGuard 环境变量配置文件 +# 注意:此文件包含敏感信息,不应提交到版本控制系统 + +# 数据库配置 +DB_USER=root +DB_PASSWORD=your_password_here +DB_HOST=localhost +DB_NAME=your_database_name_here + +# Flask配置 +SECRET_KEY=museguard-secret-key-2024 +JWT_SECRET_KEY=jwt-secret-string + +# 开发模式 +FLASK_ENV=development +FLASK_DEBUG=True \ No newline at end of file diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py new file mode 100644 index 0000000..ecaff13 --- /dev/null +++ b/src/backend/config/settings.py @@ -0,0 +1,109 @@ +""" +应用配置文件 +""" +import os +from datetime import timedelta +from dotenv import load_dotenv +from urllib.parse import quote_plus + +# 加载环境变量 - 从 config 目录读取 .env 文件 +config_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(config_dir, '.env') +load_dotenv(env_path) + +class Config: + """基础配置类""" + + # 基础配置 + SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024' + + # 数据库配置 - 支持密码中的特殊字符 + DB_USER = os.environ.get('DB_USER') + DB_PASSWORD = os.environ.get('DB_PASSWORD') + DB_HOST = os.environ.get('DB_HOST') or 'localhost' + DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema' + + # URL编码密码中的特殊字符 + from urllib.parse import quote_plus + _encoded_password = quote_plus(DB_PASSWORD) + + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \ + f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}' + SQLALCHEMY_TRACK_MODIFICATIONS = False + SQLALCHEMY_ENGINE_OPTIONS = { + 'pool_pre_ping': True, + 'pool_recycle': 300, + } + + # JWT配置 + JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string' + JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24) + JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30) + + # 静态文件根目录 + STATIC_ROOT = 'static' + + # 文件上传配置 + UPLOAD_FOLDER = 'uploads' # 临时上传目录 + MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB + ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'} + + # 图像处理配置 + ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片 + PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 + MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 + MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果 + MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 + HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 + + # 预设演示图像配置 + DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录 + DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片 + DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片 + DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图 + + # 邮件配置(用于注册验证) + MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com' + MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587) + MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1'] + MAIL_USERNAME = os.environ.get('MAIL_USERNAME') + MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD') + + # 算法配置 + ALGORITHMS = { + 'simac': { + 'name': 'SimAC算法', + 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', + 'default_epsilon': 8.0 + }, + 'caat': { + 'name': 'CAAT算法', + 'description': 'Perturbing Attention Gives You More Bang for the Buck', + 'default_epsilon': 16.0 + }, + 'pid': { + 'name': 'PID算法', + 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', + 'default_epsilon': 4.0 + } + } + +class DevelopmentConfig(Config): + """开发环境配置""" + DEBUG = True + +class ProductionConfig(Config): + """生产环境配置""" + DEBUG = False + +class TestingConfig(Config): + """测试环境配置""" + TESTING = True + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig +} \ No newline at end of file diff --git a/src/backend/init_db.py b/src/backend/init_db.py new file mode 100644 index 0000000..009ae84 --- /dev/null +++ b/src/backend/init_db.py @@ -0,0 +1,81 @@ +""" +数据库初始化脚本 +""" + +from app import create_app, db +from app.models import * + +def init_database(): + """初始化数据库""" + app = create_app() + + with app.app_context(): + # 创建所有表 + db.create_all() + + # 初始化图片类型数据 + image_types = [ + {'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'}, + {'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, + {'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, + {'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'} + ] + + for img_type in image_types: + existing = ImageType.query.filter_by(type_code=img_type['type_code']).first() + if not existing: + new_type = ImageType(**img_type) + db.session.add(new_type) + + # 初始化加噪算法数据 + perturbation_configs = [ + {'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0}, + {'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0}, + {'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0}, + {'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0} + ] + + for config in perturbation_configs: + existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first() + if not existing: + new_config = PerturbationConfig(**config) + db.session.add(new_config) + + # 初始化微调方式数据 + finetune_configs = [ + {'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'}, + {'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应(Low-Rank Adaptation)微调方法'}, + {'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'} + ] + + for config in finetune_configs: + existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first() + if not existing: + new_config = FinetuneConfig(**config) + db.session.add(new_config) + + # 创建默认管理员用户 + admin_users = [ + {'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'}, + {'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'}, + {'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'} + ] + + for admin_data in admin_users: + existing = User.query.filter_by(username=admin_data['username']).first() + if not existing: + admin_user = User(**admin_data) + admin_user.set_password('admin123') # 默认密码 + db.session.add(admin_user) + + # 为管理员创建默认配置 + db.session.flush() # 确保user.id可用 + user_config = UserConfig(user_id=admin_user.id) + db.session.add(user_config) + + # 提交所有更改 + db.session.commit() + print("数据库初始化完成!") + +if __name__ == '__main__': + init_database() \ No newline at end of file diff --git a/src/backend/quick_start.bat b/src/backend/quick_start.bat new file mode 100644 index 0000000..c274f00 --- /dev/null +++ b/src/backend/quick_start.bat @@ -0,0 +1,15 @@ +@echo off +chcp 65001 > nul +echo ========================================== +echo MuseGuard 快速启动脚本 +echo ========================================== +echo. + +REM 切换到项目目录 +cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend" + +REM 激活虚拟环境并启动服务器 +echo 正在激活虚拟环境并启动服务器... +call venv_py311\Scripts\activate.bat && python run.py + +pause \ No newline at end of file diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt new file mode 100644 index 0000000..b0b6fd7 --- /dev/null +++ b/src/backend/requirements.txt @@ -0,0 +1,33 @@ +# Core Flask Framework +Flask==3.0.0 +Flask-SQLAlchemy==3.1.1 +Flask-Migrate==4.0.5 +Flask-JWT-Extended==4.6.0 +Flask-CORS==5.0.0 +Werkzeug==3.0.1 + +# Database +PyMySQL==1.1.1 + +# Image Processing +Pillow==10.4.0 +numpy==1.26.4 + +# Security & Utils +cryptography==42.0.8 +python-dotenv==1.0.1 + +# Additional Dependencies (auto-installed) +# alembic==1.17.0 +# blinker==1.9.0 +# cffi==2.0.0 +# click==8.3.0 +# greenlet==3.2.4 +# itsdangerous==2.2.0 +# Jinja2==3.1.6 +# Mako==1.3.10 +# MarkupSafe==3.0.3 +# pycparser==2.23 +# PyJWT==2.10.1 +# SQLAlchemy==2.0.44 +# typing_extensions==4.15.0 \ No newline at end of file diff --git a/src/backend/run.py b/src/backend/run.py new file mode 100644 index 0000000..0f24a01 --- /dev/null +++ b/src/backend/run.py @@ -0,0 +1,21 @@ +""" +Flask应用启动脚本 +""" + +import os +from app import create_app + +# 设置环境变量 +os.environ.setdefault('FLASK_ENV', 'development') + +# 创建应用实例 +app = create_app() + +if __name__ == '__main__': + # 开发模式启动 + app.run( + host='0.0.0.0', + port=5000, + debug=True, + threaded=True + ) \ No newline at end of file diff --git a/src/backend/static/test.html b/src/backend/static/test.html new file mode 100644 index 0000000..4fb5976 --- /dev/null +++ b/src/backend/static/test.html @@ -0,0 +1,1644 @@ + + + + + + MuseGuard API 全功能测试页面 + + + +
+

🧪 MuseGuard API 测试页面

+ + +
+

🌐 服务器连通性测试

+ + +
+ + +
+

🎨 Demo Controller - 演示模块

+
+
+

演示图片

+ + +
+ +
+

算法信息

+ +
+ + +
+
+ + +
+

🧑‍💻 Auth Controller - 认证模块

+
+
+

用户注册

+
+ + +
+
+ + +
+
+ + +
+ +
+ +
+

用户登录

+
+ + +
+
+ + +
+ + + +

修改密码

+
+ + +
+
+ + +
+
+ + +
+ + + + +
+ +
+
+ + +
+

🔄 Task Controller - 任务管理模块

+
+
+

1. 创建任务(第一步)

+
+ + +
+ +
+ +
+

我的批次列表

+
+

正在加载批次...

+
+ +
+ +
+

2. 文件上传(第二步)

+
+ + +
+
+ +
+ +

点击或拖放图片文件到此处

+
+
+ +
+ +
+

3. 配置任务(第三步)

+

配置已自动加载上次使用的设置,您可以根据需要调整

+ +
+ + +
+
+ + + 推荐范围: 1.0 - 16.0 +
+
+ + +
+
+ + 强防护模式可提高安全性,但会增加处理时间 +
+ + +
+ +
+

4. 任务管理(第四步)

+ + + + +
+ +
+
+ + + + +
+

🖼️ Image Controller - 图像处理模块

+
+
+

图像查看和下载

+
+ + +
+ + + +
+ +
+

图像评估和对比

+
+ + +
+
+ + +
+ + +
+ +
+

其他功能

+ + +
+ + +
+
+ + +
+

👨‍💼 Admin Controller - 管理员模块

+
+
+

用户管理

+
+
+ + +
+
+ + +
+
+ + +
+ + +
+ +
+ +
+

用户创建和编辑

+
+ + +
+
+ + +
+
+ + +
+ + + +
+ +
+

系统统计

+ +
+ +
+
+
+ + + + \ No newline at end of file