From 70395a74c01218791adbb1da2eb726e49304444c Mon Sep 17 00:00:00 2001 From: pyhqos7bg Date: Thu, 30 May 2024 15:17:35 +0800 Subject: [PATCH] ADD file via upload --- model_predict.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 model_predict.py diff --git a/model_predict.py b/model_predict.py new file mode 100644 index 0000000..c0a71b6 --- /dev/null +++ b/model_predict.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np +import paddle.fluid as fluid +from 口罩检测.util import train_parameters +from 口罩检测.VGGNet import VGGNet +def load_image(img_path): + img =Image.open(img_path) + if img.mode !='RGB': + img = img.covert('RGB') + img = img.resize((244,244),Image.BILINEAR) + img = np.array(img).astype('float32') + img = img.transpose((2,0,1)) + img = img/255.0 + return img + +label_dict = train_parameters['label_dict'] + +#模型预测 +with fluid.dygraph.guard(): + model,_ = fluid.dygraph.load_dygraph('vgg') + vgg = VGGNet() + vgg.eval() + infer_path='./unmask.jpg' + img = Image.open(infer_path) + + x_data = load_image(infer_path) + x_data = np.array(x_data) + x_data = x_data[np.newaxis,:,:,:] + x_data= fluid.dygraph.to_variable(x_data) + out = vgg(x_data) + result = np.argmax(out.numpy()) + print(label_dict) + print("被预测的图片为:{}".format(label_dict[str(result)]))