@ -32,6 +32,33 @@ class AlgorithmConfig:
# 日志目录
LOGS_DIR = os . path . join ( os . path . dirname ( os . path . dirname ( __file__ ) ) , ' logs ' )
# 风格迁移防护预设风格列表
STYLE_PROTECTION_PRESETS = {
' van_gogh ' : {
' name ' : ' 梵高印象派 ' ,
' prompt ' : ' impressionism painting by van gogh ' ,
' description ' : ' 模仿梵高的印象派绘画风格 '
} ,
' kandinsky ' : {
' name ' : ' 康定斯基抽象派 ' ,
' prompt ' : ' abstract art by kandinsky ' ,
' description ' : ' 模仿康定斯基的抽象艺术风格 '
} ,
' picasso ' : {
' name ' : ' 毕加索立体派 ' ,
' prompt ' : ' cubist painting by picasso ' ,
' description ' : ' 模仿毕加索的立体主义风格 '
} ,
' baroque ' : {
' name ' : ' 巴洛克风格 ' ,
' prompt ' : ' baroque style painting ' ,
' description ' : ' 经典巴洛克艺术风格 '
}
}
# 日志目录
LOGS_DIR = os . path . join ( os . path . dirname ( os . path . dirname ( __file__ ) ) , ' logs ' )
# Conda环境配置( 从环境变量读取, 支持自定义)
CONDA_ENVS = {
' aspl ' : os . getenv ( ' CONDA_ENV_ASPL ' , ' simac ' ) ,
@ -40,6 +67,9 @@ class AlgorithmConfig:
' caat_pro ' : os . getenv ( ' CONDA_ENV_CAAT_PRO ' , ' caat ' ) ,
' pid ' : os . getenv ( ' CONDA_ENV_PID ' , ' pid ' ) ,
' glaze ' : os . getenv ( ' CONDA_ENV_GLAZE ' , ' pid ' ) ,
' anti_customize ' : os . getenv ( ' CONDA_ENV_ANTI_CUSTOMIZE ' , ' simac ' ) ,
' anti_face_edit ' : os . getenv ( ' CONDA_ENV_ANTI_FACE_EDIT ' , ' pid ' ) ,
' style_protection ' : os . getenv ( ' CONDA_ENV_STYLE_PROTECTION ' , ' pid ' ) ,
' dreambooth ' : os . getenv ( ' CONDA_ENV_DREAMBOOTH ' , ' pid ' ) ,
' lora ' : os . getenv ( ' CONDA_ENV_LORA ' , ' pid ' ) ,
' textual_inversion ' : os . getenv ( ' CONDA_ENV_TI ' , ' pid ' ) ,
@ -159,7 +189,65 @@ class AlgorithmConfig:
' center_crop ' : True ,
' max_train_steps ' : 150 ,
' eps ' : 0.05 ,
' target_style ' : ' cubism painting by picasso ' ,
' target_style ' : ' impressionism painting by van gogh ' ,
' style_strength ' : 0.75 ,
' n_runs ' : 3 ,
' style_transfer_iter ' : 15 ,
' guidance_scale ' : 7.5 ,
' seed ' : 42
}
} ,
' anti_customize ' : {
' real_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation ' , ' simac.py ' ) ,
' virtual_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation_engine.py ' ) ,
' conda_env ' : CONDA_ENVS [ ' anti_customize ' ] ,
' default_params ' : {
' pretrained_model_name_or_path ' : MODELS_DIR [ ' model1 ' ] ,
' enable_xformers_memory_efficient_attention ' : True ,
' instance_prompt ' : ' a photo of sks person ' ,
' class_prompt ' : ' a photo of person ' ,
' num_class_images ' : 100 ,
' center_crop ' : True ,
' with_prior_preservation ' : True ,
' prior_loss_weight ' : 1.0 ,
' resolution ' : 384 ,
' train_batch_size ' : 1 ,
' max_train_steps ' : 100 ,
' max_f_train_steps ' : 3 ,
' max_adv_train_steps ' : 6 ,
' checkpointing_iterations ' : 20 ,
' learning_rate ' : 5e-7 ,
' pgd_alpha ' : 0.005 ,
' seed ' : 0
}
} ,
' anti_face_edit ' : {
' real_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation ' , ' pid.py ' ) ,
' virtual_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation_engine.py ' ) ,
' conda_env ' : CONDA_ENVS [ ' anti_face_edit ' ] ,
' default_params ' : {
' pretrained_model_name_or_path ' : MODELS_DIR [ ' model2 ' ] ,
' resolution ' : 512 ,
' max_train_steps ' : 2000 ,
' center_crop ' : True ,
' step_size ' : 0.002 ,
' save_every ' : 200 ,
' attack_type ' : ' add-log ' ,
' seed ' : 0 ,
' dataloader_num_workers ' : 2
}
} ,
' style_protection ' : {
' real_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation ' , ' glaze.py ' ) ,
' virtual_script ' : os . path . join ( ALGORITHMS_DIR , ' perturbation_engine.py ' ) ,
' conda_env ' : CONDA_ENVS [ ' style_protection ' ] ,
' default_params ' : {
' pretrained_model_name_or_path ' : MODELS_DIR [ ' model2 ' ] ,
' resolution ' : 512 ,
' center_crop ' : True ,
' max_train_steps ' : 150 ,
' eps ' : 0.04 ,
' target_style ' : ' impressionism painting by van gogh ' ,
' style_strength ' : 0.75 ,
' n_runs ' : 3 ,
' style_transfer_iter ' : 15 ,
@ -195,6 +283,17 @@ class AlgorithmConfig:
config = cls . get_perturbation_config ( algorithm_code )
return config . get ( ' default_params ' , { } ) . copy ( )
@classmethod
def get_style_protection_presets ( cls ) :
""" 获取风格迁移防护的预设风格列表 """
return cls . STYLE_PROTECTION_PRESETS
@classmethod
def get_style_prompt ( cls , style_code ) :
""" 根据风格代码获取对应的提示词 """
preset = cls . STYLE_PROTECTION_PRESETS . get ( style_code )
return preset [ ' prompt ' ] if preset else None
# ========== 微调算法配置 ==========
FINETUNE_SCRIPTS = {
' dreambooth ' : {