You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

297 lines
10 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
mmWave_RA_RE.py - 深度学习特征生成基线脚本
功能: 从雷达数据立方体生成 Range-Doppler, Range-Azimuth, Range-Elevation 特征图
输入: Cube_data 目录下的 *_cube.npy 文件 [Frames, VirtRx, Chirps, Samples]
输出:
- *_features.npz: 包含 rd, ra, re 三个数组
- *_merged.npy: 拼接后的 [Frames, Range, Channels] 张量
"""
import os
import sys
import glob
import json
import numpy as np
from scipy.signal.windows import taylor, hamming
from tqdm import tqdm
# 导入配置解析器
try:
from process_data import parse_radar_cfg, parse_tx_order
except ImportError:
try:
from src.process_data import parse_radar_cfg, parse_tx_order
except ImportError:
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from process_data import parse_radar_cfg, parse_tx_order
# ===================================================================
# 配置区域
# ===================================================================
INPUT_CUBE_DIR = r"F:\test\Data_bin\lab\Cube_data"
OUTPUT_FEATURE_DIR = r"F:\Data_bin\dormitory\Heatmap_Features_25"
CFG_PATH = r"F:\test\cfg\Radar.cfg"
# 算法参数
DOWNSAMPLE_FACTOR = 4 # 帧降采样因子 (1=不降采样)
ANGLE_FFT_SIZE = 64 # 角度 FFT 点数
RANGE_WINDOW = "hamming" # Range FFT 窗函数
DOPPLER_WINDOW = "hamming" # Doppler FFT 窗函数
# IWR6843ISK 天线拓扑定义
IDXS_AZIMUTH = np.arange(0, 8) # 水平阵列: TX0(0-3) + TX2(4-7)
# 垂直配对 (水平位置对齐): (2,8), (3,9), (4,10), (5,11) - 见 elevation_phase_diff()
# ===================================================================
# 窗函数缓存 (避免重复计算)
# ===================================================================
_WINDOW_CACHE = {}
def get_window(name: str, length: int) -> np.ndarray:
"""获取或创建窗函数"""
key = (name, length)
if key not in _WINDOW_CACHE:
if name == 'hamming':
_WINDOW_CACHE[key] = hamming(length)
elif name == 'taylor':
try:
_WINDOW_CACHE[key] = taylor(length, nbar=3, sll=40)
except:
_WINDOW_CACHE[key] = hamming(length)
else:
_WINDOW_CACHE[key] = np.ones(length)
return _WINDOW_CACHE[key]
# ===================================================================
# 核心处理函数
# ===================================================================
def normalize_complex_global(data, global_max=None):
"""保留相位的归一化"""
magnitude = np.abs(data)
phase = np.angle(data)
denom = global_max if global_max else 1e4
return (magnitude / (denom + 1e-12)) * np.exp(1j * phase)
def range_fft(cube: np.ndarray, window_type: str = 'hamming') -> np.ndarray:
"""
Range FFT
Input: [Frames, VirtRx, Chirps, Samples]
Output: [Frames, VirtRx, Chirps, RangeBins]
"""
n_samples = cube.shape[-1]
win = get_window(window_type, n_samples)
cube_win = cube * win.reshape(1, 1, 1, -1)
return np.fft.fft(cube_win, axis=-1)
def doppler_fft(range_cube: np.ndarray, window_type: str = 'hamming') -> np.ndarray:
"""
Doppler FFT
Input: [Frames, VirtRx, Chirps, RangeBins]
Output: [Frames, VirtRx, Doppler, RangeBins] (fftshift)
"""
n_chirps = range_cube.shape[2]
win = get_window(window_type, n_chirps)
cube_win = range_cube * win.reshape(1, 1, -1, 1)
dop_fft = np.fft.fft(cube_win, axis=2)
return np.fft.fftshift(dop_fft, axes=2)
def clutter_removal(data: np.ndarray, axis: int = 2) -> np.ndarray:
"""静态杂波去除 (减去均值)"""
mean_val = np.mean(data, axis=axis, keepdims=True)
return data - mean_val
def azimuth_fft(range_cube: np.ndarray, num_angle_bins: int = 64) -> np.ndarray:
"""
Azimuth FFT (生成 Range-Azimuth Map)
Input: [Frames, VirtRx, Chirps, Range]
Output: [Frames, Angle, Range]
"""
# 选取水平阵列
az_data = range_cube[:, IDXS_AZIMUTH, :, :]
n_ant = az_data.shape[1]
# 加窗 (Taylor 窗获得更好的旁瓣抑制)
win = get_window('taylor', n_ant)
az_data = az_data * win.reshape(1, -1, 1, 1)
# Angle FFT
angle_out = np.fft.fft(az_data, n=num_angle_bins, axis=1)
angle_out = np.fft.fftshift(angle_out, axes=1)
# 计算功率谱并对 Chirp 维求平均 -> [Frames, Angle, Range]
ra_map = np.mean(np.abs(angle_out), axis=2)
return ra_map
def elevation_phase_diff(range_cube, num_angle_bins=64):
"""
[修正与优化] 使用4对垂直配对计算俯仰角
IWR6843ISK 物理对齐分析:
- Base Array (Z=0): Index 0,1,2,3 (TX0) 和 4,5,6,7 (TX2)
- Elev Array (Z=1): Index 8,9,10,11 (TX1, 物理右移2单位)
水平对齐配对 (X坐标相同):
- X=2: Index 2 (TX0_RX2) & Index 8 (TX1_RX0)
- X=3: Index 3 (TX0_RX3) & Index 9 (TX1_RX1)
- X=4: Index 4 (TX2_RX0) & Index 10 (TX1_RX2)
- X=5: Index 5 (TX2_RX1) & Index 11 (TX1_RX3)
"""
# [关键修正] 正确的配对列表
elevation_pairs = [(2, 8), (3, 9), (4, 10), (5, 11)]
n_frames = range_cube.shape[0]
# 存储每对的FFT结果
all_re_maps = []
for idx1, idx2 in elevation_pairs:
# 取出这一对天线 [Frames, 2, Chirps, Range]
el_data = range_cube[:, [idx1, idx2], :, :]
n_ant = 2 # 每对2根天线
win = get_window('hamming', n_ant)
el_data = el_data * win.reshape(1, -1, 1, 1)
# 2点FFT (在 axis=1, 天线维)
# 注意: 2点FFT补零到64点利用插值获得平滑的相位响应
el_out = np.fft.fft(el_data, n=num_angle_bins, axis=1)
el_out = np.fft.fftshift(el_out, axes=1)
# 计算模值 (不单独归一化,避免弱信号噪声被放大)
el_mag = np.abs(el_out)
# 对Chirp维求平均 (非相干积累) -> [Frames, Angle, Range]
re_map = np.mean(el_mag, axis=2)
all_re_maps.append(re_map)
# 对4对结果取平均, 提升信噪比
re_map_avg = np.mean(np.stack(all_re_maps, axis=0), axis=0)
return re_map_avg
# ===================================================================
# 主处理流程
# ===================================================================
def process_single_file(cube_path: str, output_root: str) -> None:
"""处理单个数据文件"""
filename = os.path.basename(cube_path).replace('_cube.npy', '')
# [Revert] 不再创建子目录,直接使用根目录
out_dir = output_root
# os.makedirs(out_dir, exist_ok=True) # 根目录在main中已创建
# 加载数据
try:
cube = np.load(cube_path)
except Exception as e:
tqdm.write(f" [Error] 加载 {filename} 失败: {e}")
return
# 数据校验
if cube.ndim != 4:
tqdm.write(f" [Error] {filename} 维度错误: {cube.shape}")
return
n_frames, n_rx, n_chirps, n_samples = cube.shape
if n_rx < 12:
tqdm.write(f" [Error] {filename} 通道数不足: {n_rx} < 12")
return
if np.any(np.isnan(cube)) or np.any(np.isinf(cube)):
tqdm.write(f" [Error] {filename} 包含 NaN 或 Inf")
return
# 1. Range FFT
range_res = range_fft(cube, window_type=RANGE_WINDOW)
# 2. Clutter Removal (MTI)
range_res = clutter_removal(range_res, axis=2)
# 3. Doppler FFT -> Range-Doppler Map
dop_res = doppler_fft(range_res, window_type=DOPPLER_WINDOW)
rd_map = np.mean(np.abs(dop_res), axis=1) # [Frames, Doppler, Range]
# 4. Azimuth FFT -> Range-Azimuth Map
ra_map = azimuth_fft(range_res, ANGLE_FFT_SIZE) # [Frames, Angle, Range]
# 5. Elevation FFT -> Range-Elevation Map
re_map = elevation_phase_diff(range_res, ANGLE_FFT_SIZE) # [Frames, Angle, Range]
# 6. 降采样
if DOWNSAMPLE_FACTOR > 1:
pick_indices = np.arange(0, n_frames, DOWNSAMPLE_FACTOR)
rd_map = rd_map[pick_indices, :, :]
ra_map = ra_map[pick_indices, :, :]
re_map = re_map[pick_indices, :, :]
# [Revert] 恢复旧的扁平化文件名逻辑: SampleName_features.npz / SampleName_merged.npy
# 7. 保存独立特征 (NPZ)
npz_path = os.path.join(out_dir, f"{filename}_features.npz")
np.savez_compressed(npz_path, rd=rd_map, ra=ra_map, re=re_map)
# 8. 保存合并张量 (NPY) - 适配 Transformer/3D-CNN 输入
# 转置为 [Frames, Range, Feature]
rd_T = np.transpose(rd_map, (0, 2, 1)) # [F, R, D]
ra_T = np.transpose(ra_map, (0, 2, 1)) # [F, R, A]
re_T = np.transpose(re_map, (0, 2, 1)) # [F, R, E]
# 拼接: [Frames, Range, D + A + E]
merged = np.concatenate([rd_T, ra_T, re_T], axis=-1)
merge_path = os.path.join(out_dir, f"{filename}_merged.npy")
np.save(merge_path, merged.astype(np.float32))
# [Revert] 不再保存 metadata.json
def main():
# 创建输出目录
os.makedirs(OUTPUT_FEATURE_DIR, exist_ok=True)
# 获取文件列表
cube_files = sorted(glob.glob(os.path.join(INPUT_CUBE_DIR, "*_cube.npy")))
if not cube_files:
print(f"未找到数据文件: {INPUT_CUBE_DIR}")
return
print(f"发现 {len(cube_files)} 个数据文件")
print(f"输出目录: {OUTPUT_FEATURE_DIR}")
print(f"降采样因子: {DOWNSAMPLE_FACTOR}, 角度FFT: {ANGLE_FFT_SIZE}")
print("-" * 50)
# 使用多进程并行处理
from multiprocessing import Pool, cpu_count
from functools import partial
num_workers = min(4, cpu_count()) # 最多4个进程
print(f"使用 {num_workers} 个并行进程")
# 创建带固定参数的处理函数
process_func = partial(process_single_file, output_root=OUTPUT_FEATURE_DIR)
try:
with Pool(num_workers) as pool:
# 使用imap显示进度
list(tqdm(
pool.imap(process_func, cube_files),
total=len(cube_files),
desc="Processing"
))
except Exception as e:
print(f"并行处理出错: {e}")
print("回退到串行处理...")
for f in tqdm(cube_files, desc="Processing"):
try:
process_single_file(f, OUTPUT_FEATURE_DIR)
except Exception as e:
tqdm.write(f"[Error] 处理 {os.path.basename(f)} 失败: {e}")
if __name__ == "__main__":
main()