将lianghao_branch合并到develop #2
Merged
hnu202326010204
merged 3 commits from lianghao_branch into develop 2 months ago
@ -1,9 +1,18 @@
|
||||
__pycache__/
|
||||
|
||||
venv/
|
||||
|
||||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
|
||||
.env
|
||||
__pycache__/
|
||||
|
||||
venv/
|
||||
python=3.11/
|
||||
|
||||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
|
||||
# 环境配置文件(包含敏感信息)
|
||||
*.env
|
||||
|
||||
# 日志文件
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# 上传文件临时目录
|
||||
uploads/
|
||||
@ -1,25 +1,25 @@
|
||||
# MuseGuard
|
||||
|
||||
占位:项目总说明。后续将补充以下内容:
|
||||
|
||||
## 简介
|
||||
(占位)
|
||||
|
||||
## 项目目标
|
||||
(占位)
|
||||
|
||||
## 技术栈
|
||||
(占位)
|
||||
|
||||
## 快速开始
|
||||
(占位)
|
||||
|
||||
## 目录结构说明
|
||||
(占位)
|
||||
|
||||
## 贡献指南
|
||||
(占位)
|
||||
|
||||
## 许可证
|
||||
(占位)
|
||||
|
||||
# MuseGuard
|
||||
|
||||
占位:项目总说明。后续将补充以下内容:
|
||||
|
||||
## 简介
|
||||
(占位)
|
||||
|
||||
## 项目目标
|
||||
(占位)
|
||||
|
||||
## 技术栈
|
||||
(占位)
|
||||
|
||||
## 快速开始
|
||||
(占位)
|
||||
|
||||
## 目录结构说明
|
||||
(占位)
|
||||
|
||||
## 贡献指南
|
||||
(占位)
|
||||
|
||||
## 许可证
|
||||
(占位)
|
||||
|
||||
|
||||
@ -0,0 +1,29 @@
|
||||
# Python 编译缓存
|
||||
__pycache__/
|
||||
|
||||
# 图片文件
|
||||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
|
||||
# 环境配置文件(包含敏感信息)
|
||||
*.env
|
||||
|
||||
# 日志及进程文件
|
||||
logs/
|
||||
*.log
|
||||
*.pid
|
||||
|
||||
# 上传文件临时目录
|
||||
uploads/
|
||||
|
||||
# 微调生成文件
|
||||
*.json
|
||||
*.bin
|
||||
*.pkl
|
||||
*.safetensors
|
||||
*.pt
|
||||
*.txt
|
||||
|
||||
# 模型文件
|
||||
hf_models/
|
||||
@ -1,46 +1,46 @@
|
||||
"""
|
||||
MuseGuard 后端主应用入口
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
from config.settings import Config
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
|
||||
def create_app(config_class=Config):
|
||||
"""Flask应用工厂函数"""
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config_class)
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
CORS(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
|
||||
return app
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
app.run(debug=True, host='0.0.0.0', port=5000)
|
||||
"""
|
||||
MuseGuard 后端主应用入口
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
from config.settings import Config
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
|
||||
def create_app(config_class=Config):
|
||||
"""Flask应用工厂函数"""
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config_class)
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
CORS(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
|
||||
return app
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
app.run(debug=True, host='0.0.0.0', port=6006)
|
||||
@ -1,83 +1,83 @@
|
||||
"""
|
||||
MuseGuard 后端应用包
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
import os
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
cors = CORS()
|
||||
|
||||
def create_app(config_name=None):
|
||||
"""应用工厂函数"""
|
||||
# 配置静态文件和模板文件路径
|
||||
app = Flask(__name__,
|
||||
static_folder='../static',
|
||||
static_url_path='/static')
|
||||
|
||||
# 加载配置
|
||||
if config_name is None:
|
||||
config_name = os.environ.get('FLASK_ENV', 'development')
|
||||
|
||||
from config.settings import config
|
||||
app.config.from_object(config[config_name])
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
cors.init_app(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
from app.controllers.demo_controller import demo_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
app.register_blueprint(demo_bp, url_prefix='/api/demo')
|
||||
|
||||
# 注册错误处理器
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return {'error': 'Not found'}, 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
db.session.rollback()
|
||||
return {'error': 'Internal server error'}, 500
|
||||
|
||||
# 根路由
|
||||
@app.route('/')
|
||||
def index():
|
||||
return {
|
||||
'message': 'MuseGuard API Server',
|
||||
'version': '1.0.0',
|
||||
'status': 'running',
|
||||
'endpoints': {
|
||||
'health': '/health',
|
||||
'api_docs': '/api',
|
||||
'test_page': '/static/test.html'
|
||||
}
|
||||
}
|
||||
|
||||
# 健康检查端点
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
|
||||
|
||||
"""
|
||||
MuseGuard 后端应用包
|
||||
基于对抗性扰动的多风格图像生成防护系统
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
import os
|
||||
|
||||
# 初始化扩展
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
cors = CORS()
|
||||
|
||||
def create_app(config_name=None):
|
||||
"""应用工厂函数"""
|
||||
# 配置静态文件和模板文件路径
|
||||
app = Flask(__name__,
|
||||
static_folder='../static',
|
||||
static_url_path='/static')
|
||||
|
||||
# 加载配置
|
||||
if config_name is None:
|
||||
config_name = os.environ.get('FLASK_ENV', 'development')
|
||||
|
||||
from config.settings import config
|
||||
app.config.from_object(config[config_name])
|
||||
|
||||
# 初始化扩展
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
cors.init_app(app)
|
||||
|
||||
# 注册蓝图
|
||||
from app.controllers.auth_controller import auth_bp
|
||||
from app.controllers.user_controller import user_bp
|
||||
from app.controllers.task_controller import task_bp
|
||||
from app.controllers.image_controller import image_bp
|
||||
from app.controllers.admin_controller import admin_bp
|
||||
from app.controllers.demo_controller import demo_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(user_bp, url_prefix='/api/user')
|
||||
app.register_blueprint(task_bp, url_prefix='/api/task')
|
||||
app.register_blueprint(image_bp, url_prefix='/api/image')
|
||||
app.register_blueprint(admin_bp, url_prefix='/api/admin')
|
||||
app.register_blueprint(demo_bp, url_prefix='/api/demo')
|
||||
|
||||
# 注册错误处理器
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return {'error': 'Not found'}, 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
db.session.rollback()
|
||||
return {'error': 'Internal server error'}, 500
|
||||
|
||||
# 根路由
|
||||
@app.route('/')
|
||||
def index():
|
||||
return {
|
||||
'message': 'MuseGuard API Server',
|
||||
'version': '1.0.0',
|
||||
'status': 'running',
|
||||
'endpoints': {
|
||||
'health': '/health',
|
||||
'api_docs': '/api',
|
||||
'test_page': '/static/test.html'
|
||||
}
|
||||
}
|
||||
|
||||
# 健康检查端点
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
return {'status': 'healthy', 'message': 'MuseGuard backend is running'}
|
||||
|
||||
return app
|
||||
@ -0,0 +1,87 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning import seed_everything
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inference")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="./test-infer/",
|
||||
help="The output directory where predictions are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The output directory where predictions are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v",
|
||||
type=str,
|
||||
default="sks",
|
||||
help="The output directory where predictions are saved",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_everything(args.seed)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# define prompts
|
||||
prompts = [
|
||||
f"a photo of {args.v} person",
|
||||
f"a dslr portrait of {args.v} person",
|
||||
f"a photo of {args.v} person looking at the mirror",
|
||||
f"a photo of {args.v} person in front of eiffel tower",
|
||||
]
|
||||
|
||||
|
||||
# create & load model
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch.float32,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
).to("cuda")
|
||||
|
||||
for prompt in prompts:
|
||||
print(">>>>>>", prompt)
|
||||
norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
|
||||
out_path = f"{args.output_dir}/{norm_prompt}"
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
all_samples = list()
|
||||
for i in range(5):
|
||||
images = pipe([prompt] * 6, num_inference_steps=100, guidance_scale=7.5,).images
|
||||
for idx, image in enumerate(images):
|
||||
image.save(f"{out_path}/{i}_{idx}.png")
|
||||
image = np.array(image, dtype=np.float32)
|
||||
image /= 255.0
|
||||
image = np.transpose(image, (2, 0, 1))
|
||||
image = torch.from_numpy(image) # numpy->tensor
|
||||
all_samples.append(image)
|
||||
grid = torch.stack(all_samples, 0)
|
||||
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=8)
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
img = Image.fromarray(grid.astype(np.uint8))
|
||||
img.save(f"{args.output_dir}/{prompt}.png")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,773 @@
|
||||
import argparse
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DreamBoothDatasetFromTensor(Dataset):
|
||||
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_images_tensor,
|
||||
instance_prompt,
|
||||
tokenizer,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.instance_images_tensor = instance_images_tensor
|
||||
self.num_instance_images = len(self.instance_images_tensor)
|
||||
self.instance_prompt = instance_prompt
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self._length = max(self.num_class_images, self.num_instance_images)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.instance_images_tensor[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
self.class_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||
" float32 precision."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir_for_train",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir_for_adversarial",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the images to add adversarial noise",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A folder containing the training data of class images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to specify images in the same class as provided instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_loss_weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The weight of prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for sampling images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Total number of training steps to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_f_train_steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Total number of sub-steps to train surogate model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_adv_train_steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Total number of sub-steps to train adversarial noise.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help=("Save a checkpoint of the training state every X iterations."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention",
|
||||
action="store_true",
|
||||
help="Whether or not to use xformers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pgd_alpha",
|
||||
type=float,
|
||||
default=1.0 / 255,
|
||||
help="The step size for pgd.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pgd_eps",
|
||||
type=int,
|
||||
default=0.05,
|
||||
help="The noise budget for pgd.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_image_path",
|
||||
default=None,
|
||||
help="target image for attacking",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
example["prompt"] = self.prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())]
|
||||
images = torch.stack(images)
|
||||
return images
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
args,
|
||||
models,
|
||||
tokenizer,
|
||||
noise_scheduler,
|
||||
vae,
|
||||
data_tensor: torch.Tensor,
|
||||
num_steps=20,
|
||||
):
|
||||
# Load the tokenizer
|
||||
|
||||
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
|
||||
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
train_dataset = DreamBoothDatasetFromTensor(
|
||||
data_tensor,
|
||||
args.instance_prompt,
|
||||
tokenizer,
|
||||
args.class_data_dir,
|
||||
args.class_prompt,
|
||||
args.resolution,
|
||||
args.center_crop,
|
||||
)
|
||||
|
||||
# weight_dtype = torch.bfloat16
|
||||
weight_dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
vae.to(device, dtype=weight_dtype)
|
||||
text_encoder.to(device, dtype=weight_dtype)
|
||||
unet.to(device, dtype=weight_dtype)
|
||||
|
||||
for step in range(num_steps):
|
||||
unet.train()
|
||||
text_encoder.train()
|
||||
|
||||
step_data = train_dataset[step % len(train_dataset)]
|
||||
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
|
||||
device, dtype=weight_dtype
|
||||
)
|
||||
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
|
||||
|
||||
latents = vae.encode(pixel_values).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# with prior preservation loss
|
||||
if args.with_prior_preservation:
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = instance_loss + args.prior_loss_weight * prior_loss
|
||||
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
print(
|
||||
f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}"
|
||||
)
|
||||
|
||||
return [unet, text_encoder]
|
||||
|
||||
|
||||
def pgd_attack(
|
||||
args,
|
||||
models,
|
||||
tokenizer,
|
||||
noise_scheduler,
|
||||
vae,
|
||||
data_tensor: torch.Tensor,
|
||||
original_images: torch.Tensor,
|
||||
target_tensor: torch.Tensor,
|
||||
num_steps: int,
|
||||
):
|
||||
"""Return new perturbed data"""
|
||||
|
||||
unet, text_encoder = models
|
||||
weight_dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
vae.to(device, dtype=weight_dtype)
|
||||
text_encoder.to(device, dtype=weight_dtype)
|
||||
unet.to(device, dtype=weight_dtype)
|
||||
|
||||
perturbed_images = data_tensor.detach().clone()
|
||||
perturbed_images.requires_grad_(True)
|
||||
|
||||
input_ids = tokenizer(
|
||||
args.instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids.repeat(len(data_tensor), 1)
|
||||
|
||||
for step in range(num_steps):
|
||||
perturbed_images.requires_grad = True
|
||||
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
#noise_scheduler.config.num_train_timesteps
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
unet.zero_grad()
|
||||
text_encoder.zero_grad()
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# target-shift loss
|
||||
if target_tensor is not None:
|
||||
xtm1_pred = torch.cat(
|
||||
[
|
||||
noise_scheduler.step(
|
||||
model_pred[idx : idx + 1],
|
||||
timesteps[idx : idx + 1],
|
||||
noisy_latents[idx : idx + 1],
|
||||
).prev_sample
|
||||
for idx in range(len(model_pred))
|
||||
]
|
||||
)
|
||||
xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1)
|
||||
loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
|
||||
|
||||
loss.backward()
|
||||
|
||||
alpha = args.pgd_alpha
|
||||
eps = args.pgd_eps / 255
|
||||
|
||||
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
|
||||
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
|
||||
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
|
||||
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
|
||||
return perturbed_images
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Generate class images if prior preservation is enabled.
|
||||
if args.with_prior_preservation:
|
||||
class_images_dir = Path(args.class_data_dir)
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True)
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
if args.mixed_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.mixed_precision == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load scheduler and models
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
).cuda()
|
||||
|
||||
vae.requires_grad_(False)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
clean_data = load_data(
|
||||
args.instance_data_dir_for_train,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
perturbed_data = load_data(
|
||||
args.instance_data_dir_for_adversarial,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
original_data = perturbed_data.clone()
|
||||
original_data.requires_grad_(False)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
target_latent_tensor = None
|
||||
if args.target_image_path is not None:
|
||||
target_image_path = Path(args.target_image_path)
|
||||
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
|
||||
|
||||
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
|
||||
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
|
||||
|
||||
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
|
||||
target_latent_tensor = (
|
||||
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
|
||||
)
|
||||
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
|
||||
|
||||
f = [unet, text_encoder]
|
||||
for i in range(args.max_train_steps):
|
||||
# 1. f' = f.clone()
|
||||
f_sur = copy.deepcopy(f)
|
||||
f_sur = train_one_epoch(
|
||||
args,
|
||||
f_sur,
|
||||
tokenizer,
|
||||
noise_scheduler,
|
||||
vae,
|
||||
clean_data,
|
||||
args.max_f_train_steps,
|
||||
)
|
||||
perturbed_data = pgd_attack(
|
||||
args,
|
||||
f_sur,
|
||||
tokenizer,
|
||||
noise_scheduler,
|
||||
vae,
|
||||
perturbed_data,
|
||||
original_data,
|
||||
target_latent_tensor,
|
||||
args.max_adv_train_steps,
|
||||
)
|
||||
f = train_one_epoch(
|
||||
args,
|
||||
f,
|
||||
tokenizer,
|
||||
noise_scheduler,
|
||||
vae,
|
||||
perturbed_data,
|
||||
args.max_f_train_steps,
|
||||
)
|
||||
|
||||
if (i + 1) % args.checkpointing_iterations == 0:
|
||||
save_folder = args.output_dir
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
noised_imgs = perturbed_data.detach()
|
||||
|
||||
img_filenames = [
|
||||
Path(instance_path).stem
|
||||
for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
|
||||
]
|
||||
|
||||
for img_pixel, img_name in zip(noised_imgs, img_filenames):
|
||||
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
|
||||
logger.info(f"即将保存图片到: {save_path}")
|
||||
Image.fromarray(
|
||||
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
|
||||
).save(save_path)
|
||||
|
||||
logger.info(f"图片已保存到: {save_path}")
|
||||
|
||||
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -0,0 +1,972 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
example["prompt"] = self.prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
class CustomDiffusionDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
concepts_list,
|
||||
tokenizer,
|
||||
size=512,
|
||||
mask_size=64,
|
||||
center_crop=False,
|
||||
with_prior_preservation=False,
|
||||
num_class_images=200,
|
||||
hflip=False,
|
||||
aug=True,
|
||||
):
|
||||
self.size = size
|
||||
self.mask_size = mask_size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
self.interpolation = Image.BILINEAR
|
||||
self.aug = aug
|
||||
|
||||
self.instance_images_path = []
|
||||
self.class_images_path = []
|
||||
self.with_prior_preservation = with_prior_preservation
|
||||
for concept in concepts_list:
|
||||
inst_img_path = [
|
||||
(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()
|
||||
]
|
||||
self.instance_images_path.extend(inst_img_path)
|
||||
|
||||
if with_prior_preservation:
|
||||
class_data_root = Path(concept["class_data_dir"])
|
||||
if os.path.isdir(class_data_root):
|
||||
class_images_path = list(class_data_root.iterdir())
|
||||
class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))]
|
||||
else:
|
||||
with open(class_data_root, "r") as f:
|
||||
class_images_path = f.read().splitlines()
|
||||
with open(concept["class_prompt"], "r") as f:
|
||||
class_prompt = f.read().splitlines()
|
||||
|
||||
class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
|
||||
self.class_images_path.extend(class_img_path[:num_class_images])
|
||||
|
||||
random.shuffle(self.instance_images_path)
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self._length = max(self.num_class_images, self.num_instance_images)
|
||||
self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
self.flip,
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def preprocess(self, image, scale, resample):
|
||||
outer, inner = self.size, scale
|
||||
factor = self.size // self.mask_size
|
||||
if scale > self.size:
|
||||
outer, inner = scale, self.size
|
||||
top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
|
||||
image = image.resize((scale, scale), resample=resample)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
|
||||
mask = np.zeros((self.size // factor, self.size // factor))
|
||||
if scale > self.size:
|
||||
instance_image = image[top : top + inner, left : left + inner, :]
|
||||
mask = np.ones((self.size // factor, self.size // factor))
|
||||
else:
|
||||
instance_image[top : top + inner, left : left + inner, :] = image
|
||||
mask[
|
||||
top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1
|
||||
] = 1.0
|
||||
return instance_image, mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]
|
||||
instance_image = Image.open(instance_image)
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
instance_image = self.flip(instance_image)
|
||||
|
||||
# apply resize augmentation and create a valid image region mask
|
||||
random_scale = self.size
|
||||
if self.aug:
|
||||
random_scale = (
|
||||
np.random.randint(self.size // 3, self.size + 1)
|
||||
if np.random.uniform() < 0.66
|
||||
else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
|
||||
)
|
||||
instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)
|
||||
|
||||
if random_scale < 0.6 * self.size:
|
||||
instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt
|
||||
elif random_scale > self.size:
|
||||
instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt
|
||||
|
||||
example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
|
||||
example["mask"] = torch.from_numpy(mask)
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.with_prior_preservation:
|
||||
class_image, class_prompt = self.class_images_path[index % self.num_class_images]
|
||||
class_image = Image.open(class_image)
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_mask"] = torch.ones_like(example["mask"])
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
class_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="CAAT training script.")
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=5e-3,
|
||||
required=True,
|
||||
help="PGD alpha.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eps",
|
||||
type=float,
|
||||
default=0.1,
|
||||
required=True,
|
||||
help="PGD eps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A folder containing the training data of class images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to specify images in the same class as provided instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_loss_weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The weight of prior preservation loss."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=200,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="outputs",
|
||||
help="The output directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="A seed for reproducible training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=250,
|
||||
help="Total number of training steps to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=250,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_model",
|
||||
type=str,
|
||||
default="crossattn_kv",
|
||||
choices=["crossattn_kv", "crossattn"],
|
||||
help="crossattn to enable fine-tuning of all params in the cross attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp32", "fp16", "bf16"],
|
||||
help=(
|
||||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concepts_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set_grads_to_none",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
||||
" behaviors, so disable this argument if it causes any problems. More info:"
|
||||
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
|
||||
parser.add_argument(
|
||||
"--noaug",
|
||||
action="store_true",
|
||||
help="Dont apply augmentation during data augmentation when this flag is enabled.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.concepts_list is None:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
accelerator.init_trackers("CAAT", config=vars(args))
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
if args.concepts_list is None:
|
||||
args.concepts_list = [
|
||||
{
|
||||
"instance_prompt": args.instance_prompt,
|
||||
"class_prompt": args.class_prompt,
|
||||
"instance_data_dir": args.instance_data_dir,
|
||||
"class_data_dir": args.class_data_dir,
|
||||
}
|
||||
]
|
||||
else:
|
||||
with open(args.concepts_list, "r") as f:
|
||||
args.concepts_list = json.load(f)
|
||||
|
||||
# Generate class images if prior preservation is enabled.
|
||||
if args.with_prior_preservation:
|
||||
for i, concept in enumerate(args.concepts_list):
|
||||
class_images_dir = Path(concept["class_data_dir"])
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif args.prior_generation_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = (
|
||||
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
)
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name,
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
attention_class = CustomDiffusionAttnProcessor
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
attention_class = CustomDiffusionXFormersAttnProcessor
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# now we will add new Custom Diffusion weights to the attention layers
|
||||
# It's important to realize here how many attention weights will be added and of which sizes
|
||||
# The sizes of the attention layers consist only of two different variables:
|
||||
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
||||
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
||||
|
||||
# Let's first see how many attention processors we will have to set.
|
||||
# For Stable Diffusion, it should be equal to:
|
||||
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
||||
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
||||
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
||||
# => 32 layers
|
||||
|
||||
# Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
|
||||
train_kv = True
|
||||
train_q_out = False if args.freeze_model == "crossattn_kv" else True
|
||||
custom_diffusion_attn_procs = {}
|
||||
|
||||
st = unet.state_dict()
|
||||
|
||||
for name, _ in unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
|
||||
"to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
|
||||
}
|
||||
if train_q_out:
|
||||
weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
|
||||
weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
|
||||
weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
|
||||
if cross_attention_dim is not None:
|
||||
custom_diffusion_attn_procs[name] = attention_class(
|
||||
train_kv=train_kv,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
).to(unet.device)
|
||||
custom_diffusion_attn_procs[name].load_state_dict(weights)
|
||||
else:
|
||||
custom_diffusion_attn_procs[name] = attention_class(
|
||||
train_kv=False,
|
||||
train_q_out=False,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
del st
|
||||
|
||||
|
||||
unet.set_attn_processor(custom_diffusion_attn_procs)
|
||||
custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
accelerator.register_for_checkpointing(custom_diffusion_layers)
|
||||
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
args.learning_rate = args.learning_rate
|
||||
if args.with_prior_preservation:
|
||||
args.learning_rate = args.learning_rate * 2.0
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
optimizer = optimizer_class(
|
||||
custom_diffusion_layers.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset creation:
|
||||
train_dataset = CustomDiffusionDataset(
|
||||
concepts_list=args.concepts_list,
|
||||
tokenizer=tokenizer,
|
||||
with_prior_preservation=args.with_prior_preservation,
|
||||
size=args.resolution,
|
||||
mask_size=vae.encode(
|
||||
torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)
|
||||
)
|
||||
.latent_dist.sample()
|
||||
.size()[-1],
|
||||
center_crop=args.center_crop,
|
||||
num_class_images=args.num_class_images,
|
||||
hflip=args.hflip,
|
||||
aug=not args.noaug,
|
||||
)
|
||||
|
||||
|
||||
# Prepare for PGD
|
||||
pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
|
||||
pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images]
|
||||
pertubed_images = torch.stack(pertubed_images).contiguous()
|
||||
pertubed_images.requires_grad_()
|
||||
|
||||
original_images = pertubed_images.clone().detach()
|
||||
original_images.requires_grad_(False)
|
||||
|
||||
input_ids = train_dataset.tokenizer(
|
||||
args.instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=train_dataset.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids.repeat(len(original_images), 1)
|
||||
|
||||
def get_one_mask(image):
|
||||
random_scale = train_dataset.size
|
||||
if train_dataset.aug:
|
||||
random_scale = (
|
||||
np.random.randint(train_dataset.size // 3, train_dataset.size + 1)
|
||||
if np.random.uniform() < 0.66
|
||||
else np.random.randint(int(1.2 * train_dataset.size), int(1.4 * train_dataset.size))
|
||||
)
|
||||
_, one_mask = train_dataset.preprocess(image, random_scale, train_dataset.interpolation)
|
||||
one_mask = torch.from_numpy(one_mask)
|
||||
if args.with_prior_preservation:
|
||||
class_mask = torch.ones_like(one_mask)
|
||||
one_mask += class_mask
|
||||
return one_mask
|
||||
|
||||
images_open_list = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path]
|
||||
mask_list = []
|
||||
for image in images_open_list:
|
||||
mask_list.append(get_one_mask(image))
|
||||
|
||||
mask = torch.stack(mask_list)
|
||||
mask = mask.to(memory_format=torch.contiguous_format).float()
|
||||
mask = mask.unsqueeze(1)
|
||||
del images_open_list
|
||||
|
||||
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare(
|
||||
custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask
|
||||
)
|
||||
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num pertubed_images = {len(pertubed_images)}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
for epoch in range(first_epoch, args.max_train_steps):
|
||||
unet.train()
|
||||
for _ in range(1):
|
||||
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
pertubed_images.requires_grad = True
|
||||
latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# unet.zero_grad()
|
||||
# text_encoder.zero_grad()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device)
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
mask = mask.to(accelerator.device)
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64])
|
||||
|
||||
#loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
|
||||
loss = loss.mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
custom_diffusion_layers.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
alpha = args.alpha
|
||||
eps = args.eps
|
||||
adv_images = pertubed_images + alpha * pertubed_images.grad.sign()
|
||||
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
|
||||
pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
logger.info("***** Final save of perturbed images *****")
|
||||
save_folder = args.output_dir
|
||||
|
||||
noised_imgs = pertubed_images.detach().cpu()
|
||||
|
||||
img_names = [
|
||||
str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path
|
||||
]
|
||||
|
||||
num_images_to_save = len(img_names)
|
||||
|
||||
for i in range(num_images_to_save):
|
||||
img_pixel = noised_imgs[i]
|
||||
img_name = img_names[i]
|
||||
save_path = os.path.join(save_folder, f"perturbed_{img_name}")
|
||||
|
||||
# 图像转换和保存
|
||||
Image.fromarray(
|
||||
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
|
||||
).save(save_path)
|
||||
|
||||
logger.info(f"Saved {num_images_to_save} final perturbed images to {save_folder}")
|
||||
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
print("<-------end-------->")
|
||||
@ -0,0 +1,274 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||
" float32 precision."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of updating steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--eps',
|
||||
type=float,
|
||||
default=12.75,
|
||||
help='pertubation budget'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--step_size',
|
||||
type=float,
|
||||
default=1/255,
|
||||
help='step size of each update'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--attack_type',
|
||||
choices=['var', 'mean', 'KL', 'add-log', 'latent_vector', 'add'],
|
||||
help='what is the attack target'
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class PIDDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
size=512,
|
||||
center_crop=False
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_instance_images
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
|
||||
example['index'] = index % self.num_instance_images
|
||||
example['pixel_values'] = self.image_transforms(instance_image)
|
||||
return example
|
||||
|
||||
|
||||
def main(args):
|
||||
# Set random seed
|
||||
if args.seed is not None:
|
||||
torch.manual_seed(args.seed)
|
||||
weight_dtype = torch.float32
|
||||
device = torch.device('cuda')
|
||||
|
||||
# VAE encoder
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
vae.requires_grad_(False)
|
||||
vae.to(device, dtype=weight_dtype)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
dataset = PIDDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1, # some parts of code don't support batching
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Wrapper of the perturbations generator
|
||||
class AttackModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
to_tensor = transforms.ToTensor()
|
||||
self.epsilon = args.eps/255
|
||||
self.delta = [torch.empty_like(to_tensor(Image.open(path))).uniform_(-self.epsilon, self.epsilon)
|
||||
for path in dataset.instance_images_path]
|
||||
self.size = dataset.size
|
||||
|
||||
def forward(self, vae, x, index, poison=False):
|
||||
# Check whether we need to add perturbation
|
||||
if poison:
|
||||
self.delta[index].requires_grad_(True)
|
||||
x = x + self.delta[index].to(dtype=weight_dtype)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
input_x = 2 * x - 1
|
||||
return vae.encode(input_x.to(device))
|
||||
|
||||
attackmodel = AttackModel()
|
||||
|
||||
# Just to zero-out the gradient
|
||||
optimizer = torch.optim.SGD(attackmodel.delta, lr=0)
|
||||
|
||||
# Progress bar
|
||||
progress_bar = tqdm(range(0, args.max_train_steps), desc="Steps")
|
||||
|
||||
# Make sure the dir exists
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Start optimizing the perturbation
|
||||
for step in progress_bar:
|
||||
|
||||
total_loss = 0.0
|
||||
for batch in dataloader:
|
||||
# Save images
|
||||
if step%25 == 0:
|
||||
to_image = transforms.ToPILImage()
|
||||
for i in range(0, len(dataset.instance_images_path)):
|
||||
img = dataset[i]['pixel_values']
|
||||
img = to_image(img + attackmodel.delta[i])
|
||||
# 使用原文件名,添加perturbed_前缀
|
||||
original_filename = Path(dataset.instance_images_path[i]).stem
|
||||
img.save(os.path.join(args.output_dir, f"perturbed_{original_filename}.png"))
|
||||
|
||||
|
||||
# Select target loss
|
||||
clean_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], False)
|
||||
poison_embedding = attackmodel(vae, batch['pixel_values'], batch['index'], True)
|
||||
clean_latent = clean_embedding.latent_dist
|
||||
poison_latent = poison_embedding.latent_dist
|
||||
|
||||
if args.attack_type == 'var':
|
||||
loss = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
|
||||
elif args.attack_type == 'mean':
|
||||
loss = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
|
||||
elif args.attack_type == 'KL':
|
||||
sigma_2, mu_2 = poison_latent.std, poison_latent.mean
|
||||
sigma_1, mu_1 = clean_latent.std, clean_latent.mean
|
||||
KL_diver = torch.log(sigma_2 / sigma_1) - 0.5 + (sigma_1 ** 2 + (mu_1 - mu_2) ** 2) / (2 * sigma_2 ** 2)
|
||||
loss = KL_diver.flatten().mean()
|
||||
elif args.attack_type == 'latent_vector':
|
||||
clean_vector = clean_latent.sample()
|
||||
poison_vector = poison_latent.sample()
|
||||
loss = F.mse_loss(clean_vector, poison_vector, reduction="mean")
|
||||
elif args.attack_type == 'add':
|
||||
loss_2 = F.mse_loss(clean_latent.std, poison_latent.std, reduction="mean")
|
||||
loss_1 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction="mean")
|
||||
loss = loss_1 + loss_2
|
||||
elif args.attack_type == 'add-log':
|
||||
loss_1 = F.mse_loss(clean_latent.var.log(), poison_latent.var.log(), reduction="mean")
|
||||
loss_2 = F.mse_loss(clean_latent.mean, poison_latent.mean, reduction='mean')
|
||||
loss = loss_1 + loss_2
|
||||
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Perform PGD update on the loss
|
||||
delta = attackmodel.delta[batch['index']]
|
||||
delta.requires_grad_(False)
|
||||
delta += delta.grad.sign() * 1/255
|
||||
delta = torch.clamp(delta, -attackmodel.epsilon, attackmodel.epsilon)
|
||||
delta = torch.clamp(delta, -batch['pixel_values'].detach().cpu(), 1-batch['pixel_values'].detach().cpu())
|
||||
attackmodel.delta[batch['index']] = delta.detach().squeeze(0)
|
||||
|
||||
total_loss += loss.detach().cpu()
|
||||
|
||||
# Logging steps
|
||||
logs = {"loss": total_loss.item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,176 +0,0 @@
|
||||
"""
|
||||
对抗性扰动算法引擎
|
||||
实现各种加噪算法的虚拟版本
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
import uuid
|
||||
from flask import current_app
|
||||
|
||||
class PerturbationEngine:
|
||||
"""对抗性扰动处理引擎"""
|
||||
|
||||
@staticmethod
|
||||
def apply_perturbation(image_path, algorithm, epsilon, use_strong_protection=False, output_path=None):
|
||||
"""
|
||||
应用对抗性扰动
|
||||
|
||||
Args:
|
||||
image_path: 原始图片路径
|
||||
algorithm: 算法名称 (simac, caat, pid)
|
||||
epsilon: 扰动强度
|
||||
use_strong_protection: 是否使用防净化版本
|
||||
|
||||
Returns:
|
||||
处理后图片的路径
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
# 加载图片
|
||||
with Image.open(image_path) as img:
|
||||
# 转换为RGB模式
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 根据算法选择处理方法
|
||||
if algorithm == 'simac':
|
||||
perturbed_img = PerturbationEngine._apply_simac(img, epsilon, use_strong_protection)
|
||||
elif algorithm == 'caat':
|
||||
perturbed_img = PerturbationEngine._apply_caat(img, epsilon, use_strong_protection)
|
||||
elif algorithm == 'pid':
|
||||
perturbed_img = PerturbationEngine._apply_pid(img, epsilon, use_strong_protection)
|
||||
else:
|
||||
raise ValueError(f"不支持的算法: {algorithm}")
|
||||
|
||||
# 使用输入的output_path参数
|
||||
if output_path is None:
|
||||
# 如果没有提供输出路径,使用默认路径
|
||||
from flask import current_app
|
||||
project_root = os.path.dirname(current_app.root_path)
|
||||
perturbed_dir = os.path.join(project_root, current_app.config['PERTURBED_IMAGES_FOLDER'])
|
||||
os.makedirs(perturbed_dir, exist_ok=True)
|
||||
|
||||
file_extension = os.path.splitext(image_path)[1]
|
||||
output_filename = f"perturbed_{uuid.uuid4().hex[:8]}{file_extension}"
|
||||
output_path = os.path.join(perturbed_dir, output_filename)
|
||||
|
||||
# 保存处理后的图片
|
||||
perturbed_img.save(output_path, quality=95)
|
||||
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"应用扰动时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _apply_simac(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
SimAC算法的虚拟实现
|
||||
Simple Anti-Customization Method for Protecting Face Privacy
|
||||
"""
|
||||
# 将PIL图像转换为numpy数组
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
# 生成随机噪声(模拟对抗性扰动)
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:添加更复杂的扰动模式
|
||||
noise = np.random.normal(0, noise_scale * 0.8, img_array.shape)
|
||||
# 添加结构化噪声
|
||||
h, w = img_array.shape[:2]
|
||||
for i in range(0, h, 8):
|
||||
for j in range(0, w, 8):
|
||||
block_noise = np.random.normal(0, noise_scale * 0.4, (min(8, h-i), min(8, w-j), 3))
|
||||
noise[i:i+8, j:j+8] += block_noise
|
||||
else:
|
||||
# 标准版本:简单高斯噪声
|
||||
noise = np.random.normal(0, noise_scale, img_array.shape)
|
||||
|
||||
# 应用噪声
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
# 转换回PIL图像
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 轻微的图像增强以模拟算法特性
|
||||
enhancer = ImageEnhance.Contrast(result_img)
|
||||
result_img = enhancer.enhance(1.02)
|
||||
|
||||
return result_img
|
||||
|
||||
@staticmethod
|
||||
def _apply_caat(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
CAAT算法的虚拟实现
|
||||
Perturbing Attention Gives You More Bang for the Buck
|
||||
"""
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:注意力区域重点扰动
|
||||
# 模拟注意力图(简单的边缘检测)
|
||||
gray_img = img.convert('L')
|
||||
edge_img = gray_img.filter(ImageFilter.FIND_EDGES)
|
||||
attention_map = np.array(edge_img, dtype=np.float32) / 255.0
|
||||
|
||||
# 在注意力区域添加更强的噪声
|
||||
noise = np.random.normal(0, noise_scale * 0.6, img_array.shape)
|
||||
for c in range(3):
|
||||
noise[:,:,c] += attention_map * np.random.normal(0, noise_scale * 0.8, attention_map.shape)
|
||||
else:
|
||||
# 标准版本:均匀分布噪声
|
||||
noise = np.random.uniform(-noise_scale, noise_scale, img_array.shape)
|
||||
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 轻微模糊以模拟注意力扰动效果
|
||||
result_img = result_img.filter(ImageFilter.BoxBlur(0.5))
|
||||
|
||||
return result_img
|
||||
|
||||
@staticmethod
|
||||
def _apply_pid(img, epsilon, use_strong_protection):
|
||||
"""
|
||||
PID算法的虚拟实现
|
||||
Prompt-Independent Data Protection Against Latent Diffusion Models
|
||||
"""
|
||||
img_array = np.array(img, dtype=np.float32)
|
||||
|
||||
noise_scale = epsilon / 255.0
|
||||
|
||||
if use_strong_protection:
|
||||
# 防净化版本:频域扰动
|
||||
# 简单的频域变换模拟
|
||||
noise = np.random.laplace(0, noise_scale * 0.7, img_array.shape)
|
||||
# 添加周期性扰动
|
||||
h, w = img_array.shape[:2]
|
||||
for i in range(h):
|
||||
for j in range(w):
|
||||
periodic_noise = noise_scale * 0.3 * np.sin(i * 0.1) * np.cos(j * 0.1)
|
||||
noise[i, j] += periodic_noise
|
||||
else:
|
||||
# 标准版本:拉普拉斯噪声
|
||||
noise = np.random.laplace(0, noise_scale * 0.5, img_array.shape)
|
||||
|
||||
perturbed_array = img_array + noise * 255.0
|
||||
perturbed_array = np.clip(perturbed_array, 0, 255).astype(np.uint8)
|
||||
|
||||
result_img = Image.fromarray(perturbed_array)
|
||||
|
||||
# 色彩微调以模拟潜在空间扰动
|
||||
enhancer = ImageEnhance.Color(result_img)
|
||||
result_img = enhancer.enhance(0.98)
|
||||
|
||||
return result_img
|
||||
@ -1,239 +1,239 @@
|
||||
"""
|
||||
管理员控制器
|
||||
处理管理员功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from app import db
|
||||
from app.models import User, Batch, Image
|
||||
|
||||
admin_bp = Blueprint('admin', __name__)
|
||||
|
||||
def admin_required(f):
|
||||
"""管理员权限装饰器"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
current_user_id = get_jwt_identity()
|
||||
user = User.query.get(current_user_id)
|
||||
|
||||
if not user or user.role != 'admin':
|
||||
return jsonify({'error': '需要管理员权限'}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
@admin_bp.route('/users', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def list_users():
|
||||
"""获取用户列表"""
|
||||
try:
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 20, type=int)
|
||||
|
||||
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
return jsonify({
|
||||
'users': [user.to_dict() for user in users.items],
|
||||
'total': users.total,
|
||||
'pages': users.pages,
|
||||
'current_page': page
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_user_detail(user_id):
|
||||
"""获取用户详情"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 获取用户统计信息
|
||||
total_tasks = Batch.query.filter_by(user_id=user_id).count()
|
||||
total_images = Image.query.filter_by(user_id=user_id).count()
|
||||
|
||||
user_dict = user.to_dict()
|
||||
user_dict['stats'] = {
|
||||
'total_tasks': total_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
|
||||
return jsonify({'user': user_dict}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def create_user():
|
||||
"""创建用户"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
email = data.get('email')
|
||||
role = data.get('role', 'user')
|
||||
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
|
||||
|
||||
if not username or not password:
|
||||
return jsonify({'error': '用户名和密码不能为空'}), 400
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if User.query.filter_by(username=username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if email and User.query.filter_by(email=email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
role=role,
|
||||
max_concurrent_tasks=max_concurrent_tasks
|
||||
)
|
||||
user.set_password(password)
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户创建成功',
|
||||
'user': user.to_dict()
|
||||
}), 201
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def update_user(user_id):
|
||||
"""更新用户信息"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
# 更新字段
|
||||
if 'username' in data:
|
||||
new_username = data['username']
|
||||
if new_username != user.username:
|
||||
if User.query.filter_by(username=new_username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
user.username = new_username
|
||||
|
||||
if 'email' in data:
|
||||
new_email = data['email']
|
||||
if new_email != user.email:
|
||||
if User.query.filter_by(email=new_email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
user.email = new_email
|
||||
|
||||
if 'role' in data:
|
||||
user.role = data['role']
|
||||
|
||||
if 'max_concurrent_tasks' in data:
|
||||
user.max_concurrent_tasks = data['max_concurrent_tasks']
|
||||
|
||||
if 'is_active' in data:
|
||||
user.is_active = bool(data['is_active'])
|
||||
|
||||
if 'password' in data and data['password']:
|
||||
user.set_password(data['password'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户信息更新成功',
|
||||
'user': user.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def delete_user(user_id):
|
||||
"""删除用户"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 不能删除自己
|
||||
if user_id == current_user_id:
|
||||
return jsonify({'error': '不能删除自己的账户'}), 400
|
||||
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 删除用户(级联删除相关数据)
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': '用户删除成功'}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/stats', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_system_stats():
|
||||
"""获取系统统计信息"""
|
||||
try:
|
||||
from app.models import EvaluationResult
|
||||
|
||||
total_users = User.query.count()
|
||||
active_users = User.query.filter_by(is_active=True).count()
|
||||
admin_users = User.query.filter_by(role='admin').count()
|
||||
|
||||
total_tasks = Batch.query.count()
|
||||
completed_tasks = Batch.query.filter_by(status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(status='failed').count()
|
||||
|
||||
total_images = Image.query.count()
|
||||
total_evaluations = EvaluationResult.query.count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'users': {
|
||||
'total': total_users,
|
||||
'active': active_users,
|
||||
'admin': admin_users
|
||||
},
|
||||
'tasks': {
|
||||
'total': total_tasks,
|
||||
'completed': completed_tasks,
|
||||
'processing': processing_tasks,
|
||||
'failed': failed_tasks
|
||||
},
|
||||
'images': {
|
||||
'total': total_images
|
||||
},
|
||||
'evaluations': {
|
||||
'total': total_evaluations
|
||||
}
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
"""
|
||||
管理员控制器
|
||||
处理管理员功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from app import db
|
||||
from app.database import User, Batch, Image
|
||||
|
||||
admin_bp = Blueprint('admin', __name__)
|
||||
|
||||
def admin_required(f):
|
||||
"""管理员权限装饰器"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
current_user_id = get_jwt_identity()
|
||||
user = User.query.get(current_user_id)
|
||||
|
||||
if not user or user.role != 'admin':
|
||||
return jsonify({'error': '需要管理员权限'}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
@admin_bp.route('/users', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def list_users():
|
||||
"""获取用户列表"""
|
||||
try:
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 20, type=int)
|
||||
|
||||
users = User.query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
return jsonify({
|
||||
'users': [user.to_dict() for user in users.items],
|
||||
'total': users.total,
|
||||
'pages': users.pages,
|
||||
'current_page': page
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户列表失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_user_detail(user_id):
|
||||
"""获取用户详情"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 获取用户统计信息
|
||||
total_tasks = Batch.query.filter_by(user_id=user_id).count()
|
||||
total_images = Image.query.filter_by(user_id=user_id).count()
|
||||
|
||||
user_dict = user.to_dict()
|
||||
user_dict['stats'] = {
|
||||
'total_tasks': total_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
|
||||
return jsonify({'user': user_dict}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户详情失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def create_user():
|
||||
"""创建用户"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
email = data.get('email')
|
||||
role = data.get('role', 'user')
|
||||
max_concurrent_tasks = data.get('max_concurrent_tasks', 0)
|
||||
|
||||
if not username or not password:
|
||||
return jsonify({'error': '用户名和密码不能为空'}), 400
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if User.query.filter_by(username=username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if email and User.query.filter_by(email=email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
role=role,
|
||||
max_concurrent_tasks=max_concurrent_tasks
|
||||
)
|
||||
user.set_password(password)
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户创建成功',
|
||||
'user': user.to_dict()
|
||||
}), 201
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'创建用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['PUT'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def update_user(user_id):
|
||||
"""更新用户信息"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
# 更新字段
|
||||
if 'username' in data:
|
||||
new_username = data['username']
|
||||
if new_username != user.username:
|
||||
if User.query.filter_by(username=new_username).first():
|
||||
return jsonify({'error': '用户名已存在'}), 400
|
||||
user.username = new_username
|
||||
|
||||
if 'email' in data:
|
||||
new_email = data['email']
|
||||
if new_email != user.email:
|
||||
if User.query.filter_by(email=new_email).first():
|
||||
return jsonify({'error': '邮箱已被使用'}), 400
|
||||
user.email = new_email
|
||||
|
||||
if 'role' in data:
|
||||
user.role = data['role']
|
||||
|
||||
if 'max_concurrent_tasks' in data:
|
||||
user.max_concurrent_tasks = data['max_concurrent_tasks']
|
||||
|
||||
if 'is_active' in data:
|
||||
user.is_active = bool(data['is_active'])
|
||||
|
||||
if 'password' in data and data['password']:
|
||||
user.set_password(data['password'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户信息更新成功',
|
||||
'user': user.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/users/<int:user_id>', methods=['DELETE'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def delete_user(user_id):
|
||||
"""删除用户"""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# 不能删除自己
|
||||
if user_id == current_user_id:
|
||||
return jsonify({'error': '不能删除自己的账户'}), 400
|
||||
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return jsonify({'error': '用户不存在'}), 404
|
||||
|
||||
# 删除用户(级联删除相关数据)
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': '用户删除成功'}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'删除用户失败: {str(e)}'}), 500
|
||||
|
||||
@admin_bp.route('/stats', methods=['GET'])
|
||||
@jwt_required()
|
||||
@admin_required
|
||||
def get_system_stats():
|
||||
"""获取系统统计信息"""
|
||||
try:
|
||||
from app.database import EvaluationResult
|
||||
|
||||
total_users = User.query.count()
|
||||
active_users = User.query.filter_by(is_active=True).count()
|
||||
admin_users = User.query.filter_by(role='admin').count()
|
||||
|
||||
total_tasks = Batch.query.count()
|
||||
completed_tasks = Batch.query.filter_by(status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(status='failed').count()
|
||||
|
||||
total_images = Image.query.count()
|
||||
total_evaluations = EvaluationResult.query.count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'users': {
|
||||
'total': total_users,
|
||||
'active': active_users,
|
||||
'admin': admin_users
|
||||
},
|
||||
'tasks': {
|
||||
'total': total_tasks,
|
||||
'completed': completed_tasks,
|
||||
'processing': processing_tasks,
|
||||
'failed': failed_tasks
|
||||
},
|
||||
'images': {
|
||||
'total': total_images
|
||||
},
|
||||
'evaluations': {
|
||||
'total': total_evaluations
|
||||
}
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取系统统计失败: {str(e)}'}), 500
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,133 +1,133 @@
|
||||
"""
|
||||
用户管理控制器
|
||||
处理用户配置等功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from app import db
|
||||
from app.models import User, UserConfig, PerturbationConfig, FinetuneConfig
|
||||
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
|
||||
|
||||
user_bp = Blueprint('user', __name__)
|
||||
|
||||
@user_bp.route('/config', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_config(current_user_id):
|
||||
"""获取用户配置"""
|
||||
try:
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
# 如果没有配置,创建默认配置
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/config', methods=['PUT'])
|
||||
@int_jwt_required
|
||||
def update_user_config(current_user_id):
|
||||
"""更新用户配置"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
# 更新配置字段
|
||||
if 'preferred_perturbation_config_id' in data:
|
||||
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
|
||||
|
||||
if 'preferred_epsilon' in data:
|
||||
epsilon = float(data['preferred_epsilon'])
|
||||
if 0 < epsilon <= 255:
|
||||
user_config.preferred_epsilon = epsilon
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'preferred_finetune_config_id' in data:
|
||||
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
|
||||
|
||||
if 'preferred_purification' in data:
|
||||
user_config.preferred_purification = bool(data['preferred_purification'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户配置更新成功',
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/algorithms', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_available_algorithms():
|
||||
"""获取可用的算法列表"""
|
||||
try:
|
||||
perturbation_configs = PerturbationConfig.query.all()
|
||||
finetune_configs = FinetuneConfig.query.all()
|
||||
|
||||
return jsonify({
|
||||
'perturbation_algorithms': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description,
|
||||
'default_epsilon': float(config.default_epsilon)
|
||||
} for config in perturbation_configs
|
||||
],
|
||||
'finetune_methods': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description
|
||||
} for config in finetune_configs
|
||||
]
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/stats', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_stats(current_user_id):
|
||||
"""获取用户统计信息"""
|
||||
try:
|
||||
from app.models import Batch, Image
|
||||
|
||||
# 统计用户的任务和图片数量
|
||||
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
|
||||
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
|
||||
|
||||
total_images = Image.query.filter_by(user_id=current_user_id).count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'total_tasks': total_tasks,
|
||||
'completed_tasks': completed_tasks,
|
||||
'processing_tasks': processing_tasks,
|
||||
'failed_tasks': failed_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
"""
|
||||
用户管理控制器
|
||||
处理用户配置等功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from app import db
|
||||
from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig
|
||||
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
|
||||
|
||||
user_bp = Blueprint('user', __name__)
|
||||
|
||||
@user_bp.route('/config', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_config(current_user_id):
|
||||
"""获取用户配置"""
|
||||
try:
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
# 如果没有配置,创建默认配置
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/config', methods=['PUT'])
|
||||
@int_jwt_required
|
||||
def update_user_config(current_user_id):
|
||||
"""更新用户配置"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
# 更新配置字段
|
||||
if 'preferred_perturbation_config_id' in data:
|
||||
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
|
||||
|
||||
if 'preferred_epsilon' in data:
|
||||
epsilon = float(data['preferred_epsilon'])
|
||||
if 0 < epsilon <= 255:
|
||||
user_config.preferred_epsilon = epsilon
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'preferred_finetune_config_id' in data:
|
||||
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
|
||||
|
||||
if 'preferred_purification' in data:
|
||||
user_config.preferred_purification = bool(data['preferred_purification'])
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户配置更新成功',
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/algorithms', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_available_algorithms():
|
||||
"""获取可用的算法列表"""
|
||||
try:
|
||||
perturbation_configs = PerturbationConfig.query.all()
|
||||
finetune_configs = FinetuneConfig.query.all()
|
||||
|
||||
return jsonify({
|
||||
'perturbation_algorithms': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description,
|
||||
'default_epsilon': float(config.default_epsilon)
|
||||
} for config in perturbation_configs
|
||||
],
|
||||
'finetune_methods': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description
|
||||
} for config in finetune_configs
|
||||
]
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/stats', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_stats(current_user_id):
|
||||
"""获取用户统计信息"""
|
||||
try:
|
||||
from app.database import Batch, Image
|
||||
|
||||
# 统计用户的任务和图片数量
|
||||
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
|
||||
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
|
||||
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
|
||||
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
|
||||
|
||||
total_images = Image.query.filter_by(user_id=current_user_id).count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'total_tasks': total_tasks,
|
||||
'completed_tasks': completed_tasks,
|
||||
'processing_tasks': processing_tasks,
|
||||
'failed_tasks': failed_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500
|
||||
@ -1,34 +1,34 @@
|
||||
"""
|
||||
认证服务
|
||||
处理用户认证相关逻辑
|
||||
"""
|
||||
|
||||
from app.models import User
|
||||
|
||||
class AuthService:
|
||||
"""认证服务类"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(username, password):
|
||||
"""验证用户凭据"""
|
||||
user = User.query.filter_by(username=username).first()
|
||||
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
return User.query.get(user_id)
|
||||
|
||||
@staticmethod
|
||||
def is_email_available(email):
|
||||
"""检查邮箱是否可用"""
|
||||
return User.query.filter_by(email=email).first() is None
|
||||
|
||||
@staticmethod
|
||||
def is_username_available(username):
|
||||
"""检查用户名是否可用"""
|
||||
"""
|
||||
认证服务
|
||||
处理用户认证相关逻辑
|
||||
"""
|
||||
|
||||
from app.database import User
|
||||
|
||||
class AuthService:
|
||||
"""认证服务类"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(username, password):
|
||||
"""验证用户凭据"""
|
||||
user = User.query.filter_by(username=username).first()
|
||||
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
return User.query.get(user_id)
|
||||
|
||||
@staticmethod
|
||||
def is_email_available(email):
|
||||
"""检查邮箱是否可用"""
|
||||
return User.query.filter_by(email=email).first() is None
|
||||
|
||||
@staticmethod
|
||||
def is_username_available(username):
|
||||
"""检查用户名是否可用"""
|
||||
return User.query.filter_by(username=username).first() is None
|
||||
@ -1,53 +1,53 @@
|
||||
"""
|
||||
文件处理工具类
|
||||
"""
|
||||
|
||||
import os
|
||||
from werkzeug.utils import secure_filename
|
||||
from flask import current_app
|
||||
|
||||
def allowed_file(filename):
|
||||
"""检查文件扩展名是否被允许"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
|
||||
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
|
||||
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in allowed_extensions
|
||||
|
||||
def save_uploaded_file(file, upload_path):
|
||||
"""保存上传的文件"""
|
||||
try:
|
||||
if not file or not allowed_file(file.filename):
|
||||
return None
|
||||
|
||||
filename = secure_filename(file.filename)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
|
||||
|
||||
file.save(upload_path)
|
||||
return upload_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_file_size(file_path):
|
||||
"""获取文件大小"""
|
||||
try:
|
||||
return os.path.getsize(file_path)
|
||||
except:
|
||||
return 0
|
||||
|
||||
def delete_file(file_path):
|
||||
"""删除文件"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
"""
|
||||
文件处理工具类
|
||||
"""
|
||||
|
||||
import os
|
||||
from werkzeug.utils import secure_filename
|
||||
from flask import current_app
|
||||
|
||||
def allowed_file(filename):
|
||||
"""检查文件扩展名是否被允许"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS',
|
||||
{'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'})
|
||||
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in allowed_extensions
|
||||
|
||||
def save_uploaded_file(file, upload_path):
|
||||
"""保存上传的文件"""
|
||||
try:
|
||||
if not file or not allowed_file(file.filename):
|
||||
return None
|
||||
|
||||
filename = secure_filename(file.filename)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(upload_path), exist_ok=True)
|
||||
|
||||
file.save(upload_path)
|
||||
return upload_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_file_size(file_path):
|
||||
"""获取文件大小"""
|
||||
try:
|
||||
return os.path.getsize(file_path)
|
||||
except:
|
||||
return 0
|
||||
|
||||
def delete_file(file_path):
|
||||
"""删除文件"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
@ -1,16 +1,16 @@
|
||||
"""
|
||||
JWT工具函数
|
||||
"""
|
||||
from functools import wraps
|
||||
from flask_jwt_extended import get_jwt_identity
|
||||
|
||||
def int_jwt_required(f):
|
||||
"""获取JWT身份并转换为整数的装饰器"""
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
jwt_identity = get_jwt_identity()
|
||||
if jwt_identity is not None:
|
||||
# 在函数调用前注入转换后的user_id
|
||||
kwargs['current_user_id'] = int(jwt_identity)
|
||||
return f(*args, **kwargs)
|
||||
"""
|
||||
JWT工具函数
|
||||
"""
|
||||
from functools import wraps
|
||||
from flask_jwt_extended import get_jwt_identity
|
||||
|
||||
def int_jwt_required(f):
|
||||
"""获取JWT身份并转换为整数的装饰器"""
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
jwt_identity = get_jwt_identity()
|
||||
if jwt_identity is not None:
|
||||
# 在函数调用前注入转换后的user_id
|
||||
kwargs['current_user_id'] = int(jwt_identity)
|
||||
return f(*args, **kwargs)
|
||||
return wrapped
|
||||
@ -1,16 +0,0 @@
|
||||
# MuseGuard 环境变量配置文件
|
||||
# 注意:此文件包含敏感信息,不应提交到版本控制系统
|
||||
|
||||
# 数据库配置
|
||||
DB_USER=root
|
||||
DB_PASSWORD=your_password_here
|
||||
DB_HOST=localhost
|
||||
DB_NAME=your_database_name_here
|
||||
|
||||
# Flask配置
|
||||
SECRET_KEY=museguard-secret-key-2024
|
||||
JWT_SECRET_KEY=jwt-secret-string
|
||||
|
||||
# 开发模式
|
||||
FLASK_ENV=development
|
||||
FLASK_DEBUG=True
|
||||
@ -1,109 +1,109 @@
|
||||
"""
|
||||
应用配置文件
|
||||
"""
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from dotenv import load_dotenv
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
# 加载环境变量 - 从 config 目录读取 .env 文件
|
||||
config_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(config_dir, '.env')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class Config:
|
||||
"""基础配置类"""
|
||||
|
||||
# 基础配置
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
|
||||
|
||||
# 数据库配置 - 支持密码中的特殊字符
|
||||
DB_USER = os.environ.get('DB_USER')
|
||||
DB_PASSWORD = os.environ.get('DB_PASSWORD')
|
||||
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
|
||||
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
|
||||
|
||||
# URL编码密码中的特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
_encoded_password = quote_plus(DB_PASSWORD)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
|
||||
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
SQLALCHEMY_ENGINE_OPTIONS = {
|
||||
'pool_pre_ping': True,
|
||||
'pool_recycle': 300,
|
||||
}
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
|
||||
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
|
||||
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
|
||||
|
||||
# 静态文件根目录
|
||||
STATIC_ROOT = 'static'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = 'uploads' # 临时上传目录
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
|
||||
|
||||
# 图像处理配置
|
||||
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'originals') # 重命名后的原始图片
|
||||
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
|
||||
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
|
||||
MODEL_CLEAN_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'clean') # 原图的模型生成结果
|
||||
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
|
||||
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
|
||||
|
||||
# 预设演示图像配置
|
||||
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
|
||||
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
|
||||
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
|
||||
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
|
||||
|
||||
# 邮件配置(用于注册验证)
|
||||
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
|
||||
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
|
||||
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
|
||||
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
|
||||
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
|
||||
|
||||
# 算法配置
|
||||
ALGORITHMS = {
|
||||
'simac': {
|
||||
'name': 'SimAC算法',
|
||||
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
|
||||
'default_epsilon': 8.0
|
||||
},
|
||||
'caat': {
|
||||
'name': 'CAAT算法',
|
||||
'description': 'Perturbing Attention Gives You More Bang for the Buck',
|
||||
'default_epsilon': 16.0
|
||||
},
|
||||
'pid': {
|
||||
'name': 'PID算法',
|
||||
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
|
||||
'default_epsilon': 4.0
|
||||
}
|
||||
}
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""开发环境配置"""
|
||||
DEBUG = True
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""生产环境配置"""
|
||||
DEBUG = False
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""测试环境配置"""
|
||||
TESTING = True
|
||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig
|
||||
"""
|
||||
应用配置文件
|
||||
"""
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from dotenv import load_dotenv
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
# 加载环境变量 - 从 config 目录读取 .env 文件
|
||||
config_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(config_dir, 'settings.env')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class Config:
|
||||
"""基础配置类"""
|
||||
|
||||
# 基础配置
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY') or 'museguard-secret-key-2024'
|
||||
|
||||
# 数据库配置 - 支持密码中的特殊字符
|
||||
DB_USER = os.environ.get('DB_USER') or 'root'
|
||||
DB_PASSWORD = os.environ.get('DB_PASSWORD') or ''
|
||||
DB_HOST = os.environ.get('DB_HOST') or 'localhost'
|
||||
DB_NAME = os.environ.get('DB_NAME') or 'museguard_schema'
|
||||
|
||||
# URL编码密码中的特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
_encoded_password = quote_plus(DB_PASSWORD) if DB_PASSWORD else ''
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
|
||||
f'mysql+pymysql://{DB_USER}:{_encoded_password}@{DB_HOST}/{DB_NAME}'
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
SQLALCHEMY_ENGINE_OPTIONS = {
|
||||
'pool_pre_ping': True,
|
||||
'pool_recycle': 300,
|
||||
}
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') or 'jwt-secret-string'
|
||||
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
|
||||
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
|
||||
|
||||
# 静态文件根目录
|
||||
STATIC_ROOT = 'static'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = 'uploads' # 临时上传目录
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'zip', 'rar'}
|
||||
|
||||
# 图像处理配置
|
||||
ORIGINAL_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'original') # 重命名后的原始图片
|
||||
PERTURBED_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'perturbed') # 加噪后的图片
|
||||
MODEL_OUTPUTS_FOLDER = os.path.join(STATIC_ROOT, 'model_outputs') # 模型生成的图片根目录
|
||||
MODEL_ORIGINAL_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'original') # 原图的模型生成结果
|
||||
MODEL_PERTURBED_FOLDER = os.path.join(MODEL_OUTPUTS_FOLDER, 'perturbed') # 加噪图的模型生成结果
|
||||
HEATMAP_FOLDER = os.path.join(STATIC_ROOT, 'heatmaps') # 热力图
|
||||
|
||||
# 预设演示图像配置
|
||||
DEMO_IMAGES_FOLDER = os.path.join(STATIC_ROOT, 'demo') # 演示图片根目录
|
||||
DEMO_ORIGINAL_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'original') # 演示原始图片
|
||||
DEMO_PERTURBED_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'perturbed') # 演示加噪图片
|
||||
DEMO_COMPARISONS_FOLDER = os.path.join(DEMO_IMAGES_FOLDER, 'comparisons') # 演示对比图
|
||||
|
||||
# 邮件配置(用于注册验证)
|
||||
MAIL_SERVER = os.environ.get('MAIL_SERVER') or 'smtp.gmail.com'
|
||||
MAIL_PORT = int(os.environ.get('MAIL_PORT') or 587)
|
||||
MAIL_USE_TLS = os.environ.get('MAIL_USE_TLS', 'true').lower() in ['true', 'on', '1']
|
||||
MAIL_USERNAME = os.environ.get('MAIL_USERNAME')
|
||||
MAIL_PASSWORD = os.environ.get('MAIL_PASSWORD')
|
||||
|
||||
# 算法配置
|
||||
ALGORITHMS = {
|
||||
'simac': {
|
||||
'name': 'SimAC算法',
|
||||
'description': 'Simple Anti-Customization Method for Protecting Face Privacy',
|
||||
'default_epsilon': 8.0
|
||||
},
|
||||
'caat': {
|
||||
'name': 'CAAT算法',
|
||||
'description': 'Perturbing Attention Gives You More Bang for the Buck',
|
||||
'default_epsilon': 16.0
|
||||
},
|
||||
'pid': {
|
||||
'name': 'PID算法',
|
||||
'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models',
|
||||
'default_epsilon': 4.0
|
||||
}
|
||||
}
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""开发环境配置"""
|
||||
DEBUG = True
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""生产环境配置"""
|
||||
DEBUG = False
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""测试环境配置"""
|
||||
TESTING = True
|
||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig
|
||||
}
|
||||
@ -1,15 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 > nul
|
||||
echo ==========================================
|
||||
echo MuseGuard 快速启动脚本
|
||||
echo ==========================================
|
||||
echo.
|
||||
|
||||
REM 切换到项目目录
|
||||
cd /d "d:\code\Software_Project\team_project\MuseGuard\src\backend"
|
||||
|
||||
REM 激活虚拟环境并启动服务器
|
||||
echo 正在激活虚拟环境并启动服务器...
|
||||
call venv_py311\Scripts\activate.bat && python run.py
|
||||
|
||||
pause
|
||||
@ -1,33 +1,37 @@
|
||||
# Core Flask Framework
|
||||
Flask==3.0.0
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
Flask-Migrate==4.0.5
|
||||
Flask-JWT-Extended==4.6.0
|
||||
Flask-CORS==5.0.0
|
||||
Werkzeug==3.0.1
|
||||
|
||||
# Database
|
||||
PyMySQL==1.1.1
|
||||
|
||||
# Image Processing
|
||||
Pillow==10.4.0
|
||||
numpy==1.26.4
|
||||
|
||||
# Security & Utils
|
||||
cryptography==42.0.8
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# Additional Dependencies (auto-installed)
|
||||
# alembic==1.17.0
|
||||
# blinker==1.9.0
|
||||
# cffi==2.0.0
|
||||
# click==8.3.0
|
||||
# greenlet==3.2.4
|
||||
# itsdangerous==2.2.0
|
||||
# Jinja2==3.1.6
|
||||
# Mako==1.3.10
|
||||
# MarkupSafe==3.0.3
|
||||
# pycparser==2.23
|
||||
# PyJWT==2.10.1
|
||||
# SQLAlchemy==2.0.44
|
||||
# Core Flask Framework
|
||||
Flask==3.0.0
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
Flask-Migrate==4.0.5
|
||||
Flask-JWT-Extended==4.6.0
|
||||
Flask-CORS==5.0.0
|
||||
Werkzeug==3.0.1
|
||||
|
||||
# Database
|
||||
PyMySQL==1.1.1
|
||||
|
||||
# Image Processing
|
||||
Pillow==10.4.0
|
||||
numpy==1.26.4
|
||||
|
||||
# Security & Utils
|
||||
cryptography==42.0.8
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# Task Queue
|
||||
redis==5.0.1
|
||||
rq==1.16.2
|
||||
|
||||
# Additional Dependencies (auto-installed)
|
||||
# alembic==1.17.0
|
||||
# blinker==1.9.0
|
||||
# cffi==2.0.0
|
||||
# click==8.3.0
|
||||
# greenlet==3.2.4
|
||||
# itsdangerous==2.2.0
|
||||
# Jinja2==3.1.6
|
||||
# Mako==1.3.10
|
||||
# MarkupSafe==3.0.3
|
||||
# pycparser==2.23
|
||||
# PyJWT==2.10.1
|
||||
# SQLAlchemy==2.0.44
|
||||
# typing_extensions==4.15.0
|
||||
@ -1,21 +1,21 @@
|
||||
"""
|
||||
Flask应用启动脚本
|
||||
"""
|
||||
|
||||
import os
|
||||
from app import create_app
|
||||
|
||||
# 设置环境变量
|
||||
os.environ.setdefault('FLASK_ENV', 'development')
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 开发模式启动
|
||||
app.run(
|
||||
host='0.0.0.0',
|
||||
port=5000,
|
||||
debug=True,
|
||||
threaded=True
|
||||
"""
|
||||
Flask应用启动脚本
|
||||
"""
|
||||
|
||||
import os
|
||||
from app import create_app
|
||||
|
||||
# 设置环境变量
|
||||
os.environ.setdefault('FLASK_ENV', 'development')
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 开发模式启动
|
||||
app.run(
|
||||
host='0.0.0.0',
|
||||
port=6006,
|
||||
debug=True,
|
||||
threaded=True
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
# MuseGuard 后端服务状态检查脚本
|
||||
|
||||
echo "========================================"
|
||||
echo " MuseGuard 后端服务状态"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 检查Flask应用
|
||||
echo "📌 Flask 应用:"
|
||||
if [ -f logs/flask.pid ]; then
|
||||
FLASK_PID=$(cat logs/flask.pid)
|
||||
if ps -p $FLASK_PID > /dev/null 2>&1; then
|
||||
echo " ✅ 运行中 (PID: $FLASK_PID)"
|
||||
echo " 📍 URL: http://127.0.0.1:6006"
|
||||
echo " 📍 测试: http://127.0.0.1:6006/static/test.html"
|
||||
else
|
||||
echo " ❌ 未运行 (PID文件存在但进程不存在)"
|
||||
fi
|
||||
else
|
||||
if pgrep -f "python run.py" > /dev/null 2>&1; then
|
||||
FLASK_PID=$(pgrep -f "python run.py")
|
||||
echo " ⚠️ 运行中但无PID文件 (PID: $FLASK_PID)"
|
||||
else
|
||||
echo " ❌ 未运行"
|
||||
fi
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 检查Worker
|
||||
echo "📌 RQ Worker:"
|
||||
if [ -f logs/worker.pid ]; then
|
||||
WORKER_PID=$(cat logs/worker.pid)
|
||||
if ps -p $WORKER_PID > /dev/null 2>&1; then
|
||||
echo " ✅ 运行中 (PID: $WORKER_PID)"
|
||||
else
|
||||
echo " ❌ 未运行 (PID文件存在但进程不存在)"
|
||||
fi
|
||||
else
|
||||
if pgrep -f "python worker.py" > /dev/null 2>&1; then
|
||||
WORKER_PID=$(pgrep -f "python worker.py")
|
||||
echo " ⚠️ 运行中但无PID文件 (PID: $WORKER_PID)"
|
||||
else
|
||||
echo " ❌ 未运行"
|
||||
fi
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 检查Redis
|
||||
echo "📌 Redis:"
|
||||
if redis-cli ping > /dev/null 2>&1; then
|
||||
echo " ✅ 运行中"
|
||||
else
|
||||
echo " ❌ 未运行"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 检查日志文件
|
||||
echo "📌 日志文件:"
|
||||
if [ -f logs/flask.log ]; then
|
||||
FLASK_LOG_SIZE=$(du -h logs/flask.log | cut -f1)
|
||||
echo " Flask: logs/flask.log ($FLASK_LOG_SIZE)"
|
||||
else
|
||||
echo " Flask: 无日志文件"
|
||||
fi
|
||||
|
||||
if [ -f logs/worker.log ]; then
|
||||
WORKER_LOG_SIZE=$(du -h logs/worker.log | cut -f1)
|
||||
echo " Worker: logs/worker.log ($WORKER_LOG_SIZE)"
|
||||
else
|
||||
echo " Worker: 无日志文件"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "========================================"
|
||||
echo " 快速操作"
|
||||
echo "========================================"
|
||||
echo "启动服务: ./start.sh"
|
||||
echo "停止服务: ./stop.sh"
|
||||
echo "查看Flask日志: tail -f logs/flask.log"
|
||||
echo "查看Worker日志: tail -f logs/worker.log"
|
||||
echo ""
|
||||
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
# MuseGuard 后端服务停止脚本
|
||||
|
||||
echo "========================================"
|
||||
echo " 停止 MuseGuard 后端服务"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 从PID文件停止进程
|
||||
if [ -f logs/flask.pid ]; then
|
||||
FLASK_PID=$(cat logs/flask.pid)
|
||||
if ps -p $FLASK_PID > /dev/null 2>&1; then
|
||||
echo "停止 Flask 应用 (PID: $FLASK_PID)..."
|
||||
kill $FLASK_PID
|
||||
echo "[成功] Flask 应用已停止"
|
||||
else
|
||||
echo "[提示] Flask 应用未运行"
|
||||
fi
|
||||
rm logs/flask.pid
|
||||
else
|
||||
echo "[提示] 未找到 Flask PID 文件"
|
||||
fi
|
||||
|
||||
if [ -f logs/worker.pid ]; then
|
||||
WORKER_PID=$(cat logs/worker.pid)
|
||||
if ps -p $WORKER_PID > /dev/null 2>&1; then
|
||||
echo "停止 RQ Worker (PID: $WORKER_PID)..."
|
||||
kill $WORKER_PID
|
||||
echo "[成功] RQ Worker 已停止"
|
||||
else
|
||||
echo "[提示] RQ Worker 未运行"
|
||||
fi
|
||||
rm logs/worker.pid
|
||||
else
|
||||
echo "[提示] 未找到 Worker PID 文件"
|
||||
fi
|
||||
|
||||
# 确保所有相关进程都停止
|
||||
echo ""
|
||||
echo "清理所有相关进程..."
|
||||
pkill -f "python run.py" 2>/dev/null && echo "清理了额外的 Flask 进程"
|
||||
pkill -f "python worker.py" 2>/dev/null && echo "清理了额外的 Worker 进程"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " 服务已停止"
|
||||
echo "========================================"
|
||||
@ -0,0 +1,42 @@
|
||||
"""
|
||||
RQ Worker 启动脚本
|
||||
用于启动后台任务处理器
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from redis import Redis
|
||||
from rq import Worker, Queue
|
||||
from config.algorithm_config import AlgorithmConfig
|
||||
from app import create_app
|
||||
|
||||
# 创建Flask应用上下文
|
||||
app = create_app()
|
||||
|
||||
def main():
|
||||
"""启动worker"""
|
||||
with app.app_context():
|
||||
# 连接Redis
|
||||
redis_conn = Redis.from_url(AlgorithmConfig.REDIS_URL)
|
||||
|
||||
# 创建队列
|
||||
queue = Queue(AlgorithmConfig.RQ_QUEUE_NAME, connection=redis_conn)
|
||||
|
||||
# 创建worker
|
||||
worker = Worker([queue], connection=redis_conn)
|
||||
|
||||
print(f"🚀 RQ Worker启动成功!")
|
||||
print(f"📡 Redis: {AlgorithmConfig.REDIS_URL}")
|
||||
print(f"📋 Queue: {AlgorithmConfig.RQ_QUEUE_NAME}")
|
||||
print(f"🔄 使用{'真实' if AlgorithmConfig.USE_REAL_ALGORITHMS else '虚拟'}算法")
|
||||
print(f"⏳ 等待任务...")
|
||||
|
||||
# 启动worker
|
||||
worker.work()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in new issue