You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.6 KiB

import model_elem
import cv2
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
# 加载pytorch提供的keypointrcnn_resnet50_fpn()网络模型可以对17个人体关键点进行检测。
#如果有可用的GPU则使用GPU否则使用CPU。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#pytorch模型的python接口
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
model.to(device)
model.eval()
def Capture_Point(image, confidence=0.9):
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# 准备需要检测的图像
transform_d = transforms.Compose([transforms.ToTensor()])
image_t = transform_d(image) ## 对图像进行变换
pred = model([image_t.to(device)]) ## 将模型作用到图像上
# 检测出目标的类别和得分
pred_class = [model_elem.OBJECT_LIST[ii] for ii in list(pred[0]['labels'].cpu().numpy())]
pred_score = list(pred[0]['scores'].detach().cpu().numpy())
# 检测出目标的边界框
pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().cpu().numpy())]
## 只保留识别的概率大约 confidence 的结果。
pred_index = [pred_score.index(x) for x in pred_score if x > confidence]
for index in pred_index:
box = pred_boxes[index]
box = [int(i) for i in box]
cv2.rectangle(image,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(0,255,255))
texts = pred_class[index] + ":" + str(np.round(pred_score[index], 2))
cv2.putText(image, texts,(box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
pred_keypoint = pred[0]["keypoints"]
# 检测到实例的关键点
pred_keypoint = pred_keypoint[pred_index].detach().cpu().numpy()
# 对实例数量索引
my_result = {}
for index in range(pred_keypoint.shape[0]):
# 对每个实例的关键点索引
keypoints = pred_keypoint[index]
for ii in range(keypoints.shape[0]): ##ii为第几个坐标点
x = int(keypoints[ii, 0]) #x坐标
y = int(keypoints[ii, 1]) #y坐标
visi = keypoints[ii, 2] #置信度
if visi > 0.:
cv2.circle(image, (int(x),int(y)), 1, (0,0,255),4)
texts = str(ii+1)
cv2.putText(image,texts, (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
my_result[texts] = (int(x), int(y))
return image,my_result