feat: 添加caat_pro算法适配

pull/29/head
梁浩 4 months ago
parent 10bd92f70b
commit 84f41f8f76

@ -206,6 +206,14 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
f"--output_dir={output_dir}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'caat_pro':
# CAAT Pro参数结构带prior preservation
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'pid':
# PID参数结构
cmd_args.extend([
@ -275,7 +283,7 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
)
# 清理class_dir
if algorithm_code in ['aspl', 'simac']:
if algorithm_code in ['aspl', 'simac', 'caat_pro']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
shutil.rmtree(class_dir)

@ -37,6 +37,7 @@ class AlgorithmConfig:
'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'),
'simac': os.getenv('CONDA_ENV_SIMAC', 'simac'),
'caat': os.getenv('CONDA_ENV_CAAT', 'caat'),
'caat_pro': os.getenv('CONDA_ENV_CAAT_PRO', 'caat'),
'pid': os.getenv('CONDA_ENV_PID', 'pid'),
'glaze': os.getenv('CONDA_ENV_GLAZE', 'pid'),
'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'),
@ -116,6 +117,26 @@ class AlgorithmConfig:
'alpha': 5e-3
}
},
'caat_pro': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'caat.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['caat_pro'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'with_prior_preservation': True,
'instance_prompt': 'a photo of a person',
'class_prompt': 'person',
'num_class_images': 200,
'resolution': 512,
'learning_rate': 1e-5,
'lr_warmup_steps': 0,
'max_train_steps': 250,
'hflip': True,
'mixed_precision': 'bf16',
'alpha': 5e-3,
'eps': 0.05
}
},
'pid': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),

@ -60,6 +60,7 @@ def init_database():
{'perturbation_code': 'aspl', 'perturbation_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense'},
{'perturbation_code': 'simac', 'perturbation_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy'},
{'perturbation_code': 'caat', 'perturbation_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck'},
{'perturbation_code': 'caat_pro', 'perturbation_name': 'CAAT Pro算法', 'description': 'CAAT with Prior Preservation - Enhanced version with class data preservation'},
{'perturbation_code': 'pid', 'perturbation_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models'},
{'perturbation_code': 'glaze', 'perturbation_name': 'Glaze算法', 'description': 'Protecting Artists from Style Mimicry by Text-to-Image Models'}
]

Loading…
Cancel
Save