|
|
|
@ -0,0 +1,54 @@
|
|
|
|
|
import model_elem
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import torch
|
|
|
|
|
import torchvision
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
|
|
|
|
# 加载pytorch提供的keypointrcnn_resnet50_fpn()网络模型,可以对17个人体关键点进行检测。
|
|
|
|
|
#如果有可用的GPU,则使用GPU,否则使用CPU。
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
#pytorch模型的python接口
|
|
|
|
|
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
|
|
|
|
|
model.to(device)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
def Capture_Point(image, confidence=0.9):
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
|
|
|
# 准备需要检测的图像
|
|
|
|
|
transform_d = transforms.Compose([transforms.ToTensor()])
|
|
|
|
|
image_t = transform_d(image) ## 对图像进行变换
|
|
|
|
|
pred = model([image_t.to(device)]) ## 将模型作用到图像上
|
|
|
|
|
# 检测出目标的类别和得分
|
|
|
|
|
pred_class = [model_elem.OBJECT_LIST[ii] for ii in list(pred[0]['labels'].cpu().numpy())]
|
|
|
|
|
pred_score = list(pred[0]['scores'].detach().cpu().numpy())
|
|
|
|
|
# 检测出目标的边界框
|
|
|
|
|
pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().cpu().numpy())]
|
|
|
|
|
## 只保留识别的概率大约 confidence 的结果。
|
|
|
|
|
pred_index = [pred_score.index(x) for x in pred_score if x > confidence]
|
|
|
|
|
for index in pred_index:
|
|
|
|
|
box = pred_boxes[index]
|
|
|
|
|
box = [int(i) for i in box]
|
|
|
|
|
cv2.rectangle(image,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(0,255,255))
|
|
|
|
|
texts = pred_class[index] + ":" + str(np.round(pred_score[index], 2))
|
|
|
|
|
cv2.putText(image, texts,(box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
|
|
|
|
|
|
|
|
|
|
pred_keypoint = pred[0]["keypoints"]
|
|
|
|
|
# 检测到实例的关键点
|
|
|
|
|
pred_keypoint = pred_keypoint[pred_index].detach().cpu().numpy()
|
|
|
|
|
# 对实例数量索引
|
|
|
|
|
my_result = {}
|
|
|
|
|
for index in range(pred_keypoint.shape[0]):
|
|
|
|
|
# 对每个实例的关键点索引
|
|
|
|
|
keypoints = pred_keypoint[index]
|
|
|
|
|
for ii in range(keypoints.shape[0]): ##ii为第几个坐标点
|
|
|
|
|
x = int(keypoints[ii, 0]) #x坐标
|
|
|
|
|
y = int(keypoints[ii, 1]) #y坐标
|
|
|
|
|
visi = keypoints[ii, 2] #置信度
|
|
|
|
|
if visi > 0.:
|
|
|
|
|
cv2.circle(image, (int(x),int(y)), 1, (0,0,255),4)
|
|
|
|
|
texts = str(ii+1)
|
|
|
|
|
cv2.putText(image,texts, (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 255), 2)
|
|
|
|
|
my_result[texts] = (int(x), int(y))
|
|
|
|
|
return image,my_result
|