|
|
|
|
@ -1,520 +0,0 @@
|
|
|
|
|
"""Stable Diffusion 注意力热力图差异可视化工具 (可靠版 - 语义阶段聚合)。
|
|
|
|
|
|
|
|
|
|
本模块使用一种健壮的方法,通过在 Stable Diffusion 扩散模型(U-Net)的
|
|
|
|
|
**早期时间步 (语义阶段)** 捕获并累加交叉注意力权重。这种方法能确保捕获到的
|
|
|
|
|
注意力图信号集中且可靠,用于对比分析干净输入和扰动输入生成的图像对模型
|
|
|
|
|
注意力机制的影响差异。
|
|
|
|
|
|
|
|
|
|
典型用法:
|
|
|
|
|
python eva_gen_heatmap.py \\
|
|
|
|
|
--model_path /path/to/sd_model \\
|
|
|
|
|
--image_path_a /path/to/clean_image.png \\
|
|
|
|
|
--image_path_b /path/to/noisy_image.png \\
|
|
|
|
|
--prompt_text "a photo of sks person" \\
|
|
|
|
|
--target_word "sks" \\
|
|
|
|
|
--output_dir output/heatmap_reports
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 通用参数解析与文件路径管理
|
|
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, Any, List, Tuple
|
|
|
|
|
|
|
|
|
|
# 数值计算与深度学习依赖
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
import itertools
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
# 可视化依赖
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import matplotlib.gridspec as gridspec
|
|
|
|
|
from matplotlib.colors import TwoSlopeNorm
|
|
|
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
|
|
|
|
|
|
|
|
# Diffusers 与 Transformers 依赖
|
|
|
|
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
|
|
|
|
from diffusers.models.attention_processor import Attention
|
|
|
|
|
from transformers import CLIPTokenizer
|
|
|
|
|
|
|
|
|
|
# 图像处理与数据读取
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
# 抑制非必要的警告输出
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============== 核心模块:注意力捕获与聚合 ==============
|
|
|
|
|
|
|
|
|
|
class AttentionMapProcessor:
|
|
|
|
|
"""自定义注意力处理器,用于捕获 U-Net 交叉注意力层的权重。
|
|
|
|
|
|
|
|
|
|
通过替换原始的 `Attention` 模块处理器,该类在模型前向传播过程中,
|
|
|
|
|
将所有交叉注意力层的注意力权重(`attention_probs`)捕获并存储。
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
attention_maps (Dict[str, List[torch.Tensor]]): 存储捕获到的注意力图,
|
|
|
|
|
键为层名称,值为该层在不同时间步捕获到的注意力图列表。
|
|
|
|
|
pipeline (StableDiffusionPipeline): 正在处理的 Stable Diffusion 管线。
|
|
|
|
|
original_processors (Dict[str, Any]): 存储原始的注意力处理器,用于恢复。
|
|
|
|
|
current_layer_name (Optional[str]): 当前正在处理的注意力层的名称。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, pipeline: StableDiffusionPipeline):
|
|
|
|
|
"""初始化注意力处理器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
pipeline: Stable Diffusion 模型管线实例。
|
|
|
|
|
"""
|
|
|
|
|
self.attention_maps: Dict[str, List[torch.Tensor]] = {}
|
|
|
|
|
self.pipeline = pipeline
|
|
|
|
|
self.original_processors = {}
|
|
|
|
|
self.current_layer_name = None
|
|
|
|
|
self._set_processors()
|
|
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
|
self,
|
|
|
|
|
attn: Attention,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
encoder_hidden_states: torch.Tensor = None,
|
|
|
|
|
attention_mask: torch.Tensor = None
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""重载 __call__ 方法,执行注意力计算并捕获权重。
|
|
|
|
|
|
|
|
|
|
此方法替代了原始的 `Attention.processor`,在计算交叉注意力时进行捕获。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
attn: 当前的 `Attention` 模块实例。
|
|
|
|
|
hidden_states: U-Net 隐状态 (query)。
|
|
|
|
|
encoder_hidden_states: 文本编码器输出 (key/value),即交叉注意力输入。
|
|
|
|
|
attention_mask: 注意力掩码。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
计算后的输出隐状态。
|
|
|
|
|
"""
|
|
|
|
|
# 如果不是交叉注意力(即 encoder_hidden_states 为 None),则调用原始处理器
|
|
|
|
|
if encoder_hidden_states is None:
|
|
|
|
|
return attn.processor(
|
|
|
|
|
attn, hidden_states, encoder_hidden_states, attention_mask
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 1. 计算 Q, K, V
|
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
|
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
|
|
|
|
|
|
# 2. 准备矩阵乘法
|
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
|
|
|
|
|
|
|
|
# 3. 计算 Attention Scores (Q @ K^T)
|
|
|
|
|
attention_scores = torch.baddbmm(
|
|
|
|
|
torch.empty(
|
|
|
|
|
query.shape[0], query.shape[1], key.shape[1],
|
|
|
|
|
dtype=query.dtype, device=query.device
|
|
|
|
|
),
|
|
|
|
|
query,
|
|
|
|
|
key.transpose(1, 2),
|
|
|
|
|
beta=0,
|
|
|
|
|
alpha=attn.scale,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 4. 计算 Attention Probabilities
|
|
|
|
|
attention_probs = attention_scores.softmax(dim=-1)
|
|
|
|
|
layer_name = self.current_layer_name
|
|
|
|
|
|
|
|
|
|
# 5. 存储捕获的注意力图
|
|
|
|
|
if layer_name not in self.attention_maps:
|
|
|
|
|
self.attention_maps[layer_name] = []
|
|
|
|
|
|
|
|
|
|
# 存储当前时间步的注意力权重
|
|
|
|
|
self.attention_maps[layer_name].append(attention_probs.detach().cpu())
|
|
|
|
|
|
|
|
|
|
# 6. 计算输出 (Attention @ V)
|
|
|
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
|
|
|
|
# 7. 输出层
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
def _set_processors(self):
|
|
|
|
|
"""注册自定义处理器,捕获 U-Net 中所有交叉注意力层的权重。
|
|
|
|
|
|
|
|
|
|
遍历 U-Net 的所有子模块,找到所有交叉注意力层(`Attention` 且名称包含 `attn2`),
|
|
|
|
|
并将其处理器替换为当前的实例。
|
|
|
|
|
"""
|
|
|
|
|
for name, module in self.pipeline.unet.named_modules():
|
|
|
|
|
if isinstance(module, Attention) and 'attn2' in name:
|
|
|
|
|
# 存储原始处理器以便后续恢复
|
|
|
|
|
self.original_processors[name] = module.processor
|
|
|
|
|
|
|
|
|
|
# 定义一个新的闭包函数,用于在调用前设置当前层的名称
|
|
|
|
|
def set_layer_name(current_name):
|
|
|
|
|
def new_call(*args, **kwargs):
|
|
|
|
|
self.current_layer_name = current_name
|
|
|
|
|
return self.__call__(*args, **kwargs)
|
|
|
|
|
return new_call
|
|
|
|
|
|
|
|
|
|
module.processor = set_layer_name(name)
|
|
|
|
|
|
|
|
|
|
def remove(self):
|
|
|
|
|
"""恢复 U-Net 的原始注意力处理器,清理钩子。"""
|
|
|
|
|
for name, original_processor in self.original_processors.items():
|
|
|
|
|
module = self.pipeline.unet.get_submodule(name)
|
|
|
|
|
module.processor = original_processor
|
|
|
|
|
self.attention_maps = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def aggregate_word_attention(
|
|
|
|
|
attention_maps: Dict[str, List[torch.Tensor]],
|
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
|
target_word: str,
|
|
|
|
|
input_ids: torch.Tensor
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""聚合所有层和语义时间步中目标词汇的注意力图,并进行归一化。
|
|
|
|
|
|
|
|
|
|
聚合步骤:
|
|
|
|
|
1. 识别目标词汇对应的 Token 索引。
|
|
|
|
|
2. 对每个层:将所有捕获时间步的注意力图求平均。
|
|
|
|
|
3. 提取目标 Token 对应的注意力子图,并对 Token 维度求和,对 Attention Heads 求平均。
|
|
|
|
|
4. 将不同分辨率的注意力图上采样到统一尺寸(64x64)。
|
|
|
|
|
5. 对所有层的结果进行累加(求和)。
|
|
|
|
|
6. 最终归一化到 [0, 1]。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
attention_maps: 包含各层和时间步捕获的注意力图的字典。
|
|
|
|
|
tokenizer: CLIP 分词器实例。
|
|
|
|
|
target_word: 需要聚焦的关键词。
|
|
|
|
|
input_ids: Prompt 对应的 Token ID 张量。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
最终聚合并上采样到 64x64 尺寸的注意力热力图 (NumPy 数组)。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: 如果无法在 Prompt 中找到目标词汇。
|
|
|
|
|
RuntimeError: 如果未捕获到任何注意力数据。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 1. 识别目标词汇的 Token 索引
|
|
|
|
|
prompt_tokens = tokenizer.convert_ids_to_tokens(
|
|
|
|
|
input_ids.squeeze().cpu().tolist()
|
|
|
|
|
)
|
|
|
|
|
target_lower = target_word.lower()
|
|
|
|
|
target_indices = []
|
|
|
|
|
|
|
|
|
|
for i, token in enumerate(prompt_tokens):
|
|
|
|
|
cleaned_token = token.replace('Ġ', '').replace('_', '').lower()
|
|
|
|
|
# 查找目标词汇或以目标词汇开头的 token 索引,并排除特殊 token
|
|
|
|
|
if (input_ids.squeeze()[i] not in tokenizer.all_special_ids and
|
|
|
|
|
(target_lower in cleaned_token or
|
|
|
|
|
cleaned_token.startswith(target_lower))):
|
|
|
|
|
target_indices.append(i)
|
|
|
|
|
|
|
|
|
|
if not target_indices:
|
|
|
|
|
print(f"[WARN] 目标词汇 '{target_word}' 未识别。请检查 Prompt 或 Target Word。")
|
|
|
|
|
raise ValueError("无法识别目标词汇的 token 索引。")
|
|
|
|
|
|
|
|
|
|
# 2. 聚合逻辑
|
|
|
|
|
all_attention_data = []
|
|
|
|
|
# U-Net 输出的最大分辨率(64x64),总像素点数
|
|
|
|
|
TARGET_SPATIAL_SIZE = 4096
|
|
|
|
|
TARGET_MAP_SIZE = 64
|
|
|
|
|
|
|
|
|
|
for layer_name, step_maps in attention_maps.items():
|
|
|
|
|
if not step_maps:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 对该层捕获的所有时间步求平均,形状: (batch, heads, spatial_res, target_tokens_len)
|
|
|
|
|
avg_map_over_time = torch.stack(step_maps).mean(dim=0)
|
|
|
|
|
|
|
|
|
|
# 移除批次维度 (假设 batch size = 1),形状: (heads, spatial_res, target_tokens_len)
|
|
|
|
|
attention_map = avg_map_over_time.squeeze(0)
|
|
|
|
|
|
|
|
|
|
# 提取目标 token 的注意力图。形状: (heads, spatial_res, target_indices_len)
|
|
|
|
|
target_token_maps = attention_map[:, :, target_indices]
|
|
|
|
|
|
|
|
|
|
# 对目标 token 求和 (dim=-1),对注意力头求平均 (dim=0),形状: (spatial_res,)
|
|
|
|
|
aggregated_map_flat = target_token_maps.sum(dim=-1).mean(dim=0).float()
|
|
|
|
|
|
|
|
|
|
# 3. 跨分辨率上采样
|
|
|
|
|
if aggregated_map_flat.shape[0] != TARGET_SPATIAL_SIZE:
|
|
|
|
|
# 当前图的尺寸:16x16 (256) 或 32x32 (1024)
|
|
|
|
|
map_size = int(np.sqrt(aggregated_map_flat.shape[0]))
|
|
|
|
|
map_2d = aggregated_map_flat.reshape(map_size, map_size)
|
|
|
|
|
map_to_interp = map_2d.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
|
|
|
|
|
|
|
|
|
|
# 使用双线性插值上采样到 64x64
|
|
|
|
|
resized_map_2d = F.interpolate(
|
|
|
|
|
map_to_interp,
|
|
|
|
|
size=(TARGET_MAP_SIZE, TARGET_MAP_SIZE),
|
|
|
|
|
mode='bilinear',
|
|
|
|
|
align_corners=False
|
|
|
|
|
)
|
|
|
|
|
resized_map_flat = resized_map_2d.squeeze().flatten()
|
|
|
|
|
all_attention_data.append(resized_map_flat)
|
|
|
|
|
else:
|
|
|
|
|
# 如果已经是 64x64,直接使用
|
|
|
|
|
all_attention_data.append(aggregated_map_flat)
|
|
|
|
|
|
|
|
|
|
if not all_attention_data:
|
|
|
|
|
raise RuntimeError("未捕获到注意力数据。可能模型或参数设置有误。")
|
|
|
|
|
|
|
|
|
|
# 4. 对所有层的结果进行累加 (求和)
|
|
|
|
|
final_map_flat = torch.stack(all_attention_data).sum(dim=0).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
# 5. 最终归一化到 [0, 1]
|
|
|
|
|
final_map_flat = final_map_flat / (final_map_flat.max() + 1e-6)
|
|
|
|
|
|
|
|
|
|
map_size = int(np.sqrt(final_map_flat.shape[0]))
|
|
|
|
|
final_map_np = final_map_flat.reshape(map_size, map_size) # 64x64
|
|
|
|
|
|
|
|
|
|
return final_map_np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_attention_map_from_image(
|
|
|
|
|
pipeline: StableDiffusionPipeline,
|
|
|
|
|
image_path: str,
|
|
|
|
|
prompt_text: str,
|
|
|
|
|
target_word: str
|
|
|
|
|
) -> Tuple[Image.Image, np.ndarray]:
|
|
|
|
|
"""执行多时间步前向传播,捕获指定图片和 Prompt 的注意力图。
|
|
|
|
|
|
|
|
|
|
通过只运行扩散过程中的语义阶段(早期时间步)来确保捕获到的注意力权重
|
|
|
|
|
具有高信号质量。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
pipeline: Stable Diffusion 模型管线实例。
|
|
|
|
|
image_path: 待处理的输入图片路径。
|
|
|
|
|
prompt_text: 用于生成图片的 Prompt 文本。
|
|
|
|
|
target_word: 需要聚焦和可视化的关键词。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
包含 (原始图片, 最终上采样后的注意力图) 的元组。
|
|
|
|
|
"""
|
|
|
|
|
print(f"\n-> 正在处理图片: {Path(image_path).name}")
|
|
|
|
|
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
|
|
|
|
image_transform = transforms.Compose([
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize([0.5], [0.5]),
|
|
|
|
|
])
|
|
|
|
|
image_tensor = (
|
|
|
|
|
image_transform(image)
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
.to(pipeline.device)
|
|
|
|
|
.to(pipeline.unet.dtype)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 1. 编码到 Latent 空间
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
latent = (
|
|
|
|
|
pipeline.vae.encode(image_tensor).latent_dist.sample() *
|
|
|
|
|
pipeline.vae.config.scaling_factor
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 2. 编码 Prompt
|
|
|
|
|
text_input = pipeline.tokenizer(
|
|
|
|
|
prompt_text,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
max_length=pipeline.tokenizer.model_max_length,
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt"
|
|
|
|
|
)
|
|
|
|
|
input_ids = text_input.input_ids
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# 获取文本嵌入
|
|
|
|
|
prompt_embeds = pipeline.text_encoder(
|
|
|
|
|
input_ids.to(pipeline.device)
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# 3. 定义语义时间步
|
|
|
|
|
scheduler = pipeline.scheduler
|
|
|
|
|
# 设置扩散步数 (例如 50 步)
|
|
|
|
|
scheduler.set_timesteps(50, device=pipeline.device)
|
|
|
|
|
|
|
|
|
|
# 只选择语义最丰富的早期 10 步进行捕获
|
|
|
|
|
semantic_steps = scheduler.timesteps[:10]
|
|
|
|
|
print(f"-> 正在对语义阶段的 {len(semantic_steps)} 个时间步进行注意力捕获...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = AttentionMapProcessor(pipeline)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 4. 运行多步 UNet Forward Pass
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# 在选定的语义时间步上运行 U-Net 预测
|
|
|
|
|
for t in semantic_steps:
|
|
|
|
|
pipeline.unet(latent, t, prompt_embeds, return_dict=False)
|
|
|
|
|
|
|
|
|
|
# 5. 聚合捕获到的数据
|
|
|
|
|
raw_map_np = aggregate_word_attention(
|
|
|
|
|
processor.attention_maps,
|
|
|
|
|
pipeline.tokenizer,
|
|
|
|
|
target_word,
|
|
|
|
|
input_ids
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[ERROR] 注意力聚合失败: {e}")
|
|
|
|
|
# 确保清理钩子
|
|
|
|
|
raw_map_np = np.zeros(image.size)
|
|
|
|
|
finally:
|
|
|
|
|
processor.remove()
|
|
|
|
|
|
|
|
|
|
# 6. 注意力图上采样到图片尺寸 (512x512)
|
|
|
|
|
# PIL 进行上采样
|
|
|
|
|
heat_map_pil = Image.fromarray((raw_map_np * 255).astype(np.uint8))
|
|
|
|
|
heat_map_np_resized = (
|
|
|
|
|
np.array(heat_map_pil.resize(
|
|
|
|
|
image.size,
|
|
|
|
|
resample=Image.Resampling.LANCZOS # 使用高质量的 Lanczos 滤波器
|
|
|
|
|
)) / 255.0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return image, heat_map_np_resized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
"""主函数,负责解析参数,加载模型,计算差异并生成可视化报告。"""
|
|
|
|
|
parser = argparse.ArgumentParser(description="SD 图片注意力差异可视化报告生成")
|
|
|
|
|
parser.add_argument("--model_path", type=str, required=True,
|
|
|
|
|
help="Stable Diffusion 模型本地路径。")
|
|
|
|
|
parser.add_argument("--image_path_a", type=str, required=True,
|
|
|
|
|
help="干净输入图片 (X) 路径。")
|
|
|
|
|
parser.add_argument("--image_path_b", type=str, required=True,
|
|
|
|
|
help="扰动输入图片 (X') 路径。")
|
|
|
|
|
parser.add_argument("--prompt_text", type=str, default="a photo of sks person",
|
|
|
|
|
help="用于生成图片的 Prompt 文本。")
|
|
|
|
|
parser.add_argument("--target_word", type=str, default="sks",
|
|
|
|
|
help="需要在注意力图中聚焦和可视化的关键词。")
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="output",
|
|
|
|
|
help="报告 PNG 文件的输出目录。")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
print(f"--- 正在生成 Stable Diffusion 注意力差异报告 ---")
|
|
|
|
|
|
|
|
|
|
# ---------------- 准备模型 ----------------
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
dtype = torch.float16 if device == 'cuda' else torch.float32
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 加载 Stable Diffusion 管线
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
|
|
args.model_path,
|
|
|
|
|
torch_dtype=dtype,
|
|
|
|
|
local_files_only=True,
|
|
|
|
|
safety_checker=None,
|
|
|
|
|
# 从子文件夹加载调度器配置
|
|
|
|
|
scheduler=DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder="scheduler")
|
|
|
|
|
).to(device)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[ERROR] 模型加载失败,请检查路径和环境依赖: {e}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# ---------------- 获取数据 ----------------
|
|
|
|
|
# 获取干净图片 A 的注意力图 M_A
|
|
|
|
|
img_A, map_A = get_attention_map_from_image(pipe, args.image_path_a, args.prompt_text, args.target_word)
|
|
|
|
|
# 获取扰动图片 B 的注意力图 M_B
|
|
|
|
|
img_B, map_B = get_attention_map_from_image(pipe, args.image_path_b, args.prompt_text, args.target_word)
|
|
|
|
|
|
|
|
|
|
if map_A.shape != map_B.shape:
|
|
|
|
|
print("错误:注意力图尺寸不匹配。中止处理。")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 计算差异图: Delta = M_A - M_B
|
|
|
|
|
diff_map = map_A - map_B
|
|
|
|
|
# 计算 L2 范数(差异距离)
|
|
|
|
|
l2_diff = np.linalg.norm(diff_map)
|
|
|
|
|
print(f"\n计算完毕,注意力图的 L2 范数差异值: {l2_diff:.4f}")
|
|
|
|
|
|
|
|
|
|
# ---------------- 绘制专业报告 ----------------
|
|
|
|
|
|
|
|
|
|
# 设置 Matplotlib 字体样式
|
|
|
|
|
plt.rcParams.update({
|
|
|
|
|
'font.family': 'serif',
|
|
|
|
|
'font.serif': ['DejaVu Serif', 'Times New Roman', 'serif'],
|
|
|
|
|
'mathtext.fontset': 'cm'
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(12, 16), dpi=120)
|
|
|
|
|
|
|
|
|
|
# 3行 x 4列 网格布局,用于图片和图例的精确控制
|
|
|
|
|
gs = gridspec.GridSpec(3, 4, figure=fig,
|
|
|
|
|
height_ratios=[1, 1, 1.3],
|
|
|
|
|
hspace=0.3, wspace=0.1)
|
|
|
|
|
|
|
|
|
|
# --- 第一行:原始图片 ---
|
|
|
|
|
ax_img_a = fig.add_subplot(gs[0, 0:2])
|
|
|
|
|
ax_img_b = fig.add_subplot(gs[0, 2:4])
|
|
|
|
|
|
|
|
|
|
# 干净图片
|
|
|
|
|
ax_img_a.imshow(img_A)
|
|
|
|
|
ax_img_a.set_title(f"Clean Image ($X$)\nFilename: {Path(args.image_path_a).name}", fontsize=14, pad=10)
|
|
|
|
|
ax_img_a.axis('off')
|
|
|
|
|
|
|
|
|
|
# 扰动图片
|
|
|
|
|
ax_img_b.imshow(img_B)
|
|
|
|
|
ax_img_b.set_title(f"Noisy Image ($X'$)\nFilename: {Path(args.image_path_b).name}", fontsize=14, pad=10)
|
|
|
|
|
ax_img_b.axis('off')
|
|
|
|
|
|
|
|
|
|
# --- 第二行:注意力热力图 (Jet配色) ---
|
|
|
|
|
ax_map_a = fig.add_subplot(gs[1, 0:2])
|
|
|
|
|
ax_map_b = fig.add_subplot(gs[1, 2:4])
|
|
|
|
|
|
|
|
|
|
# 注意力图 A
|
|
|
|
|
im_map_a = ax_map_a.imshow(map_A, cmap='jet', vmin=0, vmax=1)
|
|
|
|
|
ax_map_a.set_title(f"Attention Heatmap ($M_X$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
|
|
|
|
|
ax_map_a.axis('off')
|
|
|
|
|
|
|
|
|
|
# 注意力图 B
|
|
|
|
|
im_map_b = ax_map_b.imshow(map_B, cmap='jet', vmin=0, vmax=1)
|
|
|
|
|
ax_map_b.set_title(f"Attention Heatmap ($M_{{X'}}$)\nTarget: \"{args.target_word}\"", fontsize=14, pad=10)
|
|
|
|
|
ax_map_b.axis('off')
|
|
|
|
|
|
|
|
|
|
# 为注意力图 B 绘制颜色指示条
|
|
|
|
|
divider = make_axes_locatable(ax_map_b)
|
|
|
|
|
cax_map = divider.append_axes("right", size="5%", pad=0.05)
|
|
|
|
|
cbar1 = fig.colorbar(im_map_b, cax=cax_map)
|
|
|
|
|
cbar1.set_label('Attention Intensity', fontsize=10)
|
|
|
|
|
|
|
|
|
|
# --- 第三行:差异对比 (完美居中) ---
|
|
|
|
|
# 差异图在网格的中间两列
|
|
|
|
|
ax_diff = fig.add_subplot(gs[2, 1:3])
|
|
|
|
|
|
|
|
|
|
vmax_diff = np.max(np.abs(diff_map))
|
|
|
|
|
# 使用 TwoSlopeNorm 确保 0 值位于色条中央
|
|
|
|
|
norm_diff = TwoSlopeNorm(vmin=-vmax_diff, vcenter=0., vmax=vmax_diff)
|
|
|
|
|
|
|
|
|
|
# 使用 Coolwarm 配色,蓝色表示负差异 (M_X' > M_X),红色表示正差异 (M_X > M_X')
|
|
|
|
|
im_diff = ax_diff.imshow(diff_map, cmap='coolwarm', norm=norm_diff)
|
|
|
|
|
|
|
|
|
|
title_text = (
|
|
|
|
|
r"Difference Map: $\Delta = M_X - M_{X'}$" +
|
|
|
|
|
f"\n$L_2$ Norm Distance: $\mathbf{{{l2_diff:.4f}}}$"
|
|
|
|
|
)
|
|
|
|
|
ax_diff.set_title(title_text, fontsize=16, pad=12)
|
|
|
|
|
ax_diff.axis('off')
|
|
|
|
|
|
|
|
|
|
# 差异图颜色指示条 (居中对齐)
|
|
|
|
|
cbar2 = fig.colorbar(im_diff, ax=ax_diff, fraction=0.046, pad=0.04)
|
|
|
|
|
cbar2.set_label(r'Scale: Red ($+$) $\leftrightarrow$ Blue ($-$)', fontsize=12)
|
|
|
|
|
|
|
|
|
|
# ---------------- 整体修饰与保存 ----------------
|
|
|
|
|
fig.suptitle(f"Museguard: SD Attention Analysis Report", fontsize=20, fontweight='bold', y=0.95)
|
|
|
|
|
|
|
|
|
|
output_filename = "heatmap_dif.png"
|
|
|
|
|
output_path = Path(args.output_dir) / output_filename
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', facecolor='white')
|
|
|
|
|
print(f"\n专业分析报告已保存至:\n{output_path.resolve()}")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|