import os import random import torch import torchaudio from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms class AudioDataset(Dataset): def __init__(self, root_dir, use_mfcc=False): self.root_dir = root_dir self.use_mfcc = use_mfcc self.file_paths = [] self.labels = [] self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5} self._load_dataset() self.spec_transform = transforms.Compose([ # resize到128x128 transforms.Resize((128, 128)), # 归一化 transforms.Normalize(mean=[0.5], std=[0.5]) ]) self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128) def _load_dataset(self): label_set = set() for label in os.listdir(self.root_dir): label_dir = os.path.join(self.root_dir, label) if os.path.isdir(label_dir): label_set.add(label) for file_name in os.listdir(label_dir): if file_name.endswith('.wav'): self.file_paths.append(os.path.join(label_dir, file_name)) self.labels.append(label) # 将标签转换为整数 self.labels = [self.label_to_int[label] for label in self.labels] def __len__(self): return len(self.file_paths) def __getitem__(self, idx): file_path = self.file_paths[idx] # print(file_path) label = self.labels[idx] waveform, sample_rate = torchaudio.load(file_path) # 重新以16KHz采样率加载音频 waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) sample_rate = 16000 # 随机裁剪为1.5s length = int(sample_rate * 1.5) if waveform.size(1) > length: start = random.randint(0, waveform.size(1) - length) waveform = waveform[:, start:start+length] else: # 如果音频长度不足1.5s,则进行填充 padding = length - waveform.size(1) waveform = torch.nn.functional.pad(waveform, (0, padding)) if self.use_mfcc: waveform = self.mfcc_transform(waveform) waveform = self.spec_transform(waveform) return waveform, label def process_audio_file(file_path, use_mfcc): waveform, sample_rate = torchaudio.load(file_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) sample_rate = 16000 length = int(sample_rate * 1.5) if waveform.size(1) > length: start = random.randint(0, waveform.size(1) - length) waveform = waveform[:, start:start+length] else: padding = length - waveform.size(1) waveform = torch.nn.functional.pad(waveform, (0, padding)) if use_mfcc: mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128) spec_transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.Normalize(mean=[0.5], std=[0.5]) ]) waveform = mfcc_transform(waveform) waveform = spec_transform(waveform) return waveform if __name__ == '__main__': # 使用示例 root_dir = r'dataset\train' dataset = AudioDataset(root_dir, use_mfcc=True) print(len(dataset)) waveform, label = dataset[123] print(waveform.shape, label)