ADD file via upload

main
p9kh64cfp 8 months ago
parent 9932aec78c
commit f6b2161822

@ -0,0 +1,303 @@
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
#from model_improved_alexnet import ImprovedAlexNet#调用模型
#from model_improved_alexnet_plus import ImprovedAlexNet
from Se_AlexNet import ImprovedAlexNet
#from model_alexnet import ImprovedAlexNet
import os
import json
import time
# 配置环境
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device)
# 定义数据转换器
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),#调整图像尺寸为224*224
transforms.RandomHorizontalFlip(),#数据增强,对图像进行随机水平翻转
transforms.ToTensor(),#将图像数据转化为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),#数据标准化
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# 数据集路径
data_root = os.path.abspath(os.path.join(os.getcwd()))
image_path = data_root + "/data_set/sn_data/"
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
#print(train_dataset)
train_num = len(train_dataset)#获取训练数据集的样本数量
# 类别列表:将训练数据集的类别索引信息保存到一个 JSON 文件中,方便在以后的程序中快速查找类别与索引之间的对应关系。
class_list = train_dataset.class_to_idx
#print(class_list)
cla_dict = dict((val, key) for key, val in class_list.items())
#print(cla_dict)
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
#print(json_str)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
# 设置数据加载器
batch_size = 32#数据集的批量大小为 32
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,#在每个epoch开始时打乱数据顺序可以增加数据的随机性有助于模型的泛化能力
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])#加载验证数据集
val_num = len(validate_dataset)#获取验证数据集的样本数量
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=True,
num_workers=0)#同上
# 网络配置
net = ImprovedAlexNet(num_classes=11)#定义了net网络模型模型实例化
#预训练权重
#model_weight_path = "models/alexnet.pth"
#net.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False)
#model = torch.load('path_to_model.pth', map_location=torch.device('cpu'))
#model = torch.load('model.pth', map_location=torch.device('cpu'))
net.to(device)
# 损失函数
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)
save_path = './models/model.pth'
best_acc = 0.0#初始化最佳验证准确率,用于保存模型训练过程中的最佳准确率
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import torchmetrics
import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题,设置负号显示为正常的减号符号,而不是方块
# 初始化metrics 初始化指标, mdmc_average='samplewise',
#参数 average 表示 Precision 的计算方式,'macro' 表示计算每个类别的 Precision并对所有类别的 Precision 取平均值
train_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)#取平均值
val_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device)
train_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
val_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device)
# 用于绘图的列表,记录每次迭代的数据
epoch_count = []
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_precisions = []
val_precisions = []
train_recalls = []
val_recalls = []
train_f1s = []
val_f1s = []
# 初始化记录最优情况下的性能指标
best_precision = 0.0
best_recall = 0.0
best_f1 = 0.0
best_confusion_matrix = None
best_epoch = 0
epochs=70
# 训练模型
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
t1 = time.perf_counter()#记录当前时间用于计算每个epoch的训练时间
for step, data in enumerate(train_loader, start=0):
images, labels = data#获取数据和标签
optimizer.zero_grad()#清除之前保存的梯度,以便进行新一轮的梯度计算和优化
outputs = net(images.to(device))#输出
loss = loss_function(outputs, labels.to(device))#计算输出outputs与真实标签labels之间的损失值
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()#累加当前batch的损失值
train_accuracy(outputs.to(device), labels.to(device))#计算准确度
train_precision(outputs.to(device), labels.to(device))#计算精确度
train_recall(outputs.to(device), labels.to(device))#计算召回率
train_f1(outputs.to(device), labels.to(device))#计算F1分数
# print train process
rate = (step + 1) / len(train_loader)#显示进度条
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
#print(time.perf_counter()-t1)
# validate
net.eval()
val_accuracy.reset()
val_precision.reset()
val_recall.reset()
val_f1.reset()
val_loss = 0.0
acc = 0.0 # accumulate accurate number / epoch
all_predicts = []
all_labels = []
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_loss = loss_function(outputs.to(device), val_labels.to(device))
val_accuracy(outputs.to(device), val_labels.to(device))
val_precision(outputs.to(device), val_labels.to(device))
val_recall(outputs.to(device), val_labels.to(device))
val_f1(outputs.to(device), val_labels.to(device))
# 保存所有的预测和标签,用于后续的性能评估
all_predicts.extend(predict_y.cpu().numpy())
all_labels.extend(val_labels.cpu().numpy())
# 计算性能指标
val_accurate = acc / val_num
precision = precision_score(all_labels, all_predicts, average='macro')
recall = recall_score(all_labels, all_predicts, average='macro')
f1 = f1_score(all_labels, all_predicts, average='macro')
confusion = confusion_matrix(all_labels, all_predicts)#混淆矩阵
# 检查是否是最佳性能,如果是,更新最佳性能指标和保存模型
if val_accurate > best_acc:
best_acc = val_accurate
best_precision = precision
best_recall = recall
best_f1 = f1
best_confusion_matrix = confusion
best_epoch = epoch
torch.save(net.state_dict(), save_path)
# 计算平均损失和准确率
epoch_loss = running_loss / len(train_loader)
epoch_val_loss = val_loss.item() / len(validate_loader)
epoch_train_accuracy = train_accuracy.compute()
epoch_val_accuracy = val_accuracy.compute()
epoch_train_precision = train_precision.compute()
epoch_val_precision = val_precision.compute()
epoch_train_recall = train_recall.compute()
epoch_val_recall = val_recall.compute()
epoch_train_f1 = train_f1.compute()
epoch_val_f1 = val_f1.compute()
# 记录历史数据以便绘图
epoch_count.append(epoch + 1)
train_losses.append(epoch_loss)
val_losses.append(epoch_val_loss)
train_accuracies.append(epoch_train_accuracy.item())
val_accuracies.append(epoch_val_accuracy.item())
train_precisions.append(epoch_train_precision.item())
val_precisions.append(epoch_val_precision.item())
train_recalls.append(epoch_train_recall.item())
val_recalls.append(epoch_val_recall.item())
train_f1s.append(epoch_train_f1.item())
val_f1s.append(epoch_val_f1.item())
# 打印每个epoch的指标
print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, '
f'Train Accuracy: {epoch_train_accuracy:.4f}, Validation Accuracy: {epoch_val_accuracy:.4f}, '
f'Train Precision: {epoch_train_precision:.4f}, Validation Precision: {epoch_val_precision:.4f}, '
f'Train Recall: {epoch_train_recall:.4f}, Validation Recall: {epoch_val_recall:.4f}, '
f'Train F1: {epoch_train_f1:.4f}, Validation F1: {epoch_val_f1:.4f}')
print('Finished Training')
# 保存最佳性能指标到txt
with open('./logs/model_performance.txt', 'w', encoding='utf-8') as fb:
fb.write(f'Accuracy: {best_acc}\n')
fb.write(f'Precision: {best_precision}\n')
fb.write(f'Recall: {best_recall}\n')
fb.write(f'F1-Score: {best_f1}\n')
fb.write(f'Confusion Matrix: \n{best_confusion_matrix}\n')
# 绘制训练和验证损失图
plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 1)
plt.plot(epoch_count, train_losses, label='Train Loss')
plt.plot(epoch_count, val_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
# 绘制训练和验证准确率图
plt.subplot(2, 3, 2)
plt.plot(epoch_count, train_accuracies, label='Train Accuracy')
plt.plot(epoch_count, val_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy')
plt.legend()
# 绘制训练和验证精确度图
plt.subplot(2, 3, 3)
plt.plot(epoch_count, train_precisions, label='Train Precision')
plt.plot(epoch_count, val_precisions, label='Test Precision')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.title('Training and Test Precision')
plt.legend()#为图表添加说明性文本或标签的方法,以便在数据可视化中更好地传达数据信息
# 绘制训练和验证召回率图
plt.subplot(2, 3, 4)
plt.plot(epoch_count, train_recalls, label='Train Recall')
plt.plot(epoch_count, val_recalls, label='Test Recall')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.title('Training and Test Recall')
plt.legend()
# 绘制训练和验证F1分数图
plt.subplot(2, 3, 5)
plt.plot(epoch_count, train_f1s, label='Train F1 Score')
plt.plot(epoch_count, val_f1s, label='Test F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('Training and Test F1 Score')
plt.legend()
plt.tight_layout()
plt.show()
if best_confusion_matrix is not None:
# 设置图的大小
plt.figure(figsize=(10, 8))
# 绘制混淆矩阵的热图Seaborn库中的heatmap函数
sns.heatmap(best_confusion_matrix, annot=True, fmt="d", cmap='Blues',
xticklabels=cla_dict.values(), yticklabels=cla_dict.values()) # 使用类别名称作为标签
# 设置图的标题和坐标轴标签
plt.title('Confusion Matrix at Best Performance')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 显示图表
plt.show()
else:
print("Best confusion matrix is not available.")
# 写出准确率
with open('./logs/model.txt', 'w', encoding='utf-8') as fb:
fb.write(str(best_acc))
Loading…
Cancel
Save