From 6c0da89eb4506600de6f301460503da156418d05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 8 Nov 2025 17:20:05 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=AF=B9=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 27 +- README.md | 50 +- src/backend/.env.example | 36 + src/backend/README.md | 484 +-- src/backend/app.py | 90 +- src/backend/app/__init__.py | 164 +- .../app/algorithms/evaluation_engine.py | 300 -- src/backend/app/algorithms/finetune/infer.py | 87 + .../finetune/train_dreambooth_alone.py | 1035 ++++++ .../finetune/train_dreambooth_gen.py | 1450 ++++++++ .../app/algorithms/finetune/train_lora_gen.py | 1434 +++++++ .../finetune_virtual/train_dreambooth_gen.py | 134 + .../finetune_virtual/train_lora_gen.py | 134 + .../app/algorithms/perturbation/aspl.py | 770 ++++ .../app/algorithms/perturbation/caat.py | 972 +++++ .../app/algorithms/perturbation/pid.py | 272 ++ .../app/algorithms/perturbation/simac.py | 1039 ++++++ .../app/algorithms/perturbation_engine.py | 176 - .../algorithms/perturbation_virtual/aspl.py | 87 + .../algorithms/perturbation_virtual/caat.py | 79 + .../algorithms/perturbation_virtual/pid.py | 79 + .../algorithms/perturbation_virtual/simac.py | 82 + .../perturbation_virtual/virtual_demo.py | 25 + .../app/controllers/admin_controller.py | 476 +-- .../app/controllers/auth_controller.py | 310 +- .../app/controllers/demo_controller.py | 352 +- .../app/controllers/image_controller.py | 404 +- .../app/controllers/task_controller.py | 717 ++-- .../app/controllers/user_controller.py | 264 +- .../app/{models => database}/__init__.py | 466 +-- src/backend/app/scripts/attack_aspl.sh | 62 + src/backend/app/scripts/attack_caat.sh | 39 + .../app/scripts/attack_caat_with_prior.sh | 55 + src/backend/app/scripts/attack_pid.sh | 53 + src/backend/app/scripts/attack_simac.sh | 63 + src/backend/app/scripts/db_gen.sh | 90 + src/backend/app/scripts/db_infer.sh | 85 + src/backend/app/scripts/lora_gen.sh | 84 + src/backend/app/services/auth_service.py | 66 +- src/backend/app/services/image_service.py | 320 +- src/backend/app/services/task_service.py | 725 +++- src/backend/app/utils/file_utils.py | 104 +- src/backend/app/utils/jwt_utils.py | 30 +- src/backend/app/workers/finetune_worker.py | 365 ++ .../app/workers/perturbation_worker.py | 296 ++ src/backend/config/.env | 4 +- src/backend/config/algorithm_config.py | 197 + src/backend/config/settings.py | 216 +- src/backend/init_db.py | 160 +- src/backend/quick_start.bat | 15 - src/backend/requirements.txt | 68 +- src/backend/run.py | 40 +- src/backend/start.sh | 92 + src/backend/static/test.html | 3286 ++++++++--------- src/backend/status.sh | 85 + src/backend/stop.sh | 51 + src/backend/worker.py | 42 + 57 files changed, 13923 insertions(+), 4765 deletions(-) create mode 100644 src/backend/.env.example delete mode 100644 src/backend/app/algorithms/evaluation_engine.py create mode 100644 src/backend/app/algorithms/finetune/infer.py create mode 100644 src/backend/app/algorithms/finetune/train_dreambooth_alone.py create mode 100644 src/backend/app/algorithms/finetune/train_dreambooth_gen.py create mode 100644 src/backend/app/algorithms/finetune/train_lora_gen.py create mode 100644 src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py create mode 100644 src/backend/app/algorithms/finetune_virtual/train_lora_gen.py create mode 100644 src/backend/app/algorithms/perturbation/aspl.py create mode 100644 src/backend/app/algorithms/perturbation/caat.py create mode 100644 src/backend/app/algorithms/perturbation/pid.py create mode 100644 src/backend/app/algorithms/perturbation/simac.py delete mode 100644 src/backend/app/algorithms/perturbation_engine.py create mode 100644 src/backend/app/algorithms/perturbation_virtual/aspl.py create mode 100644 src/backend/app/algorithms/perturbation_virtual/caat.py create mode 100644 src/backend/app/algorithms/perturbation_virtual/pid.py create mode 100644 src/backend/app/algorithms/perturbation_virtual/simac.py create mode 100644 src/backend/app/algorithms/perturbation_virtual/virtual_demo.py rename src/backend/app/{models => database}/__init__.py (96%) create mode 100644 src/backend/app/scripts/attack_aspl.sh create mode 100644 src/backend/app/scripts/attack_caat.sh create mode 100644 src/backend/app/scripts/attack_caat_with_prior.sh create mode 100644 src/backend/app/scripts/attack_pid.sh create mode 100644 src/backend/app/scripts/attack_simac.sh create mode 100644 src/backend/app/scripts/db_gen.sh create mode 100644 src/backend/app/scripts/db_infer.sh create mode 100644 src/backend/app/scripts/lora_gen.sh create mode 100644 src/backend/app/workers/finetune_worker.py create mode 100644 src/backend/app/workers/perturbation_worker.py create mode 100644 src/backend/config/algorithm_config.py delete mode 100644 src/backend/quick_start.bat create mode 100644 src/backend/start.sh create mode 100644 src/backend/status.sh create mode 100644 src/backend/stop.sh create mode 100644 src/backend/worker.py diff --git a/.gitignore b/.gitignore index c9cc180..ef6693c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,18 @@ -__pycache__/ - -venv/ - -*.png -*.jpg -*.jpeg - -.env \ No newline at end of file +__pycache__/ + +venv/ +python=3.11/ + +*.png +*.jpg +*.jpeg + +# 环境配置文件(包含敏感信息) +*.env + +# 日志文件 +logs/ +*.log + +# 上传文件临时目录 +uploads/ \ No newline at end of file diff --git a/README.md b/README.md index 57290a8..1681ad8 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,25 @@ -# MuseGuard - -占位:项目总说明。后续将补充以下内容: - -## 简介 -(占位) - -## 项目目标 -(占位) - -## 技术栈 -(占位) - -## 快速开始 -(占位) - -## 目录结构说明 -(占位) - -## 贡献指南 -(占位) - -## 许可证 -(占位) - +# MuseGuard + +占位:项目总说明。后续将补充以下内容: + +## 简介 +(占位) + +## 项目目标 +(占位) + +## 技术栈 +(占位) + +## 快速开始 +(占位) + +## 目录结构说明 +(占位) + +## 贡献指南 +(占位) + +## 许可证 +(占位) + diff --git a/src/backend/.env.example b/src/backend/.env.example new file mode 100644 index 0000000..6f93179 --- /dev/null +++ b/src/backend/.env.example @@ -0,0 +1,36 @@ +# ============================================ +# 数据库配置 +# ============================================ +DB_HOST=localhost +DB_PORT=3306 +DB_USER=root +DB_PASSWORD=your_password +DB_NAME=museguard_schema + +# ============================================ +# Flask应用配置 +# ============================================ +SECRET_KEY=your-secret-key-here +JWT_SECRET_KEY=your-jwt-secret-key-here +FLASK_ENV=development + +# ============================================ +# Redis配置(用于任务队列) +# ============================================ +REDIS_URL=redis://localhost:6379/0 + +# ============================================ +# 算法模式配置 +# ============================================ +# true: 使用真实算法(需要conda环境和完整依赖) +# false: 使用虚拟算法(快速测试,不需要GPU和模型) +USE_REAL_ALGORITHMS=false + +# ============================================ +# 邮件配置(可选,用于注册验证) +# ============================================ +MAIL_SERVER=smtp.gmail.com +MAIL_PORT=587 +MAIL_USE_TLS=true +MAIL_USERNAME=your_email@gmail.com +MAIL_PASSWORD=your_email_password diff --git a/src/backend/README.md b/src/backend/README.md index 6728238..5b53032 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -1,243 +1,243 @@ -# MuseGuard 后端框架 - -基于对抗性扰动的多风格图像生成防护系统 - 后端API服务 - -## 项目结构 - -``` -backend/ -├── app/ # 主应用目录 -│ ├── algorithms/ # 算法实现 -│ │ ├── perturbation_engine.py # 对抗性扰动引擎 -│ │ └── evaluation_engine.py # 评估引擎 -│ ├── controllers/ # 控制器(路由处理) -│ │ ├── auth_controller.py # 认证控制器 -│ │ ├── user_controller.py # 用户配置控制器 -│ │ ├── task_controller.py # 任务控制器 -| | ├── demo_controller.py # 首页示例控制器 -│ │ ├── image_controller.py # 图像控制器 -│ │ └── admin_controller.py # 管理员控制器 -│ ├── models/ # 数据模型 -│ │ └── __init__.py # SQLAlchemy模型定义 -│ ├── services/ # 业务逻辑服务 -│ │ ├── auth_service.py # 认证服务 -│ │ ├── task_service.py # 任务处理服务 -│ │ └── image_service.py # 图像处理服务 -│ └── utils/ # 工具类 -│ └── file_utils.py # 文件处理工具 -├── config/ # 配置文件 -│ └── settings.py # 应用配置 -├── uploads/ # 文件上传目录 -├── static/ # 静态文件 -│ ├── originals/ # 重命名后的原始图片 -│ ├── perturbed/ # 加噪后的图片 -│ ├── model_outputs/ # 模型生成的图片 -│ │ ├── clean/ # 原图的模型生成结果 -│ │ └── perturbed/ # 加噪图的模型生成结果 -│ ├── heatmaps/ # 热力图 -│ └── demo/ # 演示图片 -│ ├── original/ # 演示原始图片 -│ ├── perturbed/ # 演示加噪图片 -│ └── comparisons/ # 演示对比图 -├── app.py # Flask应用工厂 -├── run.py # 启动脚本 -├── init_db.py # 数据库初始化脚本 -└── requirements.txt # Python依赖 -``` - -## 功能特性 - -### 用户功能 -- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次) -- ✅ 用户登录/登出 -- ✅ 密码修改 -- ✅ 任务创建和管理 -- ✅ 图片上传(单张/压缩包批量) -- ✅ 加噪处理(4种算法:SimAC、CAAT、PID、ASPL) -- ✅ 扰动强度自定义 -- ✅ 防净化版本选择 -- ✅ 智能配置记忆:自动保存用户上次选择的配置 -- ✅ 处理结果下载 -- ✅ 图片质量对比查看(FID、LPIPS、SSIM、PSNR、热力图) -- ✅ 模型生成对比分析 -- ✅ 预设演示图片浏览 - -### 管理员功能 -- ✅ 用户管理(增删改查) -- ✅ 系统统计信息查看 - -### 算法实现 -- ✅ ASPL算法虚拟实现(原始版本 + 防净化版本) -- ✅ SimAC算法虚拟实现(原始版本 + 防净化版本) -- ✅ CAAT算法虚拟实现(原始版本 + 防净化版本) -- ✅ PID算法虚拟实现(原始版本 + 防净化版本) -- ✅ 图像质量评估指标计算 -- ✅ 模型生成效果对比 -- ✅ 热力图生成 - -## 安装和运行 - -### 1. 环境准备 - -#### 使用虚拟环境(推荐) - -**为什么需要虚拟环境?** -- ✅ **避免依赖冲突**:不同项目使用不同版本的包 -- ✅ **环境隔离**:不污染系统Python环境 -- ✅ **版本一致性**:确保团队环境统一 -- ✅ **易于管理**:可以随时删除重建 - -```bash -# 创建虚拟环境 -python -m venv venv - -# 激活虚拟环境 -# Windows: -venv\\Scripts\\activate -# Linux/Mac: -source venv/bin/activate - -# 更新pip(推荐) -python -m pip install --upgrade pip - -# 安装依赖 -pip install -r requirements.txt -``` - -### 2. 数据库配置 - -确保已安装MySQL数据库并创建数据库。 - -修改 `config/.env` 中的数据库连接配置: - -### 3. 初始化数据库 - -```bash -# 运行数据库初始化脚本 -python init_db.py -``` - -### 4. 启动应用 - -```bash -# 开发模式启动 -python run.py - -# 或者使用Flask命令 -flask run -``` - -应用将在 `http://localhost:5000` 启动 - -### 5. 系统测试 - -访问 `http://localhost:5000/static/test.html` 进入功能测试页面: - -## API接口文档 - -### 认证接口 (`/api/auth`) - -- `POST /register` - 用户注册 -- `POST /login` - 用户登录 -- `POST /change-password` - 修改密码 -- `GET /profile` - 获取用户信息 -- `POST /logout` - 用户登出 - -### 任务管理 (`/api/task`) - -- `POST /create` - 创建任务(使用默认配置) -- `POST /upload/` - 上传图片到指定任务 -- `GET //config` - 获取任务配置(显示用户上次选择) -- `PUT //config` - 更新任务配置(自动保存为用户偏好) -- `GET /load-config` - 加载用户上次配置 -- `POST /save-config` - 保存用户配置偏好 -- `POST /start/` - 开始处理任务 -- `GET /list` - 获取任务列表 -- `GET /` - 获取任务详情 -- `GET //status` - 获取处理状态 - -### 图片管理 (`/api/image`) - -- `GET /file/` - 查看图片 -- `GET /download/` - 下载图片 -- `GET /batch//download` - 批量下载 -- `GET //evaluations` - 获取评估结果 -- `POST /compare` - 对比图片 -- `GET /heatmap/` - 获取热力图 -- `DELETE /delete/` - 删除图片 - -### 用户设置 (`/api/user`) - -- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中) -- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存) -- `GET /algorithms` - 获取可用算法(动态从数据库加载) -- `GET /stats` - 获取用户统计 - -### 管理员功能 (`/api/admin`) - -- `GET /users` - 用户列表 -- `GET /users/` - 用户详情 -- `POST /users` - 创建用户 -- `PUT /users/` - 更新用户 -- `DELETE /users/` - 删除用户 -- `GET /stats` - 系统统计 - -### 演示功能 (`/api/demo`) - -- `GET /images` - 获取演示图片列表 -- `GET /image/original/` - 获取演示原始图片 -- `GET /image/perturbed/` - 获取演示加噪图片 -- `GET /image/comparison/` - 获取演示对比图片 -- `GET /algorithms` - 获取算法演示信息 -- `GET /stats` - 获取演示统计数据 - -## 默认账户 - -系统初始化后会创建3个管理员账户: - -- 用户名:`admin1`, `admin2`, `admin3` -- 默认密码:`admin123` -- 邮箱:`admin1@museguard.com` 等 - -## 技术栈 - -- **Web框架**: Flask 2.3.3 -- **数据库ORM**: SQLAlchemy 3.0.5 -- **数据库**: MySQL(通过PyMySQL连接) -- **认证**: JWT (Flask-JWT-Extended) -- **跨域**: Flask-CORS -- **图像处理**: Pillow + NumPy -- **数学计算**: NumPy - -## 开发说明 - -### 虚拟实现说明 - -当前所有算法都是**虚拟实现**,用于框架搭建和测试: - -1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果 -2. **评估指标**: 基于像素差异的简化计算 -3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果 - -### 扩展指南 - -要集成真实算法: - -1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现 -2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算 -3. 根据需要调整配置参数 - -### 目录权限 - -确保以下目录有写入权限: - -- `uploads/` - 用户上传文件 -- `static/originals/` - 重命名后的原始图片 -- `static/perturbed/` - 加噪后的图片 -- `static/model_outputs/` - 模型生成的图片 -- `static/heatmaps/` - 热力图文件 -- `static/demo/` - 演示图片(需要手动添加演示文件) - -## 许可证 - +# MuseGuard 后端框架 + +基于对抗性扰动的多风格图像生成防护系统 - 后端API服务 + +## 项目结构 + +``` +backend/ +├── app/ # 主应用目录 +│ ├── algorithms/ # 算法实现 +│ │ ├── perturbation_engine.py # 对抗性扰动引擎 +│ │ └── evaluation_engine.py # 评估引擎 +│ ├── controllers/ # 控制器(路由处理) +│ │ ├── auth_controller.py # 认证控制器 +│ │ ├── user_controller.py # 用户配置控制器 +│ │ ├── task_controller.py # 任务控制器 +| | ├── demo_controller.py # 首页示例控制器 +│ │ ├── image_controller.py # 图像控制器 +│ │ └── admin_controller.py # 管理员控制器 +│ ├── models/ # 数据模型 +│ │ └── __init__.py # SQLAlchemy模型定义 +│ ├── services/ # 业务逻辑服务 +│ │ ├── auth_service.py # 认证服务 +│ │ ├── task_service.py # 任务处理服务 +│ │ └── image_service.py # 图像处理服务 +│ └── utils/ # 工具类 +│ └── file_utils.py # 文件处理工具 +├── config/ # 配置文件 +│ └── settings.py # 应用配置 +├── uploads/ # 文件上传目录 +├── static/ # 静态文件 +│ ├── originals/ # 重命名后的原始图片 +│ ├── perturbed/ # 加噪后的图片 +│ ├── model_outputs/ # 模型生成的图片 +│ │ ├── clean/ # 原图的模型生成结果 +│ │ └── perturbed/ # 加噪图的模型生成结果 +│ ├── heatmaps/ # 热力图 +│ └── demo/ # 演示图片 +│ ├── original/ # 演示原始图片 +│ ├── perturbed/ # 演示加噪图片 +│ └── comparisons/ # 演示对比图 +├── app.py # Flask应用工厂 +├── run.py # 启动脚本 +├── init_db.py # 数据库初始化脚本 +└── requirements.txt # Python依赖 +``` + +## 功能特性 + +### 用户功能 +- ✅ 用户注册(邮箱验证,同一邮箱只能注册一次) +- ✅ 用户登录/登出 +- ✅ 密码修改 +- ✅ 任务创建和管理 +- ✅ 图片上传(单张/压缩包批量) +- ✅ 加噪处理(4种算法:SimAC、CAAT、PID、ASPL) +- ✅ 扰动强度自定义 +- ✅ 防净化版本选择 +- ✅ 智能配置记忆:自动保存用户上次选择的配置 +- ✅ 处理结果下载 +- ✅ 图片质量对比查看(FID、LPIPS、SSIM、PSNR、热力图) +- ✅ 模型生成对比分析 +- ✅ 预设演示图片浏览 + +### 管理员功能 +- ✅ 用户管理(增删改查) +- ✅ 系统统计信息查看 + +### 算法实现 +- ✅ ASPL算法虚拟实现(原始版本 + 防净化版本) +- ✅ SimAC算法虚拟实现(原始版本 + 防净化版本) +- ✅ CAAT算法虚拟实现(原始版本 + 防净化版本) +- ✅ PID算法虚拟实现(原始版本 + 防净化版本) +- ✅ 图像质量评估指标计算 +- ✅ 模型生成效果对比 +- ✅ 热力图生成 + +## 安装和运行 + +### 1. 环境准备 + +#### 使用虚拟环境(推荐) + +**为什么需要虚拟环境?** +- ✅ **避免依赖冲突**:不同项目使用不同版本的包 +- ✅ **环境隔离**:不污染系统Python环境 +- ✅ **版本一致性**:确保团队环境统一 +- ✅ **易于管理**:可以随时删除重建 + +```bash +# 创建虚拟环境 +python -m venv venv + +# 激活虚拟环境 +# Windows: +venv\\Scripts\\activate +# Linux/Mac: +source venv/bin/activate + +# 更新pip(推荐) +python -m pip install --upgrade pip + +# 安装依赖 +pip install -r requirements.txt +``` + +### 2. 数据库配置 + +确保已安装MySQL数据库并创建数据库。 + +修改 `config/.env` 中的数据库连接配置: + +### 3. 初始化数据库 + +```bash +# 运行数据库初始化脚本 +python init_db.py +``` + +### 4. 启动应用 + +```bash +# 开发模式启动 +python run.py + +# 或者使用Flask命令 +flask run +``` + +应用将在 `http://localhost:5000` 启动 + +### 5. 系统测试 + +访问 `http://localhost:5000/static/test.html` 进入功能测试页面: + +## API接口文档 + +### 认证接口 (`/api/auth`) + +- `POST /register` - 用户注册 +- `POST /login` - 用户登录 +- `POST /change-password` - 修改密码 +- `GET /profile` - 获取用户信息 +- `POST /logout` - 用户登出 + +### 任务管理 (`/api/task`) + +- `POST /create` - 创建任务(使用默认配置) +- `POST /upload/` - 上传图片到指定任务 +- `GET //config` - 获取任务配置(显示用户上次选择) +- `PUT //config` - 更新任务配置(自动保存为用户偏好) +- `GET /load-config` - 加载用户上次配置 +- `POST /save-config` - 保存用户配置偏好 +- `POST /start/` - 开始处理任务 +- `GET /list` - 获取任务列表 +- `GET /` - 获取任务详情 +- `GET //status` - 获取处理状态 + +### 图片管理 (`/api/image`) + +- `GET /file/` - 查看图片 +- `GET /download/` - 下载图片 +- `GET /batch//download` - 批量下载 +- `GET //evaluations` - 获取评估结果 +- `POST /compare` - 对比图片 +- `GET /heatmap/` - 获取热力图 +- `DELETE /delete/` - 删除图片 + +### 用户设置 (`/api/user`) + +- `GET /config` - 获取用户配置(已弃用,配置集成到任务流程中) +- `PUT /config` - 更新用户配置(已弃用,通过任务配置自动保存) +- `GET /algorithms` - 获取可用算法(动态从数据库加载) +- `GET /stats` - 获取用户统计 + +### 管理员功能 (`/api/admin`) + +- `GET /users` - 用户列表 +- `GET /users/` - 用户详情 +- `POST /users` - 创建用户 +- `PUT /users/` - 更新用户 +- `DELETE /users/` - 删除用户 +- `GET /stats` - 系统统计 + +### 演示功能 (`/api/demo`) + +- `GET /images` - 获取演示图片列表 +- `GET /image/original/` - 获取演示原始图片 +- `GET /image/perturbed/` - 获取演示加噪图片 +- `GET /image/comparison/` - 获取演示对比图片 +- `GET /algorithms` - 获取算法演示信息 +- `GET /stats` - 获取演示统计数据 + +## 默认账户 + +系统初始化后会创建3个管理员账户: + +- 用户名:`admin1`, `admin2`, `admin3` +- 默认密码:`admin123` +- 邮箱:`admin1@museguard.com` 等 + +## 技术栈 + +- **Web框架**: Flask 2.3.3 +- **数据库ORM**: SQLAlchemy 3.0.5 +- **数据库**: MySQL(通过PyMySQL连接) +- **认证**: JWT (Flask-JWT-Extended) +- **跨域**: Flask-CORS +- **图像处理**: Pillow + NumPy +- **数学计算**: NumPy + +## 开发说明 + +### 虚拟实现说明 + +当前所有算法都是**虚拟实现**,用于框架搭建和测试: + +1. **对抗性扰动算法**: 使用随机噪声模拟真实算法效果 +2. **评估指标**: 基于像素差异的简化计算 +3. **模型生成**: 通过图像变换模拟DreamBooth/LoRA效果 + +### 扩展指南 + +要集成真实算法: + +1. 替换 `app/algorithms/perturbation_engine.py` 中的虚拟实现 +2. 替换 `app/algorithms/evaluation_engine.py` 中的评估计算 +3. 根据需要调整配置参数 + +### 目录权限 + +确保以下目录有写入权限: + +- `uploads/` - 用户上传文件 +- `static/originals/` - 重命名后的原始图片 +- `static/perturbed/` - 加噪后的图片 +- `static/model_outputs/` - 模型生成的图片 +- `static/heatmaps/` - 热力图文件 +- `static/demo/` - 演示图片(需要手动添加演示文件) + +## 许可证 + 本项目仅用于学习和研究目的。 \ No newline at end of file diff --git a/src/backend/app.py b/src/backend/app.py index 3d8a748..59d7da9 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -1,46 +1,46 @@ -""" -MuseGuard 后端主应用入口 -基于对抗性扰动的多风格图像生成防护系统 -""" - -from flask import Flask -from flask_sqlalchemy import SQLAlchemy -from flask_migrate import Migrate -from flask_jwt_extended import JWTManager -from flask_cors import CORS -from config.settings import Config - -# 初始化扩展 -db = SQLAlchemy() -migrate = Migrate() -jwt = JWTManager() - -def create_app(config_class=Config): - """Flask应用工厂函数""" - app = Flask(__name__) - app.config.from_object(config_class) - - # 初始化扩展 - db.init_app(app) - migrate.init_app(app, db) - jwt.init_app(app) - CORS(app) - - # 注册蓝图 - from app.controllers.auth_controller import auth_bp - from app.controllers.user_controller import user_bp - from app.controllers.task_controller import task_bp - from app.controllers.image_controller import image_bp - from app.controllers.admin_controller import admin_bp - - app.register_blueprint(auth_bp, url_prefix='/api/auth') - app.register_blueprint(user_bp, url_prefix='/api/user') - app.register_blueprint(task_bp, url_prefix='/api/task') - app.register_blueprint(image_bp, url_prefix='/api/image') - app.register_blueprint(admin_bp, url_prefix='/api/admin') - - return app - -if __name__ == '__main__': - app = create_app() +""" +MuseGuard 后端主应用入口 +基于对抗性扰动的多风格图像生成防护系统 +""" + +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate +from flask_jwt_extended import JWTManager +from flask_cors import CORS +from config.settings import Config + +# 初始化扩展 +db = SQLAlchemy() +migrate = Migrate() +jwt = JWTManager() + +def create_app(config_class=Config): + """Flask应用工厂函数""" + app = Flask(__name__) + app.config.from_object(config_class) + + # 初始化扩展 + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + CORS(app) + + # 注册蓝图 + from app.controllers.auth_controller import auth_bp + from app.controllers.user_controller import user_bp + from app.controllers.task_controller import task_bp + from app.controllers.image_controller import image_bp + from app.controllers.admin_controller import admin_bp + + app.register_blueprint(auth_bp, url_prefix='/api/auth') + app.register_blueprint(user_bp, url_prefix='/api/user') + app.register_blueprint(task_bp, url_prefix='/api/task') + app.register_blueprint(image_bp, url_prefix='/api/image') + app.register_blueprint(admin_bp, url_prefix='/api/admin') + + return app + +if __name__ == '__main__': + app = create_app() app.run(debug=True, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/src/backend/app/__init__.py b/src/backend/app/__init__.py index 4ed0625..f6890e5 100644 --- a/src/backend/app/__init__.py +++ b/src/backend/app/__init__.py @@ -1,83 +1,83 @@ -""" -MuseGuard 后端应用包 -基于对抗性扰动的多风格图像生成防护系统 -""" - -from flask import Flask -from flask_sqlalchemy import SQLAlchemy -from flask_migrate import Migrate -from flask_jwt_extended import JWTManager -from flask_cors import CORS -import os - -# 初始化扩展 -db = SQLAlchemy() -migrate = Migrate() -jwt = JWTManager() -cors = CORS() - -def create_app(config_name=None): - """应用工厂函数""" - # 配置静态文件和模板文件路径 - app = Flask(__name__, - static_folder='../static', - static_url_path='/static') - - # 加载配置 - if config_name is None: - config_name = os.environ.get('FLASK_ENV', 'development') - - from config.settings import config - app.config.from_object(config[config_name]) - - # 初始化扩展 - db.init_app(app) - migrate.init_app(app, db) - jwt.init_app(app) - cors.init_app(app) - - # 注册蓝图 - from app.controllers.auth_controller import auth_bp - from app.controllers.user_controller import user_bp - from app.controllers.task_controller import task_bp - from app.controllers.image_controller import image_bp - from app.controllers.admin_controller import admin_bp - from app.controllers.demo_controller import demo_bp - - app.register_blueprint(auth_bp, url_prefix='/api/auth') - app.register_blueprint(user_bp, url_prefix='/api/user') - app.register_blueprint(task_bp, url_prefix='/api/task') - app.register_blueprint(image_bp, url_prefix='/api/image') - app.register_blueprint(admin_bp, url_prefix='/api/admin') - app.register_blueprint(demo_bp, url_prefix='/api/demo') - - # 注册错误处理器 - @app.errorhandler(404) - def not_found_error(error): - return {'error': 'Not found'}, 404 - - @app.errorhandler(500) - def internal_error(error): - db.session.rollback() - return {'error': 'Internal server error'}, 500 - - # 根路由 - @app.route('/') - def index(): - return { - 'message': 'MuseGuard API Server', - 'version': '1.0.0', - 'status': 'running', - 'endpoints': { - 'health': '/health', - 'api_docs': '/api', - 'test_page': '/static/test.html' - } - } - - # 健康检查端点 - @app.route('/health') - def health_check(): - return {'status': 'healthy', 'message': 'MuseGuard backend is running'} - +""" +MuseGuard 后端应用包 +基于对抗性扰动的多风格图像生成防护系统 +""" + +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate +from flask_jwt_extended import JWTManager +from flask_cors import CORS +import os + +# 初始化扩展 +db = SQLAlchemy() +migrate = Migrate() +jwt = JWTManager() +cors = CORS() + +def create_app(config_name=None): + """应用工厂函数""" + # 配置静态文件和模板文件路径 + app = Flask(__name__, + static_folder='../static', + static_url_path='/static') + + # 加载配置 + if config_name is None: + config_name = os.environ.get('FLASK_ENV', 'development') + + from config.settings import config + app.config.from_object(config[config_name]) + + # 初始化扩展 + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + cors.init_app(app) + + # 注册蓝图 + from app.controllers.auth_controller import auth_bp + from app.controllers.user_controller import user_bp + from app.controllers.task_controller import task_bp + from app.controllers.image_controller import image_bp + from app.controllers.admin_controller import admin_bp + from app.controllers.demo_controller import demo_bp + + app.register_blueprint(auth_bp, url_prefix='/api/auth') + app.register_blueprint(user_bp, url_prefix='/api/user') + app.register_blueprint(task_bp, url_prefix='/api/task') + app.register_blueprint(image_bp, url_prefix='/api/image') + app.register_blueprint(admin_bp, url_prefix='/api/admin') + app.register_blueprint(demo_bp, url_prefix='/api/demo') + + # 注册错误处理器 + @app.errorhandler(404) + def not_found_error(error): + return {'error': 'Not found'}, 404 + + @app.errorhandler(500) + def internal_error(error): + db.session.rollback() + return {'error': 'Internal server error'}, 500 + + # 根路由 + @app.route('/') + def index(): + return { + 'message': 'MuseGuard API Server', + 'version': '1.0.0', + 'status': 'running', + 'endpoints': { + 'health': '/health', + 'api_docs': '/api', + 'test_page': '/static/test.html' + } + } + + # 健康检查端点 + @app.route('/health') + def health_check(): + return {'status': 'healthy', 'message': 'MuseGuard backend is running'} + return app \ No newline at end of file diff --git a/src/backend/app/algorithms/evaluation_engine.py b/src/backend/app/algorithms/evaluation_engine.py deleted file mode 100644 index b6ab294..0000000 --- a/src/backend/app/algorithms/evaluation_engine.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -评估引擎 -实现图像质量评估和模型生成对比的虚拟版本 -""" - -import os -import numpy as np -from PIL import Image, ImageDraw -import uuid -import random -from flask import current_app - -class EvaluationEngine: - """评估处理引擎""" - - @staticmethod - def evaluate_image_quality(original_path, perturbed_path): - """ - 评估图像质量指标 - - Args: - original_path: 原始图片路径 - perturbed_path: 扰动后图片路径 - - Returns: - 包含各种评估指标的字典 - """ - try: - # 加载图片 - with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img: - # 确保图片尺寸一致 - if orig_img.size != pert_img.size: - pert_img = pert_img.resize(orig_img.size) - - # 转换为RGB - if orig_img.mode != 'RGB': - orig_img = orig_img.convert('RGB') - if pert_img.mode != 'RGB': - pert_img = pert_img.convert('RGB') - - # 计算虚拟评估指标 - fid_score = EvaluationEngine._calculate_mock_fid(orig_img, pert_img) - lpips_score = EvaluationEngine._calculate_mock_lpips(orig_img, pert_img) - ssim_score = EvaluationEngine._calculate_mock_ssim(orig_img, pert_img) - psnr_score = EvaluationEngine._calculate_mock_psnr(orig_img, pert_img) - - # 生成热力图 - heatmap_path = EvaluationEngine._generate_difference_heatmap( - orig_img, pert_img, "quality" - ) - - return { - 'fid': fid_score, - 'lpips': lpips_score, - 'ssim': ssim_score, - 'psnr': psnr_score, - 'heatmap_path': heatmap_path - } - - except Exception as e: - print(f"图像质量评估失败: {str(e)}") - return { - 'fid': None, - 'lpips': None, - 'ssim': None, - 'psnr': None, - 'heatmap_path': None - } - - @staticmethod - def evaluate_model_generation(original_path, perturbed_path, finetune_method): - """ - 评估模型生成对比 - - Args: - original_path: 原始图片路径 - perturbed_path: 扰动后图片路径 - finetune_method: 微调方法 (dreambooth, lora) - - Returns: - 包含生成对比评估指标的字典 - """ - try: - # 模拟训练过程并生成对比图片 - with Image.open(original_path) as orig_img, Image.open(perturbed_path) as pert_img: - # 模拟原始图片训练后的生成结果 - orig_generated = EvaluationEngine._simulate_model_generation(orig_img, finetune_method, False) - - # 模拟扰动图片训练后的生成结果 - pert_generated = EvaluationEngine._simulate_model_generation(pert_img, finetune_method, True) - - # 计算生成质量对比指标 - fid_score = EvaluationEngine._calculate_generation_fid(orig_generated, pert_generated) - lpips_score = EvaluationEngine._calculate_generation_lpips(orig_generated, pert_generated) - ssim_score = EvaluationEngine._calculate_generation_ssim(orig_generated, pert_generated) - psnr_score = EvaluationEngine._calculate_generation_psnr(orig_generated, pert_generated) - - # 生成对比热力图 - heatmap_path = EvaluationEngine._generate_difference_heatmap( - orig_generated, pert_generated, "generation" - ) - - return { - 'fid': fid_score, - 'lpips': lpips_score, - 'ssim': ssim_score, - 'psnr': psnr_score, - 'heatmap_path': heatmap_path - } - - except Exception as e: - print(f"模型生成评估失败: {str(e)}") - return { - 'fid': None, - 'lpips': None, - 'ssim': None, - 'psnr': None, - 'heatmap_path': None - } - - @staticmethod - def _calculate_mock_fid(img1, img2): - """模拟FID计算""" - # 简单的像素差异统计作为FID近似 - arr1 = np.array(img1, dtype=np.float32) - arr2 = np.array(img2, dtype=np.float32) - - mse = np.mean((arr1 - arr2) ** 2) - # FID通常在0-300之间,扰动图片应该有较小的差异 - mock_fid = min(mse / 10.0 + random.uniform(0.5, 2.0), 50.0) - return round(mock_fid, 4) - - @staticmethod - def _calculate_mock_lpips(img1, img2): - """模拟LPIPS计算""" - arr1 = np.array(img1, dtype=np.float32) / 255.0 - arr2 = np.array(img2, dtype=np.float32) / 255.0 - - # 简单的感知差异模拟 - diff = np.abs(arr1 - arr2) - # 模拟感知权重(边缘区域权重更高) - gray1 = np.mean(arr1, axis=2) - edges = np.abs(np.gradient(gray1)[0]) + np.abs(np.gradient(gray1)[1]) - - weighted_diff = np.mean(diff * (1 + edges.reshape(edges.shape + (1,)))) - mock_lpips = min(weighted_diff + random.uniform(0.001, 0.01), 1.0) - return round(mock_lpips, 4) - - @staticmethod - def _calculate_mock_ssim(img1, img2): - """模拟SSIM计算""" - arr1 = np.array(img1, dtype=np.float32) - arr2 = np.array(img2, dtype=np.float32) - - # 简单的结构相似性模拟 - mu1 = np.mean(arr1) - mu2 = np.mean(arr2) - - sigma1_sq = np.var(arr1) - sigma2_sq = np.var(arr2) - sigma12 = np.mean((arr1 - mu1) * (arr2 - mu2)) - - c1 = (0.01 * 255) ** 2 - c2 = (0.03 * 255) ** 2 - - ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \ - ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2)) - - # SSIM在-1到1之间,但通常在0.8-1.0之间为良好 - mock_ssim = max(0.85 + random.uniform(-0.05, 0.1), 0.0) - return round(mock_ssim, 4) - - @staticmethod - def _calculate_mock_psnr(img1, img2): - """模拟PSNR计算""" - arr1 = np.array(img1, dtype=np.float32) - arr2 = np.array(img2, dtype=np.float32) - - mse = np.mean((arr1 - arr2) ** 2) - if mse == 0: - return 100.0 - - max_pixel = 255.0 - psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) - - # 对于轻微扰动,PSNR应该比较高(30-50dB) - mock_psnr = max(min(psnr + random.uniform(-2, 2), 50.0), 20.0) - return round(mock_psnr, 4) - - @staticmethod - def _simulate_model_generation(img, finetune_method, is_perturbed): - """模拟模型生成过程""" - # 对输入图像进行变换以模拟生成结果 - arr = np.array(img, dtype=np.float32) - - if finetune_method == "dreambooth": - # DreamBooth风格变换 - if is_perturbed: - # 扰动图片训练的模型生成质量较差 - noise_level = 15.0 - style_change = 0.3 - else: - # 原始图片训练的模型生成质量较好 - noise_level = 5.0 - style_change = 0.1 - - elif finetune_method == "lora": - # LoRA风格变换 - if is_perturbed: - noise_level = 12.0 - style_change = 0.25 - else: - noise_level = 3.0 - style_change = 0.08 - else: - noise_level = 8.0 - style_change = 0.15 - - # 添加噪声模拟生成质量 - noise = np.random.normal(0, noise_level, arr.shape) - generated_arr = arr + noise - - # 模拟风格变化 - if style_change > 0: - # 简单的颜色偏移 - generated_arr[:,:,0] *= (1 + style_change * random.uniform(-0.5, 0.5)) - generated_arr[:,:,1] *= (1 + style_change * random.uniform(-0.5, 0.5)) - generated_arr[:,:,2] *= (1 + style_change * random.uniform(-0.5, 0.5)) - - generated_arr = np.clip(generated_arr, 0, 255).astype(np.uint8) - return Image.fromarray(generated_arr) - - @staticmethod - def _calculate_generation_fid(img1, img2): - """计算生成图片的FID(质量越差FID越高越好)""" - fid = EvaluationEngine._calculate_mock_fid(img1, img2) - # 对于防护效果,FID越高越好 - return max(fid, 10.0) - - @staticmethod - def _calculate_generation_lpips(img1, img2): - """计算生成图片的LPIPS(差异越大越好)""" - lpips = EvaluationEngine._calculate_mock_lpips(img1, img2) - return max(lpips, 0.1) - - @staticmethod - def _calculate_generation_ssim(img1, img2): - """计算生成图片的SSIM(相似度越低越好)""" - ssim = EvaluationEngine._calculate_mock_ssim(img1, img2) - # 对于防护效果,SSIM越低越好 - return min(ssim, 0.9) - - @staticmethod - def _calculate_generation_psnr(img1, img2): - """计算生成图片的PSNR(质量差异越大越好)""" - psnr = EvaluationEngine._calculate_mock_psnr(img1, img2) - return min(psnr, 35.0) - - @staticmethod - def _generate_difference_heatmap(img1, img2, eval_type): - """生成差异热力图""" - try: - arr1 = np.array(img1, dtype=np.float32) - arr2 = np.array(img2, dtype=np.float32) - - # 计算像素差异 - diff = np.abs(arr1 - arr2) - diff_gray = np.mean(diff, axis=2) - - # 归一化到0-255 - diff_normalized = (diff_gray / diff_gray.max() * 255).astype(np.uint8) - - # 创建热力图 - heatmap = Image.fromarray(diff_normalized, mode='L') - heatmap = heatmap.convert('RGB') - - # 应用颜色映射(蓝色-绿色-红色) - heatmap_array = np.array(heatmap) - colored_heatmap = np.zeros_like(arr1, dtype=np.uint8) - - # 简单的颜色映射 - intensity = heatmap_array[:,:,0] / 255.0 - colored_heatmap[:,:,0] = (intensity * 255).astype(np.uint8) # 红色通道 - colored_heatmap[:,:,1] = ((1 - intensity) * intensity * 4 * 255).astype(np.uint8) # 绿色通道 - colored_heatmap[:,:,2] = ((1 - intensity) * 255).astype(np.uint8) # 蓝色通道 - - # 保存热力图 - heatmap_dir = os.path.join('static', 'heatmaps') - os.makedirs(heatmap_dir, exist_ok=True) - - heatmap_filename = f"{eval_type}_heatmap_{uuid.uuid4().hex[:8]}.png" - heatmap_path = os.path.join(heatmap_dir, heatmap_filename) - - Image.fromarray(colored_heatmap).save(heatmap_path) - - return heatmap_path - - except Exception as e: - print(f"生成热力图失败: {str(e)}") - return None \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/infer.py b/src/backend/app/algorithms/finetune/infer.py new file mode 100644 index 0000000..624fcc6 --- /dev/null +++ b/src/backend/app/algorithms/finetune/infer.py @@ -0,0 +1,87 @@ +import argparse +import os + +import torch +from diffusers import StableDiffusionPipeline +from torchvision.utils import make_grid +from pytorch_lightning import seed_everything +from PIL import Image +import numpy as np +from einops import rearrange + +parser = argparse.ArgumentParser(description="Inference") +parser.add_argument( + "--model_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", +) +parser.add_argument( + "--output_dir", + type=str, + default="./test-infer/", + help="The output directory where predictions are saved", +) +parser.add_argument( + "--seed", + type=int, + default=42, + help="The output directory where predictions are saved", +) +parser.add_argument( + "--v", + type=str, + default="sks", + help="The output directory where predictions are saved", +) + +args = parser.parse_args() + +if __name__ == "__main__": + seed_everything(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + # define prompts + prompts = [ + f"a photo of {args.v} person", + f"a dslr portrait of {args.v} person", + f"a photo of {args.v} person looking at the mirror", + f"a photo of {args.v} person in front of eiffel tower", + ] + + + # create & load model + pipe = StableDiffusionPipeline.from_pretrained( + args.model_path, + torch_dtype=torch.float32, + safety_checker=None, + local_files_only=True, + ).to("cuda") + + for prompt in prompts: + print(">>>>>>", prompt) + norm_prompt = prompt.lower().replace(",", "").replace(" ", "_") + out_path = f"{args.output_dir}/{norm_prompt}" + os.makedirs(out_path, exist_ok=True) + all_samples = list() + for i in range(5): + images = pipe([prompt] * 6, num_inference_steps=100, guidance_scale=7.5,).images + for idx, image in enumerate(images): + image.save(f"{out_path}/{i}_{idx}.png") + image = np.array(image, dtype=np.float32) + image /= 255.0 + image = np.transpose(image, (2, 0, 1)) + image = torch.from_numpy(image) # numpy->tensor + all_samples.append(image) + grid = torch.stack(all_samples, 0) + # grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=8) + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img.save(f"{args.output_dir}/{prompt}.png") + torch.cuda.empty_cache() + + del pipe + torch.cuda.empty_cache() diff --git a/src/backend/app/algorithms/finetune/train_dreambooth_alone.py b/src/backend/app/algorithms/finetune/train_dreambooth_alone.py new file mode 100644 index 0000000..52a04a5 --- /dev/null +++ b/src/backend/app/algorithms/finetune/train_dreambooth_alone.py @@ -0,0 +1,1035 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import hashlib +import itertools +import logging +import math +import os +import warnings +from pathlib import Path +from typing import Optional + +import datasets +import diffusers +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import HfFolder, create_repo, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.13.0.dev0") + +logger = get_logger(__name__) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--inference_prompts", + type=str, + default=None, + help="The prompt used to generate images at inference.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def infer(checkpoint_path, ckpt_pipeline, prompts=None, n_img=16, bs=8, n_steps=100, guidance_scale=7.5): + if ckpt_pipeline is None: + pipe = StableDiffusionPipeline.from_pretrained( + checkpoint_path, torch_dtype=torch.bfloat16, safety_checker=None + ).to("cuda") + else: + pipe = ckpt_pipeline.to("cuda") + pipe.enable_xformers_memory_efficient_attention() + pipe.disable_attention_slicing() + + for prompt in prompts: + print(prompt) + norm_prompt = prompt.lower().replace(",", "").replace(" ", "_") + out_path = f"{checkpoint_path}/dreambooth/{norm_prompt}" + os.makedirs(out_path, exist_ok=True) + for i in range(n_img // bs): + images = pipe( + [prompt] * bs, + num_inference_steps=n_steps, + guidance_scale=guidance_scale, + ).images + for idx, image in enumerate(images): + image.save(f"{out_path}/{i}_{idx}.png") + del pipe + + +class LatentsDataset(Dataset): + def __init__(self, latents_cache, text_encoder_cache): + self.latents_cache = latents_cache + self.text_encoder_cache = text_encoder_cache + + def __len__(self): + return len(self.latents_cache) + + def __getitem__(self, index): + return self.latents_cache[index], self.text_encoder_cache[index] + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + pipeline.enable_xformers_memory_efficient_attention() + pipeline.disable_attention_slicing() + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: + raise ValueError( + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." + f" {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=False, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + vae.to(device, dtype=weight_dtype) + + latents_cache = [] + text_encoder_cache = [] + + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to(device, dtype=weight_dtype) + batch["input_ids"] = batch["input_ids"].to(device) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if args.train_text_encoder: + text_encoder_cache.append(batch["input_ids"]) + else: + text_encoder_cache.append(text_encoder(batch["input_ids"])[0]) + train_dataset = LatentsDataset(latents_cache, text_encoder_cache) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True) + scaling_factor = vae.config.scaling_factor + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + ( + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # Move vae and text_encoder to device and cast to weight_dtype + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, args.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latent_dist = batch[0][0] + latents = latent_dist.sample() + latents = latents * scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch[0][1])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = args.output_dir + ckpt_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + if global_step < 1000: + prompts = args.inference_prompts.split(";") + infer(save_path, ckpt_pipeline, prompts, n_img=16, bs=4, n_steps=100) + else: + ckpt_pipeline.save_pretrained(save_path) + del ckpt_pipeline + prompts = args.inference_prompts.split(";") + ckpt_pipeline = None + # infer(save_path, ckpt_pipeline, prompts, n_img=16, bs=4, n_steps=100) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + print("Finish training") + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_dreambooth_gen.py b/src/backend/app/algorithms/finetune/train_dreambooth_gen.py new file mode 100644 index 0000000..fb9721e --- /dev/null +++ b/src/backend/app/algorithms/finetune/train_dreambooth_gen.py @@ -0,0 +1,1450 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import importlib +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, model_info, upload_folder +from huggingface_hub.utils import insecure_hashlib +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images: list = None, + base_model: str = None, + train_text_encoder=False, + prompt: str = None, + repo_folder: str = None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + model_description = f""" +# DreamBooth - {repo_id} + +This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). +You can find some example images in the following. \n +{img_str} + +DreamBooth for the text encoder was enabled: {train_text_encoder}. +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + prompt=prompt, + model_description=model_description, + inference=True, + ) + + tags = ["text-to-image", "dreambooth", "diffusers-training"] + if isinstance(pipeline, StableDiffusionPipeline): + tags.extend(["stable-diffusion", "stable-diffusion-diffusers"]) + else: + tags.extend(["if", "if-diffusers"]) + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + global_step, + prompt_embeds, + negative_prompt_embeds, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + pipeline_args = {} + + if vae is not None: + pipeline_args["vae"] = vae + + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + **pipeline_args, + ) + + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + module = importlib.import_module("diffusers") + scheduler_class = getattr(module, args.validation_scheduler) + pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) # disable safety checker + + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more details" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + + parser.add_argument( + "--offset_noise", + action="store_true", + default=False, + help=( + "Fine-tuning against a modified noise" + " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--validation_scheduler", + type=str, + default="DPMSolverMultistepScheduler", + choices=["DPMSolverMultistepScheduler", "DDPMScheduler"], + help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.", + ) + + parser.add_argument( + "--validation_image_output_dir", + type=str, + default=None, + help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.class_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + attention_mask = torch.cat(attention_mask, dim=0) + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def model_has_vae(args): + config_file_name = Path("vae", AutoencoderKL.config_name).as_posix() + if os.path.isdir(args.pretrained_model_name_or_path): + config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + return os.path.isfile(config_file_name) + else: + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + else: + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(unwrap_model(text_encoder))): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if vae is not None: + vae.requires_grad_(False) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." + ) + + if unwrap_model(unet).dtype != torch.float32: + raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") + + if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: + raise ValueError( + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.class_prompt is not None: + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) + else: + pre_computed_class_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_class_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder and text_encoder is not None: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers("dreambooth", config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + # Check if output_dir contains saved state files (simplified check for state files) + # We look for the presence of the required files saved by accelerator.save_state in the output directory + required_files = ["pytorch_model.bin", "optimizer.bin"] # Simplified check + + has_saved_state = all(os.path.exists(os.path.join(args.output_dir, f)) for f in required_files) + + if args.resume_from_checkpoint == "latest" and not has_saved_state: + accelerator.print( + f"Checkpoint does not exist in '{args.output_dir}'. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint at {args.output_dir}") + # Load state directly from args.output_dir + accelerator.load_state(args.output_dir) + + # Since we are loading from the main directory, we trust accelerator.load_state + # to restore the global_step correctly from the state saved in that directory. + # We initialize global_step/initial_global_step/first_epoch using the restored state after load_state. + # For simplicity, we keep the original logic's initialization structure but adjust the path/logic. + # Accelerator will internally restore the true global_step. We set temporary values. + # Note: A cleaner solution often involves saving/loading a separate 'step.json' file for global_step tracking + # when relying on in-place saving without automatic tracking of step in folder names. + # For this simple replacement, we let the accelerator handle it. + global_step = 0 + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the model input + if args.offset_noise: + noise = torch.randn_like(model_input) + 0.1 * torch.randn( + model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device + ) + else: + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False + )[0] + + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Compute instance loss + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # Save state directly to output_dir, replacing previous checkpoint + # checkpoints_total_limit logic is skipped as we are only keeping one checkpoint (the latest) + save_path = args.output_dir + accelerator.save_state(save_path) + logger.info(f"Saved state directly to {save_path}, replacing previous checkpoint at step {global_step}") + + images = [] + + if args.validation_prompt is not None and (global_step + 1) % args.validation_steps == 0: + images = log_validation( + unwrap_model(text_encoder) if text_encoder is not None else text_encoder, + tokenizer, + unwrap_model(unet), + vae, + args, + accelerator, + weight_dtype, + global_step, + validation_prompt_encoder_hidden_states, + validation_prompt_negative_prompt_embeds, + ) + + # Save validation images directly to output_dir + save_path = Path(args.validation_image_output_dir) if args.validation_image_output_dir else Path(args.output_dir) + save_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving validation images directly to {save_path}, overwriting previous images.") + + for i, image in enumerate(images): + # The file name is constant, thus overwriting + image.save(save_path / f"validation_image_{i}.png") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = unwrap_model(text_encoder) + + if args.skip_save_text_encoder: + pipeline_args["text_encoder"] = None + + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unwrap_model(unet), + revision=args.revision, + variant=args.variant, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_lora_gen.py b/src/backend/app/algorithms/finetune/train_lora_gen.py new file mode 100644 index 0000000..29e6ad5 --- /dev/null +++ b/src/backend/app/algorithms/finetune/train_lora_gen.py @@ -0,0 +1,1434 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + model_description = f""" +# LoRA DreamBooth - {repo_id} + +These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + prompt=prompt, + model_description=model_description, + inference=True, + ) + tags = ["text-to-image", "diffusers", "lora", "diffusers-training"] + if isinstance(pipeline, StableDiffusionPipeline): + tags.extend(["stable-diffusion", "stable-diffusion-diffusers"]) + else: + tags.extend(["if", "if-diffusers"]) + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) # disable safety checker + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + if args.validation_images is None: + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + + parser.add_argument( + "--validation_image_output_dir", + type=str, + default=None, + help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", + ) + + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.class_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + if vae is not None: + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], + ) + unet.add_adapter(unet_lora_config) + + # The text encoder comes from 🤗 transformers, we will also attach adapters to it. + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + elif isinstance(model, type(unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.append(text_encoder_) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.append(text_encoder) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + if args.train_text_encoder: + params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.class_prompt is not None: + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) + else: + pre_computed_class_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_class_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers("dreambooth-lora", config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + resume_path = args.output_dir + + try: + accelerator.print(f"Resuming from checkpoint at {resume_path}") + accelerator.load_state(resume_path) + + # After loading state, `accelerator` updates its internal state including `step` and `epoch` + initial_global_step = accelerator.state.global_step + global_step = initial_global_step + + # Recalculate first_epoch based on the loaded global_step + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + first_epoch = global_step // num_update_steps_per_epoch + + accelerator.print(f"Resumed at global step {global_step} and epoch {first_epoch}") + + except Exception as e: + accelerator.print( + f"Could not load state from '{resume_path}'. Starting a new training run. Error: {e}" + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + first_epoch = 0 + else: + initial_global_step = 0 + first_epoch = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + encoder_hidden_states, + class_labels=class_labels, + return_dict=False, + )[0] + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if (global_step + 1) % args.checkpointing_steps == 0: + # 1. 保存模型参数:直接保存到 args.output_dir,覆盖上一轮 + output_dir = args.output_dir + # accelerator.save_state handles saving the models using the registered hooks + accelerator.save_state(output_dir) + logger.info(f"Saving state to {output_dir} at step {global_step+1}") + + # 2. 推理调用模型:从 args.output_dir 加载最新的模型权重 + # The base pipeline is re-loaded, and the Lora weights are saved *to* args.output_dir + # in the accelerator hook. Here, we must ensure we use the saved unet/text_encoder. + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + # Use the unwrapped models which contain the latest trained LoRA weights + unet=unwrap_model(unet), + text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + ) + + # 3. 推理生成结果保存:直接保存到指定目录/output_dir,不创建子文件夹 + base_save_path = Path(args.validation_image_output_dir or args.output_dir) + base_save_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving validation images to {base_save_path}") + + # 图片直接保存在 base_save_path,会覆盖上一轮的同名图片 + for i, image in enumerate(images): + image.save(base_save_path / f"image_{i}.png") + + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + unet = unet.to(torch.float32) + + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + if args.train_text_encoder: + text_encoder = unwrap_model(text_encoder) + text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder)) + else: + text_encoder_state_dict = None + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + ) + + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype + ) + + # load attention processors + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py b/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py new file mode 100644 index 0000000..96475df --- /dev/null +++ b/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py @@ -0,0 +1,134 @@ +""" +DreamBooth微调虚拟实现 +用于测试后端流程,不执行实际的模型训练 +""" +import argparse +import os +import sys +import platform +import shutil +import glob +from PIL import Image, ImageDraw, ImageFont + +def create_generated_image(source_image_path, output_path, index): + """创建一个模拟生成的图片(添加水印表示是虚拟生成的)""" + with Image.open(source_image_path) as img: + # 复制原图 + generated = img.copy() + draw = ImageDraw.Draw(generated) + + # 添加水印文字 + width, height = generated.size + text = f"Virtual Generated #{index}" + + # 简单在图片上绘制文字 + position = (10, height - 30) + draw.text(position, text, fill=(255, 255, 255)) + + generated.save(output_path, quality=95) + +def main(): + parser = argparse.ArgumentParser(description="DreamBooth虚拟微调脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--instance_data_dir', required=True) + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--output_dir', required=True) + parser.add_argument('--validation_image_output_dir', required=True) + parser.add_argument('--with_prior_preservation', action='store_true') + parser.add_argument('--prior_loss_weight', type=float, default=1.0) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--resolution', type=int, default=512) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--learning_rate', type=float, default=2e-6) + parser.add_argument('--lr_scheduler', default='constant') + parser.add_argument('--lr_warmup_steps', type=int, default=0) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--max_train_steps', type=int, default=1000) + parser.add_argument('--checkpointing_steps', type=int, default=500) + parser.add_argument('--center_crop', action='store_true') + parser.add_argument('--mixed_precision', default='bf16') + parser.add_argument('--prior_generation_precision', default='bf16') + parser.add_argument('--sample_batch_size', type=int, default=5) + parser.add_argument('--validation_prompt', default='a photo of sks person') + parser.add_argument('--num_validation_images', type=int, default=10) + parser.add_argument('--validation_steps', type=int, default=500) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL DREAMBOOTH] 虚拟微调执行开始") + print("=" * 80) + print(f"[VIRTUAL] 微调方法: DreamBooth") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 微调参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 实例数据目录: {args.instance_data_dir}") + print(f" - 模型输出目录: {args.output_dir}") + print(f" - 验证图片输出目录: {args.validation_image_output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print(f" - 验证图片数量: {args.num_validation_images}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.validation_image_output_dir, exist_ok=True) + + # 获取训练图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张训练图片") + + # 模拟训练过程 + print("-" * 80) + print("[VIRTUAL] 开始模拟训练...") + for step in range(0, args.max_train_steps, args.checkpointing_steps): + print(f"[VIRTUAL] 训练步数: {step}/{args.max_train_steps}") + + print(f"[VIRTUAL] 训练步数: {args.max_train_steps}/{args.max_train_steps} (完成)") + + # 生成验证图片(从训练图片复制并添加标记) + print("-" * 80) + print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...") + + generated_count = 0 + for i in range(min(args.num_validation_images, len(image_files))): + source_image = image_files[i % len(image_files)] + filename = f"generated_{i:04d}.png" + output_path = os.path.join(args.validation_image_output_dir, filename) + + try: + create_generated_image(source_image, output_path, i) + generated_count += 1 + print(f"[VIRTUAL] 生成图片 {generated_count}/{args.num_validation_images}: {filename}") + except Exception as e: + print(f"[VIRTUAL] 生成图片失败: {e}") + + # 保存模型文件标记 + model_marker = os.path.join(args.output_dir, "virtual_model.txt") + with open(model_marker, 'w') as f: + f.write("This is a virtual DreamBooth model marker.\n") + f.write(f"Training images: {len(image_files)}\n") + f.write(f"Max steps: {args.max_train_steps}\n") + f.write(f"Generated images: {generated_count}\n") + + print("-" * 80) + print(f"[VIRTUAL] 成功生成 {generated_count} 张图片") + print(f"[VIRTUAL] 模型保存到: {args.output_dir}") + print("[VIRTUAL] 虚拟微调执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py b/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py new file mode 100644 index 0000000..199adc6 --- /dev/null +++ b/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py @@ -0,0 +1,134 @@ +""" +LoRA微调虚拟实现 +用于测试后端流程,不执行实际的模型训练 +""" +import argparse +import os +import sys +import platform +import shutil +import glob +from PIL import Image, ImageDraw, ImageFont + +def create_generated_image(source_image_path, output_path, index): + """创建一个模拟生成的图片(添加水印表示是虚拟生成的)""" + with Image.open(source_image_path) as img: + # 复制原图 + generated = img.copy() + draw = ImageDraw.Draw(generated) + + # 添加水印文字 + width, height = generated.size + text = f"Virtual LoRA #{index}" + + # 简单在图片上绘制文字 + position = (10, height - 30) + draw.text(position, text, fill=(255, 255, 255)) + + generated.save(output_path, quality=95) + +def main(): + parser = argparse.ArgumentParser(description="LoRA虚拟微调脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--instance_data_dir', required=True) + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--output_dir', required=True) + parser.add_argument('--validation_image_output_dir', required=True) + parser.add_argument('--with_prior_preservation', action='store_true') + parser.add_argument('--prior_loss_weight', type=float, default=1.0) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--resolution', type=int, default=512) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--learning_rate', type=float, default=1e-4) + parser.add_argument('--lr_scheduler', default='constant') + parser.add_argument('--lr_warmup_steps', type=int, default=0) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--max_train_steps', type=int, default=1000) + parser.add_argument('--checkpointing_steps', type=int, default=500) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--mixed_precision', default='fp16') + parser.add_argument('--rank', type=int, default=4) + parser.add_argument('--validation_prompt', default='a photo of sks person') + parser.add_argument('--num_validation_images', type=int, default=10) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL LORA] 虚拟微调执行开始") + print("=" * 80) + print(f"[VIRTUAL] 微调方法: LoRA (Low-Rank Adaptation)") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 微调参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 实例数据目录: {args.instance_data_dir}") + print(f" - 模型输出目录: {args.output_dir}") + print(f" - 验证图片输出目录: {args.validation_image_output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - LoRA rank: {args.rank}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print(f" - 验证图片数量: {args.num_validation_images}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.validation_image_output_dir, exist_ok=True) + + # 获取训练图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张训练图片") + + # 模拟训练过程 + print("-" * 80) + print("[VIRTUAL] 开始模拟LoRA训练...") + for step in range(0, args.max_train_steps, args.checkpointing_steps): + print(f"[VIRTUAL] 训练步数: {step}/{args.max_train_steps}") + + print(f"[VIRTUAL] 训练步数: {args.max_train_steps}/{args.max_train_steps} (完成)") + + # 生成验证图片(从训练图片复制并添加标记) + print("-" * 80) + print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...") + + generated_count = 0 + for i in range(min(args.num_validation_images, len(image_files))): + source_image = image_files[i % len(image_files)] + filename = f"generated_{i:04d}.png" + output_path = os.path.join(args.validation_image_output_dir, filename) + + try: + create_generated_image(source_image, output_path, i) + generated_count += 1 + print(f"[VIRTUAL] 生成图片 {generated_count}/{args.num_validation_images}: {filename}") + except Exception as e: + print(f"[VIRTUAL] 生成图片失败: {e}") + + # 保存LoRA模型文件标记 + model_marker = os.path.join(args.output_dir, "virtual_lora_model.txt") + with open(model_marker, 'w') as f: + f.write("This is a virtual LoRA model marker.\n") + f.write(f"Training images: {len(image_files)}\n") + f.write(f"LoRA rank: {args.rank}\n") + f.write(f"Max steps: {args.max_train_steps}\n") + f.write(f"Generated images: {generated_count}\n") + + print("-" * 80) + print(f"[VIRTUAL] 成功生成 {generated_count} 张图片") + print(f"[VIRTUAL] LoRA模型保存到: {args.output_dir}") + print("[VIRTUAL] 虚拟微调执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/perturbation/aspl.py b/src/backend/app/algorithms/perturbation/aspl.py new file mode 100644 index 0000000..6f26194 --- /dev/null +++ b/src/backend/app/algorithms/perturbation/aspl.py @@ -0,0 +1,770 @@ +import argparse +import copy +import hashlib +import itertools +import logging +import os +from pathlib import Path + +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.utils.import_utils import is_xformers_available +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + + +logger = get_logger(__name__) + + +class DreamBoothDatasetFromTensor(Dataset): + """Just like DreamBoothDataset, but take instance_images_tensor instead of path""" + + def __init__( + self, + instance_images_tensor, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_images_tensor = instance_images_tensor + self.num_instance_images = len(self.instance_images_tensor) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images_tensor[index % self.num_instance_images] + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir_for_train", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_dir_for_adversarial", + type=str, + default=None, + required=True, + help="A folder containing the images to add adversarial noise", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=8, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--max_f_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train surogate model.", + ) + parser.add_argument( + "--max_adv_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train adversarial noise.", + ) + parser.add_argument( + "--checkpointing_iterations", + type=int, + default=5, + help=("Save a checkpoint of the training state every X iterations."), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--pgd_alpha", + type=float, + default=1.0 / 255, + help="The step size for pgd.", + ) + parser.add_argument( + "--pgd_eps", + type=int, + default=0.05, + help="The noise budget for pgd.", + ) + parser.add_argument( + "--target_image_path", + default=None, + help="target image for attacking", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: + image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] + images = torch.stack(images) + return images + + +def train_one_epoch( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + num_steps=20, +): + # Load the tokenizer + + unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + train_dataset = DreamBoothDatasetFromTensor( + data_tensor, + args.instance_prompt, + tokenizer, + args.class_data_dir, + args.class_prompt, + args.resolution, + args.center_crop, + ) + + # weight_dtype = torch.bfloat16 + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + for step in range(num_steps): + unet.train() + text_encoder.train() + + step_data = train_dataset[step % len(train_dataset)] + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( + device, dtype=weight_dtype + ) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # with prior preservation loss + if args.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = instance_loss + args.prior_loss_weight * prior_loss + + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + loss.backward() + torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) + optimizer.step() + optimizer.zero_grad() + print( + f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}" + ) + + return [unet, text_encoder] + + +def pgd_attack( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, + num_steps: int, +): + """Return new perturbed data""" + + unet, text_encoder = models + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + perturbed_images = data_tensor.detach().clone() + perturbed_images.requires_grad_(True) + + input_ids = tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.repeat(len(data_tensor), 1) + + for step in range(num_steps): + perturbed_images.requires_grad = True + latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + #noise_scheduler.config.num_train_timesteps + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + unet.zero_grad() + text_encoder.zero_grad() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # target-shift loss + if target_tensor is not None: + xtm1_pred = torch.cat( + [ + noise_scheduler.step( + model_pred[idx : idx + 1], + timesteps[idx : idx + 1], + noisy_latents[idx : idx + 1], + ).prev_sample + for idx in range(len(model_pred)) + ] + ) + xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1) + loss = loss - F.mse_loss(xtm1_pred, xtm1_target) + + loss.backward() + + alpha = args.pgd_alpha + eps = args.pgd_eps / 255 + + adv_images = perturbed_images + alpha * perturbed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + print(f"PGD loss - step {step}, loss: {loss.detach().item()}") + return perturbed_images + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.mixed_precision == "fp32": + torch_dtype = torch.float32 + elif args.mixed_precision == "fp16": + torch_dtype = torch.float16 + elif args.mixed_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ).cuda() + + vae.requires_grad_(False) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + clean_data = load_data( + args.instance_data_dir_for_train, + size=args.resolution, + center_crop=args.center_crop, + ) + perturbed_data = load_data( + args.instance_data_dir_for_adversarial, + size=args.resolution, + center_crop=args.center_crop, + ) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + target_latent_tensor = None + if args.target_image_path is not None: + target_image_path = Path(args.target_image_path) + assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" + + target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) + target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) + + target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 + target_latent_tensor = ( + vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor + ) + target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() + + f = [unet, text_encoder] + for i in range(args.max_train_steps): + # 1. f' = f.clone() + f_sur = copy.deepcopy(f) + f_sur = train_one_epoch( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + clean_data, + args.max_f_train_steps, + ) + perturbed_data = pgd_attack( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + original_data, + target_latent_tensor, + args.max_adv_train_steps, + ) + f = train_one_epoch( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + args.max_f_train_steps, + ) + + if (i + 1) % args.checkpointing_iterations == 0: + save_folder = args.output_dir + os.makedirs(save_folder, exist_ok=True) + noised_imgs = perturbed_data.detach() + + img_filenames = [ + Path(instance_path).stem + for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir()) + ] + + for img_pixel, img_name in zip(noised_imgs, img_filenames): + save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") + Image.fromarray( + (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy() + ).save(save_path) + + print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation/caat.py b/src/backend/app/algorithms/perturbation/caat.py new file mode 100644 index 0000000..c7e41cd --- /dev/null +++ b/src/backend/app/algorithms/perturbation/caat.py @@ -0,0 +1,972 @@ +import argparse +import hashlib +import itertools +import json +import logging +import os +import random +import warnings +import shutil +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from packaging import version +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available + + +logger = get_logger(__name__) + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +class CustomDiffusionDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + concepts_list, + tokenizer, + size=512, + mask_size=64, + center_crop=False, + with_prior_preservation=False, + num_class_images=200, + hflip=False, + aug=True, + ): + self.size = size + self.mask_size = mask_size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.interpolation = Image.BILINEAR + self.aug = aug + + self.instance_images_path = [] + self.class_images_path = [] + self.with_prior_preservation = with_prior_preservation + for concept in concepts_list: + inst_img_path = [ + (x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file() + ] + self.instance_images_path.extend(inst_img_path) + + if with_prior_preservation: + class_data_root = Path(concept["class_data_dir"]) + if os.path.isdir(class_data_root): + class_images_path = list(class_data_root.iterdir()) + class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))] + else: + with open(class_data_root, "r") as f: + class_images_path = f.read().splitlines() + with open(concept["class_prompt"], "r") as f: + class_prompt = f.read().splitlines() + + class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] + self.class_images_path.extend(class_img_path[:num_class_images]) + + random.shuffle(self.instance_images_path) + self.num_instance_images = len(self.instance_images_path) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) + + self.image_transforms = transforms.Compose( + [ + self.flip, + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def preprocess(self, image, scale, resample): + outer, inner = self.size, scale + factor = self.size // self.mask_size + if scale > self.size: + outer, inner = scale, self.size + top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1) + image = image.resize((scale, scale), resample=resample) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) + mask = np.zeros((self.size // factor, self.size // factor)) + if scale > self.size: + instance_image = image[top : top + inner, left : left + inner, :] + mask = np.ones((self.size // factor, self.size // factor)) + else: + instance_image[top : top + inner, left : left + inner, :] = image + mask[ + top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1 + ] = 1.0 + return instance_image, mask + + def __getitem__(self, index): + example = {} + instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images] + instance_image = Image.open(instance_image) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + instance_image = self.flip(instance_image) + + # apply resize augmentation and create a valid image region mask + random_scale = self.size + if self.aug: + random_scale = ( + np.random.randint(self.size // 3, self.size + 1) + if np.random.uniform() < 0.66 + else np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) + ) + instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation) + + if random_scale < 0.6 * self.size: + instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt + elif random_scale > self.size: + instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt + + example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1) + example["mask"] = torch.from_numpy(mask) + example["instance_prompt_ids"] = self.tokenizer( + instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.with_prior_preservation: + class_image, class_prompt = self.class_images_path[index % self.num_class_images] + class_image = Image.open(class_image) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_mask"] = torch.ones_like(example["mask"]) + example["class_prompt_ids"] = self.tokenizer( + class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="CAAT training script.") + parser.add_argument( + "--alpha", + type=float, + default=5e-3, + required=True, + help="PGD alpha.", + ) + parser.add_argument( + "--eps", + type=float, + default=0.1, + required=True, + help="PGD eps.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--num_class_images", + type=int, + default=200, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs", + help="The output directory.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=250, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=250, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=2, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--freeze_model", + type=str, + default="crossattn_kv", + choices=["crossattn_kv", "crossattn"], + help="crossattn to enable fine-tuning of all params in the cross attention", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word." + ) + parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") + parser.add_argument( + "--noaug", + action="store_true", + help="Dont apply augmentation during data augmentation when this flag is enabled.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.concepts_list is None: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + accelerator.init_trackers("CAAT", config=vars(args)) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + if args.concepts_list is None: + args.concepts_list = [ + { + "instance_prompt": args.instance_prompt, + "class_prompt": args.class_prompt, + "instance_data_dir": args.instance_data_dir, + "class_data_dir": args.class_data_dir, + } + ] + else: + with open(args.concepts_list, "r") as f: + args.concepts_list = json.load(f) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + for i, concept in enumerate(args.concepts_list): + class_images_dir = Path(concept["class_data_dir"]) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True, exist_ok=True) + + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = ( + class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + attention_class = CustomDiffusionAttnProcessor + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + attention_class = CustomDiffusionXFormersAttnProcessor + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # now we will add new Custom Diffusion weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer + train_kv = True + train_q_out = False if args.freeze_model == "crossattn_kv" else True + custom_diffusion_attn_procs = {} + + st = unet.state_dict() + + for name, _ in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + layer_name = name.split(".processor")[0] + weights = { + "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], + "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], + } + if train_q_out: + weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] + weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] + weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + if cross_attention_dim is not None: + custom_diffusion_attn_procs[name] = attention_class( + train_kv=train_kv, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ).to(unet.device) + custom_diffusion_attn_procs[name].load_state_dict(weights) + else: + custom_diffusion_attn_procs[name] = attention_class( + train_kv=False, + train_q_out=False, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + del st + + + unet.set_attn_processor(custom_diffusion_attn_procs) + custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) + + accelerator.register_for_checkpointing(custom_diffusion_layers) + + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + args.learning_rate = args.learning_rate + if args.with_prior_preservation: + args.learning_rate = args.learning_rate * 2.0 + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + optimizer = optimizer_class( + custom_diffusion_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset creation: + train_dataset = CustomDiffusionDataset( + concepts_list=args.concepts_list, + tokenizer=tokenizer, + with_prior_preservation=args.with_prior_preservation, + size=args.resolution, + mask_size=vae.encode( + torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device) + ) + .latent_dist.sample() + .size()[-1], + center_crop=args.center_crop, + num_class_images=args.num_class_images, + hflip=args.hflip, + aug=not args.noaug, + ) + + + # Prepare for PGD + pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] + pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images] + pertubed_images = torch.stack(pertubed_images).contiguous() + pertubed_images.requires_grad_() + + original_images = pertubed_images.clone().detach() + original_images.requires_grad_(False) + + input_ids = train_dataset.tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=train_dataset.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.repeat(len(original_images), 1) + + def get_one_mask(image): + random_scale = train_dataset.size + if train_dataset.aug: + random_scale = ( + np.random.randint(train_dataset.size // 3, train_dataset.size + 1) + if np.random.uniform() < 0.66 + else np.random.randint(int(1.2 * train_dataset.size), int(1.4 * train_dataset.size)) + ) + _, one_mask = train_dataset.preprocess(image, random_scale, train_dataset.interpolation) + one_mask = torch.from_numpy(one_mask) + if args.with_prior_preservation: + class_mask = torch.ones_like(one_mask) + one_mask += class_mask + return one_mask + + images_open_list = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] + mask_list = [] + for image in images_open_list: + mask_list.append(get_one_mask(image)) + + mask = torch.stack(mask_list) + mask = mask.to(memory_format=torch.contiguous_format).float() + mask = mask.unsqueeze(1) + del images_open_list + + + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare( + custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask + ) + + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num pertubed_images = {len(pertubed_images)}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + for epoch in range(first_epoch, args.max_train_steps): + unet.train() + for _ in range(1): + with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): + # Convert images to latent space + pertubed_images.requires_grad = True + latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # unet.zero_grad() + # text_encoder.zero_grad() + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + mask = mask.to(accelerator.device) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64]) + + #loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + loss = loss.mean() + + accelerator.backward(loss) + + + if accelerator.sync_gradients: + params_to_clip = ( + custom_diffusion_layers.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + alpha = args.alpha + eps = args.eps + adv_images = pertubed_images + alpha * pertubed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + + optimizer.step() + + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + logger.info("***** Final save of perturbed images *****") + save_folder = args.output_dir + + noised_imgs = pertubed_images.detach().cpu() + + img_names = [ + str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path + ] + + num_images_to_save = len(img_names) + + for i in range(num_images_to_save): + img_pixel = noised_imgs[i] + img_name = img_names[i] + save_path = os.path.join(save_folder, f"final_noise_{img_name}") + + # 图像转换和保存 + Image.fromarray( + (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy() + ).save(save_path) + + logger.info(f"Saved {num_images_to_save} final perturbed images to {save_folder}") + + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) + print("<-------end-------->") diff --git a/src/backend/app/algorithms/perturbation/pid.py b/src/backend/app/algorithms/perturbation/pid.py new file mode 100644 index 0000000..e4e35fc --- /dev/null +++ b/src/backend/app/algorithms/perturbation/pid.py @@ -0,0 +1,272 @@ +import argparse +import os +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from diffusers import AutoencoderKL + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of updating steps", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + '--eps', + type=float, + default=12.75, + help='pertubation budget' + ) + parser.add_argument( + '--step_size', + type=float, + default=1/255, + help='step size of each update' + ) + parser.add_argument( + '--attack_type', + choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'], + help='what is the attack target' + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class PIDDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + size=512, + center_crop=False + ): + self.size = size + self.center_crop = center_crop + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(),]) + + def __len__(self): + return self.num_instance_images + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example['index'] = index % self.num_instance_images + example['pixel_values'] = self.image_transforms(instance_image) + return example + + +def main(args): + # Set random seed + if args.seed is not None: + torch.manual_seed(args.seed) + weight_dtype = torch.float32 + device = torch.device('cuda') + + # VAE encoder + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + vae.requires_grad_(False) + vae.to(device, dtype=weight_dtype) + + # Dataset and DataLoaders creation: + dataset = PIDDataset( + instance_data_root=args.instance_data_dir, + size=args.resolution, + center_crop=args.center_crop, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, # some parts of code don't support batching + shuffle=True, + num_workers=args.dataloader_num_workers, + ) + + # Wrapper of the perturbations generator + class AttackModel(torch.nn.Module): + def __init__(self): + super().__init__() + to_tensor = transforms.ToTensor() + self.epsilon = args.eps/255 + self.delta = [torch.empty_like(to_tensor(Image.open(path))).uniform_(-self.epsilon, self.epsilon) + for path in dataset.instance_images_path] + self.size = dataset.size + + def forward(self, vae, x, index, poison=False): + # Check whether we need to add perturbation + if poison: + self.delta[index].requires_grad_(True) + x = x + self.delta[index].to(dtype=weight_dtype) + + # Normalize to [-1, 1] + input_x = 2 * x - 1 + return vae.encode(input_x.to(device)) + + attackmodel = AttackModel() + + # Just to zero-out the gradient + optimizer = torch.optim.SGD(attackmodel.delta, lr=0) + + # Progress bar + progress_bar = tqdm(range(0, args.max_train_steps), desc="Steps") + + # Make sure the dir exists + os.makedirs(args.output_dir, exist_ok=True) + + # Start optimizing the perturbation + for step in progress_bar: + + total_loss = 0.0 + for batch in dataloader: + # Save images + if step%25 == 0: + to_image = transforms.ToPILImage() + for i in range(0, len(dataset.instance_images_path)): + img = dataset[i]['pixel_values'] + img = to_image(img + attackmodel.delta[i]) + img.save(os.path.join(args.output_dir, f"{i}.png")) + + + # Select target loss + clean_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], False) + poison_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], True) + clean_latent = clean_embedding.latent_dist + poison_latent = poison_embedding.latent_dist + + if args.attack_type == 'var': + loss = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean") + elif args.attack_type == 'mean': + loss = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean") + elif args.attack_type == 'KL': + sigma_2, mu_2 = poison_latent.std, poison_latent.mean + sigma_1, mu_1 = clean_latent.std, clean_latent.mean + KL_diver = torch.log(sigma_2 / sigma_1) - 0.5 + (sigma_1 ** 2 + (mu_1 - mu_2) ** 2) / (2 * sigma_2 ** 2) + loss = KL_diver.flatten().mean() + elif args.attack_type == 'latent_vector': + clean_vector = clean_latent.sample() + poison_vector = poison_latent.sample() + loss = F.mse_loss(clean_vector, poison_vector, reduction="mean") + elif args.attack_type == 'add': + loss_2 = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean") + loss_1 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean") + loss = loss_1 + loss_2 + elif args.attack_type == 'add-log': + loss_1 = F.mse_loss(clean_latent.var.log(), poison_latent.var.log(), reduction="mean") + loss_2 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction='mean') + loss = loss_1 + loss_2 + + + optimizer.zero_grad() + loss.backward() + + # Perform PGD update on the loss + delta = attackmodel.delta[batch['index']] + delta.requires_grad_(False) + delta += delta.grad.sign() * 1/255 + delta = torch.clamp(delta, -attackmodel.epsilon, attackmodel.epsilon) + delta = torch.clamp(delta, -batch['pixel_values'].detach().cpu(), 1-batch['pixel_values'].detach().cpu()) + attackmodel.delta[batch['index']] = delta.detach().squeeze(0) + + total_loss += loss.detach().cpu() + + # Logging steps + logs = {"loss": total_loss.item()} + progress_bar.set_postfix(**logs) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/backend/app/algorithms/perturbation/simac.py b/src/backend/app/algorithms/perturbation/simac.py new file mode 100644 index 0000000..0fc673e --- /dev/null +++ b/src/backend/app/algorithms/perturbation/simac.py @@ -0,0 +1,1039 @@ +import argparse +import copy +import hashlib +import itertools +import logging +import os +from pathlib import Path + +import datasets +import diffusers +import random +from torch.backends import cudnn +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.utils.import_utils import is_xformers_available +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + + +logger = get_logger(__name__) + + +class DreamBoothDatasetFromTensor(Dataset): + """Just like DreamBoothDataset, but take instance_images_tensor instead of path""" + + def __init__( + self, + instance_images_tensor, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_images_tensor = instance_images_tensor + self.num_instance_images = len(self.instance_images_tensor) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images_tensor[index % self.num_instance_images] + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir_for_train", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_dir_for_adversarial", + type=str, + default=None, + required=True, + help="A folder containing the images to add adversarial noise", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=8, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--max_f_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train surogate model.", + ) + parser.add_argument( + "--max_adv_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train adversarial noise.", + ) + parser.add_argument( + "--checkpointing_iterations", + type=int, + default=5, + help=("Save a checkpoint of the training state every X iterations."), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--pgd_alpha", + type=float, + default=0.005, + help="The step size for pgd.", + ) + parser.add_argument( + "--pgd_eps", + type=int, + default=16, + help="The noise budget for pgd.", + ) + parser.add_argument( + "--target_image_path", + default=None, + help="target image for attacking", + ) + parser.add_argument( + "--max_steps", + type=int, + default=50, + help=( + "Maximum steps for adaptive greedy timestep selection." + ), + ) + parser.add_argument( + "--delta_t", + type=int, + default=20, + help=( + "delete 2*delta_t for each adaptive greedy timestep selection." + ), + ) + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: + image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] + images = torch.stack(images) + return images + + + +def train_one_epoch( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + num_steps=20, +): + # Load the tokenizer + + unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + train_dataset = DreamBoothDatasetFromTensor( + data_tensor, + args.instance_prompt, + tokenizer, + args.class_data_dir, + args.class_prompt, + args.resolution, + args.center_crop, + ) + + # weight_dtype = torch.bfloat16 + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + for step in range(num_steps): + unet.train() + text_encoder.train() + + step_data = train_dataset[step % len(train_dataset)] + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( + device, dtype=weight_dtype + ) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # with prior preservation loss + if args.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = instance_loss + args.prior_loss_weight * prior_loss + + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + loss.backward() + torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) + optimizer.step() + optimizer.zero_grad() + print( + f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}" + ) + + return [unet, text_encoder] + +def set_unet_attr(unet): + def conv_forward(self): + def forward(input_tensor, temb): + self.in_layers_features = input_tensor + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + self.out_layers_features = hidden_states + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + return forward + + # [MODIFIED] 只 hook 算法实际使用到的 up_blocks[3] + conv_module_list = [ + unet.up_blocks[3].resnets[0], unet.up_blocks[3].resnets[1], unet.up_blocks[3].resnets[2], + ] + for conv_module in conv_module_list: + conv_module.forward = conv_forward(conv_module) + setattr(conv_module, 'in_layers_features', None) + setattr(conv_module, 'out_layers_features', None) + + + +def save_feature_maps(up_blocks, down_blocks): + + out_layers_features_list_3 = [] + res_3_list =[0,1,2] + + # [MODIFIED] 只提取 up_blocks[3] 的特征 + block = up_blocks[3] + for index in res_3_list: + out_layers_features_list_3.append(block.resnets[index].out_layers_features) + + out_layers_features_list_3 = torch.stack(out_layers_features_list_3, dim=0) + + # [MODIFIED] 只返回算法实际使用到的特征 + return out_layers_features_list_3 + +def pgd_attack( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, + num_steps: int, + time_list +): + """Return new perturbed data""" + + unet, text_encoder = models + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + set_unet_attr(unet) + + perturbed_images = data_tensor.detach().clone() + perturbed_images.requires_grad_(True) + + input_ids = tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.repeat(len(data_tensor), 1) + + for step in range(num_steps): + perturbed_images.requires_grad = True + latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = [] + for i in range(len(data_tensor)): + ts = time_list[i] + ts_index = torch.randint(0, len(ts), (1,)) + timestep = torch.IntTensor([ts[ts_index]]) + timestep = timestep.long() + timesteps.append(timestep) + timesteps = torch.cat(timesteps).to(device) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # [MODIFIED] feature loss (只解包需要的特征) + noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + + with torch.no_grad(): + clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() + clean_latents = clean_latents * vae.config.scaling_factor + noisy_clean_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) + clean_model_pred = unet(noisy_clean_latents, timesteps, encoder_hidden_states).sample + + # [MODIFIED] (只解包需要的特征) + clean_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + + # [LOGIC UNCHANGED] 目标损失函数不变 + target_loss = F.mse_loss(noise_out_layers_features_3.float(), clean_out_layers_features_3.float(), reduction="mean") + unet.zero_grad() + text_encoder.zero_grad() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = loss + target_loss.detach().item() # 保持原有的(奇怪的) loss.backward() 逻辑 + loss.backward() + alpha = args.pgd_alpha + eps = args.pgd_eps / 255 + adv_images = perturbed_images + alpha * perturbed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + print(f"PGD loss - step {step}, loss: {loss.detach().item()}, target_loss : {target_loss.detach().item()}") + + # [MODIFIED] 显式释放特征张量并清理缓存,以确保后续 train_one_epoch 有足够的显存 + # 这部分代码在 PGD 循环结束后添加 (即在 return perturbed_images 之前) + del noise_out_layers_features_3 + del clean_out_layers_features_3 + del noise + del latents + del encoder_hidden_states + torch.cuda.empty_cache() + + return perturbed_images + +def select_timestep( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, + ): + """Return new perturbed data""" + + unet, text_encoder = models + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + perturbed_images = data_tensor.detach().clone() + perturbed_images.requires_grad_(True) + + + input_ids = tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + time_list = [] + for id in range(len(data_tensor)): + perturbed_image = perturbed_images[id, :].unsqueeze(0) + original_image = original_images[id, :].unsqueeze(0) + time_seq = torch.tensor(list(range(0, 1000))) + input_mask = torch.ones_like(time_seq) + id_image = perturbed_image.detach().clone() + for step in range(args.max_steps): + id_image.requires_grad_(True) + select_mask = torch.where(input_mask==1, True, False) + res_time_seq = torch.masked_select(time_seq, select_mask) + if len(res_time_seq) > 100: + min_score, max_score = 0.0, 0.0 + for index in range(0, 5): + id_image.requires_grad_(True) + latents = vae.encode(id_image.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + inner_index = torch.randint(0, len(res_time_seq), (bsz,)) + timesteps = torch.IntTensor([res_time_seq[inner_index]]).to(device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + unet.zero_grad() + text_encoder.zero_grad() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss.backward() + score = torch.sum(torch.abs(id_image.grad.data)) + index = index + 1 + id_image.grad.zero_() + if index == 1: + min_score = score + max_score = score + del_t = res_time_seq[inner_index].item() + select_t = res_time_seq[inner_index].item() + else: + if min_score > score: + min_score = score + del_t = res_time_seq[inner_index].item() + if max_score < score: + max_score = score + select_t = res_time_seq[inner_index].item() + print(f"PGD loss - step {step}, index : {index}, loss: {loss.detach().item()}, score: {score}, t : {res_time_seq[inner_index]}, ts_len: {len(res_time_seq)}") + + print("del_t", del_t, "max_t", select_t) + if del_t < args.delta_t : + del_t = args.delta_t + elif del_t > (1000 - args.delta_t): + del_t= 1000 - args.delta_t + input_mask[del_t - 20: del_t + 20] = input_mask[del_t - 20: del_t + 20] - 1 + input_mask = torch.clamp(input_mask, min=0, max=+1) + + id_image.requires_grad_(True) + latents = vae.encode(id_image.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.IntTensor([select_t]).to(device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + unet.zero_grad() + text_encoder.zero_grad() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss.backward() + alpha = args.pgd_alpha + eps = args.pgd_eps / 255 + adv_image = id_image + alpha * id_image.grad.sign() + eta = torch.clamp(adv_image - original_image, min=-eps, max=+eps) + score = torch.sum(torch.abs(id_image.grad.sign())) + id_image = torch.clamp(original_image + eta, min=-1, max=+1).detach_() + + else: + # print(id, res_time_seq, step, len(res_time_seq)) + time_list.append(res_time_seq) + break + return time_list + +def setup_seeds(): + seed = 42 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + setup_seeds() + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.mixed_precision == "fp32": + torch_dtype = torch.float32 + elif args.mixed_precision == "fp16": + torch_dtype = torch.float16 + elif args.mixed_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, + ) + + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", ) + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, + ).cuda() + + vae.requires_grad_(False) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + clean_data = load_data( + args.instance_data_dir_for_train, + size=args.resolution, + center_crop=args.center_crop, + ) + perturbed_data = load_data( + args.instance_data_dir_for_adversarial, + size=args.resolution, + center_crop=args.center_crop, + ) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + target_latent_tensor = None + if args.target_image_path is not None: + target_image_path = Path(args.target_image_path) + assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" + + target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) + target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) + + target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 + target_latent_tensor = ( + vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor + ) + target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() + + f = [unet, text_encoder] + + time_list = select_timestep( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + original_data, + target_latent_tensor, + ) + for t in time_list: + print(t) + for i in range(args.max_train_steps): + # 1. f' = f.clone() + f_sur = copy.deepcopy(f) + f_sur = train_one_epoch( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + clean_data, + args.max_f_train_steps, + ) + perturbed_data = pgd_attack( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + original_data, + target_latent_tensor, + args.max_adv_train_steps, + time_list + ) + f = train_one_epoch( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + args.max_f_train_steps, + ) + + if (i + 1) % args.checkpointing_iterations == 0: + + save_folder = args.output_dir + os.makedirs(save_folder, exist_ok=True) + noised_imgs = perturbed_data.detach() + + img_names = [ + str(instance_path).split("/")[-1].split(".")[0] + for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir()) + ] + + for img_pixel, img_name in zip(noised_imgs, img_names): + save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") + Image.fromarray( + (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy() + ).save(save_path) + + print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation_engine.py b/src/backend/app/algorithms/perturbation_engine.py deleted file mode 100644 index 417b1f4..0000000 --- a/src/backend/app/algorithms/perturbation_engine.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -对抗性扰动算法引擎 -实现各种加噪算法的虚拟版本 -""" - -import os -import numpy as np -from PIL import Image, ImageEnhance, ImageFilter -import uuid -from flask import current_app - -class PerturbationEngine: - """对抗性扰动处理引擎""" - - @staticmethod - def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None): - """ - 应用对抗性扰动 - - Args: - image_path: 原始图片路径 - algorithm: 算法名称 (simac, caat, pid) - epsilon: 扰动强度 - use_strong_protection: 是否使用防净化版本 - - Returns: - 处理后图片的路径 - """ - try: - if not os.path.exists(image_path): - raise FileNotFoundError(f"图片文件不存在: {image_path}") - - # 加载图片 - with Image.open(image_path) as img: - # 转换为RGB模式 - if img.mode != 'RGB': - img = img.convert('RGB') - - # 根据算法选择处理方法 - if algorithm == 'simac': - perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection) - elif algorithm == 'caat': - perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection) - elif algorithm == 'pid': - perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection) - else: - raise ValueError(f"不支持的算法: {algorithm}") - - # 使用输入的output_path参数 - if output_path is None: - # 如果没有提供输出路径,使用默认路径 - from flask import current_app - project_root = os.path.dirname(current_app.root_path) - perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER']) - os.makedirs(perturbed_dir, exist_ok=True) - - file_extension = os.path.splitext(image_path)[1] - output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}" - output_path = os.path.join(perturbed_dir, output_filename) - - # 保存处理后的图片 - perturbed_img.save(output_path, quality=95) - - return output_path - - except Exception as e: - print(f"应用扰动时出错: {str(e)}") - return None - - @staticmethod - def _apply_simac(img, epsilon, use_strong_protection): - """ - SimAC算法的虚拟实现 - Simple Anti-Customization Method for Protecting Face Privacy - """ - # 将PIL图像转换为numpy数组 - img_array = np.array(img, dtype=np.float32) - - # 生成随机噪声(模拟对抗性扰动) - noise_scale = epsilon / 255.0 - - if use_strong_protection: - # 防净化版本:添加更复杂的扰动模式 - noise = np.random.normal(0, noise_scale * 0.8, img_array.shape) - # 添加结构化噪声 - h, w = img_array.shape[:2] - for i in range(0, h, 8): - for j in range(0, w, 8): - block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3)) - noise[i:i+8, j:j+8] += block_noise - else: - # 标准版本:简单高斯噪声 - noise = np.random.normal(0, noise_scale, img_array.shape) - - # 应用噪声 - perturbed_array = img_array + noise * 255.0 - perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) - - # 转换回PIL图像 - result_img = Image.fromarray(perturbed_array) - - # 轻微的图像增强以模拟算法特性 - enhancer = ImageEnhance.Contrast(result_img) - result_img = enhancer.enhance(1.02) - - return result_img - - @staticmethod - def _apply_caat(img, epsilon, use_strong_protection): - """ - CAAT算法的虚拟实现 - Perturbing Attention Gives You More Bang for the Buck - """ - img_array = np.array(img, dtype=np.float32) - - noise_scale = epsilon / 255.0 - - if use_strong_protection: - # 防净化版本:注意力区域重点扰动 - # 模拟注意力图(简单的边缘检测) - gray_img = img.convert('L') - edge_img = gray_img.filter(ImageFilter.FIND_EDGES) - attention_map = np.array(edge_img, dtype=np.float32) / 255.0 - - # 在注意力区域添加更强的噪声 - noise = np.random.normal(0, noise_scale * 0.6, img_array.shape) - for c in range(3): - noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape) - else: - # 标准版本:均匀分布噪声 - noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape) - - perturbed_array = img_array + noise * 255.0 - perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) - - result_img = Image.fromarray(perturbed_array) - - # 轻微模糊以模拟注意力扰动效果 - result_img = result_img.filter(ImageFilter.BoxBlur(0.5)) - - return result_img - - @staticmethod - def _apply_pid(img, epsilon, use_strong_protection): - """ - PID算法的虚拟实现 - Prompt-Independent Data Protection Against Latent Diffusion Models - """ - img_array = np.array(img, dtype=np.float32) - - noise_scale = epsilon / 255.0 - - if use_strong_protection: - # 防净化版本:频域扰动 - # 简单的频域变换模拟 - noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape) - # 添加周期性扰动 - h, w = img_array.shape[:2] - for i in range(h): - for j in range(w): - periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1) - noise[i, j] += periodic_noise - else: - # 标准版本:拉普拉斯噪声 - noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape) - - perturbed_array = img_array + noise * 255.0 - perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8) - - result_img = Image.fromarray(perturbed_array) - - # 色彩微调以模拟潜在空间扰动 - enhancer = ImageEnhance.Color(result_img) - result_img = enhancer.enhance(0.98) - - return result_img \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation_virtual/aspl.py b/src/backend/app/algorithms/perturbation_virtual/aspl.py new file mode 100644 index 0000000..d20cb11 --- /dev/null +++ b/src/backend/app/algorithms/perturbation_virtual/aspl.py @@ -0,0 +1,87 @@ +""" +ASPL算法虚拟实现 +用于测试后端流程,不执行实际的扰动生成 +""" +import argparse +import os +import sys +import platform +import shutil +import glob + +def main(): + parser = argparse.ArgumentParser(description="ASPL虚拟算法脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true') + parser.add_argument('--instance_data_dir_for_train', required=True) + parser.add_argument('--instance_data_dir_for_adversarial', required=True) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--output_dir', required=True) + parser.add_argument('--center_crop', action='store_true') + parser.add_argument('--with_prior_preservation', action='store_true') + parser.add_argument('--prior_loss_weight', type=float, default=1.0) + parser.add_argument('--resolution', type=int, default=384) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--max_train_steps', type=int, default=50) + parser.add_argument('--max_f_train_steps', type=int, default=3) + parser.add_argument('--max_adv_train_steps', type=int, default=6) + parser.add_argument('--checkpointing_iterations', type=int, default=10) + parser.add_argument('--learning_rate', type=float, default=5e-7) + parser.add_argument('--pgd_alpha', type=float, default=0.005) + parser.add_argument('--pgd_eps', type=float, default=8) + parser.add_argument('--seed', type=int, default=0) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL ASPL] 虚拟算法执行开始") + print("=" * 80) + print(f"[VIRTUAL] 算法名称: ASPL (Anti-DreamBooth)") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 算法参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 输入训练目录: {args.instance_data_dir_for_train}") + print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}") + print(f" - 输出目录: {args.output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - 扰动强度(pgd_eps): {args.pgd_eps}") + print(f" - PGD alpha: {args.pgd_alpha}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 复制图片到输出目录(模拟扰动生成) + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片") + + copied_count = 0 + for image_path in image_files: + filename = os.path.basename(image_path) + output_path = os.path.join(args.output_dir, filename) + shutil.copy(image_path, output_path) + copied_count += 1 + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + + print("-" * 80) + print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") + print("[VIRTUAL] 虚拟算法执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/perturbation_virtual/caat.py b/src/backend/app/algorithms/perturbation_virtual/caat.py new file mode 100644 index 0000000..bbc436a --- /dev/null +++ b/src/backend/app/algorithms/perturbation_virtual/caat.py @@ -0,0 +1,79 @@ +""" +CAAT算法虚拟实现 +用于测试后端流程,不执行实际的扰动生成 +""" +import argparse +import os +import sys +import platform +import shutil +import glob + +def main(): + parser = argparse.ArgumentParser(description="CAAT虚拟算法脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true') + parser.add_argument('--instance_data_dir_for_train', required=True) + parser.add_argument('--instance_data_dir_for_adversarial', required=True) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--output_dir', required=True) + parser.add_argument('--resolution', type=int, default=512) + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--max_train_steps', type=int, default=800) + parser.add_argument('--learning_rate', type=float, default=1e-5) + parser.add_argument('--pgd_eps', type=float, default=16) + parser.add_argument('--seed', type=int, default=0) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL CAAT] 虚拟算法执行开始") + print("=" * 80) + print(f"[VIRTUAL] 算法名称: CAAT (Class-wise Adversarial Attack)") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 算法参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 输入训练目录: {args.instance_data_dir_for_train}") + print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}") + print(f" - 输出目录: {args.output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - 扰动强度(pgd_eps): {args.pgd_eps}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 复制图片到输出目录(模拟扰动生成) + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片") + + copied_count = 0 + for image_path in image_files: + filename = os.path.basename(image_path) + output_path = os.path.join(args.output_dir, filename) + shutil.copy(image_path, output_path) + copied_count += 1 + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + + print("-" * 80) + print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") + print("[VIRTUAL] 虚拟算法执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/perturbation_virtual/pid.py b/src/backend/app/algorithms/perturbation_virtual/pid.py new file mode 100644 index 0000000..5453f01 --- /dev/null +++ b/src/backend/app/algorithms/perturbation_virtual/pid.py @@ -0,0 +1,79 @@ +""" +PID算法虚拟实现 +用于测试后端流程,不执行实际的扰动生成 +""" +import argparse +import os +import sys +import platform +import shutil +import glob + +def main(): + parser = argparse.ArgumentParser(description="PID虚拟算法脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true') + parser.add_argument('--instance_data_dir_for_train', required=True) + parser.add_argument('--instance_data_dir_for_adversarial', required=True) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--output_dir', required=True) + parser.add_argument('--resolution', type=int, default=512) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--max_train_steps', type=int, default=600) + parser.add_argument('--learning_rate', type=float, default=3e-6) + parser.add_argument('--pgd_eps', type=float, default=4) + parser.add_argument('--seed', type=int, default=0) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL PID] 虚拟算法执行开始") + print("=" * 80) + print(f"[VIRTUAL] 算法名称: PID (Perturbation Identity Defense)") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 算法参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 输入训练目录: {args.instance_data_dir_for_train}") + print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}") + print(f" - 输出目录: {args.output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - 扰动强度(pgd_eps): {args.pgd_eps}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 复制图片到输出目录(模拟扰动生成) + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片") + + copied_count = 0 + for image_path in image_files: + filename = os.path.basename(image_path) + output_path = os.path.join(args.output_dir, filename) + shutil.copy(image_path, output_path) + copied_count += 1 + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + + print("-" * 80) + print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") + print("[VIRTUAL] 虚拟算法执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/perturbation_virtual/simac.py b/src/backend/app/algorithms/perturbation_virtual/simac.py new file mode 100644 index 0000000..573331e --- /dev/null +++ b/src/backend/app/algorithms/perturbation_virtual/simac.py @@ -0,0 +1,82 @@ +""" +SimAC算法虚拟实现 +用于测试后端流程,不执行实际的扰动生成 +""" +import argparse +import os +import sys +import platform +import shutil +import glob + +def main(): + parser = argparse.ArgumentParser(description="SimAC虚拟算法脚本") + parser.add_argument('--pretrained_model_name_or_path', required=True) + parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true') + parser.add_argument('--instance_data_dir_for_train', required=True) + parser.add_argument('--instance_data_dir_for_adversarial', required=True) + parser.add_argument('--instance_prompt', default='a photo of sks person') + parser.add_argument('--class_data_dir', required=True) + parser.add_argument('--num_class_images', type=int, default=200) + parser.add_argument('--class_prompt', default='a photo of person') + parser.add_argument('--output_dir', required=True) + parser.add_argument('--center_crop', action='store_true') + parser.add_argument('--with_prior_preservation', action='store_true') + parser.add_argument('--prior_loss_weight', type=float, default=1.0) + parser.add_argument('--resolution', type=int, default=512) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--max_train_steps', type=int, default=1000) + parser.add_argument('--learning_rate', type=float, default=5e-6) + parser.add_argument('--pgd_eps', type=float, default=8) + parser.add_argument('--seed', type=int, default=0) + + args = parser.parse_args() + + print("=" * 80) + print("[VIRTUAL SIMAC] 虚拟算法执行开始") + print("=" * 80) + print(f"[VIRTUAL] 算法名称: SimAC (Simple and Effective)") + print(f"[VIRTUAL] 当前Conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python可执行文件: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + print(f"[VIRTUAL] 系统平台: {platform.system()} {platform.release()}") + print("-" * 80) + print("[VIRTUAL] 算法参数:") + print(f" - 模型路径: {args.pretrained_model_name_or_path}") + print(f" - 输入训练目录: {args.instance_data_dir_for_train}") + print(f" - 输入对抗目录: {args.instance_data_dir_for_adversarial}") + print(f" - 输出目录: {args.output_dir}") + print(f" - 类别目录: {args.class_data_dir}") + print(f" - 分辨率: {args.resolution}") + print(f" - 扰动强度(pgd_eps): {args.pgd_eps}") + print(f" - 最大训练步数: {args.max_train_steps}") + print(f" - 学习率: {args.learning_rate}") + print("-" * 80) + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 复制图片到输出目录(模拟扰动生成) + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext))) + image_files.extend(glob.glob(os.path.join(args.instance_data_dir_for_adversarial, ext.upper()))) + + print(f"[VIRTUAL] 找到 {len(image_files)} 张输入图片") + + copied_count = 0 + for image_path in image_files: + filename = os.path.basename(image_path) + output_path = os.path.join(args.output_dir, filename) + shutil.copy(image_path, output_path) + copied_count += 1 + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + + print("-" * 80) + print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") + print("[VIRTUAL] 虚拟算法执行完成") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/src/backend/app/algorithms/perturbation_virtual/virtual_demo.py b/src/backend/app/algorithms/perturbation_virtual/virtual_demo.py new file mode 100644 index 0000000..e9cd1ec --- /dev/null +++ b/src/backend/app/algorithms/perturbation_virtual/virtual_demo.py @@ -0,0 +1,25 @@ +import argparse +import os +import sys +import platform + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="虚拟扰动算法脚本") + parser.add_argument('--algorithm_code', required=True) + parser.add_argument('--batch_id', required=True) + parser.add_argument('--epsilon', required=True) + parser.add_argument('--use_strong_protection', required=True) + parser.add_argument('--input_dir', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + print(f"[VIRTUAL] 当前算法: {args.algorithm_code}") + print(f"[VIRTUAL] 批次ID: {args.batch_id}") + print(f"[VIRTUAL] 扰动强度: {args.epsilon}") + print(f"[VIRTUAL] 是否强防护: {args.use_strong_protection}") + print(f"[VIRTUAL] 输入目录: {args.input_dir}") + print(f"[VIRTUAL] 输出目录: {args.output_dir}") + print(f"[VIRTUAL] 当前conda环境: {os.environ.get('CONDA_DEFAULT_ENV', '未检测到')}") + print(f"[VIRTUAL] Python环境: {sys.executable}") + print(f"[VIRTUAL] Python版本: {platform.python_version()}") + # 不做图片处理,图片复制由worker完成 diff --git a/src/backend/app/controllers/admin_controller.py b/src/backend/app/controllers/admin_controller.py index 581fa21..aca0c8e 100644 --- a/src/backend/app/controllers/admin_controller.py +++ b/src/backend/app/controllers/admin_controller.py @@ -1,239 +1,239 @@ -""" -管理员控制器 -处理管理员功能 -""" - -from flask import Blueprint, request, jsonify -from flask_jwt_extended import jwt_required, get_jwt_identity -from app import db -from app.models import User, Batch, Image - -admin_bp = Blueprint('admin', __name__) - -def admin_required(f): - """管理员权限装饰器""" - from functools import wraps - - @wraps(f) - def decorated_function(*args, **kwargs): - current_user_id = get_jwt_identity() - user = User.query.get(current_user_id) - - if not user or user.role != 'admin': - return jsonify({'error': '需要管理员权限'}), 403 - - return f(*args, **kwargs) - - return decorated_function - -@admin_bp.route('/users', methods=['GET']) -@jwt_required() -@admin_required -def list_users(): - """获取用户列表""" - try: - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 20, type=int) - - users = User.query.paginate(page=page, per_page=per_page, error_out=False) - - return jsonify({ - 'users': [user.to_dict() for user in users.items], - 'total': users.total, - 'pages': users.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500 - -@admin_bp.route('/users/', methods=['GET']) -@jwt_required() -@admin_required -def get_user_detail(user_id): - """获取用户详情""" - try: - user = User.query.get(user_id) - if not user: - return jsonify({'error': '用户不存在'}), 404 - - # 获取用户统计信息 - total_tasks = Batch.query.filter_by(user_id=user_id).count() - total_images = Image.query.filter_by(user_id=user_id).count() - - user_dict = user.to_dict() - user_dict['stats'] = { - 'total_tasks': total_tasks, - 'total_images': total_images - } - - return jsonify({'user': user_dict}), 200 - - except Exception as e: - return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500 - -@admin_bp.route('/users', methods=['POST']) -@jwt_required() -@admin_required -def create_user(): - """创建用户""" - try: - data = request.get_json() - username = data.get('username') - password = data.get('password') - email = data.get('email') - role = data.get('role', 'user') - max_concurrent_tasks = data.get('max_concurrent_tasks', 0) - - if not username or not password: - return jsonify({'error': '用户名和密码不能为空'}), 400 - - # 检查用户名是否已存在 - if User.query.filter_by(username=username).first(): - return jsonify({'error': '用户名已存在'}), 400 - - # 检查邮箱是否已存在 - if email and User.query.filter_by(email=email).first(): - return jsonify({'error': '邮箱已被使用'}), 400 - - # 创建用户 - user = User( - username=username, - email=email, - role=role, - max_concurrent_tasks=max_concurrent_tasks - ) - user.set_password(password) - - db.session.add(user) - db.session.commit() - - return jsonify({ - 'message': '用户创建成功', - 'user': user.to_dict() - }), 201 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'创建用户失败: {str(e)}'}), 500 - -@admin_bp.route('/users/', methods=['PUT']) -@jwt_required() -@admin_required -def update_user(user_id): - """更新用户信息""" - try: - user = User.query.get(user_id) - if not user: - return jsonify({'error': '用户不存在'}), 404 - - data = request.get_json() - - # 更新字段 - if 'username' in data: - new_username = data['username'] - if new_username != user.username: - if User.query.filter_by(username=new_username).first(): - return jsonify({'error': '用户名已存在'}), 400 - user.username = new_username - - if 'email' in data: - new_email = data['email'] - if new_email != user.email: - if User.query.filter_by(email=new_email).first(): - return jsonify({'error': '邮箱已被使用'}), 400 - user.email = new_email - - if 'role' in data: - user.role = data['role'] - - if 'max_concurrent_tasks' in data: - user.max_concurrent_tasks = data['max_concurrent_tasks'] - - if 'is_active' in data: - user.is_active = bool(data['is_active']) - - if 'password' in data and data['password']: - user.set_password(data['password']) - - db.session.commit() - - return jsonify({ - 'message': '用户信息更新成功', - 'user': user.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新用户失败: {str(e)}'}), 500 - -@admin_bp.route('/users/', methods=['DELETE']) -@jwt_required() -@admin_required -def delete_user(user_id): - """删除用户""" - try: - current_user_id = get_jwt_identity() - - # 不能删除自己 - if user_id == current_user_id: - return jsonify({'error': '不能删除自己的账户'}), 400 - - user = User.query.get(user_id) - if not user: - return jsonify({'error': '用户不存在'}), 404 - - # 删除用户(级联删除相关数据) - db.session.delete(user) - db.session.commit() - - return jsonify({'message': '用户删除成功'}), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'删除用户失败: {str(e)}'}), 500 - -@admin_bp.route('/stats', methods=['GET']) -@jwt_required() -@admin_required -def get_system_stats(): - """获取系统统计信息""" - try: - from app.models import EvaluationResult - - total_users = User.query.count() - active_users = User.query.filter_by(is_active=True).count() - admin_users = User.query.filter_by(role='admin').count() - - total_tasks = Batch.query.count() - completed_tasks = Batch.query.filter_by(status='completed').count() - processing_tasks = Batch.query.filter_by(status='processing').count() - failed_tasks = Batch.query.filter_by(status='failed').count() - - total_images = Image.query.count() - total_evaluations = EvaluationResult.query.count() - - return jsonify({ - 'stats': { - 'users': { - 'total': total_users, - 'active': active_users, - 'admin': admin_users - }, - 'tasks': { - 'total': total_tasks, - 'completed': completed_tasks, - 'processing': processing_tasks, - 'failed': failed_tasks - }, - 'images': { - 'total': total_images - }, - 'evaluations': { - 'total': total_evaluations - } - } - }), 200 - - except Exception as e: +""" +管理员控制器 +处理管理员功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required, get_jwt_identity +from app import db +from app.database import User, Batch, Image + +admin_bp = Blueprint('admin', __name__) + +def admin_required(f): + """管理员权限装饰器""" + from functools import wraps + + @wraps(f) + def decorated_function(*args, **kwargs): + current_user_id = get_jwt_identity() + user = User.query.get(current_user_id) + + if not user or user.role != 'admin': + return jsonify({'error': '需要管理员权限'}), 403 + + return f(*args, **kwargs) + + return decorated_function + +@admin_bp.route('/users', methods=['GET']) +@jwt_required() +@admin_required +def list_users(): + """获取用户列表""" + try: + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 20, type=int) + + users = User.query.paginate(page=page, per_page=per_page, error_out=False) + + return jsonify({ + 'users': [user.to_dict() for user in users.items], + 'total': users.total, + 'pages': users.pages, + 'current_page': page + }), 200 + + except Exception as e: + return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['GET']) +@jwt_required() +@admin_required +def get_user_detail(user_id): + """获取用户详情""" + try: + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + # 获取用户统计信息 + total_tasks = Batch.query.filter_by(user_id=user_id).count() + total_images = Image.query.filter_by(user_id=user_id).count() + + user_dict = user.to_dict() + user_dict['stats'] = { + 'total_tasks': total_tasks, + 'total_images': total_images + } + + return jsonify({'user': user_dict}), 200 + + except Exception as e: + return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500 + +@admin_bp.route('/users', methods=['POST']) +@jwt_required() +@admin_required +def create_user(): + """创建用户""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + email = data.get('email') + role = data.get('role', 'user') + max_concurrent_tasks = data.get('max_concurrent_tasks', 0) + + if not username or not password: + return jsonify({'error': '用户名和密码不能为空'}), 400 + + # 检查用户名是否已存在 + if User.query.filter_by(username=username).first(): + return jsonify({'error': '用户名已存在'}), 400 + + # 检查邮箱是否已存在 + if email and User.query.filter_by(email=email).first(): + return jsonify({'error': '邮箱已被使用'}), 400 + + # 创建用户 + user = User( + username=username, + email=email, + role=role, + max_concurrent_tasks=max_concurrent_tasks + ) + user.set_password(password) + + db.session.add(user) + db.session.commit() + + return jsonify({ + 'message': '用户创建成功', + 'user': user.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'创建用户失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['PUT']) +@jwt_required() +@admin_required +def update_user(user_id): + """更新用户信息""" + try: + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + + # 更新字段 + if 'username' in data: + new_username = data['username'] + if new_username != user.username: + if User.query.filter_by(username=new_username).first(): + return jsonify({'error': '用户名已存在'}), 400 + user.username = new_username + + if 'email' in data: + new_email = data['email'] + if new_email != user.email: + if User.query.filter_by(email=new_email).first(): + return jsonify({'error': '邮箱已被使用'}), 400 + user.email = new_email + + if 'role' in data: + user.role = data['role'] + + if 'max_concurrent_tasks' in data: + user.max_concurrent_tasks = data['max_concurrent_tasks'] + + if 'is_active' in data: + user.is_active = bool(data['is_active']) + + if 'password' in data and data['password']: + user.set_password(data['password']) + + db.session.commit() + + return jsonify({ + 'message': '用户信息更新成功', + 'user': user.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新用户失败: {str(e)}'}), 500 + +@admin_bp.route('/users/', methods=['DELETE']) +@jwt_required() +@admin_required +def delete_user(user_id): + """删除用户""" + try: + current_user_id = get_jwt_identity() + + # 不能删除自己 + if user_id == current_user_id: + return jsonify({'error': '不能删除自己的账户'}), 400 + + user = User.query.get(user_id) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + # 删除用户(级联删除相关数据) + db.session.delete(user) + db.session.commit() + + return jsonify({'message': '用户删除成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'删除用户失败: {str(e)}'}), 500 + +@admin_bp.route('/stats', methods=['GET']) +@jwt_required() +@admin_required +def get_system_stats(): + """获取系统统计信息""" + try: + from app.database import EvaluationResult + + total_users = User.query.count() + active_users = User.query.filter_by(is_active=True).count() + admin_users = User.query.filter_by(role='admin').count() + + total_tasks = Batch.query.count() + completed_tasks = Batch.query.filter_by(status='completed').count() + processing_tasks = Batch.query.filter_by(status='processing').count() + failed_tasks = Batch.query.filter_by(status='failed').count() + + total_images = Image.query.count() + total_evaluations = EvaluationResult.query.count() + + return jsonify({ + 'stats': { + 'users': { + 'total': total_users, + 'active': active_users, + 'admin': admin_users + }, + 'tasks': { + 'total': total_tasks, + 'completed': completed_tasks, + 'processing': processing_tasks, + 'failed': failed_tasks + }, + 'images': { + 'total': total_images + }, + 'evaluations': { + 'total': total_evaluations + } + } + }), 200 + + except Exception as e: return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/auth_controller.py b/src/backend/app/controllers/auth_controller.py index 648780c..751a2a5 100644 --- a/src/backend/app/controllers/auth_controller.py +++ b/src/backend/app/controllers/auth_controller.py @@ -1,156 +1,156 @@ -""" -用户认证控制器 -处理注册、登录、密码修改等功能 -""" - -from flask import Blueprint, request, jsonify -from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity -from app import db -from app.models import User, UserConfig -from app.services.auth_service import AuthService -from functools import wraps -import re - -def int_jwt_required(f): - """获取JWT身份并转换为整数的装饰器""" - @wraps(f) - def wrapped(*args, **kwargs): - try: - current_user_id = int(get_jwt_identity()) - return f(*args, current_user_id=current_user_id, **kwargs) - except (TypeError, ValueError): - return jsonify({'error': '无效的用户身份标识'}), 401 - return jwt_required()(wrapped) - -auth_bp = Blueprint('auth', __name__) - -@auth_bp.route('/register', methods=['POST']) -def register(): - """用户注册""" - try: - data = request.get_json() - username = data.get('username') - password = data.get('password') - email = data.get('email') - - # 验证输入 - if not username or not password or not email: - return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400 - - # 验证邮箱格式 - email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' - if not re.match(email_pattern, email): - return jsonify({'error': '邮箱格式不正确'}), 400 - - # 检查用户名是否已存在 - if User.query.filter_by(username=username).first(): - return jsonify({'error': '用户名已存在'}), 400 - - # 检查邮箱是否已注册 - if User.query.filter_by(email=email).first(): - return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400 - - # 创建用户 - user = User(username=username, email=email) - user.set_password(password) - - db.session.add(user) - db.session.commit() - - # 创建用户默认配置 - user_config = UserConfig(user_id=user.id) - db.session.add(user_config) - db.session.commit() - - return jsonify({ - 'message': '注册成功', - 'user': user.to_dict() - }), 201 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'注册失败: {str(e)}'}), 500 - -@auth_bp.route('/login', methods=['POST']) -def login(): - """用户登录""" - try: - data = request.get_json() - username = data.get('username') - password = data.get('password') - - if not username or not password: - return jsonify({'error': '用户名和密码不能为空'}), 400 - - # 查找用户 - user = User.query.filter_by(username=username).first() - - if not user or not user.check_password(password): - return jsonify({'error': '用户名或密码错误'}), 401 - - if not user.is_active: - return jsonify({'error': '账户已被禁用'}), 401 - - # 创建访问令牌 - 确保用户ID为字符串类型 - access_token = create_access_token(identity=str(user.id)) - - return jsonify({ - 'message': '登录成功', - 'access_token': access_token, - 'user': user.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'登录失败: {str(e)}'}), 500 - -@auth_bp.route('/change-password', methods=['POST']) -@int_jwt_required -def change_password(current_user_id): - """修改密码""" - try: - user = User.query.get(current_user_id) - - if not user: - return jsonify({'error': '用户不存在'}), 404 - - data = request.get_json() - old_password = data.get('old_password') - new_password = data.get('new_password') - - if not old_password or not new_password: - return jsonify({'error': '旧密码和新密码不能为空'}), 400 - - # 验证旧密码 - if not user.check_password(old_password): - return jsonify({'error': '旧密码错误'}), 401 - - # 设置新密码 - user.set_password(new_password) - db.session.commit() - - return jsonify({'message': '密码修改成功'}), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'密码修改失败: {str(e)}'}), 500 - -@auth_bp.route('/profile', methods=['GET']) -@int_jwt_required -def get_profile(current_user_id): - """获取用户信息""" - try: - user = User.query.get(current_user_id) - - if not user: - return jsonify({'error': '用户不存在'}), 404 - - return jsonify({'user': user.to_dict()}), 200 - - except Exception as e: - return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500 - -@auth_bp.route('/logout', methods=['POST']) -@jwt_required() -def logout(): - """用户登出(客户端删除token即可)""" +""" +用户认证控制器 +处理注册、登录、密码修改等功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity +from app import db +from app.database import User, UserConfig +from app.services.auth_service import AuthService +from functools import wraps +import re + +def int_jwt_required(f): + """获取JWT身份并转换为整数的装饰器""" + @wraps(f) + def wrapped(*args, **kwargs): + try: + current_user_id = int(get_jwt_identity()) + return f(*args, current_user_id=current_user_id, **kwargs) + except (TypeError, ValueError): + return jsonify({'error': '无效的用户身份标识'}), 401 + return jwt_required()(wrapped) + +auth_bp = Blueprint('auth', __name__) + +@auth_bp.route('/register', methods=['POST']) +def register(): + """用户注册""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + email = data.get('email') + + # 验证输入 + if not username or not password or not email: + return jsonify({'error': '用户名、密码和邮箱不能为空'}), 400 + + # 验证邮箱格式 + email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(email_pattern, email): + return jsonify({'error': '邮箱格式不正确'}), 400 + + # 检查用户名是否已存在 + if User.query.filter_by(username=username).first(): + return jsonify({'error': '用户名已存在'}), 400 + + # 检查邮箱是否已注册 + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被注册,同一邮箱只能注册一次'}), 400 + + # 创建用户 + user = User(username=username, email=email) + user.set_password(password) + + db.session.add(user) + db.session.commit() + + # 创建用户默认配置 + user_config = UserConfig(user_id=user.id) + db.session.add(user_config) + db.session.commit() + + return jsonify({ + 'message': '注册成功', + 'user': user.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'注册失败: {str(e)}'}), 500 + +@auth_bp.route('/login', methods=['POST']) +def login(): + """用户登录""" + try: + data = request.get_json() + username = data.get('username') + password = data.get('password') + + if not username or not password: + return jsonify({'error': '用户名和密码不能为空'}), 400 + + # 查找用户 + user = User.query.filter_by(username=username).first() + + if not user or not user.check_password(password): + return jsonify({'error': '用户名或密码错误'}), 401 + + if not user.is_active: + return jsonify({'error': '账户已被禁用'}), 401 + + # 创建访问令牌 - 确保用户ID为字符串类型 + access_token = create_access_token(identity=str(user.id)) + + return jsonify({ + 'message': '登录成功', + 'access_token': access_token, + 'user': user.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'登录失败: {str(e)}'}), 500 + +@auth_bp.route('/change-password', methods=['POST']) +@int_jwt_required +def change_password(current_user_id): + """修改密码""" + try: + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + old_password = data.get('old_password') + new_password = data.get('new_password') + + if not old_password or not new_password: + return jsonify({'error': '旧密码和新密码不能为空'}), 400 + + # 验证旧密码 + if not user.check_password(old_password): + return jsonify({'error': '旧密码错误'}), 401 + + # 设置新密码 + user.set_password(new_password) + db.session.commit() + + return jsonify({'message': '密码修改成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'密码修改失败: {str(e)}'}), 500 + +@auth_bp.route('/profile', methods=['GET']) +@int_jwt_required +def get_profile(current_user_id): + """获取用户信息""" + try: + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + return jsonify({'user': user.to_dict()}), 200 + + except Exception as e: + return jsonify({'error': f'获取用户信息失败: {str(e)}'}), 500 + +@auth_bp.route('/logout', methods=['POST']) +@jwt_required() +def logout(): + """用户登出(客户端删除token即可)""" return jsonify({'message': '登出成功'}), 200 \ No newline at end of file diff --git a/src/backend/app/controllers/demo_controller.py b/src/backend/app/controllers/demo_controller.py index 482958e..949f8a7 100644 --- a/src/backend/app/controllers/demo_controller.py +++ b/src/backend/app/controllers/demo_controller.py @@ -1,177 +1,177 @@ -""" -演示图片控制器 -处理预设图像对比图的展示功能 -""" - -from flask import Blueprint, send_file, jsonify, current_app -from flask_jwt_extended import jwt_required -from app.models import PerturbationConfig, FinetuneConfig -import os -import glob - -demo_bp = Blueprint('demo', __name__) - -@demo_bp.route('/images', methods=['GET']) -def list_demo_images(): - """获取所有演示图片列表""" - try: - demo_images = [] - - # 获取演示原始图片 - 修正路径构建 - # 获取项目根目录(backend目录) - project_root = os.path.dirname(current_app.root_path) - original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) - - if os.path.exists(original_folder): - original_files = glob.glob(os.path.join(original_folder, '*')) - for file_path in original_files: - if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): - filename = os.path.basename(file_path) - name_without_ext = os.path.splitext(filename)[0] - - # 查找对应的加噪图片 - perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) - perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*")) - - # 查找对应的对比图 - comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) - comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*")) - - demo_image = { - 'id': name_without_ext, - 'name': name_without_ext, - 'original': f"/api/demo/image/original/{filename}", - 'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files], - 'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files] - } - demo_images.append(demo_image) - - return jsonify({ - 'demo_images': demo_images, - 'total': len(demo_images) - }), 200 - - except Exception as e: - return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500 - -@demo_bp.route('/image/original/', methods=['GET']) -def get_demo_original_image(filename): - """获取演示原始图片""" - try: - project_root = os.path.dirname(current_app.root_path) - file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename) - - if not os.path.exists(file_path): - return jsonify({'error': '图片不存在'}), 404 - - return send_file(file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500 - -@demo_bp.route('/image/perturbed/', methods=['GET']) -def get_demo_perturbed_image(filename): - """获取演示加噪图片""" - try: - project_root = os.path.dirname(current_app.root_path) - file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename) - - if not os.path.exists(file_path): - return jsonify({'error': '图片不存在'}), 404 - - return send_file(file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500 - -@demo_bp.route('/image/comparison/', methods=['GET']) -def get_demo_comparison_image(filename): - """获取演示对比图片""" - try: - project_root = os.path.dirname(current_app.root_path) - file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename) - - if not os.path.exists(file_path): - return jsonify({'error': '图片不存在'}), 404 - - return send_file(file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500 - -@demo_bp.route('/algorithms', methods=['GET']) -def get_demo_algorithms(): - """获取演示算法信息""" - try: - # 从数据库获取扰动算法 - perturbation_algorithms = [] - perturbation_configs = PerturbationConfig.query.all() - for config in perturbation_configs: - perturbation_algorithms.append({ - 'id': config.id, - 'code': config.method_code, - 'name': config.method_name, - 'type': 'perturbation', - 'description': config.description, - 'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None - }) - - # 从数据库获取微调算法 - finetune_algorithms = [] - finetune_configs = FinetuneConfig.query.all() - for config in finetune_configs: - finetune_algorithms.append({ - 'id': config.id, - 'code': config.method_code, - 'name': config.method_name, - 'type': 'finetune', - 'description': config.description - }) - - return jsonify({ - 'perturbation_algorithms': perturbation_algorithms, - 'finetune_algorithms': finetune_algorithms, - 'evaluation_metrics': [ - {'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'}, - {'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'}, - {'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'}, - {'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'} - ] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500 - -@demo_bp.route('/stats', methods=['GET']) -def get_demo_stats(): - """获取演示统计信息""" - try: - # 统计演示图片数量 - project_root = os.path.dirname(current_app.root_path) - original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) - perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) - comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) - - original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0 - perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0 - comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0 - - # 统计数据库中的算法数量 - perturbation_count = PerturbationConfig.query.count() - finetune_count = FinetuneConfig.query.count() - total_algorithms = perturbation_count + finetune_count - - return jsonify({ - 'demo_stats': { - 'original_images': original_count, - 'perturbed_images': perturbed_count, - 'comparison_images': comparison_count, - 'supported_algorithms': total_algorithms, - 'perturbation_algorithms': perturbation_count, - 'finetune_algorithms': finetune_count, - 'evaluation_metrics': 4 - } - }), 200 - - except Exception as e: +""" +演示图片控制器 +处理预设图像对比图的展示功能 +""" + +from flask import Blueprint, send_file, jsonify, current_app +from flask_jwt_extended import jwt_required +from app.database import PerturbationConfig, FinetuneConfig +import os +import glob + +demo_bp = Blueprint('demo', __name__) + +@demo_bp.route('/images', methods=['GET']) +def list_demo_images(): + """获取所有演示图片列表""" + try: + demo_images = [] + + # 获取演示原始图片 - 修正路径构建 + # 获取项目根目录(backend目录) + project_root = os.path.dirname(current_app.root_path) + original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) + + if os.path.exists(original_folder): + original_files = glob.glob(os.path.join(original_folder, '*')) + for file_path in original_files: + if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): + filename = os.path.basename(file_path) + name_without_ext = os.path.splitext(filename)[0] + + # 查找对应的加噪图片 + perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) + perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*")) + + # 查找对应的对比图 + comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) + comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*")) + + demo_image = { + 'id': name_without_ext, + 'name': name_without_ext, + 'original': f"/api/demo/image/original/{filename}", + 'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files], + 'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files] + } + demo_images.append(demo_image) + + return jsonify({ + 'demo_images': demo_images, + 'total': len(demo_images) + }), 200 + + except Exception as e: + return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500 + +@demo_bp.route('/image/original/', methods=['GET']) +def get_demo_original_image(filename): + """获取演示原始图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500 + +@demo_bp.route('/image/perturbed/', methods=['GET']) +def get_demo_perturbed_image(filename): + """获取演示加噪图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500 + +@demo_bp.route('/image/comparison/', methods=['GET']) +def get_demo_comparison_image(filename): + """获取演示对比图片""" + try: + project_root = os.path.dirname(current_app.root_path) + file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename) + + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + + return send_file(file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500 + +@demo_bp.route('/algorithms', methods=['GET']) +def get_demo_algorithms(): + """获取演示算法信息""" + try: + # 从数据库获取扰动算法 + perturbation_algorithms = [] + perturbation_configs = PerturbationConfig.query.all() + for config in perturbation_configs: + perturbation_algorithms.append({ + 'id': config.id, + 'code': config.method_code, + 'name': config.method_name, + 'type': 'perturbation', + 'description': config.description, + 'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None + }) + + # 从数据库获取微调算法 + finetune_algorithms = [] + finetune_configs = FinetuneConfig.query.all() + for config in finetune_configs: + finetune_algorithms.append({ + 'id': config.id, + 'code': config.method_code, + 'name': config.method_name, + 'type': 'finetune', + 'description': config.description + }) + + return jsonify({ + 'perturbation_algorithms': perturbation_algorithms, + 'finetune_algorithms': finetune_algorithms, + 'evaluation_metrics': [ + {'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'}, + {'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'}, + {'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'}, + {'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'} + ] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500 + +@demo_bp.route('/stats', methods=['GET']) +def get_demo_stats(): + """获取演示统计信息""" + try: + # 统计演示图片数量 + project_root = os.path.dirname(current_app.root_path) + original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) + perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER']) + comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER']) + + original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0 + perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0 + comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0 + + # 统计数据库中的算法数量 + perturbation_count = PerturbationConfig.query.count() + finetune_count = FinetuneConfig.query.count() + total_algorithms = perturbation_count + finetune_count + + return jsonify({ + 'demo_stats': { + 'original_images': original_count, + 'perturbed_images': perturbed_count, + 'comparison_images': comparison_count, + 'supported_algorithms': total_algorithms, + 'perturbation_algorithms': perturbation_count, + 'finetune_algorithms': finetune_count, + 'evaluation_metrics': 4 + } + }), 200 + + except Exception as e: return jsonify({'error': f'获取统计信息失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/image_controller.py b/src/backend/app/controllers/image_controller.py index f47fe98..d52a4a9 100644 --- a/src/backend/app/controllers/image_controller.py +++ b/src/backend/app/controllers/image_controller.py @@ -1,203 +1,203 @@ -""" -图像管理控制器 -处理图像下载、查看等功能 -""" - -from flask import Blueprint, send_file, jsonify, request, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from app.models import Image, EvaluationResult -from app.services.image_service import ImageService -import os - -image_bp = Blueprint('image', __name__) - -@image_bp.route('/file/', methods=['GET']) -@jwt_required() -def get_image_file(image_id): - """获取图片文件""" - try: - current_user_id = get_jwt_identity() - - # 查找图片记录 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 检查文件是否存在 - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file(image.file_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取图片失败: {str(e)}'}), 500 - -@image_bp.route('/download/', methods=['GET']) -@jwt_required() -def download_image(image_id): - """下载图片文件""" - try: - current_user_id = get_jwt_identity() - - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - if not os.path.exists(image.file_path): - return jsonify({'error': '图片文件不存在'}), 404 - - return send_file( - image.file_path, - as_attachment=True, - download_name=image.original_filename or f"image_{image_id}.jpg" - ) - - except Exception as e: - return jsonify({'error': f'下载图片失败: {str(e)}'}), 500 - -@image_bp.route('/batch//download', methods=['GET']) -@jwt_required() -def download_batch_images(batch_id): - """批量下载任务中的加噪后图片""" - try: - current_user_id = get_jwt_identity() - - # 获取任务中的加噪图片 - perturbed_images = Image.query.join(Image.image_type).filter( - Image.batch_id == batch_id, - Image.user_id == current_user_id, - Image.image_type.has(type_code='perturbed') - ).all() - - if not perturbed_images: - return jsonify({'error': '没有找到加噪后的图片'}), 404 - - # 创建ZIP文件 - import zipfile - import tempfile - - with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: - with zipfile.ZipFile(tmp_file.name, 'w') as zip_file: - for image in perturbed_images: - if os.path.exists(image.file_path): - arcname = image.original_filename or f"perturbed_{image.id}.jpg" - zip_file.write(image.file_path, arcname) - - return send_file( - tmp_file.name, - as_attachment=True, - download_name=f"batch_{batch_id}_perturbed_images.zip", - mimetype='application/zip' - ) - - except Exception as e: - return jsonify({'error': f'批量下载失败: {str(e)}'}), 500 - -@image_bp.route('//evaluations', methods=['GET']) -@jwt_required() -def get_image_evaluations(image_id): - """获取图片的评估结果""" - try: - current_user_id = get_jwt_identity() - - # 验证图片权限 - image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() - if not image: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 获取以该图片为参考或目标的评估结果 - evaluations = EvaluationResult.query.filter( - (EvaluationResult.reference_image_id == image_id) | - (EvaluationResult.target_image_id == image_id) - ).all() - - return jsonify({ - 'image_id': image_id, - 'evaluations': [eval_result.to_dict() for eval_result in evaluations] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500 - -@image_bp.route('/compare', methods=['POST']) -@jwt_required() -def compare_images(): - """对比两张图片""" - try: - current_user_id = get_jwt_identity() - data = request.get_json() - - image1_id = data.get('image1_id') - image2_id = data.get('image2_id') - - if not image1_id or not image2_id: - return jsonify({'error': '请提供两张图片的ID'}), 400 - - # 验证图片权限 - image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first() - image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first() - - if not image1 or not image2: - return jsonify({'error': '图片不存在或无权限'}), 404 - - # 查找现有的评估结果 - evaluation = EvaluationResult.query.filter_by( - reference_image_id=image1_id, - target_image_id=image2_id - ).first() - - if not evaluation: - # 如果没有评估结果,返回基本对比信息 - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': None, - 'message': '暂无评估数据,请等待任务处理完成' - }), 200 - - return jsonify({ - 'image1': image1.to_dict(), - 'image2': image2.to_dict(), - 'evaluation': evaluation.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'图片对比失败: {str(e)}'}), 500 - -@image_bp.route('/heatmap/', methods=['GET']) -@jwt_required() -def get_heatmap(heatmap_path): - """获取热力图文件""" - try: - # 安全检查,防止路径遍历攻击 - if '..' in heatmap_path or heatmap_path.startswith('/'): - return jsonify({'error': '无效的文件路径'}), 400 - - # 修正路径构建 - 获取项目根目录(backend目录) - project_root = os.path.dirname(current_app.root_path) - full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path)) - - if not os.path.exists(full_path): - return jsonify({'error': '热力图文件不存在'}), 404 - - return send_file(full_path, as_attachment=False) - - except Exception as e: - return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500 - -@image_bp.route('/delete/', methods=['DELETE']) -@jwt_required() -def delete_image(image_id): - """删除图片""" - try: - current_user_id = get_jwt_identity() - - result = ImageService.delete_image(image_id, current_user_id) - - if result['success']: - return jsonify({'message': '图片删除成功'}), 200 - else: - return jsonify({'error': result['error']}), 400 - - except Exception as e: +""" +图像管理控制器 +处理图像下载、查看等功能 +""" + +from flask import Blueprint, send_file, jsonify, request, current_app +from flask_jwt_extended import jwt_required, get_jwt_identity +from app.database import Image, EvaluationResult +from app.services.image_service import ImageService +import os + +image_bp = Blueprint('image', __name__) + +@image_bp.route('/file/', methods=['GET']) +@jwt_required() +def get_image_file(image_id): + """获取图片文件""" + try: + current_user_id = get_jwt_identity() + + # 查找图片记录 + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 检查文件是否存在 + if not os.path.exists(image.file_path): + return jsonify({'error': '图片文件不存在'}), 404 + + return send_file(image.file_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取图片失败: {str(e)}'}), 500 + +@image_bp.route('/download/', methods=['GET']) +@jwt_required() +def download_image(image_id): + """下载图片文件""" + try: + current_user_id = get_jwt_identity() + + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + if not os.path.exists(image.file_path): + return jsonify({'error': '图片文件不存在'}), 404 + + return send_file( + image.file_path, + as_attachment=True, + download_name=image.original_filename or f"image_{image_id}.jpg" + ) + + except Exception as e: + return jsonify({'error': f'下载图片失败: {str(e)}'}), 500 + +@image_bp.route('/batch//download', methods=['GET']) +@jwt_required() +def download_batch_images(batch_id): + """批量下载任务中的加噪后图片""" + try: + current_user_id = get_jwt_identity() + + # 获取任务中的加噪图片 + perturbed_images = Image.query.join(Image.image_type).filter( + Image.batch_id == batch_id, + Image.user_id == current_user_id, + Image.image_type.has(type_code='perturbed') + ).all() + + if not perturbed_images: + return jsonify({'error': '没有找到加噪后的图片'}), 404 + + # 创建ZIP文件 + import zipfile + import tempfile + + with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: + with zipfile.ZipFile(tmp_file.name, 'w') as zip_file: + for image in perturbed_images: + if os.path.exists(image.file_path): + arcname = image.original_filename or f"perturbed_{image.id}.jpg" + zip_file.write(image.file_path, arcname) + + return send_file( + tmp_file.name, + as_attachment=True, + download_name=f"batch_{batch_id}_perturbed_images.zip", + mimetype='application/zip' + ) + + except Exception as e: + return jsonify({'error': f'批量下载失败: {str(e)}'}), 500 + +@image_bp.route('//evaluations', methods=['GET']) +@jwt_required() +def get_image_evaluations(image_id): + """获取图片的评估结果""" + try: + current_user_id = get_jwt_identity() + + # 验证图片权限 + image = Image.query.filter_by(id=image_id, user_id=current_user_id).first() + if not image: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 获取以该图片为参考或目标的评估结果 + evaluations = EvaluationResult.query.filter( + (EvaluationResult.reference_image_id == image_id) | + (EvaluationResult.target_image_id == image_id) + ).all() + + return jsonify({ + 'image_id': image_id, + 'evaluations': [eval_result.to_dict() for eval_result in evaluations] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取评估结果失败: {str(e)}'}), 500 + +@image_bp.route('/compare', methods=['POST']) +@jwt_required() +def compare_images(): + """对比两张图片""" + try: + current_user_id = get_jwt_identity() + data = request.get_json() + + image1_id = data.get('image1_id') + image2_id = data.get('image2_id') + + if not image1_id or not image2_id: + return jsonify({'error': '请提供两张图片的ID'}), 400 + + # 验证图片权限 + image1 = Image.query.filter_by(id=image1_id, user_id=current_user_id).first() + image2 = Image.query.filter_by(id=image2_id, user_id=current_user_id).first() + + if not image1 or not image2: + return jsonify({'error': '图片不存在或无权限'}), 404 + + # 查找现有的评估结果 + evaluation = EvaluationResult.query.filter_by( + reference_image_id=image1_id, + target_image_id=image2_id + ).first() + + if not evaluation: + # 如果没有评估结果,返回基本对比信息 + return jsonify({ + 'image1': image1.to_dict(), + 'image2': image2.to_dict(), + 'evaluation': None, + 'message': '暂无评估数据,请等待任务处理完成' + }), 200 + + return jsonify({ + 'image1': image1.to_dict(), + 'image2': image2.to_dict(), + 'evaluation': evaluation.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'图片对比失败: {str(e)}'}), 500 + +@image_bp.route('/heatmap/', methods=['GET']) +@jwt_required() +def get_heatmap(heatmap_path): + """获取热力图文件""" + try: + # 安全检查,防止路径遍历攻击 + if '..' in heatmap_path or heatmap_path.startswith('/'): + return jsonify({'error': '无效的文件路径'}), 400 + + # 修正路径构建 - 获取项目根目录(backend目录) + project_root = os.path.dirname(current_app.root_path) + full_path = os.path.join(project_root, 'static', 'heatmaps', os.path.basename(heatmap_path)) + + if not os.path.exists(full_path): + return jsonify({'error': '热力图文件不存在'}), 404 + + return send_file(full_path, as_attachment=False) + + except Exception as e: + return jsonify({'error': f'获取热力图失败: {str(e)}'}), 500 + +@image_bp.route('/delete/', methods=['DELETE']) +@jwt_required() +def delete_image(image_id): + """删除图片""" + try: + current_user_id = get_jwt_identity() + + result = ImageService.delete_image(image_id, current_user_id) + + if result['success']: + return jsonify({'message': '图片删除成功'}), 200 + else: + return jsonify({'error': result['error']}), 400 + + except Exception as e: return jsonify({'error': f'删除图片失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 3cb8f73..4a57f8b 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -1,400 +1,319 @@ -""" -任务管理控制器 -处理创建任务、上传图片等功能 -""" - -from flask import Blueprint, request, jsonify, current_app -from flask_jwt_extended import jwt_required, get_jwt_identity -from werkzeug.utils import secure_filename -from app import db -from app.models import User, Batch, Image, ImageType, UserConfig -from app.services.task_service import TaskService -from app.services.image_service import ImageService -from app.utils.file_utils import allowed_file, save_uploaded_file -import os -import zipfile -import uuid - -task_bp = Blueprint('task', __name__) - -@task_bp.route('/create', methods=['POST']) -@jwt_required() -def create_task(): - """创建新任务(仅创建任务,使用默认配置)""" - try: - current_user_id = get_jwt_identity() - user = User.query.get(current_user_id) - - if not user: - return jsonify({'error': '用户不存在'}), 404 - - data = request.get_json() - batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') - - # 使用默认配置创建任务 - batch = Batch( - user_id=current_user_id, - batch_name=batch_name, - perturbation_config_id=1, # 默认配置 - preferred_epsilon=8.0, # 默认epsilon - finetune_config_id=1, # 默认微调配置 - use_strong_protection=False # 默认不启用强防护 - ) - - db.session.add(batch) - db.session.commit() - - return jsonify({ - 'message': '任务创建成功,请上传图片', - 'task': batch.to_dict() - }), 201 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'任务创建失败: {str(e)}'}), 500 - -@task_bp.route('/upload/', methods=['POST']) -@jwt_required() -def upload_images(batch_id): - """上传图片到指定任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400 - - if 'files' not in request.files: - return jsonify({'error': '没有选择文件'}), 400 - - files = request.files.getlist('files') - uploaded_files = [] - - # 获取原始图片类型ID - original_type = ImageType.query.filter_by(type_code='original').first() - if not original_type: - return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500 - - for file in files: - if file.filename == '': - continue - - if file and allowed_file(file.filename): - # 处理单张图片 - if not file.filename.lower().endswith(('.zip', '.rar')): - result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) - if result['success']: - uploaded_files.append(result['image']) - else: - return jsonify({'error': result['error']}), 400 - - # 处理压缩包 - else: - results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) - for result in results: - if result['success']: - uploaded_files.append(result['image']) - - if not uploaded_files: - return jsonify({'error': '没有有效的图片文件'}), 400 - - return jsonify({ - 'message': f'成功上传 {len(uploaded_files)} 张图片', - 'uploaded_files': [img.to_dict() for img in uploaded_files] - }), 200 - - except Exception as e: - return jsonify({'error': f'文件上传失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['GET']) -@jwt_required() -def get_task_config(batch_id): - """获取任务配置(显示用户上次的配置或默认配置)""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - # 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置 - if user_config: - suggested_config = { - 'perturbation_config_id': user_config.preferred_perturbation_config_id, - 'epsilon': float(user_config.preferred_epsilon), - 'finetune_config_id': user_config.preferred_finetune_config_id, - 'use_strong_protection': user_config.preferred_purification - } - else: - suggested_config = { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'finetune_config_id': batch.finetune_config_id, - 'use_strong_protection': batch.use_strong_protection - } - - return jsonify({ - 'task': batch.to_dict(), - 'suggested_config': suggested_config, - 'current_config': { - 'perturbation_config_id': batch.perturbation_config_id, - 'epsilon': float(batch.preferred_epsilon), - 'finetune_config_id': batch.finetune_config_id, - 'use_strong_protection': batch.use_strong_protection - } - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500 - -@task_bp.route('//config', methods=['PUT']) -@jwt_required() -def update_task_config(batch_id): - """更新任务配置""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务已开始处理,无法修改配置'}), 400 - - data = request.get_json() - - # 更新任务配置 - if 'perturbation_config_id' in data: - batch.perturbation_config_id = data['perturbation_config_id'] - - if 'epsilon' in data: - epsilon = float(data['epsilon']) - if 0 < epsilon <= 255: - batch.preferred_epsilon = epsilon - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'finetune_config_id' in data: - batch.finetune_config_id = data['finetune_config_id'] - - if 'use_strong_protection' in data: - batch.use_strong_protection = bool(data['use_strong_protection']) - - db.session.commit() - - # 更新用户配置(保存这次的选择作为下次的默认配置) - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - if not user_config: - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - - user_config.preferred_perturbation_config_id = batch.perturbation_config_id - user_config.preferred_epsilon = batch.preferred_epsilon - user_config.preferred_finetune_config_id = batch.finetune_config_id - user_config.preferred_purification = batch.use_strong_protection - - db.session.commit() - - return jsonify({ - 'message': '任务配置更新成功', - 'task': batch.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500 - -@task_bp.route('/start/', methods=['POST']) -@jwt_required() -def start_task(batch_id): - """开始处理任务""" - try: - current_user_id = get_jwt_identity() - - # 检查任务是否存在且属于当前用户 - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - if batch.status != 'pending': - return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 - - # 检查是否有上传的图片 - image_count = Image.query.filter_by(batch_id=batch_id).count() - if image_count == 0: - return jsonify({'error': '请先上传图片'}), 400 - - # 启动任务处理 - success = TaskService.start_processing(batch) - - if success: - return jsonify({ - 'message': '任务开始处理', - 'task': batch.to_dict() - }), 200 - else: - return jsonify({'error': '任务启动失败'}), 500 - - except Exception as e: - return jsonify({'error': f'任务启动失败: {str(e)}'}), 500 - -@task_bp.route('/list', methods=['GET']) -@jwt_required() -def list_tasks(): - """获取用户的任务列表""" - try: - current_user_id = get_jwt_identity() - - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - batches = Batch.query.filter_by(user_id=current_user_id)\ - .order_by(Batch.created_at.desc())\ - .paginate(page=page, per_page=per_page, error_out=False) - - return jsonify({ - 'tasks': [batch.to_dict() for batch in batches.items], - 'total': batches.total, - 'pages': batches.pages, - 'current_page': page - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500 - -@task_bp.route('/', methods=['GET']) -@jwt_required() -def get_task_detail(batch_id): - """获取任务详情""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - # 获取任务相关的图片 - images = Image.query.filter_by(batch_id=batch_id).all() - - return jsonify({ - 'task': batch.to_dict(), - 'images': [img.to_dict() for img in images], - 'image_count': len(images) - }), 200 - - except Exception as e: - return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500 - -@task_bp.route('/load-config', methods=['GET']) -@jwt_required() -def load_last_config(): - """加载用户上次的配置""" - try: - current_user_id = get_jwt_identity() - - # 获取用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if user_config: - config = { - 'perturbation_config_id': user_config.preferred_perturbation_config_id, - 'epsilon': float(user_config.preferred_epsilon), - 'finetune_config_id': user_config.preferred_finetune_config_id, - 'use_strong_protection': user_config.preferred_purification - } - return jsonify({ - 'message': '成功加载上次配置', - 'config': config - }), 200 - else: - # 返回默认配置 - default_config = { - 'perturbation_config_id': 1, - 'epsilon': 8.0, - 'finetune_config_id': 1, - 'use_strong_protection': False - } - return jsonify({ - 'message': '使用默认配置', - 'config': default_config - }), 200 - - except Exception as e: - return jsonify({'error': f'加载配置失败: {str(e)}'}), 500 - -@task_bp.route('/save-config', methods=['POST']) -@jwt_required() -def save_current_config(): - """保存当前配置作为用户偏好""" - try: - current_user_id = get_jwt_identity() - data = request.get_json() - - # 获取或创建用户配置 - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - if not user_config: - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - - # 更新配置 - if 'perturbation_config_id' in data: - user_config.preferred_perturbation_config_id = data['perturbation_config_id'] - - if 'epsilon' in data: - epsilon = float(data['epsilon']) - if 0 < epsilon <= 255: - user_config.preferred_epsilon = epsilon - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'finetune_config_id' in data: - user_config.preferred_finetune_config_id = data['finetune_config_id'] - - if 'use_strong_protection' in data: - user_config.preferred_purification = bool(data['use_strong_protection']) - - db.session.commit() - - return jsonify({ - 'message': '配置保存成功', - 'config': { - 'perturbation_config_id': user_config.preferred_perturbation_config_id, - 'epsilon': float(user_config.preferred_epsilon), - 'finetune_config_id': user_config.preferred_finetune_config_id, - 'use_strong_protection': user_config.preferred_purification - } - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'保存配置失败: {str(e)}'}), 500 - -@task_bp.route('//status', methods=['GET']) -@jwt_required() -def get_task_status(batch_id): - """获取任务处理状态""" - try: - current_user_id = get_jwt_identity() - - batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() - if not batch: - return jsonify({'error': '任务不存在或无权限'}), 404 - - return jsonify({ - 'task_id': batch_id, - 'status': batch.status, - 'progress': TaskService.get_processing_progress(batch_id), - 'error_message': batch.error_message - }), 200 - - except Exception as e: +""" +任务管理控制器 +处理创建任务、上传图片等功能 +""" + +from flask import Blueprint, request, jsonify, current_app +from flask_jwt_extended import jwt_required, get_jwt_identity +from werkzeug.utils import secure_filename +from app import db +from app.database import User, Batch, Image, ImageType, UserConfig +from app.services.task_service import TaskService +from app.services.image_service import ImageService +from app.utils.file_utils import allowed_file, save_uploaded_file +import os +import zipfile +import uuid + +task_bp = Blueprint('task', __name__) + +@task_bp.route('/create', methods=['POST']) +@jwt_required() +def create_task(): + """创建新任务(使用用户配置作为默认配置)""" + try: + current_user_id = get_jwt_identity() + user = User.query.get(current_user_id) + + if not user: + return jsonify({'error': '用户不存在'}), 404 + + data = request.get_json() + batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') + + # 获取用户配置作为默认配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if user_config: + # 使用用户上次的配置 + perturbation_config_id = user_config.preferred_perturbation_config_id or 1 + preferred_epsilon = user_config.preferred_epsilon or 8.0 + finetune_config_id = user_config.preferred_finetune_config_id or 1 + use_strong_protection = user_config.preferred_purification or False + else: + # 使用系统默认配置 + perturbation_config_id = 1 + preferred_epsilon = 8.0 + finetune_config_id = 1 + use_strong_protection = False + + # 创建任务 + batch = Batch( + user_id=current_user_id, + batch_name=batch_name, + perturbation_config_id=perturbation_config_id, + preferred_epsilon=preferred_epsilon, + finetune_config_id=finetune_config_id, + use_strong_protection=use_strong_protection + ) + + db.session.add(batch) + db.session.commit() + + return jsonify({ + 'message': '任务创建成功,请上传图片', + 'task': batch.to_dict() + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'任务创建失败: {str(e)}'}), 500 + +@task_bp.route('/upload/', methods=['POST']) +@jwt_required() +def upload_images(batch_id): + """上传图片到指定任务""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务已开始处理,无法上传新图片'}), 400 + + if 'files' not in request.files: + return jsonify({'error': '没有选择文件'}), 400 + + files = request.files.getlist('files') + uploaded_files = [] + + # 获取原始图片类型ID + original_type = ImageType.query.filter_by(type_code='original').first() + if not original_type: + return jsonify({'error': '系统配置错误:缺少原始图片类型'}), 500 + + for file in files: + if file.filename == '': + continue + + if file and allowed_file(file.filename): + # 处理单张图片 + if not file.filename.lower().endswith(('.zip', '.rar')): + result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) + if result['success']: + uploaded_files.append(result['image']) + else: + return jsonify({'error': result['error']}), 400 + + # 处理压缩包 + else: + results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) + for result in results: + if result['success']: + uploaded_files.append(result['image']) + + if not uploaded_files: + return jsonify({'error': '没有有效的图片文件'}), 400 + + return jsonify({ + 'message': f'成功上传 {len(uploaded_files)} 张图片', + 'uploaded_files': [img.to_dict() for img in uploaded_files] + }), 200 + + except Exception as e: + return jsonify({'error': f'文件上传失败: {str(e)}'}), 500 + +@task_bp.route('//config', methods=['GET']) +@jwt_required() +def get_task_config(batch_id): + """获取任务配置(显示用户上次的配置或默认配置)""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + # 获取用户配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + # 如果用户有配置,显示用户上次的配置;否则显示当前任务的默认配置 + if user_config: + suggested_config = { + 'perturbation_config_id': user_config.preferred_perturbation_config_id, + 'epsilon': float(user_config.preferred_epsilon), + 'finetune_config_id': user_config.preferred_finetune_config_id, + 'use_strong_protection': user_config.preferred_purification + } + else: + suggested_config = { + 'perturbation_config_id': batch.perturbation_config_id, + 'epsilon': float(batch.preferred_epsilon), + 'finetune_config_id': batch.finetune_config_id, + 'use_strong_protection': batch.use_strong_protection + } + + return jsonify({ + 'task': batch.to_dict(), + 'suggested_config': suggested_config, + 'current_config': { + 'perturbation_config_id': batch.perturbation_config_id, + 'epsilon': float(batch.preferred_epsilon), + 'finetune_config_id': batch.finetune_config_id, + 'use_strong_protection': batch.use_strong_protection + } + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务配置失败: {str(e)}'}), 500 + +@task_bp.route('//config', methods=['PUT']) +@jwt_required() +def update_task_config(batch_id): + """更新任务配置(仅更新任务本身,不影响用户配置)""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务已开始处理,无法修改配置'}), 400 + + data = request.get_json() + + # 更新任务配置 + if 'perturbation_config_id' in data: + batch.perturbation_config_id = data['perturbation_config_id'] + + if 'epsilon' in data: + epsilon = float(data['epsilon']) + if 0 < epsilon <= 255: + batch.preferred_epsilon = epsilon + else: + return jsonify({'error': '扰动强度必须在0-255之间'}), 400 + + if 'finetune_config_id' in data: + batch.finetune_config_id = data['finetune_config_id'] + + if 'use_strong_protection' in data: + batch.use_strong_protection = bool(data['use_strong_protection']) + + db.session.commit() + + return jsonify({ + 'message': '任务配置更新成功', + 'task': batch.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新任务配置失败: {str(e)}'}), 500 + +@task_bp.route('/start/', methods=['POST']) +@jwt_required() +def start_task(batch_id): + """开始处理任务""" + try: + current_user_id = get_jwt_identity() + + # 检查任务是否存在且属于当前用户 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + if batch.status != 'pending': + return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 + + # 检查是否有上传的图片 + image_count = Image.query.filter_by(batch_id=batch_id).count() + if image_count == 0: + return jsonify({'error': '请先上传图片'}), 400 + + # 启动任务处理 + success = TaskService.start_processing(batch) + + if success: + return jsonify({ + 'message': '任务开始处理', + 'task': batch.to_dict() + }), 200 + else: + return jsonify({'error': '任务启动失败'}), 500 + + except Exception as e: + return jsonify({'error': f'任务启动失败: {str(e)}'}), 500 + +@task_bp.route('/list', methods=['GET']) +@jwt_required() +def list_tasks(): + """获取用户的任务列表""" + try: + current_user_id = get_jwt_identity() + + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 10, type=int) + + batches = Batch.query.filter_by(user_id=current_user_id)\ + .order_by(Batch.created_at.desc())\ + .paginate(page=page, per_page=per_page, error_out=False) + + return jsonify({ + 'tasks': [batch.to_dict() for batch in batches.items], + 'total': batches.total, + 'pages': batches.pages, + 'current_page': page + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务列表失败: {str(e)}'}), 500 + +@task_bp.route('/', methods=['GET']) +@jwt_required() +def get_task_detail(batch_id): + """获取任务详情""" + try: + current_user_id = get_jwt_identity() + + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + # 获取任务相关的图片 + images = Image.query.filter_by(batch_id=batch_id).all() + + return jsonify({ + 'task': batch.to_dict(), + 'images': [img.to_dict() for img in images], + 'image_count': len(images) + }), 200 + + except Exception as e: + return jsonify({'error': f'获取任务详情失败: {str(e)}'}), 500 + +@task_bp.route('//status', methods=['GET']) +@jwt_required() +def get_task_status(batch_id): + """获取任务处理状态""" + try: + current_user_id = get_jwt_identity() + + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '任务不存在或无权限'}), 404 + + return jsonify({ + 'task_id': batch_id, + 'status': batch.status, + 'progress': TaskService.get_processing_progress(batch_id), + 'error_message': batch.error_message + }), 200 + + except Exception as e: return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/controllers/user_controller.py b/src/backend/app/controllers/user_controller.py index b66b501..2b680ed 100644 --- a/src/backend/app/controllers/user_controller.py +++ b/src/backend/app/controllers/user_controller.py @@ -1,133 +1,133 @@ -""" -用户管理控制器 -处理用户配置等功能 -""" - -from flask import Blueprint, request, jsonify -from flask_jwt_extended import jwt_required -from app import db -from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig -from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 - -user_bp = Blueprint('user', __name__) - -@user_bp.route('/config', methods=['GET']) -@int_jwt_required -def get_user_config(current_user_id): - """获取用户配置""" - try: - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - # 如果没有配置,创建默认配置 - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - db.session.commit() - - return jsonify({ - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/config', methods=['PUT']) -@int_jwt_required -def update_user_config(current_user_id): - """更新用户配置""" - try: - data = request.get_json() - - user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - - if not user_config: - user_config = UserConfig(user_id=current_user_id) - db.session.add(user_config) - - # 更新配置字段 - if 'preferred_perturbation_config_id' in data: - user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id'] - - if 'preferred_epsilon' in data: - epsilon = float(data['preferred_epsilon']) - if 0 < epsilon <= 255: - user_config.preferred_epsilon = epsilon - else: - return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - - if 'preferred_finetune_config_id' in data: - user_config.preferred_finetune_config_id = data['preferred_finetune_config_id'] - - if 'preferred_purification' in data: - user_config.preferred_purification = bool(data['preferred_purification']) - - db.session.commit() - - return jsonify({ - 'message': '用户配置更新成功', - 'config': user_config.to_dict() - }), 200 - - except Exception as e: - db.session.rollback() - return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500 - -@user_bp.route('/algorithms', methods=['GET']) -@jwt_required() -def get_available_algorithms(): - """获取可用的算法列表""" - try: - perturbation_configs = PerturbationConfig.query.all() - finetune_configs = FinetuneConfig.query.all() - - return jsonify({ - 'perturbation_algorithms': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description, - 'default_epsilon': float(config.default_epsilon) - } for config in perturbation_configs - ], - 'finetune_methods': [ - { - 'id': config.id, - 'method_code': config.method_code, - 'method_name': config.method_name, - 'description': config.description - } for config in finetune_configs - ] - }), 200 - - except Exception as e: - return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500 - -@user_bp.route('/stats', methods=['GET']) -@int_jwt_required -def get_user_stats(current_user_id): - """获取用户统计信息""" - try: - from app.models import Batch, Image - - # 统计用户的任务和图片数量 - total_tasks = Batch.query.filter_by(user_id=current_user_id).count() - completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count() - processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count() - failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count() - - total_images = Image.query.filter_by(user_id=current_user_id).count() - - return jsonify({ - 'stats': { - 'total_tasks': total_tasks, - 'completed_tasks': completed_tasks, - 'processing_tasks': processing_tasks, - 'failed_tasks': failed_tasks, - 'total_images': total_images - } - }), 200 - - except Exception as e: +""" +用户管理控制器 +处理用户配置等功能 +""" + +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required +from app import db +from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig +from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 + +user_bp = Blueprint('user', __name__) + +@user_bp.route('/config', methods=['GET']) +@int_jwt_required +def get_user_config(current_user_id): + """获取用户配置""" + try: + + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if not user_config: + # 如果没有配置,创建默认配置 + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + db.session.commit() + + return jsonify({ + 'config': user_config.to_dict() + }), 200 + + except Exception as e: + return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500 + +@user_bp.route('/config', methods=['PUT']) +@int_jwt_required +def update_user_config(current_user_id): + """更新用户配置""" + try: + data = request.get_json() + + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + if not user_config: + user_config = UserConfig(user_id=current_user_id) + db.session.add(user_config) + + # 更新配置字段 + if 'preferred_perturbation_config_id' in data: + user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id'] + + if 'preferred_epsilon' in data: + epsilon = float(data['preferred_epsilon']) + if 0 < epsilon <= 255: + user_config.preferred_epsilon = epsilon + else: + return jsonify({'error': '扰动强度必须在0-255之间'}), 400 + + if 'preferred_finetune_config_id' in data: + user_config.preferred_finetune_config_id = data['preferred_finetune_config_id'] + + if 'preferred_purification' in data: + user_config.preferred_purification = bool(data['preferred_purification']) + + db.session.commit() + + return jsonify({ + 'message': '用户配置更新成功', + 'config': user_config.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500 + +@user_bp.route('/algorithms', methods=['GET']) +@jwt_required() +def get_available_algorithms(): + """获取可用的算法列表""" + try: + perturbation_configs = PerturbationConfig.query.all() + finetune_configs = FinetuneConfig.query.all() + + return jsonify({ + 'perturbation_algorithms': [ + { + 'id': config.id, + 'method_code': config.method_code, + 'method_name': config.method_name, + 'description': config.description, + 'default_epsilon': float(config.default_epsilon) + } for config in perturbation_configs + ], + 'finetune_methods': [ + { + 'id': config.id, + 'method_code': config.method_code, + 'method_name': config.method_name, + 'description': config.description + } for config in finetune_configs + ] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500 + +@user_bp.route('/stats', methods=['GET']) +@int_jwt_required +def get_user_stats(current_user_id): + """获取用户统计信息""" + try: + from app.database import Batch, Image + + # 统计用户的任务和图片数量 + total_tasks = Batch.query.filter_by(user_id=current_user_id).count() + completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count() + processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count() + failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count() + + total_images = Image.query.filter_by(user_id=current_user_id).count() + + return jsonify({ + 'stats': { + 'total_tasks': total_tasks, + 'completed_tasks': completed_tasks, + 'processing_tasks': processing_tasks, + 'failed_tasks': failed_tasks, + 'total_images': total_images + } + }), 200 + + except Exception as e: return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500 \ No newline at end of file diff --git a/src/backend/app/models/__init__.py b/src/backend/app/database/__init__.py similarity index 96% rename from src/backend/app/models/__init__.py rename to src/backend/app/database/__init__.py index d92e3e8..5f7f8fd 100644 --- a/src/backend/app/models/__init__.py +++ b/src/backend/app/database/__init__.py @@ -1,233 +1,233 @@ -""" -数据库模型定义 -基于已有的schema.sql设计 -""" - -from datetime import datetime -from app import db -from werkzeug.security import generate_password_hash, check_password_hash -from enum import Enum as PyEnum - -class User(db.Model): - """用户表""" - __tablename__ = 'users' - - id = db.Column(db.BigInteger, primary_key=True) - username = db.Column(db.String(50), unique=True, nullable=False) - password_hash = db.Column(db.String(255), nullable=False) - email = db.Column(db.String(100)) - role = db.Column(db.Enum('user', 'admin'), default='user') - max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0) - created_at = db.Column(db.DateTime, default=datetime.utcnow) - updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - is_active = db.Column(db.Boolean, default=True) - - # 关系 - batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan') - images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan') - user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan') - - def set_password(self, password): - """设置密码""" - self.password_hash = generate_password_hash(password) - - def check_password(self, password): - """验证密码""" - return check_password_hash(self.password_hash, password) - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'username': self.username, - 'email': self.email, - 'role': self.role, - 'max_concurrent_tasks': self.max_concurrent_tasks, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'is_active': self.is_active - } - -class ImageType(db.Model): - """图片类型表""" - __tablename__ = 'image_types' - - id = db.Column(db.BigInteger, primary_key=True) - type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'), - unique=True, nullable=False) - type_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text) - - # 关系 - images = db.relationship('Image', backref='image_type', lazy='dynamic') - -class PerturbationConfig(db.Model): - """加噪算法表""" - __tablename__ = 'perturbation_configs' - - id = db.Column(db.BigInteger, primary_key=True) - method_code = db.Column(db.String(50), unique=True, nullable=False) - method_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text, nullable=False) - default_epsilon = db.Column(db.Numeric(5, 2), nullable=False) - - # 关系 - batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic') - user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic') - -class FinetuneConfig(db.Model): - """微调方式表""" - __tablename__ = 'finetune_configs' - - id = db.Column(db.BigInteger, primary_key=True) - method_code = db.Column(db.String(50), unique=True, nullable=False) - method_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text, nullable=False) - - # 关系 - batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic') - user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic') - -class Batch(db.Model): - """加噪批次表""" - __tablename__ = 'batch' - - id = db.Column(db.BigInteger, primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) - batch_name = db.Column(db.String(128)) - - # 加噪配置 - perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'), - nullable=False, default=1) - preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0) - - # 评估配置 - finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'), - nullable=False, default=1) - - # 净化配置 - use_strong_protection = db.Column(db.Boolean, nullable=False, default=False) - - # 任务状态 - status = db.Column(db.Enum('pending', 'processing', 'completed', 'failed'), default='pending') - created_at = db.Column(db.DateTime, default=datetime.utcnow) - started_at = db.Column(db.DateTime) - completed_at = db.Column(db.DateTime) - error_message = db.Column(db.Text) - result_path = db.Column(db.String(500)) - - # 关系 - images = db.relationship('Image', backref='batch', lazy='dynamic') - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'batch_name': self.batch_name, - 'status': self.status, - 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, - 'use_strong_protection': self.use_strong_protection, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'error_message': self.error_message, - 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None, - 'finetune_config': self.finetune_config.method_name if self.finetune_config else None - } - -class Image(db.Model): - """图片表""" - __tablename__ = 'images' - - id = db.Column(db.BigInteger, primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) - batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id')) - father_id = db.Column(db.BigInteger, db.ForeignKey('images.id')) - original_filename = db.Column(db.String(255)) - stored_filename = db.Column(db.String(255), unique=True, nullable=False) - file_path = db.Column(db.String(500), nullable=False) - file_size = db.Column(db.BigInteger) - image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False) - width = db.Column(db.Integer) - height = db.Column(db.Integer) - upload_time = db.Column(db.DateTime, default=datetime.utcnow) - - # 自引用关系 - children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') - - # 评估结果关系 - reference_evaluations = db.relationship('EvaluationResult', - foreign_keys='EvaluationResult.reference_image_id', - backref='reference_image', lazy='dynamic') - target_evaluations = db.relationship('EvaluationResult', - foreign_keys='EvaluationResult.target_image_id', - backref='target_image', lazy='dynamic') - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'original_filename': self.original_filename, - 'stored_filename': self.stored_filename, - 'file_path': self.file_path, - 'file_size': self.file_size, - 'width': self.width, - 'height': self.height, - 'upload_time': self.upload_time.isoformat() if self.upload_time else None, - 'image_type': self.image_type.type_name if self.image_type else None, - 'batch_id': self.batch_id - } - -class EvaluationResult(db.Model): - """评估结果表""" - __tablename__ = 'evaluation_results' - - id = db.Column(db.BigInteger, primary_key=True) - reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) - target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) - evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False) - purification_applied = db.Column(db.Boolean, default=False) - fid_score = db.Column(db.Numeric(8, 4)) - lpips_score = db.Column(db.Numeric(8, 4)) - ssim_score = db.Column(db.Numeric(8, 4)) - psnr_score = db.Column(db.Numeric(8, 4)) - heatmap_path = db.Column(db.String(500)) - evaluated_at = db.Column(db.DateTime, default=datetime.utcnow) - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'evaluation_type': self.evaluation_type, - 'purification_applied': self.purification_applied, - 'fid_score': float(self.fid_score) if self.fid_score else None, - 'lpips_score': float(self.lpips_score) if self.lpips_score else None, - 'ssim_score': float(self.ssim_score) if self.ssim_score else None, - 'psnr_score': float(self.psnr_score) if self.psnr_score else None, - 'heatmap_path': self.heatmap_path, - 'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None - } - -class UserConfig(db.Model): - """用户配置表""" - __tablename__ = 'user_configs' - - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True) - preferred_perturbation_config_id = db.Column(db.BigInteger, - db.ForeignKey('perturbation_configs.id'), default=1) - preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0) - preferred_finetune_config_id = db.Column(db.BigInteger, - db.ForeignKey('finetune_configs.id'), default=1) - preferred_purification = db.Column(db.Boolean, default=False) - created_at = db.Column(db.DateTime, default=datetime.utcnow) - updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - def to_dict(self): - """转换为字典""" - return { - 'user_id': self.user_id, - 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, - 'preferred_purification': self.preferred_purification, - 'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None, - 'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None, - 'updated_at': self.updated_at.isoformat() if self.updated_at else None - } \ No newline at end of file +""" +数据库模型定义 +基于已有的schema.sql设计 +""" + +from datetime import datetime +from app import db +from werkzeug.security import generate_password_hash, check_password_hash +from enum import Enum as PyEnum + +class User(db.Model): + """用户表""" + __tablename__ = 'users' + + id = db.Column(db.BigInteger, primary_key=True) + username = db.Column(db.String(50), unique=True, nullable=False) + password_hash = db.Column(db.String(255), nullable=False) + email = db.Column(db.String(100)) + role = db.Column(db.Enum('user', 'admin'), default='user') + max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0) + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + is_active = db.Column(db.Boolean, default=True) + + # 关系 + batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan') + images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan') + user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan') + + def set_password(self, password): + """设置密码""" + self.password_hash = generate_password_hash(password) + + def check_password(self, password): + """验证密码""" + return check_password_hash(self.password_hash, password) + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'username': self.username, + 'email': self.email, + 'role': self.role, + 'max_concurrent_tasks': self.max_concurrent_tasks, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'is_active': self.is_active + } + +class ImageType(db.Model): + """图片类型表""" + __tablename__ = 'image_types' + + id = db.Column(db.BigInteger, primary_key=True) + type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'), + unique=True, nullable=False) + type_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text) + + # 关系 + images = db.relationship('Image', backref='image_type', lazy='dynamic') + +class PerturbationConfig(db.Model): + """加噪算法表""" + __tablename__ = 'perturbation_configs' + + id = db.Column(db.BigInteger, primary_key=True) + method_code = db.Column(db.String(50), unique=True, nullable=False) + method_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=False) + default_epsilon = db.Column(db.Numeric(5, 2), nullable=False) + + # 关系 + batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic') + user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic') + +class FinetuneConfig(db.Model): + """微调方式表""" + __tablename__ = 'finetune_configs' + + id = db.Column(db.BigInteger, primary_key=True) + method_code = db.Column(db.String(50), unique=True, nullable=False) + method_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=False) + + # 关系 + batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic') + user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic') + +class Batch(db.Model): + """加噪批次表""" + __tablename__ = 'batch' + + id = db.Column(db.BigInteger, primary_key=True) + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) + batch_name = db.Column(db.String(128)) + + # 加噪配置 + perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'), + nullable=False, default=1) + preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0) + + # 评估配置 + finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'), + nullable=False, default=1) + + # 净化配置 + use_strong_protection = db.Column(db.Boolean, nullable=False, default=False) + + # 任务状态 + status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending') + created_at = db.Column(db.DateTime, default=datetime.utcnow) + started_at = db.Column(db.DateTime) + completed_at = db.Column(db.DateTime) + error_message = db.Column(db.Text) + result_path = db.Column(db.String(500)) + + # 关系 + images = db.relationship('Image', backref='batch', lazy='dynamic') + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'batch_name': self.batch_name, + 'status': self.status, + 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, + 'use_strong_protection': self.use_strong_protection, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'started_at': self.started_at.isoformat() if self.started_at else None, + 'completed_at': self.completed_at.isoformat() if self.completed_at else None, + 'error_message': self.error_message, + 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None, + 'finetune_config': self.finetune_config.method_name if self.finetune_config else None + } + +class Image(db.Model): + """图片表""" + __tablename__ = 'images' + + id = db.Column(db.BigInteger, primary_key=True) + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) + batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id')) + father_id = db.Column(db.BigInteger, db.ForeignKey('images.id')) + original_filename = db.Column(db.String(255)) + stored_filename = db.Column(db.String(255), unique=True, nullable=False) + file_path = db.Column(db.String(500), nullable=False) + file_size = db.Column(db.BigInteger) + image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False) + width = db.Column(db.Integer) + height = db.Column(db.Integer) + upload_time = db.Column(db.DateTime, default=datetime.utcnow) + + # 自引用关系 + children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') + + # 评估结果关系 + reference_evaluations = db.relationship('EvaluationResult', + foreign_keys='EvaluationResult.reference_image_id', + backref='reference_image', lazy='dynamic') + target_evaluations = db.relationship('EvaluationResult', + foreign_keys='EvaluationResult.target_image_id', + backref='target_image', lazy='dynamic') + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'original_filename': self.original_filename, + 'stored_filename': self.stored_filename, + 'file_path': self.file_path, + 'file_size': self.file_size, + 'width': self.width, + 'height': self.height, + 'upload_time': self.upload_time.isoformat() if self.upload_time else None, + 'image_type': self.image_type.type_name if self.image_type else None, + 'batch_id': self.batch_id + } + +class EvaluationResult(db.Model): + """评估结果表""" + __tablename__ = 'evaluation_results' + + id = db.Column(db.BigInteger, primary_key=True) + reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) + target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) + evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False) + purification_applied = db.Column(db.Boolean, default=False) + fid_score = db.Column(db.Numeric(8, 4)) + lpips_score = db.Column(db.Numeric(8, 4)) + ssim_score = db.Column(db.Numeric(8, 4)) + psnr_score = db.Column(db.Numeric(8, 4)) + heatmap_path = db.Column(db.String(500)) + evaluated_at = db.Column(db.DateTime, default=datetime.utcnow) + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'evaluation_type': self.evaluation_type, + 'purification_applied': self.purification_applied, + 'fid_score': float(self.fid_score) if self.fid_score else None, + 'lpips_score': float(self.lpips_score) if self.lpips_score else None, + 'ssim_score': float(self.ssim_score) if self.ssim_score else None, + 'psnr_score': float(self.psnr_score) if self.psnr_score else None, + 'heatmap_path': self.heatmap_path, + 'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None + } + +class UserConfig(db.Model): + """用户配置表""" + __tablename__ = 'user_configs' + + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True) + preferred_perturbation_config_id = db.Column(db.BigInteger, + db.ForeignKey('perturbation_configs.id'), default=1) + preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0) + preferred_finetune_config_id = db.Column(db.BigInteger, + db.ForeignKey('finetune_configs.id'), default=1) + preferred_purification = db.Column(db.Boolean, default=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + """转换为字典""" + return { + 'user_id': self.user_id, + 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, + 'preferred_purification': self.preferred_purification, + 'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None, + 'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None + } diff --git a/src/backend/app/scripts/attack_aspl.sh b/src/backend/app/scripts/attack_aspl.sh new file mode 100644 index 0000000..81e2361 --- /dev/null +++ b/src/backend/app/scripts/attack_aspl.sh @@ -0,0 +1,62 @@ +#需要环境:conda activate simac +export HF_HUB_OFFLINE=1 +export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" +export TASKNAME="task001" + +# ------------------------- Train ASPL on set CLEAN_ADV_DIR ------------------------- +export CLEAN_TRAIN_DIR="../../static/originals/${TASKNAME}" +export CLEAN_ADV_DIR="../../static/originals/${TASKNAME}" +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$CLEAN_TRAIN_DIR" +mkdir -p "$CLEAN_ADV_DIR" +mkdir -p "$OUTPUT_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete + + +accelerate launch ../algorithms/aspl.py \ +  --pretrained_model_name_or_path=$MODEL_PATH  \ +  --enable_xformers_memory_efficient_attention \ +  --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ +  --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ +  --instance_prompt="a photo of sks person" \ +  --class_data_dir=$CLASS_DIR \ +  --num_class_images=200 \ +  --class_prompt="a photo of person" \ +  --output_dir=$OUTPUT_DIR \ +  --center_crop \ +  --with_prior_preservation \ +  --prior_loss_weight=1.0 \ +  --resolution=384 \ +  --train_batch_size=1 \ +  --max_train_steps=50 \ +  --max_f_train_steps=3 \ +  --max_adv_train_steps=6 \ +  --checkpointing_iterations=10 \ +  --learning_rate=5e-7 \ +  --pgd_alpha=0.005 \ +  --pgd_eps=8 \ +  --seed=0 + +# ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/scripts/attack_caat.sh b/src/backend/app/scripts/attack_caat.sh new file mode 100644 index 0000000..90090db --- /dev/null +++ b/src/backend/app/scripts/attack_caat.sh @@ -0,0 +1,39 @@ +#需要环境:conda activate caat +export HF_HUB_OFFLINE=1 +# 强制使用本地模型缓存,避免联网下载模型 +#export HF_HOME="/root/autodl-tmp/huggingface_cache" +#export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export MODEL_NAME="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" +export TASKNAME="task001" +### Data to be protected +export INSTANCE_DIR="../../static/originals/${TASKNAME}" +### Path to save the protected data +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$OUTPUT_DIR" +echo "Directories created successfully." + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete + + +accelerate launch ../algorithms/caat.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a person" \ + --resolution=512 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --hflip \ + --mixed_precision bf16 \ + --alpha=5e-3 \ + --eps=0.05 \ No newline at end of file diff --git a/src/backend/app/scripts/attack_caat_with_prior.sh b/src/backend/app/scripts/attack_caat_with_prior.sh new file mode 100644 index 0000000..bc73b23 --- /dev/null +++ b/src/backend/app/scripts/attack_caat_with_prior.sh @@ -0,0 +1,55 @@ +export HF_HUB_OFFLINE=1 +# 强制使用本地模型缓存,避免联网下载模型 +#export HF_HOME="/root/autodl-tmp/huggingface_cache" +#export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export MODEL_NAME="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" +export TASKNAME="task001" +### Data to be protected +export INSTANCE_DIR="../../static/originals/${TASKNAME}" +### Path to save the protected data +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$OUTPUT_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete + + +accelerate launch ../algorithms/caat.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation \ + --instance_prompt="a photo of a person" \ + --num_class_images=200 \ + --class_data_dir=$CLASS_DIR \ + --class_prompt='person' \ + --resolution=512 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --hflip \ + --mixed_precision bf16 \ + --alpha=5e-3 \ + --eps=0.05 + + +# ------------------------- 【步骤 2】训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/scripts/attack_pid.sh b/src/backend/app/scripts/attack_pid.sh new file mode 100644 index 0000000..40b9101 --- /dev/null +++ b/src/backend/app/scripts/attack_pid.sh @@ -0,0 +1,53 @@ +#需要环境:conda activate pid +### Generate images protected by PID + +export HF_HUB_OFFLINE=1 +# 强制使用本地模型缓存,避免联网下载模型 + +### SD v2.1 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="stabilityai/stable-diffusion-2-1" + +### SD v1.5 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="runwayml/stable-diffusion-v1-5" +export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" + + +export TASKNAME="task001" +### Data to be protected +export INSTANCE_DIR="../../static/originals/${TASKNAME}" +### Path to save the protected data +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$OUTPUT_DIR" +echo "Directories created successfully." + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete + + + +### Generation command +# --max_train_steps: Optimizaiton steps +# --attack_type: target loss to update, choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'], +# Please refer to the file content for more usage + +CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --resolution=512 \ + --max_train_steps=1000 \ + --center_crop \ + --eps 12.75 \ + --attack_type add-log + diff --git a/src/backend/app/scripts/attack_simac.sh b/src/backend/app/scripts/attack_simac.sh new file mode 100644 index 0000000..7186c26 --- /dev/null +++ b/src/backend/app/scripts/attack_simac.sh @@ -0,0 +1,63 @@ +#需要环境:conda activate simac +export HF_HUB_OFFLINE=1 +export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" +export TASKNAME="task001" + +# ------------------------- Train ASPL on set CLEAN_ADV_DIR ------------------------- +export CLEAN_TRAIN_DIR="../../static/originals/${TASKNAME}" +export CLEAN_ADV_DIR="../../static/originals/${TASKNAME}" +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$CLEAN_TRAIN_DIR" +mkdir -p "$CLEAN_ADV_DIR" +mkdir -p "$OUTPUT_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete + + + +accelerate launch ../algorithms/simac.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --enable_xformers_memory_efficient_attention \ + --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ + --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ + --instance_prompt="a photo of sks person" \ + --class_data_dir=$CLASS_DIR \ + --num_class_images=200 \ + --class_prompt="a photo of person" \ + --output_dir=$OUTPUT_DIR \ + --center_crop \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --resolution=384 \ + --train_batch_size=1 \ + --max_train_steps=50 \ + --max_f_train_steps=3 \ + --max_adv_train_steps=6 \ + --checkpointing_iterations=10 \ + --learning_rate=5e-7 \ + --pgd_alpha=0.005 \ + --pgd_eps=8 \ + --seed=0 + +# ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/scripts/db_gen.sh b/src/backend/app/scripts/db_gen.sh new file mode 100644 index 0000000..01de68b --- /dev/null +++ b/src/backend/app/scripts/db_gen.sh @@ -0,0 +1,90 @@ +#需要环境:conda activate pid +### Trianing BAD Model (PID-poisoned Data) + +export HF_HUB_OFFLINE=1 +# 强制使用本地模型缓存,避免联网下载模型 + +### SD v2.1 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="stabilityai/stable-diffusion-2-1" + +### SD v1.5 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="runwayml/stable-diffusion-v1-5" +export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" + + +export TASKNAME="task001" +export TYPE="clean" #clean or perturbed + + +if [ "$TYPE" == "clean" ]; then + export INSTANCE_DIR="../../static/originals/${TASKNAME}" +else + export INSTANCE_DIR="../../static/perturbed/${TASKNAME}" +fi +export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}" +export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete +find "$OUTPUT_INFER_DIR" -mindepth 1 -delete + + + + +# ------------------------- Fine-tune DreamBooth on images ------------------------- +CUDA_VISIBLE_DEVICES=0 accelerate launch ../finetune_infras/train_dreambooth_gen.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$DREAMBOOTH_OUTPUT_DIR \ + --validation_image_output_dir=$OUTPUT_INFER_DIR \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks person" \ + --class_prompt="a photo of person" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=1000 \ + --checkpointing_steps=500 \ + --center_crop \ + --mixed_precision=bf16 \ + --prior_generation_precision=bf16 \ + --sample_batch_size=5 \ + --validation_prompt="a photo of sks person" \ + --num_validation_images 10 \ + --validation_steps 500 + + + + # ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/scripts/db_infer.sh b/src/backend/app/scripts/db_infer.sh new file mode 100644 index 0000000..b78d260 --- /dev/null +++ b/src/backend/app/scripts/db_infer.sh @@ -0,0 +1,85 @@ +#需要环境:conda activate simac +export HF_HUB_OFFLINE=1 +export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" + + + +export TASKNAME="task001" +export TYPE="perturbed" #clean or perturbed + + + +if [ "$TYPE" == "clean" ]; then + export INSTANCE_DIR="../../static/originals/${TASKNAME}" +else + export INSTANCE_DIR="../../static/perturbed/${TASKNAME}" +fi +export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}" +export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete +find "$OUTPUT_INFER_DIR" -mindepth 1 -delete + + +# ------------------------- Fine-tune DreamBooth on images ------------------------- +accelerate launch ../finetune_infras/train_dreambooth_alone.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --enable_xformers_memory_efficient_attention \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$DREAMBOOTH_OUTPUT_DIR \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks person" \ + --class_prompt="a photo of person" \ + --inference_prompt="a photo of sks person" \ + --resolution=384 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --learning_rate=1e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=1000 \ + --checkpointing_steps=1000 \ + --center_crop \ + --mixed_precision=bf16 \ + --prior_generation_precision=bf16 \ + --sample_batch_size=1 \ + --seed=0 + +# ------------------------- Inference ------------------------- +python ../finetune_infras/infer.py \ + --model_path $DREAMBOOTH_OUTPUT_DIR \ + --output_dir $OUTPUT_INFER_DIR + + + +# ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/scripts/lora_gen.sh b/src/backend/app/scripts/lora_gen.sh new file mode 100644 index 0000000..2237dd1 --- /dev/null +++ b/src/backend/app/scripts/lora_gen.sh @@ -0,0 +1,84 @@ +#需要环境:conda activate pid +### Trianing model +export HF_HUB_OFFLINE=1 +# 强制使用本地模型缓存,避免联网下载模型 + +### SD v2.1 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="stabilityai/stable-diffusion-2-1" + +### SD v1.5 +# export HF_HOME="/root/autodl-tmp/huggingface_cache" +# export MODEL_PATH="runwayml/stable-diffusion-v1-5" +export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" + + +export TASKNAME="task001" +export TYPE="clean" #clean or perturbed + + +if [ "$TYPE" == "clean" ]; then + export INSTANCE_DIR="../../static/originals/${TASKNAME}" +else + export INSTANCE_DIR="../../static/perturbed/${TASKNAME}" +fi +export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}" +export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$INSTANCE_DIR" +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$DREAMBOOTH_OUTPUT_DIR" +mkdir -p "$OUTPUT_INFER_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete +find "$OUTPUT_INFER_DIR" -mindepth 1 -delete + + + +# ------------------------- Fine-tune LoRA on images ------------------------- +CUDA_VISIBLE_DEVICES=0 accelerate launch ../finetune_infras/train_lora_gen.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$DREAMBOOTH_OUTPUT_DIR \ + --validation_image_output_dir=$OUTPUT_INFER_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks person" \ + --class_prompt="a photo of person" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=1000 \ + --checkpointing_steps=500 \ + --seed=0 \ + --mixed_precision=fp16 \ + --rank=4 \ + --validation_prompt="a photo of sks person" \ + --num_validation_images 10 \ + + + +# ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file diff --git a/src/backend/app/services/auth_service.py b/src/backend/app/services/auth_service.py index 63a95c1..756d907 100644 --- a/src/backend/app/services/auth_service.py +++ b/src/backend/app/services/auth_service.py @@ -1,34 +1,34 @@ -""" -认证服务 -处理用户认证相关逻辑 -""" - -from app.models import User - -class AuthService: - """认证服务类""" - - @staticmethod - def authenticate_user(username, password): - """验证用户凭据""" - user = User.query.filter_by(username=username).first() - - if user and user.check_password(password) and user.is_active: - return user - - return None - - @staticmethod - def get_user_by_id(user_id): - """根据ID获取用户""" - return User.query.get(user_id) - - @staticmethod - def is_email_available(email): - """检查邮箱是否可用""" - return User.query.filter_by(email=email).first() is None - - @staticmethod - def is_username_available(username): - """检查用户名是否可用""" +""" +认证服务 +处理用户认证相关逻辑 +""" + +from app.database import User + +class AuthService: + """认证服务类""" + + @staticmethod + def authenticate_user(username, password): + """验证用户凭据""" + user = User.query.filter_by(username=username).first() + + if user and user.check_password(password) and user.is_active: + return user + + return None + + @staticmethod + def get_user_by_id(user_id): + """根据ID获取用户""" + return User.query.get(user_id) + + @staticmethod + def is_email_available(email): + """检查邮箱是否可用""" + return User.query.filter_by(email=email).first() is None + + @staticmethod + def is_username_available(username): + """检查用户名是否可用""" return User.query.filter_by(username=username).first() is None \ No newline at end of file diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index 914533d..b8857e5 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -1,161 +1,161 @@ -""" -图像处理服务 -处理图像上传、保存等功能 -""" - -import os -import uuid -import zipfile -from werkzeug.utils import secure_filename -from flask import current_app -from PIL import Image as PILImage -from app import db -from app.models import Image -from app.utils.file_utils import allowed_file - -class ImageService: - """图像处理服务""" - - @staticmethod - def save_image(file, batch_id, user_id, image_type_id): - """保存单张图片""" - try: - if not file or not allowed_file(file.filename): - return {'success': False, 'error': '不支持的文件格式'} - - # 生成唯一文件名 - file_extension = os.path.splitext(file.filename)[1].lower() - stored_filename = f"{uuid.uuid4().hex}{file_extension}" - - # 临时保存到上传目录 - project_root = os.path.dirname(current_app.root_path) - temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) - os.makedirs(temp_dir, exist_ok=True) - temp_path = os.path.join(temp_dir, stored_filename) - file.save(temp_path) - - # 移动到对应的静态目录 - static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id)) - os.makedirs(static_dir, exist_ok=True) - file_path = os.path.join(static_dir, stored_filename) - - # 移动文件到最终位置 - import shutil - shutil.move(temp_path, file_path) - - # 获取图片尺寸 - try: - with PILImage.open(file_path) as img: - width, height = img.size - except: - width, height = None, None - - # 创建数据库记录 - image = Image( - user_id=user_id, - batch_id=batch_id, - original_filename=file.filename, - stored_filename=stored_filename, - file_path=file_path, - file_size=os.path.getsize(file_path), - image_type_id=image_type_id, - width=width, - height=height - ) - - db.session.add(image) - db.session.commit() - - return {'success': True, 'image': image} - - except Exception as e: - db.session.rollback() - return {'success': False, 'error': f'保存图片失败: {str(e)}'} - - @staticmethod - def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id): - """解压并保存压缩包中的图片""" - results = [] - temp_dir = None - - try: - # 创建临时目录 - project_root = os.path.dirname(current_app.root_path) - temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}") - os.makedirs(temp_dir, exist_ok=True) - - # 保存压缩包 - zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename)) - zip_file.save(zip_path) - - # 解压文件 - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - # 遍历解压的文件 - for root, dirs, files in os.walk(temp_dir): - for filename in files: - if filename.lower().endswith(('.zip', '.rar')): - continue # 跳过压缩包文件本身 - - if allowed_file(filename): - file_path = os.path.join(root, filename) - - # 创建虚拟文件对象 - class FileWrapper: - def __init__(self, path, name): - self.path = path - self.filename = name - - def save(self, destination): - import shutil - shutil.copy2(self.path, destination) - - virtual_file = FileWrapper(file_path, filename) - result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id) - results.append(result) - - return results - - except Exception as e: - return [{'success': False, 'error': f'解压失败: {str(e)}'}] - - finally: - # 清理临时文件 - if temp_dir and os.path.exists(temp_dir): - import shutil - try: - shutil.rmtree(temp_dir) - except: - pass - - @staticmethod - def get_image_url(image): - """获取图片访问URL""" - if not image or not image.file_path: - return None - - # 这里返回相对路径,前端可以拼接完整URL - return f"/api/image/file/{image.id}" - - @staticmethod - def delete_image(image_id, user_id): - """删除图片""" - try: - image = Image.query.filter_by(id=image_id, user_id=user_id).first() - if not image: - return {'success': False, 'error': '图片不存在或无权限'} - - # 删除文件 - if os.path.exists(image.file_path): - os.remove(image.file_path) - - # 删除数据库记录 - db.session.delete(image) - db.session.commit() - - return {'success': True} - - except Exception as e: - db.session.rollback() +""" +图像处理服务 +处理图像上传、保存等功能 +""" + +import os +import uuid +import zipfile +from werkzeug.utils import secure_filename +from flask import current_app +from PIL import Image as PILImage +from app import db +from app.database import Image +from app.utils.file_utils import allowed_file + +class ImageService: + """图像处理服务""" + + @staticmethod + def save_image(file, batch_id, user_id, image_type_id): + """保存单张图片""" + try: + if not file or not allowed_file(file.filename): + return {'success': False, 'error': '不支持的文件格式'} + + # 生成唯一文件名 + file_extension = os.path.splitext(file.filename)[1].lower() + stored_filename = f"{uuid.uuid4().hex}{file_extension}" + + # 临时保存到上传目录 + project_root = os.path.dirname(current_app.root_path) + temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, stored_filename) + file.save(temp_path) + + # 移动到对应的静态目录 + static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id)) + os.makedirs(static_dir, exist_ok=True) + file_path = os.path.join(static_dir, stored_filename) + + # 移动文件到最终位置 + import shutil + shutil.move(temp_path, file_path) + + # 获取图片尺寸 + try: + with PILImage.open(file_path) as img: + width, height = img.size + except: + width, height = None, None + + # 创建数据库记录 + image = Image( + user_id=user_id, + batch_id=batch_id, + original_filename=file.filename, + stored_filename=stored_filename, + file_path=file_path, + file_size=os.path.getsize(file_path), + image_type_id=image_type_id, + width=width, + height=height + ) + + db.session.add(image) + db.session.commit() + + return {'success': True, 'image': image} + + except Exception as e: + db.session.rollback() + return {'success': False, 'error': f'保存图片失败: {str(e)}'} + + @staticmethod + def extract_and_save_zip(zip_file, batch_id, user_id, image_type_id): + """解压并保存压缩包中的图片""" + results = [] + temp_dir = None + + try: + # 创建临时目录 + project_root = os.path.dirname(current_app.root_path) + temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], 'temp', f"{uuid.uuid4().hex}") + os.makedirs(temp_dir, exist_ok=True) + + # 保存压缩包 + zip_path = os.path.join(temp_dir, secure_filename(zip_file.filename)) + zip_file.save(zip_path) + + # 解压文件 + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + # 遍历解压的文件 + for root, dirs, files in os.walk(temp_dir): + for filename in files: + if filename.lower().endswith(('.zip', '.rar')): + continue # 跳过压缩包文件本身 + + if allowed_file(filename): + file_path = os.path.join(root, filename) + + # 创建虚拟文件对象 + class FileWrapper: + def __init__(self, path, name): + self.path = path + self.filename = name + + def save(self, destination): + import shutil + shutil.copy2(self.path, destination) + + virtual_file = FileWrapper(file_path, filename) + result = ImageService.save_image(virtual_file, batch_id, user_id, image_type_id) + results.append(result) + + return results + + except Exception as e: + return [{'success': False, 'error': f'解压失败: {str(e)}'}] + + finally: + # 清理临时文件 + if temp_dir and os.path.exists(temp_dir): + import shutil + try: + shutil.rmtree(temp_dir) + except: + pass + + @staticmethod + def get_image_url(image): + """获取图片访问URL""" + if not image or not image.file_path: + return None + + # 这里返回相对路径,前端可以拼接完整URL + return f"/api/image/file/{image.id}" + + @staticmethod + def delete_image(image_id, user_id): + """删除图片""" + try: + image = Image.query.filter_by(id=image_id, user_id=user_id).first() + if not image: + return {'success': False, 'error': '图片不存在或无权限'} + + # 删除文件 + if os.path.exists(image.file_path): + os.remove(image.file_path) + + # 删除数据库记录 + db.session.delete(image) + db.session.commit() + + return {'success': True} + + except Exception as e: + db.session.rollback() return {'success': False, 'error': f'删除图片失败: {str(e)}'} \ No newline at end of file diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index b9b00e0..aa9ab39 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -1,191 +1,534 @@ -""" -任务处理服务 -处理图像加噪、评估等核心业务逻辑 -""" - -import os -import time -from datetime import datetime -from flask import current_app -from app import db -from app.models import Batch, Image, EvaluationResult, ImageType -from app.algorithms.perturbation_engine import PerturbationEngine -from app.algorithms.evaluation_engine import EvaluationEngine - -class TaskService: - """任务处理服务""" - - @staticmethod - def start_processing(batch): - """开始处理任务""" - try: - # 更新任务状态 - batch.status = 'processing' - batch.started_at = datetime.utcnow() - db.session.commit() - - # 获取任务相关的原始图片 - original_images = Image.query.filter_by( - batch_id=batch.id - ).join(ImageType).filter( - ImageType.type_code == 'original' - ).all() - - if not original_images: - batch.status = 'failed' - batch.error_message = '没有找到原始图片' - batch.completed_at = datetime.utcnow() - db.session.commit() - return False - - # 处理每张图片 - perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() - - processed_images = [] - for original_image in original_images: - try: - # 确定加噪图片的保存路径 - project_root = os.path.dirname(current_app.root_path) - perturbed_dir = os.path.join(project_root, - current_app.config['PERTURBED_IMAGES_FOLDER'], - str(batch.user_id), - str(batch.id)) - os.makedirs(perturbed_dir, exist_ok=True) - - # 调用加噪算法 - perturbed_image_path = PerturbationEngine.apply_perturbation( - original_image.file_path, - batch.perturbation_config.method_code, - float(batch.preferred_epsilon), - batch.use_strong_protection - ) - - if perturbed_image_path: - # 保存加噪后的图片记录 - perturbed_image = Image( - user_id=batch.user_id, - batch_id=batch.id, - father_id=original_image.id, - original_filename=f"perturbed_{original_image.original_filename}", - stored_filename=os.path.basename(perturbed_image_path), - file_path=perturbed_image_path, - file_size=os.path.getsize(perturbed_image_path) if os.path.exists(perturbed_image_path) else 0, - image_type_id=perturbed_type.id, - width=original_image.width, - height=original_image.height - ) - - db.session.add(perturbed_image) - processed_images.append((original_image, perturbed_image)) - - except Exception as e: - print(f"处理图片 {original_image.id} 时出错: {str(e)}") - continue - - # 提交加噪后的图片 - db.session.commit() - - # 生成评估结果 - TaskService._generate_evaluations(batch, processed_images) - - # 更新任务状态为完成 - batch.status = 'completed' - batch.completed_at = datetime.utcnow() - db.session.commit() - - return True - - except Exception as e: - # 处理失败 - batch.status = 'failed' - batch.error_message = str(e) - batch.completed_at = datetime.utcnow() - db.session.commit() - return False - - @staticmethod - def _generate_evaluations(batch, processed_images): - """生成评估结果""" - try: - for original_image, perturbed_image in processed_images: - # 图像质量对比评估 - quality_metrics = EvaluationEngine.evaluate_image_quality( - original_image.file_path, - perturbed_image.file_path - ) - - quality_evaluation = EvaluationResult( - reference_image_id=original_image.id, - target_image_id=perturbed_image.id, - evaluation_type='image_quality', - purification_applied=False, - fid_score=quality_metrics.get('fid'), - lpips_score=quality_metrics.get('lpips'), - ssim_score=quality_metrics.get('ssim'), - psnr_score=quality_metrics.get('psnr'), - heatmap_path=quality_metrics.get('heatmap_path') - ) - - db.session.add(quality_evaluation) - - # 模型生成对比评估 - generation_metrics = EvaluationEngine.evaluate_model_generation( - original_image.file_path, - perturbed_image.file_path, - batch.finetune_config.method_code - ) - - generation_evaluation = EvaluationResult( - reference_image_id=original_image.id, - target_image_id=perturbed_image.id, - evaluation_type='model_generation', - purification_applied=False, - fid_score=generation_metrics.get('fid'), - lpips_score=generation_metrics.get('lpips'), - ssim_score=generation_metrics.get('ssim'), - psnr_score=generation_metrics.get('psnr'), - heatmap_path=generation_metrics.get('heatmap_path') - ) - - db.session.add(generation_evaluation) - - db.session.commit() - - except Exception as e: - print(f"生成评估结果时出错: {str(e)}") - - @staticmethod - def get_processing_progress(batch_id): - """获取处理进度""" - try: - batch = Batch.query.get(batch_id) - if not batch: - return 0 - - if batch.status == 'pending': - return 0 - elif batch.status == 'completed': - return 100 - elif batch.status == 'failed': - return 0 - elif batch.status == 'processing': - # 简单的进度计算:根据已处理的图片数量 - total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( - ImageType.type_code == 'original' - ).count() - - processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( - ImageType.type_code == 'perturbed' - ).count() - - if total_images == 0: - return 0 - - progress = int((processed_images / total_images) * 80) # 80%用于图像处理,20%用于评估 - return min(progress, 95) # 最多95%,剩余5%用于最终完成 - - return 0 - - except Exception as e: - print(f"获取处理进度时出错: {str(e)}") - return 0 \ No newline at end of file +""" +任务处理服务 +处理图像加噪、评估等核心业务逻辑 +使用Redis Queue进行异步任务处理 +""" + +import os +from datetime import datetime +from flask import current_app +from redis import Redis +from rq import Queue +from rq.job import Job +from app import db +from app.database import Batch, Image, EvaluationResult, ImageType +from config.algorithm_config import AlgorithmConfig + +class TaskService: + """任务处理服务""" + + @staticmethod + def _get_redis_connection(): + """获取Redis连接""" + return Redis.from_url(AlgorithmConfig.REDIS_URL) + + @staticmethod + def _get_queue(): + """获取RQ队列""" + redis_conn = TaskService._get_redis_connection() + return Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn) + + @staticmethod + def start_processing(batch): + """ + 开始处理任务(异步) + + Args: + batch: Batch对象 + + Returns: + 任务ID (RQ job id) + """ + try: + # 检查是否有原始图片 + original_images = Image.query.filter_by( + batch_id=batch.id + ).join(ImageType).filter( + ImageType.type_code == 'original' + ).all() + + if not original_images: + batch.status = 'failed' + batch.error_message = '没有找到原始图片' + batch.completed_at = datetime.utcnow() + db.session.commit() + return None + + # 准备任务参数 + project_root = os.path.dirname(current_app.root_path) + + # 输入目录(原始图片) + input_dir = os.path.join( + project_root, + current_app.config['ORIGINAL_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id) + ) + + # 输出目录(扰动后图片) + output_dir = os.path.join( + project_root, + current_app.config['PERTURBED_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id) + ) + + # 类别图片目录(用于prior preservation) + class_dir = os.path.join( + project_root, + 'static', 'class', + str(batch.user_id), + str(batch.id) + ) + + # 获取队列 + queue = TaskService._get_queue() + + # 提交任务到队列 + from app.workers.perturbation_worker import run_perturbation_task + + job = queue.enqueue( + run_perturbation_task, + batch_id=batch.id, + algorithm_code=batch.perturbation_config.method_code, + epsilon=float(batch.preferred_epsilon), + use_strong_protection=batch.use_strong_protection, + input_dir=input_dir, + output_dir=output_dir, + class_dir=class_dir, + custom_params=None, + job_timeout=AlgorithmConfig.TASK_TIMEOUT, + job_id=f"batch_{batch.id}" + ) + + # 更新任务状态 + batch.status = 'queued' + db.session.commit() + + return job.id + + except Exception as e: + # 处理失败 + batch.status = 'failed' + batch.error_message = str(e) + batch.completed_at = datetime.utcnow() + db.session.commit() + return None + + @staticmethod + def get_task_status(batch_id): + """ + 获取任务状态 + + Args: + batch_id: 批次ID + + Returns: + 任务状态信息 + """ + try: + batch = Batch.query.get(batch_id) + if not batch: + return {'status': 'not_found'} + + # 如果任务已完成或失败,直接返回数据库状态 + if batch.status in ['completed', 'failed']: + return { + 'status': batch.status, + 'error': batch.error_message if batch.status == 'failed' else None, + 'started_at': batch.started_at, + 'completed_at': batch.completed_at + } + + # 尝试从RQ获取任务状态 + try: + redis_conn = TaskService._get_redis_connection() + job = Job.fetch(f"batch_{batch_id}", connection=redis_conn) + + rq_status = job.get_status() + + # 映射RQ状态到我们的状态 + status_map = { + 'queued': 'queued', + 'started': 'processing', + 'finished': 'completed', + 'failed': 'failed' + } + + return { + 'status': status_map.get(rq_status, batch.status), + 'rq_status': rq_status, + 'progress': job.meta.get('progress', 0) if hasattr(job, 'meta') else 0, + 'started_at': batch.started_at, + 'result': job.result if rq_status == 'finished' else None, + 'error': str(job.exc_info) if rq_status == 'failed' else None + } + except: + # 如果无法从RQ获取状态,返回数据库状态 + return { + 'status': batch.status, + 'started_at': batch.started_at + } + + except Exception as e: + return {'status': 'error', 'error': str(e)} + + @staticmethod + def cancel_task(batch_id): + """ + 取消任务 + + Args: + batch_id: 批次ID + + Returns: + 是否成功取消 + """ + try: + batch = Batch.query.get(batch_id) + if not batch: + return False + + # 尝试从队列中删除任务 + try: + redis_conn = TaskService._get_redis_connection() + job = Job.fetch(f"batch_{batch_id}", connection=redis_conn) + job.cancel() + except: + pass + + # 更新数据库状态 + batch.status = 'failed' + batch.error_message = 'Task cancelled by user' + batch.completed_at = datetime.utcnow() + db.session.commit() + + return True + + except Exception as e: + print(f"取消任务时出错: {str(e)}") + return False + + @staticmethod + def process_results_and_evaluations(batch_id): + """ + 处理任务结果并生成评估(在worker完成后调用) + + Args: + batch_id: 批次ID + """ + try: + batch = Batch.query.get(batch_id) + if not batch: + return + + # 获取输出目录中的图片 + project_root = os.path.dirname(current_app.root_path) + output_dir = os.path.join( + project_root, + current_app.config['PERTURBED_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id) + ) + + # 获取原始图片 + original_images = Image.query.filter_by( + batch_id=batch.id + ).join(ImageType).filter( + ImageType.type_code == 'original' + ).all() + + perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() + + processed_images = [] + + # 为每张原始图片找到对应的扰动图片 + for original_image in original_images: + # 构建扰动图片路径 + original_name = os.path.splitext(original_image.original_filename)[0] + original_ext = os.path.splitext(original_image.original_filename)[1] + + perturbed_filename = f"{original_name}_perturbed{original_ext}" + perturbed_path = os.path.join(output_dir, perturbed_filename) + + if os.path.exists(perturbed_path): + # 保存扰动图片到数据库 + perturbed_image = Image( + user_id=batch.user_id, + batch_id=batch.id, + father_id=original_image.id, + original_filename=f"perturbed_{original_image.original_filename}", + stored_filename=os.path.basename(perturbed_path), + file_path=perturbed_path, + file_size=os.path.getsize(perturbed_path), + image_type_id=perturbed_type.id, + width=original_image.width, + height=original_image.height + ) + + db.session.add(perturbed_image) + processed_images.append((original_image, perturbed_image)) + + db.session.commit() + + # 生成评估结果 + TaskService._generate_evaluations(batch, processed_images) + + except Exception as e: + print(f"处理结果时出错: {str(e)}") @staticmethod + def _generate_evaluations(batch, processed_images): + """生成评估结果(虚拟实现)""" + try: + for original_image, perturbed_image in processed_images: + # TODO: 实现真实的评估引擎 + # 目前使用虚拟数据 + import random + + # 图像质量对比评估(虚拟数据) + quality_evaluation = EvaluationResult( + reference_image_id=original_image.id, + target_image_id=perturbed_image.id, + evaluation_type='image_quality', + purification_applied=False, + fid_score=round(random.uniform(0.1, 0.5), 4), + lpips_score=round(random.uniform(0.01, 0.1), 4), + ssim_score=round(random.uniform(0.85, 0.99), 4), + psnr_score=round(random.uniform(30, 45), 2), + heatmap_path=None + ) + + db.session.add(quality_evaluation) + + # 模型生成对比评估(虚拟数据) + generation_evaluation = EvaluationResult( + reference_image_id=original_image.id, + target_image_id=perturbed_image.id, + evaluation_type='model_generation', + purification_applied=False, + fid_score=round(random.uniform(0.2, 0.8), 4), + lpips_score=round(random.uniform(0.05, 0.2), 4), + ssim_score=round(random.uniform(0.7, 0.9), 4), + psnr_score=round(random.uniform(25, 40), 2), + heatmap_path=None + ) + + db.session.add(generation_evaluation) + + db.session.commit() + + except Exception as e: + print(f"生成评估结果时出错: {str(e)}") + + @staticmethod + def get_processing_progress(batch_id): + """获取处理进度""" + try: + batch = Batch.query.get(batch_id) + if not batch: + return 0 + + if batch.status == 'pending': + return 0 + elif batch.status == 'completed': + return 100 + elif batch.status == 'failed': + return 0 + elif batch.status == 'processing': + # 简单的进度计算:根据已处理的图片数量 + total_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( + ImageType.type_code == 'original' + ).count() + + processed_images = Image.query.filter_by(batch_id=batch_id).join(ImageType).filter( + ImageType.type_code == 'perturbed' + ).count() + + if total_images == 0: + return 0 + + progress = int((processed_images / total_images) * 80) # 80%用于图像处理,20%用于评估 + return min(progress, 95) # 最多95%,剩余5%用于最终完成 + + return 0 + + except Exception as e: + print(f"获取处理进度时出错: {str(e)}") + return 0 + + @staticmethod + def start_finetune_and_evaluation(batch): + """ + 启动微调和评估任务 + + 此方法在扰动任务完成后调用,分别使用原始图片和扰动图片微调模型, + 然后对比生成效果 + + Args: + batch: Batch对象 + + Returns: + 包含两个job_id的字典 + """ + try: + # 检查是否有扰动图片 + perturbed_images = Image.query.filter_by( + batch_id=batch.id + ).join(ImageType).filter( + ImageType.type_code == 'perturbed' + ).all() + + if not perturbed_images: + print(f"Batch {batch.id} 没有扰动图片,无法启动微调任务") + return None + + project_root = os.path.dirname(current_app.root_path) + finetune_method = batch.finetune_config.method_code + queue = TaskService._get_queue() + + from app.workers.finetune_worker import run_finetune_task + + # 原始图片目录 + original_dir = os.path.join( + project_root, + current_app.config['ORIGINAL_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id) + ) + + # 扰动图片目录 + perturbed_dir = os.path.join( + project_root, + current_app.config['PERTURBED_IMAGES_FOLDER'], + str(batch.user_id), + str(batch.id) + ) + + # 模型输出目录 + original_model_dir = os.path.join( + project_root, 'static', 'models', 'original', + str(batch.user_id), str(batch.id) + ) + + perturbed_model_dir = os.path.join( + project_root, 'static', 'models', 'perturbed', + str(batch.user_id), str(batch.id) + ) + + # 类别图片目录(微调用) + class_finetune_dir = os.path.join( + project_root, 'static', 'class_finetune', + str(batch.user_id), str(batch.id) + ) + + # 推理提示词 + inference_prompts = "a photo of sks person" + + # 1. 用原始图片微调模型 + job_original = queue.enqueue( + run_finetune_task, + batch_id=batch.id, + finetune_method=finetune_method, + train_images_dir=original_dir, + output_model_dir=original_model_dir, + class_dir=class_finetune_dir, + inference_prompts=inference_prompts, + is_perturbed=False, + custom_params=None, + job_timeout=AlgorithmConfig.TASK_TIMEOUT, + job_id=f"finetune_original_{batch.id}" + ) + + # 2. 用扰动图片微调模型(依赖于原始图片微调完成) + job_perturbed = queue.enqueue( + run_finetune_task, + batch_id=batch.id, + finetune_method=finetune_method, + train_images_dir=perturbed_dir, + output_model_dir=perturbed_model_dir, + class_dir=class_finetune_dir, + inference_prompts=inference_prompts, + is_perturbed=True, + custom_params=None, + job_timeout=AlgorithmConfig.TASK_TIMEOUT, + job_id=f"finetune_perturbed_{batch.id}", + depends_on=job_original # 依赖关系 + ) + + return { + 'original_job_id': job_original.id, + 'perturbed_job_id': job_perturbed.id + } + + except Exception as e: + print(f"启动微调任务时出错: {str(e)}") + return None + + @staticmethod + def generate_final_evaluations(batch_id): + """ + 生成最终评估(对比原始和扰动图片微调后的模型生成效果) + + 此方法在两个微调任务都完成后调用 + + Args: + batch_id: 批次ID + """ + try: + batch = Batch.query.get(batch_id) + if not batch: + return + + # 获取原始图片生成的结果 + original_generated = Image.query.filter_by( + batch_id=batch_id + ).join(ImageType).filter( + ImageType.type_code == 'original_generate' + ).all() + + # 获取扰动图片生成的结果 + perturbed_generated = Image.query.filter_by( + batch_id=batch_id + ).join(ImageType).filter( + ImageType.type_code == 'perturbed_generate' + ).all() + + if not original_generated or not perturbed_generated: + print(f"Batch {batch_id} 缺少生成的图片,无法评估") + return + + # 配对评估 + for orig_gen in original_generated: + # 找到对应的扰动生成图片(基于相同的父图片) + matching_pert_gen = None + for pert_gen in perturbed_generated: + # 尝试匹配文件名或父图片关系 + if pert_gen.original_filename.replace('generated_', '') == orig_gen.original_filename.replace('generated_', ''): + matching_pert_gen = pert_gen + break + + if matching_pert_gen: + # TODO: 实现真实的评估引擎 + # 目前使用虚拟数据 + import random + + # 保存评估结果(虚拟数据) + generation_evaluation = EvaluationResult( + reference_image_id=orig_gen.id, + target_image_id=matching_pert_gen.id, + evaluation_type='model_generation', + purification_applied=False, + fid_score=round(random.uniform(0.3, 0.9), 4), + lpips_score=round(random.uniform(0.1, 0.3), 4), + ssim_score=round(random.uniform(0.6, 0.85), 4), + psnr_score=round(random.uniform(20, 35), 2), + heatmap_path=None + ) + + db.session.add(generation_evaluation) + + db.session.commit() + print(f"Batch {batch_id} 最终评估完成") + + except Exception as e: + print(f"生成最终评估时出错: {str(e)}") + db.session.rollback() \ No newline at end of file diff --git a/src/backend/app/utils/file_utils.py b/src/backend/app/utils/file_utils.py index f23d489..1748e05 100644 --- a/src/backend/app/utils/file_utils.py +++ b/src/backend/app/utils/file_utils.py @@ -1,53 +1,53 @@ -""" -文件处理工具类 -""" - -import os -from werkzeug.utils import secure_filename -from flask import current_app - -def allowed_file(filename): - """检查文件扩展名是否被允许""" - if not filename: - return False - - allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS', - {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}) - - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in allowed_extensions - -def save_uploaded_file(file, upload_path): - """保存上传的文件""" - try: - if not file or not allowed_file(file.filename): - return None - - filename = secure_filename(file.filename) - - # 确保目录存在 - os.makedirs(os.path.dirname(upload_path), exist_ok=True) - - file.save(upload_path) - return upload_path - - except Exception as e: - print(f"保存文件失败: {str(e)}") - return None - -def get_file_size(file_path): - """获取文件大小""" - try: - return os.path.getsize(file_path) - except: - return 0 - -def delete_file(file_path): - """删除文件""" - try: - if os.path.exists(file_path): - os.remove(file_path) - return True - except: - pass +""" +文件处理工具类 +""" + +import os +from werkzeug.utils import secure_filename +from flask import current_app + +def allowed_file(filename): + """检查文件扩展名是否被允许""" + if not filename: + return False + + allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS', + {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}) + + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in allowed_extensions + +def save_uploaded_file(file, upload_path): + """保存上传的文件""" + try: + if not file or not allowed_file(file.filename): + return None + + filename = secure_filename(file.filename) + + # 确保目录存在 + os.makedirs(os.path.dirname(upload_path), exist_ok=True) + + file.save(upload_path) + return upload_path + + except Exception as e: + print(f"保存文件失败: {str(e)}") + return None + +def get_file_size(file_path): + """获取文件大小""" + try: + return os.path.getsize(file_path) + except: + return 0 + +def delete_file(file_path): + """删除文件""" + try: + if os.path.exists(file_path): + os.remove(file_path) + return True + except: + pass return False \ No newline at end of file diff --git a/src/backend/app/utils/jwt_utils.py b/src/backend/app/utils/jwt_utils.py index 8e0aee6..359d937 100644 --- a/src/backend/app/utils/jwt_utils.py +++ b/src/backend/app/utils/jwt_utils.py @@ -1,16 +1,16 @@ -""" -JWT工具函数 -""" -from functools import wraps -from flask_jwt_extended import get_jwt_identity - -def int_jwt_required(f): - """获取JWT身份并转换为整数的装饰器""" - @wraps(f) - def wrapped(*args, **kwargs): - jwt_identity = get_jwt_identity() - if jwt_identity is not None: - # 在函数调用前注入转换后的user_id - kwargs['current_user_id'] = int(jwt_identity) - return f(*args, **kwargs) +""" +JWT工具函数 +""" +from functools import wraps +from flask_jwt_extended import get_jwt_identity + +def int_jwt_required(f): + """获取JWT身份并转换为整数的装饰器""" + @wraps(f) + def wrapped(*args, **kwargs): + jwt_identity = get_jwt_identity() + if jwt_identity is not None: + # 在函数调用前注入转换后的user_id + kwargs['current_user_id'] = int(jwt_identity) + return f(*args, **kwargs) return wrapped \ No newline at end of file diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py new file mode 100644 index 0000000..40eb747 --- /dev/null +++ b/src/backend/app/workers/finetune_worker.py @@ -0,0 +1,365 @@ +""" +RQ Worker 微调任务处理器 +在后台执行模型微调任务 +""" + +import os +import subprocess +import logging +from datetime import datetime + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_finetune_task(batch_id, finetune_method, train_images_dir, output_model_dir, + class_dir, inference_prompts, is_perturbed=False, custom_params=None): + """ + 执行微调任务 + + Args: + batch_id: 任务批次ID + finetune_method: 微调方法 (dreambooth, lora) + train_images_dir: 训练图片目录(原始或扰动) + output_model_dir: 模型输出目录 + class_dir: 类别图片目录 + inference_prompts: 推理提示词 + is_perturbed: 是否是扰动图片训练 + custom_params: 自定义参数 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Batch, Image, ImageType + + app = create_app() + + with app.app_context(): + try: + batch = Batch.query.get(batch_id) + if not batch: + raise ValueError(f"Batch {batch_id} not found") + + logger.info(f"Starting finetune task for batch {batch_id}") + logger.info(f"Method: {finetune_method}, Perturbed: {is_perturbed}") + + # 确保目录存在 + os.makedirs(output_model_dir, exist_ok=True) + os.makedirs(class_dir, exist_ok=True) + + # 获取配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + + if use_real: + # 使用真实微调算法 + result = _run_real_finetune( + finetune_method, batch_id, train_images_dir, output_model_dir, + class_dir, inference_prompts, is_perturbed, custom_params + ) + else: + # 使用虚拟微调实现 + result = _run_virtual_finetune( + finetune_method, batch_id, train_images_dir, output_model_dir, + is_perturbed + ) + + # 保存生成的图片到数据库 + _save_generated_images(batch_id, output_model_dir, is_perturbed) + + logger.info(f"Finetune task completed for batch {batch_id}") + return result + + except Exception as e: + logger.error(f"Finetune task failed for batch {batch_id}: {str(e)}", exc_info=True) + raise + + +def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, + class_dir, inference_prompts, is_perturbed, custom_params): + """运行真实微调算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real finetune: {finetune_method}") + + # 获取微调脚本路径和环境 + finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) + script_path = finetune_config.get('real_script') + conda_env = finetune_config.get('conda_env') + default_params = finetune_config.get('default_params', {}) + + if not script_path: + raise ValueError(f"Finetune method {finetune_method} not configured") + + # 合并参数 + params = {**default_params, **(custom_params or {})} + + # 构建命令行参数 + cmd_args = [ + f"--instance_data_dir={train_images_dir}", + f"--output_dir={output_model_dir}", + f"--class_data_dir={class_dir}", + f"--inference_prompts={inference_prompts}", + ] + + # 添加其他参数 + for key, value in params.items(): + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 构建完整命令 + cmd = [ + 'conda', 'run', '-n', conda_env, '--no-capture-output', + 'accelerate', 'launch', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + image_type = 'perturbed' if is_perturbed else 'original' + log_file = os.path.join( + log_dir, + f'finetune_{image_type}_{batch_id}_{finetune_method}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Finetune failed with code {process.returncode}. Check log: {log_file}") + + # 清理class_dir + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + import shutil + for item in os.listdir(class_dir): + item_path = os.path.join(class_dir, item) + if os.path.isfile(item_path): + os.remove(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + + return { + 'status': 'success', + 'output_dir': output_model_dir, + 'log_file': log_file + } + + +def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_model_dir, is_perturbed): + """运行虚拟微调实现""" + from config.algorithm_config import AlgorithmConfig + import glob + + logger.info(f"Running virtual finetune: {finetune_method}") + + # 获取微调配置 + finetune_config = AlgorithmConfig.FINETUNE_SCRIPTS.get(finetune_method, {}) + if not finetune_config: + raise ValueError(f"Finetune method {finetune_method} not configured") + + conda_env = finetune_config.get('conda_env') + default_params = finetune_config.get('default_params', {}) + + # 获取虚拟微调脚本路径 + script_name = 'train_dreambooth_gen.py' if finetune_method == 'dreambooth' else 'train_lora_gen.py' + script_path = os.path.abspath(os.path.join( + os.path.dirname(__file__), + '../algorithms/finetune_virtual', + script_name + )) + + if not os.path.exists(script_path): + raise FileNotFoundError(f"Virtual finetune script not found: {script_path}") + + logger.info(f"Virtual script path: {script_path}") + logger.info(f"Conda environment: {conda_env}") + + # 创建输出目录 + os.makedirs(output_model_dir, exist_ok=True) + validation_output_dir = os.path.join(output_model_dir, 'generated') + os.makedirs(validation_output_dir, exist_ok=True) + + # 构建命令行参数(与真实微调参数一致) + cmd_args = [ + f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", + f"--instance_data_dir={train_images_dir}", + f"--output_dir={output_model_dir}", + f"--validation_image_output_dir={validation_output_dir}", + f"--class_data_dir=/tmp/class_placeholder", + ] + + # 添加其他默认参数 + for key, value in default_params.items(): + if key == 'pretrained_model_name_or_path': + continue # 已添加 + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 使用conda run执行虚拟脚本 + cmd = [ + 'conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + image_type = 'perturbed' if is_perturbed else 'original' + log_file = os.path.join( + log_dir, + f'virtual_{finetune_method}_{image_type}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Virtual finetune failed with code {process.returncode}. Check log: {log_file}") + + # 统计生成的图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + generated_files = [] + for ext in image_extensions: + generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext))) + generated_files.extend(glob.glob(os.path.join(validation_output_dir, ext.upper()))) + + logger.info(f"Virtual finetune completed. Generated {len(generated_files)} images") + + return { + 'status': 'success', + 'output_dir': output_model_dir, + 'generated_count': len(generated_files), + 'generated_files': generated_files, + 'log_file': log_file + } + + +def _save_generated_images(batch_id, output_model_dir, is_perturbed): + """保存生成的图片到数据库""" + from app import db + from app.database import Batch, Image, ImageType + import glob + + try: + batch = Batch.query.get(batch_id) + if not batch: + return + + # 确定图片类型 + if is_perturbed: + image_type = ImageType.query.filter_by(type_code='perturbed_generate').first() + else: + image_type = ImageType.query.filter_by(type_code='original_generate').first() + + if not image_type: + logger.error(f"Image type not found for is_perturbed={is_perturbed}") + return + + # 查找生成的图片 + generated_dir = os.path.join(output_model_dir, 'generated') + if not os.path.exists(generated_dir): + # 尝试直接从output_model_dir查找 + generated_dir = output_model_dir + + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(generated_dir, ext))) + image_files.extend(glob.glob(os.path.join(generated_dir, ext.upper()))) + + logger.info(f"Found {len(image_files)} generated images to save") + + # 保存到数据库 + for image_path in image_files: + try: + from PIL import Image as PILImage + with PILImage.open(image_path) as img: + width, height = img.size + + # 查找对应的父图片(原始或扰动) + filename = os.path.basename(image_path) + # 去掉"generated_"前缀 + original_filename = filename.replace('generated_', '') + + # 尝试找到父图片 + if is_perturbed: + parent_type = ImageType.query.filter_by(type_code='perturbed').first() + else: + parent_type = ImageType.query.filter_by(type_code='original').first() + + parent_image = Image.query.filter_by( + batch_id=batch_id, + image_type_id=parent_type.id + ).filter(Image.original_filename.like(f"%{original_filename.split('_')[0]}%")).first() + + # 创建图片记录 + generated_image = Image( + user_id=batch.user_id, + batch_id=batch_id, + father_id=parent_image.id if parent_image else None, + original_filename=filename, + stored_filename=os.path.basename(image_path), + file_path=image_path, + file_size=os.path.getsize(image_path), + image_type_id=image_type.id, + width=width, + height=height + ) + + db.session.add(generated_image) + logger.info(f"Saved generated image: {filename}") + + except Exception as e: + logger.error(f"Failed to save {image_path}: {str(e)}") + + db.session.commit() + logger.info(f"Successfully saved {len(image_files)} generated images") + + except Exception as e: + logger.error(f"Error saving generated images: {str(e)}") + db.session.rollback() diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py new file mode 100644 index 0000000..2eb24d2 --- /dev/null +++ b/src/backend/app/workers/perturbation_worker.py @@ -0,0 +1,296 @@ +""" +RQ Worker任务处理器 +在后台执行对抗性扰动算法 +""" + +import os +import sys +import subprocess +import logging +from datetime import datetime +from pathlib import Path + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def run_perturbation_task(batch_id, algorithm_code, epsilon, use_strong_protection, + input_dir, output_dir, class_dir, custom_params=None): + """ + 执行对抗性扰动任务 + + Args: + batch_id: 任务批次ID + algorithm_code: 算法代码 + epsilon: 扰动强度 + use_strong_protection: 是否使用防净化版本 + input_dir: 输入图片目录 + output_dir: 输出目录 + class_dir: 类别图片目录 + custom_params: 自定义参数 + + Returns: + 任务执行结果 + """ + from config.algorithm_config import AlgorithmConfig + from app import create_app, db + from app.database import Batch + + # 创建应用上下文 + app = create_app() + + with app.app_context(): + try: + # 更新任务状态 + batch = Batch.query.get(batch_id) + if not batch: + raise ValueError(f"Batch {batch_id} not found") + + batch.status = 'processing' + batch.started_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Starting perturbation task for batch {batch_id}") + logger.info(f"Algorithm: {algorithm_code}, Epsilon: {epsilon}") + + # 获取算法配置 + use_real = AlgorithmConfig.USE_REAL_ALGORITHMS + script_path = AlgorithmConfig.get_script_path(algorithm_code) + conda_env = AlgorithmConfig.get_conda_env(algorithm_code) + + # 确保目录存在 + os.makedirs(output_dir, exist_ok=True) + os.makedirs(class_dir, exist_ok=True) + + if use_real: + # 使用真实算法 + result = _run_real_algorithm( + script_path, conda_env, algorithm_code, batch_id, + epsilon, use_strong_protection, input_dir, output_dir, + class_dir, custom_params + ) + else: + # 使用虚拟实现 + result = _run_virtual_algorithm( + algorithm_code, batch_id, epsilon, use_strong_protection, + input_dir, output_dir + ) + + # 更新任务状态为完成 + batch.status = 'completed' + batch.completed_at = datetime.utcnow() + db.session.commit() + + logger.info(f"Task completed successfully for batch {batch_id}") + return result + + except Exception as e: + logger.error(f"Task failed for batch {batch_id}: {str(e)}", exc_info=True) + + # 更新任务状态为失败 + if batch: + batch.status = 'failed' + batch.error_message = str(e) + batch.completed_at = datetime.utcnow() + db.session.commit() + + raise + + +def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id, + epsilon, use_strong_protection, input_dir, output_dir, + class_dir, custom_params): + """运行真实算法""" + from config.algorithm_config import AlgorithmConfig + + logger.info(f"Running real algorithm: {algorithm_code}") + logger.info(f"Conda environment: {conda_env}") + logger.info(f"Script path: {script_path}") + + # 获取默认参数 + default_params = AlgorithmConfig.get_default_params(algorithm_code) + + # 合并自定义参数 + params = {**default_params, **(custom_params or {})} + + # 构建命令行参数 + cmd_args = [ + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir={class_dir}", + f"--pgd_eps={epsilon}", + ] + + # 添加其他参数 + for key, value in params.items(): + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 构建完整命令 + # 使用conda run避免环境嵌套问题 + cmd = [ + 'conda', 'run', '-n', conda_env, '--no-capture-output', + 'accelerate', 'launch', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f'batch_{batch_id}_{algorithm_code}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + # 实时输出日志 + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}") + + # 清理class_dir + logger.info(f"Cleaning class directory: {class_dir}") + if os.path.exists(class_dir): + import shutil + for item in os.listdir(class_dir): + item_path = os.path.join(class_dir, item) + if os.path.isfile(item_path): + os.remove(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + + return { + 'status': 'success', + 'output_dir': output_dir, + 'log_file': log_file + } + + +def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protection, + input_dir, output_dir): + """运行虚拟算法实现""" + from config.algorithm_config import AlgorithmConfig + import glob + + logger.info(f"Running virtual algorithm: {algorithm_code}") + + # 获取算法配置 + algo_config = AlgorithmConfig.get_algorithm_config(algorithm_code) + if not algo_config: + raise ValueError(f"Algorithm {algorithm_code} not configured") + + conda_env = algo_config.get('conda_env') + default_params = algo_config.get('default_params', {}) + + # 获取虚拟算法脚本路径 + script_path = os.path.abspath(os.path.join( + os.path.dirname(__file__), + '../algorithms/perturbation_virtual', + f'{algorithm_code}.py' + )) + + if not os.path.exists(script_path): + raise FileNotFoundError(f"Virtual script not found: {script_path}") + + logger.info(f"Virtual script path: {script_path}") + logger.info(f"Conda environment: {conda_env}") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 构建命令行参数(与真实算法参数一致) + cmd_args = [ + f"--pretrained_model_name_or_path={default_params.get('pretrained_model_name_or_path', 'model_path')}", + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir=/tmp/class_placeholder", + f"--pgd_eps={epsilon}", + ] + + # 添加其他默认参数 + for key, value in default_params.items(): + if key == 'pretrained_model_name_or_path': + continue # 已添加 + if isinstance(value, bool): + if value: + cmd_args.append(f"--{key}") + else: + cmd_args.append(f"--{key}={value}") + + # 使用conda run执行虚拟脚本 + cmd = [ + 'conda', 'run', '-n', conda_env, '--no-capture-output', + 'python', script_path + ] + cmd_args + + logger.info(f"Executing command: {' '.join(cmd)}") + + # 设置日志文件 + log_dir = AlgorithmConfig.LOGS_DIR + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join( + log_dir, + f'virtual_{algorithm_code}_{batch_id}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' + ) + + # 执行命令 + with open(log_file, 'w') as f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + # 实时输出日志 + for line in process.stdout: + f.write(line) + f.flush() + logger.info(line.strip()) + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Virtual algorithm failed with code {process.returncode}. Check log: {log_file}") + + # 统计处理的图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + processed_files = [] + for ext in image_extensions: + processed_files.extend(glob.glob(os.path.join(output_dir, ext))) + processed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + + logger.info(f"Virtual algorithm completed. Processed {len(processed_files)} images") + + return { + 'status': 'success', + 'output_dir': output_dir, + 'processed_count': len(processed_files), + 'processed_files': processed_files, + 'log_file': log_file + } diff --git a/src/backend/config/.env b/src/backend/config/.env index be6a0db..f6d54c1 100644 --- a/src/backend/config/.env +++ b/src/backend/config/.env @@ -3,9 +3,9 @@ # 数据库配置 DB_USER=root -DB_PASSWORD=your_password_here +DB_PASSWORD=971817787Lh DB_HOST=localhost -DB_NAME=your_database_name_here +DB_NAME=db # Flask配置 SECRET_KEY=museguard-secret-key-2024 diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py new file mode 100644 index 0000000..65c4d8c --- /dev/null +++ b/src/backend/config/algorithm_config.py @@ -0,0 +1,197 @@ +""" +算法配置 +定义各种对抗性扰动算法的参数、环境和脚本路径 +""" + +import os +from dotenv import load_dotenv + +# 加载算法专用环境变量 +config_dir = os.path.dirname(os.path.abspath(__file__)) +algorithm_env_path = os.path.join(config_dir, 'algorithm.env') +load_dotenv(algorithm_env_path) + +class AlgorithmConfig: + """算法配置基类""" + + # 算法脚本根目录 + ALGORITHMS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'app', 'algorithms') + + # 是否使用真实算法(从环境变量读取,默认False) + USE_REAL_ALGORITHMS = os.getenv('USE_REAL_ALGORITHMS', 'false').lower() == 'true' + + # Redis配置(从环境变量读取) + REDIS_URL = os.getenv('REDIS_URL', 'redis://localhost:6379/0') + + # RQ队列名称 + RQ_QUEUE_NAME = os.getenv('RQ_QUEUE_NAME', 'perturbation_tasks') + + # 任务超时时间(秒) + TASK_TIMEOUT = int(os.getenv('TASK_TIMEOUT', '3600')) + + # 日志目录 + LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') + + # Conda环境配置(从环境变量读取,支持自定义) + CONDA_ENVS = { + 'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'), + 'simac': os.getenv('CONDA_ENV_SIMAC', 'simac'), + 'caat': os.getenv('CONDA_ENV_CAAT', 'simac'), + 'pid': os.getenv('CONDA_ENV_PID', 'simac'), + 'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'simac'), + 'lora': os.getenv('CONDA_ENV_LORA', 'simac'), + } + + # 算法脚本配置 + ALGORITHM_SCRIPTS = { + 'aspl': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'aspl.py'), + 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), + 'conda_env': CONDA_ENVS['aspl'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'num_class_images': 200, + 'center_crop': True, + 'with_prior_preservation': True, + 'prior_loss_weight': 1.0, + 'resolution': 384, + 'train_batch_size': 1, + 'max_train_steps': 50, + 'max_f_train_steps': 3, + 'max_adv_train_steps': 6, + 'checkpointing_iterations': 10, + 'learning_rate': 5e-7, + 'pgd_alpha': 0.005, + 'pgd_eps': 8, + 'seed': 0 + } + }, + 'simac': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'simac.py'), + 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), + 'conda_env': CONDA_ENVS['simac'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'num_class_images': 200, + 'resolution': 512, + 'train_batch_size': 1, + 'max_train_steps': 1000, + 'learning_rate': 5e-6, + 'pgd_eps': 8, + 'seed': 0 + } + }, + 'caat': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'caat.py'), + 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), + 'conda_env': CONDA_ENVS['caat'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'num_class_images': 200, + 'resolution': 512, + 'train_batch_size': 2, + 'max_train_steps': 800, + 'learning_rate': 1e-5, + 'pgd_eps': 16, + 'seed': 0 + } + }, + 'pid': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'), + 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), + 'conda_env': CONDA_ENVS['pid'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'num_class_images': 200, + 'resolution': 512, + 'train_batch_size': 1, + 'max_train_steps': 600, + 'learning_rate': 3e-6, + 'pgd_eps': 4, + 'seed': 0 + } + } + } + + @classmethod + def get_algorithm_config(cls, algorithm_code): + """获取算法配置""" + return cls.ALGORITHM_SCRIPTS.get(algorithm_code, {}) + + @classmethod + def get_script_path(cls, algorithm_code): + """获取算法脚本路径""" + config = cls.get_algorithm_config(algorithm_code) + if cls.USE_REAL_ALGORITHMS: + return config.get('real_script') + else: + return config.get('virtual_script') + + @classmethod + def get_conda_env(cls, algorithm_code): + """获取算法的conda环境名称""" + config = cls.get_algorithm_config(algorithm_code) + return config.get('conda_env') + + @classmethod + def get_default_params(cls, algorithm_code): + """获取算法默认参数""" + config = cls.get_algorithm_config(algorithm_code) + return config.get('default_params', {}).copy() + + # ========== 微调算法配置 ========== + FINETUNE_SCRIPTS = { + 'dreambooth': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_alone.py'), + 'virtual_script': None, # 使用虚拟实现在worker中 + 'conda_env': CONDA_ENVS['dreambooth'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', + 'num_class_images': 200, + 'resolution': 512, + 'train_batch_size': 1, + 'num_train_epochs': 1, + 'max_train_steps': 1000, + 'learning_rate': 5e-6, + 'with_prior_preservation': True, + 'prior_loss_weight': 1.0, + 'seed': 0 + } + }, + 'lora': { + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora.py'), + 'virtual_script': None, + 'conda_env': CONDA_ENVS['lora'], + 'default_params': { + 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'enable_xformers_memory_efficient_attention': True, + 'instance_prompt': 'a photo of sks person', + 'resolution': 512, + 'train_batch_size': 1, + 'max_train_steps': 800, + 'learning_rate': 1e-4, + 'rank': 4, + 'seed': 0 + } + } + } + + @classmethod + def get_finetune_config(cls, finetune_method): + """获取微调算法配置""" + return cls.FINETUNE_SCRIPTS.get(finetune_method, {}) diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index ecaff13..1078878 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -1,109 +1,109 @@ -""" -应用配置文件 -""" -import os -from datetime import timedelta -from dotenv import load_dotenv -from urllib.parse import quote_plus - -# 加载环境变量 - 从 config 目录读取 .env 文件 -config_dir = os.path.dirname(os.path.abspath(__file__)) -env_path = os.path.join(config_dir, '.env') -load_dotenv(env_path) - -class Config: - """基础配置类""" - - # 基础配置 - SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024' - - # 数据库配置 - 支持密码中的特殊字符 - DB_USER = os.environ.get('DB_USER') - DB_PASSWORD = os.environ.get('DB_PASSWORD') - DB_HOST = os.environ.get('DB_HOST') or 'localhost' - DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema' - - # URL编码密码中的特殊字符 - from urllib.parse import quote_plus - _encoded_password = quote_plus(DB_PASSWORD) - - SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \ - f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}' - SQLALCHEMY_TRACK_MODIFICATIONS = False - SQLALCHEMY_ENGINE_OPTIONS = { - 'pool_pre_ping': True, - 'pool_recycle': 300, - } - - # JWT配置 - JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string' - JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24) - JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30) - - # 静态文件根目录 - STATIC_ROOT = 'static' - - # 文件上传配置 - UPLOAD_FOLDER = 'uploads' # 临时上传目录 - MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB - ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'} - - # 图像处理配置 - ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片 - PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 - MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 - MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果 - MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 - HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 - - # 预设演示图像配置 - DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录 - DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片 - DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片 - DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图 - - # 邮件配置(用于注册验证) - MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com' - MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587) - MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1'] - MAIL_USERNAME = os.environ.get('MAIL_USERNAME') - MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD') - - # 算法配置 - ALGORITHMS = { - 'simac': { - 'name': 'SimAC算法', - 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', - 'default_epsilon': 8.0 - }, - 'caat': { - 'name': 'CAAT算法', - 'description': 'Perturbing Attention Gives You More Bang for the Buck', - 'default_epsilon': 16.0 - }, - 'pid': { - 'name': 'PID算法', - 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', - 'default_epsilon': 4.0 - } - } - -class DevelopmentConfig(Config): - """开发环境配置""" - DEBUG = True - -class ProductionConfig(Config): - """生产环境配置""" - DEBUG = False - -class TestingConfig(Config): - """测试环境配置""" - TESTING = True - SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' - -config = { - 'development': DevelopmentConfig, - 'production': ProductionConfig, - 'testing': TestingConfig, - 'default': DevelopmentConfig +""" +应用配置文件 +""" +import os +from datetime import timedelta +from dotenv import load_dotenv +from urllib.parse import quote_plus + +# 加载环境变量 - 从 config 目录读取 .env 文件 +config_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(config_dir, '.env') +load_dotenv(env_path) + +class Config: + """基础配置类""" + + # 基础配置 + SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024' + + # 数据库配置 - 支持密码中的特殊字符 + DB_USER = os.environ.get('DB_USER') or 'root' + DB_PASSWORD = os.environ.get('DB_PASSWORD') or '' + DB_HOST = os.environ.get('DB_HOST') or 'localhost' + DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema' + + # URL编码密码中的特殊字符 + from urllib.parse import quote_plus + _encoded_password = quote_plus(DB_PASSWORD) if DB_PASSWORD else '' + + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \ + f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}' + SQLALCHEMY_TRACK_MODIFICATIONS = False + SQLALCHEMY_ENGINE_OPTIONS = { + 'pool_pre_ping': True, + 'pool_recycle': 300, + } + + # JWT配置 + JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string' + JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24) + JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30) + + # 静态文件根目录 + STATIC_ROOT = 'static' + + # 文件上传配置 + UPLOAD_FOLDER = 'uploads' # 临时上传目录 + MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB + ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'} + + # 图像处理配置 + ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片 + PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 + MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 + MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果 + MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 + HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 + + # 预设演示图像配置 + DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录 + DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片 + DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片 + DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图 + + # 邮件配置(用于注册验证) + MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com' + MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587) + MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1'] + MAIL_USERNAME = os.environ.get('MAIL_USERNAME') + MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD') + + # 算法配置 + ALGORITHMS = { + 'simac': { + 'name': 'SimAC算法', + 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', + 'default_epsilon': 8.0 + }, + 'caat': { + 'name': 'CAAT算法', + 'description': 'Perturbing Attention Gives You More Bang for the Buck', + 'default_epsilon': 16.0 + }, + 'pid': { + 'name': 'PID算法', + 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', + 'default_epsilon': 4.0 + } + } + +class DevelopmentConfig(Config): + """开发环境配置""" + DEBUG = True + +class ProductionConfig(Config): + """生产环境配置""" + DEBUG = False + +class TestingConfig(Config): + """测试环境配置""" + TESTING = True + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig } \ No newline at end of file diff --git a/src/backend/init_db.py b/src/backend/init_db.py index 009ae84..258a634 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -1,81 +1,81 @@ -""" -数据库初始化脚本 -""" - -from app import create_app, db -from app.models import * - -def init_database(): - """初始化数据库""" - app = create_app() - - with app.app_context(): - # 创建所有表 - db.create_all() - - # 初始化图片类型数据 - image_types = [ - {'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'}, - {'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, - {'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, - {'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'} - ] - - for img_type in image_types: - existing = ImageType.query.filter_by(type_code=img_type['type_code']).first() - if not existing: - new_type = ImageType(**img_type) - db.session.add(new_type) - - # 初始化加噪算法数据 - perturbation_configs = [ - {'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0}, - {'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0}, - {'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0}, - {'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0} - ] - - for config in perturbation_configs: - existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first() - if not existing: - new_config = PerturbationConfig(**config) - db.session.add(new_config) - - # 初始化微调方式数据 - finetune_configs = [ - {'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'}, - {'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应(Low-Rank Adaptation)微调方法'}, - {'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'} - ] - - for config in finetune_configs: - existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first() - if not existing: - new_config = FinetuneConfig(**config) - db.session.add(new_config) - - # 创建默认管理员用户 - admin_users = [ - {'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'}, - {'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'}, - {'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'} - ] - - for admin_data in admin_users: - existing = User.query.filter_by(username=admin_data['username']).first() - if not existing: - admin_user = User(**admin_data) - admin_user.set_password('admin123') # 默认密码 - db.session.add(admin_user) - - # 为管理员创建默认配置 - db.session.flush() # 确保user.id可用 - user_config = UserConfig(user_id=admin_user.id) - db.session.add(user_config) - - # 提交所有更改 - db.session.commit() - print("数据库初始化完成!") - -if __name__ == '__main__': +""" +数据库初始化脚本 +""" + +from app import create_app, db +from app.database import * + +def init_database(): + """初始化数据库""" + app = create_app() + + with app.app_context(): + # 创建所有表 + db.create_all() + + # 初始化图片类型数据 + image_types = [ + {'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'}, + {'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, + {'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, + {'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'} + ] + + for img_type in image_types: + existing = ImageType.query.filter_by(type_code=img_type['type_code']).first() + if not existing: + new_type = ImageType(**img_type) + db.session.add(new_type) + + # 初始化加噪算法数据 + perturbation_configs = [ + {'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0}, + {'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0}, + {'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0}, + {'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0} + ] + + for config in perturbation_configs: + existing = PerturbationConfig.query.filter_by(method_code=config['method_code']).first() + if not existing: + new_config = PerturbationConfig(**config) + db.session.add(new_config) + + # 初始化微调方式数据 + finetune_configs = [ + {'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'}, + {'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应(Low-Rank Adaptation)微调方法'}, + {'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'} + ] + + for config in finetune_configs: + existing = FinetuneConfig.query.filter_by(method_code=config['method_code']).first() + if not existing: + new_config = FinetuneConfig(**config) + db.session.add(new_config) + + # 创建默认管理员用户 + admin_users = [ + {'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'}, + {'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'}, + {'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'} + ] + + for admin_data in admin_users: + existing = User.query.filter_by(username=admin_data['username']).first() + if not existing: + admin_user = User(**admin_data) + admin_user.set_password('admin123') # 默认密码 + db.session.add(admin_user) + + # 为管理员创建默认配置 + db.session.flush() # 确保user.id可用 + user_config = UserConfig(user_id=admin_user.id) + db.session.add(user_config) + + # 提交所有更改 + db.session.commit() + print("数据库初始化完成!") + +if __name__ == '__main__': init_database() \ No newline at end of file diff --git a/src/backend/quick_start.bat b/src/backend/quick_start.bat deleted file mode 100644 index c274f00..0000000 --- a/src/backend/quick_start.bat +++ /dev/null @@ -1,15 +0,0 @@ -@echo off -chcp 65001 > nul -echo ========================================== -echo MuseGuard 快速启动脚本 -echo ========================================== -echo. - -REM 切换到项目目录 -cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend" - -REM 激活虚拟环境并启动服务器 -echo 正在激活虚拟环境并启动服务器... -call venv_py311\Scripts\activate.bat && python run.py - -pause \ No newline at end of file diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index b0b6fd7..5865774 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -1,33 +1,37 @@ -# Core Flask Framework -Flask==3.0.0 -Flask-SQLAlchemy==3.1.1 -Flask-Migrate==4.0.5 -Flask-JWT-Extended==4.6.0 -Flask-CORS==5.0.0 -Werkzeug==3.0.1 - -# Database -PyMySQL==1.1.1 - -# Image Processing -Pillow==10.4.0 -numpy==1.26.4 - -# Security & Utils -cryptography==42.0.8 -python-dotenv==1.0.1 - -# Additional Dependencies (auto-installed) -# alembic==1.17.0 -# blinker==1.9.0 -# cffi==2.0.0 -# click==8.3.0 -# greenlet==3.2.4 -# itsdangerous==2.2.0 -# Jinja2==3.1.6 -# Mako==1.3.10 -# MarkupSafe==3.0.3 -# pycparser==2.23 -# PyJWT==2.10.1 -# SQLAlchemy==2.0.44 +# Core Flask Framework +Flask==3.0.0 +Flask-SQLAlchemy==3.1.1 +Flask-Migrate==4.0.5 +Flask-JWT-Extended==4.6.0 +Flask-CORS==5.0.0 +Werkzeug==3.0.1 + +# Database +PyMySQL==1.1.1 + +# Image Processing +Pillow==10.4.0 +numpy==1.26.4 + +# Security & Utils +cryptography==42.0.8 +python-dotenv==1.0.1 + +# Task Queue +redis==5.0.1 +rq==1.16.2 + +# Additional Dependencies (auto-installed) +# alembic==1.17.0 +# blinker==1.9.0 +# cffi==2.0.0 +# click==8.3.0 +# greenlet==3.2.4 +# itsdangerous==2.2.0 +# Jinja2==3.1.6 +# Mako==1.3.10 +# MarkupSafe==3.0.3 +# pycparser==2.23 +# PyJWT==2.10.1 +# SQLAlchemy==2.0.44 # typing_extensions==4.15.0 \ No newline at end of file diff --git a/src/backend/run.py b/src/backend/run.py index 0f24a01..cca949e 100644 --- a/src/backend/run.py +++ b/src/backend/run.py @@ -1,21 +1,21 @@ -""" -Flask应用启动脚本 -""" - -import os -from app import create_app - -# 设置环境变量 -os.environ.setdefault('FLASK_ENV', 'development') - -# 创建应用实例 -app = create_app() - -if __name__ == '__main__': - # 开发模式启动 - app.run( - host='0.0.0.0', - port=5000, - debug=True, - threaded=True +""" +Flask应用启动脚本 +""" + +import os +from app import create_app + +# 设置环境变量 +os.environ.setdefault('FLASK_ENV', 'development') + +# 创建应用实例 +app = create_app() + +if __name__ == '__main__': + # 开发模式启动 + app.run( + host='0.0.0.0', + port=5000, + debug=True, + threaded=True ) \ No newline at end of file diff --git a/src/backend/start.sh b/src/backend/start.sh new file mode 100644 index 0000000..7730901 --- /dev/null +++ b/src/backend/start.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# MuseGuard 后端快速启动脚本 + +echo "========================================" +echo " MuseGuard 后端服务启动" +echo "========================================" +echo "" + +# 获取脚本所在目录 +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +# 激活conda环境 +echo "激活 conda 环境: flask" +source ~/anaconda3/etc/profile.d/conda.sh +conda activate flask + +# 检查conda环境是否激活成功 +if [ "$CONDA_DEFAULT_ENV" != "flask" ]; then + echo "[错误] 无法激活 flask 环境" + exit 1 +fi + +echo "[成功] conda 环境已激活: $CONDA_DEFAULT_ENV" +echo "" + +# 检查Redis是否运行 +echo "检查 Redis 连接..." +if redis-cli ping > /dev/null 2>&1; then + echo "[成功] Redis 连接正常" +else + echo "[警告] Redis 未运行,正在启动 Redis..." + redis-server --daemonize yes + sleep 2 + if redis-cli ping > /dev/null 2>&1; then + echo "[成功] Redis 已启动" + else + echo "[错误] 无法启动 Redis,请手动启动:" + echo " sudo systemctl start redis-server" + echo " 或: redis-server --daemonize yes" + exit 1 + fi +fi +echo "" + +# 创建日志目录 +mkdir -p logs + +# 停止已存在的进程 +echo "检查并停止已存在的进程..." +pkill -f "python run.py" 2>/dev/null +pkill -f "python worker.py" 2>/dev/null +sleep 1 + +# 启动Flask应用 +echo "启动 Flask 应用..." +nohup python run.py > logs/flask.log 2>&1 & +FLASK_PID=$! +echo "Flask 应用已启动 (PID: $FLASK_PID)" +echo "" + +# 等待Flask启动 +sleep 2 + +# 启动RQ Worker +echo "启动 RQ Worker..." +nohup python worker.py > logs/worker.log 2>&1 & +WORKER_PID=$! +echo "RQ Worker 已启动 (PID: $WORKER_PID)" +echo "" + +# 保存PID到文件 +echo $FLASK_PID > logs/flask.pid +echo $WORKER_PID > logs/worker.pid + +echo "========================================" +echo " 启动完成!" +echo "========================================" +echo "" +echo "服务信息:" +echo " - Flask API: http://127.0.0.1:5000" +echo " - Flask PID: $FLASK_PID" +echo " - Worker PID: $WORKER_PID" +echo "" +echo "查看日志:" +echo " - Flask: tail -f logs/flask.log" +echo " - Worker: tail -f logs/worker.log" +echo "" +echo "停止服务:" +echo " - 执行: ./stop.sh" +echo " - 或手动: kill $FLASK_PID $WORKER_PID" +echo "" diff --git a/src/backend/static/test.html b/src/backend/static/test.html index 4fb5976..719ec40 100644 --- a/src/backend/static/test.html +++ b/src/backend/static/test.html @@ -1,1644 +1,1644 @@ - - - - - - MuseGuard API 全功能测试页面 - - - -
-

🧪 MuseGuard API 测试页面

- - -
-

🌐 服务器连通性测试

- - -
- - -
-

🎨 Demo Controller - 演示模块

-
-
-

演示图片

- - -
- -
-

算法信息

- -
- - -
-
- - -
-

🧑‍💻 Auth Controller - 认证模块

-
-
-

用户注册

-
- - -
-
- - -
-
- - -
- -
- -
-

用户登录

-
- - -
-
- - -
- - - -

修改密码

-
- - -
-
- - -
-
- - -
- - - - -
- -
-
- - -
-

🔄 Task Controller - 任务管理模块

-
-
-

1. 创建任务(第一步)

-
- - -
- -
- -
-

我的批次列表

-
-

正在加载批次...

-
- -
- -
-

2. 文件上传(第二步)

-
- - -
-
- -
- -

点击或拖放图片文件到此处

-
-
- -
- -
-

3. 配置任务(第三步)

-

配置已自动加载上次使用的设置,您可以根据需要调整

- -
- - -
-
- - - 推荐范围: 1.0 - 16.0 -
-
- - -
-
- - 强防护模式可提高安全性,但会增加处理时间 -
- - -
- -
-

4. 任务管理(第四步)

- - - - -
- -
-
- - - - -
-

🖼️ Image Controller - 图像处理模块

-
-
-

图像查看和下载

-
- - -
- - - -
- -
-

图像评估和对比

-
- - -
-
- - -
- - -
- -
-

其他功能

- - -
- - -
-
- - -
-

👨‍💼 Admin Controller - 管理员模块

-
-
-

用户管理

-
-
- - -
-
- - -
-
- - -
- - -
- -
- -
-

用户创建和编辑

-
- - -
-
- - -
-
- - -
- - - -
- -
-

系统统计

- -
- -
-
-
- - - + + + + + + MuseGuard API 全功能测试页面 + + + +
+

🧪 MuseGuard API 测试页面

+ + +
+

🌐 服务器连通性测试

+ + +
+ + +
+

🎨 Demo Controller - 演示模块

+
+
+

演示图片

+ + +
+ +
+

算法信息

+ +
+ + +
+
+ + +
+

🧑‍💻 Auth Controller - 认证模块

+
+
+

用户注册

+
+ + +
+
+ + +
+
+ + +
+ +
+ +
+

用户登录

+
+ + +
+
+ + +
+ + + +

修改密码

+
+ + +
+
+ + +
+
+ + +
+ + + + +
+ +
+
+ + +
+

🔄 Task Controller - 任务管理模块

+
+
+

1. 创建任务(第一步)

+
+ + +
+ +
+ +
+

我的批次列表

+
+

正在加载批次...

+
+ +
+ +
+

2. 文件上传(第二步)

+
+ + +
+
+ +
+ +

点击或拖放图片文件到此处

+
+
+ +
+ +
+

3. 配置任务(第三步)

+

配置已自动加载上次使用的设置,您可以根据需要调整

+ +
+ + +
+
+ + + 推荐范围: 1.0 - 16.0 +
+
+ + +
+
+ + 强防护模式可提高安全性,但会增加处理时间 +
+ + +
+ +
+

4. 任务管理(第四步)

+ + + + +
+ +
+
+ + + + +
+

🖼️ Image Controller - 图像处理模块

+
+
+

图像查看和下载

+
+ + +
+ + + +
+ +
+

图像评估和对比

+
+ + +
+
+ + +
+ + +
+ +
+

其他功能

+ + +
+ + +
+
+ + +
+

👨‍💼 Admin Controller - 管理员模块

+
+
+

用户管理

+
+
+ + +
+
+ + +
+
+ + +
+ + +
+ +
+ +
+

用户创建和编辑

+
+ + +
+
+ + +
+
+ + +
+ + + +
+ +
+

系统统计

+ +
+ +
+
+
+ + + \ No newline at end of file diff --git a/src/backend/status.sh b/src/backend/status.sh new file mode 100644 index 0000000..559c15d --- /dev/null +++ b/src/backend/status.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# MuseGuard 后端服务状态检查脚本 + +echo "========================================" +echo " MuseGuard 后端服务状态" +echo "========================================" +echo "" + +# 获取脚本所在目录 +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +# 检查Flask应用 +echo "📌 Flask 应用:" +if [ -f logs/flask.pid ]; then + FLASK_PID=$(cat logs/flask.pid) + if ps -p $FLASK_PID > /dev/null 2>&1; then + echo " ✅ 运行中 (PID: $FLASK_PID)" + echo " 📍 URL: http://127.0.0.1:5000" + else + echo " ❌ 未运行 (PID文件存在但进程不存在)" + fi +else + if pgrep -f "python run.py" > /dev/null 2>&1; then + FLASK_PID=$(pgrep -f "python run.py") + echo " ⚠️ 运行中但无PID文件 (PID: $FLASK_PID)" + else + echo " ❌ 未运行" + fi +fi +echo "" + +# 检查Worker +echo "📌 RQ Worker:" +if [ -f logs/worker.pid ]; then + WORKER_PID=$(cat logs/worker.pid) + if ps -p $WORKER_PID > /dev/null 2>&1; then + echo " ✅ 运行中 (PID: $WORKER_PID)" + else + echo " ❌ 未运行 (PID文件存在但进程不存在)" + fi +else + if pgrep -f "python worker.py" > /dev/null 2>&1; then + WORKER_PID=$(pgrep -f "python worker.py") + echo " ⚠️ 运行中但无PID文件 (PID: $WORKER_PID)" + else + echo " ❌ 未运行" + fi +fi +echo "" + +# 检查Redis +echo "📌 Redis:" +if redis-cli ping > /dev/null 2>&1; then + echo " ✅ 运行中" +else + echo " ❌ 未运行" +fi +echo "" + +# 检查日志文件 +echo "📌 日志文件:" +if [ -f logs/flask.log ]; then + FLASK_LOG_SIZE=$(du -h logs/flask.log | cut -f1) + echo " Flask: logs/flask.log ($FLASK_LOG_SIZE)" +else + echo " Flask: 无日志文件" +fi + +if [ -f logs/worker.log ]; then + WORKER_LOG_SIZE=$(du -h logs/worker.log | cut -f1) + echo " Worker: logs/worker.log ($WORKER_LOG_SIZE)" +else + echo " Worker: 无日志文件" +fi +echo "" + +echo "========================================" +echo " 快速操作" +echo "========================================" +echo "启动服务: ./start.sh" +echo "停止服务: ./stop.sh" +echo "查看Flask日志: tail -f logs/flask.log" +echo "查看Worker日志: tail -f logs/worker.log" +echo "" diff --git a/src/backend/stop.sh b/src/backend/stop.sh new file mode 100644 index 0000000..ceffd8a --- /dev/null +++ b/src/backend/stop.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# MuseGuard 后端服务停止脚本 + +echo "========================================" +echo " 停止 MuseGuard 后端服务" +echo "========================================" +echo "" + +# 获取脚本所在目录 +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +# 从PID文件停止进程 +if [ -f logs/flask.pid ]; then + FLASK_PID=$(cat logs/flask.pid) + if ps -p $FLASK_PID > /dev/null 2>&1; then + echo "停止 Flask 应用 (PID: $FLASK_PID)..." + kill $FLASK_PID + echo "[成功] Flask 应用已停止" + else + echo "[提示] Flask 应用未运行" + fi + rm logs/flask.pid +else + echo "[提示] 未找到 Flask PID 文件" +fi + +if [ -f logs/worker.pid ]; then + WORKER_PID=$(cat logs/worker.pid) + if ps -p $WORKER_PID > /dev/null 2>&1; then + echo "停止 RQ Worker (PID: $WORKER_PID)..." + kill $WORKER_PID + echo "[成功] RQ Worker 已停止" + else + echo "[提示] RQ Worker 未运行" + fi + rm logs/worker.pid +else + echo "[提示] 未找到 Worker PID 文件" +fi + +# 确保所有相关进程都停止 +echo "" +echo "清理所有相关进程..." +pkill -f "python run.py" 2>/dev/null && echo "清理了额外的 Flask 进程" +pkill -f "python worker.py" 2>/dev/null && echo "清理了额外的 Worker 进程" + +echo "" +echo "========================================" +echo " 服务已停止" +echo "========================================" diff --git a/src/backend/worker.py b/src/backend/worker.py new file mode 100644 index 0000000..2c4d77c --- /dev/null +++ b/src/backend/worker.py @@ -0,0 +1,42 @@ +""" +RQ Worker 启动脚本 +用于启动后台任务处理器 +""" + +import sys +import os + +# 添加项目路径到Python路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from redis import Redis +from rq import Worker, Queue +from config.algorithm_config import AlgorithmConfig +from app import create_app + +# 创建Flask应用上下文 +app = create_app() + +def main(): + """启动worker""" + with app.app_context(): + # 连接Redis + redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL) + + # 创建队列 + queue = Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn) + + # 创建worker + worker = Worker([queue], connection=redis_conn) + + print(f"🚀 RQ Worker启动成功!") + print(f"📡 Redis: {AlgorithmConfig.REDIS_URL}") + print(f"📋 Queue: {AlgorithmConfig.RQ_QUEUE_NAME}") + print(f"🔄 使用{'真实' if AlgorithmConfig.USE_REAL_ALGORITHMS else '虚拟'}算法") + print(f"⏳ 等待任务...") + + # 启动worker + worker.work() + +if __name__ == '__main__': + main() -- 2.34.1 From 5b014a74d8315757131edc9fe5d94ec9ff01a7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sat, 15 Nov 2025 16:11:18 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BA=91=E6=9C=8D=E5=8A=A1=E5=99=A8?= =?UTF-8?q?=E6=88=90=E5=8A=9F=E8=BF=90=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/.env.example | 36 - src/backend/.gitignore | 26 + src/backend/README.md | 118 +- src/backend/app.py | 2 +- .../finetune/train_dreambooth_gen.py | 12 +- .../app/algorithms/finetune/train_lora_gen.py | 11 +- .../finetune_virtual/train_dreambooth_gen.py | 24 +- .../finetune_virtual/train_lora_gen.py | 24 +- .../app/algorithms/perturbation/aspl.py | 3 + .../app/algorithms/perturbation/caat.py | 2 +- .../app/algorithms/perturbation/pid.py | 4 +- .../algorithms/perturbation_virtual/aspl.py | 7 +- .../algorithms/perturbation_virtual/caat.py | 11 +- .../algorithms/perturbation_virtual/pid.py | 9 +- .../algorithms/perturbation_virtual/simac.py | 11 +- .../app/controllers/task_controller.py | 345 ++- src/backend/app/database/__init__.py | 59 +- src/backend/app/services/image_service.py | 184 +- src/backend/app/services/task_service.py | 119 +- src/backend/app/workers/finetune_worker.py | 162 +- .../app/workers/perturbation_worker.py | 148 +- src/backend/config/.env | 16 - src/backend/config/algorithm_config.py | 121 +- src/backend/config/settings.py | 6 +- src/backend/run.py | 2 +- src/backend/start.sh | 23 +- src/backend/static/test.html | 2209 +++++------------ src/backend/static/test0.html | 1644 ++++++++++++ src/backend/status.sh | 3 +- src/backend/stop.sh | 0 30 files changed, 3404 insertions(+), 1937 deletions(-) delete mode 100644 src/backend/.env.example create mode 100644 src/backend/.gitignore delete mode 100644 src/backend/config/.env mode change 100644 => 100755 src/backend/start.sh create mode 100644 src/backend/static/test0.html mode change 100644 => 100755 src/backend/status.sh mode change 100644 => 100755 src/backend/stop.sh diff --git a/src/backend/.env.example b/src/backend/.env.example deleted file mode 100644 index 6f93179..0000000 --- a/src/backend/.env.example +++ /dev/null @@ -1,36 +0,0 @@ -# ============================================ -# 数据库配置 -# ============================================ -DB_HOST=localhost -DB_PORT=3306 -DB_USER=root -DB_PASSWORD=your_password -DB_NAME=museguard_schema - -# ============================================ -# Flask应用配置 -# ============================================ -SECRET_KEY=your-secret-key-here -JWT_SECRET_KEY=your-jwt-secret-key-here -FLASK_ENV=development - -# ============================================ -# Redis配置(用于任务队列) -# ============================================ -REDIS_URL=redis://localhost:6379/0 - -# ============================================ -# 算法模式配置 -# ============================================ -# true: 使用真实算法(需要conda环境和完整依赖) -# false: 使用虚拟算法(快速测试,不需要GPU和模型) -USE_REAL_ALGORITHMS=false - -# ============================================ -# 邮件配置(可选,用于注册验证) -# ============================================ -MAIL_SERVER=smtp.gmail.com -MAIL_PORT=587 -MAIL_USE_TLS=true -MAIL_USERNAME=your_email@gmail.com -MAIL_PASSWORD=your_email_password diff --git a/src/backend/.gitignore b/src/backend/.gitignore new file mode 100644 index 0000000..781efa2 --- /dev/null +++ b/src/backend/.gitignore @@ -0,0 +1,26 @@ +# Python 编译缓存 +__pycache__/ + +# 图片文件 +*.png +*.jpg +*.jpeg + +# 环境配置文件(包含敏感信息) +*.env + +# 日志及进程文件 +logs/ +*.log +*.pid + +# 上传文件临时目录 +uploads/ + +# 微调生成文件 +*.json +*.bin +*.pkl +*.safetensors +*.pt +*.txt \ No newline at end of file diff --git a/src/backend/README.md b/src/backend/README.md index 5b53032..f0d4969 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -2,6 +2,118 @@ 基于对抗性扰动的多风格图像生成防护系统 - 后端API服务 + +## Linux 环境配置(MySQL、Redis、Python等) + +### 1. 安装系统依赖 + +```bash +sudo apt update +sudo apt install -y build-essential python3 python3-venv python3-pip git +``` + +### 2. 安装 MySQL + +老版使用 +```bash +# 启动 Redis +sudo service mysqld start +# 停止 Redis +sudo service mysqld stop +# 重启 Redis +sudo service mysqld restart +# 查看 Redis 状态 +sudo service mysqld status +``` + + +```bash +sudo apt install -y mysql-server +sudo systemctl enable mysql +sudo systemctl start mysql + +# 安全初始化(可选,建议设置root密码) +sudo mysql_secure_installation + +# 登录MySQL创建数据库和用户 +mysql -u root -p +``` +在MySQL命令行中执行: +```sql +CREATE DATABASE museguard DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +# CREATE USER 'museguard'@'localhost' IDENTIFIED BY 'yourpassword'; +# GRANT ALL PRIVILEGES ON museguard.* TO 'museguard'@'localhost'; +# FLUSH PRIVILEGES; +EXIT; +``` + +### 3. 安装 Redis + +老版使用service命令 +```bash +# 启动 Redis +sudo service redis-server start +# 停止 Redis +sudo service redis-server stop +# 重启 Redis +sudo service redis-server restart +# 查看 Redis 状态 +sudo service redis-server status +``` + +```bash +sudo apt install -y redis-server +sudo systemctl enable redis-server +sudo systemctl start redis-server + +# 测试 +redis-cli ping +# 返回PONG表示正常 +``` + +### 4. Python 虚拟环境与依赖 + +在Linux换成conda管理 +```bash +cd /path/to/your/project/src/backend +python3 -m venv venv +source venv/bin/activate +python -m pip install --upgrade pip +pip install -r requirements.txt +``` + +### 5. 配置数据库连接 + +编辑 `config/settings.py` 或 `setting.env` 文件,设置如下内容: +```python +DROP DATABASE IF EXISTS muse_guard; +CREATE DATABASE muse_guard DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +EXIT; + +SQLALCHEMY_DATABASE_URI = 'mysql+pymysql://museguard:yourpassword@localhost:3306/museguard?charset=utf8mb4' +REDIS_URL = 'redis://localhost:6379/0' +``` + +### 6. 初始化数据库 + +```bash +python init_db.py +``` + +### 7. 启动服务 + +```bash +# 启动Flask后端 +python run.py + +# 启动RQ Worker(另开终端) +source venv/bin/activate +cd /path/to/your/project/src/backend +rq worker museguard +``` + +--- + ## 项目结构 ``` @@ -240,4 +352,8 @@ flask run ## 许可证 -本项目仅用于学习和研究目的。 \ No newline at end of file +本项目仅用于学习和研究目的。 + + +https://docs.pingcode.com/baike/2645380 + diff --git a/src/backend/app.py b/src/backend/app.py index 59d7da9..8136e2a 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -43,4 +43,4 @@ def create_app(config_class=Config): if __name__ == '__main__': app = create_app() - app.run(debug=True, host='0.0.0.0', port=5000) \ No newline at end of file + app.run(debug=True, host='0.0.0.0', port=6006) \ No newline at end of file diff --git a/src/backend/app/algorithms/finetune/train_dreambooth_gen.py b/src/backend/app/algorithms/finetune/train_dreambooth_gen.py index fb9721e..c34a908 100644 --- a/src/backend/app/algorithms/finetune/train_dreambooth_gen.py +++ b/src/backend/app/algorithms/finetune/train_dreambooth_gen.py @@ -523,6 +523,11 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) + parser.add_argument( + "--is_perturbed", + action="store_true", + help="Whether training on perturbed images. Affects the generated image naming.", + ) parser.add_argument( "--offset_noise", @@ -1380,9 +1385,12 @@ def main(args): save_path.mkdir(parents=True, exist_ok=True) logger.info(f"Saving validation images directly to {save_path}, overwriting previous images.") + # 根据is_perturbed决定文件名前缀 + prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_" + for i, image in enumerate(images): - # The file name is constant, thus overwriting - image.save(save_path / f"validation_image_{i}.png") + # 使用序号格式化: 0000-9999 + image.save(save_path / f"{prefix}{i:04d}.png") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/src/backend/app/algorithms/finetune/train_lora_gen.py b/src/backend/app/algorithms/finetune/train_lora_gen.py index 29e6ad5..4a951b2 100644 --- a/src/backend/app/algorithms/finetune/train_lora_gen.py +++ b/src/backend/app/algorithms/finetune/train_lora_gen.py @@ -528,6 +528,11 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--is_perturbed", + action="store_true", + help="Whether training on perturbed images. Affects the generated image naming.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1354,9 +1359,13 @@ def main(args): base_save_path.mkdir(parents=True, exist_ok=True) logger.info(f"Saving validation images to {base_save_path}") + # 根据is_perturbed决定文件名前缀 + prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_" + # 图片直接保存在 base_save_path,会覆盖上一轮的同名图片 for i, image in enumerate(images): - image.save(base_save_path / f"image_{i}.png") + # 使用序号格式化: 0000-9999 + image.save(base_save_path / f"{prefix}{i:04d}.png") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} diff --git a/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py b/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py index 96475df..8528ccb 100644 --- a/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py +++ b/src/backend/app/algorithms/finetune_virtual/train_dreambooth_gen.py @@ -8,24 +8,10 @@ import sys import platform import shutil import glob -from PIL import Image, ImageDraw, ImageFont def create_generated_image(source_image_path, output_path, index): - """创建一个模拟生成的图片(添加水印表示是虚拟生成的)""" - with Image.open(source_image_path) as img: - # 复制原图 - generated = img.copy() - draw = ImageDraw.Draw(generated) - - # 添加水印文字 - width, height = generated.size - text = f"Virtual Generated #{index}" - - # 简单在图片上绘制文字 - position = (10, height - 30) - draw.text(position, text, fill=(255, 255, 255)) - - generated.save(output_path, quality=95) + """创建一个模拟生成的图片(简单复制源图片)""" + shutil.copy2(source_image_path, output_path) def main(): parser = argparse.ArgumentParser(description="DreamBooth虚拟微调脚本") @@ -54,6 +40,7 @@ def main(): parser.add_argument('--validation_prompt', default='a photo of sks person') parser.add_argument('--num_validation_images', type=int, default=10) parser.add_argument('--validation_steps', type=int, default=500) + parser.add_argument('--is_perturbed', action='store_true', help='Whether training on perturbed images') args = parser.parse_args() @@ -103,10 +90,13 @@ def main(): print("-" * 80) print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...") + # 根据is_perturbed决定文件名前缀 + prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_" + generated_count = 0 for i in range(min(args.num_validation_images, len(image_files))): source_image = image_files[i % len(image_files)] - filename = f"generated_{i:04d}.png" + filename = f"{prefix}{i:04d}.png" output_path = os.path.join(args.validation_image_output_dir, filename) try: diff --git a/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py b/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py index 199adc6..7923640 100644 --- a/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py +++ b/src/backend/app/algorithms/finetune_virtual/train_lora_gen.py @@ -8,24 +8,10 @@ import sys import platform import shutil import glob -from PIL import Image, ImageDraw, ImageFont def create_generated_image(source_image_path, output_path, index): - """创建一个模拟生成的图片(添加水印表示是虚拟生成的)""" - with Image.open(source_image_path) as img: - # 复制原图 - generated = img.copy() - draw = ImageDraw.Draw(generated) - - # 添加水印文字 - width, height = generated.size - text = f"Virtual LoRA #{index}" - - # 简单在图片上绘制文字 - position = (10, height - 30) - draw.text(position, text, fill=(255, 255, 255)) - - generated.save(output_path, quality=95) + """创建一个模拟生成的图片(简单复制源图片)""" + shutil.copy2(source_image_path, output_path) def main(): parser = argparse.ArgumentParser(description="LoRA虚拟微调脚本") @@ -52,6 +38,7 @@ def main(): parser.add_argument('--rank', type=int, default=4) parser.add_argument('--validation_prompt', default='a photo of sks person') parser.add_argument('--num_validation_images', type=int, default=10) + parser.add_argument('--is_perturbed', action='store_true', help='Whether training on perturbed images') args = parser.parse_args() @@ -102,10 +89,13 @@ def main(): print("-" * 80) print(f"[VIRTUAL] 生成 {args.num_validation_images} 张验证图片...") + # 根据is_perturbed决定文件名前缀 + prefix = "generated_perturbed_" if args.is_perturbed else "generated_original_" + generated_count = 0 for i in range(min(args.num_validation_images, len(image_files))): source_image = image_files[i % len(image_files)] - filename = f"generated_{i:04d}.png" + filename = f"{prefix}{i:04d}.png" output_path = os.path.join(args.validation_image_output_dir, filename) try: diff --git a/src/backend/app/algorithms/perturbation/aspl.py b/src/backend/app/algorithms/perturbation/aspl.py index 6f26194..8a7fc88 100644 --- a/src/backend/app/algorithms/perturbation/aspl.py +++ b/src/backend/app/algorithms/perturbation/aspl.py @@ -758,9 +758,12 @@ def main(args): for img_pixel, img_name in zip(noised_imgs, img_filenames): save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") + logger.info(f"即将保存图片到: {save_path}") Image.fromarray( (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy() ).save(save_path) + + logger.info(f"图片已保存到: {save_path}") print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") diff --git a/src/backend/app/algorithms/perturbation/caat.py b/src/backend/app/algorithms/perturbation/caat.py index c7e41cd..d15cdd4 100644 --- a/src/backend/app/algorithms/perturbation/caat.py +++ b/src/backend/app/algorithms/perturbation/caat.py @@ -953,7 +953,7 @@ def main(args): for i in range(num_images_to_save): img_pixel = noised_imgs[i] img_name = img_names[i] - save_path = os.path.join(save_folder, f"final_noise_{img_name}") + save_path = os.path.join(save_folder, f"perturbed_{img_name}") # 图像转换和保存 Image.fromarray( diff --git a/src/backend/app/algorithms/perturbation/pid.py b/src/backend/app/algorithms/perturbation/pid.py index e4e35fc..98b4fc7 100644 --- a/src/backend/app/algorithms/perturbation/pid.py +++ b/src/backend/app/algorithms/perturbation/pid.py @@ -217,7 +217,9 @@ def main(args): for i in range(0, len(dataset.instance_images_path)): img = dataset[i]['pixel_values'] img = to_image(img + attackmodel.delta[i]) - img.save(os.path.join(args.output_dir, f"{i}.png")) + # 使用原文件名,添加perturbed_前缀 + original_filename = Path(dataset.instance_images_path[i]).stem + img.save(os.path.join(args.output_dir, f"perturbed_{original_filename}.png")) # Select target loss diff --git a/src/backend/app/algorithms/perturbation_virtual/aspl.py b/src/backend/app/algorithms/perturbation_virtual/aspl.py index d20cb11..c9cda08 100644 --- a/src/backend/app/algorithms/perturbation_virtual/aspl.py +++ b/src/backend/app/algorithms/perturbation_virtual/aspl.py @@ -73,10 +73,13 @@ def main(): copied_count = 0 for image_path in image_files: filename = os.path.basename(image_path) - output_path = os.path.join(args.output_dir, filename) + # 添加perturbed_前缀 + name, ext = os.path.splitext(filename) + perturbed_filename = f"perturbed_{name}{ext}" + output_path = os.path.join(args.output_dir, perturbed_filename) shutil.copy(image_path, output_path) copied_count += 1 - print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}") print("-" * 80) print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") diff --git a/src/backend/app/algorithms/perturbation_virtual/caat.py b/src/backend/app/algorithms/perturbation_virtual/caat.py index bbc436a..3f0924e 100644 --- a/src/backend/app/algorithms/perturbation_virtual/caat.py +++ b/src/backend/app/algorithms/perturbation_virtual/caat.py @@ -24,6 +24,10 @@ def main(): parser.add_argument('--train_batch_size', type=int, default=2) parser.add_argument('--max_train_steps', type=int, default=800) parser.add_argument('--learning_rate', type=float, default=1e-5) + parser.add_argument('--lr_warmup_steps', type=int, default=0) + parser.add_argument('--hflip', action='store_true') + parser.add_argument('--mixed_precision', type=str, default='bf16') + parser.add_argument('--alpha', type=float, default=5e-3) parser.add_argument('--pgd_eps', type=float, default=16) parser.add_argument('--seed', type=int, default=0) @@ -65,10 +69,13 @@ def main(): copied_count = 0 for image_path in image_files: filename = os.path.basename(image_path) - output_path = os.path.join(args.output_dir, filename) + # 添加perturbed_前缀 + name, ext = os.path.splitext(filename) + perturbed_filename = f"perturbed_{name}{ext}" + output_path = os.path.join(args.output_dir, perturbed_filename) shutil.copy(image_path, output_path) copied_count += 1 - print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}") print("-" * 80) print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") diff --git a/src/backend/app/algorithms/perturbation_virtual/pid.py b/src/backend/app/algorithms/perturbation_virtual/pid.py index 5453f01..c7b6052 100644 --- a/src/backend/app/algorithms/perturbation_virtual/pid.py +++ b/src/backend/app/algorithms/perturbation_virtual/pid.py @@ -23,6 +23,8 @@ def main(): parser.add_argument('--resolution', type=int, default=512) parser.add_argument('--train_batch_size', type=int, default=1) parser.add_argument('--max_train_steps', type=int, default=600) + parser.add_argument('--center_crop', action='store_true') + parser.add_argument('--attack_type', type=str, default='add-log') parser.add_argument('--learning_rate', type=float, default=3e-6) parser.add_argument('--pgd_eps', type=float, default=4) parser.add_argument('--seed', type=int, default=0) @@ -65,10 +67,13 @@ def main(): copied_count = 0 for image_path in image_files: filename = os.path.basename(image_path) - output_path = os.path.join(args.output_dir, filename) + # 添加perturbed_前缀 + name, ext = os.path.splitext(filename) + perturbed_filename = f"perturbed_{name}{ext}" + output_path = os.path.join(args.output_dir, perturbed_filename) shutil.copy(image_path, output_path) copied_count += 1 - print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}") print("-" * 80) print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") diff --git a/src/backend/app/algorithms/perturbation_virtual/simac.py b/src/backend/app/algorithms/perturbation_virtual/simac.py index 573331e..f08e390 100644 --- a/src/backend/app/algorithms/perturbation_virtual/simac.py +++ b/src/backend/app/algorithms/perturbation_virtual/simac.py @@ -26,7 +26,11 @@ def main(): parser.add_argument('--resolution', type=int, default=512) parser.add_argument('--train_batch_size', type=int, default=1) parser.add_argument('--max_train_steps', type=int, default=1000) + parser.add_argument('--max_f_train_steps', type=int, default=3) + parser.add_argument('--max_adv_train_steps', type=int, default=6) + parser.add_argument('--checkpointing_iterations', type=int, default=10) parser.add_argument('--learning_rate', type=float, default=5e-6) + parser.add_argument('--pgd_alpha', type=float, default=0.005) parser.add_argument('--pgd_eps', type=float, default=8) parser.add_argument('--seed', type=int, default=0) @@ -68,10 +72,13 @@ def main(): copied_count = 0 for image_path in image_files: filename = os.path.basename(image_path) - output_path = os.path.join(args.output_dir, filename) + # 添加perturbed_前缀 + name, ext = os.path.splitext(filename) + perturbed_filename = f"perturbed_{name}{ext}" + output_path = os.path.join(args.output_dir, perturbed_filename) shutil.copy(image_path, output_path) copied_count += 1 - print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename}") + print(f"[VIRTUAL] 处理图片 {copied_count}/{len(image_files)}: {filename} -> {perturbed_filename}") print("-" * 80) print(f"[VIRTUAL] 成功处理 {copied_count} 张图片") diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index 4a57f8b..f534fec 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -7,7 +7,7 @@ from flask import Blueprint, request, jsonify, current_app from flask_jwt_extended import jwt_required, get_jwt_identity from werkzeug.utils import secure_filename from app import db -from app.database import User, Batch, Image, ImageType, UserConfig +from app.database import User, Batch, Image, ImageType, UserConfig, FinetuneBatch, FinetuneConfig from app.services.task_service import TaskService from app.services.image_service import ImageService from app.utils.file_utils import allowed_file, save_uploaded_file @@ -28,41 +28,70 @@ def create_task(): if not user: return jsonify({'error': '用户不存在'}), 404 + data = request.get_json() batch_name = data.get('batch_name', f'Task-{uuid.uuid4().hex[:8]}') - - # 获取用户配置作为默认配置 + + # 优先使用前端传来的参数,没有则用用户配置,没有再用默认 + perturbation_config_id = data.get('perturbation_config_id') + preferred_epsilon = data.get('epsilon') + use_strong_protection = data.get('use_strong_protection') + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() - if user_config: - # 使用用户上次的配置 - perturbation_config_id = user_config.preferred_perturbation_config_id or 1 - preferred_epsilon = user_config.preferred_epsilon or 8.0 - finetune_config_id = user_config.preferred_finetune_config_id or 1 - use_strong_protection = user_config.preferred_purification or False + if perturbation_config_id is None: + perturbation_config_id = user_config.preferred_perturbation_config_id or 1 + if preferred_epsilon is None: + preferred_epsilon = user_config.preferred_epsilon or 8.0 + if use_strong_protection is None: + use_strong_protection = user_config.preferred_purification or False else: - # 使用系统默认配置 + perturbation_config_id = perturbation_config_id or 1 + preferred_epsilon = preferred_epsilon or 8.0 + use_strong_protection = use_strong_protection if use_strong_protection is not None else False + + # 类型转换,防止前端传字符串 + try: + perturbation_config_id = int(perturbation_config_id) + except Exception: perturbation_config_id = 1 + try: + preferred_epsilon = float(preferred_epsilon) + except Exception: preferred_epsilon = 8.0 - finetune_config_id = 1 - use_strong_protection = False - - # 创建任务 + use_strong_protection = bool(use_strong_protection) + + # 创建任务(只包含扰动相关配置,不包含微调配置) batch = Batch( user_id=current_user_id, batch_name=batch_name, perturbation_config_id=perturbation_config_id, preferred_epsilon=preferred_epsilon, - finetune_config_id=finetune_config_id, use_strong_protection=use_strong_protection ) - + db.session.add(batch) db.session.commit() + # 自动创建关联的微调任务(如果用户有默认微调配置则自动设置) + finetune_config_id = None + if user_config and user_config.preferred_finetune_config_id: + finetune_config_id = user_config.preferred_finetune_config_id + + finetune_batch = FinetuneBatch( + batch_id=batch.id, + user_id=current_user_id, + finetune_config_id=finetune_config_id, + status='pending' + ) + db.session.add(finetune_batch) + db.session.commit() + return jsonify({ 'message': '任务创建成功,请上传图片', - 'task': batch.to_dict() + 'task': batch.to_dict(), + 'finetune_task_id': finetune_batch.id, + 'finetune_config_set': finetune_config_id is not None }), 201 except Exception as e: @@ -98,18 +127,17 @@ def upload_images(batch_id): for file in files: if file.filename == '': continue - if file and allowed_file(file.filename): # 处理单张图片 if not file.filename.lower().endswith(('.zip', '.rar')): + # 统一走save_image,内部已实现上传到uploads和预处理 result = ImageService.save_image(file, batch_id, current_user_id, original_type.id) if result['success']: uploaded_files.append(result['image']) else: return jsonify({'error': result['error']}), 400 - - # 处理压缩包 else: + # 压缩包内图片也会走save_image results = ImageService.extract_and_save_zip(file, batch_id, current_user_id, original_type.id) for result in results: if result['success']: @@ -146,14 +174,12 @@ def get_task_config(batch_id): suggested_config = { 'perturbation_config_id': user_config.preferred_perturbation_config_id, 'epsilon': float(user_config.preferred_epsilon), - 'finetune_config_id': user_config.preferred_finetune_config_id, 'use_strong_protection': user_config.preferred_purification } else: suggested_config = { 'perturbation_config_id': batch.perturbation_config_id, 'epsilon': float(batch.preferred_epsilon), - 'finetune_config_id': batch.finetune_config_id, 'use_strong_protection': batch.use_strong_protection } @@ -163,7 +189,6 @@ def get_task_config(batch_id): 'current_config': { 'perturbation_config_id': batch.perturbation_config_id, 'epsilon': float(batch.preferred_epsilon), - 'finetune_config_id': batch.finetune_config_id, 'use_strong_protection': batch.use_strong_protection } }), 200 @@ -188,7 +213,7 @@ def update_task_config(batch_id): data = request.get_json() - # 更新任务配置 + # 更新任务配置(仅扰动相关) if 'perturbation_config_id' in data: batch.perturbation_config_id = data['perturbation_config_id'] @@ -199,9 +224,6 @@ def update_task_config(batch_id): else: return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - if 'finetune_config_id' in data: - batch.finetune_config_id = data['finetune_config_id'] - if 'use_strong_protection' in data: batch.use_strong_protection = bool(data['use_strong_protection']) @@ -228,8 +250,13 @@ def start_task(batch_id): if not batch: return jsonify({'error': '任务不存在或无权限'}), 404 - if batch.status != 'pending': + if batch.status not in ['pending', 'failed', 'canceled']: return jsonify({'error': '任务状态不正确,无法开始处理'}), 400 + # 如果是失败或取消,重置状态为pending + if batch.status in ['failed', 'canceled']: + batch.status = 'pending' + batch.error_message = None + db.session.commit() # 检查是否有上传的图片 image_count = Image.query.filter_by(batch_id=batch_id).count() @@ -316,4 +343,264 @@ def get_task_status(batch_id): }), 200 except Exception as e: - return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 \ No newline at end of file + return jsonify({'error': f'获取任务状态失败: {str(e)}'}), 500 + +# ==================== 微调任务管理接口 ==================== + +@task_bp.route('/finetune/configs', methods=['GET']) +@jwt_required() +def get_finetune_configs(): + """获取所有可用的微调配置""" + try: + configs = FinetuneConfig.query.all() + return jsonify({ + 'configs': [{ + 'id': config.id, + 'method_code': config.method_code, + 'method_name': config.method_name, + 'description': config.description + } for config in configs] + }), 200 + + except Exception as e: + return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 + +@task_bp.route('/finetune/list', methods=['GET']) +@jwt_required() +def list_finetune_tasks(): + """获取用户的所有微调任务列表""" + try: + current_user_id = get_jwt_identity() + + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 10, type=int) + + finetune_tasks = FinetuneBatch.query.filter_by(user_id=current_user_id)\ + .order_by(FinetuneBatch.created_at.desc())\ + .paginate(page=page, per_page=per_page, error_out=False) + + results = [] + for ft in finetune_tasks.items: + ft_dict = ft.to_dict() + # 添加关联的扰动任务信息 + ft_dict['batch_info'] = ft.batch.to_dict() if ft.batch else None + results.append(ft_dict) + + return jsonify({ + 'finetune_tasks': results, + 'total': finetune_tasks.total, + 'pages': finetune_tasks.pages, + 'current_page': page + }), 200 + + except Exception as e: + return jsonify({'error': f'获取微调任务列表失败: {str(e)}'}), 500 + +@task_bp.route('/finetune/', methods=['GET']) +@jwt_required() +def get_finetune_task(finetune_id): + """获取微调任务详情""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + result = finetune_task.to_dict() + result['batch_info'] = finetune_task.batch.to_dict() if finetune_task.batch else None + + return jsonify({'finetune_task': result}), 200 + + except Exception as e: + return jsonify({'error': f'获取微调任务详情失败: {str(e)}'}), 500 + +@task_bp.route('/finetune/by-batch/', methods=['GET']) +@jwt_required() +def get_finetune_by_batch(batch_id): + """根据扰动任务ID获取关联的微调任务""" + try: + current_user_id = get_jwt_identity() + + # 验证扰动任务权限 + batch = Batch.query.filter_by(id=batch_id, user_id=current_user_id).first() + if not batch: + return jsonify({'error': '扰动任务不存在或无权限'}), 404 + + finetune_task = FinetuneBatch.query.filter_by(batch_id=batch_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '该扰动任务没有关联的微调任务'}), 404 + + result = finetune_task.to_dict() + result['batch_info'] = batch.to_dict() + + return jsonify({'finetune_task': result}), 200 + + except Exception as e: + return jsonify({'error': f'获取微调任务失败: {str(e)}'}), 500 + +@task_bp.route('/finetune//config', methods=['GET']) +@jwt_required() +def get_finetune_config(finetune_id): + """获取微调任务配置(显示用户默认配置或当前配置)""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + # 获取用户配置 + user_config = UserConfig.query.filter_by(user_id=current_user_id).first() + + # 如果用户有配置,显示用户默认配置;否则显示系统默认 + if user_config and user_config.preferred_finetune_config_id: + suggested_config = { + 'finetune_config_id': user_config.preferred_finetune_config_id, + 'finetune_config_name': user_config.preferred_finetune_config.method_name if user_config.preferred_finetune_config else None + } + else: + # 默认使用第一个微调配置 + default_config = FinetuneConfig.query.first() + suggested_config = { + 'finetune_config_id': default_config.id if default_config else 1, + 'finetune_config_name': default_config.method_name if default_config else None + } + + # 当前微调任务的配置 + current_config = { + 'finetune_config_id': finetune_task.finetune_config_id, + 'finetune_config_name': finetune_task.finetune_config.method_name if finetune_task.finetune_config else None + } + + return jsonify({ + 'finetune_task': finetune_task.to_dict(), + 'suggested_config': suggested_config, + 'current_config': current_config + }), 200 + + except Exception as e: + return jsonify({'error': f'获取微调配置失败: {str(e)}'}), 500 + +@task_bp.route('/finetune//config', methods=['PUT']) +@jwt_required() +def update_finetune_config(finetune_id): + """更新微调任务配置(仅限 pending 状态)""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + if finetune_task.status != 'pending': + return jsonify({'error': '只能修改待处理状态的微调任务配置'}), 400 + + data = request.get_json() + finetune_config_id = data.get('finetune_config_id') + + if not finetune_config_id: + return jsonify({'error': '请提供微调方法ID'}), 400 + + # 验证微调配置是否存在 + finetune_config = FinetuneConfig.query.get(finetune_config_id) + if not finetune_config: + return jsonify({'error': '微调配置不存在'}), 404 + + finetune_task.finetune_config_id = finetune_config_id + db.session.commit() + + return jsonify({ + 'message': '微调配置更新成功', + 'finetune_task': finetune_task.to_dict() + }), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'更新微调配置失败: {str(e)}'}), 500 + +@task_bp.route('/finetune//start', methods=['POST']) +@jwt_required() +def start_finetune(finetune_id): + """启动微调任务""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + # 检查扰动任务是否已完成 + if finetune_task.batch.status != 'completed': + return jsonify({'error': '扰动任务尚未完成,无法开始微调'}), 400 + + # 检查是否已设置微调配置 + if not finetune_task.finetune_config_id: + return jsonify({'error': '请先设置微调方法'}), 400 + + # 检查状态 + if finetune_task.status not in ['pending', 'failed']: + return jsonify({'error': f'微调任务状态为 {finetune_task.status},无法启动'}), 400 + + # 启动微调任务 + job_ids = TaskService.start_finetune_task(finetune_task) + + if job_ids: + return jsonify({ + 'message': '微调任务已启动', + 'finetune_task_id': finetune_id, + 'job_ids': job_ids + }), 200 + else: + return jsonify({'error': '微调任务启动失败'}), 500 + + except Exception as e: + return jsonify({'error': f'启动微调任务失败: {str(e)}'}), 500 + +@task_bp.route('/finetune//status', methods=['GET']) +@jwt_required() +def get_finetune_task_status(finetune_id): + """获取微调任务状态""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + # 获取详细状态 + status_info = TaskService.get_finetune_task_status(finetune_id) + + return jsonify({ + 'finetune_task_id': finetune_id, + 'status': finetune_task.status, + 'finetune_config': finetune_task.finetune_config.to_dict() if finetune_task.finetune_config else None, + 'details': status_info, + 'error_message': finetune_task.error_message + }), 200 + + except Exception as e: + return jsonify({'error': f'获取微调任务状态失败: {str(e)}'}), 500 + +@task_bp.route('/finetune/', methods=['DELETE']) +@jwt_required() +def delete_finetune_task(finetune_id): + """删除微调任务(仅限 pending 或 failed 状态)""" + try: + current_user_id = get_jwt_identity() + + finetune_task = FinetuneBatch.query.filter_by(id=finetune_id, user_id=current_user_id).first() + if not finetune_task: + return jsonify({'error': '微调任务不存在或无权限'}), 404 + + if finetune_task.status not in ['pending', 'failed']: + return jsonify({'error': '只能删除待处理或失败状态的微调任务'}), 400 + + db.session.delete(finetune_task) + db.session.commit() + + return jsonify({'message': '微调任务删除成功'}), 200 + + except Exception as e: + db.session.rollback() + return jsonify({'error': f'删除微调任务失败: {str(e)}'}), 500 diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 5f7f8fd..eabbe24 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -84,8 +84,17 @@ class FinetuneConfig(db.Model): description = db.Column(db.Text, nullable=False) # 关系 - batches = db.relationship('Batch', backref='finetune_config', lazy='dynamic') + finetune_tasks = db.relationship('FinetuneBatch', backref='finetune_config', lazy='dynamic') user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic') + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'method_code': self.method_code, + 'method_name': self.method_name, + 'description': self.description + } class Batch(db.Model): """加噪批次表""" @@ -100,10 +109,6 @@ class Batch(db.Model): nullable=False, default=1) preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0) - # 评估配置 - finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id'), - nullable=False, default=1) - # 净化配置 use_strong_protection = db.Column(db.Boolean, nullable=False, default=False) @@ -130,8 +135,7 @@ class Batch(db.Model): 'started_at': self.started_at.isoformat() if self.started_at else None, 'completed_at': self.completed_at.isoformat() if self.completed_at else None, 'error_message': self.error_message, - 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None, - 'finetune_config': self.finetune_config.method_name if self.finetune_config else None + 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None } class Image(db.Model): @@ -207,6 +211,47 @@ class EvaluationResult(db.Model): 'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None } +class FinetuneBatch(db.Model): + """微调任务表""" + __tablename__ = 'finetune_batch' + + id = db.Column(db.BigInteger, primary_key=True) + batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'), nullable=False) + user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) + finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id')) + + # 任务状态 + status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending') + created_at = db.Column(db.DateTime, default=datetime.utcnow) + started_at = db.Column(db.DateTime) + completed_at = db.Column(db.DateTime) + error_message = db.Column(db.Text) + + # 任务ID(用于RQ任务追踪) + original_job_id = db.Column(db.String(255)) # 原始图片微调任务ID + perturbed_job_id = db.Column(db.String(255)) # 扰动图片微调任务ID + + # 关系 + batch = db.relationship('Batch', backref='finetune_tasks', lazy=True) + user = db.relationship('User', backref='finetune_tasks', lazy=True) + + def to_dict(self): + """转换为字典""" + return { + 'id': self.id, + 'batch_id': self.batch_id, + 'user_id': self.user_id, + 'finetune_config_id': self.finetune_config_id, + 'finetune_config': self.finetune_config.method_name if self.finetune_config else None, + 'status': self.status, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'started_at': self.started_at.isoformat() if self.started_at else None, + 'completed_at': self.completed_at.isoformat() if self.completed_at else None, + 'error_message': self.error_message, + 'original_job_id': self.original_job_id, + 'perturbed_job_id': self.perturbed_job_id + } + class UserConfig(db.Model): """用户配置表""" __tablename__ = 'user_configs' diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index b8857e5..5c64d9e 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -6,6 +6,8 @@ import os import uuid import zipfile +import fcntl +import time from werkzeug.utils import secure_filename from flask import current_app from PIL import Image as PILImage @@ -14,60 +16,154 @@ from app.database import Image from app.utils.file_utils import allowed_file class ImageService: - """图像处理服务""" - @staticmethod - def save_image(file, batch_id, user_id, image_type_id): - """保存单张图片""" + def save_to_uploads(file, batch_id, user_id): + """ + 上传图片到uploads临时目录,返回临时文件路径和原始文件名。 + """ + project_root = os.path.dirname(current_app.root_path) + upload_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) + os.makedirs(upload_dir, exist_ok=True) + orig_ext = os.path.splitext(file.filename)[1].lower() + temp_name = f"{uuid.uuid4().hex}{orig_ext}" + temp_path = os.path.join(upload_dir, temp_name) + file.save(temp_path) + return temp_path, file.filename + + @staticmethod + def preprocess_image(temp_path, original_filename, batch_id, user_id, image_type_id, resolution=512, target_format='png'): + """ + 对图片进行中心裁剪、缩放、格式转换、重命名,保存到static/originals,返回数据库对象。 + 原图命名格式: 0000.png, 0001.png, ..., 9999.png + 使用数据库事务和重试机制确保并发安全 + """ + final_path = None + max_retries = 50 + try: - if not file or not allowed_file(file.filename): - return {'success': False, 'error': '不支持的文件格式'} - - # 生成唯一文件名 - file_extension = os.path.splitext(file.filename)[1].lower() - stored_filename = f"{uuid.uuid4().hex}{file_extension}" - - # 临时保存到上传目录 + img = PILImage.open(temp_path).convert("RGB") + width, height = img.size + min_dim = min(width, height) + left = (width - min_dim) // 2 + top = (height - min_dim) // 2 + right = left + min_dim + bottom = top + min_dim + img = img.crop((left, top, right, bottom)) + img = img.resize((resolution, resolution), resample=PILImage.Resampling.LANCZOS) + project_root = os.path.dirname(current_app.root_path) - temp_dir = os.path.join(project_root, current_app.config['UPLOAD_FOLDER'], str(user_id), str(batch_id)) - os.makedirs(temp_dir, exist_ok=True) - temp_path = os.path.join(temp_dir, stored_filename) - file.save(temp_path) - - # 移动到对应的静态目录 static_dir = os.path.join(project_root, current_app.config['ORIGINAL_IMAGES_FOLDER'], str(user_id), str(batch_id)) os.makedirs(static_dir, exist_ok=True) - file_path = os.path.join(static_dir, stored_filename) - # 移动文件到最终位置 - import shutil - shutil.move(temp_path, file_path) + from app.database import ImageType + original_type = ImageType.query.filter_by(type_code='original').first() + target_image_type_id = original_type.id if original_type else image_type_id - # 获取图片尺寸 - try: - with PILImage.open(file_path) as img: - width, height = img.size - except: - width, height = None, None + # 首次查询最大序号 + max_seq_result = db.session.execute( + db.text(""" + SELECT COALESCE(MAX(CAST(SUBSTRING_INDEX(stored_filename, '.', 1) AS UNSIGNED)), -1) as max_seq + FROM images + WHERE batch_id = :batch_id + AND image_type_id = :image_type_id + AND stored_filename REGEXP '^[0-9]{4}\\.' + """), + {'batch_id': batch_id, 'image_type_id': target_image_type_id} + ).fetchone() - # 创建数据库记录 - image = Image( - user_id=user_id, - batch_id=batch_id, - original_filename=file.filename, - stored_filename=stored_filename, - file_path=file_path, - file_size=os.path.getsize(file_path), - image_type_id=image_type_id, - width=width, - height=height - ) - - db.session.add(image) - db.session.commit() + # 强制类型转换,确保安全 + try: + base_sequence = int(max_seq_result[0]) if max_seq_result[0] is not None else -1 + except Exception: + base_sequence = -1 + base_sequence += 1 + + # 重试机制:从base_sequence开始尝试连续的序号 + for attempt in range(max_retries): + sequence_number = int(base_sequence) + int(attempt) + fmt_str = str(target_format).lower() if target_format else 'png' + new_name = f"{sequence_number:04d}.{fmt_str}" + final_path = os.path.join(static_dir, new_name) + + try: + # 检查数据库中是否已存在此文件名 + existing = Image.query.filter_by( + batch_id=batch_id, + stored_filename=new_name + ).first() + + if existing: + # 已存在,尝试下一个序号 + continue + + # 保存图片文件 + if target_format.lower() in ['jpg', 'jpeg']: + img.save(final_path, format='JPEG', quality=95) + else: + img.save(final_path, format=target_format.upper()) + + # 创建数据库记录 + image = Image( + user_id=user_id, + batch_id=batch_id, + original_filename=original_filename, + stored_filename=new_name, + file_path=final_path, + file_size=os.path.getsize(final_path), + image_type_id=image_type_id, + width=img.width, + height=img.height + ) + db.session.add(image) + db.session.commit() + + # 删除临时文件 + if os.path.exists(temp_path): + os.remove(temp_path) + + return {'success': True, 'image': image} + + except Exception as e: + db.session.rollback() + error_msg = str(e) + + # 如果是唯一性冲突,清理文件并尝试下一个序号 + if 'Duplicate entry' in error_msg or '1062' in error_msg: + if final_path and os.path.exists(final_path): + try: + os.remove(final_path) + except: + pass + # 继续循环尝试下一个序号 + time.sleep(0.005) + continue + else: + # 其他错误直接抛出 + raise - return {'success': True, 'image': image} + # 所有尝试都失败 + raise Exception(f"无法生成唯一文件名,已尝试序号 {base_sequence} 到 {base_sequence + max_retries - 1}") + except Exception as e: + db.session.rollback() + # 清理可能已保存的文件 + if final_path and os.path.exists(final_path): + try: + os.remove(final_path) + except: + pass + return {'success': False, 'error': f'图片预处理失败: {str(e)}'} + + """图像处理服务""" + + @staticmethod + def save_image(file, batch_id, user_id, image_type_id, resolution=512, target_format='png'): + """保存单张图片,自动上传到uploads并预处理""" + try: + if not file or not allowed_file(file.filename): + return {'success': False, 'error': '不支持的文件格式'} + temp_path, orig_name = ImageService.save_to_uploads(file, batch_id, user_id) + return ImageService.preprocess_image(temp_path, orig_name, batch_id, user_id, image_type_id, resolution, target_format) except Exception as e: db.session.rollback() return {'success': False, 'error': f'保存图片失败: {str(e)}'} diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index aa9ab39..60ad990 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -11,7 +11,7 @@ from redis import Redis from rq import Queue from rq.job import Job from app import db -from app.database import Batch, Image, EvaluationResult, ImageType +from app.database import Batch, Image, EvaluationResult, ImageType, FinetuneBatch from config.algorithm_config import AlgorithmConfig class TaskService: @@ -91,7 +91,7 @@ class TaskService: run_perturbation_task, batch_id=batch.id, algorithm_code=batch.perturbation_config.method_code, - epsilon=float(batch.preferred_epsilon), + epsilon=int(batch.preferred_epsilon), use_strong_protection=batch.use_strong_protection, input_dir=input_dir, output_dir=output_dir, @@ -356,20 +356,23 @@ class TaskService: return 0 @staticmethod - def start_finetune_and_evaluation(batch): + def start_finetune_task(finetune_task): """ - 启动微调和评估任务 - - 此方法在扰动任务完成后调用,分别使用原始图片和扰动图片微调模型, - 然后对比生成效果 + 启动微调任务(使用 FinetuneBatch) Args: - batch: Batch对象 + finetune_task: FinetuneBatch对象 Returns: 包含两个job_id的字典 """ try: + # 获取关联的扰动任务 + batch = finetune_task.batch + if not batch: + print(f"FinetuneBatch {finetune_task.id} 没有关联的扰动任务") + return None + # 检查是否有扰动图片 perturbed_images = Image.query.filter_by( batch_id=batch.id @@ -379,10 +382,13 @@ class TaskService: if not perturbed_images: print(f"Batch {batch.id} 没有扰动图片,无法启动微调任务") + finetune_task.status = 'failed' + finetune_task.error_message = '没有找到扰动图片' + db.session.commit() return None project_root = os.path.dirname(current_app.root_path) - finetune_method = batch.finetune_config.method_code + finetune_method = finetune_task.finetune_config.method_code queue = TaskService._get_queue() from app.workers.finetune_worker import run_finetune_task @@ -426,6 +432,7 @@ class TaskService: # 1. 用原始图片微调模型 job_original = queue.enqueue( run_finetune_task, + finetune_batch_id=finetune_task.id, batch_id=batch.id, finetune_method=finetune_method, train_images_dir=original_dir, @@ -435,12 +442,13 @@ class TaskService: is_perturbed=False, custom_params=None, job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"finetune_original_{batch.id}" + job_id=f"finetune_original_{finetune_task.id}" ) # 2. 用扰动图片微调模型(依赖于原始图片微调完成) job_perturbed = queue.enqueue( run_finetune_task, + finetune_batch_id=finetune_task.id, batch_id=batch.id, finetune_method=finetune_method, train_images_dir=perturbed_dir, @@ -450,10 +458,17 @@ class TaskService: is_perturbed=True, custom_params=None, job_timeout=AlgorithmConfig.TASK_TIMEOUT, - job_id=f"finetune_perturbed_{batch.id}", - depends_on=job_original # 依赖关系 + job_id=f"finetune_perturbed_{finetune_task.id}", + depends_on=job_original ) + # 更新微调任务状态 + finetune_task.status = 'queued' + finetune_task.original_job_id = job_original.id + finetune_task.perturbed_job_id = job_perturbed.id + finetune_task.started_at = datetime.utcnow() + db.session.commit() + return { 'original_job_id': job_original.id, 'perturbed_job_id': job_perturbed.id @@ -461,8 +476,85 @@ class TaskService: except Exception as e: print(f"启动微调任务时出错: {str(e)}") + finetune_task.status = 'failed' + finetune_task.error_message = str(e) + finetune_task.completed_at = datetime.utcnow() + db.session.commit() return None + @staticmethod + def get_finetune_task_status(finetune_id): + """ + 获取微调任务状态(使用 FinetuneBatch) + 同时会检查并更新任务状态 + + Args: + finetune_id: 微调任务ID + + Returns: + 微调任务状态信息 + """ + try: + finetune_task = FinetuneBatch.query.get(finetune_id) + if not finetune_task: + return {'status': 'not_found'} + + # 如果任务不是最终状态,检查并更新状态 + if finetune_task.status not in ['completed', 'failed']: + from app.workers.finetune_worker import _check_and_update_finetune_status + _check_and_update_finetune_status(finetune_task) + # 刷新对象以获取最新状态 + db.session.refresh(finetune_task) + + # 如果任务已完成或失败,直接返回数据库状态 + if finetune_task.status in ['completed', 'failed']: + return { + 'status': finetune_task.status, + 'error': finetune_task.error_message if finetune_task.status == 'failed' else None, + 'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None, + 'completed_at': finetune_task.completed_at.isoformat() if finetune_task.completed_at else None + } + + # 从RQ获取任务状态 + redis_conn = TaskService._get_redis_connection() + + original_job_status = 'not_found' + perturbed_job_status = 'not_found' + + try: + if finetune_task.original_job_id: + original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn) + original_job_status = original_job.get_status() + except: + pass + + try: + if finetune_task.perturbed_job_id: + perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn) + perturbed_job_status = perturbed_job.get_status() + except: + pass + + # 映射状态 + status_map = { + 'queued': 'queued', + 'started': 'processing', + 'finished': 'completed', + 'failed': 'failed', + 'not_found': 'not_started' + } + + return { + 'status': finetune_task.status, + 'original_finetune': status_map.get(original_job_status, 'unknown'), + 'perturbed_finetune': status_map.get(perturbed_job_status, 'unknown'), + 'started_at': finetune_task.started_at.isoformat() if finetune_task.started_at else None + } + + except Exception as e: + print(f"获取微调任务状态时出错: {str(e)}") + return {'status': 'error', 'error': str(e)} + @staticmethod def generate_final_evaluations(batch_id): """ @@ -531,4 +623,5 @@ class TaskService: except Exception as e: print(f"生成最终评估时出错: {str(e)}") - db.session.rollback() \ No newline at end of file + db.session.rollback() + diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 40eb747..9a382ea 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -15,13 +15,95 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def run_finetune_task(batch_id, finetune_method, train_images_dir, output_model_dir, +def _check_and_update_finetune_status(finetune_task): + """ + 检查微调任务状态并更新 + 当原始和扰动图片的微调都完成时,更新任务状态为completed + + Args: + finetune_task: FinetuneBatch对象 + """ + from app import db + from rq.job import Job + from redis import Redis + from config.algorithm_config import AlgorithmConfig + + try: + # 刷新数据库对象,确保获取最新状态 + db.session.refresh(finetune_task) + + # 如果状态已经是completed或failed,不再检查 + if finetune_task.status in ['completed', 'failed']: + return + + redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL) + + original_job_done = False + perturbed_job_done = False + has_original_job = False + has_perturbed_job = False + + # 检查原始图片微调任务 + if finetune_task.original_job_id: + has_original_job = True + try: + original_job = Job.fetch(finetune_task.original_job_id, connection=redis_conn) + status = original_job.get_status() + logger.info(f"Original job {finetune_task.original_job_id} status: {status}") + if status == 'finished': + original_job_done = True + elif status == 'failed': + # 如果原始任务失败,整个微调任务标记为失败 + finetune_task.status = 'failed' + finetune_task.error_message = f"Original finetune job failed: {original_job.exc_info}" + finetune_task.completed_at = datetime.utcnow() + db.session.commit() + logger.error(f"FinetuneBatch {finetune_task.id} failed: original job failed") + return + except Exception as e: + logger.error(f"Error checking original job: {str(e)}") + + # 检查扰动图片微调任务 + if finetune_task.perturbed_job_id: + has_perturbed_job = True + try: + perturbed_job = Job.fetch(finetune_task.perturbed_job_id, connection=redis_conn) + status = perturbed_job.get_status() + logger.info(f"Perturbed job {finetune_task.perturbed_job_id} status: {status}") + if status == 'finished': + perturbed_job_done = True + elif status == 'failed': + # 如果扰动任务失败,整个微调任务标记为失败 + finetune_task.status = 'failed' + finetune_task.error_message = f"Perturbed finetune job failed: {perturbed_job.exc_info}" + finetune_task.completed_at = datetime.utcnow() + db.session.commit() + logger.error(f"FinetuneBatch {finetune_task.id} failed: perturbed job failed") + return + except Exception as e: + logger.error(f"Error checking perturbed job: {str(e)}") + + # 如果两个任务都完成,更新状态为completed + if has_original_job and has_perturbed_job and original_job_done and perturbed_job_done: + finetune_task.status = 'completed' + finetune_task.completed_at = datetime.utcnow() + db.session.commit() + logger.info(f"FinetuneBatch {finetune_task.id} completed - both jobs finished") + else: + logger.info(f"FinetuneBatch {finetune_task.id} not all jobs finished yet: original={original_job_done}, perturbed={perturbed_job_done}") + + except Exception as e: + logger.error(f"Error checking finetune status: {str(e)}", exc_info=True) + + +def run_finetune_task(finetune_batch_id, batch_id, finetune_method, train_images_dir, output_model_dir, class_dir, inference_prompts, is_perturbed=False, custom_params=None): """ 执行微调任务 Args: - batch_id: 任务批次ID + finetune_batch_id: 微调任务ID + batch_id: 扰动任务批次ID finetune_method: 微调方法 (dreambooth, lora) train_images_dir: 训练图片目录(原始或扰动) output_model_dir: 模型输出目录 @@ -35,17 +117,26 @@ def run_finetune_task(batch_id, finetune_method, train_images_dir, output_model_ """ from config.algorithm_config import AlgorithmConfig from app import create_app, db - from app.database import Batch, Image, ImageType + from app.database import FinetuneBatch, Batch, Image, ImageType app = create_app() with app.app_context(): try: + finetune_task = FinetuneBatch.query.get(finetune_batch_id) + if not finetune_task: + raise ValueError(f"FinetuneBatch {finetune_batch_id} not found") + batch = Batch.query.get(batch_id) if not batch: raise ValueError(f"Batch {batch_id} not found") - logger.info(f"Starting finetune task for batch {batch_id}") + # 更新微调任务状态为处理中 + if finetune_task.status == 'queued': + finetune_task.status = 'processing' + db.session.commit() + + logger.info(f"Starting finetune task for FinetuneBatch {finetune_batch_id}, Batch {batch_id}") logger.info(f"Method: {finetune_method}, Perturbed: {is_perturbed}") # 确保目录存在 @@ -71,11 +162,20 @@ def run_finetune_task(batch_id, finetune_method, train_images_dir, output_model_ # 保存生成的图片到数据库 _save_generated_images(batch_id, output_model_dir, is_perturbed) - logger.info(f"Finetune task completed for batch {batch_id}") + # 检查两个任务是否都已完成 + _check_and_update_finetune_status(finetune_task) + + logger.info(f"Finetune task completed for FinetuneBatch {finetune_batch_id}") return result except Exception as e: - logger.error(f"Finetune task failed for batch {batch_id}: {str(e)}", exc_info=True) + logger.error(f"Finetune task failed for FinetuneBatch {finetune_batch_id}: {str(e)}", exc_info=True) + # 更新微调任务状态为失败 + if finetune_task: + finetune_task.status = 'failed' + finetune_task.error_message = str(e) + finetune_task.completed_at = datetime.utcnow() + db.session.commit() raise @@ -103,9 +203,12 @@ def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model f"--instance_data_dir={train_images_dir}", f"--output_dir={output_model_dir}", f"--class_data_dir={class_dir}", - f"--inference_prompts={inference_prompts}", ] + # 添加is_perturbed标志 + if is_perturbed: + cmd_args.append("--is_perturbed") + # 添加其他参数 for key, value in params.items(): if isinstance(value, bool): @@ -116,7 +219,7 @@ def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model # 构建完整命令 cmd = [ - 'conda', 'run', '-n', conda_env, '--no-capture-output', + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', 'accelerate', 'launch', script_path ] + cmd_args @@ -213,6 +316,10 @@ def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_mo f"--class_data_dir=/tmp/class_placeholder", ] + # 添加is_perturbed标志 + if is_perturbed: + cmd_args.append("--is_perturbed") + # 添加其他默认参数 for key, value in default_params.items(): if key == 'pretrained_model_name_or_path': @@ -225,7 +332,7 @@ def _run_virtual_finetune(finetune_method, batch_id, train_images_dir, output_mo # 使用conda run执行虚拟脚本 cmd = [ - 'conda', 'run', '-n', conda_env, '--no-capture-output', + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', 'python', script_path ] + cmd_args @@ -315,35 +422,33 @@ def _save_generated_images(batch_id, output_model_dir, is_perturbed): logger.info(f"Found {len(image_files)} generated images to save") # 保存到数据库 + saved_count = 0 for image_path in image_files: try: from PIL import Image as PILImage - with PILImage.open(image_path) as img: - width, height = img.size - - # 查找对应的父图片(原始或扰动) filename = os.path.basename(image_path) - # 去掉"generated_"前缀 - original_filename = filename.replace('generated_', '') - - # 尝试找到父图片 - if is_perturbed: - parent_type = ImageType.query.filter_by(type_code='perturbed').first() - else: - parent_type = ImageType.query.filter_by(type_code='original').first() - parent_image = Image.query.filter_by( + # 检查是否已经保存过(使用filename作为stored_filename) + existing = Image.query.filter_by( batch_id=batch_id, - image_type_id=parent_type.id - ).filter(Image.original_filename.like(f"%{original_filename.split('_')[0]}%")).first() + stored_filename=filename + ).first() + + if existing: + logger.info(f"Image already exists: {filename}") + continue + + with PILImage.open(image_path) as img: + width, height = img.size - # 创建图片记录 + # 生成图片不设置父图片关系(多对多关系,无法确定具体父图片) + # 创建图片记录(直接使用filename,算法已经生成了正确格式) generated_image = Image( user_id=batch.user_id, batch_id=batch_id, - father_id=parent_image.id if parent_image else None, + father_id=None, # 微调生成图片无特定父图片 original_filename=filename, - stored_filename=os.path.basename(image_path), + stored_filename=filename, # 算法输出已经是正确格式 file_path=image_path, file_size=os.path.getsize(image_path), image_type_id=image_type.id, @@ -352,13 +457,14 @@ def _save_generated_images(batch_id, output_model_dir, is_perturbed): ) db.session.add(generated_image) + saved_count += 1 logger.info(f"Saved generated image: {filename}") except Exception as e: logger.error(f"Failed to save {image_path}: {str(e)}") db.session.commit() - logger.info(f"Successfully saved {len(image_files)} generated images") + logger.info(f"Successfully saved {saved_count} generated images to database") except Exception as e: logger.error(f"Error saving generated images: {str(e)}") diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 2eb24d2..774f80c 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -85,6 +85,9 @@ def run_perturbation_task(batch_id, algorithm_code, epsilon, use_strong_protecti batch.completed_at = datetime.utcnow() db.session.commit() + # 保存扰动图片到数据库 + _save_perturbed_images(batch_id, output_dir) + logger.info(f"Task completed successfully for batch {batch_id}") return result @@ -117,14 +120,37 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id, # 合并自定义参数 params = {**default_params, **(custom_params or {})} - # 构建命令行参数 - cmd_args = [ - f"--instance_data_dir_for_train={input_dir}", - f"--instance_data_dir_for_adversarial={input_dir}", - f"--output_dir={output_dir}", - f"--class_data_dir={class_dir}", - f"--pgd_eps={epsilon}", - ] + cmd_args = [] + if algorithm_code == 'aspl': + cmd_args.extend([ + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir={class_dir}", + f"--pgd_eps={str(epsilon)}", + ]) + elif algorithm_code == 'simac': + cmd_args.extend([ + f"--instance_data_dir_for_train={input_dir}", + f"--instance_data_dir_for_adversarial={input_dir}", + f"--output_dir={output_dir}", + f"--class_data_dir={class_dir}", + f"--pgd_eps={str(epsilon)}", + ]) + elif algorithm_code == 'caat': + cmd_args.extend([ + f"--instance_data_dir={input_dir}", + f"--output_dir={output_dir}", + f"--eps={str(epsilon)}", + ]) + elif algorithm_code == 'pid': + cmd_args.extend([ + f"--instance_data_dir={input_dir}", + f"--output_dir={output_dir}", + f"--eps={str(epsilon)}", + ]) + else: + raise ValueError(f"Unsupported algorithm code: {algorithm_code}") # 添加其他参数 for key, value in params.items(): @@ -137,7 +163,7 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id, # 构建完整命令 # 使用conda run避免环境嵌套问题 cmd = [ - 'conda', 'run', '-n', conda_env, '--no-capture-output', + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', 'accelerate', 'launch', script_path ] + cmd_args @@ -167,6 +193,9 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, batch_id, process.wait() + logger.info(f"output_dir: {output_dir}") + logger.info(f"log_file: {log_file}") + if process.returncode != 0: raise RuntimeError(f"Algorithm execution failed with code {process.returncode}. Check log: {log_file}") @@ -197,7 +226,7 @@ def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protect logger.info(f"Running virtual algorithm: {algorithm_code}") # 获取算法配置 - algo_config = AlgorithmConfig.get_algorithm_config(algorithm_code) + algo_config = AlgorithmConfig.PERTURBATION_SCRIPTS.get(algorithm_code) if not algo_config: raise ValueError(f"Algorithm {algorithm_code} not configured") @@ -242,7 +271,7 @@ def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protect # 使用conda run执行虚拟脚本 cmd = [ - 'conda', 'run', '-n', conda_env, '--no-capture-output', + '/root/miniconda3/bin/conda', 'run', '-n', conda_env, '--no-capture-output', 'python', script_path ] + cmd_args @@ -294,3 +323,100 @@ def _run_virtual_algorithm(algorithm_code, batch_id, epsilon, use_strong_protect 'processed_files': processed_files, 'log_file': log_file } + + +def _save_perturbed_images(batch_id, output_dir): + """保存扰动图片到数据库""" + from app import db + from app.database import Batch, Image, ImageType + import glob + from PIL import Image as PILImage + + try: + batch = Batch.query.get(batch_id) + if not batch: + logger.error(f"Batch {batch_id} not found") + return + + # 获取扰动图片类型 + perturbed_type = ImageType.query.filter_by(type_code='perturbed').first() + if not perturbed_type: + logger.error("Perturbed image type not found") + return + + # 获取原始图片列表 + original_type = ImageType.query.filter_by(type_code='original').first() + original_images = Image.query.filter_by( + batch_id=batch_id, + image_type_id=original_type.id + ).all() + + # 创建原图映射字典: stored_filename -> Image对象 + original_map = {img.stored_filename: img for img in original_images} + + # 查找输出目录中的扰动图片 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.gif', '*.bmp'] + perturbed_files = [] + for ext in image_extensions: + perturbed_files.extend(glob.glob(os.path.join(output_dir, ext))) + perturbed_files.extend(glob.glob(os.path.join(output_dir, ext.upper()))) + + logger.info(f"Found {len(perturbed_files)} perturbed images to save") + + saved_count = 0 + for perturbed_path in perturbed_files: + try: + filename = os.path.basename(perturbed_path) + + # 扰动图片命名格式: perturbed_{原图名}.ext + # 提取原图名 + parent_image = None + if filename.startswith('perturbed_'): + # 去掉perturbed_前缀,得到原图名 + original_filename = filename[len('perturbed_'):] + # 尝试从映射中查找 + parent_image = original_map.get(original_filename) + if not parent_image: + logger.warning(f"Parent image not found for {filename}, original should be: {original_filename}") + + # 获取图片尺寸 + with PILImage.open(perturbed_path) as img: + width, height = img.size + + # 检查是否已经保存过(使用filename作为stored_filename) + existing = Image.query.filter_by( + batch_id=batch_id, + stored_filename=filename + ).first() + + if existing: + logger.info(f"Image already exists: {filename}") + continue + + # 创建扰动图片记录(直接使用filename,因为算法已经添加了perturbed_前缀) + perturbed_image = Image( + user_id=batch.user_id, + batch_id=batch_id, + father_id=parent_image.id if parent_image else None, + original_filename=filename, + stored_filename=filename, # 算法输出已经是perturbed_格式 + file_path=perturbed_path, + file_size=os.path.getsize(perturbed_path), + image_type_id=perturbed_type.id, + width=width, + height=height + ) + + db.session.add(perturbed_image) + saved_count += 1 + logger.info(f"Saved perturbed image: {filename} (parent: {parent_image.stored_filename if parent_image else 'None'})") + + except Exception as e: + logger.error(f"Failed to save {perturbed_path}: {str(e)}") + + db.session.commit() + logger.info(f"Successfully saved {saved_count} perturbed images to database") + + except Exception as e: + logger.error(f"Error saving perturbed images: {str(e)}") + db.session.rollback() diff --git a/src/backend/config/.env b/src/backend/config/.env deleted file mode 100644 index f6d54c1..0000000 --- a/src/backend/config/.env +++ /dev/null @@ -1,16 +0,0 @@ -# MuseGuard 环境变量配置文件 -# 注意:此文件包含敏感信息,不应提交到版本控制系统 - -# 数据库配置 -DB_USER=root -DB_PASSWORD=971817787Lh -DB_HOST=localhost -DB_NAME=db - -# Flask配置 -SECRET_KEY=museguard-secret-key-2024 -JWT_SECRET_KEY=jwt-secret-string - -# 开发模式 -FLASK_ENV=development -FLASK_DEBUG=True \ No newline at end of file diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 65c4d8c..51cd029 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -36,20 +36,20 @@ class AlgorithmConfig: CONDA_ENVS = { 'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'), 'simac': os.getenv('CONDA_ENV_SIMAC', 'simac'), - 'caat': os.getenv('CONDA_ENV_CAAT', 'simac'), - 'pid': os.getenv('CONDA_ENV_PID', 'simac'), - 'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'simac'), - 'lora': os.getenv('CONDA_ENV_LORA', 'simac'), + 'caat': os.getenv('CONDA_ENV_CAAT', 'caat'), + 'pid': os.getenv('CONDA_ENV_PID', 'pid'), + 'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'), + 'lora': os.getenv('CONDA_ENV_LORA', 'pid'), } # 算法脚本配置 - ALGORITHM_SCRIPTS = { + PERTURBATION_SCRIPTS = { 'aspl': { 'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'aspl.py'), 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['aspl'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', 'enable_xformers_memory_efficient_attention': True, 'instance_prompt': 'a photo of sks person', 'class_prompt': 'a photo of person', @@ -59,13 +59,12 @@ class AlgorithmConfig: 'prior_loss_weight': 1.0, 'resolution': 384, 'train_batch_size': 1, - 'max_train_steps': 50, + 'max_train_steps': 2, 'max_f_train_steps': 3, 'max_adv_train_steps': 6, - 'checkpointing_iterations': 10, + 'checkpointing_iterations': 1, 'learning_rate': 5e-7, 'pgd_alpha': 0.005, - 'pgd_eps': 8, 'seed': 0 } }, @@ -74,16 +73,22 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['simac'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', 'enable_xformers_memory_efficient_attention': True, 'instance_prompt': 'a photo of sks person', 'class_prompt': 'a photo of person', 'num_class_images': 200, - 'resolution': 512, + 'center_crop': True, + 'with_prior_preservation': True, + 'prior_loss_weight': 1.0, + 'resolution': 384, 'train_batch_size': 1, - 'max_train_steps': 1000, - 'learning_rate': 5e-6, - 'pgd_eps': 8, + 'max_train_steps': 2, + 'max_f_train_steps': 3, + 'max_adv_train_steps': 6, + 'checkpointing_iterations': 1, + 'learning_rate': 5e-7, + 'pgd_alpha': 0.005, 'seed': 0 } }, @@ -92,17 +97,15 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['caat'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', - 'enable_xformers_memory_efficient_attention': True, - 'instance_prompt': 'a photo of sks person', - 'class_prompt': 'a photo of person', - 'num_class_images': 200, + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'instance_prompt': 'a photo of a person', 'resolution': 512, - 'train_batch_size': 2, - 'max_train_steps': 800, 'learning_rate': 1e-5, - 'pgd_eps': 16, - 'seed': 0 + 'lr_warmup_steps': 0, + 'max_train_steps': 10, + 'hflip': True, + 'mixed_precision': 'bf16', + 'alpha': 5e-3 } }, 'pid': { @@ -110,30 +113,24 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['pid'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', - 'enable_xformers_memory_efficient_attention': True, - 'instance_prompt': 'a photo of sks person', - 'class_prompt': 'a photo of person', - 'num_class_images': 200, + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', 'resolution': 512, - 'train_batch_size': 1, - 'max_train_steps': 600, - 'learning_rate': 3e-6, - 'pgd_eps': 4, - 'seed': 0 + 'max_train_steps': 10, + 'center_crop': True, + 'attack_type': 'add-log' } } } @classmethod - def get_algorithm_config(cls, algorithm_code): + def get_perturbation_config(cls, algorithm_code): """获取算法配置""" - return cls.ALGORITHM_SCRIPTS.get(algorithm_code, {}) + return cls.PERTURBATION_SCRIPTS.get(algorithm_code, {}) @classmethod def get_script_path(cls, algorithm_code): """获取算法脚本路径""" - config = cls.get_algorithm_config(algorithm_code) + config = cls.get_perturbation_config(algorithm_code) if cls.USE_REAL_ALGORITHMS: return config.get('real_script') else: @@ -142,51 +139,69 @@ class AlgorithmConfig: @classmethod def get_conda_env(cls, algorithm_code): """获取算法的conda环境名称""" - config = cls.get_algorithm_config(algorithm_code) + config = cls.get_perturbation_config(algorithm_code) return config.get('conda_env') @classmethod def get_default_params(cls, algorithm_code): """获取算法默认参数""" - config = cls.get_algorithm_config(algorithm_code) + config = cls.get_perturbation_config(algorithm_code) return config.get('default_params', {}).copy() # ========== 微调算法配置 ========== FINETUNE_SCRIPTS = { 'dreambooth': { - 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_alone.py'), + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_dreambooth_gen.py'), 'virtual_script': None, # 使用虚拟实现在worker中 'conda_env': CONDA_ENVS['dreambooth'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', - 'enable_xformers_memory_efficient_attention': True, + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'with_prior_preservation': True, + 'prior_loss_weight': 1.0, 'instance_prompt': 'a photo of sks person', 'class_prompt': 'a photo of person', - 'num_class_images': 200, 'resolution': 512, 'train_batch_size': 1, - 'num_train_epochs': 1, - 'max_train_steps': 1000, - 'learning_rate': 5e-6, - 'with_prior_preservation': True, - 'prior_loss_weight': 1.0, - 'seed': 0 + 'gradient_accumulation_steps': 1, + 'learning_rate': 1e-4, + 'lr_scheduler': 'constant', + 'lr_warmup_steps': 0, + 'num_class_images': 1, + 'max_train_steps': 1, + 'checkpointing_steps': 1, + 'center_crop': True, + 'mixed_precision': 'bf16', + 'prior_generation_precision': 'bf16', + 'sample_batch_size': 1, + 'validation_prompt': 'a photo of sks person', + 'num_validation_images': 1, + 'validation_steps': 1 } }, 'lora': { - 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora.py'), + 'real_script': os.path.join(ALGORITHMS_DIR, 'finetune', 'train_lora_gen.py'), 'virtual_script': None, 'conda_env': CONDA_ENVS['lora'], 'default_params': { - 'pretrained_model_name_or_path': '../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', - 'enable_xformers_memory_efficient_attention': True, + 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'with_prior_preservation': True, + 'prior_loss_weight': 1.0, 'instance_prompt': 'a photo of sks person', + 'class_prompt': 'a photo of person', 'resolution': 512, 'train_batch_size': 1, - 'max_train_steps': 800, + 'gradient_accumulation_steps': 1, 'learning_rate': 1e-4, + 'lr_scheduler': 'constant', + 'lr_warmup_steps': 0, + 'num_class_images': 1, + 'max_train_steps': 1, + 'checkpointing_steps': 1, + 'seed': 0, + 'mixed_precision': 'fp16', 'rank': 4, - 'seed': 0 + 'validation_prompt': 'a photo of sks person', + 'num_validation_images': 1 } } } diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 1078878..caf49a7 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -8,7 +8,7 @@ from urllib.parse import quote_plus # 加载环境变量 - 从 config 目录读取 .env 文件 config_dir = os.path.dirname(os.path.abspath(__file__)) -env_path = os.path.join(config_dir, '.env') +env_path = os.path.join(config_dir, 'settings.env') load_dotenv(env_path) class Config: @@ -49,10 +49,10 @@ class Config: ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'} # 图像处理配置 - ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片 + ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片 PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 - MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果 + MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 diff --git a/src/backend/run.py b/src/backend/run.py index cca949e..ee3f721 100644 --- a/src/backend/run.py +++ b/src/backend/run.py @@ -15,7 +15,7 @@ if __name__ == '__main__': # 开发模式启动 app.run( host='0.0.0.0', - port=5000, + port=6006, debug=True, threaded=True ) \ No newline at end of file diff --git a/src/backend/start.sh b/src/backend/start.sh old mode 100644 new mode 100755 index 7730901..1728d7e --- a/src/backend/start.sh +++ b/src/backend/start.sh @@ -13,7 +13,7 @@ cd "$SCRIPT_DIR" # 激活conda环境 echo "激活 conda 环境: flask" source ~/anaconda3/etc/profile.d/conda.sh -conda activate flask +/root/miniconda3/bin/conda activate flask # 检查conda环境是否激活成功 if [ "$CONDA_DEFAULT_ENV" != "flask" ]; then @@ -24,6 +24,25 @@ fi echo "[成功] conda 环境已激活: $CONDA_DEFAULT_ENV" echo "" +# 检查数据库是否运行(以MySQL为例,可根据实际情况调整) +echo "检查数据库连接..." +if mysqladmin ping -uroot > /dev/null 2>&1; then + echo "[成功] MySQL 连接正常" +else + echo "[警告] MySQL 未运行,正在启动 MySQL..." + sudo systemctl start mysql || sudo service mysql start + sleep 2 + if mysqladmin ping -uroot > /dev/null 2>&1; then + echo "[成功] MySQL 已启动" + else + echo "[错误] 无法启动 MySQL,请手动启动:" + echo " sudo systemctl start mysql" + echo " 或: sudo service mysql start" + exit 1 + fi +fi +echo "" + # 检查Redis是否运行 echo "检查 Redis 连接..." if redis-cli ping > /dev/null 2>&1; then @@ -78,7 +97,7 @@ echo " 启动完成!" echo "========================================" echo "" echo "服务信息:" -echo " - Flask API: http://127.0.0.1:5000" +echo " - Flask API: http://127.0.0.1:6006" echo " - Flask PID: $FLASK_PID" echo " - Worker PID: $WORKER_PID" echo "" diff --git a/src/backend/static/test.html b/src/backend/static/test.html index 719ec40..b02d390 100644 --- a/src/backend/static/test.html +++ b/src/backend/static/test.html @@ -1,1644 +1,565 @@ - - - - - - MuseGuard API 全功能测试页面 - - - -
-

🧪 MuseGuard API 测试页面

- - -
-

🌐 服务器连通性测试

- - -
- - -
-

🎨 Demo Controller - 演示模块

-
-
-

演示图片

- - -
- -
-

算法信息

- -
- - -
-
- - -
-

🧑‍💻 Auth Controller - 认证模块

-
-
-

用户注册

-
- - -
-
- - -
-
- - -
- -
- -
-

用户登录

-
- - -
-
- - -
- - - -

修改密码

-
- - -
-
- - -
-
- - -
- - - - -
- -
-
- - -
-

🔄 Task Controller - 任务管理模块

-
-
-

1. 创建任务(第一步)

-
- - -
- -
- -
-

我的批次列表

-
-

正在加载批次...

-
- -
- -
-

2. 文件上传(第二步)

-
- - -
-
- -
- -

点击或拖放图片文件到此处

-
-
- -
- -
-

3. 配置任务(第三步)

-

配置已自动加载上次使用的设置,您可以根据需要调整

- -
- - -
-
- - - 推荐范围: 1.0 - 16.0 -
-
- - -
-
- - 强防护模式可提高安全性,但会增加处理时间 -
- - -
- -
-

4. 任务管理(第四步)

- - - - -
- -
-
- - - - -
-

🖼️ Image Controller - 图像处理模块

-
-
-

图像查看和下载

-
- - -
- - - -
- -
-

图像评估和对比

-
- - -
-
- - -
- - -
- -
-

其他功能

- - -
- - -
-
- - -
-

👨‍💼 Admin Controller - 管理员模块

-
-
-

用户管理

-
-
- - -
-
- - -
-
- - -
- - -
- -
- -
-

用户创建和编辑

-
- - -
-
- - -
-
- - -
- - - -
- -
-

系统统计

- -
- -
-
-
- - - - \ No newline at end of file + + + + + 基于对抗性扰动的多风格图像生成防护系统 - 测试页面 + + + + + + +
+ + + + + + + + + +
+
后端原始返回数据
+

+
+ + + + diff --git a/src/backend/static/test0.html b/src/backend/static/test0.html new file mode 100644 index 0000000..719ec40 --- /dev/null +++ b/src/backend/static/test0.html @@ -0,0 +1,1644 @@ + + + + + + MuseGuard API 全功能测试页面 + + + +
+

🧪 MuseGuard API 测试页面

+ + +
+

🌐 服务器连通性测试

+ + +
+ + +
+

🎨 Demo Controller - 演示模块

+
+
+

演示图片

+ + +
+ +
+

算法信息

+ +
+ + +
+
+ + +
+

🧑‍💻 Auth Controller - 认证模块

+
+
+

用户注册

+
+ + +
+
+ + +
+
+ + +
+ +
+ +
+

用户登录

+
+ + +
+
+ + +
+ + + +

修改密码

+
+ + +
+
+ + +
+
+ + +
+ + + + +
+ +
+
+ + +
+

🔄 Task Controller - 任务管理模块

+
+
+

1. 创建任务(第一步)

+
+ + +
+ +
+ +
+

我的批次列表

+
+

正在加载批次...

+
+ +
+ +
+

2. 文件上传(第二步)

+
+ + +
+
+ +
+ +

点击或拖放图片文件到此处

+
+
+ +
+ +
+

3. 配置任务(第三步)

+

配置已自动加载上次使用的设置,您可以根据需要调整

+ +
+ + +
+
+ + + 推荐范围: 1.0 - 16.0 +
+
+ + +
+
+ + 强防护模式可提高安全性,但会增加处理时间 +
+ + +
+ +
+

4. 任务管理(第四步)

+ + + + +
+ +
+
+ + + + +
+

🖼️ Image Controller - 图像处理模块

+
+
+

图像查看和下载

+
+ + +
+ + + +
+ +
+

图像评估和对比

+
+ + +
+
+ + +
+ + +
+ +
+

其他功能

+ + +
+ + +
+
+ + +
+

👨‍💼 Admin Controller - 管理员模块

+
+
+

用户管理

+
+
+ + +
+
+ + +
+
+ + +
+ + +
+ +
+ +
+

用户创建和编辑

+
+ + +
+
+ + +
+
+ + +
+ + + +
+ +
+

系统统计

+ +
+ +
+
+
+ + + + \ No newline at end of file diff --git a/src/backend/status.sh b/src/backend/status.sh old mode 100644 new mode 100755 index 559c15d..2c128bb --- a/src/backend/status.sh +++ b/src/backend/status.sh @@ -16,7 +16,8 @@ if [ -f logs/flask.pid ]; then FLASK_PID=$(cat logs/flask.pid) if ps -p $FLASK_PID > /dev/null 2>&1; then echo " ✅ 运行中 (PID: $FLASK_PID)" - echo " 📍 URL: http://127.0.0.1:5000" + echo " 📍 URL: http://127.0.0.1:6006" + echo " 📍 测试: http://127.0.0.1:6006/static/test.html" else echo " ❌ 未运行 (PID文件存在但进程不存在)" fi diff --git a/src/backend/stop.sh b/src/backend/stop.sh old mode 100644 new mode 100755 -- 2.34.1 From 773537e524de861e42836c2fcc9806fd0c1d6557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Sun, 23 Nov 2025 17:57:37 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=BE=AE?= =?UTF-8?q?=E8=B0=83=E7=94=9F=E6=88=90=E6=96=87=E4=BB=B6=E5=A4=B9=E6=B8=85?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/.gitignore | 5 ++++- src/backend/README.md | 19 +++++++++++++++++++ src/backend/app/services/task_service.py | 6 ++++-- src/backend/app/workers/finetune_worker.py | 18 ++++++++++++++++++ src/backend/config/algorithm_config.py | 18 ++++++++++++------ src/backend/config/settings.py | 2 +- src/backend/start.sh | 6 +++--- 7 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/backend/.gitignore b/src/backend/.gitignore index 781efa2..32b09ca 100644 --- a/src/backend/.gitignore +++ b/src/backend/.gitignore @@ -23,4 +23,7 @@ uploads/ *.pkl *.safetensors *.pt -*.txt \ No newline at end of file +*.txt + +# 模型文件 +hf_models/ \ No newline at end of file diff --git a/src/backend/README.md b/src/backend/README.md index f0d4969..09cb774 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -357,3 +357,22 @@ flask run https://docs.pingcode.com/baike/2645380 + + +功能流程正确(本地) +- 测试网页 +- 配置正确加载 +- 微调算法执行时机 +云端正常调用算法 +算法正常执行 +云端部署,本地可直接访问 +api规范 +前端对接 + + +conda activate flask +pip install accelerate +或 +conda install -c conda-forge accelerate + + diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 60ad990..078d1bd 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -411,12 +411,14 @@ class TaskService: # 模型输出目录 original_model_dir = os.path.join( - project_root, 'static', 'models', 'original', + project_root, + current_app.config['MODEL_ORIGINAL_FOLDER'], str(batch.user_id), str(batch.id) ) perturbed_model_dir = os.path.join( - project_root, 'static', 'models', 'perturbed', + project_root, + current_app.config['MODEL_PERTURBED_FOLDER'], str(batch.user_id), str(batch.id) ) diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index 9a382ea..20698ee 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -265,6 +265,24 @@ def _run_real_finetune(finetune_method, batch_id, train_images_dir, output_model os.remove(item_path) elif os.path.isdir(item_path): shutil.rmtree(item_path) + + # 清理output_model_dir中的非图片文件 + logger.info(f"Cleaning non-image files in output directory: {output_model_dir}") + if os.path.exists(output_model_dir): + import shutil + image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff'} + for item in os.listdir(output_model_dir): + item_path = os.path.join(output_model_dir, item) + # 如果是目录,直接删除 + if os.path.isdir(item_path): + logger.info(f"Removing directory: {item_path}") + shutil.rmtree(item_path) + # 如果是文件,检查是否为图片 + elif os.path.isfile(item_path): + _, ext = os.path.splitext(item.lower()) + if ext not in image_extensions: + logger.info(f"Removing non-image file: {item_path}") + os.remove(item_path) return { 'status': 'success', diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 51cd029..dd79776 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -41,6 +41,12 @@ class AlgorithmConfig: 'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'), 'lora': os.getenv('CONDA_ENV_LORA', 'pid'), } + + # 模型路径配置 + MODELS_DIR = { + 'model1': os.getenv('MODEL_SD21', '/root/autodl-tmp/muse-guard_-backend/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06'), + 'model2': os.getenv('MODEL_SD15', '/root/autodl-tmp/muse-guard_-backend/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14'), + } # 算法脚本配置 PERTURBATION_SCRIPTS = { @@ -49,7 +55,7 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['aspl'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'pretrained_model_name_or_path': MODELS_DIR['model1'], 'enable_xformers_memory_efficient_attention': True, 'instance_prompt': 'a photo of sks person', 'class_prompt': 'a photo of person', @@ -73,7 +79,7 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['simac'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06', + 'pretrained_model_name_or_path': MODELS_DIR['model1'], 'enable_xformers_memory_efficient_attention': True, 'instance_prompt': 'a photo of sks person', 'class_prompt': 'a photo of person', @@ -97,7 +103,7 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['caat'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'instance_prompt': 'a photo of a person', 'resolution': 512, 'learning_rate': 1e-5, @@ -113,7 +119,7 @@ class AlgorithmConfig: 'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'), 'conda_env': CONDA_ENVS['pid'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'resolution': 512, 'max_train_steps': 10, 'center_crop': True, @@ -155,7 +161,7 @@ class AlgorithmConfig: 'virtual_script': None, # 使用虚拟实现在worker中 'conda_env': CONDA_ENVS['dreambooth'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'with_prior_preservation': True, 'prior_loss_weight': 1.0, 'instance_prompt': 'a photo of sks person', @@ -183,7 +189,7 @@ class AlgorithmConfig: 'virtual_script': None, 'conda_env': CONDA_ENVS['lora'], 'default_params': { - 'pretrained_model_name_or_path': '/root/autodl-tmp/backend/static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', + 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'with_prior_preservation': True, 'prior_loss_weight': 1.0, 'instance_prompt': 'a photo of sks person', diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index caf49a7..5f4b3a6 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -52,7 +52,7 @@ class Config: ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片 PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片 MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录 - MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 + MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果 MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果 HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图 diff --git a/src/backend/start.sh b/src/backend/start.sh index 1728d7e..5a09f86 100755 --- a/src/backend/start.sh +++ b/src/backend/start.sh @@ -12,8 +12,8 @@ cd "$SCRIPT_DIR" # 激活conda环境 echo "激活 conda 环境: flask" -source ~/anaconda3/etc/profile.d/conda.sh -/root/miniconda3/bin/conda activate flask +source /root/miniconda3/etc/profile.d/conda.sh +conda activate flask # 检查conda环境是否激活成功 if [ "$CONDA_DEFAULT_ENV" != "flask" ]; then @@ -30,7 +30,7 @@ if mysqladmin ping -uroot > /dev/null 2>&1; then echo "[成功] MySQL 连接正常" else echo "[警告] MySQL 未运行,正在启动 MySQL..." - sudo systemctl start mysql || sudo service mysql start + service mysql start sleep 2 if mysqladmin ping -uroot > /dev/null 2>&1; then echo "[成功] MySQL 已启动" -- 2.34.1