You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
926 B

import torch
import torchvision.transforms as transforms
from PIL import Image
from MyV16 import VGG16_CIFAR10
# 加载预训练权重
model = VGG16_CIFAR10()
model.load_state_dict(torch.load('vgg16_cifar10_best.pth'))
model.eval()
# 预处理
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
image = Image.open("RC4.jpg").convert('RGB')
input_tensor = transform(image).unsqueeze(0)
# 推理
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
predicted_class = torch.argmax(probabilities).item()
# 输出结果
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f"预测类别: {classes[predicted_class]} (概率: {probabilities[predicted_class]:.2%})")