From 98796e154377e4537fabc78e460134a214e06428 Mon Sep 17 00:00:00 2001 From: Ryan <3266408525@qq.com> Date: Sat, 13 Dec 2025 21:59:32 +0800 Subject: [PATCH 1/3] =?UTF-8?q?improve:=20=E6=94=B9=E8=BF=9BDreambooth?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E8=B6=85=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../algorithms/finetune/train_db_gen_trace.py | 1253 +++++------------ src/backend/app/scripts/finetune_db.sh | 68 +- 2 files changed, 386 insertions(+), 935 deletions(-) diff --git a/src/backend/app/algorithms/finetune/train_db_gen_trace.py b/src/backend/app/algorithms/finetune/train_db_gen_trace.py index ae9b980..34efebc 100644 --- a/src/backend/app/algorithms/finetune/train_db_gen_trace.py +++ b/src/backend/app/algorithms/finetune/train_db_gen_trace.py @@ -1,35 +1,21 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and import argparse +import contextlib import copy import gc -import importlib import itertools import logging import math import os -import shutil import warnings from pathlib import Path -import pandas as pd import numpy as np +import pandas as pd import torch import torch.nn.functional as F -import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -44,7 +30,6 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -54,29 +39,33 @@ from diffusers import ( ) from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -# check_min_version("0.30.0.dev0") - logger = get_logger(__name__) +# ------------------------------------------------------------------------- +# 功能模块:模型卡保存 +# 1) 该模块用于生成/更新 README.md,记录训练来源与关键配置 +# 2) 支持将训练后验证生成的示例图片写入输出目录并写入引用 +# 3) 便于后续将模型上传到 Hub 时展示效果与实验信息 +# 4) 不参与训练与梯度计算,不影响参数更新与收敛行为 +# 5) 既可服务于 Hub 发布,也可用于本地实验的结果归档 +# ------------------------------------------------------------------------- def save_model_card( repo_id: str, - images: list = None, - base_model: str = None, - train_text_encoder=False, - prompt: str = None, - repo_folder: str = None, - pipeline: DiffusionPipeline = None, + images: list | None = None, + base_model: str | None = None, + train_text_encoder: bool = False, + prompt: str | None = None, + repo_folder: str | None = None, + pipeline: DiffusionPipeline | None = None, ): img_str = "" if images is not None: @@ -87,11 +76,13 @@ def save_model_card( model_description = f""" # DreamBooth - {repo_id} -This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). -You can find some example images in the following. \n +本模型由 {base_model} 进行 DreamBooth 微调得到。 +训练使用的实例提示词为:{prompt} + +下面展示部分训练后推理样例图像: {img_str} -DreamBooth for the text encoder was enabled: {train_text_encoder}. +是否训练了文本编码器:{train_text_encoder} """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, @@ -103,111 +94,107 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}. inference=True, ) - tags = ["text-to-image", "dreambooth", "diffusers-training"] - if isinstance(pipeline, StableDiffusionPipeline): - tags.extend(["stable-diffusion", "stable-diffusion-diffusers"]) - else: - tags.extend(["if", "if-diffusers"]) + tags = ["text-to-image", "dreambooth", "diffusers-training", "stable-diffusion", "stable-diffusion-diffusers"] model_card = populate_model_card(model_card, tags=tags) - model_card.save(os.path.join(repo_folder, "README.md")) -def log_validation( - text_encoder, - tokenizer, - unet, - vae, - args, - accelerator, - weight_dtype, - global_step, - prompt_embeds, - negative_prompt_embeds, +# ------------------------------------------------------------------------- +# 功能模块:训练后纯文本推理(validation) +# 1) 该模块仅在训练完全结束后执行,不参与训练过程与优化器状态 +# 2) 该模块从 output_dir 重新加载微调后的 pipeline,避免与训练对象耦合 +# 3) 推理只接受文本提示词,不输入任何图像,不走 img2img 相关路径 +# 4) 可设置推理步数与随机种子,方便提高细节并保证可复现 +# 5) 输出 PIL 图片列表,可保存到目录并写入日志系统便于对比分析 +# ------------------------------------------------------------------------- +def run_validation_txt2img( + finetuned_model_dir: str, + prompt: str, + negative_prompt: str, + num_images: int, + num_inference_steps: int, + guidance_scale: float, + seed: int | None, + accelerator: Accelerator, + weight_dtype: torch.dtype, + global_step: int, ): logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." + f"开始 validation 文生图:数量={num_images},步数={num_inference_steps},guidance={guidance_scale},提示词={prompt}" ) - pipeline_args = {} - - if vae is not None: - pipeline_args["vae"] = vae - - # create pipeline (note: unet and vae are loaded again in float32) - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - tokenizer=tokenizer, - text_encoder=text_encoder, - unet=unet, - revision=args.revision, - variant=args.variant, + pipe = StableDiffusionPipeline.from_pretrained( + finetuned_model_dir, torch_dtype=weight_dtype, - **pipeline_args, + safety_checker=None, + local_files_only=True, ) - - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if not isinstance(pipe, StableDiffusionPipeline): + raise TypeError(f"加载的 pipeline 类型异常:{type(pipe)},需要 StableDiffusionPipeline 才能保证纯文本生图。") - scheduler_args["variance_type"] = variance_type + pipe = pipe.to(accelerator.device) + pipe.set_progress_bar_config(disable=True) + pipe.safety_checker = lambda images, clip_input: (images, [False for _ in range(len(images))]) - module = importlib.import_module("diffusers") - scheduler_class = getattr(module, args.validation_scheduler) - pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - pipeline.safety_checker = lambda images, clip_input: (images, [False for i in range(0, len(images))]) # disable safety checker + pipe.enable_attention_slicing() + pipe.enable_vae_slicing() - if args.pre_compute_text_embeddings: - pipeline_args = { - "prompt_embeds": prompt_embeds, - "negative_prompt_embeds": negative_prompt_embeds, - } + if accelerator.device.type == "cuda": + if accelerator.mixed_precision == "bf16": + infer_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + elif accelerator.mixed_precision == "fp16": + infer_ctx = torch.autocast(device_type="cuda", dtype=torch.float16) + else: + infer_ctx = contextlib.nullcontext() else: - pipeline_args = {"prompt": args.validation_prompt} + infer_ctx = contextlib.nullcontext() - # run inference - generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - if args.validation_images is None: - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) - else: - for image in args.validation_images: - image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + with infer_ctx: + for i in range(num_images): + generator = None + if seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(seed + i) + + out = pipe( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt is not None else "", + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + ) + images.append(out.images[0]) for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + tracker.writer.add_images("validation_txt2img", np_images, global_step, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + "validation_txt2img": [ + wandb.Image(img, caption=f"{i}: {prompt}") for i, img in enumerate(images) ] } ) - del pipeline - torch.cuda.empty_cache() + del pipe + if torch.cuda.is_available(): + torch.cuda.empty_cache() return images -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): +# ------------------------------------------------------------------------- +# 功能模块:从模型目录推断 TextEncoder 类型 +# 1) 不同扩散模型对应不同文本编码器架构,需动态识别加载类 +# 2) 通过读取 text_encoder/config 来获取 architectures 字段 +# 3) 该模块返回类对象,用于后续 from_pretrained 加载权重 +# 4) 便于同一训练脚本兼容多模型,而不写死具体实现 +# 5) 若架构不支持会直接报错,避免训练过程走到一半才失败 +# ------------------------------------------------------------------------- +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str | None): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -228,382 +215,107 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st return T5EncoderModel else: - raise ValueError(f"{model_class} is not supported.") + raise ValueError(f"{model_class} 不受支持。") +# ------------------------------------------------------------------------- +# 功能模块:命令行参数解析 +# 1) 本模块定义 DreamBooth 训练参数与训练后 validation 参数 +# 2) 训练负责微调权重与记录坐标,validation 只负责训练后文生图输出 +# 3) 不提供训练中间验证参数,避免任何中途采样影响训练流程 +# 4) 对关键参数组合做合法性检查,减少运行中途异常 +# 5) 支持通过 shell 脚本传参实现批量实验、对比与复现 +# ------------------------------------------------------------------------- def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--variant", - type=str, - default=None, - help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--instance_data_dir", - type=str, - default=None, - required=True, - help="A folder containing the training data of instance images.", - ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - required=True, - help="The prompt with identifier specifying the instance", - ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided instance images.", - ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", - ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="dreambooth-model", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." - ) + parser = argparse.ArgumentParser(description="DreamBooth 训练脚本(训练后纯文字生图 validation)") + + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--variant", type=str, default=None) + parser.add_argument("--tokenizer_name", type=str, default=None) + + parser.add_argument("--instance_data_dir", type=str, default=None, required=True) + parser.add_argument("--class_data_dir", type=str, default=None, required=False) + + parser.add_argument("--instance_prompt", type=str, default=None, required=True) + parser.add_argument("--class_prompt", type=str, default=None) + + parser.add_argument("--with_prior_preservation", default=False, action="store_true") + parser.add_argument("--prior_loss_weight", type=float, default=1.0) + parser.add_argument("--num_class_images", type=int, default=100) + + parser.add_argument("--output_dir", type=str, default="dreambooth-model") + parser.add_argument("--seed", type=int, default=None) + + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--center_crop", default=False, action="store_true") + + parser.add_argument("--train_text_encoder", action="store_true") + + parser.add_argument("--train_batch_size", type=int, default=4) + parser.add_argument("--sample_batch_size", type=int, default=4) + parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " - "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." - "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." - "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" - "instructions." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more details" - ), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-6, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) + parser.add_argument("--max_train_steps", type=int, default=None) + + parser.add_argument("--checkpointing_steps", type=int, default=500) + + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_checkpointing", action="store_true") + + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--scale_lr", action="store_true", default=False) + parser.add_argument( "--lr_scheduler", type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="A prompt that is used during validation to verify that the model is learning.", - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=4, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=100, - help=( - "Run validation every X steps. Validation consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`" - " and logging the images." - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) - parser.add_argument( - "--set_grads_to_none", - action="store_true", - help=( - "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" - " behaviors, so disable this argument if it causes any problems. More info:" - " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" - ), + default="cosine_with_restarts", + help='学习率调度策略,可用 "linear"/"cosine"/"cosine_with_restarts"/"constant_with_warmup" 等', ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + parser.add_argument("--lr_num_cycles", type=int, default=1) + parser.add_argument("--lr_power", type=float, default=1.0) - parser.add_argument( - "--offset_noise", - action="store_true", - default=False, - help=( - "Fine-tuning against a modified noise" - " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." - ), - ) - parser.add_argument( - "--snr_gamma", - type=float, - default=None, - help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " - "More details here: https://arxiv.org/abs/2303.09556.", - ) - parser.add_argument( - "--pre_compute_text_embeddings", - action="store_true", - help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", - ) - parser.add_argument( - "--tokenizer_max_length", - type=int, - default=None, - required=False, - help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", - ) - parser.add_argument( - "--text_encoder_use_attention_mask", - action="store_true", - required=False, - help="Whether to use attention mask for the text encoder", - ) - parser.add_argument( - "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" - ) - parser.add_argument( - "--validation_images", - required=False, - default=None, - nargs="+", - help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", - ) - parser.add_argument( - "--class_labels_conditioning", - required=False, - default=None, - help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", - ) - parser.add_argument( - "--validation_scheduler", - type=str, - default="DPMSolverMultistepScheduler", - choices=["DPMSolverMultistepScheduler", "DDPMScheduler"], - help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.", - ) + parser.add_argument("--use_8bit_adam", action="store_true") + parser.add_argument("--dataloader_num_workers", type=int, default=0) - parser.add_argument( - "--validation_image_output_dir", - type=str, - default=None, - help="The directory where validation images will be saved. If None, images will be saved inside a subdirectory of `output_dir`.", - ) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--max_grad_norm", default=1.0, type=float) - # [START] 为可视化方案增加的参数 (通用指标) - parser.add_argument( - "--coords_save_path", - type=str, - default=None, - help="The path to save the intermediate coordinates (X, Y, Z metrics) for visualization.", - ) - parser.add_argument( - "--coords_log_interval", - type=int, - default=10, - help="Log and record intermediate coordinates every X steps.", - ) - # [END] 为可视化方案增加的参数 (通用指标) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument("--report_to", type=str, default="tensorboard") + + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) + parser.add_argument("--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"]) + parser.add_argument("--local_rank", type=int, default=-1) + + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") + parser.add_argument("--set_grads_to_none", action="store_true") + + parser.add_argument("--offset_noise", action="store_true", default=False) + parser.add_argument("--snr_gamma", type=float, default=None) + + parser.add_argument("--tokenizer_max_length", type=int, default=None, required=False) + parser.add_argument("--text_encoder_use_attention_mask", action="store_true", required=False) + parser.add_argument("--skip_save_text_encoder", action="store_true", required=False) + + parser.add_argument("--validation_prompt", type=str, required=True) + parser.add_argument("--validation_negative_prompt", type=str, default="") + parser.add_argument("--num_validation_images", type=int, default=10) + parser.add_argument("--validation_num_inference_steps", type=int, default=100) + parser.add_argument("--validation_guidance_scale", type=float, default=7.5) + parser.add_argument("--validation_image_output_dir", type=str, required=True) + + parser.add_argument("--coords_save_path", type=str, default=None) + parser.add_argument("--coords_log_interval", type=int, default=10) if input_args is not None: args = parser.parse_args(input_args) @@ -616,28 +328,27 @@ def parse_args(input_args=None): if args.with_prior_preservation: if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") + raise ValueError("启用先验保持时必须提供 class_data_dir。") if args.class_prompt is None: - raise ValueError("You must specify prompt for class images.") + raise ValueError("启用先验保持时必须提供 class_prompt。") else: - # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn("未启用先验保持时无需提供 class_data_dir。") if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") - - if args.train_text_encoder and args.pre_compute_text_embeddings: - raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + warnings.warn("未启用先验保持时无需提供 class_prompt。") return args +# ------------------------------------------------------------------------- +# 功能模块:DreamBooth 训练数据集 +# 1) 从 instance 与 class 目录读取图像,并统一做尺寸、裁剪与归一化 +# 2) 同时提供实例提示词与类别提示词的 token id 作为文本输入 +# 3) 先验保持模式下会返回两套图像与文本信息用于拼接训练 +# 4) 数据集长度按 instance 与 class 的最大值取,便于循环采样 +# 5) 数据集只负责准备输入,模型推理、损失计算与优化在主循环中完成 +# ------------------------------------------------------------------------- class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ - def __init__( self, instance_data_root, @@ -648,22 +359,18 @@ class DreamBoothDataset(Dataset): class_num=None, size=512, center_crop=False, - encoder_hidden_states=None, - class_prompt_encoder_hidden_states=None, tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer - self.encoder_hidden_states = encoder_hidden_states - self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): - raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") + raise ValueError(f"实例图像目录不存在:{self.instance_data_root}") - self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.is_file()] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -671,7 +378,7 @@ class DreamBoothDataset(Dataset): if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) + self.class_images_path = [p for p in self.class_data_root.iterdir() if p.is_file()] if class_num is not None: self.num_class_images = min(len(self.class_images_path), class_num) else: @@ -693,82 +400,77 @@ class DreamBoothDataset(Dataset): def __len__(self): return self._length + def _tokenize(self, prompt: str): + max_length = self.tokenizer_max_length if self.tokenizer_max_length is not None else self.tokenizer.model_max_length + return self.tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + def __getitem__(self, index): example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": + if instance_image.mode != "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - if self.encoder_hidden_states is not None: - example["instance_prompt_ids"] = self.encoder_hidden_states - else: - text_inputs = tokenize_prompt( - self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["instance_prompt_ids"] = text_inputs.input_ids - example["instance_attention_mask"] = text_inputs.attention_mask + text_inputs = self._tokenize(self.instance_prompt) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": + if class_image.mode != "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - if self.class_prompt_encoder_hidden_states is not None: - example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states - else: - class_text_inputs = tokenize_prompt( - self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["class_prompt_ids"] = class_text_inputs.input_ids - example["class_attention_mask"] = class_text_inputs.attention_mask + class_text_inputs = self._tokenize(self.class_prompt) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example +# ------------------------------------------------------------------------- +# 功能模块:批处理拼接与张量规整 +# 1) 将单条样本组成的列表拼接为 batch 字典,供训练循环直接使用 +# 2) 将图像张量 stack 成 (B,C,H,W) 并转换为 float,提高后续 VAE 兼容性 +# 3) 将 input_ids 与 attention_mask 在 batch 维度 cat,便于文本编码器计算 +# 4) 先验保持模式下将 instance 与 class 在 batch 维度拼接,减少前向次数 +# 5) 该模块不做任何损失与梯度计算,只负责打包输入数据结构 +# ------------------------------------------------------------------------- def collate_fn(examples, with_prior_preservation=False): - has_attention_mask = "instance_attention_mask" in examples[0] - input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + attention_mask = [example["instance_attention_mask"] for example in examples] - if has_attention_mask: - attention_mask = [example["instance_attention_mask"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + attention_mask += [example["class_attention_mask"] for example in examples] - if has_attention_mask: - attention_mask += [example["class_attention_mask"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - + pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() input_ids = torch.cat(input_ids, dim=0) + attention_mask = torch.cat(attention_mask, dim=0) - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - } - - if has_attention_mask: - attention_mask = torch.cat(attention_mask, dim=0) - batch["attention_mask"] = attention_mask - - return batch + return {"input_ids": input_ids, "pixel_values": pixel_values, "attention_mask": attention_mask} +# ------------------------------------------------------------------------- +# 功能模块:生成 class 图像的提示词数据集 +# 1) 该数据集用于先验保持时批量生成类别图像,提供固定 prompt +# 2) 每条样本返回 prompt 与索引,索引用于生成稳定的文件名 +# 3) 与训练数据集分离,避免采样逻辑影响训练数据读取与增强 +# 4) 支持多进程环境下由 accelerate 分配采样 batch,提高生成效率 +# 5) 该模块只在 with_prior_preservation 启用且 class 数据不足时使用 +# ------------------------------------------------------------------------- class PromptDataset(Dataset): - """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" - def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples @@ -777,66 +479,57 @@ class PromptDataset(Dataset): return self.num_samples def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example + return {"prompt": self.prompt, "index": index} +# ------------------------------------------------------------------------- +# 功能模块:判断预训练模型是否包含 VAE +# 1) 通过检查 vae/config.json 是否存在来决定是否加载 VAE +# 2) 同时支持本地目录与 Hub 结构,便于离线缓存模式运行 +# 3) 若不存在 VAE 子目录,将跳过加载并在训练中使用像素空间输入 +# 4) 该判断只发生在初始化阶段,不影响训练过程与日志记录 +# 5) 对 Stable Diffusion 类模型通常都会包含 VAE,属于常规路径 +# ------------------------------------------------------------------------- def model_has_vae(args): config_file_name = Path("vae", AutoencoderKL.config_name).as_posix() if os.path.isdir(args.pretrained_model_name_or_path): config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) return os.path.isfile(config_file_name) - else: - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings - return any(file.rfilename == config_file_name for file in files_in_repo) - -def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): - if tokenizer_max_length is not None: - max_length = tokenizer_max_length - else: - max_length = tokenizer.model_max_length - - text_inputs = tokenizer( - prompt, - truncation=True, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) - return text_inputs - -def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): +# ------------------------------------------------------------------------- +# 功能模块:文本编码器前向 +# 1) 将 input_ids 与 attention_mask 输入 text encoder 得到条件嵌入 +# 2) 可选择是否启用 attention_mask,以适配不同文本编码器行为 +# 3) 输出的 prompt_embeds 作为 UNet 条件输入,影响生成语义与身份绑定 +# 4) 该函数在训练循环中频繁调用,需要保持设备与 dtype 的一致性 +# 5) 返回张量为 (B, T, D),后续会与 timestep 一起输入 UNet +# ------------------------------------------------------------------------- +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask: bool): text_input_ids = input_ids.to(text_encoder.device) - if text_encoder_use_attention_mask: attention_mask = attention_mask.to(text_encoder.device) else: attention_mask = None - - prompt_embeds = text_encoder( - text_input_ids, - attention_mask=attention_mask, - return_dict=False, - ) - prompt_embeds = prompt_embeds[0] - - return prompt_embeds + return text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False)[0] +# ------------------------------------------------------------------------- +# 功能模块:主训练流程 +# 1) 负责构建 accelerate 环境、加载模型组件、准备数据与优化器 +# 2) 支持先验保持:自动补足 class 图像并将 instance/class 合并训练 +# 3) 训练循环中记录 loss、学习率与坐标指标,输出 CSV 便于可视化分析 +# 4) 训练结束后保存微调后的 pipeline 到 output_dir,作为独立可用模型 +# 5) 在保存完成后运行 validation,仅用提示词进行文生图并将结果写入输出目录 +# ------------------------------------------------------------------------- def main(args): if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) + raise ValueError("不要同时使用 wandb 与 hub_token,避免凭证泄露风险。") logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( @@ -846,46 +539,28 @@ def main(args): project_config=accelerator_project_config, ) - # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - - # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate - # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() + warnings.filterwarnings("ignore", category=UserWarning) else: transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) - # Generate class images if prior preservation is enabled. if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) + class_images_dir.mkdir(parents=True, exist_ok=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: @@ -896,52 +571,49 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = DiffusionPipeline.from_pretrained( + + pipe = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, variant=args.variant, ) - pipeline.set_progress_bar_config(disable=True) + pipe.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") + logger.info(f"需要补足 class 图像数量:{num_new_images}") sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) - - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): - images = pipeline(example["prompt"]).images + pipe.to(accelerator.device) + for example in tqdm(sample_dataloader, desc="生成 class 图像", disable=not accelerator.is_local_main_process): + images = pipe(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - del pipeline + del pipe if torch.cuda.is_available(): torch.cuda.empty_cache() - # Handle the repository creation if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - + os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, ).repo_id + else: + repo_id = None - # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) - elif args.pretrained_model_name_or_path: + else: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", @@ -949,21 +621,18 @@ def main(args): use_fast=False, ) - # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) + vae = None if model_has_vae(args): vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) - else: - vae = None unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant @@ -974,30 +643,22 @@ def main(args): model = model._orig_mod if is_compiled_module(model) else model return model - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for model in models: sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) - - # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): while len(models) > 0: - # pop models so that they are not loaded again model = models.pop() - if isinstance(model, type(unwrap_model(text_encoder))): - # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config else: - # load diffusers style into model load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") model.register_to_config(**load_model.config) - model.load_state_dict(load_model.state_dict()) del load_model @@ -1011,61 +672,33 @@ def main(args): text_encoder.requires_grad_(False) if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers + if not is_xformers_available(): + raise ValueError("xformers 不可用,请确认安装成功。") + import xformers - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warning( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + if version.parse(xformers.__version__) == version.parse("0.0.16"): + logger.warning("xformers 0.0.16 在部分 GPU 上训练不稳定,建议升级。") + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder.gradient_checkpointing_enable() - # Check that all trainable models are in full precision - low_precision_error_string = ( - "Please make sure to always have all model weights in full float32 precision when starting training - even if" - " doing mixed precision training. copy of the weights should still be float32." - ) - - if unwrap_model(unet).dtype != torch.float32: - raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") - - if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: - raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" - ) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + optimizer_class = torch.optim.AdamW if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - + raise ImportError("使用 8-bit Adam 需要安装 bitsandbytes:pip install bitsandbytes") optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - # Optimizer creation params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) @@ -1077,45 +710,6 @@ def main(args): eps=args.adam_epsilon, ) - if args.pre_compute_text_embeddings: - - def compute_text_embeddings(prompt): - with torch.no_grad(): - text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) - prompt_embeds = encode_prompt( - text_encoder, - text_inputs.input_ids, - text_inputs.attention_mask, - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) - - return prompt_embeds - - pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) - validation_prompt_negative_prompt_embeds = compute_text_embeddings("") - - if args.validation_prompt is not None: - validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) - else: - validation_prompt_encoder_hidden_states = None - - if args.class_prompt is not None: - pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) - else: - pre_computed_class_prompt_encoder_hidden_states = None - - text_encoder = None - tokenizer = None - - gc.collect() - torch.cuda.empty_cache() - else: - pre_computed_encoder_hidden_states = None - validation_prompt_encoder_hidden_states = None - validation_prompt_negative_prompt_embeds = None - pre_computed_class_prompt_encoder_hidden_states = None - - # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, @@ -1125,8 +719,6 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, - encoder_hidden_states=pre_computed_encoder_hidden_states, - class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, tokenizer_max_length=args.tokenizer_max_length, ) @@ -1138,7 +730,6 @@ def main(args): num_workers=args.dataloader_num_workers, ) - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: @@ -1154,7 +745,6 @@ def main(args): power=args.lr_power, ) - # Prepare everything with our `accelerator`. if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -1164,234 +754,136 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler ) - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move vae and text_encoder to device and cast to weight_dtype if vae is not None: vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder and text_encoder is not None: + if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = vars(copy.deepcopy(args)) - tracker_config.pop("validation_images") accelerator.init_trackers("dreambooth", config=tracker_config) - # [START] 为可视化方案增加的初始化 (通用指标) coords_list = [] - # [END] 为可视化方案增加的初始化 (通用指标) - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** 开始训练 *****") + logger.info(f"样本数:{len(train_dataset)}") + logger.info(f"每个 epoch 的 batch 数:{len(train_dataloader)}") + logger.info(f"epoch 数:{args.num_train_epochs}") + logger.info(f"单卡 batch_size:{args.train_batch_size}") + logger.info(f"总 batch_size:{total_batch_size}") + logger.info(f"梯度累积步数:{args.gradient_accumulation_steps}") + logger.info(f"总训练步数:{args.max_train_steps}") - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - # Check if output_dir contains saved state files (simplified check for state files) - # We look for the presence of the required files saved by accelerator.save_state in the output directory - required_files = ["pytorch_model.bin", "optimizer.bin"] # Simplified check - - has_saved_state = all(os.path.exists(os.path.join(args.output_dir, f)) for f in required_files) - - if args.resume_from_checkpoint == "latest" and not has_saved_state: - accelerator.print( - f"Checkpoint does not exist in '{args.output_dir}'. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint at {args.output_dir}") - # Load state directly from args.output_dir - accelerator.load_state(args.output_dir) - - # Since we are loading from the main directory, we trust accelerator.load_state - # to restore the global_step correctly from the state saved in that directory. - # We initialize global_step/initial_global_step/first_epoch using the restored state after load_state. - # For simplicity, we keep the original logic's initialization structure but adjust the path/logic. - # Accelerator will internally restore the true global_step. We set temporary values. - # Note: A cleaner solution often involves saving/loading a separate 'step.json' file for global_step tracking - # when relying on in-place saving without automatic tracking of step in folder names. - # For this simple replacement, we let the accelerator handle it. - global_step = 0 - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - else: - initial_global_step = 0 - progress_bar = tqdm( range(0, args.max_train_steps), - initial=initial_global_step, + initial=0, desc="Steps", - # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) - for epoch in range(first_epoch, args.num_train_epochs): + for epoch in range(0, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) if vae is not None: - # Convert images to latent space model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() model_input = model_input * vae.config.scaling_factor else: model_input = pixel_values - # Sample noise that we'll add to the model input if args.offset_noise: noise = torch.randn_like(model_input) + 0.1 * torch.randn( model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device ) else: noise = torch.randn_like(model_input) - bsz, channels, height, width = model_input.shape - # Sample a random timestep for each image + + bsz = model_input.shape[0] timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() + ).long() - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # Get the text embedding for conditioning - if args.pre_compute_text_embeddings: - encoder_hidden_states = batch["input_ids"] - else: - encoder_hidden_states = encode_prompt( - text_encoder, - batch["input_ids"], - batch["attention_mask"], - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) - - if unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) - - if args.class_labels_conditioning == "timesteps": - class_labels = timesteps - else: - class_labels = None + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) - # Predict the noise residual - model_pred = unet( - noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False - )[0] + model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, return_dict=False)[0] if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError(f"未知 prediction_type:{noise_scheduler.config.prediction_type}") if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - # Compute instance loss if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight + mse_loss_weights = base_weight + 1 if noise_scheduler.config.prediction_type == "v_prediction" else base_weight loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() if args.with_prior_preservation: - # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss - # [START] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) if args.coords_save_path is not None: - # 修正 X轴 计算:将 torch.linalg.norm 替换为传统的 torch.norm - # 传统的 torch.norm 支持对多个维度求范数 (dim=[1, 2, 3]) - # X轴: UNet 预测特征 L2 范数 (衡量预测的“强度”) - # torch.norm(..., p=2, dim=...) 表示 L2 范数 - X_i_feature_norm = torch.norm( - model_pred.detach().float(), - p=2, - dim=[1, 2, 3] # 对 C, H, W 维度求 L2 范数 - ).mean().item() # 对 Batch 维度求平均 - # Y轴: UNet 预测特征方差 (衡量预测的“混乱度/稳定性”) - # var() 默认对所有维度求方差 - Y_i_feature_var = torch.var( - model_pred.detach().float() - ).item() - # Z轴: LDM 损失 (衡量预测的“准确度”) + X_i_feature_norm = torch.norm(model_pred.detach().float(), p=2, dim=[1, 2, 3]).mean().item() + Y_i_feature_var = torch.var(model_pred.detach().float()).item() Z_i = loss.detach().item() - # 记录坐标 (仅在主进程进行) if accelerator.is_main_process and global_step % args.coords_log_interval == 0: coords_list.append([global_step, X_i_feature_norm, Y_i_feature_var, Z_i]) if global_step % (args.coords_log_interval * 10) == 0: df = pd.DataFrame( coords_list, - columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] + columns=["step", "X_Feature_L2_Norm", "Y_Feature_Variance", "Z_LDM_Loss"], ) save_file_path = Path(args.coords_save_path) if not save_file_path.suffix: save_file_path = save_file_path / "coords.csv" save_file_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_file_path, index=False) - logger.info( - f"Step {global_step}: 已记录可视化坐标,周期保存批次坐标到 {save_file_path}" - ) - # [END] 为可视化方案增加的 X轴 (特征范数) 和 Y轴 (特征方差) 计算 (通用指标) + logger.info(f"坐标已写入:{save_file_path}") accelerator.backward(loss) + if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -1399,47 +891,18 @@ def main(args): else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # Save state directly to output_dir, replacing previous checkpoint - # checkpoints_total_limit logic is skipped as we are only keeping one checkpoint (the latest) - save_path = args.output_dir - accelerator.save_state(save_path) - logger.info(f"Saved state directly to {save_path}, replacing previous checkpoint at step {global_step}") - - images = [] - - if args.validation_prompt is not None and (global_step + 1) % args.validation_steps == 0: - images = log_validation( - unwrap_model(text_encoder) if text_encoder is not None else text_encoder, - tokenizer, - unwrap_model(unet), - vae, - args, - accelerator, - weight_dtype, - global_step, - validation_prompt_encoder_hidden_states, - validation_prompt_negative_prompt_embeds, - ) - - # Save validation images directly to output_dir - save_path = Path(args.validation_image_output_dir) if args.validation_image_output_dir else Path(args.output_dir) - save_path.mkdir(parents=True, exist_ok=True) - logger.info(f"Saving validation images directly to {save_path}, overwriting previous images.") - - for i, image in enumerate(images): - # The file name is constant, thus overwriting - image.save(save_path / f"validation_image_{i}.png") + if accelerator.is_main_process and global_step % args.checkpointing_steps == 0: + accelerator.save_state(args.output_dir) + logger.info(f"已保存训练状态到:{args.output_dir}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -1448,15 +911,14 @@ def main(args): if global_step >= args.max_train_steps: break - # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() + + images = [] if accelerator.is_main_process: pipeline_args = {} - - if text_encoder is not None: + if not args.skip_save_text_encoder: pipeline_args["text_encoder"] = unwrap_model(text_encoder) - - if args.skip_save_text_encoder: + else: pipeline_args["text_encoder"] = None pipeline = DiffusionPipeline.from_pretrained( @@ -1466,21 +928,36 @@ def main(args): variant=args.variant, **pipeline_args, ) + pipeline.save_pretrained(args.output_dir) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type + del unet + del optimizer + del lr_scheduler + if vae is not None: + del vae + if not args.train_text_encoder: + del text_encoder + gc.collect() + torch.cuda.empty_cache() - pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + images = run_validation_txt2img( + finetuned_model_dir=args.output_dir, + prompt=args.validation_prompt, + negative_prompt=args.validation_negative_prompt, + num_images=args.num_validation_images, + num_inference_steps=args.validation_num_inference_steps, + guidance_scale=args.validation_guidance_scale, + seed=args.seed, + accelerator=accelerator, + weight_dtype=weight_dtype, + global_step=global_step, + ) - pipeline.save_pretrained(args.output_dir) + out_dir = Path(args.validation_image_output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + for i, image in enumerate(images): + image.save(out_dir / f"validation_image_{i}.png") + logger.info(f"validation 图像已保存到:{out_dir}") if args.push_to_hub: save_model_card( @@ -1498,25 +975,19 @@ def main(args): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) - - # [START] 为可视化方案增加的最终保存 (通用指标) + if args.coords_save_path is not None and coords_list: - df = pd.DataFrame( - coords_list, - columns=['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss'] - ) - # 假设 args.coords_save_path 是目标文件路径 + df = pd.DataFrame(coords_list, columns=["step", "X_Feature_L2_Norm", "Y_Feature_Variance", "Z_LDM_Loss"]) save_file_path = Path(args.coords_save_path) if not save_file_path.suffix: save_file_path = save_file_path / "coords.csv" save_file_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_file_path, index=False) - logger.info(f"训练结束:已将所有 {len(coords_list)} 坐标保存到 {save_file_path}") - # [END] 为可视化方案增加的最终保存 (通用指标) + logger.info(f"训练结束:坐标已保存到 {save_file_path}(共 {len(coords_list)} 条)") accelerator.end_training() if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/src/backend/app/scripts/finetune_db.sh b/src/backend/app/scripts/finetune_db.sh index ef0bcd3..051dbe8 100644 --- a/src/backend/app/scripts/finetune_db.sh +++ b/src/backend/app/scripts/finetune_db.sh @@ -1,34 +1,23 @@ -#需要环境:conda activate pid -### Trianing model +# 需要环境:conda activate pid export HF_HUB_OFFLINE=1 -# 强制使用本地模型缓存,避免联网下载模型 -### SD v2.1 -# export HF_HOME="/root/autodl-tmp/huggingface_cache" -# export MODEL_PATH="stabilityai/stable-diffusion-2-1" - -### SD v1.5 -# export HF_HOME="/root/autodl-tmp/huggingface_cache" -# export MODEL_PATH="runwayml/stable-diffusion-v1-5" +# SD v1.5 本地路径 export MODEL_PATH="../../static/hf_models/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14" - export TASKNAME="task001" -export TYPE="clean" #clean or perturbed - +export TYPE="perturbed" # clean or perturbed if [ "$TYPE" == "clean" ]; then export INSTANCE_DIR="../../static/originals/${TASKNAME}" else export INSTANCE_DIR="../../static/perturbed/${TASKNAME}" fi -export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}" -export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}" -export CLASS_DIR="../../static/class/${TASKNAME}" -export COORD_DIR="../../static/eva_res/position/${TASKNAME}" +export DREAMBOOTH_OUTPUT_DIR="../../static/hf_models/fine_tuned/${TYPE}/${TASKNAME}" +export OUTPUT_INFER_DIR="../../static/model_outputs/${TYPE}/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" +export COORD_DIR="../../static/eva_res/position/${TASKNAME}" -# ------------------------- 自动创建依赖路径 ------------------------- echo "Creating required directories..." mkdir -p "$INSTANCE_DIR" mkdir -p "$DREAMBOOTH_OUTPUT_DIR" @@ -36,52 +25,43 @@ mkdir -p "$OUTPUT_INFER_DIR" mkdir -p "$CLASS_DIR" mkdir -p "$COORD_DIR" - -# ------------------------- 自动清除旧文件 ------------------------- echo "Clearing output directory: $DREAMBOOTH_OUTPUT_DIR and $OUTPUT_INFER_DIR and $COORD_DIR" -# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$DREAMBOOTH_OUTPUT_DIR" -mindepth 1 -delete find "$OUTPUT_INFER_DIR" -mindepth 1 -delete find "$COORD_DIR" -mindepth 1 -delete - -# ------------------------- Fine-tune DreamBooth on images ------------------------- CUDA_VISIBLE_DEVICES=0 accelerate launch ../finetune_infras/train_db_gen_trace.py \ - --pretrained_model_name_or_path=$MODEL_PATH \ + --pretrained_model_name_or_path=$MODEL_PATH \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ --output_dir=$DREAMBOOTH_OUTPUT_DIR \ - --validation_image_output_dir=$OUTPUT_INFER_DIR \ --with_prior_preservation \ - --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks person" \ - --class_prompt="a photo of person" \ + --train_text_encoder \ + --prior_loss_weight=0.4 \ + --instance_prompt="a selfie photo of person" \ + --class_prompt="a selfie photo of person" \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=1 \ - --learning_rate=2e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --num_class_images=200 \ - --max_train_steps=1000 \ - --checkpointing_steps=500 \ - --center_crop \ + --learning_rate=5e-7 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=50 \ + --num_class_images=100 \ + --max_train_steps=800 \ + --checkpointing_steps=400 \ --mixed_precision=bf16 \ --prior_generation_precision=bf16 \ - --sample_batch_size=5 \ - --validation_prompt="a photo of sks person" \ - --num_validation_images 10 \ - --validation_steps 500 \ + --sample_batch_size=5 \ + --validation_prompt="a selfie photo of person, head-and-shoulders, face looking at the camera, Eiffel Tower clearly visible behind, outdoor daytime, realistic" \ + --num_validation_images=5 \ + --validation_num_inference_steps=120 \ + --validation_guidance_scale=7.0 \ + --validation_image_output_dir=$OUTPUT_INFER_DIR \ --coords_save_path=$COORD_DIR \ --coords_log_interval=10 - -# ------------------------- 训练后清空 CLASS_DIR ------------------------- -# 注意:这会在 accelerate launch 成功结束后执行 echo "Clearing class directory: $CLASS_DIR" -# 确保目录存在,避免清理命令失败 mkdir -p "$CLASS_DIR" -# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$CLASS_DIR" -mindepth 1 -delete echo "Script finished." \ No newline at end of file -- 2.34.1 From 5b1a298ae586dc45455d40dac8c3a512c93ad26b Mon Sep 17 00:00:00 2001 From: Ryan <3266408525@qq.com> Date: Sat, 13 Dec 2025 22:00:10 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E9=98=B2?= =?UTF-8?q?=E5=AE=9A=E5=88=B6=E7=94=9F=E6=88=90=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/scripts/attack_anti_customize_gen.sh | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 src/backend/app/scripts/attack_anti_customize_gen.sh diff --git a/src/backend/app/scripts/attack_anti_customize_gen.sh b/src/backend/app/scripts/attack_anti_customize_gen.sh new file mode 100644 index 0000000..5e98ede --- /dev/null +++ b/src/backend/app/scripts/attack_anti_customize_gen.sh @@ -0,0 +1,64 @@ +#需要环境:conda activate simac +export HF_HUB_OFFLINE=1 +export MODEL_PATH="../../static/hf_models/diffusers/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" +export TASKNAME="task001" + +# ------------------------- Train ASPL on set CLEAN_ADV_DIR ------------------------- +export CLEAN_TRAIN_DIR="../../static/originals/${TASKNAME}" +export CLEAN_ADV_DIR="../../static/originals/${TASKNAME}" +export OUTPUT_DIR="../../static/perturbed/${TASKNAME}" +export CLASS_DIR="../../static/class/${TASKNAME}" + +# ------------------------- 自动创建依赖路径 ------------------------- +echo "Creating required directories..." +mkdir -p "$CLEAN_TRAIN_DIR" +mkdir -p "$CLEAN_ADV_DIR" +mkdir -p "$OUTPUT_DIR" +mkdir -p "$CLASS_DIR" +echo "Directories created successfully." + + +# ------------------------- 训练前清空 OUTPUT_DIR ------------------------- +echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除 +mkdir -p "$OUTPUT_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$OUTPUT_DIR" -mindepth 1 -delete +find "$CLASS_DIR" -mindepth 1 -delete + + + +accelerate launch ../algorithms/simac.py \ + --pretrained_model_name_or_path=$MODEL_PATH \ + --enable_xformers_memory_efficient_attention \ + --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ + --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ + --instance_prompt="a photo of person" \ + --class_data_dir=$CLASS_DIR \ + --num_class_images=100 \ + --class_prompt="a photo of person" \ + --output_dir=$OUTPUT_DIR \ + --center_crop \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --resolution=384 \ + --train_batch_size=1 \ + --max_train_steps=100 \ + --max_f_train_steps=3 \ + --max_adv_train_steps=6 \ + --checkpointing_iterations=20 \ + --learning_rate=5e-7 \ + --pgd_alpha=0.005 \ + --pgd_eps=10 \ + --seed=0 + +# ------------------------- 训练后清空 CLASS_DIR ------------------------- +# 注意:这会在 accelerate launch 成功结束后执行 +echo "Clearing class directory: $CLASS_DIR" +# 确保目录存在,避免清理命令失败 +mkdir -p "$CLASS_DIR" +# 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) +find "$CLASS_DIR" -mindepth 1 -delete + +echo "Script finished." \ No newline at end of file -- 2.34.1 From 138f684491d85d66be4e99ca1119045781f852ab Mon Sep 17 00:00:00 2001 From: Ryan <3266408525@qq.com> Date: Sat, 13 Dec 2025 22:00:41 +0800 Subject: [PATCH 3/3] =?UTF-8?q?improve:=20=E6=94=B9=E8=BF=9BSimAC=E5=8A=A0?= =?UTF-8?q?=E5=99=AA=E8=B6=85=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/scripts/attack_simac.sh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/backend/app/scripts/attack_simac.sh b/src/backend/app/scripts/attack_simac.sh index f044dd1..660d6a1 100644 --- a/src/backend/app/scripts/attack_simac.sh +++ b/src/backend/app/scripts/attack_simac.sh @@ -20,8 +20,12 @@ echo "Directories created successfully." # ------------------------- 训练前清空 OUTPUT_DIR ------------------------- echo "Clearing output directory: $OUTPUT_DIR" +# 确保目录存在,避免清理命令失败 +# 注意:虽然前面已经创建,但这里保留是为了代码逻辑清晰,也可以删除 +mkdir -p "$OUTPUT_DIR" # 查找并删除目录下的所有文件和子目录(但不删除 . 或 ..) find "$OUTPUT_DIR" -mindepth 1 -delete +find "$CLASS_DIR" -mindepth 1 -delete @@ -30,9 +34,9 @@ accelerate launch ../algorithms/simac.py \ --enable_xformers_memory_efficient_attention \ --instance_data_dir_for_train=$CLEAN_TRAIN_DIR \ --instance_data_dir_for_adversarial=$CLEAN_ADV_DIR \ - --instance_prompt="a photo of sks person" \ + --instance_prompt="a photo of person" \ --class_data_dir=$CLASS_DIR \ - --num_class_images=200 \ + --num_class_images=100 \ --class_prompt="a photo of person" \ --output_dir=$OUTPUT_DIR \ --center_crop \ @@ -40,13 +44,13 @@ accelerate launch ../algorithms/simac.py \ --prior_loss_weight=1.0 \ --resolution=384 \ --train_batch_size=1 \ - --max_train_steps=50 \ + --max_train_steps=60 \ --max_f_train_steps=3 \ --max_adv_train_steps=6 \ --checkpointing_iterations=10 \ - --learning_rate=5e-7 \ + --learning_rate=2e-6 \ --pgd_alpha=0.005 \ - --pgd_eps=8 \ + --pgd_eps=10 \ --seed=0 # ------------------------- 训练后清空 CLASS_DIR ------------------------- -- 2.34.1