|
|
|
|
@ -1,103 +0,0 @@
|
|
|
|
|
"""
|
|
|
|
|
风格特征提取器模块 - 基础版本和传统方法。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from typing import Tuple, Dict, Any
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseFeatureExtractor(ABC):
|
|
|
|
|
"""特征提取器基类"""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def extract(self, image: np.ndarray) -> Dict[str, Any]:
|
|
|
|
|
"""提取特征抽象方法"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_features(self) -> Dict[str, Any]:
|
|
|
|
|
"""获取特征抽象方法"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StyleFeatureExtractor(BaseFeatureExtractor):
|
|
|
|
|
"""
|
|
|
|
|
传统风格特征提取器。
|
|
|
|
|
通过计算图像的直方图和局部纹理特征来定义"风格"。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
初始化提取器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
config: 配置字典,包含特征提取参数。
|
|
|
|
|
"""
|
|
|
|
|
self.config = config
|
|
|
|
|
self.style_features: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
def extract(self, image: np.ndarray) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
从图像中提取风格特征。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: 输入图像,形状为 (H, W, C),值域 [0, 255]。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
一个包含各种风格特征的字典。
|
|
|
|
|
"""
|
|
|
|
|
# 确保图像在[0, 1]范围
|
|
|
|
|
image_normalized = image.astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
|
features = {}
|
|
|
|
|
|
|
|
|
|
# 1. 颜色直方图 (RGB通道)
|
|
|
|
|
hist_r, _ = np.histogram(image_normalized[:, :, 0].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
hist_g, _ = np.histogram(image_normalized[:, :, 1].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
hist_b, _ = np.histogram(image_normalized[:, :, 2].ravel(), bins=32, range=(0, 1))
|
|
|
|
|
features['color_histogram'] = np.concatenate([hist_r, hist_g, hist_b])
|
|
|
|
|
|
|
|
|
|
# 2. 局部纹理特征
|
|
|
|
|
diff_h = np.abs(np.diff(image_normalized, axis=1)).mean()
|
|
|
|
|
diff_v = np.abs(np.diff(image_normalized, axis=0)).mean()
|
|
|
|
|
features['texture_strength'] = np.array([diff_h, diff_v])
|
|
|
|
|
|
|
|
|
|
# 3. 平均颜色
|
|
|
|
|
features['mean_color'] = image_normalized.mean(axis=(0, 1))
|
|
|
|
|
|
|
|
|
|
# 4. 颜色标准差
|
|
|
|
|
features['color_std'] = image_normalized.std(axis=(0, 1))
|
|
|
|
|
|
|
|
|
|
self.style_features = features
|
|
|
|
|
return features
|
|
|
|
|
|
|
|
|
|
def get_features(self) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
获取最近一次提取的风格特征。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
风格特征字典。
|
|
|
|
|
"""
|
|
|
|
|
return self.style_features.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 工厂函数,便于创建不同类型的提取器
|
|
|
|
|
def create_feature_extractor(extractor_type: str, config: Dict[str, Any]) -> BaseFeatureExtractor:
|
|
|
|
|
"""
|
|
|
|
|
创建特征提取器工厂函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
extractor_type: 提取器类型 ('traditional' 或 'neural')
|
|
|
|
|
config: 配置字典
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
特征提取器实例
|
|
|
|
|
"""
|
|
|
|
|
if extractor_type == 'traditional':
|
|
|
|
|
return StyleFeatureExtractor(config)
|
|
|
|
|
elif extractor_type == 'neural':
|
|
|
|
|
from .neural_feature_extractor import NeuralStyleFeatureExtractor
|
|
|
|
|
return NeuralStyleFeatureExtractor(config)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的提取器类型: {extractor_type}")
|