diff --git a/.DS_Store b/.DS_Store index ff3856e..d09b377 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/android/.DS_Store b/android/.DS_Store new file mode 100644 index 0000000..c832e3b Binary files /dev/null and b/android/.DS_Store differ diff --git a/android/.gradle/8.5/checksums/checksums.lock b/android/.gradle/8.5/checksums/checksums.lock index e776488..456b3a5 100644 Binary files a/android/.gradle/8.5/checksums/checksums.lock and b/android/.gradle/8.5/checksums/checksums.lock differ diff --git a/android/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock b/android/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock index 1737bfd..9e81dd8 100644 Binary files a/android/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock and b/android/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock differ diff --git a/android/.gradle/8.5/executionHistory/executionHistory.bin b/android/.gradle/8.5/executionHistory/executionHistory.bin index bc72ced..508368f 100644 Binary files a/android/.gradle/8.5/executionHistory/executionHistory.bin and b/android/.gradle/8.5/executionHistory/executionHistory.bin differ diff --git a/android/.gradle/8.5/executionHistory/executionHistory.lock b/android/.gradle/8.5/executionHistory/executionHistory.lock index 5c8f309..0f6277c 100644 Binary files a/android/.gradle/8.5/executionHistory/executionHistory.lock and b/android/.gradle/8.5/executionHistory/executionHistory.lock differ diff --git a/android/.gradle/8.5/fileHashes/fileHashes.bin b/android/.gradle/8.5/fileHashes/fileHashes.bin index 62b32e1..b5390e3 100644 Binary files a/android/.gradle/8.5/fileHashes/fileHashes.bin and b/android/.gradle/8.5/fileHashes/fileHashes.bin differ diff --git a/android/.gradle/8.5/fileHashes/fileHashes.lock b/android/.gradle/8.5/fileHashes/fileHashes.lock index 50a19aa..35955aa 100644 Binary files a/android/.gradle/8.5/fileHashes/fileHashes.lock and b/android/.gradle/8.5/fileHashes/fileHashes.lock differ diff --git a/android/.gradle/8.5/fileHashes/resourceHashesCache.bin b/android/.gradle/8.5/fileHashes/resourceHashesCache.bin index 3bf848f..520d066 100644 Binary files a/android/.gradle/8.5/fileHashes/resourceHashesCache.bin and b/android/.gradle/8.5/fileHashes/resourceHashesCache.bin differ diff --git a/android/.gradle/buildOutputCleanup/buildOutputCleanup.lock b/android/.gradle/buildOutputCleanup/buildOutputCleanup.lock index d301889..aad4f8a 100644 Binary files a/android/.gradle/buildOutputCleanup/buildOutputCleanup.lock and b/android/.gradle/buildOutputCleanup/buildOutputCleanup.lock differ diff --git a/android/.gradle/buildOutputCleanup/cache.properties b/android/.gradle/buildOutputCleanup/cache.properties index f3476bd..83f09d4 100644 --- a/android/.gradle/buildOutputCleanup/cache.properties +++ b/android/.gradle/buildOutputCleanup/cache.properties @@ -1,2 +1,2 @@ -#Mon Apr 21 08:14:16 CST 2025 +#Fri Apr 25 19:53:19 CST 2025 gradle.version=8.5 diff --git a/android/.gradle/buildOutputCleanup/outputFiles.bin b/android/.gradle/buildOutputCleanup/outputFiles.bin index dabf4cc..7aa3d7a 100644 Binary files a/android/.gradle/buildOutputCleanup/outputFiles.bin and b/android/.gradle/buildOutputCleanup/outputFiles.bin differ diff --git a/android/.gradle/config.properties b/android/.gradle/config.properties index 9f9d1d6..6861a6c 100644 --- a/android/.gradle/config.properties +++ b/android/.gradle/config.properties @@ -1,2 +1,2 @@ -#Mon Apr 21 08:12:48 CST 2025 +#Fri Apr 25 19:53:31 CST 2025 java.home=/Applications/Android Studio.app/Contents/jbr/Contents/Home diff --git a/android/.gradle/file-system.probe b/android/.gradle/file-system.probe index d7ed2a9..3821b15 100644 Binary files a/android/.gradle/file-system.probe and b/android/.gradle/file-system.probe differ diff --git a/android/.idea/android.iml b/android/.idea/android.iml new file mode 100644 index 0000000..12bbf74 --- /dev/null +++ b/android/.idea/android.iml @@ -0,0 +1 @@ + diff --git a/android/.idea/assetWizardSettings.xml b/android/.idea/assetWizardSettings.xml new file mode 100644 index 0000000..2a9c5e0 --- /dev/null +++ b/android/.idea/assetWizardSettings.xml @@ -0,0 +1,14 @@ + + + + + + \ No newline at end of file diff --git a/android/.idea/caches/deviceStreaming.xml b/android/.idea/caches/deviceStreaming.xml deleted file mode 100644 index 9e9ba09..0000000 --- a/android/.idea/caches/deviceStreaming.xml +++ /dev/null @@ -1,607 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/android/.idea/misc.xml b/android/.idea/misc.xml index 74dd639..d15a481 100644 --- a/android/.idea/misc.xml +++ b/android/.idea/misc.xml @@ -2,9 +2,6 @@ - - - - \ No newline at end of file diff --git a/android/.idea/render.experimental.xml b/android/.idea/render.experimental.xml new file mode 100644 index 0000000..8ec256a --- /dev/null +++ b/android/.idea/render.experimental.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/android/app/.DS_Store b/android/app/.DS_Store new file mode 100644 index 0000000..e302732 Binary files /dev/null and b/android/app/.DS_Store differ diff --git a/android/app/src/.DS_Store b/android/app/src/.DS_Store new file mode 100644 index 0000000..beeb7cb Binary files /dev/null and b/android/app/src/.DS_Store differ diff --git a/android/app/src/main/.DS_Store b/android/app/src/main/.DS_Store new file mode 100644 index 0000000..8ab4390 Binary files /dev/null and b/android/app/src/main/.DS_Store differ diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index 910942c..f802450 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -13,12 +13,23 @@ android:roundIcon="@drawable/ic_launcher" android:supportsRtl="true" android:theme="@style/Theme.PoseEstimation"> - + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/java/.DS_Store b/android/app/src/main/java/.DS_Store new file mode 100644 index 0000000..7c45a62 Binary files /dev/null and b/android/app/src/main/java/.DS_Store differ diff --git a/android/app/src/main/java/org/.DS_Store b/android/app/src/main/java/org/.DS_Store new file mode 100644 index 0000000..c2ccece Binary files /dev/null and b/android/app/src/main/java/org/.DS_Store differ diff --git a/android/app/src/main/java/org/tensorflow/.DS_Store b/android/app/src/main/java/org/tensorflow/.DS_Store new file mode 100644 index 0000000..dadc68c Binary files /dev/null and b/android/app/src/main/java/org/tensorflow/.DS_Store differ diff --git a/android/app/src/main/java/org/tensorflow/lite/.DS_Store b/android/app/src/main/java/org/tensorflow/lite/.DS_Store new file mode 100644 index 0000000..bd8e1d2 Binary files /dev/null and b/android/app/src/main/java/org/tensorflow/lite/.DS_Store differ diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/.DS_Store b/android/app/src/main/java/org/tensorflow/lite/examples/.DS_Store new file mode 100644 index 0000000..ed82e15 Binary files /dev/null and b/android/app/src/main/java/org/tensorflow/lite/examples/.DS_Store differ diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/AgeSelectionActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/AgeSelectionActivity.kt new file mode 100644 index 0000000..6dda173 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/AgeSelectionActivity.kt @@ -0,0 +1,120 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.view.MotionEvent +import android.widget.ImageButton +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import com.google.android.material.button.MaterialButton +import kotlin.math.abs + +class AgeSelectionActivity : AppCompatActivity() { + private lateinit var selectedAgeText: TextView + private lateinit var age1Above: TextView + private lateinit var age2Above: TextView + private lateinit var age1Below: TextView + private lateinit var age2Below: TextView + private lateinit var nextButton: MaterialButton + private lateinit var backButton: ImageButton + private var selectedGender: String? = null + + private var currentAge = 25 + private val minAge = 12 + private val maxAge = 90 + + private var lastY: Float = 0f + private val scrollSensitivity = 15f // 调整这个值可以改变滑动灵敏度 + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_age_selection) + + selectedGender = intent.getStringExtra("selected_gender") + + selectedAgeText = findViewById(R.id.selectedAgeText) + age1Above = findViewById(R.id.age1Above) + age2Above = findViewById(R.id.age2Above) + age1Below = findViewById(R.id.age1Below) + age2Below = findViewById(R.id.age2Below) + nextButton = findViewById(R.id.nextButton) + backButton = findViewById(R.id.backButton) + + setupUI() + setupClickListeners() + } + + private fun setupUI() { + updateAgeDisplay() + nextButton.isEnabled = true + } + + private fun updateAgeDisplay() { + selectedAgeText.text = currentAge.toString() + + // 更新上方年龄(显示较小的数字) + if (currentAge - 1 >= minAge) { + age1Above.text = (currentAge - 1).toString() + } else { + age1Above.text = "" + } + if (currentAge - 2 >= minAge) { + age2Above.text = (currentAge - 2).toString() + } else { + age2Above.text = "" + } + + // 更新下方年龄(显示较大的数字) + if (currentAge + 1 <= maxAge) { + age1Below.text = (currentAge + 1).toString() + } else { + age1Below.text = "" + } + if (currentAge + 2 <= maxAge) { + age2Below.text = (currentAge + 2).toString() + } else { + age2Below.text = "" + } + } + + private fun setupClickListeners() { + nextButton.setOnClickListener { + val intent = Intent(this, WeightSelectionActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + intent.putExtra("selected_age", currentAge) + startActivity(intent) + finish() + } + + backButton.setOnClickListener { + finish() + } + } + + override fun onTouchEvent(event: MotionEvent): Boolean { + when (event.action) { + MotionEvent.ACTION_DOWN -> { + lastY = event.y + return true + } + MotionEvent.ACTION_MOVE -> { + val currentY = event.y + val deltaY = currentY - lastY + + // 计算应该改变多少年龄 + val change = -(deltaY / scrollSensitivity).toInt() + if (abs(change) > 0) { + // 更新年龄 + val newAge = (currentAge + change).coerceIn(minAge, maxAge) + if (newAge != currentAge) { + currentAge = newAge + updateAgeDisplay() + lastY = currentY + } + } + return true + } + } + return super.onTouchEvent(event) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/CustomWeightPicker.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/CustomWeightPicker.kt new file mode 100644 index 0000000..67fa700 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/CustomWeightPicker.kt @@ -0,0 +1,105 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Context +import android.graphics.Canvas +import android.graphics.Paint +import android.graphics.Path +import android.util.AttributeSet +import android.view.MotionEvent +import android.view.View +import androidx.core.content.ContextCompat +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min + +class CustomWeightPicker @JvmOverloads constructor( + context: Context, + attrs: AttributeSet? = null, + defStyleAttr: Int = 0 +) : View(context, attrs, defStyleAttr) { + + private val paint = Paint().apply { + isAntiAlias = true + color = ContextCompat.getColor(context, R.color.purple_500) + } + + private val path = Path() + private var lastX = 0f + private var scrollOffset = 0f + private var selectedWeight = 54 + private val minWeight = 30 + private val maxWeight = 200 + private val scaleWidth = 40f // 每个刻度的宽度 + private val scaleHeight = 20f // 刻度的高度 + private val centerLineWidth = 2f + + private var onWeightChangeListener: ((Int) -> Unit)? = null + + override fun onDraw(canvas: Canvas) { + super.onDraw(canvas) + + val centerX = width / 2f + val centerY = height / 2f + + // 绘制刻度 + val startX = centerX + scrollOffset + val startWeight = selectedWeight - (centerX / scaleWidth).toInt() + + for (i in -10..10) { + val x = startX + i * scaleWidth + val weight = startWeight + i + + if (weight in minWeight..maxWeight) { + // 绘制刻度线 + canvas.drawLine( + x, + centerY - scaleHeight, + x, + centerY + scaleHeight, + paint + ) + + // 绘制刻度值 + if (i % 2 == 0) { + canvas.drawText( + weight.toString(), + x - 10f, + centerY + scaleHeight + 20f, + paint + ) + } + } + } + } + + override fun onTouchEvent(event: MotionEvent): Boolean { + when (event.action) { + MotionEvent.ACTION_DOWN -> { + lastX = event.x + return true + } + MotionEvent.ACTION_MOVE -> { + val deltaX = event.x - lastX + scrollOffset += deltaX + + // 计算新的体重值 + val newWeight = selectedWeight - (deltaX / scaleWidth).toInt() + if (newWeight in minWeight..maxWeight) { + selectedWeight = newWeight + onWeightChangeListener?.invoke(selectedWeight) + } + + lastX = event.x + invalidate() + return true + } + } + return super.onTouchEvent(event) + } + + fun setOnWeightChangeListener(listener: (Int) -> Unit) { + onWeightChangeListener = listener + } + + fun getSelectedWeight(): Int = selectedWeight +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/GenderSelectionActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/GenderSelectionActivity.kt new file mode 100644 index 0000000..202b14e --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/GenderSelectionActivity.kt @@ -0,0 +1,70 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import androidx.cardview.widget.CardView +import com.google.android.material.button.MaterialButton + +class GenderSelectionActivity : AppCompatActivity() { + private lateinit var maleButton: CardView + private lateinit var femaleButton: CardView + private lateinit var maleText: TextView + private lateinit var femaleText: TextView + private lateinit var nextButton: MaterialButton + private var selectedGender: String? = null + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_gender_selection) + + maleButton = findViewById(R.id.maleButton) + femaleButton = findViewById(R.id.femaleButton) + maleText = findViewById(R.id.maleText) + femaleText = findViewById(R.id.femaleText) + nextButton = findViewById(R.id.nextButton) + + setupClickListeners() + } + + private fun setupClickListeners() { + maleButton.setOnClickListener { + updateSelection("male") + } + + femaleButton.setOnClickListener { + updateSelection("female") + } + + nextButton.setOnClickListener { + // 跳转到年龄选择页面,并传递性别信息 + val intent = Intent(this, AgeSelectionActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + startActivity(intent) + } + } + + private fun updateSelection(gender: String) { + selectedGender = gender + + // 更新UI状态 + when (gender) { + "male" -> { + maleButton.setCardBackgroundColor(getColor(android.R.color.holo_purple)) + femaleButton.setCardBackgroundColor(getColor(android.R.color.darker_gray)) + maleText.setTextColor(getColor(android.R.color.holo_purple)) + femaleText.setTextColor(getColor(android.R.color.white)) + } + "female" -> { + femaleButton.setCardBackgroundColor(getColor(android.R.color.holo_purple)) + maleButton.setCardBackgroundColor(getColor(android.R.color.darker_gray)) + femaleText.setTextColor(getColor(android.R.color.holo_purple)) + maleText.setTextColor(getColor(android.R.color.white)) + } + } + + // 启用Next按钮 + nextButton.isEnabled = true + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/HeightSelectionActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/HeightSelectionActivity.kt new file mode 100644 index 0000000..84f0974 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/HeightSelectionActivity.kt @@ -0,0 +1,129 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.view.MotionEvent +import android.widget.ImageButton +import android.widget.LinearLayout +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import com.google.android.material.button.MaterialButton +import kotlin.math.abs + +class HeightSelectionActivity : AppCompatActivity() { + private lateinit var selectedHeightText: TextView + private lateinit var heightUnit: TextView + private lateinit var height1Above: TextView + private lateinit var height2Above: TextView + private lateinit var height1Below: TextView + private lateinit var height2Below: TextView + private lateinit var nextButton: MaterialButton + private lateinit var backButton: ImageButton + private var selectedGender: String? = null + private var selectedAge: Int = 0 + private var selectedWeight: Int = 0 + + private var currentHeight = 167 + private val minHeight = 100 + private val maxHeight = 220 + + private var lastY: Float = 0f + private val scrollSensitivity = 15f // 调整这个值可以改变滑动灵敏度 + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_height_selection) + + selectedGender = intent.getStringExtra("selected_gender") + selectedAge = intent.getIntExtra("selected_age", 0) + selectedWeight = intent.getIntExtra("selected_weight", 0) + + selectedHeightText = findViewById(R.id.selectedHeightText) + heightUnit = findViewById(R.id.heightUnit) + height1Above = findViewById(R.id.height1Above) + height2Above = findViewById(R.id.height2Above) + height1Below = findViewById(R.id.height1Below) + height2Below = findViewById(R.id.height2Below) + nextButton = findViewById(R.id.nextButton) + backButton = findViewById(R.id.backButton) + + setupUI() + setupClickListeners() + } + + private fun setupUI() { + updateHeightDisplay() + nextButton.isEnabled = true + } + + private fun updateHeightDisplay() { + selectedHeightText.text = currentHeight.toString() + // 上方身高 + if (currentHeight - 1 >= minHeight) { + height1Above.text = (currentHeight - 1).toString() + } else { + height1Above.text = "" + } + if (currentHeight - 2 >= minHeight) { + height2Above.text = (currentHeight - 2).toString() + } else { + height2Above.text = "" + } + // 下方身高 + if (currentHeight + 1 <= maxHeight) { + height1Below.text = (currentHeight + 1).toString() + } else { + height1Below.text = "" + } + if (currentHeight + 2 <= maxHeight) { + height2Below.text = (currentHeight + 2).toString() + } else { + height2Below.text = "" + } + } + + private fun setupClickListeners() { + nextButton.setOnClickListener { + val intent = Intent(this, MainActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + intent.putExtra("selected_age", selectedAge) + intent.putExtra("selected_weight", selectedWeight) + intent.putExtra("selected_height", currentHeight) + startActivity(intent) + finish() + } + + backButton.setOnClickListener { + val intent = Intent(this, WeightSelectionActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + intent.putExtra("selected_age", selectedAge) + intent.putExtra("selected_weight", selectedWeight) + startActivity(intent) + finish() + } + } + + override fun onTouchEvent(event: MotionEvent): Boolean { + when (event.action) { + MotionEvent.ACTION_DOWN -> { + lastY = event.y + return true + } + MotionEvent.ACTION_MOVE -> { + val currentY = event.y + val deltaY = currentY - lastY + val change = -(deltaY / scrollSensitivity).toInt() + if (abs(change) > 0) { + val newHeight = (currentHeight + change).coerceIn(minHeight, maxHeight) + if (newHeight != currentHeight) { + currentHeight = newHeight + updateHeightDisplay() + lastY = currentY + } + } + return true + } + } + return super.onTouchEvent(event) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding1Fragment.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding1Fragment.kt new file mode 100644 index 0000000..e4e910b --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding1Fragment.kt @@ -0,0 +1,16 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.os.Bundle +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import androidx.fragment.app.Fragment + +class Onboarding1Fragment : Fragment() { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { + return inflater.inflate(R.layout.activity_onboarding1, container, false) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding2Fragment.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding2Fragment.kt new file mode 100644 index 0000000..4157c36 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding2Fragment.kt @@ -0,0 +1,16 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.os.Bundle +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import androidx.fragment.app.Fragment + +class Onboarding2Fragment : Fragment() { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { + return inflater.inflate(R.layout.activity_onboarding2, container, false) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding3Fragment.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding3Fragment.kt new file mode 100644 index 0000000..3ad9896 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/Onboarding3Fragment.kt @@ -0,0 +1,29 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.FrameLayout +import androidx.fragment.app.Fragment + +class Onboarding3Fragment : Fragment() { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { + val view = inflater.inflate(R.layout.activity_onboarding3, container, false) + + // 找到Start now按钮并设置点击事件 + val startButton = view.findViewById(R.id.small_butto_container) + startButton.setOnClickListener { + // 跳转到性别选择页面 + val intent = Intent(requireActivity(), GenderSelectionActivity::class.java) + startActivity(intent) + requireActivity().finish() // 结束当前的OnboardingActivity + } + + return view + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingActivity.kt new file mode 100644 index 0000000..aa1a487 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingActivity.kt @@ -0,0 +1,15 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.os.Bundle +import androidx.appcompat.app.AppCompatActivity +import androidx.viewpager2.widget.ViewPager2 + +class OnboardingActivity : AppCompatActivity() { + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_onboarding) + + val viewPager = findViewById(R.id.viewPager) + viewPager.adapter = OnboardingAdapter(this) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingAdapter.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingAdapter.kt new file mode 100644 index 0000000..7c91147 --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/OnboardingAdapter.kt @@ -0,0 +1,18 @@ +package org.tensorflow.lite.examples.poseestimation + +import androidx.fragment.app.Fragment +import androidx.fragment.app.FragmentActivity +import androidx.viewpager2.adapter.FragmentStateAdapter + +class OnboardingAdapter(activity: FragmentActivity) : FragmentStateAdapter(activity) { + override fun getItemCount(): Int = 3 + + override fun createFragment(position: Int): Fragment { + return when (position) { + 0 -> Onboarding1Fragment() + 1 -> Onboarding2Fragment() + 2 -> Onboarding3Fragment() + else -> Onboarding1Fragment() + } + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/SplashActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/SplashActivity.kt new file mode 100644 index 0000000..546417c --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/SplashActivity.kt @@ -0,0 +1,50 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.os.Handler +import android.os.Looper +import android.view.View +import android.widget.RelativeLayout +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity + +class SplashActivity : AppCompatActivity() { + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_splash) + + val textView = findViewById(R.id.some_id) + val vector1 = findViewById(R.id.vector1) + val vector2 = findViewById(R.id.vector2) + + // 获取屏幕高度 + val screenHeight = resources.displayMetrics.heightPixels + val margin = screenHeight / 8 + + textView.post { + val textViewLocation = IntArray(2) + textView.getLocationOnScreen(textViewLocation) + val textViewTop = textViewLocation[1] + val textViewBottom = textViewTop + textView.height + + // 设置vector1在TextView上方,右对齐 + val params1 = vector1.layoutParams as RelativeLayout.LayoutParams + params1.addRule(RelativeLayout.ALIGN_PARENT_RIGHT) + params1.topMargin = textViewTop - margin - vector1.layoutParams.height + vector1.layoutParams = params1 + + // 设置vector2在TextView下方,左对齐 + val params2 = vector2.layoutParams as RelativeLayout.LayoutParams + params2.addRule(RelativeLayout.ALIGN_PARENT_LEFT) + params2.topMargin = textViewBottom + margin + vector2.layoutParams = params2 + } + + // 2秒后跳转到引导页 + Handler(Looper.getMainLooper()).postDelayed({ + startActivity(Intent(this, OnboardingActivity::class.java)) + finish() + }, 2000) + } +} \ No newline at end of file diff --git a/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/WeightSelectionActivity.kt b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/WeightSelectionActivity.kt new file mode 100644 index 0000000..1384a9b --- /dev/null +++ b/android/app/src/main/java/org/tensorflow/lite/examples/poseestimation/WeightSelectionActivity.kt @@ -0,0 +1,128 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Intent +import android.os.Bundle +import android.view.MotionEvent +import android.widget.ImageButton +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import com.google.android.material.button.MaterialButton +import kotlin.math.abs + +class WeightSelectionActivity : AppCompatActivity() { + private lateinit var selectedWeightText: TextView + private lateinit var weight1Above: TextView + private lateinit var weight2Above: TextView + private lateinit var weight1Below: TextView + private lateinit var weight2Below: TextView + private lateinit var nextButton: MaterialButton + private lateinit var backButton: ImageButton + + private var selectedGender: String? = null + private var selectedAge: Int = 0 + private var currentWeight = 54 + + private var lastY: Float = 0f + private val scrollSensitivity = 15f // 调整这个值可以改变滑动灵敏度 + private val minWeight = 30 + private val maxWeight = 200 + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_weight_selection) + + // 获取从上一个页面传递的数据 + selectedGender = intent.getStringExtra("selected_gender") + selectedAge = intent.getIntExtra("selected_age", 0) + + // 初始化视图 + selectedWeightText = findViewById(R.id.selectedWeightText) + weight1Above = findViewById(R.id.weight1Above) + weight2Above = findViewById(R.id.weight2Above) + weight1Below = findViewById(R.id.weight1Below) + weight2Below = findViewById(R.id.weight2Below) + nextButton = findViewById(R.id.nextButton) + backButton = findViewById(R.id.backButton) + + setupUI() + setupClickListeners() + } + + private fun setupUI() { + updateWeightDisplay() + nextButton.isEnabled = true + } + + private fun updateWeightDisplay() { + selectedWeightText.text = currentWeight.toString() + + // 更新上方体重(显示较小的数字) + if (currentWeight - 1 >= minWeight) { + weight1Above.text = (currentWeight - 1).toString() + } else { + weight1Above.text = "" + } + if (currentWeight - 2 >= minWeight) { + weight2Above.text = (currentWeight - 2).toString() + } else { + weight2Above.text = "" + } + + // 更新下方体重(显示较大的数字) + if (currentWeight + 1 <= maxWeight) { + weight1Below.text = (currentWeight + 1).toString() + } else { + weight1Below.text = "" + } + if (currentWeight + 2 <= maxWeight) { + weight2Below.text = (currentWeight + 2).toString() + } else { + weight2Below.text = "" + } + } + + private fun setupClickListeners() { + nextButton.setOnClickListener { + val intent = Intent(this, HeightSelectionActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + intent.putExtra("selected_age", selectedAge) + intent.putExtra("selected_weight", currentWeight) + startActivity(intent) + finish() + } + + backButton.setOnClickListener { + val intent = Intent(this, AgeSelectionActivity::class.java) + intent.putExtra("selected_gender", selectedGender) + startActivity(intent) + finish() + } + } + + override fun onTouchEvent(event: MotionEvent): Boolean { + when (event.action) { + MotionEvent.ACTION_DOWN -> { + lastY = event.y + return true + } + MotionEvent.ACTION_MOVE -> { + val currentY = event.y + val deltaY = currentY - lastY + + // 计算应该改变多少体重 + val change = -(deltaY / scrollSensitivity).toInt() + if (abs(change) > 0) { + // 更新体重 + val newWeight = (currentWeight + change).coerceIn(minWeight, maxWeight) + if (newWeight != currentWeight) { + currentWeight = newWeight + updateWeightDisplay() + lastY = currentY + } + } + return true + } + } + return super.onTouchEvent(event) + } +} \ No newline at end of file diff --git a/android/app/src/main/res/drawable/circle_button_background.xml b/android/app/src/main/res/drawable/circle_button_background.xml new file mode 100644 index 0000000..70ad4dd --- /dev/null +++ b/android/app/src/main/res/drawable/circle_button_background.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/android/app/src/main/res/drawable/female.png b/android/app/src/main/res/drawable/female.png new file mode 100644 index 0000000..c32634a Binary files /dev/null and b/android/app/src/main/res/drawable/female.png differ diff --git a/android/app/src/main/res/drawable/ic_back.xml b/android/app/src/main/res/drawable/ic_back.xml new file mode 100644 index 0000000..791df00 --- /dev/null +++ b/android/app/src/main/res/drawable/ic_back.xml @@ -0,0 +1,10 @@ + + + + \ No newline at end of file diff --git a/android/app/src/main/res/drawable/indicator1.png b/android/app/src/main/res/drawable/indicator1.png new file mode 100644 index 0000000..6dd2dab Binary files /dev/null and b/android/app/src/main/res/drawable/indicator1.png differ diff --git a/android/app/src/main/res/drawable/indicator2.png b/android/app/src/main/res/drawable/indicator2.png new file mode 100644 index 0000000..8f9e328 Binary files /dev/null and b/android/app/src/main/res/drawable/indicator2.png differ diff --git a/android/app/src/main/res/drawable/indicator3.png b/android/app/src/main/res/drawable/indicator3.png new file mode 100644 index 0000000..94fdbb2 Binary files /dev/null and b/android/app/src/main/res/drawable/indicator3.png differ diff --git a/android/app/src/main/res/drawable/male.png b/android/app/src/main/res/drawable/male.png new file mode 100644 index 0000000..1ac479c Binary files /dev/null and b/android/app/src/main/res/drawable/male.png differ diff --git a/android/app/src/main/res/drawable/onboarding1_man.png b/android/app/src/main/res/drawable/onboarding1_man.png new file mode 100644 index 0000000..1b661d4 Binary files /dev/null and b/android/app/src/main/res/drawable/onboarding1_man.png differ diff --git a/android/app/src/main/res/drawable/onboarding2_woman.png b/android/app/src/main/res/drawable/onboarding2_woman.png new file mode 100644 index 0000000..0865aad Binary files /dev/null and b/android/app/src/main/res/drawable/onboarding2_woman.png differ diff --git a/android/app/src/main/res/drawable/onboarding3_man.png b/android/app/src/main/res/drawable/onboarding3_man.png new file mode 100644 index 0000000..58103e6 Binary files /dev/null and b/android/app/src/main/res/drawable/onboarding3_man.png differ diff --git a/android/app/src/main/res/drawable/small_butto.xml b/android/app/src/main/res/drawable/small_butto.xml new file mode 100644 index 0000000..3bcd78e --- /dev/null +++ b/android/app/src/main/res/drawable/small_butto.xml @@ -0,0 +1,18 @@ + + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/drawable/vector.xml b/android/app/src/main/res/drawable/vector.xml new file mode 100644 index 0000000..69f30de --- /dev/null +++ b/android/app/src/main/res/drawable/vector.xml @@ -0,0 +1,14 @@ + + + + + diff --git a/android/app/src/main/res/font/.placeholder b/android/app/src/main/res/font/.placeholder new file mode 100644 index 0000000..bbc10f8 --- /dev/null +++ b/android/app/src/main/res/font/.placeholder @@ -0,0 +1 @@ +// 该文件仅用于占位,实际字体文件请放在本目录下。 \ No newline at end of file diff --git a/android/app/src/main/res/layout/activity_age_selection.xml b/android/app/src/main/res/layout/activity_age_selection.xml new file mode 100644 index 0000000..38f37ef --- /dev/null +++ b/android/app/src/main/res/layout/activity_age_selection.xml @@ -0,0 +1,147 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/layout/activity_gender_selection.xml b/android/app/src/main/res/layout/activity_gender_selection.xml new file mode 100644 index 0000000..0e26131 --- /dev/null +++ b/android/app/src/main/res/layout/activity_gender_selection.xml @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/layout/activity_height_selection.xml b/android/app/src/main/res/layout/activity_height_selection.xml new file mode 100644 index 0000000..6f2700c --- /dev/null +++ b/android/app/src/main/res/layout/activity_height_selection.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/layout/activity_onboarding.xml b/android/app/src/main/res/layout/activity_onboarding.xml new file mode 100644 index 0000000..75e823b --- /dev/null +++ b/android/app/src/main/res/layout/activity_onboarding.xml @@ -0,0 +1,12 @@ + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/layout/activity_onboarding1.xml b/android/app/src/main/res/layout/activity_onboarding1.xml new file mode 100644 index 0000000..9126c96 --- /dev/null +++ b/android/app/src/main/res/layout/activity_onboarding1.xml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + diff --git a/android/app/src/main/res/layout/activity_onboarding2.xml b/android/app/src/main/res/layout/activity_onboarding2.xml new file mode 100644 index 0000000..8784432 --- /dev/null +++ b/android/app/src/main/res/layout/activity_onboarding2.xml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + diff --git a/android/app/src/main/res/layout/activity_onboarding3.xml b/android/app/src/main/res/layout/activity_onboarding3.xml new file mode 100644 index 0000000..62346ee --- /dev/null +++ b/android/app/src/main/res/layout/activity_onboarding3.xml @@ -0,0 +1,82 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/app/src/main/res/layout/activity_splash.xml b/android/app/src/main/res/layout/activity_splash.xml new file mode 100644 index 0000000..b7f3476 --- /dev/null +++ b/android/app/src/main/res/layout/activity_splash.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/android/app/src/main/res/layout/activity_weight_selection.xml b/android/app/src/main/res/layout/activity_weight_selection.xml new file mode 100644 index 0000000..de0443a --- /dev/null +++ b/android/app/src/main/res/layout/activity_weight_selection.xml @@ -0,0 +1,151 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/res/values/colors.xml b/android/app/src/main/res/values/colors.xml index 648cfde..cfa4818 100644 --- a/android/app/src/main/res/values/colors.xml +++ b/android/app/src/main/res/values/colors.xml @@ -6,4 +6,6 @@ #FF018786 #FF000000 #FFFFFFFF + #757575 + #E0E0E0 diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml index 79bfc21..e0e323a 100644 --- a/android/app/src/main/res/values/strings.xml +++ b/android/app/src/main/res/values/strings.xml @@ -1,4 +1,5 @@ + TFL Pose Estimation This app needs camera permission. Score: %.2f @@ -27,4 +28,12 @@ BoundingBox Keypoint + + + 形动力 + Meet your coach,\nstart your journey + Create a workout plan\nto stay fit + Action is the\nkey to all success + Start Now + Next diff --git a/android/app/src/main/res/values/themes.xml b/android/app/src/main/res/values/themes.xml index 0414b39..29c4a51 100644 --- a/android/app/src/main/res/values/themes.xml +++ b/android/app/src/main/res/values/themes.xml @@ -1,5 +1,5 @@ - + + diff --git a/android/local.properties b/android/local.properties index 5112cf9..bc2a713 100644 --- a/android/local.properties +++ b/android/local.properties @@ -4,5 +4,5 @@ # Location of the SDK. This is only used by Gradle. # For customization when using a Version Control System, please read the # header note. -#Mon Apr 21 08:12:48 CST 2025 +#Fri Apr 25 19:53:17 CST 2025 sdk.dir=/Users/ziyue/Library/Android/sdk diff --git a/android1/.gradle/8.5/checksums/checksums.lock b/android1/.gradle/8.5/checksums/checksums.lock new file mode 100644 index 0000000..5edef46 Binary files /dev/null and b/android1/.gradle/8.5/checksums/checksums.lock differ diff --git a/android/.gradle/8.5/checksums/md5-checksums.bin b/android1/.gradle/8.5/checksums/md5-checksums.bin similarity index 85% rename from android/.gradle/8.5/checksums/md5-checksums.bin rename to android1/.gradle/8.5/checksums/md5-checksums.bin index 8c46e40..b64501d 100644 Binary files a/android/.gradle/8.5/checksums/md5-checksums.bin and b/android1/.gradle/8.5/checksums/md5-checksums.bin differ diff --git a/android/.gradle/8.5/checksums/sha1-checksums.bin b/android1/.gradle/8.5/checksums/sha1-checksums.bin similarity index 94% rename from android/.gradle/8.5/checksums/sha1-checksums.bin rename to android1/.gradle/8.5/checksums/sha1-checksums.bin index 26e9184..59cbbbb 100644 Binary files a/android/.gradle/8.5/checksums/sha1-checksums.bin and b/android1/.gradle/8.5/checksums/sha1-checksums.bin differ diff --git a/android1/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock b/android1/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock new file mode 100644 index 0000000..1737bfd Binary files /dev/null and b/android1/.gradle/8.5/dependencies-accessors/dependencies-accessors.lock differ diff --git a/android1/.gradle/8.5/dependencies-accessors/gc.properties b/android1/.gradle/8.5/dependencies-accessors/gc.properties new file mode 100644 index 0000000..e69de29 diff --git a/android1/.gradle/8.5/executionHistory/executionHistory.bin b/android1/.gradle/8.5/executionHistory/executionHistory.bin new file mode 100644 index 0000000..5ea72cf Binary files /dev/null and b/android1/.gradle/8.5/executionHistory/executionHistory.bin differ diff --git a/android1/.gradle/8.5/executionHistory/executionHistory.lock b/android1/.gradle/8.5/executionHistory/executionHistory.lock new file mode 100644 index 0000000..7abdfdb Binary files /dev/null and b/android1/.gradle/8.5/executionHistory/executionHistory.lock differ diff --git a/android1/.gradle/8.5/fileChanges/last-build.bin b/android1/.gradle/8.5/fileChanges/last-build.bin new file mode 100644 index 0000000..f76dd23 Binary files /dev/null and b/android1/.gradle/8.5/fileChanges/last-build.bin differ diff --git a/android1/.gradle/8.5/fileHashes/fileHashes.bin b/android1/.gradle/8.5/fileHashes/fileHashes.bin new file mode 100644 index 0000000..cf82da6 Binary files /dev/null and b/android1/.gradle/8.5/fileHashes/fileHashes.bin differ diff --git a/android1/.gradle/8.5/fileHashes/fileHashes.lock b/android1/.gradle/8.5/fileHashes/fileHashes.lock new file mode 100644 index 0000000..0459c1e Binary files /dev/null and b/android1/.gradle/8.5/fileHashes/fileHashes.lock differ diff --git a/android1/.gradle/8.5/fileHashes/resourceHashesCache.bin b/android1/.gradle/8.5/fileHashes/resourceHashesCache.bin new file mode 100644 index 0000000..3bf848f Binary files /dev/null and b/android1/.gradle/8.5/fileHashes/resourceHashesCache.bin differ diff --git a/android1/.gradle/8.5/gc.properties b/android1/.gradle/8.5/gc.properties new file mode 100644 index 0000000..e69de29 diff --git a/android1/.gradle/buildOutputCleanup/buildOutputCleanup.lock b/android1/.gradle/buildOutputCleanup/buildOutputCleanup.lock new file mode 100644 index 0000000..e5679c3 Binary files /dev/null and b/android1/.gradle/buildOutputCleanup/buildOutputCleanup.lock differ diff --git a/android1/.gradle/buildOutputCleanup/cache.properties b/android1/.gradle/buildOutputCleanup/cache.properties new file mode 100644 index 0000000..f3476bd --- /dev/null +++ b/android1/.gradle/buildOutputCleanup/cache.properties @@ -0,0 +1,2 @@ +#Mon Apr 21 08:14:16 CST 2025 +gradle.version=8.5 diff --git a/android1/.gradle/buildOutputCleanup/outputFiles.bin b/android1/.gradle/buildOutputCleanup/outputFiles.bin new file mode 100644 index 0000000..f1bac22 Binary files /dev/null and b/android1/.gradle/buildOutputCleanup/outputFiles.bin differ diff --git a/android1/.gradle/config.properties b/android1/.gradle/config.properties new file mode 100644 index 0000000..9f9d1d6 --- /dev/null +++ b/android1/.gradle/config.properties @@ -0,0 +1,2 @@ +#Mon Apr 21 08:12:48 CST 2025 +java.home=/Applications/Android Studio.app/Contents/jbr/Contents/Home diff --git a/android1/.gradle/file-system.probe b/android1/.gradle/file-system.probe new file mode 100644 index 0000000..06ff6fa Binary files /dev/null and b/android1/.gradle/file-system.probe differ diff --git a/android1/.gradle/vcs-1/gc.properties b/android1/.gradle/vcs-1/gc.properties new file mode 100644 index 0000000..e69de29 diff --git a/android1/.idea/.gitignore b/android1/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/android1/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/android1/.idea/.name b/android1/.idea/.name new file mode 100644 index 0000000..1c0e605 --- /dev/null +++ b/android1/.idea/.name @@ -0,0 +1 @@ +TFLite Pose Estimation \ No newline at end of file diff --git a/android1/.idea/AndroidProjectSystem.xml b/android1/.idea/AndroidProjectSystem.xml new file mode 100644 index 0000000..4a53bee --- /dev/null +++ b/android1/.idea/AndroidProjectSystem.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/android1/.idea/compiler.xml b/android1/.idea/compiler.xml new file mode 100644 index 0000000..b86273d --- /dev/null +++ b/android1/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/android1/.idea/deploymentTargetSelector.xml b/android1/.idea/deploymentTargetSelector.xml new file mode 100644 index 0000000..b268ef3 --- /dev/null +++ b/android1/.idea/deploymentTargetSelector.xml @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/android1/.idea/gradle.xml b/android1/.idea/gradle.xml new file mode 100644 index 0000000..639c779 --- /dev/null +++ b/android1/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/android1/.idea/kotlinc.xml b/android1/.idea/kotlinc.xml new file mode 100644 index 0000000..ae3f30a --- /dev/null +++ b/android1/.idea/kotlinc.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/android1/.idea/migrations.xml b/android1/.idea/migrations.xml new file mode 100644 index 0000000..f8051a6 --- /dev/null +++ b/android1/.idea/migrations.xml @@ -0,0 +1,10 @@ + + + + + + \ No newline at end of file diff --git a/android1/.idea/misc.xml b/android1/.idea/misc.xml new file mode 100644 index 0000000..b2c751a --- /dev/null +++ b/android1/.idea/misc.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/android1/.idea/runConfigurations.xml b/android1/.idea/runConfigurations.xml new file mode 100644 index 0000000..16660f1 --- /dev/null +++ b/android1/.idea/runConfigurations.xml @@ -0,0 +1,17 @@ + + + + + + \ No newline at end of file diff --git a/android1/.idea/vcs.xml b/android1/.idea/vcs.xml new file mode 100644 index 0000000..6c0b863 --- /dev/null +++ b/android1/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/android1/README.md b/android1/README.md new file mode 100644 index 0000000..2293f99 --- /dev/null +++ b/android1/README.md @@ -0,0 +1,73 @@ +# TensorFlow Lite Pose Estimation Android Demo + +### Overview +This is an app that continuously detects the body parts in the frames seen by +your device's camera. These instructions walk you through building and running +the demo on an Android device. Camera captures are discarded immediately after +use, nothing is stored or saved. + +The app demonstrates how to use 4 models: + +* Single pose models: The model can estimate the pose of only one person in the +input image. If the input image contains multiple persons, the detection result +can be largely incorrect. + * PoseNet + * MoveNet Lightning + * MoveNet Thunder +* Multi pose models: The model can estimate pose of multiple persons in the +input image. + * MoveNet MultiPose: Support up to 6 persons. + +See this [blog post](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html) +for a comparison between these models. + +![Demo Image](posenetimage.png) + +## Build the demo using Android Studio + +### Prerequisites + +* If you don't have it already, install **[Android Studio]( + https://developer.android.com/studio/index.html)** 4.2 or + above, following the instructions on the website. + +* Android device and Android development environment with minimum API 21. + +### Building +* Open Android Studio, and from the `Welcome` screen, select +`Open an existing Android Studio project`. + +* From the `Open File or Project` window that appears, navigate to and select + the `lite/examples/pose_estimation/android` directory from wherever you + cloned the `tensorflow/examples` GitHub repo. Click `OK`. + +* If it asks you to do a `Gradle Sync`, click `OK`. + +* You may also need to install various platforms and tools, if you get errors + like `Failed to find target with hash string 'android-21'` and similar. Click + the `Run` button (the green arrow) or select `Run` > `Run 'android'` from the + top menu. You may need to rebuild the project using `Build` > `Rebuild Project`. + +* If it asks you to use `Instant Run`, click `Proceed Without Instant Run`. + +* Also, you need to have an Android device plugged in with developer options + enabled at this point. See **[here]( + https://developer.android.com/studio/run/device)** for more details + on setting up developer devices. + + +### Model used +Downloading, extraction and placement in assets folder has been managed + automatically by `download.gradle`. + +If you explicitly want to download the model, you can download it from here: + +* [Posenet](https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite) +* [Movenet Lightning](https://kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-lightning) +* [Movenet Thunder](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-thunder) +* [Movenet MultiPose](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/multipose-lightning-tflite-float16) + +### Additional Note +_Please do not delete the assets folder content_. If you explicitly deleted the + files, then please choose `Build` > `Rebuild` from menu to re-download the + deleted model files into assets folder. diff --git a/android1/app/.gitignore b/android1/app/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/android1/app/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/android1/app/build.gradle b/android1/app/build.gradle new file mode 100644 index 0000000..d1545ae --- /dev/null +++ b/android1/app/build.gradle @@ -0,0 +1,56 @@ +plugins { + id 'com.android.application' + id 'kotlin-android' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "org.tensorflow.lite.examples.poseestimation" + minSdkVersion 23 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + + namespace "org.tensorflow.lite.examples.poseestimation" + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = '1.8' + } +} + +// Download tflite model +apply from:"download.gradle" + +dependencies { + + implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" + implementation 'androidx.core:core-ktx:1.5.0' + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation "androidx.activity:activity-ktx:1.2.3" + implementation 'androidx.fragment:fragment-ktx:1.3.5' + implementation 'org.tensorflow:tensorflow-lite:2.14.0' + implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0' + implementation 'org.tensorflow:tensorflow-lite-support:0.3.0' + + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + androidTestImplementation "com.google.truth:truth:1.1.3" +} diff --git a/android1/app/download.gradle b/android1/app/download.gradle new file mode 100644 index 0000000..423344d --- /dev/null +++ b/android1/app/download.gradle @@ -0,0 +1,67 @@ +task downloadPosenetModel(type: DownloadUrlTask) { + def modelPosenetDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite" + doFirst { + println "Downloading ${modelPosenetDownloadUrl}" + } + sourceUrl = "${modelPosenetDownloadUrl}" + target = file("src/main/assets/posenet.tflite") +} + +task downloadMovenetLightningModel(type: DownloadUrlTask) { + def modelMovenetLightningDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite" + doFirst { + println "Downloading ${modelMovenetLightningDownloadUrl}" + } + sourceUrl = "${modelMovenetLightningDownloadUrl}" + target = file("src/main/assets/movenet_lightning.tflite") +} + +task downloadMovenetThunderModel(type: DownloadUrlTask) { + def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite" + doFirst { + println "Downloading ${modelMovenetThunderDownloadUrl}" + } + sourceUrl = "${modelMovenetThunderDownloadUrl}" + target = file("src/main/assets/movenet_thunder.tflite") +} + +task downloadMovenetMultiPoseModel(type: DownloadUrlTask) { + def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite" + doFirst { + println "Downloading ${modelMovenetThunderDownloadUrl}" + } + sourceUrl = "${modelMovenetThunderDownloadUrl}" + target = file("src/main/assets/movenet_multipose_fp16.tflite") +} + +task downloadPoseClassifierModel(type: DownloadUrlTask) { + def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite" + doFirst { + println "Downloading ${modelPoseClassifierDownloadUrl}" + } + sourceUrl = "${modelPoseClassifierDownloadUrl}" + target = file("src/main/assets/classifier.tflite") +} + +task downloadModel { + dependsOn downloadPosenetModel + dependsOn downloadMovenetLightningModel + dependsOn downloadMovenetThunderModel + dependsOn downloadPoseClassifierModel + dependsOn downloadMovenetMultiPoseModel +} + +class DownloadUrlTask extends DefaultTask { + @Input + String sourceUrl + + @OutputFile + File target + + @TaskAction + void download() { + ant.get(src: sourceUrl, dest: target) + } +} + +preBuild.dependsOn downloadModel diff --git a/android1/app/proguard-rules.pro b/android1/app/proguard-rules.pro new file mode 100644 index 0000000..f1b4245 --- /dev/null +++ b/android1/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/android1/app/src/androidTest/assets/image1.png b/android1/app/src/androidTest/assets/image1.png new file mode 100644 index 0000000..4085b3d Binary files /dev/null and b/android1/app/src/androidTest/assets/image1.png differ diff --git a/android1/app/src/androidTest/assets/image2.jpg b/android1/app/src/androidTest/assets/image2.jpg new file mode 100644 index 0000000..8db9892 Binary files /dev/null and b/android1/app/src/androidTest/assets/image2.jpg differ diff --git a/android1/app/src/androidTest/assets/image3.jpeg b/android1/app/src/androidTest/assets/image3.jpeg new file mode 100644 index 0000000..f310928 Binary files /dev/null and b/android1/app/src/androidTest/assets/image3.jpeg differ diff --git a/android1/app/src/androidTest/assets/image_credits.txt b/android1/app/src/androidTest/assets/image_credits.txt new file mode 100644 index 0000000..b3c3888 --- /dev/null +++ b/android1/app/src/androidTest/assets/image_credits.txt @@ -0,0 +1,3 @@ +Image1: https://pixabay.com/illustrations/woman-stand-wait-person-shoes-1427073/ +Image2: https://pixabay.com/photos/businessman-suit-germany-black-1146791/ +Image3: https://pixabay.com/photos/tree-pose-yoga-yogini-lifestyle-4823155/ diff --git a/android1/app/src/androidTest/assets/pose_landmark_truth.csv b/android1/app/src/androidTest/assets/pose_landmark_truth.csv new file mode 100644 index 0000000..2a57769 --- /dev/null +++ b/android1/app/src/androidTest/assets/pose_landmark_truth.csv @@ -0,0 +1,3 @@ +nose_x,nose_y,left_eye_x,left_eye_y,right_eye_x,right_eye_y,left_ear_x,left_ear_y,right_ear_x,right_ear_y,left_shoulder_x,left_shoulder_y,right_shoulder_x,right_shoulder_y,left_elbow_x,left_elbow_y,right_elbow_x,right_elbow_y,left_wrist_x,left_wrist_y,right_wrist_x,right_wrist_y,left_hip_x,left_hip_y,right_hip_x,right_hip_y,left_knee_x,left_knee_y,right_knee_x,right_knee_y,left_ankle_x,left_ankle_y,right_ankle_x,right_ankle_y +186,89,200,77,177,78,224,86,167,85,244,158,154,154,258,248,143,239,265,327,136,313,234,311,170,311,247,446,134,445,262,561,92,571 +182,84,191,73,171,74,202,75,157,77,220,119,139,136,260,192,185,230,268,209,246,217,221,288,176,294,205,421,174,421,186,538,155,564 diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/EvaluationUtils.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/EvaluationUtils.kt new file mode 100644 index 0000000..e7ce855 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/EvaluationUtils.kt @@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Canvas +import android.graphics.PointF +import androidx.test.platform.app.InstrumentationRegistry +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.Truth.assertWithMessage +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Person +import java.io.BufferedReader +import java.io.InputStreamReader +import kotlin.math.pow + +object EvaluationUtils { + + /** + * Assert whether the detected person from the image match with the expected result. + * Detection result is accepted as correct if it is within the acceptableError range from the + * expected result. + */ + fun assertPoseDetectionResult( + person: Person, + expectedResult: Map, + acceptableError: Float + ) { + // Check if the model is confident enough in detecting the person + assertThat(person.score).isGreaterThan(0.5f) + + for ((bodyPart, expectedPointF) in expectedResult) { + val keypoint = person.keyPoints.firstOrNull { it.bodyPart == bodyPart } + assertWithMessage("$bodyPart must exist").that(keypoint).isNotNull() + + val detectedPointF = keypoint!!.coordinate + val distanceFromExpectedPointF = distance(detectedPointF, expectedPointF) + assertWithMessage("Detected $bodyPart must be close to expected result") + .that(distanceFromExpectedPointF).isAtMost(acceptableError) + } + } + + /** + * Load an image from assets folder using its name. + */ + fun loadBitmapAssetByName(name: String): Bitmap { + val testContext = InstrumentationRegistry.getInstrumentation().context + val testInput = testContext.assets.open(name) + return BitmapFactory.decodeStream(testInput) + } + + /** + * Load csv from assets folder + */ + fun loadCSVAsset(name: String): List> { + val data = mutableListOf>() + val testContext = InstrumentationRegistry.getInstrumentation().context + val testInput = testContext.assets.open(name) + val inputStreamReader = InputStreamReader(testInput) + val reader = BufferedReader(inputStreamReader) + // Skip header line + reader.readLine() + + // Read expected coordinates from each following lines + reader.forEachLine { + val listPoint = it.split(",") + val map = mutableMapOf() + for (i in listPoint.indices step 2) { + map[BodyPart.fromInt(i / 2)] = + PointF(listPoint[i].toFloat(), listPoint[i + 1].toFloat()) + } + data.add(map) + } + return data + } + + /** + * Calculate the distance between two points + */ + private fun distance(point1: PointF, point2: PointF): Float { + return ((point1.x - point2.x).pow(2) + (point1.y - point2.y).pow(2)).pow(0.5f) + } + + /** + * Concatenate images of same height horizontally + */ + fun hConcat(image1: Bitmap, image2: Bitmap): Bitmap { + if (image1.height != image2.height) { + throw Exception("Input images are not same height.") + } + val finalBitmap = + Bitmap.createBitmap(image1.width + image2.width, image1.height, Bitmap.Config.ARGB_8888) + val canvas = Canvas(finalBitmap) + canvas.drawBitmap(image1, 0f, 0f, null) + canvas.drawBitmap(image2, image1.width.toFloat(), 0f, null) + return finalBitmap + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetLightningTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetLightningTest.kt new file mode 100644 index 0000000..7941b5e --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetLightningTest.kt @@ -0,0 +1,83 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.PointF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device + +@RunWith(AndroidJUnit4::class) +class MovenetLightningTest { + + companion object { + private const val TEST_INPUT_IMAGE1 = "image1.png" + private const val TEST_INPUT_IMAGE2 = "image2.jpg" + private const val ACCEPTABLE_ERROR = 21f + } + + private lateinit var poseDetector: PoseDetector + private lateinit var appContext: Context + private lateinit var expectedDetectionResult: List> + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning) + expectedDetectionResult = + EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv") + } + + @Test + fun testPoseEstimationResultWithImage1() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1) + + // As Movenet use previous frame to optimize detection result, we run it multiple times + // using the same image to improve result. + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[0], + ACCEPTABLE_ERROR + ) + } + + @Test + fun testPoseEstimationResultWithImage2() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2) + + // As Movenet use previous frame to optimize detection result, we run it multiple times + // using the same image to improve result. + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[1], + ACCEPTABLE_ERROR + ) + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetMultiPoseTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetMultiPoseTest.kt new file mode 100644 index 0000000..af30acd --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetMultiPoseTest.kt @@ -0,0 +1,81 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.PointF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device +import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose +import org.tensorflow.lite.examples.poseestimation.ml.Type + +@RunWith(AndroidJUnit4::class) +class MovenetMultiPoseTest { + companion object { + private const val TEST_INPUT_IMAGE1 = "image1.png" + private const val TEST_INPUT_IMAGE2 = "image2.jpg" + private const val ACCEPTABLE_ERROR = 17f + } + + private lateinit var poseDetector: MoveNetMultiPose + private lateinit var appContext: Context + private lateinit var inputFinal: Bitmap + private lateinit var expectedDetectionResult: List> + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + poseDetector = MoveNetMultiPose.create(appContext, Device.CPU, Type.Dynamic) + val input1 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1) + val input2 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2) + inputFinal = EvaluationUtils.hConcat(input1, input2) + expectedDetectionResult = + EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv") + + // update coordination of the pose_landmark_truth.csv corresponding to the new input image + for ((_, value) in expectedDetectionResult[1]) { + value.x = value.x + input1.width + } + } + + @Test + fun testPoseEstimateResult() { + val persons = poseDetector.estimatePoses(inputFinal) + assert(persons.size == 2) + + // Sort the results so that the person on the right side come first. + val sortedPersons = persons.sortedBy { it.boundingBox?.left } + + EvaluationUtils.assertPoseDetectionResult( + sortedPersons[0], + expectedDetectionResult[0], + ACCEPTABLE_ERROR + ) + + EvaluationUtils.assertPoseDetectionResult( + sortedPersons[1], + expectedDetectionResult[1], + ACCEPTABLE_ERROR + ) + } +} diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetThunderTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetThunderTest.kt new file mode 100644 index 0000000..758e7e3 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/MovenetThunderTest.kt @@ -0,0 +1,83 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.PointF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device + +@RunWith(AndroidJUnit4::class) +class MovenetThunderTest { + + companion object { + private const val TEST_INPUT_IMAGE1 = "image1.png" + private const val TEST_INPUT_IMAGE2 = "image2.jpg" + private const val ACCEPTABLE_ERROR = 15f + } + + private lateinit var poseDetector: PoseDetector + private lateinit var appContext: Context + private lateinit var expectedDetectionResult: List> + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder) + expectedDetectionResult = + EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv") + } + + @Test + fun testPoseEstimationResultWithImage1() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1) + + // As Movenet use previous frame to optimize detection result, we run it multiple times + // using the same image to improve result. + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[0], + ACCEPTABLE_ERROR + ) + } + + @Test + fun testPoseEstimationResultWithImage2() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2) + + // As Movenet use previous frame to optimize detection result, we run it multiple times + // using the same image to improve result. + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[1], + ACCEPTABLE_ERROR + ) + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifierTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifierTest.kt new file mode 100644 index 0000000..3b8cea3 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifierTest.kt @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import junit.framework.TestCase +import junit.framework.TestCase.assertEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.Device + +@RunWith(AndroidJUnit4::class) +class PoseClassifierTest { + + companion object { + private const val TEST_INPUT_IMAGE = "image3.jpeg" + } + + private lateinit var appContext: Context + private lateinit var poseDetector: PoseDetector + private lateinit var poseClassifier: PoseClassifier + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning) + poseClassifier = PoseClassifier.create(appContext) + } + + @Test + fun testPoseClassifier() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE) + // As Movenet use previous frame to optimize detection result, we run it multiple times + // using the same image to improve result. + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + poseDetector.estimatePoses(input) + val person = poseDetector.estimatePoses(input)[0] + val classificationResult = poseClassifier.classify(person) + val predictedPose = classificationResult.maxByOrNull { it.second }?.first ?: "n/a" + assertEquals( + "Predicted pose is different from ground truth.", + "tree", + predictedPose + ) + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PosenetTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PosenetTest.kt new file mode 100644 index 0000000..a8c3fb3 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/PosenetTest.kt @@ -0,0 +1,71 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.PointF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device + +@RunWith(AndroidJUnit4::class) +class PosenetTest { + + companion object { + private const val TEST_INPUT_IMAGE1 = "image1.png" + private const val TEST_INPUT_IMAGE2 = "image2.jpg" + private const val ACCEPTABLE_ERROR = 37f + } + + private lateinit var poseDetector: PoseDetector + private lateinit var appContext: Context + private lateinit var expectedDetectionResult: List> + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + poseDetector = PoseNet.create(appContext, Device.CPU) + expectedDetectionResult = + EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv") + } + + @Test + fun testPoseEstimationResultWithImage1() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[0], + ACCEPTABLE_ERROR + ) + } + + @Test + fun testPoseEstimationResultWithImage2() { + val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2) + val person = poseDetector.estimatePoses(input)[0] + EvaluationUtils.assertPoseDetectionResult( + person, + expectedDetectionResult[1], + ACCEPTABLE_ERROR + ) + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/VisualizationTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/VisualizationTest.kt new file mode 100644 index 0000000..42511d8 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/ml/VisualizationTest.kt @@ -0,0 +1,82 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.Bitmap +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import com.google.common.truth.Truth.assertThat +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.VisualizationUtils +import org.tensorflow.lite.examples.poseestimation.data.Device + +/** + * This test is used to visually verify detection results by the models. + * You can put a breakpoint at the end of the method, debug this method, than use the + * "View Bitmap" feature of the debugger to check the visualized detection result. + */ +@RunWith(AndroidJUnit4::class) +class VisualizationTest { + + companion object { + private const val TEST_INPUT_IMAGE = "image2.jpg" + } + + private lateinit var appContext: Context + private lateinit var inputBitmap: Bitmap + + @Before + fun setup() { + appContext = InstrumentationRegistry.getInstrumentation().targetContext + inputBitmap = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE) + } + + @Test + fun testPosenet() { + val poseDetector = PoseNet.create(appContext, Device.CPU) + val person = poseDetector.estimatePoses(inputBitmap)[0] + val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person)) + assertThat(outputBitmap).isNotNull() + } + + @Test + fun testMovenetLightning() { + // Due to Movenet's cropping logic, we run inference several times with the same input + // image to improve accuracy + val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning) + poseDetector.estimatePoses(inputBitmap) + poseDetector.estimatePoses(inputBitmap) + val person2 = poseDetector.estimatePoses(inputBitmap)[0] + val outputBitmap2 = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person2)) + assertThat(outputBitmap2).isNotNull() + } + + @Test + fun testMovenetThunder() { + // Due to Movenet's cropping logic, we run inference several times with the same input + // image to improve accuracy + val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder) + poseDetector.estimatePoses(inputBitmap) + poseDetector.estimatePoses(inputBitmap) + val person = poseDetector.estimatePoses(inputBitmap)[0] + val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person)) + assertThat(outputBitmap).isNotNull() + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTrackerTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTrackerTest.kt new file mode 100644 index 0000000..7372dc4 --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTrackerTest.kt @@ -0,0 +1,228 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import android.graphics.RectF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import junit.framework.TestCase.assertEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.Person + +@RunWith(AndroidJUnit4::class) +class BoundingBoxTrackerTest { + companion object { + private const val MAX_TRACKS = 4 + private const val MAX_AGE = 1000 // Unit: milliseconds. + private const val MIN_SIMILARITY = 0.5f + } + + private lateinit var boundingBoxTracker: BoundingBoxTracker + + @Before + fun setup() { + val trackerConfig = TrackerConfig(MAX_TRACKS, MAX_AGE, MIN_SIMILARITY) + boundingBoxTracker = BoundingBoxTracker(trackerConfig) + } + + @Test + fun testIoU() { + val persons = Person( + -1, listOf(), RectF( + 0f, + 0f, + 2f / 3, + 1f + ), 1f + ) + + val track = + Track( + Person( + -1, + listOf(), + RectF( + 1 / 3f, + 0.0f, + 1f, + 1f, + ), 1f + ), 1000000 + ) + val computedIoU = boundingBoxTracker.iou(persons, track.person) + assertEquals("Wrong IoU value.", 1f / 3, computedIoU, 0.000001f) + } + + @Test + fun testIoUFullOverlap() { + val persons = Person( + -1, listOf(), + RectF( + 0f, + 0f, + 1f, + 1f + ), 1f + ) + + val track = + Track( + Person( + -1, + listOf(), + RectF( + 0f, + 0f, + 1f, + 1f, + ), 1f + ), 1000000 + ) + val computedIoU = boundingBoxTracker.iou(persons, track.person) + assertEquals("Wrong IoU value.", 1f, computedIoU, 0.000001f) + } + + @Test + fun testIoUNoIntersection() { + val persons = Person( + -1, listOf(), + RectF( + 0f, + 0f, + 0.5f, + 0.5f + ), 1f + ) + + val track = + Track( + Person( + -1, + listOf(), + RectF( + 0.5f, + 0.5f, + 1f, + 1f, + ), 1f + ), 1000000 + ) + val computedIoU = boundingBoxTracker.iou(persons, track.person) + assertEquals("Wrong IoU value.", 0f, computedIoU, 0.000001f) + } + + @Test + fun testBoundingBoxTracking() { + // Timestamp: 0. Poses becomes the first two tracks. + var persons = listOf( + Person( // Becomes track 1. + -1, listOf(), RectF( + 0f, + 0f, + 0.5f, + 0.5f, + ), 1f + ), + Person( // Becomes track 2. + -1, listOf(), RectF( + 0f, + 0f, + 1f, + 1f + ), 1f + ) + ) + persons = boundingBoxTracker.apply(persons, 0) + var track = boundingBoxTracker.tracks + assertEquals(2, persons.size) + assertEquals(1, persons[0].id) + assertEquals(2, persons[1].id) + assertEquals(2, track.size) + assertEquals(1, track[0].person.id) + assertEquals(0, track[0].lastTimestamp) + assertEquals(2, track[1].person.id) + assertEquals(0, track[1].lastTimestamp) + + // Timestamp: 100000. First pose is linked with track 1. Second pose spawns + // a new track (id = 2). + persons = listOf( + Person( // Linked with track 1. + -1, listOf(), RectF( + 0.1f, + 0.1f, + 0.5f, + 0.5f + ), 1f + ), + Person( // Becomes track 3. + -1, listOf(), RectF( + 0.2f, + 0.3f, + 0.9f, + 0.9f + ), 1f + ) + ) + persons = boundingBoxTracker.apply(persons, 100000) + track = boundingBoxTracker.tracks + assertEquals(2, persons.size) + assertEquals(1, persons[0].id) + assertEquals(3, persons[1].id) + assertEquals(3, track.size) + assertEquals(1, track[0].person.id) + assertEquals(100000, track[0].lastTimestamp) + assertEquals(3, track[1].person.id) + assertEquals(100000, track[1].lastTimestamp) + assertEquals(2, track[2].person.id) + assertEquals(0, track[2].lastTimestamp) + + // Timestamp: 1050000. First pose is linked with track 1. Second pose is + // identical to track 2, but is not linked because track 2 is deleted due to + // age. Instead it spawns track 4. + persons = listOf( + Person( // Linked with track 1. + -1, listOf(), RectF( + 0.1f, + 0.1f, + 0.55f, + 0.5f + ), 1f + ), + Person( // Becomes track 4. + -1, listOf(), RectF( + 0f, + 0f, + 1f, + 1f + ), 1f + ) + ) + persons = boundingBoxTracker.apply(persons, 1050000) + track = boundingBoxTracker.tracks + assertEquals(2, persons.size) + assertEquals(1, persons[0].id) + assertEquals(4, persons[1].id) + assertEquals(3, track.size) + assertEquals(1, track[0].person.id) + assertEquals(1050000, track[0].lastTimestamp) + assertEquals(4, track[1].person.id) + assertEquals(1050000, track[1].lastTimestamp) + assertEquals(3, track[2].person.id) + assertEquals(100000, track[2].lastTimestamp) + } +} \ No newline at end of file diff --git a/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTrackerTest.kt b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTrackerTest.kt new file mode 100644 index 0000000..0e9bf2c --- /dev/null +++ b/android1/app/src/androidTest/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTrackerTest.kt @@ -0,0 +1,308 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import android.graphics.PointF +import androidx.test.ext.junit.runners.AndroidJUnit4 +import junit.framework.Assert +import junit.framework.TestCase.assertEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.KeyPoint +import org.tensorflow.lite.examples.poseestimation.data.Person +import kotlin.math.exp +import kotlin.math.pow + +@RunWith(AndroidJUnit4::class) +class KeyPointsTrackerTest { + companion object { + private const val MAX_TRACKS = 4 + private const val MAX_AGE = 1000 + private const val MIN_SIMILARITY = 0.5f + private const val KEYPOINT_THRESHOLD = 0.2f + private const val MIN_NUM_KEYPOINT = 2 + private val KEYPOINT_FALLOFF = listOf(0.1f, 0.1f, 0.1f, 0.1f) + } + + private lateinit var keyPointsTracker: KeyPointsTracker + + @Before + fun setup() { + val trackerConfig = TrackerConfig( + MAX_TRACKS, MAX_AGE, MIN_SIMILARITY, + KeyPointsTrackerParams(KEYPOINT_THRESHOLD, KEYPOINT_FALLOFF, MIN_NUM_KEYPOINT) + ) + keyPointsTracker = KeyPointsTracker(trackerConfig) + } + + @Test + fun testOks() { + val persons = + Person( + -1, listOf( + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.7f), 0.8f), + ), score = 1f + ) + val tracks = + Track( + Person( + 0, listOf( + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f), + ), score = 1f + ), + 1000000, + ) + + val oks = keyPointsTracker.oks(persons, tracks.person) + val boxArea = (0.8f - 0.2f) * (0.8f - 0.2f) + val x = 2f * KEYPOINT_FALLOFF[3] + val d = 0.1f + val expectedOks: Float = + (1f + 1f + exp(-1f * d.pow(2) / (2f * boxArea * x.pow(2)))) / 3f + assertEquals(expectedOks, oks, 0.000001f) + } + + @Test + fun testOksReturnZero() { + // Compute OKS returns 0.0 with less than 2 valid keypoints + val persons = + Person( + -1, listOf( + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.1f), // Low confidence. + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f), + ), score = 1f + ) + val tracks = + Track( + Person( + 0, listOf( + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f),// Low confidence. + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f),// Low confidence. + ), score = 1f + ), 1000000 + ) + + val oks = keyPointsTracker.oks(persons, tracks.person) + assertEquals(0f, oks, 0.000001f) + } + + @Test + fun testArea() { + val keyPoints = listOf( + KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.3f, 0.4f), 0.9f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.7f, 0.8f), 0.1f), + ) + val area = keyPointsTracker.area(keyPoints) + val expectedArea = (0.4f - 0.1f) * (0.6f - 0.2f) + assertEquals(expectedArea, area) + } + + @Test + fun testKeyPointsTracker() { + // Timestamp: 0. Person becomes the only track. + var persons = listOf( + Person( + -1, listOf( + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f), + ), score = 0.9f + ) + ) + persons = keyPointsTracker.apply(persons, 0) + var track = keyPointsTracker.tracks + assertEquals(1, persons.size) + assertEquals(1, persons[0].id) + assertEquals(1, track.size) + assertEquals(1, track[0].person.id) + assertEquals(0, track[0].lastTimestamp) + + // Timestamp: 100000. First person is linked with track 1. Second person spawns + // a new track (id = 2). + persons = listOf( + Person( + -1, + listOf( + // Links with id = 1. + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f), + ), + score = 1f + ), + Person( + -1, + listOf( + // Becomes id = 2. + KeyPoint(BodyPart.NOSE, PointF(0.8f, 0.8f), 0.8f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.6f, 0.6f), 0.3f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.4f), 0.1f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.2f), 0.8f), + ), + score = 1f + ) + ) + persons = keyPointsTracker.apply(persons, 100000) + track = keyPointsTracker.tracks + assertEquals(2, persons.size) + assertEquals(1, persons[0].id) + assertEquals(2, persons[1].id) + assertEquals(2, track.size) + assertEquals(1, track[0].person.id) + assertEquals(100000, track[0].lastTimestamp) + assertEquals(2, track[1].person.id) + assertEquals(100000, track[1].lastTimestamp) + + // Timestamp: 900000. First person is linked with track 2. Second person spawns + // a new track (id = 3). + persons = listOf( + Person( + -1, + listOf( + // Links with id = 2. + KeyPoint(BodyPart.NOSE, PointF(0.6f, 0.7f), 0.7f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.7f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.0f, 0.0f), 0.1f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.1f), 1f), + ), + score = 1f + ), + Person( + -1, + listOf( + // Becomes id = 3. + KeyPoint(BodyPart.NOSE, PointF(0.5f, 0.1f), 0.6f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.9f, 0.3f), 0.6f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.1f, 1f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.1f), + ), + score = 1f + ) + ) + persons = keyPointsTracker.apply(persons, 900000) + track = keyPointsTracker.tracks + assertEquals(2, persons.size) + assertEquals(2, persons[0].id) + assertEquals(3, persons[1].id) + assertEquals(3, track.size) + assertEquals(2, track[0].person.id) + assertEquals(900000, track[0].lastTimestamp) + assertEquals(3, track[1].person.id) + assertEquals(900000, track[1].lastTimestamp) + assertEquals(1, track[2].person.id) + assertEquals(100000, track[2].lastTimestamp) + + // Timestamp: 1200000. First person spawns a new track (id = 4), even though + // it has the same keypoints as track 1. This is because the age exceeds + // 1000 msec. The second person links with id 2. The third person spawns a new + // track (id = 5). + persons = listOf( + Person( + -1, + listOf( + // Becomes id = 4. + KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f), + ), + score = 1f + ), + Person( + -1, + listOf( + // Links with id = 2. + KeyPoint(BodyPart.NOSE, PointF(0.55f, 0.7f), 0.7f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.9f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(1f, 1f), 0.1f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.1f), 0f), + ), + score = 1f + ), + Person( + -1, + listOf( + // Becomes id = 5. + KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.1f), 0.1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.2f), 0.9f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.3f, 0.3f), 0.7f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.8f), + ), + score = 1f + ) + ) + persons = keyPointsTracker.apply(persons, 1200000) + track = keyPointsTracker.tracks + assertEquals(3, persons.size) + assertEquals(4, persons[0].id) + assertEquals(2, persons[1].id) + assertEquals(4, track.size) + assertEquals(2, track[0].person.id) + assertEquals(1200000, track[0].lastTimestamp) + assertEquals(4, track[1].person.id) + assertEquals(1200000, track[1].lastTimestamp) + assertEquals(5, track[2].person.id) + assertEquals(1200000, track[2].lastTimestamp) + assertEquals(3, track[3].person.id) + assertEquals(900000, track[3].lastTimestamp) + + // Timestamp: 1300000. First person spawns a new track (id = 6). Since + // maxTracks is 4, the oldest track (id = 3) is removed. + persons = listOf( + Person( + -1, + listOf( + // Becomes id = 6. + KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.8f), 1f), + KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.9f), 0.6f), + KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.2f, 0.9f), 0.5f), + KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.2f), 0.4f), + ), + score = 1f + ) + ) + persons = keyPointsTracker.apply(persons, 1300000) + track = keyPointsTracker.tracks + assertEquals(1, persons.size) + assertEquals(6, persons[0].id) + assertEquals(4, track.size) + assertEquals(6, track[0].person.id) + assertEquals(1300000, track[0].lastTimestamp) + assertEquals(2, track[1].person.id) + assertEquals(1200000, track[1].lastTimestamp) + assertEquals(4, track[2].person.id) + assertEquals(1200000, track[2].lastTimestamp) + assertEquals(5, track[3].person.id) + assertEquals(1200000, track[3].lastTimestamp) + } +} \ No newline at end of file diff --git a/android1/app/src/main/AndroidManifest.xml b/android1/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..910942c --- /dev/null +++ b/android1/app/src/main/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android1/app/src/main/assets/classifier.tflite b/android1/app/src/main/assets/classifier.tflite new file mode 100644 index 0000000..8c0f598 Binary files /dev/null and b/android1/app/src/main/assets/classifier.tflite differ diff --git a/android1/app/src/main/assets/labels.txt b/android1/app/src/main/assets/labels.txt new file mode 100644 index 0000000..983e118 --- /dev/null +++ b/android1/app/src/main/assets/labels.txt @@ -0,0 +1,5 @@ +chair +cobra +dog +tree +warrior \ No newline at end of file diff --git a/android1/app/src/main/assets/movenet_lightning.tflite b/android1/app/src/main/assets/movenet_lightning.tflite new file mode 100644 index 0000000..7e90817 Binary files /dev/null and b/android1/app/src/main/assets/movenet_lightning.tflite differ diff --git a/android1/app/src/main/assets/movenet_multipose_fp16.tflite b/android1/app/src/main/assets/movenet_multipose_fp16.tflite new file mode 100644 index 0000000..13f58ef Binary files /dev/null and b/android1/app/src/main/assets/movenet_multipose_fp16.tflite differ diff --git a/android1/app/src/main/assets/movenet_thunder.tflite b/android1/app/src/main/assets/movenet_thunder.tflite new file mode 100644 index 0000000..1582dc7 Binary files /dev/null and b/android1/app/src/main/assets/movenet_thunder.tflite differ diff --git a/android1/app/src/main/assets/posenet.tflite b/android1/app/src/main/assets/posenet.tflite new file mode 100644 index 0000000..d8b8b32 Binary files /dev/null and b/android1/app/src/main/assets/posenet.tflite differ diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/MainActivity.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/MainActivity.kt new file mode 100644 index 0000000..e3088fe --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/MainActivity.kt @@ -0,0 +1,430 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation + +import android.Manifest +import android.app.AlertDialog +import android.app.Dialog +import android.content.pm.PackageManager +import android.os.Bundle +import android.os.Process +import android.view.SurfaceView +import android.view.View +import android.view.WindowManager +import android.widget.* +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.appcompat.widget.SwitchCompat +import androidx.core.content.ContextCompat +import androidx.fragment.app.DialogFragment +import androidx.lifecycle.lifecycleScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import org.tensorflow.lite.examples.poseestimation.camera.CameraSource +import org.tensorflow.lite.examples.poseestimation.data.Device +import org.tensorflow.lite.examples.poseestimation.ml.* + +class MainActivity : AppCompatActivity() { + companion object { + private const val FRAGMENT_DIALOG = "dialog" + } + + /** A [SurfaceView] for camera preview. */ + private lateinit var surfaceView: SurfaceView + + /** Default pose estimation model is 1 (MoveNet Thunder) + * 0 == MoveNet Lightning model + * 1 == MoveNet Thunder model + * 2 == MoveNet MultiPose model + * 3 == PoseNet model + **/ + private var modelPos = 1 + + /** Default device is CPU */ + private var device = Device.CPU + + private lateinit var tvScore: TextView + private lateinit var tvFPS: TextView + private lateinit var spnDevice: Spinner + private lateinit var spnModel: Spinner + private lateinit var spnTracker: Spinner + private lateinit var vTrackerOption: View + private lateinit var tvClassificationValue1: TextView + private lateinit var tvClassificationValue2: TextView + private lateinit var tvClassificationValue3: TextView + private lateinit var swClassification: SwitchCompat + private lateinit var vClassificationOption: View + private var cameraSource: CameraSource? = null + private var isClassifyPose = false + private val requestPermissionLauncher = + registerForActivityResult( + ActivityResultContracts.RequestPermission() + ) { isGranted: Boolean -> + if (isGranted) { + // Permission is granted. Continue the action or workflow in your + // app. + openCamera() + } else { + // Explain to the user that the feature is unavailable because the + // features requires a permission that the user has denied. At the + // same time, respect the user's decision. Don't link to system + // settings in an effort to convince the user to change their + // decision. + ErrorDialog.newInstance(getString(R.string.tfe_pe_request_permission)) + .show(supportFragmentManager, FRAGMENT_DIALOG) + } + } + private var changeModelListener = object : AdapterView.OnItemSelectedListener { + override fun onNothingSelected(parent: AdapterView<*>?) { + // do nothing + } + + override fun onItemSelected( + parent: AdapterView<*>?, + view: View?, + position: Int, + id: Long + ) { + changeModel(position) + } + } + + private var changeDeviceListener = object : AdapterView.OnItemSelectedListener { + override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) { + changeDevice(position) + } + + override fun onNothingSelected(parent: AdapterView<*>?) { + // do nothing + } + } + + private var changeTrackerListener = object : AdapterView.OnItemSelectedListener { + override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) { + changeTracker(position) + } + + override fun onNothingSelected(parent: AdapterView<*>?) { + // do nothing + } + } + + private var setClassificationListener = + CompoundButton.OnCheckedChangeListener { _, isChecked -> + showClassificationResult(isChecked) + isClassifyPose = isChecked + isPoseClassifier() + } + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_main) + // keep screen on while app is running + window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON) + tvScore = findViewById(R.id.tvScore) + tvFPS = findViewById(R.id.tvFps) + spnModel = findViewById(R.id.spnModel) + spnDevice = findViewById(R.id.spnDevice) + spnTracker = findViewById(R.id.spnTracker) + vTrackerOption = findViewById(R.id.vTrackerOption) + surfaceView = findViewById(R.id.surfaceView) + tvClassificationValue1 = findViewById(R.id.tvClassificationValue1) + tvClassificationValue2 = findViewById(R.id.tvClassificationValue2) + tvClassificationValue3 = findViewById(R.id.tvClassificationValue3) + swClassification = findViewById(R.id.swPoseClassification) + vClassificationOption = findViewById(R.id.vClassificationOption) + initSpinner() + spnModel.setSelection(modelPos) + swClassification.setOnCheckedChangeListener(setClassificationListener) + if (!isCameraPermissionGranted()) { + requestPermission() + } + } + + override fun onStart() { + super.onStart() + openCamera() + } + + override fun onResume() { + cameraSource?.resume() + super.onResume() + } + + override fun onPause() { + cameraSource?.close() + cameraSource = null + super.onPause() + } + + // check if permission is granted or not. + private fun isCameraPermissionGranted(): Boolean { + return checkPermission( + Manifest.permission.CAMERA, + Process.myPid(), + Process.myUid() + ) == PackageManager.PERMISSION_GRANTED + } + + // open camera + private fun openCamera() { + if (isCameraPermissionGranted()) { + if (cameraSource == null) { + cameraSource = + CameraSource(surfaceView, object : CameraSource.CameraSourceListener { + override fun onFPSListener(fps: Int) { + tvFPS.text = getString(R.string.tfe_pe_tv_fps, fps) + } + + override fun onDetectedInfo( + personScore: Float?, + poseLabels: List>? + ) { + tvScore.text = getString(R.string.tfe_pe_tv_score, personScore ?: 0f) + poseLabels?.sortedByDescending { it.second }?.let { + tvClassificationValue1.text = getString( + R.string.tfe_pe_tv_classification_value, + convertPoseLabels(if (it.isNotEmpty()) it[0] else null) + ) + tvClassificationValue2.text = getString( + R.string.tfe_pe_tv_classification_value, + convertPoseLabels(if (it.size >= 2) it[1] else null) + ) + tvClassificationValue3.text = getString( + R.string.tfe_pe_tv_classification_value, + convertPoseLabels(if (it.size >= 3) it[2] else null) + ) + } + } + + }).apply { + prepareCamera() + } + isPoseClassifier() + lifecycleScope.launch(Dispatchers.Main) { + cameraSource?.initCamera() + } + } + createPoseEstimator() + } + } + + private fun convertPoseLabels(pair: Pair?): String { + if (pair == null) return "empty" + return "${pair.first} (${String.format("%.2f", pair.second)})" + } + + private fun isPoseClassifier() { + cameraSource?.setClassifier(if (isClassifyPose) PoseClassifier.create(this) else null) + } + + // Initialize spinners to let user select model/accelerator/tracker. + private fun initSpinner() { + ArrayAdapter.createFromResource( + this, + R.array.tfe_pe_models_array, + android.R.layout.simple_spinner_item + ).also { adapter -> + // Specify the layout to use when the list of choices appears + adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) + // Apply the adapter to the spinner + spnModel.adapter = adapter + spnModel.onItemSelectedListener = changeModelListener + } + + ArrayAdapter.createFromResource( + this, + R.array.tfe_pe_device_name, android.R.layout.simple_spinner_item + ).also { adaper -> + adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) + + spnDevice.adapter = adaper + spnDevice.onItemSelectedListener = changeDeviceListener + } + + ArrayAdapter.createFromResource( + this, + R.array.tfe_pe_tracker_array, android.R.layout.simple_spinner_item + ).also { adaper -> + adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) + + spnTracker.adapter = adaper + spnTracker.onItemSelectedListener = changeTrackerListener + } + } + + // Change model when app is running + private fun changeModel(position: Int) { + if (modelPos == position) return + modelPos = position + createPoseEstimator() + } + + // Change device (accelerator) type when app is running + private fun changeDevice(position: Int) { + val targetDevice = when (position) { + 0 -> Device.CPU + 1 -> Device.GPU + else -> Device.NNAPI + } + if (device == targetDevice) return + device = targetDevice + createPoseEstimator() + } + + // Change tracker for Movenet MultiPose model + private fun changeTracker(position: Int) { + cameraSource?.setTracker( + when (position) { + 1 -> TrackerType.BOUNDING_BOX + 2 -> TrackerType.KEYPOINTS + else -> TrackerType.OFF + } + ) + } + + private fun createPoseEstimator() { + // For MoveNet MultiPose, hide score and disable pose classifier as the model returns + // multiple Person instances. + val poseDetector = when (modelPos) { + 0 -> { + // MoveNet Lightning (SinglePose) + showPoseClassifier(true) + showDetectionScore(true) + showTracker(false) + MoveNet.create(this, device, ModelType.Lightning) + } + 1 -> { + // MoveNet Thunder (SinglePose) + showPoseClassifier(true) + showDetectionScore(true) + showTracker(false) + MoveNet.create(this, device, ModelType.Thunder) + } + 2 -> { + // MoveNet (Lightning) MultiPose + showPoseClassifier(false) + showDetectionScore(false) + // Movenet MultiPose Dynamic does not support GPUDelegate + if (device == Device.GPU) { + showToast(getString(R.string.tfe_pe_gpu_error)) + } + showTracker(true) + MoveNetMultiPose.create( + this, + device, + Type.Dynamic + ) + } + 3 -> { + // PoseNet (SinglePose) + showPoseClassifier(true) + showDetectionScore(true) + showTracker(false) + PoseNet.create(this, device) + } + else -> { + null + } + } + poseDetector?.let { detector -> + cameraSource?.setDetector(detector) + } + } + + // Show/hide the pose classification option. + private fun showPoseClassifier(isVisible: Boolean) { + vClassificationOption.visibility = if (isVisible) View.VISIBLE else View.GONE + if (!isVisible) { + swClassification.isChecked = false + } + } + + // Show/hide the detection score. + private fun showDetectionScore(isVisible: Boolean) { + tvScore.visibility = if (isVisible) View.VISIBLE else View.GONE + } + + // Show/hide classification result. + private fun showClassificationResult(isVisible: Boolean) { + val visibility = if (isVisible) View.VISIBLE else View.GONE + tvClassificationValue1.visibility = visibility + tvClassificationValue2.visibility = visibility + tvClassificationValue3.visibility = visibility + } + + // Show/hide the tracking options. + private fun showTracker(isVisible: Boolean) { + if (isVisible) { + // Show tracker options and enable Bounding Box tracker. + vTrackerOption.visibility = View.VISIBLE + spnTracker.setSelection(1) + } else { + // Set tracker type to off and hide tracker option. + vTrackerOption.visibility = View.GONE + spnTracker.setSelection(0) + } + } + + private fun requestPermission() { + when (PackageManager.PERMISSION_GRANTED) { + ContextCompat.checkSelfPermission( + this, + Manifest.permission.CAMERA + ) -> { + // You can use the API that requires the permission. + openCamera() + } + else -> { + // You can directly ask for the permission. + // The registered ActivityResultCallback gets the result of this request. + requestPermissionLauncher.launch( + Manifest.permission.CAMERA + ) + } + } + } + + private fun showToast(message: String) { + Toast.makeText(this, message, Toast.LENGTH_LONG).show() + } + + /** + * Shows an error message dialog. + */ + class ErrorDialog : DialogFragment() { + + override fun onCreateDialog(savedInstanceState: Bundle?): Dialog = + AlertDialog.Builder(activity) + .setMessage(requireArguments().getString(ARG_MESSAGE)) + .setPositiveButton(android.R.string.ok) { _, _ -> + // do nothing + } + .create() + + companion object { + + @JvmStatic + private val ARG_MESSAGE = "message" + + @JvmStatic + fun newInstance(message: String): ErrorDialog = ErrorDialog().apply { + arguments = Bundle().apply { putString(ARG_MESSAGE, message) } + } + } + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/VisualizationUtils.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/VisualizationUtils.kt new file mode 100644 index 0000000..a57e185 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/VisualizationUtils.kt @@ -0,0 +1,120 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation + +import android.graphics.Bitmap +import android.graphics.Canvas +import android.graphics.Color +import android.graphics.Paint +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Person +import kotlin.math.max + +object VisualizationUtils { + /** Radius of circle used to draw keypoints. */ + private const val CIRCLE_RADIUS = 6f + + /** Width of line used to connected two keypoints. */ + private const val LINE_WIDTH = 4f + + /** The text size of the person id that will be displayed when the tracker is available. */ + private const val PERSON_ID_TEXT_SIZE = 30f + + /** Distance from person id to the nose keypoint. */ + private const val PERSON_ID_MARGIN = 6f + + /** Pair of keypoints to draw lines between. */ + private val bodyJoints = listOf( + Pair(BodyPart.NOSE, BodyPart.LEFT_EYE), + Pair(BodyPart.NOSE, BodyPart.RIGHT_EYE), + Pair(BodyPart.LEFT_EYE, BodyPart.LEFT_EAR), + Pair(BodyPart.RIGHT_EYE, BodyPart.RIGHT_EAR), + Pair(BodyPart.NOSE, BodyPart.LEFT_SHOULDER), + Pair(BodyPart.NOSE, BodyPart.RIGHT_SHOULDER), + Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_ELBOW), + Pair(BodyPart.LEFT_ELBOW, BodyPart.LEFT_WRIST), + Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_ELBOW), + Pair(BodyPart.RIGHT_ELBOW, BodyPart.RIGHT_WRIST), + Pair(BodyPart.LEFT_SHOULDER, BodyPart.RIGHT_SHOULDER), + Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_HIP), + Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_HIP), + Pair(BodyPart.LEFT_HIP, BodyPart.RIGHT_HIP), + Pair(BodyPart.LEFT_HIP, BodyPart.LEFT_KNEE), + Pair(BodyPart.LEFT_KNEE, BodyPart.LEFT_ANKLE), + Pair(BodyPart.RIGHT_HIP, BodyPart.RIGHT_KNEE), + Pair(BodyPart.RIGHT_KNEE, BodyPart.RIGHT_ANKLE) + ) + + // Draw line and point indicate body pose + fun drawBodyKeypoints( + input: Bitmap, + persons: List, + isTrackerEnabled: Boolean = false + ): Bitmap { + val paintCircle = Paint().apply { + strokeWidth = CIRCLE_RADIUS + color = Color.RED + style = Paint.Style.FILL + } + val paintLine = Paint().apply { + strokeWidth = LINE_WIDTH + color = Color.RED + style = Paint.Style.STROKE + } + + val paintText = Paint().apply { + textSize = PERSON_ID_TEXT_SIZE + color = Color.BLUE + textAlign = Paint.Align.LEFT + } + + val output = input.copy(Bitmap.Config.ARGB_8888, true) + val originalSizeCanvas = Canvas(output) + persons.forEach { person -> + // draw person id if tracker is enable + if (isTrackerEnabled) { + person.boundingBox?.let { + val personIdX = max(0f, it.left) + val personIdY = max(0f, it.top) + + originalSizeCanvas.drawText( + person.id.toString(), + personIdX, + personIdY - PERSON_ID_MARGIN, + paintText + ) + originalSizeCanvas.drawRect(it, paintLine) + } + } + bodyJoints.forEach { + val pointA = person.keyPoints[it.first.position].coordinate + val pointB = person.keyPoints[it.second.position].coordinate + originalSizeCanvas.drawLine(pointA.x, pointA.y, pointB.x, pointB.y, paintLine) + } + + person.keyPoints.forEach { point -> + originalSizeCanvas.drawCircle( + point.coordinate.x, + point.coordinate.y, + CIRCLE_RADIUS, + paintCircle + ) + } + } + return output + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/YuvToRgbConverter.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/YuvToRgbConverter.kt new file mode 100644 index 0000000..a6a8a25 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/YuvToRgbConverter.kt @@ -0,0 +1,151 @@ +package org.tensorflow.lite.examples.poseestimation + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.ImageFormat +import android.graphics.Rect +import android.media.Image +import android.renderscript.Allocation +import android.renderscript.Element +import android.renderscript.RenderScript +import android.renderscript.ScriptIntrinsicYuvToRGB +import java.nio.ByteBuffer + +class YuvToRgbConverter(context: Context) { + private val rs = RenderScript.create(context) + private val scriptYuvToRgb = ScriptIntrinsicYuvToRGB.create(rs, Element.U8_4(rs)) + + private var pixelCount: Int = -1 + private lateinit var yuvBuffer: ByteBuffer + private lateinit var inputAllocation: Allocation + private lateinit var outputAllocation: Allocation + + @Synchronized + fun yuvToRgb(image: Image, output: Bitmap) { + + // Ensure that the intermediate output byte buffer is allocated + if (!::yuvBuffer.isInitialized) { + pixelCount = image.cropRect.width() * image.cropRect.height() + yuvBuffer = ByteBuffer.allocateDirect( + pixelCount * ImageFormat.getBitsPerPixel(ImageFormat.YUV_420_888) / 8) + } + + // Get the YUV data in byte array form + imageToByteBuffer(image, yuvBuffer) + + // Ensure that the RenderScript inputs and outputs are allocated + if (!::inputAllocation.isInitialized) { + inputAllocation = Allocation.createSized(rs, Element.U8(rs), yuvBuffer.array().size) + } + if (!::outputAllocation.isInitialized) { + outputAllocation = Allocation.createFromBitmap(rs, output) + } + + // Convert YUV to RGB + inputAllocation.copyFrom(yuvBuffer.array()) + scriptYuvToRgb.setInput(inputAllocation) + scriptYuvToRgb.forEach(outputAllocation) + outputAllocation.copyTo(output) + } + + private fun imageToByteBuffer(image: Image, outputBuffer: ByteBuffer) { + assert(image.format == ImageFormat.YUV_420_888) + + val imageCrop = image.cropRect + val imagePlanes = image.planes + val rowData = ByteArray(imagePlanes.first().rowStride) + + imagePlanes.forEachIndexed { planeIndex, plane -> + + // How many values are read in input for each output value written + // Only the Y plane has a value for every pixel, U and V have half the resolution i.e. + // + // Y Plane U Plane V Plane + // =============== ======= ======= + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + val outputStride: Int + + // The index in the output buffer the next value will be written at + // For Y it's zero, for U and V we start at the end of Y and interleave them i.e. + // + // First chunk Second chunk + // =============== =============== + // Y Y Y Y Y Y Y Y U V U V U V U V + // Y Y Y Y Y Y Y Y U V U V U V U V + // Y Y Y Y Y Y Y Y U V U V U V U V + // Y Y Y Y Y Y Y Y U V U V U V U V + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + var outputOffset: Int + + when (planeIndex) { + 0 -> { + outputStride = 1 + outputOffset = 0 + } + 1 -> { + outputStride = 2 + outputOffset = pixelCount + 1 + } + 2 -> { + outputStride = 2 + outputOffset = pixelCount + } + else -> { + // Image contains more than 3 planes, something strange is going on + return@forEachIndexed + } + } + + val buffer = plane.buffer + val rowStride = plane.rowStride + val pixelStride = plane.pixelStride + + // We have to divide the width and height by two if it's not the Y plane + val planeCrop = if (planeIndex == 0) { + imageCrop + } else { + Rect( + imageCrop.left / 2, + imageCrop.top / 2, + imageCrop.right / 2, + imageCrop.bottom / 2 + ) + } + + val planeWidth = planeCrop.width() + val planeHeight = planeCrop.height() + + buffer.position(rowStride * planeCrop.top + pixelStride * planeCrop.left) + for (row in 0 until planeHeight) { + val length: Int + if (pixelStride == 1 && outputStride == 1) { + // When there is a single stride value for pixel and output, we can just copy + // the entire row in a single step + length = planeWidth + buffer.get(outputBuffer.array(), outputOffset, length) + outputOffset += length + } else { + // When either pixel or output have a stride > 1 we must copy pixel by pixel + length = (planeWidth - 1) * pixelStride + 1 + buffer.get(rowData, 0, length) + for (col in 0 until planeWidth) { + outputBuffer.array()[outputOffset] = rowData[col * pixelStride] + outputOffset += outputStride + } + } + + if (row < planeHeight - 1) { + buffer.position(buffer.position() + rowStride - length) + } + } + } + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/camera/CameraSource.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/camera/CameraSource.kt new file mode 100644 index 0000000..f6e6da7 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/camera/CameraSource.kt @@ -0,0 +1,327 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.camera + +import android.annotation.SuppressLint +import android.content.Context +import android.graphics.Bitmap +import android.graphics.ImageFormat +import android.graphics.Matrix +import android.graphics.Rect +import android.hardware.camera2.CameraCaptureSession +import android.hardware.camera2.CameraCharacteristics +import android.hardware.camera2.CameraDevice +import android.hardware.camera2.CameraManager +import android.media.ImageReader +import android.os.Handler +import android.os.HandlerThread +import android.util.Log +import android.view.Surface +import android.view.SurfaceView +import kotlinx.coroutines.suspendCancellableCoroutine +import org.tensorflow.lite.examples.poseestimation.VisualizationUtils +import org.tensorflow.lite.examples.poseestimation.YuvToRgbConverter +import org.tensorflow.lite.examples.poseestimation.data.Person +import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose +import org.tensorflow.lite.examples.poseestimation.ml.PoseClassifier +import org.tensorflow.lite.examples.poseestimation.ml.PoseDetector +import org.tensorflow.lite.examples.poseestimation.ml.TrackerType +import java.util.* +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException + +class CameraSource( + private val surfaceView: SurfaceView, + private val listener: CameraSourceListener? = null +) { + + companion object { + private const val PREVIEW_WIDTH = 640 + private const val PREVIEW_HEIGHT = 480 + + /** Threshold for confidence score. */ + private const val MIN_CONFIDENCE = .2f + private const val TAG = "Camera Source" + } + + private val lock = Any() + private var detector: PoseDetector? = null + private var classifier: PoseClassifier? = null + private var isTrackerEnabled = false + private var yuvConverter: YuvToRgbConverter = YuvToRgbConverter(surfaceView.context) + private lateinit var imageBitmap: Bitmap + + /** Frame count that have been processed so far in an one second interval to calculate FPS. */ + private var fpsTimer: Timer? = null + private var frameProcessedInOneSecondInterval = 0 + private var framesPerSecond = 0 + + /** Detects, characterizes, and connects to a CameraDevice (used for all camera operations) */ + private val cameraManager: CameraManager by lazy { + val context = surfaceView.context + context.getSystemService(Context.CAMERA_SERVICE) as CameraManager + } + + /** Readers used as buffers for camera still shots */ + private var imageReader: ImageReader? = null + + /** The [CameraDevice] that will be opened in this fragment */ + private var camera: CameraDevice? = null + + /** Internal reference to the ongoing [CameraCaptureSession] configured with our parameters */ + private var session: CameraCaptureSession? = null + + /** [HandlerThread] where all buffer reading operations run */ + private var imageReaderThread: HandlerThread? = null + + /** [Handler] corresponding to [imageReaderThread] */ + private var imageReaderHandler: Handler? = null + private var cameraId: String = "" + + suspend fun initCamera() { + camera = openCamera(cameraManager, cameraId) + imageReader = + ImageReader.newInstance(PREVIEW_WIDTH, PREVIEW_HEIGHT, ImageFormat.YUV_420_888, 3) + imageReader?.setOnImageAvailableListener({ reader -> + val image = reader.acquireLatestImage() + if (image != null) { + if (!::imageBitmap.isInitialized) { + imageBitmap = + Bitmap.createBitmap( + PREVIEW_WIDTH, + PREVIEW_HEIGHT, + Bitmap.Config.ARGB_8888 + ) + } + yuvConverter.yuvToRgb(image, imageBitmap) + // Create rotated version for portrait display + val rotateMatrix = Matrix() + rotateMatrix.postRotate(90.0f) + + val rotatedBitmap = Bitmap.createBitmap( + imageBitmap, 0, 0, PREVIEW_WIDTH, PREVIEW_HEIGHT, + rotateMatrix, false + ) + processImage(rotatedBitmap) + image.close() + } + }, imageReaderHandler) + + imageReader?.surface?.let { surface -> + session = createSession(listOf(surface)) + val cameraRequest = camera?.createCaptureRequest( + CameraDevice.TEMPLATE_PREVIEW + )?.apply { + addTarget(surface) + } + cameraRequest?.build()?.let { + session?.setRepeatingRequest(it, null, null) + } + } + } + + private suspend fun createSession(targets: List): CameraCaptureSession = + suspendCancellableCoroutine { cont -> + camera?.createCaptureSession(targets, object : CameraCaptureSession.StateCallback() { + override fun onConfigured(captureSession: CameraCaptureSession) = + cont.resume(captureSession) + + override fun onConfigureFailed(session: CameraCaptureSession) { + cont.resumeWithException(Exception("Session error")) + } + }, null) + } + + @SuppressLint("MissingPermission") + private suspend fun openCamera(manager: CameraManager, cameraId: String): CameraDevice = + suspendCancellableCoroutine { cont -> + manager.openCamera(cameraId, object : CameraDevice.StateCallback() { + override fun onOpened(camera: CameraDevice) = cont.resume(camera) + + override fun onDisconnected(camera: CameraDevice) { + camera.close() + } + + override fun onError(camera: CameraDevice, error: Int) { + if (cont.isActive) cont.resumeWithException(Exception("Camera error")) + } + }, imageReaderHandler) + } + + fun prepareCamera() { + for (cameraId in cameraManager.cameraIdList) { + val characteristics = cameraManager.getCameraCharacteristics(cameraId) + + // We don't use a front facing camera in this sample. + val cameraDirection = characteristics.get(CameraCharacteristics.LENS_FACING) + if (cameraDirection != null && + cameraDirection == CameraCharacteristics.LENS_FACING_FRONT + ) { + continue + } + this.cameraId = cameraId + } + } + + fun setDetector(detector: PoseDetector) { + synchronized(lock) { + if (this.detector != null) { + this.detector?.close() + this.detector = null + } + this.detector = detector + } + } + + fun setClassifier(classifier: PoseClassifier?) { + synchronized(lock) { + if (this.classifier != null) { + this.classifier?.close() + this.classifier = null + } + this.classifier = classifier + } + } + + /** + * Set Tracker for Movenet MuiltiPose model. + */ + fun setTracker(trackerType: TrackerType) { + isTrackerEnabled = trackerType != TrackerType.OFF + (this.detector as? MoveNetMultiPose)?.setTracker(trackerType) + } + + fun resume() { + imageReaderThread = HandlerThread("imageReaderThread").apply { start() } + imageReaderHandler = Handler(imageReaderThread!!.looper) + fpsTimer = Timer() + fpsTimer?.scheduleAtFixedRate( + object : TimerTask() { + override fun run() { + framesPerSecond = frameProcessedInOneSecondInterval + frameProcessedInOneSecondInterval = 0 + } + }, + 0, + 1000 + ) + } + + fun close() { + session?.close() + session = null + camera?.close() + camera = null + imageReader?.close() + imageReader = null + stopImageReaderThread() + detector?.close() + detector = null + classifier?.close() + classifier = null + fpsTimer?.cancel() + fpsTimer = null + frameProcessedInOneSecondInterval = 0 + framesPerSecond = 0 + } + + // process image + private fun processImage(bitmap: Bitmap) { + val persons = mutableListOf() + var classificationResult: List>? = null + + synchronized(lock) { + detector?.estimatePoses(bitmap)?.let { + persons.addAll(it) + + // if the model only returns one item, allow running the Pose classifier. + if (persons.isNotEmpty()) { + classifier?.run { + classificationResult = classify(persons[0]) + } + } + } + } + frameProcessedInOneSecondInterval++ + if (frameProcessedInOneSecondInterval == 1) { + // send fps to view + listener?.onFPSListener(framesPerSecond) + } + + // if the model returns only one item, show that item's score. + if (persons.isNotEmpty()) { + listener?.onDetectedInfo(persons[0].score, classificationResult) + } + visualize(persons, bitmap) + } + + private fun visualize(persons: List, bitmap: Bitmap) { + + val outputBitmap = VisualizationUtils.drawBodyKeypoints( + bitmap, + persons.filter { it.score > MIN_CONFIDENCE }, isTrackerEnabled + ) + + val holder = surfaceView.holder + val surfaceCanvas = holder.lockCanvas() + surfaceCanvas?.let { canvas -> + val screenWidth: Int + val screenHeight: Int + val left: Int + val top: Int + + if (canvas.height > canvas.width) { + val ratio = outputBitmap.height.toFloat() / outputBitmap.width + screenWidth = canvas.width + left = 0 + screenHeight = (canvas.width * ratio).toInt() + top = (canvas.height - screenHeight) / 2 + } else { + val ratio = outputBitmap.width.toFloat() / outputBitmap.height + screenHeight = canvas.height + top = 0 + screenWidth = (canvas.height * ratio).toInt() + left = (canvas.width - screenWidth) / 2 + } + val right: Int = left + screenWidth + val bottom: Int = top + screenHeight + + canvas.drawBitmap( + outputBitmap, Rect(0, 0, outputBitmap.width, outputBitmap.height), + Rect(left, top, right, bottom), null + ) + surfaceView.holder.unlockCanvasAndPost(canvas) + } + } + + private fun stopImageReaderThread() { + imageReaderThread?.quitSafely() + try { + imageReaderThread?.join() + imageReaderThread = null + imageReaderHandler = null + } catch (e: InterruptedException) { + Log.d(TAG, e.message.toString()) + } + } + + interface CameraSourceListener { + fun onFPSListener(fps: Int) + + fun onDetectedInfo(personScore: Float?, poseLabels: List>?) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/BodyPart.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/BodyPart.kt new file mode 100644 index 0000000..9e105f1 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/BodyPart.kt @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.data + +enum class BodyPart(val position: Int) { + NOSE(0), + LEFT_EYE(1), + RIGHT_EYE(2), + LEFT_EAR(3), + RIGHT_EAR(4), + LEFT_SHOULDER(5), + RIGHT_SHOULDER(6), + LEFT_ELBOW(7), + RIGHT_ELBOW(8), + LEFT_WRIST(9), + RIGHT_WRIST(10), + LEFT_HIP(11), + RIGHT_HIP(12), + LEFT_KNEE(13), + RIGHT_KNEE(14), + LEFT_ANKLE(15), + RIGHT_ANKLE(16); + companion object{ + private val map = values().associateBy(BodyPart::position) + fun fromInt(position: Int): BodyPart = map.getValue(position) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Device.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Device.kt new file mode 100644 index 0000000..612256f --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Device.kt @@ -0,0 +1,23 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.data + +enum class Device { + CPU, + NNAPI, + GPU +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/KeyPoint.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/KeyPoint.kt new file mode 100644 index 0000000..99a89cd --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/KeyPoint.kt @@ -0,0 +1,21 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.data + +import android.graphics.PointF + +data class KeyPoint(val bodyPart: BodyPart, var coordinate: PointF, val score: Float) diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Person.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Person.kt new file mode 100644 index 0000000..07133ff --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/Person.kt @@ -0,0 +1,26 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.data + +import android.graphics.RectF + +data class Person( + var id: Int = -1, // default id is -1 + val keyPoints: List, + val boundingBox: RectF? = null, // Only MoveNet MultiPose return bounding box. + val score: Float +) diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/TorsoAndBodyDistance.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/TorsoAndBodyDistance.kt new file mode 100644 index 0000000..644069c --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/data/TorsoAndBodyDistance.kt @@ -0,0 +1,24 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.data + +data class TorsoAndBodyDistance( + val maxTorsoYDistance: Float, + val maxTorsoXDistance: Float, + val maxBodyYDistance: Float, + val maxBodyXDistance: Float +) diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNet.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNet.kt new file mode 100644 index 0000000..9e7722d --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNet.kt @@ -0,0 +1,353 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.* +import android.os.SystemClock +import org.tensorflow.lite.DataType +import org.tensorflow.lite.Interpreter +import org.tensorflow.lite.examples.poseestimation.data.* +import org.tensorflow.lite.gpu.GpuDelegate +import org.tensorflow.lite.support.common.FileUtil +import org.tensorflow.lite.support.image.ImageProcessor +import org.tensorflow.lite.support.image.TensorImage +import org.tensorflow.lite.support.image.ops.ResizeOp +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min + +enum class ModelType { + Lightning, + Thunder +} + +class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) : + PoseDetector { + + companion object { + private const val MIN_CROP_KEYPOINT_SCORE = .2f + private const val CPU_NUM_THREADS = 4 + + // Parameters that control how large crop region should be expanded from previous frames' + // body keypoints. + private const val TORSO_EXPANSION_RATIO = 1.9f + private const val BODY_EXPANSION_RATIO = 1.2f + + // TFLite file names. + private const val LIGHTNING_FILENAME = "movenet_lightning.tflite" + private const val THUNDER_FILENAME = "movenet_thunder.tflite" + + // allow specifying model type. + fun create(context: Context, device: Device, modelType: ModelType): MoveNet { + val options = Interpreter.Options() + var gpuDelegate: GpuDelegate? = null + options.setNumThreads(CPU_NUM_THREADS) + when (device) { + Device.CPU -> { + } + Device.GPU -> { + gpuDelegate = GpuDelegate() + options.addDelegate(gpuDelegate) + } + Device.NNAPI -> options.setUseNNAPI(true) + } + return MoveNet( + Interpreter( + FileUtil.loadMappedFile( + context, + if (modelType == ModelType.Lightning) LIGHTNING_FILENAME + else THUNDER_FILENAME + ), options + ), + gpuDelegate + ) + } + + // default to lightning. + fun create(context: Context, device: Device): MoveNet = + create(context, device, ModelType.Lightning) + } + + private var cropRegion: RectF? = null + private var lastInferenceTimeNanos: Long = -1 + private val inputWidth = interpreter.getInputTensor(0).shape()[1] + private val inputHeight = interpreter.getInputTensor(0).shape()[2] + private var outputShape: IntArray = interpreter.getOutputTensor(0).shape() + + override fun estimatePoses(bitmap: Bitmap): List { + val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() + if (cropRegion == null) { + cropRegion = initRectF(bitmap.width, bitmap.height) + } + var totalScore = 0f + + val numKeyPoints = outputShape[2] + val keyPoints = mutableListOf() + + cropRegion?.run { + val rect = RectF( + (left * bitmap.width), + (top * bitmap.height), + (right * bitmap.width), + (bottom * bitmap.height) + ) + val detectBitmap = Bitmap.createBitmap( + rect.width().toInt(), + rect.height().toInt(), + Bitmap.Config.ARGB_8888 + ) + Canvas(detectBitmap).drawBitmap( + bitmap, + -rect.left, + -rect.top, + null + ) + val inputTensor = processInputImage(detectBitmap, inputWidth, inputHeight) + val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32) + val widthRatio = detectBitmap.width.toFloat() / inputWidth + val heightRatio = detectBitmap.height.toFloat() / inputHeight + + val positions = mutableListOf() + + inputTensor?.let { input -> + interpreter.run(input.buffer, outputTensor.buffer.rewind()) + val output = outputTensor.floatArray + for (idx in 0 until numKeyPoints) { + val x = output[idx * 3 + 1] * inputWidth * widthRatio + val y = output[idx * 3 + 0] * inputHeight * heightRatio + + positions.add(x) + positions.add(y) + val score = output[idx * 3 + 2] + keyPoints.add( + KeyPoint( + BodyPart.fromInt(idx), + PointF( + x, + y + ), + score + ) + ) + totalScore += score + } + } + val matrix = Matrix() + val points = positions.toFloatArray() + + matrix.postTranslate(rect.left, rect.top) + matrix.mapPoints(points) + keyPoints.forEachIndexed { index, keyPoint -> + keyPoint.coordinate = + PointF( + points[index * 2], + points[index * 2 + 1] + ) + } + // new crop region + cropRegion = determineRectF(keyPoints, bitmap.width, bitmap.height) + } + lastInferenceTimeNanos = + SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos + return listOf(Person(keyPoints = keyPoints, score = totalScore / numKeyPoints)) + } + + override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos + + override fun close() { + gpuDelegate?.close() + interpreter.close() + cropRegion = null + } + + /** + * Prepare input image for detection + */ + private fun processInputImage(bitmap: Bitmap, inputWidth: Int, inputHeight: Int): TensorImage? { + val width: Int = bitmap.width + val height: Int = bitmap.height + + val size = if (height > width) width else height + val imageProcessor = ImageProcessor.Builder().apply { + add(ResizeWithCropOrPadOp(size, size)) + add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR)) + }.build() + val tensorImage = TensorImage(DataType.UINT8) + tensorImage.load(bitmap) + return imageProcessor.process(tensorImage) + } + + /** + * Defines the default crop region. + * The function provides the initial crop region (pads the full image from both + * sides to make it a square image) when the algorithm cannot reliably determine + * the crop region from the previous frame. + */ + private fun initRectF(imageWidth: Int, imageHeight: Int): RectF { + val xMin: Float + val yMin: Float + val width: Float + val height: Float + if (imageWidth > imageHeight) { + width = 1f + height = imageWidth.toFloat() / imageHeight + xMin = 0f + yMin = (imageHeight / 2f - imageWidth / 2f) / imageHeight + } else { + height = 1f + width = imageHeight.toFloat() / imageWidth + yMin = 0f + xMin = (imageWidth / 2f - imageHeight / 2) / imageWidth + } + return RectF( + xMin, + yMin, + xMin + width, + yMin + height + ) + } + + /** + * Checks whether there are enough torso keypoints. + * This function checks whether the model is confident at predicting one of the + * shoulders/hips which is required to determine a good crop region. + */ + private fun torsoVisible(keyPoints: List): Boolean { + return ((keyPoints[BodyPart.LEFT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE).or( + keyPoints[BodyPart.RIGHT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE + )).and( + (keyPoints[BodyPart.LEFT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE).or( + keyPoints[BodyPart.RIGHT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE + ) + ) + } + + /** + * Determines the region to crop the image for the model to run inference on. + * The algorithm uses the detected joints from the previous frame to estimate + * the square region that encloses the full body of the target person and + * centers at the midpoint of two hip joints. The crop size is determined by + * the distances between each joints and the center point. + * When the model is not confident with the four torso joint predictions, the + * function returns a default crop which is the full image padded to square. + */ + private fun determineRectF( + keyPoints: List, + imageWidth: Int, + imageHeight: Int + ): RectF { + val targetKeyPoints = mutableListOf() + keyPoints.forEach { + targetKeyPoints.add( + KeyPoint( + it.bodyPart, + PointF( + it.coordinate.x, + it.coordinate.y + ), + it.score + ) + ) + } + if (torsoVisible(keyPoints)) { + val centerX = + (targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.x + + targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.x) / 2f + val centerY = + (targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.y + + targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.y) / 2f + + val torsoAndBodyDistances = + determineTorsoAndBodyDistances(keyPoints, targetKeyPoints, centerX, centerY) + + val list = listOf( + torsoAndBodyDistances.maxTorsoXDistance * TORSO_EXPANSION_RATIO, + torsoAndBodyDistances.maxTorsoYDistance * TORSO_EXPANSION_RATIO, + torsoAndBodyDistances.maxBodyXDistance * BODY_EXPANSION_RATIO, + torsoAndBodyDistances.maxBodyYDistance * BODY_EXPANSION_RATIO + ) + + var cropLengthHalf = list.maxOrNull() ?: 0f + val tmp = listOf(centerX, imageWidth - centerX, centerY, imageHeight - centerY) + cropLengthHalf = min(cropLengthHalf, tmp.maxOrNull() ?: 0f) + val cropCorner = Pair(centerY - cropLengthHalf, centerX - cropLengthHalf) + + return if (cropLengthHalf > max(imageWidth, imageHeight) / 2f) { + initRectF(imageWidth, imageHeight) + } else { + val cropLength = cropLengthHalf * 2 + RectF( + cropCorner.second / imageWidth, + cropCorner.first / imageHeight, + (cropCorner.second + cropLength) / imageWidth, + (cropCorner.first + cropLength) / imageHeight, + ) + } + } else { + return initRectF(imageWidth, imageHeight) + } + } + + /** + * Calculates the maximum distance from each keypoints to the center location. + * The function returns the maximum distances from the two sets of keypoints: + * full 17 keypoints and 4 torso keypoints. The returned information will be + * used to determine the crop size. See determineRectF for more detail. + */ + private fun determineTorsoAndBodyDistances( + keyPoints: List, + targetKeyPoints: List, + centerX: Float, + centerY: Float + ): TorsoAndBodyDistance { + val torsoJoints = listOf( + BodyPart.LEFT_SHOULDER.position, + BodyPart.RIGHT_SHOULDER.position, + BodyPart.LEFT_HIP.position, + BodyPart.RIGHT_HIP.position + ) + + var maxTorsoYRange = 0f + var maxTorsoXRange = 0f + torsoJoints.forEach { joint -> + val distY = abs(centerY - targetKeyPoints[joint].coordinate.y) + val distX = abs(centerX - targetKeyPoints[joint].coordinate.x) + if (distY > maxTorsoYRange) maxTorsoYRange = distY + if (distX > maxTorsoXRange) maxTorsoXRange = distX + } + + var maxBodyYRange = 0f + var maxBodyXRange = 0f + for (joint in keyPoints.indices) { + if (keyPoints[joint].score < MIN_CROP_KEYPOINT_SCORE) continue + val distY = abs(centerY - keyPoints[joint].coordinate.y) + val distX = abs(centerX - keyPoints[joint].coordinate.x) + + if (distY > maxBodyYRange) maxBodyYRange = distY + if (distX > maxBodyXRange) maxBodyXRange = distX + } + return TorsoAndBodyDistance( + maxTorsoYRange, + maxTorsoXRange, + maxBodyYRange, + maxBodyXRange + ) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNetMultiPose.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNetMultiPose.kt new file mode 100644 index 0000000..62f61ca --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/MoveNetMultiPose.kt @@ -0,0 +1,309 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.PointF +import android.graphics.RectF +import android.os.SystemClock +import org.tensorflow.lite.DataType +import org.tensorflow.lite.Interpreter +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device +import org.tensorflow.lite.examples.poseestimation.data.KeyPoint +import org.tensorflow.lite.examples.poseestimation.data.Person +import org.tensorflow.lite.examples.poseestimation.tracker.* +import org.tensorflow.lite.gpu.GpuDelegate +import org.tensorflow.lite.support.common.FileUtil +import org.tensorflow.lite.support.image.ImageOperator +import org.tensorflow.lite.support.image.ImageProcessor +import org.tensorflow.lite.support.image.TensorImage +import org.tensorflow.lite.support.image.ops.ResizeOp +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer +import kotlin.math.ceil + +class MoveNetMultiPose( + private val interpreter: Interpreter, + private val type: Type, + private val gpuDelegate: GpuDelegate?, +) : PoseDetector { + private val outputShape = interpreter.getOutputTensor(0).shape() + private val inputShape = interpreter.getInputTensor(0).shape() + private var imageWidth: Int = 0 + private var imageHeight: Int = 0 + private var targetWidth: Int = 0 + private var targetHeight: Int = 0 + private var scaleHeight: Int = 0 + private var scaleWidth: Int = 0 + private var lastInferenceTimeNanos: Long = -1 + private var tracker: AbstractTracker? = null + + companion object { + private const val DYNAMIC_MODEL_TARGET_INPUT_SIZE = 256 + private const val SHAPE_MULTIPLE = 32.0 + private const val DETECTION_THRESHOLD = 0.11 + private const val DETECTION_SCORE_INDEX = 55 + private const val BOUNDING_BOX_Y_MIN_INDEX = 51 + private const val BOUNDING_BOX_X_MIN_INDEX = 52 + private const val BOUNDING_BOX_Y_MAX_INDEX = 53 + private const val BOUNDING_BOX_X_MAX_INDEX = 54 + private const val KEYPOINT_COUNT = 17 + private const val OUTPUTS_COUNT_PER_KEYPOINT = 3 + private const val CPU_NUM_THREADS = 4 + + // allow specifying model type. + fun create( + context: Context, + device: Device, + type: Type, + ): MoveNetMultiPose { + val options = Interpreter.Options() + var gpuDelegate: GpuDelegate? = null + when (device) { + Device.CPU -> { + options.setNumThreads(CPU_NUM_THREADS) + } + Device.GPU -> { + // only fixed model support Gpu delegate option. + if (type == Type.Fixed) { + gpuDelegate = GpuDelegate() + options.addDelegate(gpuDelegate) + } + } + else -> { + // nothing to do + } + } + return MoveNetMultiPose( + Interpreter( + FileUtil.loadMappedFile( + context, + if (type == Type.Dynamic) + "movenet_multipose_fp16.tflite" else "" + //@TODO: (khanhlvg) Add support for fixed shape model if it's released. + ), options + ), type, gpuDelegate + ) + } + } + + /** + * Convert x and y coordinates ([0-1]) returns from the TFlite model + * to the coordinates corresponding to the input image. + */ + private fun resizeKeypoint(x: Float, y: Float): PointF { + return PointF(resizeX(x), resizeY(y)) + } + + private fun resizeX(x: Float): Float { + return if (imageWidth > imageHeight) { + val ratioWidth = imageWidth.toFloat() / targetWidth + x * targetWidth * ratioWidth + } else { + val detectedWidth = + if (type == Type.Dynamic) targetWidth else inputShape[2] + val paddingWidth = detectedWidth - scaleWidth + val ratioWidth = imageWidth.toFloat() / scaleWidth + (x * detectedWidth - paddingWidth / 2f) * ratioWidth + } + } + + private fun resizeY(y: Float): Float { + return if (imageWidth > imageHeight) { + val detectedHeight = + if (type == Type.Dynamic) targetHeight else inputShape[1] + val paddingHeight = detectedHeight - scaleHeight + val ratioHeight = imageHeight.toFloat() / scaleHeight + (y * detectedHeight - paddingHeight / 2f) * ratioHeight + } else { + val ratioHeight = imageHeight.toFloat() / targetHeight + y * targetHeight * ratioHeight + } + } + + /** + * Prepare input image for detection + */ + private fun processInputTensor(bitmap: Bitmap): TensorImage { + imageWidth = bitmap.width + imageHeight = bitmap.height + + // if model type is fixed. get input size from input shape. + val inputSizeHeight = + if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[1] + val inputSizeWidth = + if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[2] + + val resizeOp: ImageOperator + if (imageWidth > imageHeight) { + val scale = inputSizeWidth / imageWidth.toFloat() + targetWidth = inputSizeWidth + scaleHeight = ceil(imageHeight * scale).toInt() + targetHeight = (ceil((scaleHeight / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt() + resizeOp = ResizeOp(scaleHeight, targetWidth, ResizeOp.ResizeMethod.BILINEAR) + } else { + val scale = inputSizeHeight / imageHeight.toFloat() + targetHeight = inputSizeHeight + scaleWidth = ceil(imageWidth * scale).toInt() + targetWidth = (ceil((scaleWidth / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt() + resizeOp = ResizeOp(targetHeight, scaleWidth, ResizeOp.ResizeMethod.BILINEAR) + } + + val resizeWithCropOrPad = if (type == Type.Dynamic) ResizeWithCropOrPadOp( + targetHeight, + targetWidth + ) else ResizeWithCropOrPadOp( + inputSizeHeight, + inputSizeWidth + ) + val imageProcessor = ImageProcessor.Builder().apply { + add(resizeOp) + add(resizeWithCropOrPad) + }.build() + val tensorImage = TensorImage(DataType.UINT8) + tensorImage.load(bitmap) + return imageProcessor.process(tensorImage) + } + + /** + * Run tracker (if available) and process the output. + */ + private fun postProcess(modelOutput: FloatArray): List { + val persons = mutableListOf() + for (idx in modelOutput.indices step outputShape[2]) { + val personScore = modelOutput[idx + DETECTION_SCORE_INDEX] + if (personScore < DETECTION_THRESHOLD) continue + val positions = modelOutput.copyOfRange(idx, idx + 51) + val keyPoints = mutableListOf() + for (i in 0 until KEYPOINT_COUNT) { + val y = positions[i * OUTPUTS_COUNT_PER_KEYPOINT] + val x = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 1] + val score = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 2] + keyPoints.add(KeyPoint(BodyPart.fromInt(i), PointF(x, y), score)) + } + val yMin = modelOutput[idx + BOUNDING_BOX_Y_MIN_INDEX] + val xMin = modelOutput[idx + BOUNDING_BOX_X_MIN_INDEX] + val yMax = modelOutput[idx + BOUNDING_BOX_Y_MAX_INDEX] + val xMax = modelOutput[idx + BOUNDING_BOX_X_MAX_INDEX] + val boundingBox = RectF(xMin, yMin, xMax, yMax) + persons.add( + Person( + keyPoints = keyPoints, + boundingBox = boundingBox, + score = personScore + ) + ) + } + + if (persons.isEmpty()) return emptyList() + + if (tracker == null) { + persons.forEach { + it.keyPoints.forEach { key -> + key.coordinate = resizeKeypoint(key.coordinate.x, key.coordinate.y) + } + } + return persons + } else { + val trackPersons = mutableListOf() + tracker?.apply(persons, System.currentTimeMillis() * 1000)?.forEach { + val resizeKeyPoint = mutableListOf() + it.keyPoints.forEach { key -> + resizeKeyPoint.add( + KeyPoint( + key.bodyPart, + resizeKeypoint(key.coordinate.x, key.coordinate.y), + key.score + ) + ) + } + + var resizeBoundingBox: RectF? = null + it.boundingBox?.let { boundingBox -> + resizeBoundingBox = RectF( + resizeX(boundingBox.left), + resizeY(boundingBox.top), + resizeX(boundingBox.right), + resizeY(boundingBox.bottom) + ) + } + trackPersons.add(Person(it.id, resizeKeyPoint, resizeBoundingBox, it.score)) + } + return trackPersons + } + } + + /** + * Create and set tracker. + */ + fun setTracker(trackerType: TrackerType) { + tracker = when (trackerType) { + TrackerType.BOUNDING_BOX -> { + BoundingBoxTracker() + } + TrackerType.KEYPOINTS -> { + KeyPointsTracker() + } + TrackerType.OFF -> { + null + } + } + } + + /** + * Run TFlite model and Returns a list of "Person" corresponding to the input image. + */ + override fun estimatePoses(bitmap: Bitmap): List { + val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() + val inputTensor = processInputTensor(bitmap) + val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32) + + // if model is dynamic, resize input before run interpreter + if (type == Type.Dynamic) { + val inputShape = intArrayOf(1).plus(inputTensor.tensorBuffer.shape) + interpreter.resizeInput(0, inputShape, true) + interpreter.allocateTensors() + } + interpreter.run(inputTensor.buffer, outputTensor.buffer.rewind()) + + val processedPerson = postProcess(outputTensor.floatArray) + lastInferenceTimeNanos = + SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos + return processedPerson + } + + override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos + + /** + * Close all resources when not in use. + */ + override fun close() { + gpuDelegate?.close() + interpreter.close() + tracker = null + } +} + +enum class Type { + Dynamic, Fixed +} + +enum class TrackerType { + OFF, BOUNDING_BOX, KEYPOINTS +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifier.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifier.kt new file mode 100644 index 0000000..0a67189 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseClassifier.kt @@ -0,0 +1,73 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import org.tensorflow.lite.Interpreter +import org.tensorflow.lite.examples.poseestimation.data.Person +import org.tensorflow.lite.support.common.FileUtil + +class PoseClassifier( + private val interpreter: Interpreter, + private val labels: List +) { + private val input = interpreter.getInputTensor(0).shape() + private val output = interpreter.getOutputTensor(0).shape() + + companion object { + private const val MODEL_FILENAME = "classifier.tflite" + private const val LABELS_FILENAME = "labels.txt" + private const val CPU_NUM_THREADS = 4 + + fun create(context: Context): PoseClassifier { + val options = Interpreter.Options().apply { + setNumThreads(CPU_NUM_THREADS) + } + return PoseClassifier( + Interpreter( + FileUtil.loadMappedFile( + context, MODEL_FILENAME + ), options + ), + FileUtil.loadLabels(context, LABELS_FILENAME) + ) + } + } + + fun classify(person: Person?): List> { + // Preprocess the pose estimation result to a flat array + val inputVector = FloatArray(input[1]) + person?.keyPoints?.forEachIndexed { index, keyPoint -> + inputVector[index * 3] = keyPoint.coordinate.y + inputVector[index * 3 + 1] = keyPoint.coordinate.x + inputVector[index * 3 + 2] = keyPoint.score + } + + // Postprocess the model output to human readable class names + val outputTensor = FloatArray(output[1]) + interpreter.run(arrayOf(inputVector), arrayOf(outputTensor)) + val output = mutableListOf>() + outputTensor.forEachIndexed { index, score -> + output.add(Pair(labels[index], score)) + } + return output + } + + fun close() { + interpreter.close() + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseDetector.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseDetector.kt new file mode 100644 index 0000000..f8e81a8 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseDetector.kt @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.graphics.Bitmap +import org.tensorflow.lite.examples.poseestimation.data.Person + +interface PoseDetector : AutoCloseable { + + fun estimatePoses(bitmap: Bitmap): List + + fun lastInferenceTimeNanos(): Long +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseNet.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseNet.kt new file mode 100644 index 0000000..fff4470 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/ml/PoseNet.kt @@ -0,0 +1,261 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.ml + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.PointF +import android.os.SystemClock +import android.util.Log +import org.tensorflow.lite.DataType +import org.tensorflow.lite.Interpreter +import org.tensorflow.lite.examples.poseestimation.data.BodyPart +import org.tensorflow.lite.examples.poseestimation.data.Device +import org.tensorflow.lite.examples.poseestimation.data.KeyPoint +import org.tensorflow.lite.examples.poseestimation.data.Person +import org.tensorflow.lite.gpu.GpuDelegate +import org.tensorflow.lite.support.common.FileUtil +import org.tensorflow.lite.support.common.ops.NormalizeOp +import org.tensorflow.lite.support.image.ImageProcessor +import org.tensorflow.lite.support.image.TensorImage +import org.tensorflow.lite.support.image.ops.ResizeOp +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp +import kotlin.math.exp + +class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) : + PoseDetector { + + companion object { + private const val CPU_NUM_THREADS = 4 + private const val MEAN = 127.5f + private const val STD = 127.5f + private const val TAG = "Posenet" + private const val MODEL_FILENAME = "posenet.tflite" + + fun create(context: Context, device: Device): PoseNet { + val options = Interpreter.Options() + var gpuDelegate: GpuDelegate? = null + options.setNumThreads(CPU_NUM_THREADS) + when (device) { + Device.CPU -> { + } + Device.GPU -> { + gpuDelegate = GpuDelegate() + options.addDelegate(gpuDelegate) + } + Device.NNAPI -> options.setUseNNAPI(true) + } + return PoseNet( + Interpreter( + FileUtil.loadMappedFile( + context, + MODEL_FILENAME + ), options + ), + gpuDelegate + ) + } + } + + private var lastInferenceTimeNanos: Long = -1 + private val inputWidth = interpreter.getInputTensor(0).shape()[1] + private val inputHeight = interpreter.getInputTensor(0).shape()[2] + private var cropHeight = 0f + private var cropWidth = 0f + private var cropSize = 0 + + @Suppress("UNCHECKED_CAST") + override fun estimatePoses(bitmap: Bitmap): List { + val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos() + val inputArray = arrayOf(processInputImage(bitmap).tensorBuffer.buffer) + Log.i( + TAG, + String.format( + "Scaling to [-1,1] took %.2f ms", + (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000f + ) + ) + + val outputMap = initOutputMap(interpreter) + + val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() + interpreter.runForMultipleInputsOutputs(inputArray, outputMap) + lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos + Log.i( + TAG, + String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000) + ) + + val heatmaps = outputMap[0] as Array>> + val offsets = outputMap[1] as Array>> + + val postProcessingStartTimeNanos = SystemClock.elapsedRealtimeNanos() + val person = postProcessModelOuputs(heatmaps, offsets) + Log.i( + TAG, + String.format( + "Postprocessing took %.2f ms", + (SystemClock.elapsedRealtimeNanos() - postProcessingStartTimeNanos) / 1_000_000f + ) + ) + + return listOf(person) + } + + /** + * Convert heatmaps and offsets output of Posenet into a list of keypoints + */ + private fun postProcessModelOuputs( + heatmaps: Array>>, + offsets: Array>> + ): Person { + val height = heatmaps[0].size + val width = heatmaps[0][0].size + val numKeypoints = heatmaps[0][0][0].size + + // Finds the (row, col) locations of where the keypoints are most likely to be. + val keypointPositions = Array(numKeypoints) { Pair(0, 0) } + for (keypoint in 0 until numKeypoints) { + var maxVal = heatmaps[0][0][0][keypoint] + var maxRow = 0 + var maxCol = 0 + for (row in 0 until height) { + for (col in 0 until width) { + if (heatmaps[0][row][col][keypoint] > maxVal) { + maxVal = heatmaps[0][row][col][keypoint] + maxRow = row + maxCol = col + } + } + } + keypointPositions[keypoint] = Pair(maxRow, maxCol) + } + + // Calculating the x and y coordinates of the keypoints with offset adjustment. + val xCoords = IntArray(numKeypoints) + val yCoords = IntArray(numKeypoints) + val confidenceScores = FloatArray(numKeypoints) + keypointPositions.forEachIndexed { idx, position -> + val positionY = keypointPositions[idx].first + val positionX = keypointPositions[idx].second + + val inputImageCoordinateY = + position.first / (height - 1).toFloat() * inputHeight + offsets[0][positionY][positionX][idx] + val ratioHeight = cropSize.toFloat() / inputHeight + val paddingHeight = cropHeight / 2 + yCoords[idx] = (inputImageCoordinateY * ratioHeight - paddingHeight).toInt() + + val inputImageCoordinateX = + position.second / (width - 1).toFloat() * inputWidth + offsets[0][positionY][positionX][idx + numKeypoints] + val ratioWidth = cropSize.toFloat() / inputWidth + val paddingWidth = cropWidth / 2 + xCoords[idx] = (inputImageCoordinateX * ratioWidth - paddingWidth).toInt() + + confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx]) + } + + val keypointList = mutableListOf() + var totalScore = 0.0f + enumValues().forEachIndexed { idx, it -> + keypointList.add( + KeyPoint( + it, + PointF(xCoords[idx].toFloat(), yCoords[idx].toFloat()), + confidenceScores[idx] + ) + ) + totalScore += confidenceScores[idx] + } + return Person(keyPoints = keypointList.toList(), score = totalScore / numKeypoints) + } + + override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos + + override fun close() { + gpuDelegate?.close() + interpreter.close() + } + + /** + * Scale and crop the input image to a TensorImage. + */ + private fun processInputImage(bitmap: Bitmap): TensorImage { + // reset crop width and height + cropWidth = 0f + cropHeight = 0f + cropSize = if (bitmap.width > bitmap.height) { + cropHeight = (bitmap.width - bitmap.height).toFloat() + bitmap.width + } else { + cropWidth = (bitmap.height - bitmap.width).toFloat() + bitmap.height + } + + val imageProcessor = ImageProcessor.Builder().apply { + add(ResizeWithCropOrPadOp(cropSize, cropSize)) + add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR)) + add(NormalizeOp(MEAN, STD)) + }.build() + val tensorImage = TensorImage(DataType.FLOAT32) + tensorImage.load(bitmap) + return imageProcessor.process(tensorImage) + } + + /** + * Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate. + */ + private fun initOutputMap(interpreter: Interpreter): HashMap { + val outputMap = HashMap() + + // 1 * 9 * 9 * 17 contains heatmaps + val heatmapsShape = interpreter.getOutputTensor(0).shape() + outputMap[0] = Array(heatmapsShape[0]) { + Array(heatmapsShape[1]) { + Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) } + } + } + + // 1 * 9 * 9 * 34 contains offsets + val offsetsShape = interpreter.getOutputTensor(1).shape() + outputMap[1] = Array(offsetsShape[0]) { + Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } } + } + + // 1 * 9 * 9 * 32 contains forward displacements + val displacementsFwdShape = interpreter.getOutputTensor(2).shape() + outputMap[2] = Array(offsetsShape[0]) { + Array(displacementsFwdShape[1]) { + Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) } + } + } + + // 1 * 9 * 9 * 32 contains backward displacements + val displacementsBwdShape = interpreter.getOutputTensor(3).shape() + outputMap[3] = Array(displacementsBwdShape[0]) { + Array(displacementsBwdShape[1]) { + Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) } + } + } + + return outputMap + } + + /** Returns value within [0,1]. */ + private fun sigmoid(x: Float): Float { + return (1.0f / (1.0f + exp(-x))) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/AbstractTracker.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/AbstractTracker.kt new file mode 100644 index 0000000..d793016 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/AbstractTracker.kt @@ -0,0 +1,158 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import org.tensorflow.lite.examples.poseestimation.data.Person + +abstract class AbstractTracker(val config: TrackerConfig) { + + private val maxAge = config.maxAge * 1000 // convert milliseconds to microseconds + private var nextTrackId = 0 + var tracks = mutableListOf() + private set + + /** + * Computes pairwise similarity scores between detections and tracks, based + * on detected features. + * @param persons A list of detected person. + * @returns A list of shape [num_det, num_tracks] with pairwise + * similarity scores between detections and tracks. + */ + abstract fun computeSimilarity(persons: List): List> + + /** + * Tracks person instances across frames based on detections. + * @param persons A list of person + * @param timestamp The current timestamp in microseconds + * @return An updated list of persons with tracking id. + */ + fun apply(persons: List, timestamp: Long): List { + tracks = filterOldTrack(timestamp).toMutableList() + val simMatrix = computeSimilarity(persons) + assignTrack(persons, simMatrix, timestamp) + tracks = updateTrack().toMutableList() + return persons + } + + /** + * Clear all track in list of tracks + */ + fun reset() { + tracks.clear() + } + + /** + * Return the next track id + */ + private fun nextTrackID() = ++nextTrackId + + /** + * Performs a greedy optimization to link detections with tracks. The person + * list is updated in place by providing an `id` property. If incoming + * detections are not linked with existing tracks, new tracks will be created. + * @param persons A list of detected person. It's assumed that persons are + * sorted from most confident to least confident. + * @param simMatrix A list of shape [num_det, num_tracks] with pairwise + * similarity scores between detections and tracks. + * @param timestamp The current timestamp in microseconds. + */ + private fun assignTrack(persons: List, simMatrix: List>, timestamp: Long) { + if ((simMatrix.size != persons.size) != (simMatrix[0].size != tracks.size)) { + throw IllegalArgumentException( + "Size of person array and similarity matrix does not match.") + } + + val unmatchedTrackIndices = MutableList(tracks.size) { it } + val unmatchedDetectionIndices = mutableListOf() + + for (detectionIndex in persons.indices) { + // If the track list is empty, add the person's index + // to unmatched detections to create a new track later. + if (unmatchedTrackIndices.isEmpty()) { + unmatchedDetectionIndices.add(detectionIndex) + continue + } + + // Assign the detection to the track which produces the highest pairwise + // similarity score, assuming the score exceeds the minimum similarity + // threshold. + var maxTrackIndex = -1 + var maxSimilarity = -1f + unmatchedTrackIndices.forEach { trackIndex -> + val similarity = simMatrix[detectionIndex][trackIndex] + if (similarity >= config.minSimilarity && similarity > maxSimilarity) { + maxTrackIndex = trackIndex + maxSimilarity = similarity + } + } + if (maxTrackIndex >= 0) { + val linkedTrack = tracks[maxTrackIndex] + tracks[maxTrackIndex] = + createTrack(persons[detectionIndex], linkedTrack.person.id, timestamp) + persons[detectionIndex].id = linkedTrack.person.id + val index = unmatchedTrackIndices.indexOf(maxTrackIndex) + unmatchedTrackIndices.removeAt(index) + } else { + unmatchedDetectionIndices.add(detectionIndex) + } + } + + // Spawn new tracks for all unmatched detections. + unmatchedDetectionIndices.forEach { detectionIndex -> + val newTrack = createTrack(persons[detectionIndex], timestamp = timestamp) + tracks.add(newTrack) + persons[detectionIndex].id = newTrack.person.id + } + } + + /** + * Filters tracks based on their age. + * @param timestamp The timestamp in microseconds + */ + private fun filterOldTrack(timestamp: Long): List { + return tracks.filter { + timestamp - it.lastTimestamp <= maxAge + } + } + + /** + * Sort the track list by timestamp (newer first) + * and return the track list with size equal to config.maxTracks + */ + private fun updateTrack(): List { + tracks.sortByDescending { it.lastTimestamp } + return tracks.take(config.maxTracks) + } + + /** + * Create a new track from person's information. + * @param person A person + * @param id The Id assign to the new track. If it is null, assign the next track id. + * @param timestamp The timestamp in microseconds + */ + private fun createTrack(person: Person, id: Int? = null, timestamp: Long): Track { + return Track( + person = Person( + id = id ?: nextTrackID(), + keyPoints = person.keyPoints, + boundingBox = person.boundingBox, + score = person.score + ), + lastTimestamp = timestamp + ) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTracker.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTracker.kt new file mode 100644 index 0000000..6a2074a --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/BoundingBoxTracker.kt @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import androidx.annotation.VisibleForTesting +import org.tensorflow.lite.examples.poseestimation.data.Person +import kotlin.math.max +import kotlin.math.min + +/** + * BoundingBoxTracker, which tracks objects based on bounding box similarity, + * currently defined as intersection-over-union (IoU). + */ +class BoundingBoxTracker(config: TrackerConfig = TrackerConfig()) : AbstractTracker(config) { + + /** + * Computes similarity based on intersection-over-union (IoU). See `AbstractTracker` + * for more details. + */ + override fun computeSimilarity(persons: List): List> { + if (persons.isEmpty() && tracks.isEmpty()) { + return emptyList() + } + return persons.map { person -> tracks.map { track -> iou(person, track.person) } } + } + + /** + * Computes the intersection-over-union (IoU) between a person and a track person. + * @param person1 A person + * @param person2 A track person + * @return The IoU between the person and the track person. This number is + * between 0 and 1, and larger values indicate more box similarity. + */ + @VisibleForTesting(otherwise = VisibleForTesting.PRIVATE) + fun iou(person1: Person, person2: Person): Float { + if (person1.boundingBox != null && person2.boundingBox != null) { + val xMin = max(person1.boundingBox.left, person2.boundingBox.left) + val yMin = max(person1.boundingBox.top, person2.boundingBox.top) + val xMax = min(person1.boundingBox.right, person2.boundingBox.right) + val yMax = min(person1.boundingBox.bottom, person2.boundingBox.bottom) + if (xMin >= xMax || yMin >= yMax) return 0f + val intersection = (xMax - xMin) * (yMax - yMin) + val areaPerson = person1.boundingBox.width() * person1.boundingBox.height() + val areaTrack = person2.boundingBox.width() * person2.boundingBox.height() + return intersection / (areaPerson + areaTrack - intersection) + } + return 0f + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTracker.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTracker.kt new file mode 100644 index 0000000..388a874 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/KeyPointsTracker.kt @@ -0,0 +1,125 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import androidx.annotation.VisibleForTesting +import org.tensorflow.lite.examples.poseestimation.data.KeyPoint +import org.tensorflow.lite.examples.poseestimation.data.Person +import kotlin.math.exp +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow + +/** + * KeypointTracker, which tracks poses based on keypoint similarity. This + * tracker assumes that keypoints are provided in normalized image + * coordinates. + */ +class KeyPointsTracker( + trackerConfig: TrackerConfig = TrackerConfig( + keyPointsTrackerParams = KeyPointsTrackerParams() + ) +) : AbstractTracker(trackerConfig) { + /** + * Computes similarity based on Object Keypoint Similarity (OKS). It's + * assumed that the keypoints within each person are in normalized image + * coordinates. See `AbstractTracker` for more details. + */ + override fun computeSimilarity(persons: List): List> { + if (persons.isEmpty() && tracks.isEmpty()) { + return emptyList() + } + val simMatrix = mutableListOf>() + persons.forEach { person -> + val row = mutableListOf() + tracks.forEach { track -> + val oksValue = oks(person, track.person) + row.add(oksValue) + } + simMatrix.add(row) + } + return simMatrix + } + + /** + * Computes the Object Keypoint Similarity (OKS) between a person and track person. + * This is similar in spirit to the calculation used by COCO keypoint eval: + * https://cocodataset.org/#keypoints-eval + * In this case, OKS is calculated as: + * (1/sum_i d(c_i, c_ti)) * \sum_i exp(-d_i^2/(2*a_ti*x_i^2))*d(c_i, c_ti) + * where + * d(x, y) is an indicator function which only produces 1 if x and y + * exceed a given threshold (i.e. keypointThreshold), otherwise 0. + * c_i is the confidence of keypoint i from the new person + * c_ti is the confidence of keypoint i from the track person + * d_i is the Euclidean distance between the person and track person keypoint + * a_ti is the area of the track object (the box covering the keypoints) + * x_i is a constant that controls falloff in a Gaussian distribution, + * computed as 2*keypointFalloff[i]. + * @param person1 A person. + * @param person2 A track person. + * @returns The OKS score between the person and the track person. This number is + * between 0 and 1, and larger values indicate more keypoint similarity. + */ + @VisibleForTesting(otherwise = VisibleForTesting.PRIVATE) + fun oks(person1: Person, person2: Person): Float { + if (config.keyPointsTrackerParams == null) return 0f + person2.keyPoints.let { keyPoints -> + val boxArea = area(keyPoints) + 1e-6 + var oksTotal = 0f + var numValidKeyPoints = 0 + + person1.keyPoints.forEachIndexed { index, _ -> + val poseKpt = person1.keyPoints[index] + val trackKpt = person2.keyPoints[index] + val threshold = config.keyPointsTrackerParams.keypointThreshold + if (poseKpt.score < threshold || trackKpt.score < threshold) { + return@forEachIndexed + } + numValidKeyPoints += 1 + val dSquared: Float = + (poseKpt.coordinate.x - trackKpt.coordinate.x).pow(2) + (poseKpt.coordinate.y - trackKpt.coordinate.y).pow( + 2 + ) + val x = 2f * config.keyPointsTrackerParams.keypointFalloff[index] + oksTotal += exp(-1f * dSquared / (2f * boxArea * x.pow(2))).toFloat() + } + if (numValidKeyPoints < config.keyPointsTrackerParams.minNumKeyPoints) { + return 0f + } + return oksTotal / numValidKeyPoints + } + } + + /** + * Computes the area of a bounding box that tightly covers keypoints. + * @param keyPoints A list of KeyPoint. + * @returns The area of the object. + */ + @VisibleForTesting(otherwise = VisibleForTesting.PRIVATE) + fun area(keyPoints: List): Float { + val validKeypoint = keyPoints.filter { + it.score > config.keyPointsTrackerParams?.keypointThreshold ?: 0f + } + if (validKeypoint.isEmpty()) return 0f + val minX = min(1f, validKeypoint.minOf { it.coordinate.x }) + val maxX = max(0f, validKeypoint.maxOf { it.coordinate.x }) + val minY = min(1f, validKeypoint.minOf { it.coordinate.y }) + val maxY = max(0f, validKeypoint.maxOf { it.coordinate.y }) + return (maxX - minX) * (maxY - minY) + } +} diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/Track.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/Track.kt new file mode 100644 index 0000000..71cf895 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/Track.kt @@ -0,0 +1,24 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +import org.tensorflow.lite.examples.poseestimation.data.Person + +data class Track( + val person: Person, + val lastTimestamp: Long +) diff --git a/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/TrackerConfig.kt b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/TrackerConfig.kt new file mode 100644 index 0000000..e8cc998 --- /dev/null +++ b/android1/app/src/main/java/org/tensorflow/lite/examples/poseestimation/tracker/TrackerConfig.kt @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +package org.tensorflow.lite.examples.poseestimation.tracker + +data class TrackerConfig( + val maxTracks: Int = MAX_TRACKS, + val maxAge: Int = MAX_AGE, + val minSimilarity: Float = MIN_SIMILARITY, + val keyPointsTrackerParams: KeyPointsTrackerParams? = null +) { + companion object { + private const val MAX_TRACKS = 18 + private const val MAX_AGE = 1000 // millisecond + private const val MIN_SIMILARITY = 0.15f + } +} + +data class KeyPointsTrackerParams( + val keypointThreshold: Float = KEYPOINT_THRESHOLD, + // List of per-keypoint standard deviation `σ`, keypoints on a person's body (shoulders, knees, hips, etc.) + // tend to have a `σ` much larger than on a person's head (eyes, nose, ears). + // Read more at: https://cocodataset.org/#keypoints-eval + val keypointFalloff: List = KEYPOINT_FALLOFF, + val minNumKeyPoints: Int = MIN_NUM_KEYPOINT +) { + companion object { + // From COCO: + // https://cocodataset.org/#keypoints-eval + private val KEYPOINT_FALLOFF: List = listOf( + 0.026f, 0.025f, 0.025f, 0.035f, 0.035f, 0.079f, 0.079f, 0.072f, 0.072f, 0.062f, + 0.062f, 0.107f, 0.107f, 0.087f, 0.087f, 0.089f, 0.089f + ) + private const val KEYPOINT_THRESHOLD = 0.3f + private const val MIN_NUM_KEYPOINT = 4 + } +} diff --git a/android1/app/src/main/res/drawable-hdpi/ic_launcher.png b/android1/app/src/main/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000..de511b0 Binary files /dev/null and b/android1/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/android1/app/src/main/res/drawable-mdpi/ic_launcher.png b/android1/app/src/main/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000..5d50224 Binary files /dev/null and b/android1/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/android1/app/src/main/res/drawable-xxhdpi/icn_chevron_up.png b/android1/app/src/main/res/drawable-xxhdpi/icn_chevron_up.png new file mode 100644 index 0000000..1ec6a07 Binary files /dev/null and b/android1/app/src/main/res/drawable-xxhdpi/icn_chevron_up.png differ diff --git a/android1/app/src/main/res/drawable/rounded_edge.xml b/android1/app/src/main/res/drawable/rounded_edge.xml new file mode 100644 index 0000000..83f90ff --- /dev/null +++ b/android1/app/src/main/res/drawable/rounded_edge.xml @@ -0,0 +1,18 @@ + + + + + + + + + + diff --git a/android1/app/src/main/res/drawable/tfl2_logo.png b/android1/app/src/main/res/drawable/tfl2_logo.png new file mode 100644 index 0000000..48c5f33 Binary files /dev/null and b/android1/app/src/main/res/drawable/tfl2_logo.png differ diff --git a/android1/app/src/main/res/layout/activity_main.xml b/android1/app/src/main/res/layout/activity_main.xml new file mode 100644 index 0000000..bef90cf --- /dev/null +++ b/android1/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + diff --git a/android1/app/src/main/res/layout/bottom_sheet_layout.xml b/android1/app/src/main/res/layout/bottom_sheet_layout.xml new file mode 100644 index 0000000..9578722 --- /dev/null +++ b/android1/app/src/main/res/layout/bottom_sheet_layout.xml @@ -0,0 +1,124 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android1/app/src/main/res/values/colors.xml b/android1/app/src/main/res/values/colors.xml new file mode 100644 index 0000000..648cfde --- /dev/null +++ b/android1/app/src/main/res/values/colors.xml @@ -0,0 +1,9 @@ + + + #FF6200EE + #FF3700B3 + #FF03DAC5 + #FF018786 + #FF000000 + #FFFFFFFF + diff --git a/android1/app/src/main/res/values/strings.xml b/android1/app/src/main/res/values/strings.xml new file mode 100644 index 0000000..79bfc21 --- /dev/null +++ b/android1/app/src/main/res/values/strings.xml @@ -0,0 +1,30 @@ + + TFL Pose Estimation + This app needs camera permission. + Score: %.2f + Model: + Fps: %d + Device: + - %s + Pose Classification + Tracker: + Movenet MultiPose does not support GPU. Fallback to CPU. + + Movenet Lightning + Movenet Thunder + Movenet MultiPose + Posenet + + + + CPU + GPU + NNAPI + + + + Off + BoundingBox + Keypoint + + diff --git a/android1/app/src/main/res/values/themes.xml b/android1/app/src/main/res/values/themes.xml new file mode 100644 index 0000000..0414b39 --- /dev/null +++ b/android1/app/src/main/res/values/themes.xml @@ -0,0 +1,18 @@ + + + + diff --git a/android1/build.gradle b/android1/build.gradle new file mode 100644 index 0000000..51d10a1 --- /dev/null +++ b/android1/build.gradle @@ -0,0 +1,30 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +buildscript { + ext.kotlin_version = "1.9.21" + repositories { + google() + mavenCentral() + } + dependencies { + classpath "com.android.tools.build:gradle:8.0.2" + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + mavenCentral() + maven { + name 'ossrh-snapshot' + url 'https://oss.sonatype.org/content/repositories/snapshots' + } + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/android1/gradle.properties b/android1/gradle.properties new file mode 100644 index 0000000..cac7c68 --- /dev/null +++ b/android1/gradle.properties @@ -0,0 +1,21 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app"s APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Automatically convert third-party libraries to use AndroidX +android.enableJetifier=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official diff --git a/android1/gradle/wrapper/gradle-wrapper.jar b/android1/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000..d64cd49 Binary files /dev/null and b/android1/gradle/wrapper/gradle-wrapper.jar differ diff --git a/android1/gradle/wrapper/gradle-wrapper.properties b/android1/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..1af9e09 --- /dev/null +++ b/android1/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/android1/gradlew b/android1/gradlew new file mode 100755 index 0000000..1aa94a4 --- /dev/null +++ b/android1/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/android1/gradlew.bat b/android1/gradlew.bat new file mode 100755 index 0000000..93e3f59 --- /dev/null +++ b/android1/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/android1/local.properties b/android1/local.properties new file mode 100644 index 0000000..5112cf9 --- /dev/null +++ b/android1/local.properties @@ -0,0 +1,8 @@ +## This file must *NOT* be checked into Version Control Systems, +# as it contains information specific to your local configuration. +# +# Location of the SDK. This is only used by Gradle. +# For customization when using a Version Control System, please read the +# header note. +#Mon Apr 21 08:12:48 CST 2025 +sdk.dir=/Users/ziyue/Library/Android/sdk diff --git a/android1/posenetimage.png b/android1/posenetimage.png new file mode 100644 index 0000000..3f07229 Binary files /dev/null and b/android1/posenetimage.png differ diff --git a/android1/settings.gradle b/android1/settings.gradle new file mode 100644 index 0000000..2aebe71 --- /dev/null +++ b/android1/settings.gradle @@ -0,0 +1,2 @@ +include ':app' +rootProject.name = "TFLite Pose Estimation"