You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
6.8 KiB
231 lines
6.8 KiB
import os
|
|
import random
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader, random_split
|
|
import torchvision
|
|
|
|
from modules import ResNet
|
|
from modules import MyTransforms
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
EPOCH = 15
|
|
|
|
EPOCH_OFFSET = 0
|
|
|
|
LR = 0.001
|
|
|
|
DATA_PATH = "./datasets/chest_xray/train"
|
|
MODEL_PATH = ""
|
|
SAVE_MODEL_PATH = "./models/512"
|
|
TRAIN_RATIO = 0.8
|
|
TEST_RATIO = 0.2
|
|
RANDOM_SEED = 114514
|
|
NUM_WORKERS = 4
|
|
BATCH_SIZE = 32
|
|
|
|
# 设置随机种子确保可复现性
|
|
torch.manual_seed(RANDOM_SEED)
|
|
np.random.seed(RANDOM_SEED)
|
|
random.seed(RANDOM_SEED)
|
|
|
|
if __name__ == '__main__':
|
|
print('device:', DEVICE)
|
|
|
|
# 数据预处理
|
|
dataset_transform = MyTransforms.TrainCompose()
|
|
|
|
|
|
# 加载数据集
|
|
dataset = torchvision.datasets.ImageFolder(
|
|
root=DATA_PATH,
|
|
transform=dataset_transform
|
|
)
|
|
|
|
# 标签字典
|
|
label_dict = dataset.class_to_idx
|
|
# 反转字典
|
|
label_dict = {v: k for k, v in label_dict.items()}
|
|
|
|
print(label_dict)
|
|
|
|
# 各类别权重,用以解决样本不平衡问题
|
|
class_weights = torch.tensor([len(dataset) / dataset.targets.count(0), len(dataset) / dataset.targets.count(1)]).to(DEVICE)
|
|
|
|
print('class weights:', class_weights)
|
|
|
|
train_dataset, val_dataset = random_split(dataset, [0.9, 0.1])
|
|
|
|
print(f"Number of training samples: {len(train_dataset)}")
|
|
|
|
|
|
print(f"Number of validation samples: {len(val_dataset)}")
|
|
|
|
train_loader = DataLoader(
|
|
dataset=train_dataset,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=NUM_WORKERS
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
dataset=val_dataset,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
num_workers=NUM_WORKERS
|
|
)
|
|
|
|
model = ResNet.ResNet().to(DEVICE)
|
|
|
|
# 如果指定了模型路径,则加载模型
|
|
if MODEL_PATH is not None and os.path.exists(MODEL_PATH):
|
|
# 加载模型
|
|
print('loading model...')
|
|
print(MODEL_PATH)
|
|
model.load_state_dict(torch.load(MODEL_PATH))
|
|
|
|
# 如果模型路径不存在,则初始化模型
|
|
else:
|
|
print('initializing model...')
|
|
|
|
# 损失函数
|
|
loss_func = nn.CrossEntropyLoss(class_weights)
|
|
|
|
# 优化器
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
|
|
# 变长学习率
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
|
|
|
|
# 存储训练和验证的准确率,用于最后打印结果
|
|
train_total_acc = []
|
|
train_label_acc = []
|
|
val_total_acc = []
|
|
val_label_acc = []
|
|
|
|
# 创建保存模型的文件夹
|
|
if not os.path.exists(SAVE_MODEL_PATH):
|
|
os.makedirs(SAVE_MODEL_PATH)
|
|
|
|
model.train()
|
|
for epoch in range(EPOCH):
|
|
if optimizer.param_groups[0]['lr'] < 1e-6:
|
|
scheduler.step = lambda: None
|
|
optimizer.param_groups[0]['lr'] = 1e-6
|
|
print('Epoch [{}/{}], lr {}'.format(epoch + 1, EPOCH, optimizer.param_groups[0]['lr']))
|
|
|
|
train_correct = 0
|
|
train_total = 0
|
|
label_correct = {0:0, 1:0}
|
|
label_total = {0:0, 1:0}
|
|
|
|
for i, data in enumerate(train_loader):
|
|
inputs, labels = data
|
|
inputs = inputs.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
outputs = model(inputs)
|
|
loss = loss_func(outputs, labels)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
scheduler.step()
|
|
|
|
|
|
# 统计训练集准确率
|
|
_, predicted = torch.max(outputs, 1)
|
|
train_total += labels.size(0)
|
|
train_correct += (predicted == labels).sum().item()
|
|
|
|
# 统计训练集标签准确率
|
|
for label, pred in zip(labels, predicted):
|
|
label = label.item()
|
|
pred = pred.item()
|
|
if label == pred:
|
|
label_correct[label] += 1
|
|
label_total[label] += 1
|
|
|
|
print('step: ', i, ',loss: ', loss.item(), ', correct:', 100.0 * (predicted == labels).sum().item() / labels.size(0), '%')
|
|
|
|
train_total_acc.append(100.0 * train_correct / train_total)
|
|
train_label_acc.append({label: 100.0 * label_correct.get(label, 0) / label_total.get(label, 1) for label in label_correct})
|
|
|
|
# 保存模型
|
|
print(f'Saving model(epoch={epoch})...')
|
|
|
|
torch.save(model.state_dict(), SAVE_MODEL_PATH + '/model_epoch_{}.pth'.format(epoch + 1 + EPOCH_OFFSET))
|
|
|
|
|
|
# 验证集准确率
|
|
val_correct = 0
|
|
val_total = 0
|
|
val_label_correct = {0:0, 1:0}
|
|
val_label_total = {0:0, 1:0}
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for inputs, labels in val_loader:
|
|
inputs = inputs.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
outputs = model(inputs)
|
|
|
|
# 统计验证集准确率
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
val_total += labels.size(0)
|
|
val_correct += (predicted == labels).sum().item()
|
|
|
|
# 统计验证集标签准确率
|
|
for label, pred in zip(labels, predicted):
|
|
label = label.item()
|
|
pred = pred.item()
|
|
if label == pred:
|
|
val_label_correct[label] += 1
|
|
val_label_total[label] += 1
|
|
|
|
val_total_acc.append(100.0 * val_correct / val_total)
|
|
val_label_acc.append({label: 100.0 * val_label_correct.get(label, 0) / val_label_total.get(label, 1) for label in val_label_correct})
|
|
|
|
print(f'Train Accuracy: {train_total_acc[-1]:.2f}%')
|
|
print(f'Validation Accuracy: {val_total_acc[-1]:.2f}%')
|
|
print(f'Train Label Accuracy: {train_label_acc[-1]}')
|
|
print(f'Validation Label Accuracy: {val_label_acc[-1]}')
|
|
|
|
# 结果绘制
|
|
epochs = range(1, EPOCH + 1)
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(epochs, train_total_acc, 'b-', label='Train Acc')
|
|
plt.plot(epochs, val_total_acc, 'r-', label='Val Acc')
|
|
plt.title('Total Accuracy')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Accuracy (%)')
|
|
plt.legend()
|
|
|
|
plt.subplot(1, 2, 2)
|
|
for label in [0, 1]:
|
|
train_label = [acc[label] for acc in train_label_acc]
|
|
val_label = [acc[label] for acc in val_label_acc]
|
|
if label == 0:
|
|
train_color = 'b-'
|
|
val_color = 'r--'
|
|
else:
|
|
train_color = 'g-'
|
|
val_color = 'y--'
|
|
plt.plot(epochs, train_label, train_color, label=f'Train Label {label_dict[label]}')
|
|
plt.plot(epochs, val_label, val_color, label=f'Val Label {label_dict[label]}')
|
|
|
|
plt.title('Label Accuracy')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Accuracy (%)')
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
plt.savefig(SAVE_MODEL_PATH + '/accuracy_plots.png')
|
|
plt.show() |