From 9e4144ba6403437e4e553b6b9728b4fc82909880 Mon Sep 17 00:00:00 2001 From: pbyf83ift <1030227026@qq.com> Date: Tue, 2 Jul 2024 22:04:07 +0800 Subject: [PATCH] ADD file via upload --- calculate_model.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 calculate_model.py diff --git a/calculate_model.py b/calculate_model.py new file mode 100644 index 0000000..e4cad47 --- /dev/null +++ b/calculate_model.py @@ -0,0 +1,54 @@ +import model_elem + +import cv2 +import torch +import torchvision +import numpy as np +import torchvision.transforms as transforms +from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights +# 加载pytorch提供的keypointrcnn_resnet50_fpn()网络模型,可以对17个人体关键点进行检测。 +#如果有可用的GPU,则使用GPU,否则使用CPU。 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +#pytorch模型的python接口 +model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT) +model.to(device) +model.eval() + +def Capture_Point(image, confidence=0.9): + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + # 准备需要检测的图像 + transform_d = transforms.Compose([transforms.ToTensor()]) + image_t = transform_d(image) ## 对图像进行变换 + pred = model([image_t.to(device)]) ## 将模型作用到图像上 + # 检测出目标的类别和得分 + pred_class = [model_elem.OBJECT_LIST[ii] for ii in list(pred[0]['labels'].cpu().numpy())] + pred_score = list(pred[0]['scores'].detach().cpu().numpy()) + # 检测出目标的边界框 + pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().cpu().numpy())] + ## 只保留识别的概率大约 confidence 的结果。 + pred_index = [pred_score.index(x) for x in pred_score if x > confidence] + for index in pred_index: + box = pred_boxes[index] + box = [int(i) for i in box] + cv2.rectangle(image,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(0,255,255)) + texts = pred_class[index] + ":" + str(np.round(pred_score[index], 2)) + cv2.putText(image, texts,(box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2) + + pred_keypoint = pred[0]["keypoints"] + # 检测到实例的关键点 + pred_keypoint = pred_keypoint[pred_index].detach().cpu().numpy() + # 对实例数量索引 + my_result = {} + for index in range(pred_keypoint.shape[0]): + # 对每个实例的关键点索引 + keypoints = pred_keypoint[index] + for ii in range(keypoints.shape[0]): ##ii为第几个坐标点 + x = int(keypoints[ii, 0]) #x坐标 + y = int(keypoints[ii, 1]) #y坐标 + visi = keypoints[ii, 2] #置信度 + if visi > 0.: + cv2.circle(image, (int(x),int(y)), 1, (0,0,255),4) + texts = str(ii+1) + cv2.putText(image,texts, (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2) + my_result[texts] = (int(x), int(y)) + return image,my_result \ No newline at end of file