You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import model_elem
import cv2
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
# 加载pytorch提供的keypointrcnn_resnet50_fpn()网络模型可以对17个人体关键点进行检测。
#如果有可用的GPU则使用GPU否则使用CPU。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#pytorch模型的python接口
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
model.to(device)
model.eval()
def Capture_Point(image, confidence=0.9):
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# 准备需要检测的图像
transform_d = transforms.Compose([transforms.ToTensor()])
image_t = transform_d(image) ## 对图像进行变换
pred = model([image_t.to(device)]) ## 将模型作用到图像上
# 检测出目标的类别和得分
pred_class = [model_elem.OBJECT_LIST[ii] for ii in list(pred[0]['labels'].cpu().numpy())]
pred_score = list(pred[0]['scores'].detach().cpu().numpy())
# 检测出目标的边界框
pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().cpu().numpy())]
## 只保留识别的概率大约 confidence 的结果。
pred_index = [pred_score.index(x) for x in pred_score if x > confidence]
for index in pred_index:
box = pred_boxes[index]
box = [int(i) for i in box]
cv2.rectangle(image,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(0,255,255))
texts = pred_class[index] + ":" + str(np.round(pred_score[index], 2))
cv2.putText(image, texts,(box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
pred_keypoint = pred[0]["keypoints"]
# 检测到实例的关键点
pred_keypoint = pred_keypoint[pred_index].detach().cpu().numpy()
# 对实例数量索引
my_result = {}
for index in range(pred_keypoint.shape[0]):
# 对每个实例的关键点索引
keypoints = pred_keypoint[index]
for ii in range(keypoints.shape[0]): ##ii为第几个坐标点
x = int(keypoints[ii, 0]) #x坐标
y = int(keypoints[ii, 1]) #y坐标
visi = keypoints[ii, 2] #置信度
if visi > 0.:
cv2.circle(image, (int(x),int(y)), 1, (0,0,255),4)
texts = str(ii+1)
cv2.putText(image,texts, (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
my_result[texts] = (int(x), int(y))
return image,my_result