将lianghao_branch合并到develop #1

Merged
hnu202326010204 merged 3 commits from lianghao_branch into develop 2 months ago

9
.gitignore vendored

@ -0,0 +1,9 @@
__pycache__/
venv/
*.png
*.jpg
*.jpeg
.env

@ -0,0 +1,243 @@
# MuseGuard 后端框架
基于对抗性扰动的多风格图像生成防护系统 - 后端API服务
## 项目结构
```
backend/
├── app/ # 主应用目录
│ ├── algorithms/ # 算法实现
│ │ ├── perturbation_engine.py # 对抗性扰动引擎
│ │ └── evaluation_engine.py # 评估引擎
│ ├── controllers/ # 控制器(路由处理)
│ │ ├── auth_controller.py # 认证控制器
│ │ ├── user_controller.py # 用户配置控制器
│ │ ├── task_controller.py # 任务控制器
| | ├── demo_controller.py # 首页示例控制器
│ │ ├── image_controller.py # 图像控制器
│ │ └── admin_controller.py # 管理员控制器
│ ├── models/ # 数据模型
│ │ └── __init__.py # SQLAlchemy模型定义
│ ├── services/ # 业务逻辑服务
│ │ ├── auth_service.py # 认证服务
│ │ ├── task_service.py # 任务处理服务
│ │ └── image_service.py # 图像处理服务
│ └── utils/ # 工具类
│ └── file_utils.py # 文件处理工具
├── config/ # 配置文件
│ └── settings.py # 应用配置
├── uploads/ # 文件上传目录
├── static/ # 静态文件
│ ├── originals/ # 重命名后的原始图片
│ ├── perturbed/ # 加噪后的图片
│ ├── model_outputs/ # 模型生成的图片
│ │ ├── clean/ # 原图的模型生成结果
│ │ └── perturbed/ # 加噪图的模型生成结果
│ ├── heatmaps/ # 热力图
│ └── demo/ # 演示图片
│ ├── original/ # 演示原始图片
│ ├── perturbed/ # 演示加噪图片
│ └── comparisons/ # 演示对比图
├── app.py # Flask应用工厂
├── run.py # 启动脚本
├── init_db.py # 数据库初始化脚本
└── requirements.txt # Python依赖
```
## 功能特性
### 用户功能
- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次)
- ✅ 用户登录/登出
- ✅ 密码修改
- ✅ 任务创建和管理
- ✅ 图片上传(单张/压缩包批量)
- ✅ 加噪处理4种算法SimAC、CAAT、PID、ASPL
- ✅ 扰动强度自定义
- ✅ 防净化版本选择
- ✅ 智能配置记忆:自动保存用户上次选择的配置
- ✅ 处理结果下载
- ✅ 图片质量对比查看FID、LPIPS、SSIM、PSNR、热力图
- ✅ 模型生成对比分析
- ✅ 预设演示图片浏览
### 管理员功能
- ✅ 用户管理(增删改查)
- ✅ 系统统计信息查看
### 算法实现
- ✅ ASPL算法虚拟实现原始版本 + 防净化版本)
- ✅ SimAC算法虚拟实现原始版本 + 防净化版本)
- ✅ CAAT算法虚拟实现原始版本 + 防净化版本)
- ✅ PID算法虚拟实现原始版本 + 防净化版本)
- ✅ 图像质量评估指标计算
- ✅ 模型生成效果对比
- ✅ 热力图生成
## 安装和运行
### 1. 环境准备
#### 使用虚拟环境(推荐)
**为什么需要虚拟环境?**
- ✅ **避免依赖冲突**:不同项目使用不同版本的包
- ✅ **环境隔离**不污染系统Python环境
- ✅ **版本一致性**:确保团队环境统一
- ✅ **易于管理**:可以随时删除重建
```bash
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Windows:
venv\\Scripts\\activate
# Linux/Mac:
source venv/bin/activate
# 更新pip推荐
python -m pip install --upgrade pip
# 安装依赖
pip install -r requirements.txt
```
### 2. 数据库配置
确保已安装MySQL数据库并创建数据库。
修改 `config/.env` 中的数据库连接配置:
### 3. 初始化数据库
```bash
# 运行数据库初始化脚本
python init_db.py
```
### 4. 启动应用
```bash
# 开发模式启动
python run.py
# 或者使用Flask命令
flask run
```
应用将在 `http://localhost:5000` 启动
### 5. 系统测试
访问 `http://localhost:5000/static/test.html` 进入功能测试页面:
## API接口文档
### 认证接口 (`/api/auth`)
- `POST /register` - 用户注册
- `POST /login` - 用户登录
- `POST /change-password` - 修改密码
- `GET /profile` - 获取用户信息
- `POST /logout` - 用户登出
### 任务管理 (`/api/task`)
- `POST /create` - 创建任务(使用默认配置)
- `POST /upload/<batch_id>` - 上传图片到指定任务
- `GET /<batch_id>/config` - 获取任务配置(显示用户上次选择)
- `PUT /<batch_id>/config` - 更新任务配置(自动保存为用户偏好)
- `GET /load-config` - 加载用户上次配置
- `POST /save-config` - 保存用户配置偏好
- `POST /start/<batch_id>` - 开始处理任务
- `GET /list` - 获取任务列表
- `GET /<batch_id>` - 获取任务详情
- `GET /<batch_id>/status` - 获取处理状态
### 图片管理 (`/api/image`)
- `GET /file/<image_id>` - 查看图片
- `GET /download/<image_id>` - 下载图片
- `GET /batch/<batch_id>/download` - 批量下载
- `GET /<image_id>/evaluations` - 获取评估结果
- `POST /compare` - 对比图片
- `GET /heatmap/<path>` - 获取热力图
- `DELETE /delete/<image_id>` - 删除图片
### 用户设置 (`/api/user`)
- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中)
- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存)
- `GET /algorithms` - 获取可用算法(动态从数据库加载)
- `GET /stats` - 获取用户统计
### 管理员功能 (`/api/admin`)
- `GET /users` - 用户列表
- `GET /users/<user_id>` - 用户详情
- `POST /users` - 创建用户
- `PUT /users/<user_id>` - 更新用户
- `DELETE /users/<user_id>` - 删除用户
- `GET /stats` - 系统统计
### 演示功能 (`/api/demo`)
- `GET /images` - 获取演示图片列表
- `GET /image/original/<filename>` - 获取演示原始图片
- `GET /image/perturbed/<filename>` - 获取演示加噪图片
- `GET /image/comparison/<filename>` - 获取演示对比图片
- `GET /algorithms` - 获取算法演示信息
- `GET /stats` - 获取演示统计数据
## 默认账户
系统初始化后会创建3个管理员账户
- 用户名:`admin1`, `admin2`, `admin3`
- 默认密码:`admin123`
- 邮箱:`admin1@museguard.com` 等
## 技术栈
- **Web框架**: Flask 2.3.3
- **数据库ORM**: SQLAlchemy 3.0.5
- **数据库**: MySQL通过PyMySQL连接
- **认证**: JWT (Flask-JWT-Extended)
- **跨域**: Flask-CORS
- **图像处理**: Pillow + NumPy
- **数学计算**: NumPy
## 开发说明
### 虚拟实现说明
当前所有算法都是**虚拟实现**,用于框架搭建和测试:
1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果
2. **评估指标**: 基于像素差异的简化计算
3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果
### 扩展指南
要集成真实算法:
1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现
2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算
3. 根据需要调整配置参数
### 目录权限
确保以下目录有写入权限:
- `uploads/` - 用户上传文件
- `static/originals/` - 重命名后的原始图片
- `static/perturbed/` - 加噪后的图片
- `static/model_outputs/` - 模型生成的图片
- `static/heatmaps/` - 热力图文件
- `static/demo/` - 演示图片(需要手动添加演示文件)
## 许可证
本项目仅用于学习和研究目的。

@ -0,0 +1,46 @@
"""
MuseGuard 后端主应用入口
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
from config.settings import Config
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
def create_app(config_class=Config):
"""Flask应用工厂函数"""
app = Flask(__name__)
app.config.from_object(config_class)
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
CORS(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
return app
if __name__ == '__main__':
app = create_app()
app.run(debug=True, host='0.0.0.0', port=5000)

@ -0,0 +1,83 @@
"""
MuseGuard 后端应用包
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
import os
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
cors = CORS()
def create_app(config_name=None):
"""应用工厂函数"""
# 配置静态文件和模板文件路径
app = Flask(__name__,
static_folder='../static',
static_url_path='/static')
# 加载配置
if config_name is None:
config_name = os.environ.get('FLASK_ENV', 'development')
from config.settings import config
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
cors.init_app(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
from app.controllers.demo_controller import demo_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
app.register_blueprint(demo_bp, url_prefix='/api/demo')
# 注册错误处理器
@app.errorhandler(404)
def not_found_error(error):
return {'error': 'Not found'}, 404
@app.errorhandler(500)
def internal_error(error):
db.session.rollback()
return {'error': 'Internal server error'}, 500
# 根路由
@app.route('/')
def index():
return {
'message': 'MuseGuard API Server',
'version': '1.0.0',
'status': 'running',
'endpoints': {
'health': '/health',
'api_docs': '/api',
'test_page': '/static/test.html'
}
}
# 健康检查端点
@app.route('/health')
def health_check():
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
return app

@ -0,0 +1,300 @@
"""
评估引擎
实现图像质量评估和模型生成对比的虚拟版本
"""
import os
import numpy as np
from PIL import Image, ImageDraw
import uuid
import random
from flask import current_app
class EvaluationEngine:
"""评估处理引擎"""
@staticmethod
def evaluate_image_quality(original_path, perturbed_path):
"""
评估图像质量指标
Args:
original_path: 原始图片路径
perturbed_path: 扰动后图片路径
Returns:
包含各种评估指标的字典
"""
try:
# 加载图片
with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img:
# 确保图片尺寸一致
if orig_img.size != pert_img.size:
pert_img = pert_img.resize(orig_img.size)
# 转换为RGB
if orig_img.mode != 'RGB':
orig_img = orig_img.convert('RGB')
if pert_img.mode != 'RGB':
pert_img = pert_img.convert('RGB')
# 计算虚拟评估指标
fid_score = EvaluationEngine._calculate_mock_fid(orig_img, pert_img)
lpips_score = EvaluationEngine._calculate_mock_lpips(orig_img, pert_img)
ssim_score = EvaluationEngine._calculate_mock_ssim(orig_img, pert_img)
psnr_score = EvaluationEngine._calculate_mock_psnr(orig_img, pert_img)
# 生成热力图
heatmap_path = EvaluationEngine._generate_difference_heatmap(
orig_img, pert_img, "quality"
)
return {
'fid': fid_score,
'lpips': lpips_score,
'ssim': ssim_score,
'psnr': psnr_score,
'heatmap_path': heatmap_path
}
except Exception as e:
print(f"图像质量评估失败: {str(e)}")
return {
'fid': None,
'lpips': None,
'ssim': None,
'psnr': None,
'heatmap_path': None
}
@staticmethod
def evaluate_model_generation(original_path, perturbed_path, finetune_method):
"""
评估模型生成对比
Args:
original_path: 原始图片路径
perturbed_path: 扰动后图片路径
finetune_method: 微调方法 (dreambooth, lora)
Returns:
包含生成对比评估指标的字典
"""
try:
# 模拟训练过程并生成对比图片
with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img:
# 模拟原始图片训练后的生成结果
orig_generated = EvaluationEngine._simulate_model_generation(orig_img, finetune_method, False)
# 模拟扰动图片训练后的生成结果
pert_generated = EvaluationEngine._simulate_model_generation(pert_img, finetune_method, True)
# 计算生成质量对比指标
fid_score = EvaluationEngine._calculate_generation_fid(orig_generated, pert_generated)
lpips_score = EvaluationEngine._calculate_generation_lpips(orig_generated, pert_generated)
ssim_score = EvaluationEngine._calculate_generation_ssim(orig_generated, pert_generated)
psnr_score = EvaluationEngine._calculate_generation_psnr(orig_generated, pert_generated)
# 生成对比热力图
heatmap_path = EvaluationEngine._generate_difference_heatmap(
orig_generated, pert_generated, "generation"
)
return {
'fid': fid_score,
'lpips': lpips_score,
'ssim': ssim_score,
'psnr': psnr_score,
'heatmap_path': heatmap_path
}
except Exception as e:
print(f"模型生成评估失败: {str(e)}")
return {
'fid': None,
'lpips': None,
'ssim': None,
'psnr': None,
'heatmap_path': None
}
@staticmethod
def _calculate_mock_fid(img1, img2):
"""模拟FID计算"""
# 简单的像素差异统计作为FID近似
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
mse = np.mean((arr1 - arr2) ** 2)
# FID通常在0-300之间扰动图片应该有较小的差异
mock_fid = min(mse / 10.0 + random.uniform(0.5, 2.0), 50.0)
return round(mock_fid, 4)
@staticmethod
def _calculate_mock_lpips(img1, img2):
"""模拟LPIPS计算"""
arr1 = np.array(img1, dtype=np.float32) / 255.0
arr2 = np.array(img2, dtype=np.float32) / 255.0
# 简单的感知差异模拟
diff = np.abs(arr1 - arr2)
# 模拟感知权重(边缘区域权重更高)
gray1 = np.mean(arr1, axis=2)
edges = np.abs(np.gradient(gray1)[0]) + np.abs(np.gradient(gray1)[1])
weighted_diff = np.mean(diff * (1 + edges.reshape(edges.shape + (1,))))
mock_lpips = min(weighted_diff + random.uniform(0.001, 0.01), 1.0)
return round(mock_lpips, 4)
@staticmethod
def _calculate_mock_ssim(img1, img2):
"""模拟SSIM计算"""
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
# 简单的结构相似性模拟
mu1 = np.mean(arr1)
mu2 = np.mean(arr2)
sigma1_sq = np.var(arr1)
sigma2_sq = np.var(arr2)
sigma12 = np.mean((arr1 - mu1) * (arr2 - mu2))
c1 = (0.01 * 255) ** 2
c2 = (0.03 * 255) ** 2
ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \
((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2))
# SSIM在-1到1之间但通常在0.8-1.0之间为良好
mock_ssim = max(0.85 + random.uniform(-0.05, 0.1), 0.0)
return round(mock_ssim, 4)
@staticmethod
def _calculate_mock_psnr(img1, img2):
"""模拟PSNR计算"""
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
mse = np.mean((arr1 - arr2) ** 2)
if mse == 0:
return 100.0
max_pixel = 255.0
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
# 对于轻微扰动PSNR应该比较高30-50dB
mock_psnr = max(min(psnr + random.uniform(-2, 2), 50.0), 20.0)
return round(mock_psnr, 4)
@staticmethod
def _simulate_model_generation(img, finetune_method, is_perturbed):
"""模拟模型生成过程"""
# 对输入图像进行变换以模拟生成结果
arr = np.array(img, dtype=np.float32)
if finetune_method == "dreambooth":
# DreamBooth风格变换
if is_perturbed:
# 扰动图片训练的模型生成质量较差
noise_level = 15.0
style_change = 0.3
else:
# 原始图片训练的模型生成质量较好
noise_level = 5.0
style_change = 0.1
elif finetune_method == "lora":
# LoRA风格变换
if is_perturbed:
noise_level = 12.0
style_change = 0.25
else:
noise_level = 3.0
style_change = 0.08
else:
noise_level = 8.0
style_change = 0.15
# 添加噪声模拟生成质量
noise = np.random.normal(0, noise_level, arr.shape)
generated_arr = arr + noise
# 模拟风格变化
if style_change > 0:
# 简单的颜色偏移
generated_arr[:,:,0] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr[:,:,1] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr[:,:,2] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr = np.clip(generated_arr, 0, 255).astype(np.uint8)
return Image.fromarray(generated_arr)
@staticmethod
def _calculate_generation_fid(img1, img2):
"""计算生成图片的FID质量越差FID越高越好"""
fid = EvaluationEngine._calculate_mock_fid(img1, img2)
# 对于防护效果FID越高越好
return max(fid, 10.0)
@staticmethod
def _calculate_generation_lpips(img1, img2):
"""计算生成图片的LPIPS差异越大越好"""
lpips = EvaluationEngine._calculate_mock_lpips(img1, img2)
return max(lpips, 0.1)
@staticmethod
def _calculate_generation_ssim(img1, img2):
"""计算生成图片的SSIM相似度越低越好"""
ssim = EvaluationEngine._calculate_mock_ssim(img1, img2)
# 对于防护效果SSIM越低越好
return min(ssim, 0.9)
@staticmethod
def _calculate_generation_psnr(img1, img2):
"""计算生成图片的PSNR质量差异越大越好"""
psnr = EvaluationEngine._calculate_mock_psnr(img1, img2)
return min(psnr, 35.0)
@staticmethod
def _generate_difference_heatmap(img1, img2, eval_type):
"""生成差异热力图"""
try:
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
# 计算像素差异
diff = np.abs(arr1 - arr2)
diff_gray = np.mean(diff, axis=2)
# 归一化到0-255
diff_normalized = (diff_gray / diff_gray.max() * 255).astype(np.uint8)
# 创建热力图
heatmap = Image.fromarray(diff_normalized, mode='L')
heatmap = heatmap.convert('RGB')
# 应用颜色映射(蓝色-绿色-红色)
heatmap_array = np.array(heatmap)
colored_heatmap = np.zeros_like(arr1, dtype=np.uint8)
# 简单的颜色映射
intensity = heatmap_array[:,:,0] / 255.0
colored_heatmap[:,:,0] = (intensity * 255).astype(np.uint8) # 红色通道
colored_heatmap[:,:,1] = ((1 - intensity) * intensity * 4 * 255).astype(np.uint8) # 绿色通道
colored_heatmap[:,:,2] = ((1 - intensity) * 255).astype(np.uint8) # 蓝色通道
# 保存热力图
heatmap_dir = os.path.join('static', 'heatmaps')
os.makedirs(heatmap_dir, exist_ok=True)
heatmap_filename = f"{eval_type}_heatmap_{uuid.uuid4().hex[:8]}.png"
heatmap_path = os.path.join(heatmap_dir, heatmap_filename)
Image.fromarray(colored_heatmap).save(heatmap_path)
return heatmap_path
except Exception as e:
print(f"生成热力图失败: {str(e)}")
return None

@ -0,0 +1,176 @@
"""
对抗性扰动算法引擎
实现各种加噪算法的虚拟版本
"""
import os
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import uuid
from flask import current_app
class PerturbationEngine:
"""对抗性扰动处理引擎"""
@staticmethod
def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None):
"""
应用对抗性扰动
Args:
image_path: 原始图片路径
algorithm: 算法名称 (simac, caat, pid)
epsilon: 扰动强度
use_strong_protection: 是否使用防净化版本
Returns:
处理后图片的路径
"""
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片文件不存在: {image_path}")
# 加载图片
with Image.open(image_path) as img:
# 转换为RGB模式
if img.mode != 'RGB':
img = img.convert('RGB')
# 根据算法选择处理方法
if algorithm == 'simac':
perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection)
elif algorithm == 'caat':
perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection)
elif algorithm == 'pid':
perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection)
else:
raise ValueError(f"不支持的算法: {algorithm}")
# 使用输入的output_path参数
if output_path is None:
# 如果没有提供输出路径,使用默认路径
from flask import current_app
project_root = os.path.dirname(current_app.root_path)
perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER'])
os.makedirs(perturbed_dir, exist_ok=True)
file_extension = os.path.splitext(image_path)[1]
output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}"
output_path = os.path.join(perturbed_dir, output_filename)
# 保存处理后的图片
perturbed_img.save(output_path, quality=95)
return output_path
except Exception as e:
print(f"应用扰动时出错: {str(e)}")
return None
@staticmethod
def _apply_simac(img, epsilon, use_strong_protection):
"""
SimAC算法的虚拟实现
Simple Anti-Customization Method for Protecting Face Privacy
"""
# 将PIL图像转换为numpy数组
img_array = np.array(img, dtype=np.float32)
# 生成随机噪声(模拟对抗性扰动)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:添加更复杂的扰动模式
noise = np.random.normal(0, noise_scale * 0.8, img_array.shape)
# 添加结构化噪声
h, w = img_array.shape[:2]
for i in range(0, h, 8):
for j in range(0, w, 8):
block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3))
noise[i:i+8, j:j+8] += block_noise
else:
# 标准版本:简单高斯噪声
noise = np.random.normal(0, noise_scale, img_array.shape)
# 应用噪声
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
# 转换回PIL图像
result_img = Image.fromarray(perturbed_array)
# 轻微的图像增强以模拟算法特性
enhancer = ImageEnhance.Contrast(result_img)
result_img = enhancer.enhance(1.02)
return result_img
@staticmethod
def _apply_caat(img, epsilon, use_strong_protection):
"""
CAAT算法的虚拟实现
Perturbing Attention Gives You More Bang for the Buck
"""
img_array = np.array(img, dtype=np.float32)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:注意力区域重点扰动
# 模拟注意力图(简单的边缘检测)
gray_img = img.convert('L')
edge_img = gray_img.filter(ImageFilter.FIND_EDGES)
attention_map = np.array(edge_img, dtype=np.float32) / 255.0
# 在注意力区域添加更强的噪声
noise = np.random.normal(0, noise_scale * 0.6, img_array.shape)
for c in range(3):
noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape)
else:
# 标准版本:均匀分布噪声
noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape)
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
result_img = Image.fromarray(perturbed_array)
# 轻微模糊以模拟注意力扰动效果
result_img = result_img.filter(ImageFilter.BoxBlur(0.5))
return result_img
@staticmethod
def _apply_pid(img, epsilon, use_strong_protection):
"""
PID算法的虚拟实现
Prompt-Independent Data Protection Against Latent Diffusion Models
"""
img_array = np.array(img, dtype=np.float32)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:频域扰动
# 简单的频域变换模拟
noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape)
# 添加周期性扰动
h, w = img_array.shape[:2]
for i in range(h):
for j in range(w):
periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1)
noise[i, j] += periodic_noise
else:
# 标准版本:拉普拉斯噪声
noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape)
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
result_img = Image.fromarray(perturbed_array)
# 色彩微调以模拟潜在空间扰动
enhancer = ImageEnhance.Color(result_img)
result_img = enhancer.enhance(0.98)
return result_img

@ -0,0 +1,239 @@
"""
管理员控制器
处理管理员功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from app import db
from app.models import User, Batch, Image
admin_bp = Blueprint('admin', __name__)
def admin_required(f):
"""管理员权限装饰器"""
from functools import wraps
@wraps(f)
def decorated_function(*args, **kwargs):
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
if not user or user.role != 'admin':
return jsonify({'error': '需要管理员权限'}), 403
return f(*args, **kwargs)
return decorated_function
@admin_bp.route('/users', methods=['GET'])
@jwt_required()
@admin_required
def list_users():
"""获取用户列表"""
try:
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'users': [user.to_dict() for user in users.items],
'total': users.total,
'pages': users.pages,
'current_page': page
}), 200
except Exception as e:
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
@jwt_required()
@admin_required
def get_user_detail(user_id):
"""获取用户详情"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 获取用户统计信息
total_tasks = Batch.query.filter_by(user_id=user_id).count()
total_images = Image.query.filter_by(user_id=user_id).count()
user_dict = user.to_dict()
user_dict['stats'] = {
'total_tasks': total_tasks,
'total_images': total_images
}
return jsonify({'user': user_dict}), 200
except Exception as e:
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
@admin_bp.route('/users', methods=['POST'])
@jwt_required()
@admin_required
def create_user():
"""创建用户"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
role = data.get('role', 'user')
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已存在
if email and User.query.filter_by(email=email).first():
return jsonify({'error': '邮箱已被使用'}), 400
# 创建用户
user = User(
username=username,
email=email,
role=role,
max_concurrent_tasks=max_concurrent_tasks
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return jsonify({
'message': '用户创建成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
@jwt_required()
@admin_required
def update_user(user_id):
"""更新用户信息"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
# 更新字段
if 'username' in data:
new_username = data['username']
if new_username != user.username:
if User.query.filter_by(username=new_username).first():
return jsonify({'error': '用户名已存在'}), 400
user.username = new_username
if 'email' in data:
new_email = data['email']
if new_email != user.email:
if User.query.filter_by(email=new_email).first():
return jsonify({'error': '邮箱已被使用'}), 400
user.email = new_email
if 'role' in data:
user.role = data['role']
if 'max_concurrent_tasks' in data:
user.max_concurrent_tasks = data['max_concurrent_tasks']
if 'is_active' in data:
user.is_active = bool(data['is_active'])
if 'password' in data and data['password']:
user.set_password(data['password'])
db.session.commit()
return jsonify({
'message': '用户信息更新成功',
'user': user.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
@jwt_required()
@admin_required
def delete_user(user_id):
"""删除用户"""
try:
current_user_id = get_jwt_identity()
# 不能删除自己
if user_id == current_user_id:
return jsonify({'error': '不能删除自己的账户'}), 400
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 删除用户(级联删除相关数据)
db.session.delete(user)
db.session.commit()
return jsonify({'message': '用户删除成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
@admin_bp.route('/stats', methods=['GET'])
@jwt_required()
@admin_required
def get_system_stats():
"""获取系统统计信息"""
try:
from app.models import EvaluationResult
total_users = User.query.count()
active_users = User.query.filter_by(is_active=True).count()
admin_users = User.query.filter_by(role='admin').count()
total_tasks = Batch.query.count()
completed_tasks = Batch.query.filter_by(status='completed').count()
processing_tasks = Batch.query.filter_by(status='processing').count()
failed_tasks = Batch.query.filter_by(status='failed').count()
total_images = Image.query.count()
total_evaluations = EvaluationResult.query.count()
return jsonify({
'stats': {
'users': {
'total': total_users,
'active': active_users,
'admin': admin_users
},
'tasks': {
'total': total_tasks,
'completed': completed_tasks,
'processing': processing_tasks,
'failed': failed_tasks
},
'images': {
'total': total_images
},
'evaluations': {
'total': total_evaluations
}
}
}), 200
except Exception as e:
return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500

@ -0,0 +1,156 @@
"""
用户认证控制器
处理注册登录密码修改等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from app import db
from app.models import User, UserConfig
from app.services.auth_service import AuthService
from functools import wraps
import re
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
try:
current_user_id = int(get_jwt_identity())
return f(*args, current_user_id=current_user_id, **kwargs)
except (TypeError, ValueError):
return jsonify({'error': '无效的用户身份标识'}), 401
return jwt_required()(wrapped)
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/register', methods=['POST'])
def register():
"""用户注册"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
# 验证输入
if not username or not password or not email:
return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400
# 验证邮箱格式
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
return jsonify({'error': '邮箱格式不正确'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已注册
if User.query.filter_by(email=email).first():
return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400
# 创建用户
user = User(username=username, email=email)
user.set_password(password)
db.session.add(user)
db.session.commit()
# 创建用户默认配置
user_config = UserConfig(user_id=user.id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'message': '注册成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'注册失败: {str(e)}'}), 500
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 查找用户
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return jsonify({'error': '用户名或密码错误'}), 401
if not user.is_active:
return jsonify({'error': '账户已被禁用'}), 401
# 创建访问令牌 - 确保用户ID为字符串类型
access_token = create_access_token(identity=str(user.id))
return jsonify({
'message': '登录成功',
'access_token': access_token,
'user': user.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'登录失败: {str(e)}'}), 500
@auth_bp.route('/change-password', methods=['POST'])
@int_jwt_required
def change_password(current_user_id):
"""修改密码"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
old_password = data.get('old_password')
new_password = data.get('new_password')
if not old_password or not new_password:
return jsonify({'error': '旧密码和新密码不能为空'}), 400
# 验证旧密码
if not user.check_password(old_password):
return jsonify({'error': '旧密码错误'}), 401
# 设置新密码
user.set_password(new_password)
db.session.commit()
return jsonify({'message': '密码修改成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'密码修改失败: {str(e)}'}), 500
@auth_bp.route('/profile', methods=['GET'])
@int_jwt_required
def get_profile(current_user_id):
"""获取用户信息"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
return jsonify({'user': user.to_dict()}), 200
except Exception as e:
return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500
@auth_bp.route('/logout', methods=['POST'])
@jwt_required()
def logout():
"""用户登出客户端删除token即可"""
return jsonify({'message': '登出成功'}), 200

@ -0,0 +1,177 @@
"""
演示图片控制器
处理预设图像对比图的展示功能
"""
from flask import Blueprint, send_file, jsonify, current_app
from flask_jwt_extended import jwt_required
from app.models import PerturbationConfig, FinetuneConfig
import os
import glob
demo_bp = Blueprint('demo', __name__)
@demo_bp.route('/images', methods=['GET'])
def list_demo_images():
"""获取所有演示图片列表"""
try:
demo_images = []
# 获取演示原始图片 - 修正路径构建
# 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
if os.path.exists(original_folder):
original_files = glob.glob(os.path.join(original_folder, '*'))
for file_path in original_files:
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
# 查找对应的加噪图片
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*"))
# 查找对应的对比图
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*"))
demo_image = {
'id': name_without_ext,
'name': name_without_ext,
'original': f"/api/demo/image/original/{filename}",
'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files],
'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files]
}
demo_images.append(demo_image)
return jsonify({
'demo_images': demo_images,
'total': len(demo_images)
}), 200
except Exception as e:
return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500
@demo_bp.route('/image/original/<filename>', methods=['GET'])
def get_demo_original_image(filename):
"""获取演示原始图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500
@demo_bp.route('/image/perturbed/<filename>', methods=['GET'])
def get_demo_perturbed_image(filename):
"""获取演示加噪图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500
@demo_bp.route('/image/comparison/<filename>', methods=['GET'])
def get_demo_comparison_image(filename):
"""获取演示对比图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500
@demo_bp.route('/algorithms', methods=['GET'])
def get_demo_algorithms():
"""获取演示算法信息"""
try:
# 从数据库获取扰动算法
perturbation_algorithms = []
perturbation_configs = PerturbationConfig.query.all()
for config in perturbation_configs:
perturbation_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'perturbation',
'description': config.description,
'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None
})
# 从数据库获取微调算法
finetune_algorithms = []
finetune_configs = FinetuneConfig.query.all()
for config in finetune_configs:
finetune_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'finetune',
'description': config.description
})
return jsonify({
'perturbation_algorithms': perturbation_algorithms,
'finetune_algorithms': finetune_algorithms,
'evaluation_metrics': [
{'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'},
{'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'},
{'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'},
{'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'}
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500
@demo_bp.route('/stats', methods=['GET'])
def get_demo_stats():
"""获取演示统计信息"""
try:
# 统计演示图片数量
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0
perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0
comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0
# 统计数据库中的算法数量
perturbation_count = PerturbationConfig.query.count()
finetune_count = FinetuneConfig.query.count()
total_algorithms = perturbation_count + finetune_count
return jsonify({
'demo_stats': {
'original_images': original_count,
'perturbed_images': perturbed_count,
'comparison_images': comparison_count,
'supported_algorithms': total_algorithms,
'perturbation_algorithms': perturbation_count,
'finetune_algorithms': finetune_count,
'evaluation_metrics': 4
}
}), 200
except Exception as e:
return jsonify({'error': f'获取统计信息失败: {str(e)}'}), 500

@ -0,0 +1,203 @@
"""
图像管理控制器
处理图像下载查看等功能
"""
from flask import Blueprint, send_file, jsonify, request, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from app.models import Image, EvaluationResult
from app.services.image_service import ImageService
import os
image_bp = Blueprint('image', __name__)
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@jwt_required()
def get_image_file(image_id):
"""获取图片文件"""
try:
current_user_id = get_jwt_identity()
# 查找图片记录
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 检查文件是否存在
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(image.file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取图片失败: {str(e)}'}), 500
@image_bp.route('/download/<int:image_id>', methods=['GET'])
@jwt_required()
def download_image(image_id):
"""下载图片文件"""
try:
current_user_id = get_jwt_identity()
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(
image.file_path,
as_attachment=True,
download_name=image.original_filename or f"image_{image_id}.jpg"
)
except Exception as e:
return jsonify({'error': f'下载图片失败: {str(e)}'}), 500
@image_bp.route('/batch/<int:batch_id>/download', methods=['GET'])
@jwt_required()
def download_batch_images(batch_id):
"""批量下载任务中的加噪后图片"""
try:
current_user_id = get_jwt_identity()
# 获取任务中的加噪图片
perturbed_images = Image.query.join(Image.image_type).filter(
Image.batch_id == batch_id,
Image.user_id == current_user_id,
Image.image_type.has(type_code='perturbed')
).all()
if not perturbed_images:
return jsonify({'error': '没有找到加噪后的图片'}), 404
# 创建ZIP文件
import zipfile
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file:
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
for image in perturbed_images:
if os.path.exists(image.file_path):
arcname = image.original_filename or f"perturbed_{image.id}.jpg"
zip_file.write(image.file_path, arcname)
return send_file(
tmp_file.name,
as_attachment=True,
download_name=f"batch_{batch_id}_perturbed_images.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({'error': f'批量下载失败: {str(e)}'}), 500
@image_bp.route('/<int:image_id>/evaluations', methods=['GET'])
@jwt_required()
def get_image_evaluations(image_id):
"""获取图片的评估结果"""
try:
current_user_id = get_jwt_identity()
# 验证图片权限
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 获取以该图片为参考或目标的评估结果
evaluations = EvaluationResult.query.filter(
(EvaluationResult.reference_image_id == image_id) |
(EvaluationResult.target_image_id == image_id)
).all()
return jsonify({
'image_id': image_id,
'evaluations': [eval_result.to_dict() for eval_result in evaluations]
}), 200
except Exception as e:
return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500
@image_bp.route('/compare', methods=['POST'])
@jwt_required()
def compare_images():
"""对比两张图片"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
image1_id = data.get('image1_id')
image2_id = data.get('image2_id')
if not image1_id or not image2_id:
return jsonify({'error': '请提供两张图片的ID'}), 400
# 验证图片权限
image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first()
image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first()
if not image1 or not image2:
return jsonify({'error': '图片不存在或无权限'}), 404
# 查找现有的评估结果
evaluation = EvaluationResult.query.filter_by(
reference_image_id=image1_id,
target_image_id=image2_id
).first()
if not evaluation:
# 如果没有评估结果,返回基本对比信息
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': None,
'message': '暂无评估数据,请等待任务处理完成'
}), 200
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': evaluation.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'图片对比失败: {str(e)}'}), 500
@image_bp.route('/heatmap/<path:heatmap_path>', methods=['GET'])
@jwt_required()
def get_heatmap(heatmap_path):
"""获取热力图文件"""
try:
# 安全检查,防止路径遍历攻击
if '..' in heatmap_path or heatmap_path.startswith('/'):
return jsonify({'error': '无效的文件路径'}), 400
# 修正路径构建 - 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path))
if not os.path.exists(full_path):
return jsonify({'error': '热力图文件不存在'}), 404
return send_file(full_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500
@image_bp.route('/delete/<int:image_id>', methods=['DELETE'])
@jwt_required()
def delete_image(image_id):
"""删除图片"""
try:
current_user_id = get_jwt_identity()
result = ImageService.delete_image(image_id, current_user_id)
if result['success']:
return jsonify({'message': '图片删除成功'}), 200
else:
return jsonify({'error': result['error']}), 400
except Exception as e:
return jsonify({'error': f'删除图片失败: {str(e)}'}), 500

@ -0,0 +1,400 @@
"""
任务管理控制器
处理创建任务上传图片等功能
"""
from flask import Blueprint, request, jsonify, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from werkzeug.utils import secure_filename
from app import db
from app.models import User, Batch, Image, ImageType, UserConfig
from app.services.task_service import TaskService
from app.services.image_service import ImageService
from app.utils.file_utils import allowed_file, save_uploaded_file
import os
import zipfile
import uuid
task_bp = Blueprint('task', __name__)
@task_bp.route('/create', methods=['POST'])
@jwt_required()
def create_task():
"""创建新任务(仅创建任务,使用默认配置)"""
try:
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}')
# 使用默认配置创建任务
batch = Batch(
user_id=current_user_id,
batch_name=batch_name,
perturbation_config_id=1, # 默认配置
preferred_epsilon=8.0, # 默认epsilon
finetune_config_id=1, # 默认微调配置
use_strong_protection=False # 默认不启用强防护
)
db.session.add(batch)
db.session.commit()
return jsonify({
'message': '任务创建成功,请上传图片',
'task': batch.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'任务创建失败: {str(e)}'}), 500
@task_bp.route('/upload/<int:batch_id>', methods=['POST'])
@jwt_required()
def upload_images(batch_id):
"""上传图片到指定任务"""
try:
current_user_id = get_jwt_identity()
# 检查任务是否存在且属于当前用户
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
if batch.status != 'pending':
return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400
if 'files' not in request.files:
return jsonify({'error': '没有选择文件'}), 400
files = request.files.getlist('files')
uploaded_files = []
# 获取原始图片类型ID
original_type = ImageType.query.filter_by(type_code='original').first()
if not original_type:
return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500
for file in files:
if file.filename == '':
continue
if file and allowed_file(file.filename):
# 处理单张图片
if not file.filename.lower().endswith(('.zip', '.rar')):
result = ImageService.save_image(file, batch_id, current_user_id, original_type.id)
if result['success']:
uploaded_files.append(result['image'])
else:
return jsonify({'error': result['error']}), 400
# 处理压缩包
else:
results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id)
for result in results:
if result['success']:
uploaded_files.append(result['image'])
if not uploaded_files:
return jsonify({'error': '没有有效的图片文件'}), 400
return jsonify({
'message': f'成功上传 {len(uploaded_files)} 张图片',
'uploaded_files': [img.to_dict() for img in uploaded_files]
}), 200
except Exception as e:
return jsonify({'error': f'文件上传失败: {str(e)}'}), 500
@task_bp.route('/<int:batch_id>/config', methods=['GET'])
@jwt_required()
def get_task_config(batch_id):
"""获取任务配置(显示用户上次的配置或默认配置)"""
try:
current_user_id = get_jwt_identity()
# 检查任务是否存在且属于当前用户
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
# 获取用户配置
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
# 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置
if user_config:
suggested_config = {
'perturbation_config_id': user_config.preferred_perturbation_config_id,
'epsilon': float(user_config.preferred_epsilon),
'finetune_config_id': user_config.preferred_finetune_config_id,
'use_strong_protection': user_config.preferred_purification
}
else:
suggested_config = {
'perturbation_config_id': batch.perturbation_config_id,
'epsilon': float(batch.preferred_epsilon),
'finetune_config_id': batch.finetune_config_id,
'use_strong_protection': batch.use_strong_protection
}
return jsonify({
'task': batch.to_dict(),
'suggested_config': suggested_config,
'current_config': {
'perturbation_config_id': batch.perturbation_config_id,
'epsilon': float(batch.preferred_epsilon),
'finetune_config_id': batch.finetune_config_id,
'use_strong_protection': batch.use_strong_protection
}
}), 200
except Exception as e:
return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500
@task_bp.route('/<int:batch_id>/config', methods=['PUT'])
@jwt_required()
def update_task_config(batch_id):
"""更新任务配置"""
try:
current_user_id = get_jwt_identity()
# 检查任务是否存在且属于当前用户
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
if batch.status != 'pending':
return jsonify({'error': '任务已开始处理,无法修改配置'}), 400
data = request.get_json()
# 更新任务配置
if 'perturbation_config_id' in data:
batch.perturbation_config_id = data['perturbation_config_id']
if 'epsilon' in data:
epsilon = float(data['epsilon'])
if 0 < epsilon <= 255:
batch.preferred_epsilon = epsilon
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'finetune_config_id' in data:
batch.finetune_config_id = data['finetune_config_id']
if 'use_strong_protection' in data:
batch.use_strong_protection = bool(data['use_strong_protection'])
db.session.commit()
# 更新用户配置(保存这次的选择作为下次的默认配置)
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
user_config.preferred_perturbation_config_id = batch.perturbation_config_id
user_config.preferred_epsilon = batch.preferred_epsilon
user_config.preferred_finetune_config_id = batch.finetune_config_id
user_config.preferred_purification = batch.use_strong_protection
db.session.commit()
return jsonify({
'message': '任务配置更新成功',
'task': batch.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500
@task_bp.route('/start/<int:batch_id>', methods=['POST'])
@jwt_required()
def start_task(batch_id):
"""开始处理任务"""
try:
current_user_id = get_jwt_identity()
# 检查任务是否存在且属于当前用户
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
if batch.status != 'pending':
return jsonify({'error': '任务状态不正确,无法开始处理'}), 400
# 检查是否有上传的图片
image_count = Image.query.filter_by(batch_id=batch_id).count()
if image_count == 0:
return jsonify({'error': '请先上传图片'}), 400
# 启动任务处理
success = TaskService.start_processing(batch)
if success:
return jsonify({
'message': '任务开始处理',
'task': batch.to_dict()
}), 200
else:
return jsonify({'error': '任务启动失败'}), 500
except Exception as e:
return jsonify({'error': f'任务启动失败: {str(e)}'}), 500
@task_bp.route('/list', methods=['GET'])
@jwt_required()
def list_tasks():
"""获取用户的任务列表"""
try:
current_user_id = get_jwt_identity()
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
batches = Batch.query.filter_by(user_id=current_user_id)\
.order_by(Batch.created_at.desc())\
.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'tasks': [batch.to_dict() for batch in batches.items],
'total': batches.total,
'pages': batches.pages,
'current_page': page
}), 200
except Exception as e:
return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500
@task_bp.route('/<int:batch_id>', methods=['GET'])
@jwt_required()
def get_task_detail(batch_id):
"""获取任务详情"""
try:
current_user_id = get_jwt_identity()
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
# 获取任务相关的图片
images = Image.query.filter_by(batch_id=batch_id).all()
return jsonify({
'task': batch.to_dict(),
'images': [img.to_dict() for img in images],
'image_count': len(images)
}), 200
except Exception as e:
return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500
@task_bp.route('/load-config', methods=['GET'])
@jwt_required()
def load_last_config():
"""加载用户上次的配置"""
try:
current_user_id = get_jwt_identity()
# 获取用户配置
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if user_config:
config = {
'perturbation_config_id': user_config.preferred_perturbation_config_id,
'epsilon': float(user_config.preferred_epsilon),
'finetune_config_id': user_config.preferred_finetune_config_id,
'use_strong_protection': user_config.preferred_purification
}
return jsonify({
'message': '成功加载上次配置',
'config': config
}), 200
else:
# 返回默认配置
default_config = {
'perturbation_config_id': 1,
'epsilon': 8.0,
'finetune_config_id': 1,
'use_strong_protection': False
}
return jsonify({
'message': '使用默认配置',
'config': default_config
}), 200
except Exception as e:
return jsonify({'error': f'加载配置失败: {str(e)}'}), 500
@task_bp.route('/save-config', methods=['POST'])
@jwt_required()
def save_current_config():
"""保存当前配置作为用户偏好"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
# 获取或创建用户配置
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置
if 'perturbation_config_id' in data:
user_config.preferred_perturbation_config_id = data['perturbation_config_id']
if 'epsilon' in data:
epsilon = float(data['epsilon'])
if 0 < epsilon <= 255:
user_config.preferred_epsilon = epsilon
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'finetune_config_id' in data:
user_config.preferred_finetune_config_id = data['finetune_config_id']
if 'use_strong_protection' in data:
user_config.preferred_purification = bool(data['use_strong_protection'])
db.session.commit()
return jsonify({
'message': '配置保存成功',
'config': {
'perturbation_config_id': user_config.preferred_perturbation_config_id,
'epsilon': float(user_config.preferred_epsilon),
'finetune_config_id': user_config.preferred_finetune_config_id,
'use_strong_protection': user_config.preferred_purification
}
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'保存配置失败: {str(e)}'}), 500
@task_bp.route('/<int:batch_id>/status', methods=['GET'])
@jwt_required()
def get_task_status(batch_id):
"""获取任务处理状态"""
try:
current_user_id = get_jwt_identity()
batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first()
if not batch:
return jsonify({'error': '任务不存在或无权限'}), 404
return jsonify({
'task_id': batch_id,
'status': batch.status,
'progress': TaskService.get_processing_progress(batch_id),
'error_message': batch.error_message
}), 200
except Exception as e:
return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500

@ -0,0 +1,133 @@
"""
用户管理控制器
处理用户配置等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
"""获取用户配置"""
try:
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
# 如果没有配置,创建默认配置
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'config': user_config.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
"""更新用户配置"""
try:
data = request.get_json()
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置字段
if 'preferred_perturbation_config_id' in data:
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
if 'preferred_epsilon' in data:
epsilon = float(data['preferred_epsilon'])
if 0 < epsilon <= 255:
user_config.preferred_epsilon = epsilon
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'preferred_finetune_config_id' in data:
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
if 'preferred_purification' in data:
user_config.preferred_purification = bool(data['preferred_purification'])
db.session.commit()
return jsonify({
'message': '用户配置更新成功',
'config': user_config.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
@user_bp.route('/algorithms', methods=['GET'])
@jwt_required()
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = PerturbationConfig.query.all()
finetune_configs = FinetuneConfig.query.all()
return jsonify({
'perturbation_algorithms': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
'default_epsilon': float(config.default_epsilon)
} for config in perturbation_configs
],
'finetune_methods': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description
} for config in finetune_configs
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
@user_bp.route('/stats', methods=['GET'])
@int_jwt_required
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.models import Batch, Image
# 统计用户的任务和图片数量
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.filter_by(user_id=current_user_id).count()
return jsonify({
'stats': {
'total_tasks': total_tasks,
'completed_tasks': completed_tasks,
'processing_tasks': processing_tasks,
'failed_tasks': failed_tasks,
'total_images': total_images
}
}), 200
except Exception as e:
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500

@ -0,0 +1,233 @@
"""
数据库模型定义
基于已有的schema.sql设计
"""
from datetime import datetime
from app import db
from werkzeug.security import generate_password_hash, check_password_hash
from enum import Enum as PyEnum
class User(db.Model):
"""用户表"""
__tablename__ = 'users'
id = db.Column(db.BigInteger, primary_key=True)
username = db.Column(db.String(50), unique=True, nullable=False)
password_hash = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(100))
role = db.Column(db.Enum('user', 'admin'), default='user')
max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = db.Column(db.Boolean, default=True)
# 关系
batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan')
images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan')
user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan')
def set_password(self, password):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'role': self.role,
'max_concurrent_tasks': self.max_concurrent_tasks,
'created_at': self.created_at.isoformat() if self.created_at else None,
'is_active': self.is_active
}
class ImageType(db.Model):
"""图片类型表"""
__tablename__ = 'image_types'
id = db.Column(db.BigInteger, primary_key=True)
type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'),
unique=True, nullable=False)
type_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text)
# 关系
images = db.relationship('Image', backref='image_type', lazy='dynamic')
class PerturbationConfig(db.Model):
"""加噪算法表"""
__tablename__ = 'perturbation_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
default_epsilon = db.Column(db.Numeric(5, 2), nullable=False)
# 关系
batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic')
class FinetuneConfig(db.Model):
"""微调方式表"""
__tablename__ = 'finetune_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
# 关系
batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic')
class Batch(db.Model):
"""加噪批次表"""
__tablename__ = 'batch'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_name = db.Column(db.String(128))
# 加噪配置
perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'),
nullable=False, default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0)
# 评估配置
finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'),
nullable=False, default=1)
# 净化配置
use_strong_protection = db.Column(db.Boolean, nullable=False, default=False)
# 任务状态
status = db.Column(db.Enum('pending', 'processing', 'completed', 'failed'), default='pending')
created_at = db.Column(db.DateTime, default=datetime.utcnow)
started_at = db.Column(db.DateTime)
completed_at = db.Column(db.DateTime)
error_message = db.Column(db.Text)
result_path = db.Column(db.String(500))
# 关系
images = db.relationship('Image', backref='batch', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'batch_name': self.batch_name,
'status': self.status,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'use_strong_protection': self.use_strong_protection,
'created_at': self.created_at.isoformat() if self.created_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'error_message': self.error_message,
'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None,
'finetune_config': self.finetune_config.method_name if self.finetune_config else None
}
class Image(db.Model):
"""图片表"""
__tablename__ = 'images'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'))
father_id = db.Column(db.BigInteger, db.ForeignKey('images.id'))
original_filename = db.Column(db.String(255))
stored_filename = db.Column(db.String(255), unique=True, nullable=False)
file_path = db.Column(db.String(500), nullable=False)
file_size = db.Column(db.BigInteger)
image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False)
width = db.Column(db.Integer)
height = db.Column(db.Integer)
upload_time = db.Column(db.DateTime, default=datetime.utcnow)
# 自引用关系
children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic')
# 评估结果关系
reference_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.reference_image_id',
backref='reference_image', lazy='dynamic')
target_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.target_image_id',
backref='target_image', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'original_filename': self.original_filename,
'stored_filename': self.stored_filename,
'file_path': self.file_path,
'file_size': self.file_size,
'width': self.width,
'height': self.height,
'upload_time': self.upload_time.isoformat() if self.upload_time else None,
'image_type': self.image_type.type_name if self.image_type else None,
'batch_id': self.batch_id
}
class EvaluationResult(db.Model):
"""评估结果表"""
__tablename__ = 'evaluation_results'
id = db.Column(db.BigInteger, primary_key=True)
reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False)
purification_applied = db.Column(db.Boolean, default=False)
fid_score = db.Column(db.Numeric(8, 4))
lpips_score = db.Column(db.Numeric(8, 4))
ssim_score = db.Column(db.Numeric(8, 4))
psnr_score = db.Column(db.Numeric(8, 4))
heatmap_path = db.Column(db.String(500))
evaluated_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'evaluation_type': self.evaluation_type,
'purification_applied': self.purification_applied,
'fid_score': float(self.fid_score) if self.fid_score else None,
'lpips_score': float(self.lpips_score) if self.lpips_score else None,
'ssim_score': float(self.ssim_score) if self.ssim_score else None,
'psnr_score': float(self.psnr_score) if self.psnr_score else None,
'heatmap_path': self.heatmap_path,
'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None
}
class UserConfig(db.Model):
"""用户配置表"""
__tablename__ = 'user_configs'
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True)
preferred_perturbation_config_id = db.Column(db.BigInteger,
db.ForeignKey('perturbation_configs.id'), default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0)
preferred_finetune_config_id = db.Column(db.BigInteger,
db.ForeignKey('finetune_configs.id'), default=1)
preferred_purification = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'user_id': self.user_id,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'preferred_purification': self.preferred_purification,
'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None,
'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}

@ -0,0 +1,34 @@
"""
认证服务
处理用户认证相关逻辑
"""
from app.models import User
class AuthService:
"""认证服务类"""
@staticmethod
def authenticate_user(username, password):
"""验证用户凭据"""
user = User.query.filter_by(username=username).first()
if user and user.check_password(password) and user.is_active:
return user
return None
@staticmethod
def get_user_by_id(user_id):
"""根据ID获取用户"""
return User.query.get(user_id)
@staticmethod
def is_email_available(email):
"""检查邮箱是否可用"""
return User.query.filter_by(email=email).first() is None
@staticmethod
def is_username_available(username):
"""检查用户名是否可用"""
return User.query.filter_by(username=username).first() is None

@ -0,0 +1,161 @@
"""
图像处理服务
处理图像上传保存等功能
"""
import os
import uuid
import zipfile
from werkzeug.utils import secure_filename
from flask import current_app
from PIL import Image as PILImage
from app import db
from app.models import Image
from app.utils.file_utils import allowed_file
class ImageService:
"""图像处理服务"""
@staticmethod
def save_image(file, batch_id, user_id, image_type_id):
"""保存单张图片"""
try:
if not file or not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
# 生成唯一文件名
file_extension = os.path.splitext(file.filename)[1].lower()
stored_filename = f"{uuid.uuid4().hex}{file_extension}"
# 临时保存到上传目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id))
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, stored_filename)
file.save(temp_path)
# 移动到对应的静态目录
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id))
os.makedirs(static_dir, exist_ok=True)
file_path = os.path.join(static_dir, stored_filename)
# 移动文件到最终位置
import shutil
shutil.move(temp_path, file_path)
# 获取图片尺寸
try:
with PILImage.open(file_path) as img:
width, height = img.size
except:
width, height = None, None
# 创建数据库记录
image = Image(
user_id=user_id,
batch_id=batch_id,
original_filename=file.filename,
stored_filename=stored_filename,
file_path=file_path,
file_size=os.path.getsize(file_path),
image_type_id=image_type_id,
width=width,
height=height
)
db.session.add(image)
db.session.commit()
return {'success': True, 'image': image}
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
@staticmethod
def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id):
"""解压并保存压缩包中的图片"""
results = []
temp_dir = None
try:
# 创建临时目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}")
os.makedirs(temp_dir, exist_ok=True)
# 保存压缩包
zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename))
zip_file.save(zip_path)
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 遍历解压的文件
for root, dirs, files in os.walk(temp_dir):
for filename in files:
if filename.lower().endswith(('.zip', '.rar')):
continue # 跳过压缩包文件本身
if allowed_file(filename):
file_path = os.path.join(root, filename)
# 创建虚拟文件对象
class FileWrapper:
def __init__(self, path, name):
self.path = path
self.filename = name
def save(self, destination):
import shutil
shutil.copy2(self.path, destination)
virtual_file = FileWrapper(file_path, filename)
result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id)
results.append(result)
return results
except Exception as e:
return [{'success': False, 'error': f'解压失败: {str(e)}'}]
finally:
# 清理临时文件
if temp_dir and os.path.exists(temp_dir):
import shutil
try:
shutil.rmtree(temp_dir)
except:
pass
@staticmethod
def get_image_url(image):
"""获取图片访问URL"""
if not image or not image.file_path:
return None
# 这里返回相对路径前端可以拼接完整URL
return f"/api/image/file/{image.id}"
@staticmethod
def delete_image(image_id, user_id):
"""删除图片"""
try:
image = Image.query.filter_by(id=image_id, user_id=user_id).first()
if not image:
return {'success': False, 'error': '图片不存在或无权限'}
# 删除文件
if os.path.exists(image.file_path):
os.remove(image.file_path)
# 删除数据库记录
db.session.delete(image)
db.session.commit()
return {'success': True}
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'删除图片失败: {str(e)}'}

@ -0,0 +1,191 @@
"""
任务处理服务
处理图像加噪评估等核心业务逻辑
"""
import os
import time
from datetime import datetime
from flask import current_app
from app import db
from app.models import Batch, Image, EvaluationResult, ImageType
from app.algorithms.perturbation_engine import PerturbationEngine
from app.algorithms.evaluation_engine import EvaluationEngine
class TaskService:
"""任务处理服务"""
@staticmethod
def start_processing(batch):
"""开始处理任务"""
try:
# 更新任务状态
batch.status = 'processing'
batch.started_at = datetime.utcnow()
db.session.commit()
# 获取任务相关的原始图片
original_images = Image.query.filter_by(
batch_id=batch.id
).join(ImageType).filter(
ImageType.type_code == 'original'
).all()
if not original_images:
batch.status = 'failed'
batch.error_message = '没有找到原始图片'
batch.completed_at = datetime.utcnow()
db.session.commit()
return False
# 处理每张图片
perturbed_type = ImageType.query.filter_by(type_code='perturbed').first()
processed_images = []
for original_image in original_images:
try:
# 确定加噪图片的保存路径
project_root = os.path.dirname(current_app.root_path)
perturbed_dir = os.path.join(project_root,
current_app.config['PERTURBED_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id))
os.makedirs(perturbed_dir, exist_ok=True)
# 调用加噪算法
perturbed_image_path = PerturbationEngine.apply_perturbation(
original_image.file_path,
batch.perturbation_config.method_code,
float(batch.preferred_epsilon),
batch.use_strong_protection
)
if perturbed_image_path:
# 保存加噪后的图片记录
perturbed_image = Image(
user_id=batch.user_id,
batch_id=batch.id,
father_id=original_image.id,
original_filename=f"perturbed_{original_image.original_filename}",
stored_filename=os.path.basename(perturbed_image_path),
file_path=perturbed_image_path,
file_size=os.path.getsize(perturbed_image_path) if os.path.exists(perturbed_image_path) else 0,
image_type_id=perturbed_type.id,
width=original_image.width,
height=original_image.height
)
db.session.add(perturbed_image)
processed_images.append((original_image, perturbed_image))
except Exception as e:
print(f"处理图片 {original_image.id} 时出错: {str(e)}")
continue
# 提交加噪后的图片
db.session.commit()
# 生成评估结果
TaskService._generate_evaluations(batch, processed_images)
# 更新任务状态为完成
batch.status = 'completed'
batch.completed_at = datetime.utcnow()
db.session.commit()
return True
except Exception as e:
# 处理失败
batch.status = 'failed'
batch.error_message = str(e)
batch.completed_at = datetime.utcnow()
db.session.commit()
return False
@staticmethod
def _generate_evaluations(batch, processed_images):
"""生成评估结果"""
try:
for original_image, perturbed_image in processed_images:
# 图像质量对比评估
quality_metrics = EvaluationEngine.evaluate_image_quality(
original_image.file_path,
perturbed_image.file_path
)
quality_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='image_quality',
purification_applied=False,
fid_score=quality_metrics.get('fid'),
lpips_score=quality_metrics.get('lpips'),
ssim_score=quality_metrics.get('ssim'),
psnr_score=quality_metrics.get('psnr'),
heatmap_path=quality_metrics.get('heatmap_path')
)
db.session.add(quality_evaluation)
# 模型生成对比评估
generation_metrics = EvaluationEngine.evaluate_model_generation(
original_image.file_path,
perturbed_image.file_path,
batch.finetune_config.method_code
)
generation_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='model_generation',
purification_applied=False,
fid_score=generation_metrics.get('fid'),
lpips_score=generation_metrics.get('lpips'),
ssim_score=generation_metrics.get('ssim'),
psnr_score=generation_metrics.get('psnr'),
heatmap_path=generation_metrics.get('heatmap_path')
)
db.session.add(generation_evaluation)
db.session.commit()
except Exception as e:
print(f"生成评估结果时出错: {str(e)}")
@staticmethod
def get_processing_progress(batch_id):
"""获取处理进度"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return 0
if batch.status == 'pending':
return 0
elif batch.status == 'completed':
return 100
elif batch.status == 'failed':
return 0
elif batch.status == 'processing':
# 简单的进度计算:根据已处理的图片数量
total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'original'
).count()
processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'perturbed'
).count()
if total_images == 0:
return 0
progress = int((processed_images / total_images) * 80) # 80%用于图像处理20%用于评估
return min(progress, 95) # 最多95%剩余5%用于最终完成
return 0
except Exception as e:
print(f"获取处理进度时出错: {str(e)}")
return 0

@ -0,0 +1,53 @@
"""
文件处理工具类
"""
import os
from werkzeug.utils import secure_filename
from flask import current_app
def allowed_file(filename):
"""检查文件扩展名是否被允许"""
if not filename:
return False
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in allowed_extensions
def save_uploaded_file(file, upload_path):
"""保存上传的文件"""
try:
if not file or not allowed_file(file.filename):
return None
filename = secure_filename(file.filename)
# 确保目录存在
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
file.save(upload_path)
return upload_path
except Exception as e:
print(f"保存文件失败: {str(e)}")
return None
def get_file_size(file_path):
"""获取文件大小"""
try:
return os.path.getsize(file_path)
except:
return 0
def delete_file(file_path):
"""删除文件"""
try:
if os.path.exists(file_path):
os.remove(file_path)
return True
except:
pass
return False

@ -0,0 +1,16 @@
"""
JWT工具函数
"""
from functools import wraps
from flask_jwt_extended import get_jwt_identity
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
jwt_identity = get_jwt_identity()
if jwt_identity is not None:
# 在函数调用前注入转换后的user_id
kwargs['current_user_id'] = int(jwt_identity)
return f(*args, **kwargs)
return wrapped

@ -0,0 +1,16 @@
# MuseGuard 环境变量配置文件
# 注意:此文件包含敏感信息,不应提交到版本控制系统
# 数据库配置
DB_USER=root
DB_PASSWORD=your_password_here
DB_HOST=localhost
DB_NAME=your_database_name_here
# Flask配置
SECRET_KEY=museguard-secret-key-2024
JWT_SECRET_KEY=jwt-secret-string
# 开发模式
FLASK_ENV=development
FLASK_DEBUG=True

@ -0,0 +1,109 @@
"""
应用配置文件
"""
import os
from datetime import timedelta
from dotenv import load_dotenv
from urllib.parse import quote_plus
# 加载环境变量 - 从 config 目录读取 .env 文件
config_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(config_dir, '.env')
load_dotenv(env_path)
class Config:
"""基础配置类"""
# 基础配置
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
# 数据库配置 - 支持密码中的特殊字符
DB_USER = os.environ.get('DB_USER')
DB_PASSWORD = os.environ.get('DB_PASSWORD')
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
# URL编码密码中的特殊字符
from urllib.parse import quote_plus
_encoded_password = quote_plus(DB_PASSWORD)
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ENGINE_OPTIONS = {
'pool_pre_ping': True,
'pool_recycle': 300,
}
# JWT配置
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
# 静态文件根目录
STATIC_ROOT = 'static'
# 文件上传配置
UPLOAD_FOLDER = 'uploads' # 临时上传目录
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
# 图像处理配置
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
# 预设演示图像配置
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
# 邮件配置(用于注册验证)
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
# 算法配置
ALGORITHMS = {
'simac': {
'name': 'SimAC算法',
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
'default_epsilon': 8.0
},
'caat': {
'name': 'CAAT算法',
'description': 'Perturbing Attention Gives You More Bang for the Buck',
'default_epsilon': 16.0
},
'pid': {
'name': 'PID算法',
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
'default_epsilon': 4.0
}
}
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}

@ -0,0 +1,81 @@
"""
数据库初始化脚本
"""
from app import create_app, db
from app.models import *
def init_database():
"""初始化数据库"""
app = create_app()
with app.app_context():
# 创建所有表
db.create_all()
# 初始化图片类型数据
image_types = [
{'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'},
{'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'},
{'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'},
{'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'}
]
for img_type in image_types:
existing = ImageType.query.filter_by(type_code=img_type['type_code']).first()
if not existing:
new_type = ImageType(**img_type)
db.session.add(new_type)
# 初始化加噪算法数据
perturbation_configs = [
{'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0},
{'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0},
{'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0},
{'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0}
]
for config in perturbation_configs:
existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = PerturbationConfig(**config)
db.session.add(new_config)
# 初始化微调方式数据
finetune_configs = [
{'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'},
{'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应Low-Rank Adaptation微调方法'},
{'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'}
]
for config in finetune_configs:
existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = FinetuneConfig(**config)
db.session.add(new_config)
# 创建默认管理员用户
admin_users = [
{'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'},
{'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'},
{'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'}
]
for admin_data in admin_users:
existing = User.query.filter_by(username=admin_data['username']).first()
if not existing:
admin_user = User(**admin_data)
admin_user.set_password('admin123') # 默认密码
db.session.add(admin_user)
# 为管理员创建默认配置
db.session.flush() # 确保user.id可用
user_config = UserConfig(user_id=admin_user.id)
db.session.add(user_config)
# 提交所有更改
db.session.commit()
print("数据库初始化完成!")
if __name__ == '__main__':
init_database()

@ -0,0 +1,15 @@
@echo off
chcp 65001 > nul
echo ==========================================
echo MuseGuard 快速启动脚本
echo ==========================================
echo.
REM 切换到项目目录
cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend"
REM 激活虚拟环境并启动服务器
echo 正在激活虚拟环境并启动服务器...
call venv_py311\Scripts\activate.bat && python run.py
pause

@ -0,0 +1,33 @@
# Core Flask Framework
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
Flask-JWT-Extended==4.6.0
Flask-CORS==5.0.0
Werkzeug==3.0.1
# Database
PyMySQL==1.1.1
# Image Processing
Pillow==10.4.0
numpy==1.26.4
# Security & Utils
cryptography==42.0.8
python-dotenv==1.0.1
# Additional Dependencies (auto-installed)
# alembic==1.17.0
# blinker==1.9.0
# cffi==2.0.0
# click==8.3.0
# greenlet==3.2.4
# itsdangerous==2.2.0
# Jinja2==3.1.6
# Mako==1.3.10
# MarkupSafe==3.0.3
# pycparser==2.23
# PyJWT==2.10.1
# SQLAlchemy==2.0.44
# typing_extensions==4.15.0

@ -0,0 +1,21 @@
"""
Flask应用启动脚本
"""
import os
from app import create_app
# 设置环境变量
os.environ.setdefault('FLASK_ENV', 'development')
# 创建应用实例
app = create_app()
if __name__ == '__main__':
# 开发模式启动
app.run(
host='0.0.0.0',
port=5000,
debug=True,
threaded=True
)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save