From 066f039a45e28cb35e1f0e3c8e1d8c8ecb5b4878 Mon Sep 17 00:00:00 2001 From: p82b7rtam <1761133400@qq.com> Date: Mon, 29 Dec 2025 17:05:58 +0800 Subject: [PATCH] ADD file via upload --- feature_extractor.py | 103 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 feature_extractor.py diff --git a/feature_extractor.py b/feature_extractor.py new file mode 100644 index 0000000..32d85b6 --- /dev/null +++ b/feature_extractor.py @@ -0,0 +1,103 @@ +""" +风格特征提取器模块 - 基础版本和传统方法。 +""" + +import numpy as np +from typing import Tuple, Dict, Any +from abc import ABC, abstractmethod + + +class BaseFeatureExtractor(ABC): + """特征提取器基类""" + + @abstractmethod + def extract(self, image: np.ndarray) -> Dict[str, Any]: + """提取特征抽象方法""" + pass + + @abstractmethod + def get_features(self) -> Dict[str, Any]: + """获取特征抽象方法""" + pass + + +class StyleFeatureExtractor(BaseFeatureExtractor): + """ + 传统风格特征提取器。 + 通过计算图像的直方图和局部纹理特征来定义"风格"。 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化提取器。 + + Args: + config: 配置字典,包含特征提取参数。 + """ + self.config = config + self.style_features: Dict[str, Any] = {} + + def extract(self, image: np.ndarray) -> Dict[str, Any]: + """ + 从图像中提取风格特征。 + + Args: + image: 输入图像,形状为 (H, W, C),值域 [0, 255]。 + + Returns: + 一个包含各种风格特征的字典。 + """ + # 确保图像在[0, 1]范围 + image_normalized = image.astype(np.float32) / 255.0 + + features = {} + + # 1. 颜色直方图 (RGB通道) + hist_r, _ = np.histogram(image_normalized[:, :, 0].ravel(), bins=32, range=(0, 1)) + hist_g, _ = np.histogram(image_normalized[:, :, 1].ravel(), bins=32, range=(0, 1)) + hist_b, _ = np.histogram(image_normalized[:, :, 2].ravel(), bins=32, range=(0, 1)) + features['color_histogram'] = np.concatenate([hist_r, hist_g, hist_b]) + + # 2. 局部纹理特征 + diff_h = np.abs(np.diff(image_normalized, axis=1)).mean() + diff_v = np.abs(np.diff(image_normalized, axis=0)).mean() + features['texture_strength'] = np.array([diff_h, diff_v]) + + # 3. 平均颜色 + features['mean_color'] = image_normalized.mean(axis=(0, 1)) + + # 4. 颜色标准差 + features['color_std'] = image_normalized.std(axis=(0, 1)) + + self.style_features = features + return features + + def get_features(self) -> Dict[str, Any]: + """ + 获取最近一次提取的风格特征。 + + Returns: + 风格特征字典。 + """ + return self.style_features.copy() + + +# 工厂函数,便于创建不同类型的提取器 +def create_feature_extractor(extractor_type: str, config: Dict[str, Any]) -> BaseFeatureExtractor: + """ + 创建特征提取器工厂函数。 + + Args: + extractor_type: 提取器类型 ('traditional' 或 'neural') + config: 配置字典 + + Returns: + 特征提取器实例 + """ + if extractor_type == 'traditional': + return StyleFeatureExtractor(config) + elif extractor_type == 'neural': + from .neural_feature_extractor import NeuralStyleFeatureExtractor + return NeuralStyleFeatureExtractor(config) + else: + raise ValueError(f"不支持的提取器类型: {extractor_type}") \ No newline at end of file