parent
6cc65e11ee
commit
783a656688
@ -0,0 +1,92 @@
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision import transforms
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
def __init__(self, root_dir, use_mfcc=False):
|
||||
self.root_dir = root_dir
|
||||
self.use_mfcc = use_mfcc
|
||||
self.file_paths = []
|
||||
self.labels = []
|
||||
self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5}
|
||||
self._load_dataset()
|
||||
self.spec_transform = transforms.Compose([
|
||||
# resize到128x128
|
||||
transforms.Resize((128, 128)),
|
||||
# 归一化
|
||||
transforms.Normalize(mean=[0.5], std=[0.5])
|
||||
])
|
||||
self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
||||
|
||||
def _load_dataset(self):
|
||||
label_set = set()
|
||||
for label in os.listdir(self.root_dir):
|
||||
label_dir = os.path.join(self.root_dir, label)
|
||||
if os.path.isdir(label_dir):
|
||||
label_set.add(label)
|
||||
for file_name in os.listdir(label_dir):
|
||||
if file_name.endswith('.wav'):
|
||||
self.file_paths.append(os.path.join(label_dir, file_name))
|
||||
self.labels.append(label)
|
||||
|
||||
# 将标签转换为整数
|
||||
self.labels = [self.label_to_int[label] for label in self.labels]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_path = self.file_paths[idx]
|
||||
# print(file_path)
|
||||
label = self.labels[idx]
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
# 重新以16KHz采样率加载音频
|
||||
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
||||
sample_rate = 16000
|
||||
# 随机裁剪为1.5s
|
||||
length = int(sample_rate * 1.5)
|
||||
if waveform.size(1) > length:
|
||||
start = random.randint(0, waveform.size(1) - length)
|
||||
waveform = waveform[:, start:start+length]
|
||||
else:
|
||||
# 如果音频长度不足1.5s,则进行填充
|
||||
padding = length - waveform.size(1)
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
if self.use_mfcc:
|
||||
|
||||
waveform = self.mfcc_transform(waveform)
|
||||
waveform = self.spec_transform(waveform)
|
||||
|
||||
return waveform, label
|
||||
|
||||
|
||||
def process_audio_file(file_path, use_mfcc):
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
||||
sample_rate = 16000
|
||||
length = int(sample_rate * 1.5)
|
||||
if waveform.size(1) > length:
|
||||
start = random.randint(0, waveform.size(1) - length)
|
||||
waveform = waveform[:, start:start+length]
|
||||
else:
|
||||
padding = length - waveform.size(1)
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
if use_mfcc:
|
||||
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
||||
spec_transform = transforms.Compose([
|
||||
transforms.Resize((128, 128)),
|
||||
transforms.Normalize(mean=[0.5], std=[0.5])
|
||||
])
|
||||
waveform = mfcc_transform(waveform)
|
||||
waveform = spec_transform(waveform)
|
||||
return waveform
|
||||
if __name__ == '__main__':
|
||||
# 使用示例
|
||||
root_dir = r'dataset\train'
|
||||
dataset = AudioDataset(root_dir, use_mfcc=True)
|
||||
print(len(dataset))
|
||||
waveform, label = dataset[123]
|
||||
print(waveform.shape, label)
|
Loading…
Reference in new issue