You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
图像合成器模块 - 基础版本和传统方法。
"""
import numpy as np
from typing import Tuple, Dict, Any
from scipy.optimize import minimize
import warnings
from abc import ABC, abstractmethod
warnings.filterwarnings("ignore", category=UserWarning)
class BaseImageSynthesizer(ABC):
"""图像合成器基类"""
@abstractmethod
def set_content(self, image: np.ndarray):
"""设置内容图像"""
pass
@abstractmethod
def set_style_features(self, features: Dict[str, Any]):
"""设置风格特征"""
pass
@abstractmethod
def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray:
"""合成图像"""
pass
class ImageSynthesizer(BaseImageSynthesizer):
"""
传统图像合成器,负责生成最终的风格化图像。
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化合成器。
Args:
config: 配置字典。
"""
self.config = config
self.content_image: np.ndarray = None
self.style_features: Dict[str, Any] = {}
self.synthesized_image: np.ndarray = None
def set_content(self, image: np.ndarray):
"""
设置内容图像。
Args:
image: 内容图像,形状为 (H, W, C)。
"""
self.content_image = image.astype(np.float32) / 255.0
def set_style_features(self, features: Dict[str, Any]):
"""
设置风格特征。
Args:
features: 由 StyleFeatureExtractor 提取的特征字典。
"""
self.style_features = features
def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray:
"""
合成风格化图像。
Args:
initial_image: 初始图像如果为None则使用内容图像。
max_iter: 最大迭代次数。
Returns:
合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。
"""
if initial_image is None:
initial_image = self.content_image.copy()
# 将图像展平为一维向量以便于优化
h, w, c = self.content_image.shape
x0 = initial_image.ravel()
# 定义损失函数
def loss_function(x):
# 重塑回图像形状
img = x.reshape((h, w, c))
# 1. 内容损失:保持与原图相似
content_loss = np.mean((img - self.content_image) ** 2)
# 2. 风格损失:匹配风格特征
style_loss = self._calculate_style_loss(img)
# 总损失
total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss
return total_loss
# 执行优化
result = minimize(loss_function, x0, method='L-BFGS-B',
options={'maxiter': max_iter, 'disp': False})
# 重塑结果并转换回[0, 255]
synthesized = result.x.reshape((h, w, c))
synthesized = np.clip(synthesized, 0, 1) * 255
self.synthesized_image = synthesized.astype(np.uint8)
return self.synthesized_image
def _calculate_style_loss(self, image: np.ndarray) -> float:
"""
计算当前图像与目标风格特征之间的损失。
Args:
image: 当前合成的图像。
Returns:
风格损失值。
"""
loss = 0.0
# 1. 颜色直方图损失
hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1))
hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1))
hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1))
current_hist = np.concatenate([hist_r, hist_g, hist_b])
target_hist = self.style_features['color_histogram']
# 使用KL散度作为距离度量
epsilon = 1e-10
kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon)))
loss += kl_div
# 2. 纹理强度损失
diff_h = np.abs(np.diff(image, axis=1)).mean()
diff_v = np.abs(np.diff(image, axis=0)).mean()
current_texture = np.array([diff_h, diff_v])
target_texture = self.style_features['texture_strength']
texture_loss = np.mean((current_texture - target_texture) ** 2)
loss += texture_loss
# 3. 平均颜色损失
current_mean = image.mean(axis=(0, 1))
target_mean = self.style_features['mean_color']
mean_loss = np.mean((current_mean - target_mean) ** 2)
loss += mean_loss
return loss
# 工厂函数
def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer:
"""
创建图像合成器工厂函数。
Args:
synthesizer_type: 合成器类型 ('traditional' 或 'neural')
config: 配置字典
Returns:
图像合成器实例
"""
if synthesizer_type == 'traditional':
return ImageSynthesizer(config)
elif synthesizer_type == 'neural':
from .neural_image_synthesizer import NeuralImageSynthesizer
return NeuralImageSynthesizer(config)
else:
raise ValueError(f"不支持的合成器类型: {synthesizer_type}")