You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cyz_software/predict_improved_alexnet.py

39 lines
1.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
from PIL import Image#从PIL库中导入Image模块用于处理图像数据
from torchvision import transforms#用于对图像进行预处理操作
import json#用于读取和解析JSON格式的文件
# 调用
def improved_alexnet_predict(imgf,model):
#数据预处理
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),#将图像尺寸统一为224*224
#transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),#将图像数据转化为张量,用于训练
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#归一化
img = Image.open(imgf)#使用PIL库中的Image模块打开指定路径的图像文件
# [N, C, H, W]
img = data_transform(img)#调用函数对打开的图像数据进行数据预处理操作
# expand batch dimension
img = torch.unsqueeze(img, dim=0)#在数据维度上增加一个维度batch用于模型的输入
# read class_indict
try:
json_file = open('./config.json', 'r', encoding='utf-8')
class_indict = json.load(json_file)#将标签文件加载到字典中
#print(class_indict)
except Exception as e:
print(e)
exit(-1)
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))#对输入的图像进行模型推理得到输出
predict = torch.softmax(output, dim=0)#对输出进行softmax操作得到预测概率分布
predict_cla = torch.argmax(predict).numpy()#找到概率最大的类别作为预测结果的类别预测结果的类别用numpy数组的形式返回
res=class_indict[str(predict_cla)]
print(res)#输出结果
return res