From f6b2161822f8b65e4847a285699fb667a7d5d658 Mon Sep 17 00:00:00 2001 From: p9kh64cfp <1047063963@qq.com> Date: Tue, 31 Dec 2024 11:22:07 +0800 Subject: [PATCH] ADD file via upload --- train_improved_alexnet.py | 303 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 train_improved_alexnet.py diff --git a/train_improved_alexnet.py b/train_improved_alexnet.py new file mode 100644 index 0000000..ef95c8a --- /dev/null +++ b/train_improved_alexnet.py @@ -0,0 +1,303 @@ +import torch +import torch.nn as nn +from torchvision import transforms, datasets, utils +import matplotlib.pyplot as plt +import numpy as np +import torch.optim as optim +#from model_improved_alexnet import ImprovedAlexNet#调用模型 +#from model_improved_alexnet_plus import ImprovedAlexNet +from Se_AlexNet import ImprovedAlexNet +#from model_alexnet import ImprovedAlexNet +import os +import json +import time + +# 配置环境 +#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + +# 定义数据转换器 +data_transform = { + "train": transforms.Compose([transforms.RandomResizedCrop(224),#调整图像尺寸为224*224 + transforms.RandomHorizontalFlip(),#数据增强,对图像进行随机水平翻转 + transforms.ToTensor(),#将图像数据转化为张量 + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),#数据标准化 + "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} + + +# 数据集路径 +data_root = os.path.abspath(os.path.join(os.getcwd())) +image_path = data_root + "/data_set/sn_data/" + +train_dataset = datasets.ImageFolder(root=image_path + "/train", + transform=data_transform["train"]) +#print(train_dataset) +train_num = len(train_dataset)#获取训练数据集的样本数量 + +# 类别列表:将训练数据集的类别索引信息保存到一个 JSON 文件中,方便在以后的程序中快速查找类别与索引之间的对应关系。 +class_list = train_dataset.class_to_idx +#print(class_list) +cla_dict = dict((val, key) for key, val in class_list.items()) +#print(cla_dict) +# write dict into json file +json_str = json.dumps(cla_dict, indent=4) +#print(json_str) +with open('class_indices.json', 'w') as json_file: + json_file.write(json_str) + +# 设置数据加载器 +batch_size = 32#数据集的批量大小为 32 +train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=batch_size, shuffle=True,#在每个epoch开始时打乱数据顺序,可以增加数据的随机性,有助于模型的泛化能力 + num_workers=0) + +validate_dataset = datasets.ImageFolder(root=image_path + "/val", + transform=data_transform["val"])#加载验证数据集 +val_num = len(validate_dataset)#获取验证数据集的样本数量 +validate_loader = torch.utils.data.DataLoader(validate_dataset, + batch_size=4, shuffle=True, + num_workers=0)#同上 + +# 网络配置 +net = ImprovedAlexNet(num_classes=11)#定义了net网络模型,模型实例化 +#预训练权重 +#model_weight_path = "models/alexnet.pth" +#net.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False) + +#model = torch.load('path_to_model.pth', map_location=torch.device('cpu')) +#model = torch.load('model.pth', map_location=torch.device('cpu')) +net.to(device) +# 损失函数 +loss_function = nn.CrossEntropyLoss() +# pata = list(net.parameters()) +optimizer = optim.Adam(net.parameters(), lr=0.0002) +save_path = './models/model.pth' +best_acc = 0.0#初始化最佳验证准确率,用于保存模型训练过程中的最佳准确率 + + +from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix +import seaborn as sns +import torchmetrics +import matplotlib.pyplot as plt +plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体 +plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题,设置负号显示为正常的减号符号,而不是方块 + +# 初始化metrics 初始化指标, mdmc_average='samplewise', +#参数 average 表示 Precision 的计算方式,'macro' 表示计算每个类别的 Precision,并对所有类别的 Precision 取平均值 +train_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device) +val_accuracy = torchmetrics.Accuracy(num_classes=len(class_list), average='macro', task='multiclass').to(device) +train_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device)#取平均值 +val_precision = torchmetrics.Precision(num_classes=len(class_list), average='macro', task='multiclass').to(device) +train_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device) +val_recall = torchmetrics.Recall(num_classes=len(class_list), average='macro', task='multiclass').to(device) +train_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device) +val_f1 = torchmetrics.F1Score(num_classes=len(class_list), average='macro', task='multiclass').to(device) + +# 用于绘图的列表,记录每次迭代的数据 +epoch_count = [] +train_losses = [] +val_losses = [] +train_accuracies = [] +val_accuracies = [] +train_precisions = [] +val_precisions = [] +train_recalls = [] +val_recalls = [] +train_f1s = [] +val_f1s = [] +# 初始化记录最优情况下的性能指标 +best_precision = 0.0 +best_recall = 0.0 +best_f1 = 0.0 +best_confusion_matrix = None +best_epoch = 0 + +epochs=70 + +# 训练模型 +for epoch in range(epochs): + # train + net.train() + running_loss = 0.0 + t1 = time.perf_counter()#记录当前时间,用于计算每个epoch的训练时间 + for step, data in enumerate(train_loader, start=0): + images, labels = data#获取数据和标签 + optimizer.zero_grad()#清除之前保存的梯度,以便进行新一轮的梯度计算和优化 + outputs = net(images.to(device))#输出 + loss = loss_function(outputs, labels.to(device))#计算输出outputs与真实标签labels之间的损失值 + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item()#累加当前batch的损失值 + + train_accuracy(outputs.to(device), labels.to(device))#计算准确度 + train_precision(outputs.to(device), labels.to(device))#计算精确度 + train_recall(outputs.to(device), labels.to(device))#计算召回率 + train_f1(outputs.to(device), labels.to(device))#计算F1分数 + + # print train process + rate = (step + 1) / len(train_loader)#显示进度条 + a = "*" * int(rate * 50) + b = "." * int((1 - rate) * 50) + print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") + print() + #print(time.perf_counter()-t1) + + # validate + net.eval() + val_accuracy.reset() + val_precision.reset() + val_recall.reset() + val_f1.reset() + val_loss = 0.0 + acc = 0.0 # accumulate accurate number / epoch + all_predicts = [] + all_labels = [] + with torch.no_grad(): + for val_data in validate_loader: + val_images, val_labels = val_data + outputs = net(val_images.to(device)) + predict_y = torch.max(outputs, dim=1)[1] + acc += (predict_y == val_labels.to(device)).sum().item() + + val_loss = loss_function(outputs.to(device), val_labels.to(device)) + val_accuracy(outputs.to(device), val_labels.to(device)) + val_precision(outputs.to(device), val_labels.to(device)) + val_recall(outputs.to(device), val_labels.to(device)) + val_f1(outputs.to(device), val_labels.to(device)) + # 保存所有的预测和标签,用于后续的性能评估 + all_predicts.extend(predict_y.cpu().numpy()) + all_labels.extend(val_labels.cpu().numpy()) + # 计算性能指标 + val_accurate = acc / val_num + precision = precision_score(all_labels, all_predicts, average='macro') + recall = recall_score(all_labels, all_predicts, average='macro') + f1 = f1_score(all_labels, all_predicts, average='macro') + confusion = confusion_matrix(all_labels, all_predicts)#混淆矩阵 + + # 检查是否是最佳性能,如果是,更新最佳性能指标和保存模型 + if val_accurate > best_acc: + best_acc = val_accurate + best_precision = precision + best_recall = recall + best_f1 = f1 + best_confusion_matrix = confusion + best_epoch = epoch + torch.save(net.state_dict(), save_path) + # 计算平均损失和准确率 + epoch_loss = running_loss / len(train_loader) + epoch_val_loss = val_loss.item() / len(validate_loader) + epoch_train_accuracy = train_accuracy.compute() + epoch_val_accuracy = val_accuracy.compute() + epoch_train_precision = train_precision.compute() + epoch_val_precision = val_precision.compute() + epoch_train_recall = train_recall.compute() + epoch_val_recall = val_recall.compute() + epoch_train_f1 = train_f1.compute() + epoch_val_f1 = val_f1.compute() + + # 记录历史数据以便绘图 + epoch_count.append(epoch + 1) + train_losses.append(epoch_loss) + val_losses.append(epoch_val_loss) + train_accuracies.append(epoch_train_accuracy.item()) + val_accuracies.append(epoch_val_accuracy.item()) + train_precisions.append(epoch_train_precision.item()) + val_precisions.append(epoch_val_precision.item()) + train_recalls.append(epoch_train_recall.item()) + val_recalls.append(epoch_val_recall.item()) + train_f1s.append(epoch_train_f1.item()) + val_f1s.append(epoch_val_f1.item()) + + # 打印每个epoch的指标 + print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, ' + f'Train Accuracy: {epoch_train_accuracy:.4f}, Validation Accuracy: {epoch_val_accuracy:.4f}, ' + f'Train Precision: {epoch_train_precision:.4f}, Validation Precision: {epoch_val_precision:.4f}, ' + f'Train Recall: {epoch_train_recall:.4f}, Validation Recall: {epoch_val_recall:.4f}, ' + f'Train F1: {epoch_train_f1:.4f}, Validation F1: {epoch_val_f1:.4f}') + +print('Finished Training') + +# 保存最佳性能指标到txt +with open('./logs/model_performance.txt', 'w', encoding='utf-8') as fb: + fb.write(f'Accuracy: {best_acc}\n') + fb.write(f'Precision: {best_precision}\n') + fb.write(f'Recall: {best_recall}\n') + fb.write(f'F1-Score: {best_f1}\n') + fb.write(f'Confusion Matrix: \n{best_confusion_matrix}\n') + +# 绘制训练和验证损失图 +plt.figure(figsize=(12, 8)) +plt.subplot(2, 3, 1) +plt.plot(epoch_count, train_losses, label='Train Loss') +plt.plot(epoch_count, val_losses, label='Test Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Training and Test Loss') +plt.legend() + +# 绘制训练和验证准确率图 +plt.subplot(2, 3, 2) +plt.plot(epoch_count, train_accuracies, label='Train Accuracy') +plt.plot(epoch_count, val_accuracies, label='Test Accuracy') +plt.xlabel('Epoch') +plt.ylabel('Accuracy') +plt.title('Training and Test Accuracy') +plt.legend() + +# 绘制训练和验证精确度图 +plt.subplot(2, 3, 3) +plt.plot(epoch_count, train_precisions, label='Train Precision') +plt.plot(epoch_count, val_precisions, label='Test Precision') +plt.xlabel('Epoch') +plt.ylabel('Precision') +plt.title('Training and Test Precision') +plt.legend()#为图表添加说明性文本或标签的方法,以便在数据可视化中更好地传达数据信息 + +# 绘制训练和验证召回率图 +plt.subplot(2, 3, 4) +plt.plot(epoch_count, train_recalls, label='Train Recall') +plt.plot(epoch_count, val_recalls, label='Test Recall') +plt.xlabel('Epoch') +plt.ylabel('Recall') +plt.title('Training and Test Recall') +plt.legend() + +# 绘制训练和验证F1分数图 +plt.subplot(2, 3, 5) +plt.plot(epoch_count, train_f1s, label='Train F1 Score') +plt.plot(epoch_count, val_f1s, label='Test F1 Score') +plt.xlabel('Epoch') +plt.ylabel('F1 Score') +plt.title('Training and Test F1 Score') +plt.legend() + +plt.tight_layout() +plt.show() + +if best_confusion_matrix is not None: + # 设置图的大小 + plt.figure(figsize=(10, 8)) + + # 绘制混淆矩阵的热图,Seaborn库中的heatmap函数 + sns.heatmap(best_confusion_matrix, annot=True, fmt="d", cmap='Blues', + xticklabels=cla_dict.values(), yticklabels=cla_dict.values()) # 使用类别名称作为标签 + + # 设置图的标题和坐标轴标签 + plt.title('Confusion Matrix at Best Performance') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # 显示图表 + plt.show() +else: + print("Best confusion matrix is not available.") + +# 写出准确率 +with open('./logs/model.txt', 'w', encoding='utf-8') as fb: + fb.write(str(best_acc)) +