@ -0,0 +1,2 @@
|
||||
#Mon Apr 21 08:14:16 CST 2025
|
||||
gradle.version=8.5
|
@ -0,0 +1,2 @@
|
||||
#Mon Apr 21 08:12:48 CST 2025
|
||||
java.home=/Applications/Android Studio.app/Contents/jbr/Contents/Home
|
@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
@ -0,0 +1 @@
|
||||
TFLite Pose Estimation
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AndroidProjectSystem">
|
||||
<option name="providerId" value="com.android.tools.idea.GradleProjectSystem" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,607 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DeviceStreaming">
|
||||
<option name="deviceSelectionList">
|
||||
<list>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="27" />
|
||||
<option name="brand" value="DOCOMO" />
|
||||
<option name="codename" value="F01L" />
|
||||
<option name="id" value="F01L" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="FUJITSU" />
|
||||
<option name="name" value="F-01L" />
|
||||
<option name="screenDensity" value="360" />
|
||||
<option name="screenX" value="720" />
|
||||
<option name="screenY" value="1280" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="OnePlus" />
|
||||
<option name="codename" value="OP5552L1" />
|
||||
<option name="id" value="OP5552L1" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="OnePlus" />
|
||||
<option name="name" value="CPH2415" />
|
||||
<option name="screenDensity" value="480" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2412" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="OPPO" />
|
||||
<option name="codename" value="OP573DL1" />
|
||||
<option name="id" value="OP573DL1" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="OPPO" />
|
||||
<option name="name" value="CPH2557" />
|
||||
<option name="screenDensity" value="480" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="28" />
|
||||
<option name="brand" value="DOCOMO" />
|
||||
<option name="codename" value="SH-01L" />
|
||||
<option name="id" value="SH-01L" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="SHARP" />
|
||||
<option name="name" value="AQUOS sense2 SH-01L" />
|
||||
<option name="screenDensity" value="480" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2160" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="Lenovo" />
|
||||
<option name="codename" value="TB370FU" />
|
||||
<option name="formFactor" value="Tablet" />
|
||||
<option name="id" value="TB370FU" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Lenovo" />
|
||||
<option name="name" value="Tab P12" />
|
||||
<option name="screenDensity" value="340" />
|
||||
<option name="screenX" value="1840" />
|
||||
<option name="screenY" value="2944" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="a15" />
|
||||
<option name="id" value="a15" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="A15" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="a35x" />
|
||||
<option name="id" value="a35x" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="A35" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="31" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="a51" />
|
||||
<option name="id" value="a51" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy A51" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="akita" />
|
||||
<option name="id" value="akita" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 8a" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="motorola" />
|
||||
<option name="codename" value="arcfox" />
|
||||
<option name="id" value="arcfox" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Motorola" />
|
||||
<option name="name" value="razr plus 2024" />
|
||||
<option name="screenDensity" value="360" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="1272" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="motorola" />
|
||||
<option name="codename" value="austin" />
|
||||
<option name="id" value="austin" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Motorola" />
|
||||
<option name="name" value="moto g 5G (2022)" />
|
||||
<option name="screenDensity" value="280" />
|
||||
<option name="screenX" value="720" />
|
||||
<option name="screenY" value="1600" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="b0q" />
|
||||
<option name="id" value="b0q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S22 Ultra" />
|
||||
<option name="screenDensity" value="600" />
|
||||
<option name="screenX" value="1440" />
|
||||
<option name="screenY" value="3088" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="32" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="bluejay" />
|
||||
<option name="id" value="bluejay" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 6a" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="caiman" />
|
||||
<option name="id" value="caiman" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 9 Pro" />
|
||||
<option name="screenDensity" value="360" />
|
||||
<option name="screenX" value="960" />
|
||||
<option name="screenY" value="2142" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="comet" />
|
||||
<option name="default" value="true" />
|
||||
<option name="id" value="comet" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 9 Pro Fold" />
|
||||
<option name="screenDensity" value="390" />
|
||||
<option name="screenX" value="2076" />
|
||||
<option name="screenY" value="2152" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="29" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="crownqlteue" />
|
||||
<option name="id" value="crownqlteue" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Note9" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="2220" />
|
||||
<option name="screenY" value="1080" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="dm2q" />
|
||||
<option name="id" value="dm2q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="S23 Plus" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="dm3q" />
|
||||
<option name="id" value="dm3q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S23 Ultra" />
|
||||
<option name="screenDensity" value="600" />
|
||||
<option name="screenX" value="1440" />
|
||||
<option name="screenY" value="3088" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="e1q" />
|
||||
<option name="default" value="true" />
|
||||
<option name="id" value="e1q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S24" />
|
||||
<option name="screenDensity" value="480" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="e3q" />
|
||||
<option name="id" value="e3q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S24 Ultra" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1440" />
|
||||
<option name="screenY" value="3120" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="eos" />
|
||||
<option name="id" value="eos" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Eos" />
|
||||
<option name="screenDensity" value="320" />
|
||||
<option name="screenX" value="384" />
|
||||
<option name="screenY" value="384" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="felix" />
|
||||
<option name="id" value="felix" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel Fold" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="2208" />
|
||||
<option name="screenY" value="1840" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="felix" />
|
||||
<option name="id" value="felix" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel Fold" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="2208" />
|
||||
<option name="screenY" value="1840" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="felix_camera" />
|
||||
<option name="id" value="felix_camera" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel Fold (Camera-enabled)" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="2208" />
|
||||
<option name="screenY" value="1840" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="motorola" />
|
||||
<option name="codename" value="fogona" />
|
||||
<option name="id" value="fogona" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Motorola" />
|
||||
<option name="name" value="moto g play - 2024" />
|
||||
<option name="screenDensity" value="280" />
|
||||
<option name="screenX" value="720" />
|
||||
<option name="screenY" value="1600" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="g0q" />
|
||||
<option name="id" value="g0q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="SM-S906U1" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="gta9pwifi" />
|
||||
<option name="id" value="gta9pwifi" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="SM-X210" />
|
||||
<option name="screenDensity" value="240" />
|
||||
<option name="screenX" value="1200" />
|
||||
<option name="screenY" value="1920" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="gts7xllite" />
|
||||
<option name="id" value="gts7xllite" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="SM-T738U" />
|
||||
<option name="screenDensity" value="340" />
|
||||
<option name="screenX" value="1600" />
|
||||
<option name="screenY" value="2560" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="gts8uwifi" />
|
||||
<option name="formFactor" value="Tablet" />
|
||||
<option name="id" value="gts8uwifi" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Tab S8 Ultra" />
|
||||
<option name="screenDensity" value="320" />
|
||||
<option name="screenX" value="1848" />
|
||||
<option name="screenY" value="2960" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="gts8wifi" />
|
||||
<option name="formFactor" value="Tablet" />
|
||||
<option name="id" value="gts8wifi" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Tab S8" />
|
||||
<option name="screenDensity" value="274" />
|
||||
<option name="screenX" value="1600" />
|
||||
<option name="screenY" value="2560" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="gts9fe" />
|
||||
<option name="id" value="gts9fe" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Tab S9 FE 5G" />
|
||||
<option name="screenDensity" value="280" />
|
||||
<option name="screenX" value="1440" />
|
||||
<option name="screenY" value="2304" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="husky" />
|
||||
<option name="id" value="husky" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 8 Pro" />
|
||||
<option name="screenDensity" value="390" />
|
||||
<option name="screenX" value="1008" />
|
||||
<option name="screenY" value="2244" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="30" />
|
||||
<option name="brand" value="motorola" />
|
||||
<option name="codename" value="java" />
|
||||
<option name="id" value="java" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Motorola" />
|
||||
<option name="name" value="G20" />
|
||||
<option name="screenDensity" value="280" />
|
||||
<option name="screenX" value="720" />
|
||||
<option name="screenY" value="1600" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="komodo" />
|
||||
<option name="id" value="komodo" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 9 Pro XL" />
|
||||
<option name="screenDensity" value="360" />
|
||||
<option name="screenX" value="1008" />
|
||||
<option name="screenY" value="2244" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="lynx" />
|
||||
<option name="id" value="lynx" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 7a" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="motorola" />
|
||||
<option name="codename" value="maui" />
|
||||
<option name="id" value="maui" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Motorola" />
|
||||
<option name="name" value="moto g play - 2023" />
|
||||
<option name="screenDensity" value="280" />
|
||||
<option name="screenX" value="720" />
|
||||
<option name="screenY" value="1600" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="o1q" />
|
||||
<option name="id" value="o1q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S21" />
|
||||
<option name="screenDensity" value="421" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="31" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="oriole" />
|
||||
<option name="id" value="oriole" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 6" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="panther" />
|
||||
<option name="id" value="panther" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 7" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="q5q" />
|
||||
<option name="id" value="q5q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Z Fold5" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1812" />
|
||||
<option name="screenY" value="2176" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="q6q" />
|
||||
<option name="id" value="q6q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy Z Fold6" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1856" />
|
||||
<option name="screenY" value="2160" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="30" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="r11" />
|
||||
<option name="formFactor" value="Wear OS" />
|
||||
<option name="id" value="r11" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel Watch" />
|
||||
<option name="screenDensity" value="320" />
|
||||
<option name="screenX" value="384" />
|
||||
<option name="screenY" value="384" />
|
||||
<option name="type" value="WEAR_OS" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="r11q" />
|
||||
<option name="id" value="r11q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="SM-S711U" />
|
||||
<option name="screenDensity" value="450" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="30" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="redfin" />
|
||||
<option name="id" value="redfin" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 5" />
|
||||
<option name="screenDensity" value="440" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2340" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="shiba" />
|
||||
<option name="id" value="shiba" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 8" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="samsung" />
|
||||
<option name="codename" value="t2q" />
|
||||
<option name="id" value="t2q" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Samsung" />
|
||||
<option name="name" value="Galaxy S21 Plus" />
|
||||
<option name="screenDensity" value="394" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2400" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="33" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="tangorpro" />
|
||||
<option name="formFactor" value="Tablet" />
|
||||
<option name="id" value="tangorpro" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel Tablet" />
|
||||
<option name="screenDensity" value="320" />
|
||||
<option name="screenX" value="1600" />
|
||||
<option name="screenY" value="2560" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="34" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="tokay" />
|
||||
<option name="default" value="true" />
|
||||
<option name="id" value="tokay" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 9" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2424" />
|
||||
</PersistentDeviceSelectionData>
|
||||
<PersistentDeviceSelectionData>
|
||||
<option name="api" value="35" />
|
||||
<option name="brand" value="google" />
|
||||
<option name="codename" value="tokay" />
|
||||
<option name="default" value="true" />
|
||||
<option name="id" value="tokay" />
|
||||
<option name="labId" value="google" />
|
||||
<option name="manufacturer" value="Google" />
|
||||
<option name="name" value="Pixel 9" />
|
||||
<option name="screenDensity" value="420" />
|
||||
<option name="screenX" value="1080" />
|
||||
<option name="screenY" value="2424" />
|
||||
</PersistentDeviceSelectionData>
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="21" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="deploymentTargetSelector">
|
||||
<selectionStates>
|
||||
<SelectionState runConfigName="app">
|
||||
<option name="selectionMode" value="DROPDOWN" />
|
||||
</SelectionState>
|
||||
</selectionStates>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="CHOOSE_PER_TEST" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="KotlinJpsPluginSettings">
|
||||
<option name="version" value="1.9.21" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectMigrations">
|
||||
<option name="MigrateToGradleLocalJavaHome">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
</set>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="jbr-21" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
<option name="id" value="Android" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="RunConfigurationProducerService">
|
||||
<option name="ignoredProducers">
|
||||
<set>
|
||||
<option value="com.intellij.execution.junit.AbstractAllInDirectoryConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.AllInPackageConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.PatternConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.TestInClassConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.UniqueIdConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.testDiscovery.JUnitTestDiscoveryConfigurationProducer" />
|
||||
<option value="org.jetbrains.kotlin.idea.junit.KotlinJUnitRunConfigurationProducer" />
|
||||
<option value="org.jetbrains.kotlin.idea.junit.KotlinPatternConfigurationProducer" />
|
||||
</set>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,73 @@
|
||||
# TensorFlow Lite Pose Estimation Android Demo
|
||||
|
||||
### Overview
|
||||
This is an app that continuously detects the body parts in the frames seen by
|
||||
your device's camera. These instructions walk you through building and running
|
||||
the demo on an Android device. Camera captures are discarded immediately after
|
||||
use, nothing is stored or saved.
|
||||
|
||||
The app demonstrates how to use 4 models:
|
||||
|
||||
* Single pose models: The model can estimate the pose of only one person in the
|
||||
input image. If the input image contains multiple persons, the detection result
|
||||
can be largely incorrect.
|
||||
* PoseNet
|
||||
* MoveNet Lightning
|
||||
* MoveNet Thunder
|
||||
* Multi pose models: The model can estimate pose of multiple persons in the
|
||||
input image.
|
||||
* MoveNet MultiPose: Support up to 6 persons.
|
||||
|
||||
See this [blog post](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html)
|
||||
for a comparison between these models.
|
||||
|
||||

|
||||
|
||||
## Build the demo using Android Studio
|
||||
|
||||
### Prerequisites
|
||||
|
||||
* If you don't have it already, install **[Android Studio](
|
||||
https://developer.android.com/studio/index.html)** 4.2 or
|
||||
above, following the instructions on the website.
|
||||
|
||||
* Android device and Android development environment with minimum API 21.
|
||||
|
||||
### Building
|
||||
* Open Android Studio, and from the `Welcome` screen, select
|
||||
`Open an existing Android Studio project`.
|
||||
|
||||
* From the `Open File or Project` window that appears, navigate to and select
|
||||
the `lite/examples/pose_estimation/android` directory from wherever you
|
||||
cloned the `tensorflow/examples` GitHub repo. Click `OK`.
|
||||
|
||||
* If it asks you to do a `Gradle Sync`, click `OK`.
|
||||
|
||||
* You may also need to install various platforms and tools, if you get errors
|
||||
like `Failed to find target with hash string 'android-21'` and similar. Click
|
||||
the `Run` button (the green arrow) or select `Run` > `Run 'android'` from the
|
||||
top menu. You may need to rebuild the project using `Build` > `Rebuild Project`.
|
||||
|
||||
* If it asks you to use `Instant Run`, click `Proceed Without Instant Run`.
|
||||
|
||||
* Also, you need to have an Android device plugged in with developer options
|
||||
enabled at this point. See **[here](
|
||||
https://developer.android.com/studio/run/device)** for more details
|
||||
on setting up developer devices.
|
||||
|
||||
|
||||
### Model used
|
||||
Downloading, extraction and placement in assets folder has been managed
|
||||
automatically by `download.gradle`.
|
||||
|
||||
If you explicitly want to download the model, you can download it from here:
|
||||
|
||||
* [Posenet](https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite)
|
||||
* [Movenet Lightning](https://kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-lightning)
|
||||
* [Movenet Thunder](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/singlepose-thunder)
|
||||
* [Movenet MultiPose](https://www.kaggle.com/models/google/movenet/frameworks/tfLite/variations/multipose-lightning-tflite-float16)
|
||||
|
||||
### Additional Note
|
||||
_Please do not delete the assets folder content_. If you explicitly deleted the
|
||||
files, then please choose `Build` > `Rebuild` from menu to re-download the
|
||||
deleted model files into assets folder.
|
@ -0,0 +1 @@
|
||||
/build
|
@ -0,0 +1,56 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
id 'kotlin-android'
|
||||
}
|
||||
|
||||
android {
|
||||
compileSdkVersion 30
|
||||
buildToolsVersion "30.0.3"
|
||||
|
||||
defaultConfig {
|
||||
applicationId "org.tensorflow.lite.examples.poseestimation"
|
||||
minSdkVersion 23
|
||||
targetSdkVersion 30
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
}
|
||||
|
||||
namespace "org.tensorflow.lite.examples.poseestimation"
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
minifyEnabled false
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
}
|
||||
|
||||
// Download tflite model
|
||||
apply from:"download.gradle"
|
||||
|
||||
dependencies {
|
||||
|
||||
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
|
||||
implementation 'androidx.core:core-ktx:1.5.0'
|
||||
implementation 'androidx.appcompat:appcompat:1.3.0'
|
||||
implementation 'com.google.android.material:material:1.3.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
|
||||
implementation "androidx.activity:activity-ktx:1.2.3"
|
||||
implementation 'androidx.fragment:fragment-ktx:1.3.5'
|
||||
implementation 'org.tensorflow:tensorflow-lite:2.14.0'
|
||||
implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0'
|
||||
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
|
||||
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
|
||||
androidTestImplementation "com.google.truth:truth:1.1.3"
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
task downloadPosenetModel(type: DownloadUrlTask) {
|
||||
def modelPosenetDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite"
|
||||
doFirst {
|
||||
println "Downloading ${modelPosenetDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelPosenetDownloadUrl}"
|
||||
target = file("src/main/assets/posenet.tflite")
|
||||
}
|
||||
|
||||
task downloadMovenetLightningModel(type: DownloadUrlTask) {
|
||||
def modelMovenetLightningDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite"
|
||||
doFirst {
|
||||
println "Downloading ${modelMovenetLightningDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelMovenetLightningDownloadUrl}"
|
||||
target = file("src/main/assets/movenet_lightning.tflite")
|
||||
}
|
||||
|
||||
task downloadMovenetThunderModel(type: DownloadUrlTask) {
|
||||
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite"
|
||||
doFirst {
|
||||
println "Downloading ${modelMovenetThunderDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelMovenetThunderDownloadUrl}"
|
||||
target = file("src/main/assets/movenet_thunder.tflite")
|
||||
}
|
||||
|
||||
task downloadMovenetMultiPoseModel(type: DownloadUrlTask) {
|
||||
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite"
|
||||
doFirst {
|
||||
println "Downloading ${modelMovenetThunderDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelMovenetThunderDownloadUrl}"
|
||||
target = file("src/main/assets/movenet_multipose_fp16.tflite")
|
||||
}
|
||||
|
||||
task downloadPoseClassifierModel(type: DownloadUrlTask) {
|
||||
def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite"
|
||||
doFirst {
|
||||
println "Downloading ${modelPoseClassifierDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelPoseClassifierDownloadUrl}"
|
||||
target = file("src/main/assets/classifier.tflite")
|
||||
}
|
||||
|
||||
task downloadModel {
|
||||
dependsOn downloadPosenetModel
|
||||
dependsOn downloadMovenetLightningModel
|
||||
dependsOn downloadMovenetThunderModel
|
||||
dependsOn downloadPoseClassifierModel
|
||||
dependsOn downloadMovenetMultiPoseModel
|
||||
}
|
||||
|
||||
class DownloadUrlTask extends DefaultTask {
|
||||
@Input
|
||||
String sourceUrl
|
||||
|
||||
@OutputFile
|
||||
File target
|
||||
|
||||
@TaskAction
|
||||
void download() {
|
||||
ant.get(src: sourceUrl, dest: target)
|
||||
}
|
||||
}
|
||||
|
||||
preBuild.dependsOn downloadModel
|
@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
After Width: | Height: | Size: 175 KiB |
After Width: | Height: | Size: 23 KiB |
After Width: | Height: | Size: 90 KiB |
@ -0,0 +1,3 @@
|
||||
Image1: https://pixabay.com/illustrations/woman-stand-wait-person-shoes-1427073/
|
||||
Image2: https://pixabay.com/photos/businessman-suit-germany-black-1146791/
|
||||
Image3: https://pixabay.com/photos/tree-pose-yoga-yogini-lifestyle-4823155/
|
|
@ -0,0 +1,113 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.graphics.Canvas
|
||||
import android.graphics.PointF
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import com.google.common.truth.Truth.assertWithMessage
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import java.io.BufferedReader
|
||||
import java.io.InputStreamReader
|
||||
import kotlin.math.pow
|
||||
|
||||
object EvaluationUtils {
|
||||
|
||||
/**
|
||||
* Assert whether the detected person from the image match with the expected result.
|
||||
* Detection result is accepted as correct if it is within the acceptableError range from the
|
||||
* expected result.
|
||||
*/
|
||||
fun assertPoseDetectionResult(
|
||||
person: Person,
|
||||
expectedResult: Map<BodyPart, PointF>,
|
||||
acceptableError: Float
|
||||
) {
|
||||
// Check if the model is confident enough in detecting the person
|
||||
assertThat(person.score).isGreaterThan(0.5f)
|
||||
|
||||
for ((bodyPart, expectedPointF) in expectedResult) {
|
||||
val keypoint = person.keyPoints.firstOrNull { it.bodyPart == bodyPart }
|
||||
assertWithMessage("$bodyPart must exist").that(keypoint).isNotNull()
|
||||
|
||||
val detectedPointF = keypoint!!.coordinate
|
||||
val distanceFromExpectedPointF = distance(detectedPointF, expectedPointF)
|
||||
assertWithMessage("Detected $bodyPart must be close to expected result")
|
||||
.that(distanceFromExpectedPointF).isAtMost(acceptableError)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an image from assets folder using its name.
|
||||
*/
|
||||
fun loadBitmapAssetByName(name: String): Bitmap {
|
||||
val testContext = InstrumentationRegistry.getInstrumentation().context
|
||||
val testInput = testContext.assets.open(name)
|
||||
return BitmapFactory.decodeStream(testInput)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load csv from assets folder
|
||||
*/
|
||||
fun loadCSVAsset(name: String): List<Map<BodyPart, PointF>> {
|
||||
val data = mutableListOf<Map<BodyPart, PointF>>()
|
||||
val testContext = InstrumentationRegistry.getInstrumentation().context
|
||||
val testInput = testContext.assets.open(name)
|
||||
val inputStreamReader = InputStreamReader(testInput)
|
||||
val reader = BufferedReader(inputStreamReader)
|
||||
// Skip header line
|
||||
reader.readLine()
|
||||
|
||||
// Read expected coordinates from each following lines
|
||||
reader.forEachLine {
|
||||
val listPoint = it.split(",")
|
||||
val map = mutableMapOf<BodyPart, PointF>()
|
||||
for (i in listPoint.indices step 2) {
|
||||
map[BodyPart.fromInt(i / 2)] =
|
||||
PointF(listPoint[i].toFloat(), listPoint[i + 1].toFloat())
|
||||
}
|
||||
data.add(map)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the distance between two points
|
||||
*/
|
||||
private fun distance(point1: PointF, point2: PointF): Float {
|
||||
return ((point1.x - point2.x).pow(2) + (point1.y - point2.y).pow(2)).pow(0.5f)
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenate images of same height horizontally
|
||||
*/
|
||||
fun hConcat(image1: Bitmap, image2: Bitmap): Bitmap {
|
||||
if (image1.height != image2.height) {
|
||||
throw Exception("Input images are not same height.")
|
||||
}
|
||||
val finalBitmap =
|
||||
Bitmap.createBitmap(image1.width + image2.width, image1.height, Bitmap.Config.ARGB_8888)
|
||||
val canvas = Canvas(finalBitmap)
|
||||
canvas.drawBitmap(image1, 0f, 0f, null)
|
||||
canvas.drawBitmap(image2, image1.width.toFloat(), 0f, null)
|
||||
return finalBitmap
|
||||
}
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.PointF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class MovenetLightningTest {
|
||||
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE1 = "image1.png"
|
||||
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
|
||||
private const val ACCEPTABLE_ERROR = 21f
|
||||
}
|
||||
|
||||
private lateinit var poseDetector: PoseDetector
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
|
||||
expectedDetectionResult =
|
||||
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage1() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
|
||||
|
||||
// As Movenet use previous frame to optimize detection result, we run it multiple times
|
||||
// using the same image to improve result.
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[0],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage2() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
|
||||
|
||||
// As Movenet use previous frame to optimize detection result, we run it multiple times
|
||||
// using the same image to improve result.
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[1],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.PointF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.Type
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class MovenetMultiPoseTest {
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE1 = "image1.png"
|
||||
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
|
||||
private const val ACCEPTABLE_ERROR = 17f
|
||||
}
|
||||
|
||||
private lateinit var poseDetector: MoveNetMultiPose
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var inputFinal: Bitmap
|
||||
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
poseDetector = MoveNetMultiPose.create(appContext, Device.CPU, Type.Dynamic)
|
||||
val input1 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
|
||||
val input2 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
|
||||
inputFinal = EvaluationUtils.hConcat(input1, input2)
|
||||
expectedDetectionResult =
|
||||
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
|
||||
|
||||
// update coordination of the pose_landmark_truth.csv corresponding to the new input image
|
||||
for ((_, value) in expectedDetectionResult[1]) {
|
||||
value.x = value.x + input1.width
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimateResult() {
|
||||
val persons = poseDetector.estimatePoses(inputFinal)
|
||||
assert(persons.size == 2)
|
||||
|
||||
// Sort the results so that the person on the right side come first.
|
||||
val sortedPersons = persons.sortedBy { it.boundingBox?.left }
|
||||
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
sortedPersons[0],
|
||||
expectedDetectionResult[0],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
sortedPersons[1],
|
||||
expectedDetectionResult[1],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.PointF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class MovenetThunderTest {
|
||||
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE1 = "image1.png"
|
||||
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
|
||||
private const val ACCEPTABLE_ERROR = 15f
|
||||
}
|
||||
|
||||
private lateinit var poseDetector: PoseDetector
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder)
|
||||
expectedDetectionResult =
|
||||
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage1() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
|
||||
|
||||
// As Movenet use previous frame to optimize detection result, we run it multiple times
|
||||
// using the same image to improve result.
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[0],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage2() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
|
||||
|
||||
// As Movenet use previous frame to optimize detection result, we run it multiple times
|
||||
// using the same image to improve result.
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[1],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import junit.framework.TestCase
|
||||
import junit.framework.TestCase.assertEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class PoseClassifierTest {
|
||||
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE = "image3.jpeg"
|
||||
}
|
||||
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var poseDetector: PoseDetector
|
||||
private lateinit var poseClassifier: PoseClassifier
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
|
||||
poseClassifier = PoseClassifier.create(appContext)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseClassifier() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE)
|
||||
// As Movenet use previous frame to optimize detection result, we run it multiple times
|
||||
// using the same image to improve result.
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
poseDetector.estimatePoses(input)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
val classificationResult = poseClassifier.classify(person)
|
||||
val predictedPose = classificationResult.maxByOrNull { it.second }?.first ?: "n/a"
|
||||
assertEquals(
|
||||
"Predicted pose is different from ground truth.",
|
||||
"tree",
|
||||
predictedPose
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.PointF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class PosenetTest {
|
||||
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE1 = "image1.png"
|
||||
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
|
||||
private const val ACCEPTABLE_ERROR = 37f
|
||||
}
|
||||
|
||||
private lateinit var poseDetector: PoseDetector
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
poseDetector = PoseNet.create(appContext, Device.CPU)
|
||||
expectedDetectionResult =
|
||||
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage1() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[0],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPoseEstimationResultWithImage2() {
|
||||
val input = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
|
||||
val person = poseDetector.estimatePoses(input)[0]
|
||||
EvaluationUtils.assertPoseDetectionResult(
|
||||
person,
|
||||
expectedDetectionResult[1],
|
||||
ACCEPTABLE_ERROR
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,82 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.VisualizationUtils
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
|
||||
/**
|
||||
* This test is used to visually verify detection results by the models.
|
||||
* You can put a breakpoint at the end of the method, debug this method, than use the
|
||||
* "View Bitmap" feature of the debugger to check the visualized detection result.
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class VisualizationTest {
|
||||
|
||||
companion object {
|
||||
private const val TEST_INPUT_IMAGE = "image2.jpg"
|
||||
}
|
||||
|
||||
private lateinit var appContext: Context
|
||||
private lateinit var inputBitmap: Bitmap
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
inputBitmap = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPosenet() {
|
||||
val poseDetector = PoseNet.create(appContext, Device.CPU)
|
||||
val person = poseDetector.estimatePoses(inputBitmap)[0]
|
||||
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
|
||||
assertThat(outputBitmap).isNotNull()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMovenetLightning() {
|
||||
// Due to Movenet's cropping logic, we run inference several times with the same input
|
||||
// image to improve accuracy
|
||||
val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Lightning)
|
||||
poseDetector.estimatePoses(inputBitmap)
|
||||
poseDetector.estimatePoses(inputBitmap)
|
||||
val person2 = poseDetector.estimatePoses(inputBitmap)[0]
|
||||
val outputBitmap2 = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person2))
|
||||
assertThat(outputBitmap2).isNotNull()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMovenetThunder() {
|
||||
// Due to Movenet's cropping logic, we run inference several times with the same input
|
||||
// image to improve accuracy
|
||||
val poseDetector = MoveNet.create(appContext, Device.CPU, ModelType.Thunder)
|
||||
poseDetector.estimatePoses(inputBitmap)
|
||||
poseDetector.estimatePoses(inputBitmap)
|
||||
val person = poseDetector.estimatePoses(inputBitmap)[0]
|
||||
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
|
||||
assertThat(outputBitmap).isNotNull()
|
||||
}
|
||||
}
|
@ -0,0 +1,228 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import android.graphics.RectF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import junit.framework.TestCase.assertEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class BoundingBoxTrackerTest {
|
||||
companion object {
|
||||
private const val MAX_TRACKS = 4
|
||||
private const val MAX_AGE = 1000 // Unit: milliseconds.
|
||||
private const val MIN_SIMILARITY = 0.5f
|
||||
}
|
||||
|
||||
private lateinit var boundingBoxTracker: BoundingBoxTracker
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
val trackerConfig = TrackerConfig(MAX_TRACKS, MAX_AGE, MIN_SIMILARITY)
|
||||
boundingBoxTracker = BoundingBoxTracker(trackerConfig)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testIoU() {
|
||||
val persons = Person(
|
||||
-1, listOf(), RectF(
|
||||
0f,
|
||||
0f,
|
||||
2f / 3,
|
||||
1f
|
||||
), 1f
|
||||
)
|
||||
|
||||
val track =
|
||||
Track(
|
||||
Person(
|
||||
-1,
|
||||
listOf(),
|
||||
RectF(
|
||||
1 / 3f,
|
||||
0.0f,
|
||||
1f,
|
||||
1f,
|
||||
), 1f
|
||||
), 1000000
|
||||
)
|
||||
val computedIoU = boundingBoxTracker.iou(persons, track.person)
|
||||
assertEquals("Wrong IoU value.", 1f / 3, computedIoU, 0.000001f)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testIoUFullOverlap() {
|
||||
val persons = Person(
|
||||
-1, listOf(),
|
||||
RectF(
|
||||
0f,
|
||||
0f,
|
||||
1f,
|
||||
1f
|
||||
), 1f
|
||||
)
|
||||
|
||||
val track =
|
||||
Track(
|
||||
Person(
|
||||
-1,
|
||||
listOf(),
|
||||
RectF(
|
||||
0f,
|
||||
0f,
|
||||
1f,
|
||||
1f,
|
||||
), 1f
|
||||
), 1000000
|
||||
)
|
||||
val computedIoU = boundingBoxTracker.iou(persons, track.person)
|
||||
assertEquals("Wrong IoU value.", 1f, computedIoU, 0.000001f)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testIoUNoIntersection() {
|
||||
val persons = Person(
|
||||
-1, listOf(),
|
||||
RectF(
|
||||
0f,
|
||||
0f,
|
||||
0.5f,
|
||||
0.5f
|
||||
), 1f
|
||||
)
|
||||
|
||||
val track =
|
||||
Track(
|
||||
Person(
|
||||
-1,
|
||||
listOf(),
|
||||
RectF(
|
||||
0.5f,
|
||||
0.5f,
|
||||
1f,
|
||||
1f,
|
||||
), 1f
|
||||
), 1000000
|
||||
)
|
||||
val computedIoU = boundingBoxTracker.iou(persons, track.person)
|
||||
assertEquals("Wrong IoU value.", 0f, computedIoU, 0.000001f)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBoundingBoxTracking() {
|
||||
// Timestamp: 0. Poses becomes the first two tracks.
|
||||
var persons = listOf(
|
||||
Person( // Becomes track 1.
|
||||
-1, listOf(), RectF(
|
||||
0f,
|
||||
0f,
|
||||
0.5f,
|
||||
0.5f,
|
||||
), 1f
|
||||
),
|
||||
Person( // Becomes track 2.
|
||||
-1, listOf(), RectF(
|
||||
0f,
|
||||
0f,
|
||||
1f,
|
||||
1f
|
||||
), 1f
|
||||
)
|
||||
)
|
||||
persons = boundingBoxTracker.apply(persons, 0)
|
||||
var track = boundingBoxTracker.tracks
|
||||
assertEquals(2, persons.size)
|
||||
assertEquals(1, persons[0].id)
|
||||
assertEquals(2, persons[1].id)
|
||||
assertEquals(2, track.size)
|
||||
assertEquals(1, track[0].person.id)
|
||||
assertEquals(0, track[0].lastTimestamp)
|
||||
assertEquals(2, track[1].person.id)
|
||||
assertEquals(0, track[1].lastTimestamp)
|
||||
|
||||
// Timestamp: 100000. First pose is linked with track 1. Second pose spawns
|
||||
// a new track (id = 2).
|
||||
persons = listOf(
|
||||
Person( // Linked with track 1.
|
||||
-1, listOf(), RectF(
|
||||
0.1f,
|
||||
0.1f,
|
||||
0.5f,
|
||||
0.5f
|
||||
), 1f
|
||||
),
|
||||
Person( // Becomes track 3.
|
||||
-1, listOf(), RectF(
|
||||
0.2f,
|
||||
0.3f,
|
||||
0.9f,
|
||||
0.9f
|
||||
), 1f
|
||||
)
|
||||
)
|
||||
persons = boundingBoxTracker.apply(persons, 100000)
|
||||
track = boundingBoxTracker.tracks
|
||||
assertEquals(2, persons.size)
|
||||
assertEquals(1, persons[0].id)
|
||||
assertEquals(3, persons[1].id)
|
||||
assertEquals(3, track.size)
|
||||
assertEquals(1, track[0].person.id)
|
||||
assertEquals(100000, track[0].lastTimestamp)
|
||||
assertEquals(3, track[1].person.id)
|
||||
assertEquals(100000, track[1].lastTimestamp)
|
||||
assertEquals(2, track[2].person.id)
|
||||
assertEquals(0, track[2].lastTimestamp)
|
||||
|
||||
// Timestamp: 1050000. First pose is linked with track 1. Second pose is
|
||||
// identical to track 2, but is not linked because track 2 is deleted due to
|
||||
// age. Instead it spawns track 4.
|
||||
persons = listOf(
|
||||
Person( // Linked with track 1.
|
||||
-1, listOf(), RectF(
|
||||
0.1f,
|
||||
0.1f,
|
||||
0.55f,
|
||||
0.5f
|
||||
), 1f
|
||||
),
|
||||
Person( // Becomes track 4.
|
||||
-1, listOf(), RectF(
|
||||
0f,
|
||||
0f,
|
||||
1f,
|
||||
1f
|
||||
), 1f
|
||||
)
|
||||
)
|
||||
persons = boundingBoxTracker.apply(persons, 1050000)
|
||||
track = boundingBoxTracker.tracks
|
||||
assertEquals(2, persons.size)
|
||||
assertEquals(1, persons[0].id)
|
||||
assertEquals(4, persons[1].id)
|
||||
assertEquals(3, track.size)
|
||||
assertEquals(1, track[0].person.id)
|
||||
assertEquals(1050000, track[0].lastTimestamp)
|
||||
assertEquals(4, track[1].person.id)
|
||||
assertEquals(1050000, track[1].lastTimestamp)
|
||||
assertEquals(3, track[2].person.id)
|
||||
assertEquals(100000, track[2].lastTimestamp)
|
||||
}
|
||||
}
|
@ -0,0 +1,308 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import android.graphics.PointF
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import junit.framework.Assert
|
||||
import junit.framework.TestCase.assertEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.pow
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class KeyPointsTrackerTest {
|
||||
companion object {
|
||||
private const val MAX_TRACKS = 4
|
||||
private const val MAX_AGE = 1000
|
||||
private const val MIN_SIMILARITY = 0.5f
|
||||
private const val KEYPOINT_THRESHOLD = 0.2f
|
||||
private const val MIN_NUM_KEYPOINT = 2
|
||||
private val KEYPOINT_FALLOFF = listOf(0.1f, 0.1f, 0.1f, 0.1f)
|
||||
}
|
||||
|
||||
private lateinit var keyPointsTracker: KeyPointsTracker
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
val trackerConfig = TrackerConfig(
|
||||
MAX_TRACKS, MAX_AGE, MIN_SIMILARITY,
|
||||
KeyPointsTrackerParams(KEYPOINT_THRESHOLD, KEYPOINT_FALLOFF, MIN_NUM_KEYPOINT)
|
||||
)
|
||||
keyPointsTracker = KeyPointsTracker(trackerConfig)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testOks() {
|
||||
val persons =
|
||||
Person(
|
||||
-1, listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.7f), 0.8f),
|
||||
), score = 1f
|
||||
)
|
||||
val tracks =
|
||||
Track(
|
||||
Person(
|
||||
0, listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
|
||||
), score = 1f
|
||||
),
|
||||
1000000,
|
||||
)
|
||||
|
||||
val oks = keyPointsTracker.oks(persons, tracks.person)
|
||||
val boxArea = (0.8f - 0.2f) * (0.8f - 0.2f)
|
||||
val x = 2f * KEYPOINT_FALLOFF[3]
|
||||
val d = 0.1f
|
||||
val expectedOks: Float =
|
||||
(1f + 1f + exp(-1f * d.pow(2) / (2f * boxArea * x.pow(2)))) / 3f
|
||||
assertEquals(expectedOks, oks, 0.000001f)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testOksReturnZero() {
|
||||
// Compute OKS returns 0.0 with less than 2 valid keypoints
|
||||
val persons =
|
||||
Person(
|
||||
-1, listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.1f), // Low confidence.
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
|
||||
), score = 1f
|
||||
)
|
||||
val tracks =
|
||||
Track(
|
||||
Person(
|
||||
0, listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.1f),// Low confidence.
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f),// Low confidence.
|
||||
), score = 1f
|
||||
), 1000000
|
||||
)
|
||||
|
||||
val oks = keyPointsTracker.oks(persons, tracks.person)
|
||||
assertEquals(0f, oks, 0.000001f)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testArea() {
|
||||
val keyPoints = listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.3f, 0.4f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.7f, 0.8f), 0.1f),
|
||||
)
|
||||
val area = keyPointsTracker.area(keyPoints)
|
||||
val expectedArea = (0.4f - 0.1f) * (0.6f - 0.2f)
|
||||
assertEquals(expectedArea, area)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testKeyPointsTracker() {
|
||||
// Timestamp: 0. Person becomes the only track.
|
||||
var persons = listOf(
|
||||
Person(
|
||||
-1, listOf(
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.0f),
|
||||
), score = 0.9f
|
||||
)
|
||||
)
|
||||
persons = keyPointsTracker.apply(persons, 0)
|
||||
var track = keyPointsTracker.tracks
|
||||
assertEquals(1, persons.size)
|
||||
assertEquals(1, persons[0].id)
|
||||
assertEquals(1, track.size)
|
||||
assertEquals(1, track[0].person.id)
|
||||
assertEquals(0, track[0].lastTimestamp)
|
||||
|
||||
// Timestamp: 100000. First person is linked with track 1. Second person spawns
|
||||
// a new track (id = 2).
|
||||
persons = listOf(
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Links with id = 1.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
|
||||
),
|
||||
score = 1f
|
||||
),
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Becomes id = 2.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.8f, 0.8f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.6f, 0.6f), 0.3f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.4f, 0.4f), 0.1f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.2f), 0.8f),
|
||||
),
|
||||
score = 1f
|
||||
)
|
||||
)
|
||||
persons = keyPointsTracker.apply(persons, 100000)
|
||||
track = keyPointsTracker.tracks
|
||||
assertEquals(2, persons.size)
|
||||
assertEquals(1, persons[0].id)
|
||||
assertEquals(2, persons[1].id)
|
||||
assertEquals(2, track.size)
|
||||
assertEquals(1, track[0].person.id)
|
||||
assertEquals(100000, track[0].lastTimestamp)
|
||||
assertEquals(2, track[1].person.id)
|
||||
assertEquals(100000, track[1].lastTimestamp)
|
||||
|
||||
// Timestamp: 900000. First person is linked with track 2. Second person spawns
|
||||
// a new track (id = 3).
|
||||
persons = listOf(
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Links with id = 2.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.6f, 0.7f), 0.7f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.7f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.0f, 0.0f), 0.1f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.2f, 0.1f), 1f),
|
||||
),
|
||||
score = 1f
|
||||
),
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Becomes id = 3.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.5f, 0.1f), 0.6f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.9f, 0.3f), 0.6f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.1f, 1f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.1f),
|
||||
),
|
||||
score = 1f
|
||||
)
|
||||
)
|
||||
persons = keyPointsTracker.apply(persons, 900000)
|
||||
track = keyPointsTracker.tracks
|
||||
assertEquals(2, persons.size)
|
||||
assertEquals(2, persons[0].id)
|
||||
assertEquals(3, persons[1].id)
|
||||
assertEquals(3, track.size)
|
||||
assertEquals(2, track[0].person.id)
|
||||
assertEquals(900000, track[0].lastTimestamp)
|
||||
assertEquals(3, track[1].person.id)
|
||||
assertEquals(900000, track[1].lastTimestamp)
|
||||
assertEquals(1, track[2].person.id)
|
||||
assertEquals(100000, track[2].lastTimestamp)
|
||||
|
||||
// Timestamp: 1200000. First person spawns a new track (id = 4), even though
|
||||
// it has the same keypoints as track 1. This is because the age exceeds
|
||||
// 1000 msec. The second person links with id 2. The third person spawns a new
|
||||
// track (id = 5).
|
||||
persons = listOf(
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Becomes id = 4.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.2f, 0.2f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.4f, 0.4f), 0.8f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.6f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.8f), 0.8f),
|
||||
),
|
||||
score = 1f
|
||||
),
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Links with id = 2.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.55f, 0.7f), 0.7f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.5f, 0.6f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(1f, 1f), 0.1f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.1f), 0f),
|
||||
),
|
||||
score = 1f
|
||||
),
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Becomes id = 5.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.1f), 0.1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.2f), 0.9f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.3f, 0.3f), 0.7f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.4f, 0.4f), 0.8f),
|
||||
),
|
||||
score = 1f
|
||||
)
|
||||
)
|
||||
persons = keyPointsTracker.apply(persons, 1200000)
|
||||
track = keyPointsTracker.tracks
|
||||
assertEquals(3, persons.size)
|
||||
assertEquals(4, persons[0].id)
|
||||
assertEquals(2, persons[1].id)
|
||||
assertEquals(4, track.size)
|
||||
assertEquals(2, track[0].person.id)
|
||||
assertEquals(1200000, track[0].lastTimestamp)
|
||||
assertEquals(4, track[1].person.id)
|
||||
assertEquals(1200000, track[1].lastTimestamp)
|
||||
assertEquals(5, track[2].person.id)
|
||||
assertEquals(1200000, track[2].lastTimestamp)
|
||||
assertEquals(3, track[3].person.id)
|
||||
assertEquals(900000, track[3].lastTimestamp)
|
||||
|
||||
// Timestamp: 1300000. First person spawns a new track (id = 6). Since
|
||||
// maxTracks is 4, the oldest track (id = 3) is removed.
|
||||
persons = listOf(
|
||||
Person(
|
||||
-1,
|
||||
listOf(
|
||||
// Becomes id = 6.
|
||||
KeyPoint(BodyPart.NOSE, PointF(0.1f, 0.8f), 1f),
|
||||
KeyPoint(BodyPart.RIGHT_ELBOW, PointF(0.2f, 0.9f), 0.6f),
|
||||
KeyPoint(BodyPart.RIGHT_KNEE, PointF(0.2f, 0.9f), 0.5f),
|
||||
KeyPoint(BodyPart.RIGHT_ANKLE, PointF(0.8f, 0.2f), 0.4f),
|
||||
),
|
||||
score = 1f
|
||||
)
|
||||
)
|
||||
persons = keyPointsTracker.apply(persons, 1300000)
|
||||
track = keyPointsTracker.tracks
|
||||
assertEquals(1, persons.size)
|
||||
assertEquals(6, persons[0].id)
|
||||
assertEquals(4, track.size)
|
||||
assertEquals(6, track[0].person.id)
|
||||
assertEquals(1300000, track[0].lastTimestamp)
|
||||
assertEquals(2, track[1].person.id)
|
||||
assertEquals(1200000, track[1].lastTimestamp)
|
||||
assertEquals(4, track[2].person.id)
|
||||
assertEquals(1200000, track[2].lastTimestamp)
|
||||
assertEquals(5, track[3].person.id)
|
||||
assertEquals(1200000, track[3].lastTimestamp)
|
||||
}
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.lite.examples.poseestimation">
|
||||
|
||||
<uses-permission android:name="android.permission.CAMERA" />
|
||||
<uses-feature
|
||||
android:name="android.hardware.camera"
|
||||
android:required="true" />
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:icon="@drawable/ic_launcher"
|
||||
android:label="@string/tfe_pe_app_name"
|
||||
android:roundIcon="@drawable/ic_launcher"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.PoseEstimation">
|
||||
<activity android:name=".MainActivity">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -0,0 +1,5 @@
|
||||
chair
|
||||
cobra
|
||||
dog
|
||||
tree
|
||||
warrior
|
@ -0,0 +1,430 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation
|
||||
|
||||
import android.Manifest
|
||||
import android.app.AlertDialog
|
||||
import android.app.Dialog
|
||||
import android.content.pm.PackageManager
|
||||
import android.os.Bundle
|
||||
import android.os.Process
|
||||
import android.view.SurfaceView
|
||||
import android.view.View
|
||||
import android.view.WindowManager
|
||||
import android.widget.*
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import androidx.appcompat.widget.SwitchCompat
|
||||
import androidx.core.content.ContextCompat
|
||||
import androidx.fragment.app.DialogFragment
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import org.tensorflow.lite.examples.poseestimation.camera.CameraSource
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.*
|
||||
|
||||
class MainActivity : AppCompatActivity() {
|
||||
companion object {
|
||||
private const val FRAGMENT_DIALOG = "dialog"
|
||||
}
|
||||
|
||||
/** A [SurfaceView] for camera preview. */
|
||||
private lateinit var surfaceView: SurfaceView
|
||||
|
||||
/** Default pose estimation model is 1 (MoveNet Thunder)
|
||||
* 0 == MoveNet Lightning model
|
||||
* 1 == MoveNet Thunder model
|
||||
* 2 == MoveNet MultiPose model
|
||||
* 3 == PoseNet model
|
||||
**/
|
||||
private var modelPos = 1
|
||||
|
||||
/** Default device is CPU */
|
||||
private var device = Device.CPU
|
||||
|
||||
private lateinit var tvScore: TextView
|
||||
private lateinit var tvFPS: TextView
|
||||
private lateinit var spnDevice: Spinner
|
||||
private lateinit var spnModel: Spinner
|
||||
private lateinit var spnTracker: Spinner
|
||||
private lateinit var vTrackerOption: View
|
||||
private lateinit var tvClassificationValue1: TextView
|
||||
private lateinit var tvClassificationValue2: TextView
|
||||
private lateinit var tvClassificationValue3: TextView
|
||||
private lateinit var swClassification: SwitchCompat
|
||||
private lateinit var vClassificationOption: View
|
||||
private var cameraSource: CameraSource? = null
|
||||
private var isClassifyPose = false
|
||||
private val requestPermissionLauncher =
|
||||
registerForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) { isGranted: Boolean ->
|
||||
if (isGranted) {
|
||||
// Permission is granted. Continue the action or workflow in your
|
||||
// app.
|
||||
openCamera()
|
||||
} else {
|
||||
// Explain to the user that the feature is unavailable because the
|
||||
// features requires a permission that the user has denied. At the
|
||||
// same time, respect the user's decision. Don't link to system
|
||||
// settings in an effort to convince the user to change their
|
||||
// decision.
|
||||
ErrorDialog.newInstance(getString(R.string.tfe_pe_request_permission))
|
||||
.show(supportFragmentManager, FRAGMENT_DIALOG)
|
||||
}
|
||||
}
|
||||
private var changeModelListener = object : AdapterView.OnItemSelectedListener {
|
||||
override fun onNothingSelected(parent: AdapterView<*>?) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
override fun onItemSelected(
|
||||
parent: AdapterView<*>?,
|
||||
view: View?,
|
||||
position: Int,
|
||||
id: Long
|
||||
) {
|
||||
changeModel(position)
|
||||
}
|
||||
}
|
||||
|
||||
private var changeDeviceListener = object : AdapterView.OnItemSelectedListener {
|
||||
override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) {
|
||||
changeDevice(position)
|
||||
}
|
||||
|
||||
override fun onNothingSelected(parent: AdapterView<*>?) {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
private var changeTrackerListener = object : AdapterView.OnItemSelectedListener {
|
||||
override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) {
|
||||
changeTracker(position)
|
||||
}
|
||||
|
||||
override fun onNothingSelected(parent: AdapterView<*>?) {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
private var setClassificationListener =
|
||||
CompoundButton.OnCheckedChangeListener { _, isChecked ->
|
||||
showClassificationResult(isChecked)
|
||||
isClassifyPose = isChecked
|
||||
isPoseClassifier()
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContentView(R.layout.activity_main)
|
||||
// keep screen on while app is running
|
||||
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
|
||||
tvScore = findViewById(R.id.tvScore)
|
||||
tvFPS = findViewById(R.id.tvFps)
|
||||
spnModel = findViewById(R.id.spnModel)
|
||||
spnDevice = findViewById(R.id.spnDevice)
|
||||
spnTracker = findViewById(R.id.spnTracker)
|
||||
vTrackerOption = findViewById(R.id.vTrackerOption)
|
||||
surfaceView = findViewById(R.id.surfaceView)
|
||||
tvClassificationValue1 = findViewById(R.id.tvClassificationValue1)
|
||||
tvClassificationValue2 = findViewById(R.id.tvClassificationValue2)
|
||||
tvClassificationValue3 = findViewById(R.id.tvClassificationValue3)
|
||||
swClassification = findViewById(R.id.swPoseClassification)
|
||||
vClassificationOption = findViewById(R.id.vClassificationOption)
|
||||
initSpinner()
|
||||
spnModel.setSelection(modelPos)
|
||||
swClassification.setOnCheckedChangeListener(setClassificationListener)
|
||||
if (!isCameraPermissionGranted()) {
|
||||
requestPermission()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onStart() {
|
||||
super.onStart()
|
||||
openCamera()
|
||||
}
|
||||
|
||||
override fun onResume() {
|
||||
cameraSource?.resume()
|
||||
super.onResume()
|
||||
}
|
||||
|
||||
override fun onPause() {
|
||||
cameraSource?.close()
|
||||
cameraSource = null
|
||||
super.onPause()
|
||||
}
|
||||
|
||||
// check if permission is granted or not.
|
||||
private fun isCameraPermissionGranted(): Boolean {
|
||||
return checkPermission(
|
||||
Manifest.permission.CAMERA,
|
||||
Process.myPid(),
|
||||
Process.myUid()
|
||||
) == PackageManager.PERMISSION_GRANTED
|
||||
}
|
||||
|
||||
// open camera
|
||||
private fun openCamera() {
|
||||
if (isCameraPermissionGranted()) {
|
||||
if (cameraSource == null) {
|
||||
cameraSource =
|
||||
CameraSource(surfaceView, object : CameraSource.CameraSourceListener {
|
||||
override fun onFPSListener(fps: Int) {
|
||||
tvFPS.text = getString(R.string.tfe_pe_tv_fps, fps)
|
||||
}
|
||||
|
||||
override fun onDetectedInfo(
|
||||
personScore: Float?,
|
||||
poseLabels: List<Pair<String, Float>>?
|
||||
) {
|
||||
tvScore.text = getString(R.string.tfe_pe_tv_score, personScore ?: 0f)
|
||||
poseLabels?.sortedByDescending { it.second }?.let {
|
||||
tvClassificationValue1.text = getString(
|
||||
R.string.tfe_pe_tv_classification_value,
|
||||
convertPoseLabels(if (it.isNotEmpty()) it[0] else null)
|
||||
)
|
||||
tvClassificationValue2.text = getString(
|
||||
R.string.tfe_pe_tv_classification_value,
|
||||
convertPoseLabels(if (it.size >= 2) it[1] else null)
|
||||
)
|
||||
tvClassificationValue3.text = getString(
|
||||
R.string.tfe_pe_tv_classification_value,
|
||||
convertPoseLabels(if (it.size >= 3) it[2] else null)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}).apply {
|
||||
prepareCamera()
|
||||
}
|
||||
isPoseClassifier()
|
||||
lifecycleScope.launch(Dispatchers.Main) {
|
||||
cameraSource?.initCamera()
|
||||
}
|
||||
}
|
||||
createPoseEstimator()
|
||||
}
|
||||
}
|
||||
|
||||
private fun convertPoseLabels(pair: Pair<String, Float>?): String {
|
||||
if (pair == null) return "empty"
|
||||
return "${pair.first} (${String.format("%.2f", pair.second)})"
|
||||
}
|
||||
|
||||
private fun isPoseClassifier() {
|
||||
cameraSource?.setClassifier(if (isClassifyPose) PoseClassifier.create(this) else null)
|
||||
}
|
||||
|
||||
// Initialize spinners to let user select model/accelerator/tracker.
|
||||
private fun initSpinner() {
|
||||
ArrayAdapter.createFromResource(
|
||||
this,
|
||||
R.array.tfe_pe_models_array,
|
||||
android.R.layout.simple_spinner_item
|
||||
).also { adapter ->
|
||||
// Specify the layout to use when the list of choices appears
|
||||
adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
|
||||
// Apply the adapter to the spinner
|
||||
spnModel.adapter = adapter
|
||||
spnModel.onItemSelectedListener = changeModelListener
|
||||
}
|
||||
|
||||
ArrayAdapter.createFromResource(
|
||||
this,
|
||||
R.array.tfe_pe_device_name, android.R.layout.simple_spinner_item
|
||||
).also { adaper ->
|
||||
adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
|
||||
|
||||
spnDevice.adapter = adaper
|
||||
spnDevice.onItemSelectedListener = changeDeviceListener
|
||||
}
|
||||
|
||||
ArrayAdapter.createFromResource(
|
||||
this,
|
||||
R.array.tfe_pe_tracker_array, android.R.layout.simple_spinner_item
|
||||
).also { adaper ->
|
||||
adaper.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
|
||||
|
||||
spnTracker.adapter = adaper
|
||||
spnTracker.onItemSelectedListener = changeTrackerListener
|
||||
}
|
||||
}
|
||||
|
||||
// Change model when app is running
|
||||
private fun changeModel(position: Int) {
|
||||
if (modelPos == position) return
|
||||
modelPos = position
|
||||
createPoseEstimator()
|
||||
}
|
||||
|
||||
// Change device (accelerator) type when app is running
|
||||
private fun changeDevice(position: Int) {
|
||||
val targetDevice = when (position) {
|
||||
0 -> Device.CPU
|
||||
1 -> Device.GPU
|
||||
else -> Device.NNAPI
|
||||
}
|
||||
if (device == targetDevice) return
|
||||
device = targetDevice
|
||||
createPoseEstimator()
|
||||
}
|
||||
|
||||
// Change tracker for Movenet MultiPose model
|
||||
private fun changeTracker(position: Int) {
|
||||
cameraSource?.setTracker(
|
||||
when (position) {
|
||||
1 -> TrackerType.BOUNDING_BOX
|
||||
2 -> TrackerType.KEYPOINTS
|
||||
else -> TrackerType.OFF
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private fun createPoseEstimator() {
|
||||
// For MoveNet MultiPose, hide score and disable pose classifier as the model returns
|
||||
// multiple Person instances.
|
||||
val poseDetector = when (modelPos) {
|
||||
0 -> {
|
||||
// MoveNet Lightning (SinglePose)
|
||||
showPoseClassifier(true)
|
||||
showDetectionScore(true)
|
||||
showTracker(false)
|
||||
MoveNet.create(this, device, ModelType.Lightning)
|
||||
}
|
||||
1 -> {
|
||||
// MoveNet Thunder (SinglePose)
|
||||
showPoseClassifier(true)
|
||||
showDetectionScore(true)
|
||||
showTracker(false)
|
||||
MoveNet.create(this, device, ModelType.Thunder)
|
||||
}
|
||||
2 -> {
|
||||
// MoveNet (Lightning) MultiPose
|
||||
showPoseClassifier(false)
|
||||
showDetectionScore(false)
|
||||
// Movenet MultiPose Dynamic does not support GPUDelegate
|
||||
if (device == Device.GPU) {
|
||||
showToast(getString(R.string.tfe_pe_gpu_error))
|
||||
}
|
||||
showTracker(true)
|
||||
MoveNetMultiPose.create(
|
||||
this,
|
||||
device,
|
||||
Type.Dynamic
|
||||
)
|
||||
}
|
||||
3 -> {
|
||||
// PoseNet (SinglePose)
|
||||
showPoseClassifier(true)
|
||||
showDetectionScore(true)
|
||||
showTracker(false)
|
||||
PoseNet.create(this, device)
|
||||
}
|
||||
else -> {
|
||||
null
|
||||
}
|
||||
}
|
||||
poseDetector?.let { detector ->
|
||||
cameraSource?.setDetector(detector)
|
||||
}
|
||||
}
|
||||
|
||||
// Show/hide the pose classification option.
|
||||
private fun showPoseClassifier(isVisible: Boolean) {
|
||||
vClassificationOption.visibility = if (isVisible) View.VISIBLE else View.GONE
|
||||
if (!isVisible) {
|
||||
swClassification.isChecked = false
|
||||
}
|
||||
}
|
||||
|
||||
// Show/hide the detection score.
|
||||
private fun showDetectionScore(isVisible: Boolean) {
|
||||
tvScore.visibility = if (isVisible) View.VISIBLE else View.GONE
|
||||
}
|
||||
|
||||
// Show/hide classification result.
|
||||
private fun showClassificationResult(isVisible: Boolean) {
|
||||
val visibility = if (isVisible) View.VISIBLE else View.GONE
|
||||
tvClassificationValue1.visibility = visibility
|
||||
tvClassificationValue2.visibility = visibility
|
||||
tvClassificationValue3.visibility = visibility
|
||||
}
|
||||
|
||||
// Show/hide the tracking options.
|
||||
private fun showTracker(isVisible: Boolean) {
|
||||
if (isVisible) {
|
||||
// Show tracker options and enable Bounding Box tracker.
|
||||
vTrackerOption.visibility = View.VISIBLE
|
||||
spnTracker.setSelection(1)
|
||||
} else {
|
||||
// Set tracker type to off and hide tracker option.
|
||||
vTrackerOption.visibility = View.GONE
|
||||
spnTracker.setSelection(0)
|
||||
}
|
||||
}
|
||||
|
||||
private fun requestPermission() {
|
||||
when (PackageManager.PERMISSION_GRANTED) {
|
||||
ContextCompat.checkSelfPermission(
|
||||
this,
|
||||
Manifest.permission.CAMERA
|
||||
) -> {
|
||||
// You can use the API that requires the permission.
|
||||
openCamera()
|
||||
}
|
||||
else -> {
|
||||
// You can directly ask for the permission.
|
||||
// The registered ActivityResultCallback gets the result of this request.
|
||||
requestPermissionLauncher.launch(
|
||||
Manifest.permission.CAMERA
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun showToast(message: String) {
|
||||
Toast.makeText(this, message, Toast.LENGTH_LONG).show()
|
||||
}
|
||||
|
||||
/**
|
||||
* Shows an error message dialog.
|
||||
*/
|
||||
class ErrorDialog : DialogFragment() {
|
||||
|
||||
override fun onCreateDialog(savedInstanceState: Bundle?): Dialog =
|
||||
AlertDialog.Builder(activity)
|
||||
.setMessage(requireArguments().getString(ARG_MESSAGE))
|
||||
.setPositiveButton(android.R.string.ok) { _, _ ->
|
||||
// do nothing
|
||||
}
|
||||
.create()
|
||||
|
||||
companion object {
|
||||
|
||||
@JvmStatic
|
||||
private val ARG_MESSAGE = "message"
|
||||
|
||||
@JvmStatic
|
||||
fun newInstance(message: String): ErrorDialog = ErrorDialog().apply {
|
||||
arguments = Bundle().apply { putString(ARG_MESSAGE, message) }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Canvas
|
||||
import android.graphics.Color
|
||||
import android.graphics.Paint
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import kotlin.math.max
|
||||
|
||||
object VisualizationUtils {
|
||||
/** Radius of circle used to draw keypoints. */
|
||||
private const val CIRCLE_RADIUS = 6f
|
||||
|
||||
/** Width of line used to connected two keypoints. */
|
||||
private const val LINE_WIDTH = 4f
|
||||
|
||||
/** The text size of the person id that will be displayed when the tracker is available. */
|
||||
private const val PERSON_ID_TEXT_SIZE = 30f
|
||||
|
||||
/** Distance from person id to the nose keypoint. */
|
||||
private const val PERSON_ID_MARGIN = 6f
|
||||
|
||||
/** Pair of keypoints to draw lines between. */
|
||||
private val bodyJoints = listOf(
|
||||
Pair(BodyPart.NOSE, BodyPart.LEFT_EYE),
|
||||
Pair(BodyPart.NOSE, BodyPart.RIGHT_EYE),
|
||||
Pair(BodyPart.LEFT_EYE, BodyPart.LEFT_EAR),
|
||||
Pair(BodyPart.RIGHT_EYE, BodyPart.RIGHT_EAR),
|
||||
Pair(BodyPart.NOSE, BodyPart.LEFT_SHOULDER),
|
||||
Pair(BodyPart.NOSE, BodyPart.RIGHT_SHOULDER),
|
||||
Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_ELBOW),
|
||||
Pair(BodyPart.LEFT_ELBOW, BodyPart.LEFT_WRIST),
|
||||
Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_ELBOW),
|
||||
Pair(BodyPart.RIGHT_ELBOW, BodyPart.RIGHT_WRIST),
|
||||
Pair(BodyPart.LEFT_SHOULDER, BodyPart.RIGHT_SHOULDER),
|
||||
Pair(BodyPart.LEFT_SHOULDER, BodyPart.LEFT_HIP),
|
||||
Pair(BodyPart.RIGHT_SHOULDER, BodyPart.RIGHT_HIP),
|
||||
Pair(BodyPart.LEFT_HIP, BodyPart.RIGHT_HIP),
|
||||
Pair(BodyPart.LEFT_HIP, BodyPart.LEFT_KNEE),
|
||||
Pair(BodyPart.LEFT_KNEE, BodyPart.LEFT_ANKLE),
|
||||
Pair(BodyPart.RIGHT_HIP, BodyPart.RIGHT_KNEE),
|
||||
Pair(BodyPart.RIGHT_KNEE, BodyPart.RIGHT_ANKLE)
|
||||
)
|
||||
|
||||
// Draw line and point indicate body pose
|
||||
fun drawBodyKeypoints(
|
||||
input: Bitmap,
|
||||
persons: List<Person>,
|
||||
isTrackerEnabled: Boolean = false
|
||||
): Bitmap {
|
||||
val paintCircle = Paint().apply {
|
||||
strokeWidth = CIRCLE_RADIUS
|
||||
color = Color.RED
|
||||
style = Paint.Style.FILL
|
||||
}
|
||||
val paintLine = Paint().apply {
|
||||
strokeWidth = LINE_WIDTH
|
||||
color = Color.RED
|
||||
style = Paint.Style.STROKE
|
||||
}
|
||||
|
||||
val paintText = Paint().apply {
|
||||
textSize = PERSON_ID_TEXT_SIZE
|
||||
color = Color.BLUE
|
||||
textAlign = Paint.Align.LEFT
|
||||
}
|
||||
|
||||
val output = input.copy(Bitmap.Config.ARGB_8888, true)
|
||||
val originalSizeCanvas = Canvas(output)
|
||||
persons.forEach { person ->
|
||||
// draw person id if tracker is enable
|
||||
if (isTrackerEnabled) {
|
||||
person.boundingBox?.let {
|
||||
val personIdX = max(0f, it.left)
|
||||
val personIdY = max(0f, it.top)
|
||||
|
||||
originalSizeCanvas.drawText(
|
||||
person.id.toString(),
|
||||
personIdX,
|
||||
personIdY - PERSON_ID_MARGIN,
|
||||
paintText
|
||||
)
|
||||
originalSizeCanvas.drawRect(it, paintLine)
|
||||
}
|
||||
}
|
||||
bodyJoints.forEach {
|
||||
val pointA = person.keyPoints[it.first.position].coordinate
|
||||
val pointB = person.keyPoints[it.second.position].coordinate
|
||||
originalSizeCanvas.drawLine(pointA.x, pointA.y, pointB.x, pointB.y, paintLine)
|
||||
}
|
||||
|
||||
person.keyPoints.forEach { point ->
|
||||
originalSizeCanvas.drawCircle(
|
||||
point.coordinate.x,
|
||||
point.coordinate.y,
|
||||
CIRCLE_RADIUS,
|
||||
paintCircle
|
||||
)
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
}
|
@ -0,0 +1,151 @@
|
||||
package org.tensorflow.lite.examples.poseestimation
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.ImageFormat
|
||||
import android.graphics.Rect
|
||||
import android.media.Image
|
||||
import android.renderscript.Allocation
|
||||
import android.renderscript.Element
|
||||
import android.renderscript.RenderScript
|
||||
import android.renderscript.ScriptIntrinsicYuvToRGB
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
class YuvToRgbConverter(context: Context) {
|
||||
private val rs = RenderScript.create(context)
|
||||
private val scriptYuvToRgb = ScriptIntrinsicYuvToRGB.create(rs, Element.U8_4(rs))
|
||||
|
||||
private var pixelCount: Int = -1
|
||||
private lateinit var yuvBuffer: ByteBuffer
|
||||
private lateinit var inputAllocation: Allocation
|
||||
private lateinit var outputAllocation: Allocation
|
||||
|
||||
@Synchronized
|
||||
fun yuvToRgb(image: Image, output: Bitmap) {
|
||||
|
||||
// Ensure that the intermediate output byte buffer is allocated
|
||||
if (!::yuvBuffer.isInitialized) {
|
||||
pixelCount = image.cropRect.width() * image.cropRect.height()
|
||||
yuvBuffer = ByteBuffer.allocateDirect(
|
||||
pixelCount * ImageFormat.getBitsPerPixel(ImageFormat.YUV_420_888) / 8)
|
||||
}
|
||||
|
||||
// Get the YUV data in byte array form
|
||||
imageToByteBuffer(image, yuvBuffer)
|
||||
|
||||
// Ensure that the RenderScript inputs and outputs are allocated
|
||||
if (!::inputAllocation.isInitialized) {
|
||||
inputAllocation = Allocation.createSized(rs, Element.U8(rs), yuvBuffer.array().size)
|
||||
}
|
||||
if (!::outputAllocation.isInitialized) {
|
||||
outputAllocation = Allocation.createFromBitmap(rs, output)
|
||||
}
|
||||
|
||||
// Convert YUV to RGB
|
||||
inputAllocation.copyFrom(yuvBuffer.array())
|
||||
scriptYuvToRgb.setInput(inputAllocation)
|
||||
scriptYuvToRgb.forEach(outputAllocation)
|
||||
outputAllocation.copyTo(output)
|
||||
}
|
||||
|
||||
private fun imageToByteBuffer(image: Image, outputBuffer: ByteBuffer) {
|
||||
assert(image.format == ImageFormat.YUV_420_888)
|
||||
|
||||
val imageCrop = image.cropRect
|
||||
val imagePlanes = image.planes
|
||||
val rowData = ByteArray(imagePlanes.first().rowStride)
|
||||
|
||||
imagePlanes.forEachIndexed { planeIndex, plane ->
|
||||
|
||||
// How many values are read in input for each output value written
|
||||
// Only the Y plane has a value for every pixel, U and V have half the resolution i.e.
|
||||
//
|
||||
// Y Plane U Plane V Plane
|
||||
// =============== ======= =======
|
||||
// Y Y Y Y Y Y Y Y U U U U V V V V
|
||||
// Y Y Y Y Y Y Y Y U U U U V V V V
|
||||
// Y Y Y Y Y Y Y Y U U U U V V V V
|
||||
// Y Y Y Y Y Y Y Y U U U U V V V V
|
||||
// Y Y Y Y Y Y Y Y
|
||||
// Y Y Y Y Y Y Y Y
|
||||
// Y Y Y Y Y Y Y Y
|
||||
val outputStride: Int
|
||||
|
||||
// The index in the output buffer the next value will be written at
|
||||
// For Y it's zero, for U and V we start at the end of Y and interleave them i.e.
|
||||
//
|
||||
// First chunk Second chunk
|
||||
// =============== ===============
|
||||
// Y Y Y Y Y Y Y Y U V U V U V U V
|
||||
// Y Y Y Y Y Y Y Y U V U V U V U V
|
||||
// Y Y Y Y Y Y Y Y U V U V U V U V
|
||||
// Y Y Y Y Y Y Y Y U V U V U V U V
|
||||
// Y Y Y Y Y Y Y Y
|
||||
// Y Y Y Y Y Y Y Y
|
||||
// Y Y Y Y Y Y Y Y
|
||||
var outputOffset: Int
|
||||
|
||||
when (planeIndex) {
|
||||
0 -> {
|
||||
outputStride = 1
|
||||
outputOffset = 0
|
||||
}
|
||||
1 -> {
|
||||
outputStride = 2
|
||||
outputOffset = pixelCount + 1
|
||||
}
|
||||
2 -> {
|
||||
outputStride = 2
|
||||
outputOffset = pixelCount
|
||||
}
|
||||
else -> {
|
||||
// Image contains more than 3 planes, something strange is going on
|
||||
return@forEachIndexed
|
||||
}
|
||||
}
|
||||
|
||||
val buffer = plane.buffer
|
||||
val rowStride = plane.rowStride
|
||||
val pixelStride = plane.pixelStride
|
||||
|
||||
// We have to divide the width and height by two if it's not the Y plane
|
||||
val planeCrop = if (planeIndex == 0) {
|
||||
imageCrop
|
||||
} else {
|
||||
Rect(
|
||||
imageCrop.left / 2,
|
||||
imageCrop.top / 2,
|
||||
imageCrop.right / 2,
|
||||
imageCrop.bottom / 2
|
||||
)
|
||||
}
|
||||
|
||||
val planeWidth = planeCrop.width()
|
||||
val planeHeight = planeCrop.height()
|
||||
|
||||
buffer.position(rowStride * planeCrop.top + pixelStride * planeCrop.left)
|
||||
for (row in 0 until planeHeight) {
|
||||
val length: Int
|
||||
if (pixelStride == 1 && outputStride == 1) {
|
||||
// When there is a single stride value for pixel and output, we can just copy
|
||||
// the entire row in a single step
|
||||
length = planeWidth
|
||||
buffer.get(outputBuffer.array(), outputOffset, length)
|
||||
outputOffset += length
|
||||
} else {
|
||||
// When either pixel or output have a stride > 1 we must copy pixel by pixel
|
||||
length = (planeWidth - 1) * pixelStride + 1
|
||||
buffer.get(rowData, 0, length)
|
||||
for (col in 0 until planeWidth) {
|
||||
outputBuffer.array()[outputOffset] = rowData[col * pixelStride]
|
||||
outputOffset += outputStride
|
||||
}
|
||||
}
|
||||
|
||||
if (row < planeHeight - 1) {
|
||||
buffer.position(buffer.position() + rowStride - length)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,327 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.camera
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.ImageFormat
|
||||
import android.graphics.Matrix
|
||||
import android.graphics.Rect
|
||||
import android.hardware.camera2.CameraCaptureSession
|
||||
import android.hardware.camera2.CameraCharacteristics
|
||||
import android.hardware.camera2.CameraDevice
|
||||
import android.hardware.camera2.CameraManager
|
||||
import android.media.ImageReader
|
||||
import android.os.Handler
|
||||
import android.os.HandlerThread
|
||||
import android.util.Log
|
||||
import android.view.Surface
|
||||
import android.view.SurfaceView
|
||||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||||
import org.tensorflow.lite.examples.poseestimation.VisualizationUtils
|
||||
import org.tensorflow.lite.examples.poseestimation.YuvToRgbConverter
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.PoseClassifier
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.PoseDetector
|
||||
import org.tensorflow.lite.examples.poseestimation.ml.TrackerType
|
||||
import java.util.*
|
||||
import kotlin.coroutines.resume
|
||||
import kotlin.coroutines.resumeWithException
|
||||
|
||||
class CameraSource(
|
||||
private val surfaceView: SurfaceView,
|
||||
private val listener: CameraSourceListener? = null
|
||||
) {
|
||||
|
||||
companion object {
|
||||
private const val PREVIEW_WIDTH = 640
|
||||
private const val PREVIEW_HEIGHT = 480
|
||||
|
||||
/** Threshold for confidence score. */
|
||||
private const val MIN_CONFIDENCE = .2f
|
||||
private const val TAG = "Camera Source"
|
||||
}
|
||||
|
||||
private val lock = Any()
|
||||
private var detector: PoseDetector? = null
|
||||
private var classifier: PoseClassifier? = null
|
||||
private var isTrackerEnabled = false
|
||||
private var yuvConverter: YuvToRgbConverter = YuvToRgbConverter(surfaceView.context)
|
||||
private lateinit var imageBitmap: Bitmap
|
||||
|
||||
/** Frame count that have been processed so far in an one second interval to calculate FPS. */
|
||||
private var fpsTimer: Timer? = null
|
||||
private var frameProcessedInOneSecondInterval = 0
|
||||
private var framesPerSecond = 0
|
||||
|
||||
/** Detects, characterizes, and connects to a CameraDevice (used for all camera operations) */
|
||||
private val cameraManager: CameraManager by lazy {
|
||||
val context = surfaceView.context
|
||||
context.getSystemService(Context.CAMERA_SERVICE) as CameraManager
|
||||
}
|
||||
|
||||
/** Readers used as buffers for camera still shots */
|
||||
private var imageReader: ImageReader? = null
|
||||
|
||||
/** The [CameraDevice] that will be opened in this fragment */
|
||||
private var camera: CameraDevice? = null
|
||||
|
||||
/** Internal reference to the ongoing [CameraCaptureSession] configured with our parameters */
|
||||
private var session: CameraCaptureSession? = null
|
||||
|
||||
/** [HandlerThread] where all buffer reading operations run */
|
||||
private var imageReaderThread: HandlerThread? = null
|
||||
|
||||
/** [Handler] corresponding to [imageReaderThread] */
|
||||
private var imageReaderHandler: Handler? = null
|
||||
private var cameraId: String = ""
|
||||
|
||||
suspend fun initCamera() {
|
||||
camera = openCamera(cameraManager, cameraId)
|
||||
imageReader =
|
||||
ImageReader.newInstance(PREVIEW_WIDTH, PREVIEW_HEIGHT, ImageFormat.YUV_420_888, 3)
|
||||
imageReader?.setOnImageAvailableListener({ reader ->
|
||||
val image = reader.acquireLatestImage()
|
||||
if (image != null) {
|
||||
if (!::imageBitmap.isInitialized) {
|
||||
imageBitmap =
|
||||
Bitmap.createBitmap(
|
||||
PREVIEW_WIDTH,
|
||||
PREVIEW_HEIGHT,
|
||||
Bitmap.Config.ARGB_8888
|
||||
)
|
||||
}
|
||||
yuvConverter.yuvToRgb(image, imageBitmap)
|
||||
// Create rotated version for portrait display
|
||||
val rotateMatrix = Matrix()
|
||||
rotateMatrix.postRotate(90.0f)
|
||||
|
||||
val rotatedBitmap = Bitmap.createBitmap(
|
||||
imageBitmap, 0, 0, PREVIEW_WIDTH, PREVIEW_HEIGHT,
|
||||
rotateMatrix, false
|
||||
)
|
||||
processImage(rotatedBitmap)
|
||||
image.close()
|
||||
}
|
||||
}, imageReaderHandler)
|
||||
|
||||
imageReader?.surface?.let { surface ->
|
||||
session = createSession(listOf(surface))
|
||||
val cameraRequest = camera?.createCaptureRequest(
|
||||
CameraDevice.TEMPLATE_PREVIEW
|
||||
)?.apply {
|
||||
addTarget(surface)
|
||||
}
|
||||
cameraRequest?.build()?.let {
|
||||
session?.setRepeatingRequest(it, null, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun createSession(targets: List<Surface>): CameraCaptureSession =
|
||||
suspendCancellableCoroutine { cont ->
|
||||
camera?.createCaptureSession(targets, object : CameraCaptureSession.StateCallback() {
|
||||
override fun onConfigured(captureSession: CameraCaptureSession) =
|
||||
cont.resume(captureSession)
|
||||
|
||||
override fun onConfigureFailed(session: CameraCaptureSession) {
|
||||
cont.resumeWithException(Exception("Session error"))
|
||||
}
|
||||
}, null)
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
private suspend fun openCamera(manager: CameraManager, cameraId: String): CameraDevice =
|
||||
suspendCancellableCoroutine { cont ->
|
||||
manager.openCamera(cameraId, object : CameraDevice.StateCallback() {
|
||||
override fun onOpened(camera: CameraDevice) = cont.resume(camera)
|
||||
|
||||
override fun onDisconnected(camera: CameraDevice) {
|
||||
camera.close()
|
||||
}
|
||||
|
||||
override fun onError(camera: CameraDevice, error: Int) {
|
||||
if (cont.isActive) cont.resumeWithException(Exception("Camera error"))
|
||||
}
|
||||
}, imageReaderHandler)
|
||||
}
|
||||
|
||||
fun prepareCamera() {
|
||||
for (cameraId in cameraManager.cameraIdList) {
|
||||
val characteristics = cameraManager.getCameraCharacteristics(cameraId)
|
||||
|
||||
// We don't use a front facing camera in this sample.
|
||||
val cameraDirection = characteristics.get(CameraCharacteristics.LENS_FACING)
|
||||
if (cameraDirection != null &&
|
||||
cameraDirection == CameraCharacteristics.LENS_FACING_FRONT
|
||||
) {
|
||||
continue
|
||||
}
|
||||
this.cameraId = cameraId
|
||||
}
|
||||
}
|
||||
|
||||
fun setDetector(detector: PoseDetector) {
|
||||
synchronized(lock) {
|
||||
if (this.detector != null) {
|
||||
this.detector?.close()
|
||||
this.detector = null
|
||||
}
|
||||
this.detector = detector
|
||||
}
|
||||
}
|
||||
|
||||
fun setClassifier(classifier: PoseClassifier?) {
|
||||
synchronized(lock) {
|
||||
if (this.classifier != null) {
|
||||
this.classifier?.close()
|
||||
this.classifier = null
|
||||
}
|
||||
this.classifier = classifier
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set Tracker for Movenet MuiltiPose model.
|
||||
*/
|
||||
fun setTracker(trackerType: TrackerType) {
|
||||
isTrackerEnabled = trackerType != TrackerType.OFF
|
||||
(this.detector as? MoveNetMultiPose)?.setTracker(trackerType)
|
||||
}
|
||||
|
||||
fun resume() {
|
||||
imageReaderThread = HandlerThread("imageReaderThread").apply { start() }
|
||||
imageReaderHandler = Handler(imageReaderThread!!.looper)
|
||||
fpsTimer = Timer()
|
||||
fpsTimer?.scheduleAtFixedRate(
|
||||
object : TimerTask() {
|
||||
override fun run() {
|
||||
framesPerSecond = frameProcessedInOneSecondInterval
|
||||
frameProcessedInOneSecondInterval = 0
|
||||
}
|
||||
},
|
||||
0,
|
||||
1000
|
||||
)
|
||||
}
|
||||
|
||||
fun close() {
|
||||
session?.close()
|
||||
session = null
|
||||
camera?.close()
|
||||
camera = null
|
||||
imageReader?.close()
|
||||
imageReader = null
|
||||
stopImageReaderThread()
|
||||
detector?.close()
|
||||
detector = null
|
||||
classifier?.close()
|
||||
classifier = null
|
||||
fpsTimer?.cancel()
|
||||
fpsTimer = null
|
||||
frameProcessedInOneSecondInterval = 0
|
||||
framesPerSecond = 0
|
||||
}
|
||||
|
||||
// process image
|
||||
private fun processImage(bitmap: Bitmap) {
|
||||
val persons = mutableListOf<Person>()
|
||||
var classificationResult: List<Pair<String, Float>>? = null
|
||||
|
||||
synchronized(lock) {
|
||||
detector?.estimatePoses(bitmap)?.let {
|
||||
persons.addAll(it)
|
||||
|
||||
// if the model only returns one item, allow running the Pose classifier.
|
||||
if (persons.isNotEmpty()) {
|
||||
classifier?.run {
|
||||
classificationResult = classify(persons[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
frameProcessedInOneSecondInterval++
|
||||
if (frameProcessedInOneSecondInterval == 1) {
|
||||
// send fps to view
|
||||
listener?.onFPSListener(framesPerSecond)
|
||||
}
|
||||
|
||||
// if the model returns only one item, show that item's score.
|
||||
if (persons.isNotEmpty()) {
|
||||
listener?.onDetectedInfo(persons[0].score, classificationResult)
|
||||
}
|
||||
visualize(persons, bitmap)
|
||||
}
|
||||
|
||||
private fun visualize(persons: List<Person>, bitmap: Bitmap) {
|
||||
|
||||
val outputBitmap = VisualizationUtils.drawBodyKeypoints(
|
||||
bitmap,
|
||||
persons.filter { it.score > MIN_CONFIDENCE }, isTrackerEnabled
|
||||
)
|
||||
|
||||
val holder = surfaceView.holder
|
||||
val surfaceCanvas = holder.lockCanvas()
|
||||
surfaceCanvas?.let { canvas ->
|
||||
val screenWidth: Int
|
||||
val screenHeight: Int
|
||||
val left: Int
|
||||
val top: Int
|
||||
|
||||
if (canvas.height > canvas.width) {
|
||||
val ratio = outputBitmap.height.toFloat() / outputBitmap.width
|
||||
screenWidth = canvas.width
|
||||
left = 0
|
||||
screenHeight = (canvas.width * ratio).toInt()
|
||||
top = (canvas.height - screenHeight) / 2
|
||||
} else {
|
||||
val ratio = outputBitmap.width.toFloat() / outputBitmap.height
|
||||
screenHeight = canvas.height
|
||||
top = 0
|
||||
screenWidth = (canvas.height * ratio).toInt()
|
||||
left = (canvas.width - screenWidth) / 2
|
||||
}
|
||||
val right: Int = left + screenWidth
|
||||
val bottom: Int = top + screenHeight
|
||||
|
||||
canvas.drawBitmap(
|
||||
outputBitmap, Rect(0, 0, outputBitmap.width, outputBitmap.height),
|
||||
Rect(left, top, right, bottom), null
|
||||
)
|
||||
surfaceView.holder.unlockCanvasAndPost(canvas)
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopImageReaderThread() {
|
||||
imageReaderThread?.quitSafely()
|
||||
try {
|
||||
imageReaderThread?.join()
|
||||
imageReaderThread = null
|
||||
imageReaderHandler = null
|
||||
} catch (e: InterruptedException) {
|
||||
Log.d(TAG, e.message.toString())
|
||||
}
|
||||
}
|
||||
|
||||
interface CameraSourceListener {
|
||||
fun onFPSListener(fps: Int)
|
||||
|
||||
fun onDetectedInfo(personScore: Float?, poseLabels: List<Pair<String, Float>>?)
|
||||
}
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.data
|
||||
|
||||
enum class BodyPart(val position: Int) {
|
||||
NOSE(0),
|
||||
LEFT_EYE(1),
|
||||
RIGHT_EYE(2),
|
||||
LEFT_EAR(3),
|
||||
RIGHT_EAR(4),
|
||||
LEFT_SHOULDER(5),
|
||||
RIGHT_SHOULDER(6),
|
||||
LEFT_ELBOW(7),
|
||||
RIGHT_ELBOW(8),
|
||||
LEFT_WRIST(9),
|
||||
RIGHT_WRIST(10),
|
||||
LEFT_HIP(11),
|
||||
RIGHT_HIP(12),
|
||||
LEFT_KNEE(13),
|
||||
RIGHT_KNEE(14),
|
||||
LEFT_ANKLE(15),
|
||||
RIGHT_ANKLE(16);
|
||||
companion object{
|
||||
private val map = values().associateBy(BodyPart::position)
|
||||
fun fromInt(position: Int): BodyPart = map.getValue(position)
|
||||
}
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.data
|
||||
|
||||
enum class Device {
|
||||
CPU,
|
||||
NNAPI,
|
||||
GPU
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.data
|
||||
|
||||
import android.graphics.PointF
|
||||
|
||||
data class KeyPoint(val bodyPart: BodyPart, var coordinate: PointF, val score: Float)
|
@ -0,0 +1,26 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.data
|
||||
|
||||
import android.graphics.RectF
|
||||
|
||||
data class Person(
|
||||
var id: Int = -1, // default id is -1
|
||||
val keyPoints: List<KeyPoint>,
|
||||
val boundingBox: RectF? = null, // Only MoveNet MultiPose return bounding box.
|
||||
val score: Float
|
||||
)
|
@ -0,0 +1,24 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.data
|
||||
|
||||
data class TorsoAndBodyDistance(
|
||||
val maxTorsoYDistance: Float,
|
||||
val maxTorsoXDistance: Float,
|
||||
val maxBodyYDistance: Float,
|
||||
val maxBodyXDistance: Float
|
||||
)
|
@ -0,0 +1,353 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.*
|
||||
import android.os.SystemClock
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.examples.poseestimation.data.*
|
||||
import org.tensorflow.lite.gpu.GpuDelegate
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
|
||||
enum class ModelType {
|
||||
Lightning,
|
||||
Thunder
|
||||
}
|
||||
|
||||
class MoveNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
|
||||
PoseDetector {
|
||||
|
||||
companion object {
|
||||
private const val MIN_CROP_KEYPOINT_SCORE = .2f
|
||||
private const val CPU_NUM_THREADS = 4
|
||||
|
||||
// Parameters that control how large crop region should be expanded from previous frames'
|
||||
// body keypoints.
|
||||
private const val TORSO_EXPANSION_RATIO = 1.9f
|
||||
private const val BODY_EXPANSION_RATIO = 1.2f
|
||||
|
||||
// TFLite file names.
|
||||
private const val LIGHTNING_FILENAME = "movenet_lightning.tflite"
|
||||
private const val THUNDER_FILENAME = "movenet_thunder.tflite"
|
||||
|
||||
// allow specifying model type.
|
||||
fun create(context: Context, device: Device, modelType: ModelType): MoveNet {
|
||||
val options = Interpreter.Options()
|
||||
var gpuDelegate: GpuDelegate? = null
|
||||
options.setNumThreads(CPU_NUM_THREADS)
|
||||
when (device) {
|
||||
Device.CPU -> {
|
||||
}
|
||||
Device.GPU -> {
|
||||
gpuDelegate = GpuDelegate()
|
||||
options.addDelegate(gpuDelegate)
|
||||
}
|
||||
Device.NNAPI -> options.setUseNNAPI(true)
|
||||
}
|
||||
return MoveNet(
|
||||
Interpreter(
|
||||
FileUtil.loadMappedFile(
|
||||
context,
|
||||
if (modelType == ModelType.Lightning) LIGHTNING_FILENAME
|
||||
else THUNDER_FILENAME
|
||||
), options
|
||||
),
|
||||
gpuDelegate
|
||||
)
|
||||
}
|
||||
|
||||
// default to lightning.
|
||||
fun create(context: Context, device: Device): MoveNet =
|
||||
create(context, device, ModelType.Lightning)
|
||||
}
|
||||
|
||||
private var cropRegion: RectF? = null
|
||||
private var lastInferenceTimeNanos: Long = -1
|
||||
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
|
||||
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
|
||||
private var outputShape: IntArray = interpreter.getOutputTensor(0).shape()
|
||||
|
||||
override fun estimatePoses(bitmap: Bitmap): List<Person> {
|
||||
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
|
||||
if (cropRegion == null) {
|
||||
cropRegion = initRectF(bitmap.width, bitmap.height)
|
||||
}
|
||||
var totalScore = 0f
|
||||
|
||||
val numKeyPoints = outputShape[2]
|
||||
val keyPoints = mutableListOf<KeyPoint>()
|
||||
|
||||
cropRegion?.run {
|
||||
val rect = RectF(
|
||||
(left * bitmap.width),
|
||||
(top * bitmap.height),
|
||||
(right * bitmap.width),
|
||||
(bottom * bitmap.height)
|
||||
)
|
||||
val detectBitmap = Bitmap.createBitmap(
|
||||
rect.width().toInt(),
|
||||
rect.height().toInt(),
|
||||
Bitmap.Config.ARGB_8888
|
||||
)
|
||||
Canvas(detectBitmap).drawBitmap(
|
||||
bitmap,
|
||||
-rect.left,
|
||||
-rect.top,
|
||||
null
|
||||
)
|
||||
val inputTensor = processInputImage(detectBitmap, inputWidth, inputHeight)
|
||||
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
|
||||
val widthRatio = detectBitmap.width.toFloat() / inputWidth
|
||||
val heightRatio = detectBitmap.height.toFloat() / inputHeight
|
||||
|
||||
val positions = mutableListOf<Float>()
|
||||
|
||||
inputTensor?.let { input ->
|
||||
interpreter.run(input.buffer, outputTensor.buffer.rewind())
|
||||
val output = outputTensor.floatArray
|
||||
for (idx in 0 until numKeyPoints) {
|
||||
val x = output[idx * 3 + 1] * inputWidth * widthRatio
|
||||
val y = output[idx * 3 + 0] * inputHeight * heightRatio
|
||||
|
||||
positions.add(x)
|
||||
positions.add(y)
|
||||
val score = output[idx * 3 + 2]
|
||||
keyPoints.add(
|
||||
KeyPoint(
|
||||
BodyPart.fromInt(idx),
|
||||
PointF(
|
||||
x,
|
||||
y
|
||||
),
|
||||
score
|
||||
)
|
||||
)
|
||||
totalScore += score
|
||||
}
|
||||
}
|
||||
val matrix = Matrix()
|
||||
val points = positions.toFloatArray()
|
||||
|
||||
matrix.postTranslate(rect.left, rect.top)
|
||||
matrix.mapPoints(points)
|
||||
keyPoints.forEachIndexed { index, keyPoint ->
|
||||
keyPoint.coordinate =
|
||||
PointF(
|
||||
points[index * 2],
|
||||
points[index * 2 + 1]
|
||||
)
|
||||
}
|
||||
// new crop region
|
||||
cropRegion = determineRectF(keyPoints, bitmap.width, bitmap.height)
|
||||
}
|
||||
lastInferenceTimeNanos =
|
||||
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
|
||||
return listOf(Person(keyPoints = keyPoints, score = totalScore / numKeyPoints))
|
||||
}
|
||||
|
||||
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
|
||||
|
||||
override fun close() {
|
||||
gpuDelegate?.close()
|
||||
interpreter.close()
|
||||
cropRegion = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare input image for detection
|
||||
*/
|
||||
private fun processInputImage(bitmap: Bitmap, inputWidth: Int, inputHeight: Int): TensorImage? {
|
||||
val width: Int = bitmap.width
|
||||
val height: Int = bitmap.height
|
||||
|
||||
val size = if (height > width) width else height
|
||||
val imageProcessor = ImageProcessor.Builder().apply {
|
||||
add(ResizeWithCropOrPadOp(size, size))
|
||||
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
|
||||
}.build()
|
||||
val tensorImage = TensorImage(DataType.UINT8)
|
||||
tensorImage.load(bitmap)
|
||||
return imageProcessor.process(tensorImage)
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the default crop region.
|
||||
* The function provides the initial crop region (pads the full image from both
|
||||
* sides to make it a square image) when the algorithm cannot reliably determine
|
||||
* the crop region from the previous frame.
|
||||
*/
|
||||
private fun initRectF(imageWidth: Int, imageHeight: Int): RectF {
|
||||
val xMin: Float
|
||||
val yMin: Float
|
||||
val width: Float
|
||||
val height: Float
|
||||
if (imageWidth > imageHeight) {
|
||||
width = 1f
|
||||
height = imageWidth.toFloat() / imageHeight
|
||||
xMin = 0f
|
||||
yMin = (imageHeight / 2f - imageWidth / 2f) / imageHeight
|
||||
} else {
|
||||
height = 1f
|
||||
width = imageHeight.toFloat() / imageWidth
|
||||
yMin = 0f
|
||||
xMin = (imageWidth / 2f - imageHeight / 2) / imageWidth
|
||||
}
|
||||
return RectF(
|
||||
xMin,
|
||||
yMin,
|
||||
xMin + width,
|
||||
yMin + height
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether there are enough torso keypoints.
|
||||
* This function checks whether the model is confident at predicting one of the
|
||||
* shoulders/hips which is required to determine a good crop region.
|
||||
*/
|
||||
private fun torsoVisible(keyPoints: List<KeyPoint>): Boolean {
|
||||
return ((keyPoints[BodyPart.LEFT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE).or(
|
||||
keyPoints[BodyPart.RIGHT_HIP.position].score > MIN_CROP_KEYPOINT_SCORE
|
||||
)).and(
|
||||
(keyPoints[BodyPart.LEFT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE).or(
|
||||
keyPoints[BodyPart.RIGHT_SHOULDER.position].score > MIN_CROP_KEYPOINT_SCORE
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the region to crop the image for the model to run inference on.
|
||||
* The algorithm uses the detected joints from the previous frame to estimate
|
||||
* the square region that encloses the full body of the target person and
|
||||
* centers at the midpoint of two hip joints. The crop size is determined by
|
||||
* the distances between each joints and the center point.
|
||||
* When the model is not confident with the four torso joint predictions, the
|
||||
* function returns a default crop which is the full image padded to square.
|
||||
*/
|
||||
private fun determineRectF(
|
||||
keyPoints: List<KeyPoint>,
|
||||
imageWidth: Int,
|
||||
imageHeight: Int
|
||||
): RectF {
|
||||
val targetKeyPoints = mutableListOf<KeyPoint>()
|
||||
keyPoints.forEach {
|
||||
targetKeyPoints.add(
|
||||
KeyPoint(
|
||||
it.bodyPart,
|
||||
PointF(
|
||||
it.coordinate.x,
|
||||
it.coordinate.y
|
||||
),
|
||||
it.score
|
||||
)
|
||||
)
|
||||
}
|
||||
if (torsoVisible(keyPoints)) {
|
||||
val centerX =
|
||||
(targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.x +
|
||||
targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.x) / 2f
|
||||
val centerY =
|
||||
(targetKeyPoints[BodyPart.LEFT_HIP.position].coordinate.y +
|
||||
targetKeyPoints[BodyPart.RIGHT_HIP.position].coordinate.y) / 2f
|
||||
|
||||
val torsoAndBodyDistances =
|
||||
determineTorsoAndBodyDistances(keyPoints, targetKeyPoints, centerX, centerY)
|
||||
|
||||
val list = listOf(
|
||||
torsoAndBodyDistances.maxTorsoXDistance * TORSO_EXPANSION_RATIO,
|
||||
torsoAndBodyDistances.maxTorsoYDistance * TORSO_EXPANSION_RATIO,
|
||||
torsoAndBodyDistances.maxBodyXDistance * BODY_EXPANSION_RATIO,
|
||||
torsoAndBodyDistances.maxBodyYDistance * BODY_EXPANSION_RATIO
|
||||
)
|
||||
|
||||
var cropLengthHalf = list.maxOrNull() ?: 0f
|
||||
val tmp = listOf(centerX, imageWidth - centerX, centerY, imageHeight - centerY)
|
||||
cropLengthHalf = min(cropLengthHalf, tmp.maxOrNull() ?: 0f)
|
||||
val cropCorner = Pair(centerY - cropLengthHalf, centerX - cropLengthHalf)
|
||||
|
||||
return if (cropLengthHalf > max(imageWidth, imageHeight) / 2f) {
|
||||
initRectF(imageWidth, imageHeight)
|
||||
} else {
|
||||
val cropLength = cropLengthHalf * 2
|
||||
RectF(
|
||||
cropCorner.second / imageWidth,
|
||||
cropCorner.first / imageHeight,
|
||||
(cropCorner.second + cropLength) / imageWidth,
|
||||
(cropCorner.first + cropLength) / imageHeight,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
return initRectF(imageWidth, imageHeight)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the maximum distance from each keypoints to the center location.
|
||||
* The function returns the maximum distances from the two sets of keypoints:
|
||||
* full 17 keypoints and 4 torso keypoints. The returned information will be
|
||||
* used to determine the crop size. See determineRectF for more detail.
|
||||
*/
|
||||
private fun determineTorsoAndBodyDistances(
|
||||
keyPoints: List<KeyPoint>,
|
||||
targetKeyPoints: List<KeyPoint>,
|
||||
centerX: Float,
|
||||
centerY: Float
|
||||
): TorsoAndBodyDistance {
|
||||
val torsoJoints = listOf(
|
||||
BodyPart.LEFT_SHOULDER.position,
|
||||
BodyPart.RIGHT_SHOULDER.position,
|
||||
BodyPart.LEFT_HIP.position,
|
||||
BodyPart.RIGHT_HIP.position
|
||||
)
|
||||
|
||||
var maxTorsoYRange = 0f
|
||||
var maxTorsoXRange = 0f
|
||||
torsoJoints.forEach { joint ->
|
||||
val distY = abs(centerY - targetKeyPoints[joint].coordinate.y)
|
||||
val distX = abs(centerX - targetKeyPoints[joint].coordinate.x)
|
||||
if (distY > maxTorsoYRange) maxTorsoYRange = distY
|
||||
if (distX > maxTorsoXRange) maxTorsoXRange = distX
|
||||
}
|
||||
|
||||
var maxBodyYRange = 0f
|
||||
var maxBodyXRange = 0f
|
||||
for (joint in keyPoints.indices) {
|
||||
if (keyPoints[joint].score < MIN_CROP_KEYPOINT_SCORE) continue
|
||||
val distY = abs(centerY - keyPoints[joint].coordinate.y)
|
||||
val distX = abs(centerX - keyPoints[joint].coordinate.x)
|
||||
|
||||
if (distY > maxBodyYRange) maxBodyYRange = distY
|
||||
if (distX > maxBodyXRange) maxBodyXRange = distX
|
||||
}
|
||||
return TorsoAndBodyDistance(
|
||||
maxTorsoYRange,
|
||||
maxTorsoXRange,
|
||||
maxBodyYRange,
|
||||
maxBodyXRange
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,309 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.PointF
|
||||
import android.graphics.RectF
|
||||
import android.os.SystemClock
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import org.tensorflow.lite.examples.poseestimation.tracker.*
|
||||
import org.tensorflow.lite.gpu.GpuDelegate
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.image.ImageOperator
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import kotlin.math.ceil
|
||||
|
||||
class MoveNetMultiPose(
|
||||
private val interpreter: Interpreter,
|
||||
private val type: Type,
|
||||
private val gpuDelegate: GpuDelegate?,
|
||||
) : PoseDetector {
|
||||
private val outputShape = interpreter.getOutputTensor(0).shape()
|
||||
private val inputShape = interpreter.getInputTensor(0).shape()
|
||||
private var imageWidth: Int = 0
|
||||
private var imageHeight: Int = 0
|
||||
private var targetWidth: Int = 0
|
||||
private var targetHeight: Int = 0
|
||||
private var scaleHeight: Int = 0
|
||||
private var scaleWidth: Int = 0
|
||||
private var lastInferenceTimeNanos: Long = -1
|
||||
private var tracker: AbstractTracker? = null
|
||||
|
||||
companion object {
|
||||
private const val DYNAMIC_MODEL_TARGET_INPUT_SIZE = 256
|
||||
private const val SHAPE_MULTIPLE = 32.0
|
||||
private const val DETECTION_THRESHOLD = 0.11
|
||||
private const val DETECTION_SCORE_INDEX = 55
|
||||
private const val BOUNDING_BOX_Y_MIN_INDEX = 51
|
||||
private const val BOUNDING_BOX_X_MIN_INDEX = 52
|
||||
private const val BOUNDING_BOX_Y_MAX_INDEX = 53
|
||||
private const val BOUNDING_BOX_X_MAX_INDEX = 54
|
||||
private const val KEYPOINT_COUNT = 17
|
||||
private const val OUTPUTS_COUNT_PER_KEYPOINT = 3
|
||||
private const val CPU_NUM_THREADS = 4
|
||||
|
||||
// allow specifying model type.
|
||||
fun create(
|
||||
context: Context,
|
||||
device: Device,
|
||||
type: Type,
|
||||
): MoveNetMultiPose {
|
||||
val options = Interpreter.Options()
|
||||
var gpuDelegate: GpuDelegate? = null
|
||||
when (device) {
|
||||
Device.CPU -> {
|
||||
options.setNumThreads(CPU_NUM_THREADS)
|
||||
}
|
||||
Device.GPU -> {
|
||||
// only fixed model support Gpu delegate option.
|
||||
if (type == Type.Fixed) {
|
||||
gpuDelegate = GpuDelegate()
|
||||
options.addDelegate(gpuDelegate)
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
return MoveNetMultiPose(
|
||||
Interpreter(
|
||||
FileUtil.loadMappedFile(
|
||||
context,
|
||||
if (type == Type.Dynamic)
|
||||
"movenet_multipose_fp16.tflite" else ""
|
||||
//@TODO: (khanhlvg) Add support for fixed shape model if it's released.
|
||||
), options
|
||||
), type, gpuDelegate
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert x and y coordinates ([0-1]) returns from the TFlite model
|
||||
* to the coordinates corresponding to the input image.
|
||||
*/
|
||||
private fun resizeKeypoint(x: Float, y: Float): PointF {
|
||||
return PointF(resizeX(x), resizeY(y))
|
||||
}
|
||||
|
||||
private fun resizeX(x: Float): Float {
|
||||
return if (imageWidth > imageHeight) {
|
||||
val ratioWidth = imageWidth.toFloat() / targetWidth
|
||||
x * targetWidth * ratioWidth
|
||||
} else {
|
||||
val detectedWidth =
|
||||
if (type == Type.Dynamic) targetWidth else inputShape[2]
|
||||
val paddingWidth = detectedWidth - scaleWidth
|
||||
val ratioWidth = imageWidth.toFloat() / scaleWidth
|
||||
(x * detectedWidth - paddingWidth / 2f) * ratioWidth
|
||||
}
|
||||
}
|
||||
|
||||
private fun resizeY(y: Float): Float {
|
||||
return if (imageWidth > imageHeight) {
|
||||
val detectedHeight =
|
||||
if (type == Type.Dynamic) targetHeight else inputShape[1]
|
||||
val paddingHeight = detectedHeight - scaleHeight
|
||||
val ratioHeight = imageHeight.toFloat() / scaleHeight
|
||||
(y * detectedHeight - paddingHeight / 2f) * ratioHeight
|
||||
} else {
|
||||
val ratioHeight = imageHeight.toFloat() / targetHeight
|
||||
y * targetHeight * ratioHeight
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare input image for detection
|
||||
*/
|
||||
private fun processInputTensor(bitmap: Bitmap): TensorImage {
|
||||
imageWidth = bitmap.width
|
||||
imageHeight = bitmap.height
|
||||
|
||||
// if model type is fixed. get input size from input shape.
|
||||
val inputSizeHeight =
|
||||
if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[1]
|
||||
val inputSizeWidth =
|
||||
if (type == Type.Dynamic) DYNAMIC_MODEL_TARGET_INPUT_SIZE else inputShape[2]
|
||||
|
||||
val resizeOp: ImageOperator
|
||||
if (imageWidth > imageHeight) {
|
||||
val scale = inputSizeWidth / imageWidth.toFloat()
|
||||
targetWidth = inputSizeWidth
|
||||
scaleHeight = ceil(imageHeight * scale).toInt()
|
||||
targetHeight = (ceil((scaleHeight / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt()
|
||||
resizeOp = ResizeOp(scaleHeight, targetWidth, ResizeOp.ResizeMethod.BILINEAR)
|
||||
} else {
|
||||
val scale = inputSizeHeight / imageHeight.toFloat()
|
||||
targetHeight = inputSizeHeight
|
||||
scaleWidth = ceil(imageWidth * scale).toInt()
|
||||
targetWidth = (ceil((scaleWidth / SHAPE_MULTIPLE)) * SHAPE_MULTIPLE).toInt()
|
||||
resizeOp = ResizeOp(targetHeight, scaleWidth, ResizeOp.ResizeMethod.BILINEAR)
|
||||
}
|
||||
|
||||
val resizeWithCropOrPad = if (type == Type.Dynamic) ResizeWithCropOrPadOp(
|
||||
targetHeight,
|
||||
targetWidth
|
||||
) else ResizeWithCropOrPadOp(
|
||||
inputSizeHeight,
|
||||
inputSizeWidth
|
||||
)
|
||||
val imageProcessor = ImageProcessor.Builder().apply {
|
||||
add(resizeOp)
|
||||
add(resizeWithCropOrPad)
|
||||
}.build()
|
||||
val tensorImage = TensorImage(DataType.UINT8)
|
||||
tensorImage.load(bitmap)
|
||||
return imageProcessor.process(tensorImage)
|
||||
}
|
||||
|
||||
/**
|
||||
* Run tracker (if available) and process the output.
|
||||
*/
|
||||
private fun postProcess(modelOutput: FloatArray): List<Person> {
|
||||
val persons = mutableListOf<Person>()
|
||||
for (idx in modelOutput.indices step outputShape[2]) {
|
||||
val personScore = modelOutput[idx + DETECTION_SCORE_INDEX]
|
||||
if (personScore < DETECTION_THRESHOLD) continue
|
||||
val positions = modelOutput.copyOfRange(idx, idx + 51)
|
||||
val keyPoints = mutableListOf<KeyPoint>()
|
||||
for (i in 0 until KEYPOINT_COUNT) {
|
||||
val y = positions[i * OUTPUTS_COUNT_PER_KEYPOINT]
|
||||
val x = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 1]
|
||||
val score = positions[i * OUTPUTS_COUNT_PER_KEYPOINT + 2]
|
||||
keyPoints.add(KeyPoint(BodyPart.fromInt(i), PointF(x, y), score))
|
||||
}
|
||||
val yMin = modelOutput[idx + BOUNDING_BOX_Y_MIN_INDEX]
|
||||
val xMin = modelOutput[idx + BOUNDING_BOX_X_MIN_INDEX]
|
||||
val yMax = modelOutput[idx + BOUNDING_BOX_Y_MAX_INDEX]
|
||||
val xMax = modelOutput[idx + BOUNDING_BOX_X_MAX_INDEX]
|
||||
val boundingBox = RectF(xMin, yMin, xMax, yMax)
|
||||
persons.add(
|
||||
Person(
|
||||
keyPoints = keyPoints,
|
||||
boundingBox = boundingBox,
|
||||
score = personScore
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (persons.isEmpty()) return emptyList()
|
||||
|
||||
if (tracker == null) {
|
||||
persons.forEach {
|
||||
it.keyPoints.forEach { key ->
|
||||
key.coordinate = resizeKeypoint(key.coordinate.x, key.coordinate.y)
|
||||
}
|
||||
}
|
||||
return persons
|
||||
} else {
|
||||
val trackPersons = mutableListOf<Person>()
|
||||
tracker?.apply(persons, System.currentTimeMillis() * 1000)?.forEach {
|
||||
val resizeKeyPoint = mutableListOf<KeyPoint>()
|
||||
it.keyPoints.forEach { key ->
|
||||
resizeKeyPoint.add(
|
||||
KeyPoint(
|
||||
key.bodyPart,
|
||||
resizeKeypoint(key.coordinate.x, key.coordinate.y),
|
||||
key.score
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
var resizeBoundingBox: RectF? = null
|
||||
it.boundingBox?.let { boundingBox ->
|
||||
resizeBoundingBox = RectF(
|
||||
resizeX(boundingBox.left),
|
||||
resizeY(boundingBox.top),
|
||||
resizeX(boundingBox.right),
|
||||
resizeY(boundingBox.bottom)
|
||||
)
|
||||
}
|
||||
trackPersons.add(Person(it.id, resizeKeyPoint, resizeBoundingBox, it.score))
|
||||
}
|
||||
return trackPersons
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create and set tracker.
|
||||
*/
|
||||
fun setTracker(trackerType: TrackerType) {
|
||||
tracker = when (trackerType) {
|
||||
TrackerType.BOUNDING_BOX -> {
|
||||
BoundingBoxTracker()
|
||||
}
|
||||
TrackerType.KEYPOINTS -> {
|
||||
KeyPointsTracker()
|
||||
}
|
||||
TrackerType.OFF -> {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run TFlite model and Returns a list of "Person" corresponding to the input image.
|
||||
*/
|
||||
override fun estimatePoses(bitmap: Bitmap): List<Person> {
|
||||
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
|
||||
val inputTensor = processInputTensor(bitmap)
|
||||
val outputTensor = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
|
||||
|
||||
// if model is dynamic, resize input before run interpreter
|
||||
if (type == Type.Dynamic) {
|
||||
val inputShape = intArrayOf(1).plus(inputTensor.tensorBuffer.shape)
|
||||
interpreter.resizeInput(0, inputShape, true)
|
||||
interpreter.allocateTensors()
|
||||
}
|
||||
interpreter.run(inputTensor.buffer, outputTensor.buffer.rewind())
|
||||
|
||||
val processedPerson = postProcess(outputTensor.floatArray)
|
||||
lastInferenceTimeNanos =
|
||||
SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
|
||||
return processedPerson
|
||||
}
|
||||
|
||||
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
|
||||
|
||||
/**
|
||||
* Close all resources when not in use.
|
||||
*/
|
||||
override fun close() {
|
||||
gpuDelegate?.close()
|
||||
interpreter.close()
|
||||
tracker = null
|
||||
}
|
||||
}
|
||||
|
||||
enum class Type {
|
||||
Dynamic, Fixed
|
||||
}
|
||||
|
||||
enum class TrackerType {
|
||||
OFF, BOUNDING_BOX, KEYPOINTS
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
|
||||
class PoseClassifier(
|
||||
private val interpreter: Interpreter,
|
||||
private val labels: List<String>
|
||||
) {
|
||||
private val input = interpreter.getInputTensor(0).shape()
|
||||
private val output = interpreter.getOutputTensor(0).shape()
|
||||
|
||||
companion object {
|
||||
private const val MODEL_FILENAME = "classifier.tflite"
|
||||
private const val LABELS_FILENAME = "labels.txt"
|
||||
private const val CPU_NUM_THREADS = 4
|
||||
|
||||
fun create(context: Context): PoseClassifier {
|
||||
val options = Interpreter.Options().apply {
|
||||
setNumThreads(CPU_NUM_THREADS)
|
||||
}
|
||||
return PoseClassifier(
|
||||
Interpreter(
|
||||
FileUtil.loadMappedFile(
|
||||
context, MODEL_FILENAME
|
||||
), options
|
||||
),
|
||||
FileUtil.loadLabels(context, LABELS_FILENAME)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun classify(person: Person?): List<Pair<String, Float>> {
|
||||
// Preprocess the pose estimation result to a flat array
|
||||
val inputVector = FloatArray(input[1])
|
||||
person?.keyPoints?.forEachIndexed { index, keyPoint ->
|
||||
inputVector[index * 3] = keyPoint.coordinate.y
|
||||
inputVector[index * 3 + 1] = keyPoint.coordinate.x
|
||||
inputVector[index * 3 + 2] = keyPoint.score
|
||||
}
|
||||
|
||||
// Postprocess the model output to human readable class names
|
||||
val outputTensor = FloatArray(output[1])
|
||||
interpreter.run(arrayOf(inputVector), arrayOf(outputTensor))
|
||||
val output = mutableListOf<Pair<String, Float>>()
|
||||
outputTensor.forEachIndexed { index, score ->
|
||||
output.add(Pair(labels[index], score))
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
fun close() {
|
||||
interpreter.close()
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
|
||||
interface PoseDetector : AutoCloseable {
|
||||
|
||||
fun estimatePoses(bitmap: Bitmap): List<Person>
|
||||
|
||||
fun lastInferenceTimeNanos(): Long
|
||||
}
|
@ -0,0 +1,261 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.PointF
|
||||
import android.os.SystemClock
|
||||
import android.util.Log
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Device
|
||||
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import org.tensorflow.lite.gpu.GpuDelegate
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
|
||||
import kotlin.math.exp
|
||||
|
||||
class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) :
|
||||
PoseDetector {
|
||||
|
||||
companion object {
|
||||
private const val CPU_NUM_THREADS = 4
|
||||
private const val MEAN = 127.5f
|
||||
private const val STD = 127.5f
|
||||
private const val TAG = "Posenet"
|
||||
private const val MODEL_FILENAME = "posenet.tflite"
|
||||
|
||||
fun create(context: Context, device: Device): PoseNet {
|
||||
val options = Interpreter.Options()
|
||||
var gpuDelegate: GpuDelegate? = null
|
||||
options.setNumThreads(CPU_NUM_THREADS)
|
||||
when (device) {
|
||||
Device.CPU -> {
|
||||
}
|
||||
Device.GPU -> {
|
||||
gpuDelegate = GpuDelegate()
|
||||
options.addDelegate(gpuDelegate)
|
||||
}
|
||||
Device.NNAPI -> options.setUseNNAPI(true)
|
||||
}
|
||||
return PoseNet(
|
||||
Interpreter(
|
||||
FileUtil.loadMappedFile(
|
||||
context,
|
||||
MODEL_FILENAME
|
||||
), options
|
||||
),
|
||||
gpuDelegate
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private var lastInferenceTimeNanos: Long = -1
|
||||
private val inputWidth = interpreter.getInputTensor(0).shape()[1]
|
||||
private val inputHeight = interpreter.getInputTensor(0).shape()[2]
|
||||
private var cropHeight = 0f
|
||||
private var cropWidth = 0f
|
||||
private var cropSize = 0
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun estimatePoses(bitmap: Bitmap): List<Person> {
|
||||
val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
|
||||
val inputArray = arrayOf(processInputImage(bitmap).tensorBuffer.buffer)
|
||||
Log.i(
|
||||
TAG,
|
||||
String.format(
|
||||
"Scaling to [-1,1] took %.2f ms",
|
||||
(SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000f
|
||||
)
|
||||
)
|
||||
|
||||
val outputMap = initOutputMap(interpreter)
|
||||
|
||||
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
|
||||
interpreter.runForMultipleInputsOutputs(inputArray, outputMap)
|
||||
lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
|
||||
Log.i(
|
||||
TAG,
|
||||
String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
|
||||
)
|
||||
|
||||
val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
|
||||
val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>
|
||||
|
||||
val postProcessingStartTimeNanos = SystemClock.elapsedRealtimeNanos()
|
||||
val person = postProcessModelOuputs(heatmaps, offsets)
|
||||
Log.i(
|
||||
TAG,
|
||||
String.format(
|
||||
"Postprocessing took %.2f ms",
|
||||
(SystemClock.elapsedRealtimeNanos() - postProcessingStartTimeNanos) / 1_000_000f
|
||||
)
|
||||
)
|
||||
|
||||
return listOf(person)
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert heatmaps and offsets output of Posenet into a list of keypoints
|
||||
*/
|
||||
private fun postProcessModelOuputs(
|
||||
heatmaps: Array<Array<Array<FloatArray>>>,
|
||||
offsets: Array<Array<Array<FloatArray>>>
|
||||
): Person {
|
||||
val height = heatmaps[0].size
|
||||
val width = heatmaps[0][0].size
|
||||
val numKeypoints = heatmaps[0][0][0].size
|
||||
|
||||
// Finds the (row, col) locations of where the keypoints are most likely to be.
|
||||
val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
|
||||
for (keypoint in 0 until numKeypoints) {
|
||||
var maxVal = heatmaps[0][0][0][keypoint]
|
||||
var maxRow = 0
|
||||
var maxCol = 0
|
||||
for (row in 0 until height) {
|
||||
for (col in 0 until width) {
|
||||
if (heatmaps[0][row][col][keypoint] > maxVal) {
|
||||
maxVal = heatmaps[0][row][col][keypoint]
|
||||
maxRow = row
|
||||
maxCol = col
|
||||
}
|
||||
}
|
||||
}
|
||||
keypointPositions[keypoint] = Pair(maxRow, maxCol)
|
||||
}
|
||||
|
||||
// Calculating the x and y coordinates of the keypoints with offset adjustment.
|
||||
val xCoords = IntArray(numKeypoints)
|
||||
val yCoords = IntArray(numKeypoints)
|
||||
val confidenceScores = FloatArray(numKeypoints)
|
||||
keypointPositions.forEachIndexed { idx, position ->
|
||||
val positionY = keypointPositions[idx].first
|
||||
val positionX = keypointPositions[idx].second
|
||||
|
||||
val inputImageCoordinateY =
|
||||
position.first / (height - 1).toFloat() * inputHeight + offsets[0][positionY][positionX][idx]
|
||||
val ratioHeight = cropSize.toFloat() / inputHeight
|
||||
val paddingHeight = cropHeight / 2
|
||||
yCoords[idx] = (inputImageCoordinateY * ratioHeight - paddingHeight).toInt()
|
||||
|
||||
val inputImageCoordinateX =
|
||||
position.second / (width - 1).toFloat() * inputWidth + offsets[0][positionY][positionX][idx + numKeypoints]
|
||||
val ratioWidth = cropSize.toFloat() / inputWidth
|
||||
val paddingWidth = cropWidth / 2
|
||||
xCoords[idx] = (inputImageCoordinateX * ratioWidth - paddingWidth).toInt()
|
||||
|
||||
confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
|
||||
}
|
||||
|
||||
val keypointList = mutableListOf<KeyPoint>()
|
||||
var totalScore = 0.0f
|
||||
enumValues<BodyPart>().forEachIndexed { idx, it ->
|
||||
keypointList.add(
|
||||
KeyPoint(
|
||||
it,
|
||||
PointF(xCoords[idx].toFloat(), yCoords[idx].toFloat()),
|
||||
confidenceScores[idx]
|
||||
)
|
||||
)
|
||||
totalScore += confidenceScores[idx]
|
||||
}
|
||||
return Person(keyPoints = keypointList.toList(), score = totalScore / numKeypoints)
|
||||
}
|
||||
|
||||
override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos
|
||||
|
||||
override fun close() {
|
||||
gpuDelegate?.close()
|
||||
interpreter.close()
|
||||
}
|
||||
|
||||
/**
|
||||
* Scale and crop the input image to a TensorImage.
|
||||
*/
|
||||
private fun processInputImage(bitmap: Bitmap): TensorImage {
|
||||
// reset crop width and height
|
||||
cropWidth = 0f
|
||||
cropHeight = 0f
|
||||
cropSize = if (bitmap.width > bitmap.height) {
|
||||
cropHeight = (bitmap.width - bitmap.height).toFloat()
|
||||
bitmap.width
|
||||
} else {
|
||||
cropWidth = (bitmap.height - bitmap.width).toFloat()
|
||||
bitmap.height
|
||||
}
|
||||
|
||||
val imageProcessor = ImageProcessor.Builder().apply {
|
||||
add(ResizeWithCropOrPadOp(cropSize, cropSize))
|
||||
add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR))
|
||||
add(NormalizeOp(MEAN, STD))
|
||||
}.build()
|
||||
val tensorImage = TensorImage(DataType.FLOAT32)
|
||||
tensorImage.load(bitmap)
|
||||
return imageProcessor.process(tensorImage)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
|
||||
*/
|
||||
private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
|
||||
val outputMap = HashMap<Int, Any>()
|
||||
|
||||
// 1 * 9 * 9 * 17 contains heatmaps
|
||||
val heatmapsShape = interpreter.getOutputTensor(0).shape()
|
||||
outputMap[0] = Array(heatmapsShape[0]) {
|
||||
Array(heatmapsShape[1]) {
|
||||
Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) }
|
||||
}
|
||||
}
|
||||
|
||||
// 1 * 9 * 9 * 34 contains offsets
|
||||
val offsetsShape = interpreter.getOutputTensor(1).shape()
|
||||
outputMap[1] = Array(offsetsShape[0]) {
|
||||
Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } }
|
||||
}
|
||||
|
||||
// 1 * 9 * 9 * 32 contains forward displacements
|
||||
val displacementsFwdShape = interpreter.getOutputTensor(2).shape()
|
||||
outputMap[2] = Array(offsetsShape[0]) {
|
||||
Array(displacementsFwdShape[1]) {
|
||||
Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) }
|
||||
}
|
||||
}
|
||||
|
||||
// 1 * 9 * 9 * 32 contains backward displacements
|
||||
val displacementsBwdShape = interpreter.getOutputTensor(3).shape()
|
||||
outputMap[3] = Array(displacementsBwdShape[0]) {
|
||||
Array(displacementsBwdShape[1]) {
|
||||
Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) }
|
||||
}
|
||||
}
|
||||
|
||||
return outputMap
|
||||
}
|
||||
|
||||
/** Returns value within [0,1]. */
|
||||
private fun sigmoid(x: Float): Float {
|
||||
return (1.0f / (1.0f + exp(-x)))
|
||||
}
|
||||
}
|
@ -0,0 +1,158 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
|
||||
abstract class AbstractTracker(val config: TrackerConfig) {
|
||||
|
||||
private val maxAge = config.maxAge * 1000 // convert milliseconds to microseconds
|
||||
private var nextTrackId = 0
|
||||
var tracks = mutableListOf<Track>()
|
||||
private set
|
||||
|
||||
/**
|
||||
* Computes pairwise similarity scores between detections and tracks, based
|
||||
* on detected features.
|
||||
* @param persons A list of detected person.
|
||||
* @returns A list of shape [num_det, num_tracks] with pairwise
|
||||
* similarity scores between detections and tracks.
|
||||
*/
|
||||
abstract fun computeSimilarity(persons: List<Person>): List<List<Float>>
|
||||
|
||||
/**
|
||||
* Tracks person instances across frames based on detections.
|
||||
* @param persons A list of person
|
||||
* @param timestamp The current timestamp in microseconds
|
||||
* @return An updated list of persons with tracking id.
|
||||
*/
|
||||
fun apply(persons: List<Person>, timestamp: Long): List<Person> {
|
||||
tracks = filterOldTrack(timestamp).toMutableList()
|
||||
val simMatrix = computeSimilarity(persons)
|
||||
assignTrack(persons, simMatrix, timestamp)
|
||||
tracks = updateTrack().toMutableList()
|
||||
return persons
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all track in list of tracks
|
||||
*/
|
||||
fun reset() {
|
||||
tracks.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the next track id
|
||||
*/
|
||||
private fun nextTrackID() = ++nextTrackId
|
||||
|
||||
/**
|
||||
* Performs a greedy optimization to link detections with tracks. The person
|
||||
* list is updated in place by providing an `id` property. If incoming
|
||||
* detections are not linked with existing tracks, new tracks will be created.
|
||||
* @param persons A list of detected person. It's assumed that persons are
|
||||
* sorted from most confident to least confident.
|
||||
* @param simMatrix A list of shape [num_det, num_tracks] with pairwise
|
||||
* similarity scores between detections and tracks.
|
||||
* @param timestamp The current timestamp in microseconds.
|
||||
*/
|
||||
private fun assignTrack(persons: List<Person>, simMatrix: List<List<Float>>, timestamp: Long) {
|
||||
if ((simMatrix.size != persons.size) != (simMatrix[0].size != tracks.size)) {
|
||||
throw IllegalArgumentException(
|
||||
"Size of person array and similarity matrix does not match.")
|
||||
}
|
||||
|
||||
val unmatchedTrackIndices = MutableList(tracks.size) { it }
|
||||
val unmatchedDetectionIndices = mutableListOf<Int>()
|
||||
|
||||
for (detectionIndex in persons.indices) {
|
||||
// If the track list is empty, add the person's index
|
||||
// to unmatched detections to create a new track later.
|
||||
if (unmatchedTrackIndices.isEmpty()) {
|
||||
unmatchedDetectionIndices.add(detectionIndex)
|
||||
continue
|
||||
}
|
||||
|
||||
// Assign the detection to the track which produces the highest pairwise
|
||||
// similarity score, assuming the score exceeds the minimum similarity
|
||||
// threshold.
|
||||
var maxTrackIndex = -1
|
||||
var maxSimilarity = -1f
|
||||
unmatchedTrackIndices.forEach { trackIndex ->
|
||||
val similarity = simMatrix[detectionIndex][trackIndex]
|
||||
if (similarity >= config.minSimilarity && similarity > maxSimilarity) {
|
||||
maxTrackIndex = trackIndex
|
||||
maxSimilarity = similarity
|
||||
}
|
||||
}
|
||||
if (maxTrackIndex >= 0) {
|
||||
val linkedTrack = tracks[maxTrackIndex]
|
||||
tracks[maxTrackIndex] =
|
||||
createTrack(persons[detectionIndex], linkedTrack.person.id, timestamp)
|
||||
persons[detectionIndex].id = linkedTrack.person.id
|
||||
val index = unmatchedTrackIndices.indexOf(maxTrackIndex)
|
||||
unmatchedTrackIndices.removeAt(index)
|
||||
} else {
|
||||
unmatchedDetectionIndices.add(detectionIndex)
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn new tracks for all unmatched detections.
|
||||
unmatchedDetectionIndices.forEach { detectionIndex ->
|
||||
val newTrack = createTrack(persons[detectionIndex], timestamp = timestamp)
|
||||
tracks.add(newTrack)
|
||||
persons[detectionIndex].id = newTrack.person.id
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters tracks based on their age.
|
||||
* @param timestamp The timestamp in microseconds
|
||||
*/
|
||||
private fun filterOldTrack(timestamp: Long): List<Track> {
|
||||
return tracks.filter {
|
||||
timestamp - it.lastTimestamp <= maxAge
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort the track list by timestamp (newer first)
|
||||
* and return the track list with size equal to config.maxTracks
|
||||
*/
|
||||
private fun updateTrack(): List<Track> {
|
||||
tracks.sortByDescending { it.lastTimestamp }
|
||||
return tracks.take(config.maxTracks)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new track from person's information.
|
||||
* @param person A person
|
||||
* @param id The Id assign to the new track. If it is null, assign the next track id.
|
||||
* @param timestamp The timestamp in microseconds
|
||||
*/
|
||||
private fun createTrack(person: Person, id: Int? = null, timestamp: Long): Track {
|
||||
return Track(
|
||||
person = Person(
|
||||
id = id ?: nextTrackID(),
|
||||
keyPoints = person.keyPoints,
|
||||
boundingBox = person.boundingBox,
|
||||
score = person.score
|
||||
),
|
||||
lastTimestamp = timestamp
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import androidx.annotation.VisibleForTesting
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* BoundingBoxTracker, which tracks objects based on bounding box similarity,
|
||||
* currently defined as intersection-over-union (IoU).
|
||||
*/
|
||||
class BoundingBoxTracker(config: TrackerConfig = TrackerConfig()) : AbstractTracker(config) {
|
||||
|
||||
/**
|
||||
* Computes similarity based on intersection-over-union (IoU). See `AbstractTracker`
|
||||
* for more details.
|
||||
*/
|
||||
override fun computeSimilarity(persons: List<Person>): List<List<Float>> {
|
||||
if (persons.isEmpty() && tracks.isEmpty()) {
|
||||
return emptyList()
|
||||
}
|
||||
return persons.map { person -> tracks.map { track -> iou(person, track.person) } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the intersection-over-union (IoU) between a person and a track person.
|
||||
* @param person1 A person
|
||||
* @param person2 A track person
|
||||
* @return The IoU between the person and the track person. This number is
|
||||
* between 0 and 1, and larger values indicate more box similarity.
|
||||
*/
|
||||
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
|
||||
fun iou(person1: Person, person2: Person): Float {
|
||||
if (person1.boundingBox != null && person2.boundingBox != null) {
|
||||
val xMin = max(person1.boundingBox.left, person2.boundingBox.left)
|
||||
val yMin = max(person1.boundingBox.top, person2.boundingBox.top)
|
||||
val xMax = min(person1.boundingBox.right, person2.boundingBox.right)
|
||||
val yMax = min(person1.boundingBox.bottom, person2.boundingBox.bottom)
|
||||
if (xMin >= xMax || yMin >= yMax) return 0f
|
||||
val intersection = (xMax - xMin) * (yMax - yMin)
|
||||
val areaPerson = person1.boundingBox.width() * person1.boundingBox.height()
|
||||
val areaTrack = person2.boundingBox.width() * person2.boundingBox.height()
|
||||
return intersection / (areaPerson + areaTrack - intersection)
|
||||
}
|
||||
return 0f
|
||||
}
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import androidx.annotation.VisibleForTesting
|
||||
import org.tensorflow.lite.examples.poseestimation.data.KeyPoint
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
import kotlin.math.pow
|
||||
|
||||
/**
|
||||
* KeypointTracker, which tracks poses based on keypoint similarity. This
|
||||
* tracker assumes that keypoints are provided in normalized image
|
||||
* coordinates.
|
||||
*/
|
||||
class KeyPointsTracker(
|
||||
trackerConfig: TrackerConfig = TrackerConfig(
|
||||
keyPointsTrackerParams = KeyPointsTrackerParams()
|
||||
)
|
||||
) : AbstractTracker(trackerConfig) {
|
||||
/**
|
||||
* Computes similarity based on Object Keypoint Similarity (OKS). It's
|
||||
* assumed that the keypoints within each person are in normalized image
|
||||
* coordinates. See `AbstractTracker` for more details.
|
||||
*/
|
||||
override fun computeSimilarity(persons: List<Person>): List<List<Float>> {
|
||||
if (persons.isEmpty() && tracks.isEmpty()) {
|
||||
return emptyList()
|
||||
}
|
||||
val simMatrix = mutableListOf<MutableList<Float>>()
|
||||
persons.forEach { person ->
|
||||
val row = mutableListOf<Float>()
|
||||
tracks.forEach { track ->
|
||||
val oksValue = oks(person, track.person)
|
||||
row.add(oksValue)
|
||||
}
|
||||
simMatrix.add(row)
|
||||
}
|
||||
return simMatrix
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the Object Keypoint Similarity (OKS) between a person and track person.
|
||||
* This is similar in spirit to the calculation used by COCO keypoint eval:
|
||||
* https://cocodataset.org/#keypoints-eval
|
||||
* In this case, OKS is calculated as:
|
||||
* (1/sum_i d(c_i, c_ti)) * \sum_i exp(-d_i^2/(2*a_ti*x_i^2))*d(c_i, c_ti)
|
||||
* where
|
||||
* d(x, y) is an indicator function which only produces 1 if x and y
|
||||
* exceed a given threshold (i.e. keypointThreshold), otherwise 0.
|
||||
* c_i is the confidence of keypoint i from the new person
|
||||
* c_ti is the confidence of keypoint i from the track person
|
||||
* d_i is the Euclidean distance between the person and track person keypoint
|
||||
* a_ti is the area of the track object (the box covering the keypoints)
|
||||
* x_i is a constant that controls falloff in a Gaussian distribution,
|
||||
* computed as 2*keypointFalloff[i].
|
||||
* @param person1 A person.
|
||||
* @param person2 A track person.
|
||||
* @returns The OKS score between the person and the track person. This number is
|
||||
* between 0 and 1, and larger values indicate more keypoint similarity.
|
||||
*/
|
||||
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
|
||||
fun oks(person1: Person, person2: Person): Float {
|
||||
if (config.keyPointsTrackerParams == null) return 0f
|
||||
person2.keyPoints.let { keyPoints ->
|
||||
val boxArea = area(keyPoints) + 1e-6
|
||||
var oksTotal = 0f
|
||||
var numValidKeyPoints = 0
|
||||
|
||||
person1.keyPoints.forEachIndexed { index, _ ->
|
||||
val poseKpt = person1.keyPoints[index]
|
||||
val trackKpt = person2.keyPoints[index]
|
||||
val threshold = config.keyPointsTrackerParams.keypointThreshold
|
||||
if (poseKpt.score < threshold || trackKpt.score < threshold) {
|
||||
return@forEachIndexed
|
||||
}
|
||||
numValidKeyPoints += 1
|
||||
val dSquared: Float =
|
||||
(poseKpt.coordinate.x - trackKpt.coordinate.x).pow(2) + (poseKpt.coordinate.y - trackKpt.coordinate.y).pow(
|
||||
2
|
||||
)
|
||||
val x = 2f * config.keyPointsTrackerParams.keypointFalloff[index]
|
||||
oksTotal += exp(-1f * dSquared / (2f * boxArea * x.pow(2))).toFloat()
|
||||
}
|
||||
if (numValidKeyPoints < config.keyPointsTrackerParams.minNumKeyPoints) {
|
||||
return 0f
|
||||
}
|
||||
return oksTotal / numValidKeyPoints
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the area of a bounding box that tightly covers keypoints.
|
||||
* @param keyPoints A list of KeyPoint.
|
||||
* @returns The area of the object.
|
||||
*/
|
||||
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
|
||||
fun area(keyPoints: List<KeyPoint>): Float {
|
||||
val validKeypoint = keyPoints.filter {
|
||||
it.score > config.keyPointsTrackerParams?.keypointThreshold ?: 0f
|
||||
}
|
||||
if (validKeypoint.isEmpty()) return 0f
|
||||
val minX = min(1f, validKeypoint.minOf { it.coordinate.x })
|
||||
val maxX = max(0f, validKeypoint.maxOf { it.coordinate.x })
|
||||
val minY = min(1f, validKeypoint.minOf { it.coordinate.y })
|
||||
val maxY = max(0f, validKeypoint.maxOf { it.coordinate.y })
|
||||
return (maxX - minX) * (maxY - minY)
|
||||
}
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
*/
|
||||
|
||||
package org.tensorflow.lite.examples.poseestimation.tracker
|
||||
|
||||
import org.tensorflow.lite.examples.poseestimation.data.Person
|
||||
|
||||
data class Track(
|
||||
val person: Person,
|
||||
val lastTimestamp: Long
|
||||
)
|
After Width: | Height: | Size: 1.4 KiB |
After Width: | Height: | Size: 929 B |
After Width: | Height: | Size: 596 B |
@ -0,0 +1,18 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<shape xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:shape="rectangle">
|
||||
<solid
|
||||
android:color="#e4e4e4">
|
||||
</solid>
|
||||
<stroke
|
||||
android:width="0dp"
|
||||
android:color="#424242">
|
||||
</stroke>
|
||||
<corners
|
||||
android:topLeftRadius="30dp"
|
||||
android:topRightRadius="30dp"
|
||||
android:bottomLeftRadius="0dp"
|
||||
android:bottomRightRadius="0dp">
|
||||
</corners>
|
||||
|
||||
</shape>
|
After Width: | Height: | Size: 12 KiB |
@ -0,0 +1,27 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<androidx.coordinatorlayout.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<SurfaceView
|
||||
android:id="@+id/surfaceView"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent" />
|
||||
|
||||
<androidx.appcompat.widget.Toolbar
|
||||
android:id="@+id/toolbar"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="?attr/actionBarSize"
|
||||
android:background="#66000000">
|
||||
|
||||
<ImageView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:contentDescription="@null"
|
||||
android:src="@drawable/tfl2_logo" />
|
||||
</androidx.appcompat.widget.Toolbar>
|
||||
|
||||
<include layout="@layout/bottom_sheet_layout"/>
|
||||
</androidx.coordinatorlayout.widget.CoordinatorLayout>
|
@ -0,0 +1,124 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<androidx.appcompat.widget.LinearLayoutCompat xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:id="@+id/bottom_sheet"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_margin="8dp"
|
||||
android:background="@drawable/rounded_edge"
|
||||
android:orientation="vertical"
|
||||
android:paddingStart="8dp"
|
||||
android:paddingTop="10dp"
|
||||
android:paddingEnd="8dp"
|
||||
android:paddingBottom="16dp"
|
||||
app:behavior_hideable="false"
|
||||
app:behavior_peekHeight="36dp"
|
||||
app:layout_behavior="com.google.android.material.bottomsheet.BottomSheetBehavior"
|
||||
tools:showIn="@layout/activity_main">
|
||||
|
||||
<ImageView
|
||||
android:id="@+id/bottom_sheet_arrow"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_gravity="center"
|
||||
android:contentDescription="@null"
|
||||
android:src="@drawable/icn_chevron_up" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/tvFps"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/tvScore"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content" />
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:orientation="horizontal">
|
||||
|
||||
<TextView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="@string/tfe_pe_tv_device" />
|
||||
|
||||
<Spinner
|
||||
android:id="@+id/spnDevice"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content" />
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:orientation="horizontal">
|
||||
|
||||
<TextView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="@string/tfe_pe_tv_model" />
|
||||
|
||||
<Spinner
|
||||
android:id="@+id/spnModel"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content" />
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/vTrackerOption"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:orientation="horizontal">
|
||||
|
||||
<TextView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="@string/tfe_pe_tv_tracking" />
|
||||
|
||||
<Spinner
|
||||
android:id="@+id/spnTracker"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content" />
|
||||
</LinearLayout>
|
||||
|
||||
<RelativeLayout
|
||||
android:id="@+id/vClassificationOption"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:orientation="horizontal">
|
||||
|
||||
<TextView
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_centerVertical="true"
|
||||
android:layout_toStartOf="@id/swPoseClassification"
|
||||
android:text="@string/tfe_pe_tv_pose_classification" />
|
||||
|
||||
<androidx.appcompat.widget.SwitchCompat
|
||||
android:id="@+id/swPoseClassification"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_alignParentEnd="true" />
|
||||
</RelativeLayout>
|
||||
|
||||
<TextView
|
||||
android:id="@+id/tvClassificationValue1"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:visibility="gone" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/tvClassificationValue2"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:visibility="gone" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/tvClassificationValue3"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:visibility="gone" />
|
||||
</androidx.appcompat.widget.LinearLayoutCompat>
|
@ -0,0 +1,9 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -0,0 +1,30 @@
|
||||
<resources>
|
||||
<string name="tfe_pe_app_name">TFL Pose Estimation</string>
|
||||
<string name="tfe_pe_request_permission">This app needs camera permission.</string>
|
||||
<string name="tfe_pe_tv_score">Score: %.2f</string>
|
||||
<string name="tfe_pe_tv_model">Model: </string>
|
||||
<string name="tfe_pe_tv_fps">Fps: %d</string>
|
||||
<string name="tfe_pe_tv_device">Device: </string>
|
||||
<string name="tfe_pe_tv_classification_value">- %s</string>
|
||||
<string name="tfe_pe_tv_pose_classification">Pose Classification</string>
|
||||
<string name="tfe_pe_tv_tracking">Tracker: </string>
|
||||
<string name="tfe_pe_gpu_error">Movenet MultiPose does not support GPU. Fallback to CPU.</string>
|
||||
<string-array name="tfe_pe_models_array">
|
||||
<item>Movenet Lightning</item>
|
||||
<item>Movenet Thunder</item>
|
||||
<item>Movenet MultiPose</item>
|
||||
<item>Posenet</item>
|
||||
</string-array>
|
||||
|
||||
<string-array name="tfe_pe_device_name">
|
||||
<item>CPU</item>
|
||||
<item>GPU</item>
|
||||
<item>NNAPI</item>
|
||||
</string-array>
|
||||
|
||||
<string-array name="tfe_pe_tracker_array">
|
||||
<item>Off</item>
|
||||
<item>BoundingBox</item>
|
||||
<item>Keypoint</item>
|
||||
</string-array>
|
||||
</resources>
|
@ -0,0 +1,18 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.PoseEstimation" parent="Theme.MaterialComponents.Light.NoActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_500</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/white</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
<item name="android:textColor">@color/black</item>
|
||||
<item name="android:textSize">16sp</item>
|
||||
</style>
|
||||
</resources>
|
@ -0,0 +1,30 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
buildscript {
|
||||
ext.kotlin_version = "1.9.21"
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
dependencies {
|
||||
classpath "com.android.tools.build:gradle:8.0.2"
|
||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||
|
||||
// NOTE: Do not place your application dependencies here; they belong
|
||||
// in the individual module build.gradle files
|
||||
}
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
maven {
|
||||
name 'ossrh-snapshot'
|
||||
url 'https://oss.sonatype.org/content/repositories/snapshots'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
task clean(type: Delete) {
|
||||
delete rootProject.buildDir
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app"s APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Automatically convert third-party libraries to use AndroidX
|
||||
android.enableJetifier=true
|
||||
# Kotlin code style for this project: "official" or "obsolete":
|
||||
kotlin.code.style=official
|
@ -0,0 +1,7 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
|
||||
networkTimeout=10000
|
||||
validateDistributionUrl=true
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
@ -0,0 +1,249 @@
|
||||
#!/bin/sh
|
||||
|
||||
#
|
||||
# Copyright © 2015-2021 the original authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
##############################################################################
|
||||
#
|
||||
# Gradle start up script for POSIX generated by Gradle.
|
||||
#
|
||||
# Important for running:
|
||||
#
|
||||
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
|
||||
# noncompliant, but you have some other compliant shell such as ksh or
|
||||
# bash, then to run this script, type that shell name before the whole
|
||||
# command line, like:
|
||||
#
|
||||
# ksh Gradle
|
||||
#
|
||||
# Busybox and similar reduced shells will NOT work, because this script
|
||||
# requires all of these POSIX shell features:
|
||||
# * functions;
|
||||
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
|
||||
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
|
||||
# * compound commands having a testable exit status, especially «case»;
|
||||
# * various built-in commands including «command», «set», and «ulimit».
|
||||
#
|
||||
# Important for patching:
|
||||
#
|
||||
# (2) This script targets any POSIX shell, so it avoids extensions provided
|
||||
# by Bash, Ksh, etc; in particular arrays are avoided.
|
||||
#
|
||||
# The "traditional" practice of packing multiple parameters into a
|
||||
# space-separated string is a well documented source of bugs and security
|
||||
# problems, so this is (mostly) avoided, by progressively accumulating
|
||||
# options in "$@", and eventually passing that to Java.
|
||||
#
|
||||
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
|
||||
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
|
||||
# see the in-line comments for details.
|
||||
#
|
||||
# There are tweaks for specific operating systems such as AIX, CygWin,
|
||||
# Darwin, MinGW, and NonStop.
|
||||
#
|
||||
# (3) This script is generated from the Groovy template
|
||||
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
|
||||
# within the Gradle project.
|
||||
#
|
||||
# You can find Gradle at https://github.com/gradle/gradle/.
|
||||
#
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
|
||||
# Resolve links: $0 may be a link
|
||||
app_path=$0
|
||||
|
||||
# Need this for daisy-chained symlinks.
|
||||
while
|
||||
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
||||
[ -h "$app_path" ]
|
||||
do
|
||||
ls=$( ls -ld "$app_path" )
|
||||
link=${ls#*' -> '}
|
||||
case $link in #(
|
||||
/*) app_path=$link ;; #(
|
||||
*) app_path=$APP_HOME$link ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# This is normally unused
|
||||
# shellcheck disable=SC2034
|
||||
APP_BASE_NAME=${0##*/}
|
||||
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD=maximum
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
} >&2
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
} >&2
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "$( uname )" in #(
|
||||
CYGWIN* ) cygwin=true ;; #(
|
||||
Darwin* ) darwin=true ;; #(
|
||||
MSYS* | MINGW* ) msys=true ;; #(
|
||||
NONSTOP* ) nonstop=true ;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD=$JAVA_HOME/jre/sh/java
|
||||
else
|
||||
JAVACMD=$JAVA_HOME/bin/java
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD=java
|
||||
if ! command -v java >/dev/null 2>&1
|
||||
then
|
||||
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
||||
case $MAX_FD in #(
|
||||
max*)
|
||||
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
MAX_FD=$( ulimit -H -n ) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
esac
|
||||
case $MAX_FD in #(
|
||||
'' | soft) :;; #(
|
||||
*)
|
||||
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
esac
|
||||
fi
|
||||
|
||||
# Collect all arguments for the java command, stacking in reverse order:
|
||||
# * args from the command line
|
||||
# * the main class name
|
||||
# * -classpath
|
||||
# * -D...appname settings
|
||||
# * --module-path (only if needed)
|
||||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
|
||||
|
||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||
if "$cygwin" || "$msys" ; then
|
||||
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
|
||||
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
|
||||
|
||||
JAVACMD=$( cygpath --unix "$JAVACMD" )
|
||||
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
for arg do
|
||||
if
|
||||
case $arg in #(
|
||||
-*) false ;; # don't mess with options #(
|
||||
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||
[ -e "$t" ] ;; #(
|
||||
*) false ;;
|
||||
esac
|
||||
then
|
||||
arg=$( cygpath --path --ignore --mixed "$arg" )
|
||||
fi
|
||||
# Roll the args list around exactly as many times as the number of
|
||||
# args, so each arg winds up back in the position where it started, but
|
||||
# possibly modified.
|
||||
#
|
||||
# NB: a `for` loop captures its iteration list before it begins, so
|
||||
# changing the positional parameters here affects neither the number of
|
||||
# iterations, nor the values presented in `arg`.
|
||||
shift # remove old arg
|
||||
set -- "$@" "$arg" # push replacement arg
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Collect all arguments for the java command:
|
||||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
|
||||
# and any embedded shellness will be escaped.
|
||||
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
|
||||
# treated as '${Hostname}' itself on the command line.
|
||||
|
||||
set -- \
|
||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||
-classpath "$CLASSPATH" \
|
||||
org.gradle.wrapper.GradleWrapperMain \
|
||||
"$@"
|
||||
|
||||
# Stop when "xargs" is not available.
|
||||
if ! command -v xargs >/dev/null 2>&1
|
||||
then
|
||||
die "xargs is not available"
|
||||
fi
|
||||
|
||||
# Use "xargs" to parse quoted args.
|
||||
#
|
||||
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
|
||||
#
|
||||
# In Bash we could simply go:
|
||||
#
|
||||
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
|
||||
# set -- "${ARGS[@]}" "$@"
|
||||
#
|
||||
# but POSIX shell has neither arrays nor command substitution, so instead we
|
||||
# post-process each arg (as a line of input to sed) to backslash-escape any
|
||||
# character that might be a shell metacharacter, then use eval to reverse
|
||||
# that process (while maintaining the separation between arguments), and wrap
|
||||
# the whole thing up as a single "set" statement.
|
||||
#
|
||||
# This will of course break if any of these variables contains a newline or
|
||||
# an unmatched quote.
|
||||
#
|
||||
|
||||
eval "set -- $(
|
||||
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
||||
xargs -n1 |
|
||||
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
||||
tr '\n' ' '
|
||||
)" '"$@"'
|
||||
|
||||
exec "$JAVACMD" "$@"
|
@ -0,0 +1,92 @@
|
||||
@rem
|
||||
@rem Copyright 2015 the original author or authors.
|
||||
@rem
|
||||
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@rem you may not use this file except in compliance with the License.
|
||||
@rem You may obtain a copy of the License at
|
||||
@rem
|
||||
@rem https://www.apache.org/licenses/LICENSE-2.0
|
||||
@rem
|
||||
@rem Unless required by applicable law or agreed to in writing, software
|
||||
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@rem See the License for the specific language governing permissions and
|
||||
@rem limitations under the License.
|
||||
@rem
|
||||
|
||||
@if "%DEBUG%"=="" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%"=="" set DIRNAME=.
|
||||
@rem This is normally unused
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
|
||||
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if %ERRORLEVEL% equ 0 goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if %ERRORLEVEL% equ 0 goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
set EXIT_CODE=%ERRORLEVEL%
|
||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
|
||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
|
||||
exit /b %EXIT_CODE%
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
@ -0,0 +1,8 @@
|
||||
## This file must *NOT* be checked into Version Control Systems,
|
||||
# as it contains information specific to your local configuration.
|
||||
#
|
||||
# Location of the SDK. This is only used by Gradle.
|
||||
# For customization when using a Version Control System, please read the
|
||||
# header note.
|
||||
#Mon Apr 21 08:12:48 CST 2025
|
||||
sdk.dir=/Users/ziyue/Library/Android/sdk
|
After Width: | Height: | Size: 602 KiB |
@ -0,0 +1,2 @@
|
||||
include ':app'
|
||||
rootProject.name = "TFLite Pose Estimation"
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@ -1,24 +0,0 @@
|
||||
package com.addd.xingdongli.ui;
|
||||
|
||||
import android.os.Bundle;
|
||||
import android.view.LayoutInflater;
|
||||
import android.view.View;
|
||||
import android.view.ViewGroup;
|
||||
import android.widget.Button;
|
||||
|
||||
import androidx.annotation.NonNull;
|
||||
import androidx.annotation.Nullable;
|
||||
import androidx.fragment.app.Fragment;
|
||||
|
||||
import com.addd.xingdongli.R;
|
||||
|
||||
public class HomeFragment extends Fragment {
|
||||
|
||||
@Override
|
||||
public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) {
|
||||
// 使用 fragment_home.xml 布局文件
|
||||
return inflater.inflate(R.layout.fragment_home, container, false);
|
||||
}
|
||||
|
||||
}
|
||||
|