You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
9.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torchvision.models as models
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock1D(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock1D, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(out_channels)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck1D(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(Bottleneck1D, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(out_channels)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(out_channels)
self.conv3 = nn.Conv1d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AudioResNet(nn.Module):
def __init__(self, block, layers, num_classes=7):
"""
构建用于音频分类的1D ResNet
参数:
block: 使用的残差块类型(BasicBlock1D or Bottleneck1D)
layers: 每个层的块数量的列表
num_classes: 分类的类别数量默认为7种情感
"""
super(AudioResNet, self).__init__()
self.in_channels = 64
# 初始卷积层,缩减序列长度
self.conv1 = nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
# 残差块堆叠
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# 全局平均池化和分类器
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 权重初始化
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv1d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
# 输入 x 形状: [batch_size, 1, 24000]
x = self.conv1(x) # [batch_size, 64, 12000]
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) # [batch_size, 64, 6000]
x = self.layer1(x) # [batch_size, 64*expansion, 6000]
x = self.layer2(x) # [batch_size, 128*expansion, 3000]
x = self.layer3(x) # [batch_size, 256*expansion, 1500]
x = self.layer4(x) # [batch_size, 512*expansion, 750]
x = self.avgpool(x) # [batch_size, 512*expansion, 1]
x = torch.flatten(x, 1) # [batch_size, 512*expansion]
x = self.fc(x) # [batch_size, num_classes]
return x
# 定义不同规模的ResNet模型
def waveform_resnet18(num_classes=7):
"""
构建类似ResNet18的音频分类模型
"""
return AudioResNet(BasicBlock1D, [2, 2, 2, 2], num_classes)
def waveform_resnet34(num_classes=7):
"""
构建类似ResNet34的音频分类模型
"""
return AudioResNet(BasicBlock1D, [3, 4, 6, 3], num_classes)
def waveform_resnet50(num_classes=7):
"""
构建类似ResNet50的音频分类模型
"""
return AudioResNet(Bottleneck1D, [3, 4, 6, 3], num_classes)
def waveform_resnet101(num_classes=7):
"""
构建类似ResNet101的音频分类模型
"""
return AudioResNet(Bottleneck1D, [3, 4, 23, 3], num_classes)
class SpectrogramResNet(nn.Module):
"""
使用预训练的ResNet模型对音频频谱图进行情感分类
"""
def __init__(self, model_name='resnet18', num_classes=6, pretrained=True):
"""
初始化频谱图ResNet分类模型
参数:
model_name: 使用的ResNet版本 ('resnet18', 'resnet34', 'resnet50', 'resnet101')
num_classes: 情感类别数量
pretrained: 是否使用预训练权重
"""
super(SpectrogramResNet, self).__init__()
# 选择预训练的ResNet模型
if model_name == 'resnet18':
base_model = models.resnet18(pretrained=pretrained)
elif model_name == 'resnet34':
base_model = models.resnet34(pretrained=pretrained)
elif model_name == 'resnet50':
base_model = models.resnet50(pretrained=pretrained)
elif model_name == 'resnet101':
base_model = models.resnet101(pretrained=pretrained)
else:
raise ValueError(f"不支持的模型名称: {model_name}")
# 修改第一个卷积层以接受单通道输入
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# 使用预训练权重初始化第一层(如果可用)
if pretrained:
# 将预训练的三通道权重平均为单通道权重
with torch.no_grad():
self.conv1.weight.data = base_model.conv1.weight.data.mean(dim=1, keepdim=True)
# 使用其余的预训练层
self.bn1 = base_model.bn1
self.relu = base_model.relu
self.maxpool = base_model.maxpool
self.layer1 = base_model.layer1
self.layer2 = base_model.layer2
self.layer3 = base_model.layer3
self.layer4 = base_model.layer4
self.avgpool = base_model.avgpool
# 修改全连接层以匹配目标类别数
in_features = base_model.fc.in_features
self.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
"""
前向传播
参数:
x: 形状为 [batch_size, 1, 128, 128] 的频谱图
返回:
形状为 [batch_size, num_classes] 的类别预测
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 便捷函数用于创建不同版本的模型
def spectrogram_resnet18(num_classes=6, pretrained=True):
return SpectrogramResNet('resnet18', num_classes, pretrained)
def spectrogram_resnet34(num_classes=6, pretrained=True):
return SpectrogramResNet('resnet34', num_classes, pretrained)
def spectrogram_resnet50(num_classes=6, pretrained=True):
return SpectrogramResNet('resnet50', num_classes, pretrained)
def spectrogram_resnet101(num_classes=6, pretrained=True):
return SpectrogramResNet('resnet101', num_classes, pretrained)