将lianghao_branch合并到develop #29

Merged
hnu202326010204 merged 4 commits from lianghao_branch into develop 3 weeks ago

@ -325,42 +325,145 @@
"configs": [
{
"perturbation_configs_id": 1,
"perturbation_code": "aspl",
"perturbation_name": "ASPL算法",
"description": "Advanced Semantic Protection Layer for Enhanced Privacy Defense"
},
{
"perturbation_configs_id": 2,
"perturbation_code": "simac",
"perturbation_name": "SimAC算法",
"description": "Simple Anti-Customization Method"
"description": "Simple Anti-Customization Method for Protecting Face Privacy"
},
{
"perturbation_configs_id": 2,
"perturbation_configs_id": 3,
"perturbation_code": "caat",
"perturbation_name": "CAAT算法",
"description": "Perturbing Attention..."
"description": "Perturbing Attention Gives You More Bang for the Buck"
},
{
"perturbation_configs_id": 4,
"perturbation_code": "caat_pro",
"perturbation_name": "CAAT Pro算法",
"description": "CAAT with Prior Preservation - Enhanced version with class data preservation"
},
{
"perturbation_configs_id": 5,
"perturbation_code": "pid",
"perturbation_name": "PID算法",
"description": "Prompt-Independent Data Protection Against Latent Diffusion Models"
},
{
"perturbation_configs_id": 6,
"perturbation_code": "glaze",
"perturbation_name": "Glaze算法",
"description": "Protecting Artists from Style Mimicry by Text-to-Image Models"
},
{
"perturbation_configs_id": 7,
"perturbation_code": "anti_customize",
"perturbation_name": "防定制生成",
"description": "Anti-Customization Generation - 专门防止人脸定制化生成"
},
{
"perturbation_configs_id": 8,
"perturbation_code": "anti_face_edit",
"perturbation_name": "防人脸编辑",
"description": "Anti-Face-Editing - 专门防止人脸图像被编辑"
},
{
"perturbation_configs_id": 9,
"perturbation_code": "style_protection",
"perturbation_name": "风格迁移防护",
"description": "Style Transfer Protection - 保护艺术作品免受风格模仿"
}
]
}
```
**说明**
- `perturbation_configs_id=7,8` 仅适用于人脸数据集(`data_type_id=1`
- `perturbation_configs_id=9` 仅适用于艺术作品数据集(`data_type_id=2`)且**必须**指定 `target_style` 参数
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `500 Internal Server Error`
##### GET `/api/task/perturbation/style-presets`
**功能**:获取风格迁移防护算法的预设风格列表(仅用于 `style_protection` 算法)。
**认证**:是
**成功响应** `200 OK`
```json
{
"presets": [
{
"style_code": "van_gogh",
"name": "梵高印象派",
"prompt": "impressionism painting by van gogh",
"description": "模仿梵高的印象派绘画风格"
},
{
"style_code": "kandinsky",
"name": "康定斯基抽象派",
"prompt": "abstract art by kandinsky",
"description": "模仿康定斯基的抽象艺术风格"
},
{
"style_code": "picasso",
"name": "毕加索立体派",
"prompt": "cubist painting by picasso",
"description": "模仿毕加索的立体主义风格"
},
{
"style_code": "baroque",
"name": "巴洛克风格",
"prompt": "baroque style painting",
"description": "经典巴洛克艺术风格"
}
]
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
##### POST `/api/task/perturbation`
**功能**:基于指定数据集与配置创建加噪任务,并支持图片文件一并上传。
**请求格式**`multipart/form-data`,字段:
- `data_type_id`(数字,必填)
- `perturbation_configs_id`(数字,必填)
- `perturbation_intensity`(数字,必填)
- `perturbation_name`(字符串,可选)
- `description`(字符串,可选)
- `flow_id`(数字,可选)
- `files`(一个或多个图片文件,可选)
- `data_type_id`(数字,必填)- 数据集类型ID1=人脸2=艺术作品
- `perturbation_configs_id`(数字,必填)- 算法配置ID参考上方配置列表
- `perturbation_intensity`(数字,必填)- 扰动强度epsilon值
- `target_style`(字符串,条件必填)- **仅当使用 `style_protection` 算法时必填**,可选值:`van_gogh`、`kandinsky`、`picasso`、`baroque`
- `perturbation_name`(字符串,可选)- 任务自定义名称
- `description`(字符串,可选)- 任务描述
- `flow_id`(数字,可选)- 流程ID若不指定则自动生成
- `files`(文件数组,必填)- 一个或多个图片文件jpg/jpeg/png/bmp/gif/webp/tiff
**成功响应** `201 Created`
```json
{
"message": "加噪任务已创建并已启动",
"task": { ... },
"job_id": "pert_901"
"task": {
"task_id": 123,
"flow_id": 1734700000000,
"task_type": "perturbation",
"status": "waiting",
"user_id": 1,
"description": null,
"created_at": "2025-12-20T10:00:00",
"started_at": null,
"finished_at": null,
"error_message": null,
"perturbation": {
"data_type_id": 2,
"perturbation_configs_id": 9,
"perturbation_intensity": 0.04,
"perturbation_name": "梵高风格保护",
"target_style": "van_gogh"
}
},
"job_id": "pert_123"
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `400 {"error": "缺少必要的任务参数"}`
@ -368,27 +471,76 @@
- `400 {"error": "数据集类型不存在"}`
- `403 {"error": "普通用户仅可使用人脸数据集"}`
- `400 {"error": "加噪配置不存在"}`
- `400 {"error": "风格迁移防护算法必须指定target_style参数"}`
- `400 {"error": "无效的风格代码: xxx。请使用 /api/task/perturbation/style-presets 查看可用风格"}`
- `400 {"error": "请上传至少一张图片"}`
- `400 {"error": "不支持的文件格式: xxx。仅支持图片格式。"}`
- `400 {"error": "非法的 flow_id 参数"}`
- `500 {"error": "Task status 'waiting' is not configured"}` / `{...}`
- `500 {"error": "创建任务失败: ..."}`
**特殊说明**
1. **算法与数据集类型限制**
- `anti_customize`防定制生成ID=7`anti_face_edit`防人脸编辑ID=8仅适用于人脸数据集`data_type_id=1`
- `style_protection`风格迁移防护ID=9仅适用于艺术作品数据集`data_type_id=2`
2. **风格选择**
- 使用 `style_protection` 算法时,**必须**通过 `target_style` 参数指定预设风格
- 可用风格代码:`van_gogh`、`kandinsky`、`picasso`、`baroque`
- 使用 `GET /api/task/perturbation/style-presets` 查看完整风格列表
3. **扰动强度**
- 不同算法的 `perturbation_intensity` 取值范围不同
- 大部分算法0.01-0.1(浮点数)
- 部分算法如SimAC系列整数值如16
##### PATCH `/api/task/perturbation/<task_id>`
**功能**:调整已有加噪任务的参数或描述。
**请求体**:可选字段 同创建接口。
**请求体**JSON格式所有字段可选
```json
{
"data_type_id": 2,
"perturbation_configs_id": 9,
"perturbation_intensity": 0.05,
"perturbation_name": "更新后的任务名",
"target_style": "picasso",
"description": "更新后的描述"
}
```
**成功响应** `200 OK`
```json
{
"message": "任务已更新",
"task": { ... 同 `serialize_task` 输出 ... }
"task": {
"task_id": 123,
"flow_id": 1734700000000,
"task_type": "perturbation",
"status": "waiting",
"perturbation": {
"data_type_id": 2,
"perturbation_configs_id": 9,
"perturbation_intensity": 0.05,
"perturbation_name": "更新后的任务名",
"target_style": "picasso"
}
}
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
- `404 {"error": "任务配置不存在"}`
- `400 {"error": "数据集类型不存在"}`
- `400 {"error": "加噪配置不存在"}`
- `400 {"error": "无效的风格代码: xxx"}`
- `500 {"error": "更新任务失败: ..."}`(数据库提交失败或参数类型转换异常)
**说明**
- 仅可更新未执行或执行失败的任务
- 更新 `target_style` 时会自动验证风格代码有效性
##### POST `/api/task/perturbation/<task_id>/start`
**功能**:向异步队列提交该加噪任务,并将任务状态重置为 `waiting`
**成功响应** `200 OK`
@ -406,15 +558,43 @@
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `500 {"error": "Task type 'perturbation' is not configured"}`(数据库缺少任务类型配置时触发)
##### GET `/api/task/perturbation/<task_id>`
返回单个任务结构。
##### GET `/api/task/perturbation/<task_id>`
**功能**:查看指定加噪任务的完整信息。
**成功响应** `200 OK`
```json
{
"task": {
"task_id": 123,
"flow_id": 1734700000000,
"task_type": "perturbation",
"status": "completed",
"user_id": 1,
"description": "艺术作品保护",
"created_at": "2025-12-20T10:00:00",
"started_at": "2025-12-20T10:01:00",
"finished_at": "2025-12-20T10:15:00",
"error_message": null,
"perturbation": {
"data_type_id": 2,
"perturbation_configs_id": 9,
"perturbation_intensity": 0.04,
"perturbation_name": "梵高风格保护",
"target_style": "van_gogh"
}
}
}
```
**错误响应**
- `401 {"error": "无效的用户身份标识"}`
- `404 {"error": "任务不存在或无权限"}`
**说明**
- `target_style` 字段仅在使用 `style_protection` 算法时有值
- 其他算法该字段为 `null`
#### 热力图任务相关
##### POST `/api/task/heatmap`
@ -1700,6 +1880,14 @@ Authorization: Bearer <token>
## 文档更新记录
### 2025-12-20 风格迁移防护功能更新
- [GET /api/task/perturbation/configs](#get-apitaskperturbationconfigs)更新算法配置列表新增9种算法的完整信息及适用范围说明。
- [GET /api/task/perturbation/style-presets](#get-apitaskperturbationstyle-presets)**新增接口**用于获取风格迁移防护算法的4种预设风格梵高/康定斯基/毕加索/巴洛克)。
- [POST /api/task/perturbation](#post-apitaskperturbation):新增 `target_style` 参数(风格迁移防护算法必填),完善请求示例和错误处理,新增算法与数据集类型限制说明。
- [PATCH /api/task/perturbation/<task_id>](#patch-apitaskperturbationtask_id):新增 `target_style` 参数支持,完善请求响应示例。
- [GET /api/task/perturbation/<task_id>](#get-apitaskperturbationtask_id):更新响应示例,包含 `target_style` 字段说明。
### 历史更新
- [POST /api/task/finetune/from-perturbation](#post-apitaskfinetunefrom-perturbation):新增 `custom_prompt` 参数。
- [POST /api/task/finetune/from-upload](#post-apitaskfinetunefrom-upload):新增 `custom_prompt` 参数。
- [GET /api/task/finetune/<task_id>/coords](#get-apitaskfinetunetask_idcoords)完善3D可视化坐标数据接口文档新增详细的请求响应格式说明和错误处理。

@ -1,132 +0,0 @@
import pandas as pd
from pathlib import Path
import sys
import logging
import numpy as np
import statsmodels.api as sm
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def apply_lowess_and_clipping_scaling(input_csv_path, output_csv_path, lowess_frac, target_range, clipping_percentile):
"""
应用 Lowess 局部加权回归进行平滑提取总体趋势然后使用百分位数裁剪后的 Min-Max 边界来缩放
目标生成最平滑最接近单调下降的客观趋势
"""
input_path = Path(input_csv_path)
output_path = Path(output_csv_path)
if not input_path.exists():
logging.error(f"错误:未找到输入文件 {input_csv_path}")
return
logging.info(f"读取原始数据: {input_csv_path}")
df = pd.read_csv(input_path)
df = df.loc[:,~df.columns.duplicated()].copy()
# 定义原始数据列名
raw_x_col = 'X_Feature_L2_Norm'
raw_y_col = 'Y_Feature_Variance'
raw_z_col = 'Z_LDM_Loss'
# --------------------------- 1. Lowess 局部加权回归平滑 (提取总体趋势) ---------------------------
logging.info(f"应用 Lowess 局部加权回归,平滑因子 frac={lowess_frac}")
x_coords = df['step'].values
for raw_col in [raw_x_col, raw_y_col, raw_z_col]:
y_coords = df[raw_col].values
smoothed_data = sm.nonparametric.lowess(
endog=y_coords,
exog=x_coords,
frac=lowess_frac,
it=0
)
df[f'{raw_col}_LOWESS'] = smoothed_data[:, 1]
# --------------------------- 2. 百分位数边界缩放与方向统一 ---------------------------
p = clipping_percentile
logging.info(f"应用百分位数边界 (p={p}) 进行线性缩放,目标范围 [0, {target_range:.2f}]")
scale_cols_map = {
'X_Feature_L2_Norm': f'{raw_x_col}_LOWESS',
'Y_Feature_Variance': f'{raw_y_col}_LOWESS',
'Z_LDM_Loss': f'{raw_z_col}_LOWESS'
}
for final_col, lowess_col in scale_cols_map.items():
data = df[lowess_col]
# 裁剪:计算裁剪后的 min/max (定义缩放窗口)
lower_bound = data.quantile(p)
upper_bound = data.quantile(1.0 - p)
min_val = lower_bound
max_val = upper_bound
data_range = max_val - min_val
if data_range <= 0 or data_range == np.nan:
df[final_col] = 0.0
logging.warning(f"{final_col} 裁剪后的范围为 {data_range:.4f},跳过缩放。")
continue
# 归一化: (data - Min_window) / Range_window
normalized_data = (data - min_val) / data_range
# **优化方向统一逻辑 (所有指标都应是越小越好):**
if final_col in ['X_Feature_L2_Norm', 'Y_Feature_Variance']:
# X/Y 反转:将 Max 映射到 0Min 映射到 TargetRange
final_scaled_data = (1.0 - normalized_data) * target_range
else: # Z_LDM_Loss
# Z 标准缩放Min 映射到 0Max 映射到 TargetRange
final_scaled_data = normalized_data * target_range
# 保留负值,以确保平滑过渡
df[final_col] = final_scaled_data
logging.info(f" - 列 {final_col}:裁剪边界: [{min_val:.4f}, {max_val:.4f}]。缩放后范围不再严格约束 [0, {target_range:.2f}],以保留趋势。")
# --------------------------- 3. 最终保存 ---------------------------
output_path.parent.mkdir(parents=True, exist_ok=True)
final_cols = ['step', 'X_Feature_L2_Norm', 'Y_Feature_Variance', 'Z_LDM_Loss']
df[final_cols].to_csv(
output_path,
index=False,
float_format='%.3f'
)
logging.info(f"Lowess平滑和缩放后的数据已保存到: {output_csv_path}")
if __name__ == '__main__':
if len(sys.argv) != 6:
logging.error("使用方法: python smooth_coords.py <输入CSV路径> <输出CSV路径> <Lowess 平滑因子 frac (例如 0.4)> <目标视觉范围 (例如 30)> <离散点裁剪百分比 (例如 0.15)>")
else:
input_csv = sys.argv[1]
output_csv = sys.argv[2]
try:
lowess_frac = float(sys.argv[3])
target_range = float(sys.argv[4])
clipping_p = float(sys.argv[5])
if not (0.0 < lowess_frac <= 1.0):
raise ValueError("Lowess 平滑因子 frac 必须在 (0.0, 1.0] 之间。")
if target_range <= 0:
raise ValueError("目标视觉范围必须大于 0。")
if not (0 <= clipping_p < 0.5):
raise ValueError("裁剪百分比必须在 [0, 0.5) 之间。")
if not Path(output_csv).suffix:
output_csv = str(Path(output_csv) / "scaled_coords.csv")
apply_lowess_and_clipping_scaling(input_csv, output_csv, lowess_frac, target_range, clipping_p)
except ValueError as e:
logging.error(f"参数错误: {e}")

@ -1,149 +0,0 @@
"""
图片处理功能用于把原始图片剪裁为中心正方形指定分辨率并保存为指定格式还可以选择是否序列化改名
"""
import argparse
import os
from pathlib import Path
from PIL import Image
# --- 1. 参数解析 ---
def parse_args(input_args=None):
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description="Image Processor for Centering, Resizing, and Format Conversion.")
# 路径和分辨率参数
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="A folder containing the original images to be processed and overwritten.",
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help="The target resolution (width and height) for the output images (e.g., 512 for 512x512).",
)
# 格式参数
parser.add_argument(
"--target_format",
type=str,
default="png",
choices=["jpeg", "png", "webp", "jpg"],
help="The target format for the saved images (e.g., 'png', 'jpg', 'webp'). The original file will be overwritten, potentially changing the file extension.",
)
# 序列化数字重命名参数
parser.add_argument(
"--rename_sequential",
action="store_true", # 当这个参数存在时,其值为 True
help="If set, images will be sequentially renamed (e.g., 001.jpg, 002.jpg...) instead of preserving the original filename. WARNING: This WILL delete the originals.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
# --- 2. 核心图像处理逻辑 ---
def process_image(image_path: Path, output_path: Path, resolution: int, target_format: str, delete_original: bool):
"""
加载图像居中取最大正方形升降分辨率并保存为目标格式
Args:
image_path: 原始图片路径
output_path: 最终保存路径
resolution: 目标分辨率
target_format: 目标文件格式
delete_original: 是否删除原始文件
"""
try:
# 加载图像并统一转换为 RGB 模式
img = Image.open(image_path).convert("RGB")
# 居中取最大正方形
width, height = img.size
min_dim = min(width, height)
# 计算裁剪框 (以最短边为尺寸的中心正方形)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
# 裁剪中心正方形
img = img.crop((left, top, right, bottom))
# 升降分辨率到指定 resolution
# 使用 LANCZOS 高质量重采样方法
img = img.resize((resolution, resolution), resample=Image.Resampling.LANCZOS)
# 准备输出格式
save_format = target_format.upper().replace('JPEG', 'JPG')
# 保存图片
# 对于 JPEG/JPG设置 quality 参数
if save_format == 'JPG':
img.save(output_path, format='JPEG', quality=95)
else:
img.save(output_path, format=save_format)
# 根据标记决定是否删除原始文件
if delete_original and image_path.resolve() != output_path.resolve():
os.remove(image_path)
print(f"Processed: {image_path.name} -> {output_path.name} ({resolution}x{resolution} {save_format})")
except Exception as e:
print(f"Error processing {image_path.name}: {e}")
# --- 3. 主函数 ---
def main(args):
# 路径准备
input_dir = Path(args.input_dir)
if not input_dir.is_dir():
print(f"Error: Input directory not found at {input_dir}")
return
# 查找所有图片文件 (支持 jpg, jpeg, png, webp)
valid_suffixes = ['.jpg', '.jpeg', '.png', '.webp']
image_paths = sorted([p for p in input_dir.iterdir() if p.suffix.lower() in valid_suffixes]) # 排序以确保重命名顺序一致
if not image_paths:
print(f"No image files found in {input_dir}")
return
print(f"Found {len(image_paths)} images in {input_dir}. Starting processing...")
# 准备目标格式的扩展名
extension = args.target_format.lower().replace('jpeg', 'jpg')
# 迭代处理图片
for i, image_path in enumerate(image_paths):
# 决定输出路径
if args.rename_sequential:
# 顺序重命名逻辑001, 002, 003... (至少三位数字)
new_name = f"{i + 1:03d}.{extension}"
output_path = input_dir / new_name
# 如果原始文件与新文件名称不同,则需要删除原始文件
delete_original = True
else:
# 保持原始文件名,但修改后缀
output_path = image_path.with_suffix(f'.{extension}')
# 只有当原始后缀与目标后缀不同时,才需要删除原始文件(防止遗留旧格式)
delete_original = (image_path.suffix.lower() != f'.{extension}')
process_image(image_path, output_path, args.resolution, args.target_format, delete_original)
print("Processing complete.")
if __name__ == "__main__":
args = parse_args()
main(args)

@ -143,6 +143,24 @@ def list_perturbation_configs(current_user_id):
]}), 200
@task_bp.route('/perturbation/style-presets', methods=['GET'])
@int_jwt_required
def get_style_presets(current_user_id):
"""获取风格迁移防护的预设风格列表"""
presets = AlgorithmConfig.get_style_protection_presets()
return jsonify({
'presets': [
{
'style_code': code,
'name': info['name'],
'prompt': info['prompt'],
'description': info['description']
}
for code, info in presets.items()
]
}), 200
@task_bp.route('/perturbation', methods=['POST'])
@int_jwt_required
def create_perturbation_task(current_user_id):
@ -157,6 +175,7 @@ def create_perturbation_task(current_user_id):
perturbation_configs_id = data.get('perturbation_configs_id', type=int) if hasattr(data, 'get') else int(data.get('perturbation_configs_id', 0))
intensity = data.get('perturbation_intensity', type=float) if hasattr(data, 'get') else float(data.get('perturbation_intensity', 0))
description = data.get('description')
target_style = data.get('target_style') # 可选参数仅用于style_protection算法
if not all([data_type_id, perturbation_configs_id, intensity]):
return TaskService.json_error('缺少必要的任务参数')
@ -171,8 +190,19 @@ def create_perturbation_task(current_user_id):
role_code = user.role.role_code if user.role else 'user'
if role_code in ('user', 'normal') and data_type.data_type_code != 'face':
return TaskService.json_error('普通用户仅可使用人脸数据集', 403)
if not PerturbationConfig.query.get(perturbation_configs_id):
pert_config = PerturbationConfig.query.get(perturbation_configs_id)
if not pert_config:
return TaskService.json_error('加噪配置不存在')
# 如果是风格迁移防护算法验证target_style参数
if pert_config.perturbation_code == 'style_protection':
if not target_style:
return TaskService.json_error('风格迁移防护算法必须指定target_style参数')
# 验证风格代码是否有效
style_prompt = AlgorithmConfig.get_style_prompt(target_style)
if not style_prompt:
return TaskService.json_error(f'无效的风格代码: {target_style}。请使用 /api/task/perturbation/style-presets 查看可用风格')
# 验证上传的图片
files = request.files.getlist('files') if hasattr(request, 'files') else []
@ -215,7 +245,8 @@ def create_perturbation_task(current_user_id):
data_type_id=data_type_id,
perturbation_configs_id=perturbation_configs_id,
perturbation_intensity=float(intensity),
perturbation_name=data.get('perturbation_name')
perturbation_name=data.get('perturbation_name'),
target_style=target_style # 保存用户选择的风格
)
db.session.add(perturbation)
db.session.commit()
@ -260,6 +291,14 @@ def update_perturbation_task(task_id, current_user_id):
pert.perturbation_intensity = float(data['perturbation_intensity'])
if 'perturbation_name' in data:
pert.perturbation_name = data['perturbation_name']
if 'target_style' in data:
# 如果更新target_style验证风格代码有效性
target_style = data['target_style']
if target_style:
style_prompt = AlgorithmConfig.get_style_prompt(target_style)
if not style_prompt:
return TaskService.json_error(f'无效的风格代码: {target_style}')
pert.target_style = target_style
if 'description' in data:
task.description = data['description']

@ -229,6 +229,7 @@ class Perturbation(db.Model):
perturbation_name = db.Column(String(100), comment='加噪任务自定义名称')
perturbation_configs_id = db.Column(Integer, ForeignKey('perturbation_configs.perturbation_configs_id'), nullable=False, comment='使用的算法')
perturbation_intensity = db.Column(Float, nullable=False, comment='扰动强度')
target_style = db.Column(String(100), comment='风格迁移防护的目标风格(仅用于style_protection算法)')
# 关系
task = db.relationship('Task', back_populates='perturbation')

@ -221,6 +221,7 @@ class TaskService:
'perturbation_configs_id': task.perturbation.perturbation_configs_id,
'perturbation_intensity': float(task.perturbation.perturbation_intensity),
'perturbation_name': task.perturbation.perturbation_name,
'target_style': task.perturbation.target_style,
}
elif task_type == 'finetune' and task.finetune:
try:

@ -184,14 +184,23 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
if 'class_prompt' in default_params:
default_params['class_prompt'] = class_prompt
# 如果是风格迁移防护算法,使用用户选择的风格
if algorithm_code == 'style_protection' and perturbation.target_style:
style_prompt = AlgorithmConfig.get_style_prompt(perturbation.target_style)
if style_prompt:
default_params['target_style'] = style_prompt
logger.info(f"Using user-selected style: {perturbation.target_style} -> '{style_prompt}'")
else:
logger.warning(f"Invalid target_style '{perturbation.target_style}', using default")
# 合并自定义参数
params = {**default_params, **(custom_params or {})}
# 根据算法构建命令参数参考sh脚本
cmd_args = []
if algorithm_code in ['aspl', 'simac']:
# ASPL和SimAC使用相同的参数结构
if algorithm_code in ['aspl', 'simac', 'anti_customize']:
# ASPL、SimAC 和防定制生成使用相同的参数结构
cmd_args.extend([
f"--instance_data_dir_for_train={input_dir}",
f"--instance_data_dir_for_adversarial={input_dir}",
@ -206,15 +215,23 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
f"--output_dir={output_dir}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'pid':
# PID参数结构
elif algorithm_code == 'caat_pro':
# CAAT Pro参数结构带prior preservation
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--class_data_dir={class_dir}",
f"--eps={float(epsilon)}",
])
elif algorithm_code == 'glaze':
# Glaze参数结构
elif algorithm_code in ['pid', 'anti_face_edit']:
# PID 和防人脸编辑参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
f"--eps={int(epsilon)}",
])
elif algorithm_code in ['glaze', 'style_protection']:
# Glaze 和风格迁移防护参数结构
cmd_args.extend([
f"--instance_data_dir={input_dir}",
f"--output_dir={output_dir}",
@ -275,7 +292,7 @@ def _run_real_algorithm(script_path, conda_env, algorithm_code, task_id,
)
# 清理class_dir
if algorithm_code in ['aspl', 'simac']:
if algorithm_code in ['aspl', 'simac', 'anti_customize', 'caat_pro']:
logger.info(f"Cleaning class directory: {class_dir}")
if os.path.exists(class_dir):
shutil.rmtree(class_dir)

@ -32,13 +32,44 @@ class AlgorithmConfig:
# 日志目录
LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# 风格迁移防护预设风格列表
STYLE_PROTECTION_PRESETS = {
'van_gogh': {
'name': '梵高印象派',
'prompt': 'impressionism painting by van gogh',
'description': '模仿梵高的印象派绘画风格'
},
'kandinsky': {
'name': '康定斯基抽象派',
'prompt': 'abstract art by kandinsky',
'description': '模仿康定斯基的抽象艺术风格'
},
'picasso': {
'name': '毕加索立体派',
'prompt': 'cubist painting by picasso',
'description': '模仿毕加索的立体主义风格'
},
'baroque': {
'name': '巴洛克风格',
'prompt': 'baroque style painting',
'description': '经典巴洛克艺术风格'
}
}
# 日志目录
LOGS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# Conda环境配置从环境变量读取支持自定义
CONDA_ENVS = {
'aspl': os.getenv('CONDA_ENV_ASPL', 'simac'),
'simac': os.getenv('CONDA_ENV_SIMAC', 'simac'),
'caat': os.getenv('CONDA_ENV_CAAT', 'caat'),
'caat_pro': os.getenv('CONDA_ENV_CAAT_PRO', 'caat'),
'pid': os.getenv('CONDA_ENV_PID', 'pid'),
'glaze': os.getenv('CONDA_ENV_GLAZE', 'pid'),
'anti_customize': os.getenv('CONDA_ENV_ANTI_CUSTOMIZE', 'simac'),
'anti_face_edit': os.getenv('CONDA_ENV_ANTI_FACE_EDIT', 'pid'),
'style_protection': os.getenv('CONDA_ENV_STYLE_PROTECTION', 'pid'),
'dreambooth': os.getenv('CONDA_ENV_DREAMBOOTH', 'pid'),
'lora': os.getenv('CONDA_ENV_LORA', 'pid'),
'textual_inversion': os.getenv('CONDA_ENV_TI', 'pid'),
@ -116,6 +147,26 @@ class AlgorithmConfig:
'alpha': 5e-3
}
},
'caat_pro': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'caat.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['caat_pro'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'with_prior_preservation': True,
'instance_prompt': 'a photo of a person',
'class_prompt': 'person',
'num_class_images': 200,
'resolution': 512,
'learning_rate': 1e-5,
'lr_warmup_steps': 0,
'max_train_steps': 250,
'hflip': True,
'mixed_precision': 'bf16',
'alpha': 5e-3,
'eps': 0.05
}
},
'pid': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
@ -138,7 +189,65 @@ class AlgorithmConfig:
'center_crop': True,
'max_train_steps': 150,
'eps': 0.05,
'target_style': 'cubism painting by picasso',
'target_style': 'impressionism painting by van gogh',
'style_strength': 0.75,
'n_runs': 3,
'style_transfer_iter': 15,
'guidance_scale': 7.5,
'seed': 42
}
},
'anti_customize': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'simac.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['anti_customize'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model1'],
'enable_xformers_memory_efficient_attention': True,
'instance_prompt': 'a photo of sks person',
'class_prompt': 'a photo of person',
'num_class_images': 100,
'center_crop': True,
'with_prior_preservation': True,
'prior_loss_weight': 1.0,
'resolution': 384,
'train_batch_size': 1,
'max_train_steps': 100,
'max_f_train_steps': 3,
'max_adv_train_steps': 6,
'checkpointing_iterations': 20,
'learning_rate': 5e-7,
'pgd_alpha': 0.005,
'seed': 0
}
},
'anti_face_edit': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'pid.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['anti_face_edit'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'max_train_steps': 2000,
'center_crop': True,
'step_size': 0.002,
'save_every': 200,
'attack_type': 'add-log',
'seed': 0,
'dataloader_num_workers': 2
}
},
'style_protection': {
'real_script': os.path.join(ALGORITHMS_DIR, 'perturbation', 'glaze.py'),
'virtual_script': os.path.join(ALGORITHMS_DIR, 'perturbation_engine.py'),
'conda_env': CONDA_ENVS['style_protection'],
'default_params': {
'pretrained_model_name_or_path': MODELS_DIR['model2'],
'resolution': 512,
'center_crop': True,
'max_train_steps': 150,
'eps': 0.04,
'target_style': 'impressionism painting by van gogh',
'style_strength': 0.75,
'n_runs': 3,
'style_transfer_iter': 15,
@ -174,6 +283,17 @@ class AlgorithmConfig:
config = cls.get_perturbation_config(algorithm_code)
return config.get('default_params', {}).copy()
@classmethod
def get_style_protection_presets(cls):
"""获取风格迁移防护的预设风格列表"""
return cls.STYLE_PROTECTION_PRESETS
@classmethod
def get_style_prompt(cls, style_code):
"""根据风格代码获取对应的提示词"""
preset = cls.STYLE_PROTECTION_PRESETS.get(style_code)
return preset['prompt'] if preset else None
# ========== 微调算法配置 ==========
FINETUNE_SCRIPTS = {
'dreambooth': {

@ -60,8 +60,12 @@ def init_database():
{'perturbation_code': 'aspl', 'perturbation_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense'},
{'perturbation_code': 'simac', 'perturbation_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy'},
{'perturbation_code': 'caat', 'perturbation_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck'},
{'perturbation_code': 'caat_pro', 'perturbation_name': 'CAAT Pro算法', 'description': 'CAAT with Prior Preservation - Enhanced version with class data preservation'},
{'perturbation_code': 'pid', 'perturbation_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models'},
{'perturbation_code': 'glaze', 'perturbation_name': 'Glaze算法', 'description': 'Protecting Artists from Style Mimicry by Text-to-Image Models'}
{'perturbation_code': 'glaze', 'perturbation_name': 'Glaze算法', 'description': 'Protecting Artists from Style Mimicry by Text-to-Image Models'},
{'perturbation_code': 'anti_customize', 'perturbation_name': '防定制生成', 'description': 'Anti-Customization Generation - 专门防止人脸定制化生成'},
{'perturbation_code': 'anti_face_edit', 'perturbation_name': '防人脸编辑', 'description': 'Anti-Face-Editing - 专门防止人脸图像被编辑'},
{'perturbation_code': 'style_protection', 'perturbation_name': '风格迁移防护', 'description': 'Style Transfer Protection - 保护艺术作品免受风格模仿'}
]
for config in perturbation_configs:
@ -95,7 +99,7 @@ def init_database():
},
{
'data_type_code': 'art',
'instance_prompt': 'a painting in sks style',
'instance_prompt': 'a painting in <sks-style> style',
'class_prompt': 'a painting',
'placeholder_token': '<sks-style>',
'initializer_token': 'painting',

Loading…
Cancel
Save