From fe90dc173e3d8eed9849c594d43e3825f07607dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Tue, 6 Jan 2026 19:16:31 +0800 Subject: [PATCH 1/5] =?UTF-8?q?improve:=20=E4=BC=98=E5=8C=96=E7=AE=97?= =?UTF-8?q?=E6=B3=95=E8=B6=85=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/scripts/attack_anti_face_edit.sh | 2 - src/backend/app/scripts/attack_aspl.sh | 46 +++++++++---------- src/backend/app/scripts/attack_caat.sh | 26 ++++++----- .../app/scripts/attack_caat_with_prior.sh | 8 ++-- src/backend/app/scripts/attack_simac.sh | 14 +++--- src/backend/config/algorithm_config.py | 12 +++-- 6 files changed, 57 insertions(+), 51 deletions(-) diff --git a/src/backend/app/scripts/attack_anti_face_edit.sh b/src/backend/app/scripts/attack_anti_face_edit.sh index c66ced4..857d433 100644 --- a/src/backend/app/scripts/attack_anti_face_edit.sh +++ b/src/backend/app/scripts/attack_anti_face_edit.sh @@ -50,8 +50,6 @@ CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \ --center_crop \ --eps 10 \ --step_size 0.002 \ - --save_every 200 \ --attack_type add-log \ --seed 0 \ --dataloader_num_workers 2 - diff --git a/src/backend/app/scripts/attack_aspl.sh b/src/backend/app/scripts/attack_aspl.sh index a4f9e53..f1dc6b9 100644 --- a/src/backend/app/scripts/attack_aspl.sh +++ b/src/backend/app/scripts/attack_aspl.sh @@ -24,29 +24,29 @@ echo "Clearing output directory: $OUTPUT_DIR" find "$OUTPUT_DIR" -mindepth 1 -delete -accelerate launch ../algorithms/aspl.py \ -  --pretrained_model_name_or_path=$MODEL_PATH  \ -  --enable_xformers_memory_efficient_attention \ -  --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ -  --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ -  --instance_prompt="a photo of sks person" \ -  --class_data_dir=$CLASS_DIR \ -  --num_class_images=200 \ -  --class_prompt="a photo of person" \ -  --output_dir=$OUTPUT_DIR \ -  --center_crop \ -  --with_prior_preservation \ -  --prior_loss_weight=1.0 \ -  --resolution=384 \ -  --train_batch_size=1 \ -  --max_train_steps=50 \ -  --max_f_train_steps=3 \ -  --max_adv_train_steps=6 \ -  --checkpointing_iterations=10 \ -  --learning_rate=5e-7 \ -  --pgd_alpha=0.005 \ -  --pgd_eps=8 \ -  --seed=0 +accelerate launch --num_processes 1 --num_machines 1 ../algorithms/aspl.py \ + --pretrained_model_name_or_path="$MODEL_PATH" \ + --enable_xformers_memory_efficient_attention \ + --instance_data_dir_for_train="$CLEAN_TRAIN_DIR" \ + --instance_data_dir_for_adversarial="$CLEAN_ADV_DIR" \ + --instance_prompt="a photo of sks person" \ + --class_data_dir="$CLASS_DIR" \ + --num_class_images=200 \ + --class_prompt="a photo of person" \ + --output_dir="$OUTPUT_DIR" \ + --center_crop \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --resolution=384 \ + --train_batch_size=1 \ + --max_train_steps=50 \ + --max_f_train_steps=3 \ + --max_adv_train_steps=6 \ + --checkpointing_iterations=10 \ + --learning_rate=5e-7 \ + --pgd_alpha=0.005 \ + --pgd_eps=8 \ + --seed=0 # ------------------------- 训练后清空 CLASS_DIR ------------------------- # 注意:这会在 accelerate launch 成功结束后执行 diff --git a/src/backend/app/scripts/attack_caat.sh b/src/backend/app/scripts/attack_caat.sh index 00a9f8c..fe394e6 100644 --- a/src/backend/app/scripts/attack_caat.sh +++ b/src/backend/app/scripts/attack_caat.sh @@ -21,17 +21,19 @@ echo "Clearing output directory: $OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete - +#--debug_oom_step0_only \ accelerate launch ../algorithms/caat.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --instance_prompt="a photo of a person" \ - --resolution=512 \ - --learning_rate=1e-5 \ - --lr_warmup_steps=0 \ - --max_train_steps=250 \ + --pretrained_model_name_or_path="$MODEL_NAME" \ + --instance_data_dir="$INSTANCE_DIR" \ + --output_dir="$OUTPUT_DIR" \ + --instance_prompt="a photo of person" \ + --resolution 512 \ + --learning_rate 1e-5 \ + --lr_warmup_steps 0 \ + --max_train_steps 250 \ --hflip \ - --mixed_precision bf16 \ - --alpha=5e-3 \ - --eps=0.05 \ No newline at end of file + --mixed_precision bf16 \ + --alpha 5e-3 \ + --eps 0.05 \ + --debug_oom \ + --debug_oom_sync \ No newline at end of file diff --git a/src/backend/app/scripts/attack_caat_with_prior.sh b/src/backend/app/scripts/attack_caat_with_prior.sh index a7e149e..ecb7c92 100644 --- a/src/backend/app/scripts/attack_caat_with_prior.sh +++ b/src/backend/app/scripts/attack_caat_with_prior.sh @@ -22,13 +22,13 @@ echo "Clearing output directory: $OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete - +#--debug_oom_step0_only \ accelerate launch ../algorithms/caat.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --with_prior_preservation \ - --instance_prompt="a photo of a person" \ + --instance_prompt="a photo of person" \ --num_class_images=200 \ --class_data_dir=$CLASS_DIR \ --class_prompt='person' \ @@ -39,7 +39,9 @@ accelerate launch ../algorithms/caat.py \ --hflip \ --mixed_precision bf16 \ --alpha=5e-3 \ - --eps=0.05 + --eps=0.05 \ + --debug_oom \ + --debug_oom_sync # ------------------------- 【步骤 2】训练后清空 CLASS_DIR ------------------------- diff --git a/src/backend/app/scripts/attack_simac.sh b/src/backend/app/scripts/attack_simac.sh index 660d6a1..a6b9f20 100644 --- a/src/backend/app/scripts/attack_simac.sh +++ b/src/backend/app/scripts/attack_simac.sh @@ -25,20 +25,20 @@ echo "Clearing output directory: $OUTPUT_DIR" mkdir -p "$OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete -find "$CLASS_DIR" -mindepth 1 -delete +# find "$CLASS_DIR" -mindepth 1 -delete -accelerate launch ../algorithms/simac.py \ - --pretrained_model_name_or_path=$MODEL_PATH \ +accelerate launch --num_processes 1 --num_machines 1 ../algorithms/simac.py \ + --pretrained_model_name_or_path="$MODEL_PATH" \ --enable_xformers_memory_efficient_attention \ - --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ - --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ + --instance_data_dir_for_train="$CLEAN_TRAIN_DIR" \ + --instance_data_dir_for_adversarial="$CLEAN_ADV_DIR" \ --instance_prompt="a photo of person" \ - --class_data_dir=$CLASS_DIR \ + --class_data_dir="$CLASS_DIR" \ --num_class_images=100 \ --class_prompt="a photo of person" \ - --output_dir=$OUTPUT_DIR \ + --output_dir="$OUTPUT_DIR" \ --center_crop \ --with_prior_preservation \ --prior_loss_weight=1.0 \ diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 53662d5..49c7631 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -145,7 +145,10 @@ class AlgorithmConfig: 'max_train_steps': 250, 'hflip': True, 'mixed_precision': 'bf16', - 'alpha': 5e-3 + 'alpha': 5e-3, + 'eps': 0.05, + 'debug_oom': True, + 'debug_oom_sync': True } }, 'caat_pro': { @@ -156,7 +159,7 @@ class AlgorithmConfig: 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'with_prior_preservation': True, 'instance_prompt': 'a selfie photo of person', - 'class_prompt': 'a selfie photo of person', + 'class_prompt': 'person', 'num_class_images': 200, 'resolution': 512, 'learning_rate': 1e-5, @@ -165,7 +168,9 @@ class AlgorithmConfig: 'hflip': True, 'mixed_precision': 'bf16', 'alpha': 5e-3, - 'eps': 0.05 + 'eps': 0.05, + 'debug_oom': True, + 'debug_oom_sync': True } }, 'pid': { @@ -233,7 +238,6 @@ class AlgorithmConfig: 'max_train_steps': 2000, 'center_crop': True, 'step_size': 0.002, - 'save_every': 200, 'attack_type': 'add-log', 'seed': 0, 'dataloader_num_workers': 2 -- 2.34.1 From b5af0d22ab26a382ca80c9f9facf24d93f39c7ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Tue, 6 Jan 2026 19:16:56 +0800 Subject: [PATCH 2/5] =?UTF-8?q?improve:=20=E4=BC=98=E5=8C=96=E7=AE=97?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/algorithms/perturbation/aspl.py | 648 +++++------ .../app/algorithms/perturbation/caat.py | 746 ++++++------ .../app/algorithms/perturbation/simac.py | 1001 +++++++---------- 3 files changed, 1058 insertions(+), 1337 deletions(-) diff --git a/src/backend/app/algorithms/perturbation/aspl.py b/src/backend/app/algorithms/perturbation/aspl.py index 6f26194..c96ae16 100644 --- a/src/backend/app/algorithms/perturbation/aspl.py +++ b/src/backend/app/algorithms/perturbation/aspl.py @@ -1,9 +1,11 @@ import argparse import copy +import gc import hashlib import itertools import logging import os +import random from pathlib import Path import datasets @@ -24,12 +26,77 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig - logger = get_logger(__name__) +# ----------------------------- +# Lightweight debug helpers (low overhead) +# ----------------------------- +def _cuda_gc() -> None: + """Best-effort CUDA memory cleanup (does not change algorithmic behavior).""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _fmt_bytes(n: int) -> str: + return f"{n / (1024**2):.1f}MB" + + +def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None): + """Log CUDA memory stats without copying tensors to CPU.""" + if not torch.cuda.is_available(): + logger.info(f"[mem] {prefix} cuda_not_available") + return + if sync: + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() + reserv = torch.cuda.memory_reserved() + max_alloc = torch.cuda.max_memory_allocated() + max_reserv = torch.cuda.max_memory_reserved() + dev = str(accelerator.device) if accelerator is not None else "cuda" + msg = ( + f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} " + f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}" + ) + if extra: + msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()]) + logger.info(msg) + + +def log_path_stats(prefix: str, p: Path): + """Log directory/file existence and file count (best-effort).""" + try: + exists = p.exists() + is_dir = p.is_dir() if exists else False + n_files = 0 + if exists and is_dir: + n_files = sum(1 for x in p.iterdir() if x.is_file()) + logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}") + except Exception as e: + logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}") + + +def log_args(args): + for k in sorted(vars(args).keys()): + logger.info(f"[args] {k}={getattr(args, k)}") + + +def log_tensor_meta(prefix: str, t: torch.Tensor | None): + if t is None: + logger.info(f"[tensor] {prefix} None") + return + logger.info( + f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " + f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" + ) + + +# ----------------------------- +# Dataset +# ----------------------------- class DreamBoothDatasetFromTensor(Dataset): - """Just like DreamBoothDataset, but take instance_images_tensor instead of path""" + """基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。""" def __init__( self, @@ -53,10 +120,19 @@ class DreamBoothDatasetFromTensor(Dataset): if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) + # Only keep files to avoid directories affecting length. + self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()] self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt + + # Early, explicit failure instead of ZeroDivisionError later. + if self.num_class_images == 0: + raise ValueError( + f"class_data_dir is empty: {self.class_data_root}. " + f"Prior preservation requires class images. " + f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation." + ) else: self.class_data_root = None @@ -85,8 +161,10 @@ class DreamBoothDatasetFromTensor(Dataset): ).input_ids if self.class_data_root: + if self.num_class_images == 0: + raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}") class_image = Image.open(self.class_images_path[index % self.num_class_images]) - if not class_image.mode == "RGB": + if class_image.mode != "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( @@ -100,6 +178,9 @@ class DreamBoothDatasetFromTensor(Dataset): return example +# ----------------------------- +# Model helper +# ----------------------------- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -120,217 +201,47 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") +# ----------------------------- +# Args +# ----------------------------- def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--instance_data_dir_for_train", - type=str, - default=None, - required=True, - help="A folder containing the training data of instance images.", - ) - parser.add_argument( - "--instance_data_dir_for_adversarial", - type=str, - default=None, - required=True, - help="A folder containing the images to add adversarial noise", - ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - required=True, - help="The prompt with identifier specifying the instance", - ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided instance images.", - ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", - ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="text-inversion-model", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--sample_batch_size", - type=int, - default=8, - help="Batch size (per device) for sampling images.", - ) - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform.", - ) - parser.add_argument( - "--max_f_train_steps", - type=int, - default=10, - help="Total number of sub-steps to train surogate model.", - ) - parser.add_argument( - "--max_adv_train_steps", - type=int, - default=10, - help="Total number of sub-steps to train adversarial noise.", - ) - parser.add_argument( - "--checkpointing_iterations", - type=int, - default=5, - help=("Save a checkpoint of the training state every X iterations."), - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-6, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", - ) - parser.add_argument( - "--pgd_alpha", - type=float, - default=1.0 / 255, - help="The step size for pgd.", - ) - parser.add_argument( - "--pgd_eps", - type=int, - default=0.05, - help="The noise budget for pgd.", - ) - parser.add_argument( - "--target_image_path", - default=None, - help="target image for attacking", - ) + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True) + parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True) + parser.add_argument("--class_data_dir", type=str, default=None, required=False) + parser.add_argument("--instance_prompt", type=str, default=None, required=True) + parser.add_argument("--class_prompt", type=str, default=None) + parser.add_argument("--with_prior_preservation", default=False, action="store_true") + parser.add_argument("--prior_loss_weight", type=float, default=1.0) + parser.add_argument("--num_class_images", type=int, default=100) + parser.add_argument("--output_dir", type=str, default="text-inversion-model") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--center_crop", default=False, action="store_true") + parser.add_argument("--train_text_encoder", action="store_true") + parser.add_argument("--train_batch_size", type=int, default=4) + parser.add_argument("--sample_batch_size", type=int, default=8) + parser.add_argument("--max_train_steps", type=int, default=20) + parser.add_argument("--max_f_train_steps", type=int, default=10) + parser.add_argument("--max_adv_train_steps", type=int, default=10) + parser.add_argument("--checkpointing_iterations", type=int, default=5) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument("--report_to", type=str, default="tensorboard") + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") + parser.add_argument("--pgd_alpha", type=float, default=1.0 / 255) + parser.add_argument("--pgd_eps", type=int, default=0.05) + parser.add_argument("--target_image_path", default=None) + + # Debug / diagnostics (low-overhead) + parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.") + parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.") + parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.") if input_args is not None: args = parser.parse_args(input_args) @@ -340,8 +251,11 @@ def parse_args(input_args=None): return args +# ----------------------------- +# Class image prompt dataset +# ----------------------------- class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -357,6 +271,9 @@ class PromptDataset(Dataset): return example +# ----------------------------- +# IO +# ----------------------------- def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: image_transforms = transforms.Compose( [ @@ -372,17 +289,10 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: return images -def train_one_epoch( - args, - models, - tokenizer, - noise_scheduler, - vae, - data_tensor: torch.Tensor, - num_steps=20, -): - # Load the tokenizer - +# ----------------------------- +# Core routines +# ----------------------------- +def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20): unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -394,17 +304,17 @@ def train_one_epoch( eps=1e-08, ) + # IMPORTANT: only pass class_data_dir when with_prior_preservation is enabled. train_dataset = DreamBoothDatasetFromTensor( data_tensor, args.instance_prompt, tokenizer, - args.class_data_dir, + args.class_data_dir if args.with_prior_preservation else None, args.class_prompt, args.resolution, args.center_crop, ) - # weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -416,33 +326,35 @@ def train_one_epoch( unet.train() text_encoder.train() - step_data = train_dataset[step % len(train_dataset)] - pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( - device, dtype=weight_dtype - ) - input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + try: + step_data = train_dataset[step % len(train_dataset)] + except Exception as e: + logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}") + raise + + try: + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( + device, dtype=weight_dtype + ) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + except KeyError as e: + logger.error( + f"[err] missing key in step_data at step={step}: missing={str(e)}. " + f"with_prior_preservation={args.with_prior_preservation}" + ) + raise latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning encoder_hidden_states = text_encoder(input_ids)[0] - - # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -450,47 +362,37 @@ def train_one_epoch( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - # with prior preservation loss if args.with_prior_preservation: model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. loss = instance_loss + args.prior_loss_weight * prior_loss - else: + prior_loss = torch.tensor(0.0, device=device) + instance_loss = torch.tensor(0.0, device=device) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss.backward() torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) optimizer.step() optimizer.zero_grad() - print( - f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}" + + logger.info( + f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} " + f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}" ) - return [unet, text_encoder] + del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states + del model_pred, target, loss, prior_loss, instance_loss + del optimizer, train_dataset, params_to_optimize + _cuda_gc() + return [unet, text_encoder] -def pgd_attack( - args, - models, - tokenizer, - noise_scheduler, - vae, - data_tensor: torch.Tensor, - original_images: torch.Tensor, - target_tensor: torch.Tensor, - num_steps: int, -): - """Return new perturbed data""" +def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps: int): unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -515,24 +417,14 @@ def pgd_attack( latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - #noise_scheduler.config.num_train_timesteps - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning encoder_hidden_states = text_encoder(input_ids.to(device))[0] - - # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -540,11 +432,10 @@ def pgd_attack( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - unet.zero_grad() - text_encoder.zero_grad() + unet.zero_grad(set_to_none=True) + text_encoder.zero_grad(set_to_none=True) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # target-shift loss if target_tensor is not None: xtm1_pred = torch.cat( [ @@ -561,16 +452,25 @@ def pgd_attack( loss.backward() - alpha = args.pgd_alpha + alpha = args.pgd_alpha eps = args.pgd_eps / 255 adv_images = perturbed_images + alpha * perturbed_images.grad.sign() eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - print(f"PGD loss - step {step}, loss: {loss.detach().item()}") + + logger.info(f"[pgd] step={step} loss={loss.detach().item():.6f} alpha={alpha} eps={eps}") + + del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss + del adv_images, eta + + _cuda_gc() return perturbed_images +# ----------------------------- +# Main +# ----------------------------- def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -586,6 +486,7 @@ def main(args): level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() @@ -595,15 +496,35 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() + if accelerator.is_local_main_process: + logger.info(f"[run] using_file={__file__}") + log_args(args) + if args.seed is not None: set_seed(args.seed) - # Generate class images if prior preservation is enabled. + if args.debug and accelerator.is_local_main_process: + log_cuda("startup", accelerator, sync=args.debug_cuda_sync) + + # ------------------------- + # Prior preservation: generate class images if needed + # ------------------------- if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("--with_prior_preservation requires --class_data_dir") + if args.class_prompt is None: + raise ValueError("--with_prior_preservation requires --class_prompt") + class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) - cur_class_images = len(list(class_images_dir.iterdir())) + class_images_dir.mkdir(parents=True, exist_ok=True) + + if accelerator.is_local_main_process: + log_path_stats("class_dir_before", class_images_dir) + + cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) + if accelerator.is_local_main_process: + logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}") + if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.mixed_precision == "fp32": @@ -612,6 +533,12 @@ def main(args): torch_dtype = torch.float16 elif args.mixed_precision == "bf16": torch_dtype = torch.bfloat16 + + if accelerator.is_local_main_process: + logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}") + if args.debug: + log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync) + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -621,56 +548,67 @@ def main(args): pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) + if accelerator.is_local_main_process and args.debug: + log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync) + for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images + if accelerator.is_local_main_process and args.debug: + logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}") for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + del pipeline, sample_dataset, sample_dataloader + _cuda_gc() + + accelerator.wait_for_everyone() + + final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) + if accelerator.is_local_main_process: + logger.info(f"[class_gen] done final_class_images={final_class_images}") + log_path_stats("class_dir_after", class_images_dir) + if final_class_images == 0: + raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.") + + else: + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + logger.info("[class_gen] skipped (already enough images)") + else: + if accelerator.is_local_main_process: + logger.info("[class_gen] disabled (with_prior_preservation is False)") - # import correct text encoder class + # ------------------------- + # Load models / tokenizer / scheduler / VAE + # ------------------------- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - # Load scheduler and models + if accelerator.is_local_main_process and args.debug: + log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync) + text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False ) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision - ).cuda() - + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda() vae.requires_grad_(False) if not args.train_text_encoder: @@ -679,52 +617,60 @@ def main(args): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True - clean_data = load_data( - args.instance_data_dir_for_train, - size=args.resolution, - center_crop=args.center_crop, - ) - perturbed_data = load_data( - args.instance_data_dir_for_adversarial, - size=args.resolution, - center_crop=args.center_crop, - ) - original_data = perturbed_data.clone() - original_data.requires_grad_(False) - if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() + if accelerator.is_local_main_process: + logger.info("[xformers] enabled") else: raise ValueError("xformers is not available. Make sure it is installed correctly") + if accelerator.is_local_main_process and args.debug: + log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync) + + # ------------------------- + # Load data tensors + # ------------------------- + train_dir = Path(args.instance_data_dir_for_train) + adv_dir = Path(args.instance_data_dir_for_adversarial) + if accelerator.is_local_main_process and args.debug: + log_path_stats("train_dir", train_dir) + log_path_stats("adv_dir", adv_dir) + + clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop) + perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + + if accelerator.is_local_main_process and args.debug: + log_tensor_meta("clean_data_cpu", clean_data) + log_tensor_meta("perturbed_data_cpu", perturbed_data) + target_latent_tensor = None if args.target_image_path is not None: target_image_path = Path(args.target_image_path) - assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" + if not target_image_path.is_file(): + raise ValueError(f"Target image path does not exist: {target_image_path}") target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 - target_latent_tensor = ( - vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor - ) + target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) + target_latent_tensor = target_latent_tensor * vae.config.scaling_factor target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() + if accelerator.is_local_main_process and args.debug: + log_tensor_meta("target_latent_tensor", target_latent_tensor) + f = [unet, text_encoder] for i in range(args.max_train_steps): - # 1. f' = f.clone() + if accelerator.is_local_main_process: + logger.info(f"[outer] i={i}/{args.max_train_steps}") + f_sur = copy.deepcopy(f) - f_sur = train_one_epoch( - args, - f_sur, - tokenizer, - noise_scheduler, - vae, - clean_data, - args.max_f_train_steps, - ) + f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps) + perturbed_data = pgd_attack( args, f_sur, @@ -736,33 +682,31 @@ def main(args): target_latent_tensor, args.max_adv_train_steps, ) - f = train_one_epoch( - args, - f, - tokenizer, - noise_scheduler, - vae, - perturbed_data, - args.max_f_train_steps, - ) + + f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps) if (i + 1) % args.checkpointing_iterations == 0: save_folder = args.output_dir os.makedirs(save_folder, exist_ok=True) noised_imgs = perturbed_data.detach() - - img_filenames = [ - Path(instance_path).stem - for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir()) - ] + + img_filenames = [Path(instance_path).stem for instance_path in list(adv_dir.iterdir()) if instance_path.is_file()] for img_pixel, img_name in zip(noised_imgs, img_filenames): - save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") + save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") Image.fromarray( - (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy() + (img_pixel * 127.5 + 128) + .clamp(0, 255) + .to(torch.uint8) + .permute(1, 2, 0) + .cpu() + .numpy() ).save(save_path) - - print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") + + if accelerator.is_local_main_process: + logger.info(f"[save] step={i+1} saved={len(img_filenames)} to {save_folder}") + + _cuda_gc() if __name__ == "__main__": diff --git a/src/backend/app/algorithms/perturbation/caat.py b/src/backend/app/algorithms/perturbation/caat.py index c7e41cd..b5c331a 100644 --- a/src/backend/app/algorithms/perturbation/caat.py +++ b/src/backend/app/algorithms/perturbation/caat.py @@ -64,8 +64,9 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st else: raise ValueError(f"{model_class} is not supported.") + class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """用于批量生成 class 图像的 prompt 数据集。""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -82,10 +83,7 @@ class PromptDataset(Dataset): class CustomDiffusionDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ + """CAAT/Custom Diffusion 训练数据集。""" def __init__( self, @@ -133,6 +131,7 @@ class CustomDiffusionDataset(Dataset): self.num_instance_images = len(self.instance_images_path) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) + self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) self.image_transforms = transforms.Compose( @@ -165,19 +164,20 @@ class CustomDiffusionDataset(Dataset): else: instance_image[top : top + inner, left : left + inner, :] = image mask[ - top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1 + top // factor + 1 : (top + scale) // factor - 1, + left // factor + 1 : (left + scale) // factor - 1, ] = 1.0 return instance_image, mask def __getitem__(self, index): example = {} + instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images] instance_image = Image.open(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") instance_image = self.flip(instance_image) - # apply resize augmentation and create a valid image region mask random_scale = self.size if self.aug: random_scale = ( @@ -220,270 +220,59 @@ class CustomDiffusionDataset(Dataset): return example - def parse_args(input_args=None): + """解析 CAAT 训练参数。""" parser = argparse.ArgumentParser(description="CAAT training script.") - parser.add_argument( - "--alpha", - type=float, - default=5e-3, - required=True, - help="PGD alpha.", - ) - parser.add_argument( - "--eps", - type=float, - default=0.1, - required=True, - help="PGD eps.", - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--instance_data_dir", - type=str, - default=None, - help="A folder containing the training data of instance images.", - ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - help="A folder containing the training data of class images.", - ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - help="The prompt with identifier specifying the instance", - ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided instance images.", - ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", - ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss." - ) - parser.add_argument( - "--num_class_images", - type=int, - default=200, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="outputs", - help="The output directory.", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="A seed for reproducible training." - ) - parser.add_argument( - "--resolution", - type=int, - default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." - ) - parser.add_argument( - "--max_train_steps", - type=int, - default=250, - help="Total number of training steps to perform.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=250, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=2, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--freeze_model", - type=str, - default="crossattn_kv", - choices=["crossattn_kv", "crossattn"], - help="crossattn to enable fine-tuning of all params in the cross attention", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) - parser.add_argument( - "--concepts_list", - type=str, - default=None, - help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", - ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) - parser.add_argument( - "--set_grads_to_none", - action="store_true", - help=( - "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" - " behaviors, so disable this argument if it causes any problems. More info:" - " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" - ), - ) - parser.add_argument( - "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word." - ) - parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") - parser.add_argument( - "--noaug", - action="store_true", - help="Dont apply augmentation during data augmentation when this flag is enabled.", - ) + + parser.add_argument("--alpha", type=float, default=5e-3, required=True, help="PGD alpha.") + parser.add_argument("--eps", type=float, default=0.1, required=True, help="PGD eps.") + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--instance_data_dir", type=str, default=None) + parser.add_argument("--class_data_dir", type=str, default=None) + parser.add_argument("--instance_prompt", type=str, default=None) + parser.add_argument("--class_prompt", type=str, default=None) + parser.add_argument("--with_prior_preservation", default=False, action="store_true") + parser.add_argument("--prior_loss_weight", type=float, default=1.0) + parser.add_argument("--num_class_images", type=int, default=200) + parser.add_argument("--output_dir", type=str, default="outputs") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--center_crop", default=False, action="store_true") + parser.add_argument("--sample_batch_size", type=int, default=4) + parser.add_argument("--max_train_steps", type=int, default=250) + parser.add_argument("--checkpointing_steps", type=int, default=250) + parser.add_argument("--checkpoints_total_limit", type=int, default=None) + parser.add_argument("--gradient_checkpointing", action="store_true") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--dataloader_num_workers", type=int, default=2) + parser.add_argument("--freeze_model", type=str, default="crossattn_kv", choices=["crossattn_kv", "crossattn"]) + parser.add_argument("--lr_scheduler", type=str, default="constant") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--use_8bit_adam", action="store_true") + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--max_grad_norm", default=1.0, type=float) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument("--report_to", type=str, default="tensorboard") + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) + parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"]) + parser.add_argument("--concepts_list", type=str, default=None) + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") + parser.add_argument("--set_grads_to_none", action="store_true") + parser.add_argument("--initializer_token", type=str, default="ktn+pll+ucd") + parser.add_argument("--hflip", action="store_true") + parser.add_argument("--noaug", action="store_true") + + parser.add_argument("--debug_oom", action="store_true", help="开启显存与关键张量日志,用于定位第0步OOM") + parser.add_argument("--debug_oom_sync", action="store_true", help="打印前强制同步CUDA,日志更准但更慢") + parser.add_argument("--debug_oom_step0_only", action="store_true", help="只打印第0步相关日志,降低干扰") if input_args is not None: args = parser.parse_args(input_args) @@ -501,7 +290,6 @@ def parse_args(input_args=None): if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: - # logger is not available yet if args.class_data_dir is not None: warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: @@ -510,9 +298,76 @@ def parse_args(input_args=None): return args +def _fmt_bytes(n: int) -> str: + gb = n / (1024**3) + return f"{gb:.2f}GB" + + +def _debug_should_print(args, global_step: int) -> bool: + if not args.debug_oom: + return False + if args.debug_oom_step0_only: + return global_step == 0 + return True + + +def log_cuda(prefix: str, args, accelerator: Accelerator, extra: dict | None = None): + """打印CUDA显存状态与可选附加信息,不改变训练逻辑。""" + if not args.debug_oom: + return + if not torch.cuda.is_available(): + logger.info(f"[mem] {prefix} cuda_not_available") + return + + if args.debug_oom_sync: + torch.cuda.synchronize() + + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + max_alloc = torch.cuda.max_memory_allocated() + max_reserved = torch.cuda.max_memory_reserved() + + msg = ( + f"[mem] {prefix} " + f"alloc={_fmt_bytes(allocated)} reserv={_fmt_bytes(reserved)} " + f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserved)} " + f"device={accelerator.device}" + ) + if extra: + kv = " ".join([f"{k}={v}" for k, v in extra.items()]) + msg = msg + " " + kv + + logger.info(msg) + + +def log_tensor(prefix: str, t: torch.Tensor | None, args, accelerator: Accelerator): + """打印张量的shape/dtype/device/grad状态,避免误把大张量复制到CPU。""" + if not args.debug_oom: + return + if t is None: + logger.info(f"[tensor] {prefix} None") + return + logger.info( + f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " + f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" + ) + + +def log_trainable_params(prefix: str, module: torch.nn.Module, args): + """打印模块可训练参数规模,确认是否意外训练了大量参数。""" + if not args.debug_oom: + return + trainable = [(n, p.numel(), str(p.dtype), str(p.device)) for n, p in module.named_parameters() if p.requires_grad] + total = sum(x[1] for x in trainable) + logger.info(f"[trainable] {prefix} tensors={len(trainable)} total_params={total}") + for n, numel, dtype, dev in trainable[:30]: + logger.info(f"[trainable] {prefix} name={n} numel={numel} dtype={dtype} device={dev}") + if len(trainable) > 30: + logger.info(f"[trainable] {prefix} ... (total {len(trainable)} trainable tensors)") + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( @@ -526,6 +381,7 @@ def main(args): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) + logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() @@ -536,9 +392,17 @@ def main(args): accelerator.init_trackers("CAAT", config=vars(args)) - # If passed along, set the training seed now. + if accelerator.is_local_main_process: + logger.info("========== CAAT 参数 ==========") + for k in sorted(vars(args).keys()): + logger.info(f"{k}: {getattr(args, k)}") + logger.info("===============================") + + log_cuda("startup", args, accelerator) + if args.seed is not None: set_seed(args.seed) + if args.concepts_list is None: args.concepts_list = [ { @@ -552,7 +416,6 @@ def main(args): with open(args.concepts_list, "r") as f: args.concepts_list = json.load(f) - # Generate class images if prior preservation is enabled. if args.with_prior_preservation: for i, concept in enumerate(args.concepts_list): class_images_dir = Path(concept["class_data_dir"]) @@ -560,7 +423,6 @@ def main(args): class_images_dir.mkdir(parents=True, exist_ok=True) cur_class_images = len(list(class_images_dir.iterdir())) - if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": @@ -569,6 +431,9 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + + log_cuda("before_prior_pipeline_load", args, accelerator, extra={"torch_dtype": str(torch_dtype)}) + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -577,115 +442,90 @@ def main(args): ) pipeline.set_progress_bar_config(disable=True) - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataset = PromptDataset(args.class_prompt, args.num_class_images - cur_class_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) + log_cuda("after_prior_pipeline_to_device", args, accelerator) + for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + log_cuda("after_prior_pipeline_del", args, accelerator) + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=args.revision, - use_fast=False, - ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False ) - # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - - # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision) + log_cuda("after_load_models_cpu_or_meta", args, accelerator) vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype + if accelerator.is_local_main_process and args.debug_oom: + logger.info(f"[debug] weight_dtype={weight_dtype} mixed_precision={accelerator.mixed_precision}") + text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) + log_cuda("after_models_to_device", args, accelerator) + attention_class = CustomDiffusionAttnProcessor if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) + logger.info(f"[debug] xformers_version={xformers_version}") if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + logger.warning( + "xFormers 0.0.16 may be unstable for training on some GPUs; consider upgrading to >=0.0.17." ) attention_class = CustomDiffusionXFormersAttnProcessor else: raise ValueError("xformers is not available. Make sure it is installed correctly") - # now we will add new Custom Diffusion weights to the attention layers - # It's important to realize here how many attention weights will be added and of which sizes - # The sizes of the attention layers consist only of two different variables: - # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. - # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. - - # Let's first see how many attention processors we will have to set. - # For Stable Diffusion, it should be equal to: - # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 - # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 - # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 - # => 32 layers - - # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer train_kv = True train_q_out = False if args.freeze_model == "crossattn_kv" else True custom_diffusion_attn_procs = {} st = unet.state_dict() - for name, _ in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -696,7 +536,9 @@ def main(args): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] + layer_name = name.split(".processor")[0] + weights = { "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], @@ -705,6 +547,7 @@ def main(args): weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + if cross_attention_dim is not None: custom_diffusion_attn_procs[name] = attention_class( train_kv=train_kv, @@ -722,38 +565,37 @@ def main(args): ) del st - unet.set_attn_processor(custom_diffusion_attn_procs) custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(custom_diffusion_layers) + log_trainable_params("unet_after_set_attn_processor", unet, args) + log_trainable_params("custom_diffusion_layers", custom_diffusion_layers, args) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if accelerator.is_local_main_process and args.debug_oom: + logger.info("[debug] gradient_checkpointing enabled") + if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + if accelerator.is_local_main_process and args.debug_oom: + logger.info("[debug] allow_tf32 enabled") - args.learning_rate = args.learning_rate if args.with_prior_preservation: args.learning_rate = args.learning_rate * 2.0 - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - + raise ImportError("To use 8-bit Adam, please install bitsandbytes: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit + if accelerator.is_local_main_process and args.debug_oom: + logger.info("[debug] using 8-bit AdamW") else: optimizer_class = torch.optim.AdamW - # Optimizer creation optimizer = optimizer_class( custom_diffusion_layers.parameters(), lr=args.learning_rate, @@ -762,25 +604,29 @@ def main(args): eps=args.adam_epsilon, ) - # Dataset creation: + # 与 CAAT 代码保持一致:通过一次 VAE encode 推导 mask_size + mask_size = ( + vae.encode(torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)) + .latent_dist.sample() + .size()[-1] + ) + if accelerator.is_local_main_process and args.debug_oom: + logger.info(f"[debug] inferred mask_size={mask_size}") + train_dataset = CustomDiffusionDataset( concepts_list=args.concepts_list, tokenizer=tokenizer, with_prior_preservation=args.with_prior_preservation, size=args.resolution, - mask_size=vae.encode( - torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device) - ) - .latent_dist.sample() - .size()[-1], + mask_size=mask_size, center_crop=args.center_crop, num_class_images=args.num_class_images, hflip=args.hflip, aug=not args.noaug, ) + log_cuda("after_build_dataset", args, accelerator, extra={"num_instance_images": train_dataset.num_instance_images}) - # Prepare for PGD pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images] pertubed_images = torch.stack(pertubed_images).contiguous() @@ -822,7 +668,10 @@ def main(args): mask = mask.unsqueeze(1) del images_open_list - + log_tensor("pertubed_images_before_prepare", pertubed_images, args, accelerator) + log_tensor("original_images_before_prepare", original_images, args, accelerator) + log_tensor("mask_before_prepare", mask, args, accelerator) + log_tensor("input_ids_cpu", input_ids, args, accelerator) lr_scheduler = get_scheduler( args.lr_scheduler, @@ -831,106 +680,172 @@ def main(args): num_training_steps=args.max_train_steps * accelerator.num_processes, ) + log_cuda("before_accelerator_prepare", args, accelerator) + custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare( custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask ) + log_cuda("after_accelerator_prepare", args, accelerator) + log_tensor("pertubed_images_after_prepare", pertubed_images, args, accelerator) + log_tensor("original_images_after_prepare", original_images, args, accelerator) + log_tensor("mask_after_prepare", mask, args, accelerator) - # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num pertubed_images = {len(pertubed_images)}") logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 first_epoch = 0 - # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") + for epoch in range(first_epoch, args.max_train_steps): unet.train() + for _ in range(1): with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): - # Convert images to latent space - pertubed_images.requires_grad = True - latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + if _debug_should_print(args, global_step): + log_cuda("step_begin", args, accelerator, extra={"global_step": global_step}) + logger.info(f"[debug] step={global_step} starting forward path") + + # 关键定位:你说每次“加噪第0步开始前就爆”,这里把每个子阶段都打点 + try: + pertubed_images.requires_grad = True + if _debug_should_print(args, global_step): + log_tensor("pertubed_images_pre_vae", pertubed_images, args, accelerator) + log_cuda("before_vae_encode", args, accelerator, extra={"global_step": global_step}) - # unet.zero_grad() - # text_encoder.zero_grad() + latents_dist = vae.encode( + pertubed_images.to(accelerator.device).to(dtype=weight_dtype) + ).latent_dist - if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + if _debug_should_print(args, global_step): + log_cuda("after_vae_encode", args, accelerator, extra={"global_step": global_step}) - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + latents = latents_dist.sample() + latents = latents * vae.config.scaling_factor + + if _debug_should_print(args, global_step): + log_tensor("latents", latents, args, accelerator) + log_cuda("after_latents_sample", args, accelerator, extra={"global_step": global_step}) - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - mask = mask.to(accelerator.device) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64]) + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + if _debug_should_print(args, global_step): + log_tensor("noise", noise, args, accelerator) + log_tensor("timesteps", timesteps, args, accelerator) + log_cuda("before_add_noise", args, accelerator, extra={"global_step": global_step}) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + if _debug_should_print(args, global_step): + log_tensor("noisy_latents", noisy_latents, args, accelerator) + log_cuda("after_add_noise", args, accelerator, extra={"global_step": global_step}) + + if _debug_should_print(args, global_step): + log_cuda("before_text_encoder", args, accelerator, extra={"global_step": global_step}) + + encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] + + if _debug_should_print(args, global_step): + log_tensor("encoder_hidden_states", encoder_hidden_states, args, accelerator) + log_cuda("after_text_encoder", args, accelerator, extra={"global_step": global_step}) + + if _debug_should_print(args, global_step): + log_cuda("before_unet_forward", args, accelerator, extra={"global_step": global_step}) + + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if _debug_should_print(args, global_step): + log_tensor("model_pred", model_pred, args, accelerator) + log_cuda("after_unet_forward", args, accelerator, extra={"global_step": global_step}) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + mask_inst = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = ((loss * mask_inst).sum([1, 2, 3]) / mask_inst.sum([1, 2, 3])).mean() + + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + loss = loss + args.prior_loss_weight * prior_loss + else: + mask_inst = mask.to(accelerator.device) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean() + + if _debug_should_print(args, global_step): + logger.info(f"[debug] step={global_step} loss_value={loss.detach().float().item()}") + log_cuda("before_backward", args, accelerator, extra={"global_step": global_step}) + + accelerator.backward(loss) + + if _debug_should_print(args, global_step): + log_cuda("after_backward", args, accelerator, extra={"global_step": global_step}) + if pertubed_images.grad is None: + logger.info(f"[debug] step={global_step} pertubed_images.grad=None") + else: + logger.info( + f"[debug] step={global_step} pertubed_images.grad_abs_mean=" + f"{pertubed_images.grad.abs().mean().item():.6e}" + ) + + if accelerator.sync_gradients: + params_to_clip = custom_diffusion_layers.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + alpha = args.alpha + eps = args.eps + + if _debug_should_print(args, global_step): + logger.info(f"[debug] step={global_step} alpha={alpha} eps={eps}") + log_cuda("before_pgd_update", args, accelerator, extra={"global_step": global_step}) - #loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() - loss = loss.mean() + adv_images = pertubed_images + alpha * pertubed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - accelerator.backward(loss) + if _debug_should_print(args, global_step): + log_tensor("pertubed_images_after_pgd", pertubed_images, args, accelerator) + log_cuda("after_pgd_update", args, accelerator, extra={"global_step": global_step}) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + if _debug_should_print(args, global_step): + log_cuda("after_optimizer_step", args, accelerator, extra={"global_step": global_step}) - if accelerator.sync_gradients: - params_to_clip = ( - custom_diffusion_layers.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + except RuntimeError as e: + # 捕获OOM并打印尽可能多的上下文,便于你定位爆显存发生在哪个子阶段 + if "out of memory" in str(e).lower() or "cuda" in str(e).lower(): + logger.error(f"[OOM] step={global_step} caught RuntimeError: {e}") + log_cuda("oom_caught", args, accelerator, extra={"global_step": global_step}) + logger.error("[OOM] 如果你看到oom发生在 before_unet_forward/after_unet_forward 附近,通常是UNet前向峰值") + logger.error("[OOM] 如果你看到oom发生在 before_backward/after_backward 附近,通常是反传保存激活导致峰值") + logger.error("[OOM] 如果你看到oom发生在 after_accelerator_prepare 附近,通常是prepare或模型常驻占用过高") + raise - alpha = args.alpha - eps = args.eps - adv_images = pertubed_images + alpha * pertubed_images.grad.sign() - eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) - pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - - optimizer.step() - - lr_scheduler.step() - optimizer.zero_grad(set_to_none=args.set_grads_to_none) - - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -941,27 +856,18 @@ def main(args): if accelerator.is_main_process: logger.info("***** Final save of perturbed images *****") save_folder = args.output_dir + noised_imgs = pertubed_images.detach().cpu() - noised_imgs = pertubed_images.detach().cpu() - - img_names = [ - str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path - ] - - num_images_to_save = len(img_names) - - for i in range(num_images_to_save): + img_names = [str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path] + for i in range(len(img_names)): img_pixel = noised_imgs[i] img_name = img_names[i] save_path = os.path.join(save_folder, f"final_noise_{img_name}") - - # 图像转换和保存 Image.fromarray( (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy() ).save(save_path) - - logger.info(f"Saved {num_images_to_save} final perturbed images to {save_folder}") + logger.info(f"Saved {len(img_names)} final perturbed images to {save_folder}") accelerator.end_training() @@ -969,4 +875,4 @@ def main(args): if __name__ == "__main__": args = parse_args() main(args) - print("<-------end-------->") + print("<-------end-------->") \ No newline at end of file diff --git a/src/backend/app/algorithms/perturbation/simac.py b/src/backend/app/algorithms/perturbation/simac.py index cab654d..af93f64 100644 --- a/src/backend/app/algorithms/perturbation/simac.py +++ b/src/backend/app/algorithms/perturbation/simac.py @@ -10,7 +10,7 @@ from pathlib import Path import datasets import diffusers -import transformers +import transformers import numpy as np import torch import torch.nn.functional as F @@ -30,19 +30,84 @@ from transformers import AutoTokenizer, PretrainedConfig logger = get_logger(__name__) +# ----------------------------- +# Lightweight debug helpers +# ----------------------------- def _cuda_gc() -> None: - """Try to release unreferenced CUDA memory and reduce fragmentation. - - This is a best-effort helper. It does not change algorithmic behavior but can - make long runs less prone to OOM due to fragmentation/reserved-memory growth. - """ + """Best-effort CUDA memory cleanup (does not change algorithmic behavior).""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() +def _fmt_bytes(n: int) -> str: + return f"{n / (1024**2):.1f}MB" + + +def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None): + """Log CUDA memory stats without copying tensors to CPU.""" + if not torch.cuda.is_available(): + logger.info(f"[mem] {prefix} cuda_not_available") + return + if sync: + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() + reserv = torch.cuda.memory_reserved() + max_alloc = torch.cuda.max_memory_allocated() + max_reserv = torch.cuda.max_memory_reserved() + dev = str(accelerator.device) if accelerator is not None else "cuda" + msg = ( + f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} " + f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}" + ) + if extra: + msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()]) + logger.info(msg) + + +def log_path_stats(prefix: str, p: Path): + """Log directory/file existence and file count (best-effort).""" + try: + exists = p.exists() + is_dir = p.is_dir() if exists else False + n_files = 0 + if exists and is_dir: + n_files = sum(1 for x in p.iterdir() if x.is_file()) + logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}") + except Exception as e: + logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}") + + +def log_args(args): + for k in sorted(vars(args).keys()): + logger.info(f"[args] {k}={getattr(args, k)}") + + +def log_tensor_meta(prefix: str, t: torch.Tensor | None): + if t is None: + logger.info(f"[tensor] {prefix} None") + return + logger.info( + f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " + f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" + ) + + +def setup_seeds(): + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + + +# ----------------------------- +# Dataset +# ----------------------------- class DreamBoothDatasetFromTensor(Dataset): - """Just like DreamBoothDataset, but take instance_images_tensor instead of path.""" + """基于内存张量的 DreamBooth 数据集:直接返回图像张量与 prompt token。""" def __init__( self, @@ -66,10 +131,19 @@ class DreamBoothDatasetFromTensor(Dataset): if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) + # Only keep files to avoid directories affecting length. + self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()] self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt + + # Early, explicit failure instead of ZeroDivisionError later. + if self.num_class_images == 0: + raise ValueError( + f"class_data_dir is empty: {self.class_data_root}. " + f"Prior preservation requires class images. " + f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation." + ) else: self.class_data_root = None @@ -98,6 +172,9 @@ class DreamBoothDatasetFromTensor(Dataset): ).input_ids if self.class_data_root: + # Defensive: if class_images become empty due to external deletion, raise a clear error. + if self.num_class_images == 0: + raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}") class_image = Image.open(self.class_images_path[index % self.num_class_images]) if class_image.mode != "RGB": class_image = class_image.convert("RGB") @@ -113,6 +190,9 @@ class DreamBoothDatasetFromTensor(Dataset): return example +# ----------------------------- +# Model helpers +# ----------------------------- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -132,370 +212,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--instance_data_dir_for_train", - type=str, - default=None, - required=True, - help="A folder containing the training data of instance images.", - ) - parser.add_argument( - "--instance_data_dir_for_adversarial", - type=str, - default=None, - required=True, - help="A folder containing the images to add adversarial noise", - ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - required=True, - help="The prompt with identifier specifying the instance", - ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided instance images.", - ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", - ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="text-inversion-model", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--sample_batch_size", - type=int, - default=8, - help="Batch size (per device) for sampling images.", - ) - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform.", - ) - parser.add_argument( - "--max_f_train_steps", - type=int, - default=10, - help="Total number of sub-steps to train surogate model.", - ) - parser.add_argument( - "--max_adv_train_steps", - type=int, - default=10, - help="Total number of sub-steps to train adversarial noise.", - ) - parser.add_argument( - "--checkpointing_iterations", - type=int, - default=5, - help=("Save a checkpoint of the training state every X iterations."), - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-6, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", - ) - parser.add_argument( - "--pgd_alpha", - type=float, - default=0.005, - help="The step size for pgd.", - ) - parser.add_argument( - "--pgd_eps", - type=int, - default=16, - help="The noise budget for pgd.", - ) - parser.add_argument( - "--target_image_path", - default=None, - help="target image for attacking", - ) - parser.add_argument( - "--max_steps", - type=int, - default=50, - help=("Maximum steps for adaptive greedy timestep selection."), - ) - parser.add_argument( - "--delta_t", - type=int, - default=20, - help=("delete 2*delta_t for each adaptive greedy timestep selection."), - ) - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - return args - - -class PromptDataset(Dataset): - """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" - - def __init__(self, prompt, num_samples): - self.prompt = prompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example - - -def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: - image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] - images = torch.stack(images) - return images - - -def train_one_epoch( - args, - models, - tokenizer, - noise_scheduler, - vae, - data_tensor: torch.Tensor, - num_steps=20, -): - unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) - params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) - - optimizer = torch.optim.AdamW( - params_to_optimize, - lr=args.learning_rate, - betas=(0.9, 0.999), - weight_decay=1e-2, - eps=1e-08, - ) - - train_dataset = DreamBoothDatasetFromTensor( - data_tensor, - args.instance_prompt, - tokenizer, - args.class_data_dir, - args.class_prompt, - args.resolution, - args.center_crop, - ) - - weight_dtype = torch.bfloat16 - device = torch.device("cuda") - - vae.to(device, dtype=weight_dtype) - text_encoder.to(device, dtype=weight_dtype) - unet.to(device, dtype=weight_dtype) - - for step in range(num_steps): - unet.train() - text_encoder.train() - - step_data = train_dataset[step % len(train_dataset)] - pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( - device, dtype=weight_dtype - ) - input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) - - latents = vae.encode(pixel_values).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - encoder_hidden_states = text_encoder(input_ids)[0] - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.with_prior_preservation: - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - loss = instance_loss + args.prior_loss_weight * prior_loss - else: - prior_loss = torch.tensor(0.0, device=device) - instance_loss = torch.tensor(0.0, device=device) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - loss.backward() - torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) - optimizer.step() - optimizer.zero_grad() - - print( - f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, " - f"instance_loss: {instance_loss.detach().item()}" - ) - - # Best-effort: free per-step tensors earlier (no behavior change). - del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states - del model_pred, target, loss, prior_loss, instance_loss - - # Best-effort: release optimizer state + dataset refs sooner. - del optimizer, train_dataset, params_to_optimize - _cuda_gc() - - return [unet, text_encoder] - - def set_unet_attr(unet): def conv_forward(self): def forward(input_tensor, temb): @@ -565,65 +281,139 @@ def save_feature_maps(up_blocks, down_blocks): return out_layers_features_list_3 -def pgd_attack( - args, - models, - tokenizer, - noise_scheduler, - vae, - data_tensor: torch.Tensor, - original_images: torch.Tensor, - target_tensor: torch.Tensor, - num_steps: int, - time_list, -): - """Return new perturbed data. - - Note: This function keeps the external behavior identical, but tries to reduce - memory pressure by freeing tensors early and avoiding lingering references. - """ - unet, text_encoder = models +# ----------------------------- +# Args +# ----------------------------- +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True) + parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True) + parser.add_argument("--class_data_dir", type=str, default=None, required=False) + parser.add_argument("--instance_prompt", type=str, default=None, required=True) + parser.add_argument("--class_prompt", type=str, default=None) + parser.add_argument("--with_prior_preservation", default=False, action="store_true") + parser.add_argument("--prior_loss_weight", type=float, default=1.0) + parser.add_argument("--num_class_images", type=int, default=100) + parser.add_argument("--output_dir", type=str, default="text-inversion-model") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--center_crop", default=False, action="store_true") + parser.add_argument("--train_text_encoder", action="store_true") + parser.add_argument("--train_batch_size", type=int, default=4) + parser.add_argument("--sample_batch_size", type=int, default=8) + parser.add_argument("--max_train_steps", type=int, default=20) + parser.add_argument("--max_f_train_steps", type=int, default=10) + parser.add_argument("--max_adv_train_steps", type=int, default=10) + parser.add_argument("--checkpointing_iterations", type=int, default=5) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument("--report_to", type=str, default="tensorboard") + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") + parser.add_argument("--pgd_alpha", type=float, default=0.005) + parser.add_argument("--pgd_eps", type=int, default=16) + parser.add_argument("--target_image_path", default=None) + parser.add_argument("--max_steps", type=int, default=50) + parser.add_argument("--delta_t", type=int, default=20) + + # Debug / diagnostics (low-overhead) + parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.") + parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.") + parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + return args + + +# ----------------------------- +# IO helpers +# ----------------------------- +def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: + image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] + images = torch.stack(images) + return images + + +# ----------------------------- +# Train / Attack +# ----------------------------- +def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20, accelerator=None): + unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + train_dataset = DreamBoothDatasetFromTensor( + data_tensor, + args.instance_prompt, + tokenizer, + args.class_data_dir if args.with_prior_preservation else None, + args.class_prompt, + args.resolution, + args.center_crop, + ) + weight_dtype = torch.bfloat16 device = torch.device("cuda") vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) - set_unet_attr(unet) - perturbed_images = data_tensor.detach().clone() - perturbed_images.requires_grad_(True) + for step in range(num_steps): + if args.debug_step0_only and step != 0: + pass - # Keep input_ids on CPU; move to GPU only when encoding. - input_ids = tokenizer( - args.instance_prompt, - truncation=True, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", - ).input_ids.repeat(len(data_tensor), 1) + unet.train() + text_encoder.train() - for step in range(num_steps): - perturbed_images.requires_grad_(True) + try: + step_data = train_dataset[step % len(train_dataset)] + except Exception as e: + logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}") + raise + + # This will fail fast if class_images missing (KeyError), better than silent wrong behavior. + try: + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(device, dtype=weight_dtype) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + except KeyError as e: + logger.error( + f"[err] missing key in step_data at step={step}: missing={str(e)}. " + f"with_prior_preservation={args.with_prior_preservation}" + ) + raise - latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) - - timesteps = [] - for i in range(len(data_tensor)): - ts = time_list[i] - ts_index = torch.randint(0, len(ts), (1,)) - timestep = torch.IntTensor([ts[ts_index]]) - timestep = timestep.long() - timesteps.append(timestep) - timesteps = torch.cat(timesteps).to(device) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - encoder_hidden_states = text_encoder(input_ids.to(device))[0] - + encoder_hidden_states = text_encoder(input_ids)[0] model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": @@ -633,62 +423,38 @@ def pgd_attack( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) - - with torch.no_grad(): - clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() - clean_latents = clean_latents * vae.config.scaling_factor - noisy_clean_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) - _ = unet(noisy_clean_latents, timesteps, encoder_hidden_states).sample - clean_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) - - target_loss = F.mse_loss( - noise_out_layers_features_3.float(), - clean_out_layers_features_3.float(), - reduction="mean", - ) + if args.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) - unet.zero_grad(set_to_none=True) - text_encoder.zero_grad(set_to_none=True) + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + loss = instance_loss + args.prior_loss_weight * prior_loss + else: + prior_loss = torch.tensor(0.0, device=device) + instance_loss = torch.tensor(0.0, device=device) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Keep original behavior: feature loss does not backprop (added as Python float). - loss = loss + target_loss.detach().item() loss.backward() + torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) + optimizer.step() + optimizer.zero_grad() - alpha = args.pgd_alpha - eps = args.pgd_eps / 255 - adv_images = perturbed_images + alpha * perturbed_images.grad.sign() - eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) - perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - - print( - f"PGD loss - step {step}, loss: {loss.detach().item()}, target_loss : {target_loss.detach().item()}" + logger.info( + f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} " + f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}" ) - # Best-effort: free per-step tensors early. - del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target - del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3 - del target_loss, loss, adv_images, eta + # Free some step tensors early. + del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states + del model_pred, target, loss, prior_loss, instance_loss + del optimizer, train_dataset, params_to_optimize _cuda_gc() - return perturbed_images + return [unet, text_encoder] -def select_timestep( - args, - models, - tokenizer, - noise_scheduler, - vae, - data_tensor: torch.Tensor, - original_images: torch.Tensor, - target_tensor: torch.Tensor, -): - """Return timestep lists for each image. - - External behavior unchanged; add best-effort per-loop cleanup to lower memory pressure. - """ +def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor): unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -731,13 +497,10 @@ def select_timestep( noise = torch.randn_like(latents) bsz = latents.shape[0] inner_index = torch.randint(0, len(res_time_seq), (bsz,)) - timesteps = torch.IntTensor([res_time_seq[inner_index]]).to(device) - timesteps = timesteps.long() + timesteps = torch.IntTensor([res_time_seq[inner_index]]).to(device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(input_ids.to(device))[0] - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": @@ -769,14 +532,15 @@ def select_timestep( max_score = score select_t = res_time_seq[inner_index].item() - print( - f"PGD loss - step {step}, index : {inner_try + 1}, loss: {loss.detach().item()}, " - f"score: {score}, t : {res_time_seq[inner_index]}, ts_len: {len(res_time_seq)}" - ) + if args.debug: + logger.info( + f"[select_timestep] img={img_id} outer={step} inner={inner_try} loss={loss.detach().item():.6f} " + f"score={score.item() if torch.is_tensor(score) else score} t={res_time_seq[inner_index].item()} " + f"len={len(res_time_seq)}" + ) del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, score - print("del_t", del_t, "max_t", select_t) if del_t < args.delta_t: del_t = args.delta_t elif del_t > (1000 - args.delta_t): @@ -790,8 +554,7 @@ def select_timestep( latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) - timesteps = torch.IntTensor([select_t]).to(device) - timesteps = timesteps.long() + timesteps = torch.IntTensor([select_t]).to(device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(input_ids.to(device))[0] @@ -806,7 +569,6 @@ def select_timestep( unet.zero_grad(set_to_none=True) text_encoder.zero_grad(set_to_none=True) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss.backward() @@ -814,7 +576,6 @@ def select_timestep( eps = args.pgd_eps / 255 adv_image = id_image + alpha * id_image.grad.sign() eta = torch.clamp(adv_image - original_image, min=-eps, max=+eps) - _ = torch.sum(torch.abs(id_image.grad.sign())) id_image = torch.clamp(original_image + eta, min=-1, max=+1).detach_() del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, adv_image, eta @@ -831,16 +592,99 @@ def select_timestep( return time_list -def setup_seeds(): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - cudnn.benchmark = False - cudnn.deterministic = True +def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps, time_list): + unet, text_encoder = models + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + set_unet_attr(unet) + + perturbed_images = data_tensor.detach().clone() + perturbed_images.requires_grad_(True) + + input_ids = tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.repeat(len(data_tensor), 1) + + for step in range(num_steps): + if args.debug_step0_only and step != 0: + pass + + perturbed_images.requires_grad_(True) + + latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + noise = torch.randn_like(latents) + + timesteps = [] + for i in range(len(data_tensor)): + ts = time_list[i] + if len(ts) == 0: + raise ValueError(f"time_list[{i}] is empty; select_timestep failed.") + ts_index = torch.randint(0, len(ts), (1,)) + timestep = torch.IntTensor([ts[ts_index]]).long() + timesteps.append(timestep) + timesteps = torch.cat(timesteps).to(device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + + with torch.no_grad(): + clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() + clean_latents = clean_latents * vae.config.scaling_factor + noisy_clean_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) + _ = unet(noisy_clean_latents, timesteps, encoder_hidden_states).sample + clean_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + + target_loss = F.mse_loss(noise_out_layers_features_3.float(), clean_out_layers_features_3.float(), reduction="mean") + + unet.zero_grad(set_to_none=True) + text_encoder.zero_grad(set_to_none=True) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = loss + target_loss.detach().item() + loss.backward() + + alpha = args.pgd_alpha + eps = args.pgd_eps / 255 + adv_images = perturbed_images + alpha * perturbed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + + logger.info( + f"[pgd] step={step} loss={loss.detach().item():.6f} target_loss={target_loss.detach().item():.6f} " + f"alpha={alpha} eps={eps}" + ) + del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target + del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3 + del target_loss, loss, adv_images, eta + _cuda_gc() + return perturbed_images + + +# ----------------------------- +# Main +# ----------------------------- def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -855,6 +699,7 @@ def main(args): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) + logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() @@ -865,15 +710,28 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() + if accelerator.is_local_main_process: + logger.info(f"[run] using_file={__file__}") + log_args(args) + if args.seed is not None: set_seed(args.seed) setup_seeds() - # Generate class images if prior preservation is enabled. + if args.debug and accelerator.is_local_main_process: + log_cuda("startup", accelerator, sync=args.debug_cuda_sync) + + # ------------------------- + # Prior preservation: generate class images if needed + # ------------------------- if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) class_images_dir.mkdir(parents=True, exist_ok=True) - cur_class_images = len(list(class_images_dir.iterdir())) + log_path_stats("class_dir_before", class_images_dir) + + cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) + if accelerator.is_local_main_process: + logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}") if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 @@ -884,6 +742,10 @@ def main(args): elif args.mixed_precision == "bf16": torch_dtype = torch.bfloat16 + if accelerator.is_local_main_process: + logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}") + log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync) + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -893,20 +755,24 @@ def main(args): pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) + if accelerator.is_local_main_process and args.debug: + log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync) + for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images + if accelerator.is_local_main_process and args.debug: + logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}") + for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" @@ -915,32 +781,45 @@ def main(args): del pipeline, sample_dataset, sample_dataloader _cuda_gc() + # IMPORTANT: sync all processes before training reads the directory + accelerator.wait_for_everyone() + + # Post-check: ensure class images exist + final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) + if accelerator.is_local_main_process: + logger.info(f"[class_gen] done final_class_images={final_class_images}") + log_path_stats("class_dir_after", class_images_dir) + if final_class_images == 0: + raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.") + + else: + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + logger.info("[class_gen] skipped (already enough images)") + else: + if accelerator.is_local_main_process: + logger.info("[class_gen] disabled (with_prior_preservation is False)") + + # ------------------------- + # Load models + # ------------------------- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + if accelerator.is_local_main_process and args.debug: + log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync) + text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False ) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision ).cuda() vae.requires_grad_(False) @@ -950,93 +829,85 @@ def main(args): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True - clean_data = load_data( - args.instance_data_dir_for_train, - size=args.resolution, - center_crop=args.center_crop, - ) - perturbed_data = load_data( - args.instance_data_dir_for_adversarial, - size=args.resolution, - center_crop=args.center_crop, - ) - original_data = perturbed_data.clone() - original_data.requires_grad_(False) - if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() + if accelerator.is_local_main_process: + logger.info("[xformers] enabled") else: raise ValueError("xformers is not available. Make sure it is installed correctly") + if accelerator.is_local_main_process and args.debug: + log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync) + + # ------------------------- + # Load data tensors + # ------------------------- + train_dir = Path(args.instance_data_dir_for_train) + adv_dir = Path(args.instance_data_dir_for_adversarial) + if accelerator.is_local_main_process and args.debug: + log_path_stats("train_dir", train_dir) + log_path_stats("adv_dir", adv_dir) + + clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop) + perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + + if accelerator.is_local_main_process and args.debug: + log_tensor_meta("clean_data_cpu", clean_data) + log_tensor_meta("perturbed_data_cpu", perturbed_data) + target_latent_tensor = None if args.target_image_path is not None: target_image_path = Path(args.target_image_path) - assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" + if not target_image_path.is_file(): + raise ValueError(f"Target image path does not exist: {target_image_path}") target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 - target_latent_tensor = ( - vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor - ) + target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) + target_latent_tensor = target_latent_tensor * vae.config.scaling_factor target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() - f = [unet, text_encoder] + if accelerator.is_local_main_process and args.debug: + log_tensor_meta("target_latent_tensor", target_latent_tensor) - time_list = select_timestep( - args, - f, - tokenizer, - noise_scheduler, - vae, - perturbed_data, - original_data, - target_latent_tensor, - ) - for t in time_list: - print(t) + f = [unet, text_encoder] + # ------------------------- + # Select timesteps + # ------------------------- + if accelerator.is_local_main_process: + logger.info("[phase] select_timestep begin") + time_list = select_timestep(args, f, tokenizer, noise_scheduler, vae, perturbed_data, original_data, target_latent_tensor) + if accelerator.is_local_main_process: + logger.info("[phase] select_timestep end") + if args.debug: + for i, t in enumerate(time_list[: min(10, len(time_list))]): + logger.info(f"[time_list] idx={i} len={len(t)} first={t[0].item() if len(t)>0 else 'NA'}") + + # ------------------------- + # Main training loop + # ------------------------- for i in range(args.max_train_steps): - f_sur = copy.deepcopy(f) + if accelerator.is_local_main_process: + logger.info(f"[outer] i={i}/{args.max_train_steps}") - f_sur = train_one_epoch( - args, - f_sur, - tokenizer, - noise_scheduler, - vae, - clean_data, - args.max_f_train_steps, - ) + f_sur = copy.deepcopy(f) + f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps, accelerator=accelerator) perturbed_data = pgd_attack( - args, - f_sur, - tokenizer, - noise_scheduler, - vae, - perturbed_data, - original_data, - target_latent_tensor, - args.max_adv_train_steps, - time_list, + args, f_sur, tokenizer, noise_scheduler, vae, + perturbed_data, original_data, target_latent_tensor, args.max_adv_train_steps, time_list ) - # Free surrogate ASAP (best-effort, behavior unchanged). del f_sur _cuda_gc() - f = train_one_epoch( - args, - f, - tokenizer, - noise_scheduler, - vae, - perturbed_data, - args.max_f_train_steps, - ) + f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps, accelerator=accelerator) if (i + 1) % args.checkpointing_iterations == 0: save_folder = args.output_dir @@ -1059,9 +930,9 @@ def main(args): .numpy() ).save(save_path) - print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") + if accelerator.is_local_main_process: + logger.info(f"[save] step={i+1} saved={len(img_names)} to {save_folder}") - # Best-effort cleanup at the end of each outer iteration. _cuda_gc() -- 2.34.1 From e447ec0984cf70424844a738f07feedb9447e7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Wed, 7 Jan 2026 14:38:45 +0800 Subject: [PATCH 3/5] =?UTF-8?q?Revert=20"improve:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=AE=97=E6=B3=95"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b5af0d22ab26a382ca80c9f9facf24d93f39c7ca. --- .../app/algorithms/perturbation/aspl.py | 648 ++++++----- .../app/algorithms/perturbation/caat.py | 746 ++++++------ .../app/algorithms/perturbation/simac.py | 1001 ++++++++++------- 3 files changed, 1337 insertions(+), 1058 deletions(-) diff --git a/src/backend/app/algorithms/perturbation/aspl.py b/src/backend/app/algorithms/perturbation/aspl.py index c96ae16..6f26194 100644 --- a/src/backend/app/algorithms/perturbation/aspl.py +++ b/src/backend/app/algorithms/perturbation/aspl.py @@ -1,11 +1,9 @@ import argparse import copy -import gc import hashlib import itertools import logging import os -import random from pathlib import Path import datasets @@ -26,77 +24,12 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig + logger = get_logger(__name__) -# ----------------------------- -# Lightweight debug helpers (low overhead) -# ----------------------------- -def _cuda_gc() -> None: - """Best-effort CUDA memory cleanup (does not change algorithmic behavior).""" - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def _fmt_bytes(n: int) -> str: - return f"{n / (1024**2):.1f}MB" - - -def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None): - """Log CUDA memory stats without copying tensors to CPU.""" - if not torch.cuda.is_available(): - logger.info(f"[mem] {prefix} cuda_not_available") - return - if sync: - torch.cuda.synchronize() - alloc = torch.cuda.memory_allocated() - reserv = torch.cuda.memory_reserved() - max_alloc = torch.cuda.max_memory_allocated() - max_reserv = torch.cuda.max_memory_reserved() - dev = str(accelerator.device) if accelerator is not None else "cuda" - msg = ( - f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} " - f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}" - ) - if extra: - msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()]) - logger.info(msg) - - -def log_path_stats(prefix: str, p: Path): - """Log directory/file existence and file count (best-effort).""" - try: - exists = p.exists() - is_dir = p.is_dir() if exists else False - n_files = 0 - if exists and is_dir: - n_files = sum(1 for x in p.iterdir() if x.is_file()) - logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}") - except Exception as e: - logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}") - - -def log_args(args): - for k in sorted(vars(args).keys()): - logger.info(f"[args] {k}={getattr(args, k)}") - - -def log_tensor_meta(prefix: str, t: torch.Tensor | None): - if t is None: - logger.info(f"[tensor] {prefix} None") - return - logger.info( - f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " - f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" - ) - - -# ----------------------------- -# Dataset -# ----------------------------- class DreamBoothDatasetFromTensor(Dataset): - """基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。""" + """Just like DreamBoothDataset, but take instance_images_tensor instead of path""" def __init__( self, @@ -120,19 +53,10 @@ class DreamBoothDatasetFromTensor(Dataset): if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - # Only keep files to avoid directories affecting length. - self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()] + self.class_images_path = list(self.class_data_root.iterdir()) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt - - # Early, explicit failure instead of ZeroDivisionError later. - if self.num_class_images == 0: - raise ValueError( - f"class_data_dir is empty: {self.class_data_root}. " - f"Prior preservation requires class images. " - f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation." - ) else: self.class_data_root = None @@ -161,10 +85,8 @@ class DreamBoothDatasetFromTensor(Dataset): ).input_ids if self.class_data_root: - if self.num_class_images == 0: - raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}") class_image = Image.open(self.class_images_path[index % self.num_class_images]) - if class_image.mode != "RGB": + if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( @@ -178,9 +100,6 @@ class DreamBoothDatasetFromTensor(Dataset): return example -# ----------------------------- -# Model helper -# ----------------------------- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -201,47 +120,217 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") -# ----------------------------- -# Args -# ----------------------------- def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) - parser.add_argument("--revision", type=str, default=None, required=False) - parser.add_argument("--tokenizer_name", type=str, default=None) - parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True) - parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True) - parser.add_argument("--class_data_dir", type=str, default=None, required=False) - parser.add_argument("--instance_prompt", type=str, default=None, required=True) - parser.add_argument("--class_prompt", type=str, default=None) - parser.add_argument("--with_prior_preservation", default=False, action="store_true") - parser.add_argument("--prior_loss_weight", type=float, default=1.0) - parser.add_argument("--num_class_images", type=int, default=100) - parser.add_argument("--output_dir", type=str, default="text-inversion-model") - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--resolution", type=int, default=512) - parser.add_argument("--center_crop", default=False, action="store_true") - parser.add_argument("--train_text_encoder", action="store_true") - parser.add_argument("--train_batch_size", type=int, default=4) - parser.add_argument("--sample_batch_size", type=int, default=8) - parser.add_argument("--max_train_steps", type=int, default=20) - parser.add_argument("--max_f_train_steps", type=int, default=10) - parser.add_argument("--max_adv_train_steps", type=int, default=10) - parser.add_argument("--checkpointing_iterations", type=int, default=5) - parser.add_argument("--learning_rate", type=float, default=5e-6) - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument("--allow_tf32", action="store_true") - parser.add_argument("--report_to", type=str, default="tensorboard") - parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") - parser.add_argument("--pgd_alpha", type=float, default=1.0 / 255) - parser.add_argument("--pgd_eps", type=int, default=0.05) - parser.add_argument("--target_image_path", default=None) - - # Debug / diagnostics (low-overhead) - parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.") - parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.") - parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir_for_train", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_dir_for_adversarial", + type=str, + default=None, + required=True, + help="A folder containing the images to add adversarial noise", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=8, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--max_f_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train surogate model.", + ) + parser.add_argument( + "--max_adv_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train adversarial noise.", + ) + parser.add_argument( + "--checkpointing_iterations", + type=int, + default=5, + help=("Save a checkpoint of the training state every X iterations."), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--pgd_alpha", + type=float, + default=1.0 / 255, + help="The step size for pgd.", + ) + parser.add_argument( + "--pgd_eps", + type=int, + default=0.05, + help="The noise budget for pgd.", + ) + parser.add_argument( + "--target_image_path", + default=None, + help="target image for attacking", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -251,11 +340,8 @@ def parse_args(input_args=None): return args -# ----------------------------- -# Class image prompt dataset -# ----------------------------- class PromptDataset(Dataset): - """用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。""" + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt, num_samples): self.prompt = prompt @@ -271,9 +357,6 @@ class PromptDataset(Dataset): return example -# ----------------------------- -# IO -# ----------------------------- def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: image_transforms = transforms.Compose( [ @@ -289,10 +372,17 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: return images -# ----------------------------- -# Core routines -# ----------------------------- -def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20): +def train_one_epoch( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + num_steps=20, +): + # Load the tokenizer + unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -304,17 +394,17 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: eps=1e-08, ) - # IMPORTANT: only pass class_data_dir when with_prior_preservation is enabled. train_dataset = DreamBoothDatasetFromTensor( data_tensor, args.instance_prompt, tokenizer, - args.class_data_dir if args.with_prior_preservation else None, + args.class_data_dir, args.class_prompt, args.resolution, args.center_crop, ) + # weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -326,35 +416,33 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: unet.train() text_encoder.train() - try: - step_data = train_dataset[step % len(train_dataset)] - except Exception as e: - logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}") - raise - - try: - pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( - device, dtype=weight_dtype - ) - input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) - except KeyError as e: - logger.error( - f"[err] missing key in step_data at step={step}: missing={str(e)}. " - f"with_prior_preservation={args.with_prior_preservation}" - ) - raise + step_data = train_dataset[step % len(train_dataset)] + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( + device, dtype=weight_dtype + ) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor + # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Get the text embedding for conditioning encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -362,37 +450,47 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + # with prior preservation loss if args.with_prior_preservation: model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + # Compute instance loss instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. loss = instance_loss + args.prior_loss_weight * prior_loss + else: - prior_loss = torch.tensor(0.0, device=device) - instance_loss = torch.tensor(0.0, device=device) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss.backward() torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) optimizer.step() optimizer.zero_grad() - - logger.info( - f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} " - f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}" + print( + f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}" ) - del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states - del model_pred, target, loss, prior_loss, instance_loss - - del optimizer, train_dataset, params_to_optimize - _cuda_gc() return [unet, text_encoder] -def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps: int): +def pgd_attack( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, + num_steps: int, +): + """Return new perturbed data""" + unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -417,14 +515,24 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor + # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() + # Sample a random timestep for each image + #noise_scheduler.config.num_train_timesteps + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Get the text embedding for conditioning encoder_hidden_states = text_encoder(input_ids.to(device))[0] + + # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -432,10 +540,11 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - unet.zero_grad(set_to_none=True) - text_encoder.zero_grad(set_to_none=True) + unet.zero_grad() + text_encoder.zero_grad() loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + # target-shift loss if target_tensor is not None: xtm1_pred = torch.cat( [ @@ -452,25 +561,16 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi loss.backward() - alpha = args.pgd_alpha + alpha = args.pgd_alpha eps = args.pgd_eps / 255 adv_images = perturbed_images + alpha * perturbed_images.grad.sign() eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - - logger.info(f"[pgd] step={step} loss={loss.detach().item():.6f} alpha={alpha} eps={eps}") - - del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss - del adv_images, eta - - _cuda_gc() + print(f"PGD loss - step {step}, loss: {loss.detach().item()}") return perturbed_images -# ----------------------------- -# Main -# ----------------------------- def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -486,7 +586,6 @@ def main(args): level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() @@ -496,35 +595,15 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() - if accelerator.is_local_main_process: - logger.info(f"[run] using_file={__file__}") - log_args(args) - if args.seed is not None: set_seed(args.seed) - if args.debug and accelerator.is_local_main_process: - log_cuda("startup", accelerator, sync=args.debug_cuda_sync) - - # ------------------------- - # Prior preservation: generate class images if needed - # ------------------------- + # Generate class images if prior preservation is enabled. if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("--with_prior_preservation requires --class_data_dir") - if args.class_prompt is None: - raise ValueError("--with_prior_preservation requires --class_prompt") - class_images_dir = Path(args.class_data_dir) - class_images_dir.mkdir(parents=True, exist_ok=True) - - if accelerator.is_local_main_process: - log_path_stats("class_dir_before", class_images_dir) - - cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) - if accelerator.is_local_main_process: - logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}") - + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.mixed_precision == "fp32": @@ -533,12 +612,6 @@ def main(args): torch_dtype = torch.float16 elif args.mixed_precision == "bf16": torch_dtype = torch.bfloat16 - - if accelerator.is_local_main_process: - logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}") - if args.debug: - log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync) - pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -548,67 +621,56 @@ def main(args): pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - if accelerator.is_local_main_process and args.debug: - log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync) - for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images - if accelerator.is_local_main_process and args.debug: - logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}") for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - del pipeline, sample_dataset, sample_dataloader - _cuda_gc() - - accelerator.wait_for_everyone() - - final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) - if accelerator.is_local_main_process: - logger.info(f"[class_gen] done final_class_images={final_class_images}") - log_path_stats("class_dir_after", class_images_dir) - if final_class_images == 0: - raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.") - - else: - accelerator.wait_for_everyone() - if accelerator.is_local_main_process: - logger.info("[class_gen] skipped (already enough images)") - else: - if accelerator.is_local_main_process: - logger.info("[class_gen] disabled (with_prior_preservation is False)") + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() - # ------------------------- - # Load models / tokenizer / scheduler / VAE - # ------------------------- + # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - if accelerator.is_local_main_process and args.debug: - log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync) - + # Load scheduler and models text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) + tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda() + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ).cuda() + vae.requires_grad_(False) if not args.train_text_encoder: @@ -617,60 +679,52 @@ def main(args): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + clean_data = load_data( + args.instance_data_dir_for_train, + size=args.resolution, + center_crop=args.center_crop, + ) + perturbed_data = load_data( + args.instance_data_dir_for_adversarial, + size=args.resolution, + center_crop=args.center_crop, + ) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - if accelerator.is_local_main_process: - logger.info("[xformers] enabled") else: raise ValueError("xformers is not available. Make sure it is installed correctly") - if accelerator.is_local_main_process and args.debug: - log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync) - - # ------------------------- - # Load data tensors - # ------------------------- - train_dir = Path(args.instance_data_dir_for_train) - adv_dir = Path(args.instance_data_dir_for_adversarial) - if accelerator.is_local_main_process and args.debug: - log_path_stats("train_dir", train_dir) - log_path_stats("adv_dir", adv_dir) - - clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop) - perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop) - original_data = perturbed_data.clone() - original_data.requires_grad_(False) - - if accelerator.is_local_main_process and args.debug: - log_tensor_meta("clean_data_cpu", clean_data) - log_tensor_meta("perturbed_data_cpu", perturbed_data) - target_latent_tensor = None if args.target_image_path is not None: target_image_path = Path(args.target_image_path) - if not target_image_path.is_file(): - raise ValueError(f"Target image path does not exist: {target_image_path}") + assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 - target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) - target_latent_tensor = target_latent_tensor * vae.config.scaling_factor + target_latent_tensor = ( + vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor + ) target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() - if accelerator.is_local_main_process and args.debug: - log_tensor_meta("target_latent_tensor", target_latent_tensor) - f = [unet, text_encoder] for i in range(args.max_train_steps): - if accelerator.is_local_main_process: - logger.info(f"[outer] i={i}/{args.max_train_steps}") - + # 1. f' = f.clone() f_sur = copy.deepcopy(f) - f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps) - + f_sur = train_one_epoch( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + clean_data, + args.max_f_train_steps, + ) perturbed_data = pgd_attack( args, f_sur, @@ -682,31 +736,33 @@ def main(args): target_latent_tensor, args.max_adv_train_steps, ) - - f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps) + f = train_one_epoch( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + args.max_f_train_steps, + ) if (i + 1) % args.checkpointing_iterations == 0: save_folder = args.output_dir os.makedirs(save_folder, exist_ok=True) noised_imgs = perturbed_data.detach() - - img_filenames = [Path(instance_path).stem for instance_path in list(adv_dir.iterdir()) if instance_path.is_file()] + + img_filenames = [ + Path(instance_path).stem + for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir()) + ] for img_pixel, img_name in zip(noised_imgs, img_filenames): - save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") + save_path = os.path.join(save_folder, f"perturbed_{img_name}.png") Image.fromarray( - (img_pixel * 127.5 + 128) - .clamp(0, 255) - .to(torch.uint8) - .permute(1, 2, 0) - .cpu() - .numpy() + (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy() ).save(save_path) - - if accelerator.is_local_main_process: - logger.info(f"[save] step={i+1} saved={len(img_filenames)} to {save_folder}") - - _cuda_gc() + + print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") if __name__ == "__main__": diff --git a/src/backend/app/algorithms/perturbation/caat.py b/src/backend/app/algorithms/perturbation/caat.py index b5c331a..c7e41cd 100644 --- a/src/backend/app/algorithms/perturbation/caat.py +++ b/src/backend/app/algorithms/perturbation/caat.py @@ -64,9 +64,8 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st else: raise ValueError(f"{model_class} is not supported.") - class PromptDataset(Dataset): - """用于批量生成 class 图像的 prompt 数据集。""" + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt, num_samples): self.prompt = prompt @@ -83,7 +82,10 @@ class PromptDataset(Dataset): class CustomDiffusionDataset(Dataset): - """CAAT/Custom Diffusion 训练数据集。""" + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ def __init__( self, @@ -131,7 +133,6 @@ class CustomDiffusionDataset(Dataset): self.num_instance_images = len(self.instance_images_path) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) - self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) self.image_transforms = transforms.Compose( @@ -164,20 +165,19 @@ class CustomDiffusionDataset(Dataset): else: instance_image[top : top + inner, left : left + inner, :] = image mask[ - top // factor + 1 : (top + scale) // factor - 1, - left // factor + 1 : (left + scale) // factor - 1, + top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1 ] = 1.0 return instance_image, mask def __getitem__(self, index): example = {} - instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images] instance_image = Image.open(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") instance_image = self.flip(instance_image) + # apply resize augmentation and create a valid image region mask random_scale = self.size if self.aug: random_scale = ( @@ -220,59 +220,270 @@ class CustomDiffusionDataset(Dataset): return example + def parse_args(input_args=None): - """解析 CAAT 训练参数。""" parser = argparse.ArgumentParser(description="CAAT training script.") - - parser.add_argument("--alpha", type=float, default=5e-3, required=True, help="PGD alpha.") - parser.add_argument("--eps", type=float, default=0.1, required=True, help="PGD eps.") - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) - parser.add_argument("--revision", type=str, default=None, required=False) - parser.add_argument("--tokenizer_name", type=str, default=None) - parser.add_argument("--instance_data_dir", type=str, default=None) - parser.add_argument("--class_data_dir", type=str, default=None) - parser.add_argument("--instance_prompt", type=str, default=None) - parser.add_argument("--class_prompt", type=str, default=None) - parser.add_argument("--with_prior_preservation", default=False, action="store_true") - parser.add_argument("--prior_loss_weight", type=float, default=1.0) - parser.add_argument("--num_class_images", type=int, default=200) - parser.add_argument("--output_dir", type=str, default="outputs") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--resolution", type=int, default=512) - parser.add_argument("--center_crop", default=False, action="store_true") - parser.add_argument("--sample_batch_size", type=int, default=4) - parser.add_argument("--max_train_steps", type=int, default=250) - parser.add_argument("--checkpointing_steps", type=int, default=250) - parser.add_argument("--checkpoints_total_limit", type=int, default=None) - parser.add_argument("--gradient_checkpointing", action="store_true") - parser.add_argument("--learning_rate", type=float, default=1e-5) - parser.add_argument("--dataloader_num_workers", type=int, default=2) - parser.add_argument("--freeze_model", type=str, default="crossattn_kv", choices=["crossattn_kv", "crossattn"]) - parser.add_argument("--lr_scheduler", type=str, default="constant") - parser.add_argument("--lr_warmup_steps", type=int, default=500) - parser.add_argument("--use_8bit_adam", action="store_true") - parser.add_argument("--adam_beta1", type=float, default=0.9) - parser.add_argument("--adam_beta2", type=float, default=0.999) - parser.add_argument("--adam_weight_decay", type=float, default=1e-2) - parser.add_argument("--adam_epsilon", type=float, default=1e-08) - parser.add_argument("--max_grad_norm", default=1.0, type=float) - parser.add_argument("--hub_model_id", type=str, default=None) - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument("--allow_tf32", action="store_true") - parser.add_argument("--report_to", type=str, default="tensorboard") - parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) - parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"]) - parser.add_argument("--concepts_list", type=str, default=None) - parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") - parser.add_argument("--set_grads_to_none", action="store_true") - parser.add_argument("--initializer_token", type=str, default="ktn+pll+ucd") - parser.add_argument("--hflip", action="store_true") - parser.add_argument("--noaug", action="store_true") - - parser.add_argument("--debug_oom", action="store_true", help="开启显存与关键张量日志,用于定位第0步OOM") - parser.add_argument("--debug_oom_sync", action="store_true", help="打印前强制同步CUDA,日志更准但更慢") - parser.add_argument("--debug_oom_step0_only", action="store_true", help="只打印第0步相关日志,降低干扰") + parser.add_argument( + "--alpha", + type=float, + default=5e-3, + required=True, + help="PGD alpha.", + ) + parser.add_argument( + "--eps", + type=float, + default=0.1, + required=True, + help="PGD eps.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--num_class_images", + type=int, + default=200, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs", + help="The output directory.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=250, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=250, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=2, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--freeze_model", + type=str, + default="crossattn_kv", + choices=["crossattn_kv", "crossattn"], + help="crossattn to enable fine-tuning of all params in the cross attention", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word." + ) + parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") + parser.add_argument( + "--noaug", + action="store_true", + help="Dont apply augmentation during data augmentation when this flag is enabled.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -290,6 +501,7 @@ def parse_args(input_args=None): if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: + # logger is not available yet if args.class_data_dir is not None: warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: @@ -298,76 +510,9 @@ def parse_args(input_args=None): return args -def _fmt_bytes(n: int) -> str: - gb = n / (1024**3) - return f"{gb:.2f}GB" - - -def _debug_should_print(args, global_step: int) -> bool: - if not args.debug_oom: - return False - if args.debug_oom_step0_only: - return global_step == 0 - return True - - -def log_cuda(prefix: str, args, accelerator: Accelerator, extra: dict | None = None): - """打印CUDA显存状态与可选附加信息,不改变训练逻辑。""" - if not args.debug_oom: - return - if not torch.cuda.is_available(): - logger.info(f"[mem] {prefix} cuda_not_available") - return - - if args.debug_oom_sync: - torch.cuda.synchronize() - - allocated = torch.cuda.memory_allocated() - reserved = torch.cuda.memory_reserved() - max_alloc = torch.cuda.max_memory_allocated() - max_reserved = torch.cuda.max_memory_reserved() - - msg = ( - f"[mem] {prefix} " - f"alloc={_fmt_bytes(allocated)} reserv={_fmt_bytes(reserved)} " - f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserved)} " - f"device={accelerator.device}" - ) - if extra: - kv = " ".join([f"{k}={v}" for k, v in extra.items()]) - msg = msg + " " + kv - - logger.info(msg) - - -def log_tensor(prefix: str, t: torch.Tensor | None, args, accelerator: Accelerator): - """打印张量的shape/dtype/device/grad状态,避免误把大张量复制到CPU。""" - if not args.debug_oom: - return - if t is None: - logger.info(f"[tensor] {prefix} None") - return - logger.info( - f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " - f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" - ) - - -def log_trainable_params(prefix: str, module: torch.nn.Module, args): - """打印模块可训练参数规模,确认是否意外训练了大量参数。""" - if not args.debug_oom: - return - trainable = [(n, p.numel(), str(p.dtype), str(p.device)) for n, p in module.named_parameters() if p.requires_grad] - total = sum(x[1] for x in trainable) - logger.info(f"[trainable] {prefix} tensors={len(trainable)} total_params={total}") - for n, numel, dtype, dev in trainable[:30]: - logger.info(f"[trainable] {prefix} name={n} numel={numel} dtype={dtype} device={dev}") - if len(trainable) > 30: - logger.info(f"[trainable] {prefix} ... (total {len(trainable)} trainable tensors)") - - def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( @@ -381,7 +526,6 @@ def main(args): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) - logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() @@ -392,17 +536,9 @@ def main(args): accelerator.init_trackers("CAAT", config=vars(args)) - if accelerator.is_local_main_process: - logger.info("========== CAAT 参数 ==========") - for k in sorted(vars(args).keys()): - logger.info(f"{k}: {getattr(args, k)}") - logger.info("===============================") - - log_cuda("startup", args, accelerator) - + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) - if args.concepts_list is None: args.concepts_list = [ { @@ -416,6 +552,7 @@ def main(args): with open(args.concepts_list, "r") as f: args.concepts_list = json.load(f) + # Generate class images if prior preservation is enabled. if args.with_prior_preservation: for i, concept in enumerate(args.concepts_list): class_images_dir = Path(concept["class_data_dir"]) @@ -423,6 +560,7 @@ def main(args): class_images_dir.mkdir(parents=True, exist_ok=True) cur_class_images = len(list(class_images_dir.iterdir())) + if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": @@ -431,9 +569,6 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - - log_cuda("before_prior_pipeline_load", args, accelerator, extra={"torch_dtype": str(torch_dtype)}) - pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -442,90 +577,115 @@ def main(args): ) pipeline.set_progress_bar_config(disable=True) - sample_dataset = PromptDataset(args.class_prompt, args.num_class_images - cur_class_images) + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - log_cuda("after_prior_pipeline_to_device", args, accelerator) - for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images + for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() - log_cuda("after_prior_pipeline_del", args, accelerator) - if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) + # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) - log_cuda("after_load_models_cpu_or_meta", args, accelerator) vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) - + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - if accelerator.is_local_main_process and args.debug_oom: - logger.info(f"[debug] weight_dtype={weight_dtype} mixed_precision={accelerator.mixed_precision}") - + # Move unet, vae and text_encoder to device and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - log_cuda("after_models_to_device", args, accelerator) - attention_class = CustomDiffusionAttnProcessor if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) - logger.info(f"[debug] xformers_version={xformers_version}") if xformers_version == version.parse("0.0.16"): - logger.warning( - "xFormers 0.0.16 may be unstable for training on some GPUs; consider upgrading to >=0.0.17." + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) attention_class = CustomDiffusionXFormersAttnProcessor else: raise ValueError("xformers is not available. Make sure it is installed correctly") + # now we will add new Custom Diffusion weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer train_kv = True train_q_out = False if args.freeze_model == "crossattn_kv" else True custom_diffusion_attn_procs = {} st = unet.state_dict() + for name, _ in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -536,9 +696,7 @@ def main(args): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - layer_name = name.split(".processor")[0] - weights = { "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], @@ -547,7 +705,6 @@ def main(args): weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] - if cross_attention_dim is not None: custom_diffusion_attn_procs[name] = attention_class( train_kv=train_kv, @@ -565,37 +722,38 @@ def main(args): ) del st + unet.set_attn_processor(custom_diffusion_attn_procs) custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) + accelerator.register_for_checkpointing(custom_diffusion_layers) - log_trainable_params("unet_after_set_attn_processor", unet, args) - log_trainable_params("custom_diffusion_layers", custom_diffusion_layers, args) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - if accelerator.is_local_main_process and args.debug_oom: - logger.info("[debug] gradient_checkpointing enabled") - + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True - if accelerator.is_local_main_process and args.debug_oom: - logger.info("[debug] allow_tf32 enabled") + args.learning_rate = args.learning_rate if args.with_prior_preservation: args.learning_rate = args.learning_rate * 2.0 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install bitsandbytes: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + optimizer_class = bnb.optim.AdamW8bit - if accelerator.is_local_main_process and args.debug_oom: - logger.info("[debug] using 8-bit AdamW") else: optimizer_class = torch.optim.AdamW + # Optimizer creation optimizer = optimizer_class( custom_diffusion_layers.parameters(), lr=args.learning_rate, @@ -604,29 +762,25 @@ def main(args): eps=args.adam_epsilon, ) - # 与 CAAT 代码保持一致:通过一次 VAE encode 推导 mask_size - mask_size = ( - vae.encode(torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)) - .latent_dist.sample() - .size()[-1] - ) - if accelerator.is_local_main_process and args.debug_oom: - logger.info(f"[debug] inferred mask_size={mask_size}") - + # Dataset creation: train_dataset = CustomDiffusionDataset( concepts_list=args.concepts_list, tokenizer=tokenizer, with_prior_preservation=args.with_prior_preservation, size=args.resolution, - mask_size=mask_size, + mask_size=vae.encode( + torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device) + ) + .latent_dist.sample() + .size()[-1], center_crop=args.center_crop, num_class_images=args.num_class_images, hflip=args.hflip, aug=not args.noaug, ) - log_cuda("after_build_dataset", args, accelerator, extra={"num_instance_images": train_dataset.num_instance_images}) + # Prepare for PGD pertubed_images = [Image.open(i[0]).convert("RGB") for i in train_dataset.instance_images_path] pertubed_images = [train_dataset.image_transforms(i) for i in pertubed_images] pertubed_images = torch.stack(pertubed_images).contiguous() @@ -668,10 +822,7 @@ def main(args): mask = mask.unsqueeze(1) del images_open_list - log_tensor("pertubed_images_before_prepare", pertubed_images, args, accelerator) - log_tensor("original_images_before_prepare", original_images, args, accelerator) - log_tensor("mask_before_prepare", mask, args, accelerator) - log_tensor("input_ids_cpu", input_ids, args, accelerator) + lr_scheduler = get_scheduler( args.lr_scheduler, @@ -680,172 +831,106 @@ def main(args): num_training_steps=args.max_train_steps * accelerator.num_processes, ) - log_cuda("before_accelerator_prepare", args, accelerator) - custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask = accelerator.prepare( custom_diffusion_layers, optimizer, pertubed_images, lr_scheduler, original_images, mask ) - log_cuda("after_accelerator_prepare", args, accelerator) - log_tensor("pertubed_images_after_prepare", pertubed_images, args, accelerator) - log_tensor("original_images_after_prepare", original_images, args, accelerator) - log_tensor("mask_after_prepare", mask, args, accelerator) + # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num pertubed_images = {len(pertubed_images)}") logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 first_epoch = 0 + # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") - for epoch in range(first_epoch, args.max_train_steps): unet.train() - for _ in range(1): with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): - if _debug_should_print(args, global_step): - log_cuda("step_begin", args, accelerator, extra={"global_step": global_step}) - logger.info(f"[debug] step={global_step} starting forward path") - - # 关键定位:你说每次“加噪第0步开始前就爆”,这里把每个子阶段都打点 - try: - pertubed_images.requires_grad = True - if _debug_should_print(args, global_step): - log_tensor("pertubed_images_pre_vae", pertubed_images, args, accelerator) - log_cuda("before_vae_encode", args, accelerator, extra={"global_step": global_step}) + # Convert images to latent space + pertubed_images.requires_grad = True + latents = vae.encode(pertubed_images.to(accelerator.device).to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - latents_dist = vae.encode( - pertubed_images.to(accelerator.device).to(dtype=weight_dtype) - ).latent_dist + # unet.zero_grad() + # text_encoder.zero_grad() - if _debug_should_print(args, global_step): - log_cuda("after_vae_encode", args, accelerator, extra={"global_step": global_step}) + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + mask = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() - latents = latents_dist.sample() - latents = latents * vae.config.scaling_factor - - if _debug_should_print(args, global_step): - log_tensor("latents", latents, args, accelerator) - log_cuda("after_latents_sample", args, accelerator, extra={"global_step": global_step}) + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() - - if _debug_should_print(args, global_step): - log_tensor("noise", noise, args, accelerator) - log_tensor("timesteps", timesteps, args, accelerator) - log_cuda("before_add_noise", args, accelerator, extra={"global_step": global_step}) - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - if _debug_should_print(args, global_step): - log_tensor("noisy_latents", noisy_latents, args, accelerator) - log_cuda("after_add_noise", args, accelerator, extra={"global_step": global_step}) - - if _debug_should_print(args, global_step): - log_cuda("before_text_encoder", args, accelerator, extra={"global_step": global_step}) - - encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0] - - if _debug_should_print(args, global_step): - log_tensor("encoder_hidden_states", encoder_hidden_states, args, accelerator) - log_cuda("after_text_encoder", args, accelerator, extra={"global_step": global_step}) - - if _debug_should_print(args, global_step): - log_cuda("before_unet_forward", args, accelerator, extra={"global_step": global_step}) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if _debug_should_print(args, global_step): - log_tensor("model_pred", model_pred, args, accelerator) - log_cuda("after_unet_forward", args, accelerator, extra={"global_step": global_step}) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.with_prior_preservation: - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - mask_inst = torch.chunk(mask, 2, dim=0)[0].to(accelerator.device) - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = ((loss * mask_inst).sum([1, 2, 3]) / mask_inst.sum([1, 2, 3])).mean() - - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - loss = loss + args.prior_loss_weight * prior_loss - else: - mask_inst = mask.to(accelerator.device) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean() - - if _debug_should_print(args, global_step): - logger.info(f"[debug] step={global_step} loss_value={loss.detach().float().item()}") - log_cuda("before_backward", args, accelerator, extra={"global_step": global_step}) - - accelerator.backward(loss) - - if _debug_should_print(args, global_step): - log_cuda("after_backward", args, accelerator, extra={"global_step": global_step}) - if pertubed_images.grad is None: - logger.info(f"[debug] step={global_step} pertubed_images.grad=None") - else: - logger.info( - f"[debug] step={global_step} pertubed_images.grad_abs_mean=" - f"{pertubed_images.grad.abs().mean().item():.6e}" - ) - - if accelerator.sync_gradients: - params_to_clip = custom_diffusion_layers.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - alpha = args.alpha - eps = args.eps - - if _debug_should_print(args, global_step): - logger.info(f"[debug] step={global_step} alpha={alpha} eps={eps}") - log_cuda("before_pgd_update", args, accelerator, extra={"global_step": global_step}) + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + mask = mask.to(accelerator.device) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # torch.Size([5, 4, 64, 64]) - adv_images = pertubed_images + alpha * pertubed_images.grad.sign() - eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) - pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + #loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + loss = loss.mean() - if _debug_should_print(args, global_step): - log_tensor("pertubed_images_after_pgd", pertubed_images, args, accelerator) - log_cuda("after_pgd_update", args, accelerator, extra={"global_step": global_step}) + accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=args.set_grads_to_none) - - if _debug_should_print(args, global_step): - log_cuda("after_optimizer_step", args, accelerator, extra={"global_step": global_step}) - except RuntimeError as e: - # 捕获OOM并打印尽可能多的上下文,便于你定位爆显存发生在哪个子阶段 - if "out of memory" in str(e).lower() or "cuda" in str(e).lower(): - logger.error(f"[OOM] step={global_step} caught RuntimeError: {e}") - log_cuda("oom_caught", args, accelerator, extra={"global_step": global_step}) - logger.error("[OOM] 如果你看到oom发生在 before_unet_forward/after_unet_forward 附近,通常是UNet前向峰值") - logger.error("[OOM] 如果你看到oom发生在 before_backward/after_backward 附近,通常是反传保存激活导致峰值") - logger.error("[OOM] 如果你看到oom发生在 after_accelerator_prepare 附近,通常是prepare或模型常驻占用过高") - raise + if accelerator.sync_gradients: + params_to_clip = ( + custom_diffusion_layers.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + alpha = args.alpha + eps = args.eps + adv_images = pertubed_images + alpha * pertubed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + pertubed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + + optimizer.step() + + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -856,18 +941,27 @@ def main(args): if accelerator.is_main_process: logger.info("***** Final save of perturbed images *****") save_folder = args.output_dir - noised_imgs = pertubed_images.detach().cpu() - img_names = [str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path] - for i in range(len(img_names)): + noised_imgs = pertubed_images.detach().cpu() + + img_names = [ + str(instance_path[0]).split("/")[-1] for instance_path in train_dataset.instance_images_path + ] + + num_images_to_save = len(img_names) + + for i in range(num_images_to_save): img_pixel = noised_imgs[i] img_name = img_names[i] save_path = os.path.join(save_folder, f"final_noise_{img_name}") + + # 图像转换和保存 Image.fromarray( (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy() ).save(save_path) + + logger.info(f"Saved {num_images_to_save} final perturbed images to {save_folder}") - logger.info(f"Saved {len(img_names)} final perturbed images to {save_folder}") accelerator.end_training() @@ -875,4 +969,4 @@ def main(args): if __name__ == "__main__": args = parse_args() main(args) - print("<-------end-------->") \ No newline at end of file + print("<-------end-------->") diff --git a/src/backend/app/algorithms/perturbation/simac.py b/src/backend/app/algorithms/perturbation/simac.py index af93f64..cab654d 100644 --- a/src/backend/app/algorithms/perturbation/simac.py +++ b/src/backend/app/algorithms/perturbation/simac.py @@ -10,7 +10,7 @@ from pathlib import Path import datasets import diffusers -import transformers +import transformers import numpy as np import torch import torch.nn.functional as F @@ -30,84 +30,19 @@ from transformers import AutoTokenizer, PretrainedConfig logger = get_logger(__name__) -# ----------------------------- -# Lightweight debug helpers -# ----------------------------- def _cuda_gc() -> None: - """Best-effort CUDA memory cleanup (does not change algorithmic behavior).""" + """Try to release unreferenced CUDA memory and reduce fragmentation. + + This is a best-effort helper. It does not change algorithmic behavior but can + make long runs less prone to OOM due to fragmentation/reserved-memory growth. + """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() -def _fmt_bytes(n: int) -> str: - return f"{n / (1024**2):.1f}MB" - - -def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None): - """Log CUDA memory stats without copying tensors to CPU.""" - if not torch.cuda.is_available(): - logger.info(f"[mem] {prefix} cuda_not_available") - return - if sync: - torch.cuda.synchronize() - alloc = torch.cuda.memory_allocated() - reserv = torch.cuda.memory_reserved() - max_alloc = torch.cuda.max_memory_allocated() - max_reserv = torch.cuda.max_memory_reserved() - dev = str(accelerator.device) if accelerator is not None else "cuda" - msg = ( - f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} " - f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}" - ) - if extra: - msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()]) - logger.info(msg) - - -def log_path_stats(prefix: str, p: Path): - """Log directory/file existence and file count (best-effort).""" - try: - exists = p.exists() - is_dir = p.is_dir() if exists else False - n_files = 0 - if exists and is_dir: - n_files = sum(1 for x in p.iterdir() if x.is_file()) - logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}") - except Exception as e: - logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}") - - -def log_args(args): - for k in sorted(vars(args).keys()): - logger.info(f"[args] {k}={getattr(args, k)}") - - -def log_tensor_meta(prefix: str, t: torch.Tensor | None): - if t is None: - logger.info(f"[tensor] {prefix} None") - return - logger.info( - f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " - f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}" - ) - - -def setup_seeds(): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - cudnn.benchmark = False - cudnn.deterministic = True - - -# ----------------------------- -# Dataset -# ----------------------------- class DreamBoothDatasetFromTensor(Dataset): - """基于内存张量的 DreamBooth 数据集:直接返回图像张量与 prompt token。""" + """Just like DreamBoothDataset, but take instance_images_tensor instead of path.""" def __init__( self, @@ -131,19 +66,10 @@ class DreamBoothDatasetFromTensor(Dataset): if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - # Only keep files to avoid directories affecting length. - self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()] + self.class_images_path = list(self.class_data_root.iterdir()) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt - - # Early, explicit failure instead of ZeroDivisionError later. - if self.num_class_images == 0: - raise ValueError( - f"class_data_dir is empty: {self.class_data_root}. " - f"Prior preservation requires class images. " - f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation." - ) else: self.class_data_root = None @@ -172,9 +98,6 @@ class DreamBoothDatasetFromTensor(Dataset): ).input_ids if self.class_data_root: - # Defensive: if class_images become empty due to external deletion, raise a clear error. - if self.num_class_images == 0: - raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}") class_image = Image.open(self.class_images_path[index % self.num_class_images]) if class_image.mode != "RGB": class_image = class_image.convert("RGB") @@ -190,9 +113,6 @@ class DreamBoothDatasetFromTensor(Dataset): return example -# ----------------------------- -# Model helpers -# ----------------------------- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -212,6 +132,370 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir_for_train", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_dir_for_adversarial", + type=str, + default=None, + required=True, + help="A folder containing the images to add adversarial noise", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=8, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--max_f_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train surogate model.", + ) + parser.add_argument( + "--max_adv_train_steps", + type=int, + default=10, + help="Total number of sub-steps to train adversarial noise.", + ) + parser.add_argument( + "--checkpointing_iterations", + type=int, + default=5, + help=("Save a checkpoint of the training state every X iterations."), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--pgd_alpha", + type=float, + default=0.005, + help="The step size for pgd.", + ) + parser.add_argument( + "--pgd_eps", + type=int, + default=16, + help="The noise budget for pgd.", + ) + parser.add_argument( + "--target_image_path", + default=None, + help="target image for attacking", + ) + parser.add_argument( + "--max_steps", + type=int, + default=50, + help=("Maximum steps for adaptive greedy timestep selection."), + ) + parser.add_argument( + "--delta_t", + type=int, + default=20, + help=("delete 2*delta_t for each adaptive greedy timestep selection."), + ) + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + + +class PromptDataset(Dataset): + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: + image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] + images = torch.stack(images) + return images + + +def train_one_epoch( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + num_steps=20, +): + unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + train_dataset = DreamBoothDatasetFromTensor( + data_tensor, + args.instance_prompt, + tokenizer, + args.class_data_dir, + args.class_prompt, + args.resolution, + args.center_crop, + ) + + weight_dtype = torch.bfloat16 + device = torch.device("cuda") + + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + for step in range(num_steps): + unet.train() + text_encoder.train() + + step_data = train_dataset[step % len(train_dataset)] + pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to( + device, dtype=weight_dtype + ) + input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) + + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + encoder_hidden_states = text_encoder(input_ids)[0] + + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + loss = instance_loss + args.prior_loss_weight * prior_loss + else: + prior_loss = torch.tensor(0.0, device=device) + instance_loss = torch.tensor(0.0, device=device) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + loss.backward() + torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) + optimizer.step() + optimizer.zero_grad() + + print( + f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, " + f"instance_loss: {instance_loss.detach().item()}" + ) + + # Best-effort: free per-step tensors earlier (no behavior change). + del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states + del model_pred, target, loss, prior_loss, instance_loss + + # Best-effort: release optimizer state + dataset refs sooner. + del optimizer, train_dataset, params_to_optimize + _cuda_gc() + + return [unet, text_encoder] + + def set_unet_attr(unet): def conv_forward(self): def forward(input_tensor, temb): @@ -281,139 +565,65 @@ def save_feature_maps(up_blocks, down_blocks): return out_layers_features_list_3 -# ----------------------------- -# Args -# ----------------------------- -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) - parser.add_argument("--revision", type=str, default=None, required=False) - parser.add_argument("--tokenizer_name", type=str, default=None) - parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True) - parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True) - parser.add_argument("--class_data_dir", type=str, default=None, required=False) - parser.add_argument("--instance_prompt", type=str, default=None, required=True) - parser.add_argument("--class_prompt", type=str, default=None) - parser.add_argument("--with_prior_preservation", default=False, action="store_true") - parser.add_argument("--prior_loss_weight", type=float, default=1.0) - parser.add_argument("--num_class_images", type=int, default=100) - parser.add_argument("--output_dir", type=str, default="text-inversion-model") - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--resolution", type=int, default=512) - parser.add_argument("--center_crop", default=False, action="store_true") - parser.add_argument("--train_text_encoder", action="store_true") - parser.add_argument("--train_batch_size", type=int, default=4) - parser.add_argument("--sample_batch_size", type=int, default=8) - parser.add_argument("--max_train_steps", type=int, default=20) - parser.add_argument("--max_f_train_steps", type=int, default=10) - parser.add_argument("--max_adv_train_steps", type=int, default=10) - parser.add_argument("--checkpointing_iterations", type=int, default=5) - parser.add_argument("--learning_rate", type=float, default=5e-6) - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument("--allow_tf32", action="store_true") - parser.add_argument("--report_to", type=str, default="tensorboard") - parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") - parser.add_argument("--pgd_alpha", type=float, default=0.005) - parser.add_argument("--pgd_eps", type=int, default=16) - parser.add_argument("--target_image_path", default=None) - parser.add_argument("--max_steps", type=int, default=50) - parser.add_argument("--delta_t", type=int, default=20) - - # Debug / diagnostics (low-overhead) - parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.") - parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.") - parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.") - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - return args - - -# ----------------------------- -# IO helpers -# ----------------------------- -def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor: - image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())] - images = torch.stack(images) - return images - - -# ----------------------------- -# Train / Attack -# ----------------------------- -def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20, accelerator=None): - unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1]) - params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) - - optimizer = torch.optim.AdamW( - params_to_optimize, - lr=args.learning_rate, - betas=(0.9, 0.999), - weight_decay=1e-2, - eps=1e-08, - ) - - train_dataset = DreamBoothDatasetFromTensor( - data_tensor, - args.instance_prompt, - tokenizer, - args.class_data_dir if args.with_prior_preservation else None, - args.class_prompt, - args.resolution, - args.center_crop, - ) - +def pgd_attack( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, + num_steps: int, + time_list, +): + """Return new perturbed data. + + Note: This function keeps the external behavior identical, but tries to reduce + memory pressure by freeing tensors early and avoiding lingering references. + """ + unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) + set_unet_attr(unet) - for step in range(num_steps): - if args.debug_step0_only and step != 0: - pass + perturbed_images = data_tensor.detach().clone() + perturbed_images.requires_grad_(True) - unet.train() - text_encoder.train() + # Keep input_ids on CPU; move to GPU only when encoding. + input_ids = tokenizer( + args.instance_prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.repeat(len(data_tensor), 1) - try: - step_data = train_dataset[step % len(train_dataset)] - except Exception as e: - logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}") - raise - - # This will fail fast if class_images missing (KeyError), better than silent wrong behavior. - try: - pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(device, dtype=weight_dtype) - input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device) - except KeyError as e: - logger.error( - f"[err] missing key in step_data at step={step}: missing={str(e)}. " - f"with_prior_preservation={args.with_prior_preservation}" - ) - raise + for step in range(num_steps): + perturbed_images.requires_grad_(True) - latents = vae.encode(pixel_values).latent_dist.sample() + latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() + + timesteps = [] + for i in range(len(data_tensor)): + ts = time_list[i] + ts_index = torch.randint(0, len(ts), (1,)) + timestep = torch.IntTensor([ts[ts_index]]) + timestep = timestep.long() + timesteps.append(timestep) + timesteps = torch.cat(timesteps).to(device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(input_ids)[0] + + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": @@ -423,38 +633,62 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if args.with_prior_preservation: - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) + noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) - instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - loss = instance_loss + args.prior_loss_weight * prior_loss - else: - prior_loss = torch.tensor(0.0, device=device) - instance_loss = torch.tensor(0.0, device=device) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + with torch.no_grad(): + clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() + clean_latents = clean_latents * vae.config.scaling_factor + noisy_clean_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) + _ = unet(noisy_clean_latents, timesteps, encoder_hidden_states).sample + clean_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) + + target_loss = F.mse_loss( + noise_out_layers_features_3.float(), + clean_out_layers_features_3.float(), + reduction="mean", + ) + + unet.zero_grad(set_to_none=True) + text_encoder.zero_grad(set_to_none=True) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + # Keep original behavior: feature loss does not backprop (added as Python float). + loss = loss + target_loss.detach().item() loss.backward() - torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True) - optimizer.step() - optimizer.zero_grad() - logger.info( - f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} " - f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}" + alpha = args.pgd_alpha + eps = args.pgd_eps / 255 + adv_images = perturbed_images + alpha * perturbed_images.grad.sign() + eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) + perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() + + print( + f"PGD loss - step {step}, loss: {loss.detach().item()}, target_loss : {target_loss.detach().item()}" ) - # Free some step tensors early. - del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states - del model_pred, target, loss, prior_loss, instance_loss + # Best-effort: free per-step tensors early. + del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target + del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3 + del target_loss, loss, adv_images, eta - del optimizer, train_dataset, params_to_optimize _cuda_gc() - return [unet, text_encoder] + return perturbed_images -def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor): +def select_timestep( + args, + models, + tokenizer, + noise_scheduler, + vae, + data_tensor: torch.Tensor, + original_images: torch.Tensor, + target_tensor: torch.Tensor, +): + """Return timestep lists for each image. + + External behavior unchanged; add best-effort per-loop cleanup to lower memory pressure. + """ unet, text_encoder = models weight_dtype = torch.bfloat16 device = torch.device("cuda") @@ -497,10 +731,13 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, noise = torch.randn_like(latents) bsz = latents.shape[0] inner_index = torch.randint(0, len(res_time_seq), (bsz,)) - timesteps = torch.IntTensor([res_time_seq[inner_index]]).to(device).long() + timesteps = torch.IntTensor([res_time_seq[inner_index]]).to(device) + timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states = text_encoder(input_ids.to(device))[0] + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": @@ -532,15 +769,14 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, max_score = score select_t = res_time_seq[inner_index].item() - if args.debug: - logger.info( - f"[select_timestep] img={img_id} outer={step} inner={inner_try} loss={loss.detach().item():.6f} " - f"score={score.item() if torch.is_tensor(score) else score} t={res_time_seq[inner_index].item()} " - f"len={len(res_time_seq)}" - ) + print( + f"PGD loss - step {step}, index : {inner_try + 1}, loss: {loss.detach().item()}, " + f"score: {score}, t : {res_time_seq[inner_index]}, ts_len: {len(res_time_seq)}" + ) del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, score + print("del_t", del_t, "max_t", select_t) if del_t < args.delta_t: del_t = args.delta_t elif del_t > (1000 - args.delta_t): @@ -554,7 +790,8 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) - timesteps = torch.IntTensor([select_t]).to(device).long() + timesteps = torch.IntTensor([select_t]).to(device) + timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(input_ids.to(device))[0] @@ -569,6 +806,7 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, unet.zero_grad(set_to_none=True) text_encoder.zero_grad(set_to_none=True) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss.backward() @@ -576,6 +814,7 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, eps = args.pgd_eps / 255 adv_image = id_image + alpha * id_image.grad.sign() eta = torch.clamp(adv_image - original_image, min=-eps, max=+eps) + _ = torch.sum(torch.abs(id_image.grad.sign())) id_image = torch.clamp(original_image + eta, min=-1, max=+1).detach_() del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss, adv_image, eta @@ -592,99 +831,16 @@ def select_timestep(args, models, tokenizer, noise_scheduler, vae, data_tensor, return time_list -def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps, time_list): - unet, text_encoder = models - weight_dtype = torch.bfloat16 - device = torch.device("cuda") - - vae.to(device, dtype=weight_dtype) - text_encoder.to(device, dtype=weight_dtype) - unet.to(device, dtype=weight_dtype) - set_unet_attr(unet) - - perturbed_images = data_tensor.detach().clone() - perturbed_images.requires_grad_(True) - - input_ids = tokenizer( - args.instance_prompt, - truncation=True, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", - ).input_ids.repeat(len(data_tensor), 1) - - for step in range(num_steps): - if args.debug_step0_only and step != 0: - pass - - perturbed_images.requires_grad_(True) - - latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - noise = torch.randn_like(latents) - - timesteps = [] - for i in range(len(data_tensor)): - ts = time_list[i] - if len(ts) == 0: - raise ValueError(f"time_list[{i}] is empty; select_timestep failed.") - ts_index = torch.randint(0, len(ts), (1,)) - timestep = torch.IntTensor([ts[ts_index]]).long() - timesteps.append(timestep) - timesteps = torch.cat(timesteps).to(device) - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(input_ids.to(device))[0] - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - noise_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) - - with torch.no_grad(): - clean_latents = vae.encode(data_tensor.to(device, dtype=weight_dtype)).latent_dist.sample() - clean_latents = clean_latents * vae.config.scaling_factor - noisy_clean_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) - _ = unet(noisy_clean_latents, timesteps, encoder_hidden_states).sample - clean_out_layers_features_3 = save_feature_maps(unet.up_blocks, unet.down_blocks) - - target_loss = F.mse_loss(noise_out_layers_features_3.float(), clean_out_layers_features_3.float(), reduction="mean") - - unet.zero_grad(set_to_none=True) - text_encoder.zero_grad(set_to_none=True) - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - loss = loss + target_loss.detach().item() - loss.backward() - - alpha = args.pgd_alpha - eps = args.pgd_eps / 255 - adv_images = perturbed_images + alpha * perturbed_images.grad.sign() - eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps) - perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_() - - logger.info( - f"[pgd] step={step} loss={loss.detach().item():.6f} target_loss={target_loss.detach().item():.6f} " - f"alpha={alpha} eps={eps}" - ) - - del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target - del noise_out_layers_features_3, clean_latents, noisy_clean_latents, clean_out_layers_features_3 - del target_loss, loss, adv_images, eta - - _cuda_gc() - return perturbed_images +def setup_seeds(): + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True -# ----------------------------- -# Main -# ----------------------------- def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -699,7 +855,6 @@ def main(args): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) - logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() @@ -710,28 +865,15 @@ def main(args): transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() - if accelerator.is_local_main_process: - logger.info(f"[run] using_file={__file__}") - log_args(args) - if args.seed is not None: set_seed(args.seed) setup_seeds() - if args.debug and accelerator.is_local_main_process: - log_cuda("startup", accelerator, sync=args.debug_cuda_sync) - - # ------------------------- - # Prior preservation: generate class images if needed - # ------------------------- + # Generate class images if prior preservation is enabled. if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) class_images_dir.mkdir(parents=True, exist_ok=True) - log_path_stats("class_dir_before", class_images_dir) - - cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) - if accelerator.is_local_main_process: - logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}") + cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 @@ -742,10 +884,6 @@ def main(args): elif args.mixed_precision == "bf16": torch_dtype = torch.bfloat16 - if accelerator.is_local_main_process: - logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}") - log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync) - pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -755,24 +893,20 @@ def main(args): pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - if accelerator.is_local_main_process and args.debug: - log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync) - for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images - if accelerator.is_local_main_process and args.debug: - logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}") - for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" @@ -781,45 +915,32 @@ def main(args): del pipeline, sample_dataset, sample_dataloader _cuda_gc() - # IMPORTANT: sync all processes before training reads the directory - accelerator.wait_for_everyone() - - # Post-check: ensure class images exist - final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file()) - if accelerator.is_local_main_process: - logger.info(f"[class_gen] done final_class_images={final_class_images}") - log_path_stats("class_dir_after", class_images_dir) - if final_class_images == 0: - raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.") - - else: - accelerator.wait_for_everyone() - if accelerator.is_local_main_process: - logger.info("[class_gen] skipped (already enough images)") - else: - if accelerator.is_local_main_process: - logger.info("[class_gen] disabled (with_prior_preservation is False)") - - # ------------------------- - # Load models - # ------------------------- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - if accelerator.is_local_main_process and args.debug: - log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync) - text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, ) + tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, ).cuda() vae.requires_grad_(False) @@ -829,85 +950,93 @@ def main(args): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True + clean_data = load_data( + args.instance_data_dir_for_train, + size=args.resolution, + center_crop=args.center_crop, + ) + perturbed_data = load_data( + args.instance_data_dir_for_adversarial, + size=args.resolution, + center_crop=args.center_crop, + ) + original_data = perturbed_data.clone() + original_data.requires_grad_(False) + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - if accelerator.is_local_main_process: - logger.info("[xformers] enabled") else: raise ValueError("xformers is not available. Make sure it is installed correctly") - if accelerator.is_local_main_process and args.debug: - log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync) - - # ------------------------- - # Load data tensors - # ------------------------- - train_dir = Path(args.instance_data_dir_for_train) - adv_dir = Path(args.instance_data_dir_for_adversarial) - if accelerator.is_local_main_process and args.debug: - log_path_stats("train_dir", train_dir) - log_path_stats("adv_dir", adv_dir) - - clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop) - perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop) - original_data = perturbed_data.clone() - original_data.requires_grad_(False) - - if accelerator.is_local_main_process and args.debug: - log_tensor_meta("clean_data_cpu", clean_data) - log_tensor_meta("perturbed_data_cpu", perturbed_data) - target_latent_tensor = None if args.target_image_path is not None: target_image_path = Path(args.target_image_path) - if not target_image_path.is_file(): - raise ValueError(f"Target image path does not exist: {target_image_path}") + assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist" target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution)) target_image = np.array(target_image)[None].transpose(0, 3, 1, 2) target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0 - target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) - target_latent_tensor = target_latent_tensor * vae.config.scaling_factor + target_latent_tensor = ( + vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor + ) target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda() - if accelerator.is_local_main_process and args.debug: - log_tensor_meta("target_latent_tensor", target_latent_tensor) - f = [unet, text_encoder] - # ------------------------- - # Select timesteps - # ------------------------- - if accelerator.is_local_main_process: - logger.info("[phase] select_timestep begin") - time_list = select_timestep(args, f, tokenizer, noise_scheduler, vae, perturbed_data, original_data, target_latent_tensor) - if accelerator.is_local_main_process: - logger.info("[phase] select_timestep end") - if args.debug: - for i, t in enumerate(time_list[: min(10, len(time_list))]): - logger.info(f"[time_list] idx={i} len={len(t)} first={t[0].item() if len(t)>0 else 'NA'}") - - # ------------------------- - # Main training loop - # ------------------------- - for i in range(args.max_train_steps): - if accelerator.is_local_main_process: - logger.info(f"[outer] i={i}/{args.max_train_steps}") + time_list = select_timestep( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + original_data, + target_latent_tensor, + ) + for t in time_list: + print(t) + for i in range(args.max_train_steps): f_sur = copy.deepcopy(f) - f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps, accelerator=accelerator) + + f_sur = train_one_epoch( + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + clean_data, + args.max_f_train_steps, + ) perturbed_data = pgd_attack( - args, f_sur, tokenizer, noise_scheduler, vae, - perturbed_data, original_data, target_latent_tensor, args.max_adv_train_steps, time_list + args, + f_sur, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + original_data, + target_latent_tensor, + args.max_adv_train_steps, + time_list, ) + # Free surrogate ASAP (best-effort, behavior unchanged). del f_sur _cuda_gc() - f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps, accelerator=accelerator) + f = train_one_epoch( + args, + f, + tokenizer, + noise_scheduler, + vae, + perturbed_data, + args.max_f_train_steps, + ) if (i + 1) % args.checkpointing_iterations == 0: save_folder = args.output_dir @@ -930,9 +1059,9 @@ def main(args): .numpy() ).save(save_path) - if accelerator.is_local_main_process: - logger.info(f"[save] step={i+1} saved={len(img_names)} to {save_folder}") + print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)") + # Best-effort cleanup at the end of each outer iteration. _cuda_gc() -- 2.34.1 From 8627af3e8fb490ff2cbfd872aca18d17ee227b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Wed, 7 Jan 2026 14:39:28 +0800 Subject: [PATCH 4/5] =?UTF-8?q?Revert=20"improve:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=AE=97=E6=B3=95=E8=B6=85=E5=8F=82=E6=95=B0"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit fe90dc173e3d8eed9849c594d43e3825f07607dd. --- .../app/scripts/attack_anti_face_edit.sh | 2 + src/backend/app/scripts/attack_aspl.sh | 46 +++++++++---------- src/backend/app/scripts/attack_caat.sh | 26 +++++------ .../app/scripts/attack_caat_with_prior.sh | 8 ++-- src/backend/app/scripts/attack_simac.sh | 14 +++--- src/backend/config/algorithm_config.py | 12 ++--- 6 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/backend/app/scripts/attack_anti_face_edit.sh b/src/backend/app/scripts/attack_anti_face_edit.sh index 857d433..c66ced4 100644 --- a/src/backend/app/scripts/attack_anti_face_edit.sh +++ b/src/backend/app/scripts/attack_anti_face_edit.sh @@ -50,6 +50,8 @@ CUDA_VISIBLE_DEVICES=0 python ../algorithms/pid.py \ --center_crop \ --eps 10 \ --step_size 0.002 \ + --save_every 200 \ --attack_type add-log \ --seed 0 \ --dataloader_num_workers 2 + diff --git a/src/backend/app/scripts/attack_aspl.sh b/src/backend/app/scripts/attack_aspl.sh index f1dc6b9..a4f9e53 100644 --- a/src/backend/app/scripts/attack_aspl.sh +++ b/src/backend/app/scripts/attack_aspl.sh @@ -24,29 +24,29 @@ echo "Clearing output directory: $OUTPUT_DIR" find "$OUTPUT_DIR" -mindepth 1 -delete -accelerate launch --num_processes 1 --num_machines 1 ../algorithms/aspl.py \ - --pretrained_model_name_or_path="$MODEL_PATH" \ - --enable_xformers_memory_efficient_attention \ - --instance_data_dir_for_train="$CLEAN_TRAIN_DIR" \ - --instance_data_dir_for_adversarial="$CLEAN_ADV_DIR" \ - --instance_prompt="a photo of sks person" \ - --class_data_dir="$CLASS_DIR" \ - --num_class_images=200 \ - --class_prompt="a photo of person" \ - --output_dir="$OUTPUT_DIR" \ - --center_crop \ - --with_prior_preservation \ - --prior_loss_weight=1.0 \ - --resolution=384 \ - --train_batch_size=1 \ - --max_train_steps=50 \ - --max_f_train_steps=3 \ - --max_adv_train_steps=6 \ - --checkpointing_iterations=10 \ - --learning_rate=5e-7 \ - --pgd_alpha=0.005 \ - --pgd_eps=8 \ - --seed=0 +accelerate launch ../algorithms/aspl.py \ +  --pretrained_model_name_or_path=$MODEL_PATH  \ +  --enable_xformers_memory_efficient_attention \ +  --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ +  --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ +  --instance_prompt="a photo of sks person" \ +  --class_data_dir=$CLASS_DIR \ +  --num_class_images=200 \ +  --class_prompt="a photo of person" \ +  --output_dir=$OUTPUT_DIR \ +  --center_crop \ +  --with_prior_preservation \ +  --prior_loss_weight=1.0 \ +  --resolution=384 \ +  --train_batch_size=1 \ +  --max_train_steps=50 \ +  --max_f_train_steps=3 \ +  --max_adv_train_steps=6 \ +  --checkpointing_iterations=10 \ +  --learning_rate=5e-7 \ +  --pgd_alpha=0.005 \ +  --pgd_eps=8 \ +  --seed=0 # ------------------------- 训练后清空 CLASS_DIR ------------------------- # 注意:这会在 accelerate launch 成功结束后执行 diff --git a/src/backend/app/scripts/attack_caat.sh b/src/backend/app/scripts/attack_caat.sh index fe394e6..00a9f8c 100644 --- a/src/backend/app/scripts/attack_caat.sh +++ b/src/backend/app/scripts/attack_caat.sh @@ -21,19 +21,17 @@ echo "Clearing output directory: $OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete -#--debug_oom_step0_only \ + accelerate launch ../algorithms/caat.py \ - --pretrained_model_name_or_path="$MODEL_NAME" \ - --instance_data_dir="$INSTANCE_DIR" \ - --output_dir="$OUTPUT_DIR" \ - --instance_prompt="a photo of person" \ - --resolution 512 \ - --learning_rate 1e-5 \ - --lr_warmup_steps 0 \ - --max_train_steps 250 \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a person" \ + --resolution=512 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ --hflip \ - --mixed_precision bf16 \ - --alpha 5e-3 \ - --eps 0.05 \ - --debug_oom \ - --debug_oom_sync \ No newline at end of file + --mixed_precision bf16 \ + --alpha=5e-3 \ + --eps=0.05 \ No newline at end of file diff --git a/src/backend/app/scripts/attack_caat_with_prior.sh b/src/backend/app/scripts/attack_caat_with_prior.sh index ecb7c92..a7e149e 100644 --- a/src/backend/app/scripts/attack_caat_with_prior.sh +++ b/src/backend/app/scripts/attack_caat_with_prior.sh @@ -22,13 +22,13 @@ echo "Clearing output directory: $OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete -#--debug_oom_step0_only \ + accelerate launch ../algorithms/caat.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --with_prior_preservation \ - --instance_prompt="a photo of person" \ + --instance_prompt="a photo of a person" \ --num_class_images=200 \ --class_data_dir=$CLASS_DIR \ --class_prompt='person' \ @@ -39,9 +39,7 @@ accelerate launch ../algorithms/caat.py \ --hflip \ --mixed_precision bf16 \ --alpha=5e-3 \ - --eps=0.05 \ - --debug_oom \ - --debug_oom_sync + --eps=0.05 # ------------------------- 【步骤 2】训练后清空 CLASS_DIR ------------------------- diff --git a/src/backend/app/scripts/attack_simac.sh b/src/backend/app/scripts/attack_simac.sh index a6b9f20..660d6a1 100644 --- a/src/backend/app/scripts/attack_simac.sh +++ b/src/backend/app/scripts/attack_simac.sh @@ -25,20 +25,20 @@ echo "Clearing output directory: $OUTPUT_DIR" mkdir -p "$OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete -# find "$CLASS_DIR" -mindepth 1 -delete +find "$CLASS_DIR" -mindepth 1 -delete -accelerate launch --num_processes 1 --num_machines 1 ../algorithms/simac.py \ - --pretrained_model_name_or_path="$MODEL_PATH" \ +accelerate launch ../algorithms/simac.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ --enable_xformers_memory_efficient_attention \ - --instance_data_dir_for_train="$CLEAN_TRAIN_DIR" \ - --instance_data_dir_for_adversarial="$CLEAN_ADV_DIR" \ + --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ + --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ --instance_prompt="a photo of person" \ - --class_data_dir="$CLASS_DIR" \ + --class_data_dir=$CLASS_DIR \ --num_class_images=100 \ --class_prompt="a photo of person" \ - --output_dir="$OUTPUT_DIR" \ + --output_dir=$OUTPUT_DIR \ --center_crop \ --with_prior_preservation \ --prior_loss_weight=1.0 \ diff --git a/src/backend/config/algorithm_config.py b/src/backend/config/algorithm_config.py index 49c7631..53662d5 100644 --- a/src/backend/config/algorithm_config.py +++ b/src/backend/config/algorithm_config.py @@ -145,10 +145,7 @@ class AlgorithmConfig: 'max_train_steps': 250, 'hflip': True, 'mixed_precision': 'bf16', - 'alpha': 5e-3, - 'eps': 0.05, - 'debug_oom': True, - 'debug_oom_sync': True + 'alpha': 5e-3 } }, 'caat_pro': { @@ -159,7 +156,7 @@ class AlgorithmConfig: 'pretrained_model_name_or_path': MODELS_DIR['model2'], 'with_prior_preservation': True, 'instance_prompt': 'a selfie photo of person', - 'class_prompt': 'person', + 'class_prompt': 'a selfie photo of person', 'num_class_images': 200, 'resolution': 512, 'learning_rate': 1e-5, @@ -168,9 +165,7 @@ class AlgorithmConfig: 'hflip': True, 'mixed_precision': 'bf16', 'alpha': 5e-3, - 'eps': 0.05, - 'debug_oom': True, - 'debug_oom_sync': True + 'eps': 0.05 } }, 'pid': { @@ -238,6 +233,7 @@ class AlgorithmConfig: 'max_train_steps': 2000, 'center_crop': True, 'step_size': 0.002, + 'save_every': 200, 'attack_type': 'add-log', 'seed': 0, 'dataloader_num_workers': 2 -- 2.34.1 From cb67dc73a30189a0cd3c5755223f6d0de4596302 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E6=B5=A9?= <971817787@qq.com> Date: Wed, 7 Jan 2026 14:49:33 +0800 Subject: [PATCH 5/5] =?UTF-8?q?fix:=20=E6=97=B6=E9=97=B4=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=E6=9C=AC=E5=9C=B0=E7=B3=BB=E7=BB=9F=E6=97=B6?= =?UTF-8?q?=E5=8C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/database/__init__.py | 10 +++++----- src/backend/app/repositories/task_repository.py | 4 ++-- src/backend/app/services/image/image_storage.py | 2 +- src/backend/app/services/image_service.py | 2 +- src/backend/app/services/task_service.py | 4 ++-- src/backend/app/services/vip_service.py | 4 ++-- src/backend/app/workers/evaluate_worker.py | 6 +++--- src/backend/app/workers/finetune_worker.py | 6 +++--- src/backend/app/workers/heatmap_worker.py | 6 +++--- src/backend/app/workers/perturbation_worker.py | 6 +++--- src/backend/tests/factories.py | 2 +- 11 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index 990816e..502f80e 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -38,8 +38,8 @@ class User(db.Model): email = db.Column(String(100), unique=True, nullable=False, index=True, comment='邮箱') role_id = db.Column(Integer, ForeignKey('role.role_id'), nullable=False, comment='外键关联role表') is_active = db.Column(Boolean, default=True, comment='是否激活') - created_at = db.Column(DateTime, default=datetime.utcnow, comment='创建时间') - updated_at = db.Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment='更新时间') + created_at = db.Column(DateTime, default=datetime.now, comment='创建时间') + updated_at = db.Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间') # 关系 role = db.relationship('Role', backref=db.backref('users', lazy='dynamic')) @@ -177,8 +177,8 @@ class UserConfig(db.Model): perturbation_configs_id = db.Column(Integer, ForeignKey('perturbation_configs.perturbation_configs_id'), default=None, comment='默认加噪算法') perturbation_intensity = db.Column(Float, default=None, comment='默认扰动强度') finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), default=None, comment='默认微调方式') - created_at = db.Column(DateTime, default=datetime.utcnow) - updated_at = db.Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at = db.Column(DateTime, default=datetime.now) + updated_at = db.Column(DateTime, default=datetime.now, onupdate=datetime.now) # 关系 data_type = db.relationship('DataType') @@ -212,7 +212,7 @@ class Task(db.Model): tasks_type_id = db.Column(Integer, ForeignKey('task_type.task_type_id'), nullable=False, comment='任务类型') user_id = db.Column(Integer, ForeignKey('users.user_id'), nullable=False, index=True, comment='归属用户') tasks_status_id = db.Column(Integer, ForeignKey('task_status.task_status_id'), nullable=False, comment='任务状态ID') - created_at = db.Column(DateTime, default=datetime.utcnow) + created_at = db.Column(DateTime, default=datetime.now) started_at = db.Column(DateTime, default=None) finished_at = db.Column(DateTime, default=None) error_message = db.Column(Text, comment='错误信息') diff --git a/src/backend/app/repositories/task_repository.py b/src/backend/app/repositories/task_repository.py index 446b958..872a51f 100644 --- a/src/backend/app/repositories/task_repository.py +++ b/src/backend/app/repositories/task_repository.py @@ -121,9 +121,9 @@ class TaskRepository(BaseRepository[Task]): # 自动更新时间戳 if status_code == 'processing': - task.started_at = datetime.utcnow() + task.started_at = datetime.now() elif status_code in ('completed', 'failed'): - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() return True diff --git a/src/backend/app/services/image/image_storage.py b/src/backend/app/services/image/image_storage.py index ec2beb5..8c705ae 100644 --- a/src/backend/app/services/image/image_storage.py +++ b/src/backend/app/services/image/image_storage.py @@ -242,7 +242,7 @@ class ImageStorage: def _save_with_unique_name(self, image, target_dir: str) -> Tuple[str, str, int, int, int]: """保存图片并生成唯一文件名""" - timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') + timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f') filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" path = os.path.join(target_dir, filename) diff --git a/src/backend/app/services/image_service.py b/src/backend/app/services/image_service.py index f4021ed..9ed2272 100644 --- a/src/backend/app/services/image_service.py +++ b/src/backend/app/services/image_service.py @@ -249,7 +249,7 @@ class ImageService: import uuid from datetime import datetime - timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f') + timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f') filename = f"{timestamp}_{uuid.uuid4().hex[:6]}.png" path = os.path.join(target_dir, filename) image.save(path, format='PNG') diff --git a/src/backend/app/services/task_service.py b/src/backend/app/services/task_service.py index 669a711..3ce39eb 100644 --- a/src/backend/app/services/task_service.py +++ b/src/backend/app/services/task_service.py @@ -234,7 +234,7 @@ class TaskService: @staticmethod def generate_flow_id(): """生成唯一的flow_id""" - base = int(datetime.utcnow().timestamp() * 1000) + base = int(datetime.now().timestamp() * 1000) task_repo = _get_task_repo() while task_repo.find_one_by(flow_id=base): base += 1 @@ -468,7 +468,7 @@ class TaskService: logger.warning(f"Could not cancel/stop RQ job: {e}") # 更新为cancelled if task_repo.update_status(task, 'cancelled'): - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() return task_repo.save() return False except Exception as e: diff --git a/src/backend/app/services/vip_service.py b/src/backend/app/services/vip_service.py index 8b883b2..56c6ea4 100644 --- a/src/backend/app/services/vip_service.py +++ b/src/backend/app/services/vip_service.py @@ -108,7 +108,7 @@ class VipService: code_info = { 'used': True, 'used_by': user_id, - 'used_at': datetime.utcnow().isoformat() + 'used_at': datetime.now().isoformat() } # 已使用的邀请码保留90天记录 @@ -137,7 +137,7 @@ class VipService: 'used': False, 'used_by': None, 'used_at': None, - 'created_at': datetime.utcnow().isoformat(), + 'created_at': datetime.now().isoformat(), 'expires_days': expires_days } diff --git a/src/backend/app/workers/evaluate_worker.py b/src/backend/app/workers/evaluate_worker.py index 0a3dd2e..5618cc3 100644 --- a/src/backend/app/workers/evaluate_worker.py +++ b/src/backend/app/workers/evaluate_worker.py @@ -55,7 +55,7 @@ def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() if processing_status: task.tasks_status_id = processing_status.task_status_id - task.started_at = datetime.utcnow() + task.started_at = datetime.now() db.session.commit() logger.info(f"Starting evaluate task {task_id}") @@ -104,7 +104,7 @@ def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() logger.info(f"Evaluate task {task_id} completed") @@ -117,7 +117,7 @@ def run_evaluate_task(task_id, clean_ref_dir, clean_output_dir, failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() return {'success': False, 'error': str(e)} diff --git a/src/backend/app/workers/finetune_worker.py b/src/backend/app/workers/finetune_worker.py index d91b7c6..8c3c4bd 100644 --- a/src/backend/app/workers/finetune_worker.py +++ b/src/backend/app/workers/finetune_worker.py @@ -60,7 +60,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, if processing_status: task.tasks_status_id = processing_status.task_status_id if not task.started_at: - task.started_at = datetime.utcnow() + task.started_at = datetime.now() db.session.commit() logger.info(f"Method: {finetune_method}, finetune_type: {finetune_type}") @@ -181,7 +181,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() logger.info(f"Finetune task {task_id} completed successfully") @@ -195,7 +195,7 @@ def run_finetune_task(task_id, finetune_method, train_images_dir, failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() task.error_message = str(e) db.session.commit() except: diff --git a/src/backend/app/workers/heatmap_worker.py b/src/backend/app/workers/heatmap_worker.py index 0d57dd5..4b3d00b 100644 --- a/src/backend/app/workers/heatmap_worker.py +++ b/src/backend/app/workers/heatmap_worker.py @@ -54,7 +54,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path, processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() if processing_status: task.tasks_status_id = processing_status.task_status_id - task.started_at = datetime.utcnow() + task.started_at = datetime.now() db.session.commit() logger.info(f"Starting heatmap task {task_id}") @@ -127,7 +127,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path, completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() logger.info(f"Heatmap task {task_id} completed") @@ -140,7 +140,7 @@ def run_heatmap_task(task_id, original_image_path, perturbed_image_path, failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() return {'success': False, 'error': str(e)} diff --git a/src/backend/app/workers/perturbation_worker.py b/src/backend/app/workers/perturbation_worker.py index 10ae359..4f891f6 100644 --- a/src/backend/app/workers/perturbation_worker.py +++ b/src/backend/app/workers/perturbation_worker.py @@ -59,7 +59,7 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, input_dir, output_di processing_status = TaskStatus.query.filter_by(task_status_code='processing').first() if processing_status: task.tasks_status_id = processing_status.task_status_id - task.started_at = datetime.utcnow() + task.started_at = datetime.now() db.session.commit() logger.info(f"Starting perturbation task {task_id}") @@ -119,7 +119,7 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, input_dir, output_di completed_status = TaskStatus.query.filter_by(task_status_code='completed').first() if completed_status: task.tasks_status_id = completed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() db.session.commit() logger.info(f"Perturbation task {task_id} completed successfully") @@ -132,7 +132,7 @@ def run_perturbation_task(task_id, algorithm_code, epsilon, input_dir, output_di failed_status = TaskStatus.query.filter_by(task_status_code='failed').first() if failed_status: task.tasks_status_id = failed_status.task_status_id - task.finished_at = datetime.utcnow() + task.finished_at = datetime.now() task.error_message = str(e) db.session.commit() diff --git a/src/backend/tests/factories.py b/src/backend/tests/factories.py index 8785ee6..9cf34bc 100644 --- a/src/backend/tests/factories.py +++ b/src/backend/tests/factories.py @@ -157,7 +157,7 @@ class TaskFactory(BaseFactory): user_id = factory.LazyAttribute(lambda obj: UserFactory().user_id) tasks_status_id = 1 # waiting description = factory.Faker('sentence', locale='zh_CN') - created_at = factory.LazyFunction(datetime.utcnow) + created_at = factory.LazyFunction(datetime.now) class PerturbationTaskFactory(TaskFactory): -- 2.34.1