You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dsj/feature_extractor.py

794 lines
34 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
多模态网络流图数据抽取 - 初步特征提取模块支持CSV导出
功能:协议特征、时序特征、统计特征提取
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Any
import json
import csv
from collections import Counter, defaultdict
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
import warnings
warnings.filterwarnings('ignore')
class FeatureExtractor:
"""特征提取器类"""
def __init__(self, csv_file: str):
"""
初始化特征提取器
Args:
csv_file (str): 预处理后的CSV文件路径
"""
self.csv_file = csv_file
self.df = None
self.protocol_features = {}
self.temporal_features = {}
self.statistical_features = {}
self.extracted_features = {}
def load_data(self) -> pd.DataFrame:
"""
加载预处理后的数据
Returns:
pd.DataFrame: 数据框
"""
print(f"正在加载预处理数据: {self.csv_file}")
try:
self.df = pd.read_csv(self.csv_file, encoding='utf-8')
# 兼容两列表CSV若缺少必要列仅使用可用信息降级运行
if '源IP' in self.df.columns and '目标IP' in self.df.columns:
if '时间戳' not in self.df.columns:
self.df['时间戳'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if '协议' not in self.df.columns:
self.df['协议'] = 'Unknown'
if '数据包大小' not in self.df.columns:
self.df['数据包大小'] = 0
if '载荷大小' not in self.df.columns:
self.df['载荷大小'] = 0
print(f"成功加载 {len(self.df)} 条记录")
# 转换时间戳
if '时间戳' in self.df.columns:
self.df['时间戳'] = pd.to_datetime(self.df['时间戳'], errors='coerce')
# 填补无效时间戳
self.df['时间戳'] = self.df['时间戳'].fillna(pd.Timestamp.now())
return self.df
except Exception as e:
print(f"加载数据时出错: {e}")
return None
def extract_protocol_features(self) -> Dict[str, Any]:
"""
协议特征提取
Returns:
Dict[str, Any]: 协议特征字典
"""
print("开始协议特征提取...")
if self.df is None:
self.load_data()
protocol_features = {}
# 1. 统计各协议类型分布
print("1. 统计协议类型分布...")
if '协议' in self.df.columns:
protocol_distribution = self.df['协议'].value_counts().to_dict()
else:
protocol_distribution = {'Unknown': len(self.df)}
protocol_features['protocol_distribution'] = protocol_distribution
# 计算协议占比
total_packets = len(self.df)
protocol_percentages = {protocol: count/total_packets for protocol, count in protocol_distribution.items()}
protocol_features['protocol_percentages'] = protocol_percentages
# 2. 生成协议 One-Hot 编码
print("2. 生成协议One-Hot编码...")
try:
# 使用sklearn的LabelEncoder和OneHotEncoder
label_encoder = LabelEncoder()
proto_series = self.df['协议'].astype(str) if '协议' in self.df.columns else pd.Series(['Unknown']*len(self.df))
protocol_encoded = label_encoder.fit_transform(proto_series)
onehot_encoder = OneHotEncoder(sparse_output=False)
protocol_onehot = onehot_encoder.fit_transform(protocol_encoded.reshape(-1, 1))
# 创建协议One-Hot编码DataFrame
protocol_names = label_encoder.classes_
protocol_onehot_df = pd.DataFrame(
protocol_onehot,
columns=[f'protocol_{name}' for name in protocol_names]
)
protocol_features['protocol_onehot_encoding'] = {
'encoded_data': protocol_onehot_df.to_dict('records'),
'protocol_names': protocol_names.tolist(),
'encoding_shape': protocol_onehot.shape
}
except Exception as e:
print(f"One-Hot编码生成失败: {e}")
protocol_features['protocol_onehot_encoding'] = None
# 3. 协议相关的流量统计
print("3. 计算协议相关流量统计...")
protocol_stats = {}
for protocol in (self.df['协议'].unique() if '协议' in self.df.columns else ['Unknown']):
protocol_data = self.df[self.df['协议'] == protocol] if '协议' in self.df.columns else self.df
protocol_stats[protocol] = {
'packet_count': len(protocol_data),
'total_bytes': (protocol_data['数据包大小'].sum() if '数据包大小' in protocol_data.columns else 0),
'avg_packet_size': (protocol_data['数据包大小'].mean() if '数据包大小' in protocol_data.columns else 0),
'max_packet_size': (protocol_data['数据包大小'].max() if '数据包大小' in protocol_data.columns else 0),
'min_packet_size': (protocol_data['数据包大小'].min() if '数据包大小' in protocol_data.columns else 0),
'payload_bytes': (protocol_data['载荷大小'].sum() if '载荷大小' in protocol_data.columns else 0),
'avg_payload_size': (protocol_data['载荷大小'].mean() if '载荷大小' in protocol_data.columns else 0)
}
protocol_features['protocol_statistics'] = protocol_stats
# 4. 协议组合分析
print("4. 分析协议组合...")
# 按源IP和目标IP分组分析每个连接的协议组合
connection_protocols = self.df.groupby(['源IP', '目标IP'])['协议'].apply(list).to_dict()
protocol_combinations = {}
for connection, protocols in connection_protocols.items():
unique_protocols = list(set(protocols))
protocol_key = '-'.join(sorted(unique_protocols))
if protocol_key not in protocol_combinations:
protocol_combinations[protocol_key] = 0
protocol_combinations[protocol_key] += 1
protocol_features['protocol_combinations'] = protocol_combinations
self.protocol_features = protocol_features
return protocol_features
def extract_temporal_features(self, time_window: int = 60) -> Dict[str, Any]:
"""
时序特征提取
Args:
time_window (int): 时间窗口(秒)
Returns:
Dict[str, Any]: 时序特征字典
"""
print("开始时序特征提取...")
if self.df is None:
self.load_data()
temporal_features = {}
# 1. 计算流量时间分布
print("1. 计算流量时间分布...")
if '时间戳' in self.df.columns:
self.df['时间戳'] = pd.to_datetime(self.df['时间戳'], errors='coerce').fillna(pd.Timestamp.now())
else:
# 无时间戳时,构造一个等间隔伪时间轴以便窗口统计
base = pd.Timestamp.now()
self.df['时间戳'] = [base + timedelta(seconds=i) for i in range(len(self.df))]
# 按时间窗口分组
start_time = self.df['时间戳'].min()
end_time = self.df['时间戳'].max()
time_windows = []
current_time = start_time
while current_time < end_time:
window_end = current_time + timedelta(seconds=time_window)
window_data = self.df[
(self.df['时间戳'] >= current_time) &
(self.df['时间戳'] < window_end)
]
time_windows.append({
'start_time': current_time.isoformat(),
'end_time': window_end.isoformat(),
'packet_count': len(window_data),
'total_bytes': (window_data['数据包大小'].sum() if '数据包大小' in window_data.columns else 0),
'unique_connections': len(window_data.groupby(['源IP', '目标IP'])) if ('源IP' in window_data.columns and '目标IP' in window_data.columns) else 0,
'protocols': (window_data['协议'].value_counts().to_dict() if '协议' in window_data.columns else {'Unknown': len(window_data)})
})
current_time = window_end
temporal_features['time_distribution'] = time_windows
# 2. 提取活跃度变化率
print("2. 提取活跃度变化率...")
if len(time_windows) > 1:
packet_counts = [window['packet_count'] for window in time_windows]
byte_counts = [window['total_bytes'] for window in time_windows]
# 计算变化率
packet_change_rates = np.diff(packet_counts) / np.array(packet_counts[:-1])
byte_change_rates = np.diff(byte_counts) / np.array(byte_counts[:-1])
# 处理除零情况
packet_change_rates = np.nan_to_num(packet_change_rates, nan=0.0, posinf=0.0, neginf=0.0)
byte_change_rates = np.nan_to_num(byte_change_rates, nan=0.0, posinf=0.0, neginf=0.0)
temporal_features['activity_change_rates'] = {
'packet_change_rates': packet_change_rates.tolist(),
'byte_change_rates': byte_change_rates.tolist(),
'avg_packet_change_rate': float(np.mean(packet_change_rates)),
'avg_byte_change_rate': float(np.mean(byte_change_rates)),
'max_packet_change_rate': float(np.max(packet_change_rates)),
'max_byte_change_rate': float(np.max(byte_change_rates))
}
# 3. 时间序列分析
print("3. 时间序列分析...")
# 按小时统计流量
self.df['hour'] = self.df['时间戳'].dt.hour
agg_map = {}
if '数据包大小' in self.df.columns:
agg_map['数据包大小'] = ['count', 'sum', 'mean']
if '载荷大小' in self.df.columns:
agg_map['载荷大小'] = ['sum', 'mean']
if not agg_map:
agg_map['hour'] = ['count']
hourly_stats = self.df.groupby('hour').agg(agg_map).round(2)
# 转换MultiIndex为可序列化的格式
hourly_dict = {}
for col in hourly_stats.columns:
col_name = f"{col[0]}_{col[1]}" if isinstance(col, tuple) else str(col)
hourly_dict[col_name] = hourly_stats[col].to_dict()
temporal_features['hourly_distribution'] = hourly_dict
# 按分钟统计流量
self.df['minute'] = self.df['时间戳'].dt.minute
agg_map_min = {}
if '数据包大小' in self.df.columns:
agg_map_min['数据包大小'] = ['count', 'sum', 'mean']
else:
agg_map_min['minute'] = ['count']
minute_stats = self.df.groupby('minute').agg(agg_map_min).round(2)
# 转换MultiIndex为可序列化的格式
minute_dict = {}
for col in minute_stats.columns:
col_name = f"{col[0]}_{col[1]}" if isinstance(col, tuple) else str(col)
minute_dict[col_name] = minute_stats[col].to_dict()
temporal_features['minute_distribution'] = minute_dict
# 4. 流量峰值检测
print("4. 流量峰值检测...")
if len(time_windows) > 3:
packet_counts = [window['packet_count'] for window in time_windows]
byte_counts = [window['total_bytes'] for window in time_windows]
# 使用滑动窗口检测峰值
window_size = min(3, len(packet_counts) // 2)
peaks = []
for i in range(window_size, len(packet_counts) - window_size):
current_packets = packet_counts[i]
current_bytes = byte_counts[i]
# 检查是否为峰值
is_packet_peak = all(current_packets >= packet_counts[i-j] for j in range(1, window_size+1)) and \
all(current_packets >= packet_counts[i+j] for j in range(1, window_size+1))
is_byte_peak = all(current_bytes >= byte_counts[i-j] for j in range(1, window_size+1)) and \
all(current_bytes >= byte_counts[i+j] for j in range(1, window_size+1))
if is_packet_peak or is_byte_peak:
peaks.append({
'time_index': i,
'packet_count': current_packets,
'byte_count': current_bytes,
'is_packet_peak': is_packet_peak,
'is_byte_peak': is_byte_peak
})
temporal_features['traffic_peaks'] = peaks
self.temporal_features = temporal_features
return temporal_features
def extract_statistical_features(self) -> Dict[str, Any]:
"""
统计特征提取
Returns:
Dict[str, Any]: 统计特征字典
"""
print("开始统计特征提取...")
if self.df is None:
self.load_data()
statistical_features = {}
# 1. 流量大小统计
print("1. 流量大小统计...")
packet_sizes = self.df['数据包大小']
payload_sizes = self.df['载荷大小']
size_statistics = {
'packet_size_stats': {
'count': len(packet_sizes),
'mean': float(packet_sizes.mean()),
'std': float(packet_sizes.std()),
'min': float(packet_sizes.min()),
'max': float(packet_sizes.max()),
'median': float(packet_sizes.median()),
'q25': float(packet_sizes.quantile(0.25)),
'q75': float(packet_sizes.quantile(0.75))
},
'payload_size_stats': {
'count': len(payload_sizes),
'mean': float(payload_sizes.mean()),
'std': float(payload_sizes.std()),
'min': float(payload_sizes.min()),
'max': float(payload_sizes.max()),
'median': float(payload_sizes.median()),
'q25': float(payload_sizes.quantile(0.25)),
'q75': float(payload_sizes.quantile(0.75))
}
}
statistical_features['size_statistics'] = size_statistics
# 2. 速率方差计算
print("2. 速率方差计算...")
# 按时间窗口计算速率
time_windows = []
start_time = self.df['时间戳'].min()
end_time = self.df['时间戳'].max()
window_size = 60 # 60秒窗口
current_time = start_time
while current_time < end_time:
window_end = current_time + timedelta(seconds=window_size)
window_data = self.df[
(self.df['时间戳'] >= current_time) &
(self.df['时间戳'] < window_end)
]
if len(window_data) > 0:
packet_rate = len(window_data) / window_size # 包/秒
byte_rate = (window_data['数据包大小'].sum() / window_size) if '数据包大小' in self.df.columns else 0
time_windows.append({
'start_time': current_time.isoformat(),
'packet_rate': packet_rate,
'byte_rate': byte_rate
})
current_time = window_end
if len(time_windows) > 1:
packet_rates = [w['packet_rate'] for w in time_windows]
byte_rates = [w['byte_rate'] for w in time_windows]
rate_variance = {
'packet_rate_variance': float(np.var(packet_rates)),
'byte_rate_variance': float(np.var(byte_rates)),
'packet_rate_std': float(np.std(packet_rates)),
'byte_rate_std': float(np.std(byte_rates)),
'packet_rate_mean': float(np.mean(packet_rates)),
'byte_rate_mean': float(np.mean(byte_rates))
}
statistical_features['rate_variance'] = rate_variance
# 3. 连接统计特征
print("3. 连接统计特征...")
# 按源IP和目标IP分组
connections = self.df.groupby(['源IP', '目标IP']).agg({
'数据包大小': ['count', 'sum', 'mean', 'std'],
'载荷大小': ['sum', 'mean', 'std'],
'时间戳': ['min', 'max']
}).round(2)
# 计算连接持续时间
connections['duration'] = (connections[('时间戳', 'max')] - connections[('时间戳', 'min')]).dt.total_seconds()
# 转换MultiIndex为可序列化的格式
connection_stats = {
'total_connections': len(connections),
'avg_packets_per_connection': float(connections[('数据包大小', 'count')].mean()),
'avg_bytes_per_connection': float(connections[('数据包大小', 'sum')].mean()),
'avg_duration_per_connection': float(connections['duration'].mean()),
'max_duration': float(connections['duration'].max()),
'min_duration': float(connections['duration'].min())
}
statistical_features['connection_statistics'] = connection_stats
# 4. 流量模式分析
print("4. 流量模式分析...")
# 计算流量的自相关
if len(time_windows) > 10:
packet_rates = [w['packet_rate'] for w in time_windows]
byte_rates = [w['byte_rate'] for w in time_windows]
# 计算自相关系数
packet_autocorr = np.corrcoef(packet_rates[:-1], packet_rates[1:])[0, 1] if len(packet_rates) > 1 else 0
byte_autocorr = np.corrcoef(byte_rates[:-1], byte_rates[1:])[0, 1] if len(byte_rates) > 1 else 0
pattern_analysis = {
'packet_autocorrelation': float(packet_autocorr),
'byte_autocorrelation': float(byte_autocorr),
'traffic_regularity': abs(packet_autocorr) + abs(byte_autocorr)
}
statistical_features['pattern_analysis'] = pattern_analysis
self.statistical_features = statistical_features
return statistical_features
def extract_all_features(self, time_window: int = 60) -> Dict[str, Any]:
"""
提取所有特征
Args:
time_window (int): 时间窗口(秒)
Returns:
Dict[str, Any]: 所有特征的字典
"""
print("开始提取所有特征...")
# 提取各类特征
protocol_features = self.extract_protocol_features()
temporal_features = self.extract_temporal_features(time_window)
statistical_features = self.extract_statistical_features()
# 合并所有特征
all_features = {
'protocol_features': protocol_features,
'temporal_features': temporal_features,
'statistical_features': statistical_features,
'extraction_time': datetime.now().isoformat(),
'data_info': {
'total_records': len(self.df),
'time_range': {
'start': self.df['时间戳'].min().isoformat(),
'end': self.df['时间戳'].max().isoformat()
},
'unique_ips': len(set(self.df['源IP'].tolist() + self.df['目标IP'].tolist())),
'unique_connections': len(self.df.groupby(['源IP', '目标IP']))
}
}
self.extracted_features = all_features
return all_features
def export_features(self, output_dir: str = "feature_extraction_results") -> Dict[str, str]:
"""
导出提取的特征支持CSV和JSON格式
Args:
output_dir (str): 输出目录
Returns:
Dict[str, str]: 输出文件路径
"""
import os
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_files = {}
# 导出所有特征CSV扁平化
if self.extracted_features:
features_file = os.path.join(output_dir, f"extracted_features_{timestamp}.csv")
flat_rows = []
def add_prefix(prefix, obj):
if isinstance(obj, dict):
for k, v in obj.items():
add_prefix(f"{prefix}.{k}" if prefix else str(k), v)
elif isinstance(obj, list):
for i, v in enumerate(obj):
add_prefix(f"{prefix}[{i}]", v)
else:
flat_rows.append({'': prefix, '': obj})
add_prefix('', self.extracted_features)
pd.DataFrame(flat_rows).to_csv(features_file, index=False, encoding='utf-8')
output_files['all_features_csv'] = features_file
print(f"所有特征(CSV)已导出: {features_file}")
# 导出协议特征CSV
if self.protocol_features:
# CSV格式
protocol_csv_file = os.path.join(output_dir, f"protocol_features_{timestamp}.csv")
self._export_protocol_features_csv(protocol_csv_file)
output_files['protocol_features_csv'] = protocol_csv_file
print(f"协议特征(CSV)已导出: {protocol_csv_file}")
# 导出时序特征CSV
if self.temporal_features:
# CSV格式
temporal_csv_file = os.path.join(output_dir, f"temporal_features_{timestamp}.csv")
self._export_temporal_features_csv(temporal_csv_file)
output_files['temporal_features_csv'] = temporal_csv_file
print(f"时序特征(CSV)已导出: {temporal_csv_file}")
# 导出统计特征CSV
if self.statistical_features:
# CSV格式
statistical_csv_file = os.path.join(output_dir, f"statistical_features_{timestamp}.csv")
self._export_statistical_features_csv(statistical_csv_file)
output_files['statistical_features_csv'] = statistical_csv_file
print(f"统计特征(CSV)已导出: {statistical_csv_file}")
return output_files
def _export_protocol_features_csv(self, output_file: str):
"""导出协议特征到CSV"""
csv_data = []
# 协议分布
if 'protocol_distribution' in self.protocol_features:
for protocol, count in self.protocol_features['protocol_distribution'].items():
csv_data.append({
'特征类型': '协议分布',
'协议': protocol,
'数据包数量': count,
'占比': self.protocol_features.get('protocol_percentages', {}).get(protocol, 0)
})
# 协议统计
if 'protocol_statistics' in self.protocol_features:
for protocol, stats in self.protocol_features['protocol_statistics'].items():
csv_data.append({
'特征类型': '协议统计',
'协议': protocol,
'数据包数量': stats.get('packet_count', 0),
'总字节数': stats.get('total_bytes', 0),
'平均包大小': stats.get('avg_packet_size', 0),
'最大包大小': stats.get('max_packet_size', 0),
'最小包大小': stats.get('min_packet_size', 0),
'载荷字节数': stats.get('payload_bytes', 0),
'平均载荷大小': stats.get('avg_payload_size', 0)
})
# 协议组合
if 'protocol_combinations' in self.protocol_features:
for combination, count in self.protocol_features['protocol_combinations'].items():
csv_data.append({
'特征类型': '协议组合',
'协议组合': combination,
'连接数量': count
})
# 写入CSV
if csv_data:
df = pd.DataFrame(csv_data)
df.to_csv(output_file, index=False, encoding='utf-8')
def _export_temporal_features_csv(self, output_file: str):
"""导出时序特征到CSV"""
csv_data = []
# 时间分布
if 'time_distribution' in self.temporal_features:
for i, window in enumerate(self.temporal_features['time_distribution']):
csv_data.append({
'特征类型': '时间分布',
'时间窗口': i,
'开始时间': window.get('start_time', ''),
'结束时间': window.get('end_time', ''),
'数据包数量': window.get('packet_count', 0),
'总字节数': window.get('total_bytes', 0),
'唯一连接数': window.get('unique_connections', 0)
})
# 活跃度变化率
if 'activity_change_rates' in self.temporal_features:
change_rates = self.temporal_features['activity_change_rates']
csv_data.append({
'特征类型': '活跃度变化率',
'平均包变化率': change_rates.get('avg_packet_change_rate', 0),
'平均字节变化率': change_rates.get('avg_byte_change_rate', 0),
'最大包变化率': change_rates.get('max_packet_change_rate', 0),
'最大字节变化率': change_rates.get('max_byte_change_rate', 0)
})
# 流量峰值
if 'traffic_peaks' in self.temporal_features:
for peak in self.temporal_features['traffic_peaks']:
csv_data.append({
'特征类型': '流量峰值',
'时间索引': peak.get('time_index', 0),
'数据包数量': peak.get('packet_count', 0),
'字节数量': peak.get('byte_count', 0),
'是否包峰值': peak.get('is_packet_peak', False),
'是否字节峰值': peak.get('is_byte_peak', False)
})
# 写入CSV
if csv_data:
df = pd.DataFrame(csv_data)
df.to_csv(output_file, index=False, encoding='utf-8')
def _export_statistical_features_csv(self, output_file: str):
"""导出统计特征到CSV"""
csv_data = []
# 包大小统计
if 'size_statistics' in self.statistical_features:
size_stats = self.statistical_features['size_statistics']
# 数据包大小统计
packet_stats = size_stats.get('packet_size_stats', {})
csv_data.append({
'特征类型': '包大小统计',
'统计项': '数据包大小',
'数量': packet_stats.get('count', 0),
'平均值': packet_stats.get('mean', 0),
'标准差': packet_stats.get('std', 0),
'最小值': packet_stats.get('min', 0),
'最大值': packet_stats.get('max', 0),
'中位数': packet_stats.get('median', 0),
'25分位数': packet_stats.get('q25', 0),
'75分位数': packet_stats.get('q75', 0)
})
# 载荷大小统计
payload_stats = size_stats.get('payload_size_stats', {})
csv_data.append({
'特征类型': '载荷大小统计',
'统计项': '载荷大小',
'数量': payload_stats.get('count', 0),
'平均值': payload_stats.get('mean', 0),
'标准差': payload_stats.get('std', 0),
'最小值': payload_stats.get('min', 0),
'最大值': payload_stats.get('max', 0),
'中位数': payload_stats.get('median', 0),
'25分位数': payload_stats.get('q25', 0),
'75分位数': payload_stats.get('q75', 0)
})
# 速率方差
if 'rate_variance' in self.statistical_features:
rate_stats = self.statistical_features['rate_variance']
csv_data.append({
'特征类型': '速率方差',
'包速率方差': rate_stats.get('packet_rate_variance', 0),
'字节速率方差': rate_stats.get('byte_rate_variance', 0),
'包速率标准差': rate_stats.get('packet_rate_std', 0),
'字节速率标准差': rate_stats.get('byte_rate_std', 0),
'包速率平均值': rate_stats.get('packet_rate_mean', 0),
'字节速率平均值': rate_stats.get('byte_rate_mean', 0)
})
# 连接统计
if 'connection_statistics' in self.statistical_features:
conn_stats = self.statistical_features['connection_statistics']
csv_data.append({
'特征类型': '连接统计',
'总连接数': conn_stats.get('total_connections', 0),
'平均每连接包数': conn_stats.get('avg_packets_per_connection', 0),
'平均每连接字节数': conn_stats.get('avg_bytes_per_connection', 0),
'平均连接持续时间': conn_stats.get('avg_duration_per_connection', 0),
'最大连接持续时间': conn_stats.get('max_duration', 0),
'最小连接持续时间': conn_stats.get('min_duration', 0)
})
# 模式分析
if 'pattern_analysis' in self.statistical_features:
pattern_stats = self.statistical_features['pattern_analysis']
csv_data.append({
'特征类型': '模式分析',
'包自相关': pattern_stats.get('packet_autocorrelation', 0),
'字节自相关': pattern_stats.get('byte_autocorrelation', 0),
'流量规律性': pattern_stats.get('traffic_regularity', 0)
})
# 写入CSV
if csv_data:
df = pd.DataFrame(csv_data)
df.to_csv(output_file, index=False, encoding='utf-8')
def print_feature_summary(self):
"""打印特征提取摘要"""
print("\n" + "="*60)
print("特征提取摘要")
print("="*60)
if self.protocol_features:
print("协议特征:")
print(f" 协议类型数: {len(self.protocol_features.get('protocol_distribution', {}))}")
print(f" 协议分布: {self.protocol_features.get('protocol_distribution', {})}")
if self.temporal_features:
print("\n时序特征:")
time_windows = self.temporal_features.get('time_distribution', [])
print(f" 时间窗口数: {len(time_windows)}")
if 'activity_change_rates' in self.temporal_features:
change_rates = self.temporal_features['activity_change_rates']
print(f" 平均包变化率: {change_rates.get('avg_packet_change_rate', 0):.4f}")
print(f" 平均字节变化率: {change_rates.get('avg_byte_change_rate', 0):.4f}")
if self.statistical_features:
print("\n统计特征:")
size_stats = self.statistical_features.get('size_statistics', {})
if 'packet_size_stats' in size_stats:
packet_stats = size_stats['packet_size_stats']
print(f" 平均包大小: {packet_stats.get('mean', 0):.2f} 字节")
print(f" 包大小标准差: {packet_stats.get('std', 0):.2f}")
if 'rate_variance' in self.statistical_features:
rate_stats = self.statistical_features['rate_variance']
print(f" 包速率方差: {rate_stats.get('packet_rate_variance', 0):.2f}")
print(f" 字节速率方差: {rate_stats.get('byte_rate_variance', 0):.2f}")
print("="*60)
def main():
"""主函数"""
print("多模态网络流图数据抽取 - 初步特征提取模块支持CSV导出")
print("="*70)
# 查找最新的预处理CSV文件
import glob
import os
csv_files = glob.glob("preprocessing_results/cleaned_data_*.csv")
if not csv_files:
print("错误找不到预处理后的CSV文件")
print("请先运行数据预处理流水线")
return
# 使用最新的CSV文件
latest_csv = max(csv_files, key=os.path.getctime)
print(f"使用预处理数据文件: {latest_csv}")
try:
# 创建特征提取器
extractor = FeatureExtractor(latest_csv)
# 执行特征提取
print("\n1. 加载数据...")
extractor.load_data()
print("\n2. 提取所有特征...")
extractor.extract_all_features(time_window=60)
print("\n3. 导出特征...")
output_files = extractor.export_features()
print("\n4. 生成摘要...")
extractor.print_feature_summary()
print(f"\n特征提取完成!")
print("输出文件:")
for key, file_path in output_files.items():
print(f" {key}: {file_path}")
except Exception as e:
print(f"特征提取过程中出错: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()