|
|
|
|
@ -1,171 +0,0 @@
|
|
|
|
|
"""
|
|
|
|
|
图像合成器模块 - 基础版本和传统方法。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from typing import Tuple, Dict, Any
|
|
|
|
|
from scipy.optimize import minimize
|
|
|
|
|
import warnings
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseImageSynthesizer(ABC):
|
|
|
|
|
"""图像合成器基类"""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def set_content(self, image: np.ndarray):
|
|
|
|
|
"""设置内容图像"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def set_style_features(self, features: Dict[str, Any]):
|
|
|
|
|
"""设置风格特征"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray:
|
|
|
|
|
"""合成图像"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageSynthesizer(BaseImageSynthesizer):
|
|
|
|
|
"""
|
|
|
|
|
传统图像合成器,负责生成最终的风格化图像。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
初始化合成器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
config: 配置字典。
|
|
|
|
|
"""
|
|
|
|
|
self.config = config
|
|
|
|
|
self.content_image: np.ndarray = None
|
|
|
|
|
self.style_features: Dict[str, Any] = {}
|
|
|
|
|
self.synthesized_image: np.ndarray = None
|
|
|
|
|
|
|
|
|
|
def set_content(self, image: np.ndarray):
|
|
|
|
|
"""
|
|
|
|
|
设置内容图像。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: 内容图像,形状为 (H, W, C)。
|
|
|
|
|
"""
|
|
|
|
|
self.content_image = image.astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
|
def set_style_features(self, features: Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
设置风格特征。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
features: 由 StyleFeatureExtractor 提取的特征字典。
|
|
|
|
|
"""
|
|
|
|
|
self.style_features = features
|
|
|
|
|
|
|
|
|
|
def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
合成风格化图像。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
initial_image: 初始图像,如果为None,则使用内容图像。
|
|
|
|
|
max_iter: 最大迭代次数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。
|
|
|
|
|
"""
|
|
|
|
|
if initial_image is None:
|
|
|
|
|
initial_image = self.content_image.copy()
|
|
|
|
|
|
|
|
|
|
# 将图像展平为一维向量以便于优化
|
|
|
|
|
h, w, c = self.content_image.shape
|
|
|
|
|
x0 = initial_image.ravel()
|
|
|
|
|
|
|
|
|
|
# 定义损失函数
|
|
|
|
|
def loss_function(x):
|
|
|
|
|
# 重塑回图像形状
|
|
|
|
|
img = x.reshape((h, w, c))
|
|
|
|
|
|
|
|
|
|
# 1. 内容损失:保持与原图相似
|
|
|
|
|
content_loss = np.mean((img - self.content_image) ** 2)
|
|
|
|
|
|
|
|
|
|
# 2. 风格损失:匹配风格特征
|
|
|
|
|
style_loss = self._calculate_style_loss(img)
|
|
|
|
|
|
|
|
|
|
# 总损失
|
|
|
|
|
total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss
|
|
|
|
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
|
# 执行优化
|
|
|
|
|
result = minimize(loss_function, x0, method='L-BFGS-B',
|
|
|
|
|
options={'maxiter': max_iter, 'disp': False})
|
|
|
|
|
|
|
|
|
|
# 重塑结果并转换回[0, 255]
|
|
|
|
|
synthesized = result.x.reshape((h, w, c))
|
|
|
|
|
synthesized = np.clip(synthesized, 0, 1) * 255
|
|
|
|
|
self.synthesized_image = synthesized.astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
return self.synthesized_image
|
|
|
|
|
|
|
|
|
|
def _calculate_style_loss(self, image: np.ndarray) -> float:
|
|
|
|
|
"""
|
|
|
|
|
计算当前图像与目标风格特征之间的损失。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: 当前合成的图像。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
风格损失值。
|
|
|
|
|
"""
|
|
|
|
|
loss = 0.0
|
|
|
|
|
|
|
|
|
|
# 1. 颜色直方图损失
|
|
|
|
|
hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
current_hist = np.concatenate([hist_r, hist_g, hist_b])
|
|
|
|
|
target_hist = self.style_features['color_histogram']
|
|
|
|
|
# 使用KL散度作为距离度量
|
|
|
|
|
epsilon = 1e-10
|
|
|
|
|
kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon)))
|
|
|
|
|
loss += kl_div
|
|
|
|
|
|
|
|
|
|
# 2. 纹理强度损失
|
|
|
|
|
diff_h = np.abs(np.diff(image, axis=1)).mean()
|
|
|
|
|
diff_v = np.abs(np.diff(image, axis=0)).mean()
|
|
|
|
|
current_texture = np.array([diff_h, diff_v])
|
|
|
|
|
target_texture = self.style_features['texture_strength']
|
|
|
|
|
texture_loss = np.mean((current_texture - target_texture) ** 2)
|
|
|
|
|
loss += texture_loss
|
|
|
|
|
|
|
|
|
|
# 3. 平均颜色损失
|
|
|
|
|
current_mean = image.mean(axis=(0, 1))
|
|
|
|
|
target_mean = self.style_features['mean_color']
|
|
|
|
|
mean_loss = np.mean((current_mean - target_mean) ** 2)
|
|
|
|
|
loss += mean_loss
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 工厂函数
|
|
|
|
|
def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer:
|
|
|
|
|
"""
|
|
|
|
|
创建图像合成器工厂函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
synthesizer_type: 合成器类型 ('traditional' 或 'neural')
|
|
|
|
|
config: 配置字典
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
图像合成器实例
|
|
|
|
|
"""
|
|
|
|
|
if synthesizer_type == 'traditional':
|
|
|
|
|
return ImageSynthesizer(config)
|
|
|
|
|
elif synthesizer_type == 'neural':
|
|
|
|
|
from .neural_image_synthesizer import NeuralImageSynthesizer
|
|
|
|
|
return NeuralImageSynthesizer(config)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的合成器类型: {synthesizer_type}")
|