diff --git a/image_synthesizer.py b/image_synthesizer.py deleted file mode 100644 index 50558e5..0000000 --- a/image_synthesizer.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -图像合成器模块 - 基础版本和传统方法。 -""" - -import numpy as np -from typing import Tuple, Dict, Any -from scipy.optimize import minimize -import warnings -from abc import ABC, abstractmethod - -warnings.filterwarnings("ignore", category=UserWarning) - - -class BaseImageSynthesizer(ABC): - """图像合成器基类""" - - @abstractmethod - def set_content(self, image: np.ndarray): - """设置内容图像""" - pass - - @abstractmethod - def set_style_features(self, features: Dict[str, Any]): - """设置风格特征""" - pass - - @abstractmethod - def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray: - """合成图像""" - pass - - -class ImageSynthesizer(BaseImageSynthesizer): - """ - 传统图像合成器,负责生成最终的风格化图像。 - """ - - def __init__(self, config: Dict[str, Any]): - """ - 初始化合成器。 - - Args: - config: 配置字典。 - """ - self.config = config - self.content_image: np.ndarray = None - self.style_features: Dict[str, Any] = {} - self.synthesized_image: np.ndarray = None - - def set_content(self, image: np.ndarray): - """ - 设置内容图像。 - - Args: - image: 内容图像,形状为 (H, W, C)。 - """ - self.content_image = image.astype(np.float32) / 255.0 - - def set_style_features(self, features: Dict[str, Any]): - """ - 设置风格特征。 - - Args: - features: 由 StyleFeatureExtractor 提取的特征字典。 - """ - self.style_features = features - - def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray: - """ - 合成风格化图像。 - - Args: - initial_image: 初始图像,如果为None,则使用内容图像。 - max_iter: 最大迭代次数。 - - Returns: - 合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。 - """ - if initial_image is None: - initial_image = self.content_image.copy() - - # 将图像展平为一维向量以便于优化 - h, w, c = self.content_image.shape - x0 = initial_image.ravel() - - # 定义损失函数 - def loss_function(x): - # 重塑回图像形状 - img = x.reshape((h, w, c)) - - # 1. 内容损失:保持与原图相似 - content_loss = np.mean((img - self.content_image) ** 2) - - # 2. 风格损失:匹配风格特征 - style_loss = self._calculate_style_loss(img) - - # 总损失 - total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss - - return total_loss - - # 执行优化 - result = minimize(loss_function, x0, method='L-BFGS-B', - options={'maxiter': max_iter, 'disp': False}) - - # 重塑结果并转换回[0, 255] - synthesized = result.x.reshape((h, w, c)) - synthesized = np.clip(synthesized, 0, 1) * 255 - self.synthesized_image = synthesized.astype(np.uint8) - - return self.synthesized_image - - def _calculate_style_loss(self, image: np.ndarray) -> float: - """ - 计算当前图像与目标风格特征之间的损失。 - - Args: - image: 当前合成的图像。 - - Returns: - 风格损失值。 - """ - loss = 0.0 - - # 1. 颜色直方图损失 - hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1)) - hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1)) - hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1)) - current_hist = np.concatenate([hist_r, hist_g, hist_b]) - target_hist = self.style_features['color_histogram'] - # 使用KL散度作为距离度量 - epsilon = 1e-10 - kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon))) - loss += kl_div - - # 2. 纹理强度损失 - diff_h = np.abs(np.diff(image, axis=1)).mean() - diff_v = np.abs(np.diff(image, axis=0)).mean() - current_texture = np.array([diff_h, diff_v]) - target_texture = self.style_features['texture_strength'] - texture_loss = np.mean((current_texture - target_texture) ** 2) - loss += texture_loss - - # 3. 平均颜色损失 - current_mean = image.mean(axis=(0, 1)) - target_mean = self.style_features['mean_color'] - mean_loss = np.mean((current_mean - target_mean) ** 2) - loss += mean_loss - - return loss - - -# 工厂函数 -def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer: - """ - 创建图像合成器工厂函数。 - - Args: - synthesizer_type: 合成器类型 ('traditional' 或 'neural') - config: 配置字典 - - Returns: - 图像合成器实例 - """ - if synthesizer_type == 'traditional': - return ImageSynthesizer(config) - elif synthesizer_type == 'neural': - from .neural_image_synthesizer import NeuralImageSynthesizer - return NeuralImageSynthesizer(config) - else: - raise ValueError(f"不支持的合成器类型: {synthesizer_type}") \ No newline at end of file