|
|
@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from torchvision import transforms, datasets, utils
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
|
|
#from model_improved_alexnet import ImprovedAlexNet#调用模型
|
|
|
|
|
|
|
|
#from model_improved_alexnet_plus import ImprovedAlexNet
|
|
|
|
|
|
|
|
from Se_AlexNet import ImprovedAlexNet
|
|
|
|
|
|
|
|
#from model_alexnet import ImprovedAlexNet
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 配置环境
|
|
|
|
|
|
|
|
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
|
|
|
print(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义数据转换器
|
|
|
|
|
|
|
|
data_transform = {
|
|
|
|
|
|
|
|
"train": transforms.Compose([transforms.RandomResizedCrop(224),#调整图像尺寸为224*224
|
|
|
|
|
|
|
|
transforms.RandomHorizontalFlip(),#数据增强,对图像进行随机水平翻转
|
|
|
|
|
|
|
|
transforms.ToTensor(),#将图像数据转化为张量
|
|
|
|
|
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),#数据标准化
|
|
|
|
|
|
|
|
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
|
|
|
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 数据集路径
|
|
|
|
|
|
|
|
data_root = os.path.abspath(os.path.join(os.getcwd()))
|
|
|
|
|
|
|
|
image_path = data_root + "/data_set/sn_data/"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = datasets.ImageFolder(root=image_path + "/train",
|
|
|
|
|
|
|
|
transform=data_transform["train"])
|
|
|
|
|
|
|
|
#print(train_dataset)
|
|
|
|
|
|
|
|
train_num = len(train_dataset)#获取训练数据集的样本数量
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 类别列表:将训练数据集的类别索引信息保存到一个 JSON 文件中,方便在以后的程序中快速查找类别与索引之间的对应关系。
|
|
|
|
|
|
|
|
class_list = train_dataset.class_to_idx
|
|
|
|
|
|
|
|
#print(class_list)
|
|
|
|
|
|
|
|
cla_dict = dict((val, key) for key, val in class_list.items())
|
|
|
|
|
|
|
|
#print(cla_dict)
|
|
|
|
|
|
|
|
# write dict into json file
|
|
|
|
|
|
|
|
json_str = json.dumps(cla_dict, indent=4)
|
|
|
|
|
|
|
|
#print(json_str)
|
|
|
|
|
|
|
|
with open('class_indices.json', 'w') as json_file:
|
|
|
|
|
|
|
|
json_file.write(json_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置数据加载器
|
|
|
|
|
|
|
|
batch_size = 32#数据集的批量大小为 32
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
|
|
|
|
|
|
|
batch_size=batch_size, shuffle=True,#在每个epoch开始时打乱数据顺序,可以增加数据的随机性,有助于模型的泛化能力
|
|
|
|
|
|
|
|
num_workers=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
|
|
|
|
|
|
|
|
transform=data_transform["val"])#加载验证数据集
|
|
|
|
|
|
|
|
val_num = len(validate_dataset)#获取验证数据集的样本数量
|
|
|
|
|
|
|
|
validate_loader = torch.utils.data.DataLoader(validate_dataset,
|
|
|
|
|
|
|
|
batch_size=4, shuffle=True,
|
|
|
|
|
|
|
|
num_workers=0)#同上
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 网络配置
|
|
|
|
|
|
|
|
net = ImprovedAlexNet(num_classes=11)#定义了net网络模型,模型实例化
|
|
|
|
|
|
|
|
#预训练权重
|
|
|
|
|
|
|
|
#model_weight_path = "models/alexnet.pth"
|
|
|
|
|
|
|
|
#net.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#model = torch.load('path_to_model.pth', map_location=torch.device('cpu'))
|
|
|
|
|
|
|
|
#model = torch.load('model.pth', map_location=torch.device('cpu'))
|
|
|
|
|
|
|
|
net.to(device)
|
|
|
|
|
|
|
|
# 损失函数
|
|
|
|
|
|
|
|
loss_function = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
# pata = list(net.parameters())
|
|
|
|
|
|
|
|
optimizer = optim.Adam(net.parameters(), lr=0.0002)
|
|
|
|
|
|
|
|
save_path = './models/model.pth'
|
|
|
|
|
|
|
|
best_acc = 0.0#初始化最佳验证准确率,用于保存模型训练过程中的最佳准确率
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
|
|
|
|
|
|
|
|
import seaborn as sns
|
|
|
|
|
|
|
|
import torchmetrics
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
|
|
|
|
|
|
|
|
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题,设置负号显示为正常的减号符号,而不是方块
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化metrics 初始化指标, mdmc_average='samplewise',
|
|
|
|
|
|
|
|
#参数 average 表示 Precision 的计算方式,'macro' 表示计算每个类别的 Precision,并对所有类别的 Precision 取平均值
|
|
|
|
|
|
|
|
train_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
val_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
train_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)#取平均值
|
|
|
|
|
|
|
|
val_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
train_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
val_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
train_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
val_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 用于绘图的列表,记录每次迭代的数据
|
|
|
|
|
|
|
|
epoch_count = []
|
|
|
|
|
|
|
|
train_losses = []
|
|
|
|
|
|
|
|
val_losses = []
|
|
|
|
|
|
|
|
train_accuracies = []
|
|
|
|
|
|
|
|
val_accuracies = []
|
|
|
|
|
|
|
|
train_precisions = []
|
|
|
|
|
|
|
|
val_precisions = []
|
|
|
|
|
|
|
|
train_recalls = []
|
|
|
|
|
|
|
|
val_recalls = []
|
|
|
|
|
|
|
|
train_f1s = []
|
|
|
|
|
|
|
|
val_f1s = []
|
|
|
|
|
|
|
|
# 初始化记录最优情况下的性能指标
|
|
|
|
|
|
|
|
best_precision = 0.0
|
|
|
|
|
|
|
|
best_recall = 0.0
|
|
|
|
|
|
|
|
best_f1 = 0.0
|
|
|
|
|
|
|
|
best_confusion_matrix = None
|
|
|
|
|
|
|
|
best_epoch = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epochs=70
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 训练模型
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
|
|
|
|
|
# train
|
|
|
|
|
|
|
|
net.train()
|
|
|
|
|
|
|
|
running_loss = 0.0
|
|
|
|
|
|
|
|
t1 = time.perf_counter()#记录当前时间,用于计算每个epoch的训练时间
|
|
|
|
|
|
|
|
for step, data in enumerate(train_loader, start=0):
|
|
|
|
|
|
|
|
images, labels = data#获取数据和标签
|
|
|
|
|
|
|
|
optimizer.zero_grad()#清除之前保存的梯度,以便进行新一轮的梯度计算和优化
|
|
|
|
|
|
|
|
outputs = net(images.to(device))#输出
|
|
|
|
|
|
|
|
loss = loss_function(outputs, labels.to(device))#计算输出outputs与真实标签labels之间的损失值
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# print statistics
|
|
|
|
|
|
|
|
running_loss += loss.item()#累加当前batch的损失值
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_accuracy(outputs.to(device), labels.to(device))#计算准确度
|
|
|
|
|
|
|
|
train_precision(outputs.to(device), labels.to(device))#计算精确度
|
|
|
|
|
|
|
|
train_recall(outputs.to(device), labels.to(device))#计算召回率
|
|
|
|
|
|
|
|
train_f1(outputs.to(device), labels.to(device))#计算F1分数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# print train process
|
|
|
|
|
|
|
|
rate = (step + 1) / len(train_loader)#显示进度条
|
|
|
|
|
|
|
|
a = "*" * int(rate * 50)
|
|
|
|
|
|
|
|
b = "." * int((1 - rate) * 50)
|
|
|
|
|
|
|
|
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
|
|
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
#print(time.perf_counter()-t1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# validate
|
|
|
|
|
|
|
|
net.eval()
|
|
|
|
|
|
|
|
val_accuracy.reset()
|
|
|
|
|
|
|
|
val_precision.reset()
|
|
|
|
|
|
|
|
val_recall.reset()
|
|
|
|
|
|
|
|
val_f1.reset()
|
|
|
|
|
|
|
|
val_loss = 0.0
|
|
|
|
|
|
|
|
acc = 0.0 # accumulate accurate number / epoch
|
|
|
|
|
|
|
|
all_predicts = []
|
|
|
|
|
|
|
|
all_labels = []
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
for val_data in validate_loader:
|
|
|
|
|
|
|
|
val_images, val_labels = val_data
|
|
|
|
|
|
|
|
outputs = net(val_images.to(device))
|
|
|
|
|
|
|
|
predict_y = torch.max(outputs, dim=1)[1]
|
|
|
|
|
|
|
|
acc += (predict_y == val_labels.to(device)).sum().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val_loss = loss_function(outputs.to(device), val_labels.to(device))
|
|
|
|
|
|
|
|
val_accuracy(outputs.to(device), val_labels.to(device))
|
|
|
|
|
|
|
|
val_precision(outputs.to(device), val_labels.to(device))
|
|
|
|
|
|
|
|
val_recall(outputs.to(device), val_labels.to(device))
|
|
|
|
|
|
|
|
val_f1(outputs.to(device), val_labels.to(device))
|
|
|
|
|
|
|
|
# 保存所有的预测和标签,用于后续的性能评估
|
|
|
|
|
|
|
|
all_predicts.extend(predict_y.cpu().numpy())
|
|
|
|
|
|
|
|
all_labels.extend(val_labels.cpu().numpy())
|
|
|
|
|
|
|
|
# 计算性能指标
|
|
|
|
|
|
|
|
val_accurate = acc / val_num
|
|
|
|
|
|
|
|
precision = precision_score(all_labels, all_predicts, average='macro')
|
|
|
|
|
|
|
|
recall = recall_score(all_labels, all_predicts, average='macro')
|
|
|
|
|
|
|
|
f1 = f1_score(all_labels, all_predicts, average='macro')
|
|
|
|
|
|
|
|
confusion = confusion_matrix(all_labels, all_predicts)#混淆矩阵
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否是最佳性能,如果是,更新最佳性能指标和保存模型
|
|
|
|
|
|
|
|
if val_accurate > best_acc:
|
|
|
|
|
|
|
|
best_acc = val_accurate
|
|
|
|
|
|
|
|
best_precision = precision
|
|
|
|
|
|
|
|
best_recall = recall
|
|
|
|
|
|
|
|
best_f1 = f1
|
|
|
|
|
|
|
|
best_confusion_matrix = confusion
|
|
|
|
|
|
|
|
best_epoch = epoch
|
|
|
|
|
|
|
|
torch.save(net.state_dict(), save_path)
|
|
|
|
|
|
|
|
# 计算平均损失和准确率
|
|
|
|
|
|
|
|
epoch_loss = running_loss / len(train_loader)
|
|
|
|
|
|
|
|
epoch_val_loss = val_loss.item() / len(validate_loader)
|
|
|
|
|
|
|
|
epoch_train_accuracy = train_accuracy.compute()
|
|
|
|
|
|
|
|
epoch_val_accuracy = val_accuracy.compute()
|
|
|
|
|
|
|
|
epoch_train_precision = train_precision.compute()
|
|
|
|
|
|
|
|
epoch_val_precision = val_precision.compute()
|
|
|
|
|
|
|
|
epoch_train_recall = train_recall.compute()
|
|
|
|
|
|
|
|
epoch_val_recall = val_recall.compute()
|
|
|
|
|
|
|
|
epoch_train_f1 = train_f1.compute()
|
|
|
|
|
|
|
|
epoch_val_f1 = val_f1.compute()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 记录历史数据以便绘图
|
|
|
|
|
|
|
|
epoch_count.append(epoch + 1)
|
|
|
|
|
|
|
|
train_losses.append(epoch_loss)
|
|
|
|
|
|
|
|
val_losses.append(epoch_val_loss)
|
|
|
|
|
|
|
|
train_accuracies.append(epoch_train_accuracy.item())
|
|
|
|
|
|
|
|
val_accuracies.append(epoch_val_accuracy.item())
|
|
|
|
|
|
|
|
train_precisions.append(epoch_train_precision.item())
|
|
|
|
|
|
|
|
val_precisions.append(epoch_val_precision.item())
|
|
|
|
|
|
|
|
train_recalls.append(epoch_train_recall.item())
|
|
|
|
|
|
|
|
val_recalls.append(epoch_val_recall.item())
|
|
|
|
|
|
|
|
train_f1s.append(epoch_train_f1.item())
|
|
|
|
|
|
|
|
val_f1s.append(epoch_val_f1.item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 打印每个epoch的指标
|
|
|
|
|
|
|
|
print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, '
|
|
|
|
|
|
|
|
f'Train Accuracy: {epoch_train_accuracy:.4f}, Validation Accuracy: {epoch_val_accuracy:.4f}, '
|
|
|
|
|
|
|
|
f'Train Precision: {epoch_train_precision:.4f}, Validation Precision: {epoch_val_precision:.4f}, '
|
|
|
|
|
|
|
|
f'Train Recall: {epoch_train_recall:.4f}, Validation Recall: {epoch_val_recall:.4f}, '
|
|
|
|
|
|
|
|
f'Train F1: {epoch_train_f1:.4f}, Validation F1: {epoch_val_f1:.4f}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('Finished Training')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存最佳性能指标到txt
|
|
|
|
|
|
|
|
with open('./logs/model_performance.txt', 'w', encoding='utf-8') as fb:
|
|
|
|
|
|
|
|
fb.write(f'Accuracy: {best_acc}\n')
|
|
|
|
|
|
|
|
fb.write(f'Precision: {best_precision}\n')
|
|
|
|
|
|
|
|
fb.write(f'Recall: {best_recall}\n')
|
|
|
|
|
|
|
|
fb.write(f'F1-Score: {best_f1}\n')
|
|
|
|
|
|
|
|
fb.write(f'Confusion Matrix: \n{best_confusion_matrix}\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制训练和验证损失图
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
|
|
|
|
plt.subplot(2, 3, 1)
|
|
|
|
|
|
|
|
plt.plot(epoch_count, train_losses, label='Train Loss')
|
|
|
|
|
|
|
|
plt.plot(epoch_count, val_losses, label='Test Loss')
|
|
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
|
|
|
plt.ylabel('Loss')
|
|
|
|
|
|
|
|
plt.title('Training and Test Loss')
|
|
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制训练和验证准确率图
|
|
|
|
|
|
|
|
plt.subplot(2, 3, 2)
|
|
|
|
|
|
|
|
plt.plot(epoch_count, train_accuracies, label='Train Accuracy')
|
|
|
|
|
|
|
|
plt.plot(epoch_count, val_accuracies, label='Test Accuracy')
|
|
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
|
|
|
plt.ylabel('Accuracy')
|
|
|
|
|
|
|
|
plt.title('Training and Test Accuracy')
|
|
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制训练和验证精确度图
|
|
|
|
|
|
|
|
plt.subplot(2, 3, 3)
|
|
|
|
|
|
|
|
plt.plot(epoch_count, train_precisions, label='Train Precision')
|
|
|
|
|
|
|
|
plt.plot(epoch_count, val_precisions, label='Test Precision')
|
|
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
|
|
|
plt.ylabel('Precision')
|
|
|
|
|
|
|
|
plt.title('Training and Test Precision')
|
|
|
|
|
|
|
|
plt.legend()#为图表添加说明性文本或标签的方法,以便在数据可视化中更好地传达数据信息
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制训练和验证召回率图
|
|
|
|
|
|
|
|
plt.subplot(2, 3, 4)
|
|
|
|
|
|
|
|
plt.plot(epoch_count, train_recalls, label='Train Recall')
|
|
|
|
|
|
|
|
plt.plot(epoch_count, val_recalls, label='Test Recall')
|
|
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
|
|
|
plt.ylabel('Recall')
|
|
|
|
|
|
|
|
plt.title('Training and Test Recall')
|
|
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制训练和验证F1分数图
|
|
|
|
|
|
|
|
plt.subplot(2, 3, 5)
|
|
|
|
|
|
|
|
plt.plot(epoch_count, train_f1s, label='Train F1 Score')
|
|
|
|
|
|
|
|
plt.plot(epoch_count, val_f1s, label='Test F1 Score')
|
|
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
|
|
|
plt.ylabel('F1 Score')
|
|
|
|
|
|
|
|
plt.title('Training and Test F1 Score')
|
|
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if best_confusion_matrix is not None:
|
|
|
|
|
|
|
|
# 设置图的大小
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制混淆矩阵的热图,Seaborn库中的heatmap函数
|
|
|
|
|
|
|
|
sns.heatmap(best_confusion_matrix, annot=True, fmt="d", cmap='Blues',
|
|
|
|
|
|
|
|
xticklabels=cla_dict.values(), yticklabels=cla_dict.values()) # 使用类别名称作为标签
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置图的标题和坐标轴标签
|
|
|
|
|
|
|
|
plt.title('Confusion Matrix at Best Performance')
|
|
|
|
|
|
|
|
plt.xlabel('Predicted Label')
|
|
|
|
|
|
|
|
plt.ylabel('True Label')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 显示图表
|
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
print("Best confusion matrix is not available.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 写出准确率
|
|
|
|
|
|
|
|
with open('./logs/model.txt', 'w', encoding='utf-8') as fb:
|
|
|
|
|
|
|
|
fb.write(str(best_acc))
|
|
|
|
|
|
|
|
|