feat: 添加专题防护算法后端调用

pull/29/head
梁浩 5 months ago
parent 5cc9719244
commit e493ef7b84

@ -143,6 +143,24 @@ def list_perturbation_configs(current_user_id):
]}), 200
@task_bp.route('/perturbation/style-presets', methods=['GET'])
@int_jwt_required
def get_style_presets(current_user_id):
"""获取风格迁移防护的预设风格列表"""
presets = AlgorithmConfig.get_style_protection_presets()
return jsonify({
'presets': [
{
'style_code': code,
'name': info['name'],
'prompt': info['prompt'],
'description': info['description']
}
for code, info in presets.items()
]
}), 200
@task_bp.route('/perturbation', methods=['POST'])
@int_jwt_required
def create_perturbation_task(current_user_id):
@ -157,6 +175,7 @@ def create_perturbation_task(current_user_id):
perturbation_configs_id = data.get('perturbation_configs_id', type=int) if hasattr(data, 'get') else int(data.get('perturbation_configs_id', 0))
intensity = data.get('perturbation_intensity', type=float) if hasattr(data, 'get') else float(data.get('perturbation_intensity', 0))
description = data.get('description')
target_style = data.get('target_style') # 可选参数仅用于style_protection算法
if not all([data_type_id, perturbation_configs_id, intensity]):
return TaskService.json_error('缺少必要的任务参数')
@ -171,8 +190,19 @@ def create_perturbation_task(current_user_id):
role_code = user.role.role_code if user.role else 'user'
if role_code in ('user', 'normal') and data_type.data_type_code != 'face':
return TaskService.json_error('普通用户仅可使用人脸数据集', 403)
if not PerturbationConfig.query.get(perturbation_configs_id):
pert_config = PerturbationConfig.query.get(perturbation_configs_id)
if not pert_config:
return TaskService.json_error('加噪配置不存在')
# 如果是风格迁移防护算法验证target_style参数
if pert_config.perturbation_code == 'style_protection':
if not target_style:
return TaskService.json_error('风格迁移防护算法必须指定target_style参数')
# 验证风格代码是否有效
style_prompt = AlgorithmConfig.get_style_prompt(target_style)
if not style_prompt:
return TaskService.json_error(f'无效的风格代码: {target_style}。请使用 /api/task/perturbation/style-presets 查看可用风格')
# 验证上传的图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
@ -215,7 +245,8 @@ def create_perturbation_task(current_user_id):
data_type_id=data_type_id,
perturbation_configs_id=perturbation_configs_id,
perturbation_intensity=float(intensity),
perturbation_name=data.get('perturbation_name')
perturbation_name=data.get('perturbation_name'),
target_style=target_style # 保存用户选择的风格
)
db.session.add(perturbation)
db.session.commit()
@ -260,6 +291,14 @@ def update_perturbation_task(task_id, current_user_id):
pert.perturbation_intensity = float(data['perturbation_intensity'])
if 'perturbation_name' in data:
pert.perturbation_name = data['perturbation_name']
if 'target_style' in data:
# 如果更新target_style验证风格代码有效性
target_style = data['target_style']
if target_style:
style_prompt = AlgorithmConfig.get_style_prompt(target_style)
if not style_prompt:
return TaskService.json_error(f'无效的风格代码: {target_style}')
pert.target_style = target_style
if 'description' in data:
task.description = data['description']

@ -229,6 +229,7 @@ class Perturbation(db.Model):
perturbation_name = db.Column(String(100), comment='加噪任务自定义名称')
perturbation_configs_id = db.Column(Integer, ForeignKey('perturbation_configs.perturbation_configs_id'), nullable=False, comment='使用的算法')
perturbation_intensity = db.Column(Float, nullable=False, comment='扰动强度')
target_style = db.Column(String(100), comment='风格迁移防护的目标风格(仅用于style_protection算法)')
# 关系
task = db.relationship('Task', back_populates='perturbation')

@ -221,6 +221,7 @@ class TaskService:
'perturbation_configs_id': task.perturbation.perturbation_configs_id,
'perturbation_intensity': float(task.perturbation.perturbation_intensity),
'perturbation_name': task.perturbation.perturbation_name,
'target_style': task.perturbation.target_style,
}
elif task_type == 'finetune' and task.finetune:
try:

@ -184,14 +184,23 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
if 'class_prompt' in default_params:
default_params['class_prompt'] = class_prompt
# 如果是风格迁移防护算法,使用用户选择的风格
if algorithm_code == 'style_protection' and perturbation.target_style:
style_prompt = AlgorithmConfig.get_style_prompt(perturbation.target_style)
if style_prompt:
default_params['target_style'] = style_prompt
logger.info(f"Using user-selected style: {perturbation.target_style} -> '{style_prompt}'")
else:
logger.warning(f"Invalid target_style '{perturbation.target_style}', using default")
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
# 根据算法构建命令参数参考sh脚本
cmd_args = []
if algorithm_code in ['aspl', 'simac']:
# ASPL和SimAC使用相同的参数结构
if algorithm_code in ['aspl', 'simac', 'anti_customize']:
# ASPL、SimAC 和防定制生成使用相同的参数结构
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
@ -214,15 +223,15 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
f"--class_data_dir={class_dir}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'pid':
# PID参数结构
elif algorithm_code in ['pid', 'anti_face_edit']:
# PID 和防人脸编辑参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={float(epsilon)}",
f"--eps={int(epsilon)}",
])
elif algorithm_code == 'glaze':
# Glaze参数结构
elif algorithm_code in ['glaze', 'style_protection']:
# Glaze 和风格迁移防护参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
@ -283,7 +292,7 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
)
# 清理class_dir
if algorithm_code in ['aspl', 'simac', 'caat_pro']:
if algorithm_code in ['aspl', 'simac', 'anti_customize', 'caat_pro']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
shutil.rmtree(class_dir)

@ -32,6 +32,33 @@ class AlgorithmConfig:
# 日志目录
LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# 风格迁移防护预设风格列表
STYLE_PROTECTION_PRESETS = {
'van_gogh': {
'name': '梵高印象派',
'prompt': 'impressionism painting by van gogh',
'description': '模仿梵高的印象派绘画风格'
},
'kandinsky': {
'name': '康定斯基抽象派',
'prompt': 'abstract art by kandinsky',
'description': '模仿康定斯基的抽象艺术风格'
},
'picasso': {
'name': '毕加索立体派',
'prompt': 'cubist painting by picasso',
'description': '模仿毕加索的立体主义风格'
},
'baroque': {
'name': '巴洛克风格',
'prompt': 'baroque style painting',
'description': '经典巴洛克艺术风格'
}
}
# 日志目录
LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# Conda环境配置从环境变量读取支持自定义
CONDA_ENVS = {
'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'),
@ -40,6 +67,9 @@ class AlgorithmConfig:
'caat_pro': os.getenv('CONDA_ENV_CAAT_PRO', 'caat'),
'pid': os.getenv('CONDA_ENV_PID', 'pid'),
'glaze': os.getenv('CONDA_ENV_GLAZE', 'pid'),
'anti_customize': os.getenv('CONDA_ENV_ANTI_CUSTOMIZE', 'simac'),
'anti_face_edit': os.getenv('CONDA_ENV_ANTI_FACE_EDIT', 'pid'),
'style_protection': os.getenv('CONDA_ENV_STYLE_PROTECTION', 'pid'),
'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'),
'lora': os.getenv('CONDA_ENV_LORA', 'pid'),
'textual_inversion': os.getenv('CONDA_ENV_TI', 'pid'),
@ -159,7 +189,65 @@ class AlgorithmConfig:
'center_crop': True,
'max_train_steps': 150,
'eps': 0.05,
'target_style': 'cubism painting by picasso',
'target_style': 'impressionism painting by van gogh',
'style_strength': 0.75,
'n_runs': 3,
'style_transfer_iter': 15,
'guidance_scale': 7.5,
'seed': 42
}
},
'anti_customize': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'simac.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['anti_customize'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model1'],
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 100,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 100,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'checkpointing_iterations': 20,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
'seed': 0
}
},
'anti_face_edit': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['anti_face_edit'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'max_train_steps': 2000,
'center_crop': True,
'step_size': 0.002,
'save_every': 200,
'attack_type': 'add-log',
'seed': 0,
'dataloader_num_workers': 2
}
},
'style_protection': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'glaze.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['style_protection'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'center_crop': True,
'max_train_steps': 150,
'eps': 0.04,
'target_style': 'impressionism painting by van gogh',
'style_strength': 0.75,
'n_runs': 3,
'style_transfer_iter': 15,
@ -195,6 +283,17 @@ class AlgorithmConfig:
config = cls.get_perturbation_config(algorithm_code)
return config.get('default_params', {}).copy()
@classmethod
def get_style_protection_presets(cls):
"""获取风格迁移防护的预设风格列表"""
return cls.STYLE_PROTECTION_PRESETS
@classmethod
def get_style_prompt(cls, style_code):
"""根据风格代码获取对应的提示词"""
preset = cls.STYLE_PROTECTION_PRESETS.get(style_code)
return preset['prompt'] if preset else None
# ========== 微调算法配置 ==========
FINETUNE_SCRIPTS = {
'dreambooth': {

@ -62,7 +62,10 @@ def init_database():
{'perturbation_code': 'caat', 'perturbation_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck'},
{'perturbation_code': 'caat_pro', 'perturbation_name': 'CAAT Pro算法', 'description': 'CAAT with Prior Preservation - Enhanced version with class data preservation'},
{'perturbation_code': 'pid', 'perturbation_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models'},
{'perturbation_code': 'glaze', 'perturbation_name': 'Glaze算法', 'description': 'Protecting Artists from Style Mimicry by Text-to-Image Models'}
{'perturbation_code': 'glaze', 'perturbation_name': 'Glaze算法', 'description': 'Protecting Artists from Style Mimicry by Text-to-Image Models'},
{'perturbation_code': 'anti_customize', 'perturbation_name': '防定制生成', 'description': 'Anti-Customization Generation - 专门防止人脸定制化生成'},
{'perturbation_code': 'anti_face_edit', 'perturbation_name': '防人脸编辑', 'description': 'Anti-Face-Editing - 专门防止人脸图像被编辑'},
{'perturbation_code': 'style_protection', 'perturbation_name': '风格迁移防护', 'description': 'Style Transfer Protection - 保护艺术作品免受风格模仿'}
]
for config in perturbation_configs:
@ -96,7 +99,7 @@ def init_database():
},
{
'data_type_code': 'art',
'instance_prompt': 'a painting in sks style',
'instance_prompt': 'a painting in <sks-style> style',
'class_prompt': 'a painting',
'placeholder_token': '<sks-style>',
'initializer_token': 'painting',

Loading…
Cancel
Save