From 68d4aa4598377d29445a0325544b24423bec96c2 Mon Sep 17 00:00:00 2001 From: p9kh64cfp <1047063963@qq.com> Date: Tue, 31 Dec 2024 11:21:34 +0800 Subject: [PATCH] ADD file via upload --- predict_improved_alexnet.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 predict_improved_alexnet.py diff --git a/predict_improved_alexnet.py b/predict_improved_alexnet.py new file mode 100644 index 0000000..149bd1a --- /dev/null +++ b/predict_improved_alexnet.py @@ -0,0 +1,38 @@ +import torch +from PIL import Image#从PIL库中导入Image模块,用于处理图像数据 +from torchvision import transforms#用于对图像进行预处理操作 +import json#用于读取和解析JSON格式的文件 + +# 调用 +def improved_alexnet_predict(imgf,model): + #数据预处理 + data_transform = transforms.Compose( + [transforms.Resize((224, 224)),#将图像尺寸统一为224*224 + #transforms.Grayscale(num_output_channels=3), + transforms.ToTensor(),#将图像数据转化为张量,用于训练 + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#归一化 + img = Image.open(imgf)#使用PIL库中的Image模块打开指定路径的图像文件 + # [N, C, H, W] + img = data_transform(img)#调用函数对打开的图像数据进行数据预处理操作 + # expand batch dimension + img = torch.unsqueeze(img, dim=0)#在数据维度上增加一个维度batch,用于模型的输入 + + # read class_indict + try: + json_file = open('./config.json', 'r', encoding='utf-8') + class_indict = json.load(json_file)#将标签文件加载到字典中 + #print(class_indict) + except Exception as e: + print(e) + exit(-1) + + with torch.no_grad(): + # predict class + output = torch.squeeze(model(img))#对输入的图像进行模型推理得到输出 + predict = torch.softmax(output, dim=0)#对输出进行softmax操作得到预测概率分布 + predict_cla = torch.argmax(predict).numpy()#找到概率最大的类别作为预测结果的类别,预测结果的类别用numpy数组的形式返回 + res=class_indict[str(predict_cla)] + print(res)#输出结果 + return res + +