parent
6cc65e11ee
commit
783a656688
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
class AudioDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, use_mfcc=False):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.use_mfcc = use_mfcc
|
||||||
|
self.file_paths = []
|
||||||
|
self.labels = []
|
||||||
|
self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5}
|
||||||
|
self._load_dataset()
|
||||||
|
self.spec_transform = transforms.Compose([
|
||||||
|
# resize到128x128
|
||||||
|
transforms.Resize((128, 128)),
|
||||||
|
# 归一化
|
||||||
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
||||||
|
])
|
||||||
|
self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
||||||
|
|
||||||
|
def _load_dataset(self):
|
||||||
|
label_set = set()
|
||||||
|
for label in os.listdir(self.root_dir):
|
||||||
|
label_dir = os.path.join(self.root_dir, label)
|
||||||
|
if os.path.isdir(label_dir):
|
||||||
|
label_set.add(label)
|
||||||
|
for file_name in os.listdir(label_dir):
|
||||||
|
if file_name.endswith('.wav'):
|
||||||
|
self.file_paths.append(os.path.join(label_dir, file_name))
|
||||||
|
self.labels.append(label)
|
||||||
|
|
||||||
|
# 将标签转换为整数
|
||||||
|
self.labels = [self.label_to_int[label] for label in self.labels]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
file_path = self.file_paths[idx]
|
||||||
|
# print(file_path)
|
||||||
|
label = self.labels[idx]
|
||||||
|
waveform, sample_rate = torchaudio.load(file_path)
|
||||||
|
# 重新以16KHz采样率加载音频
|
||||||
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
||||||
|
sample_rate = 16000
|
||||||
|
# 随机裁剪为1.5s
|
||||||
|
length = int(sample_rate * 1.5)
|
||||||
|
if waveform.size(1) > length:
|
||||||
|
start = random.randint(0, waveform.size(1) - length)
|
||||||
|
waveform = waveform[:, start:start+length]
|
||||||
|
else:
|
||||||
|
# 如果音频长度不足1.5s,则进行填充
|
||||||
|
padding = length - waveform.size(1)
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
if self.use_mfcc:
|
||||||
|
|
||||||
|
waveform = self.mfcc_transform(waveform)
|
||||||
|
waveform = self.spec_transform(waveform)
|
||||||
|
|
||||||
|
return waveform, label
|
||||||
|
|
||||||
|
|
||||||
|
def process_audio_file(file_path, use_mfcc):
|
||||||
|
waveform, sample_rate = torchaudio.load(file_path)
|
||||||
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
||||||
|
sample_rate = 16000
|
||||||
|
length = int(sample_rate * 1.5)
|
||||||
|
if waveform.size(1) > length:
|
||||||
|
start = random.randint(0, waveform.size(1) - length)
|
||||||
|
waveform = waveform[:, start:start+length]
|
||||||
|
else:
|
||||||
|
padding = length - waveform.size(1)
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
if use_mfcc:
|
||||||
|
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
||||||
|
spec_transform = transforms.Compose([
|
||||||
|
transforms.Resize((128, 128)),
|
||||||
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
||||||
|
])
|
||||||
|
waveform = mfcc_transform(waveform)
|
||||||
|
waveform = spec_transform(waveform)
|
||||||
|
return waveform
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 使用示例
|
||||||
|
root_dir = r'dataset\train'
|
||||||
|
dataset = AudioDataset(root_dir, use_mfcc=True)
|
||||||
|
print(len(dataset))
|
||||||
|
waveform, label = dataset[123]
|
||||||
|
print(waveform.shape, label)
|
Loading…
Reference in new issue