ADD file via upload

main
p82b7rtam 2 months ago
parent 3badce0659
commit 00f0e079b6

171
main

@ -0,0 +1,171 @@
"""
图像合成器模块 - 基础版本和传统方法。
"""
import numpy as np
from typing import Tuple, Dict, Any
from scipy.optimize import minimize
import warnings
from abc import ABC, abstractmethod
warnings.filterwarnings("ignore", category=UserWarning)
class BaseImageSynthesizer(ABC):
"""图像合成器基类"""
@abstractmethod
def set_content(self, image: np.ndarray):
"""设置内容图像"""
pass
@abstractmethod
def set_style_features(self, features: Dict[str, Any]):
"""设置风格特征"""
pass
@abstractmethod
def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray:
"""合成图像"""
pass
class ImageSynthesizer(BaseImageSynthesizer):
"""
传统图像合成器,负责生成最终的风格化图像。
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化合成器。
Args:
config: 配置字典。
"""
self.config = config
self.content_image: np.ndarray = None
self.style_features: Dict[str, Any] = {}
self.synthesized_image: np.ndarray = None
def set_content(self, image: np.ndarray):
"""
设置内容图像。
Args:
image: 内容图像,形状为 (H, W, C)。
"""
self.content_image = image.astype(np.float32) / 255.0
def set_style_features(self, features: Dict[str, Any]):
"""
设置风格特征。
Args:
features: 由 StyleFeatureExtractor 提取的特征字典。
"""
self.style_features = features
def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray:
"""
合成风格化图像。
Args:
initial_image: 初始图像如果为None则使用内容图像。
max_iter: 最大迭代次数。
Returns:
合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。
"""
if initial_image is None:
initial_image = self.content_image.copy()
# 将图像展平为一维向量以便于优化
h, w, c = self.content_image.shape
x0 = initial_image.ravel()
# 定义损失函数
def loss_function(x):
# 重塑回图像形状
img = x.reshape((h, w, c))
# 1. 内容损失:保持与原图相似
content_loss = np.mean((img - self.content_image) ** 2)
# 2. 风格损失:匹配风格特征
style_loss = self._calculate_style_loss(img)
# 总损失
total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss
return total_loss
# 执行优化
result = minimize(loss_function, x0, method='L-BFGS-B',
options={'maxiter': max_iter, 'disp': False})
# 重塑结果并转换回[0, 255]
synthesized = result.x.reshape((h, w, c))
synthesized = np.clip(synthesized, 0, 1) * 255
self.synthesized_image = synthesized.astype(np.uint8)
return self.synthesized_image
def _calculate_style_loss(self, image: np.ndarray) -> float:
"""
计算当前图像与目标风格特征之间的损失。
Args:
image: 当前合成的图像。
Returns:
风格损失值。
"""
loss = 0.0
# 1. 颜色直方图损失
hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1))
hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1))
hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1))
current_hist = np.concatenate([hist_r, hist_g, hist_b])
target_hist = self.style_features['color_histogram']
# 使用KL散度作为距离度量
epsilon = 1e-10
kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon)))
loss += kl_div
# 2. 纹理强度损失
diff_h = np.abs(np.diff(image, axis=1)).mean()
diff_v = np.abs(np.diff(image, axis=0)).mean()
current_texture = np.array([diff_h, diff_v])
target_texture = self.style_features['texture_strength']
texture_loss = np.mean((current_texture - target_texture) ** 2)
loss += texture_loss
# 3. 平均颜色损失
current_mean = image.mean(axis=(0, 1))
target_mean = self.style_features['mean_color']
mean_loss = np.mean((current_mean - target_mean) ** 2)
loss += mean_loss
return loss
# 工厂函数
def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer:
"""
创建图像合成器工厂函数。
Args:
synthesizer_type: 合成器类型 ('traditional' 或 'neural')
config: 配置字典
Returns:
图像合成器实例
"""
if synthesizer_type == 'traditional':
return ImageSynthesizer(config)
elif synthesizer_type == 'neural':
from .neural_image_synthesizer import NeuralImageSynthesizer
return NeuralImageSynthesizer(config)
else:
raise ValueError(f"不支持的合成器类型: {synthesizer_type}")
Loading…
Cancel
Save