|
|
|
@ -0,0 +1,38 @@
|
|
|
|
|
import torch
|
|
|
|
|
from PIL import Image#从PIL库中导入Image模块,用于处理图像数据
|
|
|
|
|
from torchvision import transforms#用于对图像进行预处理操作
|
|
|
|
|
import json#用于读取和解析JSON格式的文件
|
|
|
|
|
|
|
|
|
|
# 调用
|
|
|
|
|
def improved_alexnet_predict(imgf,model):
|
|
|
|
|
#数据预处理
|
|
|
|
|
data_transform = transforms.Compose(
|
|
|
|
|
[transforms.Resize((224, 224)),#将图像尺寸统一为224*224
|
|
|
|
|
#transforms.Grayscale(num_output_channels=3),
|
|
|
|
|
transforms.ToTensor(),#将图像数据转化为张量,用于训练
|
|
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#归一化
|
|
|
|
|
img = Image.open(imgf)#使用PIL库中的Image模块打开指定路径的图像文件
|
|
|
|
|
# [N, C, H, W]
|
|
|
|
|
img = data_transform(img)#调用函数对打开的图像数据进行数据预处理操作
|
|
|
|
|
# expand batch dimension
|
|
|
|
|
img = torch.unsqueeze(img, dim=0)#在数据维度上增加一个维度batch,用于模型的输入
|
|
|
|
|
|
|
|
|
|
# read class_indict
|
|
|
|
|
try:
|
|
|
|
|
json_file = open('./config.json', 'r', encoding='utf-8')
|
|
|
|
|
class_indict = json.load(json_file)#将标签文件加载到字典中
|
|
|
|
|
#print(class_indict)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# predict class
|
|
|
|
|
output = torch.squeeze(model(img))#对输入的图像进行模型推理得到输出
|
|
|
|
|
predict = torch.softmax(output, dim=0)#对输出进行softmax操作得到预测概率分布
|
|
|
|
|
predict_cla = torch.argmax(predict).numpy()#找到概率最大的类别作为预测结果的类别,预测结果的类别用numpy数组的形式返回
|
|
|
|
|
res=class_indict[str(predict_cla)]
|
|
|
|
|
print(res)#输出结果
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|