关键点列表遍历,关键点相对位置计算

main
o__-xpf-__o 1 week ago
parent a0ff98ae48
commit 0e5e49df79

@ -1,2 +1,2 @@
#Fri Apr 25 19:53:31 CST 2025
java.home=/Applications/Android Studio.app/Contents/jbr/Contents/Home
#Sat May 24 23:39:46 CST 2025
java.home=D\:\\Andr\\jbr

Binary file not shown.

@ -5,6 +5,15 @@
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
</SelectionState>
<SelectionState runConfigName="MovenetLightningTest">
<option name="selectionMode" value="DROPDOWN" />
</SelectionState>
<SelectionState runConfigName="MovenetThunderTest">
<option name="selectionMode" value="DROPDOWN" />
</SelectionState>
<SelectionState runConfigName="MovenetMultiPoseTest">
<option name="selectionMode" value="DROPDOWN" />
</SelectionState>
</selectionStates>
</component>
</project>

@ -9,8 +9,7 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
See the License for the specific language governing permissions and limitations under the License.
==============================================================================
*/
@ -26,37 +25,49 @@ import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
// 使用 AndroidJUnit4 测试框架运行测试
@RunWith(AndroidJUnit4::class)
class MovenetLightningTest {
// 常量定义,用于测试
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 21f
private const val TEST_INPUT_IMAGE1 = "image1.png" // 测试用输入图像1
private const val TEST_INPUT_IMAGE2 = "image2.jpg" // 测试用输入图像2
private const val ACCEPTABLE_ERROR = 21f // 可接受的误差阈值
}
// 用于检测姿势的类实例
private lateinit var poseDetector: PoseDetector
// 应用上下文
private lateinit var appContext: Context
// 预期的检测结果
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
// 在每个测试执行前进行初始化
@Before
fun setup() {
// 获取应用的上下文
appContext = InstrumentationRegistry.getInstrumentation().targetContext
// 创建 PoseDetector 实例,使用 Movenet 模型的 Lightening 版本,并指定使用 CPU 设备
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
// 加载预期的检测结果CSV 文件中的数据)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
}
// 测试第一个输入图像的姿势估计结果
@Test
fun testPoseEstimationResultWithImage1() {
// 加载输入图像1
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
// Movenet 使用前一帧的信息优化检测结果,因此多次运行相同图像以改善检测结果
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
// 获取姿势估计的第一个人的结果
val person = poseDetector.estimatePoses(input)[0]
// 校验检测结果与预期值的误差是否在可接受范围内
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[0],
@ -64,20 +75,23 @@ class MovenetLightningTest {
)
}
// 测试第二个输入图像的姿势估计结果
@Test
fun testPoseEstimationResultWithImage2() {
// 加载输入图像2
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
// Movenet 使用前一帧的信息优化检测结果,因此多次运行相同图像以改善检测结果
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
// 获取姿势估计的第一个人的结果
val person = poseDetector.estimatePoses(input)[0]
// 校验检测结果与预期值的误差是否在可接受范围内
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}
}

@ -25,7 +25,12 @@ import android.os.Process
import android.view.SurfaceView
import android.view.View
import android.view.WindowManager
import android.widget.*
import android.widget.AdapterView
import android.widget.ArrayAdapter
import android.widget.CompoundButton
import android.widget.Spinner
import android.widget.TextView
import android.widget.Toast
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.appcompat.widget.SwitchCompat
@ -36,7 +41,13 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import org.tensorflow.lite.examples.poseestimation.camera.CameraSource
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.ml.*
import org.tensorflow.lite.examples.poseestimation.ml.ModelType
import org.tensorflow.lite.examples.poseestimation.ml.MoveNet
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
import org.tensorflow.lite.examples.poseestimation.ml.PoseClassifier
import org.tensorflow.lite.examples.poseestimation.ml.PoseNet
import org.tensorflow.lite.examples.poseestimation.ml.TrackerType
import org.tensorflow.lite.examples.poseestimation.ml.Type
class MainActivity : AppCompatActivity() {
companion object {
@ -267,99 +278,95 @@ class MainActivity : AppCompatActivity() {
}
}
// Change model when app is running
// 在应用运行时更改模型
private fun changeModel(position: Int) {
if (modelPos == position) return
modelPos = position
createPoseEstimator()
createPoseEstimator() // 重新创建姿势估计器
}
// Change device (accelerator) type when app is running
// 在应用运行时更改设备(加速器)类型
private fun changeDevice(position: Int) {
val targetDevice = when (position) {
0 -> Device.CPU
1 -> Device.GPU
else -> Device.NNAPI
0 -> Device.CPU // 使用CPU
1 -> Device.GPU // 使用GPU
else -> Device.NNAPI // 使用NNAPI
}
if (device == targetDevice) return
device = targetDevice
createPoseEstimator()
createPoseEstimator() // 重新创建姿势估计器
}
// Change tracker for Movenet MultiPose model
// 更改Movenet MultiPose模型的跟踪器
private fun changeTracker(position: Int) {
cameraSource?.setTracker(
when (position) {
1 -> TrackerType.BOUNDING_BOX
2 -> TrackerType.KEYPOINTS
else -> TrackerType.OFF
1 -> TrackerType.BOUNDING_BOX // 选择边界框跟踪器
2 -> TrackerType.KEYPOINTS // 选择关键点跟踪器
else -> TrackerType.OFF // 关闭跟踪器
}
)
}
// 创建姿势估计器,根据不同的模型类型配置不同的设置
private fun createPoseEstimator() {
// For MoveNet MultiPose, hide score and disable pose classifier as the model returns
// multiple Person instances.
// 对于MoveNet MultiPose隐藏分数并禁用姿势分类器因为该模型返回多个Person实例。
val poseDetector = when (modelPos) {
0 -> {
// MoveNet Lightning (SinglePose)
// MoveNet Lightning (单人姿势)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
MoveNet.create(this, device, ModelType.Lightning)
}
1 -> {
// MoveNet Thunder (SinglePose)
// MoveNet Thunder (单人姿势)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
MoveNet.create(this, device, ModelType.Thunder)
}
2 -> {
// MoveNet (Lightning) MultiPose
// MoveNet (Lightning) MultiPose (多人姿势)
showPoseClassifier(false)
showDetectionScore(false)
// Movenet MultiPose Dynamic does not support GPUDelegate
// Movenet MultiPose Dynamic不支持GPUDelegate
if (device == Device.GPU) {
showToast(getString(R.string.tfe_pe_gpu_error))
showToast(getString(R.string.tfe_pe_gpu_error)) // 如果是GPU设备显示错误提示
}
showTracker(true)
MoveNetMultiPose.create(
this,
device,
Type.Dynamic
)
MoveNetMultiPose.create(this, device, Type.Dynamic)
}
3 -> {
// PoseNet (SinglePose)
// PoseNet (单人姿势)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
PoseNet.create(this, device)
}
else -> {
null
null // 如果没有匹配的模型则返回null
}
}
poseDetector?.let { detector ->
cameraSource?.setDetector(detector)
cameraSource?.setDetector(detector) // 设置相机源的检测器
}
}
// Show/hide the pose classification option.
// 显示/隐藏姿势分类选项
private fun showPoseClassifier(isVisible: Boolean) {
vClassificationOption.visibility = if (isVisible) View.VISIBLE else View.GONE
if (!isVisible) {
swClassification.isChecked = false
swClassification.isChecked = false // 如果不可见,取消选择分类开关
}
}
// Show/hide the detection score.
// 显示/隐藏检测分数
private fun showDetectionScore(isVisible: Boolean) {
tvScore.visibility = if (isVisible) View.VISIBLE else View.GONE
}
// Show/hide classification result.
// 显示/隐藏分类结果
private fun showClassificationResult(isVisible: Boolean) {
val visibility = if (isVisible) View.VISIBLE else View.GONE
tvClassificationValue1.visibility = visibility
@ -367,31 +374,31 @@ class MainActivity : AppCompatActivity() {
tvClassificationValue3.visibility = visibility
}
// Show/hide the tracking options.
// 显示/隐藏跟踪选项
private fun showTracker(isVisible: Boolean) {
if (isVisible) {
// Show tracker options and enable Bounding Box tracker.
// 显示跟踪器选项并启用边界框跟踪器
vTrackerOption.visibility = View.VISIBLE
spnTracker.setSelection(1)
} else {
// Set tracker type to off and hide tracker option.
// 设置跟踪器类型为关闭,并隐藏跟踪选项
vTrackerOption.visibility = View.GONE
spnTracker.setSelection(0)
}
}
// 请求权限
private fun requestPermission() {
when (PackageManager.PERMISSION_GRANTED) {
ContextCompat.checkSelfPermission(
this,
Manifest.permission.CAMERA
) -> {
// You can use the API that requires the permission.
// 如果已获得权限可以直接使用需要权限的API
openCamera()
}
else -> {
// You can directly ask for the permission.
// The registered ActivityResultCallback gets the result of this request.
// 否则,直接请求权限
requestPermissionLauncher.launch(
Manifest.permission.CAMERA
)
@ -399,6 +406,7 @@ class MainActivity : AppCompatActivity() {
}
}
// 显示Toast消息
private fun showToast(message: String) {
Toast.makeText(this, message, Toast.LENGTH_LONG).show()
}

@ -1,41 +1,51 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
// 定义一个枚举类 BodyPart用于表示人体的各个部位
// 每个部位都关联一个整数值,表示该部位在模型中的位置
// BodyPart 枚举类,其中每个枚举值代表一个人体部位
// 枚举类的构造函数接收一个整数值 position表示该部位的位置信息
enum class BodyPart(val position: Int) {
NOSE(0),
LEFT_EYE(1),
RIGHT_EYE(2),
LEFT_EAR(3),
RIGHT_EAR(4),
LEFT_SHOULDER(5),
RIGHT_SHOULDER(6),
LEFT_ELBOW(7),
RIGHT_ELBOW(8),
LEFT_WRIST(9),
RIGHT_WRIST(10),
LEFT_HIP(11),
RIGHT_HIP(12),
LEFT_KNEE(13),
RIGHT_KNEE(14),
LEFT_ANKLE(15),
RIGHT_ANKLE(16);
companion object{
NOSE(0), // 鼻子,位置 0
LEFT_EYE(1), // 左眼,位置 1
RIGHT_EYE(2), // 右眼,位置 2
LEFT_EAR(3), // 左耳,位置 3
RIGHT_EAR(4), // 右耳,位置 4
LEFT_SHOULDER(5), // 左肩,位置 5
RIGHT_SHOULDER(6),// 右肩,位置 6
LEFT_ELBOW(7), // 左肘,位置 7
RIGHT_ELBOW(8), // 右肘,位置 8
LEFT_WRIST(9), // 左手腕,位置 9
RIGHT_WRIST(10), // 右手腕,位置 10
LEFT_HIP(11), // 左髋,位置 11
RIGHT_HIP(12), // 右髋,位置 12
LEFT_KNEE(13), // 左膝,位置 13
RIGHT_KNEE(14), // 右膝,位置 14
LEFT_ANKLE(15), // 左踝,位置 15
RIGHT_ANKLE(16); // 右踝,位置 16
companion object {
// 使用一个映射表,将每个部位的位置与 BodyPart 进行关联
private val map = values().associateBy(BodyPart::position)
// 根据位置返回对应的 BodyPart
// 如果传入的位置没有对应的 BodyPart则抛出异常
fun fromInt(position: Int): BodyPart = map.getValue(position)
}
}

@ -17,5 +17,56 @@ limitations under the License.
package org.tensorflow.lite.examples.poseestimation.data
import android.graphics.PointF
import kotlin.math.acos
import kotlin.math.atan2
import kotlin.math.sqrt
data class KeyPoint(val bodyPart: BodyPart, var coordinate: PointF, val score: Float)
data class KeyPoint(
// bodyPart 表示该关键点对应的人体部位,是一个 BodyPart 枚举类型
val bodyPart: BodyPart,
// coordinate 表示该关键点的坐标位置,是一个 PointF 对象,包含了 x 和 y 坐标
var coordinate: PointF,
// score 表示该关键点的置信度分数,值越高表示模型对该关键点预测的准确性越高
val score: Float
){
// 函数1: 计算两点连线与竖直方向和水平方向的夹角
fun abtPoints(other: KeyPoint): Pair<Float, Float> {
val dx = other.coordinate.x - this.coordinate.x
val dy = other.coordinate.y - this.coordinate.y
// 计算连线与水平方向的夹角atan2返回的是弧度转成角度
val angleX = Math.toDegrees(atan2(dy.toDouble(), dx.toDouble())).toFloat()
// 计算连线与竖直方向的夹角
val angleY = Math.toDegrees(atan2(dx.toDouble(), dy.toDouble())).toFloat()
return Pair(angleX, angleY)
}
// 函数2: 计算角ABC
fun abttPoints(A: KeyPoint, B: KeyPoint, C: KeyPoint): Float {
// 计算向量BA和BC的坐标
val BAx = A.coordinate.x - B.coordinate.x
val BAy = A.coordinate.y - B.coordinate.y
val BCx = C.coordinate.x - B.coordinate.x
val BCy = C.coordinate.y - B.coordinate.y
// 计算BA和BC的点积
val dotProduct = BAx * BCx + BAy * BCy
// 计算BA和BC的模长
val magBA = sqrt(BAx * BAx + BAy * BAy)
val magBC = sqrt(BCx * BCx + BCy * BCy)
// 计算夹角的cos值
val cosAngle = dotProduct / (magBA * magBC)
// 计算夹角的角度(弧度转角度)
val angle = Math.toDegrees(acos(cosAngle.toDouble())).toFloat()
return angle
}
}

@ -16,11 +16,24 @@ limitations under the License.
package org.tensorflow.lite.examples.poseestimation.data
// 导入 Android 的 RectF 类,用于表示矩形区域
import android.graphics.RectF
// Person 数据类用于表示一个人体,包含其识别信息、关键点、边界框等
// 适用于多人体姿态估计应用场景
data class Person(
var id: Int = -1, // default id is -1
// id 表示该人体的唯一标识符,默认为 -1。每个人体有一个唯一的 id便于区分不同的人体
var id: Int = -1,
// keyPoints 表示该人体的关键点集合,是一个包含多个 KeyPoint 对象的列表
// 每个 KeyPoint 表示人体的一个部位及其坐标和置信度
val keyPoints: List<KeyPoint>,
val boundingBox: RectF? = null, // Only MoveNet MultiPose return bounding box.
// boundingBox 表示该人体的边界框,是一个 RectF 对象
// 仅在使用 MoveNet MultiPose 模型时返回边界框信息
// 这个矩形框用于表示人体的大致位置和区域
val boundingBox: RectF? = null, // MoveNet MultiPose 返回的边界框,可为 null
// score 表示该人体的置信度分数,值越高表示模型对该人体的识别和预测的可信度越高
val score: Float
)

@ -15,10 +15,23 @@ limitations under the License.
*/
package org.tensorflow.lite.examples.poseestimation.data
// TorsoAndBodyDistance 数据类用于表示躯干和身体在不同方向上的最大距离
// 该类通常用于计算和表示躯干(如躯干的上半部分)和身体整体在不同坐标轴上的位移或距离差异
data class TorsoAndBodyDistance(
// maxTorsoYDistance 表示躯干在 Y 轴(垂直方向)上的最大距离
// 这个值用于衡量躯干在上下方向上的移动范围或位移
val maxTorsoYDistance: Float,
// maxTorsoXDistance 表示躯干在 X 轴(水平方向)上的最大距离
// 这个值用于衡量躯干在左右方向上的移动范围或位移
val maxTorsoXDistance: Float,
// maxBodyYDistance 表示身体整体在 Y 轴(垂直方向)上的最大距离
// 这个值用于衡量整个身体在上下方向上的移动范围或位移
val maxBodyYDistance: Float,
// maxBodyXDistance 表示身体整体在 X 轴(水平方向)上的最大距离
// 这个值用于衡量整个身体在左右方向上的移动范围或位移
val maxBodyXDistance: Float
)

@ -33,40 +33,41 @@ import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min
// 定义模型类型支持Lightning和Thunder两种
enum class ModelType {
Lightning,
Thunder
}
// MoveNet类实现了PoseDetector接口用于姿势估计
class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
PoseDetector {
companion object {
private const val MIN_CROP_KEYPOINT_SCORE = .2f
private const val CPU_NUM_THREADS = 4
private const val MIN_CROP_KEYPOINT_SCORE = .2f // 关键点最低得分
private const val CPU_NUM_THREADS = 4 // CPU线程数
// Parameters that control how large crop region should be expanded from previous frames'
// body keypoints.
private const val TORSO_EXPANSION_RATIO = 1.9f
private const val BODY_EXPANSION_RATIO = 1.2f
// 控制如何从上一帧的人体关键点扩展裁剪区域的参数
private const val TORSO_EXPANSION_RATIO = 1.9f // 胸部扩展比率
private const val BODY_EXPANSION_RATIO = 1.2f // 身体扩展比率
// TFLite file names.
// TFLite模型文件名称
private const val LIGHTNING_FILENAME = "movenet_lightning.tflite"
private const val THUNDER_FILENAME = "movenet_thunder.tflite"
// allow specifying model type.
// 根据设备类型创建MoveNet实例
fun create(context: Context, device: Device, modelType: ModelType): MoveNet {
val options = Interpreter.Options()
var gpuDelegate: GpuDelegate? = null
options.setNumThreads(CPU_NUM_THREADS)
options.setNumThreads(CPU_NUM_THREADS) // 设置CPU线程数
when (device) {
Device.CPU -> {
}
Device.GPU -> {
gpuDelegate = GpuDelegate()
gpuDelegate = GpuDelegate() // 使用GPU加速
options.addDelegate(gpuDelegate)
}
Device.NNAPI -> options.setUseNNAPI(true)
Device.NNAPI -> options.setUseNNAPI(true) // 使用NNAPI加速
}
return MoveNet(
Interpreter(
@ -80,26 +81,27 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
)
}
// default to lightning.
// 默认使用Lightning模型
fun create(context: Context, device: Device): MoveNet =
create(context, device, ModelType.Lightning)
}
private var cropRegion: RectF? = null
private var lastInferenceTimeNanos: Long = -1
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
private var outputShape: IntArray = interpreter.getOutputTensor(0).shape()
private var cropRegion: RectF? = null // 用于存储裁剪区域
private var lastInferenceTimeNanos: Long = -1 // 上次推理的时间
private val inputWidth = interpreter.getInputTensor(0).shape()[1] // 输入图像的宽度
private val inputHeight = interpreter.getInputTensor(0).shape()[2] // 输入图像的高度
private var outputShape: IntArray = interpreter.getOutputTensor(0).shape() // 输出的形状
// 估算姿势
override fun estimatePoses(bitmap: Bitmap): List<Person> {
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() // 获取当前时间
if (cropRegion == null) {
cropRegion = initRectF(bitmap.width, bitmap.height)
cropRegion = initRectF(bitmap.width, bitmap.height) // 初始化裁剪区域
}
var totalScore = 0f
var totalScore = 0f // 总得分
val numKeyPoints = outputShape[2]
val keyPoints = mutableListOf<KeyPoint>()
val numKeyPoints = outputShape[2] // 关键点的数量
val keyPoints = mutableListOf<KeyPoint>() // 存储关键点
cropRegion?.run {
val rect = RectF(
@ -107,35 +109,35 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
(top * bitmap.height),
(right * bitmap.width),
(bottom * bitmap.height)
)
) // 根据裁剪区域计算裁剪矩形
val detectBitmap = Bitmap.createBitmap(
rect.width().toInt(),
rect.height().toInt(),
Bitmap.Config.ARGB_8888
)
) // 创建裁剪后的Bitmap
Canvas(detectBitmap).drawBitmap(
bitmap,
-rect.left,
-rect.top,
null
)
val inputTensor = processInputImage(detectBitmap, inputWidth, inputHeight)
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
val inputTensor = processInputImage(detectBitmap, inputWidth, inputHeight) // 处理输入图像
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32) // 创建输出张量
val widthRatio = detectBitmap.width.toFloat() / inputWidth
val heightRatio = detectBitmap.height.toFloat() / inputHeight
val positions = mutableListOf<Float>()
inputTensor?.let { input ->
interpreter.run(input.buffer, outputTensor.buffer.rewind())
interpreter.run(input.buffer, outputTensor.buffer.rewind()) // 运行推理
val output = outputTensor.floatArray
for (idx in 0 until numKeyPoints) {
val x = output[idx * 3 + 1] * inputWidth * widthRatio
val y = output[idx * 3 + 0] * inputHeight * heightRatio
val x = output[idx * 3 + 1] * inputWidth * widthRatio // 获取x坐标
val y = output[idx * 3 + 0] * inputHeight * heightRatio // 获取y坐标
positions.add(x)
positions.add(y)
val score = output[idx * 3 + 2]
val score = output[idx * 3 + 2] // 获取得分
keyPoints.add(
KeyPoint(
BodyPart.fromInt(idx),
@ -152,8 +154,8 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
val matrix = Matrix()
val points = positions.toFloatArray()
matrix.postTranslate(rect.left, rect.top)
matrix.mapPoints(points)
matrix.postTranslate(rect.left, rect.top) // 计算偏移
matrix.mapPoints(points) // 映射坐标
keyPoints.forEachIndexed { index, keyPoint ->
keyPoint.coordinate =
PointF(
@ -161,16 +163,18 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
points[index * 2 + 1]
)
}
// new crop region
// 更新裁剪区域
cropRegion = determineRectF(keyPoints, bitmap.width, bitmap.height)
}
lastInferenceTimeNanos =
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
return listOf(Person(keyPoints = keyPoints, score = totalScore / numKeyPoints))
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos // 计算推理时间
return listOf(Person(keyPoints = keyPoints, score = totalScore / numKeyPoints)) // 返回姿势信息
}
// 获取上次推理时间
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
// 关闭资源
override fun close() {
gpuDelegate?.close()
interpreter.close()
@ -178,27 +182,25 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
}
/**
* Prepare input image for detection
* 准备输入图像进行检测
*/
private fun processInputImage(bitmap: Bitmap, inputWidth: Int, inputHeight: Int): TensorImage? {
val width: Int = bitmap.width
val height: Int = bitmap.height
val size = if (height > width) width else height
val size = if (height > width) width else height // 选择较小的一边
val imageProcessor = ImageProcessor.Builder().apply {
add(ResizeWithCropOrPadOp(size, size))
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
add(ResizeWithCropOrPadOp(size, size)) // 裁剪或填充为正方形
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR)) // 调整大小
}.build()
val tensorImage = TensorImage(DataType.UINT8)
tensorImage.load(bitmap)
return imageProcessor.process(tensorImage)
tensorImage.load(bitmap) // 加载图像
return imageProcessor.process(tensorImage) // 返回处理后的图像
}
/**
* Defines the default crop region.
* The function provides the initial crop region (pads the full image from both
* sides to make it a square image) when the algorithm cannot reliably determine
* the crop region from the previous frame.
* 初始化裁剪区域
* 当算法无法从上一帧可靠地确定裁剪区域时提供初始裁剪区域将图像从两边填充为正方形
*/
private fun initRectF(imageWidth: Int, imageHeight: Int): RectF {
val xMin: Float
@ -225,9 +227,8 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
}
/**
* Checks whether there are enough torso keypoints.
* This function checks whether the model is confident at predicting one of the
* shoulders/hips which is required to determine a good crop region.
* 检查是否有足够的躯干关键点
* 该函数检查模型是否能够自信地预测一个肩膀/臀部这是确定良好裁剪区域所必需的
*/
private fun torsoVisible(keyPoints: List<KeyPoint>): Boolean {
return ((keyPoints[BodyPart.LEFT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE).or(
@ -240,13 +241,9 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
}
/**
* Determines the region to crop the image for the model to run inference on.
* The algorithm uses the detected joints from the previous frame to estimate
* the square region that encloses the full body of the target person and
* centers at the midpoint of two hip joints. The crop size is determined by
* the distances between each joints and the center point.
* When the model is not confident with the four torso joint predictions, the
* function returns a default crop which is the full image padded to square.
* 根据上一帧的关键点来确定裁剪区域
* 该算法使用上一帧检测到的关节来估算围绕目标人体的正方形区域并以两个臀部关节的中点为中心裁剪大小由每个关节与中心点的距离决定
* 当模型对四个躯干关节的预测不自信时函数返回默认的裁剪区域即将整个图像填充为正方形
*/
private fun determineRectF(
keyPoints: List<KeyPoint>,
@ -306,10 +303,9 @@ class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
}
/**
* Calculates the maximum distance from each keypoints to the center location.
* The function returns the maximum distances from the two sets of keypoints:
* full 17 keypoints and 4 torso keypoints. The returned information will be
* used to determine the crop size. See determineRectF for more detail.
* * 计算每个关键点到中心位置的最大距离
* * 该函数返回来自两组关键点的最大距离
* * 全部17个关键点和4个躯干关键点返回的信息将用于确定裁剪大小
*/
private fun determineTorsoAndBodyDistances(
keyPoints: List<KeyPoint>,

@ -21,52 +21,60 @@ import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.examples.poseestimation.data.Person
import org.tensorflow.lite.support.common.FileUtil
// 定义PoseClassifier类用于进行姿势分类
class PoseClassifier(
private val interpreter: Interpreter,
private val labels: List<String>
private val interpreter: Interpreter, // TensorFlow Lite解释器实例
private val labels: List<String> // 标签列表,用于模型输出后转换为具体的类别名称
) {
private val input = interpreter.getInputTensor(0).shape()
private val output = interpreter.getOutputTensor(0).shape()
private val input = interpreter.getInputTensor(0).shape() // 获取输入张量的形状
private val output = interpreter.getOutputTensor(0).shape() // 获取输出张量的形状
companion object {
private const val MODEL_FILENAME = "classifier.tflite"
private const val LABELS_FILENAME = "labels.txt"
private const val CPU_NUM_THREADS = 4
private const val MODEL_FILENAME = "classifier.tflite" // 模型文件名
private const val LABELS_FILENAME = "labels.txt" // 标签文件名
private const val CPU_NUM_THREADS = 4 // 设置使用的线程数
// 创建PoseClassifier的工厂方法
fun create(context: Context): PoseClassifier {
// 设置解释器的选项,包括线程数
val options = Interpreter.Options().apply {
setNumThreads(CPU_NUM_THREADS)
}
// 返回一个新的PoseClassifier实例
return PoseClassifier(
Interpreter(
FileUtil.loadMappedFile(
context, MODEL_FILENAME
context, MODEL_FILENAME // 加载模型文件
), options
),
FileUtil.loadLabels(context, LABELS_FILENAME)
FileUtil.loadLabels(context, LABELS_FILENAME) // 加载标签文件
)
}
}
// 姿势分类方法
fun classify(person: Person?): List<Pair<String, Float>> {
// Preprocess the pose estimation result to a flat array
val inputVector = FloatArray(input[1])
// 预处理姿势估计结果,转换为一维数组
val inputVector = FloatArray(input[1]) // 创建输入向量
person?.keyPoints?.forEachIndexed { index, keyPoint ->
// 将每个关键点的坐标y, x和得分score填充到输入向量中
inputVector[index * 3] = keyPoint.coordinate.y
inputVector[index * 3 + 1] = keyPoint.coordinate.x
inputVector[index * 3 + 2] = keyPoint.score
}
// Postprocess the model output to human readable class names
val outputTensor = FloatArray(output[1])
interpreter.run(arrayOf(inputVector), arrayOf(outputTensor))
val output = mutableListOf<Pair<String, Float>>()
// 处理模型输出,将结果转换为可读的类别名称
val outputTensor = FloatArray(output[1]) // 创建输出张量
interpreter.run(arrayOf(inputVector), arrayOf(outputTensor)) // 运行模型进行推理
val output = mutableListOf<Pair<String, Float>>() // 存储输出结果的列表
outputTensor.forEachIndexed { index, score ->
// 将每个输出结果与标签配对
output.add(Pair(labels[index], score))
}
return output
return output // 返回输出结果
}
// 关闭解释器
fun close() {
interpreter.close()
}

@ -19,9 +19,12 @@ package org.tensorflow.lite.examples.poseestimation.ml
import android.graphics.Bitmap
import org.tensorflow.lite.examples.poseestimation.data.Person
// 定义一个PoseDetector接口继承自AutoCloseable表示此接口的实现类需要提供关闭资源的功能
interface PoseDetector : AutoCloseable {
// 使用输入的位图Bitmap估计图像中的姿势返回一个Person对象的列表每个Person表示一个姿势估计的结果
fun estimatePoses(bitmap: Bitmap): List<Person>
// 获取上一次推理的耗时(以纳秒为单位)
fun lastInferenceTimeNanos(): Long
}

@ -34,18 +34,25 @@ import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
import kotlin.math.acos
import kotlin.math.atan2
import kotlin.math.exp
import kotlin.math.sqrt
class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
PoseDetector {
// 定义PoseNet类继承自PoseDetector接口负责姿势检测
class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) : PoseDetector {
companion object {
// 设置CPU线程数
private const val CPU_NUM_THREADS = 4
// 图像标准化时的均值和标准差
private const val MEAN = 127.5f
private const val STD = 127.5f
private const val TAG = "Posenet"
// 模型文件名称
private const val MODEL_FILENAME = "posenet.tflite"
// 创建PoseNet实例的方法根据设备类型选择相应的计算方式CPU、GPU或NNAPI
fun create(context: Context, device: Device): PoseNet {
val options = Interpreter.Options()
var gpuDelegate: GpuDelegate? = null
@ -54,33 +61,35 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
Device.CPU -> {
}
Device.GPU -> {
gpuDelegate = GpuDelegate()
gpuDelegate = GpuDelegate() // 使用GPU加速
options.addDelegate(gpuDelegate)
}
Device.NNAPI -> options.setUseNNAPI(true)
Device.NNAPI -> options.setUseNNAPI(true) // 使用NNAPI加速
}
return PoseNet(
Interpreter(
FileUtil.loadMappedFile(
context,
MODEL_FILENAME
), options
FileUtil.loadMappedFile(context, MODEL_FILENAME), options
),
gpuDelegate
)
}
}
// 保存上一次推理的时间
private var lastInferenceTimeNanos: Long = -1
// 获取输入张量的宽度和高度
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
private var cropHeight = 0f
private var cropWidth = 0f
private var cropSize = 0
// 重写estimatePoses方法用于估算图像中的姿势
@Suppress("UNCHECKED_CAST")
override fun estimatePoses(bitmap: Bitmap): List<Person> {
// 记录估算开始时间
val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
// 处理输入图像并转换为TensorBuffer
val inputArray = arrayOf(processInputImage(bitmap).tensorBuffer.buffer)
Log.i(
TAG,
@ -90,20 +99,27 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
)
)
// 初始化输出map
val outputMap = initOutputMap(interpreter)
// 记录推理开始时间
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
// 运行模型进行推理
interpreter.runForMultipleInputsOutputs(inputArray, outputMap)
// 记录推理时间
lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
Log.i(
TAG,
String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
)
// 获取热图和偏移量输出
val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>
// 记录后处理开始时间
val postProcessingStartTimeNanos = SystemClock.elapsedRealtimeNanos()
// 进行后处理,提取关键点
val person = postProcessModelOuputs(heatmaps, offsets)
Log.i(
TAG,
@ -113,12 +129,42 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
)
)
// 返回结果封装成一个Person对象
return listOf(person)
}
/**
* Convert heatmaps and offsets output of Posenet into a list of keypoints
* 将Posenet的输出热图和偏移量转换为关键点列表
*/
//函数1: 计算两点连线倾斜程度
fun calculateAngle(x1: Float, y1: Float, x2: Float, y2: Float): Float {
val Y = y2 - y1
val X = x2 - x1
return Math.toDegrees(atan2(Y.toDouble(),X.toDouble())).toFloat()
}
// 函数2: 计算角ABC
fun abttPoints(A: KeyPoint, B: KeyPoint, C: KeyPoint): Float {
// 计算向量BA和BC的坐标
val BAx = A.coordinate.x - B.coordinate.x
val BAy = A.coordinate.y - B.coordinate.y
val BCx = C.coordinate.x - B.coordinate.x
val BCy = C.coordinate.y - B.coordinate.y
// 计算BA和BC的点积
val dotProduct = BAx * BCx + BAy * BCy
// 计算BA和BC的模长
val magBA = sqrt(BAx * BAx + BAy * BAy)
val magBC = sqrt(BCx * BCx + BCy * BCy)
// 计算夹角的cos值
val cosAngle = dotProduct / (magBA * magBC)
// 计算夹角的角度(弧度转角度)
val angle = Math.toDegrees(acos(cosAngle.toDouble())).toFloat()
return angle
}
private fun postProcessModelOuputs(
heatmaps: Array<Array<Array<FloatArray>>>,
offsets: Array<Array<Array<FloatArray>>>
@ -127,7 +173,7 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
val width = heatmaps[0][0].size
val numKeypoints = heatmaps[0][0][0].size
// Finds the (row, col) locations of where the keypoints are most likely to be.
// 寻找每个关键点最可能出现的(行,列)位置
val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
for (keypoint in 0 until numKeypoints) {
var maxVal = heatmaps[0][0][0][keypoint]
@ -145,7 +191,7 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
keypointPositions[keypoint] = Pair(maxRow, maxCol)
}
// Calculating the x and y coordinates of the keypoints with offset adjustment.
// 计算关键点的x和y坐标并进行偏移调整
val xCoords = IntArray(numKeypoints)
val yCoords = IntArray(numKeypoints)
val confidenceScores = FloatArray(numKeypoints)
@ -168,6 +214,7 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
}
// 创建关键点列表并计算总分数
val keypointList = mutableListOf<KeyPoint>()
var totalScore = 0.0f
enumValues<BodyPart>().forEachIndexed { idx, it ->
@ -180,21 +227,52 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
)
totalScore += confidenceScores[idx]
}
// 在此部分中添加角度数据和合格率数据
// 遍历 keyPoints
keypointList.forEach { keypoint ->
// 获取当前关键点的位置和置信度
val x = keypoint.coordinate.x
val y = keypoint.coordinate.y
//这里添加标准度
// 假设我们计算胸部与左右肩膀之间的角度(这里只是示例,具体取决于你的模型)
if (keypoint.bodyPart == BodyPart.LEFT_SHOULDER || keypoint.bodyPart == BodyPart.RIGHT_SHOULDER) {
// 获取左肩和右肩的坐标
val leftShoulder = keypointList.first { it.bodyPart == BodyPart.LEFT_SHOULDER }
val rightShoulder = keypointList.first { it.bodyPart == BodyPart.RIGHT_SHOULDER }
// 计算肩膀角度
val angle = calculateAngle(
leftShoulder.coordinate.x,
leftShoulder.coordinate.y,
rightShoulder.coordinate.x,
rightShoulder.coordinate.y
)
}
// 添加合格率数据
}
// 返回一个包含关键点和总分数的Person对象
return Person(keyPoints = keypointList.toList(), score = totalScore / numKeypoints)
}
// 返回上一次推理的时间
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
// 关闭PoseNet模型和GPU代理
override fun close() {
gpuDelegate?.close()
interpreter.close()
}
/**
* Scale and crop the input image to a TensorImage.
* 将输入的位图图像进行缩放和裁剪转换为TensorImage
*/
private fun processInputImage(bitmap: Bitmap): TensorImage {
// reset crop width and height
// 重置裁剪的宽度和高度
cropWidth = 0f
cropHeight = 0f
cropSize = if (bitmap.width > bitmap.height) {
@ -205,15 +283,21 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
bitmap.height
}
// 设置图像处理器,执行缩放、裁剪和标准化操作
val imageProcessor = ImageProcessor.Builder().apply {
add(ResizeWithCropOrPadOp(cropSize, cropSize))
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
add(NormalizeOp(MEAN, STD))
}.build()
// 创建TensorImage并加载图像
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(bitmap)
// 处理图像并返回
return imageProcessor.process(tensorImage)
}
}
/**
* Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
@ -258,4 +342,3 @@ class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: Gpu
private fun sigmoid(x: Float): Float {
return (1.0f / (1.0f + exp(-x)))
}
}

@ -4,5 +4,5 @@
# Location of the SDK. This is only used by Gradle.
# For customization when using a Version Control System, please read the
# header note.
#Fri Apr 25 19:53:17 CST 2025
sdk.dir=/Users/ziyue/Library/Android/sdk
#Sat May 24 23:39:55 CST 2025
sdk.dir=C\:\\Users\\26891\\AppData\\Local\\Android\\Sdk

Loading…
Cancel
Save