You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
3.5 KiB

import os
import random
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
class AudioDataset(Dataset):
def __init__(self, root_dir, use_mfcc=False):
self.root_dir = root_dir
self.use_mfcc = use_mfcc
self.file_paths = []
self.labels = []
self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5}
self._load_dataset()
self.spec_transform = transforms.Compose([
# resize到128x128
transforms.Resize((128, 128)),
# 归一化
transforms.Normalize(mean=[0.5], std=[0.5])
])
self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
def _load_dataset(self):
label_set = set()
for label in os.listdir(self.root_dir):
label_dir = os.path.join(self.root_dir, label)
if os.path.isdir(label_dir):
label_set.add(label)
for file_name in os.listdir(label_dir):
if file_name.endswith('.wav'):
self.file_paths.append(os.path.join(label_dir, file_name))
self.labels.append(label)
# 将标签转换为整数
self.labels = [self.label_to_int[label] for label in self.labels]
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
# print(file_path)
label = self.labels[idx]
waveform, sample_rate = torchaudio.load(file_path)
# 重新以16KHz采样率加载音频
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
sample_rate = 16000
# 随机裁剪为1.5s
length = int(sample_rate * 1.5)
if waveform.size(1) > length:
start = random.randint(0, waveform.size(1) - length)
waveform = waveform[:, start:start+length]
else:
# 如果音频长度不足1.5s,则进行填充
padding = length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding))
if self.use_mfcc:
waveform = self.mfcc_transform(waveform)
waveform = self.spec_transform(waveform)
return waveform, label
def process_audio_file(file_path, use_mfcc):
waveform, sample_rate = torchaudio.load(file_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
sample_rate = 16000
length = int(sample_rate * 1.5)
if waveform.size(1) > length:
start = random.randint(0, waveform.size(1) - length)
waveform = waveform[:, start:start+length]
else:
padding = length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding))
if use_mfcc:
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
spec_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.Normalize(mean=[0.5], std=[0.5])
])
waveform = mfcc_transform(waveform)
waveform = spec_transform(waveform)
return waveform
if __name__ == '__main__':
# 使用示例
root_dir = r'dataset\train'
dataset = AudioDataset(root_dir, use_mfcc=True)
print(len(dataset))
waveform, label = dataset[123]
print(waveform.shape, label)