@ -1,26 +1,26 @@
import argparse
import copy
import gc
import hashlib
import itertools
import logging
import os
import random
from pathlib import Path
import datasets
import diffusers
import random
from torch . backends import cudnn
import transformers
import numpy as np
import torch
import torch . nn . functional as F
import torch . utils . checkpoint
import transformers
from accelerate import Accelerator
from accelerate . logging import get_logger
from accelerate . utils import set_seed
from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , UNet2DConditionModel
from diffusers . utils . import_utils import is_xformers_available
from PIL import Image
from torch . backends import cudnn
from torch . utils . data import Dataset
from torchvision import transforms
from tqdm . auto import tqdm
@ -30,8 +30,19 @@ from transformers import AutoTokenizer, PretrainedConfig
logger = get_logger ( __name__ )
def _cuda_gc ( ) - > None :
""" Try to release unreferenced CUDA memory and reduce fragmentation.
This is a best - effort helper . It does not change algorithmic behavior but can
make long runs less prone to OOM due to fragmentation / reserved - memory growth .
"""
gc . collect ( )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
class DreamBoothDatasetFromTensor ( Dataset ) :
""" Just like DreamBoothDataset, but take instance_images_tensor instead of path """
""" Just like DreamBoothDataset, but take instance_images_tensor instead of path . """
def __init__ (
self ,
@ -88,7 +99,7 @@ class DreamBoothDatasetFromTensor(Dataset):
if self . class_data_root :
class_image = Image . open ( self . class_images_path [ index % self . num_class_images ] )
if not class_image . mode = = " RGB " :
if class_image . mode ! = " RGB " :
class_image = class_image . convert ( " RGB " )
example [ " class_images " ] = self . image_transforms ( class_image )
example [ " class_prompt_ids " ] = self . tokenizer (
@ -114,12 +125,11 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
from transformers import CLIPTextModel
return CLIPTextModel
el if model_class == " RobertaSeriesModelWithTransformation " :
if model_class == " RobertaSeriesModelWithTransformation " :
from diffusers . pipelines . alt_diffusion . modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else :
raise ValueError ( f " { model_class } is not supported. " )
raise ValueError ( f " { model_class } is not supported. " )
def parse_args ( input_args = None ) :
@ -337,17 +347,13 @@ def parse_args(input_args=None):
" --max_steps " ,
type = int ,
default = 50 ,
help = (
" Maximum steps for adaptive greedy timestep selection. "
) ,
help = ( " Maximum steps for adaptive greedy timestep selection. " ) ,
)
parser . add_argument (
" --delta_t " ,
type = int ,
default = 20 ,
help = (
" delete 2*delta_t for each adaptive greedy timestep selection. "
) ,
help = ( " delete 2*delta_t for each adaptive greedy timestep selection. " ) ,
)
if input_args is not None :
args = parser . parse_args ( input_args )
@ -358,7 +364,7 @@ def parse_args(input_args=None):
class PromptDataset ( Dataset ) :
" A simple dataset to prepare the prompts to generate class images on multiple GPUs. "
" "" A simple dataset to prepare the prompts to generate class images on multiple GPUs. "" "
def __init__ ( self , prompt , num_samples ) :
self . prompt = prompt
@ -389,7 +395,6 @@ def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
return images
def train_one_epoch (
args ,
models ,
@ -399,8 +404,6 @@ def train_one_epoch(
data_tensor : torch . Tensor ,
num_steps = 20 ,
) :
# Load the tokenizer
unet , text_encoder = copy . deepcopy ( models [ 0 ] ) , copy . deepcopy ( models [ 1 ] )
params_to_optimize = itertools . chain ( unet . parameters ( ) , text_encoder . parameters ( ) )
@ -422,7 +425,6 @@ def train_one_epoch(
args . center_crop ,
)
# weight_dtype = torch.bfloat16
weight_dtype = torch . bfloat16
device = torch . device ( " cuda " )
@ -443,24 +445,17 @@ def train_one_epoch(
latents = vae . encode ( pixel_values ) . latent_dist . sample ( )
latents = latents * vae . config . scaling_factor
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents )
bsz = latents . shape [ 0 ]
# Sample a random timestep for each image
timesteps = torch . randint ( 0 , noise_scheduler . config . num_train_timesteps , ( bsz , ) , device = latents . device )
timesteps = timesteps . long ( )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder ( input_ids ) [ 0 ]
# Predict the noise residual
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
# Get the target for loss depending on the prediction type
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
@ -468,33 +463,39 @@ def train_one_epoch(
else :
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
# with prior preservation loss
if args . with_prior_preservation :
model_pred , model_pred_prior = torch . chunk ( model_pred , 2 , dim = 0 )
target , target_prior = torch . chunk ( target , 2 , dim = 0 )
# Compute instance loss
instance_loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
# Compute prior loss
prior_loss = F . mse_loss ( model_pred_prior . float ( ) , target_prior . float ( ) , reduction = " mean " )
# Add the prior loss to the instance loss.
loss = instance_loss + args . prior_loss_weight * prior_loss
else :
prior_loss = torch . tensor ( 0.0 , device = device )
instance_loss = torch . tensor ( 0.0 , device = device )
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
loss . backward ( )
torch . nn . utils . clip_grad_norm_ ( params_to_optimize , 1.0 , error_if_nonfinite = True )
optimizer . step ( )
optimizer . zero_grad ( )
print (
f " Step # { step } , loss: { loss . detach ( ) . item ( ) } , prior_loss: { prior_loss . detach ( ) . item ( ) } , instance_loss: { instance_loss . detach ( ) . item ( ) } "
f " Step # { step } , loss: { loss . detach ( ) . item ( ) } , prior_loss: { prior_loss . detach ( ) . item ( ) } , "
f " instance_loss: { instance_loss . detach ( ) . item ( ) } "
)
# Best-effort: free per-step tensors earlier (no behavior change).
del step_data , pixel_values , input_ids , latents , noise , timesteps , noisy_latents , encoder_hidden_states
del model_pred , target , loss , prior_loss , instance_loss
# Best-effort: release optimizer state + dataset refs sooner.
del optimizer , train_dataset , params_to_optimize
_cuda_gc ( )
return [ unet , text_encoder ]
def set_unet_attr ( unet ) :
def conv_forward ( self ) :
def forward ( input_tensor , temb ) :
@ -505,7 +506,6 @@ def set_unet_attr(unet):
hidden_states = self . nonlinearity ( hidden_states )
if self . upsample is not None :
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states . shape [ 0 ] > = 64 :
input_tensor = input_tensor . contiguous ( )
hidden_states = hidden_states . contiguous ( )
@ -538,37 +538,33 @@ def set_unet_attr(unet):
input_tensor = self . conv_shortcut ( input_tensor )
output_tensor = ( input_tensor + hidden_states ) / self . output_scale_factor
return output_tensor
return forward
# [MODIFIED] 只 hook 算法实际使用到的 up_blocks[3]
conv_module_list = [
unet . up_blocks [ 3 ] . resnets [ 0 ] , unet . up_blocks [ 3 ] . resnets [ 1 ] , unet . up_blocks [ 3 ] . resnets [ 2 ] ,
]
unet . up_blocks [ 3 ] . resnets [ 0 ] ,
unet . up_blocks [ 3 ] . resnets [ 1 ] ,
unet . up_blocks [ 3 ] . resnets [ 2 ] ,
]
for conv_module in conv_module_list :
conv_module . forward = conv_forward ( conv_module )
setattr ( conv_module , ' in_layers_features ' , None )
setattr ( conv_module , ' out_layers_features ' , None )
setattr ( conv_module , " in_layers_features " , None )
setattr ( conv_module , " out_layers_features " , None )
def save_feature_maps ( up_blocks , down_blocks ) :
out_layers_features_list_3 = [ ]
res_3_list = [ 0 , 1 , 2 ]
res_3_list = [ 0 , 1 , 2 ]
# [MODIFIED] 只提取 up_blocks[3] 的特征
block = up_blocks [ 3 ]
for index in res_3_list :
out_layers_features_list_3 . append ( block . resnets [ index ] . out_layers_features )
out_layers_features_list_3 = torch . stack ( out_layers_features_list_3 , dim = 0 )
# [MODIFIED] 只返回算法实际使用到的特征
return out_layers_features_list_3
def pgd_attack (
args ,
models ,
@ -579,10 +575,13 @@ def pgd_attack(
original_images : torch . Tensor ,
target_tensor : torch . Tensor ,
num_steps : int ,
time_list
time_list ,
) :
""" Return new perturbed data """
""" Return new perturbed data .
Note : This function keeps the external behavior identical , but tries to reduce
memory pressure by freeing tensors early and avoiding lingering references .
"""
unet , text_encoder = models
weight_dtype = torch . bfloat16
device = torch . device ( " cuda " )
@ -595,6 +594,7 @@ def pgd_attack(
perturbed_images = data_tensor . detach ( ) . clone ( )
perturbed_images . requires_grad_ ( True )
# Keep input_ids on CPU; move to GPU only when encoding.
input_ids = tokenizer (
args . instance_prompt ,
truncation = True ,
@ -604,12 +604,13 @@ def pgd_attack(
) . input_ids . repeat ( len ( data_tensor ) , 1 )
for step in range ( num_steps ) :
perturbed_images . requires_grad = True
perturbed_images . requires_grad_ ( True )
latents = vae . encode ( perturbed_images . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * vae . config . scaling_factor
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents )
bsz = latents . shape [ 0 ]
timesteps = [ ]
for i in range ( len ( data_tensor ) ) :
ts = time_list [ i ]
@ -618,58 +619,62 @@ def pgd_attack(
timestep = timestep . long ( )
timesteps . append ( timestep )
timesteps = torch . cat ( timesteps ) . to ( device )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder ( input_ids . to ( device ) ) [ 0 ]
# Predict the noise residual
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
# Get the target for loss depending on the prediction type
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
else :
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
# [MODIFIED] feature loss (只解包需要的特征)
noise_out_layers_features_3 = save_feature_maps ( unet . up_blocks , unet . down_blocks )
with torch . no_grad ( ) :
clean_latents = vae . encode ( data_tensor . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
clean_latents = clean_latents * vae . config . scaling_factor
noisy_clean_latents = noise_scheduler . add_noise ( clean_latents , noise , timesteps )
clean_model_pred = unet ( noisy_clean_latents , timesteps , encoder_hidden_states ) . sample
# [MODIFIED] (只解包需要的特征)
clean_out_layers_features_3 = save_feature_maps ( unet . up_blocks , unet . down_blocks )
# [LOGIC UNCHANGED] 目标损失函数不变
target_loss = F . mse_loss ( noise_out_layers_features_3 . float ( ) , clean_out_layers_features_3 . float ( ) , reduction = " mean " )
unet . zero_grad ( )
text_encoder . zero_grad ( )
_ = unet ( noisy_clean_latents , timesteps , encoder_hidden_states ) . sample
clean_out_layers_features_3 = save_feature_maps ( unet . up_blocks , unet . down_blocks )
target_loss = F . mse_loss (
noise_out_layers_features_3 . float ( ) ,
clean_out_layers_features_3 . float ( ) ,
reduction = " mean " ,
)
unet . zero_grad ( set_to_none = True )
text_encoder . zero_grad ( set_to_none = True )
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
loss = loss + target_loss . detach ( ) . item ( ) # 保持原有的(奇怪的) loss.backward() 逻辑
# Keep original behavior: feature loss does not backprop (added as Python float).
loss = loss + target_loss . detach ( ) . item ( )
loss . backward ( )
alpha = args . pgd_alpha
eps = args . pgd_eps / 255
adv_images = perturbed_images + alpha * perturbed_images . grad . sign ( )
eta = torch . clamp ( adv_images - original_images , min = - eps , max = + eps )
perturbed_images = torch . clamp ( original_images + eta , min = - 1 , max = + 1 ) . detach_ ( )
print ( f " PGD loss - step { step } , loss: { loss . detach ( ) . item ( ) } , target_loss : { target_loss . detach ( ) . item ( ) } " )
# [MODIFIED] 显式释放特征张量并清理缓存,以确保后续 train_one_epoch 有足够的显存
# 这部分代码在 PGD 循环结束后添加 (即在 return perturbed_images 之前)
del noise_out_layers_features_3
del clean_out_layers_features_3
del noise
del latents
del encoder_hidden_states
torch . cuda . empty_cache ( )
print (
f " PGD loss - step { step } , loss: { loss . detach ( ) . item ( ) } , target_loss : { target_loss . detach ( ) . item ( ) } "
)
# Best-effort: free per-step tensors early.
del latents , noise , timesteps , noisy_latents , encoder_hidden_states , model_pred , target
del noise_out_layers_features_3 , clean_latents , noisy_clean_latents , clean_out_layers_features_3
del target_loss , loss , adv_images , eta
_cuda_gc ( )
return perturbed_images
def select_timestep (
args ,
models ,
@ -679,9 +684,11 @@ def select_timestep(
data_tensor : torch . Tensor ,
original_images : torch . Tensor ,
target_tensor : torch . Tensor ,
) :
""" Return new perturbed data"""
) :
""" Return timestep lists for each image.
External behavior unchanged ; add best - effort per - loop cleanup to lower memory pressure .
"""
unet , text_encoder = models
weight_dtype = torch . bfloat16
device = torch . device ( " cuda " )
@ -693,7 +700,6 @@ def select_timestep(
perturbed_images = data_tensor . detach ( ) . clone ( )
perturbed_images . requires_grad_ ( True )
input_ids = tokenizer (
args . instance_prompt ,
truncation = True ,
@ -701,93 +707,39 @@ def select_timestep(
max_length = tokenizer . model_max_length ,
return_tensors = " pt " ,
) . input_ids
time_list = [ ]
for id in range ( len ( data_tensor ) ) :
perturbed_image = perturbed_images [ id , : ] . unsqueeze ( 0 )
original_image = original_images [ id , : ] . unsqueeze ( 0 )
time_seq = torch . tensor ( list ( range ( 0 , 1000 ) ) )
input_mask = torch . ones_like ( time_seq )
id_image = perturbed_image . detach ( ) . clone ( )
for step in range ( args . max_steps ) :
id_image . requires_grad_ ( True )
select_mask = torch . where ( input_mask == 1 , True , False )
res_time_seq = torch . masked_select ( time_seq , select_mask )
if len ( res_time_seq ) > 100 :
min_score , max_score = 0.0 , 0.0
for index in range ( 0 , 5 ) :
id_image . requires_grad_ ( True )
latents = vae . encode ( id_image . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * vae . config . scaling_factor
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents )
bsz = latents . shape [ 0 ]
# Sample a random timestep for each image
inner_index = torch . randint ( 0 , len ( res_time_seq ) , ( bsz , ) )
timesteps = torch . IntTensor ( [ res_time_seq [ inner_index ] ] ) . to ( device )
timesteps = timesteps . long ( )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder ( input_ids . to ( device ) ) [ 0 ]
# Predict the noise residual
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
# Get the target for loss depending on the prediction type
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
else :
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
unet . zero_grad ( )
text_encoder . zero_grad ( )
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
loss . backward ( )
score = torch . sum ( torch . abs ( id_image . grad . data ) )
index = index + 1
id_image . grad . zero_ ( )
if index == 1 :
min_score = score
max_score = score
del_t = res_time_seq [ inner_index ] . item ( )
select_t = res_time_seq [ inner_index ] . item ( )
else :
if min_score > score :
min_score = score
del_t = res_time_seq [ inner_index ] . item ( )
if max_score < score :
max_score = score
select_t = res_time_seq [ inner_index ] . item ( )
print ( f " PGD loss - step { step } , index : { index } , loss: { loss . detach ( ) . item ( ) } , score: { score } , t : { res_time_seq [ inner_index ] } , ts_len: { len ( res_time_seq ) } " )
print ( " del_t " , del_t , " max_t " , select_t )
if del_t < args . delta_t :
del_t = args . delta_t
elif del_t > ( 1000 - args . delta_t ) :
del_t = 1000 - args . delta_t
input_mask [ del_t - 20 : del_t + 20 ] = input_mask [ del_t - 20 : del_t + 20 ] - 1
input_mask = torch . clamp ( input_mask , min = 0 , max = + 1 )
time_list = [ ]
for img_id in range ( len ( data_tensor ) ) :
perturbed_image = perturbed_images [ img_id , : ] . unsqueeze ( 0 )
original_image = original_images [ img_id , : ] . unsqueeze ( 0 )
time_seq = torch . tensor ( list ( range ( 0 , 1000 ) ) )
input_mask = torch . ones_like ( time_seq )
id_image = perturbed_image . detach ( ) . clone ( )
for step in range ( args . max_steps ) :
id_image . requires_grad_ ( True )
select_mask = torch . where ( input_mask == 1 , True , False )
res_time_seq = torch . masked_select ( time_seq , select_mask )
if len ( res_time_seq ) > 100 :
min_score , max_score = 0.0 , 0.0
for inner_try in range ( 0 , 5 ) :
id_image . requires_grad_ ( True )
latents = vae . encode ( id_image . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * vae . config . scaling_factor
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents )
bsz = latents . shape [ 0 ]
timesteps = torch . IntTensor ( [ select_t ] ) . to ( device )
inner_index = torch . randint ( 0 , len ( res_time_seq ) , ( bsz , ) )
timesteps = torch . IntTensor ( [ res_time_seq [ inner_index ] ] ) . to ( device )
timesteps = timesteps . long ( )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder ( input_ids . to ( device ) ) [ 0 ]
# Predict the noise residual
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
# Get the target for loss depending on the prediction type
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
@ -795,26 +747,92 @@ def select_timestep(
else :
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
unet . zero_grad ( )
text_encoder . zero_grad ( )
unet . zero_grad ( set_to_none = True )
text_encoder . zero_grad ( set_to_none = True )
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
loss . backward ( )
alpha = args . pgd_alpha
eps = args . pgd_eps / 255
adv_image = id_image + alpha * id_image . grad . sign ( )
eta = torch . clamp ( adv_image - original_image , min = - eps , max = + eps )
score = torch . sum ( torch . abs ( id_image . grad . sign ( ) ) )
id_image = torch . clamp ( original_image + eta , min = - 1 , max = + 1 ) . detach_ ( )
score = torch . sum ( torch . abs ( id_image . grad . data ) )
id_image . grad . zero_ ( )
if inner_try == 0 :
min_score = score
max_score = score
del_t = res_time_seq [ inner_index ] . item ( )
select_t = res_time_seq [ inner_index ] . item ( )
else :
if min_score > score :
min_score = score
del_t = res_time_seq [ inner_index ] . item ( )
if max_score < score :
max_score = score
select_t = res_time_seq [ inner_index ] . item ( )
print (
f " PGD loss - step { step } , index : { inner_try + 1 } , loss: { loss . detach ( ) . item ( ) } , "
f " score: { score } , t : { res_time_seq [ inner_index ] } , ts_len: { len ( res_time_seq ) } "
)
del latents , noise , timesteps , noisy_latents , encoder_hidden_states , model_pred , target , loss , score
print ( " del_t " , del_t , " max_t " , select_t )
if del_t < args . delta_t :
del_t = args . delta_t
elif del_t > ( 1000 - args . delta_t ) :
del_t = 1000 - args . delta_t
input_mask [ del_t - 20 : del_t + 20 ] = input_mask [ del_t - 20 : del_t + 20 ] - 1
input_mask = torch . clamp ( input_mask , min = 0 , max = + 1 )
id_image . requires_grad_ ( True )
latents = vae . encode ( id_image . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * vae . config . scaling_factor
noise = torch . randn_like ( latents )
timesteps = torch . IntTensor ( [ select_t ] ) . to ( device )
timesteps = timesteps . long ( )
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
encoder_hidden_states = text_encoder ( input_ids . to ( device ) ) [ 0 ]
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type == " v_prediction " :
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
else :
# print(id, res_time_seq, step, len(res_time_seq))
time_list . append ( res_time_seq )
break
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
unet . zero_grad ( set_to_none = True )
text_encoder . zero_grad ( set_to_none = True )
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
loss . backward ( )
alpha = args . pgd_alpha
eps = args . pgd_eps / 255
adv_image = id_image + alpha * id_image . grad . sign ( )
eta = torch . clamp ( adv_image - original_image , min = - eps , max = + eps )
_ = torch . sum ( torch . abs ( id_image . grad . sign ( ) ) )
id_image = torch . clamp ( original_image + eta , min = - 1 , max = + 1 ) . detach_ ( )
del latents , noise , timesteps , noisy_latents , encoder_hidden_states , model_pred , target , loss , adv_image , eta
else :
time_list . append ( res_time_seq )
break
del perturbed_image , original_image , time_seq , input_mask , id_image
_cuda_gc ( )
del perturbed_images , input_ids
_cuda_gc ( )
return time_list
def setup_seeds ( ) :
seed = 42
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
@ -850,11 +868,11 @@ def main(args):
if args . seed is not None :
set_seed ( args . seed )
setup_seeds ( )
# Generate class images if prior preservation is enabled.
if args . with_prior_preservation :
class_images_dir = Path ( args . class_data_dir )
if not class_images_dir . exists ( ) :
class_images_dir . mkdir ( parents = True )
class_images_dir . mkdir ( parents = True , exist_ok = True )
cur_class_images = len ( list ( class_images_dir . iterdir ( ) ) )
if cur_class_images < args . num_class_images :
@ -865,12 +883,12 @@ def main(args):
torch_dtype = torch . float16
elif args . mixed_precision == " bf16 " :
torch_dtype = torch . bfloat16
pipeline = DiffusionPipeline . from_pretrained (
args . pretrained_model_name_or_path ,
torch_dtype = torch_dtype ,
safety_checker = None ,
revision = args . revision ,
)
pipeline . set_progress_bar_config ( disable = True )
@ -889,27 +907,25 @@ def main(args):
disable = not accelerator . is_local_main_process ,
) :
images = pipeline ( example [ " prompt " ] ) . images
for i , image in enumerate ( images ) :
hash_image = hashlib . sha1 ( image . tobytes ( ) ) . hexdigest ( )
image_filename = class_images_dir / f " { example [ ' index ' ] [ i ] + cur_class_images } - { hash_image } .jpg "
image . save ( image_filename )
del pipeline
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
del pipeline , sample_dataset , sample_dataloader
_cuda_gc ( )
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path ( args . pretrained_model_name_or_path , args . revision )
# Load scheduler and models
text_encoder = text_encoder_cls . from_pretrained (
args . pretrained_model_name_or_path ,
subfolder = " text_encoder " ,
revision = args . revision ,
)
unet = UNet2DConditionModel . from_pretrained (
args . pretrained_model_name_or_path , subfolder = " unet " , revision = args . revision ,
args . pretrained_model_name_or_path ,
subfolder = " unet " ,
revision = args . revision ,
)
tokenizer = AutoTokenizer . from_pretrained (
@ -919,12 +935,13 @@ def main(args):
use_fast = False ,
)
noise_scheduler = DDPMScheduler . from_pretrained ( args . pretrained_model_name_or_path , subfolder = " scheduler " , )
noise_scheduler = DDPMScheduler . from_pretrained ( args . pretrained_model_name_or_path , subfolder = " scheduler " )
vae = AutoencoderKL . from_pretrained (
args . pretrained_model_name_or_path , subfolder = " vae " , revision = args . revision ,
args . pretrained_model_name_or_path ,
subfolder = " vae " ,
revision = args . revision ,
) . cuda ( )
vae . requires_grad_ ( False )
if not args . train_text_encoder :
@ -967,22 +984,23 @@ def main(args):
target_latent_tensor = target_latent_tensor . repeat ( len ( perturbed_data ) , 1 , 1 , 1 ) . cuda ( )
f = [ unet , text_encoder ]
time_list = select_timestep (
args ,
f ,
tokenizer ,
noise_scheduler ,
vae ,
perturbed_data ,
original_data ,
target_latent_tensor ,
args ,
f ,
tokenizer ,
noise_scheduler ,
vae ,
perturbed_data ,
original_data ,
target_latent_tensor ,
)
for t in time_list :
print ( t )
for i in range ( args . max_train_steps ) :
# 1. f' = f.clone()
f_sur = copy . deepcopy ( f )
f_sur = train_one_epoch (
args ,
f_sur ,
@ -992,6 +1010,7 @@ def main(args):
clean_data ,
args . max_f_train_steps ,
)
perturbed_data = pgd_attack (
args ,
f_sur ,
@ -1002,8 +1021,13 @@ def main(args):
original_data ,
target_latent_tensor ,
args . max_adv_train_steps ,
time_list
time_list ,
)
# Free surrogate ASAP (best-effort, behavior unchanged).
del f_sur
_cuda_gc ( )
f = train_one_epoch (
args ,
f ,
@ -1015,24 +1039,31 @@ def main(args):
)
if ( i + 1 ) % args . checkpointing_iterations == 0 :
save_folder = args . output_dir
save_folder = args . output_dir
os . makedirs ( save_folder , exist_ok = True )
noised_imgs = perturbed_data . detach ( )
img_names = [
str ( instance_path ) . split ( " / " ) [ - 1 ] . split ( " . " ) [ 0 ]
for instance_path in list ( Path ( args . instance_data_dir_for_adversarial ) . iterdir ( ) )
]
for img_pixel , img_name in zip ( noised_imgs , img_names ) :
save_path = os . path . join ( save_folder , f " perturbed_ { img_name } .png " )
save_path = os . path . join ( save_folder , f " perturbed_ { img_name } .png " )
Image . fromarray (
( img_pixel * 127.5 + 128 ) . clamp ( 0 , 255 ) . to ( torch . uint8 ) . permute ( 1 , 2 , 0 ) . cpu ( ) . numpy ( )
( img_pixel * 127.5 + 128 )
. clamp ( 0 , 255 )
. to ( torch . uint8 )
. permute ( 1 , 2 , 0 )
. cpu ( )
. numpy ( )
) . save ( save_path )
print ( f " Saved perturbed images at step { i + 1 } to { save_folder } (Files are overwritten) " )
# Best-effort cleanup at the end of each outer iteration.
_cuda_gc ( )
if __name__ == " __main__ " :
args = parse_args ( )