You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
926 B
29 lines
926 B
import torch
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from MyV16 import VGG16_CIFAR10
|
|
|
|
# 加载预训练权重
|
|
model = VGG16_CIFAR10()
|
|
model.load_state_dict(torch.load('vgg16_cifar10_best.pth'))
|
|
model.eval()
|
|
|
|
# 预处理
|
|
transform = transforms.Compose([
|
|
transforms.Resize(32),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
|
|
])
|
|
|
|
image = Image.open("RC4.jpg").convert('RGB')
|
|
input_tensor = transform(image).unsqueeze(0)
|
|
|
|
# 推理
|
|
with torch.no_grad():
|
|
outputs = model(input_tensor)
|
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
|
predicted_class = torch.argmax(probabilities).item()
|
|
|
|
# 输出结果
|
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
|
print(f"预测类别: {classes[predicted_class]} (概率: {probabilities[predicted_class]:.2%})") |