Revert "improve: 优化算法"

This reverts commit b5af0d22ab.
pull/50/head
梁浩 4 months ago
parent b5af0d22ab
commit e447ec0984

@ -1,11 +1,9 @@
import argparse
import copy
import gc
import hashlib
import itertools
import logging
import os
import random
from pathlib import Path
import datasets
@ -26,77 +24,12 @@ from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
logger = get_logger(__name__)
# -----------------------------
# Lightweight debug helpers (low overhead)
# -----------------------------
def _cuda_gc() -> None:
"""Best-effort CUDA memory cleanup (does not change algorithmic behavior)."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _fmt_bytes(n: int) -> str:
return f"{n / (1024**2):.1f}MB"
def log_cuda(prefix: str, accelerator: Accelerator | None = None, sync: bool = False, extra: dict | None = None):
"""Log CUDA memory stats without copying tensors to CPU."""
if not torch.cuda.is_available():
logger.info(f"[mem] {prefix} cuda_not_available")
return
if sync:
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated()
reserv = torch.cuda.memory_reserved()
max_alloc = torch.cuda.max_memory_allocated()
max_reserv = torch.cuda.max_memory_reserved()
dev = str(accelerator.device) if accelerator is not None else "cuda"
msg = (
f"[mem] {prefix} dev={dev} alloc={_fmt_bytes(alloc)} reserv={_fmt_bytes(reserv)} "
f"max_alloc={_fmt_bytes(max_alloc)} max_reserv={_fmt_bytes(max_reserv)}"
)
if extra:
msg += " " + " ".join([f"{k}={v}" for k, v in extra.items()])
logger.info(msg)
def log_path_stats(prefix: str, p: Path):
"""Log directory/file existence and file count (best-effort)."""
try:
exists = p.exists()
is_dir = p.is_dir() if exists else False
n_files = 0
if exists and is_dir:
n_files = sum(1 for x in p.iterdir() if x.is_file())
logger.info(f"[path] {prefix} path={str(p)} exists={exists} is_dir={is_dir} files={n_files}")
except Exception as e:
logger.info(f"[path] {prefix} path={str(p)} stat_error={repr(e)}")
def log_args(args):
for k in sorted(vars(args).keys()):
logger.info(f"[args] {k}={getattr(args, k)}")
def log_tensor_meta(prefix: str, t: torch.Tensor | None):
if t is None:
logger.info(f"[tensor] {prefix} None")
return
logger.info(
f"[tensor] {prefix} shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
f"requires_grad={t.requires_grad} is_leaf={t.is_leaf}"
)
# -----------------------------
# Dataset
# -----------------------------
class DreamBoothDatasetFromTensor(Dataset):
"""基于内存张量的 DreamBooth 数据集:直接使用张量输入,返回图像与对应 prompt token。"""
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
def __init__(
self,
@ -120,19 +53,10 @@ class DreamBoothDatasetFromTensor(Dataset):
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
# Only keep files to avoid directories affecting length.
self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()]
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
# Early, explicit failure instead of ZeroDivisionError later.
if self.num_class_images == 0:
raise ValueError(
f"class_data_dir is empty: {self.class_data_root}. "
f"Prior preservation requires class images. "
f"Please generate class images first, or fix class_data_dir, or disable --with_prior_preservation."
)
else:
self.class_data_root = None
@ -161,10 +85,8 @@ class DreamBoothDatasetFromTensor(Dataset):
).input_ids
if self.class_data_root:
if self.num_class_images == 0:
raise ValueError(f"class_data_dir became empty at runtime: {self.class_data_root}")
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if class_image.mode != "RGB":
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
@ -178,9 +100,6 @@ class DreamBoothDatasetFromTensor(Dataset):
return example
# -----------------------------
# Model helper
# -----------------------------
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
@ -201,47 +120,217 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")
# -----------------------------
# Args
# -----------------------------
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True)
parser.add_argument("--revision", type=str, default=None, required=False)
parser.add_argument("--tokenizer_name", type=str, default=None)
parser.add_argument("--instance_data_dir_for_train", type=str, default=None, required=True)
parser.add_argument("--instance_data_dir_for_adversarial", type=str, default=None, required=True)
parser.add_argument("--class_data_dir", type=str, default=None, required=False)
parser.add_argument("--instance_prompt", type=str, default=None, required=True)
parser.add_argument("--class_prompt", type=str, default=None)
parser.add_argument("--with_prior_preservation", default=False, action="store_true")
parser.add_argument("--prior_loss_weight", type=float, default=1.0)
parser.add_argument("--num_class_images", type=int, default=100)
parser.add_argument("--output_dir", type=str, default="text-inversion-model")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--center_crop", default=False, action="store_true")
parser.add_argument("--train_text_encoder", action="store_true")
parser.add_argument("--train_batch_size", type=int, default=4)
parser.add_argument("--sample_batch_size", type=int, default=8)
parser.add_argument("--max_train_steps", type=int, default=20)
parser.add_argument("--max_f_train_steps", type=int, default=10)
parser.add_argument("--max_adv_train_steps", type=int, default=10)
parser.add_argument("--checkpointing_iterations", type=int, default=5)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument("--allow_tf32", action="store_true")
parser.add_argument("--report_to", type=str, default="tensorboard")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
parser.add_argument("--pgd_alpha", type=float, default=1.0 / 255)
parser.add_argument("--pgd_eps", type=int, default=0.05)
parser.add_argument("--target_image_path", default=None)
# Debug / diagnostics (low-overhead)
parser.add_argument("--debug", action="store_true", help="Enable detailed logs for failure points.")
parser.add_argument("--debug_cuda_sync", action="store_true", help="Synchronize CUDA for more accurate mem logs.")
parser.add_argument("--debug_step0_only", action="store_true", help="Only print per-step logs for step 0.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir_for_train",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--instance_data_dir_for_adversarial",
type=str,
default=None,
required=True,
help="A folder containing the images to add adversarial noise",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--prior_loss_weight",
type=float,
default=1.0,
help="The weight of prior preservation loss.",
)
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=8,
help="Batch size (per device) for sampling images.",
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--max_f_train_steps",
type=int,
default=10,
help="Total number of sub-steps to train surogate model.",
)
parser.add_argument(
"--max_adv_train_steps",
type=int,
default=10,
help="Total number of sub-steps to train adversarial noise.",
)
parser.add_argument(
"--checkpointing_iterations",
type=int,
default=5,
help=("Save a checkpoint of the training state every X iterations."),
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention",
action="store_true",
help="Whether or not to use xformers.",
)
parser.add_argument(
"--pgd_alpha",
type=float,
default=1.0 / 255,
help="The step size for pgd.",
)
parser.add_argument(
"--pgd_eps",
type=int,
default=0.05,
help="The noise budget for pgd.",
)
parser.add_argument(
"--target_image_path",
default=None,
help="target image for attacking",
)
if input_args is not None:
args = parser.parse_args(input_args)
@ -251,11 +340,8 @@ def parse_args(input_args=None):
return args
# -----------------------------
# Class image prompt dataset
# -----------------------------
class PromptDataset(Dataset):
"""用于批量生成 class 图像的提示词数据集,可在多 GPU 环境下并行采样。"""
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
@ -271,9 +357,6 @@ class PromptDataset(Dataset):
return example
# -----------------------------
# IO
# -----------------------------
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
image_transforms = transforms.Compose(
[
@ -289,10 +372,17 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
return images
# -----------------------------
# Core routines
# -----------------------------
def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor: torch.Tensor, num_steps=20):
def train_one_epoch(
args,
models,
tokenizer,
noise_scheduler,
vae,
data_tensor: torch.Tensor,
num_steps=20,
):
# Load the tokenizer
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
@ -304,17 +394,17 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
eps=1e-08,
)
# IMPORTANT: only pass class_data_dir when with_prior_preservation is enabled.
train_dataset = DreamBoothDatasetFromTensor(
data_tensor,
args.instance_prompt,
tokenizer,
args.class_data_dir if args.with_prior_preservation else None,
args.class_data_dir,
args.class_prompt,
args.resolution,
args.center_crop,
)
# weight_dtype = torch.bfloat16
weight_dtype = torch.bfloat16
device = torch.device("cuda")
@ -326,35 +416,33 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
unet.train()
text_encoder.train()
try:
step_data = train_dataset[step % len(train_dataset)]
except Exception as e:
logger.error(f"[err] train_one_epoch dataset getitem failed at step={step}: {repr(e)}")
raise
try:
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
device, dtype=weight_dtype
)
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
except KeyError as e:
logger.error(
f"[err] missing key in step_data at step={step}: missing={str(e)}. "
f"with_prior_preservation={args.with_prior_preservation}"
)
raise
step_data = train_dataset[step % len(train_dataset)]
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
device, dtype=weight_dtype
)
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -362,37 +450,47 @@ def train_one_epoch(args, models, tokenizer, noise_scheduler, vae, data_tensor:
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# with prior preservation loss
if args.with_prior_preservation:
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = instance_loss + args.prior_loss_weight * prior_loss
else:
prior_loss = torch.tensor(0.0, device=device)
instance_loss = torch.tensor(0.0, device=device)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss.backward()
torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
optimizer.step()
optimizer.zero_grad()
logger.info(
f"[train_one_epoch] step={step} loss={loss.detach().item():.6f} "
f"prior={prior_loss.detach().item():.6f} inst={instance_loss.detach().item():.6f}"
print(
f"Step #{step}, loss: {loss.detach().item()}, prior_loss: {prior_loss.detach().item()}, instance_loss: {instance_loss.detach().item()}"
)
del step_data, pixel_values, input_ids, latents, noise, timesteps, noisy_latents, encoder_hidden_states
del model_pred, target, loss, prior_loss, instance_loss
del optimizer, train_dataset, params_to_optimize
_cuda_gc()
return [unet, text_encoder]
def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, original_images, target_tensor, num_steps: int):
def pgd_attack(
args,
models,
tokenizer,
noise_scheduler,
vae,
data_tensor: torch.Tensor,
original_images: torch.Tensor,
target_tensor: torch.Tensor,
num_steps: int,
):
"""Return new perturbed data"""
unet, text_encoder = models
weight_dtype = torch.bfloat16
device = torch.device("cuda")
@ -417,14 +515,24 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
# Sample a random timestep for each image
#noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -432,10 +540,11 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet.zero_grad(set_to_none=True)
text_encoder.zero_grad(set_to_none=True)
unet.zero_grad()
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
if target_tensor is not None:
xtm1_pred = torch.cat(
[
@ -452,25 +561,16 @@ def pgd_attack(args, models, tokenizer, noise_scheduler, vae, data_tensor, origi
loss.backward()
alpha = args.pgd_alpha
alpha = args.pgd_alpha
eps = args.pgd_eps / 255
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
logger.info(f"[pgd] step={step} loss={loss.detach().item():.6f} alpha={alpha} eps={eps}")
del latents, noise, timesteps, noisy_latents, encoder_hidden_states, model_pred, target, loss
del adv_images, eta
_cuda_gc()
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
return perturbed_images
# -----------------------------
# Main
# -----------------------------
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
@ -486,7 +586,6 @@ def main(args):
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
@ -496,35 +595,15 @@ def main(args):
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if accelerator.is_local_main_process:
logger.info(f"[run] using_file={__file__}")
log_args(args)
if args.seed is not None:
set_seed(args.seed)
if args.debug and accelerator.is_local_main_process:
log_cuda("startup", accelerator, sync=args.debug_cuda_sync)
# -------------------------
# Prior preservation: generate class images if needed
# -------------------------
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("--with_prior_preservation requires --class_data_dir")
if args.class_prompt is None:
raise ValueError("--with_prior_preservation requires --class_prompt")
class_images_dir = Path(args.class_data_dir)
class_images_dir.mkdir(parents=True, exist_ok=True)
if accelerator.is_local_main_process:
log_path_stats("class_dir_before", class_images_dir)
cur_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
if accelerator.is_local_main_process:
logger.info(f"[class_gen] cur_class_images={cur_class_images} target={args.num_class_images}")
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.mixed_precision == "fp32":
@ -533,12 +612,6 @@ def main(args):
torch_dtype = torch.float16
elif args.mixed_precision == "bf16":
torch_dtype = torch.bfloat16
if accelerator.is_local_main_process:
logger.info(f"[class_gen] will_generate={args.num_class_images - cur_class_images} torch_dtype={torch_dtype}")
if args.debug:
log_cuda("before_pipeline_load", accelerator, sync=args.debug_cuda_sync)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
@ -548,67 +621,56 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
if accelerator.is_local_main_process and args.debug:
log_cuda("after_pipeline_to_device", accelerator, sync=args.debug_cuda_sync)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
if accelerator.is_local_main_process and args.debug:
logger.info(f"[class_gen] batch_prompts={len(example['prompt'])} generated_images={len(images)}")
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline, sample_dataset, sample_dataloader
_cuda_gc()
accelerator.wait_for_everyone()
final_class_images = sum(1 for p in class_images_dir.iterdir() if p.is_file())
if accelerator.is_local_main_process:
logger.info(f"[class_gen] done final_class_images={final_class_images}")
log_path_stats("class_dir_after", class_images_dir)
if final_class_images == 0:
raise RuntimeError(f"class image generation failed: {class_images_dir} is still empty.")
else:
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
logger.info("[class_gen] skipped (already enough images)")
else:
if accelerator.is_local_main_process:
logger.info("[class_gen] disabled (with_prior_preservation is False)")
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# -------------------------
# Load models / tokenizer / scheduler / VAE
# -------------------------
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
if accelerator.is_local_main_process and args.debug:
log_cuda("before_load_models", accelerator, sync=args.debug_cuda_sync)
# Load scheduler and models
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda()
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
).cuda()
vae.requires_grad_(False)
if not args.train_text_encoder:
@ -617,60 +679,52 @@ def main(args):
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
clean_data = load_data(
args.instance_data_dir_for_train,
size=args.resolution,
center_crop=args.center_crop,
)
perturbed_data = load_data(
args.instance_data_dir_for_adversarial,
size=args.resolution,
center_crop=args.center_crop,
)
original_data = perturbed_data.clone()
original_data.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
if accelerator.is_local_main_process:
logger.info("[xformers] enabled")
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if accelerator.is_local_main_process and args.debug:
log_cuda("after_load_models", accelerator, sync=args.debug_cuda_sync)
# -------------------------
# Load data tensors
# -------------------------
train_dir = Path(args.instance_data_dir_for_train)
adv_dir = Path(args.instance_data_dir_for_adversarial)
if accelerator.is_local_main_process and args.debug:
log_path_stats("train_dir", train_dir)
log_path_stats("adv_dir", adv_dir)
clean_data = load_data(train_dir, size=args.resolution, center_crop=args.center_crop)
perturbed_data = load_data(adv_dir, size=args.resolution, center_crop=args.center_crop)
original_data = perturbed_data.clone()
original_data.requires_grad_(False)
if accelerator.is_local_main_process and args.debug:
log_tensor_meta("clean_data_cpu", clean_data)
log_tensor_meta("perturbed_data_cpu", perturbed_data)
target_latent_tensor = None
if args.target_image_path is not None:
target_image_path = Path(args.target_image_path)
if not target_image_path.is_file():
raise ValueError(f"Target image path does not exist: {target_image_path}")
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16)
target_latent_tensor = target_latent_tensor * vae.config.scaling_factor
target_latent_tensor = (
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
)
target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
if accelerator.is_local_main_process and args.debug:
log_tensor_meta("target_latent_tensor", target_latent_tensor)
f = [unet, text_encoder]
for i in range(args.max_train_steps):
if accelerator.is_local_main_process:
logger.info(f"[outer] i={i}/{args.max_train_steps}")
# 1. f' = f.clone()
f_sur = copy.deepcopy(f)
f_sur = train_one_epoch(args, f_sur, tokenizer, noise_scheduler, vae, clean_data, args.max_f_train_steps)
f_sur = train_one_epoch(
args,
f_sur,
tokenizer,
noise_scheduler,
vae,
clean_data,
args.max_f_train_steps,
)
perturbed_data = pgd_attack(
args,
f_sur,
@ -682,31 +736,33 @@ def main(args):
target_latent_tensor,
args.max_adv_train_steps,
)
f = train_one_epoch(args, f, tokenizer, noise_scheduler, vae, perturbed_data, args.max_f_train_steps)
f = train_one_epoch(
args,
f,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
args.max_f_train_steps,
)
if (i + 1) % args.checkpointing_iterations == 0:
save_folder = args.output_dir
os.makedirs(save_folder, exist_ok=True)
noised_imgs = perturbed_data.detach()
img_filenames = [Path(instance_path).stem for instance_path in list(adv_dir.iterdir()) if instance_path.is_file()]
img_filenames = [
Path(instance_path).stem
for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
]
for img_pixel, img_name in zip(noised_imgs, img_filenames):
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
save_path = os.path.join(save_folder, f"perturbed_{img_name}.png")
Image.fromarray(
(img_pixel * 127.5 + 128)
.clamp(0, 255)
.to(torch.uint8)
.permute(1, 2, 0)
.cpu()
.numpy()
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
).save(save_path)
if accelerator.is_local_main_process:
logger.info(f"[save] step={i+1} saved={len(img_filenames)} to {save_folder}")
_cuda_gc()
print(f"Saved perturbed images at step {i+1} to {save_folder} (Files are overwritten)")
if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save