|
|
|
|
@ -1,11 +1,9 @@
|
|
|
|
|
import argparse
|
|
|
|
|
import copy
|
|
|
|
|
import gc
|
|
|
|
|
import hashlib
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import datasets
|
|
|
|
|
@ -26,77 +24,12 @@ from torchvision import transforms
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
from transformers import AutoTokenizer, PretrainedConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Lightweight debug helpers (low overhead)
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def _cuda_gc() -> None:
|
|
|
|
|
"""Best-effort CUDA memory cleanup (does not change algorithmic behavior)."""
|
|
|
|
|
gc.collect()
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fmt_bytes(n: int) -> str:
|
|
|
|
|
return f"{n / (1024**2):.1f}MB"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None):
|
|
|
|
|
"""Log CUDA memory stats without copying tensors to CPU."""
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
|
logger.info(f"[mem] {prefix} cuda_not_available")
|
|
|
|
|
return
|
|
|
|
|
if sync:
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
alloc = torch.cuda.memory_allocated()
|
|
|
|
|
reserv = torch.cuda.memory_reserved()
|
|
|
|
|
max_alloc = torch.cuda.max_memory_allocated()
|
|
|
|
|
max_reserv = torch.cuda.max_memory_reserved()
|
|
|
|
|
dev = str(accelerator.device) if accelerator is not None else "cuda"
|
|
|
|
|
msg = (
|
|
|
|
|
f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} "
|
|
|
|
|
f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}"
|
|
|
|
|
)
|
|
|
|
|
if extra:
|
|
|
|
|
msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()])
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_path_stats(prefix: str, p: Path):
|
|
|
|
|
"""Log directory/file existence and file count (best-effort)."""
|
|
|
|
|
try:
|
|
|
|
|
exists = p.exists()
|
|
|
|
|
is_dir = p.is_dir() if exists else False
|
|
|
|
|
n_files = 0
|
|
|
|
|
if exists and is_dir:
|
|
|
|
|
n_files = sum(1 for x in p.iterdir() if x.is_file())
|
|
|
|
|
logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_args(args):
|
|
|
|
|
for k in sorted(vars(args).keys()):
|
|
|
|
|
logger.info(f"[args] {k}={getattr(args, k)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_tensor_meta(prefix: str, t: torch.Tensor | None):
|
|
|
|
|
if t is None:
|
|
|
|
|
logger.info(f"[tensor] {prefix} None")
|
|
|
|
|
return
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
|
|
|
|
|
f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Dataset
|
|
|
|
|
# -----------------------------
|
|
|
|
|
class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
"""基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。"""
|
|
|
|
|
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@ -120,19 +53,10 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
if class_data_root is not None:
|
|
|
|
|
self.class_data_root = Path(class_data_root)
|
|
|
|
|
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
# Only keep files to avoid directories affecting length.
|
|
|
|
|
self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()]
|
|
|
|
|
self.class_images_path = list(self.class_data_root.iterdir())
|
|
|
|
|
self.num_class_images = len(self.class_images_path)
|
|
|
|
|
self._length = max(self.num_class_images, self.num_instance_images)
|
|
|
|
|
self.class_prompt = class_prompt
|
|
|
|
|
|
|
|
|
|
# Early, explicit failure instead of ZeroDivisionError later.
|
|
|
|
|
if self.num_class_images == 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"class_data_dir is empty: {self.class_data_root}. "
|
|
|
|
|
f"Prior preservation requires class images. "
|
|
|
|
|
f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.class_data_root = None
|
|
|
|
|
|
|
|
|
|
@ -161,10 +85,8 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
).input_ids
|
|
|
|
|
|
|
|
|
|
if self.class_data_root:
|
|
|
|
|
if self.num_class_images == 0:
|
|
|
|
|
raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}")
|
|
|
|
|
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
|
|
|
|
if class_image.mode != "RGB":
|
|
|
|
|
if not class_image.mode == "RGB":
|
|
|
|
|
class_image = class_image.convert("RGB")
|
|
|
|
|
example["class_images"] = self.image_transforms(class_image)
|
|
|
|
|
example["class_prompt_ids"] = self.tokenizer(
|
|
|
|
|
@ -178,9 +100,6 @@ class DreamBoothDatasetFromTensor(Dataset):
|
|
|
|
|
return example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Model helper
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
|
|
|
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
@ -201,47 +120,217 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
|
|
|
|
raise ValueError(f"{model_class} is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Args
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def parse_args(input_args=None):
|
|
|
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
|
|
|
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--revision", type=str, default=None, required=False)
|
|
|
|
|
parser.add_argument("--tokenizer_name", type=str, default=None)
|
|
|
|
|
parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--class_data_dir", type=str, default=None, required=False)
|
|
|
|
|
parser.add_argument("--instance_prompt", type=str, default=None, required=True)
|
|
|
|
|
parser.add_argument("--class_prompt", type=str, default=None)
|
|
|
|
|
parser.add_argument("--with_prior_preservation", default=False, action="store_true")
|
|
|
|
|
parser.add_argument("--prior_loss_weight", type=float, default=1.0)
|
|
|
|
|
parser.add_argument("--num_class_images", type=int, default=100)
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="text-inversion-model")
|
|
|
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
|
|
|
parser.add_argument("--resolution", type=int, default=512)
|
|
|
|
|
parser.add_argument("--center_crop", default=False, action="store_true")
|
|
|
|
|
parser.add_argument("--train_text_encoder", action="store_true")
|
|
|
|
|
parser.add_argument("--train_batch_size", type=int, default=4)
|
|
|
|
|
parser.add_argument("--sample_batch_size", type=int, default=8)
|
|
|
|
|
parser.add_argument("--max_train_steps", type=int, default=20)
|
|
|
|
|
parser.add_argument("--max_f_train_steps", type=int, default=10)
|
|
|
|
|
parser.add_argument("--max_adv_train_steps", type=int, default=10)
|
|
|
|
|
parser.add_argument("--checkpointing_iterations", type=int, default=5)
|
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
|
|
|
|
parser.add_argument("--logging_dir", type=str, default="logs")
|
|
|
|
|
parser.add_argument("--allow_tf32", action="store_true")
|
|
|
|
|
parser.add_argument("--report_to", type=str, default="tensorboard")
|
|
|
|
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
|
|
|
|
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
|
|
|
|
|
parser.add_argument("--pgd_alpha", type=float, default=1.0 / 255)
|
|
|
|
|
parser.add_argument("--pgd_eps", type=int, default=0.05)
|
|
|
|
|
parser.add_argument("--target_image_path", default=None)
|
|
|
|
|
|
|
|
|
|
# Debug / diagnostics (low-overhead)
|
|
|
|
|
parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.")
|
|
|
|
|
parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.")
|
|
|
|
|
parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pretrained_model_name_or_path",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--revision",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=False,
|
|
|
|
|
help=(
|
|
|
|
|
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
|
|
|
|
" float32 precision."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tokenizer_name",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_data_dir_for_train",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="A folder containing the training data of instance images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_data_dir_for_adversarial",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="A folder containing the images to add adversarial noise",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--class_data_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=False,
|
|
|
|
|
help="A folder containing the training data of class images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--instance_prompt",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
required=True,
|
|
|
|
|
help="The prompt with identifier specifying the instance",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--class_prompt",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="The prompt to specify images in the same class as provided instance images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--with_prior_preservation",
|
|
|
|
|
default=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Flag to add prior preservation loss.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--prior_loss_weight",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0,
|
|
|
|
|
help="The weight of prior preservation loss.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--num_class_images",
|
|
|
|
|
type=int,
|
|
|
|
|
default=100,
|
|
|
|
|
help=(
|
|
|
|
|
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
|
|
|
|
" class_data_dir, additional images will be sampled with class_prompt."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="text-inversion-model",
|
|
|
|
|
help="The output directory where the model predictions and checkpoints will be written.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--resolution",
|
|
|
|
|
type=int,
|
|
|
|
|
default=512,
|
|
|
|
|
help=(
|
|
|
|
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
|
|
|
|
" resolution"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--center_crop",
|
|
|
|
|
default=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help=(
|
|
|
|
|
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
|
|
|
|
" cropped. The images will be resized to the resolution first before cropping."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_text_encoder",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_batch_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=4,
|
|
|
|
|
help="Batch size (per device) for the training dataloader.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--sample_batch_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=8,
|
|
|
|
|
help="Batch size (per device) for sampling images.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=20,
|
|
|
|
|
help="Total number of training steps to perform.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_f_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=10,
|
|
|
|
|
help="Total number of sub-steps to train surogate model.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_adv_train_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=10,
|
|
|
|
|
help="Total number of sub-steps to train adversarial noise.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--checkpointing_iterations",
|
|
|
|
|
type=int,
|
|
|
|
|
default=5,
|
|
|
|
|
help=("Save a checkpoint of the training state every X iterations."),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--learning_rate",
|
|
|
|
|
type=float,
|
|
|
|
|
default=5e-6,
|
|
|
|
|
help="Initial learning rate (after the potential warmup period) to use.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--logging_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="logs",
|
|
|
|
|
help=(
|
|
|
|
|
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
|
|
|
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--allow_tf32",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help=(
|
|
|
|
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
|
|
|
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--report_to",
|
|
|
|
|
type=str,
|
|
|
|
|
default="tensorboard",
|
|
|
|
|
help=(
|
|
|
|
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
|
|
|
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mixed_precision",
|
|
|
|
|
type=str,
|
|
|
|
|
default="fp16",
|
|
|
|
|
choices=["no", "fp16", "bf16"],
|
|
|
|
|
help=(
|
|
|
|
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
|
|
|
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
|
|
|
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable_xformers_memory_efficient_attention",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether or not to use xformers.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pgd_alpha",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0 / 255,
|
|
|
|
|
help="The step size for pgd.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--pgd_eps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0.05,
|
|
|
|
|
help="The noise budget for pgd.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--target_image_path",
|
|
|
|
|
default=None,
|
|
|
|
|
help="target image for attacking",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if input_args is not None:
|
|
|
|
|
args = parser.parse_args(input_args)
|
|
|
|
|
@ -251,11 +340,8 @@ def parse_args(input_args=None):
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Class image prompt dataset
|
|
|
|
|
# -----------------------------
|
|
|
|
|
class PromptDataset(Dataset):
|
|
|
|
|
"""用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。"""
|
|
|
|
|
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
|
|
|
|
|
|
|
|
|
def __init__(self, prompt, num_samples):
|
|
|
|
|
self.prompt = prompt
|
|
|
|
|
@ -271,9 +357,6 @@ class PromptDataset(Dataset):
|
|
|
|
|
return example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# IO
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
|
|
|
|
|
image_transforms = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
@ -289,10 +372,17 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
|
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Core routines
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20):
|
|
|
|
|
def train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
models,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
data_tensor: torch.Tensor,
|
|
|
|
|
num_steps=20,
|
|
|
|
|
):
|
|
|
|
|
# Load the tokenizer
|
|
|
|
|
|
|
|
|
|
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
|
|
|
|
|
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
|
|
|
|
|
|
|
|
|
|
@ -304,17 +394,17 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
|
|
|
|
|
eps=1e-08,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# IMPORTANT: only pass class_data_dir when with_prior_preservation is enabled.
|
|
|
|
|
train_dataset = DreamBoothDatasetFromTensor(
|
|
|
|
|
data_tensor,
|
|
|
|
|
args.instance_prompt,
|
|
|
|
|
tokenizer,
|
|
|
|
|
args.class_data_dir if args.with_prior_preservation else None,
|
|
|
|
|
args.class_data_dir,
|
|
|
|
|
args.class_prompt,
|
|
|
|
|
args.resolution,
|
|
|
|
|
args.center_crop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# weight_dtype = torch.bfloat16
|
|
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
@ -326,35 +416,33 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
|
|
|
|
|
unet.train()
|
|
|
|
|
text_encoder.train()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
step_data = train_dataset[step % len(train_dataset)]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
|
|
|
|
|
device, dtype=weight_dtype
|
|
|
|
|
)
|
|
|
|
|
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
|
|
|
|
|
except KeyError as e:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"[err] missing key in step_data at step={step}: missing={str(e)}. "
|
|
|
|
|
f"with_prior_preservation={args.with_prior_preservation}"
|
|
|
|
|
)
|
|
|
|
|
raise
|
|
|
|
|
step_data = train_dataset[step % len(train_dataset)]
|
|
|
|
|
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
|
|
|
|
|
device, dtype=weight_dtype
|
|
|
|
|
)
|
|
|
|
|
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
|
|
|
|
|
|
|
|
|
|
latents = vae.encode(pixel_values).latent_dist.sample()
|
|
|
|
|
latents = latents * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
# Sample noise that we'll add to the latents
|
|
|
|
|
noise = torch.randn_like(latents)
|
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
|
# Sample a random timestep for each image
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
|
|
|
timesteps = timesteps.long()
|
|
|
|
|
|
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
|
|
|
# (this is the forward diffusion process)
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
# Get the text embedding for conditioning
|
|
|
|
|
encoder_hidden_states = text_encoder(input_ids)[0]
|
|
|
|
|
|
|
|
|
|
# Predict the noise residual
|
|
|
|
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
|
|
|
|
|
|
# Get the target for loss depending on the prediction type
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
|
|
|
target = noise
|
|
|
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
|
|
|
@ -362,37 +450,47 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
|
|
|
|
|
|
# with prior preservation loss
|
|
|
|
|
if args.with_prior_preservation:
|
|
|
|
|
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
|
|
|
|
target, target_prior = torch.chunk(target, 2, dim=0)
|
|
|
|
|
|
|
|
|
|
# Compute instance loss
|
|
|
|
|
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
# Compute prior loss
|
|
|
|
|
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
# Add the prior loss to the instance loss.
|
|
|
|
|
loss = instance_loss + args.prior_loss_weight * prior_loss
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
prior_loss = torch.tensor(0.0, device=device)
|
|
|
|
|
instance_loss = torch.tensor(0.0, device=device)
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
|
|
|
|
|
optimizer.step()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} "
|
|
|
|
|
f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}"
|
|
|
|
|
print(
|
|
|
|
|
f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states
|
|
|
|
|
del model_pred, target, loss, prior_loss, instance_loss
|
|
|
|
|
|
|
|
|
|
del optimizer, train_dataset, params_to_optimize
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
return [unet, text_encoder]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps: int):
|
|
|
|
|
def pgd_attack(
|
|
|
|
|
args,
|
|
|
|
|
models,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
data_tensor: torch.Tensor,
|
|
|
|
|
original_images: torch.Tensor,
|
|
|
|
|
target_tensor: torch.Tensor,
|
|
|
|
|
num_steps: int,
|
|
|
|
|
):
|
|
|
|
|
"""Return new perturbed data"""
|
|
|
|
|
|
|
|
|
|
unet, text_encoder = models
|
|
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
@ -417,14 +515,24 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
|
|
|
|
|
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
|
|
|
|
|
latents = latents * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
# Sample noise that we'll add to the latents
|
|
|
|
|
noise = torch.randn_like(latents)
|
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
|
# Sample a random timestep for each image
|
|
|
|
|
#noise_scheduler.config.num_train_timesteps
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
|
|
|
timesteps = timesteps.long()
|
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
|
|
|
# (this is the forward diffusion process)
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
# Get the text embedding for conditioning
|
|
|
|
|
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
|
|
|
|
|
|
|
|
|
|
# Predict the noise residual
|
|
|
|
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
|
|
|
|
|
|
# Get the target for loss depending on the prediction type
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
|
|
|
target = noise
|
|
|
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
|
|
|
@ -432,10 +540,11 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
|
|
|
|
|
|
unet.zero_grad(set_to_none=True)
|
|
|
|
|
text_encoder.zero_grad(set_to_none=True)
|
|
|
|
|
unet.zero_grad()
|
|
|
|
|
text_encoder.zero_grad()
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
|
# target-shift loss
|
|
|
|
|
if target_tensor is not None:
|
|
|
|
|
xtm1_pred = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
@ -452,25 +561,16 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
alpha = args.pgd_alpha
|
|
|
|
|
alpha = args.pgd_alpha
|
|
|
|
|
eps = args.pgd_eps / 255
|
|
|
|
|
|
|
|
|
|
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
|
|
|
|
|
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
|
|
|
|
|
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
|
|
|
|
|
|
|
|
|
|
logger.info(f"[pgd] step={step} loss={loss.detach().item():.6f} alpha={alpha} eps={eps}")
|
|
|
|
|
|
|
|
|
|
del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss
|
|
|
|
|
del adv_images, eta
|
|
|
|
|
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
|
|
|
|
|
return perturbed_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------
|
|
|
|
|
# Main
|
|
|
|
|
# -----------------------------
|
|
|
|
|
def main(args):
|
|
|
|
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
|
|
|
|
|
|
|
|
@ -486,7 +586,6 @@ def main(args):
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
)
|
|
|
|
|
logger.info(accelerator.state, main_process_only=False)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
datasets.utils.logging.set_verbosity_warning()
|
|
|
|
|
transformers.utils.logging.set_verbosity_warning()
|
|
|
|
|
@ -496,35 +595,15 @@ def main(args):
|
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
|
diffusers.utils.logging.set_verbosity_error()
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[run] using_file={__file__}")
|
|
|
|
|
log_args(args)
|
|
|
|
|
|
|
|
|
|
if args.seed is not None:
|
|
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
if args.debug and accelerator.is_local_main_process:
|
|
|
|
|
log_cuda("startup", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Prior preservation: generate class images if needed
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Generate class images if prior preservation is enabled.
|
|
|
|
|
if args.with_prior_preservation:
|
|
|
|
|
if args.class_data_dir is None:
|
|
|
|
|
raise ValueError("--with_prior_preservation requires --class_data_dir")
|
|
|
|
|
if args.class_prompt is None:
|
|
|
|
|
raise ValueError("--with_prior_preservation requires --class_prompt")
|
|
|
|
|
|
|
|
|
|
class_images_dir = Path(args.class_data_dir)
|
|
|
|
|
class_images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
log_path_stats("class_dir_before", class_images_dir)
|
|
|
|
|
|
|
|
|
|
cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}")
|
|
|
|
|
|
|
|
|
|
if not class_images_dir.exists():
|
|
|
|
|
class_images_dir.mkdir(parents=True)
|
|
|
|
|
cur_class_images = len(list(class_images_dir.iterdir()))
|
|
|
|
|
if cur_class_images < args.num_class_images:
|
|
|
|
|
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
|
|
|
|
if args.mixed_precision == "fp32":
|
|
|
|
|
@ -533,12 +612,6 @@ def main(args):
|
|
|
|
|
torch_dtype = torch.float16
|
|
|
|
|
elif args.mixed_precision == "bf16":
|
|
|
|
|
torch_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}")
|
|
|
|
|
if args.debug:
|
|
|
|
|
log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
@ -548,67 +621,56 @@ def main(args):
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
num_new_images = args.num_class_images - cur_class_images
|
|
|
|
|
logger.info(f"Number of class images to sample: {num_new_images}.")
|
|
|
|
|
|
|
|
|
|
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
|
|
|
|
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
|
|
|
|
|
|
|
|
|
sample_dataloader = accelerator.prepare(sample_dataloader)
|
|
|
|
|
pipeline.to(accelerator.device)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
for example in tqdm(
|
|
|
|
|
sample_dataloader,
|
|
|
|
|
desc="Generating class images",
|
|
|
|
|
disable=not accelerator.is_local_main_process,
|
|
|
|
|
):
|
|
|
|
|
images = pipeline(example["prompt"]).images
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}")
|
|
|
|
|
|
|
|
|
|
for i, image in enumerate(images):
|
|
|
|
|
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
|
|
|
|
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
|
|
|
|
image.save(image_filename)
|
|
|
|
|
|
|
|
|
|
del pipeline, sample_dataset, sample_dataloader
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
|
|
|
|
|
final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[class_gen] done final_class_images={final_class_images}")
|
|
|
|
|
log_path_stats("class_dir_after", class_images_dir)
|
|
|
|
|
if final_class_images == 0:
|
|
|
|
|
raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[class_gen] skipped (already enough images)")
|
|
|
|
|
else:
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[class_gen] disabled (with_prior_preservation is False)")
|
|
|
|
|
del pipeline
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Load models / tokenizer / scheduler / VAE
|
|
|
|
|
# -------------------------
|
|
|
|
|
# import correct text encoder class
|
|
|
|
|
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
# Load scheduler and models
|
|
|
|
|
text_encoder = text_encoder_cls.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
subfolder="text_encoder",
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
)
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
subfolder="tokenizer",
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
use_fast=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda()
|
|
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
|
|
|
|
).cuda()
|
|
|
|
|
|
|
|
|
|
vae.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if not args.train_text_encoder:
|
|
|
|
|
@ -617,60 +679,52 @@ def main(args):
|
|
|
|
|
if args.allow_tf32:
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
clean_data = load_data(
|
|
|
|
|
args.instance_data_dir_for_train,
|
|
|
|
|
size=args.resolution,
|
|
|
|
|
center_crop=args.center_crop,
|
|
|
|
|
)
|
|
|
|
|
perturbed_data = load_data(
|
|
|
|
|
args.instance_data_dir_for_adversarial,
|
|
|
|
|
size=args.resolution,
|
|
|
|
|
center_crop=args.center_crop,
|
|
|
|
|
)
|
|
|
|
|
original_data = perturbed_data.clone()
|
|
|
|
|
original_data.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if args.enable_xformers_memory_efficient_attention:
|
|
|
|
|
if is_xformers_available():
|
|
|
|
|
unet.enable_xformers_memory_efficient_attention()
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info("[xformers] enabled")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync)
|
|
|
|
|
|
|
|
|
|
# -------------------------
|
|
|
|
|
# Load data tensors
|
|
|
|
|
# -------------------------
|
|
|
|
|
train_dir = Path(args.instance_data_dir_for_train)
|
|
|
|
|
adv_dir = Path(args.instance_data_dir_for_adversarial)
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_path_stats("train_dir", train_dir)
|
|
|
|
|
log_path_stats("adv_dir", adv_dir)
|
|
|
|
|
|
|
|
|
|
clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop)
|
|
|
|
|
perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop)
|
|
|
|
|
original_data = perturbed_data.clone()
|
|
|
|
|
original_data.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_tensor_meta("clean_data_cpu", clean_data)
|
|
|
|
|
log_tensor_meta("perturbed_data_cpu", perturbed_data)
|
|
|
|
|
|
|
|
|
|
target_latent_tensor = None
|
|
|
|
|
if args.target_image_path is not None:
|
|
|
|
|
target_image_path = Path(args.target_image_path)
|
|
|
|
|
if not target_image_path.is_file():
|
|
|
|
|
raise ValueError(f"Target image path does not exist: {target_image_path}")
|
|
|
|
|
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
|
|
|
|
|
|
|
|
|
|
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
|
|
|
|
|
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
|
|
|
|
|
|
|
|
|
|
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
|
|
|
|
|
target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16)
|
|
|
|
|
target_latent_tensor = target_latent_tensor * vae.config.scaling_factor
|
|
|
|
|
target_latent_tensor = (
|
|
|
|
|
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
|
|
|
|
|
)
|
|
|
|
|
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process and args.debug:
|
|
|
|
|
log_tensor_meta("target_latent_tensor", target_latent_tensor)
|
|
|
|
|
|
|
|
|
|
f = [unet, text_encoder]
|
|
|
|
|
for i in range(args.max_train_steps):
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[outer] i={i}/{args.max_train_steps}")
|
|
|
|
|
|
|
|
|
|
# 1. f' = f.clone()
|
|
|
|
|
f_sur = copy.deepcopy(f)
|
|
|
|
|
f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps)
|
|
|
|
|
|
|
|
|
|
f_sur = train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
f_sur,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
clean_data,
|
|
|
|
|
args.max_f_train_steps,
|
|
|
|
|
)
|
|
|
|
|
perturbed_data = pgd_attack(
|
|
|
|
|
args,
|
|
|
|
|
f_sur,
|
|
|
|
|
@ -682,31 +736,33 @@ def main(args):
|
|
|
|
|
target_latent_tensor,
|
|
|
|
|
args.max_adv_train_steps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps)
|
|
|
|
|
f = train_one_epoch(
|
|
|
|
|
args,
|
|
|
|
|
f,
|
|
|
|
|
tokenizer,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
vae,
|
|
|
|
|
perturbed_data,
|
|
|
|
|
args.max_f_train_steps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (i + 1) % args.checkpointing_iterations == 0:
|
|
|
|
|
save_folder = args.output_dir
|
|
|
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
|
|
|
noised_imgs = perturbed_data.detach()
|
|
|
|
|
|
|
|
|
|
img_filenames = [Path(instance_path).stem for instance_path in list(adv_dir.iterdir()) if instance_path.is_file()]
|
|
|
|
|
|
|
|
|
|
img_filenames = [
|
|
|
|
|
Path(instance_path).stem
|
|
|
|
|
for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for img_pixel, img_name in zip(noised_imgs, img_filenames):
|
|
|
|
|
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
|
|
|
|
|
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
|
|
|
|
|
Image.fromarray(
|
|
|
|
|
(img_pixel * 127.5 + 128)
|
|
|
|
|
.clamp(0, 255)
|
|
|
|
|
.to(torch.uint8)
|
|
|
|
|
.permute(1, 2, 0)
|
|
|
|
|
.cpu()
|
|
|
|
|
.numpy()
|
|
|
|
|
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
|
|
|
|
|
).save(save_path)
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process:
|
|
|
|
|
logger.info(f"[save] step={i+1} saved={len(img_filenames)} to {save_folder}")
|
|
|
|
|
|
|
|
|
|
_cuda_gc()
|
|
|
|
|
|
|
|
|
|
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|