|
|
|
|
@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
|
|
|
|
|
import torchvision
|
|
|
|
|
from modules import MyTransforms
|
|
|
|
|
from modules.ResNet import ResNet
|
|
|
|
|
from sklearn import metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
@ -50,10 +51,8 @@ def test():
|
|
|
|
|
# 测试模型
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
test_total = 0
|
|
|
|
|
test_correct = 0
|
|
|
|
|
test_label_correct = {0:0, 1:0}
|
|
|
|
|
test_label_total = {0:0, 1:0}
|
|
|
|
|
y_test = []
|
|
|
|
|
y_pred = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
@ -62,22 +61,14 @@ def test():
|
|
|
|
|
labels = labels.to(DEVICE)
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
|
test_total += labels.size(0)
|
|
|
|
|
test_correct += (predicted == labels).sum().item()
|
|
|
|
|
|
|
|
|
|
# Compute label accuracy
|
|
|
|
|
for label, pred in zip(labels, predicted):
|
|
|
|
|
label = label.item()
|
|
|
|
|
pred = pred.item()
|
|
|
|
|
if label == pred:
|
|
|
|
|
test_label_correct[label] += 1
|
|
|
|
|
test_label_total[label] += 1
|
|
|
|
|
|
|
|
|
|
print('step: ', i, ', correct:', 100.0 * (predicted == labels).sum().item() / labels.size(0), '%')
|
|
|
|
|
|
|
|
|
|
y_test.extend(labels.cpu().numpy().tolist())
|
|
|
|
|
y_pred.extend(predicted.cpu().numpy().tolist())
|
|
|
|
|
|
|
|
|
|
print(f'Test Accuracy: {100. * test_correct / test_total:.2f}%')
|
|
|
|
|
acc = {label: 100.0 * test_label_correct.get(label, 0) / test_label_total.get(label, 1) for label in test_label_correct}
|
|
|
|
|
print(f'Test Label Accuracy: {acc}')
|
|
|
|
|
print("classification_report:")
|
|
|
|
|
print(metrics.classification_report(y_test, y_pred, target_names=test_dataset.classes))
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test()
|
|
|
|
|
|