将lianghao_branch合并到develop #2

Merged
hnu202326010204 merged 3 commits from lianghao_branch into develop 2 months ago

27
.gitignore vendored

@ -1,9 +1,18 @@
__pycache__/
venv/
*.png
*.jpg
*.jpeg
.env
__pycache__/
venv/
python=3.11/
*.png
*.jpg
*.jpeg
# 环境配置文件(包含敏感信息)
*.env
# 日志文件
logs/
*.log
# 上传文件临时目录
uploads/

@ -1,25 +1,25 @@
# MuseGuard
占位:项目总说明。后续将补充以下内容:
## 简介
(占位)
## 项目目标
(占位)
## 技术栈
(占位)
## 快速开始
(占位)
## 目录结构说明
(占位)
## 贡献指南
(占位)
## 许可证
(占位)
# MuseGuard
占位:项目总说明。后续将补充以下内容:
## 简介
(占位)
## 项目目标
(占位)
## 技术栈
(占位)
## 快速开始
(占位)
## 目录结构说明
(占位)
## 贡献指南
(占位)
## 许可证
(占位)

@ -0,0 +1,29 @@
# Python 编译缓存
__pycache__/
# 图片文件
*.png
*.jpg
*.jpeg
# 环境配置文件(包含敏感信息)
*.env
# 日志及进程文件
logs/
*.log
*.pid
# 上传文件临时目录
uploads/
# 微调生成文件
*.json
*.bin
*.pkl
*.safetensors
*.pt
*.txt
# 模型文件
hf_models/

@ -1,243 +1,378 @@
# MuseGuard 后端框架
基于对抗性扰动的多风格图像生成防护系统 - 后端API服务
## 项目结构
```
backend/
├── app/ # 主应用目录
│ ├── algorithms/ # 算法实现
│ │ ├── perturbation_engine.py # 对抗性扰动引擎
│ │ └── evaluation_engine.py # 评估引擎
│ ├── controllers/ # 控制器(路由处理)
│ │ ├── auth_controller.py # 认证控制器
│ │ ├── user_controller.py # 用户配置控制器
│ │ ├── task_controller.py # 任务控制器
| | ├── demo_controller.py # 首页示例控制器
│ │ ├── image_controller.py # 图像控制器
│ │ └── admin_controller.py # 管理员控制器
│ ├── models/ # 数据模型
│ │ └── __init__.py # SQLAlchemy模型定义
│ ├── services/ # 业务逻辑服务
│ │ ├── auth_service.py # 认证服务
│ │ ├── task_service.py # 任务处理服务
│ │ └── image_service.py # 图像处理服务
│ └── utils/ # 工具类
│ └── file_utils.py # 文件处理工具
├── config/ # 配置文件
│ └── settings.py # 应用配置
├── uploads/ # 文件上传目录
├── static/ # 静态文件
│ ├── originals/ # 重命名后的原始图片
│ ├── perturbed/ # 加噪后的图片
│ ├── model_outputs/ # 模型生成的图片
│ │ ├── clean/ # 原图的模型生成结果
│ │ └── perturbed/ # 加噪图的模型生成结果
│ ├── heatmaps/ # 热力图
│ └── demo/ # 演示图片
│ ├── original/ # 演示原始图片
│ ├── perturbed/ # 演示加噪图片
│ └── comparisons/ # 演示对比图
├── app.py # Flask应用工厂
├── run.py # 启动脚本
├── init_db.py # 数据库初始化脚本
└── requirements.txt # Python依赖
```
## 功能特性
### 用户功能
- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次)
- ✅ 用户登录/登出
- ✅ 密码修改
- ✅ 任务创建和管理
- ✅ 图片上传(单张/压缩包批量)
- ✅ 加噪处理4种算法SimAC、CAAT、PID、ASPL
- ✅ 扰动强度自定义
- ✅ 防净化版本选择
- ✅ 智能配置记忆:自动保存用户上次选择的配置
- ✅ 处理结果下载
- ✅ 图片质量对比查看FID、LPIPS、SSIM、PSNR、热力图
- ✅ 模型生成对比分析
- ✅ 预设演示图片浏览
### 管理员功能
- ✅ 用户管理(增删改查)
- ✅ 系统统计信息查看
### 算法实现
- ✅ ASPL算法虚拟实现原始版本 + 防净化版本)
- ✅ SimAC算法虚拟实现原始版本 + 防净化版本)
- ✅ CAAT算法虚拟实现原始版本 + 防净化版本)
- ✅ PID算法虚拟实现原始版本 + 防净化版本)
- ✅ 图像质量评估指标计算
- ✅ 模型生成效果对比
- ✅ 热力图生成
## 安装和运行
### 1. 环境准备
#### 使用虚拟环境(推荐)
**为什么需要虚拟环境?**
- ✅ **避免依赖冲突**:不同项目使用不同版本的包
- ✅ **环境隔离**不污染系统Python环境
- ✅ **版本一致性**:确保团队环境统一
- ✅ **易于管理**:可以随时删除重建
```bash
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Windows:
venv\\Scripts\\activate
# Linux/Mac:
source venv/bin/activate
# 更新pip推荐
python -m pip install --upgrade pip
# 安装依赖
pip install -r requirements.txt
```
### 2. 数据库配置
确保已安装MySQL数据库并创建数据库。
修改 `config/.env` 中的数据库连接配置:
### 3. 初始化数据库
```bash
# 运行数据库初始化脚本
python init_db.py
```
### 4. 启动应用
```bash
# 开发模式启动
python run.py
# 或者使用Flask命令
flask run
```
应用将在 `http://localhost:5000` 启动
### 5. 系统测试
访问 `http://localhost:5000/static/test.html` 进入功能测试页面:
## API接口文档
### 认证接口 (`/api/auth`)
- `POST /register` - 用户注册
- `POST /login` - 用户登录
- `POST /change-password` - 修改密码
- `GET /profile` - 获取用户信息
- `POST /logout` - 用户登出
### 任务管理 (`/api/task`)
- `POST /create` - 创建任务(使用默认配置)
- `POST /upload/<batch_id>` - 上传图片到指定任务
- `GET /<batch_id>/config` - 获取任务配置(显示用户上次选择)
- `PUT /<batch_id>/config` - 更新任务配置(自动保存为用户偏好)
- `GET /load-config` - 加载用户上次配置
- `POST /save-config` - 保存用户配置偏好
- `POST /start/<batch_id>` - 开始处理任务
- `GET /list` - 获取任务列表
- `GET /<batch_id>` - 获取任务详情
- `GET /<batch_id>/status` - 获取处理状态
### 图片管理 (`/api/image`)
- `GET /file/<image_id>` - 查看图片
- `GET /download/<image_id>` - 下载图片
- `GET /batch/<batch_id>/download` - 批量下载
- `GET /<image_id>/evaluations` - 获取评估结果
- `POST /compare` - 对比图片
- `GET /heatmap/<path>` - 获取热力图
- `DELETE /delete/<image_id>` - 删除图片
### 用户设置 (`/api/user`)
- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中)
- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存)
- `GET /algorithms` - 获取可用算法(动态从数据库加载)
- `GET /stats` - 获取用户统计
### 管理员功能 (`/api/admin`)
- `GET /users` - 用户列表
- `GET /users/<user_id>` - 用户详情
- `POST /users` - 创建用户
- `PUT /users/<user_id>` - 更新用户
- `DELETE /users/<user_id>` - 删除用户
- `GET /stats` - 系统统计
### 演示功能 (`/api/demo`)
- `GET /images` - 获取演示图片列表
- `GET /image/original/<filename>` - 获取演示原始图片
- `GET /image/perturbed/<filename>` - 获取演示加噪图片
- `GET /image/comparison/<filename>` - 获取演示对比图片
- `GET /algorithms` - 获取算法演示信息
- `GET /stats` - 获取演示统计数据
## 默认账户
系统初始化后会创建3个管理员账户
- 用户名:`admin1`, `admin2`, `admin3`
- 默认密码:`admin123`
- 邮箱:`admin1@museguard.com` 等
## 技术栈
- **Web框架**: Flask 2.3.3
- **数据库ORM**: SQLAlchemy 3.0.5
- **数据库**: MySQL通过PyMySQL连接
- **认证**: JWT (Flask-JWT-Extended)
- **跨域**: Flask-CORS
- **图像处理**: Pillow + NumPy
- **数学计算**: NumPy
## 开发说明
### 虚拟实现说明
当前所有算法都是**虚拟实现**,用于框架搭建和测试:
1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果
2. **评估指标**: 基于像素差异的简化计算
3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果
### 扩展指南
要集成真实算法:
1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现
2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算
3. 根据需要调整配置参数
### 目录权限
确保以下目录有写入权限:
- `uploads/` - 用户上传文件
- `static/originals/` - 重命名后的原始图片
- `static/perturbed/` - 加噪后的图片
- `static/model_outputs/` - 模型生成的图片
- `static/heatmaps/` - 热力图文件
- `static/demo/` - 演示图片(需要手动添加演示文件)
## 许可证
本项目仅用于学习和研究目的。
# MuseGuard 后端框架
基于对抗性扰动的多风格图像生成防护系统 - 后端API服务
## Linux 环境配置MySQL、Redis、Python等
### 1. 安装系统依赖
```bash
sudo apt update
sudo apt install -y build-essential python3 python3-venv python3-pip git
```
### 2. 安装 MySQL
老版使用
```bash
# 启动 Redis
sudo service mysqld start
# 停止 Redis
sudo service mysqld stop
# 重启 Redis
sudo service mysqld restart
# 查看 Redis 状态
sudo service mysqld status
```
```bash
sudo apt install -y mysql-server
sudo systemctl enable mysql
sudo systemctl start mysql
# 安全初始化可选建议设置root密码
sudo mysql_secure_installation
# 登录MySQL创建数据库和用户
mysql -u root -p
```
在MySQL命令行中执行
```sql
CREATE DATABASE museguard DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
# CREATE USER 'museguard'@'localhost' IDENTIFIED BY 'yourpassword';
# GRANT ALL PRIVILEGES ON museguard.* TO 'museguard'@'localhost';
# FLUSH PRIVILEGES;
EXIT;
```
### 3. 安装 Redis
老版使用service命令
```bash
# 启动 Redis
sudo service redis-server start
# 停止 Redis
sudo service redis-server stop
# 重启 Redis
sudo service redis-server restart
# 查看 Redis 状态
sudo service redis-server status
```
```bash
sudo apt install -y redis-server
sudo systemctl enable redis-server
sudo systemctl start redis-server
# 测试
redis-cli ping
# 返回PONG表示正常
```
### 4. Python 虚拟环境与依赖
在Linux换成conda管理
```bash
cd /path/to/your/project/src/backend
python3 -m venv venv
source venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
```
### 5. 配置数据库连接
编辑 `config/settings.py``setting.env` 文件,设置如下内容:
```python
DROP DATABASE IF EXISTS muse_guard;
CREATE DATABASE muse_guard DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
EXIT;
SQLALCHEMY_DATABASE_URI = 'mysql+pymysql://museguard:yourpassword@localhost:3306/museguard?charset=utf8mb4'
REDIS_URL = 'redis://localhost:6379/0'
```
### 6. 初始化数据库
```bash
python init_db.py
```
### 7. 启动服务
```bash
# 启动Flask后端
python run.py
# 启动RQ Worker另开终端
source venv/bin/activate
cd /path/to/your/project/src/backend
rq worker museguard
```
---
## 项目结构
```
backend/
├── app/ # 主应用目录
│ ├── algorithms/ # 算法实现
│ │ ├── perturbation_engine.py # 对抗性扰动引擎
│ │ └── evaluation_engine.py # 评估引擎
│ ├── controllers/ # 控制器(路由处理)
│ │ ├── auth_controller.py # 认证控制器
│ │ ├── user_controller.py # 用户配置控制器
│ │ ├── task_controller.py # 任务控制器
| | ├── demo_controller.py # 首页示例控制器
│ │ ├── image_controller.py # 图像控制器
│ │ └── admin_controller.py # 管理员控制器
│ ├── models/ # 数据模型
│ │ └── __init__.py # SQLAlchemy模型定义
│ ├── services/ # 业务逻辑服务
│ │ ├── auth_service.py # 认证服务
│ │ ├── task_service.py # 任务处理服务
│ │ └── image_service.py # 图像处理服务
│ └── utils/ # 工具类
│ └── file_utils.py # 文件处理工具
├── config/ # 配置文件
│ └── settings.py # 应用配置
├── uploads/ # 文件上传目录
├── static/ # 静态文件
│ ├── originals/ # 重命名后的原始图片
│ ├── perturbed/ # 加噪后的图片
│ ├── model_outputs/ # 模型生成的图片
│ │ ├── clean/ # 原图的模型生成结果
│ │ └── perturbed/ # 加噪图的模型生成结果
│ ├── heatmaps/ # 热力图
│ └── demo/ # 演示图片
│ ├── original/ # 演示原始图片
│ ├── perturbed/ # 演示加噪图片
│ └── comparisons/ # 演示对比图
├── app.py # Flask应用工厂
├── run.py # 启动脚本
├── init_db.py # 数据库初始化脚本
└── requirements.txt # Python依赖
```
## 功能特性
### 用户功能
- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次)
- ✅ 用户登录/登出
- ✅ 密码修改
- ✅ 任务创建和管理
- ✅ 图片上传(单张/压缩包批量)
- ✅ 加噪处理4种算法SimAC、CAAT、PID、ASPL
- ✅ 扰动强度自定义
- ✅ 防净化版本选择
- ✅ 智能配置记忆:自动保存用户上次选择的配置
- ✅ 处理结果下载
- ✅ 图片质量对比查看FID、LPIPS、SSIM、PSNR、热力图
- ✅ 模型生成对比分析
- ✅ 预设演示图片浏览
### 管理员功能
- ✅ 用户管理(增删改查)
- ✅ 系统统计信息查看
### 算法实现
- ✅ ASPL算法虚拟实现原始版本 + 防净化版本)
- ✅ SimAC算法虚拟实现原始版本 + 防净化版本)
- ✅ CAAT算法虚拟实现原始版本 + 防净化版本)
- ✅ PID算法虚拟实现原始版本 + 防净化版本)
- ✅ 图像质量评估指标计算
- ✅ 模型生成效果对比
- ✅ 热力图生成
## 安装和运行
### 1. 环境准备
#### 使用虚拟环境(推荐)
**为什么需要虚拟环境?**
- ✅ **避免依赖冲突**:不同项目使用不同版本的包
- ✅ **环境隔离**不污染系统Python环境
- ✅ **版本一致性**:确保团队环境统一
- ✅ **易于管理**:可以随时删除重建
```bash
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Windows:
venv\\Scripts\\activate
# Linux/Mac:
source venv/bin/activate
# 更新pip推荐
python -m pip install --upgrade pip
# 安装依赖
pip install -r requirements.txt
```
### 2. 数据库配置
确保已安装MySQL数据库并创建数据库。
修改 `config/.env` 中的数据库连接配置:
### 3. 初始化数据库
```bash
# 运行数据库初始化脚本
python init_db.py
```
### 4. 启动应用
```bash
# 开发模式启动
python run.py
# 或者使用Flask命令
flask run
```
应用将在 `http://localhost:5000` 启动
### 5. 系统测试
访问 `http://localhost:5000/static/test.html` 进入功能测试页面:
## API接口文档
### 认证接口 (`/api/auth`)
- `POST /register` - 用户注册
- `POST /login` - 用户登录
- `POST /change-password` - 修改密码
- `GET /profile` - 获取用户信息
- `POST /logout` - 用户登出
### 任务管理 (`/api/task`)
- `POST /create` - 创建任务(使用默认配置)
- `POST /upload/<batch_id>` - 上传图片到指定任务
- `GET /<batch_id>/config` - 获取任务配置(显示用户上次选择)
- `PUT /<batch_id>/config` - 更新任务配置(自动保存为用户偏好)
- `GET /load-config` - 加载用户上次配置
- `POST /save-config` - 保存用户配置偏好
- `POST /start/<batch_id>` - 开始处理任务
- `GET /list` - 获取任务列表
- `GET /<batch_id>` - 获取任务详情
- `GET /<batch_id>/status` - 获取处理状态
### 图片管理 (`/api/image`)
- `GET /file/<image_id>` - 查看图片
- `GET /download/<image_id>` - 下载图片
- `GET /batch/<batch_id>/download` - 批量下载
- `GET /<image_id>/evaluations` - 获取评估结果
- `POST /compare` - 对比图片
- `GET /heatmap/<path>` - 获取热力图
- `DELETE /delete/<image_id>` - 删除图片
### 用户设置 (`/api/user`)
- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中)
- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存)
- `GET /algorithms` - 获取可用算法(动态从数据库加载)
- `GET /stats` - 获取用户统计
### 管理员功能 (`/api/admin`)
- `GET /users` - 用户列表
- `GET /users/<user_id>` - 用户详情
- `POST /users` - 创建用户
- `PUT /users/<user_id>` - 更新用户
- `DELETE /users/<user_id>` - 删除用户
- `GET /stats` - 系统统计
### 演示功能 (`/api/demo`)
- `GET /images` - 获取演示图片列表
- `GET /image/original/<filename>` - 获取演示原始图片
- `GET /image/perturbed/<filename>` - 获取演示加噪图片
- `GET /image/comparison/<filename>` - 获取演示对比图片
- `GET /algorithms` - 获取算法演示信息
- `GET /stats` - 获取演示统计数据
## 默认账户
系统初始化后会创建3个管理员账户
- 用户名:`admin1`, `admin2`, `admin3`
- 默认密码:`admin123`
- 邮箱:`admin1@museguard.com` 等
## 技术栈
- **Web框架**: Flask 2.3.3
- **数据库ORM**: SQLAlchemy 3.0.5
- **数据库**: MySQL通过PyMySQL连接
- **认证**: JWT (Flask-JWT-Extended)
- **跨域**: Flask-CORS
- **图像处理**: Pillow + NumPy
- **数学计算**: NumPy
## 开发说明
### 虚拟实现说明
当前所有算法都是**虚拟实现**,用于框架搭建和测试:
1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果
2. **评估指标**: 基于像素差异的简化计算
3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果
### 扩展指南
要集成真实算法:
1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现
2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算
3. 根据需要调整配置参数
### 目录权限
确保以下目录有写入权限:
- `uploads/` - 用户上传文件
- `static/originals/` - 重命名后的原始图片
- `static/perturbed/` - 加噪后的图片
- `static/model_outputs/` - 模型生成的图片
- `static/heatmaps/` - 热力图文件
- `static/demo/` - 演示图片(需要手动添加演示文件)
## 许可证
本项目仅用于学习和研究目的。
https://docs.pingcode.com/baike/2645380
功能流程正确(本地)
- 测试网页
- 配置正确加载
- 微调算法执行时机
云端正常调用算法
算法正常执行
云端部署,本地可直接访问
api规范
前端对接
conda activate flask
pip install accelerate
conda install -c conda-forge accelerate

@ -1,46 +1,46 @@
"""
MuseGuard 后端主应用入口
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
from config.settings import Config
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
def create_app(config_class=Config):
"""Flask应用工厂函数"""
app = Flask(__name__)
app.config.from_object(config_class)
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
CORS(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
return app
if __name__ == '__main__':
app = create_app()
app.run(debug=True, host='0.0.0.0', port=5000)
"""
MuseGuard 后端主应用入口
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
from config.settings import Config
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
def create_app(config_class=Config):
"""Flask应用工厂函数"""
app = Flask(__name__)
app.config.from_object(config_class)
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
CORS(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
return app
if __name__ == '__main__':
app = create_app()
app.run(debug=True, host='0.0.0.0', port=6006)

@ -1,83 +1,83 @@
"""
MuseGuard 后端应用包
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
import os
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
cors = CORS()
def create_app(config_name=None):
"""应用工厂函数"""
# 配置静态文件和模板文件路径
app = Flask(__name__,
static_folder='../static',
static_url_path='/static')
# 加载配置
if config_name is None:
config_name = os.environ.get('FLASK_ENV', 'development')
from config.settings import config
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
cors.init_app(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
from app.controllers.demo_controller import demo_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
app.register_blueprint(demo_bp, url_prefix='/api/demo')
# 注册错误处理器
@app.errorhandler(404)
def not_found_error(error):
return {'error': 'Not found'}, 404
@app.errorhandler(500)
def internal_error(error):
db.session.rollback()
return {'error': 'Internal server error'}, 500
# 根路由
@app.route('/')
def index():
return {
'message': 'MuseGuard API Server',
'version': '1.0.0',
'status': 'running',
'endpoints': {
'health': '/health',
'api_docs': '/api',
'test_page': '/static/test.html'
}
}
# 健康检查端点
@app.route('/health')
def health_check():
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
"""
MuseGuard 后端应用包
基于对抗性扰动的多风格图像生成防护系统
"""
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_cors import CORS
import os
# 初始化扩展
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
cors = CORS()
def create_app(config_name=None):
"""应用工厂函数"""
# 配置静态文件和模板文件路径
app = Flask(__name__,
static_folder='../static',
static_url_path='/static')
# 加载配置
if config_name is None:
config_name = os.environ.get('FLASK_ENV', 'development')
from config.settings import config
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
cors.init_app(app)
# 注册蓝图
from app.controllers.auth_controller import auth_bp
from app.controllers.user_controller import user_bp
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
from app.controllers.demo_controller import demo_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
app.register_blueprint(demo_bp, url_prefix='/api/demo')
# 注册错误处理器
@app.errorhandler(404)
def not_found_error(error):
return {'error': 'Not found'}, 404
@app.errorhandler(500)
def internal_error(error):
db.session.rollback()
return {'error': 'Internal server error'}, 500
# 根路由
@app.route('/')
def index():
return {
'message': 'MuseGuard API Server',
'version': '1.0.0',
'status': 'running',
'endpoints': {
'health': '/health',
'api_docs': '/api',
'test_page': '/static/test.html'
}
}
# 健康检查端点
@app.route('/health')
def health_check():
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
return app

@ -1,300 +0,0 @@
"""
评估引擎
实现图像质量评估和模型生成对比的虚拟版本
"""
import os
import numpy as np
from PIL import Image, ImageDraw
import uuid
import random
from flask import current_app
class EvaluationEngine:
"""评估处理引擎"""
@staticmethod
def evaluate_image_quality(original_path, perturbed_path):
"""
评估图像质量指标
Args:
original_path: 原始图片路径
perturbed_path: 扰动后图片路径
Returns:
包含各种评估指标的字典
"""
try:
# 加载图片
with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img:
# 确保图片尺寸一致
if orig_img.size != pert_img.size:
pert_img = pert_img.resize(orig_img.size)
# 转换为RGB
if orig_img.mode != 'RGB':
orig_img = orig_img.convert('RGB')
if pert_img.mode != 'RGB':
pert_img = pert_img.convert('RGB')
# 计算虚拟评估指标
fid_score = EvaluationEngine._calculate_mock_fid(orig_img, pert_img)
lpips_score = EvaluationEngine._calculate_mock_lpips(orig_img, pert_img)
ssim_score = EvaluationEngine._calculate_mock_ssim(orig_img, pert_img)
psnr_score = EvaluationEngine._calculate_mock_psnr(orig_img, pert_img)
# 生成热力图
heatmap_path = EvaluationEngine._generate_difference_heatmap(
orig_img, pert_img, "quality"
)
return {
'fid': fid_score,
'lpips': lpips_score,
'ssim': ssim_score,
'psnr': psnr_score,
'heatmap_path': heatmap_path
}
except Exception as e:
print(f"图像质量评估失败: {str(e)}")
return {
'fid': None,
'lpips': None,
'ssim': None,
'psnr': None,
'heatmap_path': None
}
@staticmethod
def evaluate_model_generation(original_path, perturbed_path, finetune_method):
"""
评估模型生成对比
Args:
original_path: 原始图片路径
perturbed_path: 扰动后图片路径
finetune_method: 微调方法 (dreambooth, lora)
Returns:
包含生成对比评估指标的字典
"""
try:
# 模拟训练过程并生成对比图片
with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img:
# 模拟原始图片训练后的生成结果
orig_generated = EvaluationEngine._simulate_model_generation(orig_img, finetune_method, False)
# 模拟扰动图片训练后的生成结果
pert_generated = EvaluationEngine._simulate_model_generation(pert_img, finetune_method, True)
# 计算生成质量对比指标
fid_score = EvaluationEngine._calculate_generation_fid(orig_generated, pert_generated)
lpips_score = EvaluationEngine._calculate_generation_lpips(orig_generated, pert_generated)
ssim_score = EvaluationEngine._calculate_generation_ssim(orig_generated, pert_generated)
psnr_score = EvaluationEngine._calculate_generation_psnr(orig_generated, pert_generated)
# 生成对比热力图
heatmap_path = EvaluationEngine._generate_difference_heatmap(
orig_generated, pert_generated, "generation"
)
return {
'fid': fid_score,
'lpips': lpips_score,
'ssim': ssim_score,
'psnr': psnr_score,
'heatmap_path': heatmap_path
}
except Exception as e:
print(f"模型生成评估失败: {str(e)}")
return {
'fid': None,
'lpips': None,
'ssim': None,
'psnr': None,
'heatmap_path': None
}
@staticmethod
def _calculate_mock_fid(img1, img2):
"""模拟FID计算"""
# 简单的像素差异统计作为FID近似
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
mse = np.mean((arr1 - arr2) ** 2)
# FID通常在0-300之间扰动图片应该有较小的差异
mock_fid = min(mse / 10.0 + random.uniform(0.5, 2.0), 50.0)
return round(mock_fid, 4)
@staticmethod
def _calculate_mock_lpips(img1, img2):
"""模拟LPIPS计算"""
arr1 = np.array(img1, dtype=np.float32) / 255.0
arr2 = np.array(img2, dtype=np.float32) / 255.0
# 简单的感知差异模拟
diff = np.abs(arr1 - arr2)
# 模拟感知权重(边缘区域权重更高)
gray1 = np.mean(arr1, axis=2)
edges = np.abs(np.gradient(gray1)[0]) + np.abs(np.gradient(gray1)[1])
weighted_diff = np.mean(diff * (1 + edges.reshape(edges.shape + (1,))))
mock_lpips = min(weighted_diff + random.uniform(0.001, 0.01), 1.0)
return round(mock_lpips, 4)
@staticmethod
def _calculate_mock_ssim(img1, img2):
"""模拟SSIM计算"""
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
# 简单的结构相似性模拟
mu1 = np.mean(arr1)
mu2 = np.mean(arr2)
sigma1_sq = np.var(arr1)
sigma2_sq = np.var(arr2)
sigma12 = np.mean((arr1 - mu1) * (arr2 - mu2))
c1 = (0.01 * 255) ** 2
c2 = (0.03 * 255) ** 2
ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \
((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2))
# SSIM在-1到1之间但通常在0.8-1.0之间为良好
mock_ssim = max(0.85 + random.uniform(-0.05, 0.1), 0.0)
return round(mock_ssim, 4)
@staticmethod
def _calculate_mock_psnr(img1, img2):
"""模拟PSNR计算"""
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
mse = np.mean((arr1 - arr2) ** 2)
if mse == 0:
return 100.0
max_pixel = 255.0
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
# 对于轻微扰动PSNR应该比较高30-50dB
mock_psnr = max(min(psnr + random.uniform(-2, 2), 50.0), 20.0)
return round(mock_psnr, 4)
@staticmethod
def _simulate_model_generation(img, finetune_method, is_perturbed):
"""模拟模型生成过程"""
# 对输入图像进行变换以模拟生成结果
arr = np.array(img, dtype=np.float32)
if finetune_method == "dreambooth":
# DreamBooth风格变换
if is_perturbed:
# 扰动图片训练的模型生成质量较差
noise_level = 15.0
style_change = 0.3
else:
# 原始图片训练的模型生成质量较好
noise_level = 5.0
style_change = 0.1
elif finetune_method == "lora":
# LoRA风格变换
if is_perturbed:
noise_level = 12.0
style_change = 0.25
else:
noise_level = 3.0
style_change = 0.08
else:
noise_level = 8.0
style_change = 0.15
# 添加噪声模拟生成质量
noise = np.random.normal(0, noise_level, arr.shape)
generated_arr = arr + noise
# 模拟风格变化
if style_change > 0:
# 简单的颜色偏移
generated_arr[:,:,0] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr[:,:,1] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr[:,:,2] *= (1 + style_change * random.uniform(-0.5, 0.5))
generated_arr = np.clip(generated_arr, 0, 255).astype(np.uint8)
return Image.fromarray(generated_arr)
@staticmethod
def _calculate_generation_fid(img1, img2):
"""计算生成图片的FID质量越差FID越高越好"""
fid = EvaluationEngine._calculate_mock_fid(img1, img2)
# 对于防护效果FID越高越好
return max(fid, 10.0)
@staticmethod
def _calculate_generation_lpips(img1, img2):
"""计算生成图片的LPIPS差异越大越好"""
lpips = EvaluationEngine._calculate_mock_lpips(img1, img2)
return max(lpips, 0.1)
@staticmethod
def _calculate_generation_ssim(img1, img2):
"""计算生成图片的SSIM相似度越低越好"""
ssim = EvaluationEngine._calculate_mock_ssim(img1, img2)
# 对于防护效果SSIM越低越好
return min(ssim, 0.9)
@staticmethod
def _calculate_generation_psnr(img1, img2):
"""计算生成图片的PSNR质量差异越大越好"""
psnr = EvaluationEngine._calculate_mock_psnr(img1, img2)
return min(psnr, 35.0)
@staticmethod
def _generate_difference_heatmap(img1, img2, eval_type):
"""生成差异热力图"""
try:
arr1 = np.array(img1, dtype=np.float32)
arr2 = np.array(img2, dtype=np.float32)
# 计算像素差异
diff = np.abs(arr1 - arr2)
diff_gray = np.mean(diff, axis=2)
# 归一化到0-255
diff_normalized = (diff_gray / diff_gray.max() * 255).astype(np.uint8)
# 创建热力图
heatmap = Image.fromarray(diff_normalized, mode='L')
heatmap = heatmap.convert('RGB')
# 应用颜色映射(蓝色-绿色-红色)
heatmap_array = np.array(heatmap)
colored_heatmap = np.zeros_like(arr1, dtype=np.uint8)
# 简单的颜色映射
intensity = heatmap_array[:,:,0] / 255.0
colored_heatmap[:,:,0] = (intensity * 255).astype(np.uint8) # 红色通道
colored_heatmap[:,:,1] = ((1 - intensity) * intensity * 4 * 255).astype(np.uint8) # 绿色通道
colored_heatmap[:,:,2] = ((1 - intensity) * 255).astype(np.uint8) # 蓝色通道
# 保存热力图
heatmap_dir = os.path.join('static', 'heatmaps')
os.makedirs(heatmap_dir, exist_ok=True)
heatmap_filename = f"{eval_type}_heatmap_{uuid.uuid4().hex[:8]}.png"
heatmap_path = os.path.join(heatmap_dir, heatmap_filename)
Image.fromarray(colored_heatmap).save(heatmap_path)
return heatmap_path
except Exception as e:
print(f"生成热力图失败: {str(e)}")
return None

@ -0,0 +1,87 @@
import argparse
import os
import torch
from diffusers import StableDiffusionPipeline
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from PIL import Image
import numpy as np
from einops import rearrange
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
"--model_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./test-infer/",
help="The output directory where predictions are saved",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The output directory where predictions are saved",
)
parser.add_argument(
"--v",
type=str,
default="sks",
help="The output directory where predictions are saved",
)
args = parser.parse_args()
if __name__ == "__main__":
seed_everything(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
# define prompts
prompts = [
f"a photo of {args.v} person",
f"a dslr portrait of {args.v} person",
f"a photo of {args.v} person looking at the mirror",
f"a photo of {args.v} person in front of eiffel tower",
]
# create & load model
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path,
torch_dtype=torch.float32,
safety_checker=None,
local_files_only=True,
).to("cuda")
for prompt in prompts:
print(">>>>>>", prompt)
norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
out_path = f"{args.output_dir}/{norm_prompt}"
os.makedirs(out_path, exist_ok=True)
all_samples = list()
for i in range(5):
images = pipe([prompt] * 6, num_inference_steps=100, guidance_scale=7.5,).images
for idx, image in enumerate(images):
image.save(f"{out_path}/{i}_{idx}.png")
image = np.array(image, dtype=np.float32)
image /= 255.0
image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image) # numpy->tensor
all_samples.append(image)
grid = torch.stack(all_samples, 0)
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=8)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save(f"{args.output_dir}/{prompt}.png")
torch.cuda.empty_cache()
del pipe
torch.cuda.empty_cache()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,124 @@
"""
DreamBooth微调虚拟实现
用于测试后端流程不执行实际的模型训练
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def create_generated_image(source_image_path, output_path, index):
"""创建一个模拟生成的图片(简单复制源图片)"""
shutil.copy2(source_image_path, output_path)
def main():
parser = argparse.ArgumentParser(description="DreamBooth虚拟微调脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--instance_data_dir', required=True)
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--output_dir', required=True)
parser.add_argument('--validation_image_output_dir', required=True)
parser.add_argument('--with_prior_preservation', action='store_true')
parser.add_argument('--prior_loss_weight', type=float, default=1.0)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--resolution', type=int, default=512)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=2e-6)
parser.add_argument('--lr_scheduler', default='constant')
parser.add_argument('--lr_warmup_steps', type=int, default=0)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--max_train_steps', type=int, default=1000)
parser.add_argument('--checkpointing_steps', type=int, default=500)
parser.add_argument('--center_crop', action='store_true')
parser.add_argument('--mixed_precision', default='bf16')
parser.add_argument('--prior_generation_precision', default='bf16')
parser.add_argument('--sample_batch_size', type=int, default=5)
parser.add_argument('--validation_prompt', default='a photo of sks person')
parser.add_argument('--num_validation_images', type=int, default=10)
parser.add_argument('--validation_steps', type=int, default=500)
parser.add_argument('--is_perturbed', action='store_true', help='Whether training on perturbed images')
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL DREAMBOOTH] 虚拟微调执行开始")
print("=" * 80)
print(f"[VIRTUAL] 微调方法: DreamBooth")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 微调参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 实例数据目录: {args.instance_data_dir}")
print(f" - 模型输出目录: {args.output_dir}")
print(f" - 验证图片输出目录: {args.validation_image_output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print(f" - 验证图片数量: {args.num_validation_images}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.validation_image_output_dir, exist_ok=True)
# 获取训练图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张训练图片")
# 模拟训练过程
print("-" * 80)
print("[VIRTUAL] 开始模拟训练...")
for step in range(0, args.max_train_steps, args.checkpointing_steps):
print(f"[VIRTUAL] 训练步数: {step}/{args.max_train_steps}")
print(f"[VIRTUAL] 训练步数: {args.max_train_steps}/{args.max_train_steps} (完成)")
# 生成验证图片(从训练图片复制并添加标记)
print("-" * 80)
print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...")
# 根据is_perturbed决定文件名前缀
prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_"
generated_count = 0
for i in range(min(args.num_validation_images, len(image_files))):
source_image = image_files[i % len(image_files)]
filename = f"{prefix}{i:04d}.png"
output_path = os.path.join(args.validation_image_output_dir, filename)
try:
create_generated_image(source_image, output_path, i)
generated_count += 1
print(f"[VIRTUAL] 生成图片 {generated_count}/{args.num_validation_images}: {filename}")
except Exception as e:
print(f"[VIRTUAL] 生成图片失败: {e}")
# 保存模型文件标记
model_marker = os.path.join(args.output_dir, "virtual_model.txt")
with open(model_marker, 'w') as f:
f.write("This is a virtual DreamBooth model marker.\n")
f.write(f"Training images: {len(image_files)}\n")
f.write(f"Max steps: {args.max_train_steps}\n")
f.write(f"Generated images: {generated_count}\n")
print("-" * 80)
print(f"[VIRTUAL] 成功生成 {generated_count} 张图片")
print(f"[VIRTUAL] 模型保存到: {args.output_dir}")
print("[VIRTUAL] 虚拟微调执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,124 @@
"""
LoRA微调虚拟实现
用于测试后端流程不执行实际的模型训练
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def create_generated_image(source_image_path, output_path, index):
"""创建一个模拟生成的图片(简单复制源图片)"""
shutil.copy2(source_image_path, output_path)
def main():
parser = argparse.ArgumentParser(description="LoRA虚拟微调脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--instance_data_dir', required=True)
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--output_dir', required=True)
parser.add_argument('--validation_image_output_dir', required=True)
parser.add_argument('--with_prior_preservation', action='store_true')
parser.add_argument('--prior_loss_weight', type=float, default=1.0)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--resolution', type=int, default=512)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--lr_scheduler', default='constant')
parser.add_argument('--lr_warmup_steps', type=int, default=0)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--max_train_steps', type=int, default=1000)
parser.add_argument('--checkpointing_steps', type=int, default=500)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--mixed_precision', default='fp16')
parser.add_argument('--rank', type=int, default=4)
parser.add_argument('--validation_prompt', default='a photo of sks person')
parser.add_argument('--num_validation_images', type=int, default=10)
parser.add_argument('--is_perturbed', action='store_true', help='Whether training on perturbed images')
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL LORA] 虚拟微调执行开始")
print("=" * 80)
print(f"[VIRTUAL] 微调方法: LoRA (Low-Rank Adaptation)")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 微调参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 实例数据目录: {args.instance_data_dir}")
print(f" - 模型输出目录: {args.output_dir}")
print(f" - 验证图片输出目录: {args.validation_image_output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - LoRA rank: {args.rank}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print(f" - 验证图片数量: {args.num_validation_images}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.validation_image_output_dir, exist_ok=True)
# 获取训练图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张训练图片")
# 模拟训练过程
print("-" * 80)
print("[VIRTUAL] 开始模拟LoRA训练...")
for step in range(0, args.max_train_steps, args.checkpointing_steps):
print(f"[VIRTUAL] 训练步数: {step}/{args.max_train_steps}")
print(f"[VIRTUAL] 训练步数: {args.max_train_steps}/{args.max_train_steps} (完成)")
# 生成验证图片(从训练图片复制并添加标记)
print("-" * 80)
print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...")
# 根据is_perturbed决定文件名前缀
prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_"
generated_count = 0
for i in range(min(args.num_validation_images, len(image_files))):
source_image = image_files[i % len(image_files)]
filename = f"{prefix}{i:04d}.png"
output_path = os.path.join(args.validation_image_output_dir, filename)
try:
create_generated_image(source_image, output_path, i)
generated_count += 1
print(f"[VIRTUAL] 生成图片 {generated_count}/{args.num_validation_images}: {filename}")
except Exception as e:
print(f"[VIRTUAL] 生成图片失败: {e}")
# 保存LoRA模型文件标记
model_marker = os.path.join(args.output_dir, "virtual_lora_model.txt")
with open(model_marker, 'w') as f:
f.write("This is a virtual LoRA model marker.\n")
f.write(f"Training images: {len(image_files)}\n")
f.write(f"LoRA rank: {args.rank}\n")
f.write(f"Max steps: {args.max_train_steps}\n")
f.write(f"Generated images: {generated_count}\n")
print("-" * 80)
print(f"[VIRTUAL] 成功生成 {generated_count} 张图片")
print(f"[VIRTUAL] LoRA模型保存到: {args.output_dir}")
print("[VIRTUAL] 虚拟微调执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,773 @@
import argparse
import copy
import hashlib
import itertools
import logging
import os
from pathlib import Path
import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
logger = get_logger(__name__)
class DreamBoothDatasetFromTensor(Dataset):
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
def __init__(
self,
instance_images_tensor,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_images_tensor = instance_images_tensor
self.num_instance_images = len(self.instance_images_tensor)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = self.instance_images_tensor[index % self.num_instance_images]
example["instance_images"] = instance_image
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir_for_train",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--instance_data_dir_for_adversarial",
type=str,
default=None,
required=True,
help="A folder containing the images to add adversarial noise",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--prior_loss_weight",
type=float,
default=1.0,
help="The weight of prior preservation loss.",
)
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=8,
help="Batch size (per device) for sampling images.",
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--max_f_train_steps",
type=int,
default=10,
help="Total number of sub-steps to train surogate model.",
)
parser.add_argument(
"--max_adv_train_steps",
type=int,
default=10,
help="Total number of sub-steps to train adversarial noise.",
)
parser.add_argument(
"--checkpointing_iterations",
type=int,
default=5,
help=("Save a checkpoint of the training state every X iterations."),
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention",
action="store_true",
help="Whether or not to use xformers.",
)
parser.add_argument(
"--pgd_alpha",
type=float,
default=1.0 / 255,
help="The step size for pgd.",
)
parser.add_argument(
"--pgd_eps",
type=int,
default=0.05,
help="The noise budget for pgd.",
)
parser.add_argument(
"--target_image_path",
default=None,
help="target image for attacking",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())]
images = torch.stack(images)
return images
def train_one_epoch(
args,
models,
tokenizer,
noise_scheduler,
vae,
data_tensor: torch.Tensor,
num_steps=20,
):
# Load the tokenizer
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=1e-2,
eps=1e-08,
)
train_dataset = DreamBoothDatasetFromTensor(
data_tensor,
args.instance_prompt,
tokenizer,
args.class_data_dir,
args.class_prompt,
args.resolution,
args.center_crop,
)
# weight_dtype = torch.bfloat16
weight_dtype = torch.bfloat16
device = torch.device("cuda")
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
for step in range(num_steps):
unet.train()
text_encoder.train()
step_data = train_dataset[step % len(train_dataset)]
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
device, dtype=weight_dtype
)
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# with prior preservation loss
if args.with_prior_preservation:
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = instance_loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss.backward()
torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
optimizer.step()
optimizer.zero_grad()
print(
f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}"
)
return [unet, text_encoder]
def pgd_attack(
args,
models,
tokenizer,
noise_scheduler,
vae,
data_tensor: torch.Tensor,
original_images: torch.Tensor,
target_tensor: torch.Tensor,
num_steps: int,
):
"""Return new perturbed data"""
unet, text_encoder = models
weight_dtype = torch.bfloat16
device = torch.device("cuda")
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
perturbed_images = data_tensor.detach().clone()
perturbed_images.requires_grad_(True)
input_ids = tokenizer(
args.instance_prompt,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids.repeat(len(data_tensor), 1)
for step in range(num_steps):
perturbed_images.requires_grad = True
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
#noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet.zero_grad()
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
if target_tensor is not None:
xtm1_pred = torch.cat(
[
noise_scheduler.step(
model_pred[idx : idx + 1],
timesteps[idx : idx + 1],
noisy_latents[idx : idx + 1],
).prev_sample
for idx in range(len(model_pred))
]
)
xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1)
loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
loss.backward()
alpha = args.pgd_alpha
eps = args.pgd_eps / 255
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
return perturbed_images
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.mixed_precision == "fp32":
torch_dtype = torch.float32
elif args.mixed_precision == "fp16":
torch_dtype = torch.float16
elif args.mixed_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
).cuda()
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
clean_data = load_data(
args.instance_data_dir_for_train,
size=args.resolution,
center_crop=args.center_crop,
)
perturbed_data = load_data(
args.instance_data_dir_for_adversarial,
size=args.resolution,
center_crop=args.center_crop,
)
original_data = perturbed_data.clone()
original_data.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
target_latent_tensor = None
if args.target_image_path is not None:
target_image_path = Path(args.target_image_path)
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
target_latent_tensor = (
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
)
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
f = [unet, text_encoder]
for i in range(args.max_train_steps):
# 1. f' = f.clone()
f_sur = copy.deepcopy(f)
f_sur = train_one_epoch(
args,
f_sur,
tokenizer,
noise_scheduler,
vae,
clean_data,
args.max_f_train_steps,
)
perturbed_data = pgd_attack(
args,
f_sur,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
original_data,
target_latent_tensor,
args.max_adv_train_steps,
)
f = train_one_epoch(
args,
f,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
args.max_f_train_steps,
)
if (i + 1) % args.checkpointing_iterations == 0:
save_folder = args.output_dir
os.makedirs(save_folder, exist_ok=True)
noised_imgs = perturbed_data.detach()
img_filenames = [
Path(instance_path).stem
for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
]
for img_pixel, img_name in zip(noised_imgs, img_filenames):
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
logger.info(f"即将保存图片到: {save_path}")
Image.fromarray(
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
).save(save_path)
logger.info(f"图片已保存到: {save_path}")
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
if __name__ == "__main__":
args = parse_args()
main(args)

@ -0,0 +1,972 @@
import argparse
import hashlib
import itertools
import json
import logging
import os
import random
import warnings
import shutil
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
logger = get_logger(__name__)
def freeze_params(params):
for param in params:
param.requires_grad = False
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
class CustomDiffusionDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
concepts_list,
tokenizer,
size=512,
mask_size=64,
center_crop=False,
with_prior_preservation=False,
num_class_images=200,
hflip=False,
aug=True,
):
self.size = size
self.mask_size = mask_size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.interpolation = Image.BILINEAR
self.aug = aug
self.instance_images_path = []
self.class_images_path = []
self.with_prior_preservation = with_prior_preservation
for concept in concepts_list:
inst_img_path = [
(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()
]
self.instance_images_path.extend(inst_img_path)
if with_prior_preservation:
class_data_root = Path(concept["class_data_dir"])
if os.path.isdir(class_data_root):
class_images_path = list(class_data_root.iterdir())
class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))]
else:
with open(class_data_root, "r") as f:
class_images_path = f.read().splitlines()
with open(concept["class_prompt"], "r") as f:
class_prompt = f.read().splitlines()
class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
self.class_images_path.extend(class_img_path[:num_class_images])
random.shuffle(self.instance_images_path)
self.num_instance_images = len(self.instance_images_path)
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
self.image_transforms = transforms.Compose(
[
self.flip,
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def preprocess(self, image, scale, resample):
outer, inner = self.size, scale
factor = self.size // self.mask_size
if scale > self.size:
outer, inner = scale, self.size
top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
image = image.resize((scale, scale), resample=resample)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
mask = np.zeros((self.size // factor, self.size // factor))
if scale > self.size:
instance_image = image[top : top + inner, left : left + inner, :]
mask = np.ones((self.size // factor, self.size // factor))
else:
instance_image[top : top + inner, left : left + inner, :] = image
mask[
top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1
] = 1.0
return instance_image, mask
def __getitem__(self, index):
example = {}
instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]
instance_image = Image.open(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
instance_image = self.flip(instance_image)
# apply resize augmentation and create a valid image region mask
random_scale = self.size
if self.aug:
random_scale = (
np.random.randint(self.size // 3, self.size + 1)
if np.random.uniform() < 0.66
else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
)
instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)
if random_scale < 0.6 * self.size:
instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt
elif random_scale > self.size:
instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt
example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
example["mask"] = torch.from_numpy(mask)
example["instance_prompt_ids"] = self.tokenizer(
instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.with_prior_preservation:
class_image, class_prompt = self.class_images_path[index % self.num_class_images]
class_image = Image.open(class_image)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_mask"] = torch.ones_like(example["mask"])
example["class_prompt_ids"] = self.tokenizer(
class_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="CAAT training script.")
parser.add_argument(
"--alpha",
type=float,
default=5e-3,
required=True,
help="PGD alpha.",
)
parser.add_argument(
"--eps",
type=float,
default=0.1,
required=True,
help="PGD eps.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--prior_loss_weight",
type=float,
default=1.0,
help="The weight of prior preservation loss."
)
parser.add_argument(
"--num_class_images",
type=int,
default=200,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs",
help="The output directory.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=250,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=250,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=2,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--freeze_model",
type=str,
default="crossattn_kv",
choices=["crossattn_kv", "crossattn"],
help="crossattn to enable fine-tuning of all params in the cross attention",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument(
"--concepts_list",
type=str,
default=None,
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser.add_argument(
"--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word."
)
parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
parser.add_argument(
"--noaug",
action="store_true",
help="Dont apply augmentation during data augmentation when this flag is enabled.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.concepts_list is None:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
# logger is not available yet
if args.class_data_dir is not None:
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
return args
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
accelerator.init_trackers("CAAT", config=vars(args))
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if args.concepts_list is None:
args.concepts_list = [
{
"instance_prompt": args.instance_prompt,
"class_prompt": args.class_prompt,
"instance_data_dir": args.instance_data_dir,
"class_data_dir": args.class_data_dir,
}
]
else:
with open(args.concepts_list, "r") as f:
args.concepts_list = json.load(f)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
for i, concept in enumerate(args.concepts_list):
class_images_dir = Path(concept["class_data_dir"])
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
)
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
attention_class = CustomDiffusionAttnProcessor
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
attention_class = CustomDiffusionXFormersAttnProcessor
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# now we will add new Custom Diffusion weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
train_kv = True
train_q_out = False if args.freeze_model == "crossattn_kv" else True
custom_diffusion_attn_procs = {}
st = unet.state_dict()
for name, _ in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
layer_name = name.split(".processor")[0]
weights = {
"to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
"to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
}
if train_q_out:
weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
if cross_attention_dim is not None:
custom_diffusion_attn_procs[name] = attention_class(
train_kv=train_kv,
train_q_out=train_q_out,
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
).to(unet.device)
custom_diffusion_attn_procs[name].load_state_dict(weights)
else:
custom_diffusion_attn_procs[name] = attention_class(
train_kv=False,
train_q_out=False,
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
del st
unet.set_attn_processor(custom_diffusion_attn_procs)
custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(custom_diffusion_layers)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
args.learning_rate = args.learning_rate
if args.with_prior_preservation:
args.learning_rate = args.learning_rate * 2.0
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
optimizer = optimizer_class(
custom_diffusion_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset creation:
train_dataset = CustomDiffusionDataset(
concepts_list=args.concepts_list,
tokenizer=tokenizer,
with_prior_preservation=args.with_prior_preservation,
size=args.resolution,
mask_size=vae.encode(
torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)
)
.latent_dist.sample()
.size()[-1],
center_crop=args.center_crop,
num_class_images=args.num_class_images,
hflip=args.hflip,
aug=not args.noaug,
)
# Prepare for PGD
pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images]
pertubed_images = torch.stack(pertubed_images).contiguous()
pertubed_images.requires_grad_()
original_images = pertubed_images.clone().detach()
original_images.requires_grad_(False)
input_ids = train_dataset.tokenizer(
args.instance_prompt,
truncation=True,
padding="max_length",
max_length=train_dataset.tokenizer.model_max_length,
return_tensors="pt",
).input_ids.repeat(len(original_images), 1)
def get_one_mask(image):
random_scale = train_dataset.size
if train_dataset.aug:
random_scale = (
np.random.randint(train_dataset.size // 3, train_dataset.size + 1)
if np.random.uniform() < 0.66
else np.random.randint(int(1.2 * train_dataset.size), int(1.4 * train_dataset.size))
)
_, one_mask = train_dataset.preprocess(image, random_scale, train_dataset.interpolation)
one_mask = torch.from_numpy(one_mask)
if args.with_prior_preservation:
class_mask = torch.ones_like(one_mask)
one_mask += class_mask
return one_mask
images_open_list = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
mask_list = []
for image in images_open_list:
mask_list.append(get_one_mask(image))
mask = torch.stack(mask_list)
mask = mask.to(memory_format=torch.contiguous_format).float()
mask = mask.unsqueeze(1)
del images_open_list
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare(
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask
)
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num pertubed_images = {len(pertubed_images)}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.max_train_steps):
unet.train()
for _ in range(1):
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
# Convert images to latent space
pertubed_images.requires_grad = True
latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# unet.zero_grad()
# text_encoder.zero_grad()
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
mask = mask.to(accelerator.device)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64])
#loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
custom_diffusion_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
alpha = args.alpha
eps = args.eps
adv_images = pertubed_images + alpha * pertubed_images.grad.sign()
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
logger.info("***** Final save of perturbed images *****")
save_folder = args.output_dir
noised_imgs = pertubed_images.detach().cpu()
img_names = [
str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path
]
num_images_to_save = len(img_names)
for i in range(num_images_to_save):
img_pixel = noised_imgs[i]
img_name = img_names[i]
save_path = os.path.join(save_folder, f"perturbed_{img_name}")
# 图像转换和保存
Image.fromarray(
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
).save(save_path)
logger.info(f"Saved {num_images_to_save} final perturbed images to {save_folder}")
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
print("<-------end-------->")

@ -0,0 +1,274 @@
import argparse
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers import AutoencoderKL
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of updating steps",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
'--eps',
type=float,
default=12.75,
help='pertubation budget'
)
parser.add_argument(
'--step_size',
type=float,
default=1/255,
help='step size of each update'
)
parser.add_argument(
'--attack_type',
choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'],
help='what is the attack target'
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
class PIDDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
size=512,
center_crop=False
):
self.size = size
self.center_crop = center_crop
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),])
def __len__(self):
return self.num_instance_images
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example['index'] = index % self.num_instance_images
example['pixel_values'] = self.image_transforms(instance_image)
return example
def main(args):
# Set random seed
if args.seed is not None:
torch.manual_seed(args.seed)
weight_dtype = torch.float32
device = torch.device('cuda')
# VAE encoder
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
# Dataset and DataLoaders creation:
dataset = PIDDataset(
instance_data_root=args.instance_data_dir,
size=args.resolution,
center_crop=args.center_crop,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1, # some parts of code don't support batching
shuffle=True,
num_workers=args.dataloader_num_workers,
)
# Wrapper of the perturbations generator
class AttackModel(torch.nn.Module):
def __init__(self):
super().__init__()
to_tensor = transforms.ToTensor()
self.epsilon = args.eps/255
self.delta = [torch.empty_like(to_tensor(Image.open(path))).uniform_(-self.epsilon, self.epsilon)
for path in dataset.instance_images_path]
self.size = dataset.size
def forward(self, vae, x, index, poison=False):
# Check whether we need to add perturbation
if poison:
self.delta[index].requires_grad_(True)
x = x + self.delta[index].to(dtype=weight_dtype)
# Normalize to [-1, 1]
input_x = 2 * x - 1
return vae.encode(input_x.to(device))
attackmodel = AttackModel()
# Just to zero-out the gradient
optimizer = torch.optim.SGD(attackmodel.delta, lr=0)
# Progress bar
progress_bar = tqdm(range(0, args.max_train_steps), desc="Steps")
# Make sure the dir exists
os.makedirs(args.output_dir, exist_ok=True)
# Start optimizing the perturbation
for step in progress_bar:
total_loss = 0.0
for batch in dataloader:
# Save images
if step%25 == 0:
to_image = transforms.ToPILImage()
for i in range(0, len(dataset.instance_images_path)):
img = dataset[i]['pixel_values']
img = to_image(img + attackmodel.delta[i])
# 使用原文件名,添加perturbed_前缀
original_filename = Path(dataset.instance_images_path[i]).stem
img.save(os.path.join(args.output_dir, f"perturbed_{original_filename}.png"))
# Select target loss
clean_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], False)
poison_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], True)
clean_latent = clean_embedding.latent_dist
poison_latent = poison_embedding.latent_dist
if args.attack_type == 'var':
loss = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
elif args.attack_type == 'mean':
loss = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
elif args.attack_type == 'KL':
sigma_2, mu_2 = poison_latent.std, poison_latent.mean
sigma_1, mu_1 = clean_latent.std, clean_latent.mean
KL_diver = torch.log(sigma_2 / sigma_1) - 0.5 + (sigma_1 ** 2 + (mu_1 - mu_2) ** 2) / (2 * sigma_2 ** 2)
loss = KL_diver.flatten().mean()
elif args.attack_type == 'latent_vector':
clean_vector = clean_latent.sample()
poison_vector = poison_latent.sample()
loss = F.mse_loss(clean_vector, poison_vector, reduction="mean")
elif args.attack_type == 'add':
loss_2 = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
loss_1 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
loss = loss_1 + loss_2
elif args.attack_type == 'add-log':
loss_1 = F.mse_loss(clean_latent.var.log(), poison_latent.var.log(), reduction="mean")
loss_2 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction='mean')
loss = loss_1 + loss_2
optimizer.zero_grad()
loss.backward()
# Perform PGD update on the loss
delta = attackmodel.delta[batch['index']]
delta.requires_grad_(False)
delta += delta.grad.sign() * 1/255
delta = torch.clamp(delta, -attackmodel.epsilon, attackmodel.epsilon)
delta = torch.clamp(delta, -batch['pixel_values'].detach().cpu(), 1-batch['pixel_values'].detach().cpu())
attackmodel.delta[batch['index']] = delta.detach().squeeze(0)
total_loss += loss.detach().cpu()
# Logging steps
logs = {"loss": total_loss.item()}
progress_bar.set_postfix(**logs)
if __name__ == "__main__":
args = parse_args()
main(args)

File diff suppressed because it is too large Load Diff

@ -1,176 +0,0 @@
"""
对抗性扰动算法引擎
实现各种加噪算法的虚拟版本
"""
import os
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import uuid
from flask import current_app
class PerturbationEngine:
"""对抗性扰动处理引擎"""
@staticmethod
def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None):
"""
应用对抗性扰动
Args:
image_path: 原始图片路径
algorithm: 算法名称 (simac, caat, pid)
epsilon: 扰动强度
use_strong_protection: 是否使用防净化版本
Returns:
处理后图片的路径
"""
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片文件不存在: {image_path}")
# 加载图片
with Image.open(image_path) as img:
# 转换为RGB模式
if img.mode != 'RGB':
img = img.convert('RGB')
# 根据算法选择处理方法
if algorithm == 'simac':
perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection)
elif algorithm == 'caat':
perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection)
elif algorithm == 'pid':
perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection)
else:
raise ValueError(f"不支持的算法: {algorithm}")
# 使用输入的output_path参数
if output_path is None:
# 如果没有提供输出路径,使用默认路径
from flask import current_app
project_root = os.path.dirname(current_app.root_path)
perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER'])
os.makedirs(perturbed_dir, exist_ok=True)
file_extension = os.path.splitext(image_path)[1]
output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}"
output_path = os.path.join(perturbed_dir, output_filename)
# 保存处理后的图片
perturbed_img.save(output_path, quality=95)
return output_path
except Exception as e:
print(f"应用扰动时出错: {str(e)}")
return None
@staticmethod
def _apply_simac(img, epsilon, use_strong_protection):
"""
SimAC算法的虚拟实现
Simple Anti-Customization Method for Protecting Face Privacy
"""
# 将PIL图像转换为numpy数组
img_array = np.array(img, dtype=np.float32)
# 生成随机噪声(模拟对抗性扰动)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:添加更复杂的扰动模式
noise = np.random.normal(0, noise_scale * 0.8, img_array.shape)
# 添加结构化噪声
h, w = img_array.shape[:2]
for i in range(0, h, 8):
for j in range(0, w, 8):
block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3))
noise[i:i+8, j:j+8] += block_noise
else:
# 标准版本:简单高斯噪声
noise = np.random.normal(0, noise_scale, img_array.shape)
# 应用噪声
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
# 转换回PIL图像
result_img = Image.fromarray(perturbed_array)
# 轻微的图像增强以模拟算法特性
enhancer = ImageEnhance.Contrast(result_img)
result_img = enhancer.enhance(1.02)
return result_img
@staticmethod
def _apply_caat(img, epsilon, use_strong_protection):
"""
CAAT算法的虚拟实现
Perturbing Attention Gives You More Bang for the Buck
"""
img_array = np.array(img, dtype=np.float32)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:注意力区域重点扰动
# 模拟注意力图(简单的边缘检测)
gray_img = img.convert('L')
edge_img = gray_img.filter(ImageFilter.FIND_EDGES)
attention_map = np.array(edge_img, dtype=np.float32) / 255.0
# 在注意力区域添加更强的噪声
noise = np.random.normal(0, noise_scale * 0.6, img_array.shape)
for c in range(3):
noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape)
else:
# 标准版本:均匀分布噪声
noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape)
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
result_img = Image.fromarray(perturbed_array)
# 轻微模糊以模拟注意力扰动效果
result_img = result_img.filter(ImageFilter.BoxBlur(0.5))
return result_img
@staticmethod
def _apply_pid(img, epsilon, use_strong_protection):
"""
PID算法的虚拟实现
Prompt-Independent Data Protection Against Latent Diffusion Models
"""
img_array = np.array(img, dtype=np.float32)
noise_scale = epsilon / 255.0
if use_strong_protection:
# 防净化版本:频域扰动
# 简单的频域变换模拟
noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape)
# 添加周期性扰动
h, w = img_array.shape[:2]
for i in range(h):
for j in range(w):
periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1)
noise[i, j] += periodic_noise
else:
# 标准版本:拉普拉斯噪声
noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape)
perturbed_array = img_array + noise * 255.0
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
result_img = Image.fromarray(perturbed_array)
# 色彩微调以模拟潜在空间扰动
enhancer = ImageEnhance.Color(result_img)
result_img = enhancer.enhance(0.98)
return result_img

@ -0,0 +1,90 @@
"""
ASPL算法虚拟实现
用于测试后端流程不执行实际的扰动生成
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def main():
parser = argparse.ArgumentParser(description="ASPL虚拟算法脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true')
parser.add_argument('--instance_data_dir_for_train', required=True)
parser.add_argument('--instance_data_dir_for_adversarial', required=True)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--output_dir', required=True)
parser.add_argument('--center_crop', action='store_true')
parser.add_argument('--with_prior_preservation', action='store_true')
parser.add_argument('--prior_loss_weight', type=float, default=1.0)
parser.add_argument('--resolution', type=int, default=384)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--max_train_steps', type=int, default=50)
parser.add_argument('--max_f_train_steps', type=int, default=3)
parser.add_argument('--max_adv_train_steps', type=int, default=6)
parser.add_argument('--checkpointing_iterations', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=5e-7)
parser.add_argument('--pgd_alpha', type=float, default=0.005)
parser.add_argument('--pgd_eps', type=float, default=8)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL ASPL] 虚拟算法执行开始")
print("=" * 80)
print(f"[VIRTUAL] 算法名称: ASPL (Anti-DreamBooth)")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 算法参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 输入训练目录: {args.instance_data_dir_for_train}")
print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}")
print(f" - 输出目录: {args.output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - 扰动强度(pgd_eps): {args.pgd_eps}")
print(f" - PGD alpha: {args.pgd_alpha}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 复制图片到输出目录(模拟扰动生成)
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片")
copied_count = 0
for image_path in image_files:
filename = os.path.basename(image_path)
# 添加perturbed_前缀
name, ext = os.path.splitext(filename)
perturbed_filename = f"perturbed_{name}{ext}"
output_path = os.path.join(args.output_dir, perturbed_filename)
shutil.copy(image_path, output_path)
copied_count += 1
print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}")
print("-" * 80)
print(f"[VIRTUAL] 成功处理 {copied_count} 张图片")
print("[VIRTUAL] 虚拟算法执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,86 @@
"""
CAAT算法虚拟实现
用于测试后端流程不执行实际的扰动生成
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def main():
parser = argparse.ArgumentParser(description="CAAT虚拟算法脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true')
parser.add_argument('--instance_data_dir_for_train', required=True)
parser.add_argument('--instance_data_dir_for_adversarial', required=True)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--output_dir', required=True)
parser.add_argument('--resolution', type=int, default=512)
parser.add_argument('--train_batch_size', type=int, default=2)
parser.add_argument('--max_train_steps', type=int, default=800)
parser.add_argument('--learning_rate', type=float, default=1e-5)
parser.add_argument('--lr_warmup_steps', type=int, default=0)
parser.add_argument('--hflip', action='store_true')
parser.add_argument('--mixed_precision', type=str, default='bf16')
parser.add_argument('--alpha', type=float, default=5e-3)
parser.add_argument('--pgd_eps', type=float, default=16)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL CAAT] 虚拟算法执行开始")
print("=" * 80)
print(f"[VIRTUAL] 算法名称: CAAT (Class-wise Adversarial Attack)")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 算法参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 输入训练目录: {args.instance_data_dir_for_train}")
print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}")
print(f" - 输出目录: {args.output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - 扰动强度(pgd_eps): {args.pgd_eps}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 复制图片到输出目录(模拟扰动生成)
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片")
copied_count = 0
for image_path in image_files:
filename = os.path.basename(image_path)
# 添加perturbed_前缀
name, ext = os.path.splitext(filename)
perturbed_filename = f"perturbed_{name}{ext}"
output_path = os.path.join(args.output_dir, perturbed_filename)
shutil.copy(image_path, output_path)
copied_count += 1
print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}")
print("-" * 80)
print(f"[VIRTUAL] 成功处理 {copied_count} 张图片")
print("[VIRTUAL] 虚拟算法执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,84 @@
"""
PID算法虚拟实现
用于测试后端流程不执行实际的扰动生成
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def main():
parser = argparse.ArgumentParser(description="PID虚拟算法脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true')
parser.add_argument('--instance_data_dir_for_train', required=True)
parser.add_argument('--instance_data_dir_for_adversarial', required=True)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--output_dir', required=True)
parser.add_argument('--resolution', type=int, default=512)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--max_train_steps', type=int, default=600)
parser.add_argument('--center_crop', action='store_true')
parser.add_argument('--attack_type', type=str, default='add-log')
parser.add_argument('--learning_rate', type=float, default=3e-6)
parser.add_argument('--pgd_eps', type=float, default=4)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL PID] 虚拟算法执行开始")
print("=" * 80)
print(f"[VIRTUAL] 算法名称: PID (Perturbation Identity Defense)")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 算法参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 输入训练目录: {args.instance_data_dir_for_train}")
print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}")
print(f" - 输出目录: {args.output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - 扰动强度(pgd_eps): {args.pgd_eps}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 复制图片到输出目录(模拟扰动生成)
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片")
copied_count = 0
for image_path in image_files:
filename = os.path.basename(image_path)
# 添加perturbed_前缀
name, ext = os.path.splitext(filename)
perturbed_filename = f"perturbed_{name}{ext}"
output_path = os.path.join(args.output_dir, perturbed_filename)
shutil.copy(image_path, output_path)
copied_count += 1
print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}")
print("-" * 80)
print(f"[VIRTUAL] 成功处理 {copied_count} 张图片")
print("[VIRTUAL] 虚拟算法执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,89 @@
"""
SimAC算法虚拟实现
用于测试后端流程不执行实际的扰动生成
"""
import argparse
import os
import sys
import platform
import shutil
import glob
def main():
parser = argparse.ArgumentParser(description="SimAC虚拟算法脚本")
parser.add_argument('--pretrained_model_name_or_path', required=True)
parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true')
parser.add_argument('--instance_data_dir_for_train', required=True)
parser.add_argument('--instance_data_dir_for_adversarial', required=True)
parser.add_argument('--instance_prompt', default='a photo of sks person')
parser.add_argument('--class_data_dir', required=True)
parser.add_argument('--num_class_images', type=int, default=200)
parser.add_argument('--class_prompt', default='a photo of person')
parser.add_argument('--output_dir', required=True)
parser.add_argument('--center_crop', action='store_true')
parser.add_argument('--with_prior_preservation', action='store_true')
parser.add_argument('--prior_loss_weight', type=float, default=1.0)
parser.add_argument('--resolution', type=int, default=512)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--max_train_steps', type=int, default=1000)
parser.add_argument('--max_f_train_steps', type=int, default=3)
parser.add_argument('--max_adv_train_steps', type=int, default=6)
parser.add_argument('--checkpointing_iterations', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=5e-6)
parser.add_argument('--pgd_alpha', type=float, default=0.005)
parser.add_argument('--pgd_eps', type=float, default=8)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
print("=" * 80)
print("[VIRTUAL SIMAC] 虚拟算法执行开始")
print("=" * 80)
print(f"[VIRTUAL] 算法名称: SimAC (Simple and Effective)")
print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python可执行文件: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}")
print("-" * 80)
print("[VIRTUAL] 算法参数:")
print(f" - 模型路径: {args.pretrained_model_name_or_path}")
print(f" - 输入训练目录: {args.instance_data_dir_for_train}")
print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}")
print(f" - 输出目录: {args.output_dir}")
print(f" - 类别目录: {args.class_data_dir}")
print(f" - 分辨率: {args.resolution}")
print(f" - 扰动强度(pgd_eps): {args.pgd_eps}")
print(f" - 最大训练步数: {args.max_train_steps}")
print(f" - 学习率: {args.learning_rate}")
print("-" * 80)
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 复制图片到输出目录(模拟扰动生成)
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext)))
image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper())))
print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片")
copied_count = 0
for image_path in image_files:
filename = os.path.basename(image_path)
# 添加perturbed_前缀
name, ext = os.path.splitext(filename)
perturbed_filename = f"perturbed_{name}{ext}"
output_path = os.path.join(args.output_dir, perturbed_filename)
shutil.copy(image_path, output_path)
copied_count += 1
print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}")
print("-" * 80)
print(f"[VIRTUAL] 成功处理 {copied_count} 张图片")
print("[VIRTUAL] 虚拟算法执行完成")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,25 @@
import argparse
import os
import sys
import platform
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="虚拟扰动算法脚本")
parser.add_argument('--algorithm_code', required=True)
parser.add_argument('--batch_id', required=True)
parser.add_argument('--epsilon', required=True)
parser.add_argument('--use_strong_protection', required=True)
parser.add_argument('--input_dir', required=True)
parser.add_argument('--output_dir', required=True)
args = parser.parse_args()
print(f"[VIRTUAL] 当前算法: {args.algorithm_code}")
print(f"[VIRTUAL] 批次ID: {args.batch_id}")
print(f"[VIRTUAL] 扰动强度: {args.epsilon}")
print(f"[VIRTUAL] 是否强防护: {args.use_strong_protection}")
print(f"[VIRTUAL] 输入目录: {args.input_dir}")
print(f"[VIRTUAL] 输出目录: {args.output_dir}")
print(f"[VIRTUAL] 当前conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}")
print(f"[VIRTUAL] Python环境: {sys.executable}")
print(f"[VIRTUAL] Python版本: {platform.python_version()}")
# 不做图片处理图片复制由worker完成

@ -1,239 +1,239 @@
"""
管理员控制器
处理管理员功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from app import db
from app.models import User, Batch, Image
admin_bp = Blueprint('admin', __name__)
def admin_required(f):
"""管理员权限装饰器"""
from functools import wraps
@wraps(f)
def decorated_function(*args, **kwargs):
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
if not user or user.role != 'admin':
return jsonify({'error': '需要管理员权限'}), 403
return f(*args, **kwargs)
return decorated_function
@admin_bp.route('/users', methods=['GET'])
@jwt_required()
@admin_required
def list_users():
"""获取用户列表"""
try:
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'users': [user.to_dict() for user in users.items],
'total': users.total,
'pages': users.pages,
'current_page': page
}), 200
except Exception as e:
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
@jwt_required()
@admin_required
def get_user_detail(user_id):
"""获取用户详情"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 获取用户统计信息
total_tasks = Batch.query.filter_by(user_id=user_id).count()
total_images = Image.query.filter_by(user_id=user_id).count()
user_dict = user.to_dict()
user_dict['stats'] = {
'total_tasks': total_tasks,
'total_images': total_images
}
return jsonify({'user': user_dict}), 200
except Exception as e:
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
@admin_bp.route('/users', methods=['POST'])
@jwt_required()
@admin_required
def create_user():
"""创建用户"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
role = data.get('role', 'user')
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已存在
if email and User.query.filter_by(email=email).first():
return jsonify({'error': '邮箱已被使用'}), 400
# 创建用户
user = User(
username=username,
email=email,
role=role,
max_concurrent_tasks=max_concurrent_tasks
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return jsonify({
'message': '用户创建成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
@jwt_required()
@admin_required
def update_user(user_id):
"""更新用户信息"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
# 更新字段
if 'username' in data:
new_username = data['username']
if new_username != user.username:
if User.query.filter_by(username=new_username).first():
return jsonify({'error': '用户名已存在'}), 400
user.username = new_username
if 'email' in data:
new_email = data['email']
if new_email != user.email:
if User.query.filter_by(email=new_email).first():
return jsonify({'error': '邮箱已被使用'}), 400
user.email = new_email
if 'role' in data:
user.role = data['role']
if 'max_concurrent_tasks' in data:
user.max_concurrent_tasks = data['max_concurrent_tasks']
if 'is_active' in data:
user.is_active = bool(data['is_active'])
if 'password' in data and data['password']:
user.set_password(data['password'])
db.session.commit()
return jsonify({
'message': '用户信息更新成功',
'user': user.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
@jwt_required()
@admin_required
def delete_user(user_id):
"""删除用户"""
try:
current_user_id = get_jwt_identity()
# 不能删除自己
if user_id == current_user_id:
return jsonify({'error': '不能删除自己的账户'}), 400
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 删除用户(级联删除相关数据)
db.session.delete(user)
db.session.commit()
return jsonify({'message': '用户删除成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
@admin_bp.route('/stats', methods=['GET'])
@jwt_required()
@admin_required
def get_system_stats():
"""获取系统统计信息"""
try:
from app.models import EvaluationResult
total_users = User.query.count()
active_users = User.query.filter_by(is_active=True).count()
admin_users = User.query.filter_by(role='admin').count()
total_tasks = Batch.query.count()
completed_tasks = Batch.query.filter_by(status='completed').count()
processing_tasks = Batch.query.filter_by(status='processing').count()
failed_tasks = Batch.query.filter_by(status='failed').count()
total_images = Image.query.count()
total_evaluations = EvaluationResult.query.count()
return jsonify({
'stats': {
'users': {
'total': total_users,
'active': active_users,
'admin': admin_users
},
'tasks': {
'total': total_tasks,
'completed': completed_tasks,
'processing': processing_tasks,
'failed': failed_tasks
},
'images': {
'total': total_images
},
'evaluations': {
'total': total_evaluations
}
}
}), 200
except Exception as e:
"""
管理员控制器
处理管理员功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from app import db
from app.database import User, Batch, Image
admin_bp = Blueprint('admin', __name__)
def admin_required(f):
"""管理员权限装饰器"""
from functools import wraps
@wraps(f)
def decorated_function(*args, **kwargs):
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
if not user or user.role != 'admin':
return jsonify({'error': '需要管理员权限'}), 403
return f(*args, **kwargs)
return decorated_function
@admin_bp.route('/users', methods=['GET'])
@jwt_required()
@admin_required
def list_users():
"""获取用户列表"""
try:
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'users': [user.to_dict() for user in users.items],
'total': users.total,
'pages': users.pages,
'current_page': page
}), 200
except Exception as e:
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
@jwt_required()
@admin_required
def get_user_detail(user_id):
"""获取用户详情"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 获取用户统计信息
total_tasks = Batch.query.filter_by(user_id=user_id).count()
total_images = Image.query.filter_by(user_id=user_id).count()
user_dict = user.to_dict()
user_dict['stats'] = {
'total_tasks': total_tasks,
'total_images': total_images
}
return jsonify({'user': user_dict}), 200
except Exception as e:
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
@admin_bp.route('/users', methods=['POST'])
@jwt_required()
@admin_required
def create_user():
"""创建用户"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
role = data.get('role', 'user')
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已存在
if email and User.query.filter_by(email=email).first():
return jsonify({'error': '邮箱已被使用'}), 400
# 创建用户
user = User(
username=username,
email=email,
role=role,
max_concurrent_tasks=max_concurrent_tasks
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return jsonify({
'message': '用户创建成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
@jwt_required()
@admin_required
def update_user(user_id):
"""更新用户信息"""
try:
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
# 更新字段
if 'username' in data:
new_username = data['username']
if new_username != user.username:
if User.query.filter_by(username=new_username).first():
return jsonify({'error': '用户名已存在'}), 400
user.username = new_username
if 'email' in data:
new_email = data['email']
if new_email != user.email:
if User.query.filter_by(email=new_email).first():
return jsonify({'error': '邮箱已被使用'}), 400
user.email = new_email
if 'role' in data:
user.role = data['role']
if 'max_concurrent_tasks' in data:
user.max_concurrent_tasks = data['max_concurrent_tasks']
if 'is_active' in data:
user.is_active = bool(data['is_active'])
if 'password' in data and data['password']:
user.set_password(data['password'])
db.session.commit()
return jsonify({
'message': '用户信息更新成功',
'user': user.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
@jwt_required()
@admin_required
def delete_user(user_id):
"""删除用户"""
try:
current_user_id = get_jwt_identity()
# 不能删除自己
if user_id == current_user_id:
return jsonify({'error': '不能删除自己的账户'}), 400
user = User.query.get(user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
# 删除用户(级联删除相关数据)
db.session.delete(user)
db.session.commit()
return jsonify({'message': '用户删除成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
@admin_bp.route('/stats', methods=['GET'])
@jwt_required()
@admin_required
def get_system_stats():
"""获取系统统计信息"""
try:
from app.database import EvaluationResult
total_users = User.query.count()
active_users = User.query.filter_by(is_active=True).count()
admin_users = User.query.filter_by(role='admin').count()
total_tasks = Batch.query.count()
completed_tasks = Batch.query.filter_by(status='completed').count()
processing_tasks = Batch.query.filter_by(status='processing').count()
failed_tasks = Batch.query.filter_by(status='failed').count()
total_images = Image.query.count()
total_evaluations = EvaluationResult.query.count()
return jsonify({
'stats': {
'users': {
'total': total_users,
'active': active_users,
'admin': admin_users
},
'tasks': {
'total': total_tasks,
'completed': completed_tasks,
'processing': processing_tasks,
'failed': failed_tasks
},
'images': {
'total': total_images
},
'evaluations': {
'total': total_evaluations
}
}
}), 200
except Exception as e:
return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500

@ -1,156 +1,156 @@
"""
用户认证控制器
处理注册登录密码修改等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from app import db
from app.models import User, UserConfig
from app.services.auth_service import AuthService
from functools import wraps
import re
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
try:
current_user_id = int(get_jwt_identity())
return f(*args, current_user_id=current_user_id, **kwargs)
except (TypeError, ValueError):
return jsonify({'error': '无效的用户身份标识'}), 401
return jwt_required()(wrapped)
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/register', methods=['POST'])
def register():
"""用户注册"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
# 验证输入
if not username or not password or not email:
return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400
# 验证邮箱格式
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
return jsonify({'error': '邮箱格式不正确'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已注册
if User.query.filter_by(email=email).first():
return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400
# 创建用户
user = User(username=username, email=email)
user.set_password(password)
db.session.add(user)
db.session.commit()
# 创建用户默认配置
user_config = UserConfig(user_id=user.id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'message': '注册成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'注册失败: {str(e)}'}), 500
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 查找用户
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return jsonify({'error': '用户名或密码错误'}), 401
if not user.is_active:
return jsonify({'error': '账户已被禁用'}), 401
# 创建访问令牌 - 确保用户ID为字符串类型
access_token = create_access_token(identity=str(user.id))
return jsonify({
'message': '登录成功',
'access_token': access_token,
'user': user.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'登录失败: {str(e)}'}), 500
@auth_bp.route('/change-password', methods=['POST'])
@int_jwt_required
def change_password(current_user_id):
"""修改密码"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
old_password = data.get('old_password')
new_password = data.get('new_password')
if not old_password or not new_password:
return jsonify({'error': '旧密码和新密码不能为空'}), 400
# 验证旧密码
if not user.check_password(old_password):
return jsonify({'error': '旧密码错误'}), 401
# 设置新密码
user.set_password(new_password)
db.session.commit()
return jsonify({'message': '密码修改成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'密码修改失败: {str(e)}'}), 500
@auth_bp.route('/profile', methods=['GET'])
@int_jwt_required
def get_profile(current_user_id):
"""获取用户信息"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
return jsonify({'user': user.to_dict()}), 200
except Exception as e:
return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500
@auth_bp.route('/logout', methods=['POST'])
@jwt_required()
def logout():
"""用户登出客户端删除token即可"""
"""
用户认证控制器
处理注册登录密码修改等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from app import db
from app.database import User, UserConfig
from app.services.auth_service import AuthService
from functools import wraps
import re
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
try:
current_user_id = int(get_jwt_identity())
return f(*args, current_user_id=current_user_id, **kwargs)
except (TypeError, ValueError):
return jsonify({'error': '无效的用户身份标识'}), 401
return jwt_required()(wrapped)
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/register', methods=['POST'])
def register():
"""用户注册"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
email = data.get('email')
# 验证输入
if not username or not password or not email:
return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400
# 验证邮箱格式
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
return jsonify({'error': '邮箱格式不正确'}), 400
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': '用户名已存在'}), 400
# 检查邮箱是否已注册
if User.query.filter_by(email=email).first():
return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400
# 创建用户
user = User(username=username, email=email)
user.set_password(password)
db.session.add(user)
db.session.commit()
# 创建用户默认配置
user_config = UserConfig(user_id=user.id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'message': '注册成功',
'user': user.to_dict()
}), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': f'注册失败: {str(e)}'}), 500
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
try:
data = request.get_json()
username = data.get('username')
password = data.get('password')
if not username or not password:
return jsonify({'error': '用户名和密码不能为空'}), 400
# 查找用户
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return jsonify({'error': '用户名或密码错误'}), 401
if not user.is_active:
return jsonify({'error': '账户已被禁用'}), 401
# 创建访问令牌 - 确保用户ID为字符串类型
access_token = create_access_token(identity=str(user.id))
return jsonify({
'message': '登录成功',
'access_token': access_token,
'user': user.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'登录失败: {str(e)}'}), 500
@auth_bp.route('/change-password', methods=['POST'])
@int_jwt_required
def change_password(current_user_id):
"""修改密码"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
data = request.get_json()
old_password = data.get('old_password')
new_password = data.get('new_password')
if not old_password or not new_password:
return jsonify({'error': '旧密码和新密码不能为空'}), 400
# 验证旧密码
if not user.check_password(old_password):
return jsonify({'error': '旧密码错误'}), 401
# 设置新密码
user.set_password(new_password)
db.session.commit()
return jsonify({'message': '密码修改成功'}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'密码修改失败: {str(e)}'}), 500
@auth_bp.route('/profile', methods=['GET'])
@int_jwt_required
def get_profile(current_user_id):
"""获取用户信息"""
try:
user = User.query.get(current_user_id)
if not user:
return jsonify({'error': '用户不存在'}), 404
return jsonify({'user': user.to_dict()}), 200
except Exception as e:
return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500
@auth_bp.route('/logout', methods=['POST'])
@jwt_required()
def logout():
"""用户登出客户端删除token即可"""
return jsonify({'message': '登出成功'}), 200

@ -1,177 +1,177 @@
"""
演示图片控制器
处理预设图像对比图的展示功能
"""
from flask import Blueprint, send_file, jsonify, current_app
from flask_jwt_extended import jwt_required
from app.models import PerturbationConfig, FinetuneConfig
import os
import glob
demo_bp = Blueprint('demo', __name__)
@demo_bp.route('/images', methods=['GET'])
def list_demo_images():
"""获取所有演示图片列表"""
try:
demo_images = []
# 获取演示原始图片 - 修正路径构建
# 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
if os.path.exists(original_folder):
original_files = glob.glob(os.path.join(original_folder, '*'))
for file_path in original_files:
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
# 查找对应的加噪图片
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*"))
# 查找对应的对比图
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*"))
demo_image = {
'id': name_without_ext,
'name': name_without_ext,
'original': f"/api/demo/image/original/{filename}",
'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files],
'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files]
}
demo_images.append(demo_image)
return jsonify({
'demo_images': demo_images,
'total': len(demo_images)
}), 200
except Exception as e:
return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500
@demo_bp.route('/image/original/<filename>', methods=['GET'])
def get_demo_original_image(filename):
"""获取演示原始图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500
@demo_bp.route('/image/perturbed/<filename>', methods=['GET'])
def get_demo_perturbed_image(filename):
"""获取演示加噪图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500
@demo_bp.route('/image/comparison/<filename>', methods=['GET'])
def get_demo_comparison_image(filename):
"""获取演示对比图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500
@demo_bp.route('/algorithms', methods=['GET'])
def get_demo_algorithms():
"""获取演示算法信息"""
try:
# 从数据库获取扰动算法
perturbation_algorithms = []
perturbation_configs = PerturbationConfig.query.all()
for config in perturbation_configs:
perturbation_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'perturbation',
'description': config.description,
'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None
})
# 从数据库获取微调算法
finetune_algorithms = []
finetune_configs = FinetuneConfig.query.all()
for config in finetune_configs:
finetune_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'finetune',
'description': config.description
})
return jsonify({
'perturbation_algorithms': perturbation_algorithms,
'finetune_algorithms': finetune_algorithms,
'evaluation_metrics': [
{'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'},
{'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'},
{'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'},
{'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'}
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500
@demo_bp.route('/stats', methods=['GET'])
def get_demo_stats():
"""获取演示统计信息"""
try:
# 统计演示图片数量
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0
perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0
comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0
# 统计数据库中的算法数量
perturbation_count = PerturbationConfig.query.count()
finetune_count = FinetuneConfig.query.count()
total_algorithms = perturbation_count + finetune_count
return jsonify({
'demo_stats': {
'original_images': original_count,
'perturbed_images': perturbed_count,
'comparison_images': comparison_count,
'supported_algorithms': total_algorithms,
'perturbation_algorithms': perturbation_count,
'finetune_algorithms': finetune_count,
'evaluation_metrics': 4
}
}), 200
except Exception as e:
"""
演示图片控制器
处理预设图像对比图的展示功能
"""
from flask import Blueprint, send_file, jsonify, current_app
from flask_jwt_extended import jwt_required
from app.database import PerturbationConfig, FinetuneConfig
import os
import glob
demo_bp = Blueprint('demo', __name__)
@demo_bp.route('/images', methods=['GET'])
def list_demo_images():
"""获取所有演示图片列表"""
try:
demo_images = []
# 获取演示原始图片 - 修正路径构建
# 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
if os.path.exists(original_folder):
original_files = glob.glob(os.path.join(original_folder, '*'))
for file_path in original_files:
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
# 查找对应的加噪图片
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*"))
# 查找对应的对比图
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*"))
demo_image = {
'id': name_without_ext,
'name': name_without_ext,
'original': f"/api/demo/image/original/{filename}",
'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files],
'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files]
}
demo_images.append(demo_image)
return jsonify({
'demo_images': demo_images,
'total': len(demo_images)
}), 200
except Exception as e:
return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500
@demo_bp.route('/image/original/<filename>', methods=['GET'])
def get_demo_original_image(filename):
"""获取演示原始图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500
@demo_bp.route('/image/perturbed/<filename>', methods=['GET'])
def get_demo_perturbed_image(filename):
"""获取演示加噪图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500
@demo_bp.route('/image/comparison/<filename>', methods=['GET'])
def get_demo_comparison_image(filename):
"""获取演示对比图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500
@demo_bp.route('/algorithms', methods=['GET'])
def get_demo_algorithms():
"""获取演示算法信息"""
try:
# 从数据库获取扰动算法
perturbation_algorithms = []
perturbation_configs = PerturbationConfig.query.all()
for config in perturbation_configs:
perturbation_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'perturbation',
'description': config.description,
'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None
})
# 从数据库获取微调算法
finetune_algorithms = []
finetune_configs = FinetuneConfig.query.all()
for config in finetune_configs:
finetune_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'finetune',
'description': config.description
})
return jsonify({
'perturbation_algorithms': perturbation_algorithms,
'finetune_algorithms': finetune_algorithms,
'evaluation_metrics': [
{'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'},
{'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'},
{'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'},
{'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'}
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500
@demo_bp.route('/stats', methods=['GET'])
def get_demo_stats():
"""获取演示统计信息"""
try:
# 统计演示图片数量
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0
perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0
comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0
# 统计数据库中的算法数量
perturbation_count = PerturbationConfig.query.count()
finetune_count = FinetuneConfig.query.count()
total_algorithms = perturbation_count + finetune_count
return jsonify({
'demo_stats': {
'original_images': original_count,
'perturbed_images': perturbed_count,
'comparison_images': comparison_count,
'supported_algorithms': total_algorithms,
'perturbation_algorithms': perturbation_count,
'finetune_algorithms': finetune_count,
'evaluation_metrics': 4
}
}), 200
except Exception as e:
return jsonify({'error': f'获取统计信息失败: {str(e)}'}), 500

@ -1,203 +1,203 @@
"""
图像管理控制器
处理图像下载查看等功能
"""
from flask import Blueprint, send_file, jsonify, request, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from app.models import Image, EvaluationResult
from app.services.image_service import ImageService
import os
image_bp = Blueprint('image', __name__)
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@jwt_required()
def get_image_file(image_id):
"""获取图片文件"""
try:
current_user_id = get_jwt_identity()
# 查找图片记录
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 检查文件是否存在
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(image.file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取图片失败: {str(e)}'}), 500
@image_bp.route('/download/<int:image_id>', methods=['GET'])
@jwt_required()
def download_image(image_id):
"""下载图片文件"""
try:
current_user_id = get_jwt_identity()
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(
image.file_path,
as_attachment=True,
download_name=image.original_filename or f"image_{image_id}.jpg"
)
except Exception as e:
return jsonify({'error': f'下载图片失败: {str(e)}'}), 500
@image_bp.route('/batch/<int:batch_id>/download', methods=['GET'])
@jwt_required()
def download_batch_images(batch_id):
"""批量下载任务中的加噪后图片"""
try:
current_user_id = get_jwt_identity()
# 获取任务中的加噪图片
perturbed_images = Image.query.join(Image.image_type).filter(
Image.batch_id == batch_id,
Image.user_id == current_user_id,
Image.image_type.has(type_code='perturbed')
).all()
if not perturbed_images:
return jsonify({'error': '没有找到加噪后的图片'}), 404
# 创建ZIP文件
import zipfile
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file:
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
for image in perturbed_images:
if os.path.exists(image.file_path):
arcname = image.original_filename or f"perturbed_{image.id}.jpg"
zip_file.write(image.file_path, arcname)
return send_file(
tmp_file.name,
as_attachment=True,
download_name=f"batch_{batch_id}_perturbed_images.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({'error': f'批量下载失败: {str(e)}'}), 500
@image_bp.route('/<int:image_id>/evaluations', methods=['GET'])
@jwt_required()
def get_image_evaluations(image_id):
"""获取图片的评估结果"""
try:
current_user_id = get_jwt_identity()
# 验证图片权限
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 获取以该图片为参考或目标的评估结果
evaluations = EvaluationResult.query.filter(
(EvaluationResult.reference_image_id == image_id) |
(EvaluationResult.target_image_id == image_id)
).all()
return jsonify({
'image_id': image_id,
'evaluations': [eval_result.to_dict() for eval_result in evaluations]
}), 200
except Exception as e:
return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500
@image_bp.route('/compare', methods=['POST'])
@jwt_required()
def compare_images():
"""对比两张图片"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
image1_id = data.get('image1_id')
image2_id = data.get('image2_id')
if not image1_id or not image2_id:
return jsonify({'error': '请提供两张图片的ID'}), 400
# 验证图片权限
image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first()
image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first()
if not image1 or not image2:
return jsonify({'error': '图片不存在或无权限'}), 404
# 查找现有的评估结果
evaluation = EvaluationResult.query.filter_by(
reference_image_id=image1_id,
target_image_id=image2_id
).first()
if not evaluation:
# 如果没有评估结果,返回基本对比信息
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': None,
'message': '暂无评估数据,请等待任务处理完成'
}), 200
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': evaluation.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'图片对比失败: {str(e)}'}), 500
@image_bp.route('/heatmap/<path:heatmap_path>', methods=['GET'])
@jwt_required()
def get_heatmap(heatmap_path):
"""获取热力图文件"""
try:
# 安全检查,防止路径遍历攻击
if '..' in heatmap_path or heatmap_path.startswith('/'):
return jsonify({'error': '无效的文件路径'}), 400
# 修正路径构建 - 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path))
if not os.path.exists(full_path):
return jsonify({'error': '热力图文件不存在'}), 404
return send_file(full_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500
@image_bp.route('/delete/<int:image_id>', methods=['DELETE'])
@jwt_required()
def delete_image(image_id):
"""删除图片"""
try:
current_user_id = get_jwt_identity()
result = ImageService.delete_image(image_id, current_user_id)
if result['success']:
return jsonify({'message': '图片删除成功'}), 200
else:
return jsonify({'error': result['error']}), 400
except Exception as e:
"""
图像管理控制器
处理图像下载查看等功能
"""
from flask import Blueprint, send_file, jsonify, request, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from app.database import Image, EvaluationResult
from app.services.image_service import ImageService
import os
image_bp = Blueprint('image', __name__)
@image_bp.route('/file/<int:image_id>', methods=['GET'])
@jwt_required()
def get_image_file(image_id):
"""获取图片文件"""
try:
current_user_id = get_jwt_identity()
# 查找图片记录
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 检查文件是否存在
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(image.file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取图片失败: {str(e)}'}), 500
@image_bp.route('/download/<int:image_id>', methods=['GET'])
@jwt_required()
def download_image(image_id):
"""下载图片文件"""
try:
current_user_id = get_jwt_identity()
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
if not os.path.exists(image.file_path):
return jsonify({'error': '图片文件不存在'}), 404
return send_file(
image.file_path,
as_attachment=True,
download_name=image.original_filename or f"image_{image_id}.jpg"
)
except Exception as e:
return jsonify({'error': f'下载图片失败: {str(e)}'}), 500
@image_bp.route('/batch/<int:batch_id>/download', methods=['GET'])
@jwt_required()
def download_batch_images(batch_id):
"""批量下载任务中的加噪后图片"""
try:
current_user_id = get_jwt_identity()
# 获取任务中的加噪图片
perturbed_images = Image.query.join(Image.image_type).filter(
Image.batch_id == batch_id,
Image.user_id == current_user_id,
Image.image_type.has(type_code='perturbed')
).all()
if not perturbed_images:
return jsonify({'error': '没有找到加噪后的图片'}), 404
# 创建ZIP文件
import zipfile
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file:
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
for image in perturbed_images:
if os.path.exists(image.file_path):
arcname = image.original_filename or f"perturbed_{image.id}.jpg"
zip_file.write(image.file_path, arcname)
return send_file(
tmp_file.name,
as_attachment=True,
download_name=f"batch_{batch_id}_perturbed_images.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({'error': f'批量下载失败: {str(e)}'}), 500
@image_bp.route('/<int:image_id>/evaluations', methods=['GET'])
@jwt_required()
def get_image_evaluations(image_id):
"""获取图片的评估结果"""
try:
current_user_id = get_jwt_identity()
# 验证图片权限
image = Image.query.filter_by(id=image_id, user_id=current_user_id).first()
if not image:
return jsonify({'error': '图片不存在或无权限'}), 404
# 获取以该图片为参考或目标的评估结果
evaluations = EvaluationResult.query.filter(
(EvaluationResult.reference_image_id == image_id) |
(EvaluationResult.target_image_id == image_id)
).all()
return jsonify({
'image_id': image_id,
'evaluations': [eval_result.to_dict() for eval_result in evaluations]
}), 200
except Exception as e:
return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500
@image_bp.route('/compare', methods=['POST'])
@jwt_required()
def compare_images():
"""对比两张图片"""
try:
current_user_id = get_jwt_identity()
data = request.get_json()
image1_id = data.get('image1_id')
image2_id = data.get('image2_id')
if not image1_id or not image2_id:
return jsonify({'error': '请提供两张图片的ID'}), 400
# 验证图片权限
image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first()
image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first()
if not image1 or not image2:
return jsonify({'error': '图片不存在或无权限'}), 404
# 查找现有的评估结果
evaluation = EvaluationResult.query.filter_by(
reference_image_id=image1_id,
target_image_id=image2_id
).first()
if not evaluation:
# 如果没有评估结果,返回基本对比信息
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': None,
'message': '暂无评估数据,请等待任务处理完成'
}), 200
return jsonify({
'image1': image1.to_dict(),
'image2': image2.to_dict(),
'evaluation': evaluation.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'图片对比失败: {str(e)}'}), 500
@image_bp.route('/heatmap/<path:heatmap_path>', methods=['GET'])
@jwt_required()
def get_heatmap(heatmap_path):
"""获取热力图文件"""
try:
# 安全检查,防止路径遍历攻击
if '..' in heatmap_path or heatmap_path.startswith('/'):
return jsonify({'error': '无效的文件路径'}), 400
# 修正路径构建 - 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path))
if not os.path.exists(full_path):
return jsonify({'error': '热力图文件不存在'}), 404
return send_file(full_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500
@image_bp.route('/delete/<int:image_id>', methods=['DELETE'])
@jwt_required()
def delete_image(image_id):
"""删除图片"""
try:
current_user_id = get_jwt_identity()
result = ImageService.delete_image(image_id, current_user_id)
if result['success']:
return jsonify({'message': '图片删除成功'}), 200
else:
return jsonify({'error': result['error']}), 400
except Exception as e:
return jsonify({'error': f'删除图片失败: {str(e)}'}), 500

File diff suppressed because it is too large Load Diff

@ -1,133 +1,133 @@
"""
用户管理控制器
处理用户配置等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
"""获取用户配置"""
try:
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
# 如果没有配置,创建默认配置
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'config': user_config.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
"""更新用户配置"""
try:
data = request.get_json()
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置字段
if 'preferred_perturbation_config_id' in data:
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
if 'preferred_epsilon' in data:
epsilon = float(data['preferred_epsilon'])
if 0 < epsilon <= 255:
user_config.preferred_epsilon = epsilon
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'preferred_finetune_config_id' in data:
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
if 'preferred_purification' in data:
user_config.preferred_purification = bool(data['preferred_purification'])
db.session.commit()
return jsonify({
'message': '用户配置更新成功',
'config': user_config.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
@user_bp.route('/algorithms', methods=['GET'])
@jwt_required()
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = PerturbationConfig.query.all()
finetune_configs = FinetuneConfig.query.all()
return jsonify({
'perturbation_algorithms': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
'default_epsilon': float(config.default_epsilon)
} for config in perturbation_configs
],
'finetune_methods': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description
} for config in finetune_configs
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
@user_bp.route('/stats', methods=['GET'])
@int_jwt_required
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.models import Batch, Image
# 统计用户的任务和图片数量
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.filter_by(user_id=current_user_id).count()
return jsonify({
'stats': {
'total_tasks': total_tasks,
'completed_tasks': completed_tasks,
'processing_tasks': processing_tasks,
'failed_tasks': failed_tasks,
'total_images': total_images
}
}), 200
except Exception as e:
"""
用户管理控制器
处理用户配置等功能
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
"""获取用户配置"""
try:
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
# 如果没有配置,创建默认配置
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
db.session.commit()
return jsonify({
'config': user_config.to_dict()
}), 200
except Exception as e:
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
"""更新用户配置"""
try:
data = request.get_json()
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
if not user_config:
user_config = UserConfig(user_id=current_user_id)
db.session.add(user_config)
# 更新配置字段
if 'preferred_perturbation_config_id' in data:
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
if 'preferred_epsilon' in data:
epsilon = float(data['preferred_epsilon'])
if 0 < epsilon <= 255:
user_config.preferred_epsilon = epsilon
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'preferred_finetune_config_id' in data:
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
if 'preferred_purification' in data:
user_config.preferred_purification = bool(data['preferred_purification'])
db.session.commit()
return jsonify({
'message': '用户配置更新成功',
'config': user_config.to_dict()
}), 200
except Exception as e:
db.session.rollback()
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
@user_bp.route('/algorithms', methods=['GET'])
@jwt_required()
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = PerturbationConfig.query.all()
finetune_configs = FinetuneConfig.query.all()
return jsonify({
'perturbation_algorithms': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
'default_epsilon': float(config.default_epsilon)
} for config in perturbation_configs
],
'finetune_methods': [
{
'id': config.id,
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description
} for config in finetune_configs
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
@user_bp.route('/stats', methods=['GET'])
@int_jwt_required
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.database import Batch, Image
# 统计用户的任务和图片数量
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.filter_by(user_id=current_user_id).count()
return jsonify({
'stats': {
'total_tasks': total_tasks,
'completed_tasks': completed_tasks,
'processing_tasks': processing_tasks,
'failed_tasks': failed_tasks,
'total_images': total_images
}
}), 200
except Exception as e:
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500

@ -1,233 +1,278 @@
"""
数据库模型定义
基于已有的schema.sql设计
"""
from datetime import datetime
from app import db
from werkzeug.security import generate_password_hash, check_password_hash
from enum import Enum as PyEnum
class User(db.Model):
"""用户表"""
__tablename__ = 'users'
id = db.Column(db.BigInteger, primary_key=True)
username = db.Column(db.String(50), unique=True, nullable=False)
password_hash = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(100))
role = db.Column(db.Enum('user', 'admin'), default='user')
max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = db.Column(db.Boolean, default=True)
# 关系
batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan')
images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan')
user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan')
def set_password(self, password):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'role': self.role,
'max_concurrent_tasks': self.max_concurrent_tasks,
'created_at': self.created_at.isoformat() if self.created_at else None,
'is_active': self.is_active
}
class ImageType(db.Model):
"""图片类型表"""
__tablename__ = 'image_types'
id = db.Column(db.BigInteger, primary_key=True)
type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'),
unique=True, nullable=False)
type_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text)
# 关系
images = db.relationship('Image', backref='image_type', lazy='dynamic')
class PerturbationConfig(db.Model):
"""加噪算法表"""
__tablename__ = 'perturbation_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
default_epsilon = db.Column(db.Numeric(5, 2), nullable=False)
# 关系
batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic')
class FinetuneConfig(db.Model):
"""微调方式表"""
__tablename__ = 'finetune_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
# 关系
batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic')
class Batch(db.Model):
"""加噪批次表"""
__tablename__ = 'batch'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_name = db.Column(db.String(128))
# 加噪配置
perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'),
nullable=False, default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0)
# 评估配置
finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'),
nullable=False, default=1)
# 净化配置
use_strong_protection = db.Column(db.Boolean, nullable=False, default=False)
# 任务状态
status = db.Column(db.Enum('pending', 'processing', 'completed', 'failed'), default='pending')
created_at = db.Column(db.DateTime, default=datetime.utcnow)
started_at = db.Column(db.DateTime)
completed_at = db.Column(db.DateTime)
error_message = db.Column(db.Text)
result_path = db.Column(db.String(500))
# 关系
images = db.relationship('Image', backref='batch', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'batch_name': self.batch_name,
'status': self.status,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'use_strong_protection': self.use_strong_protection,
'created_at': self.created_at.isoformat() if self.created_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'error_message': self.error_message,
'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None,
'finetune_config': self.finetune_config.method_name if self.finetune_config else None
}
class Image(db.Model):
"""图片表"""
__tablename__ = 'images'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'))
father_id = db.Column(db.BigInteger, db.ForeignKey('images.id'))
original_filename = db.Column(db.String(255))
stored_filename = db.Column(db.String(255), unique=True, nullable=False)
file_path = db.Column(db.String(500), nullable=False)
file_size = db.Column(db.BigInteger)
image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False)
width = db.Column(db.Integer)
height = db.Column(db.Integer)
upload_time = db.Column(db.DateTime, default=datetime.utcnow)
# 自引用关系
children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic')
# 评估结果关系
reference_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.reference_image_id',
backref='reference_image', lazy='dynamic')
target_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.target_image_id',
backref='target_image', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'original_filename': self.original_filename,
'stored_filename': self.stored_filename,
'file_path': self.file_path,
'file_size': self.file_size,
'width': self.width,
'height': self.height,
'upload_time': self.upload_time.isoformat() if self.upload_time else None,
'image_type': self.image_type.type_name if self.image_type else None,
'batch_id': self.batch_id
}
class EvaluationResult(db.Model):
"""评估结果表"""
__tablename__ = 'evaluation_results'
id = db.Column(db.BigInteger, primary_key=True)
reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False)
purification_applied = db.Column(db.Boolean, default=False)
fid_score = db.Column(db.Numeric(8, 4))
lpips_score = db.Column(db.Numeric(8, 4))
ssim_score = db.Column(db.Numeric(8, 4))
psnr_score = db.Column(db.Numeric(8, 4))
heatmap_path = db.Column(db.String(500))
evaluated_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'evaluation_type': self.evaluation_type,
'purification_applied': self.purification_applied,
'fid_score': float(self.fid_score) if self.fid_score else None,
'lpips_score': float(self.lpips_score) if self.lpips_score else None,
'ssim_score': float(self.ssim_score) if self.ssim_score else None,
'psnr_score': float(self.psnr_score) if self.psnr_score else None,
'heatmap_path': self.heatmap_path,
'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None
}
class UserConfig(db.Model):
"""用户配置表"""
__tablename__ = 'user_configs'
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True)
preferred_perturbation_config_id = db.Column(db.BigInteger,
db.ForeignKey('perturbation_configs.id'), default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0)
preferred_finetune_config_id = db.Column(db.BigInteger,
db.ForeignKey('finetune_configs.id'), default=1)
preferred_purification = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'user_id': self.user_id,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'preferred_purification': self.preferred_purification,
'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None,
'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}
"""
数据库模型定义
基于已有的schema.sql设计
"""
from datetime import datetime
from app import db
from werkzeug.security import generate_password_hash, check_password_hash
from enum import Enum as PyEnum
class User(db.Model):
"""用户表"""
__tablename__ = 'users'
id = db.Column(db.BigInteger, primary_key=True)
username = db.Column(db.String(50), unique=True, nullable=False)
password_hash = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(100))
role = db.Column(db.Enum('user', 'admin'), default='user')
max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = db.Column(db.Boolean, default=True)
# 关系
batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan')
images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan')
user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan')
def set_password(self, password):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'role': self.role,
'max_concurrent_tasks': self.max_concurrent_tasks,
'created_at': self.created_at.isoformat() if self.created_at else None,
'is_active': self.is_active
}
class ImageType(db.Model):
"""图片类型表"""
__tablename__ = 'image_types'
id = db.Column(db.BigInteger, primary_key=True)
type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'),
unique=True, nullable=False)
type_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text)
# 关系
images = db.relationship('Image', backref='image_type', lazy='dynamic')
class PerturbationConfig(db.Model):
"""加噪算法表"""
__tablename__ = 'perturbation_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
default_epsilon = db.Column(db.Numeric(5, 2), nullable=False)
# 关系
batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic')
class FinetuneConfig(db.Model):
"""微调方式表"""
__tablename__ = 'finetune_configs'
id = db.Column(db.BigInteger, primary_key=True)
method_code = db.Column(db.String(50), unique=True, nullable=False)
method_name = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text, nullable=False)
# 关系
finetune_tasks = db.relationship('FinetuneBatch', backref='finetune_config', lazy='dynamic')
user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'method_code': self.method_code,
'method_name': self.method_name,
'description': self.description
}
class Batch(db.Model):
"""加噪批次表"""
__tablename__ = 'batch'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_name = db.Column(db.String(128))
# 加噪配置
perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'),
nullable=False, default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0)
# 净化配置
use_strong_protection = db.Column(db.Boolean, nullable=False, default=False)
# 任务状态
status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending')
created_at = db.Column(db.DateTime, default=datetime.utcnow)
started_at = db.Column(db.DateTime)
completed_at = db.Column(db.DateTime)
error_message = db.Column(db.Text)
result_path = db.Column(db.String(500))
# 关系
images = db.relationship('Image', backref='batch', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'batch_name': self.batch_name,
'status': self.status,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'use_strong_protection': self.use_strong_protection,
'created_at': self.created_at.isoformat() if self.created_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'error_message': self.error_message,
'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None
}
class Image(db.Model):
"""图片表"""
__tablename__ = 'images'
id = db.Column(db.BigInteger, primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'))
father_id = db.Column(db.BigInteger, db.ForeignKey('images.id'))
original_filename = db.Column(db.String(255))
stored_filename = db.Column(db.String(255), unique=True, nullable=False)
file_path = db.Column(db.String(500), nullable=False)
file_size = db.Column(db.BigInteger)
image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False)
width = db.Column(db.Integer)
height = db.Column(db.Integer)
upload_time = db.Column(db.DateTime, default=datetime.utcnow)
# 自引用关系
children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic')
# 评估结果关系
reference_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.reference_image_id',
backref='reference_image', lazy='dynamic')
target_evaluations = db.relationship('EvaluationResult',
foreign_keys='EvaluationResult.target_image_id',
backref='target_image', lazy='dynamic')
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'original_filename': self.original_filename,
'stored_filename': self.stored_filename,
'file_path': self.file_path,
'file_size': self.file_size,
'width': self.width,
'height': self.height,
'upload_time': self.upload_time.isoformat() if self.upload_time else None,
'image_type': self.image_type.type_name if self.image_type else None,
'batch_id': self.batch_id
}
class EvaluationResult(db.Model):
"""评估结果表"""
__tablename__ = 'evaluation_results'
id = db.Column(db.BigInteger, primary_key=True)
reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False)
evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False)
purification_applied = db.Column(db.Boolean, default=False)
fid_score = db.Column(db.Numeric(8, 4))
lpips_score = db.Column(db.Numeric(8, 4))
ssim_score = db.Column(db.Numeric(8, 4))
psnr_score = db.Column(db.Numeric(8, 4))
heatmap_path = db.Column(db.String(500))
evaluated_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'evaluation_type': self.evaluation_type,
'purification_applied': self.purification_applied,
'fid_score': float(self.fid_score) if self.fid_score else None,
'lpips_score': float(self.lpips_score) if self.lpips_score else None,
'ssim_score': float(self.ssim_score) if self.ssim_score else None,
'psnr_score': float(self.psnr_score) if self.psnr_score else None,
'heatmap_path': self.heatmap_path,
'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None
}
class FinetuneBatch(db.Model):
"""微调任务表"""
__tablename__ = 'finetune_batch'
id = db.Column(db.BigInteger, primary_key=True)
batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'), nullable=False)
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False)
finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'))
# 任务状态
status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending')
created_at = db.Column(db.DateTime, default=datetime.utcnow)
started_at = db.Column(db.DateTime)
completed_at = db.Column(db.DateTime)
error_message = db.Column(db.Text)
# 任务ID用于RQ任务追踪
original_job_id = db.Column(db.String(255)) # 原始图片微调任务ID
perturbed_job_id = db.Column(db.String(255)) # 扰动图片微调任务ID
# 关系
batch = db.relationship('Batch', backref='finetune_tasks', lazy=True)
user = db.relationship('User', backref='finetune_tasks', lazy=True)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'batch_id': self.batch_id,
'user_id': self.user_id,
'finetune_config_id': self.finetune_config_id,
'finetune_config': self.finetune_config.method_name if self.finetune_config else None,
'status': self.status,
'created_at': self.created_at.isoformat() if self.created_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'error_message': self.error_message,
'original_job_id': self.original_job_id,
'perturbed_job_id': self.perturbed_job_id
}
class UserConfig(db.Model):
"""用户配置表"""
__tablename__ = 'user_configs'
user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True)
preferred_perturbation_config_id = db.Column(db.BigInteger,
db.ForeignKey('perturbation_configs.id'), default=1)
preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0)
preferred_finetune_config_id = db.Column(db.BigInteger,
db.ForeignKey('finetune_configs.id'), default=1)
preferred_purification = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'user_id': self.user_id,
'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None,
'preferred_purification': self.preferred_purification,
'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None,
'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}

@ -0,0 +1,62 @@
#需要环境conda activate simac
export HF_HUB_OFFLINE=1
export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06"
export TASKNAME="task001"
# ------------------------- Train ASPL on set CLEAN_ADV_DIR -------------------------
export CLEAN_TRAIN_DIR="../../static/originals/${TASKNAME}"
export CLEAN_ADV_DIR="../../static/originals/${TASKNAME}"
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$CLEAN_TRAIN_DIR"
mkdir -p "$CLEAN_ADV_DIR"
mkdir -p "$OUTPUT_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 确保目录存在,避免清理命令失败
# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除
mkdir -p "$OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
accelerate launch ../algorithms/aspl.py \
  --pretrained_model_name_or_path=$MODEL_PATH  \
  --enable_xformers_memory_efficient_attention \
  --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \
  --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \
  --instance_prompt="a photo of sks person" \
  --class_data_dir=$CLASS_DIR \
  --num_class_images=200 \
  --class_prompt="a photo of person" \
  --output_dir=$OUTPUT_DIR \
  --center_crop \
  --with_prior_preservation \
  --prior_loss_weight=1.0 \
  --resolution=384 \
  --train_batch_size=1 \
  --max_train_steps=50 \
  --max_f_train_steps=3 \
  --max_adv_train_steps=6 \
  --checkpointing_iterations=10 \
  --learning_rate=5e-7 \
  --pgd_alpha=0.005 \
  --pgd_eps=8 \
  --seed=0
# ------------------------- 训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -0,0 +1,39 @@
#需要环境conda activate caat
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
#export HF_HOME="/root/autodl-tmp/huggingface_cache"
#export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export MODEL_NAME="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task001"
### Data to be protected
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
### Path to save the protected data
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$OUTPUT_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
accelerate launch ../algorithms/caat.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of a person" \
--resolution=512 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--hflip \
--mixed_precision bf16 \
--alpha=5e-3 \
--eps=0.05

@ -0,0 +1,55 @@
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
#export HF_HOME="/root/autodl-tmp/huggingface_cache"
#export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export MODEL_NAME="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task001"
### Data to be protected
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
### Path to save the protected data
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$OUTPUT_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
accelerate launch ../algorithms/caat.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation \
--instance_prompt="a photo of a person" \
--num_class_images=200 \
--class_data_dir=$CLASS_DIR \
--class_prompt='person' \
--resolution=512 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--hflip \
--mixed_precision bf16 \
--alpha=5e-3 \
--eps=0.05
# ------------------------- 【步骤 2】训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -0,0 +1,53 @@
#需要环境conda activate pid
### Generate images protected by PID
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
### SD v2.1
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="stabilityai/stable-diffusion-2-1"
### SD v1.5
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="runwayml/stable-diffusion-v1-5"
export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task001"
### Data to be protected
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
### Path to save the protected data
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$OUTPUT_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
### Generation command
# --max_train_steps: Optimizaiton steps
# --attack_type: target loss to update, choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'],
# Please refer to the file content for more usage
CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--max_train_steps=1000 \
--center_crop \
--eps 12.75 \
--attack_type add-log

@ -0,0 +1,63 @@
#需要环境conda activate simac
export HF_HUB_OFFLINE=1
export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06"
export TASKNAME="task001"
# ------------------------- Train ASPL on set CLEAN_ADV_DIR -------------------------
export CLEAN_TRAIN_DIR="../../static/originals/${TASKNAME}"
export CLEAN_ADV_DIR="../../static/originals/${TASKNAME}"
export OUTPUT_DIR="../../static/perturbed/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$CLEAN_TRAIN_DIR"
mkdir -p "$CLEAN_ADV_DIR"
mkdir -p "$OUTPUT_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $OUTPUT_DIR"
# 确保目录存在,避免清理命令失败
# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除
mkdir -p "$OUTPUT_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$OUTPUT_DIR" -mindepth 1 -delete
accelerate launch ../algorithms/simac.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--enable_xformers_memory_efficient_attention \
--instance_data_dir_for_train=$CLEAN_TRAIN_DIR \
--instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \
--instance_prompt="a photo of sks person" \
--class_data_dir=$CLASS_DIR \
--num_class_images=200 \
--class_prompt="a photo of person" \
--output_dir=$OUTPUT_DIR \
--center_crop \
--with_prior_preservation \
--prior_loss_weight=1.0 \
--resolution=384 \
--train_batch_size=1 \
--max_train_steps=50 \
--max_f_train_steps=3 \
--max_adv_train_steps=6 \
--checkpointing_iterations=10 \
--learning_rate=5e-7 \
--pgd_alpha=0.005 \
--pgd_eps=8 \
--seed=0
# ------------------------- 训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -0,0 +1,90 @@
#需要环境conda activate pid
### Trianing BAD Model (PID-poisoned Data)
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
### SD v2.1
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="stabilityai/stable-diffusion-2-1"
### SD v1.5
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="runwayml/stable-diffusion-v1-5"
export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task001"
export TYPE="clean" #clean or perturbed
if [ "$TYPE" == "clean" ]; then
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
else
export INSTANCE_DIR="../../static/perturbed/${TASKNAME}"
fi
export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}"
export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete
find "$OUTPUT_INFER_DIR" -mindepth 1 -delete
# ------------------------- Fine-tune DreamBooth on images -------------------------
CUDA_VISIBLE_DEVICES=0 accelerate launch ../finetune_infras/train_dreambooth_gen.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$DREAMBOOTH_OUTPUT_DIR \
--validation_image_output_dir=$OUTPUT_INFER_DIR \
--with_prior_preservation \
--prior_loss_weight=1.0 \
--instance_prompt="a photo of sks person" \
--class_prompt="a photo of person" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=2e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=1000 \
--checkpointing_steps=500 \
--center_crop \
--mixed_precision=bf16 \
--prior_generation_precision=bf16 \
--sample_batch_size=5 \
--validation_prompt="a photo of sks person" \
--num_validation_images 10 \
--validation_steps 500
# ------------------------- 训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -0,0 +1,85 @@
#需要环境conda activate simac
export HF_HUB_OFFLINE=1
export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06"
export TASKNAME="task001"
export TYPE="perturbed" #clean or perturbed
if [ "$TYPE" == "clean" ]; then
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
else
export INSTANCE_DIR="../../static/perturbed/${TASKNAME}"
fi
export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}"
export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete
find "$OUTPUT_INFER_DIR" -mindepth 1 -delete
# ------------------------- Fine-tune DreamBooth on images -------------------------
accelerate launch ../finetune_infras/train_dreambooth_alone.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--enable_xformers_memory_efficient_attention \
--train_text_encoder \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$DREAMBOOTH_OUTPUT_DIR \
--with_prior_preservation \
--prior_loss_weight=1.0 \
--instance_prompt="a photo of sks person" \
--class_prompt="a photo of person" \
--inference_prompt="a photo of sks person" \
--resolution=384 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=1000 \
--checkpointing_steps=1000 \
--center_crop \
--mixed_precision=bf16 \
--prior_generation_precision=bf16 \
--sample_batch_size=1 \
--seed=0
# ------------------------- Inference -------------------------
python ../finetune_infras/infer.py \
--model_path $DREAMBOOTH_OUTPUT_DIR \
--output_dir $OUTPUT_INFER_DIR
# ------------------------- 训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -0,0 +1,84 @@
#需要环境conda activate pid
### Trianing model
export HF_HUB_OFFLINE=1
# 强制使用本地模型缓存,避免联网下载模型
### SD v2.1
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="stabilityai/stable-diffusion-2-1"
### SD v1.5
# export HF_HOME="/root/autodl-tmp/huggingface_cache"
# export MODEL_PATH="runwayml/stable-diffusion-v1-5"
export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"
export TASKNAME="task001"
export TYPE="clean" #clean or perturbed
if [ "$TYPE" == "clean" ]; then
export INSTANCE_DIR="../../static/originals/${TASKNAME}"
else
export INSTANCE_DIR="../../static/perturbed/${TASKNAME}"
fi
export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}"
export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}"
export CLASS_DIR="../../static/class/${TASKNAME}"
# ------------------------- 自动创建依赖路径 -------------------------
echo "Creating required directories..."
mkdir -p "$INSTANCE_DIR"
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
mkdir -p "$CLASS_DIR"
echo "Directories created successfully."
# ------------------------- 训练前清空 OUTPUT_DIR -------------------------
echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
mkdir -p "$OUTPUT_INFER_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete
find "$OUTPUT_INFER_DIR" -mindepth 1 -delete
# ------------------------- Fine-tune LoRA on images -------------------------
CUDA_VISIBLE_DEVICES=0 accelerate launch ../finetune_infras/train_lora_gen.py \
--pretrained_model_name_or_path=$MODEL_PATH \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$DREAMBOOTH_OUTPUT_DIR \
--validation_image_output_dir=$OUTPUT_INFER_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks person" \
--class_prompt="a photo of person" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=1000 \
--checkpointing_steps=500 \
--seed=0 \
--mixed_precision=fp16 \
--rank=4 \
--validation_prompt="a photo of sks person" \
--num_validation_images 10 \
# ------------------------- 训练后清空 CLASS_DIR -------------------------
# 注意:这会在 accelerate launch 成功结束后执行
echo "Clearing class directory: $CLASS_DIR"
# 确保目录存在,避免清理命令失败
mkdir -p "$CLASS_DIR"
# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..
find "$CLASS_DIR" -mindepth 1 -delete
echo "Script finished."

@ -1,34 +1,34 @@
"""
认证服务
处理用户认证相关逻辑
"""
from app.models import User
class AuthService:
"""认证服务类"""
@staticmethod
def authenticate_user(username, password):
"""验证用户凭据"""
user = User.query.filter_by(username=username).first()
if user and user.check_password(password) and user.is_active:
return user
return None
@staticmethod
def get_user_by_id(user_id):
"""根据ID获取用户"""
return User.query.get(user_id)
@staticmethod
def is_email_available(email):
"""检查邮箱是否可用"""
return User.query.filter_by(email=email).first() is None
@staticmethod
def is_username_available(username):
"""检查用户名是否可用"""
"""
认证服务
处理用户认证相关逻辑
"""
from app.database import User
class AuthService:
"""认证服务类"""
@staticmethod
def authenticate_user(username, password):
"""验证用户凭据"""
user = User.query.filter_by(username=username).first()
if user and user.check_password(password) and user.is_active:
return user
return None
@staticmethod
def get_user_by_id(user_id):
"""根据ID获取用户"""
return User.query.get(user_id)
@staticmethod
def is_email_available(email):
"""检查邮箱是否可用"""
return User.query.filter_by(email=email).first() is None
@staticmethod
def is_username_available(username):
"""检查用户名是否可用"""
return User.query.filter_by(username=username).first() is None

@ -1,161 +1,257 @@
"""
图像处理服务
处理图像上传保存等功能
"""
import os
import uuid
import zipfile
from werkzeug.utils import secure_filename
from flask import current_app
from PIL import Image as PILImage
from app import db
from app.models import Image
from app.utils.file_utils import allowed_file
class ImageService:
"""图像处理服务"""
@staticmethod
def save_image(file, batch_id, user_id, image_type_id):
"""保存单张图片"""
try:
if not file or not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
# 生成唯一文件名
file_extension = os.path.splitext(file.filename)[1].lower()
stored_filename = f"{uuid.uuid4().hex}{file_extension}"
# 临时保存到上传目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id))
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, stored_filename)
file.save(temp_path)
# 移动到对应的静态目录
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id))
os.makedirs(static_dir, exist_ok=True)
file_path = os.path.join(static_dir, stored_filename)
# 移动文件到最终位置
import shutil
shutil.move(temp_path, file_path)
# 获取图片尺寸
try:
with PILImage.open(file_path) as img:
width, height = img.size
except:
width, height = None, None
# 创建数据库记录
image = Image(
user_id=user_id,
batch_id=batch_id,
original_filename=file.filename,
stored_filename=stored_filename,
file_path=file_path,
file_size=os.path.getsize(file_path),
image_type_id=image_type_id,
width=width,
height=height
)
db.session.add(image)
db.session.commit()
return {'success': True, 'image': image}
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
@staticmethod
def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id):
"""解压并保存压缩包中的图片"""
results = []
temp_dir = None
try:
# 创建临时目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}")
os.makedirs(temp_dir, exist_ok=True)
# 保存压缩包
zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename))
zip_file.save(zip_path)
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 遍历解压的文件
for root, dirs, files in os.walk(temp_dir):
for filename in files:
if filename.lower().endswith(('.zip', '.rar')):
continue # 跳过压缩包文件本身
if allowed_file(filename):
file_path = os.path.join(root, filename)
# 创建虚拟文件对象
class FileWrapper:
def __init__(self, path, name):
self.path = path
self.filename = name
def save(self, destination):
import shutil
shutil.copy2(self.path, destination)
virtual_file = FileWrapper(file_path, filename)
result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id)
results.append(result)
return results
except Exception as e:
return [{'success': False, 'error': f'解压失败: {str(e)}'}]
finally:
# 清理临时文件
if temp_dir and os.path.exists(temp_dir):
import shutil
try:
shutil.rmtree(temp_dir)
except:
pass
@staticmethod
def get_image_url(image):
"""获取图片访问URL"""
if not image or not image.file_path:
return None
# 这里返回相对路径前端可以拼接完整URL
return f"/api/image/file/{image.id}"
@staticmethod
def delete_image(image_id, user_id):
"""删除图片"""
try:
image = Image.query.filter_by(id=image_id, user_id=user_id).first()
if not image:
return {'success': False, 'error': '图片不存在或无权限'}
# 删除文件
if os.path.exists(image.file_path):
os.remove(image.file_path)
# 删除数据库记录
db.session.delete(image)
db.session.commit()
return {'success': True}
except Exception as e:
db.session.rollback()
"""
图像处理服务
处理图像上传保存等功能
"""
import os
import uuid
import zipfile
import fcntl
import time
from werkzeug.utils import secure_filename
from flask import current_app
from PIL import Image as PILImage
from app import db
from app.database import Image
from app.utils.file_utils import allowed_file
class ImageService:
@staticmethod
def save_to_uploads(file, batch_id, user_id):
"""
上传图片到uploads临时目录返回临时文件路径和原始文件名
"""
project_root = os.path.dirname(current_app.root_path)
upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id))
os.makedirs(upload_dir, exist_ok=True)
orig_ext = os.path.splitext(file.filename)[1].lower()
temp_name = f"{uuid.uuid4().hex}{orig_ext}"
temp_path = os.path.join(upload_dir, temp_name)
file.save(temp_path)
return temp_path, file.filename
@staticmethod
def preprocess_image(temp_path, original_filename, batch_id, user_id, image_type_id, resolution=512, target_format='png'):
"""
对图片进行中心裁剪缩放格式转换重命名保存到static/originals返回数据库对象
原图命名格式: 0000.png, 0001.png, ..., 9999.png
使用数据库事务和重试机制确保并发安全
"""
final_path = None
max_retries = 50
try:
img = PILImage.open(temp_path).convert("RGB")
width, height = img.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
img = img.crop((left, top, right, bottom))
img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS)
project_root = os.path.dirname(current_app.root_path)
static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id))
os.makedirs(static_dir, exist_ok=True)
from app.database import ImageType
original_type = ImageType.query.filter_by(type_code='original').first()
target_image_type_id = original_type.id if original_type else image_type_id
# 首次查询最大序号
max_seq_result = db.session.execute(
db.text("""
SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq
FROM images
WHERE batch_id = :batch_id
AND image_type_id = :image_type_id
AND stored_filename REGEXP '^[0-9]{4}\\.'
"""),
{'batch_id': batch_id, 'image_type_id': target_image_type_id}
).fetchone()
# 强制类型转换,确保安全
try:
base_sequence = int(max_seq_result[0]) if max_seq_result[0] is not None else -1
except Exception:
base_sequence = -1
base_sequence += 1
# 重试机制从base_sequence开始尝试连续的序号
for attempt in range(max_retries):
sequence_number = int(base_sequence) + int(attempt)
fmt_str = str(target_format).lower() if target_format else 'png'
new_name = f"{sequence_number:04d}.{fmt_str}"
final_path = os.path.join(static_dir, new_name)
try:
# 检查数据库中是否已存在此文件名
existing = Image.query.filter_by(
batch_id=batch_id,
stored_filename=new_name
).first()
if existing:
# 已存在,尝试下一个序号
continue
# 保存图片文件
if target_format.lower() in ['jpg', 'jpeg']:
img.save(final_path, format='JPEG', quality=95)
else:
img.save(final_path, format=target_format.upper())
# 创建数据库记录
image = Image(
user_id=user_id,
batch_id=batch_id,
original_filename=original_filename,
stored_filename=new_name,
file_path=final_path,
file_size=os.path.getsize(final_path),
image_type_id=image_type_id,
width=img.width,
height=img.height
)
db.session.add(image)
db.session.commit()
# 删除临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
return {'success': True, 'image': image}
except Exception as e:
db.session.rollback()
error_msg = str(e)
# 如果是唯一性冲突,清理文件并尝试下一个序号
if 'Duplicate entry' in error_msg or '1062' in error_msg:
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except:
pass
# 继续循环尝试下一个序号
time.sleep(0.005)
continue
else:
# 其他错误直接抛出
raise
# 所有尝试都失败
raise Exception(f"无法生成唯一文件名,已尝试序号 {base_sequence}{base_sequence + max_retries - 1}")
except Exception as e:
db.session.rollback()
# 清理可能已保存的文件
if final_path and os.path.exists(final_path):
try:
os.remove(final_path)
except:
pass
return {'success': False, 'error': f'图片预处理失败: {str(e)}'}
"""图像处理服务"""
@staticmethod
def save_image(file, batch_id, user_id, image_type_id, resolution=512, target_format='png'):
"""保存单张图片自动上传到uploads并预处理"""
try:
if not file or not allowed_file(file.filename):
return {'success': False, 'error': '不支持的文件格式'}
temp_path, orig_name = ImageService.save_to_uploads(file, batch_id, user_id)
return ImageService.preprocess_image(temp_path, orig_name, batch_id, user_id, image_type_id, resolution, target_format)
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'保存图片失败: {str(e)}'}
@staticmethod
def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id):
"""解压并保存压缩包中的图片"""
results = []
temp_dir = None
try:
# 创建临时目录
project_root = os.path.dirname(current_app.root_path)
temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}")
os.makedirs(temp_dir, exist_ok=True)
# 保存压缩包
zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename))
zip_file.save(zip_path)
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 遍历解压的文件
for root, dirs, files in os.walk(temp_dir):
for filename in files:
if filename.lower().endswith(('.zip', '.rar')):
continue # 跳过压缩包文件本身
if allowed_file(filename):
file_path = os.path.join(root, filename)
# 创建虚拟文件对象
class FileWrapper:
def __init__(self, path, name):
self.path = path
self.filename = name
def save(self, destination):
import shutil
shutil.copy2(self.path, destination)
virtual_file = FileWrapper(file_path, filename)
result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id)
results.append(result)
return results
except Exception as e:
return [{'success': False, 'error': f'解压失败: {str(e)}'}]
finally:
# 清理临时文件
if temp_dir and os.path.exists(temp_dir):
import shutil
try:
shutil.rmtree(temp_dir)
except:
pass
@staticmethod
def get_image_url(image):
"""获取图片访问URL"""
if not image or not image.file_path:
return None
# 这里返回相对路径前端可以拼接完整URL
return f"/api/image/file/{image.id}"
@staticmethod
def delete_image(image_id, user_id):
"""删除图片"""
try:
image = Image.query.filter_by(id=image_id, user_id=user_id).first()
if not image:
return {'success': False, 'error': '图片不存在或无权限'}
# 删除文件
if os.path.exists(image.file_path):
os.remove(image.file_path)
# 删除数据库记录
db.session.delete(image)
db.session.commit()
return {'success': True}
except Exception as e:
db.session.rollback()
return {'success': False, 'error': f'删除图片失败: {str(e)}'}

@ -1,191 +1,629 @@
"""
任务处理服务
处理图像加噪评估等核心业务逻辑
"""
import os
import time
from datetime import datetime
from flask import current_app
from app import db
from app.models import Batch, Image, EvaluationResult, ImageType
from app.algorithms.perturbation_engine import PerturbationEngine
from app.algorithms.evaluation_engine import EvaluationEngine
class TaskService:
"""任务处理服务"""
@staticmethod
def start_processing(batch):
"""开始处理任务"""
try:
# 更新任务状态
batch.status = 'processing'
batch.started_at = datetime.utcnow()
db.session.commit()
# 获取任务相关的原始图片
original_images = Image.query.filter_by(
batch_id=batch.id
).join(ImageType).filter(
ImageType.type_code == 'original'
).all()
if not original_images:
batch.status = 'failed'
batch.error_message = '没有找到原始图片'
batch.completed_at = datetime.utcnow()
db.session.commit()
return False
# 处理每张图片
perturbed_type = ImageType.query.filter_by(type_code='perturbed').first()
processed_images = []
for original_image in original_images:
try:
# 确定加噪图片的保存路径
project_root = os.path.dirname(current_app.root_path)
perturbed_dir = os.path.join(project_root,
current_app.config['PERTURBED_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id))
os.makedirs(perturbed_dir, exist_ok=True)
# 调用加噪算法
perturbed_image_path = PerturbationEngine.apply_perturbation(
original_image.file_path,
batch.perturbation_config.method_code,
float(batch.preferred_epsilon),
batch.use_strong_protection
)
if perturbed_image_path:
# 保存加噪后的图片记录
perturbed_image = Image(
user_id=batch.user_id,
batch_id=batch.id,
father_id=original_image.id,
original_filename=f"perturbed_{original_image.original_filename}",
stored_filename=os.path.basename(perturbed_image_path),
file_path=perturbed_image_path,
file_size=os.path.getsize(perturbed_image_path) if os.path.exists(perturbed_image_path) else 0,
image_type_id=perturbed_type.id,
width=original_image.width,
height=original_image.height
)
db.session.add(perturbed_image)
processed_images.append((original_image, perturbed_image))
except Exception as e:
print(f"处理图片 {original_image.id} 时出错: {str(e)}")
continue
# 提交加噪后的图片
db.session.commit()
# 生成评估结果
TaskService._generate_evaluations(batch, processed_images)
# 更新任务状态为完成
batch.status = 'completed'
batch.completed_at = datetime.utcnow()
db.session.commit()
return True
except Exception as e:
# 处理失败
batch.status = 'failed'
batch.error_message = str(e)
batch.completed_at = datetime.utcnow()
db.session.commit()
return False
@staticmethod
def _generate_evaluations(batch, processed_images):
"""生成评估结果"""
try:
for original_image, perturbed_image in processed_images:
# 图像质量对比评估
quality_metrics = EvaluationEngine.evaluate_image_quality(
original_image.file_path,
perturbed_image.file_path
)
quality_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='image_quality',
purification_applied=False,
fid_score=quality_metrics.get('fid'),
lpips_score=quality_metrics.get('lpips'),
ssim_score=quality_metrics.get('ssim'),
psnr_score=quality_metrics.get('psnr'),
heatmap_path=quality_metrics.get('heatmap_path')
)
db.session.add(quality_evaluation)
# 模型生成对比评估
generation_metrics = EvaluationEngine.evaluate_model_generation(
original_image.file_path,
perturbed_image.file_path,
batch.finetune_config.method_code
)
generation_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='model_generation',
purification_applied=False,
fid_score=generation_metrics.get('fid'),
lpips_score=generation_metrics.get('lpips'),
ssim_score=generation_metrics.get('ssim'),
psnr_score=generation_metrics.get('psnr'),
heatmap_path=generation_metrics.get('heatmap_path')
)
db.session.add(generation_evaluation)
db.session.commit()
except Exception as e:
print(f"生成评估结果时出错: {str(e)}")
@staticmethod
def get_processing_progress(batch_id):
"""获取处理进度"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return 0
if batch.status == 'pending':
return 0
elif batch.status == 'completed':
return 100
elif batch.status == 'failed':
return 0
elif batch.status == 'processing':
# 简单的进度计算:根据已处理的图片数量
total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'original'
).count()
processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'perturbed'
).count()
if total_images == 0:
return 0
progress = int((processed_images / total_images) * 80) # 80%用于图像处理20%用于评估
return min(progress, 95) # 最多95%剩余5%用于最终完成
return 0
except Exception as e:
print(f"获取处理进度时出错: {str(e)}")
return 0
"""
任务处理服务
处理图像加噪评估等核心业务逻辑
使用Redis Queue进行异步任务处理
"""
import os
from datetime import datetime
from flask import current_app
from redis import Redis
from rq import Queue
from rq.job import Job
from app import db
from app.database import Batch, Image, EvaluationResult, ImageType, FinetuneBatch
from config.algorithm_config import AlgorithmConfig
class TaskService:
"""任务处理服务"""
@staticmethod
def _get_redis_connection():
"""获取Redis连接"""
return Redis.from_url(AlgorithmConfig.REDIS_URL)
@staticmethod
def _get_queue():
"""获取RQ队列"""
redis_conn = TaskService._get_redis_connection()
return Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn)
@staticmethod
def start_processing(batch):
"""
开始处理任务异步
Args:
batch: Batch对象
Returns:
任务ID (RQ job id)
"""
try:
# 检查是否有原始图片
original_images = Image.query.filter_by(
batch_id=batch.id
).join(ImageType).filter(
ImageType.type_code == 'original'
).all()
if not original_images:
batch.status = 'failed'
batch.error_message = '没有找到原始图片'
batch.completed_at = datetime.utcnow()
db.session.commit()
return None
# 准备任务参数
project_root = os.path.dirname(current_app.root_path)
# 输入目录(原始图片)
input_dir = os.path.join(
project_root,
current_app.config['ORIGINAL_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id)
)
# 输出目录(扰动后图片)
output_dir = os.path.join(
project_root,
current_app.config['PERTURBED_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id)
)
# 类别图片目录用于prior preservation
class_dir = os.path.join(
project_root,
'static', 'class',
str(batch.user_id),
str(batch.id)
)
# 获取队列
queue = TaskService._get_queue()
# 提交任务到队列
from app.workers.perturbation_worker import run_perturbation_task
job = queue.enqueue(
run_perturbation_task,
batch_id=batch.id,
algorithm_code=batch.perturbation_config.method_code,
epsilon=int(batch.preferred_epsilon),
use_strong_protection=batch.use_strong_protection,
input_dir=input_dir,
output_dir=output_dir,
class_dir=class_dir,
custom_params=None,
job_timeout=AlgorithmConfig.TASK_TIMEOUT,
job_id=f"batch_{batch.id}"
)
# 更新任务状态
batch.status = 'queued'
db.session.commit()
return job.id
except Exception as e:
# 处理失败
batch.status = 'failed'
batch.error_message = str(e)
batch.completed_at = datetime.utcnow()
db.session.commit()
return None
@staticmethod
def get_task_status(batch_id):
"""
获取任务状态
Args:
batch_id: 批次ID
Returns:
任务状态信息
"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return {'status': 'not_found'}
# 如果任务已完成或失败,直接返回数据库状态
if batch.status in ['completed', 'failed']:
return {
'status': batch.status,
'error': batch.error_message if batch.status == 'failed' else None,
'started_at': batch.started_at,
'completed_at': batch.completed_at
}
# 尝试从RQ获取任务状态
try:
redis_conn = TaskService._get_redis_connection()
job = Job.fetch(f"batch_{batch_id}", connection=redis_conn)
rq_status = job.get_status()
# 映射RQ状态到我们的状态
status_map = {
'queued': 'queued',
'started': 'processing',
'finished': 'completed',
'failed': 'failed'
}
return {
'status': status_map.get(rq_status, batch.status),
'rq_status': rq_status,
'progress': job.meta.get('progress', 0) if hasattr(job, 'meta') else 0,
'started_at': batch.started_at,
'result': job.result if rq_status == 'finished' else None,
'error': str(job.exc_info) if rq_status == 'failed' else None
}
except:
# 如果无法从RQ获取状态返回数据库状态
return {
'status': batch.status,
'started_at': batch.started_at
}
except Exception as e:
return {'status': 'error', 'error': str(e)}
@staticmethod
def cancel_task(batch_id):
"""
取消任务
Args:
batch_id: 批次ID
Returns:
是否成功取消
"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return False
# 尝试从队列中删除任务
try:
redis_conn = TaskService._get_redis_connection()
job = Job.fetch(f"batch_{batch_id}", connection=redis_conn)
job.cancel()
except:
pass
# 更新数据库状态
batch.status = 'failed'
batch.error_message = 'Task cancelled by user'
batch.completed_at = datetime.utcnow()
db.session.commit()
return True
except Exception as e:
print(f"取消任务时出错: {str(e)}")
return False
@staticmethod
def process_results_and_evaluations(batch_id):
"""
处理任务结果并生成评估在worker完成后调用
Args:
batch_id: 批次ID
"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return
# 获取输出目录中的图片
project_root = os.path.dirname(current_app.root_path)
output_dir = os.path.join(
project_root,
current_app.config['PERTURBED_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id)
)
# 获取原始图片
original_images = Image.query.filter_by(
batch_id=batch.id
).join(ImageType).filter(
ImageType.type_code == 'original'
).all()
perturbed_type = ImageType.query.filter_by(type_code='perturbed').first()
processed_images = []
# 为每张原始图片找到对应的扰动图片
for original_image in original_images:
# 构建扰动图片路径
original_name = os.path.splitext(original_image.original_filename)[0]
original_ext = os.path.splitext(original_image.original_filename)[1]
perturbed_filename = f"{original_name}_perturbed{original_ext}"
perturbed_path = os.path.join(output_dir, perturbed_filename)
if os.path.exists(perturbed_path):
# 保存扰动图片到数据库
perturbed_image = Image(
user_id=batch.user_id,
batch_id=batch.id,
father_id=original_image.id,
original_filename=f"perturbed_{original_image.original_filename}",
stored_filename=os.path.basename(perturbed_path),
file_path=perturbed_path,
file_size=os.path.getsize(perturbed_path),
image_type_id=perturbed_type.id,
width=original_image.width,
height=original_image.height
)
db.session.add(perturbed_image)
processed_images.append((original_image, perturbed_image))
db.session.commit()
# 生成评估结果
TaskService._generate_evaluations(batch, processed_images)
except Exception as e:
print(f"处理结果时出错: {str(e)}") @staticmethod
def _generate_evaluations(batch, processed_images):
"""生成评估结果(虚拟实现)"""
try:
for original_image, perturbed_image in processed_images:
# TODO: 实现真实的评估引擎
# 目前使用虚拟数据
import random
# 图像质量对比评估(虚拟数据)
quality_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='image_quality',
purification_applied=False,
fid_score=round(random.uniform(0.1, 0.5), 4),
lpips_score=round(random.uniform(0.01, 0.1), 4),
ssim_score=round(random.uniform(0.85, 0.99), 4),
psnr_score=round(random.uniform(30, 45), 2),
heatmap_path=None
)
db.session.add(quality_evaluation)
# 模型生成对比评估(虚拟数据)
generation_evaluation = EvaluationResult(
reference_image_id=original_image.id,
target_image_id=perturbed_image.id,
evaluation_type='model_generation',
purification_applied=False,
fid_score=round(random.uniform(0.2, 0.8), 4),
lpips_score=round(random.uniform(0.05, 0.2), 4),
ssim_score=round(random.uniform(0.7, 0.9), 4),
psnr_score=round(random.uniform(25, 40), 2),
heatmap_path=None
)
db.session.add(generation_evaluation)
db.session.commit()
except Exception as e:
print(f"生成评估结果时出错: {str(e)}")
@staticmethod
def get_processing_progress(batch_id):
"""获取处理进度"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return 0
if batch.status == 'pending':
return 0
elif batch.status == 'completed':
return 100
elif batch.status == 'failed':
return 0
elif batch.status == 'processing':
# 简单的进度计算:根据已处理的图片数量
total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'original'
).count()
processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter(
ImageType.type_code == 'perturbed'
).count()
if total_images == 0:
return 0
progress = int((processed_images / total_images) * 80) # 80%用于图像处理20%用于评估
return min(progress, 95) # 最多95%剩余5%用于最终完成
return 0
except Exception as e:
print(f"获取处理进度时出错: {str(e)}")
return 0
@staticmethod
def start_finetune_task(finetune_task):
"""
启动微调任务使用 FinetuneBatch
Args:
finetune_task: FinetuneBatch对象
Returns:
包含两个job_id的字典
"""
try:
# 获取关联的扰动任务
batch = finetune_task.batch
if not batch:
print(f"FinetuneBatch {finetune_task.id} 没有关联的扰动任务")
return None
# 检查是否有扰动图片
perturbed_images = Image.query.filter_by(
batch_id=batch.id
).join(ImageType).filter(
ImageType.type_code == 'perturbed'
).all()
if not perturbed_images:
print(f"Batch {batch.id} 没有扰动图片,无法启动微调任务")
finetune_task.status = 'failed'
finetune_task.error_message = '没有找到扰动图片'
db.session.commit()
return None
project_root = os.path.dirname(current_app.root_path)
finetune_method = finetune_task.finetune_config.method_code
queue = TaskService._get_queue()
from app.workers.finetune_worker import run_finetune_task
# 原始图片目录
original_dir = os.path.join(
project_root,
current_app.config['ORIGINAL_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id)
)
# 扰动图片目录
perturbed_dir = os.path.join(
project_root,
current_app.config['PERTURBED_IMAGES_FOLDER'],
str(batch.user_id),
str(batch.id)
)
# 模型输出目录
original_model_dir = os.path.join(
project_root,
current_app.config['MODEL_ORIGINAL_FOLDER'],
str(batch.user_id), str(batch.id)
)
perturbed_model_dir = os.path.join(
project_root,
current_app.config['MODEL_PERTURBED_FOLDER'],
str(batch.user_id), str(batch.id)
)
# 类别图片目录(微调用)
class_finetune_dir = os.path.join(
project_root, 'static', 'class_finetune',
str(batch.user_id), str(batch.id)
)
# 推理提示词
inference_prompts = "a photo of sks person"
# 1. 用原始图片微调模型
job_original = queue.enqueue(
run_finetune_task,
finetune_batch_id=finetune_task.id,
batch_id=batch.id,
finetune_method=finetune_method,
train_images_dir=original_dir,
output_model_dir=original_model_dir,
class_dir=class_finetune_dir,
inference_prompts=inference_prompts,
is_perturbed=False,
custom_params=None,
job_timeout=AlgorithmConfig.TASK_TIMEOUT,
job_id=f"finetune_original_{finetune_task.id}"
)
# 2. 用扰动图片微调模型(依赖于原始图片微调完成)
job_perturbed = queue.enqueue(
run_finetune_task,
finetune_batch_id=finetune_task.id,
batch_id=batch.id,
finetune_method=finetune_method,
train_images_dir=perturbed_dir,
output_model_dir=perturbed_model_dir,
class_dir=class_finetune_dir,
inference_prompts=inference_prompts,
is_perturbed=True,
custom_params=None,
job_timeout=AlgorithmConfig.TASK_TIMEOUT,
job_id=f"finetune_perturbed_{finetune_task.id}",
depends_on=job_original
)
# 更新微调任务状态
finetune_task.status = 'queued'
finetune_task.original_job_id = job_original.id
finetune_task.perturbed_job_id = job_perturbed.id
finetune_task.started_at = datetime.utcnow()
db.session.commit()
return {
'original_job_id': job_original.id,
'perturbed_job_id': job_perturbed.id
}
except Exception as e:
print(f"启动微调任务时出错: {str(e)}")
finetune_task.status = 'failed'
finetune_task.error_message = str(e)
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
return None
@staticmethod
def get_finetune_task_status(finetune_id):
"""
获取微调任务状态使用 FinetuneBatch
同时会检查并更新任务状态
Args:
finetune_id: 微调任务ID
Returns:
微调任务状态信息
"""
try:
finetune_task = FinetuneBatch.query.get(finetune_id)
if not finetune_task:
return {'status': 'not_found'}
# 如果任务不是最终状态,检查并更新状态
if finetune_task.status not in ['completed', 'failed']:
from app.workers.finetune_worker import _check_and_update_finetune_status
_check_and_update_finetune_status(finetune_task)
# 刷新对象以获取最新状态
db.session.refresh(finetune_task)
# 如果任务已完成或失败,直接返回数据库状态
if finetune_task.status in ['completed', 'failed']:
return {
'status': finetune_task.status,
'error': finetune_task.error_message if finetune_task.status == 'failed' else None,
'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None,
'completed_at': finetune_task.completed_at.isoformat() if finetune_task.completed_at else None
}
# 从RQ获取任务状态
redis_conn = TaskService._get_redis_connection()
original_job_status = 'not_found'
perturbed_job_status = 'not_found'
try:
if finetune_task.original_job_id:
original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn)
original_job_status = original_job.get_status()
except:
pass
try:
if finetune_task.perturbed_job_id:
perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn)
perturbed_job_status = perturbed_job.get_status()
except:
pass
# 映射状态
status_map = {
'queued': 'queued',
'started': 'processing',
'finished': 'completed',
'failed': 'failed',
'not_found': 'not_started'
}
return {
'status': finetune_task.status,
'original_finetune': status_map.get(original_job_status, 'unknown'),
'perturbed_finetune': status_map.get(perturbed_job_status, 'unknown'),
'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None
}
except Exception as e:
print(f"获取微调任务状态时出错: {str(e)}")
return {'status': 'error', 'error': str(e)}
@staticmethod
def generate_final_evaluations(batch_id):
"""
生成最终评估对比原始和扰动图片微调后的模型生成效果
此方法在两个微调任务都完成后调用
Args:
batch_id: 批次ID
"""
try:
batch = Batch.query.get(batch_id)
if not batch:
return
# 获取原始图片生成的结果
original_generated = Image.query.filter_by(
batch_id=batch_id
).join(ImageType).filter(
ImageType.type_code == 'original_generate'
).all()
# 获取扰动图片生成的结果
perturbed_generated = Image.query.filter_by(
batch_id=batch_id
).join(ImageType).filter(
ImageType.type_code == 'perturbed_generate'
).all()
if not original_generated or not perturbed_generated:
print(f"Batch {batch_id} 缺少生成的图片,无法评估")
return
# 配对评估
for orig_gen in original_generated:
# 找到对应的扰动生成图片(基于相同的父图片)
matching_pert_gen = None
for pert_gen in perturbed_generated:
# 尝试匹配文件名或父图片关系
if pert_gen.original_filename.replace('generated_', '') == orig_gen.original_filename.replace('generated_', ''):
matching_pert_gen = pert_gen
break
if matching_pert_gen:
# TODO: 实现真实的评估引擎
# 目前使用虚拟数据
import random
# 保存评估结果(虚拟数据)
generation_evaluation = EvaluationResult(
reference_image_id=orig_gen.id,
target_image_id=matching_pert_gen.id,
evaluation_type='model_generation',
purification_applied=False,
fid_score=round(random.uniform(0.3, 0.9), 4),
lpips_score=round(random.uniform(0.1, 0.3), 4),
ssim_score=round(random.uniform(0.6, 0.85), 4),
psnr_score=round(random.uniform(20, 35), 2),
heatmap_path=None
)
db.session.add(generation_evaluation)
db.session.commit()
print(f"Batch {batch_id} 最终评估完成")
except Exception as e:
print(f"生成最终评估时出错: {str(e)}")
db.session.rollback()

@ -1,53 +1,53 @@
"""
文件处理工具类
"""
import os
from werkzeug.utils import secure_filename
from flask import current_app
def allowed_file(filename):
"""检查文件扩展名是否被允许"""
if not filename:
return False
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in allowed_extensions
def save_uploaded_file(file, upload_path):
"""保存上传的文件"""
try:
if not file or not allowed_file(file.filename):
return None
filename = secure_filename(file.filename)
# 确保目录存在
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
file.save(upload_path)
return upload_path
except Exception as e:
print(f"保存文件失败: {str(e)}")
return None
def get_file_size(file_path):
"""获取文件大小"""
try:
return os.path.getsize(file_path)
except:
return 0
def delete_file(file_path):
"""删除文件"""
try:
if os.path.exists(file_path):
os.remove(file_path)
return True
except:
pass
"""
文件处理工具类
"""
import os
from werkzeug.utils import secure_filename
from flask import current_app
def allowed_file(filename):
"""检查文件扩展名是否被允许"""
if not filename:
return False
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in allowed_extensions
def save_uploaded_file(file, upload_path):
"""保存上传的文件"""
try:
if not file or not allowed_file(file.filename):
return None
filename = secure_filename(file.filename)
# 确保目录存在
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
file.save(upload_path)
return upload_path
except Exception as e:
print(f"保存文件失败: {str(e)}")
return None
def get_file_size(file_path):
"""获取文件大小"""
try:
return os.path.getsize(file_path)
except:
return 0
def delete_file(file_path):
"""删除文件"""
try:
if os.path.exists(file_path):
os.remove(file_path)
return True
except:
pass
return False

@ -1,16 +1,16 @@
"""
JWT工具函数
"""
from functools import wraps
from flask_jwt_extended import get_jwt_identity
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
jwt_identity = get_jwt_identity()
if jwt_identity is not None:
# 在函数调用前注入转换后的user_id
kwargs['current_user_id'] = int(jwt_identity)
return f(*args, **kwargs)
"""
JWT工具函数
"""
from functools import wraps
from flask_jwt_extended import get_jwt_identity
def int_jwt_required(f):
"""获取JWT身份并转换为整数的装饰器"""
@wraps(f)
def wrapped(*args, **kwargs):
jwt_identity = get_jwt_identity()
if jwt_identity is not None:
# 在函数调用前注入转换后的user_id
kwargs['current_user_id'] = int(jwt_identity)
return f(*args, **kwargs)
return wrapped

@ -0,0 +1,489 @@
"""
RQ Worker 微调任务处理器
在后台执行模型微调任务
"""
import os
import subprocess
import logging
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def _check_and_update_finetune_status(finetune_task):
"""
检查微调任务状态并更新
当原始和扰动图片的微调都完成时更新任务状态为completed
Args:
finetune_task: FinetuneBatch对象
"""
from app import db
from rq.job import Job
from redis import Redis
from config.algorithm_config import AlgorithmConfig
try:
# 刷新数据库对象,确保获取最新状态
db.session.refresh(finetune_task)
# 如果状态已经是completed或failed不再检查
if finetune_task.status in ['completed', 'failed']:
return
redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL)
original_job_done = False
perturbed_job_done = False
has_original_job = False
has_perturbed_job = False
# 检查原始图片微调任务
if finetune_task.original_job_id:
has_original_job = True
try:
original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn)
status = original_job.get_status()
logger.info(f"Original job {finetune_task.original_job_id} status: {status}")
if status == 'finished':
original_job_done = True
elif status == 'failed':
# 如果原始任务失败,整个微调任务标记为失败
finetune_task.status = 'failed'
finetune_task.error_message = f"Original finetune job failed: {original_job.exc_info}"
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.error(f"FinetuneBatch {finetune_task.id} failed: original job failed")
return
except Exception as e:
logger.error(f"Error checking original job: {str(e)}")
# 检查扰动图片微调任务
if finetune_task.perturbed_job_id:
has_perturbed_job = True
try:
perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn)
status = perturbed_job.get_status()
logger.info(f"Perturbed job {finetune_task.perturbed_job_id} status: {status}")
if status == 'finished':
perturbed_job_done = True
elif status == 'failed':
# 如果扰动任务失败,整个微调任务标记为失败
finetune_task.status = 'failed'
finetune_task.error_message = f"Perturbed finetune job failed: {perturbed_job.exc_info}"
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.error(f"FinetuneBatch {finetune_task.id} failed: perturbed job failed")
return
except Exception as e:
logger.error(f"Error checking perturbed job: {str(e)}")
# 如果两个任务都完成更新状态为completed
if has_original_job and has_perturbed_job and original_job_done and perturbed_job_done:
finetune_task.status = 'completed'
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
logger.info(f"FinetuneBatch {finetune_task.id} completed - both jobs finished")
else:
logger.info(f"FinetuneBatch {finetune_task.id} not all jobs finished yet: original={original_job_done}, perturbed={perturbed_job_done}")
except Exception as e:
logger.error(f"Error checking finetune status: {str(e)}", exc_info=True)
def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed=False, custom_params=None):
"""
执行微调任务
Args:
finetune_batch_id: 微调任务ID
batch_id: 扰动任务批次ID
finetune_method: 微调方法 (dreambooth, lora)
train_images_dir: 训练图片目录原始或扰动
output_model_dir: 模型输出目录
class_dir: 类别图片目录
inference_prompts: 推理提示词
is_perturbed: 是否是扰动图片训练
custom_params: 自定义参数
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import FinetuneBatch, Batch, Image, ImageType
app = create_app()
with app.app_context():
try:
finetune_task = FinetuneBatch.query.get(finetune_batch_id)
if not finetune_task:
raise ValueError(f"FinetuneBatch {finetune_batch_id} not found")
batch = Batch.query.get(batch_id)
if not batch:
raise ValueError(f"Batch {batch_id} not found")
# 更新微调任务状态为处理中
if finetune_task.status == 'queued':
finetune_task.status = 'processing'
db.session.commit()
logger.info(f"Starting finetune task for FinetuneBatch {finetune_batch_id}, Batch {batch_id}")
logger.info(f"Method: {finetune_method}, Perturbed: {is_perturbed}")
# 确保目录存在
os.makedirs(output_model_dir, exist_ok=True)
os.makedirs(class_dir, exist_ok=True)
# 获取配置
use_real = AlgorithmConfig.USE_REAL_ALGORITHMS
if use_real:
# 使用真实微调算法
result = _run_real_finetune(
finetune_method, batch_id, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed, custom_params
)
else:
# 使用虚拟微调实现
result = _run_virtual_finetune(
finetune_method, batch_id, train_images_dir, output_model_dir,
is_perturbed
)
# 保存生成的图片到数据库
_save_generated_images(batch_id, output_model_dir, is_perturbed)
# 检查两个任务是否都已完成
_check_and_update_finetune_status(finetune_task)
logger.info(f"Finetune task completed for FinetuneBatch {finetune_batch_id}")
return result
except Exception as e:
logger.error(f"Finetune task failed for FinetuneBatch {finetune_batch_id}: {str(e)}", exc_info=True)
# 更新微调任务状态为失败
if finetune_task:
finetune_task.status = 'failed'
finetune_task.error_message = str(e)
finetune_task.completed_at = datetime.utcnow()
db.session.commit()
raise
def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model_dir,
class_dir, inference_prompts, is_perturbed, custom_params):
"""运行真实微调算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real finetune: {finetune_method}")
# 获取微调脚本路径和环境
finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {})
script_path = finetune_config.get('real_script')
conda_env = finetune_config.get('conda_env')
default_params = finetune_config.get('default_params', {})
if not script_path:
raise ValueError(f"Finetune method {finetune_method} not configured")
# 合并参数
params = {**default_params, **(custom_params or {})}
# 构建命令行参数
cmd_args = [
f"--instance_data_dir={train_images_dir}",
f"--output_dir={output_model_dir}",
f"--class_data_dir={class_dir}",
]
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 构建完整命令
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
image_type = 'perturbed' if is_perturbed else 'original'
log_file = os.path.join(
log_dir,
f'finetune_{image_type}_{batch_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
import shutil
for item in os.listdir(class_dir):
item_path = os.path.join(class_dir, item)
if os.path.isfile(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# 清理output_model_dir中的非图片文件
logger.info(f"Cleaning non-image files in output directory: {output_model_dir}")
if os.path.exists(output_model_dir):
import shutil
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'}
for item in os.listdir(output_model_dir):
item_path = os.path.join(output_model_dir, item)
# 如果是目录,直接删除
if os.path.isdir(item_path):
logger.info(f"Removing directory: {item_path}")
shutil.rmtree(item_path)
# 如果是文件,检查是否为图片
elif os.path.isfile(item_path):
_, ext = os.path.splitext(item.lower())
if ext not in image_extensions:
logger.info(f"Removing non-image file: {item_path}")
os.remove(item_path)
return {
'status': 'success',
'output_dir': output_model_dir,
'log_file': log_file
}
def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, is_perturbed):
"""运行虚拟微调实现"""
from config.algorithm_config import AlgorithmConfig
import glob
logger.info(f"Running virtual finetune: {finetune_method}")
# 获取微调配置
finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {})
if not finetune_config:
raise ValueError(f"Finetune method {finetune_method} not configured")
conda_env = finetune_config.get('conda_env')
default_params = finetune_config.get('default_params', {})
# 获取虚拟微调脚本路径
script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py'
script_path = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../algorithms/finetune_virtual',
script_name
))
if not os.path.exists(script_path):
raise FileNotFoundError(f"Virtual finetune script not found: {script_path}")
logger.info(f"Virtual script path: {script_path}")
logger.info(f"Conda environment: {conda_env}")
# 创建输出目录
os.makedirs(output_model_dir, exist_ok=True)
validation_output_dir = os.path.join(output_model_dir, 'generated')
os.makedirs(validation_output_dir, exist_ok=True)
# 构建命令行参数(与真实微调参数一致)
cmd_args = [
f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}",
f"--instance_data_dir={train_images_dir}",
f"--output_dir={output_model_dir}",
f"--validation_image_output_dir={validation_output_dir}",
f"--class_data_dir=/tmp/class_placeholder",
]
# 添加is_perturbed标志
if is_perturbed:
cmd_args.append("--is_perturbed")
# 添加其他默认参数
for key, value in default_params.items():
if key == 'pretrained_model_name_or_path':
continue # 已添加
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 使用conda run执行虚拟脚本
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
image_type = 'perturbed' if is_perturbed else 'original'
log_file = os.path.join(
log_dir,
f'virtual_{finetune_method}_{image_type}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}")
# 统计生成的图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
generated_files = []
for ext in image_extensions:
generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext)))
generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext.upper())))
logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images")
return {
'status': 'success',
'output_dir': output_model_dir,
'generated_count': len(generated_files),
'generated_files': generated_files,
'log_file': log_file
}
def _save_generated_images(batch_id, output_model_dir, is_perturbed):
"""保存生成的图片到数据库"""
from app import db
from app.database import Batch, Image, ImageType
import glob
try:
batch = Batch.query.get(batch_id)
if not batch:
return
# 确定图片类型
if is_perturbed:
image_type = ImageType.query.filter_by(type_code='perturbed_generate').first()
else:
image_type = ImageType.query.filter_by(type_code='original_generate').first()
if not image_type:
logger.error(f"Image type not found for is_perturbed={is_perturbed}")
return
# 查找生成的图片
generated_dir = os.path.join(output_model_dir, 'generated')
if not os.path.exists(generated_dir):
# 尝试直接从output_model_dir查找
generated_dir = output_model_dir
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(generated_dir, ext)))
image_files.extend(glob.glob(os.path.join(generated_dir, ext.upper())))
logger.info(f"Found {len(image_files)} generated images to save")
# 保存到数据库
saved_count = 0
for image_path in image_files:
try:
from PIL import Image as PILImage
filename = os.path.basename(image_path)
# 检查是否已经保存过使用filename作为stored_filename
existing = Image.query.filter_by(
batch_id=batch_id,
stored_filename=filename
).first()
if existing:
logger.info(f"Image already exists: {filename}")
continue
with PILImage.open(image_path) as img:
width, height = img.size
# 生成图片不设置父图片关系(多对多关系,无法确定具体父图片)
# 创建图片记录直接使用filename算法已经生成了正确格式
generated_image = Image(
user_id=batch.user_id,
batch_id=batch_id,
father_id=None, # 微调生成图片无特定父图片
original_filename=filename,
stored_filename=filename, # 算法输出已经是正确格式
file_path=image_path,
file_size=os.path.getsize(image_path),
image_type_id=image_type.id,
width=width,
height=height
)
db.session.add(generated_image)
saved_count += 1
logger.info(f"Saved generated image: {filename}")
except Exception as e:
logger.error(f"Failed to save {image_path}: {str(e)}")
db.session.commit()
logger.info(f"Successfully saved {saved_count} generated images to database")
except Exception as e:
logger.error(f"Error saving generated images: {str(e)}")
db.session.rollback()

@ -0,0 +1,422 @@
"""
RQ Worker任务处理器
在后台执行对抗性扰动算法
"""
import os
import sys
import subprocess
import logging
from datetime import datetime
from pathlib import Path
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def run_perturbation_task(batch_id, algorithm_code, epsilon, use_strong_protection,
input_dir, output_dir, class_dir, custom_params=None):
"""
执行对抗性扰动任务
Args:
batch_id: 任务批次ID
algorithm_code: 算法代码
epsilon: 扰动强度
use_strong_protection: 是否使用防净化版本
input_dir: 输入图片目录
output_dir: 输出目录
class_dir: 类别图片目录
custom_params: 自定义参数
Returns:
任务执行结果
"""
from config.algorithm_config import AlgorithmConfig
from app import create_app, db
from app.database import Batch
# 创建应用上下文
app = create_app()
with app.app_context():
try:
# 更新任务状态
batch = Batch.query.get(batch_id)
if not batch:
raise ValueError(f"Batch {batch_id} not found")
batch.status = 'processing'
batch.started_at = datetime.utcnow()
db.session.commit()
logger.info(f"Starting perturbation task for batch {batch_id}")
logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}")
# 获取算法配置
use_real = AlgorithmConfig.USE_REAL_ALGORITHMS
script_path = AlgorithmConfig.get_script_path(algorithm_code)
conda_env = AlgorithmConfig.get_conda_env(algorithm_code)
# 确保目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(class_dir, exist_ok=True)
if use_real:
# 使用真实算法
result = _run_real_algorithm(
script_path, conda_env, algorithm_code, batch_id,
epsilon, use_strong_protection, input_dir, output_dir,
class_dir, custom_params
)
else:
# 使用虚拟实现
result = _run_virtual_algorithm(
algorithm_code, batch_id, epsilon, use_strong_protection,
input_dir, output_dir
)
# 更新任务状态为完成
batch.status = 'completed'
batch.completed_at = datetime.utcnow()
db.session.commit()
# 保存扰动图片到数据库
_save_perturbed_images(batch_id, output_dir)
logger.info(f"Task completed successfully for batch {batch_id}")
return result
except Exception as e:
logger.error(f"Task failed for batch {batch_id}: {str(e)}", exc_info=True)
# 更新任务状态为失败
if batch:
batch.status = 'failed'
batch.error_message = str(e)
batch.completed_at = datetime.utcnow()
db.session.commit()
raise
def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id,
epsilon, use_strong_protection, input_dir, output_dir,
class_dir, custom_params):
"""运行真实算法"""
from config.algorithm_config import AlgorithmConfig
logger.info(f"Running real algorithm: {algorithm_code}")
logger.info(f"Conda environment: {conda_env}")
logger.info(f"Script path: {script_path}")
# 获取默认参数
default_params = AlgorithmConfig.get_default_params(algorithm_code)
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
cmd_args = []
if algorithm_code == 'aspl':
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={str(epsilon)}",
])
elif algorithm_code == 'simac':
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--pgd_eps={str(epsilon)}",
])
elif algorithm_code == 'caat':
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={str(epsilon)}",
])
elif algorithm_code == 'pid':
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={str(epsilon)}",
])
else:
raise ValueError(f"Unsupported algorithm code: {algorithm_code}")
# 添加其他参数
for key, value in params.items():
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 构建完整命令
# 使用conda run避免环境嵌套问题
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'accelerate', 'launch', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'batch_{batch_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
logger.info(f"output_dir: {output_dir}")
logger.info(f"log_file: {log_file}")
if process.returncode != 0:
raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}")
# 清理class_dir
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
import shutil
for item in os.listdir(class_dir):
item_path = os.path.join(class_dir, item)
if os.path.isfile(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
return {
'status': 'success',
'output_dir': output_dir,
'log_file': log_file
}
def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protection,
input_dir, output_dir):
"""运行虚拟算法实现"""
from config.algorithm_config import AlgorithmConfig
import glob
logger.info(f"Running virtual algorithm: {algorithm_code}")
# 获取算法配置
algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code)
if not algo_config:
raise ValueError(f"Algorithm {algorithm_code} not configured")
conda_env = algo_config.get('conda_env')
default_params = algo_config.get('default_params', {})
# 获取虚拟算法脚本路径
script_path = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../algorithms/perturbation_virtual',
f'{algorithm_code}.py'
))
if not os.path.exists(script_path):
raise FileNotFoundError(f"Virtual script not found: {script_path}")
logger.info(f"Virtual script path: {script_path}")
logger.info(f"Conda environment: {conda_env}")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 构建命令行参数(与真实算法参数一致)
cmd_args = [
f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}",
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir=/tmp/class_placeholder",
f"--pgd_eps={epsilon}",
]
# 添加其他默认参数
for key, value in default_params.items():
if key == 'pretrained_model_name_or_path':
continue # 已添加
if isinstance(value, bool):
if value:
cmd_args.append(f"--{key}")
else:
cmd_args.append(f"--{key}={value}")
# 使用conda run执行虚拟脚本
cmd = [
'/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output',
'python', script_path
] + cmd_args
logger.info(f"Executing command: {' '.join(cmd)}")
# 设置日志文件
log_dir = AlgorithmConfig.LOGS_DIR
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(
log_dir,
f'virtual_{algorithm_code}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 执行命令
with open(log_file, 'w') as f:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时输出日志
for line in process.stdout:
f.write(line)
f.flush()
logger.info(line.strip())
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}")
# 统计处理的图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
processed_files = []
for ext in image_extensions:
processed_files.extend(glob.glob(os.path.join(output_dir, ext)))
processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images")
return {
'status': 'success',
'output_dir': output_dir,
'processed_count': len(processed_files),
'processed_files': processed_files,
'log_file': log_file
}
def _save_perturbed_images(batch_id, output_dir):
"""保存扰动图片到数据库"""
from app import db
from app.database import Batch, Image, ImageType
import glob
from PIL import Image as PILImage
try:
batch = Batch.query.get(batch_id)
if not batch:
logger.error(f"Batch {batch_id} not found")
return
# 获取扰动图片类型
perturbed_type = ImageType.query.filter_by(type_code='perturbed').first()
if not perturbed_type:
logger.error("Perturbed image type not found")
return
# 获取原始图片列表
original_type = ImageType.query.filter_by(type_code='original').first()
original_images = Image.query.filter_by(
batch_id=batch_id,
image_type_id=original_type.id
).all()
# 创建原图映射字典: stored_filename -> Image对象
original_map = {img.stored_filename: img for img in original_images}
# 查找输出目录中的扰动图片
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp']
perturbed_files = []
for ext in image_extensions:
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext)))
perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper())))
logger.info(f"Found {len(perturbed_files)} perturbed images to save")
saved_count = 0
for perturbed_path in perturbed_files:
try:
filename = os.path.basename(perturbed_path)
# 扰动图片命名格式: perturbed_{原图名}.ext
# 提取原图名
parent_image = None
if filename.startswith('perturbed_'):
# 去掉perturbed_前缀得到原图名
original_filename = filename[len('perturbed_'):]
# 尝试从映射中查找
parent_image = original_map.get(original_filename)
if not parent_image:
logger.warning(f"Parent image not found for {filename}, original should be: {original_filename}")
# 获取图片尺寸
with PILImage.open(perturbed_path) as img:
width, height = img.size
# 检查是否已经保存过使用filename作为stored_filename
existing = Image.query.filter_by(
batch_id=batch_id,
stored_filename=filename
).first()
if existing:
logger.info(f"Image already exists: {filename}")
continue
# 创建扰动图片记录直接使用filename因为算法已经添加了perturbed_前缀
perturbed_image = Image(
user_id=batch.user_id,
batch_id=batch_id,
father_id=parent_image.id if parent_image else None,
original_filename=filename,
stored_filename=filename, # 算法输出已经是perturbed_格式
file_path=perturbed_path,
file_size=os.path.getsize(perturbed_path),
image_type_id=perturbed_type.id,
width=width,
height=height
)
db.session.add(perturbed_image)
saved_count += 1
logger.info(f"Saved perturbed image: {filename} (parent: {parent_image.stored_filename if parent_image else 'None'})")
except Exception as e:
logger.error(f"Failed to save {perturbed_path}: {str(e)}")
db.session.commit()
logger.info(f"Successfully saved {saved_count} perturbed images to database")
except Exception as e:
logger.error(f"Error saving perturbed images: {str(e)}")
db.session.rollback()

@ -1,16 +0,0 @@
# MuseGuard 环境变量配置文件
# 注意:此文件包含敏感信息,不应提交到版本控制系统
# 数据库配置
DB_USER=root
DB_PASSWORD=your_password_here
DB_HOST=localhost
DB_NAME=your_database_name_here
# Flask配置
SECRET_KEY=museguard-secret-key-2024
JWT_SECRET_KEY=jwt-secret-string
# 开发模式
FLASK_ENV=development
FLASK_DEBUG=True

@ -0,0 +1,218 @@
"""
算法配置
定义各种对抗性扰动算法的参数环境和脚本路径
"""
import os
from dotenv import load_dotenv
# 加载算法专用环境变量
config_dir = os.path.dirname(os.path.abspath(__file__))
algorithm_env_path = os.path.join(config_dir, 'algorithm.env')
load_dotenv(algorithm_env_path)
class AlgorithmConfig:
"""算法配置基类"""
# 算法脚本根目录
ALGORITHMS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'app', 'algorithms')
# 是否使用真实算法从环境变量读取默认False
USE_REAL_ALGORITHMS = os.getenv('USE_REAL_ALGORITHMS', 'false').lower() == 'true'
# Redis配置从环境变量读取
REDIS_URL = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
# RQ队列名称
RQ_QUEUE_NAME = os.getenv('RQ_QUEUE_NAME', 'perturbation_tasks')
# 任务超时时间(秒)
TASK_TIMEOUT = int(os.getenv('TASK_TIMEOUT', '3600'))
# 日志目录
LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# Conda环境配置从环境变量读取支持自定义
CONDA_ENVS = {
'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'),
'simac': os.getenv('CONDA_ENV_SIMAC', 'simac'),
'caat': os.getenv('CONDA_ENV_CAAT', 'caat'),
'pid': os.getenv('CONDA_ENV_PID', 'pid'),
'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'),
'lora': os.getenv('CONDA_ENV_LORA', 'pid'),
}
# 模型路径配置
MODELS_DIR = {
'model1': os.getenv('MODEL_SD21', '/root/autodl-tmp/muse-guard_-backend/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06'),
'model2': os.getenv('MODEL_SD15', '/root/autodl-tmp/muse-guard_-backend/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14'),
}
# 算法脚本配置
PERTURBATION_SCRIPTS = {
'aspl': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'aspl.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['aspl'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model1'],
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 200,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 2,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'checkpointing_iterations': 1,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
'seed': 0
}
},
'simac': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'simac.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['simac'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model1'],
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 200,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 2,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'checkpointing_iterations': 1,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
'seed': 0
}
},
'caat': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'caat.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['caat'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'instance_prompt': 'a photo of a person',
'resolution': 512,
'learning_rate': 1e-5,
'lr_warmup_steps': 0,
'max_train_steps': 10,
'hflip': True,
'mixed_precision': 'bf16',
'alpha': 5e-3
}
},
'pid': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['pid'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'max_train_steps': 10,
'center_crop': True,
'attack_type': 'add-log'
}
}
}
@classmethod
def get_perturbation_config(cls, algorithm_code):
"""获取算法配置"""
return cls.PERTURBATION_SCRIPTS.get(algorithm_code, {})
@classmethod
def get_script_path(cls, algorithm_code):
"""获取算法脚本路径"""
config = cls.get_perturbation_config(algorithm_code)
if cls.USE_REAL_ALGORITHMS:
return config.get('real_script')
else:
return config.get('virtual_script')
@classmethod
def get_conda_env(cls, algorithm_code):
"""获取算法的conda环境名称"""
config = cls.get_perturbation_config(algorithm_code)
return config.get('conda_env')
@classmethod
def get_default_params(cls, algorithm_code):
"""获取算法默认参数"""
config = cls.get_perturbation_config(algorithm_code)
return config.get('default_params', {}).copy()
# ========== 微调算法配置 ==========
FINETUNE_SCRIPTS = {
'dreambooth': {
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_gen.py'),
'virtual_script': None, # 使用虚拟实现在worker中
'conda_env': CONDA_ENVS['dreambooth'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'resolution': 512,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
'learning_rate': 1e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 1,
'max_train_steps': 1,
'checkpointing_steps': 1,
'center_crop': True,
'mixed_precision': 'bf16',
'prior_generation_precision': 'bf16',
'sample_batch_size': 1,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 1,
'validation_steps': 1
}
},
'lora': {
'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen.py'),
'virtual_script': None,
'conda_env': CONDA_ENVS['lora'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'resolution': 512,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
'learning_rate': 1e-4,
'lr_scheduler': 'constant',
'lr_warmup_steps': 0,
'num_class_images': 1,
'max_train_steps': 1,
'checkpointing_steps': 1,
'seed': 0,
'mixed_precision': 'fp16',
'rank': 4,
'validation_prompt': 'a photo of sks person',
'num_validation_images': 1
}
}
}
@classmethod
def get_finetune_config(cls, finetune_method):
"""获取微调算法配置"""
return cls.FINETUNE_SCRIPTS.get(finetune_method, {})

@ -1,109 +1,109 @@
"""
应用配置文件
"""
import os
from datetime import timedelta
from dotenv import load_dotenv
from urllib.parse import quote_plus
# 加载环境变量 - 从 config 目录读取 .env 文件
config_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(config_dir, '.env')
load_dotenv(env_path)
class Config:
"""基础配置类"""
# 基础配置
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
# 数据库配置 - 支持密码中的特殊字符
DB_USER = os.environ.get('DB_USER')
DB_PASSWORD = os.environ.get('DB_PASSWORD')
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
# URL编码密码中的特殊字符
from urllib.parse import quote_plus
_encoded_password = quote_plus(DB_PASSWORD)
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ENGINE_OPTIONS = {
'pool_pre_ping': True,
'pool_recycle': 300,
}
# JWT配置
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
# 静态文件根目录
STATIC_ROOT = 'static'
# 文件上传配置
UPLOAD_FOLDER = 'uploads' # 临时上传目录
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
# 图像处理配置
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
# 预设演示图像配置
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
# 邮件配置(用于注册验证)
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
# 算法配置
ALGORITHMS = {
'simac': {
'name': 'SimAC算法',
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
'default_epsilon': 8.0
},
'caat': {
'name': 'CAAT算法',
'description': 'Perturbing Attention Gives You More Bang for the Buck',
'default_epsilon': 16.0
},
'pid': {
'name': 'PID算法',
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
'default_epsilon': 4.0
}
}
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
"""
应用配置文件
"""
import os
from datetime import timedelta
from dotenv import load_dotenv
from urllib.parse import quote_plus
# 加载环境变量 - 从 config 目录读取 .env 文件
config_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(config_dir, 'settings.env')
load_dotenv(env_path)
class Config:
"""基础配置类"""
# 基础配置
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
# 数据库配置 - 支持密码中的特殊字符
DB_USER = os.environ.get('DB_USER') or 'root'
DB_PASSWORD = os.environ.get('DB_PASSWORD') or ''
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
# URL编码密码中的特殊字符
from urllib.parse import quote_plus
_encoded_password = quote_plus(DB_PASSWORD) if DB_PASSWORD else ''
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ENGINE_OPTIONS = {
'pool_pre_ping': True,
'pool_recycle': 300,
}
# JWT配置
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
# 静态文件根目录
STATIC_ROOT = 'static'
# 文件上传配置
UPLOAD_FOLDER = 'uploads' # 临时上传目录
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
# 图像处理配置
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
# 预设演示图像配置
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
# 邮件配置(用于注册验证)
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
# 算法配置
ALGORITHMS = {
'simac': {
'name': 'SimAC算法',
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
'default_epsilon': 8.0
},
'caat': {
'name': 'CAAT算法',
'description': 'Perturbing Attention Gives You More Bang for the Buck',
'default_epsilon': 16.0
},
'pid': {
'name': 'PID算法',
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
'default_epsilon': 4.0
}
}
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}

@ -1,81 +1,81 @@
"""
数据库初始化脚本
"""
from app import create_app, db
from app.models import *
def init_database():
"""初始化数据库"""
app = create_app()
with app.app_context():
# 创建所有表
db.create_all()
# 初始化图片类型数据
image_types = [
{'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'},
{'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'},
{'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'},
{'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'}
]
for img_type in image_types:
existing = ImageType.query.filter_by(type_code=img_type['type_code']).first()
if not existing:
new_type = ImageType(**img_type)
db.session.add(new_type)
# 初始化加噪算法数据
perturbation_configs = [
{'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0},
{'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0},
{'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0},
{'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0}
]
for config in perturbation_configs:
existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = PerturbationConfig(**config)
db.session.add(new_config)
# 初始化微调方式数据
finetune_configs = [
{'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'},
{'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应Low-Rank Adaptation微调方法'},
{'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'}
]
for config in finetune_configs:
existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = FinetuneConfig(**config)
db.session.add(new_config)
# 创建默认管理员用户
admin_users = [
{'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'},
{'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'},
{'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'}
]
for admin_data in admin_users:
existing = User.query.filter_by(username=admin_data['username']).first()
if not existing:
admin_user = User(**admin_data)
admin_user.set_password('admin123') # 默认密码
db.session.add(admin_user)
# 为管理员创建默认配置
db.session.flush() # 确保user.id可用
user_config = UserConfig(user_id=admin_user.id)
db.session.add(user_config)
# 提交所有更改
db.session.commit()
print("数据库初始化完成!")
if __name__ == '__main__':
"""
数据库初始化脚本
"""
from app import create_app, db
from app.database import *
def init_database():
"""初始化数据库"""
app = create_app()
with app.app_context():
# 创建所有表
db.create_all()
# 初始化图片类型数据
image_types = [
{'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'},
{'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'},
{'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'},
{'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'}
]
for img_type in image_types:
existing = ImageType.query.filter_by(type_code=img_type['type_code']).first()
if not existing:
new_type = ImageType(**img_type)
db.session.add(new_type)
# 初始化加噪算法数据
perturbation_configs = [
{'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0},
{'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0},
{'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0},
{'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0}
]
for config in perturbation_configs:
existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = PerturbationConfig(**config)
db.session.add(new_config)
# 初始化微调方式数据
finetune_configs = [
{'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'},
{'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应Low-Rank Adaptation微调方法'},
{'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'}
]
for config in finetune_configs:
existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first()
if not existing:
new_config = FinetuneConfig(**config)
db.session.add(new_config)
# 创建默认管理员用户
admin_users = [
{'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'},
{'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'},
{'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'}
]
for admin_data in admin_users:
existing = User.query.filter_by(username=admin_data['username']).first()
if not existing:
admin_user = User(**admin_data)
admin_user.set_password('admin123') # 默认密码
db.session.add(admin_user)
# 为管理员创建默认配置
db.session.flush() # 确保user.id可用
user_config = UserConfig(user_id=admin_user.id)
db.session.add(user_config)
# 提交所有更改
db.session.commit()
print("数据库初始化完成!")
if __name__ == '__main__':
init_database()

@ -1,15 +0,0 @@
@echo off
chcp 65001 > nul
echo ==========================================
echo MuseGuard 快速启动脚本
echo ==========================================
echo.
REM 切换到项目目录
cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend"
REM 激活虚拟环境并启动服务器
echo 正在激活虚拟环境并启动服务器...
call venv_py311\Scripts\activate.bat && python run.py
pause

@ -1,33 +1,37 @@
# Core Flask Framework
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
Flask-JWT-Extended==4.6.0
Flask-CORS==5.0.0
Werkzeug==3.0.1
# Database
PyMySQL==1.1.1
# Image Processing
Pillow==10.4.0
numpy==1.26.4
# Security & Utils
cryptography==42.0.8
python-dotenv==1.0.1
# Additional Dependencies (auto-installed)
# alembic==1.17.0
# blinker==1.9.0
# cffi==2.0.0
# click==8.3.0
# greenlet==3.2.4
# itsdangerous==2.2.0
# Jinja2==3.1.6
# Mako==1.3.10
# MarkupSafe==3.0.3
# pycparser==2.23
# PyJWT==2.10.1
# SQLAlchemy==2.0.44
# Core Flask Framework
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
Flask-JWT-Extended==4.6.0
Flask-CORS==5.0.0
Werkzeug==3.0.1
# Database
PyMySQL==1.1.1
# Image Processing
Pillow==10.4.0
numpy==1.26.4
# Security & Utils
cryptography==42.0.8
python-dotenv==1.0.1
# Task Queue
redis==5.0.1
rq==1.16.2
# Additional Dependencies (auto-installed)
# alembic==1.17.0
# blinker==1.9.0
# cffi==2.0.0
# click==8.3.0
# greenlet==3.2.4
# itsdangerous==2.2.0
# Jinja2==3.1.6
# Mako==1.3.10
# MarkupSafe==3.0.3
# pycparser==2.23
# PyJWT==2.10.1
# SQLAlchemy==2.0.44
# typing_extensions==4.15.0

@ -1,21 +1,21 @@
"""
Flask应用启动脚本
"""
import os
from app import create_app
# 设置环境变量
os.environ.setdefault('FLASK_ENV', 'development')
# 创建应用实例
app = create_app()
if __name__ == '__main__':
# 开发模式启动
app.run(
host='0.0.0.0',
port=5000,
debug=True,
threaded=True
"""
Flask应用启动脚本
"""
import os
from app import create_app
# 设置环境变量
os.environ.setdefault('FLASK_ENV', 'development')
# 创建应用实例
app = create_app()
if __name__ == '__main__':
# 开发模式启动
app.run(
host='0.0.0.0',
port=6006,
debug=True,
threaded=True
)

@ -0,0 +1,111 @@
#!/bin/bash
# MuseGuard 后端快速启动脚本
echo "========================================"
echo " MuseGuard 后端服务启动"
echo "========================================"
echo ""
# 获取脚本所在目录
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
# 激活conda环境
echo "激活 conda 环境: flask"
source /root/miniconda3/etc/profile.d/conda.sh
conda activate flask
# 检查conda环境是否激活成功
if [ "$CONDA_DEFAULT_ENV" != "flask" ]; then
echo "[错误] 无法激活 flask 环境"
exit 1
fi
echo "[成功] conda 环境已激活: $CONDA_DEFAULT_ENV"
echo ""
# 检查数据库是否运行以MySQL为例可根据实际情况调整
echo "检查数据库连接..."
if mysqladmin ping -uroot > /dev/null 2>&1; then
echo "[成功] MySQL 连接正常"
else
echo "[警告] MySQL 未运行,正在启动 MySQL..."
service mysql start
sleep 2
if mysqladmin ping -uroot > /dev/null 2>&1; then
echo "[成功] MySQL 已启动"
else
echo "[错误] 无法启动 MySQL请手动启动"
echo " sudo systemctl start mysql"
echo " 或: sudo service mysql start"
exit 1
fi
fi
echo ""
# 检查Redis是否运行
echo "检查 Redis 连接..."
if redis-cli ping > /dev/null 2>&1; then
echo "[成功] Redis 连接正常"
else
echo "[警告] Redis 未运行,正在启动 Redis..."
redis-server --daemonize yes
sleep 2
if redis-cli ping > /dev/null 2>&1; then
echo "[成功] Redis 已启动"
else
echo "[错误] 无法启动 Redis请手动启动"
echo " sudo systemctl start redis-server"
echo " 或: redis-server --daemonize yes"
exit 1
fi
fi
echo ""
# 创建日志目录
mkdir -p logs
# 停止已存在的进程
echo "检查并停止已存在的进程..."
pkill -f "python run.py" 2>/dev/null
pkill -f "python worker.py" 2>/dev/null
sleep 1
# 启动Flask应用
echo "启动 Flask 应用..."
nohup python run.py > logs/flask.log 2>&1 &
FLASK_PID=$!
echo "Flask 应用已启动 (PID: $FLASK_PID)"
echo ""
# 等待Flask启动
sleep 2
# 启动RQ Worker
echo "启动 RQ Worker..."
nohup python worker.py > logs/worker.log 2>&1 &
WORKER_PID=$!
echo "RQ Worker 已启动 (PID: $WORKER_PID)"
echo ""
# 保存PID到文件
echo $FLASK_PID > logs/flask.pid
echo $WORKER_PID > logs/worker.pid
echo "========================================"
echo " 启动完成!"
echo "========================================"
echo ""
echo "服务信息:"
echo " - Flask API: http://127.0.0.1:6006"
echo " - Flask PID: $FLASK_PID"
echo " - Worker PID: $WORKER_PID"
echo ""
echo "查看日志:"
echo " - Flask: tail -f logs/flask.log"
echo " - Worker: tail -f logs/worker.log"
echo ""
echo "停止服务:"
echo " - 执行: ./stop.sh"
echo " - 或手动: kill $FLASK_PID $WORKER_PID"
echo ""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,86 @@
#!/bin/bash
# MuseGuard 后端服务状态检查脚本
echo "========================================"
echo " MuseGuard 后端服务状态"
echo "========================================"
echo ""
# 获取脚本所在目录
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
# 检查Flask应用
echo "📌 Flask 应用:"
if [ -f logs/flask.pid ]; then
FLASK_PID=$(cat logs/flask.pid)
if ps -p $FLASK_PID > /dev/null 2>&1; then
echo " ✅ 运行中 (PID: $FLASK_PID)"
echo " 📍 URL: http://127.0.0.1:6006"
echo " 📍 测试: http://127.0.0.1:6006/static/test.html"
else
echo " ❌ 未运行 (PID文件存在但进程不存在)"
fi
else
if pgrep -f "python run.py" > /dev/null 2>&1; then
FLASK_PID=$(pgrep -f "python run.py")
echo " ⚠️ 运行中但无PID文件 (PID: $FLASK_PID)"
else
echo " ❌ 未运行"
fi
fi
echo ""
# 检查Worker
echo "📌 RQ Worker:"
if [ -f logs/worker.pid ]; then
WORKER_PID=$(cat logs/worker.pid)
if ps -p $WORKER_PID > /dev/null 2>&1; then
echo " ✅ 运行中 (PID: $WORKER_PID)"
else
echo " ❌ 未运行 (PID文件存在但进程不存在)"
fi
else
if pgrep -f "python worker.py" > /dev/null 2>&1; then
WORKER_PID=$(pgrep -f "python worker.py")
echo " ⚠️ 运行中但无PID文件 (PID: $WORKER_PID)"
else
echo " ❌ 未运行"
fi
fi
echo ""
# 检查Redis
echo "📌 Redis:"
if redis-cli ping > /dev/null 2>&1; then
echo " ✅ 运行中"
else
echo " ❌ 未运行"
fi
echo ""
# 检查日志文件
echo "📌 日志文件:"
if [ -f logs/flask.log ]; then
FLASK_LOG_SIZE=$(du -h logs/flask.log | cut -f1)
echo " Flask: logs/flask.log ($FLASK_LOG_SIZE)"
else
echo " Flask: 无日志文件"
fi
if [ -f logs/worker.log ]; then
WORKER_LOG_SIZE=$(du -h logs/worker.log | cut -f1)
echo " Worker: logs/worker.log ($WORKER_LOG_SIZE)"
else
echo " Worker: 无日志文件"
fi
echo ""
echo "========================================"
echo " 快速操作"
echo "========================================"
echo "启动服务: ./start.sh"
echo "停止服务: ./stop.sh"
echo "查看Flask日志: tail -f logs/flask.log"
echo "查看Worker日志: tail -f logs/worker.log"
echo ""

@ -0,0 +1,51 @@
#!/bin/bash
# MuseGuard 后端服务停止脚本
echo "========================================"
echo " 停止 MuseGuard 后端服务"
echo "========================================"
echo ""
# 获取脚本所在目录
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
# 从PID文件停止进程
if [ -f logs/flask.pid ]; then
FLASK_PID=$(cat logs/flask.pid)
if ps -p $FLASK_PID > /dev/null 2>&1; then
echo "停止 Flask 应用 (PID: $FLASK_PID)..."
kill $FLASK_PID
echo "[成功] Flask 应用已停止"
else
echo "[提示] Flask 应用未运行"
fi
rm logs/flask.pid
else
echo "[提示] 未找到 Flask PID 文件"
fi
if [ -f logs/worker.pid ]; then
WORKER_PID=$(cat logs/worker.pid)
if ps -p $WORKER_PID > /dev/null 2>&1; then
echo "停止 RQ Worker (PID: $WORKER_PID)..."
kill $WORKER_PID
echo "[成功] RQ Worker 已停止"
else
echo "[提示] RQ Worker 未运行"
fi
rm logs/worker.pid
else
echo "[提示] 未找到 Worker PID 文件"
fi
# 确保所有相关进程都停止
echo ""
echo "清理所有相关进程..."
pkill -f "python run.py" 2>/dev/null && echo "清理了额外的 Flask 进程"
pkill -f "python worker.py" 2>/dev/null && echo "清理了额外的 Worker 进程"
echo ""
echo "========================================"
echo " 服务已停止"
echo "========================================"

@ -0,0 +1,42 @@
"""
RQ Worker 启动脚本
用于启动后台任务处理器
"""
import sys
import os
# 添加项目路径到Python路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from redis import Redis
from rq import Worker, Queue
from config.algorithm_config import AlgorithmConfig
from app import create_app
# 创建Flask应用上下文
app = create_app()
def main():
"""启动worker"""
with app.app_context():
# 连接Redis
redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL)
# 创建队列
queue = Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn)
# 创建worker
worker = Worker([queue], connection=redis_conn)
print(f"🚀 RQ Worker启动成功!")
print(f"📡 Redis: {AlgorithmConfig.REDIS_URL}")
print(f"📋 Queue: {AlgorithmConfig.RQ_QUEUE_NAME}")
print(f"🔄 使用{'真实' if AlgorithmConfig.USE_REAL_ALGORITHMS else '虚拟'}算法")
print(f"⏳ 等待任务...")
# 启动worker
worker.work()
if __name__ == '__main__':
main()
Loading…
Cancel
Save