You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
NST/feature_extractor.py

103 lines
3.1 KiB

"""
风格特征提取器模块 - 基础版本和传统方法。
"""
import numpy as np
from typing import Tuple, Dict, Any
from abc import ABC, abstractmethod
class BaseFeatureExtractor(ABC):
"""特征提取器基类"""
@abstractmethod
def extract(self, image: np.ndarray) -> Dict[str, Any]:
"""提取特征抽象方法"""
pass
@abstractmethod
def get_features(self) -> Dict[str, Any]:
"""获取特征抽象方法"""
pass
class StyleFeatureExtractor(BaseFeatureExtractor):
"""
传统风格特征提取器。
通过计算图像的直方图和局部纹理特征来定义"风格"
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化提取器。
Args:
config: 配置字典,包含特征提取参数。
"""
self.config = config
self.style_features: Dict[str, Any] = {}
def extract(self, image: np.ndarray) -> Dict[str, Any]:
"""
从图像中提取风格特征。
Args:
image: 输入图像,形状为 (H, W, C),值域 [0, 255]。
Returns:
一个包含各种风格特征的字典。
"""
# 确保图像在[0, 1]范围
image_normalized = image.astype(np.float32) / 255.0
features = {}
# 1. 颜色直方图 (RGB通道)
hist_r, _ = np.histogram(image_normalized[:, :, 0].ravel(), bins=32, range=(0, 1))
hist_g, _ = np.histogram(image_normalized[:, :, 1].ravel(), bins=32, range=(0, 1))
hist_b, _ = np.histogram(image_normalized[:, :, 2].ravel(), bins=32, range=(0, 1))
features['color_histogram'] = np.concatenate([hist_r, hist_g, hist_b])
# 2. 局部纹理特征
diff_h = np.abs(np.diff(image_normalized, axis=1)).mean()
diff_v = np.abs(np.diff(image_normalized, axis=0)).mean()
features['texture_strength'] = np.array([diff_h, diff_v])
# 3. 平均颜色
features['mean_color'] = image_normalized.mean(axis=(0, 1))
# 4. 颜色标准差
features['color_std'] = image_normalized.std(axis=(0, 1))
self.style_features = features
return features
def get_features(self) -> Dict[str, Any]:
"""
获取最近一次提取的风格特征。
Returns:
风格特征字典。
"""
return self.style_features.copy()
# 工厂函数,便于创建不同类型的提取器
def create_feature_extractor(extractor_type: str, config: Dict[str, Any]) -> BaseFeatureExtractor:
"""
创建特征提取器工厂函数。
Args:
extractor_type: 提取器类型 ('traditional''neural')
config: 配置字典
Returns:
特征提取器实例
"""
if extractor_type == 'traditional':
return StyleFeatureExtractor(config)
elif extractor_type == 'neural':
from .neural_feature_extractor import NeuralStyleFeatureExtractor
return NeuralStyleFeatureExtractor(config)
else:
raise ValueError(f"不支持的提取器类型: {extractor_type}")