ADD file via upload

main
p9kh64cfp 8 months ago
parent 16f181daea
commit 68d4aa4598

@ -0,0 +1,38 @@
import torch
from PIL import Image#从PIL库中导入Image模块用于处理图像数据
from torchvision import transforms#用于对图像进行预处理操作
import json#用于读取和解析JSON格式的文件
# 调用
def improved_alexnet_predict(imgf,model):
#数据预处理
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),#将图像尺寸统一为224*224
#transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),#将图像数据转化为张量,用于训练
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#归一化
img = Image.open(imgf)#使用PIL库中的Image模块打开指定路径的图像文件
# [N, C, H, W]
img = data_transform(img)#调用函数对打开的图像数据进行数据预处理操作
# expand batch dimension
img = torch.unsqueeze(img, dim=0)#在数据维度上增加一个维度batch用于模型的输入
# read class_indict
try:
json_file = open('./config.json', 'r', encoding='utf-8')
class_indict = json.load(json_file)#将标签文件加载到字典中
#print(class_indict)
except Exception as e:
print(e)
exit(-1)
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))#对输入的图像进行模型推理得到输出
predict = torch.softmax(output, dim=0)#对输出进行softmax操作得到预测概率分布
predict_cla = torch.argmax(predict).numpy()#找到概率最大的类别作为预测结果的类别预测结果的类别用numpy数组的形式返回
res=class_indict[str(predict_cla)]
print(res)#输出结果
return res
Loading…
Cancel
Save