diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..247f422 --- /dev/null +++ b/dataset.py @@ -0,0 +1,92 @@ +import os +import random +import torch +import torchaudio +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + +class AudioDataset(Dataset): + def __init__(self, root_dir, use_mfcc=False): + self.root_dir = root_dir + self.use_mfcc = use_mfcc + self.file_paths = [] + self.labels = [] + self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5} + self._load_dataset() + self.spec_transform = transforms.Compose([ + # resize到128x128 + transforms.Resize((128, 128)), + # 归一化 + transforms.Normalize(mean=[0.5], std=[0.5]) + ]) + self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128) + + def _load_dataset(self): + label_set = set() + for label in os.listdir(self.root_dir): + label_dir = os.path.join(self.root_dir, label) + if os.path.isdir(label_dir): + label_set.add(label) + for file_name in os.listdir(label_dir): + if file_name.endswith('.wav'): + self.file_paths.append(os.path.join(label_dir, file_name)) + self.labels.append(label) + + # 将标签转换为整数 + self.labels = [self.label_to_int[label] for label in self.labels] + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + file_path = self.file_paths[idx] + # print(file_path) + label = self.labels[idx] + waveform, sample_rate = torchaudio.load(file_path) + # 重新以16KHz采样率加载音频 + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) + sample_rate = 16000 + # 随机裁剪为1.5s + length = int(sample_rate * 1.5) + if waveform.size(1) > length: + start = random.randint(0, waveform.size(1) - length) + waveform = waveform[:, start:start+length] + else: + # 如果音频长度不足1.5s,则进行填充 + padding = length - waveform.size(1) + waveform = torch.nn.functional.pad(waveform, (0, padding)) + if self.use_mfcc: + + waveform = self.mfcc_transform(waveform) + waveform = self.spec_transform(waveform) + + return waveform, label + + +def process_audio_file(file_path, use_mfcc): + waveform, sample_rate = torchaudio.load(file_path) + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) + sample_rate = 16000 + length = int(sample_rate * 1.5) + if waveform.size(1) > length: + start = random.randint(0, waveform.size(1) - length) + waveform = waveform[:, start:start+length] + else: + padding = length - waveform.size(1) + waveform = torch.nn.functional.pad(waveform, (0, padding)) + if use_mfcc: + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128) + spec_transform = transforms.Compose([ + transforms.Resize((128, 128)), + transforms.Normalize(mean=[0.5], std=[0.5]) + ]) + waveform = mfcc_transform(waveform) + waveform = spec_transform(waveform) + return waveform +if __name__ == '__main__': + # 使用示例 + root_dir = r'dataset\train' + dataset = AudioDataset(root_dir, use_mfcc=True) + print(len(dataset)) + waveform, label = dataset[123] + print(waveform.shape, label) \ No newline at end of file