You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
1.8 KiB
76 lines
1.8 KiB
import random
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
import torchvision
|
|
from modules import MyTransforms
|
|
from modules.ResNet import ResNet
|
|
from sklearn import metrics
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
DATA_PATH = "./datasets/chest_xray/test"
|
|
MODEL_PATH = "./models/512/model_epoch_2.pth"
|
|
NUM_WORKERS = 2
|
|
BATCH_SIZE = 16
|
|
RANDOM_SEED = 114514
|
|
|
|
# 设置随机种子确保可复现性
|
|
torch.manual_seed(RANDOM_SEED)
|
|
np.random.seed(RANDOM_SEED)
|
|
random.seed(RANDOM_SEED)
|
|
|
|
def test():
|
|
# 加载模型
|
|
model = ResNet()
|
|
model.to(DEVICE)
|
|
model.load_state_dict(torch.load(MODEL_PATH))
|
|
|
|
# 数据预处理
|
|
test_transform = MyTransforms.TestCompose()
|
|
|
|
|
|
# 加载数据集
|
|
test_dataset = torchvision.datasets.ImageFolder(
|
|
root=DATA_PATH,
|
|
transform=test_transform
|
|
)
|
|
|
|
print(test_dataset.class_to_idx)
|
|
|
|
print(f"Number of tseting samples: {len(test_dataset)}")
|
|
|
|
test_loader = DataLoader(
|
|
dataset=test_dataset,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
num_workers=NUM_WORKERS
|
|
)
|
|
|
|
# 测试模型
|
|
model.eval()
|
|
|
|
y_test = []
|
|
y_pred = []
|
|
|
|
|
|
with torch.no_grad():
|
|
for i, (inputs, labels) in enumerate(test_loader):
|
|
inputs = inputs.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
outputs = model(inputs)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
print('step: ', i, ', correct:', 100.0 * (predicted == labels).sum().item() / labels.size(0), '%')
|
|
|
|
y_test.extend(labels.cpu().numpy().tolist())
|
|
y_pred.extend(predicted.cpu().numpy().tolist())
|
|
|
|
print("classification_report:")
|
|
print(metrics.classification_report(y_test, y_pred, target_names=test_dataset.classes))
|
|
|
|
if __name__ == '__main__':
|
|
test()
|
|
|