diff --git a/main b/main new file mode 100644 index 0000000..50558e5 --- /dev/null +++ b/main @@ -0,0 +1,171 @@ +""" +图像合成器模块 - 基础版本和传统方法。 +""" + +import numpy as np +from typing import Tuple, Dict, Any +from scipy.optimize import minimize +import warnings +from abc import ABC, abstractmethod + +warnings.filterwarnings("ignore", category=UserWarning) + + +class BaseImageSynthesizer(ABC): + """图像合成器基类""" + + @abstractmethod + def set_content(self, image: np.ndarray): + """设置内容图像""" + pass + + @abstractmethod + def set_style_features(self, features: Dict[str, Any]): + """设置风格特征""" + pass + + @abstractmethod + def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray: + """合成图像""" + pass + + +class ImageSynthesizer(BaseImageSynthesizer): + """ + 传统图像合成器,负责生成最终的风格化图像。 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化合成器。 + + Args: + config: 配置字典。 + """ + self.config = config + self.content_image: np.ndarray = None + self.style_features: Dict[str, Any] = {} + self.synthesized_image: np.ndarray = None + + def set_content(self, image: np.ndarray): + """ + 设置内容图像。 + + Args: + image: 内容图像,形状为 (H, W, C)。 + """ + self.content_image = image.astype(np.float32) / 255.0 + + def set_style_features(self, features: Dict[str, Any]): + """ + 设置风格特征。 + + Args: + features: 由 StyleFeatureExtractor 提取的特征字典。 + """ + self.style_features = features + + def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray: + """ + 合成风格化图像。 + + Args: + initial_image: 初始图像,如果为None,则使用内容图像。 + max_iter: 最大迭代次数。 + + Returns: + 合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。 + """ + if initial_image is None: + initial_image = self.content_image.copy() + + # 将图像展平为一维向量以便于优化 + h, w, c = self.content_image.shape + x0 = initial_image.ravel() + + # 定义损失函数 + def loss_function(x): + # 重塑回图像形状 + img = x.reshape((h, w, c)) + + # 1. 内容损失:保持与原图相似 + content_loss = np.mean((img - self.content_image) ** 2) + + # 2. 风格损失:匹配风格特征 + style_loss = self._calculate_style_loss(img) + + # 总损失 + total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss + + return total_loss + + # 执行优化 + result = minimize(loss_function, x0, method='L-BFGS-B', + options={'maxiter': max_iter, 'disp': False}) + + # 重塑结果并转换回[0, 255] + synthesized = result.x.reshape((h, w, c)) + synthesized = np.clip(synthesized, 0, 1) * 255 + self.synthesized_image = synthesized.astype(np.uint8) + + return self.synthesized_image + + def _calculate_style_loss(self, image: np.ndarray) -> float: + """ + 计算当前图像与目标风格特征之间的损失。 + + Args: + image: 当前合成的图像。 + + Returns: + 风格损失值。 + """ + loss = 0.0 + + # 1. 颜色直方图损失 + hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1)) + hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1)) + hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1)) + current_hist = np.concatenate([hist_r, hist_g, hist_b]) + target_hist = self.style_features['color_histogram'] + # 使用KL散度作为距离度量 + epsilon = 1e-10 + kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon))) + loss += kl_div + + # 2. 纹理强度损失 + diff_h = np.abs(np.diff(image, axis=1)).mean() + diff_v = np.abs(np.diff(image, axis=0)).mean() + current_texture = np.array([diff_h, diff_v]) + target_texture = self.style_features['texture_strength'] + texture_loss = np.mean((current_texture - target_texture) ** 2) + loss += texture_loss + + # 3. 平均颜色损失 + current_mean = image.mean(axis=(0, 1)) + target_mean = self.style_features['mean_color'] + mean_loss = np.mean((current_mean - target_mean) ** 2) + loss += mean_loss + + return loss + + +# 工厂函数 +def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer: + """ + 创建图像合成器工厂函数。 + + Args: + synthesizer_type: 合成器类型 ('traditional' 或 'neural') + config: 配置字典 + + Returns: + 图像合成器实例 + """ + if synthesizer_type == 'traditional': + return ImageSynthesizer(config) + elif synthesizer_type == 'neural': + from .neural_image_synthesizer import NeuralImageSynthesizer + return NeuralImageSynthesizer(config) + else: + raise ValueError(f"不支持的合成器类型: {synthesizer_type}") \ No newline at end of file