Compare commits

...

127 Commits
dev ... master

Author SHA1 Message Date
pvqf6mep3 388b25d6b0 Merge pull request '完善' (#55) from majian_branch into master
2 years ago
大耳刮子 27b951823f 完善
2 years ago
pvqf6mep3 e09dae75b1 Merge pull request 'final version' (#54) from majian_branch into master
2 years ago
大耳刮子 725a4b28fd final
2 years ago
pvqf6mep3 4ff5503da1 Merge pull request 'final' (#53) from majian_branch into master
2 years ago
大耳刮子 eb572bfaf3 final
2 years ago
大耳刮子 094188df48 完善
2 years ago
pvqf6mep3 ca6c0289ea 同意
2 years ago
大耳刮子 86b1bef153 自评报告
2 years ago
大耳刮子 94fb3adabc doc
2 years ago
pvqf6mep3 3ce5f68bdc 同意
2 years ago
Thinner123 e15e21206e 文档修改
2 years ago
pvqf6mep3 bcc8ea078d 合并
2 years ago
大耳刮子 729e5aaac6 ppt
2 years ago
pvqf6mep3 6370eaabf9 同意
2 years ago
大耳刮子 2a6711d22d vedio show
2 years ago
大耳刮子 9529ed5ca0 merge
2 years ago
大耳刮子 faa391adba doc“
2 years ago
pvqf6mep3 470e05f375 同意
2 years ago
JackyMa 8b8b8c2094 merge
2 years ago
JackyMa 9746d85d34 final code
2 years ago
大耳刮子 2584585d51 需求分析
2 years ago
大耳刮子 bf52c37a04 用例图
2 years ago
pvqf6mep3 db4c8f7a7c 同意
2 years ago
大耳刮子 880485a099 详细设计
2 years ago
pvqf6mep3 e796e5405b 同意
2 years ago
大耳刮子 1b5b8929a0 精化设计类
2 years ago
pvqf6mep3 e27a25a43b 同意
2 years ago
大耳刮子 2db888ffef doc
2 years ago
pvqf6mep3 b6e6b884ca ac
2 years ago
大耳刮子 6c11103cbf uml
2 years ago
pvqf6mep3 37e6653753 accept
2 years ago
Thinner123 1c1221f89d Astar
2 years ago
pvqf6mep3 182b92fb7d 同意
2 years ago
pvqf6mep3 e9904b96bb 同意
2 years ago
hackii 0e7846f5e8 add
2 years ago
JackyMa 6ce3f340c6 test yolo success
2 years ago
pvqf6mep3 16f7cec4dd 同意
2 years ago
JackyMa 570867305f solve crash
2 years ago
JackyMa 7455da0ae8 Merge branch 'master' into majian_branch
2 years ago
JackyMa 6f6d8f374d add
2 years ago
pvqf6mep3 cc3622a0c1 同意
2 years ago
JackyMa 333f28d634 test yolo
2 years ago
JackyMa b7ba593ee9 yolo
2 years ago
hackii f0b59ca147 new
2 years ago
JackyMa 5bd4af7fda yolov6
2 years ago
pvqf6mep3 de5e6c02f5 同意
3 years ago
大耳刮子 9304605ba9 success test vedio
3 years ago
大耳刮子 4cd6a736f2 Merge branch 'master' into majian_branch
3 years ago
大耳刮子 4cc1d21c74 add
3 years ago
pvqf6mep3 a31e97dcca 同意
3 years ago
13195980010 dbc640fdb5 re
3 years ago
pvqf6mep3 a4e554efd7 同意
3 years ago
大耳刮子 e39ee0bc6c add login ui
3 years ago
pvqf6mep3 f9f7c210ee 同意
3 years ago
大耳刮子 73d0367db7 Merge branch 'gehongbo_branch' into majian_branch
3 years ago
大耳刮子 897269e030 rebuild
3 years ago
大耳刮子 23c319bd40 add
3 years ago
JackyMa a178489f07 replace user.name
3 years ago
pvqf6mep3 f9615030f2 Merge pull request 'add vedio' (#33) from majian_branch into master
3 years ago
大耳刮子 4cb93f6e3b add vedio
3 years ago
13195980010 270b5519ae ui
3 years ago
大耳刮子 5d165fae06 Merge branch 'master' into majian_branch
3 years ago
pvqf6mep3 6c204cbc6f tongyi
3 years ago
大耳刮子 b50292b3cd change
3 years ago
13195980010 d2a1fa26e3 telloui
3 years ago
pvqf6mep3 deca08193c 同意
3 years ago
大耳刮子 aaf38a9a28 517
3 years ago
pvqf6mep3 96c5ec4aef 同意
3 years ago
大耳刮子 062a36cccb 优化界面
3 years ago
大耳刮子 45153ce6a6 add AutoMove v1.0
3 years ago
pvqf6mep3 10b0e7928d agree
3 years ago
大耳刮子 62218ad955 add vedio laber
3 years ago
pvqf6mep3 263828dcda 同意
3 years ago
大耳刮子 0b4d862d42 add dashboard
3 years ago
pvqf6mep3 3fb4b8a161 同意
3 years ago
大耳刮子 4e494534f6 优化布局
3 years ago
pvqf6mep3 402a15165d 同意
3 years ago
大耳刮子 ecd8512dc7 优化布局
3 years ago
pvqf6mep3 2db4600cbb 同意
3 years ago
大耳刮子 e4f867170f 添加界面切换模式
3 years ago
pvqf6mep3 4d814fb263 同意
3 years ago
大耳刮子 0f4d911bb9 Merge branch 'majian_branch' of https://bdgit.educoder.net/pvqf6mep3/Air_Ground_CEC into majian_branch
3 years ago
大耳刮子 ebbcbb594c add joystick
3 years ago
pvqf6mep3 39475efbd2 同意
3 years ago
pvqf6mep3 ea9e707942 Merge pull request '修改错别字' (#20) from majian_branch into master
3 years ago
pvqf6mep3 a1726087de 同意
3 years ago
大耳刮子 d67d479ef1 优化界面
3 years ago
大耳刮子 70dd42f080 修改文件名
3 years ago
大耳刮子 ad15ac7150 修改错别字
3 years ago
pvqf6mep3 a68f2bd6a3 Merge pull request '修改 doc' (#18) from majian_branch into master
3 years ago
大耳刮子 7c5c347e6b add doc
3 years ago
pvqf6mep3 4d29148277 同意
3 years ago
大耳刮子 c6dc952dcf 修改
3 years ago
pvqf6mep3 3031461e45 同意
3 years ago
大耳刮子 9ecaa01651 删除
3 years ago
pvqf6mep3 a248932619 同意
3 years ago
大耳刮子 4b2fddc20e 完善doc文档
3 years ago
pvqf6mep3 2af3af8948 同意
3 years ago
13195980010 bd38d9d3ef delete txt in model
3 years ago
13195980010 fdb5bed0d9 doc and vs
3 years ago
13195980010 fa92e48b6b doc upload
3 years ago
pvqf6mep3 9af234741f 同意
3 years ago
pvqf6mep3 aa1423744b 同意
3 years ago
大耳刮子 f9de4ae955 优化键盘控制界面
3 years ago
13195980010 cf0b3aaaf5 keyboardupdate
3 years ago
13195980010 9f032f7f67 delete debug
3 years ago
13195980010 67f288898d flip
3 years ago
pvqf6mep3 22bda90868 同意合并
3 years ago
大耳刮子 545537255b add twist
3 years ago
pvqf6mep3 54a70739bc 同意
3 years ago
大耳刮子 01c3158d4e 键盘控制ros,在 gui 界面中实现
3 years ago
pvqf6mep3 0f1af0ed0a 同意
3 years ago
大耳刮子 72649e60fb keyboard teleop
3 years ago
大耳刮子 8a4edf9eae Merge branch 'majian_branch' of https://bdgit.educoder.net/pvqf6mep3/Air_Ground_CEC into majian_branch
3 years ago
大耳刮子 b89f13c49f Merge branch 'master' into majian_branch
3 years ago
pvqf6mep3 e597d7d35e 同意合并
3 years ago
gehongbo 40b49fdf68 uiqt
3 years ago
大耳刮子 cc4aa84ce7 test
3 years ago
大耳刮子 1df3ed7ed5 test
3 years ago
大耳刮子 15e612ae1e 同步master
3 years ago
pvqf6mep3 929b292e59 同意
3 years ago
gehongbo f8bbafdc2d add
3 years ago
gehongbo 4aa6de0b28 Merge branch 'master' into gehongbo_branch
3 years ago
gehongbo 68e9ccf753 delete
3 years ago
pvqf6mep3 9ad4bec99e 合并
3 years ago
gehongbo 947f7d5d0c cv
3 years ago

@ -4,130 +4,42 @@ project(Air_Ground_CEC)
## Compile as C++11, supported in ROS Kinetic and newer
add_compile_options(-std=c++11)
## Find catkin macros and libraries
## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
## is used, also find other catkin packages
find_package(catkin REQUIRED COMPONENTS
roscpp
std_msgs
cv_bridge
sensor_msgs
image_transport
)
## System dependencies are found with CMake's conventions
# find_package(Boost REQUIRED COMPONENTS system)
## Uncomment this if the package has a setup.py. This macro ensures
## modules and global scripts declared therein get installedsudo sh -c 'echo "/usr/local/lib" > /etc/ld.so.conf.d/opencv.conf'
## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
# catkin_python_setup()
################################################
## Declare ROS messages, services and actions ##
################################################
## To declare and build messages, services or actions from within this
## package, follow these steps:
## * Let MSG_DEP_SET be the set of packages whose message types you use in
## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
## * In the file package.xml:
## * add a build_depend tag for "message_generation"
## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
## but can be declared for certainty nonetheless:
## * add a exec_depend tag for "message_runtime"
## * In this file (CMakeLists.txt):
## * add "message_generation" and every package in MSG_DEP_SET to
## find_package(catkin REQUIRED COMPONENTS ...)
## * add "message_runtime" and every package in MSG_DEP_SET to
## catkin_package(CATKIN_DEPENDS ...)
## * uncomment the add_*_files sections below as needed
## and list every .msg/.srv/.action file to be processed
## * uncomment the generate_messages entry below
## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
## Generate messages in the 'msg' folder
# add_message_files(
# FILES
# Message1.msg
# Message2.msg
# )
## Generate services in the 'srv' folder
# add_service_files(
# FILES
# Service1.srv
# Service2.srv
# )
## Generate actions in the 'action' folder
# add_action_files(
# FILES
# Action1.action
# Action2.action
# )
## Generate added messages and services with any dependencies listed here
# generate_messages(
# DEPENDENCIES
# std_msgs # Or other packages containing msgs
# )
################################################
## Declare ROS dynamic reconfigure parameters ##
################################################
## To declare and build dynamic reconfigure parameters within this
## package, follow these steps:
## * In the file package.xml:
## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
## * In this file (CMakeLists.txt):
## * add "dynamic_reconfigure" to
## find_package(catkin REQUIRED COMPONENTS ...)
## * uncomment the "generate_dynamic_reconfigure_options" section below
## and list every .cfg file to be processed
## Generate dynamic reconfigure parameters in the 'cfg' folder
# generate_dynamic_reconfigure_options(
# cfg/DynReconf1.cfg
# cfg/DynReconf2.cfg
# )
###################################
## catkin specific configuration ##
###################################
## The catkin_package macro generates cmake config files for your package
## Declare things to be passed to dependent projects
## INCLUDE_DIRS: uncomment this if your package contains header files
## LIBRARIES: libraries you create in this project that dependent projects also need
## CATKIN_DEPENDS: catkin_packages dependent projects also need
## DEPENDS: system dependencies of this project that dependent projects also need
catkin_package(
# INCLUDE_DIRS include
# LIBRARIES Air_Ground_CEC
# CATKIN_DEPENDS roscpp
# DEPENDS system_lib
)
###########
## Build ##
###########
#set(OpenCV_DIR /usr/local/share/opencv4)
find_package(OpenCV REQUIRED)
find_package(Qt5 REQUIRED COMPONENTS Widgets )
set(CMAKE_AUTOMOC ON)
set(CMAKE_AUTOUIC ON)
set(CMAKE_INCLUDE_CURRENT_DIR ON)
set(SOURCES
src/main.cpp
src/mainwindow.cpp
include/Air_Ground_CEC/mainwindow.hpp
src/mainwindow.ui
src/qnode.cpp
include/Air_Ground_CEC/qnode.hpp
)
file(GLOB QT_FORMS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} src/ui/*.ui)
file(GLOB QT_RESOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} src/resources/*.qrc)
file(GLOB_RECURSE QT_MOC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
FOLLOW_SYMLINKS src/include/Air_Ground_CEC/*.hpp *.h)
file(GLOB_RECURSE QT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
FOLLOW_SYMLINKS src/*.cpp)
QT5_ADD_RESOURCES(QT_RESOURCES_CPP ${QT_RESOURCES})
QT5_WRAP_UI(QT_FORMS_HPP ${QT_FORMS})
QT5_WRAP_CPP(QT_MOC_HPP ${QT_MOC})
## Specify additional locations of header files
## Your package locations should be listed before other locations
include_directories(
@ -136,94 +48,14 @@ include_directories(
${OpenCV_INCLUDE_DIRS}
)
add_executable(Air_Ground_CEC ${SOURCES})
add_executable(Air_Ground_CEC ${QRC_FILES} ${QT_SOURCES}
${QT_RESOURCES_CPP} ${QT_FORMS_HPP} ${QT_MOC_HPP})
target_link_libraries(Air_Ground_CEC
Qt5::Widgets
${catkin_LIBRARIES}
${OpenCV_LIBRARIES}
)
## Declare a C++ library
# add_library(${PROJECT_NAME}
# src/${PROJECT_NAME}/Air_Ground_CEC.cpp
# )
## Add cmake target dependencies of the library
## as an example, code may need to be generated before libraries
## either from message generation or dynamic reconfigure
# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
## Declare a C++ executable
## With catkin_make all packages are built within a single CMake context
## The recommended prefix ensures that target names across packages don't collide
# add_executable(${PROJECT_NAME}_node src/Air_Ground_CEC_node.cpp)
## Rename C++ executable without prefix
## The above recommended prefix causes long target names, the following renames the
## target back to the shorter version for ease of user use
## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
## Add cmake target dependencies of the executable
## same as for the library above
# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
## Specify libraries to link a library or executable target against
# target_link_libraries(${PROJECT_NAME}_node
# ${catkin_LIBRARIES}
# )
#############
## Install ##
#############
# all install targets should use catkin DESTINATION variables
# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
## Mark executable scripts (Python etc.) for installation
## in contrast to setup.py, you can choose the destination
# catkin_install_python(PROGRAMS
# scripts/my_python_script
# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
# )
## Mark executables for installation
## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
# install(TARGETS ${PROJECT_NAME}_node
# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
# )
## Mark libraries for installation
## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
# install(TARGETS ${PROJECT_NAME}
# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
# )
## Mark cpp header files for installation
# install(DIRECTORY include/${PROJECT_NAME}/
# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
# FILES_MATCHING PATTERN "*.h"
# PATTERN ".svn" EXCLUDE
# )
## Mark other files for installation (e.g. launch and bag files, etc.)
# install(FILES
# # myfile1
# # myfile2
# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
# )
#############
## Testing ##
#############
## Add gtest based cpp test target and link libraries
# catkin_add_gtest(${PROJECT_NAME}-test test/test_Air_Ground_CEC.cpp)
# if(TARGET ${PROJECT_NAME}-test)
# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
# endif()
## Add folders to be run by python nosetests
# catkin_add_nosetests(test)

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 MiB

Binary file not shown.

Binary file not shown.

@ -1,28 +0,0 @@
#ifndef MAINWINDOW_H
#define MAINWINDOW_H
#include <QMainWindow>
#include <QImage>
#include <QMutex>
namespace Ui {
class MainWindow;
}
class MainWindow : public QMainWindow
{
Q_OBJECT
public:
explicit MainWindow(QWidget *parent = nullptr);
~MainWindow();
void updateLogcamera();
void displayCamera(const QImage& image);
private:
Ui::MainWindow *ui;
//QNode qnode;
QImage qimage_;
mutable QMutex qinmage_mutex_;
};
#endif // MAINWINDOW_H

@ -1,8 +0,0 @@
#include <opencv4/opencv2/core.hpp>
#include <opencv4/opencv2/highgui.hpp>
#include <opencv4/opencv2/imgproc.hpp>
#include <image_transport/image_transport.h>
#include <cv_bridge/cv_bridge.h>
#include <QImage>

@ -1 +0,0 @@
create

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -0,0 +1,803 @@
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// STL A* Search implementation
// (C)2001 Justin Heyes-Jones
//
// This uses my A* code to solve the 8-puzzle
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <stdio.h>
#include <assert.h>
#include <new>
#include <ctype.h>
using namespace std;
// Configuration
#define NUM_TIMES_TO_RUN_SEARCH 1
#define DISPLAY_SOLUTION_FORWARDS 1
#define DISPLAY_SOLUTION_BACKWARDS 0
#define DISPLAY_SOLUTION_INFO 1
#define DEBUG_LISTS 0
// AStar search class
#include "stlastar.h" // See header for copyright and usage information
// Global data
#define BOARD_WIDTH (3)
#define BOARD_HEIGHT (3)
#define GM_TILE (-1)
#define GM_SPACE (0)
#define GM_OFF_BOARD (1)
// Definitions
// To use the search class you must define the following calls...
// Data
// Your own state space information
// Functions
// (Optional) Constructor.
// Nodes are created by the user, so whether you use a
// constructor with parameters as below, or just set the object up after the
// constructor, is up to you.
//
// (Optional) Destructor.
// The destructor will be called if you create one. You
// can rely on the default constructor unless you dynamically allocate something in
// your data
//
// float GoalDistanceEstimate( PuzzleState &nodeGoal );
// Return the estimated cost to goal from this node (pass reference to goal node)
//
// bool IsGoal( PuzzleState &nodeGoal );
// Return true if this node is the goal.
//
// bool GetSuccessors( AStarSearch<PuzzleState> *astarsearch );
// For each successor to this state call the AStarSearch's AddSuccessor call to
// add each one to the current search - return false if you are out of memory and the search
// will fail
//
// float GetCost( PuzzleState *successor );
// Return the cost moving from this state to the state of successor
//
// bool IsSameState( PuzzleState &rhs );
// Return true if the provided state is the same as this state
// Here the example is the 8-puzzle state ...
class PuzzleState
{
public:
// defs
typedef enum
{
TL_SPACE,
TL_1,
TL_2,
TL_3,
TL_4,
TL_5,
TL_6,
TL_7,
TL_8
} TILE;
// data
static TILE g_goal[ BOARD_WIDTH*BOARD_HEIGHT];
static TILE g_start[ BOARD_WIDTH*BOARD_HEIGHT];
// the tile data for the 8-puzzle
TILE tiles[ BOARD_WIDTH*BOARD_HEIGHT ];
// member functions
PuzzleState() {
memcpy( tiles, g_goal, sizeof( TILE ) * BOARD_WIDTH * BOARD_HEIGHT );
}
PuzzleState( TILE *param_tiles )
{
memcpy( tiles, param_tiles, sizeof( TILE ) * BOARD_WIDTH * BOARD_HEIGHT );
}
float GoalDistanceEstimate( PuzzleState &nodeGoal );
bool IsGoal( PuzzleState &nodeGoal );
bool GetSuccessors( AStarSearch<PuzzleState> *astarsearch, PuzzleState *parent_node );
float GetCost( PuzzleState &successor );
bool IsSameState( PuzzleState &rhs );
void PrintNodeInfo();
private:
// User stuff - Just add what you need to help you write the above functions...
void GetSpacePosition( PuzzleState *pn, int *rx, int *ry );
bool LegalMove( TILE *StartTiles, TILE *TargetTiles, int spx, int spy, int tx, int ty );
int GetMap( int x, int y, TILE *tiles );
};
// Goal state
PuzzleState::TILE PuzzleState::g_goal[] =
{
TL_1,
TL_2,
TL_3,
TL_8,
TL_SPACE,
TL_4,
TL_7,
TL_6,
TL_5,
};
// Some nice Start states
PuzzleState::TILE PuzzleState::g_start[] =
{
// Three example start states from Bratko's Prolog Programming for Artificial Intelligence
#if 1
// ex a - 4 steps
TL_1 ,
TL_3 ,
TL_4 ,
TL_8 ,
TL_SPACE ,
TL_2 ,
TL_7 ,
TL_6 ,
TL_5 ,
#elif 0
// ex b - 5 steps
TL_2 ,
TL_8 ,
TL_3 ,
TL_1 ,
TL_6 ,
TL_4 ,
TL_7 ,
TL_SPACE ,
TL_5 ,
#elif 0
// ex c - 18 steps
TL_2 ,
TL_1 ,
TL_6 ,
TL_4 ,
TL_SPACE ,
TL_8 ,
TL_7 ,
TL_5 ,
TL_3 ,
#elif 0
// nasty one - doesn't solve
TL_6 ,
TL_3 ,
TL_SPACE ,
TL_4 ,
TL_8 ,
TL_5 ,
TL_7 ,
TL_2 ,
TL_1 ,
#elif 0
// sent by email - does work though
TL_1 , TL_2 , TL_3 ,
TL_4 , TL_5 , TL_6 ,
TL_8 , TL_7 , TL_SPACE ,
// from http://www.cs.utexas.edu/users/novak/asg-8p.html
//Goal: Easy: Medium: Hard: Worst:
//1 2 3 1 3 4 2 8 1 2 8 1 5 6 7
//8 4 8 6 2 4 3 4 6 3 4 8
//7 6 5 7 5 7 6 5 7 5 3 2 1
#elif 0
// easy 5
TL_1 ,
TL_3 ,
TL_4 ,
TL_8 ,
TL_6 ,
TL_2 ,
TL_7 ,
TL_SPACE ,
TL_5 ,
#elif 0
// medium 9
TL_2 ,
TL_8 ,
TL_1 ,
TL_SPACE ,
TL_4 ,
TL_3 ,
TL_7 ,
TL_6 ,
TL_5 ,
#elif 0
// hard 12
TL_2 ,
TL_8 ,
TL_1 ,
TL_4 ,
TL_6 ,
TL_3 ,
TL_SPACE ,
TL_7 ,
TL_5 ,
#elif 0
// worst 30
TL_5 ,
TL_6 ,
TL_7 ,
TL_4 ,
TL_SPACE ,
TL_8 ,
TL_3 ,
TL_2 ,
TL_1 ,
#elif 0
// 123
// 784
// 65
// two move simple board
TL_1 ,
TL_2 ,
TL_3 ,
TL_7 ,
TL_8 ,
TL_4 ,
TL_SPACE ,
TL_6 ,
TL_5 ,
#elif 0
// a1 b2 c3 d4 e5 f6 g7 h8
//C3,Blank,H8,A1,G8,F6,E5,D4,B2
TL_3 ,
TL_SPACE ,
TL_8 ,
TL_1 ,
TL_8 ,
TL_6 ,
TL_5 ,
TL_4 ,
TL_2 ,
#endif
};
bool PuzzleState::IsSameState( PuzzleState &rhs )
{
for( int i=0; i<(BOARD_HEIGHT*BOARD_WIDTH); i++ )
{
if( tiles[i] != rhs.tiles[i] )
{
return false;
}
}
return true;
}
void PuzzleState::PrintNodeInfo()
{
char str[100];
sprintf( str, "%c %c %c\n%c %c %c\n%c %c %c\n",
tiles[0] + '0',
tiles[1] + '0',
tiles[2] + '0',
tiles[3] + '0',
tiles[4] + '0',
tiles[5] + '0',
tiles[6] + '0',
tiles[7] + '0',
tiles[8] + '0'
);
cout << str;
}
// Here's the heuristic function that estimates the distance from a PuzzleState
// to the Goal.
float PuzzleState::GoalDistanceEstimate( PuzzleState &nodeGoal )
{
// Nilsson's sequence score
int i, cx, cy, ax, ay, h = 0, s, t;
// given a tile this returns the tile that should be clockwise
TILE correct_follower_to[ BOARD_WIDTH * BOARD_HEIGHT ] =
{
TL_SPACE, // always wrong
TL_2,
TL_3,
TL_4,
TL_5,
TL_6,
TL_7,
TL_8,
TL_1,
};
// given a table index returns the index of the tile that is clockwise to it 3*3 only
int clockwise_tile_of[ BOARD_WIDTH * BOARD_HEIGHT ] =
{
1,
2, // 012
5, // 345
0, // 678
-1, // never called with center square
8,
3,
6,
7
};
int tile_x[ BOARD_WIDTH * BOARD_HEIGHT ] =
{
/* TL_SPACE */ 1,
/* TL_1 */ 0,
/* TL_2 */ 1,
/* TL_3 */ 2,
/* TL_4 */ 2,
/* TL_5 */ 2,
/* TL_6 */ 1,
/* TL_7 */ 0,
/* TL_8 */ 0,
};
int tile_y[ BOARD_WIDTH * BOARD_HEIGHT ] =
{
/* TL_SPACE */ 1,
/* TL_1 */ 0,
/* TL_2 */ 0,
/* TL_3 */ 0,
/* TL_4 */ 1,
/* TL_5 */ 2,
/* TL_6 */ 2,
/* TL_7 */ 2,
/* TL_8 */ 1,
};
s=0;
// score 1 point if centre is not correct
if( tiles[(BOARD_HEIGHT*BOARD_WIDTH)/2] != nodeGoal.tiles[(BOARD_HEIGHT*BOARD_WIDTH)/2] )
{
s = 1;
}
for( i=0; i<(BOARD_HEIGHT*BOARD_WIDTH); i++ )
{
// this loop adds up the totaldist element in h and
// the sequence score in s
// the space does not count
if( tiles[i] == TL_SPACE )
{
continue;
}
// get correct x and y of this tile
cx = tile_x[tiles[i]];
cy = tile_y[tiles[i]];
// get actual
ax = i % BOARD_WIDTH;
ay = i / BOARD_WIDTH;
// add manhatten distance to h
h += abs( cx-ax );
h += abs( cy-ay );
// no s score for center tile
if( (ax == (BOARD_WIDTH/2)) && (ay == (BOARD_HEIGHT/2)) )
{
continue;
}
// score 2 points if not followed by successor
if( correct_follower_to[ tiles[i] ] != tiles[ clockwise_tile_of[ i ] ] )
{
s += 2;
}
}
// mult by 3 and add to h
t = h + (3*s);
return (float) t;
}
bool PuzzleState::IsGoal( PuzzleState &nodeGoal )
{
return IsSameState( nodeGoal );
}
// Helper
// Return the x and y position of the space tile
void PuzzleState::GetSpacePosition( PuzzleState *pn, int *rx, int *ry )
{
int x,y;
for( y=0; y<BOARD_HEIGHT; y++ )
{
for( x=0; x<BOARD_WIDTH; x++ )
{
if( pn->tiles[(y*BOARD_WIDTH)+x] == TL_SPACE )
{
*rx = x;
*ry = y;
return;
}
}
}
assert( false && "Something went wrong. There's no space on the board" );
}
int PuzzleState::GetMap( int x, int y, TILE *tiles )
{
if( x < 0 ||
x >= BOARD_WIDTH ||
y < 0 ||
y >= BOARD_HEIGHT
)
return GM_OFF_BOARD;
if( tiles[(y*BOARD_WIDTH)+x] == TL_SPACE )
{
return GM_SPACE;
}
return GM_TILE;
}
// Given a node set of tiles and a set of tiles to move them into, do the move as if it was on a tile board
// note : returns false if the board wasn't changed, and simply returns the tiles as they were in the target
// spx and spy is the space position while tx and ty is the target move from position
bool PuzzleState::LegalMove( TILE *StartTiles, TILE *TargetTiles, int spx, int spy, int tx, int ty )
{
int t;
if( GetMap( spx, spy, StartTiles ) == GM_SPACE )
{
if( GetMap( tx, ty, StartTiles ) == GM_TILE )
{
// copy tiles
for( t=0; t<(BOARD_HEIGHT*BOARD_WIDTH); t++ )
{
TargetTiles[t] = StartTiles[t];
}
TargetTiles[ (ty*BOARD_WIDTH)+tx ] = StartTiles[ (spy*BOARD_WIDTH)+spx ];
TargetTiles[ (spy*BOARD_WIDTH)+spx ] = StartTiles[ (ty*BOARD_WIDTH)+tx ];
return true;
}
}
return false;
}
// This generates the successors to the given PuzzleState. It uses a helper function called
// AddSuccessor to give the successors to the AStar class. The A* specific initialisation
// is done for each node internally, so here you just set the state information that
// is specific to the application
bool PuzzleState::GetSuccessors( AStarSearch<PuzzleState> *astarsearch, PuzzleState *parent_node )
{
PuzzleState NewNode;
int sp_x,sp_y;
GetSpacePosition( this, &sp_x, &sp_y );
bool ret;
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x, sp_y-1 ) == true )
{
ret = astarsearch->AddSuccessor( NewNode );
if( !ret ) return false;
}
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x, sp_y+1 ) == true )
{
ret = astarsearch->AddSuccessor( NewNode );
if( !ret ) return false;
}
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x-1, sp_y ) == true )
{
ret = astarsearch->AddSuccessor( NewNode );
if( !ret ) return false;
}
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x+1, sp_y ) == true )
{
ret = astarsearch->AddSuccessor( NewNode );
if( !ret ) return false;
}
return true;
}
// given this node, what does it cost to move to successor. In the case
// of our map the answer is the map terrain value at this node since that is
// conceptually where we're moving
float PuzzleState::GetCost( PuzzleState &successor )
{
return 1.0f; // I love it when life is simple
}
// Main
int puzzle( int argc, char *argv[] )
{
cout << "STL A* 8-puzzle solver implementation\n(C)2001 Justin Heyes-Jones\n";
if( argc > 1 )
{
int i = 0;
int c;
while( (c = argv[1][i]) )
{
if( isdigit( c ) )
{
int num = (c - '0');
PuzzleState::g_start[i] = static_cast<PuzzleState::TILE>(num);
}
i++;
}
}
// Create an instance of the search class...
AStarSearch<PuzzleState> astarsearch;
int NumTimesToSearch = NUM_TIMES_TO_RUN_SEARCH;
while( NumTimesToSearch-- )
{
// Create a start state
PuzzleState nodeStart( PuzzleState::g_start );
// Define the goal state
PuzzleState nodeEnd( PuzzleState::g_goal );
// Set Start and goal states
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
unsigned int SearchState;
unsigned int SearchSteps = 0;
do
{
SearchState = astarsearch.SearchStep();
#if DEBUG_LISTS
float f,g,h;
cout << "Search step " << SearchSteps << endl;
cout << "Open:\n";
PuzzleState *p = astarsearch.GetOpenListStart( f,g,h );
while( p )
{
((PuzzleState *)p)->PrintNodeInfo();
cout << "f: " << f << " g: " << g << " h: " << h << "\n\n";
p = astarsearch.GetOpenListNext( f,g,h );
}
cout << "Closed:\n";
p = astarsearch.GetClosedListStart( f,g,h );
while( p )
{
p->PrintNodeInfo();
cout << "f: " << f << " g: " << g << " h: " << h << "\n\n";
p = astarsearch.GetClosedListNext( f,g,h );
}
#endif
// Test cancel search
#if 0
int StepCount = astarsearch.GetStepCount();
if( StepCount == 10 )
{
astarsearch.CancelSearch();
}
#endif
SearchSteps++;
}
while( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_SEARCHING );
if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_SUCCEEDED )
{
#if DISPLAY_SOLUTION_FORWARDS
cout << "Search found goal state\n";
#endif
PuzzleState *node = astarsearch.GetSolutionStart();
#if DISPLAY_SOLUTION_FORWARDS
cout << "Displaying solution\n";
#endif
int steps = 0;
#if DISPLAY_SOLUTION_FORWARDS
node->PrintNodeInfo();
cout << endl;
#endif
for( ;; )
{
node = astarsearch.GetSolutionNext();
if( !node )
{
break;
}
#if DISPLAY_SOLUTION_FORWARDS
node->PrintNodeInfo();
cout << endl;
#endif
steps ++;
};
#if DISPLAY_SOLUTION_FORWARDS
// todo move step count into main algorithm
cout << "Solution steps " << steps << endl;
#endif
////////////
node = astarsearch.GetSolutionEnd();
#if DISPLAY_SOLUTION_BACKWARDS
cout << "Displaying reverse solution\n";
#endif
steps = 0;
node->PrintNodeInfo();
cout << endl;
for( ;; )
{
node = astarsearch.GetSolutionPrev();
if( !node )
{
break;
}
#if DISPLAY_SOLUTION_BACKWARDS
node->PrintNodeInfo();
cout << endl;
#endif
steps ++;
};
#if DISPLAY_SOLUTION_BACKWARDS
cout << "Solution steps " << steps << endl;
#endif
//////////////
// Once you're done with the solution you can free the nodes up
astarsearch.FreeSolutionNodes();
}
else if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_FAILED )
{
#if DISPLAY_SOLUTION_INFO
cout << "Search terminated. Did not find goal state\n";
#endif
}
else if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_OUT_OF_MEMORY )
{
#if DISPLAY_SOLUTION_INFO
cout << "Search terminated. Out of memory\n";
#endif
}
// Display the number of loops the search went through
#if DISPLAY_SOLUTION_INFO
cout << "SearchSteps : " << astarsearch.GetStepCount() << endl;
#endif
}
return 0;
}

@ -0,0 +1,343 @@
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// STL A* Search implementation
// (C)2001 Justin Heyes-Jones
//
// Finding a path on a simple grid maze
// This shows how to do shortest path finding using A*
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include "stlastar.h" // See header for copyright and usage information
#include <iostream>
#include <stdio.h>
#include <math.h>
#define DEBUG_LISTS 0
#define DEBUG_LIST_LENGTHS_ONLY 0
using namespace std;
// Global data
// The world map
const int MAP_WIDTH = 20;
const int MAP_HEIGHT = 20;
int world_map[ MAP_WIDTH * MAP_HEIGHT ] =
{
// 0001020304050607080910111213141516171819
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // 00
1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,1, // 01
1,9,9,1,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 02
1,9,9,1,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 03
1,9,1,1,1,1,9,9,1,9,1,9,1,1,1,1,9,9,1,1, // 04
1,9,1,1,9,1,1,1,1,9,1,1,1,1,9,1,1,1,1,1, // 05
1,9,9,9,9,1,1,1,1,1,1,9,9,9,9,1,1,1,1,1, // 06
1,9,9,9,9,9,9,9,9,1,1,1,9,9,9,9,9,9,9,1, // 07
1,9,1,1,1,1,1,1,1,1,1,9,1,1,1,1,1,1,1,1, // 08
1,9,1,9,9,9,9,9,9,9,1,1,9,9,9,9,9,9,9,1, // 09
1,9,1,1,1,1,9,1,1,9,1,1,1,1,1,1,1,1,1,1, // 10
1,9,9,9,9,9,1,9,1,9,1,9,9,9,9,9,1,1,1,1, // 11
1,9,1,9,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 12
1,9,1,9,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 13
1,9,1,1,1,1,9,9,1,9,1,9,1,1,1,1,9,9,1,1, // 14
1,9,1,1,9,1,1,1,1,9,1,1,1,1,9,1,1,1,1,1, // 15
1,9,9,9,9,1,1,1,1,1,1,9,9,9,9,1,1,1,1,1, // 16
1,1,9,9,9,9,9,9,9,1,1,1,9,9,9,1,9,9,9,9, // 17
1,9,1,1,1,1,1,1,1,1,1,9,1,1,1,1,1,1,1,1, // 18
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // 19
};
// map helper functions
int GetMap( int x, int y )
{
if( x < 0 ||
x >= MAP_WIDTH ||
y < 0 ||
y >= MAP_HEIGHT
)
{
return 9;
}
return world_map[(y*MAP_WIDTH)+x];
}
// Definitions
class MapSearchNode
{
public:
int x; // the (x,y) positions of the node
int y;
MapSearchNode() { x = y = 0; }
MapSearchNode( int px, int py ) { x=px; y=py; }
float GoalDistanceEstimate( MapSearchNode &nodeGoal );
bool IsGoal( MapSearchNode &nodeGoal );
bool GetSuccessors( AStarSearch<MapSearchNode> *astarsearch, MapSearchNode *parent_node );
float GetCost( MapSearchNode &successor );
bool IsSameState( MapSearchNode &rhs );
void PrintNodeInfo();
};
bool MapSearchNode::IsSameState( MapSearchNode &rhs )
{
// same state in a maze search is simply when (x,y) are the same
if( (x == rhs.x) &&
(y == rhs.y) )
{
return true;
}
else
{
return false;
}
}
void MapSearchNode::PrintNodeInfo()
{
char str[100];
sprintf( str, "Node position : (%d,%d)\n", x,y );
cout << str;
}
// Here's the heuristic function that estimates the distance from a Node
// to the Goal.
float MapSearchNode::GoalDistanceEstimate( MapSearchNode &nodeGoal )
{
return abs(x - nodeGoal.x) + abs(y - nodeGoal.y);
}
bool MapSearchNode::IsGoal( MapSearchNode &nodeGoal )
{
if( (x == nodeGoal.x) &&
(y == nodeGoal.y) )
{
return true;
}
return false;
}
// This generates the successors to the given Node. It uses a helper function called
// AddSuccessor to give the successors to the AStar class. The A* specific initialisation
// is done for each node internally, so here you just set the state information that
// is specific to the application
bool MapSearchNode::GetSuccessors( AStarSearch<MapSearchNode> *astarsearch, MapSearchNode *parent_node )
{
int parent_x = -1;
int parent_y = -1;
if( parent_node )
{
parent_x = parent_node->x;
parent_y = parent_node->y;
}
MapSearchNode NewNode;
// push each possible move except allowing the search to go backwards
if( (GetMap( x-1, y ) < 9)
&& !((parent_x == x-1) && (parent_y == y))
)
{
NewNode = MapSearchNode( x-1, y );
astarsearch->AddSuccessor( NewNode );
}
if( (GetMap( x, y-1 ) < 9)
&& !((parent_x == x) && (parent_y == y-1))
)
{
NewNode = MapSearchNode( x, y-1 );
astarsearch->AddSuccessor( NewNode );
}
if( (GetMap( x+1, y ) < 9)
&& !((parent_x == x+1) && (parent_y == y))
)
{
NewNode = MapSearchNode( x+1, y );
astarsearch->AddSuccessor( NewNode );
}
if( (GetMap( x, y+1 ) < 9)
&& !((parent_x == x) && (parent_y == y+1))
)
{
NewNode = MapSearchNode( x, y+1 );
astarsearch->AddSuccessor( NewNode );
}
return true;
}
// given this node, what does it cost to move to successor. In the case
// of our map the answer is the map terrain value at this node since that is
// conceptually where we're moving
float MapSearchNode::GetCost( MapSearchNode &successor )
{
return (float) GetMap( x, y );
}
// Main
int findpath( int argc, char *argv[] )
{
cout << "STL A* Search implementation\n(C)2001 Justin Heyes-Jones\n";
// Our sample problem defines the world as a 2d array representing a terrain
// Each element contains an integer from 0 to 5 which indicates the cost
// of travel across the terrain. Zero means the least possible difficulty
// in travelling (think ice rink if you can skate) whilst 5 represents the
// most difficult. 9 indicates that we cannot pass.
// Create an instance of the search class...
AStarSearch<MapSearchNode> astarsearch;
unsigned int SearchCount = 0;
const unsigned int NumSearches = 1;
while(SearchCount < NumSearches)
{
// Create a start state
MapSearchNode nodeStart;
nodeStart.x = rand()%MAP_WIDTH;
nodeStart.y = rand()%MAP_HEIGHT;
// Define the goal state
MapSearchNode nodeEnd;
nodeEnd.x = rand()%MAP_WIDTH;
nodeEnd.y = rand()%MAP_HEIGHT;
// Set Start and goal states
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
unsigned int SearchState;
unsigned int SearchSteps = 0;
do
{
SearchState = astarsearch.SearchStep();
SearchSteps++;
#if DEBUG_LISTS
cout << "Steps:" << SearchSteps << "\n";
int len = 0;
cout << "Open:\n";
MapSearchNode *p = astarsearch.GetOpenListStart();
while( p )
{
len++;
#if !DEBUG_LIST_LENGTHS_ONLY
((MapSearchNode *)p)->PrintNodeInfo();
#endif
p = astarsearch.GetOpenListNext();
}
cout << "Open list has " << len << " nodes\n";
len = 0;
cout << "Closed:\n";
p = astarsearch.GetClosedListStart();
while( p )
{
len++;
#if !DEBUG_LIST_LENGTHS_ONLY
p->PrintNodeInfo();
#endif
p = astarsearch.GetClosedListNext();
}
cout << "Closed list has " << len << " nodes\n";
#endif
}
while( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_SEARCHING );
if( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_SUCCEEDED )
{
cout << "Search found goal state\n";
MapSearchNode *node = astarsearch.GetSolutionStart();
#if DISPLAY_SOLUTION
cout << "Displaying solution\n";
#endif
int steps = 0;
node->PrintNodeInfo();
for( ;; )
{
node = astarsearch.GetSolutionNext();
if( !node )
{
break;
}
node->PrintNodeInfo();
steps ++;
};
cout << "Solution steps " << steps << endl;
// Once you're done with the solution you can free the nodes up
astarsearch.FreeSolutionNodes();
}
else if( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_FAILED )
{
cout << "Search terminated. Did not find goal state\n";
}
// Display the number of loops the search went through
cout << "SearchSteps : " << SearchSteps << "\n";
SearchCount ++;
astarsearch.EnsureMemoryFreed();
}
return 0;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////

@ -0,0 +1,252 @@
/*
A* Algorithm Implementation using STL is
Copyright (C)2001-2005 Justin Heyes-Jones
Permission is given by the author to freely redistribute and
include this code in any program as long as this credit is
given where due.
COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS,
WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED CODE
IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE
OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND
PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED
CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL
DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY
NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF
WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE
OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER
THIS DISCLAIMER.
Use at your own risk!
FixedSizeAllocator class
Copyright 2001 Justin Heyes-Jones
This class is a constant time O(1) memory manager for objects of
a specified type. The type is specified using a template class.
Memory is allocated from a fixed size buffer which you can specify in the
class constructor or use the default.
Using GetFirst and GetNext it is possible to iterate through the elements
one by one, and this would be the most common use for the class.
I would suggest using this class when you want O(1) add and delete
and you don't do much searching, which would be O(n). Structures such as binary
trees can be used instead to get O(logn) access time.
*/
#ifndef FSA_H
#define FSA_H
#include <string.h>
#include <stdio.h>
template <class USER_TYPE> class FixedSizeAllocator
{
public:
// Constants
enum
{
FSA_DEFAULT_SIZE = 100
};
// This class enables us to transparently manage the extra data
// needed to enable the user class to form part of the double-linked
// list class
struct FSA_ELEMENT
{
USER_TYPE UserType;
FSA_ELEMENT *pPrev;
FSA_ELEMENT *pNext;
};
public: // methods
FixedSizeAllocator( unsigned int MaxElements = FSA_DEFAULT_SIZE ) :
m_pFirstUsed( NULL ),
m_MaxElements( MaxElements )
{
// Allocate enough memory for the maximum number of elements
char *pMem = new char[ m_MaxElements * sizeof(FSA_ELEMENT) ];
m_pMemory = (FSA_ELEMENT *) pMem;
// Set the free list first pointer
m_pFirstFree = m_pMemory;
// Clear the memory
memset( m_pMemory, 0, sizeof( FSA_ELEMENT ) * m_MaxElements );
// Point at first element
FSA_ELEMENT *pElement = m_pFirstFree;
// Set the double linked free list
for( unsigned int i=0; i<m_MaxElements; i++ )
{
pElement->pPrev = pElement-1;
pElement->pNext = pElement+1;
pElement++;
}
// first element should have a null prev
m_pFirstFree->pPrev = NULL;
// last element should have a null next
(pElement-1)->pNext = NULL;
}
~FixedSizeAllocator()
{
// Free up the memory
delete [] (char *) m_pMemory;
}
// Allocate a new USER_TYPE and return a pointer to it
USER_TYPE *alloc()
{
FSA_ELEMENT *pNewNode = NULL;
if( !m_pFirstFree )
{
return NULL;
}
else
{
pNewNode = m_pFirstFree;
m_pFirstFree = pNewNode->pNext;
// if the new node points to another free node then
// change that nodes prev free pointer...
if( pNewNode->pNext )
{
pNewNode->pNext->pPrev = NULL;
}
// node is now on the used list
pNewNode->pPrev = NULL; // the allocated node is always first in the list
if( m_pFirstUsed == NULL )
{
pNewNode->pNext = NULL; // no other nodes
}
else
{
m_pFirstUsed->pPrev = pNewNode; // insert this at the head of the used list
pNewNode->pNext = m_pFirstUsed;
}
m_pFirstUsed = pNewNode;
}
return reinterpret_cast<USER_TYPE*>(pNewNode);
}
// Free the given user type
// For efficiency I don't check whether the user_data is a valid
// pointer that was allocated. I may add some debug only checking
// (To add the debug check you'd need to make sure the pointer is in
// the m_pMemory area and is pointing at the start of a node)
void free( USER_TYPE *user_data )
{
FSA_ELEMENT *pNode = reinterpret_cast<FSA_ELEMENT*>(user_data);
// manage used list, remove this node from it
if( pNode->pPrev )
{
pNode->pPrev->pNext = pNode->pNext;
}
else
{
// this handles the case that we delete the first node in the used list
m_pFirstUsed = pNode->pNext;
}
if( pNode->pNext )
{
pNode->pNext->pPrev = pNode->pPrev;
}
// add to free list
if( m_pFirstFree == NULL )
{
// free list was empty
m_pFirstFree = pNode;
pNode->pPrev = NULL;
pNode->pNext = NULL;
}
else
{
// Add this node at the start of the free list
m_pFirstFree->pPrev = pNode;
pNode->pNext = m_pFirstFree;
m_pFirstFree = pNode;
}
}
// For debugging this displays both lists (using the prev/next list pointers)
void Debug()
{
printf( "free list " );
FSA_ELEMENT *p = m_pFirstFree;
while( p )
{
printf( "%x!%x ", p->pPrev, p->pNext );
p = p->pNext;
}
printf( "\n" );
printf( "used list " );
p = m_pFirstUsed;
while( p )
{
printf( "%x!%x ", p->pPrev, p->pNext );
p = p->pNext;
}
printf( "\n" );
}
// Iterators
USER_TYPE *GetFirst()
{
return reinterpret_cast<USER_TYPE *>(m_pFirstUsed);
}
USER_TYPE *GetNext( USER_TYPE *node )
{
return reinterpret_cast<USER_TYPE *>
(
(reinterpret_cast<FSA_ELEMENT *>(node))->pNext
);
}
public: // data
private: // methods
private: // data
FSA_ELEMENT *m_pFirstFree;
FSA_ELEMENT *m_pFirstUsed;
unsigned int m_MaxElements;
FSA_ELEMENT *m_pMemory;
};
#endif // defined FSA_H

@ -0,0 +1,315 @@
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// This example code illustrate how to use STL A* Search implementation to find the minimum path between two
// cities given a map. The example is taken from the book AI: A Modern Approach, 3rd Ed., by Russel, where a map
// of Romania is given. The target node is Bucharest, and the user can specify the initial city from which the
// search algorithm will start looking for the minimum path to Bucharest.
//
// Usage: min_path_to_Bucharest <start city name>
// Example:
// min_path_to_Bucharest Arad
//
// Thanks to Rasoul Mojtahedzadeh for this contribution
// Mojtahedzadeh _atsign_ gmail com
//
// Please note that this exercise is academic in nature and that the distances between the cities may not be
// correct compared to the latitude and longnitude differences in real life. Thanks to parthi2929 for noticing
// this issue. In a real application you would use some kind of exact x,y position of each city in order to get
// accurate heuristics. In fact, that is part of the point of this example, in the book you will see Norvig
// mention that the algorithm does some backtracking because the heuristic is not accurate (yet still admissable).
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include "stlastar.h"
#include <iostream>
#include <string>
#include <vector>
#include <stdio.h>
#define DEBUG_LISTS 0
#define DEBUG_LIST_LENGTHS_ONLY 0
using namespace std;
const int MAX_CITIES = 20;
enum ENUM_CITIES{Arad=0, Bucharest, Craiova, Drobeta, Eforie, Fagaras, Giurgiu, Hirsova, Iasi, Lugoj, Mehadia, Neamt, Oradea, Pitesti, RimnicuVilcea, Sibiu, Timisoara, Urziceni, Vaslui, Zerind};
vector<string> CityNames(MAX_CITIES);
float RomaniaMap[MAX_CITIES][MAX_CITIES];
class PathSearchNode
{
public:
ENUM_CITIES city;
PathSearchNode() { city = Arad; }
PathSearchNode( ENUM_CITIES in ) { city = in; }
float GoalDistanceEstimate( PathSearchNode &nodeGoal );
bool IsGoal( PathSearchNode &nodeGoal );
bool GetSuccessors( AStarSearch<PathSearchNode> *astarsearch, PathSearchNode *parent_node );
float GetCost( PathSearchNode &successor );
bool IsSameState( PathSearchNode &rhs );
void PrintNodeInfo();
};
// check if "this" node is the same as "rhs" node
bool PathSearchNode::IsSameState( PathSearchNode &rhs )
{
if(city == rhs.city) return(true);
return(false);
}
// Euclidean distance between "this" node city and Bucharest
// Note: Numbers are taken from the book
float PathSearchNode::GoalDistanceEstimate( PathSearchNode &nodeGoal )
{
// goal is always Bucharest!
switch(city)
{
case Arad: return 366; break;
case Bucharest: return 0; break;
case Craiova: return 160; break;
case Drobeta: return 242; break;
case Eforie: return 161; break;
case Fagaras: return 176; break;
case Giurgiu: return 77; break;
case Hirsova: return 151; break;
case Iasi: return 226; break;
case Lugoj: return 244; break;
case Mehadia: return 241; break;
case Neamt: return 234; break;
case Oradea: return 380; break;
case Pitesti: return 100; break;
case RimnicuVilcea: return 193; break;
case Sibiu: return 253; break;
case Timisoara: return 329; break;
case Urziceni: return 80; break;
case Vaslui: return 199; break;
case Zerind: return 374; break;
}
cerr << "ASSERT: city = " << CityNames[city] << endl;
return 0;
}
// check if "this" node is the goal node
bool PathSearchNode::IsGoal( PathSearchNode &nodeGoal )
{
if( city == Bucharest ) return(true);
return(false);
}
// generates the successor nodes of "this" node
bool PathSearchNode::GetSuccessors( AStarSearch<PathSearchNode> *astarsearch, PathSearchNode *parent_node )
{
PathSearchNode NewNode;
for(int c=0; c<MAX_CITIES; c++)
{
if(RomaniaMap[city][c] < 0) continue;
NewNode = PathSearchNode((ENUM_CITIES)c);
astarsearch->AddSuccessor( NewNode );
}
return true;
}
// the cost of going from "this" node to the "successor" node
float PathSearchNode::GetCost( PathSearchNode &successor )
{
return RomaniaMap[city][successor.city];
}
// prints out information about the node
void PathSearchNode::PrintNodeInfo()
{
cout << " " << CityNames[city] << "\n";
}
// Main
int min_path_to_Bucharest( int argc, char *argv[] )
{
// creating map of Romania
for(int i=0; i<MAX_CITIES; i++)
for(int j=0; j<MAX_CITIES; j++)
RomaniaMap[i][j]=-1.0;
RomaniaMap[Arad][Sibiu]=140;
RomaniaMap[Arad][Zerind]=75;
RomaniaMap[Arad][Timisoara]=118;
RomaniaMap[Bucharest][Giurgiu]=90;
RomaniaMap[Bucharest][Urziceni]=85;
RomaniaMap[Bucharest][Fagaras]=211;
RomaniaMap[Bucharest][Pitesti]=101;
RomaniaMap[Craiova][Drobeta]=120;
RomaniaMap[Craiova][RimnicuVilcea]=146;
RomaniaMap[Craiova][Pitesti]=138;
RomaniaMap[Drobeta][Craiova]=120;
RomaniaMap[Drobeta][Mehadia]=75;
RomaniaMap[Eforie][Hirsova]=75;
RomaniaMap[Fagaras][Bucharest]=211;
RomaniaMap[Fagaras][Sibiu]=99;
RomaniaMap[Giurgiu][Bucharest]=90;
RomaniaMap[Hirsova][Eforie]=86;
RomaniaMap[Hirsova][Urziceni]=98;
RomaniaMap[Iasi][Vaslui]=92;
RomaniaMap[Iasi][Neamt]=87;
RomaniaMap[Lugoj][Timisoara]=111;
RomaniaMap[Lugoj][Mehadia]=70;
RomaniaMap[Mehadia][Lugoj]=70;
RomaniaMap[Mehadia][Drobeta]=75;
RomaniaMap[Neamt][Iasi]=87;
RomaniaMap[Oradea][Zerind]=71;
RomaniaMap[Oradea][Sibiu]=151;
RomaniaMap[Pitesti][Bucharest]=101;
RomaniaMap[Pitesti][RimnicuVilcea]=97;
RomaniaMap[Pitesti][Craiova]=138;
RomaniaMap[RimnicuVilcea][Pitesti]=97;
RomaniaMap[RimnicuVilcea][Craiova]=146;
RomaniaMap[RimnicuVilcea][Sibiu]=80;
RomaniaMap[Sibiu][RimnicuVilcea]=80;
RomaniaMap[Sibiu][Fagaras]=99;
RomaniaMap[Sibiu][Oradea]=151;
RomaniaMap[Sibiu][Arad]=140;
RomaniaMap[Timisoara][Arad]=118;
RomaniaMap[Timisoara][Lugoj]=111;
RomaniaMap[Urziceni][Bucharest]=85;
RomaniaMap[Urziceni][Hirsova]=98;
RomaniaMap[Urziceni][Vaslui]=142;
RomaniaMap[Vaslui][Urziceni]=142;
RomaniaMap[Vaslui][Iasi]=92;
RomaniaMap[Zerind][Arad]=75;
RomaniaMap[Zerind][Oradea]=71;
// City names
CityNames[Arad].assign("Arad");
CityNames[Bucharest].assign("Bucharest");
CityNames[Craiova].assign("Craiova");
CityNames[Drobeta].assign("Drobeta");
CityNames[Eforie].assign("Eforie");
CityNames[Fagaras].assign("Fagaras");
CityNames[Giurgiu].assign("Giurgiu");
CityNames[Hirsova].assign("Hirsova");
CityNames[Iasi].assign("Iasi");
CityNames[Lugoj].assign("Lugoj");
CityNames[Mehadia].assign("Mehadia");
CityNames[Neamt].assign("Neamt");
CityNames[Oradea].assign("Oradea");
CityNames[Pitesti].assign("Pitesti");
CityNames[RimnicuVilcea].assign("RimnicuVilcea");
CityNames[Sibiu].assign("Sibiu");
CityNames[Timisoara].assign("Timisoara");
CityNames[Urziceni].assign("Urziceni");
CityNames[Vaslui].assign("Vaslui");
CityNames[Zerind].assign("Zerind");
ENUM_CITIES initCity = Arad;
if(argc == 2)
{
bool found = false;
for(size_t i=0; i<CityNames.size(); i++)
if(CityNames[i].compare(argv[1])==0)
{
initCity = (ENUM_CITIES)i;
found = true;
break;
}
if(not found)
{
cout << "There is no city named "<<argv[1]<<" in the map!\n";
return(1);
}
}
// An instance of A* search class
AStarSearch<PathSearchNode> astarsearch;
unsigned int SearchCount = 0;
const unsigned int NumSearches = 1;
while(SearchCount < NumSearches)
{
// Create a start state
PathSearchNode nodeStart;
nodeStart.city = initCity;
// Define the goal state, always Bucharest!
PathSearchNode nodeEnd;
nodeEnd.city = Bucharest;
// Set Start and goal states
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
unsigned int SearchState;
unsigned int SearchSteps = 0;
do
{
SearchState = astarsearch.SearchStep();
SearchSteps++;
#if DEBUG_LISTS
cout << "Steps:" << SearchSteps << "\n";
int len = 0;
cout << "Open:\n";
PathSearchNode *p = astarsearch.GetOpenListStart();
while( p )
{
len++;
#if !DEBUG_LIST_LENGTHS_ONLY
((PathSearchNode *)p)->PrintNodeInfo();
#endif
p = astarsearch.GetOpenListNext();
}
cout << "Open list has " << len << " nodes\n";
len = 0;
cout << "Closed:\n";
p = astarsearch.GetClosedListStart();
while( p )
{
len++;
#if !DEBUG_LIST_LENGTHS_ONLY
p->PrintNodeInfo();
#endif
p = astarsearch.GetClosedListNext();
}
cout << "Closed list has " << len << " nodes\n";
#endif
}
while( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_SEARCHING );
if( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_SUCCEEDED )
{
cout << "Search found the goal state\n";
PathSearchNode *node = astarsearch.GetSolutionStart();
cout << "Displaying solution\n";
int steps = 0;
node->PrintNodeInfo();
for( ;; )
{
node = astarsearch.GetSolutionNext();
if( !node ) break;
node->PrintNodeInfo();
steps ++;
};
cout << "Solution steps " << steps << endl;
// Once you're done with the solution you can free the nodes up
astarsearch.FreeSolutionNodes();
}
else if( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_FAILED )
{
cout << "Search terminated. Did not find goal state\n";
}
// Display the number of loops the search went through
cout << "SearchSteps : " << SearchSteps << "\n";
SearchCount ++;
astarsearch.EnsureMemoryFreed();
}
return 0;
}

@ -0,0 +1,833 @@
/*
A* Algorithm Implementation using STL is
Copyright (C)2001-2005 Justin Heyes-Jones
Permission is given by the author to freely redistribute and
include this code in any program as long as this credit is
given where due.
COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS,
WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED CODE
IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE
OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND
PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED
CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL
DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY
NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF
WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE
OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER
THIS DISCLAIMER.
Use at your own risk!
*/
#ifndef STLASTAR_H
#define STLASTAR_H
// used for text debugging
#include <iostream>
#include <stdio.h>
//#include <conio.h>
#include <assert.h>
// stl includes
#include <algorithm>
#include <set>
#include <vector>
#include <cfloat>
using namespace std;
// fast fixed size memory allocator, used for fast node memory management
#include "fsa.h"
// Fixed size memory allocator can be disabled to compare performance
// Uses std new and delete instead if you turn it off
#define USE_FSA_MEMORY 1
// disable warning that debugging information has lines that are truncated
// occurs in stl headers
#if defined(WIN32) && defined(_WINDOWS)
#pragma warning( disable : 4786 )
#endif
template <class T> class AStarState;
// The AStar search class. UserState is the users state space type
template <class UserState> class AStarSearch
{
public: // data
enum
{
SEARCH_STATE_NOT_INITIALISED,
SEARCH_STATE_SEARCHING,
SEARCH_STATE_SUCCEEDED,
SEARCH_STATE_FAILED,
SEARCH_STATE_OUT_OF_MEMORY,
SEARCH_STATE_INVALID
};
// A node represents a possible state in the search
// The user provided state type is included inside this type
public:
class Node
{
public:
Node *parent; // used during the search to record the parent of successor nodes
Node *child; // used after the search for the application to view the search in reverse
float g; // cost of this node + it's predecessors
float h; // heuristic estimate of distance to goal
float f; // sum of cumulative cost of predecessors and self and heuristic
Node() :
parent( 0 ),
child( 0 ),
g( 0.0f ),
h( 0.0f ),
f( 0.0f )
{
}
UserState m_UserState;
};
// For sorting the heap the STL needs compare function that lets us compare
// the f value of two nodes
class HeapCompare_f
{
public:
bool operator() ( const Node *x, const Node *y ) const
{
return x->f > y->f;
}
};
public: // methods
// constructor just initialises private data
AStarSearch() :
m_State( SEARCH_STATE_NOT_INITIALISED ),
m_CurrentSolutionNode( NULL ),
#if USE_FSA_MEMORY
m_FixedSizeAllocator( 1000 ),
#endif
m_AllocateNodeCount(0),
m_CancelRequest( false )
{
}
AStarSearch( int MaxNodes ) :
m_State( SEARCH_STATE_NOT_INITIALISED ),
m_CurrentSolutionNode( NULL ),
#if USE_FSA_MEMORY
m_FixedSizeAllocator( MaxNodes ),
#endif
m_AllocateNodeCount(0),
m_CancelRequest( false )
{
}
// call at any time to cancel the search and free up all the memory
void CancelSearch()
{
m_CancelRequest = true;
}
// Set Start and goal states
void SetStartAndGoalStates( UserState &Start, UserState &Goal )
{
m_CancelRequest = false;
m_Start = AllocateNode();
m_Goal = AllocateNode();
assert((m_Start != NULL && m_Goal != NULL));
m_Start->m_UserState = Start;
m_Goal->m_UserState = Goal;
m_State = SEARCH_STATE_SEARCHING;
// Initialise the AStar specific parts of the Start Node
// The user only needs fill out the state information
m_Start->g = 0;
m_Start->h = m_Start->m_UserState.GoalDistanceEstimate( m_Goal->m_UserState );
m_Start->f = m_Start->g + m_Start->h;
m_Start->parent = 0;
// Push the start node on the Open list
m_OpenList.push_back( m_Start ); // heap now unsorted
// Sort back element into heap
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
// Initialise counter for search steps
m_Steps = 0;
}
// Advances search one step
unsigned int SearchStep()
{
// Firstly break if the user has not initialised the search
assert( (m_State > SEARCH_STATE_NOT_INITIALISED) &&
(m_State < SEARCH_STATE_INVALID) );
// Next I want it to be safe to do a searchstep once the search has succeeded...
if( (m_State == SEARCH_STATE_SUCCEEDED) ||
(m_State == SEARCH_STATE_FAILED)
)
{
return m_State;
}
// Failure is defined as emptying the open list as there is nothing left to
// search...
// New: Allow user abort
if( m_OpenList.empty() || m_CancelRequest )
{
FreeAllNodes();
m_State = SEARCH_STATE_FAILED;
return m_State;
}
// Incremement step count
m_Steps ++;
// Pop the best node (the one with the lowest f)
Node *n = m_OpenList.front(); // get pointer to the node
pop_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
m_OpenList.pop_back();
// Check for the goal, once we pop that we're done
if( n->m_UserState.IsGoal( m_Goal->m_UserState ) )
{
// The user is going to use the Goal Node he passed in
// so copy the parent pointer of n
m_Goal->parent = n->parent;
m_Goal->g = n->g;
// A special case is that the goal was passed in as the start state
// so handle that here
if( false == n->m_UserState.IsSameState( m_Start->m_UserState ) )
{
FreeNode( n );
// set the child pointers in each node (except Goal which has no child)
Node *nodeChild = m_Goal;
Node *nodeParent = m_Goal->parent;
do
{
nodeParent->child = nodeChild;
nodeChild = nodeParent;
nodeParent = nodeParent->parent;
}
while( nodeChild != m_Start ); // Start is always the first node by definition
}
// delete nodes that aren't needed for the solution
FreeUnusedNodes();
m_State = SEARCH_STATE_SUCCEEDED;
return m_State;
}
else // not goal
{
// We now need to generate the successors of this node
// The user helps us to do this, and we keep the new nodes in
// m_Successors ...
m_Successors.clear(); // empty vector of successor nodes to n
// User provides this functions and uses AddSuccessor to add each successor of
// node 'n' to m_Successors
bool ret = n->m_UserState.GetSuccessors( this, n->parent ? &n->parent->m_UserState : NULL );
if( !ret )
{
typename vector< Node * >::iterator successor;
// free the nodes that may previously have been added
for( successor = m_Successors.begin(); successor != m_Successors.end(); successor ++ )
{
FreeNode( (*successor) );
}
m_Successors.clear(); // empty vector of successor nodes to n
// free up everything else we allocated
FreeNode( (n) );
FreeAllNodes();
m_State = SEARCH_STATE_OUT_OF_MEMORY;
return m_State;
}
// Now handle each successor to the current node ...
for( typename vector< Node * >::iterator successor = m_Successors.begin(); successor != m_Successors.end(); successor ++ )
{
// The g value for this successor ...
float newg = n->g + n->m_UserState.GetCost( (*successor)->m_UserState );
// Now we need to find whether the node is on the open or closed lists
// If it is but the node that is already on them is better (lower g)
// then we can forget about this successor
// First linear search of open list to find node
typename vector< Node * >::iterator openlist_result;
for( openlist_result = m_OpenList.begin(); openlist_result != m_OpenList.end(); openlist_result ++ )
{
if( (*openlist_result)->m_UserState.IsSameState( (*successor)->m_UserState ) )
{
break;
}
}
if( openlist_result != m_OpenList.end() )
{
// we found this state on open
if( (*openlist_result)->g <= newg )
{
FreeNode( (*successor) );
// the one on Open is cheaper than this one
continue;
}
}
typename vector< Node * >::iterator closedlist_result;
for( closedlist_result = m_ClosedList.begin(); closedlist_result != m_ClosedList.end(); closedlist_result ++ )
{
if( (*closedlist_result)->m_UserState.IsSameState( (*successor)->m_UserState ) )
{
break;
}
}
if( closedlist_result != m_ClosedList.end() )
{
// we found this state on closed
if( (*closedlist_result)->g <= newg )
{
// the one on Closed is cheaper than this one
FreeNode( (*successor) );
continue;
}
}
// This node is the best node so far with this particular state
// so lets keep it and set up its AStar specific data ...
(*successor)->parent = n;
(*successor)->g = newg;
(*successor)->h = (*successor)->m_UserState.GoalDistanceEstimate( m_Goal->m_UserState );
(*successor)->f = (*successor)->g + (*successor)->h;
// Successor in closed list
// 1 - Update old version of this node in closed list
// 2 - Move it from closed to open list
// 3 - Sort heap again in open list
if( closedlist_result != m_ClosedList.end() )
{
// Update closed node with successor node AStar data
//*(*closedlist_result) = *(*successor);
(*closedlist_result)->parent = (*successor)->parent;
(*closedlist_result)->g = (*successor)->g;
(*closedlist_result)->h = (*successor)->h;
(*closedlist_result)->f = (*successor)->f;
// Free successor node
FreeNode( (*successor) );
// Push closed node into open list
m_OpenList.push_back( (*closedlist_result) );
// Remove closed node from closed list
m_ClosedList.erase( closedlist_result );
// Sort back element into heap
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
// Fix thanks to ...
// Greg Douglas <gregdouglasmail@gmail.com>
// who noticed that this code path was incorrect
// Here we have found a new state which is already CLOSED
}
// Successor in open list
// 1 - Update old version of this node in open list
// 2 - sort heap again in open list
else if( openlist_result != m_OpenList.end() )
{
// Update open node with successor node AStar data
//*(*openlist_result) = *(*successor);
(*openlist_result)->parent = (*successor)->parent;
(*openlist_result)->g = (*successor)->g;
(*openlist_result)->h = (*successor)->h;
(*openlist_result)->f = (*successor)->f;
// Free successor node
FreeNode( (*successor) );
// re-make the heap
// make_heap rather than sort_heap is an essential bug fix
// thanks to Mike Ryynanen for pointing this out and then explaining
// it in detail. sort_heap called on an invalid heap does not work
make_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
}
// New successor
// 1 - Move it from successors to open list
// 2 - sort heap again in open list
else
{
// Push successor node into open list
m_OpenList.push_back( (*successor) );
// Sort back element into heap
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
}
}
// push n onto Closed, as we have expanded it now
m_ClosedList.push_back( n );
} // end else (not goal so expand)
return m_State; // Succeeded bool is false at this point.
}
// User calls this to add a successor to a list of successors
// when expanding the search frontier
bool AddSuccessor( UserState &State )
{
Node *node = AllocateNode();
if( node )
{
node->m_UserState = State;
m_Successors.push_back( node );
return true;
}
return false;
}
// Free the solution nodes
// This is done to clean up all used Node memory when you are done with the
// search
void FreeSolutionNodes()
{
Node *n = m_Start;
if( m_Start->child )
{
do
{
Node *del = n;
n = n->child;
FreeNode( del );
del = NULL;
} while( n != m_Goal );
FreeNode( n ); // Delete the goal
}
else
{
// if the start node is the solution we need to just delete the start and goal
// nodes
FreeNode( m_Start );
FreeNode( m_Goal );
}
}
// Functions for traversing the solution
// Get start node
UserState *GetSolutionStart()
{
m_CurrentSolutionNode = m_Start;
if( m_Start )
{
return &m_Start->m_UserState;
}
else
{
return NULL;
}
}
// Get next node
UserState *GetSolutionNext()
{
if( m_CurrentSolutionNode )
{
if( m_CurrentSolutionNode->child )
{
Node *child = m_CurrentSolutionNode->child;
m_CurrentSolutionNode = m_CurrentSolutionNode->child;
return &child->m_UserState;
}
}
return NULL;
}
// Get end node
UserState *GetSolutionEnd()
{
m_CurrentSolutionNode = m_Goal;
if( m_Goal )
{
return &m_Goal->m_UserState;
}
else
{
return NULL;
}
}
// Step solution iterator backwards
UserState *GetSolutionPrev()
{
if( m_CurrentSolutionNode )
{
if( m_CurrentSolutionNode->parent )
{
Node *parent = m_CurrentSolutionNode->parent;
m_CurrentSolutionNode = m_CurrentSolutionNode->parent;
return &parent->m_UserState;
}
}
return NULL;
}
// Get final cost of solution
// Returns FLT_MAX if goal is not defined or there is no solution
float GetSolutionCost()
{
if( m_Goal && m_State == SEARCH_STATE_SUCCEEDED )
{
return m_Goal->g;
}
else
{
return FLT_MAX;
}
}
// For educational use and debugging it is useful to be able to view
// the open and closed list at each step, here are two functions to allow that.
UserState *GetOpenListStart()
{
float f,g,h;
return GetOpenListStart( f,g,h );
}
UserState *GetOpenListStart( float &f, float &g, float &h )
{
iterDbgOpen = m_OpenList.begin();
if( iterDbgOpen != m_OpenList.end() )
{
f = (*iterDbgOpen)->f;
g = (*iterDbgOpen)->g;
h = (*iterDbgOpen)->h;
return &(*iterDbgOpen)->m_UserState;
}
return NULL;
}
UserState *GetOpenListNext()
{
float f,g,h;
return GetOpenListNext( f,g,h );
}
UserState *GetOpenListNext( float &f, float &g, float &h )
{
iterDbgOpen++;
if( iterDbgOpen != m_OpenList.end() )
{
f = (*iterDbgOpen)->f;
g = (*iterDbgOpen)->g;
h = (*iterDbgOpen)->h;
return &(*iterDbgOpen)->m_UserState;
}
return NULL;
}
UserState *GetClosedListStart()
{
float f,g,h;
return GetClosedListStart( f,g,h );
}
UserState *GetClosedListStart( float &f, float &g, float &h )
{
iterDbgClosed = m_ClosedList.begin();
if( iterDbgClosed != m_ClosedList.end() )
{
f = (*iterDbgClosed)->f;
g = (*iterDbgClosed)->g;
h = (*iterDbgClosed)->h;
return &(*iterDbgClosed)->m_UserState;
}
return NULL;
}
UserState *GetClosedListNext()
{
float f,g,h;
return GetClosedListNext( f,g,h );
}
UserState *GetClosedListNext( float &f, float &g, float &h )
{
iterDbgClosed++;
if( iterDbgClosed != m_ClosedList.end() )
{
f = (*iterDbgClosed)->f;
g = (*iterDbgClosed)->g;
h = (*iterDbgClosed)->h;
return &(*iterDbgClosed)->m_UserState;
}
return NULL;
}
// Get the number of steps
int GetStepCount() { return m_Steps; }
void EnsureMemoryFreed()
{
#if USE_FSA_MEMORY
assert(m_AllocateNodeCount == 0);
#endif
}
private: // methods
// This is called when a search fails or is cancelled to free all used
// memory
void FreeAllNodes()
{
// iterate open list and delete all nodes
typename vector< Node * >::iterator iterOpen = m_OpenList.begin();
while( iterOpen != m_OpenList.end() )
{
Node *n = (*iterOpen);
FreeNode( n );
iterOpen ++;
}
m_OpenList.clear();
// iterate closed list and delete unused nodes
typename vector< Node * >::iterator iterClosed;
for( iterClosed = m_ClosedList.begin(); iterClosed != m_ClosedList.end(); iterClosed ++ )
{
Node *n = (*iterClosed);
FreeNode( n );
}
m_ClosedList.clear();
// delete the goal
FreeNode(m_Goal);
}
// This call is made by the search class when the search ends. A lot of nodes may be
// created that are still present when the search ends. They will be deleted by this
// routine once the search ends
void FreeUnusedNodes()
{
// iterate open list and delete unused nodes
typename vector< Node * >::iterator iterOpen = m_OpenList.begin();
while( iterOpen != m_OpenList.end() )
{
Node *n = (*iterOpen);
if( !n->child )
{
FreeNode( n );
n = NULL;
}
iterOpen ++;
}
m_OpenList.clear();
// iterate closed list and delete unused nodes
typename vector< Node * >::iterator iterClosed;
for( iterClosed = m_ClosedList.begin(); iterClosed != m_ClosedList.end(); iterClosed ++ )
{
Node *n = (*iterClosed);
if( !n->child )
{
FreeNode( n );
n = NULL;
}
}
m_ClosedList.clear();
}
// Node memory management
Node *AllocateNode()
{
#if !USE_FSA_MEMORY
m_AllocateNodeCount ++;
Node *p = new Node;
return p;
#else
Node *address = m_FixedSizeAllocator.alloc();
if( !address )
{
return NULL;
}
m_AllocateNodeCount ++;
Node *p = new (address) Node;
return p;
#endif
}
void FreeNode( Node *node )
{
m_AllocateNodeCount --;
#if !USE_FSA_MEMORY
delete node;
#else
node->~Node();
m_FixedSizeAllocator.free( node );
#endif
}
private: // data
// Heap (simple vector but used as a heap, cf. Steve Rabin's game gems article)
vector< Node *> m_OpenList;
// Closed list is a vector.
vector< Node * > m_ClosedList;
// Successors is a vector filled out by the user each type successors to a node
// are generated
vector< Node * > m_Successors;
// State
unsigned int m_State;
// Counts steps
int m_Steps;
// Start and goal state pointers
Node *m_Start;
Node *m_Goal;
Node *m_CurrentSolutionNode;
#if USE_FSA_MEMORY
// Memory
FixedSizeAllocator<Node> m_FixedSizeAllocator;
#endif
//Debug : need to keep these two iterators around
// for the user Dbg functions
typename vector< Node * >::iterator iterDbgOpen;
typename vector< Node * >::iterator iterDbgClosed;
// debugging : count memory allocation and free's
int m_AllocateNodeCount;
bool m_CancelRequest;
};
template <class T> class AStarState
{
public:
virtual ~AStarState() {}
virtual float GoalDistanceEstimate( T &nodeGoal ) = 0; // Heuristic function which computes the estimated cost to the goal node
virtual bool IsGoal( T &nodeGoal ) = 0; // Returns true if this node is the goal node
virtual bool GetSuccessors( AStarSearch<T> *astarsearch, T *parent_node ) = 0; // Retrieves all successors to this node and adds them via astarsearch.addSuccessor()
virtual float GetCost( T &successor ) = 0; // Computes the cost of travelling from this node to the successor node
virtual bool IsSameState( T &rhs ) = 0; // Returns true if this node is the same as the rhs node
};
#endif

@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

@ -0,0 +1,105 @@
# MT-YOLOv6 [About Naming YOLOv6](./docs/About_naming_yolov6.md)
## Introduction
YOLOv6 is a single-stage object detection framework dedicated to industrial applications, with hardware-friendly efficient design and high performance.
<img src="assets/picture.png" width="800">
YOLOv6-nano achieves 35.0 mAP on COCO val2017 dataset with 1242 FPS on T4 using TensorRT FP16 for bs32 inference, and YOLOv6-s achieves 43.1 mAP on COCO val2017 dataset with 520 FPS on T4 using TensorRT FP16 for bs32 inference.
YOLOv6 is composed of the following methods:
- Hardware-friendly Design for Backbone and Neck
- Efficient Decoupled Head with SIoU Loss
## Coming soon
- [ ] YOLOv6 m/l/x model.
- [ ] Deployment for MNN/TNN/NCNN/CoreML...
- [ ] Quantization tools
## Quick Start
### Install
```shell
git clone https://github.com/meituan/YOLOv6
cd YOLOv6
pip install -r requirements.txt
```
### Inference
First, download a pretrained model from the YOLOv6 [release](https://github.com/meituan/YOLOv6/releases/tag/0.1.0)
Second, run inference with `tools/infer.py`
```shell
python tools/infer.py --weights yolov6s.pt --source img.jpg / imgdir
yolov6n.pt
```
### Training
Single GPU
```shell
python tools/train.py --batch 32 --conf configs/yolov6s.py --data data/coco.yaml --device 0
configs/yolov6n.py
```
Multi GPUs (DDP mode recommended)
```shell
python -m torch.distributed.launch --nproc_per_node 8 tools/train.py --batch 256 --conf configs/yolov6s.py --data data/coco.yaml --device 0,1,2,3,4,5,6,7
configs/yolov6n.py
```
- conf: select config file to specify network/optimizer/hyperparameters
- data: prepare [COCO](http://cocodataset.org) dataset and specify dataset paths in data.yaml
### Evaluation
Reproduce mAP on COCO val2017 dataset
```shell
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6s.pt --task val
yolov6n.pt
```
### Deployment
* [ONNX](./deploy/ONNX)
* [OpenVINO](./deploy/OpenVINO)
### Tutorials
* [Train custom data](./docs/Train_custom_data.md)
* [Test speed](./docs/Test_speed.md)
## Benchmark
| Model | Size | mAP<sup>val<br/>0.5:0.95 | Speed<sup>V100<br/>fp16 b32 <br/>(ms) | Speed<sup>V100<br/>fp32 b32 <br/>(ms) | Speed<sup>T4<br/>trt fp16 b1 <br/>(fps) | Speed<sup>T4<br/>trt fp16 b32 <br/>(fps) | Params<br/><sup> (M) | Flops<br/><sup> (G) |
| :-------------- | ----------- | :----------------------- | :------------------------------------ | :------------------------------------ | ---------------------------------------- | ----------------------------------------- | --------------- | -------------- |
| [**YOLOv6-n**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.pt) | 416<br/>640 | 30.8<br/>35.0 | 0.3<br/>0.5 | 0.4<br/>0.7 | 1100<br/>788 | 2716<br/>1242 | 4.3<br/>4.3 | 4.7<br/>11.1 |
| [**YOLOv6-tiny**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6t.pt) | 640 | 41.3 | 0.9 | 1.5 | 425 | 602 | 15.0 | 36.7 |
| [**YOLOv6-s**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.pt) | 640 | 43.1 | 1.0 | 1.7 | 373 | 520 | 17.2 | 44.2 |
- Comparisons of the mAP and speed of different object detectors are tested on [COCO val2017](https://cocodataset.org/#download) dataset.
- Refer to [Test speed](./docs/Test_speed.md) tutorial to reproduce the speed results of YOLOv6.
- Params and Flops of YOLOv6 are estimated on deployed model.
- Speed results of other methods are tested in our environment using official codebase and model if not found from the corresponding official release.
## Third-party resources
* YOLOv6 NCNN Android app demo: [ncnn-android-yolov6](https://github.com/FeiGeChuanShu/ncnn-android-yolov6) from [FeiGeChuanShu](https://github.com/FeiGeChuanShu)
* YOLOv6 ONNXRuntime/MNN/TNN C++: [YOLOv6-ORT](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/yolov6.cpp), [YOLOv6-MNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_yolov6.cpp) and [YOLOv6-TNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_yolov6.cpp) from [DefTruth](https://github.com/DefTruth)
* YOLOv6 TensorRT Python: [yolov6-tensorrt-python](https://github.com/Linaom1214/tensorrt-python/blob/main/yolov6/trt.py) from [Linaom1214](https://github.com/Linaom1214)
* YOLOv6 TensorRT Windows C++: [yolort](https://github.com/zhiqwang/yolov5-rt-stack/tree/main/deployment/tensorrt-yolov6) from [Wei Zeng](https://github.com/Wulingtian)

Binary file not shown.

After

Width:  |  Height:  |  Size: 517 KiB

@ -0,0 +1,53 @@
# YOLOv6t model
model = dict(
type='YOLOv6t',
pretrained=None,
depth_multiple=0.25,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='ciou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)
data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)

@ -0,0 +1,53 @@
# YOLOv6t model
model = dict(
type='YOLOv6t',
pretrained='./weights/yolov6t.pt',
depth_multiple=0.25,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='ciou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.0032,
lrf=0.12,
momentum=0.843,
weight_decay=0.00036,
warmup_epochs=2.0,
warmup_momentum=0.5,
warmup_bias_lr=0.05
)
data_aug = dict(
hsv_h=0.0138,
hsv_s=0.664,
hsv_v=0.464,
degrees=0.373,
translate=0.245,
scale=0.898,
shear=0.602,
flipud=0.00856,
fliplr=0.5,
mosaic=1.0,
mixup=0.243,
)

@ -0,0 +1,53 @@
# YOLOv6n model
model = dict(
type='YOLOv6n',
pretrained=None,
depth_multiple=0.33,
width_multiple=0.25,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='ciou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)
data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)

@ -0,0 +1,53 @@
# YOLOv6n model
model = dict(
type='YOLOv6n',
pretrained='./weights/yolov6n.pt',
depth_multiple=0.33,
width_multiple=0.25,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='ciou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.0032,
lrf=0.12,
momentum=0.843,
weight_decay=0.00036,
warmup_epochs=2.0,
warmup_momentum=0.5,
warmup_bias_lr=0.05
)
data_aug = dict(
hsv_h=0.0138,
hsv_s=0.664,
hsv_v=0.464,
degrees=0.373,
translate=0.245,
scale=0.898,
shear=0.602,
flipud=0.00856,
fliplr=0.5,
mosaic=1.0,
mixup=0.243
)

@ -0,0 +1,53 @@
# YOLOv6s model
model = dict(
type='YOLOv6s',
pretrained=None,
depth_multiple=0.33,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='siou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)
data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)

@ -0,0 +1,53 @@
# YOLOv6s model
model = dict(
type='YOLOv6s',
pretrained='./weights/yolov6s.pt',
depth_multiple=0.33,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='siou'
)
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.0032,
lrf=0.12,
momentum=0.843,
weight_decay=0.00036,
warmup_epochs=2.0,
warmup_momentum=0.5,
warmup_bias_lr=0.05
)
data_aug = dict(
hsv_h=0.0138,
hsv_s=0.664,
hsv_v=0.464,
degrees=0.373,
translate=0.245,
scale=0.898,
shear=0.602,
flipud=0.00856,
fliplr=0.5,
mosaic=1.0,
mixup=0.243,
)

@ -0,0 +1,20 @@
# COCO 2017 dataset http://cocodataset.org
train: ../coco/images/train2017 # 118287 images
val: ../coco/images/val2017 # 5000 images
test: ../coco/images/test2017
anno_path: ../coco/annotations/instances_val2017.json
# number of classes
nc: 80
# whether it is coco dataset, only coco dataset should be set to True.
is_coco: True
# class names
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush' ]

@ -0,0 +1,11 @@
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
train: ../custom_dataset/images/train # train images
val: ../custom_dataset/images/val # val images
test: ../custom_dataset/images/test # test images (optional)
# whether it is coco dataset, only coco dataset should be set to True.
is_coco: False
# Classes
nc: 20 # number of classes
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

@ -0,0 +1,98 @@
# Export ONNX Model
## Check requirements
```shell
pip install onnx>=1.10.0
```
## Export script
```shell
python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1
```
#### Description of all arguments
- `--weights` : The path of yolov6 model weights.
- `--img` : Image size of model inputs.
- `--batch` : Batch size of model inputs.
- `--half` : Whether to export half-precision model.
- `--inplace` : Whether to set Detect() inplace.
- `--simplify` : Whether to simplify onnx. Not support in end to end export.
- `--end2end` : Whether to export end to end onnx model. Only support onnxruntime and TensorRT >= 8.0.0 .
- `--max-wh` : Default is None for TensorRT backend. Set int for onnxruntime backend.
- `--topk-all` : Topk objects for every image.
- `--iou-thres` : IoU threshold for NMS algorithm.
- `--conf-thres` : Confidence threshold for NMS algorithm.
- `--device` : Export device. Cuda device : 0 or 0,1,2,3 ... , CPU : cpu .
## Download
* [YOLOv6-nano](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.onnx)
* [YOLOv6-tiny](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6t.onnx)
* [YOLOv6-s](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.onnx)
## End2End export
Now YOLOv6 supports end to end detect for onnxruntime and TensorRT !
If you want to deploy in TensorRT, make sure you have installed TensorRT >= 8.0.0 !
### onnxruntime backend
#### Usage
```bash
python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1 \
--end2end \
--max-wh 7680
```
You will get an onnx with **NonMaxSuppression** operater .
The onnx outputs shape is ```nums x 7```.
```nums``` means the number of all objects which were detected.
```7``` means [`batch_index`,`x0`,`y0`,`x1`,` y1`,`classid`,`score`]
### TensorRT backend (TensorRT version>= 8.0.0)
#### Usage
```bash
python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1 \
--end2end
```
You will get an onnx with **[EfficientNMS_TRT](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin)** plugin .
The onnx outputs are as shown :
<img src="https://user-images.githubusercontent.com/92794867/176650971-a4fa3d65-10d4-4b65-b8ef-00a2ff13406c.png" height="300px" />
```num_dets``` means the number of object in every image in its batch .
```det_boxes``` means topk(100) object's location about [`x0`,`y0`,`x1`,` y1`] .
```det_scores``` means the confidence score of every topk(100) objects .
```det_classes``` means the category of every topk(100) objects .
You can export TensorRT engine use [trtexec](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#trtexec-ovr) tools.
#### Usage
``` shell
/path/to/trtexec \
--onnx=yolov6s.onnx \
--saveEngine=yolov6s.engine \
--fp16 # if export TensorRT fp16 model
```

@ -0,0 +1,112 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import time
import sys
import os
import torch
import torch.nn as nn
import onnx
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.models.yolo import *
from yolov6.models.effidehead import Detect
from yolov6.layers.common import *
from yolov6.utils.events import LOGGER
from yolov6.utils.checkpoint import load_checkpoint
from io import BytesIO
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov6s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
parser.add_argument('--max-wh', type=int, default=None, help='None for trt int for ort')
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
print(args)
t = time.time()
# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'
# Load PyTorch model
model = load_checkpoint(args.weights, map_location=device, inplace=True, fuse=True) # load FP32 model
for layer in model.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device) # image size(1,3,320,192) iDetection
# Update model
if args.half:
img, model = img.half(), model.half() # to FP16
model.eval()
for k, m in model.named_modules():
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, Detect):
m.inplace = args.inplace
if args.end2end:
from yolov6.models.end2end import End2End
model = End2End(model, max_obj=args.topk_all, iou_thres=args.iou_thres,
score_thres=args.conf_thres, max_wh=args.max_wh, device=device)
y = model(img) # dry run
# ONNX export
try:
LOGGER.info('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx') # filename
with BytesIO() as f:
torch.onnx.export(model, img, f, verbose=False, opset_version=12,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['image_arrays'],
output_names=['num_dets', 'det_boxes', 'det_scores', 'det_classes']
if args.end2end and args.max_wh is None else ['outputs'],)
f.seek(0)
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# Fix output shape
if args.end2end and args.max_wh is None:
shapes = [args.batch_size, 1, args.batch_size, args.topk_all, 4,
args.batch_size, args.topk_all, args.batch_size, args.topk_all]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
LOGGER.info('\nStarting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
LOGGER.info(f'Simplifier failure: {e}')
onnx.save(onnx_model, export_file)
LOGGER.info(f'ONNX export success, saved as {export_file}')
except Exception as e:
LOGGER.info(f'ONNX export failure: {e}')
# Finish
LOGGER.info('\nExport complete (%.2fs)' % (time.time() - t))
if args.end2end:
if args.max_wh is None:
LOGGER.info('\nYou can export tensorrt engine use trtexec tools.\nCommand is:')
LOGGER.info(f'trtexec --onnx={export_file} --saveEngine={export_file.replace(".onnx",".engine")}')

@ -0,0 +1,24 @@
## Export OpenVINO Model
### Check requirements
```shell
pip install --upgrade pip
pip install openvino-dev
```
### Export script
```shell
python deploy/OpenVINO/export_openvino.py --weights yolov6s.pt --img 640 --batch 1
```
### Download
* [YOLOv6-nano](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
* [YOLOv6-tiny](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
* [YOLOv6-s](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
### Speed test
```shell
benchmark_app -m yolov6s_openvino/yolov6s.xml -i data/images/image1.jpg -d CPU -niter 100 -progress
```

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import time
import sys
import os
import torch
import torch.nn as nn
import onnx
import subprocess
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.models.yolo import *
from yolov6.models.effidehead import Detect
from yolov6.layers.common import *
from yolov6.utils.events import LOGGER
from yolov6.utils.checkpoint import load_checkpoint
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov6s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
print(args)
t = time.time()
# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'
# Load PyTorch model
model = load_checkpoint(args.weights, map_location=device, inplace=True, fuse=True) # load FP32 model
for layer in model.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device) # image size(1,3,320,192) iDetection
# Update model
if args.half:
img, model = img.half(), model.half() # to FP16
model.eval()
for k, m in model.named_modules():
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, Detect):
m.inplace = args.inplace
y = model(img) # dry run
# ONNX export
try:
LOGGER.info('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, export_file, verbose=False, opset_version=12,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['image_arrays'],
output_names=['outputs'],
)
# Checks
onnx_model = onnx.load(export_file) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
LOGGER.info(f'ONNX export success, saved as {export_file}')
except Exception as e:
LOGGER.info(f'ONNX export failure: {e}')
# OpenVINO export
try:
LOGGER.info('\nStarting to export OpenVINO...')
import_file = args.weights.replace('.pt', '.onnx')
export_dir = str(import_file).replace('.onnx', '_openvino')
cmd = f"mo --input_model {import_file} --output_dir {export_dir} --data_type {'FP16' if args.half else 'FP32'}"
subprocess.check_output(cmd.split())
LOGGER.info(f'OpenVINO export success, saved as {export_dir}')
except Exception as e:
LOGGER.info(f'OpenVINO export failure: {e}')
# Finish
LOGGER.info('\nExport complete (%.2fs)' % (time.time() - t))

@ -0,0 +1,9 @@
# About the naming of YOLOv6
### WHY named YOLOv6 ?
The full name is actually MT-YOLOv6, which is called YOLOv6 for brevity. Our work is majorly inspired by the original idea of the one-stage YOLO detection algorithm and the implementation has leveraged various techniques and tricks of former relevant work . Therefore, we named the project YOLOv6 to pay tribute to the work of YOLO series. Furthermore, we have indeed adopted some novel method and made solid engineering improvements to dedicate the algorithm to industrial applications.
As for the project, we'll continue to improve and maintain it, contributing more values for industrial applications.
P.S. We are contacting the authors of YOLO series about the naming of YOLOv6.
Thanks for your attention

@ -0,0 +1,41 @@
# Test speed
This guidence explains how to reproduce speed results of YOLOv6. For fair comparison, the speed results do not contain the time cost of data pre-processing and NMS post-processing.
## 0. Prepare model
Download the models you want to test from the latest release.
## 1. Prepare testing environment
Refer to README, install packages corresponding to CUDA, CUDNN and TensorRT version.
Here, we use Torch1.8.0 inference on V100 and TensorRT 7.2 on T4.
## 2. Reproduce speed
#### 2.1 Torch Inference on V100
To get inference speed without TensorRT on V100, you can run the following command:
```shell
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6n.pt --task speed [--half]
```
- Speed results with batchsize = 1 are unstable in multiple runs, thus we do not provide the bs1 speed results.
#### 2.2 TensorRT Inference on T4
To get inference speed with TensorRT in FP16 mode on T4, you can follow the steps below:
First, export pytorch model as onnx format using the following command:
```shell
python deploy/ONNX/export_onnx.py --weights yolov6n.pt --device 0 --batch [1 or 32]
```
Second, generate an inference trt engine and test speed using `trtexec`:
```
trtexec --onnx=yolov6n.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw
```

@ -0,0 +1,143 @@
# Train Custom Data
This guidence explains how to train your own custom data with YOLOv6 (take fine-tuning YOLOv6-s model for example).
## 0. Before you start
Clone this repo and follow README.md to install requirements in a Python3.8 environment.
## 1. Prepare your own dataset
**Step 1** Prepare your own dataset with images. For labeling images, you can use tools like [Labelme](https://github.com/wkentaro/labelme).
**Step 2** Generate label files in YOLO format.
One image corresponds to one label file, and the label format example is presented as below.
```json
# class_id center_x center_y bbox_width bbox_height
0 0.300926 0.617063 0.601852 0.765873
1 0.575 0.319531 0.4 0.551562
```
- Each row represents one object.
- Class id starts from `0`.
- Boundingbox coordinates must be in normalized `xywh` format (from 0 - 1). If your boxes are in pixels, divide `center_x` and `bbox_width` by image width, and `center_y` and `bbox_height` by image height.
**Step 3** Organize directories.
Organize your directory of custom dataset as follows:
```shell
custom_dataset
├── images
│   ├── train
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   ├── val
│   │   ├── val0.jpg
│   │   └── val1.jpg
│   └── test
│   ├── test0.jpg
│   └── test1.jpg
└── labels
├── train
│   ├── train0.txt
│   └── train1.txt
├── val
│   ├── val0.txt
│   └── val1.txt
└── test
├── test0.txt
└── test1.txt
```
**Step 4** Create `dataset.yaml` in `$YOLOv6_DIR/data`.
```yaml
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
train: ../custom_dataset/images/train # train images
val: ../custom_dataset/images/val # val images
test: ../custom_dataset/images/test # test images (optional)
# whether it is coco dataset, only coco dataset should be set to True.
is_coco: False
# Classes
nc: 20 # number of classes
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
```
## 2. Create a config file
We use a config file to specify the network structure and training setting, including optimizer and data augmentation hyperparameters.
If you create a new config file, please put it under the configs directory.
Or just use the provided config file in `$YOLOV6_HOME/configs/*_finetune.py`.
```python
## YOLOv6s Model config file
model = dict(
type='YOLOv6s',
pretrained='./weights/yolov6s.pt', # download pretrain model from YOLOv6 github if use pretrained model
depth_multiple = 0.33,
width_multiple = 0.50,
...
)
solver=dict(
optim='SGD',
lr_scheduler='Cosine',
...
)
data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
...
)
```
## 3. Train
Single GPU
```shell
python tools/train.py --batch 256 --conf configs/yolov6s_finetune.py --data data/data.yaml --device 0
```
Multi GPUs (DDP mode recommended)
```shell
python -m torch.distributed.launch --nproc_per_node 4 tools/train.py --batch 256 --conf configs/yolov6s_finetune.py --data data/data.yaml --device 0,1,2,3
```
## 4. Evaluation
```shell
python tools/eval.py --data data/data.yaml --weights output_dir/name/weights/best_ckpt.pt --device 0
```
## 5. Inference
```shell
python tools/infer.py --weights output_dir/name/weights/best_ckpt.pt --source img.jpg --device 0
```
## 6. Deployment
Export as ONNX Format
```shell
python deploy/ONNX/export_onnx.py --weights output_dir/name/weights/best_ckpt.pt --device 0
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

@ -0,0 +1,16 @@
# pip install -r requirements.txt
# python3.8 environment
torch>=1.8.0
torchvision>=0.9.0
numpy>=1.18.5
opencv-python>=4.1.2
PyYAML>=5.3.1
scipy>=1.4.1
tqdm>=4.41.0
addict>=2.4.0
tensorboard>=2.7.0
pycocotools>=2.0
onnx>=1.10.0 # ONNX export
onnx-simplifier>=0.3.6 # ONNX simplifier
thop # FLOPs computation

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 420 KiB

@ -0,0 +1,93 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import os.path as osp
import sys
import torch
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.core.evaler import Evaler
from yolov6.utils.events import LOGGER
from yolov6.utils.general import increment_name
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Evalating', add_help=add_help)
parser.add_argument('--data', type=str, default='./data/coco.yaml', help='dataset.yaml path')
parser.add_argument('--weights', type=str, default='./weights/yolov6s.pt', help='model.pt path(s)')
parser.add_argument('--batch-size', type=int, default=32, help='batch size')
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.65, help='NMS IoU threshold')
parser.add_argument('--task', default='val', help='val, or speed')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', default=False, action='store_true', help='whether to use fp16 infer')
parser.add_argument('--save_dir', type=str, default='runs/val/', help='evaluation save dir')
parser.add_argument('--name', type=str, default='exp', help='save evaluation results to save_dir/name')
args = parser.parse_args()
LOGGER.info(args)
return args
@torch.no_grad()
def run(data,
weights=None,
batch_size=32,
img_size=640,
conf_thres=0.001,
iou_thres=0.65,
task='val',
device='',
half=False,
model=None,
dataloader=None,
save_dir='',
name = ''
):
""" Run the evaluation process
This function is the main process of evaluataion, supporting image file and dir containing images.
It has tasks of 'val', 'train' and 'speed'. Task 'train' processes the evaluation during training phase.
Task 'val' processes the evaluation purely and return the mAP of model.pt. Task 'speed' precesses the
evaluation of inference speed of model.pt.
"""
# task
Evaler.check_task(task)
if task == 'train':
save_dir = save_dir
else:
save_dir = str(increment_name(osp.join(save_dir, name)))
os.makedirs(save_dir, exist_ok=True)
# reload thres/device/half/data according task
conf_thres, iou_thres = Evaler.reload_thres(conf_thres, iou_thres, task)
device = Evaler.reload_device(device, model, task)
half = device.type != 'cpu' and half
data = Evaler.reload_dataset(data) if isinstance(data, str) else data
# init
val = Evaler(data, batch_size, img_size, conf_thres, \
iou_thres, device, half, save_dir)
model = val.init_model(model, weights, task)
dataloader = val.init_data(dataloader, task)
# eval
model.eval()
pred_result = val.predict_model(model, dataloader, task)
eval_result = val.eval_model(pred_result, model, dataloader, task)
return eval_result
def main(args):
run(**vars(args))
if __name__ == "__main__":
args = get_args_parser()
main(args)

@ -0,0 +1,108 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import sys
import os.path as osp
import torch
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.utils.events import LOGGER
from yolov6.core.inferer import Inferer
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Inference.', add_help=add_help)
parser.add_argument('--weights', type=str, default='weights/yolov6s.pt', help='model path(s) for inference.')
parser.add_argument('--source', type=str, default='data/images', help='the source path, e.g. image-file/dir.')
parser.add_argument('--yaml', type=str, default='data/coco.yaml', help='data yaml file.')
parser.add_argument('--img-size', type=int, default=640, help='the image-size(h,w) in inference size.')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold for inference.')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold for inference.')
parser.add_argument('--max-det', type=int, default=1000, help='maximal inferences per image.')
parser.add_argument('--device', default='0', help='device to run our model i.e. 0 or 0,1,2,3 or cpu.')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt.')
parser.add_argument('--save-img', action='store_false', help='save visuallized inference results.')
parser.add_argument('--classes', nargs='+', type=int, help='filter by classes, e.g. --classes 0, or --classes 0 2 3.')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS.')
parser.add_argument('--project', default='runs/inference', help='save inference results to project/name.')
parser.add_argument('--name', default='exp', help='save inference results to project/name.')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels.')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences.')
parser.add_argument('--half', action='store_true', help='whether to use FP16 half-precision inference.')
args = parser.parse_args()
LOGGER.info(args)
return args
@torch.no_grad()
def run(weights=osp.join(ROOT, 'yolov6s.pt'),
source=osp.join(ROOT, 'data/images'),
yaml=None,
img_size=640,
conf_thres=0.25,
iou_thres=0.45,
max_det=1000,
device='',
save_txt=False,
save_img=True,
classes=None,
agnostic_nms=False,
project=osp.join(ROOT, 'runs/inference'),
name='exp',
hide_labels=False,
hide_conf=False,
half=False,
):
""" Inference process
This function is the main process of inference, supporting image files or dirs containing images.
Args:
weights: The path of model.pt, e.g. yolov6s.pt
source: Source path, supporting image files or dirs containing images.
yaml: Data yaml file, .
img_size: Inference image-size, e.g. 640
conf_thres: Confidence threshold in inference, e.g. 0.25
iou_thres: NMS IOU threshold in inference, e.g. 0.45
max_det: Maximal detections per image, e.g. 1000
device: Cuda device, e.e. 0, or 0,1,2,3 or cpu
save_txt: Save results to *.txt
save_img: Save visualized inference results
classes: Filter by class: --class 0, or --class 0 2 3
agnostic_nms: Class-agnostic NMS
project: Save results to project/name
name: Save results to project/name, e.g. 'exp'
line_thickness: Bounding box thickness (pixels), e.g. 3
hide_labels: Hide labels, e.g. False
hide_conf: Hide confidences
half: Use FP16 half-precision inference, e.g. False
"""
# create save dir
save_dir = osp.join(project, name)
if (save_img or save_txt) and not osp.exists(save_dir):
os.makedirs(save_dir)
else:
LOGGER.warning('Save directory already existed')
if save_txt:
os.mkdir(osp.join(save_dir, 'labels'))
# Inference
inferer = Inferer(source, weights, device, yaml, img_size, half)
inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf)
if save_txt or save_img:
LOGGER.info(f"Results saved to {save_dir}")
def main(args):
run(**vars(args))
if __name__ == "__main__":
args = get_args_parser()
main(args)

@ -0,0 +1,210 @@
#
# Modified by Meituan
# 2022.6.24
#
# Copyright 2019 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import glob
import random
import logging
import cv2
import numpy as np
from PIL import Image
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
def preprocess_yolov6(image, channels=3, height=224, width=224):
"""Pre-processing for YOLOv6-based Object Detection Models
Parameters
----------
image: PIL.Image
The image resulting from PIL.Image.open(filename) to preprocess
channels: int
The number of channels the image has (Usually 1 or 3)
height: int
The desired height of the image (usually 640)
width: int
The desired width of the image (usually 640)
Returns
-------
img_data: numpy array
The preprocessed image data in the form of a numpy array
"""
# Get the image in CHW format
resized_image = image.resize((width, height), Image.BILINEAR)
img_data = np.asarray(resized_image).astype(np.float32)
if len(img_data.shape) == 2:
# For images without a channel dimension, we stack
img_data = np.stack([img_data] * 3)
logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape))
else:
img_data = img_data.transpose([2, 0, 1])
mean_vec = np.array([0.0, 0.0, 0.0])
stddev_vec = np.array([1.0, 1.0, 1.0])
assert img_data.shape[0] == channels
for i in range(img_data.shape[0]):
# Scale each pixel to [0, 1] and normalize per channel.
img_data[i, :, :] = (img_data[i, :, :] / 255.0 - mean_vec[i]) / stddev_vec[i]
return img_data
def get_int8_calibrator(calib_cache, calib_data, max_calib_size, calib_batch_size):
# Use calibration cache if it exists
if os.path.exists(calib_cache):
logger.info("Skipping calibration files, using calibration cache: {:}".format(calib_cache))
calib_files = []
# Use calibration files from validation dataset if no cache exists
else:
if not calib_data:
raise ValueError("ERROR: Int8 mode requested, but no calibration data provided. Please provide --calibration-data /path/to/calibration/files")
calib_files = get_calibration_files(calib_data, max_calib_size)
# Choose pre-processing function for INT8 calibration
preprocess_func = preprocess_yolov6
int8_calibrator = ImageCalibrator(calibration_files=calib_files,
batch_size=calib_batch_size,
cache_file=calib_cache)
return int8_calibrator
def get_calibration_files(calibration_data, max_calibration_size=None, allowed_extensions=(".jpeg", ".jpg", ".png")):
"""Returns a list of all filenames ending with `allowed_extensions` found in the `calibration_data` directory.
Parameters
----------
calibration_data: str
Path to directory containing desired files.
max_calibration_size: int
Max number of files to use for calibration. If calibration_data contains more than this number,
a random sample of size max_calibration_size will be returned instead. If None, all samples will be used.
Returns
-------
calibration_files: List[str]
List of filenames contained in the `calibration_data` directory ending with `allowed_extensions`.
"""
logger.info("Collecting calibration files from: {:}".format(calibration_data))
calibration_files = [path for path in glob.iglob(os.path.join(calibration_data, "**"), recursive=True)
if os.path.isfile(path) and path.lower().endswith(allowed_extensions)]
logger.info("Number of Calibration Files found: {:}".format(len(calibration_files)))
if len(calibration_files) == 0:
raise Exception("ERROR: Calibration data path [{:}] contains no files!".format(calibration_data))
if max_calibration_size:
if len(calibration_files) > max_calibration_size:
logger.warning("Capping number of calibration images to max_calibration_size: {:}".format(max_calibration_size))
random.seed(42) # Set seed for reproducibility
calibration_files = random.sample(calibration_files, max_calibration_size)
return calibration_files
# https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Int8/EntropyCalibrator2.html
class ImageCalibrator(trt.IInt8EntropyCalibrator2):
"""INT8 Calibrator Class for Imagenet-based Image Classification Models.
Parameters
----------
calibration_files: List[str]
List of image filenames to use for INT8 Calibration
batch_size: int
Number of images to pass through in one batch during calibration
input_shape: Tuple[int]
Tuple of integers defining the shape of input to the model (Default: (3, 224, 224))
cache_file: str
Name of file to read/write calibration cache from/to.
preprocess_func: function -> numpy.ndarray
Pre-processing function to run on calibration data. This should match the pre-processing
done at inference time. In general, this function should return a numpy array of
shape `input_shape`.
"""
def __init__(self, calibration_files=[], batch_size=32, input_shape=(3, 224, 224),
cache_file="calibration.cache", use_cv2=False):
super().__init__()
self.input_shape = input_shape
self.cache_file = cache_file
self.batch_size = batch_size
self.batch = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
self.device_input = cuda.mem_alloc(self.batch.nbytes)
self.files = calibration_files
self.use_cv2 = use_cv2
# Pad the list so it is a multiple of batch_size
if len(self.files) % self.batch_size != 0:
logger.info("Padding # calibration files to be a multiple of batch_size {:}".format(self.batch_size))
self.files += calibration_files[(len(calibration_files) % self.batch_size):self.batch_size]
self.batches = self.load_batches()
self.preprocess_func = preprocess_yolov6
def load_batches(self):
# Populates a persistent self.batch buffer with images.
for index in range(0, len(self.files), self.batch_size):
for offset in range(self.batch_size):
if self.use_cv2:
image = cv2.imread(self.files[index + offset])
else:
image = Image.open(self.files[index + offset])
self.batch[offset] = self.preprocess_func(image, *self.input_shape)
logger.info("Calibration images pre-processed: {:}/{:}".format(index+self.batch_size, len(self.files)))
yield self.batch
def get_batch_size(self):
return self.batch_size
def get_batch(self, names):
try:
# Assume self.batches is a generator that provides batch data.
batch = next(self.batches)
# Assume that self.device_input is a device buffer allocated by the constructor.
cuda.memcpy_htod(self.device_input, batch)
return [int(self.device_input)]
except StopIteration:
# When we're out of batches, we return either [] or None.
# This signals to TensorRT that there is no calibration data remaining.
return None
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
logger.info("Using calibration cache to save time: {:}".format(self.cache_file))
return f.read()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
logger.info("Caching calibration data for future use: {:}".format(self.cache_file))
f.write(cache)

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2020 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -0,0 +1,83 @@
# ONNX -> TensorRT INT8
These scripts were last tested using the
[NGC TensorRT Container Version 20.06-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorrt).
You can see the corresponding framework versions for this container [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-container-release-notes/rel_20.06.html#rel_20.06).
## Quickstart
> **NOTE**: This INT8 example is only valid for **fixed-shape** ONNX models at the moment.
>
INT8 Calibration on **dynamic-shape** models is now supported, however this example has not been updated
to reflect that yet. For more details on INT8 Calibration for **dynamic-shape** models, please
see the [documentation](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#int8-calib-dynamic-shapes).
### 1. Convert ONNX model to TensorRT INT8
See `./onnx_to_tensorrt.py -h` for full list of command line arguments.
```bash
./onnx_to_tensorrt.py --explicit-batch \
--onnx resnet50/model.onnx \
--fp16 \
--int8 \
--calibration-cache="caches/yolov6.cache" \
-o resnet50.int8.engine
```
See the [INT8 Calibration](#int8-calibration) section below for details on calibration
using your own model or different data, where you don't have an existing calibration cache
or want to create a new one.
## INT8 Calibration
See [ImagenetCalibrator.py](ImagenetCalibrator.py) for a reference implementation
of TensorRT's [IInt8EntropyCalibrator2](https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Int8/EntropyCalibrator2.html).
This class can be tweaked to work for other kinds of models, inputs, etc.
In the [Quickstart](#quickstart) section above, we made use of a pre-existing cache,
[caches/yolov6.cache](caches/yolov6.cache), to save time for the sake of an example.
However, to calibrate using different data or a different model, you can do so with the `--calibration-data` argument.
* This requires that you've mounted a dataset, such as Imagenet, to use for calibration.
* Add something like `-v /imagenet:/imagenet` to your Docker command in Step (1)
to mount a dataset found locally at `/imagenet`.
* You can specify your own `preprocess_func` by defining it inside of `ImageCalibrator.py`
```bash
# Path to dataset to use for calibration.
# **Not necessary if you already have a calibration cache from a previous run.
CALIBRATION_DATA="/imagenet"
# Truncate calibration images to a random sample of this amount if more are found.
# **Not necessary if you already have a calibration cache from a previous run.
MAX_CALIBRATION_SIZE=512
# Calibration cache to be used instead of calibration data if it already exists,
# or the cache will be created from the calibration data if it doesn't exist.
CACHE_FILENAME="caches/yolov6.cache"
# Path to ONNX model
ONNX_MODEL="model/yolov6.onnx"
# Path to write TensorRT engine to
OUTPUT="yolov6.int8.engine"
# Creates an int8 engine from your ONNX model, creating ${CACHE_FILENAME} based
# on your ${CALIBRATION_DATA}, unless ${CACHE_FILENAME} already exists, then
# it will use simply use that instead.
python3 onnx_to_tensorrt.py --fp16 --int8 -v \
--max_calibration_size=${MAX_CALIBRATION_SIZE} \
--calibration-data=${CALIBRATION_DATA} \
--calibration-cache=${CACHE_FILENAME} \
--preprocess_func=${PREPROCESS_FUNC} \
--explicit-batch \
--onnx ${ONNX_MODEL} -o ${OUTPUT}
```
### Pre-processing
In order to calibrate your model correctly, you should `pre-process` your data the same way
that you would during inference.

@ -0,0 +1,220 @@
#!/usr/bin/env python3
#
# Modified by Meituan
# 2022.6.24
#
# Copyright 2019 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import glob
import math
import logging
import argparse
import tensorrt as trt
#sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
TRT_LOGGER = trt.Logger()
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
def add_profiles(config, inputs, opt_profiles):
logger.debug("=== Optimization Profiles ===")
for i, profile in enumerate(opt_profiles):
for inp in inputs:
_min, _opt, _max = profile.get_shape(inp.name)
logger.debug("{} - OptProfile {} - Min {} Opt {} Max {}".format(inp.name, i, _min, _opt, _max))
config.add_optimization_profile(profile)
def mark_outputs(network):
# Mark last layer's outputs if not already marked
# NOTE: This may not be correct in all cases
last_layer = network.get_layer(network.num_layers-1)
if not last_layer.num_outputs:
logger.error("Last layer contains no outputs.")
return
for i in range(last_layer.num_outputs):
network.mark_output(last_layer.get_output(i))
def check_network(network):
if not network.num_outputs:
logger.warning("No output nodes found, marking last layer's outputs as network outputs. Correct this if wrong.")
mark_outputs(network)
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
max_len = max([len(inp.name) for inp in inputs] + [len(out.name) for out in outputs])
logger.debug("=== Network Description ===")
for i, inp in enumerate(inputs):
logger.debug("Input {0} | Name: {1:{2}} | Shape: {3}".format(i, inp.name, max_len, inp.shape))
for i, out in enumerate(outputs):
logger.debug("Output {0} | Name: {1:{2}} | Shape: {3}".format(i, out.name, max_len, out.shape))
def get_batch_sizes(max_batch_size):
# Returns powers of 2, up to and including max_batch_size
max_exponent = math.log2(max_batch_size)
for i in range(int(max_exponent)+1):
batch_size = 2**i
yield batch_size
if max_batch_size != batch_size:
yield max_batch_size
# TODO: This only covers dynamic shape for batch size, not dynamic shape for other dimensions
def create_optimization_profiles(builder, inputs, batch_sizes=[1,8,16,32,64]):
# Check if all inputs are fixed explicit batch to create a single profile and avoid duplicates
if all([inp.shape[0] > -1 for inp in inputs]):
profile = builder.create_optimization_profile()
for inp in inputs:
fbs, shape = inp.shape[0], inp.shape[1:]
profile.set_shape(inp.name, min=(fbs, *shape), opt=(fbs, *shape), max=(fbs, *shape))
return [profile]
# Otherwise for mixed fixed+dynamic explicit batch inputs, create several profiles
profiles = {}
for bs in batch_sizes:
if not profiles.get(bs):
profiles[bs] = builder.create_optimization_profile()
for inp in inputs:
shape = inp.shape[1:]
# Check if fixed explicit batch
if inp.shape[0] > -1:
bs = inp.shape[0]
profiles[bs].set_shape(inp.name, min=(bs, *shape), opt=(bs, *shape), max=(bs, *shape))
return list(profiles.values())
def main():
parser = argparse.ArgumentParser(description="Creates a TensorRT engine from the provided ONNX file.\n")
parser.add_argument("--onnx", required=True, help="The ONNX model file to convert to TensorRT")
parser.add_argument("-o", "--output", type=str, default="model.engine", help="The path at which to write the engine")
parser.add_argument("-b", "--max-batch-size", type=int, help="The max batch size for the TensorRT engine input")
parser.add_argument("-v", "--verbosity", action="count", help="Verbosity for logging. (None) for ERROR, (-v) for INFO/WARNING/ERROR, (-vv) for VERBOSE.")
parser.add_argument("--explicit-batch", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH.")
parser.add_argument("--explicit-precision", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION.")
parser.add_argument("--gpu-fallback", action='store_true', help="Set trt.BuilderFlag.GPU_FALLBACK.")
parser.add_argument("--refittable", action='store_true', help="Set trt.BuilderFlag.REFIT.")
parser.add_argument("--debug", action='store_true', help="Set trt.BuilderFlag.DEBUG.")
parser.add_argument("--strict-types", action='store_true', help="Set trt.BuilderFlag.STRICT_TYPES.")
parser.add_argument("--fp16", action="store_true", help="Attempt to use FP16 kernels when possible.")
parser.add_argument("--int8", action="store_true", help="Attempt to use INT8 kernels when possible. This should generally be used in addition to the --fp16 flag. \
ONLY SUPPORTS RESNET-LIKE MODELS SUCH AS RESNET50/VGG16/INCEPTION/etc.")
parser.add_argument("--calibration-cache", help="(INT8 ONLY) The path to read/write from calibration cache.", default="calibration.cache")
parser.add_argument("--calibration-data", help="(INT8 ONLY) The directory containing {*.jpg, *.jpeg, *.png} files to use for calibration. (ex: Imagenet Validation Set)", default=None)
parser.add_argument("--calibration-batch-size", help="(INT8 ONLY) The batch size to use during calibration.", type=int, default=128)
parser.add_argument("--max-calibration-size", help="(INT8 ONLY) The max number of data to calibrate on from --calibration-data.", type=int, default=2048)
parser.add_argument("-s", "--simple", action="store_true", help="Use SimpleCalibrator with random data instead of ImagenetCalibrator for INT8 calibration.")
args, _ = parser.parse_known_args()
print(args)
# Adjust logging verbosity
if args.verbosity is None:
TRT_LOGGER.min_severity = trt.Logger.Severity.ERROR
# -v
elif args.verbosity == 1:
TRT_LOGGER.min_severity = trt.Logger.Severity.INFO
# -vv
else:
TRT_LOGGER.min_severity = trt.Logger.Severity.VERBOSE
logger.info("TRT_LOGGER Verbosity: {:}".format(TRT_LOGGER.min_severity))
# Network flags
network_flags = 0
if args.explicit_batch:
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if args.explicit_precision:
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
builder_flag_map = {
'gpu_fallback': trt.BuilderFlag.GPU_FALLBACK,
'refittable': trt.BuilderFlag.REFIT,
'debug': trt.BuilderFlag.DEBUG,
'strict_types': trt.BuilderFlag.STRICT_TYPES,
'fp16': trt.BuilderFlag.FP16,
'int8': trt.BuilderFlag.INT8,
}
# Building engine
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(network_flags) as network, \
builder.create_builder_config() as config, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
config.max_workspace_size = 2**30 # 1GiB
# Set Builder Config Flags
for flag in builder_flag_map:
if getattr(args, flag):
logger.info("Setting {}".format(builder_flag_map[flag]))
config.set_flag(builder_flag_map[flag])
# Fill network atrributes with information by parsing model
with open(args.onnx, "rb") as f:
if not parser.parse(f.read()):
print('ERROR: Failed to parse the ONNX file: {}'.format(args.onnx))
for error in range(parser.num_errors):
print(parser.get_error(error))
sys.exit(1)
# Display network info and check certain properties
check_network(network)
if args.explicit_batch:
# Add optimization profiles
batch_sizes = [1, 8, 16, 32, 64]
inputs = [network.get_input(i) for i in range(network.num_inputs)]
opt_profiles = create_optimization_profiles(builder, inputs, batch_sizes)
add_profiles(config, inputs, opt_profiles)
# Implicit Batch Network
else:
builder.max_batch_size = args.max_batch_size
opt_profiles = []
# Precision flags
if args.fp16 and not builder.platform_has_fast_fp16:
logger.warning("FP16 not supported on this platform.")
if args.int8 and not builder.platform_has_fast_int8:
logger.warning("INT8 not supported on this platform.")
if args.int8:
from Calibrator import ImageCalibrator, get_int8_calibrator # local module
config.int8_calibrator = get_int8_calibrator(args.calibration_cache,
args.calibration_data,
args.max_calibration_size,
args.calibration_batch_size)
logger.info("Building Engine...")
with builder.build_engine(network, config) as engine, open(args.output, "wb") as f:
logger.info("Serializing engine to file: {:}".format(args.output))
f.write(engine.serialize())
if __name__ == "__main__":
main()

@ -0,0 +1,23 @@
# Path to ONNX model
# ex: ../yolov6.onnx
ONNX_MODEL=$1
# Path to dataset to use for calibration.
# **Not necessary if you already have a calibration cache from a previous run.
CALIBRATION_DATA=$2
# Path to Cache file to Serving
# ex: ./caches/demo.cache
CACHE_FILENAME=$3
# Path to write TensorRT engine to
OUTPUT=$4
# Creates an int8 engine from your ONNX model, creating ${CACHE_FILENAME} based
# on your ${CALIBRATION_DATA}, unless ${CACHE_FILENAME} already exists, then
# it will use simply use that instead.
python3 onnx_to_tensorrt.py --fp16 --int8 -v \
--calibration-data=${CALIBRATION_DATA} \
--calibration-cache=${CACHE_FILENAME} \
--explicit-batch \
--onnx ${ONNX_MODEL} -o ${OUTPUT}

@ -0,0 +1,7 @@
# pip install -r requirements.txt
# python3.8 environment
tensorrt # TensorRT 8.0+
pycuda==2020.1 # CUDA 11.0
nvidia-pyindex
pytorch-quantization

@ -0,0 +1,39 @@
#
# QAT_quantizer.py
# YOLOv6
#
# Created by Meituan on 2022/06/24.
# Copyright © 2022
#
from absl import logging
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
# Call this function before defining the model
def tensorrt_official_qat():
# Quantization Aware Training is based on Straight Through Estimator (STE) derivative approximation.
# It is some time known as “quantization aware training”.
# PyTorch-Quantization is a toolkit for training and evaluating PyTorch models with simulated quantization.
# Quantization can be added to the model automatically, or manually, allowing the model to be tuned for accuracy and performance.
# Quantization is compatible with NVIDIAs high performance integer kernels which leverage integer Tensor Cores.
# The quantized model can be exported to ONNX and imported by TensorRT 8.0 and later.
# https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/finetune_quant_resnet50.ipynb
# The example to export the
# model.eval()
# quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX
# opset_version = 13
# Export ONNX for multiple batch sizes
# print("Creating ONNX file: " + onnx_filename)
# dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model
# torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True)
try:
quant_modules.initialize()
except NameError:
logging.info("initialzation error for quant_modules")
# def QAT_quantizer():
# coming soon

@ -0,0 +1,94 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import os.path as osp
import torch
import torch.distributed as dist
import sys
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.core.engine import Trainer
from yolov6.utils.config import Config
from yolov6.utils.events import LOGGER, save_yaml
from yolov6.utils.envs import get_envs, select_device, set_random_seed
from yolov6.utils.general import increment_name
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help)
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='path of dataset')
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiments description file')
parser.add_argument('--img-size', default=640, type=int, help='train, val image size (pixels)')
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs')
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run')
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)')
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--eval-interval', default=20, type=int, help='evaluate at every interval epochs')
parser.add_argument('--eval-final-only', action='store_true', help='only evaluate at the final epoch')
parser.add_argument('--heavy-eval-range', default=50, type=int,
help='evaluating every epoch for last such epochs (can be jointly used with --eval-interval)')
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets')
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets')
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs')
parser.add_argument('--name', default='exp', type=str, help='experiment name, saved to output_dir/name')
parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
parser.add_argument('--gpu_count', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')
parser.add_argument('--resume', type=str, default=None, help='resume the corresponding ckpt')
return parser
def check_and_init(args):
'''check config files and device, and initialize '''
# check files
master_process = args.rank == 0 if args.world_size > 1 else args.rank == -1
args.save_dir = str(increment_name(osp.join(args.output_dir, args.name), master_process))
cfg = Config.fromfile(args.conf_file)
# check device
device = select_device(args.device)
# set random seed
set_random_seed(1+args.rank, deterministic=(args.rank == -1))
# save args
if master_process:
os.makedirs(args.save_dir)
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))
return cfg, device
def main(args):
'''main function of training'''
# Setup
args.rank, args.local_rank, args.world_size = get_envs()
LOGGER.info(f'training args are: {args}\n')
cfg, device = check_and_init(args)
if args.local_rank != -1: # if DDP mode
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
LOGGER.info('Initializing process group... ')
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", \
init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size)
# Start
trainer = Trainer(args, cfg, device)
trainer.train()
# End
if args.world_size > 1 and args.rank == 0:
LOGGER.info('Destroying process group... ')
dist.destroy_process_group()
if __name__ == '__main__':
args = get_args_parser().parse_args()
main(args)

@ -0,0 +1,276 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import time
from copy import deepcopy
import os.path as osp
from tqdm import tqdm
import numpy as np
import torch
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
import tools.eval as eval
from yolov6.data.data_load import create_dataloader
from yolov6.models.yolo import build_model
from yolov6.models.loss import ComputeLoss
from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog
from yolov6.utils.ema import ModelEMA, de_parallel
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
from yolov6.solver.build import build_optimizer, build_lr_scheduler
class Trainer:
def __init__(self, args, cfg, device):
self.args = args
self.cfg = cfg
self.device = device
self.rank = args.rank
self.local_rank = args.local_rank
self.world_size = args.world_size
self.main_process = self.rank in [-1, 0]
self.save_dir = args.save_dir
# get data loader
self.data_dict = load_yaml(args.data_path)
self.num_classes = self.data_dict['nc']
self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
# get model and optimizer
model = self.get_model(args, cfg, self.num_classes, device)
self.optimizer = self.get_optimizer(args, cfg, model)
self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
self.ema = ModelEMA(model) if self.main_process else None
self.model = self.parallel_model(args, model, device)
self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']
# tensorboard
self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None
self.start_epoch = 0
# resume ckpt from user-defined path
if args.resume:
assert os.path.isfile(args.resume), 'ERROR: --resume checkpoint does not exists'
self.ckpt = torch.load(args.resume, map_location='cpu')
self.start_epoch = self.ckpt['epoch'] + 1
self.max_epoch = args.epochs
self.max_stepnum = len(self.train_loader)
self.batch_size = args.batch_size
self.img_size = args.img_size
# Training Process
def train(self):
try:
self.train_before_loop()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.train_in_loop()
except Exception as _:
LOGGER.error('ERROR in training loop or eval/save model.')
raise
finally:
self.train_after_loop()
# Training loop for each epoch
def train_in_loop(self):
try:
self.prepare_for_steps()
for self.step, self.batch_data in self.pbar:
self.train_in_steps()
self.print_details()
except Exception as _:
LOGGER.error('ERROR in training steps.')
raise
try:
self.eval_and_save()
except Exception as _:
LOGGER.error('ERROR in evaluate and save model.')
raise
# Training loop for batchdata
def train_in_steps(self):
images, targets = self.prepro_data(self.batch_data, self.device)
# forward
with amp.autocast(enabled=self.device != 'cpu'):
preds = self.model(images)
total_loss, loss_items = self.compute_loss(preds, targets)
if self.rank != -1:
total_loss *= self.world_size
# backward
self.scaler.scale(total_loss).backward()
self.loss_items = loss_items
self.update_optimizer()
def eval_and_save(self):
remaining_epochs = self.max_epoch - self.epoch
eval_interval = self.args.eval_interval if remaining_epochs > self.args.heavy_eval_range else 1
is_val_epoch = (not self.args.eval_final_only or (remaining_epochs == 1)) and (self.epoch % eval_interval == 0)
if self.main_process:
self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model
if is_val_epoch:
self.eval_model()
self.ap = self.evaluate_results[0] * 0.1 + self.evaluate_results[1] * 0.9
self.best_ap = max(self.ap, self.best_ap)
# save ckpt
ckpt = {
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'epoch': self.epoch,
}
save_ckpt_dir = osp.join(self.save_dir, 'weights')
save_checkpoint(ckpt, (is_val_epoch) and (self.ap == self.best_ap), save_ckpt_dir, model_name='last_ckpt')
del ckpt
# log for tensorboard
write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)
def eval_model(self):
results = eval.run(self.data_dict,
batch_size=self.batch_size // self.world_size * 2,
img_size=self.img_size,
model=self.ema.ema,
dataloader=self.val_loader,
save_dir=self.save_dir,
task='train')
LOGGER.info(f"Epoch: {self.epoch} | mAP@0.5: {results[0]} | mAP@0.50:0.95: {results[1]}")
self.evaluate_results = results[:2]
def train_before_loop(self):
LOGGER.info('Training start...')
self.start_time = time.time()
self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000)
self.scheduler.last_epoch = self.start_epoch - 1
self.last_opt_step = -1
self.scaler = amp.GradScaler(enabled=self.device != 'cpu')
self.best_ap, self.ap = 0.0, 0.0
self.evaluate_results = (0, 0) # AP50, AP50_95
self.compute_loss = ComputeLoss(iou_type=self.cfg.model.head.iou_type)
if hasattr(self, "ckpt"):
resume_state_dict = self.ckpt['model'].float().state_dict() # checkpoint's state_dict as FP32
self.model.load_state_dict(resume_state_dict, strict=True) # load model state dict
self.optimizer.load_state_dict(self.ckpt['optimizer']) # load optimizer
self.start_epoch = self.ckpt['epoch'] + 1
self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict()) # load ema state dict
self.ema.updates = self.ckpt['updates']
def prepare_for_steps(self):
if self.epoch > self.start_epoch:
self.scheduler.step()
self.model.train()
if self.rank != -1:
self.train_loader.sampler.set_epoch(self.epoch)
self.mean_loss = torch.zeros(4, device=self.device)
self.optimizer.zero_grad()
LOGGER.info(('\n' + '%10s' * 5) % ('Epoch', 'iou_loss', 'l1_loss', 'obj_loss', 'cls_loss'))
self.pbar = enumerate(self.train_loader)
if self.main_process:
self.pbar = tqdm(self.pbar, total=self.max_stepnum, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
# Print loss after each steps
def print_details(self):
if self.main_process:
self.mean_loss = (self.mean_loss * self.step + self.loss_items) / (self.step + 1)
self.pbar.set_description(('%10s' + '%10.4g' * 4) % (f'{self.epoch}/{self.max_epoch - 1}', \
*(self.mean_loss)))
# Empty cache if training finished
def train_after_loop(self):
if self.main_process:
LOGGER.info(f'\nTraining completed in {(time.time() - self.start_time) / 3600:.3f} hours.')
save_ckpt_dir = osp.join(self.save_dir, 'weights')
strip_optimizer(save_ckpt_dir, self.epoch) # strip optimizers for saved pt model
if self.device != 'cpu':
torch.cuda.empty_cache()
def update_optimizer(self):
curr_step = self.step + self.max_stepnum * self.epoch
self.accumulate = max(1, round(64 / self.batch_size))
if curr_step <= self.warmup_stepnum:
self.accumulate = max(1, np.interp(curr_step, [0, self.warmup_stepnum], [1, 64 / self.batch_size]).round())
for k, param in enumerate(self.optimizer.param_groups):
warmup_bias_lr = self.cfg.solver.warmup_bias_lr if k == 2 else 0.0
param['lr'] = np.interp(curr_step, [0, self.warmup_stepnum], [warmup_bias_lr, param['initial_lr'] * self.lf(self.epoch)])
if 'momentum' in param:
param['momentum'] = np.interp(curr_step, [0, self.warmup_stepnum], [self.cfg.solver.warmup_momentum, self.cfg.solver.momentum])
if curr_step - self.last_opt_step >= self.accumulate:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
self.last_opt_step = curr_step
@staticmethod
def get_data_loader(args, cfg, data_dict):
train_path, val_path = data_dict['train'], data_dict['val']
# check data
nc = int(data_dict['nc'])
class_names = data_dict['names']
assert len(class_names) == nc, f'the length of class names does not match the number of classes defined'
grid_size = max(int(max(cfg.model.head.strides)), 32)
# create train dataloader
train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,
hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,
workers=args.workers, shuffle=True, check_images=args.check_images,
check_labels=args.check_labels, data_dict=data_dict, task='train')[0]
# create val dataloader
val_loader = None
if args.rank in [-1, 0]:
val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,
hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
workers=args.workers, check_images=args.check_images,
check_labels=args.check_labels, data_dict=data_dict, task='val')[0]
return train_loader, val_loader
@staticmethod
def prepro_data(batch_data, device):
images = batch_data[0].to(device, non_blocking=True).float() / 255
targets = batch_data[1].to(device)
return images, targets
def get_model(self, args, cfg, nc, device):
model = build_model(cfg, nc, device)
weights = cfg.model.pretrained
if weights: # finetune if pretrained model is set
LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
model = load_state_dict(weights, model, map_location=device)
LOGGER.info('Model: {}'.format(model))
return model
@staticmethod
def parallel_model(args, model, device):
# If DP mode
dp_mode = device.type != 'cpu' and args.rank == -1
if dp_mode and torch.cuda.device_count() > 1:
LOGGER.warning('WARNING: DP not recommended, use DDP instead.\n')
model = torch.nn.DataParallel(model)
# If DDP mode
ddp_mode = device.type != 'cpu' and args.rank != -1
if ddp_mode:
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
return model
def get_optimizer(self, args, cfg, model):
accumulate = max(1, round(64 / args.batch_size))
cfg.solver.weight_decay *= args.batch_size * accumulate / 64
optimizer = build_optimizer(cfg, model)
return optimizer
@staticmethod
def get_lr_scheduler(args, cfg, optimizer):
epochs = args.epochs
lr_scheduler, lf = build_lr_scheduler(cfg, optimizer, epochs)
return lr_scheduler, lf

@ -0,0 +1,256 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
from tqdm import tqdm
import numpy as np
import json
import torch
import yaml
from pathlib import Path
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from yolov6.data.data_load import create_dataloader
from yolov6.utils.events import LOGGER, NCOLS
from yolov6.utils.nms import non_max_suppression
from yolov6.utils.checkpoint import load_checkpoint
from yolov6.utils.torch_utils import time_sync, get_model_info
'''
python tools/eval.py --task 'train'/'val'/'speed'
'''
class Evaler:
def __init__(self,
data,
batch_size=32,
img_size=640,
conf_thres=0.001,
iou_thres=0.65,
device='',
half=True,
save_dir=''):
self.data = data
self.batch_size = batch_size
self.img_size = img_size
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.device = device
self.half = half
self.save_dir = save_dir
def init_model(self, model, weights, task):
if task != 'train':
model = load_checkpoint(weights, map_location=self.device)
self.stride = int(model.stride.max())
if self.device.type != 'cpu':
model(torch.zeros(1, 3, self.img_size, self.img_size).to(self.device).type_as(next(model.parameters())))
# switch to deploy
from yolov6.layers.common import RepVGGBlock
for layer in model.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
LOGGER.info("Switch model to deploy modality.")
LOGGER.info("Model Summary: {}".format(get_model_info(model, self.img_size)))
model.half() if self.half else model.float()
return model
def init_data(self, dataloader, task):
'''Initialize dataloader.
Returns a dataloader for task val or speed.
'''
self.is_coco = self.data.get("is_coco", False)
self.ids = self.coco80_to_coco91_class() if self.is_coco else list(range(1000))
if task != 'train':
pad = 0.0 if task == 'speed' else 0.5
dataloader = create_dataloader(self.data[task if task in ('train', 'val', 'test') else 'val'],
self.img_size, self.batch_size, self.stride, check_labels=True, pad=pad, rect=True,
data_dict=self.data, task=task)[0]
return dataloader
def predict_model(self, model, dataloader, task):
'''Model prediction
Predicts the whole dataset and gets the prediced results and inference time.
'''
self.speed_result = torch.zeros(4, device=self.device)
pred_results = []
pbar = tqdm(dataloader, desc="Inferencing model in val datasets.", ncols=NCOLS)
for imgs, targets, paths, shapes in pbar:
# pre-process
t1 = time_sync()
imgs = imgs.to(self.device, non_blocking=True)
imgs = imgs.half() if self.half else imgs.float()
imgs /= 255
self.speed_result[1] += time_sync() - t1 # pre-process time
# Inference
t2 = time_sync()
outputs = model(imgs)
self.speed_result[2] += time_sync() - t2 # inference time
# post-process
t3 = time_sync()
outputs = non_max_suppression(outputs, self.conf_thres, self.iou_thres, multi_label=True)
self.speed_result[3] += time_sync() - t3 # post-process time
self.speed_result[0] += len(outputs)
# save result
pred_results.extend(self.convert_to_coco_format(outputs, imgs, paths, shapes, self.ids))
return pred_results
def eval_model(self, pred_results, model, dataloader, task):
'''Evaluate models
For task speed, this function only evaluates the speed of model and outputs inference time.
For task val, this function evaluates the speed and mAP by pycocotools, and returns
inference time and mAP value.
'''
LOGGER.info(f'\nEvaluating speed.')
self.eval_speed(task)
LOGGER.info(f'\nEvaluating mAP by pycocotools.')
if task != 'speed' and len(pred_results):
if 'anno_path' in self.data:
anno_json = self.data['anno_path']
else:
# generated coco format labels in dataset initialization
dataset_root = os.path.dirname(os.path.dirname(self.data['val']))
base_name = os.path.basename(self.data['val'])
anno_json = os.path.join(dataset_root, 'annotations', f'instances_{base_name}.json')
pred_json = os.path.join(self.save_dir, "predictions.json")
LOGGER.info(f'Saving {pred_json}...')
with open(pred_json, 'w') as f:
json.dump(pred_results, f)
anno = COCO(anno_json)
pred = anno.loadRes(pred_json)
cocoEval = COCOeval(anno, pred, 'bbox')
if self.is_coco:
imgIds = [int(os.path.basename(x).split(".")[0])
for x in dataloader.dataset.img_paths]
cocoEval.params.imgIds = imgIds
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
# Return results
model.float() # for training
if task != 'train':
LOGGER.info(f"Results saved to {self.save_dir}")
return (map50, map)
return (0.0, 0.0)
def eval_speed(self, task):
'''Evaluate model inference speed.'''
if task != 'train':
n_samples = self.speed_result[0].item()
pre_time, inf_time, nms_time = 1000 * self.speed_result[1:].cpu().numpy() / n_samples
for n, v in zip(["pre-process", "inference", "NMS"],[pre_time, inf_time, nms_time]):
LOGGER.info("Average {} time: {:.2f} ms".format(n, v))
def box_convert(self, x):
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2]] -= pad[0] # x padding
coords[:, [1, 3]] -= pad[1] # y padding
coords[:, :4] /= gain
if isinstance(coords, torch.Tensor): # faster individually
coords[:, 0].clamp_(0, img0_shape[1]) # x1
coords[:, 1].clamp_(0, img0_shape[0]) # y1
coords[:, 2].clamp_(0, img0_shape[1]) # x2
coords[:, 3].clamp_(0, img0_shape[0]) # y2
else: # np.array (faster grouped)
coords[:, [0, 2]] = coords[:, [0, 2]].clip(0, img0_shape[1]) # x1, x2
coords[:, [1, 3]] = coords[:, [1, 3]].clip(0, img0_shape[0]) # y1, y2
return coords
def convert_to_coco_format(self, outputs, imgs, paths, shapes, ids):
pred_results = []
for i, pred in enumerate(outputs):
if len(pred) == 0:
continue
path, shape = Path(paths[i]), shapes[i][0]
self.scale_coords(imgs[i].shape[1:], pred[:, :4], shape, shapes[i][1])
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
bboxes = self.box_convert(pred[:, 0:4])
bboxes[:, :2] -= bboxes[:, 2:] / 2
cls = pred[:, 5]
scores = pred[:, 4]
for ind in range(pred.shape[0]):
category_id = ids[int(cls[ind])]
bbox = [round(x, 3) for x in bboxes[ind].tolist()]
score = round(scores[ind].item(), 5)
pred_data = {
"image_id": image_id,
"category_id": category_id,
"bbox": bbox,
"score": score
}
pred_results.append(pred_data)
return pred_results
@staticmethod
def check_task(task):
if task not in ['train','val','speed']:
raise Exception("task argument error: only support 'train' / 'val' / 'speed' task.")
@staticmethod
def reload_thres(conf_thres, iou_thres, task):
'''Sets conf and iou threshold for task val/speed'''
if task != 'train':
if task == 'val':
conf_thres = 0.001
if task == 'speed':
conf_thres = 0.25
iou_thres = 0.45
return conf_thres, iou_thres
@staticmethod
def reload_device(device, model, task):
# device = 'cpu' or '0' or '0,1,2,3'
if task == 'train':
device = next(model.parameters()).device
else:
if device == 'cpu':
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
elif device:
os.environ['CUDA_VISIBLE_DEVICES'] = device
assert torch.cuda.is_available()
cuda = device != 'cpu' and torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
return device
@staticmethod
def reload_dataset(data):
with open(data, errors='ignore') as yaml_file:
data = yaml.safe_load(yaml_file)
val = data.get('val')
if not os.path.exists(val):
raise Exception('Dataset not found.')
return data
@staticmethod
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
return x

@ -0,0 +1,193 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import os.path as osp
import math
from tqdm import tqdm
import numpy as np
import cv2
import torch
from PIL import ImageFont
from yolov6.utils.events import LOGGER, load_yaml
from yolov6.layers.common import DetectBackend
from yolov6.data.data_augment import letterbox
from yolov6.utils.nms import non_max_suppression
class Inferer:
def __init__(self, source, weights, device, yaml, img_size, half):
import glob
from yolov6.data.datasets import IMG_FORMATS
self.__dict__.update(locals())
# Init model
self.device = device
self.img_size = img_size
cuda = self.device != 'cpu' and torch.cuda.is_available()
self.device = torch.device('cuda:0' if cuda else 'cpu')
self.model = DetectBackend(weights, device=self.device)
self.stride = self.model.stride
self.class_names = load_yaml(yaml)['names']
self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
# Half precision
if half & (self.device.type != 'cpu'):
self.model.model.half()
else:
self.model.model.float()
half = False
if self.device.type != 'cpu':
self.model(torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup
# Load data
if os.path.isdir(source):
img_paths = sorted(glob.glob(os.path.join(source, '*.*'))) # dir
elif os.path.isfile(source):
img_paths = [source] # files
else:
raise Exception(f'Invalid path: {source}')
self.img_paths = [img_path for img_path in img_paths if img_path.split('.')[-1].lower() in IMG_FORMATS]
def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf):
''' Model Inference and results visualization '''
for img_path in tqdm(self.img_paths):
img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half)
img = img.to(self.device)
if len(img.shape) == 3:
img = img[None]
# expand for batch dim
pred_results = self.model(img)
det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
save_path = osp.join(save_dir, osp.basename(img_path)) # im.jpg
txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
img_ori = img_src
# check image and font
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).'
self.font_check()
if len(det):
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf)
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img:
class_num = int(cls) # integer class
label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
img_src = np.asarray(img_ori)
# Save results (image with detections)
if save_img:
cv2.imwrite(save_path, img_src)
@staticmethod
def precess_image(path, img_size, stride, half):
'''Process image before image inference.'''
try:
img_src = cv2.imread(path)
assert img_src is not None, f'Invalid image: {path}'
except Exception as e:
LOGGER.Warning(e)
image = letterbox(img_src, img_size, stride=stride)[0]
# Convert
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
image = torch.from_numpy(np.ascontiguousarray(image))
image = image.half() if half else image.float() # uint8 to fp16/32
image /= 255 # 0 - 255 to 0.0 - 1.0
return image, img_src
@staticmethod
def rescale(ori_shape, boxes, target_shape):
'''Rescale the output to the original image shape'''
ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
boxes[:, [0, 2]] -= padding[0]
boxes[:, [1, 3]] -= padding[1]
boxes[:, :4] /= ratio
boxes[:, 0].clamp_(0, target_shape[1]) # x1
boxes[:, 1].clamp_(0, target_shape[0]) # y1
boxes[:, 2].clamp_(0, target_shape[1]) # x2
boxes[:, 3].clamp_(0, target_shape[0]) # y2
return boxes
def check_img_size(self, img_size, s=32, floor=0):
"""Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
if isinstance(img_size, int): # integer i.e. img_size=640
new_size = max(self.make_divisible(img_size, int(s)), floor)
elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
else:
raise Exception(f"Unsupported type of img_size: {type(img_size)}")
if new_size != img_size:
print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}')
return new_size if isinstance(img_size,list) else [new_size]*2
def make_divisible(self, x, divisor):
# Upward revision the value x to make it evenly divisible by the divisor.
return math.ceil(x / divisor) * divisor
@staticmethod
def plot_box_and_label(image, lw, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
# Add one xyxy box to image with label
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
if label:
tf = max(lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, lw / 3, txt_color,
thickness=tf, lineType=cv2.LINE_AA)
@staticmethod
def font_check(font='./yolov6/utils/Arial.ttf', size=10):
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
assert osp.exists(font), f'font path not exists: {font}'
try:
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
except Exception as e: # download if missing
return ImageFont.truetype(str(font), size)
@staticmethod
def box_convert(x):
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
@staticmethod
def generate_colors(i, bgr=False):
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
palette = []
for iter in hex:
h = '#' + iter
palette.append(tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)))
num = len(palette)
color = palette[int(i) % num]
return (color[2], color[1], color[0]) if bgr else color

@ -0,0 +1,193 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# This code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
import math
import random
import cv2
import numpy as np
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation
if hgain or sgain or vgain:
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
dtype = im.dtype # uint8
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, r, (dw, dh)
def mixup(im, labels, im2, labels2):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
im = (im * r + im2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)
return im, labels
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
def random_affine(img, labels=(), degrees=10, translate=.1, scale=.1, shear=10,
new_shape=(640, 640)):
n = len(labels)
height, width = new_shape
M, s = get_transform_matrix(img.shape[:2], (height, width), degrees, scale, shear, translate)
if (M != np.eye(3)).any(): # image changed
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
# Transform label coordinates
if n:
new = np.zeros((n, 4))
xy = np.ones((n * 4, 3))
xy[:, :2] = labels[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform
xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
# filter candidates
i = box_candidates(box1=labels[:, 1:5].T * s, box2=new.T, area_thr=0.1)
labels = labels[i]
labels[:, 1:5] = new[i]
return img, labels
def get_transform_matrix(img_shape, new_shape, degrees, scale, shear, translate):
new_height, new_width = new_shape
# Center
C = np.eye(3)
C[0, 2] = -img_shape[1] / 2 # x translation (pixels)
C[1, 2] = -img_shape[0] / 2 # y translation (pixels)
# Rotation and Scale
R = np.eye(3)
a = random.uniform(-degrees, degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(1 - scale, 1 + scale)
# s = 2 ** random.uniform(-scale, scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_width # x translation (pixels)
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_height # y transla ion (pixels)
# Combined rotation matrix
M = T @ S @ R @ C # order of operations (right to left) is IMPORTANT
return M, s
def mosaic_augmentation(img_size, imgs, hs, ws, labels, hyp):
assert len(imgs) == 4, "Mosaic augmentation of current version only supports 4 images."
labels4 = []
s = img_size
yc, xc = (int(random.uniform(s//2, 3*s//2)) for _ in range(2)) # mosaic center x, y
for i in range(len(imgs)):
# Load image
img, h, w = imgs[i], hs[i], ws[i]
# place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
# Labels
labels_per_img = labels[i].copy()
if labels_per_img.size:
boxes = np.copy(labels_per_img[:, 1:])
boxes[:, 0] = w * (labels_per_img[:, 1] - labels_per_img[:, 3] / 2) + padw # top left x
boxes[:, 1] = h * (labels_per_img[:, 2] - labels_per_img[:, 4] / 2) + padh # top left y
boxes[:, 2] = w * (labels_per_img[:, 1] + labels_per_img[:, 3] / 2) + padw # bottom right x
boxes[:, 3] = h * (labels_per_img[:, 2] + labels_per_img[:, 4] / 2) + padh # bottom right y
labels_per_img[:, 1:] = boxes
labels4.append(labels_per_img)
# Concat/clip labels
labels4 = np.concatenate(labels4, 0)
for x in (labels4[:, 1:]):
np.clip(x, 0, 2 * s, out=x)
# Augment
img4, labels4 = random_affine(img4, labels4,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'])
return img4, labels4

@ -0,0 +1,113 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# This code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
import os
from torch.utils.data import dataloader, distributed
from .datasets import TrainValDataset
from yolov6.utils.events import LOGGER
from yolov6.utils.torch_utils import torch_distributed_zero_first
def create_dataloader(
path,
img_size,
batch_size,
stride,
hyp=None,
augment=False,
check_images=False,
check_labels=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
shuffle=False,
data_dict=None,
task="Train",
):
"""Create general dataloader.
Returns dataloader and dataset
"""
if rect and shuffle:
LOGGER.warning(
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
)
shuffle = False
with torch_distributed_zero_first(rank):
dataset = TrainValDataset(
path,
img_size,
batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
check_labels=check_labels,
stride=int(stride),
pad=pad,
rank=rank,
data_dict=data_dict,
task=task,
)
batch_size = min(batch_size, len(dataset))
workers = min(
[
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
batch_size if batch_size > 1 else 0,
workers,
]
) # number of workers
sampler = (
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
)
return (
TrainValDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn,
),
dataset,
)
class TrainValDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
"""Sampler that repeats forever
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)

@ -0,0 +1,550 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import glob
import os
import os.path as osp
import random
import json
import time
import hashlib
from multiprocessing.pool import Pool
import cv2
import numpy as np
import torch
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset
from tqdm import tqdm
from .data_augment import (
augment_hsv,
letterbox,
mixup,
random_affine,
mosaic_augmentation,
)
from yolov6.utils.events import LOGGER
# Parameters
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
# Get orientation exif tag
for k, v in ExifTags.TAGS.items():
if v == "Orientation":
ORIENTATION = k
break
class TrainValDataset(Dataset):
# YOLOv6 train_loader/val_loader, loads images and labels for training and validation
def __init__(
self,
img_dir,
img_size=640,
batch_size=16,
augment=False,
hyp=None,
rect=False,
check_images=False,
check_labels=False,
stride=32,
pad=0.0,
rank=-1,
data_dict=None,
task="train",
):
assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}"
t1 = time.time()
self.__dict__.update(locals())
self.main_process = self.rank in (-1, 0)
self.task = self.task.capitalize()
self.class_names = data_dict["names"]
self.img_paths, self.labels = self.get_imgs_labels(self.img_dir)
if self.rect:
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
self.shapes = np.array(shapes, dtype=np.float64)
self.batch_indices = np.floor(
np.arange(len(shapes)) / self.batch_size
).astype(
np.int
) # batch indices of each image
self.sort_files_shapes()
t2 = time.time()
if self.main_process:
LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1))
def __len__(self):
"""Get the length of dataset"""
return len(self.img_paths)
def __getitem__(self, index):
"""Fetching a data sample for a given key.
This function applies mosaic and mixup augments during training.
During validation, letterbox augment is applied.
"""
# Mosaic Augmentation
if self.augment and random.random() < self.hyp["mosaic"]:
img, labels = self.get_mosaic(index)
shapes = None
# MixUp augmentation
if random.random() < self.hyp["mixup"]:
img_other, labels_other = self.get_mosaic(
random.randint(0, len(self.img_paths) - 1)
)
img, labels = mixup(img, labels, img_other, labels_other)
else:
# Load image
img, (h0, w0), (h, w) = self.load_image(index)
# Letterbox
shape = (
self.batch_shapes[self.batch_indices[index]]
if self.rect
else self.img_size
) # final letterboxed shape
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
labels = self.labels[index].copy()
if labels.size:
w *= ratio
h *= ratio
# new boxes
boxes = np.copy(labels[:, 1:])
boxes[:, 0] = (
w * (labels[:, 1] - labels[:, 3] / 2) + pad[0]
) # top left x
boxes[:, 1] = (
h * (labels[:, 2] - labels[:, 4] / 2) + pad[1]
) # top left y
boxes[:, 2] = (
w * (labels[:, 1] + labels[:, 3] / 2) + pad[0]
) # bottom right x
boxes[:, 3] = (
h * (labels[:, 2] + labels[:, 4] / 2) + pad[1]
) # bottom right y
labels[:, 1:] = boxes
if self.augment:
img, labels = random_affine(
img,
labels,
degrees=self.hyp["degrees"],
translate=self.hyp["translate"],
scale=self.hyp["scale"],
shear=self.hyp["shear"],
new_shape=(self.img_size, self.img_size),
)
if len(labels):
h, w = img.shape[:2]
labels[:, [1, 3]] = labels[:, [1, 3]].clip(0, w - 1e-3) # x1, x2
labels[:, [2, 4]] = labels[:, [2, 4]].clip(0, h - 1e-3) # y1, y2
boxes = np.copy(labels[:, 1:])
boxes[:, 0] = ((labels[:, 1] + labels[:, 3]) / 2) / w # x center
boxes[:, 1] = ((labels[:, 2] + labels[:, 4]) / 2) / h # y center
boxes[:, 2] = (labels[:, 3] - labels[:, 1]) / w # width
boxes[:, 3] = (labels[:, 4] - labels[:, 2]) / h # height
labels[:, 1:] = boxes
if self.augment:
img, labels = self.general_augment(img, labels)
labels_out = torch.zeros((len(labels), 6))
if len(labels):
labels_out[:, 1:] = torch.from_numpy(labels)
# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
return torch.from_numpy(img), labels_out, self.img_paths[index], shapes
def load_image(self, index):
"""Load image.
This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio.
Returns:
Image, original shape of image, resized image shape
"""
path = self.img_paths[index]
im = cv2.imread(path)
assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}"
h0, w0 = im.shape[:2] # origin shape
r = self.img_size / max(h0, w0)
if r != 1:
im = cv2.resize(
im,
(int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA
if r < 1 and not self.augment
else cv2.INTER_LINEAR,
)
return im, (h0, w0), im.shape[:2]
@staticmethod
def collate_fn(batch):
"""Merges a list of samples to form a mini-batch of Tensor(s)"""
img, label, path, shapes = zip(*batch)
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
def get_imgs_labels(self, img_dir):
assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!"
valid_img_record = osp.join(
osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
)
NUM_THREADS = min(8, os.cpu_count())
img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
img_paths = sorted(
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
)
assert img_paths, f"No images found in {img_dir}."
img_hash = self.get_hash(img_paths)
if osp.exists(valid_img_record):
with open(valid_img_record, "r") as f:
cache_info = json.load(f)
if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
img_info = cache_info["information"]
else:
self.check_images = True
else:
self.check_images = True
# check images
if self.check_images and self.main_process:
img_info = {}
nc, msgs = 0, [] # number corrupt, messages
LOGGER.info(
f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
)
with Pool(NUM_THREADS) as pool:
pbar = tqdm(
pool.imap(TrainValDataset.check_image, img_paths),
total=len(img_paths),
)
for img_path, shape_per_img, nc_per_img, msg in pbar:
if nc_per_img == 0: # not corrupted
img_info[img_path] = {"shape": shape_per_img}
nc += nc_per_img
if msg:
msgs.append(msg)
pbar.desc = f"{nc} image(s) corrupted"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
cache_info = {"information": img_info, "image_hash": img_hash}
# save valid image paths.
with open(valid_img_record, "w") as f:
json.dump(cache_info, f)
# check and load anns
label_dir = osp.join(
osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
)
assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
img_paths = list(img_info.keys())
label_paths = sorted(
osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt")
for p in img_paths
)
label_hash = self.get_hash(label_paths)
if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
self.check_labels = True
if self.check_labels:
cache_info["label_hash"] = label_hash
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages
LOGGER.info(
f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
)
with Pool(NUM_THREADS) as pool:
pbar = pool.imap(
TrainValDataset.check_label_files, zip(img_paths, label_paths)
)
pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
for (
img_path,
labels_per_file,
nc_per_file,
nm_per_file,
nf_per_file,
ne_per_file,
msg,
) in pbar:
if nc_per_file == 0:
img_info[img_path]["labels"] = labels_per_file
else:
img_info.pop(img_path)
nc += nc_per_file
nm += nm_per_file
nf += nf_per_file
ne += ne_per_file
if msg:
msgs.append(msg)
if self.main_process:
pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
if self.main_process:
pbar.close()
with open(valid_img_record, "w") as f:
json.dump(cache_info, f)
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(
f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
)
if self.task.lower() == "val":
if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
else:
assert (
self.class_names
), "Class names is required when converting labels to coco format for evaluating."
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
if not osp.exists(save_dir):
os.mkdir(save_dir)
save_path = osp.join(
save_dir, "instances_" + osp.basename(img_dir) + ".json"
)
TrainValDataset.generate_coco_format_labels(
img_info, self.class_names, save_path
)
img_paths, labels = list(
zip(
*[
(
img_path,
np.array(info["labels"], dtype=np.float32)
if info["labels"]
else np.zeros((0, 5), dtype=np.float32),
)
for img_path, info in img_info.items()
]
)
)
self.img_info = img_info
LOGGER.info(
f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
)
return img_paths, labels
def get_mosaic(self, index):
"""Gets images and labels after mosaic augments"""
indices = [index] + random.choices(
range(0, len(self.img_paths)), k=3
) # 3 additional image indices
random.shuffle(indices)
imgs, hs, ws, labels = [], [], [], []
for index in indices:
img, _, (h, w) = self.load_image(index)
labels_per_img = self.labels[index]
imgs.append(img)
hs.append(h)
ws.append(w)
labels.append(labels_per_img)
img, labels = mosaic_augmentation(self.img_size, imgs, hs, ws, labels, self.hyp)
return img, labels
def general_augment(self, img, labels):
"""Gets images and labels after general augment
This function applies hsv, random ud-flip and random lr-flips augments.
"""
nl = len(labels)
# HSV color-space
augment_hsv(
img,
hgain=self.hyp["hsv_h"],
sgain=self.hyp["hsv_s"],
vgain=self.hyp["hsv_v"],
)
# Flip up-down
if random.random() < self.hyp["flipud"]:
img = np.flipud(img)
if nl:
labels[:, 2] = 1 - labels[:, 2]
# Flip left-right
if random.random() < self.hyp["fliplr"]:
img = np.fliplr(img)
if nl:
labels[:, 1] = 1 - labels[:, 1]
return img, labels
def sort_files_shapes(self):
# Sort by aspect ratio
batch_num = self.batch_indices[-1] + 1
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.img_paths = [self.img_paths[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.shapes = s[irect] # wh
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * batch_num
for i in range(batch_num):
ari = ar[self.batch_indices == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = (
np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(
np.int
)
* self.stride
)
@staticmethod
def check_image(im_file):
# verify an image.
nc, msg = 0, ""
try:
im = Image.open(im_file)
im.verify() # PIL verify
shape = im.size # (width, height)
im_exif = im._getexif()
if im_exif and ORIENTATION in im_exif:
rotation = im_exif[ORIENTATION]
if rotation in (6, 8):
shape = (shape[1], shape[0])
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(
im_file, "JPEG", subsampling=0, quality=100
)
msg += f"WARNING: {im_file}: corrupt JPEG restored and saved"
return im_file, shape, nc, msg
except Exception as e:
nc = 1
msg = f"WARNING: {im_file}: ignoring corrupt image: {e}"
return im_file, None, nc, msg
@staticmethod
def check_label_files(args):
img_path, lb_path = args
nm, nf, ne, nc, msg = 0, 0, 0, 0, "" # number (missing, found, empty, message
try:
if osp.exists(lb_path):
nf = 1 # label found
with open(lb_path, "r") as f:
labels = [
x.split() for x in f.read().strip().splitlines() if len(x)
]
labels = np.array(labels, dtype=np.float32)
if len(labels):
assert all(
len(l) == 5 for l in labels
), f"{lb_path}: wrong label format."
assert (
labels >= 0
).all(), f"{lb_path}: Label values error: all values in label file must > 0"
assert (
labels[:, 1:] <= 1
).all(), f"{lb_path}: Label values error: all coordinates must be normalized"
_, indices = np.unique(labels, axis=0, return_index=True)
if len(indices) < len(labels): # duplicate row check
labels = labels[indices] # remove duplicates
msg += f"WARNING: {lb_path}: {len(labels) - len(indices)} duplicate labels removed"
labels = labels.tolist()
else:
ne = 1 # label empty
labels = []
else:
nm = 1 # label missing
labels = []
return img_path, labels, nc, nm, nf, ne, msg
except Exception as e:
nc = 1
msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}"
return img_path, None, nc, nm, nf, ne, msg
@staticmethod
def generate_coco_format_labels(img_info, class_names, save_path):
# for evaluation with pycocotools
dataset = {"categories": [], "annotations": [], "images": []}
for i, class_name in enumerate(class_names):
dataset["categories"].append(
{"id": i, "name": class_name, "supercategory": ""}
)
ann_id = 0
LOGGER.info(f"Convert to COCO format")
for i, (img_path, info) in enumerate(tqdm(img_info.items())):
labels = info["labels"] if info["labels"] else []
img_id = osp.splitext(osp.basename(img_path))[0]
img_id = int(img_id) if img_id.isnumeric() else img_id
img_w, img_h = info["shape"]
dataset["images"].append(
{
"file_name": os.path.basename(img_path),
"id": img_id,
"width": img_w,
"height": img_h,
}
)
if labels:
for label in labels:
c, x, y, w, h = label[:5]
# convert x,y,w,h to x1,y1,x2,y2
x1 = (x - w / 2) * img_w
y1 = (y - h / 2) * img_h
x2 = (x + w / 2) * img_w
y2 = (y + h / 2) * img_h
# cls_id starts from 0
cls_id = int(c)
w = max(0, x2 - x1)
h = max(0, y2 - y1)
dataset["annotations"].append(
{
"area": h * w,
"bbox": [x1, y1, w, h],
"category_id": cls_id,
"id": ann_id,
"image_id": img_id,
"iscrowd": 0,
# mask
"segmentation": [],
}
)
ann_id += 1
with open(save_path, "w") as f:
json.dump(dataset, f)
LOGGER.info(
f"Convert to COCO format finished. Resutls saved in {save_path}"
)
@staticmethod
def get_hash(paths):
"""Get the hash value of paths"""
assert isinstance(paths, list), "Only support list currently."
h = hashlib.md5("".join(paths).encode())
return h.hexdigest()

@ -0,0 +1,501 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import warnings
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from yolov6.layers.dbb_transforms import *
class SiLU(nn.Module):
'''Activation of SiLU'''
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
class Conv(nn.Module):
'''Normal Conv with SiLU activation'''
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class SimConv(nn.Module):
'''Normal Conv with ReLU activation'''
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class SimSPPF(nn.Module):
'''Simplified SPPF with ReLU activation'''
def __init__(self, in_channels, out_channels, kernel_size=5):
super().__init__()
c_ = in_channels // 2 # hidden channels
self.cv1 = SimConv(in_channels, c_, 1, 1)
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
class Transpose(nn.Module):
'''Normal Transpose, default for upsampling'''
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
super().__init__()
self.upsample_transpose = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=True
)
def forward(self, x):
return self.upsample_transpose(x)
class Concat(nn.Module):
def __init__(self, dimension=1):
super().__init__()
self.d = dimension
def forward(self, x):
return torch.cat(x, self.d)
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
'''Basic cell for rep-style block, including conv and bn'''
result = nn.Sequential()
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result
class RepBlock(nn.Module):
'''
RepBlock is a stage block with rep-style basic block
'''
def __init__(self, in_channels, out_channels, n=1):
super().__init__()
self.conv1 = RepVGGBlock(in_channels, out_channels)
self.block = nn.Sequential(*(RepVGGBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
def forward(self, x):
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x
class RepVGGBlock(nn.Module):
'''RepVGGBlock is a basic rep-style block, including training and deploy status
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
'''
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
super(RepVGGBlock, self).__init__()
""" Initialization of the class.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 1
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
padding_mode (string, optional): Default: 'zeros'
deploy: Whether to be deploy status or training status. Default: False
use_se: Whether to use se. Default: False
"""
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels
assert kernel_size == 3
assert padding == 1
padding_11 = padding - kernel_size // 2
self.nonlinearity = nn.ReLU()
if use_se:
raise NotImplementedError("se block not supported yet")
else:
self.se = nn.Identity()
if deploy:
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
else:
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
def forward(self, inputs):
'''Forward process'''
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)
return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True
def conv_bn_v2(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros'):
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=False, padding_mode=padding_mode)
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
se = nn.Sequential()
se.add_module('conv', conv_layer)
se.add_module('bn', bn_layer)
return se
class IdentityBasedConv1x1(nn.Conv2d):
def __init__(self, channels, groups=1):
super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
nn.init.zeros_(self.weight)
def forward(self, input):
kernel = self.weight + self.id_tensor.to(self.weight.device)
result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor.to(self.weight.device)
class BNAndPadLayer(nn.Module):
def __init__(self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps)
else:
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0:self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels:, :] = pad_values
output[:, :, :, 0:self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def bn_weight(self):
return self.bn.weight
@property
def bn_bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class DBBBlock(nn.Module):
'''
RepBlock is a stage block with rep-style basic block
'''
def __init__(self, in_channels, out_channels, n=1):
super().__init__()
self.conv1 = DiverseBranchBlock(in_channels, out_channels)
self.block = nn.Sequential(*(DiverseBranchBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
def forward(self, x):
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x
class DiverseBranchBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding=1, dilation=1, groups=1,
internal_channels_1x1_3x3=None,
deploy=False, nonlinear=nn.ReLU(), single_init=False):
super(DiverseBranchBlock, self).__init__()
self.deploy = deploy
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nonlinear
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if deploy:
self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=True)
else:
self.dbb_origin = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_module('conv',
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
stride=1, padding=0, groups=groups, bias=False))
self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels))
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
self.dbb_1x1 = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
padding=0, groups=groups)
else:
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding))
self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels))
if internal_channels_1x1_3x3 is None:
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups))
else:
self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3,
kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
if single_init:
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
self.single_init()
def get_equivalent_kernel_bias(self):
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn)
if hasattr(self, 'dbb_1x1'):
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups)
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, 'conv'):
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn)
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
def switch_to_deploy(self):
if hasattr(self, 'dbb_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels,
kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride,
padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True)
self.dbb_reparam.weight.data = kernel
self.dbb_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('dbb_origin')
self.__delattr__('dbb_avg')
if hasattr(self, 'dbb_1x1'):
self.__delattr__('dbb_1x1')
self.__delattr__('dbb_1x1_kxk')
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def init_gamma(self, gamma_value):
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
if hasattr(self, "dbb_1x1"):
torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
if hasattr(self, "dbb_avg"):
torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
if hasattr(self, "dbb_1x1_kxk"):
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
def single_init(self):
self.init_gamma(0.0)
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
class DetectBackend(nn.Module):
def __init__(self, weights='yolov6s.pt', device=None, dnn=True):
super().__init__()
assert isinstance(weights, str) and Path(weights).suffix == '.pt', f'{Path(weights).suffix} format is not supported.'
from yolov6.utils.checkpoint import load_checkpoint
model = load_checkpoint(weights, map_location=device)
stride = int(model.stride.max())
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, val=False):
y = self.model(im)
if isinstance(y, np.ndarray):
y = torch.tensor(y, device=self.device)
return y

@ -0,0 +1,50 @@
import torch
import numpy as np
import torch.nn.functional as F
def transI_fusebn(kernel, bn):
gamma = bn.weight
std = (bn.running_var + bn.eps).sqrt()
return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
def transII_addbranch(kernels, biases):
return sum(kernels), sum(biases)
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
else:
k_slices = []
b_slices = []
k1_T = k1.permute(1, 0, 2, 3)
k1_group_width = k1.size(0) // groups
k2_group_width = k2.size(0) // groups
for g in range(groups):
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
k, b_hat = transIV_depthconcat(k_slices, b_slices)
return k, b_hat + b2
def transIV_depthconcat(kernels, biases):
return torch.cat(kernels, dim=0), torch.cat(biases)
def transV_avg(channels, kernel_size, groups):
input_dim = channels // groups
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
return k
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def transVI_multiscale(kernel, target_kernel_size):
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])

@ -0,0 +1,102 @@
from torch import nn
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
class EfficientRep(nn.Module):
'''EfficientRep Backbone
EfficientRep is handcrafted by hardware-aware neural network design.
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
'''
def __init__(
self,
in_channels=3,
channels_list=None,
num_repeats=None,
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None
self.stem = RepVGGBlock(
in_channels=in_channels,
out_channels=channels_list[0],
kernel_size=3,
stride=2
)
self.ERBlock_2 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[0],
out_channels=channels_list[1],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[1],
out_channels=channels_list[1],
n=num_repeats[1]
)
)
self.ERBlock_3 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[1],
out_channels=channels_list[2],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[2],
out_channels=channels_list[2],
n=num_repeats[2]
)
)
self.ERBlock_4 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[2],
out_channels=channels_list[3],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[3],
out_channels=channels_list[3],
n=num_repeats[3]
)
)
self.ERBlock_5 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[3],
out_channels=channels_list[4],
kernel_size=3,
stride=2,
),
RepBlock(
in_channels=channels_list[4],
out_channels=channels_list[4],
n=num_repeats[4]
),
SimSPPF(
in_channels=channels_list[4],
out_channels=channels_list[4],
kernel_size=5
)
)
def forward(self, x):
outputs = []
x = self.stem(x)
x = self.ERBlock_2(x)
x = self.ERBlock_3(x)
outputs.append(x)
x = self.ERBlock_4(x)
outputs.append(x)
x = self.ERBlock_5(x)
outputs.append(x)
return tuple(outputs)

@ -0,0 +1,211 @@
import torch
import torch.nn as nn
import math
from yolov6.layers.common import *
class Detect(nn.Module):
'''Efficient Decoupled Head
With hardware-aware degisn, the decoupled head is optimized with
hybridchannels methods.
'''
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
super().__init__()
assert head_layers is not None
self.nc = num_classes # number of classes
self.no = num_classes + 5 # number of outputs per anchor
self.nl = num_layers # number of detection layers
if isinstance(anchors, (list, tuple)):
self.na = len(anchors[0]) // 2
else:
self.na = anchors
self.anchors = anchors
self.grid = [torch.zeros(1)] * num_layers
self.prior_prob = 1e-2
self.inplace = inplace
stride = [8, 16, 32] # strides computed during build
self.stride = torch.tensor(stride)
# Init decouple head
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.obj_preds = nn.ModuleList()
self.stems = nn.ModuleList()
# Efficient decoupled head layers
for i in range(num_layers):
idx = i*6
self.stems.append(head_layers[idx])
self.cls_convs.append(head_layers[idx+1])
self.reg_convs.append(head_layers[idx+2])
self.cls_preds.append(head_layers[idx+3])
self.reg_preds.append(head_layers[idx+4])
self.obj_preds.append(head_layers[idx+5])
def initialize_biases(self):
for conv in self.cls_preds:
b = conv.bias.view(self.na, -1)
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
for conv in self.obj_preds:
b = conv.bias.view(self.na, -1)
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def forward(self, x):
z = []
for i in range(self.nl):
x[i] = self.stems[i](x[i])
cls_x = x[i]
reg_x = x[i]
cls_feat = self.cls_convs[i](cls_x)
cls_output = self.cls_preds[i](cls_feat)
reg_feat = self.reg_convs[i](reg_x)
reg_output = self.reg_preds[i](reg_feat)
obj_output = self.obj_preds[i](reg_feat)
if self.training:
x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
bs, _, ny, nx = x[i].shape
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
else:
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
bs, _, ny, nx = y.shape
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if self.grid[i].shape[2:4] != y.shape[2:4]:
d = self.stride.device
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
else:
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
return x if self.training else torch.cat(z, 1)
def build_effidehead_layer(channels_list, num_anchors, num_classes):
head_layers = nn.Sequential(
# stem0
Conv(
in_channels=channels_list[6],
out_channels=channels_list[6],
kernel_size=1,
stride=1
),
# cls_conv0
Conv(
in_channels=channels_list[6],
out_channels=channels_list[6],
kernel_size=3,
stride=1
),
# reg_conv0
Conv(
in_channels=channels_list[6],
out_channels=channels_list[6],
kernel_size=3,
stride=1
),
# cls_pred0
nn.Conv2d(
in_channels=channels_list[6],
out_channels=num_classes * num_anchors,
kernel_size=1
),
# reg_pred0
nn.Conv2d(
in_channels=channels_list[6],
out_channels=4 * num_anchors,
kernel_size=1
),
# obj_pred0
nn.Conv2d(
in_channels=channels_list[6],
out_channels=1 * num_anchors,
kernel_size=1
),
# stem1
Conv(
in_channels=channels_list[8],
out_channels=channels_list[8],
kernel_size=1,
stride=1
),
# cls_conv1
Conv(
in_channels=channels_list[8],
out_channels=channels_list[8],
kernel_size=3,
stride=1
),
# reg_conv1
Conv(
in_channels=channels_list[8],
out_channels=channels_list[8],
kernel_size=3,
stride=1
),
# cls_pred1
nn.Conv2d(
in_channels=channels_list[8],
out_channels=num_classes * num_anchors,
kernel_size=1
),
# reg_pred1
nn.Conv2d(
in_channels=channels_list[8],
out_channels=4 * num_anchors,
kernel_size=1
),
# obj_pred1
nn.Conv2d(
in_channels=channels_list[8],
out_channels=1 * num_anchors,
kernel_size=1
),
# stem2
Conv(
in_channels=channels_list[10],
out_channels=channels_list[10],
kernel_size=1,
stride=1
),
# cls_conv2
Conv(
in_channels=channels_list[10],
out_channels=channels_list[10],
kernel_size=3,
stride=1
),
# reg_conv2
Conv(
in_channels=channels_list[10],
out_channels=channels_list[10],
kernel_size=3,
stride=1
),
# cls_pred2
nn.Conv2d(
in_channels=channels_list[10],
out_channels=num_classes * num_anchors,
kernel_size=1
),
# reg_pred2
nn.Conv2d(
in_channels=channels_list[10],
out_channels=4 * num_anchors,
kernel_size=1
),
# obj_pred2
nn.Conv2d(
in_channels=channels_list[10],
out_channels=1 * num_anchors,
kernel_size=1
)
)
return head_layers

@ -0,0 +1,151 @@
import torch
import torch.nn as nn
import random
class ORT_NMS(torch.autograd.Function):
@staticmethod
def forward(ctx,
boxes,
scores,
max_output_boxes_per_class=torch.tensor([100]),
iou_threshold=torch.tensor([0.45]),
score_threshold=torch.tensor([0.25])):
device = boxes.device
batch = scores.shape[0]
num_det = random.randint(0, 100)
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices
@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
class TRT_NMS(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums,boxes,scores,classes
class ONNX_ORT(nn.Module):
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
self.iou_threshold = torch.tensor([iou_thres]).to(device)
self.score_threshold = torch.tensor([score_thres]).to(device)
self.max_wh = max_wh
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)
def forward(self, x):
box = x[:, :, :4]
conf = x[:, :, 4:5]
score = x[:, :, 5:]
score *= conf
box @= self.convert_matrix
objScore, objCls = score.max(2, keepdim=True)
dis = objCls.float() * self.max_wh
nmsbox = box + dis
objScore1 = objScore.transpose(1, 2).contiguous()
selected_indices = ORT_NMS.apply(nmsbox, objScore1, self.max_obj, self.iou_threshold, self.score_threshold)
X, Y = selected_indices[:, 0], selected_indices[:, 2]
resBoxes = box[X, Y, :]
resClasses = objCls[X, Y, :].float()
resScores = objScore[X, Y, :]
X = X.unsqueeze(1).float()
return torch.concat([X, resBoxes, resClasses, resScores], 1)
class ONNX_TRT(nn.Module):
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 0,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)
def forward(self, x):
box = x[:, :, :4]
conf = x[:, :, 4:5]
score = x[:, :, 5:]
score *= conf
box @= self.convert_matrix
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(box, score, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
self.score_threshold)
return num_det, det_boxes, det_scores, det_classes
class End2End(nn.Module):
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
super().__init__()
device = device if device else torch.device('cpu')
self.model = model.to(device)
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
self.end2end.eval()
def forward(self, x):
x = self.model(x)
x = self.end2end(x)
return x

@ -0,0 +1,411 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# The code is based on
# https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
# Copyright (c) Megvii, Inc. and its affiliates.
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from yolov6.utils.figure_iou import IOUloss, pairwise_bbox_iou
class ComputeLoss:
'''Loss computation func.
This func contains SimOTA and siou loss.
'''
def __init__(self,
reg_weight=5.0,
iou_weight=3.0,
cls_weight=1.0,
center_radius=2.5,
eps=1e-7,
in_channels=[256, 512, 1024],
strides=[8, 16, 32],
n_anchors=1,
iou_type='ciou'
):
self.reg_weight = reg_weight
self.iou_weight = iou_weight
self.cls_weight = cls_weight
self.center_radius = center_radius
self.eps = eps
self.n_anchors = n_anchors
self.strides = strides
self.grids = [torch.zeros(1)] * len(in_channels)
# Define criteria
self.l1_loss = nn.L1Loss(reduction="none")
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.iou_loss = IOUloss(iou_type=iou_type, reduction="none")
def __call__(
self,
outputs,
targets
):
dtype = outputs[0].type()
device = targets.device
loss_cls, loss_obj, loss_iou, loss_l1 = torch.zeros(1, device=device), torch.zeros(1, device=device), \
torch.zeros(1, device=device), torch.zeros(1, device=device)
num_classes = outputs[0].shape[-1] - 5
outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(
outputs, self.strides, dtype, device)
total_num_anchors = outputs.shape[1]
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
bbox_preds_org = outputs_origin[:, :, :4] # [batch, n_anchors_all, 4]
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
# targets
batch_size = bbox_preds.shape[0]
targets_list = np.zeros((batch_size, 1, 5)).tolist()
for i, item in enumerate(targets.cpu().numpy().tolist()):
targets_list[int(item[0])].append(item[1:])
max_len = max((len(l) for l in targets_list))
targets = torch.from_numpy(np.array(list(map(lambda l:l + [[-1,0,0,0,0]]*(max_len - len(l)), targets_list)))[:,1:,:]).to(targets.device)
num_targets_list = (targets.sum(dim=2) > 0).sum(dim=1) # number of objects
num_fg, num_gts = 0, 0
cls_targets, reg_targets, l1_targets, obj_targets, fg_masks = [], [], [], [], []
for batch_idx in range(batch_size):
num_gt = int(num_targets_list[batch_idx])
num_gts += num_gt
if num_gt == 0:
cls_target = outputs.new_zeros((0, num_classes))
reg_target = outputs.new_zeros((0, 4))
l1_target = outputs.new_zeros((0, 4))
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
gt_bboxes_per_image = targets[batch_idx, :num_gt, 1:5].mul_(gt_bboxes_scale)
gt_classes = targets[batch_idx, :num_gt, 0]
bboxes_preds_per_image = bbox_preds[batch_idx]
cls_preds_per_image = cls_preds[batch_idx]
obj_preds_per_image = obj_preds[batch_idx]
try:
(
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg_img,
) = self.get_assignments(
batch_idx,
num_gt,
total_num_anchors,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
cls_preds_per_image,
obj_preds_per_image,
expanded_strides,
xy_shifts,
num_classes
)
except RuntimeError:
print(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
try to reduce the batch size or image size."
)
torch.cuda.empty_cache()
print("------------CPU Mode for This Batch-------------")
_gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
_gt_classes = gt_classes.cpu().float()
_bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
_cls_preds_per_image = cls_preds_per_image.cpu().float()
_obj_preds_per_image = obj_preds_per_image.cpu().float()
_expanded_strides = expanded_strides.cpu().float()
_xy_shifts = xy_shifts.cpu()
(
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg_img,
) = self.get_assignments(
batch_idx,
num_gt,
total_num_anchors,
_gt_bboxes_per_image,
_gt_classes,
_bboxes_preds_per_image,
_cls_preds_per_image,
_obj_preds_per_image,
_expanded_strides,
_xy_shifts,
num_classes
)
gt_matched_classes = gt_matched_classes.cuda()
fg_mask = fg_mask.cuda()
pred_ious_this_matching = pred_ious_this_matching.cuda()
matched_gt_inds = matched_gt_inds.cuda()
torch.cuda.empty_cache()
num_fg += num_fg_img
if num_fg_img > 0:
cls_target = F.one_hot(
gt_matched_classes.to(torch.int64), num_classes
) * pred_ious_this_matching.unsqueeze(-1)
obj_target = fg_mask.unsqueeze(-1)
reg_target = gt_bboxes_per_image[matched_gt_inds]
l1_target = self.get_l1_target(
outputs.new_zeros((num_fg_img, 4)),
gt_bboxes_per_image[matched_gt_inds],
expanded_strides[0][fg_mask],
xy_shifts=xy_shifts[0][fg_mask],
)
cls_targets.append(cls_target)
reg_targets.append(reg_target)
obj_targets.append(obj_target)
l1_targets.append(l1_target)
fg_masks.append(fg_mask)
cls_targets = torch.cat(cls_targets, 0)
reg_targets = torch.cat(reg_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
l1_targets = torch.cat(l1_targets, 0)
fg_masks = torch.cat(fg_masks, 0)
num_fg = max(num_fg, 1)
# loss
loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls
return total_losses, torch.cat((self.reg_weight * loss_iou, loss_l1, loss_obj, loss_cls)).detach()
def decode_output(self, output, k, stride, dtype, device):
grid = self.grids[k].to(device)
batch_size = output.shape[0]
hsize, wsize = output.shape[2:4]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)
self.grids[k] = grid
output = output.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
output_origin = output.clone()
grid = grid.view(1, -1, 2)
output[..., :2] = (output[..., :2] + grid) * stride
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
return output, output_origin, grid, hsize, wsize
def get_outputs_and_grids(self, outputs, strides, dtype, device):
xy_shifts = []
expanded_strides = []
outputs_new = []
outputs_origin = []
for k, output in enumerate(outputs):
output, output_origin, grid, feat_h, feat_w = self.decode_output(
output, k, strides[k], dtype, device)
xy_shift = grid
expanded_stride = torch.full((1, grid.shape[1], 1), strides[k], dtype=grid.dtype, device=grid.device)
xy_shifts.append(xy_shift)
expanded_strides.append(expanded_stride)
outputs_new.append(output)
outputs_origin.append(output_origin)
xy_shifts = torch.cat(xy_shifts, 1) # [1, n_anchors_all, 2]
expanded_strides = torch.cat(expanded_strides, 1) # [1, n_anchors_all, 1]
outputs_origin = torch.cat(outputs_origin, 1)
outputs = torch.cat(outputs_new, 1)
feat_h *= strides[-1]
feat_w *= strides[-1]
gt_bboxes_scale = torch.Tensor([[feat_w, feat_h, feat_w, feat_h]]).type_as(outputs)
return outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides
def get_l1_target(self, l1_target, gt, stride, xy_shifts, eps=1e-8):
l1_target[:, 0:2] = gt[:, 0:2] / stride - xy_shifts
l1_target[:, 2:4] = torch.log(gt[:, 2:4] / stride + eps)
return l1_target
@torch.no_grad()
def get_assignments(
self,
batch_idx,
num_gt,
total_num_anchors,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
cls_preds_per_image,
obj_preds_per_image,
expanded_strides,
xy_shifts,
num_classes
):
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
gt_bboxes_per_image,
expanded_strides,
xy_shifts,
total_num_anchors,
num_gt,
)
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
cls_preds_ = cls_preds_per_image[fg_mask]
obj_preds_ = obj_preds_per_image[fg_mask]
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
# cost
pair_wise_ious = pairwise_bbox_iou(gt_bboxes_per_image, bboxes_preds_per_image, box_format='xywh')
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
gt_cls_per_image = (
F.one_hot(gt_classes.to(torch.int64), num_classes)
.float()
.unsqueeze(1)
.repeat(1, num_in_boxes_anchor, 1)
)
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
* obj_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
)
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
).sum(-1)
del cls_preds_, obj_preds_
cost = (
self.cls_weight * pair_wise_cls_loss
+ self.iou_weight * pair_wise_ious_loss
+ 100000.0 * (~is_in_boxes_and_center)
)
(
num_fg,
gt_matched_classes,
pred_ious_this_matching,
matched_gt_inds,
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
return (
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg,
)
def get_in_boxes_info(
self,
gt_bboxes_per_image,
expanded_strides,
xy_shifts,
total_num_anchors,
num_gt,
):
expanded_strides_per_image = expanded_strides[0]
xy_shifts_per_image = xy_shifts[0] * expanded_strides_per_image
xy_centers_per_image = (
(xy_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1, 1)
) # [n_anchor, 2] -> [n_gt, n_anchor, 2]
gt_bboxes_per_image_lt = (
(gt_bboxes_per_image[:, 0:2] - 0.5 * gt_bboxes_per_image[:, 2:4])
.unsqueeze(1)
.repeat(1, total_num_anchors, 1)
)
gt_bboxes_per_image_rb = (
(gt_bboxes_per_image[:, 0:2] + 0.5 * gt_bboxes_per_image[:, 2:4])
.unsqueeze(1)
.repeat(1, total_num_anchors, 1)
) # [n_gt, 2] -> [n_gt, n_anchor, 2]
b_lt = xy_centers_per_image - gt_bboxes_per_image_lt
b_rb = gt_bboxes_per_image_rb - xy_centers_per_image
bbox_deltas = torch.cat([b_lt, b_rb], 2)
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
# in fixed center
gt_bboxes_per_image_lt = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
1, total_num_anchors, 1
) - self.center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_rb = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
1, total_num_anchors, 1
) + self.center_radius * expanded_strides_per_image.unsqueeze(0)
c_lt = xy_centers_per_image - gt_bboxes_per_image_lt
c_rb = gt_bboxes_per_image_rb - xy_centers_per_image
center_deltas = torch.cat([c_lt, c_rb], 2)
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
# in boxes and in centers
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = (
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
)
return is_in_boxes_anchor, is_in_boxes_and_center
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
dynamic_ks = dynamic_ks.tolist()
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
)
matching_matrix[gt_idx][pos_idx] = 1
del topk_ious, dynamic_ks, pos_idx
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
fg_mask_inboxes = matching_matrix.sum(0) > 0
num_fg = fg_mask_inboxes.sum().item()
fg_mask[fg_mask.clone()] = fg_mask_inboxes
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
gt_matched_classes = gt_classes[matched_gt_inds]
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
fg_mask_inboxes
]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

@ -0,0 +1,108 @@
import torch
from torch import nn
from yolov6.layers.common import RepBlock, SimConv, Transpose
class RepPANNeck(nn.Module):
"""RepPANNeck Module
EfficientRep is the default backbone of this model.
RepPANNeck has the balance of feature fusion ability and hardware efficiency.
"""
def __init__(
self,
channels_list=None,
num_repeats=None
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None
self.Rep_p4 = RepBlock(
in_channels=channels_list[3] + channels_list[5],
out_channels=channels_list[5],
n=num_repeats[5],
)
self.Rep_p3 = RepBlock(
in_channels=channels_list[2] + channels_list[6],
out_channels=channels_list[6],
n=num_repeats[6]
)
self.Rep_n3 = RepBlock(
in_channels=channels_list[6] + channels_list[7],
out_channels=channels_list[8],
n=num_repeats[7],
)
self.Rep_n4 = RepBlock(
in_channels=channels_list[5] + channels_list[9],
out_channels=channels_list[10],
n=num_repeats[8]
)
self.reduce_layer0 = SimConv(
in_channels=channels_list[4],
out_channels=channels_list[5],
kernel_size=1,
stride=1
)
self.upsample0 = Transpose(
in_channels=channels_list[5],
out_channels=channels_list[5],
)
self.reduce_layer1 = SimConv(
in_channels=channels_list[5],
out_channels=channels_list[6],
kernel_size=1,
stride=1
)
self.upsample1 = Transpose(
in_channels=channels_list[6],
out_channels=channels_list[6]
)
self.downsample2 = SimConv(
in_channels=channels_list[6],
out_channels=channels_list[7],
kernel_size=3,
stride=2
)
self.downsample1 = SimConv(
in_channels=channels_list[8],
out_channels=channels_list[9],
kernel_size=3,
stride=2
)
def forward(self, input):
(x2, x1, x0) = input
fpn_out0 = self.reduce_layer0(x0)
upsample_feat0 = self.upsample0(fpn_out0)
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1)
f_out0 = self.Rep_p4(f_concat_layer0)
fpn_out1 = self.reduce_layer1(f_out0)
upsample_feat1 = self.upsample1(fpn_out1)
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1)
pan_out2 = self.Rep_p3(f_concat_layer1)
down_feat1 = self.downsample2(pan_out2)
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1)
pan_out1 = self.Rep_n3(p_concat_layer1)
down_feat0 = self.downsample1(pan_out1)
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1)
pan_out0 = self.Rep_n4(p_concat_layer2)
outputs = [pan_out2, pan_out1, pan_out0]
return outputs

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
import torch.nn as nn
from yolov6.layers.common import *
from yolov6.utils.torch_utils import initialize_weights
from yolov6.models.efficientrep import EfficientRep
from yolov6.models.reppan import RepPANNeck
from yolov6.models.effidehead import Detect, build_effidehead_layer
class Model(nn.Module):
'''YOLOv6 model with backbone, neck and head.
The default parts are EfficientRep Backbone, Rep-PAN and
Efficient Decoupled Head.
'''
def __init__(self, config, channels=3, num_classes=None, anchors=None): # model, input channels, number of classes
super().__init__()
# Build network
num_layers = config.model.head.num_layers
self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers)
# Init Detect head
begin_indices = config.model.head.begin_indices
out_indices_head = config.model.head.out_indices
self.stride = self.detect.stride
self.detect.i = begin_indices
self.detect.f = out_indices_head
self.detect.initialize_biases()
# Init weights
initialize_weights(self)
def forward(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.detect(x)
return x
def _apply(self, fn):
self = super()._apply(fn)
self.detect.stride = fn(self.detect.stride)
self.detect.grid = list(map(fn, self.detect.grid))
return self
def make_divisible(x, divisor):
# Upward revision the value x to make it evenly divisible by the divisor.
return math.ceil(x / divisor) * divisor
def build_network(config, channels, num_classes, anchors, num_layers):
depth_mul = config.model.depth_multiple
width_mul = config.model.width_multiple
num_repeat_backbone = config.model.backbone.num_repeats
channels_list_backbone = config.model.backbone.out_channels
num_repeat_neck = config.model.neck.num_repeats
channels_list_neck = config.model.neck.out_channels
num_anchors = config.model.head.anchors
num_repeat = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]
channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_backbone + channels_list_neck)]
backbone = EfficientRep(
in_channels=channels,
channels_list=channels_list,
num_repeats=num_repeat
)
neck = RepPANNeck(
channels_list=channels_list,
num_repeats=num_repeat
)
head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes)
head = Detect(num_classes, anchors, num_layers, head_layers=head_layers)
return backbone, neck, head
def build_model(cfg, num_classes, device):
model = Model(cfg, channels=3, num_classes=num_classes, anchors=cfg.model.head.anchors).to(device)
return model

@ -0,0 +1,42 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import math
import torch
import torch.nn as nn
def build_optimizer(cfg, model):
""" Build optimizer from cfg file."""
g_bnw, g_w, g_b = [], [], []
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
g_b.append(v.bias)
if isinstance(v, nn.BatchNorm2d):
g_bnw.append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
g_w.append(v.weight)
assert cfg.solver.optim == 'SGD' or 'Adam', 'ERROR: unknown optimizer, use SGD defaulted'
if cfg.solver.optim == 'SGD':
optimizer = torch.optim.SGD(g_bnw, lr=cfg.solver.lr0, momentum=cfg.solver.momentum, nesterov=True)
elif cfg.solver.optim == 'Adam':
optimizer = torch.optim.Adam(g_bnw, lr=cfg.solver.lr0, betas=(cfg.solver.momentum, 0.999))
optimizer.add_param_group({'params': g_w, 'weight_decay': cfg.solver.weight_decay})
optimizer.add_param_group({'params': g_b})
del g_bnw, g_w, g_b
return optimizer
def build_lr_scheduler(cfg, optimizer, epochs):
"""Build learning rate scheduler from cfg file."""
if cfg.solver.lr_scheduler == 'Cosine':
lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.solver.lrf - 1) + 1
else:
LOGGER.error('unknown lr scheduler, use Cosine defaulted')
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
return scheduler, lf

Binary file not shown.

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import shutil
import torch
import os.path as osp
from yolov6.utils.events import LOGGER
from yolov6.utils.torch_utils import fuse_model
def load_state_dict(weights, model, map_location=None):
"""Load weights from checkpoint file, only assign weights those layers' name and shape are match."""
ckpt = torch.load(weights, map_location=map_location)
state_dict = ckpt['model'].float().state_dict()
model_state_dict = model.state_dict()
state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
model.load_state_dict(state_dict, strict=False)
del ckpt, state_dict, model_state_dict
return model
def load_checkpoint(weights, map_location=None, inplace=True, fuse=True):
"""Load model from checkpoint file."""
LOGGER.info("Loading checkpoint from {}".format(weights))
ckpt = torch.load(weights, map_location=map_location) # load
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
if fuse:
LOGGER.info("\nFusing model...")
model = fuse_model(model).eval()
else:
model = model.eval()
return model
def save_checkpoint(ckpt, is_best, save_dir, model_name=""):
""" Save checkpoint to the disk."""
if not osp.exists(save_dir):
os.makedirs(save_dir)
filename = osp.join(save_dir, model_name + '.pt')
torch.save(ckpt, filename)
if is_best:
best_filename = osp.join(save_dir, 'best_ckpt.pt')
shutil.copyfile(filename, best_filename)
def strip_optimizer(ckpt_dir, epoch):
for s in ['best', 'last']:
ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s))
if not osp.exists(ckpt_path):
continue
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
if ckpt.get('ema'):
ckpt['model'] = ckpt['ema'] # replace model with ema
for k in ['optimizer', 'ema', 'updates']: # keys
ckpt[k] = None
ckpt['epoch'] = epoch
ckpt['model'].half() # to FP16
for p in ckpt['model'].parameters():
p.requires_grad = False
torch.save(ckpt, ckpt_path)

@ -0,0 +1,101 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# The code is based on
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
# Copyright (c) OpenMMLab.
import os.path as osp
import shutil
import sys
import tempfile
from importlib import import_module
from addict import Dict
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError("'{}' object has no attribute '{}'".format(
self.__class__.__name__, name))
except Exception as e:
ex = e
else:
return value
raise ex
class Config(object):
@staticmethod
def _file2dict(filename):
filename = str(filename)
if filename.endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
shutil.copyfile(filename,
osp.join(temp_config_dir, '_tempconfig.py'))
sys.path.insert(0, temp_config_dir)
mod = import_module('_tempconfig')
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules['_tempconfig']
else:
raise IOError('Only .py type are supported now!')
cfg_text = filename + '\n'
with open(filename, 'r') as f:
cfg_text += f.read()
return cfg_dict, cfg_text
@staticmethod
def fromfile(filename):
cfg_dict, cfg_text = Config._file2dict(filename)
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but got {}'.format(
type(cfg_dict)))
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
def __repr__(self):
return 'Config (path: {}): {}'.format(self.filename,
self._cfg_dict.__repr__())
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)

@ -0,0 +1,59 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# The code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
import math
from copy import deepcopy
import torch
import torch.nn as nn
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, updates=0):
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
self.updates = updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
for param in self.ema.parameters():
param.requires_grad_(False)
def update(self, model):
with torch.no_grad():
self.updates += 1
decay = self.decay(self.updates)
state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
for k, item in self.ema.state_dict().items():
if item.dtype.is_floating_point:
item *= decay
item += (1 - decay) * state_dict[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
copy_attr(self.ema, model, include, exclude)
def copy_attr(a, b, include=(), exclude=()):
"""Copy attributes from one instance and set them to another instance."""
for k, item in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, item)
def is_parallel(model):
# Return True if model's type is DP or DDP, else False.
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model. Return single-GPU model if model's type is DP or DDP.
return model.module if is_parallel(model) else model

@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from yolov6.utils.events import LOGGER
def get_envs():
"""Get PyTorch needed environments from system envirionments."""
local_rank = int(os.getenv('LOCAL_RANK', -1))
rank = int(os.getenv('RANK', -1))
world_size = int(os.getenv('WORLD_SIZE', 1))
return local_rank, rank, world_size
def select_device(device):
"""Set devices' information to the program.
Args:
device: a string, like 'cpu' or '1,2,3,4'
Returns:
torch.device
"""
if device == 'cpu':
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
LOGGER.info('Using CPU for training... ')
elif device:
os.environ['CUDA_VISIBLE_DEVICES'] = device
assert torch.cuda.is_available()
nd = len(device.strip().split(','))
LOGGER.info(f'Using {nd} GPU for training... ')
cuda = device != 'cpu' and torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
return device
def set_random_seed(seed, deterministic=False):
""" Set random state to random libray, numpy, torch and cudnn.
Args:
seed: int value.
deterministic: bool value.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if deterministic:
cudnn.deterministic = True
cudnn.benchmark = False
else:
cudnn.deterministic = False
cudnn.benchmark = True

@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import yaml
import logging
import shutil
def set_logging(name=None):
rank = int(os.getenv('RANK', -1))
logging.basicConfig(format="%(message)s", level=logging.INFO if (rank in (-1, 0)) else logging.WARNING)
return logging.getLogger(name)
LOGGER = set_logging(__name__)
NCOLS = shutil.get_terminal_size().columns
def load_yaml(file_path):
"""Load data from yaml file."""
if isinstance(file_path, str):
with open(file_path, errors='ignore') as f:
data_dict = yaml.safe_load(f)
return data_dict
def save_yaml(data_dict, save_path):
"""Save data to yaml file"""
with open(save_path, 'w') as f:
yaml.safe_dump(data_dict, f, sort_keys=False)
def write_tblog(tblogger, epoch, results, losses):
"""Display mAP and loss information to log."""
tblogger.add_scalar("val/mAP@0.5", results[0], epoch + 1)
tblogger.add_scalar("val/mAP@0.50:0.95", results[1], epoch + 1)
tblogger.add_scalar("train/iou_loss", losses[0], epoch + 1)
tblogger.add_scalar("train/l1_loss", losses[1], epoch + 1)
tblogger.add_scalar("train/obj_loss", losses[2], epoch + 1)
tblogger.add_scalar("train/cls_loss", losses[3], epoch + 1)

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
import torch
class IOUloss:
""" Calculate IoU loss.
"""
def __init__(self, box_format='xywh', iou_type='ciou', reduction='none', eps=1e-7):
""" Setting of the class.
Args:
box_format: (string), must be one of 'xywh' or 'xyxy'.
iou_type: (string), can be one of 'ciou', 'diou', 'giou' or 'siou'
reduction: (string), specifies the reduction to apply to the output, must be one of 'none', 'mean','sum'.
eps: (float), a value to avoid divide by zero error.
"""
self.box_format = box_format
self.iou_type = iou_type.lower()
self.reduction = reduction
self.eps = eps
def __call__(self, box1, box2):
""" calculate iou. box1 and box2 are torch tensor with shape [M, 4] and [Nm 4].
"""
box2 = box2.T
if self.box_format == 'xyxy':
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
elif self.box_format == 'xywh':
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
# Intersection area
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps
union = w1 * h1 + w2 * h2 - inter + self.eps
iou = inter / union
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex width
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
if self.iou_type == 'giou':
c_area = cw * ch + self.eps # convex area
iou = iou - (c_area - union) / c_area
elif self.iou_type in ['diou', 'ciou']:
c2 = cw ** 2 + ch ** 2 + self.eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
if self.iou_type == 'diou':
iou = iou - rho2 / c2
elif self.iou_type == 'ciou':
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + self.eps))
iou = iou - (rho2 / c2 + v * alpha)
elif self.iou_type == 'siou':
# SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
sin_alpha_1 = torch.abs(s_cw) / sigma
sin_alpha_2 = torch.abs(s_ch) / sigma
threshold = pow(2, 0.5) / 2
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
rho_x = (s_cw / cw) ** 2
rho_y = (s_ch / ch) ** 2
gamma = angle_cost - 2
distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
iou = iou - 0.5 * (distance_cost + shape_cost)
loss = 1.0 - iou
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss
def pairwise_bbox_iou(box1, box2, box_format='xywh'):
"""Calculate iou.
This code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/boxes.py
"""
if box_format == 'xyxy':
lt = torch.max(box1[:, None, :2], box2[:, :2])
rb = torch.min(box1[:, None, 2:], box2[:, 2:])
area_1 = torch.prod(box1[:, 2:] - box1[:, :2], 1)
area_2 = torch.prod(box2[:, 2:] - box2[:, :2], 1)
elif box_format == 'xywh':
lt = torch.max(
(box1[:, None, :2] - box1[:, None, 2:] / 2),
(box2[:, :2] - box2[:, 2:] / 2),
)
rb = torch.min(
(box1[:, None, :2] + box1[:, None, 2:] / 2),
(box2[:, :2] + box2[:, 2:] / 2),
)
area_1 = torch.prod(box1[:, 2:], 1)
area_2 = torch.prod(box2[:, 2:], 1)
valid = (lt < rb).type(lt.type()).prod(dim=2)
inter = torch.prod(rb - lt, 2) * valid
return inter / (area_1[:, None] + area_2 - inter)

@ -0,0 +1,17 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
from pathlib import Path
def increment_name(path, master_process):
"increase save directory's id"
path = Path(path)
sep = ''
if path.exists() and master_process:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
for n in range(1, 9999):
p = f'{path}{sep}{n}{suffix}'
if not os.path.exists(p):
break
path = Path(p)
return path

@ -0,0 +1,106 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# The code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/general.py
import os
import time
import numpy as np
import cv2
import torch
import torchvision
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
def xywh2xyxy(x):
# Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results.
This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775
Args:
prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes.
conf_thres: (float) confidence threshold.
iou_thres: (float) iou threshold.
classes: (None or list[int]), if a list is provided, nms only keep the classes you provide.
agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively.
multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label.
max_det:(int), max number of output bboxes.
Returns:
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
"""
num_classes = prediction.shape[2] - 5 # number of classes
pred_candidates = prediction[..., 4] > conf_thres # candidates
# Check the parameters.
assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.'
assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.'
# Function settings.
max_wh = 4096 # maximum box width and height
max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms()
time_limit = 10.0 # quit the function when nms cost time exceed the limit time.
multi_label &= num_classes > 1 # multiple labels per box
tik = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for img_idx, x in enumerate(prediction): # image index, image inference
x = x[pred_candidates[img_idx]] # confidence
# If no box remains, skip the next process.
if not x.shape[0]:
continue
# confidence multiply the objectness
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls)
if multi_label:
box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1)
else: # Only keep the class with highest scores.
conf, class_idx = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class, only keep boxes whose category is in classes.
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
num_box = x.shape[0] # number of boxes
if not num_box: # no boxes kept.
continue
elif num_box > max_nms: # excess max boxes' number.
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores
keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if keep_box_idx.shape[0] > max_det: # limit detections
keep_box_idx = keep_box_idx[:max_det]
output[img_idx] = x[keep_box_idx]
if (time.time() - tik) > time_limit:
print(f'WARNING: NMS cost time exceed the limited {time_limit}s.')
break # time limit exceeded
return output

@ -0,0 +1,109 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import time
from contextlib import contextmanager
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from yolov6.utils.events import LOGGER
try:
import thop # for FLOPs computation
except ImportError:
thop = None
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
dist.barrier(device_ids=[local_rank])
yield
if local_rank == 0:
dist.barrier(device_ids=[0])
def time_sync():
# Waits for all kernels in all streams on a CUDA device to complete if cuda is available.
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def initialize_weights(model):
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# prepare spatial bias
b_conv = (
torch.zeros(conv.weight.size(0), device=conv.weight.device)
if conv.bias is None
else conv.bias
)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def fuse_model(model):
from yolov6.layers.common import Conv
for m in model.modules():
if type(m) is Conv and hasattr(m, "bn"):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
return model
def get_model_info(model, img_size=640):
"""Get model Params and GFlops.
Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py
"""
from thop import profile
stride = 32
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
params /= 1e6
flops /= 1e9
img_size = img_size if isinstance(img_size, list) else [img_size, img_size]
flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
return info

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save