You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.5 KiB
92 lines
3.5 KiB
import os
|
|
import random
|
|
import torch
|
|
import torchaudio
|
|
from torch.utils.data import Dataset, DataLoader, random_split
|
|
from torchvision import transforms
|
|
|
|
class AudioDataset(Dataset):
|
|
def __init__(self, root_dir, use_mfcc=False):
|
|
self.root_dir = root_dir
|
|
self.use_mfcc = use_mfcc
|
|
self.file_paths = []
|
|
self.labels = []
|
|
self.label_to_int = {'angry':0,'fear':1,'happy':2,'neutral':3,'sad':4,'surprise':5}
|
|
self._load_dataset()
|
|
self.spec_transform = transforms.Compose([
|
|
# resize到128x128
|
|
transforms.Resize((128, 128)),
|
|
# 归一化
|
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
|
])
|
|
self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
|
|
|
def _load_dataset(self):
|
|
label_set = set()
|
|
for label in os.listdir(self.root_dir):
|
|
label_dir = os.path.join(self.root_dir, label)
|
|
if os.path.isdir(label_dir):
|
|
label_set.add(label)
|
|
for file_name in os.listdir(label_dir):
|
|
if file_name.endswith('.wav'):
|
|
self.file_paths.append(os.path.join(label_dir, file_name))
|
|
self.labels.append(label)
|
|
|
|
# 将标签转换为整数
|
|
self.labels = [self.label_to_int[label] for label in self.labels]
|
|
|
|
def __len__(self):
|
|
return len(self.file_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
file_path = self.file_paths[idx]
|
|
# print(file_path)
|
|
label = self.labels[idx]
|
|
waveform, sample_rate = torchaudio.load(file_path)
|
|
# 重新以16KHz采样率加载音频
|
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
|
sample_rate = 16000
|
|
# 随机裁剪为1.5s
|
|
length = int(sample_rate * 1.5)
|
|
if waveform.size(1) > length:
|
|
start = random.randint(0, waveform.size(1) - length)
|
|
waveform = waveform[:, start:start+length]
|
|
else:
|
|
# 如果音频长度不足1.5s,则进行填充
|
|
padding = length - waveform.size(1)
|
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
|
if self.use_mfcc:
|
|
|
|
waveform = self.mfcc_transform(waveform)
|
|
waveform = self.spec_transform(waveform)
|
|
|
|
return waveform, label
|
|
|
|
|
|
def process_audio_file(file_path, use_mfcc):
|
|
waveform, sample_rate = torchaudio.load(file_path)
|
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
|
sample_rate = 16000
|
|
length = int(sample_rate * 1.5)
|
|
if waveform.size(1) > length:
|
|
start = random.randint(0, waveform.size(1) - length)
|
|
waveform = waveform[:, start:start+length]
|
|
else:
|
|
padding = length - waveform.size(1)
|
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
|
if use_mfcc:
|
|
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
|
|
spec_transform = transforms.Compose([
|
|
transforms.Resize((128, 128)),
|
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
|
])
|
|
waveform = mfcc_transform(waveform)
|
|
waveform = spec_transform(waveform)
|
|
return waveform
|
|
if __name__ == '__main__':
|
|
# 使用示例
|
|
root_dir = r'dataset\train'
|
|
dataset = AudioDataset(root_dir, use_mfcc=True)
|
|
print(len(dataset))
|
|
waveform, label = dataset[123]
|
|
print(waveform.shape, label) |