|
|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
import argparse
|
|
|
|
|
import copy
|
|
|
|
|
import gc
|
|
|
|
|
import hashlib
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
|
|
|
|
import datasets
|
|
|
|
|
import diffusers
|
|
|
|
|
@ -28,42 +30,123 @@ from transformers import AutoTokenizer, PretrainedConfig
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Lightweight debug helpers (low overhead)
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def _cuda_gc() -> None:
|
|
|
|
|
"""Best-effort CUDA memory cleanup (does not change algorithmic behavior)."""
|
|
|
|
|
gc.collect()
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fmt_bytes(n: int) -> str:
|
|
|
|
|
return f"{n / (1024**2):.1f}MB"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_cuda(
|
|
|
|
|
prefix: str,
|
|
|
|
|
accelerator: Optional[Accelerator] = None,
|
|
|
|
|
sync: bool = False,
|
|
|
|
|
extra: Optional[Dict[str, Any]] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Log CUDA memory stats without copying tensors to CPU."""
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
|
logger.info(f"[mem] {prefix} cuda_not_available")
|
|
|
|
|
return
|
|
|
|
|
if sync:
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
alloc = torch.cuda.memory_allocated()
|
|
|
|
|
reserv = torch.cuda.memory_reserved()
|
|
|
|
|
max_alloc = torch.cuda.max_memory_allocated()
|
|
|
|
|
max_reserv = torch.cuda.max_memory_reserved()
|
|
|
|
|
dev = str(accelerator.device) if accelerator is not None else "cuda"
|
|
|
|
|
msg = (
|
|
|
|
|
f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} "
|
|
|
|
|
f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}"
|
|
|
|
|
)
|
|
|
|
|
if extra:
|
|
|
|
|
msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()])
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_path_stats(prefix: str, p: Path) -> None:
|
|
|
|
|
"""Log directory/file existence and file count (best-effort)."""
|
|
|
|
|
try:
|
|
|
|
|
exists = p.exists()
|
|
|
|
|
is_dir = p.is_dir() if exists else False
|
|
|
|
|
n_files = 0
|
|
|
|
|
if exists and is_dir:
|
|
|
|
|
n_files = sum(1 for x in p.iterdir() if x.is_file())
|
|
|
|
|
logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_args(args: argparse.Namespace) -> None:
|
|
|
|
|
for k in sorted(vars(args).keys()):
|
|
|
|
|
logger.info(f"[args] {k}={getattr(args, k)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_tensor_meta(prefix: str, t: Optional[torch.Tensor]) -> None:
|
|
|
|
|
if t is None:
|
|
|
|
|
logger.info(f"[tensor] {prefix} None")
|
|
|
|
|
return
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
|
|
|
|
|
f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Dataset
|
|
|
|
|
# -----------------------------
|
|
|
|
|
class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
"""基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
instance_images_tensor,
|
|
|
|
|
instance_prompt,
|
|
|
|
|
tokenizer,
|
|
|
|
|
class_data_root=None,
|
|
|
|
|
class_prompt=None,
|
|
|
|
|
size=512,
|
|
|
|
|
center_crop=False,
|
|
|
|
|
instance_images_tensor: torch.Tensor,
|
|
|
|
|
instance_prompt: str,
|
|
|
|
|
tokenizer: AutoTokenizer,
|
|
|
|
|
class_data_root: Optional[str] = None,
|
|
|
|
|
class_prompt: Optional[str] = None,
|
|
|
|
|
size: int = 512,
|
|
|
|
|
center_crop: bool = False,
|
|
|
|
|
):
|
|
|
|
|
# 保存图像处理参数与 tokenizer
|
|
|
|
|
self.size = size
|
|
|
|
|
self.center_crop = center_crop
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
# 实例数据:直接来自传入的张量列表
|
|
|
|
|
self.instance_images_tensor = instance_images_tensor
|
|
|
|
|
self.num_instance_images = len(self.instance_images_tensor)
|
|
|
|
|
self.instance_prompt = instance_prompt
|
|
|
|
|
self._length = self.num_instance_images
|
|
|
|
|
|
|
|
|
|
# 可选类数据:用于先验保持,长度取实例与类数据的最大值
|
|
|
|
|
if class_data_root is not None:
|
|
|
|
|
self.class_data_root = Path(class_data_root)
|
|
|
|
|
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
self.class_images_path = list(self.class_data_root.iterdir())
|
|
|
|
|
# Only keep files to avoid directories affecting length.
|
|
|
|
|
self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()]
|
|
|
|
|
self.num_class_images = len(self.class_images_path)
|
|
|
|
|
self._length = max(self.num_class_images, self.num_instance_images)
|
|
|
|
|
self.class_prompt = class_prompt
|
|
|
|
|
|
|
|
|
|
if self.num_class_images == 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"class_data_dir is empty: {self.class_data_root}. "
|
|
|
|
|
f"Prior preservation requires class images. "
|
|
|
|
|
f"Please generate class images first, or fix class_data_dir, "
|
|
|
|
|
f"or disable --with_prior_preservation."
|
|
|
|
|
)
|
|
|
|
|
if self.class_prompt is None:
|
|
|
|
|
raise ValueError("class_prompt is required when class_data_root is provided.")
|
|
|
|
|
else:
|
|
|
|
|
self.class_data_root = None
|
|
|
|
|
self.class_images_path = []
|
|
|
|
|
self.num_class_images = 0
|
|
|
|
|
self.class_prompt = None
|
|
|
|
|
|
|
|
|
|
# 统一的图像预处理
|
|
|
|
|
self.image_transforms = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
|
|
@ -73,12 +156,11 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return self._length
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
# 取出实例图像张量与对应 prompt token
|
|
|
|
|
example = {}
|
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
|
|
|
|
example: Dict[str, Any] = {}
|
|
|
|
|
instance_image = self.instance_images_tensor[index % self.num_instance_images]
|
|
|
|
|
example["instance_images"] = instance_image
|
|
|
|
|
example["instance_prompt_ids"] = self.tokenizer(
|
|
|
|
|
@ -89,14 +171,15 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
).input_ids
|
|
|
|
|
|
|
|
|
|
# 若有类数据,则同时返回类图像与类 prompt token
|
|
|
|
|
if self.class_data_root:
|
|
|
|
|
if self.class_data_root is not None:
|
|
|
|
|
if self.num_class_images == 0:
|
|
|
|
|
raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}")
|
|
|
|
|
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
|
|
|
|
if not class_image.mode == "RGB":
|
|
|
|
|
if class_image.mode != "RGB":
|
|
|
|
|
class_image = class_image.convert("RGB")
|
|
|
|
|
example["class_images"] = self.image_transforms(class_image)
|
|
|
|
|
example["class_prompt_ids"] = self.tokenizer(
|
|
|
|
|
self.class_prompt,
|
|
|
|
|
self.class_prompt, # type: ignore[arg-type]
|
|
|
|
|
truncation=True,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
|
@ -106,8 +189,10 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
return example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
|
|
|
|
# 根据 text_encoder 配置识别其架构,选择正确的模型类
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Model helper
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: Optional[str]):
|
|
|
|
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
subfolder="text_encoder",
|
|
|
|
|
@ -119,254 +204,97 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
|
|
|
|
from transformers import CLIPTextModel
|
|
|
|
|
|
|
|
|
|
return CLIPTextModel
|
|
|
|
|
elif model_class == "RobertaSeriesModelWithTransformation":
|
|
|
|
|
if model_class == "RobertaSeriesModelWithTransformation":
|
|
|
|
|
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
|
|
|
|
|
|
|
|
|
return RobertaSeriesModelWithTransformation
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"{model_class} is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(input_args=None):
|
|
|
|
|
# 解析命令行参数:模型路径、数据路径、对抗参数、先验保持、训练与日志配置
|
|
|
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pretrained_model_name_or_path",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--revision",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=False,
|
|
|
|
|
help=(
|
|
|
|
|
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
|
|
|
|
" float32 precision."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tokenizer_name",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_data_dir_for_train",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="A folder containing the training data of instance images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_data_dir_for_adversarial",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="A folder containing the images to add adversarial noise",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--class_data_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=False,
|
|
|
|
|
help="A folder containing the training data of class images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_prompt",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="The prompt with identifier specifying the instance",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--class_prompt",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="The prompt to specify images in the same class as provided instance images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--with_prior_preservation",
|
|
|
|
|
default=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Flag to add prior preservation loss.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--prior_loss_weight",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0,
|
|
|
|
|
help="The weight of prior preservation loss.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--num_class_images",
|
|
|
|
|
type=int,
|
|
|
|
|
default=100,
|
|
|
|
|
help=(
|
|
|
|
|
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
|
|
|
|
" class_data_dir, additional images will be sampled with class_prompt."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="text-inversion-model",
|
|
|
|
|
help="The output directory where the model predictions and checkpoints will be written.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--resolution",
|
|
|
|
|
type=int,
|
|
|
|
|
default=512,
|
|
|
|
|
help=(
|
|
|
|
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
|
|
|
|
" resolution"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--center_crop",
|
|
|
|
|
default=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help=(
|
|
|
|
|
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
|
|
|
|
" cropped. The images will be resized to the resolution first before cropping."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_text_encoder",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_batch_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=4,
|
|
|
|
|
help="Batch size (per device) for the training dataloader.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--sample_batch_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=8,
|
|
|
|
|
help="Batch size (per device) for sampling images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=20,
|
|
|
|
|
help="Total number of training steps to perform.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_f_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=10,
|
|
|
|
|
help="Total number of sub-steps to train surogate model.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_adv_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=10,
|
|
|
|
|
help="Total number of sub-steps to train adversarial noise.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--checkpointing_iterations",
|
|
|
|
|
type=int,
|
|
|
|
|
default=5,
|
|
|
|
|
help=("Save a checkpoint of the training state every X iterations."),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--learning_rate",
|
|
|
|
|
type=float,
|
|
|
|
|
default=5e-6,
|
|
|
|
|
help="Initial learning rate (after the potential warmup period) to use.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--logging_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="logs",
|
|
|
|
|
help=(
|
|
|
|
|
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
|
|
|
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--allow_tf32",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help=(
|
|
|
|
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
|
|
|
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--report_to",
|
|
|
|
|
type=str,
|
|
|
|
|
default="tensorboard",
|
|
|
|
|
help=(
|
|
|
|
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
|
|
|
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mixed_precision",
|
|
|
|
|
type=str,
|
|
|
|
|
default="fp16",
|
|
|
|
|
choices=["no", "fp16", "bf16"],
|
|
|
|
|
help=(
|
|
|
|
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
|
|
|
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
|
|
|
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable_xformers_memory_efficient_attention",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether or not to use xformers.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pgd_alpha",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0 / 255,
|
|
|
|
|
help="The step size for pgd.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pgd_eps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0.05,
|
|
|
|
|
help="The noise budget for pgd.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--target_image_path",
|
|
|
|
|
default=None,
|
|
|
|
|
help="target image for attacking",
|
|
|
|
|
)
|
|
|
|
|
raise ValueError(f"{model_class} is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Args
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def parse_args(input_args=None) -> argparse.Namespace:
|
|
|
|
|
parser = argparse.ArgumentParser(description="ASPL training script with diagnostics.")
|
|
|
|
|
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--revision", type=str, default=None, required=False)
|
|
|
|
|
parser.add_argument("--tokenizer_name", type=str, default=None)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--class_data_dir", type=str, default=None, required=False)
|
|
|
|
|
parser.add_argument("--instance_prompt", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--class_prompt", type=str, default=None)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--with_prior_preservation", default=False, action="store_true")
|
|
|
|
|
parser.add_argument("--prior_loss_weight", type=float, default=1.0)
|
|
|
|
|
parser.add_argument("--num_class_images", type=int, default=100)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="text-inversion-model")
|
|
|
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--resolution", type=int, default=512)
|
|
|
|
|
parser.add_argument("--center_crop", default=False, action="store_true")
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--train_text_encoder", action="store_true")
|
|
|
|
|
parser.add_argument("--train_batch_size", type=int, default=4)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--sample_batch_size", type=int, default=8)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--max_train_steps", type=int, default=20)
|
|
|
|
|
parser.add_argument("--max_f_train_steps", type=int, default=10)
|
|
|
|
|
parser.add_argument("--max_adv_train_steps", type=int, default=10)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--checkpointing_iterations", type=int, default=5)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
|
|
|
|
parser.add_argument("--logging_dir", type=str, default="logs")
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--allow_tf32", action="store_true")
|
|
|
|
|
parser.add_argument("--report_to", type=str, default="tensorboard")
|
|
|
|
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--pgd_alpha", type=float, default=1.0 / 255)
|
|
|
|
|
parser.add_argument("--pgd_eps", type=float, default=0.05) # keep float, later /255
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--target_image_path", default=None)
|
|
|
|
|
|
|
|
|
|
# Debug / diagnostics (low-overhead)
|
|
|
|
|
parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.")
|
|
|
|
|
parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.")
|
|
|
|
|
parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.")
|
|
|
|
|
|
|
|
|
|
if input_args is not None:
|
|
|
|
|
args = parser.parse_args(input_args)
|
|
|
|
|
else:
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Class image prompt dataset
|
|
|
|
|
# -----------------------------
|
|
|
|
|
class PromptDataset(Dataset):
|
|
|
|
|
"""用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, prompt, num_samples):
|
|
|
|
|
def __init__(self, prompt: str, num_samples: int):
|
|
|
|
|
self.prompt = prompt
|
|
|
|
|
self.num_samples = num_samples
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return self.num_samples
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
example = {}
|
|
|
|
|
example["prompt"] = self.prompt
|
|
|
|
|
example["index"] = index
|
|
|
|
|
return example
|
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
|
|
|
|
return {"prompt": self.prompt, "index": index}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
|
|
|
|
|
# 读取目录下所有图片,按训练要求 resize/crop/normalize,返回堆叠后的张量
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# IO
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def load_data(data_dir: Path, size: int = 512, center_crop: bool = True) -> torch.Tensor:
|
|
|
|
|
image_transforms = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
|
|
@ -376,21 +304,16 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())]
|
|
|
|
|
images = torch.stack(images)
|
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
models,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
data_tensor: torch.Tensor,
|
|
|
|
|
num_steps=20,
|
|
|
|
|
):
|
|
|
|
|
# 单轮训练:复制当前模型,使用给定数据迭代若干步,返回更新后的副本
|
|
|
|
|
images = [image_transforms(Image.open(p).convert("RGB")) for p in list(Path(data_dir).iterdir()) if p.is_file()]
|
|
|
|
|
if len(images) == 0:
|
|
|
|
|
raise ValueError(f"No image files found in directory: {data_dir}")
|
|
|
|
|
return torch.stack(images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Core routines
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps: int = 20):
|
|
|
|
|
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
|
|
|
|
|
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
|
|
|
|
|
|
|
|
|
|
@ -406,7 +329,7 @@ def train_one_epoch(
|
|
|
|
|
data_tensor,
|
|
|
|
|
args.instance_prompt,
|
|
|
|
|
tokenizer,
|
|
|
|
|
args.class_data_dir,
|
|
|
|
|
args.class_data_dir if args.with_prior_preservation else None,
|
|
|
|
|
args.class_prompt,
|
|
|
|
|
args.resolution,
|
|
|
|
|
args.center_crop,
|
|
|
|
|
@ -423,8 +346,8 @@ def train_one_epoch(
|
|
|
|
|
unet.train()
|
|
|
|
|
text_encoder.train()
|
|
|
|
|
|
|
|
|
|
# 构造当前步的样本(instance + class),并生成文本 token
|
|
|
|
|
step_data = train_dataset[step % len(train_dataset)]
|
|
|
|
|
|
|
|
|
|
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
|
|
|
|
|
device, dtype=weight_dtype
|
|
|
|
|
)
|
|
|
|
|
@ -433,20 +356,14 @@ def train_one_epoch(
|
|
|
|
|
latents = vae.encode(pixel_values).latent_dist.sample()
|
|
|
|
|
latents = latents * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
# 随机采样时间步并加噪,模拟正向扩散
|
|
|
|
|
noise = torch.randn_like(latents)
|
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
|
|
|
timesteps = timesteps.long()
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
# 文本条件编码
|
|
|
|
|
encoder_hidden_states = text_encoder(input_ids)[0]
|
|
|
|
|
|
|
|
|
|
# UNet 预测噪声
|
|
|
|
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
|
|
|
|
|
|
# 根据 scheduler 的预测类型选择目标
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
|
|
|
target = noise
|
|
|
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
|
|
|
@ -454,7 +371,6 @@ def train_one_epoch(
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
|
|
|
|
|
|
# 可选先验保持:拆分 instance 与 class 部分分别计算 MSE
|
|
|
|
|
if args.with_prior_preservation:
|
|
|
|
|
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
|
|
|
|
target, target_prior = torch.chunk(target, 2, dim=0)
|
|
|
|
|
@ -462,34 +378,30 @@ def train_one_epoch(
|
|
|
|
|
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
|
|
|
|
loss = instance_loss + args.prior_loss_weight * prior_loss
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
prior_loss = torch.tensor(0.0, device=device)
|
|
|
|
|
instance_loss = torch.tensor(0.0, device=device)
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
|
|
|
|
|
optimizer.step()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
print(
|
|
|
|
|
f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}"
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} "
|
|
|
|
|
f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return [unet, text_encoder]
|
|
|
|
|
del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states
|
|
|
|
|
del model_pred, target, loss, prior_loss, instance_loss
|
|
|
|
|
|
|
|
|
|
del optimizer, train_dataset, params_to_optimize
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
return [unet, text_encoder]
|
|
|
|
|
|
|
|
|
|
def pgd_attack(
|
|
|
|
|
args,
|
|
|
|
|
models,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
data_tensor: torch.Tensor,
|
|
|
|
|
original_images: torch.Tensor,
|
|
|
|
|
target_tensor: torch.Tensor,
|
|
|
|
|
num_steps: int,
|
|
|
|
|
):
|
|
|
|
|
"""PGD 对抗扰动:在噪声预算内迭代更新输入,返回新的扰动数据。"""
|
|
|
|
|
|
|
|
|
|
def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps: int):
|
|
|
|
|
unet, text_encoder = models
|
|
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
@ -510,22 +422,19 @@ def pgd_attack(
|
|
|
|
|
).input_ids.repeat(len(data_tensor), 1)
|
|
|
|
|
|
|
|
|
|
for step in range(num_steps):
|
|
|
|
|
perturbed_images.requires_grad = True
|
|
|
|
|
perturbed_images.requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
|
|
|
|
|
latents = latents * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
# 采样时间步并加噪,准备 UNet 预测
|
|
|
|
|
noise = torch.randn_like(latents)
|
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
|
|
|
timesteps = timesteps.long()
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
# 文本条件与噪声预测
|
|
|
|
|
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
|
|
|
|
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
|
|
|
|
|
|
# 目标噪声或速度
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
|
|
|
target = noise
|
|
|
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
|
|
|
@ -533,11 +442,10 @@ def pgd_attack(
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
|
|
|
|
|
|
unet.zero_grad()
|
|
|
|
|
text_encoder.zero_grad()
|
|
|
|
|
unet.zero_grad(set_to_none=True)
|
|
|
|
|
text_encoder.zero_grad(set_to_none=True)
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
# 若有目标图像 latent,加入目标对齐项(保持原有逻辑:损失为差值)
|
|
|
|
|
if target_tensor is not None:
|
|
|
|
|
xtm1_pred = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
@ -554,18 +462,26 @@ def pgd_attack(
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
# PGD 更新并投影到 eps 球内,再裁剪到 [-1, 1]
|
|
|
|
|
alpha = args.pgd_alpha
|
|
|
|
|
eps = args.pgd_eps / 255
|
|
|
|
|
alpha = args.pgd_alpha
|
|
|
|
|
eps = float(args.pgd_eps) / 255.0
|
|
|
|
|
|
|
|
|
|
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
|
|
|
|
|
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
|
|
|
|
|
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
|
|
|
|
|
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"[pgd] step={step} loss={loss.detach().item():.6f} alpha={alpha} eps={eps}")
|
|
|
|
|
|
|
|
|
|
del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss
|
|
|
|
|
del adv_images, eta
|
|
|
|
|
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
return perturbed_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Main
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def main(args: argparse.Namespace) -> None:
|
|
|
|
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
|
|
|
|
|
|
|
|
accelerator = Accelerator(
|
|
|
|
|
@ -580,6 +496,7 @@ def main(args):
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
)
|
|
|
|
|
logger.info(accelerator.state, main_process_only=False)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
datasets.utils.logging.set_verbosity_warning()
|
|
|
|
|
transformers.utils.logging.set_verbosity_warning()
|
|
|
|
|
@ -589,15 +506,35 @@ def main(args):
|
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
|
diffusers.utils.logging.set_verbosity_error()
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[run] using_file={__file__}")
|
|
|
|
|
log_args(args)
|
|
|
|
|
|
|
|
|
|
if args.seed is not None:
|
|
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
# 先验保持:不足的 class 图像用基础模型生成补齐
|
|
|
|
|
if args.debug and accelerator.is_local_main_process:
|
|
|
|
|
log_cuda("startup", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Prior preservation: generate class images if needed
|
|
|
|
|
# -------------------------
|
|
|
|
|
if args.with_prior_preservation:
|
|
|
|
|
if args.class_data_dir is None:
|
|
|
|
|
raise ValueError("--with_prior_preservation requires --class_data_dir")
|
|
|
|
|
if args.class_prompt is None:
|
|
|
|
|
raise ValueError("--with_prior_preservation requires --class_prompt")
|
|
|
|
|
|
|
|
|
|
class_images_dir = Path(args.class_data_dir)
|
|
|
|
|
if not class_images_dir.exists():
|
|
|
|
|
class_images_dir.mkdir(parents=True)
|
|
|
|
|
cur_class_images = len(list(class_images_dir.iterdir()))
|
|
|
|
|
class_images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
log_path_stats("class_dir_before", class_images_dir)
|
|
|
|
|
|
|
|
|
|
cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}")
|
|
|
|
|
|
|
|
|
|
if cur_class_images < args.num_class_images:
|
|
|
|
|
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
|
|
|
|
if args.mixed_precision == "fp32":
|
|
|
|
|
@ -606,6 +543,12 @@ def main(args):
|
|
|
|
|
torch_dtype = torch.float16
|
|
|
|
|
elif args.mixed_precision == "bf16":
|
|
|
|
|
torch_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}")
|
|
|
|
|
if args.debug:
|
|
|
|
|
log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
@ -615,8 +558,6 @@ def main(args):
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
num_new_images = args.num_class_images - cur_class_images
|
|
|
|
|
logger.info(f"Number of class images to sample: {num_new_images}.")
|
|
|
|
|
|
|
|
|
|
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
|
|
|
|
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
|
|
|
|
|
|
|
|
|
@ -629,19 +570,41 @@ def main(args):
|
|
|
|
|
disable=not accelerator.is_local_main_process,
|
|
|
|
|
):
|
|
|
|
|
images = pipeline(example["prompt"]).images
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
logger.info(f"[class_gen] generated_images={len(images)}")
|
|
|
|
|
|
|
|
|
|
for i, image in enumerate(images):
|
|
|
|
|
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
|
|
|
|
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
|
|
|
|
image.save(image_filename)
|
|
|
|
|
|
|
|
|
|
del pipeline
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
del pipeline, sample_dataset, sample_dataloader
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
|
|
|
|
|
final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] done final_class_images={final_class_images}")
|
|
|
|
|
log_path_stats("class_dir_after", class_images_dir)
|
|
|
|
|
if final_class_images == 0:
|
|
|
|
|
raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.")
|
|
|
|
|
else:
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[class_gen] skipped (already enough images)")
|
|
|
|
|
else:
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[class_gen] disabled (with_prior_preservation is False)")
|
|
|
|
|
|
|
|
|
|
# 加载 text encoder / UNet / tokenizer / scheduler / VAE
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Load models / tokenizer / scheduler / VAE
|
|
|
|
|
# -------------------------
|
|
|
|
|
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
text_encoder = text_encoder_cls.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
subfolder="text_encoder",
|
|
|
|
|
@ -657,13 +620,11 @@ def main(args):
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
use_fast=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
|
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
|
|
|
|
).cuda()
|
|
|
|
|
|
|
|
|
|
vae.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if not args.train_text_encoder:
|
|
|
|
|
@ -672,52 +633,57 @@ def main(args):
|
|
|
|
|
if args.allow_tf32:
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
clean_data = load_data(
|
|
|
|
|
args.instance_data_dir_for_train,
|
|
|
|
|
size=args.resolution,
|
|
|
|
|
center_crop=args.center_crop,
|
|
|
|
|
)
|
|
|
|
|
perturbed_data = load_data(
|
|
|
|
|
args.instance_data_dir_for_adversarial,
|
|
|
|
|
size=args.resolution,
|
|
|
|
|
center_crop=args.center_crop,
|
|
|
|
|
)
|
|
|
|
|
original_data = perturbed_data.clone()
|
|
|
|
|
original_data.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if args.enable_xformers_memory_efficient_attention:
|
|
|
|
|
if is_xformers_available():
|
|
|
|
|
unet.enable_xformers_memory_efficient_attention()
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[xformers] enabled")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
|
|
|
|
|
|
|
|
|
target_latent_tensor = None
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Load data tensors
|
|
|
|
|
# -------------------------
|
|
|
|
|
train_dir = Path(args.instance_data_dir_for_train)
|
|
|
|
|
adv_dir = Path(args.instance_data_dir_for_adversarial)
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_path_stats("train_dir", train_dir)
|
|
|
|
|
log_path_stats("adv_dir", adv_dir)
|
|
|
|
|
|
|
|
|
|
clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop)
|
|
|
|
|
perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop)
|
|
|
|
|
original_data = perturbed_data.clone()
|
|
|
|
|
original_data.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_tensor_meta("clean_data_cpu", clean_data)
|
|
|
|
|
log_tensor_meta("perturbed_data_cpu", perturbed_data)
|
|
|
|
|
|
|
|
|
|
target_latent_tensor: Optional[torch.Tensor] = None
|
|
|
|
|
if args.target_image_path is not None:
|
|
|
|
|
target_image_path = Path(args.target_image_path)
|
|
|
|
|
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
|
|
|
|
|
if not target_image_path.is_file():
|
|
|
|
|
raise ValueError(f"Target image path does not exist: {target_image_path}")
|
|
|
|
|
|
|
|
|
|
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
|
|
|
|
|
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
|
|
|
|
|
|
|
|
|
|
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
|
|
|
|
|
target_latent_tensor = (
|
|
|
|
|
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
|
|
|
|
|
)
|
|
|
|
|
target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16)
|
|
|
|
|
target_latent_tensor = target_latent_tensor * vae.config.scaling_factor
|
|
|
|
|
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
|
|
|
|
|
|
|
|
|
|
# 交替流程:训练 surrogate -> PGD 扰动 -> 用扰动数据再训练主模型,周期性导出对抗样本
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_tensor_meta("target_latent_tensor", target_latent_tensor)
|
|
|
|
|
|
|
|
|
|
f = [unet, text_encoder]
|
|
|
|
|
for i in range(args.max_train_steps):
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[outer] i={i}/{args.max_train_steps}")
|
|
|
|
|
|
|
|
|
|
f_sur = copy.deepcopy(f)
|
|
|
|
|
f_sur = train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
f_sur,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
clean_data,
|
|
|
|
|
args.max_f_train_steps,
|
|
|
|
|
)
|
|
|
|
|
f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps)
|
|
|
|
|
|
|
|
|
|
perturbed_data = pgd_attack(
|
|
|
|
|
args,
|
|
|
|
|
f_sur,
|
|
|
|
|
@ -729,34 +695,30 @@ def main(args):
|
|
|
|
|
target_latent_tensor,
|
|
|
|
|
args.max_adv_train_steps,
|
|
|
|
|
)
|
|
|
|
|
f = train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
f,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
perturbed_data,
|
|
|
|
|
args.max_f_train_steps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 周期保存当前扰动图像,便于后续评估与复现
|
|
|
|
|
f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps)
|
|
|
|
|
|
|
|
|
|
if (i + 1) % args.checkpointing_iterations == 0:
|
|
|
|
|
save_folder = args.output_dir
|
|
|
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
|
|
|
noised_imgs = perturbed_data.detach()
|
|
|
|
|
|
|
|
|
|
img_filenames = [
|
|
|
|
|
Path(instance_path).stem
|
|
|
|
|
for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
img_filenames = [p.stem for p in adv_dir.iterdir() if p.is_file()]
|
|
|
|
|
for img_pixel, img_name in zip(noised_imgs, img_filenames):
|
|
|
|
|
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
|
|
|
|
|
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
|
|
|
|
|
Image.fromarray(
|
|
|
|
|
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
|
|
|
|
|
(img_pixel * 127.5 + 128)
|
|
|
|
|
.clamp(0, 255)
|
|
|
|
|
.to(torch.uint8)
|
|
|
|
|
.permute(1, 2, 0)
|
|
|
|
|
.cpu()
|
|
|
|
|
.numpy()
|
|
|
|
|
).save(save_path)
|
|
|
|
|
|
|
|
|
|
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[save] step={i+1} saved={len(img_filenames)} to {save_folder}")
|
|
|
|
|
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|