main
田子悦 3 weeks ago
parent fec3b0f056
commit 46c49dd1d9

BIN
.DS_Store vendored

Binary file not shown.

Binary file not shown.

@ -0,0 +1,2 @@
#Mon Apr 21 08:14:16 CST 2025
gradle.version=8.5

@ -0,0 +1,2 @@
#Mon Apr 21 08:12:48 CST 2025
java.home=/Applications/Android Studio.app/Contents/jbr/Contents/Home

Binary file not shown.

@ -0,0 +1,3 @@
# Default ignored files
/shelf/
/workspace.xml

@ -0,0 +1 @@
TFLite Pose Estimation

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AndroidProjectSystem">
<option name="providerId" value="com.android.tools.idea.GradleProjectSystem" />
</component>
</project>

@ -0,0 +1,607 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DeviceStreaming">
<option name="deviceSelectionList">
<list>
<PersistentDeviceSelectionData>
<option name="api" value="27" />
<option name="brand" value="DOCOMO" />
<option name="codename" value="F01L" />
<option name="id" value="F01L" />
<option name="labId" value="google" />
<option name="manufacturer" value="FUJITSU" />
<option name="name" value="F-01L" />
<option name="screenDensity" value="360" />
<option name="screenX" value="720" />
<option name="screenY" value="1280" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="OnePlus" />
<option name="codename" value="OP5552L1" />
<option name="id" value="OP5552L1" />
<option name="labId" value="google" />
<option name="manufacturer" value="OnePlus" />
<option name="name" value="CPH2415" />
<option name="screenDensity" value="480" />
<option name="screenX" value="1080" />
<option name="screenY" value="2412" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="OPPO" />
<option name="codename" value="OP573DL1" />
<option name="id" value="OP573DL1" />
<option name="labId" value="google" />
<option name="manufacturer" value="OPPO" />
<option name="name" value="CPH2557" />
<option name="screenDensity" value="480" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="28" />
<option name="brand" value="DOCOMO" />
<option name="codename" value="SH-01L" />
<option name="id" value="SH-01L" />
<option name="labId" value="google" />
<option name="manufacturer" value="SHARP" />
<option name="name" value="AQUOS sense2 SH-01L" />
<option name="screenDensity" value="480" />
<option name="screenX" value="1080" />
<option name="screenY" value="2160" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="Lenovo" />
<option name="codename" value="TB370FU" />
<option name="formFactor" value="Tablet" />
<option name="id" value="TB370FU" />
<option name="labId" value="google" />
<option name="manufacturer" value="Lenovo" />
<option name="name" value="Tab P12" />
<option name="screenDensity" value="340" />
<option name="screenX" value="1840" />
<option name="screenY" value="2944" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="a15" />
<option name="id" value="a15" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="A15" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="a35x" />
<option name="id" value="a35x" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="A35" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="31" />
<option name="brand" value="samsung" />
<option name="codename" value="a51" />
<option name="id" value="a51" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy A51" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="akita" />
<option name="id" value="akita" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 8a" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="motorola" />
<option name="codename" value="arcfox" />
<option name="id" value="arcfox" />
<option name="labId" value="google" />
<option name="manufacturer" value="Motorola" />
<option name="name" value="razr plus 2024" />
<option name="screenDensity" value="360" />
<option name="screenX" value="1080" />
<option name="screenY" value="1272" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="motorola" />
<option name="codename" value="austin" />
<option name="id" value="austin" />
<option name="labId" value="google" />
<option name="manufacturer" value="Motorola" />
<option name="name" value="moto g 5G (2022)" />
<option name="screenDensity" value="280" />
<option name="screenX" value="720" />
<option name="screenY" value="1600" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="samsung" />
<option name="codename" value="b0q" />
<option name="id" value="b0q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S22 Ultra" />
<option name="screenDensity" value="600" />
<option name="screenX" value="1440" />
<option name="screenY" value="3088" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="32" />
<option name="brand" value="google" />
<option name="codename" value="bluejay" />
<option name="id" value="bluejay" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 6a" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="caiman" />
<option name="id" value="caiman" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 9 Pro" />
<option name="screenDensity" value="360" />
<option name="screenX" value="960" />
<option name="screenY" value="2142" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="comet" />
<option name="default" value="true" />
<option name="id" value="comet" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 9 Pro Fold" />
<option name="screenDensity" value="390" />
<option name="screenX" value="2076" />
<option name="screenY" value="2152" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="29" />
<option name="brand" value="samsung" />
<option name="codename" value="crownqlteue" />
<option name="id" value="crownqlteue" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Note9" />
<option name="screenDensity" value="420" />
<option name="screenX" value="2220" />
<option name="screenY" value="1080" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="dm2q" />
<option name="id" value="dm2q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="S23 Plus" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="dm3q" />
<option name="id" value="dm3q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S23 Ultra" />
<option name="screenDensity" value="600" />
<option name="screenX" value="1440" />
<option name="screenY" value="3088" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="e1q" />
<option name="default" value="true" />
<option name="id" value="e1q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S24" />
<option name="screenDensity" value="480" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="e3q" />
<option name="id" value="e3q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S24 Ultra" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1440" />
<option name="screenY" value="3120" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="eos" />
<option name="id" value="eos" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Eos" />
<option name="screenDensity" value="320" />
<option name="screenX" value="384" />
<option name="screenY" value="384" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="felix" />
<option name="id" value="felix" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel Fold" />
<option name="screenDensity" value="420" />
<option name="screenX" value="2208" />
<option name="screenY" value="1840" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="felix" />
<option name="id" value="felix" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel Fold" />
<option name="screenDensity" value="420" />
<option name="screenX" value="2208" />
<option name="screenY" value="1840" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="felix_camera" />
<option name="id" value="felix_camera" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel Fold (Camera-enabled)" />
<option name="screenDensity" value="420" />
<option name="screenX" value="2208" />
<option name="screenY" value="1840" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="motorola" />
<option name="codename" value="fogona" />
<option name="id" value="fogona" />
<option name="labId" value="google" />
<option name="manufacturer" value="Motorola" />
<option name="name" value="moto g play - 2024" />
<option name="screenDensity" value="280" />
<option name="screenX" value="720" />
<option name="screenY" value="1600" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="g0q" />
<option name="id" value="g0q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="SM-S906U1" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="gta9pwifi" />
<option name="id" value="gta9pwifi" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="SM-X210" />
<option name="screenDensity" value="240" />
<option name="screenX" value="1200" />
<option name="screenY" value="1920" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="gts7xllite" />
<option name="id" value="gts7xllite" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="SM-T738U" />
<option name="screenDensity" value="340" />
<option name="screenX" value="1600" />
<option name="screenY" value="2560" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="samsung" />
<option name="codename" value="gts8uwifi" />
<option name="formFactor" value="Tablet" />
<option name="id" value="gts8uwifi" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Tab S8 Ultra" />
<option name="screenDensity" value="320" />
<option name="screenX" value="1848" />
<option name="screenY" value="2960" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="gts8wifi" />
<option name="formFactor" value="Tablet" />
<option name="id" value="gts8wifi" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Tab S8" />
<option name="screenDensity" value="274" />
<option name="screenX" value="1600" />
<option name="screenY" value="2560" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="gts9fe" />
<option name="id" value="gts9fe" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Tab S9 FE 5G" />
<option name="screenDensity" value="280" />
<option name="screenX" value="1440" />
<option name="screenY" value="2304" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="husky" />
<option name="id" value="husky" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 8 Pro" />
<option name="screenDensity" value="390" />
<option name="screenX" value="1008" />
<option name="screenY" value="2244" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="30" />
<option name="brand" value="motorola" />
<option name="codename" value="java" />
<option name="id" value="java" />
<option name="labId" value="google" />
<option name="manufacturer" value="Motorola" />
<option name="name" value="G20" />
<option name="screenDensity" value="280" />
<option name="screenX" value="720" />
<option name="screenY" value="1600" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="komodo" />
<option name="id" value="komodo" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 9 Pro XL" />
<option name="screenDensity" value="360" />
<option name="screenX" value="1008" />
<option name="screenY" value="2244" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="lynx" />
<option name="id" value="lynx" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 7a" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="motorola" />
<option name="codename" value="maui" />
<option name="id" value="maui" />
<option name="labId" value="google" />
<option name="manufacturer" value="Motorola" />
<option name="name" value="moto g play - 2023" />
<option name="screenDensity" value="280" />
<option name="screenX" value="720" />
<option name="screenY" value="1600" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="o1q" />
<option name="id" value="o1q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S21" />
<option name="screenDensity" value="421" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="31" />
<option name="brand" value="google" />
<option name="codename" value="oriole" />
<option name="id" value="oriole" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 6" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="panther" />
<option name="id" value="panther" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 7" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="q5q" />
<option name="id" value="q5q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Z Fold5" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1812" />
<option name="screenY" value="2176" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="q6q" />
<option name="id" value="q6q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy Z Fold6" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1856" />
<option name="screenY" value="2160" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="30" />
<option name="brand" value="google" />
<option name="codename" value="r11" />
<option name="formFactor" value="Wear OS" />
<option name="id" value="r11" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel Watch" />
<option name="screenDensity" value="320" />
<option name="screenX" value="384" />
<option name="screenY" value="384" />
<option name="type" value="WEAR_OS" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="r11q" />
<option name="id" value="r11q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="SM-S711U" />
<option name="screenDensity" value="450" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="30" />
<option name="brand" value="google" />
<option name="codename" value="redfin" />
<option name="id" value="redfin" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 5" />
<option name="screenDensity" value="440" />
<option name="screenX" value="1080" />
<option name="screenY" value="2340" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="shiba" />
<option name="id" value="shiba" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 8" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="samsung" />
<option name="codename" value="t2q" />
<option name="id" value="t2q" />
<option name="labId" value="google" />
<option name="manufacturer" value="Samsung" />
<option name="name" value="Galaxy S21 Plus" />
<option name="screenDensity" value="394" />
<option name="screenX" value="1080" />
<option name="screenY" value="2400" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="33" />
<option name="brand" value="google" />
<option name="codename" value="tangorpro" />
<option name="formFactor" value="Tablet" />
<option name="id" value="tangorpro" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel Tablet" />
<option name="screenDensity" value="320" />
<option name="screenX" value="1600" />
<option name="screenY" value="2560" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="34" />
<option name="brand" value="google" />
<option name="codename" value="tokay" />
<option name="default" value="true" />
<option name="id" value="tokay" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 9" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2424" />
</PersistentDeviceSelectionData>
<PersistentDeviceSelectionData>
<option name="api" value="35" />
<option name="brand" value="google" />
<option name="codename" value="tokay" />
<option name="default" value="true" />
<option name="id" value="tokay" />
<option name="labId" value="google" />
<option name="manufacturer" value="Google" />
<option name="name" value="Pixel 9" />
<option name="screenDensity" value="420" />
<option name="screenX" value="1080" />
<option name="screenY" value="2424" />
</PersistentDeviceSelectionData>
</list>
</option>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CompilerConfiguration">
<bytecodeTargetLevel target="21" />
</component>
</project>

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="deploymentTargetSelector">
<selectionStates>
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
</SelectionState>
</selectionStates>
</component>
</project>

@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="testRunner" value="CHOOSE_PER_TEST" />
<option name="externalProjectPath" value="$PROJECT_DIR$" />
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
<option name="modules">
<set>
<option value="$PROJECT_DIR$" />
<option value="$PROJECT_DIR$/app" />
</set>
</option>
</GradleProjectSettings>
</option>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="KotlinJpsPluginSettings">
<option name="version" value="1.9.21" />
</component>
</project>

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectMigrations">
<option name="MigrateToGradleLocalJavaHome">
<set>
<option value="$PROJECT_DIR$" />
</set>
</option>
</component>
</project>

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="jbr-21" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/build/classes" />
</component>
<component name="ProjectType">
<option name="id" value="Android" />
</component>
</project>

@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="RunConfigurationProducerService">
<option name="ignoredProducers">
<set>
<option value="com.intellij.execution.junit.AbstractAllInDirectoryConfigurationProducer" />
<option value="com.intellij.execution.junit.AllInPackageConfigurationProducer" />
<option value="com.intellij.execution.junit.PatternConfigurationProducer" />
<option value="com.intellij.execution.junit.TestInClassConfigurationProducer" />
<option value="com.intellij.execution.junit.UniqueIdConfigurationProducer" />
<option value="com.intellij.execution.junit.testDiscovery.JUnitTestDiscoveryConfigurationProducer" />
<option value="org.jetbrains.kotlin.idea.junit.KotlinJUnitRunConfigurationProducer" />
<option value="org.jetbrains.kotlin.idea.junit.KotlinPatternConfigurationProducer" />
</set>
</option>
</component>
</project>

@ -0,0 +1,73 @@
# TensorFlow Lite Pose Estimation Android Demo
### Overview
This is an app that continuously detects the body parts in the frames seen by
your device's camera. These instructions walk you through building and running
the demo on an Android device. Camera captures are discarded immediately after
use, nothing is stored or saved.
The app demonstrates how to use 4 models:
* Single pose models: The model can estimate the pose of only one person in the
input image. If the input image contains multiple persons, the detection result
can be largely incorrect.
* PoseNet
* MoveNet Lightning
* MoveNet Thunder
* Multi pose models: The model can estimate pose of multiple persons in the
input image.
* MoveNet MultiPose: Support up to 6 persons.
See this [blog post](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html)
for a comparison between these models.
![Demo Image](posenetimage.png)
## Build the demo using Android Studio
### Prerequisites
* If you don't have it already, install **[Android Studio](
https://developer.android.com/studio/index.html)** 4.2 or
above, following the instructions on the website.
* Android device and Android development environment with minimum API 21.
### Building
* Open Android Studio, and from the `Welcome` screen, select
`Open an existing Android Studio project`.
* From the `Open File or Project` window that appears, navigate to and select
the `lite/examples/pose_estimation/android` directory from wherever you
cloned the `tensorflow/examples` GitHub repo. Click `OK`.
* If it asks you to do a `Gradle Sync`, click `OK`.
* You may also need to install various platforms and tools, if you get errors
like `Failed to find target with hash string 'android-21'` and similar. Click
the `Run` button (the green arrow) or select `Run` > `Run 'android'` from the
top menu. You may need to rebuild the project using `Build` > `Rebuild Project`.
* If it asks you to use `Instant Run`, click `Proceed Without Instant Run`.
* Also, you need to have an Android device plugged in with developer options
enabled at this point. See **[here](
https://developer.android.com/studio/run/device)** for more details
on setting up developer devices.
### Model used
Downloading, extraction and placement in assets folder has been managed
automatically by `download.gradle`.
If you explicitly want to download the model, you can download it from here:
* [Posenet](https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite)
* [Movenet Lightning](https://kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-lightning)
* [Movenet Thunder](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-thunder)
* [Movenet MultiPose](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/multipose-lightning-tflite-float16)
### Additional Note
_Please do not delete the assets folder content_. If you explicitly deleted the
files, then please choose `Build` > `Rebuild` from menu to re-download the
deleted model files into assets folder.

@ -0,0 +1 @@
/build

@ -0,0 +1,56 @@
plugins {
id 'com.android.application'
id 'kotlin-android'
}
android {
compileSdkVersion 30
buildToolsVersion "30.0.3"
defaultConfig {
applicationId "org.tensorflow.lite.examples.poseestimation"
minSdkVersion 23
targetSdkVersion 30
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
namespace "org.tensorflow.lite.examples.poseestimation"
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = '1.8'
}
}
// Download tflite model
apply from:"download.gradle"
dependencies {
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
implementation 'androidx.core:core-ktx:1.5.0'
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'com.google.android.material:material:1.3.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
implementation "androidx.activity:activity-ktx:1.2.3"
implementation 'androidx.fragment:fragment-ktx:1.3.5'
implementation 'org.tensorflow:tensorflow-lite:2.14.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
androidTestImplementation "com.google.truth:truth:1.1.3"
}

@ -0,0 +1,67 @@
task downloadPosenetModel(type: DownloadUrlTask) {
def modelPosenetDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite"
doFirst {
println "Downloading ${modelPosenetDownloadUrl}"
}
sourceUrl = "${modelPosenetDownloadUrl}"
target = file("src/main/assets/posenet.tflite")
}
task downloadMovenetLightningModel(type: DownloadUrlTask) {
def modelMovenetLightningDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetLightningDownloadUrl}"
}
sourceUrl = "${modelMovenetLightningDownloadUrl}"
target = file("src/main/assets/movenet_lightning.tflite")
}
task downloadMovenetThunderModel(type: DownloadUrlTask) {
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetThunderDownloadUrl}"
}
sourceUrl = "${modelMovenetThunderDownloadUrl}"
target = file("src/main/assets/movenet_thunder.tflite")
}
task downloadMovenetMultiPoseModel(type: DownloadUrlTask) {
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetThunderDownloadUrl}"
}
sourceUrl = "${modelMovenetThunderDownloadUrl}"
target = file("src/main/assets/movenet_multipose_fp16.tflite")
}
task downloadPoseClassifierModel(type: DownloadUrlTask) {
def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite"
doFirst {
println "Downloading ${modelPoseClassifierDownloadUrl}"
}
sourceUrl = "${modelPoseClassifierDownloadUrl}"
target = file("src/main/assets/classifier.tflite")
}
task downloadModel {
dependsOn downloadPosenetModel
dependsOn downloadMovenetLightningModel
dependsOn downloadMovenetThunderModel
dependsOn downloadPoseClassifierModel
dependsOn downloadMovenetMultiPoseModel
}
class DownloadUrlTask extends DefaultTask {
@Input
String sourceUrl
@OutputFile
File target
@TaskAction
void download() {
ant.get(src: sourceUrl, dest: target)
}
}
preBuild.dependsOn downloadModel

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

@ -0,0 +1,3 @@
Image1: https://pixabay.com/illustrations/woman-stand-wait-person-shoes-1427073/
Image2: https://pixabay.com/photos/businessman-suit-germany-black-1146791/
Image3: https://pixabay.com/photos/tree-pose-yoga-yogini-lifestyle-4823155/

@ -0,0 +1,3 @@
nose_x,nose_y,left_eye_x,left_eye_y,right_eye_x,right_eye_y,left_ear_x,left_ear_y,right_ear_x,right_ear_y,left_shoulder_x,left_shoulder_y,right_shoulder_x,right_shoulder_y,left_elbow_x,left_elbow_y,right_elbow_x,right_elbow_y,left_wrist_x,left_wrist_y,right_wrist_x,right_wrist_y,left_hip_x,left_hip_y,right_hip_x,right_hip_y,left_knee_x,left_knee_y,right_knee_x,right_knee_y,left_ankle_x,left_ankle_y,right_ankle_x,right_ankle_y
186,89,200,77,177,78,224,86,167,85,244,158,154,154,258,248,143,239,265,327,136,313,234,311,170,311,247,446,134,445,262,561,92,571
182,84,191,73,171,74,202,75,157,77,220,119,139,136,260,192,185,230,268,209,246,217,221,288,176,294,205,421,174,421,186,538,155,564
1 nose_x nose_y left_eye_x left_eye_y right_eye_x right_eye_y left_ear_x left_ear_y right_ear_x right_ear_y left_shoulder_x left_shoulder_y right_shoulder_x right_shoulder_y left_elbow_x left_elbow_y right_elbow_x right_elbow_y left_wrist_x left_wrist_y right_wrist_x right_wrist_y left_hip_x left_hip_y right_hip_x right_hip_y left_knee_x left_knee_y right_knee_x right_knee_y left_ankle_x left_ankle_y right_ankle_x right_ankle_y
2 186 89 200 77 177 78 224 86 167 85 244 158 154 154 258 248 143 239 265 327 136 313 234 311 170 311 247 446 134 445 262 561 92 571
3 182 84 191 73 171 74 202 75 157 77 220 119 139 136 260 192 185 230 268 209 246 217 221 288 176 294 205 421 174 421 186 538 155 564

@ -0,0 +1,113 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Canvas
import android.graphics.PointF
import androidx.test.platform.app.InstrumentationRegistry
import com.google.common.truth.Truth.assertThat
import com.google.common.truth.Truth.assertWithMessage
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Person
import java.io.BufferedReader
import java.io.InputStreamReader
import kotlin.math.pow
object EvaluationUtils {
/**
* Assert whether the detected person from the image match with the expected result.
* Detection result is accepted as correct if it is within the acceptableError range from the
* expected result.
*/
fun assertPoseDetectionResult(
person: Person,
expectedResult: Map<BodyPart, PointF>,
acceptableError: Float
) {
// Check if the model is confident enough in detecting the person
assertThat(person.score).isGreaterThan(0.5f)
for ((bodyPart, expectedPointF) in expectedResult) {
val keypoint = person.keyPoints.firstOrNull { it.bodyPart == bodyPart }
assertWithMessage("$bodyPart must exist").that(keypoint).isNotNull()
val detectedPointF = keypoint!!.coordinate
val distanceFromExpectedPointF = distance(detectedPointF, expectedPointF)
assertWithMessage("Detected $bodyPart must be close to expected result")
.that(distanceFromExpectedPointF).isAtMost(acceptableError)
}
}
/**
* Load an image from assets folder using its name.
*/
fun loadBitmapAssetByName(name: String): Bitmap {
val testContext = InstrumentationRegistry.getInstrumentation().context
val testInput = testContext.assets.open(name)
return BitmapFactory.decodeStream(testInput)
}
/**
* Load csv from assets folder
*/
fun loadCSVAsset(name: String): List<Map<BodyPart, PointF>> {
val data = mutableListOf<Map<BodyPart, PointF>>()
val testContext = InstrumentationRegistry.getInstrumentation().context
val testInput = testContext.assets.open(name)
val inputStreamReader = InputStreamReader(testInput)
val reader = BufferedReader(inputStreamReader)
// Skip header line
reader.readLine()
// Read expected coordinates from each following lines
reader.forEachLine {
val listPoint = it.split(",")
val map = mutableMapOf<BodyPart, PointF>()
for (i in listPoint.indices step 2) {
map[BodyPart.fromInt(i / 2)] =
PointF(listPoint[i].toFloat(), listPoint[i + 1].toFloat())
}
data.add(map)
}
return data
}
/**
* Calculate the distance between two points
*/
private fun distance(point1: PointF, point2: PointF): Float {
return ((point1.x - point2.x).pow(2) + (point1.y - point2.y).pow(2)).pow(0.5f)
}
/**
* Concatenate images of same height horizontally
*/
fun hConcat(image1: Bitmap, image2: Bitmap): Bitmap {
if (image1.height != image2.height) {
throw Exception("Input images are not same height.")
}
val finalBitmap =
Bitmap.createBitmap(image1.width + image2.width, image1.height, Bitmap.Config.ARGB_8888)
val canvas = Canvas(finalBitmap)
canvas.drawBitmap(image1, 0f, 0f, null)
canvas.drawBitmap(image2, image1.width.toFloat(), 0f, null)
return finalBitmap
}
}

@ -0,0 +1,83 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
@RunWith(AndroidJUnit4::class)
class MovenetLightningTest {
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 21f
}
private lateinit var poseDetector: PoseDetector
private lateinit var appContext: Context
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
}
@Test
fun testPoseEstimationResultWithImage1() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[0],
ACCEPTABLE_ERROR
)
}
@Test
fun testPoseEstimationResultWithImage2() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}

@ -0,0 +1,81 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.Bitmap
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
import org.tensorflow.lite.examples.poseestimation.ml.Type
@RunWith(AndroidJUnit4::class)
class MovenetMultiPoseTest {
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 17f
}
private lateinit var poseDetector: MoveNetMultiPose
private lateinit var appContext: Context
private lateinit var inputFinal: Bitmap
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = MoveNetMultiPose.create(appContext, Device.CPU, Type.Dynamic)
val input1 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
val input2 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
inputFinal = EvaluationUtils.hConcat(input1, input2)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
// update coordination of the pose_landmark_truth.csv corresponding to the new input image
for ((_, value) in expectedDetectionResult[1]) {
value.x = value.x + input1.width
}
}
@Test
fun testPoseEstimateResult() {
val persons = poseDetector.estimatePoses(inputFinal)
assert(persons.size == 2)
// Sort the results so that the person on the right side come first.
val sortedPersons = persons.sortedBy { it.boundingBox?.left }
EvaluationUtils.assertPoseDetectionResult(
sortedPersons[0],
expectedDetectionResult[0],
ACCEPTABLE_ERROR
)
EvaluationUtils.assertPoseDetectionResult(
sortedPersons[1],
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}

@ -0,0 +1,83 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
@RunWith(AndroidJUnit4::class)
class MovenetThunderTest {
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 15f
}
private lateinit var poseDetector: PoseDetector
private lateinit var appContext: Context
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
}
@Test
fun testPoseEstimationResultWithImage1() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[0],
ACCEPTABLE_ERROR
)
}
@Test
fun testPoseEstimationResultWithImage2() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}

@ -0,0 +1,64 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import junit.framework.TestCase
import junit.framework.TestCase.assertEquals
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.Device
@RunWith(AndroidJUnit4::class)
class PoseClassifierTest {
companion object {
private const val TEST_INPUT_IMAGE = "image3.jpeg"
}
private lateinit var appContext: Context
private lateinit var poseDetector: PoseDetector
private lateinit var poseClassifier: PoseClassifier
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
poseClassifier = PoseClassifier.create(appContext)
}
@Test
fun testPoseClassifier() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE)
// As Movenet use previous frame to optimize detection result, we run it multiple times
// using the same image to improve result.
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
poseDetector.estimatePoses(input)
val person = poseDetector.estimatePoses(input)[0]
val classificationResult = poseClassifier.classify(person)
val predictedPose = classificationResult.maxByOrNull { it.second }?.first ?: "n/a"
assertEquals(
"Predicted pose is different from ground truth.",
"tree",
predictedPose
)
}
}

@ -0,0 +1,71 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
@RunWith(AndroidJUnit4::class)
class PosenetTest {
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 37f
}
private lateinit var poseDetector: PoseDetector
private lateinit var appContext: Context
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = PoseNet.create(appContext, Device.CPU)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
}
@Test
fun testPoseEstimationResultWithImage1() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[0],
ACCEPTABLE_ERROR
)
}
@Test
fun testPoseEstimationResultWithImage2() {
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
val person = poseDetector.estimatePoses(input)[0]
EvaluationUtils.assertPoseDetectionResult(
person,
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}

@ -0,0 +1,82 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.Bitmap
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import com.google.common.truth.Truth.assertThat
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.VisualizationUtils
import org.tensorflow.lite.examples.poseestimation.data.Device
/**
* This test is used to visually verify detection results by the models.
* You can put a breakpoint at the end of the method, debug this method, than use the
* "View Bitmap" feature of the debugger to check the visualized detection result.
*/
@RunWith(AndroidJUnit4::class)
class VisualizationTest {
companion object {
private const val TEST_INPUT_IMAGE = "image2.jpg"
}
private lateinit var appContext: Context
private lateinit var inputBitmap: Bitmap
@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
inputBitmap = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE)
}
@Test
fun testPosenet() {
val poseDetector = PoseNet.create(appContext, Device.CPU)
val person = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
assertThat(outputBitmap).isNotNull()
}
@Test
fun testMovenetLightning() {
// Due to Movenet's cropping logic, we run inference several times with the same input
// image to improve accuracy
val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
poseDetector.estimatePoses(inputBitmap)
poseDetector.estimatePoses(inputBitmap)
val person2 = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap2 = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person2))
assertThat(outputBitmap2).isNotNull()
}
@Test
fun testMovenetThunder() {
// Due to Movenet's cropping logic, we run inference several times with the same input
// image to improve accuracy
val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder)
poseDetector.estimatePoses(inputBitmap)
poseDetector.estimatePoses(inputBitmap)
val person = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
assertThat(outputBitmap).isNotNull()
}
}

@ -0,0 +1,228 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import android.graphics.RectF
import androidx.test.ext.junit.runners.AndroidJUnit4
import junit.framework.TestCase.assertEquals
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.Person
@RunWith(AndroidJUnit4::class)
class BoundingBoxTrackerTest {
companion object {
private const val MAX_TRACKS = 4
private const val MAX_AGE = 1000 // Unit: milliseconds.
private const val MIN_SIMILARITY = 0.5f
}
private lateinit var boundingBoxTracker: BoundingBoxTracker
@Before
fun setup() {
val trackerConfig = TrackerConfig(MAX_TRACKS, MAX_AGE, MIN_SIMILARITY)
boundingBoxTracker = BoundingBoxTracker(trackerConfig)
}
@Test
fun testIoU() {
val persons = Person(
-1, listOf(), RectF(
0f,
0f,
2f / 3,
1f
), 1f
)
val track =
Track(
Person(
-1,
listOf(),
RectF(
1 / 3f,
0.0f,
1f,
1f,
), 1f
), 1000000
)
val computedIoU = boundingBoxTracker.iou(persons, track.person)
assertEquals("Wrong IoU value.", 1f / 3, computedIoU, 0.000001f)
}
@Test
fun testIoUFullOverlap() {
val persons = Person(
-1, listOf(),
RectF(
0f,
0f,
1f,
1f
), 1f
)
val track =
Track(
Person(
-1,
listOf(),
RectF(
0f,
0f,
1f,
1f,
), 1f
), 1000000
)
val computedIoU = boundingBoxTracker.iou(persons, track.person)
assertEquals("Wrong IoU value.", 1f, computedIoU, 0.000001f)
}
@Test
fun testIoUNoIntersection() {
val persons = Person(
-1, listOf(),
RectF(
0f,
0f,
0.5f,
0.5f
), 1f
)
val track =
Track(
Person(
-1,
listOf(),
RectF(
0.5f,
0.5f,
1f,
1f,
), 1f
), 1000000
)
val computedIoU = boundingBoxTracker.iou(persons, track.person)
assertEquals("Wrong IoU value.", 0f, computedIoU, 0.000001f)
}
@Test
fun testBoundingBoxTracking() {
// Timestamp: 0. Poses becomes the first two tracks.
var persons = listOf(
Person( // Becomes track 1.
-1, listOf(), RectF(
0f,
0f,
0.5f,
0.5f,
), 1f
),
Person( // Becomes track 2.
-1, listOf(), RectF(
0f,
0f,
1f,
1f
), 1f
)
)
persons = boundingBoxTracker.apply(persons, 0)
var track = boundingBoxTracker.tracks
assertEquals(2, persons.size)
assertEquals(1, persons[0].id)
assertEquals(2, persons[1].id)
assertEquals(2, track.size)
assertEquals(1, track[0].person.id)
assertEquals(0, track[0].lastTimestamp)
assertEquals(2, track[1].person.id)
assertEquals(0, track[1].lastTimestamp)
// Timestamp: 100000. First pose is linked with track 1. Second pose spawns
// a new track (id = 2).
persons = listOf(
Person( // Linked with track 1.
-1, listOf(), RectF(
0.1f,
0.1f,
0.5f,
0.5f
), 1f
),
Person( // Becomes track 3.
-1, listOf(), RectF(
0.2f,
0.3f,
0.9f,
0.9f
), 1f
)
)
persons = boundingBoxTracker.apply(persons, 100000)
track = boundingBoxTracker.tracks
assertEquals(2, persons.size)
assertEquals(1, persons[0].id)
assertEquals(3, persons[1].id)
assertEquals(3, track.size)
assertEquals(1, track[0].person.id)
assertEquals(100000, track[0].lastTimestamp)
assertEquals(3, track[1].person.id)
assertEquals(100000, track[1].lastTimestamp)
assertEquals(2, track[2].person.id)
assertEquals(0, track[2].lastTimestamp)
// Timestamp: 1050000. First pose is linked with track 1. Second pose is
// identical to track 2, but is not linked because track 2 is deleted due to
// age. Instead it spawns track 4.
persons = listOf(
Person( // Linked with track 1.
-1, listOf(), RectF(
0.1f,
0.1f,
0.55f,
0.5f
), 1f
),
Person( // Becomes track 4.
-1, listOf(), RectF(
0f,
0f,
1f,
1f
), 1f
)
)
persons = boundingBoxTracker.apply(persons, 1050000)
track = boundingBoxTracker.tracks
assertEquals(2, persons.size)
assertEquals(1, persons[0].id)
assertEquals(4, persons[1].id)
assertEquals(3, track.size)
assertEquals(1, track[0].person.id)
assertEquals(1050000, track[0].lastTimestamp)
assertEquals(4, track[1].person.id)
assertEquals(1050000, track[1].lastTimestamp)
assertEquals(3, track[2].person.id)
assertEquals(100000, track[2].lastTimestamp)
}
}

@ -0,0 +1,308 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import junit.framework.Assert
import junit.framework.TestCase.assertEquals
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
import org.tensorflow.lite.examples.poseestimation.data.Person
import kotlin.math.exp
import kotlin.math.pow
@RunWith(AndroidJUnit4::class)
class KeyPointsTrackerTest {
companion object {
private const val MAX_TRACKS = 4
private const val MAX_AGE = 1000
private const val MIN_SIMILARITY = 0.5f
private const val KEYPOINT_THRESHOLD = 0.2f
private const val MIN_NUM_KEYPOINT = 2
private val KEYPOINT_FALLOFF = listOf(0.1f, 0.1f, 0.1f, 0.1f)
}
private lateinit var keyPointsTracker: KeyPointsTracker
@Before
fun setup() {
val trackerConfig = TrackerConfig(
MAX_TRACKS, MAX_AGE, MIN_SIMILARITY,
KeyPointsTrackerParams(KEYPOINT_THRESHOLD, KEYPOINT_FALLOFF, MIN_NUM_KEYPOINT)
)
keyPointsTracker = KeyPointsTracker(trackerConfig)
}
@Test
fun testOks() {
val persons =
Person(
-1, listOf(
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.7f), 0.8f),
), score = 1f
)
val tracks =
Track(
Person(
0, listOf(
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
), score = 1f
),
1000000,
)
val oks = keyPointsTracker.oks(persons, tracks.person)
val boxArea = (0.8f - 0.2f) * (0.8f - 0.2f)
val x = 2f * KEYPOINT_FALLOFF[3]
val d = 0.1f
val expectedOks: Float =
(1f + 1f + exp(-1f * d.pow(2) / (2f * boxArea * x.pow(2)))) / 3f
assertEquals(expectedOks, oks, 0.000001f)
}
@Test
fun testOksReturnZero() {
// Compute OKS returns 0.0 with less than 2 valid keypoints
val persons =
Person(
-1, listOf(
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.1f), // Low confidence.
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
), score = 1f
)
val tracks =
Track(
Person(
0, listOf(
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f),// Low confidence.
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f),// Low confidence.
), score = 1f
), 1000000
)
val oks = keyPointsTracker.oks(persons, tracks.person)
assertEquals(0f, oks, 0.000001f)
}
@Test
fun testArea() {
val keyPoints = listOf(
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.3f, 0.4f), 0.9f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.7f, 0.8f), 0.1f),
)
val area = keyPointsTracker.area(keyPoints)
val expectedArea = (0.4f - 0.1f) * (0.6f - 0.2f)
assertEquals(expectedArea, area)
}
@Test
fun testKeyPointsTracker() {
// Timestamp: 0. Person becomes the only track.
var persons = listOf(
Person(
-1, listOf(
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f),
), score = 0.9f
)
)
persons = keyPointsTracker.apply(persons, 0)
var track = keyPointsTracker.tracks
assertEquals(1, persons.size)
assertEquals(1, persons[0].id)
assertEquals(1, track.size)
assertEquals(1, track[0].person.id)
assertEquals(0, track[0].lastTimestamp)
// Timestamp: 100000. First person is linked with track 1. Second person spawns
// a new track (id = 2).
persons = listOf(
Person(
-1,
listOf(
// Links with id = 1.
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
),
score = 1f
),
Person(
-1,
listOf(
// Becomes id = 2.
KeyPoint(BodyPart.NOSE, PointF(0.8f, 0.8f), 0.8f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.6f, 0.6f), 0.3f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.4f), 0.1f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.2f), 0.8f),
),
score = 1f
)
)
persons = keyPointsTracker.apply(persons, 100000)
track = keyPointsTracker.tracks
assertEquals(2, persons.size)
assertEquals(1, persons[0].id)
assertEquals(2, persons[1].id)
assertEquals(2, track.size)
assertEquals(1, track[0].person.id)
assertEquals(100000, track[0].lastTimestamp)
assertEquals(2, track[1].person.id)
assertEquals(100000, track[1].lastTimestamp)
// Timestamp: 900000. First person is linked with track 2. Second person spawns
// a new track (id = 3).
persons = listOf(
Person(
-1,
listOf(
// Links with id = 2.
KeyPoint(BodyPart.NOSE, PointF(0.6f, 0.7f), 0.7f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.7f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.0f, 0.0f), 0.1f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.1f), 1f),
),
score = 1f
),
Person(
-1,
listOf(
// Becomes id = 3.
KeyPoint(BodyPart.NOSE, PointF(0.5f, 0.1f), 0.6f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.9f, 0.3f), 0.6f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.1f, 1f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.1f),
),
score = 1f
)
)
persons = keyPointsTracker.apply(persons, 900000)
track = keyPointsTracker.tracks
assertEquals(2, persons.size)
assertEquals(2, persons[0].id)
assertEquals(3, persons[1].id)
assertEquals(3, track.size)
assertEquals(2, track[0].person.id)
assertEquals(900000, track[0].lastTimestamp)
assertEquals(3, track[1].person.id)
assertEquals(900000, track[1].lastTimestamp)
assertEquals(1, track[2].person.id)
assertEquals(100000, track[2].lastTimestamp)
// Timestamp: 1200000. First person spawns a new track (id = 4), even though
// it has the same keypoints as track 1. This is because the age exceeds
// 1000 msec. The second person links with id 2. The third person spawns a new
// track (id = 5).
persons = listOf(
Person(
-1,
listOf(
// Becomes id = 4.
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
),
score = 1f
),
Person(
-1,
listOf(
// Links with id = 2.
KeyPoint(BodyPart.NOSE, PointF(0.55f, 0.7f), 0.7f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.9f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(1f, 1f), 0.1f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.1f), 0f),
),
score = 1f
),
Person(
-1,
listOf(
// Becomes id = 5.
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.1f), 0.1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.2f), 0.9f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.3f, 0.3f), 0.7f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.8f),
),
score = 1f
)
)
persons = keyPointsTracker.apply(persons, 1200000)
track = keyPointsTracker.tracks
assertEquals(3, persons.size)
assertEquals(4, persons[0].id)
assertEquals(2, persons[1].id)
assertEquals(4, track.size)
assertEquals(2, track[0].person.id)
assertEquals(1200000, track[0].lastTimestamp)
assertEquals(4, track[1].person.id)
assertEquals(1200000, track[1].lastTimestamp)
assertEquals(5, track[2].person.id)
assertEquals(1200000, track[2].lastTimestamp)
assertEquals(3, track[3].person.id)
assertEquals(900000, track[3].lastTimestamp)
// Timestamp: 1300000. First person spawns a new track (id = 6). Since
// maxTracks is 4, the oldest track (id = 3) is removed.
persons = listOf(
Person(
-1,
listOf(
// Becomes id = 6.
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.8f), 1f),
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.9f), 0.6f),
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.2f, 0.9f), 0.5f),
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.2f), 0.4f),
),
score = 1f
)
)
persons = keyPointsTracker.apply(persons, 1300000)
track = keyPointsTracker.tracks
assertEquals(1, persons.size)
assertEquals(6, persons[0].id)
assertEquals(4, track.size)
assertEquals(6, track[0].person.id)
assertEquals(1300000, track[0].lastTimestamp)
assertEquals(2, track[1].person.id)
assertEquals(1200000, track[1].lastTimestamp)
assertEquals(4, track[2].person.id)
assertEquals(1200000, track[2].lastTimestamp)
assertEquals(5, track[3].person.id)
assertEquals(1200000, track[3].lastTimestamp)
}
}

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.examples.poseestimation">
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature
android:name="android.hardware.camera"
android:required="true" />
<application
android:allowBackup="true"
android:icon="@drawable/ic_launcher"
android:label="@string/tfe_pe_app_name"
android:roundIcon="@drawable/ic_launcher"
android:supportsRtl="true"
android:theme="@style/Theme.PoseEstimation">
<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

@ -0,0 +1,5 @@
chair
cobra
dog
tree
warrior

@ -0,0 +1,430 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation
import android.Manifest
import android.app.AlertDialog
import android.app.Dialog
import android.content.pm.PackageManager
import android.os.Bundle
import android.os.Process
import android.view.SurfaceView
import android.view.View
import android.view.WindowManager
import android.widget.*
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.appcompat.widget.SwitchCompat
import androidx.core.content.ContextCompat
import androidx.fragment.app.DialogFragment
import androidx.lifecycle.lifecycleScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import org.tensorflow.lite.examples.poseestimation.camera.CameraSource
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.ml.*
class MainActivity : AppCompatActivity() {
companion object {
private const val FRAGMENT_DIALOG = "dialog"
}
/** A [SurfaceView] for camera preview. */
private lateinit var surfaceView: SurfaceView
/** Default pose estimation model is 1 (MoveNet Thunder)
* 0 == MoveNet Lightning model
* 1 == MoveNet Thunder model
* 2 == MoveNet MultiPose model
* 3 == PoseNet model
**/
private var modelPos = 1
/** Default device is CPU */
private var device = Device.CPU
private lateinit var tvScore: TextView
private lateinit var tvFPS: TextView
private lateinit var spnDevice: Spinner
private lateinit var spnModel: Spinner
private lateinit var spnTracker: Spinner
private lateinit var vTrackerOption: View
private lateinit var tvClassificationValue1: TextView
private lateinit var tvClassificationValue2: TextView
private lateinit var tvClassificationValue3: TextView
private lateinit var swClassification: SwitchCompat
private lateinit var vClassificationOption: View
private var cameraSource: CameraSource? = null
private var isClassifyPose = false
private val requestPermissionLauncher =
registerForActivityResult(
ActivityResultContracts.RequestPermission()
) { isGranted: Boolean ->
if (isGranted) {
// Permission is granted. Continue the action or workflow in your
// app.
openCamera()
} else {
// Explain to the user that the feature is unavailable because the
// features requires a permission that the user has denied. At the
// same time, respect the user's decision. Don't link to system
// settings in an effort to convince the user to change their
// decision.
ErrorDialog.newInstance(getString(R.string.tfe_pe_request_permission))
.show(supportFragmentManager, FRAGMENT_DIALOG)
}
}
private var changeModelListener = object : AdapterView.OnItemSelectedListener {
override fun onNothingSelected(parent: AdapterView<*>?) {
// do nothing
}
override fun onItemSelected(
parent: AdapterView<*>?,
view: View?,
position: Int,
id: Long
) {
changeModel(position)
}
}
private var changeDeviceListener = object : AdapterView.OnItemSelectedListener {
override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) {
changeDevice(position)
}
override fun onNothingSelected(parent: AdapterView<*>?) {
// do nothing
}
}
private var changeTrackerListener = object : AdapterView.OnItemSelectedListener {
override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) {
changeTracker(position)
}
override fun onNothingSelected(parent: AdapterView<*>?) {
// do nothing
}
}
private var setClassificationListener =
CompoundButton.OnCheckedChangeListener { _, isChecked ->
showClassificationResult(isChecked)
isClassifyPose = isChecked
isPoseClassifier()
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
// keep screen on while app is running
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
tvScore = findViewById(R.id.tvScore)
tvFPS = findViewById(R.id.tvFps)
spnModel = findViewById(R.id.spnModel)
spnDevice = findViewById(R.id.spnDevice)
spnTracker = findViewById(R.id.spnTracker)
vTrackerOption = findViewById(R.id.vTrackerOption)
surfaceView = findViewById(R.id.surfaceView)
tvClassificationValue1 = findViewById(R.id.tvClassificationValue1)
tvClassificationValue2 = findViewById(R.id.tvClassificationValue2)
tvClassificationValue3 = findViewById(R.id.tvClassificationValue3)
swClassification = findViewById(R.id.swPoseClassification)
vClassificationOption = findViewById(R.id.vClassificationOption)
initSpinner()
spnModel.setSelection(modelPos)
swClassification.setOnCheckedChangeListener(setClassificationListener)
if (!isCameraPermissionGranted()) {
requestPermission()
}
}
override fun onStart() {
super.onStart()
openCamera()
}
override fun onResume() {
cameraSource?.resume()
super.onResume()
}
override fun onPause() {
cameraSource?.close()
cameraSource = null
super.onPause()
}
// check if permission is granted or not.
private fun isCameraPermissionGranted(): Boolean {
return checkPermission(
Manifest.permission.CAMERA,
Process.myPid(),
Process.myUid()
) == PackageManager.PERMISSION_GRANTED
}
// open camera
private fun openCamera() {
if (isCameraPermissionGranted()) {
if (cameraSource == null) {
cameraSource =
CameraSource(surfaceView, object : CameraSource.CameraSourceListener {
override fun onFPSListener(fps: Int) {
tvFPS.text = getString(R.string.tfe_pe_tv_fps, fps)
}
override fun onDetectedInfo(
personScore: Float?,
poseLabels: List<Pair<String, Float>>?
) {
tvScore.text = getString(R.string.tfe_pe_tv_score, personScore ?: 0f)
poseLabels?.sortedByDescending { it.second }?.let {
tvClassificationValue1.text = getString(
R.string.tfe_pe_tv_classification_value,
convertPoseLabels(if (it.isNotEmpty()) it[0] else null)
)
tvClassificationValue2.text = getString(
R.string.tfe_pe_tv_classification_value,
convertPoseLabels(if (it.size >= 2) it[1] else null)
)
tvClassificationValue3.text = getString(
R.string.tfe_pe_tv_classification_value,
convertPoseLabels(if (it.size >= 3) it[2] else null)
)
}
}
}).apply {
prepareCamera()
}
isPoseClassifier()
lifecycleScope.launch(Dispatchers.Main) {
cameraSource?.initCamera()
}
}
createPoseEstimator()
}
}
private fun convertPoseLabels(pair: Pair<String, Float>?): String {
if (pair == null) return "empty"
return "${pair.first} (${String.format("%.2f", pair.second)})"
}
private fun isPoseClassifier() {
cameraSource?.setClassifier(if (isClassifyPose) PoseClassifier.create(this) else null)
}
// Initialize spinners to let user select model/accelerator/tracker.
private fun initSpinner() {
ArrayAdapter.createFromResource(
this,
R.array.tfe_pe_models_array,
android.R.layout.simple_spinner_item
).also { adapter ->
// Specify the layout to use when the list of choices appears
adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
// Apply the adapter to the spinner
spnModel.adapter = adapter
spnModel.onItemSelectedListener = changeModelListener
}
ArrayAdapter.createFromResource(
this,
R.array.tfe_pe_device_name, android.R.layout.simple_spinner_item
).also { adaper ->
adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
spnDevice.adapter = adaper
spnDevice.onItemSelectedListener = changeDeviceListener
}
ArrayAdapter.createFromResource(
this,
R.array.tfe_pe_tracker_array, android.R.layout.simple_spinner_item
).also { adaper ->
adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
spnTracker.adapter = adaper
spnTracker.onItemSelectedListener = changeTrackerListener
}
}
// Change model when app is running
private fun changeModel(position: Int) {
if (modelPos == position) return
modelPos = position
createPoseEstimator()
}
// Change device (accelerator) type when app is running
private fun changeDevice(position: Int) {
val targetDevice = when (position) {
0 -> Device.CPU
1 -> Device.GPU
else -> Device.NNAPI
}
if (device == targetDevice) return
device = targetDevice
createPoseEstimator()
}
// Change tracker for Movenet MultiPose model
private fun changeTracker(position: Int) {
cameraSource?.setTracker(
when (position) {
1 -> TrackerType.BOUNDING_BOX
2 -> TrackerType.KEYPOINTS
else -> TrackerType.OFF
}
)
}
private fun createPoseEstimator() {
// For MoveNet MultiPose, hide score and disable pose classifier as the model returns
// multiple Person instances.
val poseDetector = when (modelPos) {
0 -> {
// MoveNet Lightning (SinglePose)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
MoveNet.create(this, device, ModelType.Lightning)
}
1 -> {
// MoveNet Thunder (SinglePose)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
MoveNet.create(this, device, ModelType.Thunder)
}
2 -> {
// MoveNet (Lightning) MultiPose
showPoseClassifier(false)
showDetectionScore(false)
// Movenet MultiPose Dynamic does not support GPUDelegate
if (device == Device.GPU) {
showToast(getString(R.string.tfe_pe_gpu_error))
}
showTracker(true)
MoveNetMultiPose.create(
this,
device,
Type.Dynamic
)
}
3 -> {
// PoseNet (SinglePose)
showPoseClassifier(true)
showDetectionScore(true)
showTracker(false)
PoseNet.create(this, device)
}
else -> {
null
}
}
poseDetector?.let { detector ->
cameraSource?.setDetector(detector)
}
}
// Show/hide the pose classification option.
private fun showPoseClassifier(isVisible: Boolean) {
vClassificationOption.visibility = if (isVisible) View.VISIBLE else View.GONE
if (!isVisible) {
swClassification.isChecked = false
}
}
// Show/hide the detection score.
private fun showDetectionScore(isVisible: Boolean) {
tvScore.visibility = if (isVisible) View.VISIBLE else View.GONE
}
// Show/hide classification result.
private fun showClassificationResult(isVisible: Boolean) {
val visibility = if (isVisible) View.VISIBLE else View.GONE
tvClassificationValue1.visibility = visibility
tvClassificationValue2.visibility = visibility
tvClassificationValue3.visibility = visibility
}
// Show/hide the tracking options.
private fun showTracker(isVisible: Boolean) {
if (isVisible) {
// Show tracker options and enable Bounding Box tracker.
vTrackerOption.visibility = View.VISIBLE
spnTracker.setSelection(1)
} else {
// Set tracker type to off and hide tracker option.
vTrackerOption.visibility = View.GONE
spnTracker.setSelection(0)
}
}
private fun requestPermission() {
when (PackageManager.PERMISSION_GRANTED) {
ContextCompat.checkSelfPermission(
this,
Manifest.permission.CAMERA
) -> {
// You can use the API that requires the permission.
openCamera()
}
else -> {
// You can directly ask for the permission.
// The registered ActivityResultCallback gets the result of this request.
requestPermissionLauncher.launch(
Manifest.permission.CAMERA
)
}
}
}
private fun showToast(message: String) {
Toast.makeText(this, message, Toast.LENGTH_LONG).show()
}
/**
* Shows an error message dialog.
*/
class ErrorDialog : DialogFragment() {
override fun onCreateDialog(savedInstanceState: Bundle?): Dialog =
AlertDialog.Builder(activity)
.setMessage(requireArguments().getString(ARG_MESSAGE))
.setPositiveButton(android.R.string.ok) { _, _ ->
// do nothing
}
.create()
companion object {
@JvmStatic
private val ARG_MESSAGE = "message"
@JvmStatic
fun newInstance(message: String): ErrorDialog = ErrorDialog().apply {
arguments = Bundle().apply { putString(ARG_MESSAGE, message) }
}
}
}
}

@ -0,0 +1,120 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Person
import kotlin.math.max
object VisualizationUtils {
/** Radius of circle used to draw keypoints. */
private const val CIRCLE_RADIUS = 6f
/** Width of line used to connected two keypoints. */
private const val LINE_WIDTH = 4f
/** The text size of the person id that will be displayed when the tracker is available. */
private const val PERSON_ID_TEXT_SIZE = 30f
/** Distance from person id to the nose keypoint. */
private const val PERSON_ID_MARGIN = 6f
/** Pair of keypoints to draw lines between. */
private val bodyJoints = listOf(
Pair(BodyPart.NOSE, BodyPart.LEFT_EYE),
Pair(BodyPart.NOSE, BodyPart.RIGHT_EYE),
Pair(BodyPart.LEFT_EYE, BodyPart.LEFT_EAR),
Pair(BodyPart.RIGHT_EYE, BodyPart.RIGHT_EAR),
Pair(BodyPart.NOSE, BodyPart.LEFT_SHOULDER),
Pair(BodyPart.NOSE, BodyPart.RIGHT_SHOULDER),
Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_ELBOW),
Pair(BodyPart.LEFT_ELBOW, BodyPart.LEFT_WRIST),
Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_ELBOW),
Pair(BodyPart.RIGHT_ELBOW, BodyPart.RIGHT_WRIST),
Pair(BodyPart.LEFT_SHOULDER, BodyPart.RIGHT_SHOULDER),
Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_HIP),
Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_HIP),
Pair(BodyPart.LEFT_HIP, BodyPart.RIGHT_HIP),
Pair(BodyPart.LEFT_HIP, BodyPart.LEFT_KNEE),
Pair(BodyPart.LEFT_KNEE, BodyPart.LEFT_ANKLE),
Pair(BodyPart.RIGHT_HIP, BodyPart.RIGHT_KNEE),
Pair(BodyPart.RIGHT_KNEE, BodyPart.RIGHT_ANKLE)
)
// Draw line and point indicate body pose
fun drawBodyKeypoints(
input: Bitmap,
persons: List<Person>,
isTrackerEnabled: Boolean = false
): Bitmap {
val paintCircle = Paint().apply {
strokeWidth = CIRCLE_RADIUS
color = Color.RED
style = Paint.Style.FILL
}
val paintLine = Paint().apply {
strokeWidth = LINE_WIDTH
color = Color.RED
style = Paint.Style.STROKE
}
val paintText = Paint().apply {
textSize = PERSON_ID_TEXT_SIZE
color = Color.BLUE
textAlign = Paint.Align.LEFT
}
val output = input.copy(Bitmap.Config.ARGB_8888, true)
val originalSizeCanvas = Canvas(output)
persons.forEach { person ->
// draw person id if tracker is enable
if (isTrackerEnabled) {
person.boundingBox?.let {
val personIdX = max(0f, it.left)
val personIdY = max(0f, it.top)
originalSizeCanvas.drawText(
person.id.toString(),
personIdX,
personIdY - PERSON_ID_MARGIN,
paintText
)
originalSizeCanvas.drawRect(it, paintLine)
}
}
bodyJoints.forEach {
val pointA = person.keyPoints[it.first.position].coordinate
val pointB = person.keyPoints[it.second.position].coordinate
originalSizeCanvas.drawLine(pointA.x, pointA.y, pointB.x, pointB.y, paintLine)
}
person.keyPoints.forEach { point ->
originalSizeCanvas.drawCircle(
point.coordinate.x,
point.coordinate.y,
CIRCLE_RADIUS,
paintCircle
)
}
}
return output
}
}

@ -0,0 +1,151 @@
package org.tensorflow.lite.examples.poseestimation
import android.content.Context
import android.graphics.Bitmap
import android.graphics.ImageFormat
import android.graphics.Rect
import android.media.Image
import android.renderscript.Allocation
import android.renderscript.Element
import android.renderscript.RenderScript
import android.renderscript.ScriptIntrinsicYuvToRGB
import java.nio.ByteBuffer
class YuvToRgbConverter(context: Context) {
private val rs = RenderScript.create(context)
private val scriptYuvToRgb = ScriptIntrinsicYuvToRGB.create(rs, Element.U8_4(rs))
private var pixelCount: Int = -1
private lateinit var yuvBuffer: ByteBuffer
private lateinit var inputAllocation: Allocation
private lateinit var outputAllocation: Allocation
@Synchronized
fun yuvToRgb(image: Image, output: Bitmap) {
// Ensure that the intermediate output byte buffer is allocated
if (!::yuvBuffer.isInitialized) {
pixelCount = image.cropRect.width() * image.cropRect.height()
yuvBuffer = ByteBuffer.allocateDirect(
pixelCount * ImageFormat.getBitsPerPixel(ImageFormat.YUV_420_888) / 8)
}
// Get the YUV data in byte array form
imageToByteBuffer(image, yuvBuffer)
// Ensure that the RenderScript inputs and outputs are allocated
if (!::inputAllocation.isInitialized) {
inputAllocation = Allocation.createSized(rs, Element.U8(rs), yuvBuffer.array().size)
}
if (!::outputAllocation.isInitialized) {
outputAllocation = Allocation.createFromBitmap(rs, output)
}
// Convert YUV to RGB
inputAllocation.copyFrom(yuvBuffer.array())
scriptYuvToRgb.setInput(inputAllocation)
scriptYuvToRgb.forEach(outputAllocation)
outputAllocation.copyTo(output)
}
private fun imageToByteBuffer(image: Image, outputBuffer: ByteBuffer) {
assert(image.format == ImageFormat.YUV_420_888)
val imageCrop = image.cropRect
val imagePlanes = image.planes
val rowData = ByteArray(imagePlanes.first().rowStride)
imagePlanes.forEachIndexed { planeIndex, plane ->
// How many values are read in input for each output value written
// Only the Y plane has a value for every pixel, U and V have half the resolution i.e.
//
// Y Plane U Plane V Plane
// =============== ======= =======
// Y Y Y Y Y Y Y Y U U U U V V V V
// Y Y Y Y Y Y Y Y U U U U V V V V
// Y Y Y Y Y Y Y Y U U U U V V V V
// Y Y Y Y Y Y Y Y U U U U V V V V
// Y Y Y Y Y Y Y Y
// Y Y Y Y Y Y Y Y
// Y Y Y Y Y Y Y Y
val outputStride: Int
// The index in the output buffer the next value will be written at
// For Y it's zero, for U and V we start at the end of Y and interleave them i.e.
//
// First chunk Second chunk
// =============== ===============
// Y Y Y Y Y Y Y Y U V U V U V U V
// Y Y Y Y Y Y Y Y U V U V U V U V
// Y Y Y Y Y Y Y Y U V U V U V U V
// Y Y Y Y Y Y Y Y U V U V U V U V
// Y Y Y Y Y Y Y Y
// Y Y Y Y Y Y Y Y
// Y Y Y Y Y Y Y Y
var outputOffset: Int
when (planeIndex) {
0 -> {
outputStride = 1
outputOffset = 0
}
1 -> {
outputStride = 2
outputOffset = pixelCount + 1
}
2 -> {
outputStride = 2
outputOffset = pixelCount
}
else -> {
// Image contains more than 3 planes, something strange is going on
return@forEachIndexed
}
}
val buffer = plane.buffer
val rowStride = plane.rowStride
val pixelStride = plane.pixelStride
// We have to divide the width and height by two if it's not the Y plane
val planeCrop = if (planeIndex == 0) {
imageCrop
} else {
Rect(
imageCrop.left / 2,
imageCrop.top / 2,
imageCrop.right / 2,
imageCrop.bottom / 2
)
}
val planeWidth = planeCrop.width()
val planeHeight = planeCrop.height()
buffer.position(rowStride * planeCrop.top + pixelStride * planeCrop.left)
for (row in 0 until planeHeight) {
val length: Int
if (pixelStride == 1 && outputStride == 1) {
// When there is a single stride value for pixel and output, we can just copy
// the entire row in a single step
length = planeWidth
buffer.get(outputBuffer.array(), outputOffset, length)
outputOffset += length
} else {
// When either pixel or output have a stride > 1 we must copy pixel by pixel
length = (planeWidth - 1) * pixelStride + 1
buffer.get(rowData, 0, length)
for (col in 0 until planeWidth) {
outputBuffer.array()[outputOffset] = rowData[col * pixelStride]
outputOffset += outputStride
}
}
if (row < planeHeight - 1) {
buffer.position(buffer.position() + rowStride - length)
}
}
}
}
}

@ -0,0 +1,327 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.camera
import android.annotation.SuppressLint
import android.content.Context
import android.graphics.Bitmap
import android.graphics.ImageFormat
import android.graphics.Matrix
import android.graphics.Rect
import android.hardware.camera2.CameraCaptureSession
import android.hardware.camera2.CameraCharacteristics
import android.hardware.camera2.CameraDevice
import android.hardware.camera2.CameraManager
import android.media.ImageReader
import android.os.Handler
import android.os.HandlerThread
import android.util.Log
import android.view.Surface
import android.view.SurfaceView
import kotlinx.coroutines.suspendCancellableCoroutine
import org.tensorflow.lite.examples.poseestimation.VisualizationUtils
import org.tensorflow.lite.examples.poseestimation.YuvToRgbConverter
import org.tensorflow.lite.examples.poseestimation.data.Person
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
import org.tensorflow.lite.examples.poseestimation.ml.PoseClassifier
import org.tensorflow.lite.examples.poseestimation.ml.PoseDetector
import org.tensorflow.lite.examples.poseestimation.ml.TrackerType
import java.util.*
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
class CameraSource(
private val surfaceView: SurfaceView,
private val listener: CameraSourceListener? = null
) {
companion object {
private const val PREVIEW_WIDTH = 640
private const val PREVIEW_HEIGHT = 480
/** Threshold for confidence score. */
private const val MIN_CONFIDENCE = .2f
private const val TAG = "Camera Source"
}
private val lock = Any()
private var detector: PoseDetector? = null
private var classifier: PoseClassifier? = null
private var isTrackerEnabled = false
private var yuvConverter: YuvToRgbConverter = YuvToRgbConverter(surfaceView.context)
private lateinit var imageBitmap: Bitmap
/** Frame count that have been processed so far in an one second interval to calculate FPS. */
private var fpsTimer: Timer? = null
private var frameProcessedInOneSecondInterval = 0
private var framesPerSecond = 0
/** Detects, characterizes, and connects to a CameraDevice (used for all camera operations) */
private val cameraManager: CameraManager by lazy {
val context = surfaceView.context
context.getSystemService(Context.CAMERA_SERVICE) as CameraManager
}
/** Readers used as buffers for camera still shots */
private var imageReader: ImageReader? = null
/** The [CameraDevice] that will be opened in this fragment */
private var camera: CameraDevice? = null
/** Internal reference to the ongoing [CameraCaptureSession] configured with our parameters */
private var session: CameraCaptureSession? = null
/** [HandlerThread] where all buffer reading operations run */
private var imageReaderThread: HandlerThread? = null
/** [Handler] corresponding to [imageReaderThread] */
private var imageReaderHandler: Handler? = null
private var cameraId: String = ""
suspend fun initCamera() {
camera = openCamera(cameraManager, cameraId)
imageReader =
ImageReader.newInstance(PREVIEW_WIDTH, PREVIEW_HEIGHT, ImageFormat.YUV_420_888, 3)
imageReader?.setOnImageAvailableListener({ reader ->
val image = reader.acquireLatestImage()
if (image != null) {
if (!::imageBitmap.isInitialized) {
imageBitmap =
Bitmap.createBitmap(
PREVIEW_WIDTH,
PREVIEW_HEIGHT,
Bitmap.Config.ARGB_8888
)
}
yuvConverter.yuvToRgb(image, imageBitmap)
// Create rotated version for portrait display
val rotateMatrix = Matrix()
rotateMatrix.postRotate(90.0f)
val rotatedBitmap = Bitmap.createBitmap(
imageBitmap, 0, 0, PREVIEW_WIDTH, PREVIEW_HEIGHT,
rotateMatrix, false
)
processImage(rotatedBitmap)
image.close()
}
}, imageReaderHandler)
imageReader?.surface?.let { surface ->
session = createSession(listOf(surface))
val cameraRequest = camera?.createCaptureRequest(
CameraDevice.TEMPLATE_PREVIEW
)?.apply {
addTarget(surface)
}
cameraRequest?.build()?.let {
session?.setRepeatingRequest(it, null, null)
}
}
}
private suspend fun createSession(targets: List<Surface>): CameraCaptureSession =
suspendCancellableCoroutine { cont ->
camera?.createCaptureSession(targets, object : CameraCaptureSession.StateCallback() {
override fun onConfigured(captureSession: CameraCaptureSession) =
cont.resume(captureSession)
override fun onConfigureFailed(session: CameraCaptureSession) {
cont.resumeWithException(Exception("Session error"))
}
}, null)
}
@SuppressLint("MissingPermission")
private suspend fun openCamera(manager: CameraManager, cameraId: String): CameraDevice =
suspendCancellableCoroutine { cont ->
manager.openCamera(cameraId, object : CameraDevice.StateCallback() {
override fun onOpened(camera: CameraDevice) = cont.resume(camera)
override fun onDisconnected(camera: CameraDevice) {
camera.close()
}
override fun onError(camera: CameraDevice, error: Int) {
if (cont.isActive) cont.resumeWithException(Exception("Camera error"))
}
}, imageReaderHandler)
}
fun prepareCamera() {
for (cameraId in cameraManager.cameraIdList) {
val characteristics = cameraManager.getCameraCharacteristics(cameraId)
// We don't use a front facing camera in this sample.
val cameraDirection = characteristics.get(CameraCharacteristics.LENS_FACING)
if (cameraDirection != null &&
cameraDirection == CameraCharacteristics.LENS_FACING_FRONT
) {
continue
}
this.cameraId = cameraId
}
}
fun setDetector(detector: PoseDetector) {
synchronized(lock) {
if (this.detector != null) {
this.detector?.close()
this.detector = null
}
this.detector = detector
}
}
fun setClassifier(classifier: PoseClassifier?) {
synchronized(lock) {
if (this.classifier != null) {
this.classifier?.close()
this.classifier = null
}
this.classifier = classifier
}
}
/**
* Set Tracker for Movenet MuiltiPose model.
*/
fun setTracker(trackerType: TrackerType) {
isTrackerEnabled = trackerType != TrackerType.OFF
(this.detector as? MoveNetMultiPose)?.setTracker(trackerType)
}
fun resume() {
imageReaderThread = HandlerThread("imageReaderThread").apply { start() }
imageReaderHandler = Handler(imageReaderThread!!.looper)
fpsTimer = Timer()
fpsTimer?.scheduleAtFixedRate(
object : TimerTask() {
override fun run() {
framesPerSecond = frameProcessedInOneSecondInterval
frameProcessedInOneSecondInterval = 0
}
},
0,
1000
)
}
fun close() {
session?.close()
session = null
camera?.close()
camera = null
imageReader?.close()
imageReader = null
stopImageReaderThread()
detector?.close()
detector = null
classifier?.close()
classifier = null
fpsTimer?.cancel()
fpsTimer = null
frameProcessedInOneSecondInterval = 0
framesPerSecond = 0
}
// process image
private fun processImage(bitmap: Bitmap) {
val persons = mutableListOf<Person>()
var classificationResult: List<Pair<String, Float>>? = null
synchronized(lock) {
detector?.estimatePoses(bitmap)?.let {
persons.addAll(it)
// if the model only returns one item, allow running the Pose classifier.
if (persons.isNotEmpty()) {
classifier?.run {
classificationResult = classify(persons[0])
}
}
}
}
frameProcessedInOneSecondInterval++
if (frameProcessedInOneSecondInterval == 1) {
// send fps to view
listener?.onFPSListener(framesPerSecond)
}
// if the model returns only one item, show that item's score.
if (persons.isNotEmpty()) {
listener?.onDetectedInfo(persons[0].score, classificationResult)
}
visualize(persons, bitmap)
}
private fun visualize(persons: List<Person>, bitmap: Bitmap) {
val outputBitmap = VisualizationUtils.drawBodyKeypoints(
bitmap,
persons.filter { it.score > MIN_CONFIDENCE }, isTrackerEnabled
)
val holder = surfaceView.holder
val surfaceCanvas = holder.lockCanvas()
surfaceCanvas?.let { canvas ->
val screenWidth: Int
val screenHeight: Int
val left: Int
val top: Int
if (canvas.height > canvas.width) {
val ratio = outputBitmap.height.toFloat() / outputBitmap.width
screenWidth = canvas.width
left = 0
screenHeight = (canvas.width * ratio).toInt()
top = (canvas.height - screenHeight) / 2
} else {
val ratio = outputBitmap.width.toFloat() / outputBitmap.height
screenHeight = canvas.height
top = 0
screenWidth = (canvas.height * ratio).toInt()
left = (canvas.width - screenWidth) / 2
}
val right: Int = left + screenWidth
val bottom: Int = top + screenHeight
canvas.drawBitmap(
outputBitmap, Rect(0, 0, outputBitmap.width, outputBitmap.height),
Rect(left, top, right, bottom), null
)
surfaceView.holder.unlockCanvasAndPost(canvas)
}
}
private fun stopImageReaderThread() {
imageReaderThread?.quitSafely()
try {
imageReaderThread?.join()
imageReaderThread = null
imageReaderHandler = null
} catch (e: InterruptedException) {
Log.d(TAG, e.message.toString())
}
}
interface CameraSourceListener {
fun onFPSListener(fps: Int)
fun onDetectedInfo(personScore: Float?, poseLabels: List<Pair<String, Float>>?)
}
}

@ -0,0 +1,41 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
enum class BodyPart(val position: Int) {
NOSE(0),
LEFT_EYE(1),
RIGHT_EYE(2),
LEFT_EAR(3),
RIGHT_EAR(4),
LEFT_SHOULDER(5),
RIGHT_SHOULDER(6),
LEFT_ELBOW(7),
RIGHT_ELBOW(8),
LEFT_WRIST(9),
RIGHT_WRIST(10),
LEFT_HIP(11),
RIGHT_HIP(12),
LEFT_KNEE(13),
RIGHT_KNEE(14),
LEFT_ANKLE(15),
RIGHT_ANKLE(16);
companion object{
private val map = values().associateBy(BodyPart::position)
fun fromInt(position: Int): BodyPart = map.getValue(position)
}
}

@ -0,0 +1,23 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
enum class Device {
CPU,
NNAPI,
GPU
}

@ -0,0 +1,21 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
import android.graphics.PointF
data class KeyPoint(val bodyPart: BodyPart, var coordinate: PointF, val score: Float)

@ -0,0 +1,26 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
import android.graphics.RectF
data class Person(
var id: Int = -1, // default id is -1
val keyPoints: List<KeyPoint>,
val boundingBox: RectF? = null, // Only MoveNet MultiPose return bounding box.
val score: Float
)

@ -0,0 +1,24 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.data
data class TorsoAndBodyDistance(
val maxTorsoYDistance: Float,
val maxTorsoXDistance: Float,
val maxBodyYDistance: Float,
val maxBodyXDistance: Float
)

@ -0,0 +1,353 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.*
import android.os.SystemClock
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.examples.poseestimation.data.*
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min
enum class ModelType {
Lightning,
Thunder
}
class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
PoseDetector {
companion object {
private const val MIN_CROP_KEYPOINT_SCORE = .2f
private const val CPU_NUM_THREADS = 4
// Parameters that control how large crop region should be expanded from previous frames'
// body keypoints.
private const val TORSO_EXPANSION_RATIO = 1.9f
private const val BODY_EXPANSION_RATIO = 1.2f
// TFLite file names.
private const val LIGHTNING_FILENAME = "movenet_lightning.tflite"
private const val THUNDER_FILENAME = "movenet_thunder.tflite"
// allow specifying model type.
fun create(context: Context, device: Device, modelType: ModelType): MoveNet {
val options = Interpreter.Options()
var gpuDelegate: GpuDelegate? = null
options.setNumThreads(CPU_NUM_THREADS)
when (device) {
Device.CPU -> {
}
Device.GPU -> {
gpuDelegate = GpuDelegate()
options.addDelegate(gpuDelegate)
}
Device.NNAPI -> options.setUseNNAPI(true)
}
return MoveNet(
Interpreter(
FileUtil.loadMappedFile(
context,
if (modelType == ModelType.Lightning) LIGHTNING_FILENAME
else THUNDER_FILENAME
), options
),
gpuDelegate
)
}
// default to lightning.
fun create(context: Context, device: Device): MoveNet =
create(context, device, ModelType.Lightning)
}
private var cropRegion: RectF? = null
private var lastInferenceTimeNanos: Long = -1
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
private var outputShape: IntArray = interpreter.getOutputTensor(0).shape()
override fun estimatePoses(bitmap: Bitmap): List<Person> {
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
if (cropRegion == null) {
cropRegion = initRectF(bitmap.width, bitmap.height)
}
var totalScore = 0f
val numKeyPoints = outputShape[2]
val keyPoints = mutableListOf<KeyPoint>()
cropRegion?.run {
val rect = RectF(
(left * bitmap.width),
(top * bitmap.height),
(right * bitmap.width),
(bottom * bitmap.height)
)
val detectBitmap = Bitmap.createBitmap(
rect.width().toInt(),
rect.height().toInt(),
Bitmap.Config.ARGB_8888
)
Canvas(detectBitmap).drawBitmap(
bitmap,
-rect.left,
-rect.top,
null
)
val inputTensor = processInputImage(detectBitmap, inputWidth, inputHeight)
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
val widthRatio = detectBitmap.width.toFloat() / inputWidth
val heightRatio = detectBitmap.height.toFloat() / inputHeight
val positions = mutableListOf<Float>()
inputTensor?.let { input ->
interpreter.run(input.buffer, outputTensor.buffer.rewind())
val output = outputTensor.floatArray
for (idx in 0 until numKeyPoints) {
val x = output[idx * 3 + 1] * inputWidth * widthRatio
val y = output[idx * 3 + 0] * inputHeight * heightRatio
positions.add(x)
positions.add(y)
val score = output[idx * 3 + 2]
keyPoints.add(
KeyPoint(
BodyPart.fromInt(idx),
PointF(
x,
y
),
score
)
)
totalScore += score
}
}
val matrix = Matrix()
val points = positions.toFloatArray()
matrix.postTranslate(rect.left, rect.top)
matrix.mapPoints(points)
keyPoints.forEachIndexed { index, keyPoint ->
keyPoint.coordinate =
PointF(
points[index * 2],
points[index * 2 + 1]
)
}
// new crop region
cropRegion = determineRectF(keyPoints, bitmap.width, bitmap.height)
}
lastInferenceTimeNanos =
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
return listOf(Person(keyPoints = keyPoints, score = totalScore / numKeyPoints))
}
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
override fun close() {
gpuDelegate?.close()
interpreter.close()
cropRegion = null
}
/**
* Prepare input image for detection
*/
private fun processInputImage(bitmap: Bitmap, inputWidth: Int, inputHeight: Int): TensorImage? {
val width: Int = bitmap.width
val height: Int = bitmap.height
val size = if (height > width) width else height
val imageProcessor = ImageProcessor.Builder().apply {
add(ResizeWithCropOrPadOp(size, size))
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
}.build()
val tensorImage = TensorImage(DataType.UINT8)
tensorImage.load(bitmap)
return imageProcessor.process(tensorImage)
}
/**
* Defines the default crop region.
* The function provides the initial crop region (pads the full image from both
* sides to make it a square image) when the algorithm cannot reliably determine
* the crop region from the previous frame.
*/
private fun initRectF(imageWidth: Int, imageHeight: Int): RectF {
val xMin: Float
val yMin: Float
val width: Float
val height: Float
if (imageWidth > imageHeight) {
width = 1f
height = imageWidth.toFloat() / imageHeight
xMin = 0f
yMin = (imageHeight / 2f - imageWidth / 2f) / imageHeight
} else {
height = 1f
width = imageHeight.toFloat() / imageWidth
yMin = 0f
xMin = (imageWidth / 2f - imageHeight / 2) / imageWidth
}
return RectF(
xMin,
yMin,
xMin + width,
yMin + height
)
}
/**
* Checks whether there are enough torso keypoints.
* This function checks whether the model is confident at predicting one of the
* shoulders/hips which is required to determine a good crop region.
*/
private fun torsoVisible(keyPoints: List<KeyPoint>): Boolean {
return ((keyPoints[BodyPart.LEFT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE).or(
keyPoints[BodyPart.RIGHT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE
)).and(
(keyPoints[BodyPart.LEFT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE).or(
keyPoints[BodyPart.RIGHT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE
)
)
}
/**
* Determines the region to crop the image for the model to run inference on.
* The algorithm uses the detected joints from the previous frame to estimate
* the square region that encloses the full body of the target person and
* centers at the midpoint of two hip joints. The crop size is determined by
* the distances between each joints and the center point.
* When the model is not confident with the four torso joint predictions, the
* function returns a default crop which is the full image padded to square.
*/
private fun determineRectF(
keyPoints: List<KeyPoint>,
imageWidth: Int,
imageHeight: Int
): RectF {
val targetKeyPoints = mutableListOf<KeyPoint>()
keyPoints.forEach {
targetKeyPoints.add(
KeyPoint(
it.bodyPart,
PointF(
it.coordinate.x,
it.coordinate.y
),
it.score
)
)
}
if (torsoVisible(keyPoints)) {
val centerX =
(targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.x +
targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.x) / 2f
val centerY =
(targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.y +
targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.y) / 2f
val torsoAndBodyDistances =
determineTorsoAndBodyDistances(keyPoints, targetKeyPoints, centerX, centerY)
val list = listOf(
torsoAndBodyDistances.maxTorsoXDistance * TORSO_EXPANSION_RATIO,
torsoAndBodyDistances.maxTorsoYDistance * TORSO_EXPANSION_RATIO,
torsoAndBodyDistances.maxBodyXDistance * BODY_EXPANSION_RATIO,
torsoAndBodyDistances.maxBodyYDistance * BODY_EXPANSION_RATIO
)
var cropLengthHalf = list.maxOrNull() ?: 0f
val tmp = listOf(centerX, imageWidth - centerX, centerY, imageHeight - centerY)
cropLengthHalf = min(cropLengthHalf, tmp.maxOrNull() ?: 0f)
val cropCorner = Pair(centerY - cropLengthHalf, centerX - cropLengthHalf)
return if (cropLengthHalf > max(imageWidth, imageHeight) / 2f) {
initRectF(imageWidth, imageHeight)
} else {
val cropLength = cropLengthHalf * 2
RectF(
cropCorner.second / imageWidth,
cropCorner.first / imageHeight,
(cropCorner.second + cropLength) / imageWidth,
(cropCorner.first + cropLength) / imageHeight,
)
}
} else {
return initRectF(imageWidth, imageHeight)
}
}
/**
* Calculates the maximum distance from each keypoints to the center location.
* The function returns the maximum distances from the two sets of keypoints:
* full 17 keypoints and 4 torso keypoints. The returned information will be
* used to determine the crop size. See determineRectF for more detail.
*/
private fun determineTorsoAndBodyDistances(
keyPoints: List<KeyPoint>,
targetKeyPoints: List<KeyPoint>,
centerX: Float,
centerY: Float
): TorsoAndBodyDistance {
val torsoJoints = listOf(
BodyPart.LEFT_SHOULDER.position,
BodyPart.RIGHT_SHOULDER.position,
BodyPart.LEFT_HIP.position,
BodyPart.RIGHT_HIP.position
)
var maxTorsoYRange = 0f
var maxTorsoXRange = 0f
torsoJoints.forEach { joint ->
val distY = abs(centerY - targetKeyPoints[joint].coordinate.y)
val distX = abs(centerX - targetKeyPoints[joint].coordinate.x)
if (distY > maxTorsoYRange) maxTorsoYRange = distY
if (distX > maxTorsoXRange) maxTorsoXRange = distX
}
var maxBodyYRange = 0f
var maxBodyXRange = 0f
for (joint in keyPoints.indices) {
if (keyPoints[joint].score < MIN_CROP_KEYPOINT_SCORE) continue
val distY = abs(centerY - keyPoints[joint].coordinate.y)
val distX = abs(centerX - keyPoints[joint].coordinate.x)
if (distY > maxBodyYRange) maxBodyYRange = distY
if (distX > maxBodyXRange) maxBodyXRange = distX
}
return TorsoAndBodyDistance(
maxTorsoYRange,
maxTorsoXRange,
maxBodyYRange,
maxBodyXRange
)
}
}

@ -0,0 +1,309 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.Bitmap
import android.graphics.PointF
import android.graphics.RectF
import android.os.SystemClock
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
import org.tensorflow.lite.examples.poseestimation.data.Person
import org.tensorflow.lite.examples.poseestimation.tracker.*
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.image.ImageOperator
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import kotlin.math.ceil
class MoveNetMultiPose(
private val interpreter: Interpreter,
private val type: Type,
private val gpuDelegate: GpuDelegate?,
) : PoseDetector {
private val outputShape = interpreter.getOutputTensor(0).shape()
private val inputShape = interpreter.getInputTensor(0).shape()
private var imageWidth: Int = 0
private var imageHeight: Int = 0
private var targetWidth: Int = 0
private var targetHeight: Int = 0
private var scaleHeight: Int = 0
private var scaleWidth: Int = 0
private var lastInferenceTimeNanos: Long = -1
private var tracker: AbstractTracker? = null
companion object {
private const val DYNAMIC_MODEL_TARGET_INPUT_SIZE = 256
private const val SHAPE_MULTIPLE = 32.0
private const val DETECTION_THRESHOLD = 0.11
private const val DETECTION_SCORE_INDEX = 55
private const val BOUNDING_BOX_Y_MIN_INDEX = 51
private const val BOUNDING_BOX_X_MIN_INDEX = 52
private const val BOUNDING_BOX_Y_MAX_INDEX = 53
private const val BOUNDING_BOX_X_MAX_INDEX = 54
private const val KEYPOINT_COUNT = 17
private const val OUTPUTS_COUNT_PER_KEYPOINT = 3
private const val CPU_NUM_THREADS = 4
// allow specifying model type.
fun create(
context: Context,
device: Device,
type: Type,
): MoveNetMultiPose {
val options = Interpreter.Options()
var gpuDelegate: GpuDelegate? = null
when (device) {
Device.CPU -> {
options.setNumThreads(CPU_NUM_THREADS)
}
Device.GPU -> {
// only fixed model support Gpu delegate option.
if (type == Type.Fixed) {
gpuDelegate = GpuDelegate()
options.addDelegate(gpuDelegate)
}
}
else -> {
// nothing to do
}
}
return MoveNetMultiPose(
Interpreter(
FileUtil.loadMappedFile(
context,
if (type == Type.Dynamic)
"movenet_multipose_fp16.tflite" else ""
//@TODO: (khanhlvg) Add support for fixed shape model if it's released.
), options
), type, gpuDelegate
)
}
}
/**
* Convert x and y coordinates ([0-1]) returns from the TFlite model
* to the coordinates corresponding to the input image.
*/
private fun resizeKeypoint(x: Float, y: Float): PointF {
return PointF(resizeX(x), resizeY(y))
}
private fun resizeX(x: Float): Float {
return if (imageWidth > imageHeight) {
val ratioWidth = imageWidth.toFloat() / targetWidth
x * targetWidth * ratioWidth
} else {
val detectedWidth =
if (type == Type.Dynamic) targetWidth else inputShape[2]
val paddingWidth = detectedWidth - scaleWidth
val ratioWidth = imageWidth.toFloat() / scaleWidth
(x * detectedWidth - paddingWidth / 2f) * ratioWidth
}
}
private fun resizeY(y: Float): Float {
return if (imageWidth > imageHeight) {
val detectedHeight =
if (type == Type.Dynamic) targetHeight else inputShape[1]
val paddingHeight = detectedHeight - scaleHeight
val ratioHeight = imageHeight.toFloat() / scaleHeight
(y * detectedHeight - paddingHeight / 2f) * ratioHeight
} else {
val ratioHeight = imageHeight.toFloat() / targetHeight
y * targetHeight * ratioHeight
}
}
/**
* Prepare input image for detection
*/
private fun processInputTensor(bitmap: Bitmap): TensorImage {
imageWidth = bitmap.width
imageHeight = bitmap.height
// if model type is fixed. get input size from input shape.
val inputSizeHeight =
if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[1]
val inputSizeWidth =
if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[2]
val resizeOp: ImageOperator
if (imageWidth > imageHeight) {
val scale = inputSizeWidth / imageWidth.toFloat()
targetWidth = inputSizeWidth
scaleHeight = ceil(imageHeight * scale).toInt()
targetHeight = (ceil((scaleHeight / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt()
resizeOp = ResizeOp(scaleHeight, targetWidth, ResizeOp.ResizeMethod.BILINEAR)
} else {
val scale = inputSizeHeight / imageHeight.toFloat()
targetHeight = inputSizeHeight
scaleWidth = ceil(imageWidth * scale).toInt()
targetWidth = (ceil((scaleWidth / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt()
resizeOp = ResizeOp(targetHeight, scaleWidth, ResizeOp.ResizeMethod.BILINEAR)
}
val resizeWithCropOrPad = if (type == Type.Dynamic) ResizeWithCropOrPadOp(
targetHeight,
targetWidth
) else ResizeWithCropOrPadOp(
inputSizeHeight,
inputSizeWidth
)
val imageProcessor = ImageProcessor.Builder().apply {
add(resizeOp)
add(resizeWithCropOrPad)
}.build()
val tensorImage = TensorImage(DataType.UINT8)
tensorImage.load(bitmap)
return imageProcessor.process(tensorImage)
}
/**
* Run tracker (if available) and process the output.
*/
private fun postProcess(modelOutput: FloatArray): List<Person> {
val persons = mutableListOf<Person>()
for (idx in modelOutput.indices step outputShape[2]) {
val personScore = modelOutput[idx + DETECTION_SCORE_INDEX]
if (personScore < DETECTION_THRESHOLD) continue
val positions = modelOutput.copyOfRange(idx, idx + 51)
val keyPoints = mutableListOf<KeyPoint>()
for (i in 0 until KEYPOINT_COUNT) {
val y = positions[i * OUTPUTS_COUNT_PER_KEYPOINT]
val x = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 1]
val score = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 2]
keyPoints.add(KeyPoint(BodyPart.fromInt(i), PointF(x, y), score))
}
val yMin = modelOutput[idx + BOUNDING_BOX_Y_MIN_INDEX]
val xMin = modelOutput[idx + BOUNDING_BOX_X_MIN_INDEX]
val yMax = modelOutput[idx + BOUNDING_BOX_Y_MAX_INDEX]
val xMax = modelOutput[idx + BOUNDING_BOX_X_MAX_INDEX]
val boundingBox = RectF(xMin, yMin, xMax, yMax)
persons.add(
Person(
keyPoints = keyPoints,
boundingBox = boundingBox,
score = personScore
)
)
}
if (persons.isEmpty()) return emptyList()
if (tracker == null) {
persons.forEach {
it.keyPoints.forEach { key ->
key.coordinate = resizeKeypoint(key.coordinate.x, key.coordinate.y)
}
}
return persons
} else {
val trackPersons = mutableListOf<Person>()
tracker?.apply(persons, System.currentTimeMillis() * 1000)?.forEach {
val resizeKeyPoint = mutableListOf<KeyPoint>()
it.keyPoints.forEach { key ->
resizeKeyPoint.add(
KeyPoint(
key.bodyPart,
resizeKeypoint(key.coordinate.x, key.coordinate.y),
key.score
)
)
}
var resizeBoundingBox: RectF? = null
it.boundingBox?.let { boundingBox ->
resizeBoundingBox = RectF(
resizeX(boundingBox.left),
resizeY(boundingBox.top),
resizeX(boundingBox.right),
resizeY(boundingBox.bottom)
)
}
trackPersons.add(Person(it.id, resizeKeyPoint, resizeBoundingBox, it.score))
}
return trackPersons
}
}
/**
* Create and set tracker.
*/
fun setTracker(trackerType: TrackerType) {
tracker = when (trackerType) {
TrackerType.BOUNDING_BOX -> {
BoundingBoxTracker()
}
TrackerType.KEYPOINTS -> {
KeyPointsTracker()
}
TrackerType.OFF -> {
null
}
}
}
/**
* Run TFlite model and Returns a list of "Person" corresponding to the input image.
*/
override fun estimatePoses(bitmap: Bitmap): List<Person> {
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
val inputTensor = processInputTensor(bitmap)
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
// if model is dynamic, resize input before run interpreter
if (type == Type.Dynamic) {
val inputShape = intArrayOf(1).plus(inputTensor.tensorBuffer.shape)
interpreter.resizeInput(0, inputShape, true)
interpreter.allocateTensors()
}
interpreter.run(inputTensor.buffer, outputTensor.buffer.rewind())
val processedPerson = postProcess(outputTensor.floatArray)
lastInferenceTimeNanos =
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
return processedPerson
}
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
/**
* Close all resources when not in use.
*/
override fun close() {
gpuDelegate?.close()
interpreter.close()
tracker = null
}
}
enum class Type {
Dynamic, Fixed
}
enum class TrackerType {
OFF, BOUNDING_BOX, KEYPOINTS
}

@ -0,0 +1,73 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.examples.poseestimation.data.Person
import org.tensorflow.lite.support.common.FileUtil
class PoseClassifier(
private val interpreter: Interpreter,
private val labels: List<String>
) {
private val input = interpreter.getInputTensor(0).shape()
private val output = interpreter.getOutputTensor(0).shape()
companion object {
private const val MODEL_FILENAME = "classifier.tflite"
private const val LABELS_FILENAME = "labels.txt"
private const val CPU_NUM_THREADS = 4
fun create(context: Context): PoseClassifier {
val options = Interpreter.Options().apply {
setNumThreads(CPU_NUM_THREADS)
}
return PoseClassifier(
Interpreter(
FileUtil.loadMappedFile(
context, MODEL_FILENAME
), options
),
FileUtil.loadLabels(context, LABELS_FILENAME)
)
}
}
fun classify(person: Person?): List<Pair<String, Float>> {
// Preprocess the pose estimation result to a flat array
val inputVector = FloatArray(input[1])
person?.keyPoints?.forEachIndexed { index, keyPoint ->
inputVector[index * 3] = keyPoint.coordinate.y
inputVector[index * 3 + 1] = keyPoint.coordinate.x
inputVector[index * 3 + 2] = keyPoint.score
}
// Postprocess the model output to human readable class names
val outputTensor = FloatArray(output[1])
interpreter.run(arrayOf(inputVector), arrayOf(outputTensor))
val output = mutableListOf<Pair<String, Float>>()
outputTensor.forEachIndexed { index, score ->
output.add(Pair(labels[index], score))
}
return output
}
fun close() {
interpreter.close()
}
}

@ -0,0 +1,27 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.graphics.Bitmap
import org.tensorflow.lite.examples.poseestimation.data.Person
interface PoseDetector : AutoCloseable {
fun estimatePoses(bitmap: Bitmap): List<Person>
fun lastInferenceTimeNanos(): Long
}

@ -0,0 +1,261 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.ml
import android.content.Context
import android.graphics.Bitmap
import android.graphics.PointF
import android.os.SystemClock
import android.util.Log
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
import org.tensorflow.lite.examples.poseestimation.data.Person
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
import kotlin.math.exp
class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
PoseDetector {
companion object {
private const val CPU_NUM_THREADS = 4
private const val MEAN = 127.5f
private const val STD = 127.5f
private const val TAG = "Posenet"
private const val MODEL_FILENAME = "posenet.tflite"
fun create(context: Context, device: Device): PoseNet {
val options = Interpreter.Options()
var gpuDelegate: GpuDelegate? = null
options.setNumThreads(CPU_NUM_THREADS)
when (device) {
Device.CPU -> {
}
Device.GPU -> {
gpuDelegate = GpuDelegate()
options.addDelegate(gpuDelegate)
}
Device.NNAPI -> options.setUseNNAPI(true)
}
return PoseNet(
Interpreter(
FileUtil.loadMappedFile(
context,
MODEL_FILENAME
), options
),
gpuDelegate
)
}
}
private var lastInferenceTimeNanos: Long = -1
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
private var cropHeight = 0f
private var cropWidth = 0f
private var cropSize = 0
@Suppress("UNCHECKED_CAST")
override fun estimatePoses(bitmap: Bitmap): List<Person> {
val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
val inputArray = arrayOf(processInputImage(bitmap).tensorBuffer.buffer)
Log.i(
TAG,
String.format(
"Scaling to [-1,1] took %.2f ms",
(SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000f
)
)
val outputMap = initOutputMap(interpreter)
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
interpreter.runForMultipleInputsOutputs(inputArray, outputMap)
lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
Log.i(
TAG,
String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
)
val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>
val postProcessingStartTimeNanos = SystemClock.elapsedRealtimeNanos()
val person = postProcessModelOuputs(heatmaps, offsets)
Log.i(
TAG,
String.format(
"Postprocessing took %.2f ms",
(SystemClock.elapsedRealtimeNanos() - postProcessingStartTimeNanos) / 1_000_000f
)
)
return listOf(person)
}
/**
* Convert heatmaps and offsets output of Posenet into a list of keypoints
*/
private fun postProcessModelOuputs(
heatmaps: Array<Array<Array<FloatArray>>>,
offsets: Array<Array<Array<FloatArray>>>
): Person {
val height = heatmaps[0].size
val width = heatmaps[0][0].size
val numKeypoints = heatmaps[0][0][0].size
// Finds the (row, col) locations of where the keypoints are most likely to be.
val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
for (keypoint in 0 until numKeypoints) {
var maxVal = heatmaps[0][0][0][keypoint]
var maxRow = 0
var maxCol = 0
for (row in 0 until height) {
for (col in 0 until width) {
if (heatmaps[0][row][col][keypoint] > maxVal) {
maxVal = heatmaps[0][row][col][keypoint]
maxRow = row
maxCol = col
}
}
}
keypointPositions[keypoint] = Pair(maxRow, maxCol)
}
// Calculating the x and y coordinates of the keypoints with offset adjustment.
val xCoords = IntArray(numKeypoints)
val yCoords = IntArray(numKeypoints)
val confidenceScores = FloatArray(numKeypoints)
keypointPositions.forEachIndexed { idx, position ->
val positionY = keypointPositions[idx].first
val positionX = keypointPositions[idx].second
val inputImageCoordinateY =
position.first / (height - 1).toFloat() * inputHeight + offsets[0][positionY][positionX][idx]
val ratioHeight = cropSize.toFloat() / inputHeight
val paddingHeight = cropHeight / 2
yCoords[idx] = (inputImageCoordinateY * ratioHeight - paddingHeight).toInt()
val inputImageCoordinateX =
position.second / (width - 1).toFloat() * inputWidth + offsets[0][positionY][positionX][idx + numKeypoints]
val ratioWidth = cropSize.toFloat() / inputWidth
val paddingWidth = cropWidth / 2
xCoords[idx] = (inputImageCoordinateX * ratioWidth - paddingWidth).toInt()
confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
}
val keypointList = mutableListOf<KeyPoint>()
var totalScore = 0.0f
enumValues<BodyPart>().forEachIndexed { idx, it ->
keypointList.add(
KeyPoint(
it,
PointF(xCoords[idx].toFloat(), yCoords[idx].toFloat()),
confidenceScores[idx]
)
)
totalScore += confidenceScores[idx]
}
return Person(keyPoints = keypointList.toList(), score = totalScore / numKeypoints)
}
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
override fun close() {
gpuDelegate?.close()
interpreter.close()
}
/**
* Scale and crop the input image to a TensorImage.
*/
private fun processInputImage(bitmap: Bitmap): TensorImage {
// reset crop width and height
cropWidth = 0f
cropHeight = 0f
cropSize = if (bitmap.width > bitmap.height) {
cropHeight = (bitmap.width - bitmap.height).toFloat()
bitmap.width
} else {
cropWidth = (bitmap.height - bitmap.width).toFloat()
bitmap.height
}
val imageProcessor = ImageProcessor.Builder().apply {
add(ResizeWithCropOrPadOp(cropSize, cropSize))
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
add(NormalizeOp(MEAN, STD))
}.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(bitmap)
return imageProcessor.process(tensorImage)
}
/**
* Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
*/
private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
val outputMap = HashMap<Int, Any>()
// 1 * 9 * 9 * 17 contains heatmaps
val heatmapsShape = interpreter.getOutputTensor(0).shape()
outputMap[0] = Array(heatmapsShape[0]) {
Array(heatmapsShape[1]) {
Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) }
}
}
// 1 * 9 * 9 * 34 contains offsets
val offsetsShape = interpreter.getOutputTensor(1).shape()
outputMap[1] = Array(offsetsShape[0]) {
Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } }
}
// 1 * 9 * 9 * 32 contains forward displacements
val displacementsFwdShape = interpreter.getOutputTensor(2).shape()
outputMap[2] = Array(offsetsShape[0]) {
Array(displacementsFwdShape[1]) {
Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) }
}
}
// 1 * 9 * 9 * 32 contains backward displacements
val displacementsBwdShape = interpreter.getOutputTensor(3).shape()
outputMap[3] = Array(displacementsBwdShape[0]) {
Array(displacementsBwdShape[1]) {
Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) }
}
}
return outputMap
}
/** Returns value within [0,1]. */
private fun sigmoid(x: Float): Float {
return (1.0f / (1.0f + exp(-x)))
}
}

@ -0,0 +1,158 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import org.tensorflow.lite.examples.poseestimation.data.Person
abstract class AbstractTracker(val config: TrackerConfig) {
private val maxAge = config.maxAge * 1000 // convert milliseconds to microseconds
private var nextTrackId = 0
var tracks = mutableListOf<Track>()
private set
/**
* Computes pairwise similarity scores between detections and tracks, based
* on detected features.
* @param persons A list of detected person.
* @returns A list of shape [num_det, num_tracks] with pairwise
* similarity scores between detections and tracks.
*/
abstract fun computeSimilarity(persons: List<Person>): List<List<Float>>
/**
* Tracks person instances across frames based on detections.
* @param persons A list of person
* @param timestamp The current timestamp in microseconds
* @return An updated list of persons with tracking id.
*/
fun apply(persons: List<Person>, timestamp: Long): List<Person> {
tracks = filterOldTrack(timestamp).toMutableList()
val simMatrix = computeSimilarity(persons)
assignTrack(persons, simMatrix, timestamp)
tracks = updateTrack().toMutableList()
return persons
}
/**
* Clear all track in list of tracks
*/
fun reset() {
tracks.clear()
}
/**
* Return the next track id
*/
private fun nextTrackID() = ++nextTrackId
/**
* Performs a greedy optimization to link detections with tracks. The person
* list is updated in place by providing an `id` property. If incoming
* detections are not linked with existing tracks, new tracks will be created.
* @param persons A list of detected person. It's assumed that persons are
* sorted from most confident to least confident.
* @param simMatrix A list of shape [num_det, num_tracks] with pairwise
* similarity scores between detections and tracks.
* @param timestamp The current timestamp in microseconds.
*/
private fun assignTrack(persons: List<Person>, simMatrix: List<List<Float>>, timestamp: Long) {
if ((simMatrix.size != persons.size) != (simMatrix[0].size != tracks.size)) {
throw IllegalArgumentException(
"Size of person array and similarity matrix does not match.")
}
val unmatchedTrackIndices = MutableList(tracks.size) { it }
val unmatchedDetectionIndices = mutableListOf<Int>()
for (detectionIndex in persons.indices) {
// If the track list is empty, add the person's index
// to unmatched detections to create a new track later.
if (unmatchedTrackIndices.isEmpty()) {
unmatchedDetectionIndices.add(detectionIndex)
continue
}
// Assign the detection to the track which produces the highest pairwise
// similarity score, assuming the score exceeds the minimum similarity
// threshold.
var maxTrackIndex = -1
var maxSimilarity = -1f
unmatchedTrackIndices.forEach { trackIndex ->
val similarity = simMatrix[detectionIndex][trackIndex]
if (similarity >= config.minSimilarity && similarity > maxSimilarity) {
maxTrackIndex = trackIndex
maxSimilarity = similarity
}
}
if (maxTrackIndex >= 0) {
val linkedTrack = tracks[maxTrackIndex]
tracks[maxTrackIndex] =
createTrack(persons[detectionIndex], linkedTrack.person.id, timestamp)
persons[detectionIndex].id = linkedTrack.person.id
val index = unmatchedTrackIndices.indexOf(maxTrackIndex)
unmatchedTrackIndices.removeAt(index)
} else {
unmatchedDetectionIndices.add(detectionIndex)
}
}
// Spawn new tracks for all unmatched detections.
unmatchedDetectionIndices.forEach { detectionIndex ->
val newTrack = createTrack(persons[detectionIndex], timestamp = timestamp)
tracks.add(newTrack)
persons[detectionIndex].id = newTrack.person.id
}
}
/**
* Filters tracks based on their age.
* @param timestamp The timestamp in microseconds
*/
private fun filterOldTrack(timestamp: Long): List<Track> {
return tracks.filter {
timestamp - it.lastTimestamp <= maxAge
}
}
/**
* Sort the track list by timestamp (newer first)
* and return the track list with size equal to config.maxTracks
*/
private fun updateTrack(): List<Track> {
tracks.sortByDescending { it.lastTimestamp }
return tracks.take(config.maxTracks)
}
/**
* Create a new track from person's information.
* @param person A person
* @param id The Id assign to the new track. If it is null, assign the next track id.
* @param timestamp The timestamp in microseconds
*/
private fun createTrack(person: Person, id: Int? = null, timestamp: Long): Track {
return Track(
person = Person(
id = id ?: nextTrackID(),
keyPoints = person.keyPoints,
boundingBox = person.boundingBox,
score = person.score
),
lastTimestamp = timestamp
)
}
}

@ -0,0 +1,63 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import androidx.annotation.VisibleForTesting
import org.tensorflow.lite.examples.poseestimation.data.Person
import kotlin.math.max
import kotlin.math.min
/**
* BoundingBoxTracker, which tracks objects based on bounding box similarity,
* currently defined as intersection-over-union (IoU).
*/
class BoundingBoxTracker(config: TrackerConfig = TrackerConfig()) : AbstractTracker(config) {
/**
* Computes similarity based on intersection-over-union (IoU). See `AbstractTracker`
* for more details.
*/
override fun computeSimilarity(persons: List<Person>): List<List<Float>> {
if (persons.isEmpty() && tracks.isEmpty()) {
return emptyList()
}
return persons.map { person -> tracks.map { track -> iou(person, track.person) } }
}
/**
* Computes the intersection-over-union (IoU) between a person and a track person.
* @param person1 A person
* @param person2 A track person
* @return The IoU between the person and the track person. This number is
* between 0 and 1, and larger values indicate more box similarity.
*/
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
fun iou(person1: Person, person2: Person): Float {
if (person1.boundingBox != null && person2.boundingBox != null) {
val xMin = max(person1.boundingBox.left, person2.boundingBox.left)
val yMin = max(person1.boundingBox.top, person2.boundingBox.top)
val xMax = min(person1.boundingBox.right, person2.boundingBox.right)
val yMax = min(person1.boundingBox.bottom, person2.boundingBox.bottom)
if (xMin >= xMax || yMin >= yMax) return 0f
val intersection = (xMax - xMin) * (yMax - yMin)
val areaPerson = person1.boundingBox.width() * person1.boundingBox.height()
val areaTrack = person2.boundingBox.width() * person2.boundingBox.height()
return intersection / (areaPerson + areaTrack - intersection)
}
return 0f
}
}

@ -0,0 +1,125 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import androidx.annotation.VisibleForTesting
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
import org.tensorflow.lite.examples.poseestimation.data.Person
import kotlin.math.exp
import kotlin.math.max
import kotlin.math.min
import kotlin.math.pow
/**
* KeypointTracker, which tracks poses based on keypoint similarity. This
* tracker assumes that keypoints are provided in normalized image
* coordinates.
*/
class KeyPointsTracker(
trackerConfig: TrackerConfig = TrackerConfig(
keyPointsTrackerParams = KeyPointsTrackerParams()
)
) : AbstractTracker(trackerConfig) {
/**
* Computes similarity based on Object Keypoint Similarity (OKS). It's
* assumed that the keypoints within each person are in normalized image
* coordinates. See `AbstractTracker` for more details.
*/
override fun computeSimilarity(persons: List<Person>): List<List<Float>> {
if (persons.isEmpty() && tracks.isEmpty()) {
return emptyList()
}
val simMatrix = mutableListOf<MutableList<Float>>()
persons.forEach { person ->
val row = mutableListOf<Float>()
tracks.forEach { track ->
val oksValue = oks(person, track.person)
row.add(oksValue)
}
simMatrix.add(row)
}
return simMatrix
}
/**
* Computes the Object Keypoint Similarity (OKS) between a person and track person.
* This is similar in spirit to the calculation used by COCO keypoint eval:
* https://cocodataset.org/#keypoints-eval
* In this case, OKS is calculated as:
* (1/sum_i d(c_i, c_ti)) * \sum_i exp(-d_i^2/(2*a_ti*x_i^2))*d(c_i, c_ti)
* where
* d(x, y) is an indicator function which only produces 1 if x and y
* exceed a given threshold (i.e. keypointThreshold), otherwise 0.
* c_i is the confidence of keypoint i from the new person
* c_ti is the confidence of keypoint i from the track person
* d_i is the Euclidean distance between the person and track person keypoint
* a_ti is the area of the track object (the box covering the keypoints)
* x_i is a constant that controls falloff in a Gaussian distribution,
* computed as 2*keypointFalloff[i].
* @param person1 A person.
* @param person2 A track person.
* @returns The OKS score between the person and the track person. This number is
* between 0 and 1, and larger values indicate more keypoint similarity.
*/
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
fun oks(person1: Person, person2: Person): Float {
if (config.keyPointsTrackerParams == null) return 0f
person2.keyPoints.let { keyPoints ->
val boxArea = area(keyPoints) + 1e-6
var oksTotal = 0f
var numValidKeyPoints = 0
person1.keyPoints.forEachIndexed { index, _ ->
val poseKpt = person1.keyPoints[index]
val trackKpt = person2.keyPoints[index]
val threshold = config.keyPointsTrackerParams.keypointThreshold
if (poseKpt.score < threshold || trackKpt.score < threshold) {
return@forEachIndexed
}
numValidKeyPoints += 1
val dSquared: Float =
(poseKpt.coordinate.x - trackKpt.coordinate.x).pow(2) + (poseKpt.coordinate.y - trackKpt.coordinate.y).pow(
2
)
val x = 2f * config.keyPointsTrackerParams.keypointFalloff[index]
oksTotal += exp(-1f * dSquared / (2f * boxArea * x.pow(2))).toFloat()
}
if (numValidKeyPoints < config.keyPointsTrackerParams.minNumKeyPoints) {
return 0f
}
return oksTotal / numValidKeyPoints
}
}
/**
* Computes the area of a bounding box that tightly covers keypoints.
* @param keyPoints A list of KeyPoint.
* @returns The area of the object.
*/
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
fun area(keyPoints: List<KeyPoint>): Float {
val validKeypoint = keyPoints.filter {
it.score > config.keyPointsTrackerParams?.keypointThreshold ?: 0f
}
if (validKeypoint.isEmpty()) return 0f
val minX = min(1f, validKeypoint.minOf { it.coordinate.x })
val maxX = max(0f, validKeypoint.maxOf { it.coordinate.x })
val minY = min(1f, validKeypoint.minOf { it.coordinate.y })
val maxY = max(0f, validKeypoint.maxOf { it.coordinate.y })
return (maxX - minX) * (maxY - minY)
}
}

@ -0,0 +1,24 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
import org.tensorflow.lite.examples.poseestimation.data.Person
data class Track(
val person: Person,
val lastTimestamp: Long
)

@ -0,0 +1,50 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.lite.examples.poseestimation.tracker
data class TrackerConfig(
val maxTracks: Int = MAX_TRACKS,
val maxAge: Int = MAX_AGE,
val minSimilarity: Float = MIN_SIMILARITY,
val keyPointsTrackerParams: KeyPointsTrackerParams? = null
) {
companion object {
private const val MAX_TRACKS = 18
private const val MAX_AGE = 1000 // millisecond
private const val MIN_SIMILARITY = 0.15f
}
}
data class KeyPointsTrackerParams(
val keypointThreshold: Float = KEYPOINT_THRESHOLD,
// List of per-keypoint standard deviation `σ`, keypoints on a person's body (shoulders, knees, hips, etc.)
// tend to have a `σ` much larger than on a person's head (eyes, nose, ears).
// Read more at: https://cocodataset.org/#keypoints-eval
val keypointFalloff: List<Float> = KEYPOINT_FALLOFF,
val minNumKeyPoints: Int = MIN_NUM_KEYPOINT
) {
companion object {
// From COCO:
// https://cocodataset.org/#keypoints-eval
private val KEYPOINT_FALLOFF: List<Float> = listOf(
0.026f, 0.025f, 0.025f, 0.035f, 0.035f, 0.079f, 0.079f, 0.072f, 0.072f, 0.062f,
0.062f, 0.107f, 0.107f, 0.087f, 0.087f, 0.089f, 0.089f
)
private const val KEYPOINT_THRESHOLD = 0.3f
private const val MIN_NUM_KEYPOINT = 4
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 929 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 596 B

@ -0,0 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<shape xmlns:android="http://schemas.android.com/apk/res/android"
android:shape="rectangle">
<solid
android:color="#e4e4e4">
</solid>
<stroke
android:width="0dp"
android:color="#424242">
</stroke>
<corners
android:topLeftRadius="30dp"
android:topRightRadius="30dp"
android:bottomLeftRadius="0dp"
android:bottomRightRadius="0dp">
</corners>
</shape>

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

@ -0,0 +1,27 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.coordinatorlayout.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<SurfaceView
android:id="@+id/surfaceView"
android:layout_width="match_parent"
android:layout_height="match_parent" />
<androidx.appcompat.widget.Toolbar
android:id="@+id/toolbar"
android:layout_width="match_parent"
android:layout_height="?attr/actionBarSize"
android:background="#66000000">
<ImageView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:contentDescription="@null"
android:src="@drawable/tfl2_logo" />
</androidx.appcompat.widget.Toolbar>
<include layout="@layout/bottom_sheet_layout"/>
</androidx.coordinatorlayout.widget.CoordinatorLayout>

@ -0,0 +1,124 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.appcompat.widget.LinearLayoutCompat xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/bottom_sheet"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="8dp"
android:background="@drawable/rounded_edge"
android:orientation="vertical"
android:paddingStart="8dp"
android:paddingTop="10dp"
android:paddingEnd="8dp"
android:paddingBottom="16dp"
app:behavior_hideable="false"
app:behavior_peekHeight="36dp"
app:layout_behavior="com.google.android.material.bottomsheet.BottomSheetBehavior"
tools:showIn="@layout/activity_main">
<ImageView
android:id="@+id/bottom_sheet_arrow"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:contentDescription="@null"
android:src="@drawable/icn_chevron_up" />
<TextView
android:id="@+id/tvFps"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
<TextView
android:id="@+id/tvScore"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/tfe_pe_tv_device" />
<Spinner
android:id="@+id/spnDevice"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</LinearLayout>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/tfe_pe_tv_model" />
<Spinner
android:id="@+id/spnModel"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</LinearLayout>
<LinearLayout
android:id="@+id/vTrackerOption"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/tfe_pe_tv_tracking" />
<Spinner
android:id="@+id/spnTracker"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</LinearLayout>
<RelativeLayout
android:id="@+id/vClassificationOption"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<TextView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_centerVertical="true"
android:layout_toStartOf="@id/swPoseClassification"
android:text="@string/tfe_pe_tv_pose_classification" />
<androidx.appcompat.widget.SwitchCompat
android:id="@+id/swPoseClassification"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentEnd="true" />
</RelativeLayout>
<TextView
android:id="@+id/tvClassificationValue1"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:visibility="gone" />
<TextView
android:id="@+id/tvClassificationValue2"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:visibility="gone" />
<TextView
android:id="@+id/tvClassificationValue3"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:visibility="gone" />
</androidx.appcompat.widget.LinearLayoutCompat>

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="teal_700">#FF018786</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>

@ -0,0 +1,30 @@
<resources>
<string name="tfe_pe_app_name">TFL Pose Estimation</string>
<string name="tfe_pe_request_permission">This app needs camera permission.</string>
<string name="tfe_pe_tv_score">Score: %.2f</string>
<string name="tfe_pe_tv_model">Model: </string>
<string name="tfe_pe_tv_fps">Fps: %d</string>
<string name="tfe_pe_tv_device">Device: </string>
<string name="tfe_pe_tv_classification_value">- %s</string>
<string name="tfe_pe_tv_pose_classification">Pose Classification</string>
<string name="tfe_pe_tv_tracking">Tracker: </string>
<string name="tfe_pe_gpu_error">Movenet MultiPose does not support GPU. Fallback to CPU.</string>
<string-array name="tfe_pe_models_array">
<item>Movenet Lightning</item>
<item>Movenet Thunder</item>
<item>Movenet MultiPose</item>
<item>Posenet</item>
</string-array>
<string-array name="tfe_pe_device_name">
<item>CPU</item>
<item>GPU</item>
<item>NNAPI</item>
</string-array>
<string-array name="tfe_pe_tracker_array">
<item>Off</item>
<item>BoundingBox</item>
<item>Keypoint</item>
</string-array>
</resources>

@ -0,0 +1,18 @@
<resources xmlns:tools="http://schemas.android.com/tools">
<!-- Base application theme. -->
<style name="Theme.PoseEstimation" parent="Theme.MaterialComponents.Light.NoActionBar">
<!-- Primary brand color. -->
<item name="colorPrimary">@color/purple_500</item>
<item name="colorPrimaryVariant">@color/purple_700</item>
<item name="colorOnPrimary">@color/white</item>
<!-- Secondary brand color. -->
<item name="colorSecondary">@color/teal_200</item>
<item name="colorSecondaryVariant">@color/teal_700</item>
<item name="colorOnSecondary">@color/black</item>
<!-- Status bar color. -->
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
<!-- Customize your theme here. -->
<item name="android:textColor">@color/black</item>
<item name="android:textSize">16sp</item>
</style>
</resources>

@ -0,0 +1,30 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
ext.kotlin_version = "1.9.21"
repositories {
google()
mavenCentral()
}
dependencies {
classpath "com.android.tools.build:gradle:8.0.2"
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
}
allprojects {
repositories {
google()
mavenCentral()
maven {
name 'ossrh-snapshot'
url 'https://oss.sonatype.org/content/repositories/snapshots'
}
}
}
task clean(type: Delete) {
delete rootProject.buildDir
}

@ -0,0 +1,21 @@
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app"s APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Automatically convert third-party libraries to use AndroidX
android.enableJetifier=true
# Kotlin code style for this project: "official" or "obsolete":
kotlin.code.style=official

Binary file not shown.

@ -0,0 +1,7 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

249
android/gradlew vendored

@ -0,0 +1,249 @@
#!/bin/sh
#
# Copyright © 2015-2021 the original authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
#
# Gradle start up script for POSIX generated by Gradle.
#
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
warn () {
echo "$*"
} >&2
die () {
echo
echo "$*"
echo
exit 1
} >&2
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD=java
if ! command -v java >/dev/null 2>&1
then
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC2039,SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC2039,SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Collect all arguments for the java command, stacking in reverse order:
# * args from the command line
# * the main class name
# * -classpath
# * -D...appname settings
# * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Collect all arguments for the java command:
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
# and any embedded shellness will be escaped.
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
# treated as '${Hostname}' itself on the command line.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

@ -0,0 +1,92 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%"=="" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

@ -0,0 +1,8 @@
## This file must *NOT* be checked into Version Control Systems,
# as it contains information specific to your local configuration.
#
# Location of the SDK. This is only used by Gradle.
# For customization when using a Version Control System, please read the
# header note.
#Mon Apr 21 08:12:48 CST 2025
sdk.dir=/Users/ziyue/Library/Android/sdk

Binary file not shown.

After

Width:  |  Height:  |  Size: 602 KiB

@ -0,0 +1,2 @@
include ':app'
rootProject.name = "TFLite Pose Estimation"

@ -1,4 +1,3 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="jbr-21" project-jdk-type="JavaSDK">

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

@ -21,6 +21,8 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<activity android:name=".ui.PlanActivity" />
<activity android:name=".ui.SettingsActivity" />
</application>
</manifest>

@ -1,24 +0,0 @@
package com.addd.xingdongli.ui;
import android.os.Bundle;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;
import com.addd.xingdongli.R;
public class HomeFragment extends Fragment {
@Override
public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) {
// 使用 fragment_home.xml 布局文件
return inflater.inflate(R.layout.fragment_home, container, false);
}
}

@ -1,24 +1,15 @@
package com.addd.xingdongli.ui;
import android.content.Intent;
import android.os.Bundle;
import android.util.DisplayMetrics;
import android.view.MenuItem;
import android.widget.LinearLayout;
import android.view.View;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.fragment.app.Fragment;
import androidx.fragment.app.FragmentTransaction;
import com.addd.xingdongli.R;
import com.google.android.material.bottomnavigation.BottomNavigationView;
import com.google.android.material.bottomsheet.BottomSheetBehavior;
import com.google.android.material.floatingactionbutton.FloatingActionButton;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
@ -27,44 +18,39 @@ public class MainActivity extends AppCompatActivity {
// 初始化 BottomNavigationView
BottomNavigationView bottomNavigation = findViewById(R.id.bottom_navigation);
// 设置菜单项点击监听
bottomNavigation.setOnNavigationItemSelectedListener(new BottomNavigationView.OnNavigationItemSelectedListener() {
@Override
public boolean onNavigationItemSelected(@NonNull MenuItem item) {
int itemId = item.getItemId();
if (itemId == R.id.navigation_home) {
loadFragment(new HomeFragment());
// 已经在主页面,不需要切换
return true;
} else if (itemId == R.id.navigation_plan) {
loadFragment(new PlanFragment());
} else if (itemId == R.id.navigation_progress) {
Intent intent = new Intent(MainActivity.this, PlanActivity.class);
startActivity(intent);
overridePendingTransition(R.anim.slide_in_right, R.anim.slide_out_left);
finish();
return true;
} else if (itemId == R.id.navigation_settings) {
loadFragment(new SettingsFragment());
Intent intent = new Intent(MainActivity.this, SettingsActivity.class);
startActivity(intent);
overridePendingTransition(R.anim.slide_in_right, R.anim.slide_out_left);
finish();
return true;
}
return false;
}
});
// 默认加载主页
if (savedInstanceState == null) {
loadFragment(new HomeFragment());
}
// 设置默认选中项为主页
bottomNavigation.setSelectedItemId(R.id.navigation_home);
}
/**
* Fragment
* @param fragment Fragment
*/
private void loadFragment(Fragment fragment) {
FragmentTransaction transaction = getSupportFragmentManager().beginTransaction();
transaction.replace(R.id.fragment_container, fragment);
transaction.commit();
@Override
public void onBackPressed() {
super.onBackPressed();
overridePendingTransition(R.anim.slide_in_left, R.anim.slide_out_right);
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save