You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

304 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
#from model_improved_alexnet import ImprovedAlexNet#调用模型
#from model_improved_alexnet_plus import ImprovedAlexNet
from Se_AlexNet import ImprovedAlexNet
#from model_alexnet import ImprovedAlexNet
import os
import json
import time
# 配置环境
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device)
# 定义数据转换器
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),#调整图像尺寸为224*224
transforms.RandomHorizontalFlip(),#数据增强,对图像进行随机水平翻转
transforms.ToTensor(),#将图像数据转化为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),#数据标准化
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# 数据集路径
data_root = os.path.abspath(os.path.join(os.getcwd()))
image_path = data_root + "/data_set/sn_data/"
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
#print(train_dataset)
train_num = len(train_dataset)#获取训练数据集的样本数量
# 类别列表:将训练数据集的类别索引信息保存到一个 JSON 文件中,方便在以后的程序中快速查找类别与索引之间的对应关系。
class_list = train_dataset.class_to_idx
#print(class_list)
cla_dict = dict((val, key) for key, val in class_list.items())
#print(cla_dict)
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
#print(json_str)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
# 设置数据加载器
batch_size = 32#数据集的批量大小为 32
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,#在每个epoch开始时打乱数据顺序可以增加数据的随机性有助于模型的泛化能力
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])#加载验证数据集
val_num = len(validate_dataset)#获取验证数据集的样本数量
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=True,
num_workers=0)#同上
# 网络配置
net = ImprovedAlexNet(num_classes=11)#定义了net网络模型模型实例化
#预训练权重
#model_weight_path = "models/alexnet.pth"
#net.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False)
#model = torch.load('path_to_model.pth', map_location=torch.device('cpu'))
#model = torch.load('model.pth', map_location=torch.device('cpu'))
net.to(device)
# 损失函数
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)
save_path = './models/model.pth'
best_acc = 0.0#初始化最佳验证准确率,用于保存模型训练过程中的最佳准确率
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import torchmetrics
import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题,设置负号显示为正常的减号符号,而不是方块
# 初始化metrics 初始化指标, mdmc_average='samplewise',
#参数 average 表示 Precision 的计算方式,'macro' 表示计算每个类别的 Precision并对所有类别的 Precision 取平均值
train_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)#取平均值
val_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
# 用于绘图的列表,记录每次迭代的数据
epoch_count = []
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_precisions = []
val_precisions = []
train_recalls = []
val_recalls = []
train_f1s = []
val_f1s = []
# 初始化记录最优情况下的性能指标
best_precision = 0.0
best_recall = 0.0
best_f1 = 0.0
best_confusion_matrix = None
best_epoch = 0
epochs=70
# 训练模型
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
t1 = time.perf_counter()#记录当前时间用于计算每个epoch的训练时间
for step, data in enumerate(train_loader, start=0):
images, labels = data#获取数据和标签
optimizer.zero_grad()#清除之前保存的梯度,以便进行新一轮的梯度计算和优化
outputs = net(images.to(device))#输出
loss = loss_function(outputs, labels.to(device))#计算输出outputs与真实标签labels之间的损失值
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()#累加当前batch的损失值
train_accuracy(outputs.to(device), labels.to(device))#计算准确度
train_precision(outputs.to(device), labels.to(device))#计算精确度
train_recall(outputs.to(device), labels.to(device))#计算召回率
train_f1(outputs.to(device), labels.to(device))#计算F1分数
# print train process
rate = (step + 1) / len(train_loader)#显示进度条
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
#print(time.perf_counter()-t1)
# validate
net.eval()
val_accuracy.reset()
val_precision.reset()
val_recall.reset()
val_f1.reset()
val_loss = 0.0
acc = 0.0 # accumulate accurate number / epoch
all_predicts = []
all_labels = []
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_loss = loss_function(outputs.to(device), val_labels.to(device))
val_accuracy(outputs.to(device), val_labels.to(device))
val_precision(outputs.to(device), val_labels.to(device))
val_recall(outputs.to(device), val_labels.to(device))
val_f1(outputs.to(device), val_labels.to(device))
# 保存所有的预测和标签,用于后续的性能评估
all_predicts.extend(predict_y.cpu().numpy())
all_labels.extend(val_labels.cpu().numpy())
# 计算性能指标
val_accurate = acc / val_num
precision = precision_score(all_labels, all_predicts, average='macro')
recall = recall_score(all_labels, all_predicts, average='macro')
f1 = f1_score(all_labels, all_predicts, average='macro')
confusion = confusion_matrix(all_labels, all_predicts)#混淆矩阵
# 检查是否是最佳性能,如果是,更新最佳性能指标和保存模型
if val_accurate > best_acc:
best_acc = val_accurate
best_precision = precision
best_recall = recall
best_f1 = f1
best_confusion_matrix = confusion
best_epoch = epoch
torch.save(net.state_dict(), save_path)
# 计算平均损失和准确率
epoch_loss = running_loss / len(train_loader)
epoch_val_loss = val_loss.item() / len(validate_loader)
epoch_train_accuracy = train_accuracy.compute()
epoch_val_accuracy = val_accuracy.compute()
epoch_train_precision = train_precision.compute()
epoch_val_precision = val_precision.compute()
epoch_train_recall = train_recall.compute()
epoch_val_recall = val_recall.compute()
epoch_train_f1 = train_f1.compute()
epoch_val_f1 = val_f1.compute()
# 记录历史数据以便绘图
epoch_count.append(epoch + 1)
train_losses.append(epoch_loss)
val_losses.append(epoch_val_loss)
train_accuracies.append(epoch_train_accuracy.item())
val_accuracies.append(epoch_val_accuracy.item())
train_precisions.append(epoch_train_precision.item())
val_precisions.append(epoch_val_precision.item())
train_recalls.append(epoch_train_recall.item())
val_recalls.append(epoch_val_recall.item())
train_f1s.append(epoch_train_f1.item())
val_f1s.append(epoch_val_f1.item())
# 打印每个epoch的指标
print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, '
f'Train Accuracy: {epoch_train_accuracy:.4f}, Validation Accuracy: {epoch_val_accuracy:.4f}, '
f'Train Precision: {epoch_train_precision:.4f}, Validation Precision: {epoch_val_precision:.4f}, '
f'Train Recall: {epoch_train_recall:.4f}, Validation Recall: {epoch_val_recall:.4f}, '
f'Train F1: {epoch_train_f1:.4f}, Validation F1: {epoch_val_f1:.4f}')
print('Finished Training')
# 保存最佳性能指标到txt
with open('./logs/model_performance.txt', 'w', encoding='utf-8') as fb:
fb.write(f'Accuracy: {best_acc}\n')
fb.write(f'Precision: {best_precision}\n')
fb.write(f'Recall: {best_recall}\n')
fb.write(f'F1-Score: {best_f1}\n')
fb.write(f'Confusion Matrix: \n{best_confusion_matrix}\n')
# 绘制训练和验证损失图
plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 1)
plt.plot(epoch_count, train_losses, label='Train Loss')
plt.plot(epoch_count, val_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
# 绘制训练和验证准确率图
plt.subplot(2, 3, 2)
plt.plot(epoch_count, train_accuracies, label='Train Accuracy')
plt.plot(epoch_count, val_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy')
plt.legend()
# 绘制训练和验证精确度图
plt.subplot(2, 3, 3)
plt.plot(epoch_count, train_precisions, label='Train Precision')
plt.plot(epoch_count, val_precisions, label='Test Precision')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.title('Training and Test Precision')
plt.legend()#为图表添加说明性文本或标签的方法,以便在数据可视化中更好地传达数据信息
# 绘制训练和验证召回率图
plt.subplot(2, 3, 4)
plt.plot(epoch_count, train_recalls, label='Train Recall')
plt.plot(epoch_count, val_recalls, label='Test Recall')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.title('Training and Test Recall')
plt.legend()
# 绘制训练和验证F1分数图
plt.subplot(2, 3, 5)
plt.plot(epoch_count, train_f1s, label='Train F1 Score')
plt.plot(epoch_count, val_f1s, label='Test F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('Training and Test F1 Score')
plt.legend()
plt.tight_layout()
plt.show()
if best_confusion_matrix is not None:
# 设置图的大小
plt.figure(figsize=(10, 8))
# 绘制混淆矩阵的热图Seaborn库中的heatmap函数
sns.heatmap(best_confusion_matrix, annot=True, fmt="d", cmap='Blues',
xticklabels=cla_dict.values(), yticklabels=cla_dict.values()) # 使用类别名称作为标签
# 设置图的标题和坐标轴标签
plt.title('Confusion Matrix at Best Performance')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 显示图表
plt.show()
else:
print("Best confusion matrix is not available.")
# 写出准确率
with open('./logs/model.txt', 'w', encoding='utf-8') as fb:
fb.write(str(best_acc))