Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
pvqf6mep3 | e763ddb3a0 | 3 years ago |
@ -1,3 +1,3 @@
|
|||||||
# Air-ground-CEC
|
# Air-ground-CEC
|
||||||
|
|
||||||
test
|
此分支用于 测试 git 命令 不具有任何其他意义
|
Before Width: | Height: | Size: 3.1 MiB |
@ -0,0 +1,28 @@
|
|||||||
|
#ifndef MAINWINDOW_H
|
||||||
|
#define MAINWINDOW_H
|
||||||
|
|
||||||
|
#include <QMainWindow>
|
||||||
|
#include <QImage>
|
||||||
|
#include <QMutex>
|
||||||
|
|
||||||
|
namespace Ui {
|
||||||
|
class MainWindow;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MainWindow : public QMainWindow
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit MainWindow(QWidget *parent = nullptr);
|
||||||
|
~MainWindow();
|
||||||
|
void updateLogcamera();
|
||||||
|
void displayCamera(const QImage& image);
|
||||||
|
private:
|
||||||
|
Ui::MainWindow *ui;
|
||||||
|
//QNode qnode;
|
||||||
|
QImage qimage_;
|
||||||
|
mutable QMutex qinmage_mutex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // MAINWINDOW_H
|
@ -0,0 +1,8 @@
|
|||||||
|
#include <opencv4/opencv2/core.hpp>
|
||||||
|
#include <opencv4/opencv2/highgui.hpp>
|
||||||
|
#include <opencv4/opencv2/imgproc.hpp>
|
||||||
|
#include <image_transport/image_transport.h>
|
||||||
|
#include <cv_bridge/cv_bridge.h>
|
||||||
|
#include <QImage>
|
||||||
|
|
||||||
|
|
@ -0,0 +1 @@
|
|||||||
|
create
|
@ -1,803 +0,0 @@
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// STL A* Search implementation
|
|
||||||
// (C)2001 Justin Heyes-Jones
|
|
||||||
//
|
|
||||||
// This uses my A* code to solve the 8-puzzle
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <assert.h>
|
|
||||||
#include <new>
|
|
||||||
|
|
||||||
#include <ctype.h>
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
// Configuration
|
|
||||||
|
|
||||||
#define NUM_TIMES_TO_RUN_SEARCH 1
|
|
||||||
#define DISPLAY_SOLUTION_FORWARDS 1
|
|
||||||
#define DISPLAY_SOLUTION_BACKWARDS 0
|
|
||||||
#define DISPLAY_SOLUTION_INFO 1
|
|
||||||
#define DEBUG_LISTS 0
|
|
||||||
|
|
||||||
// AStar search class
|
|
||||||
#include "stlastar.h" // See header for copyright and usage information
|
|
||||||
|
|
||||||
// Global data
|
|
||||||
|
|
||||||
#define BOARD_WIDTH (3)
|
|
||||||
#define BOARD_HEIGHT (3)
|
|
||||||
|
|
||||||
#define GM_TILE (-1)
|
|
||||||
#define GM_SPACE (0)
|
|
||||||
#define GM_OFF_BOARD (1)
|
|
||||||
|
|
||||||
// Definitions
|
|
||||||
|
|
||||||
// To use the search class you must define the following calls...
|
|
||||||
|
|
||||||
// Data
|
|
||||||
// Your own state space information
|
|
||||||
// Functions
|
|
||||||
// (Optional) Constructor.
|
|
||||||
// Nodes are created by the user, so whether you use a
|
|
||||||
// constructor with parameters as below, or just set the object up after the
|
|
||||||
// constructor, is up to you.
|
|
||||||
//
|
|
||||||
// (Optional) Destructor.
|
|
||||||
// The destructor will be called if you create one. You
|
|
||||||
// can rely on the default constructor unless you dynamically allocate something in
|
|
||||||
// your data
|
|
||||||
//
|
|
||||||
// float GoalDistanceEstimate( PuzzleState &nodeGoal );
|
|
||||||
// Return the estimated cost to goal from this node (pass reference to goal node)
|
|
||||||
//
|
|
||||||
// bool IsGoal( PuzzleState &nodeGoal );
|
|
||||||
// Return true if this node is the goal.
|
|
||||||
//
|
|
||||||
// bool GetSuccessors( AStarSearch<PuzzleState> *astarsearch );
|
|
||||||
// For each successor to this state call the AStarSearch's AddSuccessor call to
|
|
||||||
// add each one to the current search - return false if you are out of memory and the search
|
|
||||||
// will fail
|
|
||||||
//
|
|
||||||
// float GetCost( PuzzleState *successor );
|
|
||||||
// Return the cost moving from this state to the state of successor
|
|
||||||
//
|
|
||||||
// bool IsSameState( PuzzleState &rhs );
|
|
||||||
// Return true if the provided state is the same as this state
|
|
||||||
|
|
||||||
// Here the example is the 8-puzzle state ...
|
|
||||||
class PuzzleState
|
|
||||||
{
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
// defs
|
|
||||||
|
|
||||||
typedef enum
|
|
||||||
{
|
|
||||||
TL_SPACE,
|
|
||||||
TL_1,
|
|
||||||
TL_2,
|
|
||||||
TL_3,
|
|
||||||
TL_4,
|
|
||||||
TL_5,
|
|
||||||
TL_6,
|
|
||||||
TL_7,
|
|
||||||
TL_8
|
|
||||||
|
|
||||||
} TILE;
|
|
||||||
|
|
||||||
// data
|
|
||||||
|
|
||||||
static TILE g_goal[ BOARD_WIDTH*BOARD_HEIGHT];
|
|
||||||
static TILE g_start[ BOARD_WIDTH*BOARD_HEIGHT];
|
|
||||||
|
|
||||||
// the tile data for the 8-puzzle
|
|
||||||
TILE tiles[ BOARD_WIDTH*BOARD_HEIGHT ];
|
|
||||||
|
|
||||||
// member functions
|
|
||||||
|
|
||||||
PuzzleState() {
|
|
||||||
memcpy( tiles, g_goal, sizeof( TILE ) * BOARD_WIDTH * BOARD_HEIGHT );
|
|
||||||
}
|
|
||||||
|
|
||||||
PuzzleState( TILE *param_tiles )
|
|
||||||
{
|
|
||||||
memcpy( tiles, param_tiles, sizeof( TILE ) * BOARD_WIDTH * BOARD_HEIGHT );
|
|
||||||
}
|
|
||||||
|
|
||||||
float GoalDistanceEstimate( PuzzleState &nodeGoal );
|
|
||||||
bool IsGoal( PuzzleState &nodeGoal );
|
|
||||||
bool GetSuccessors( AStarSearch<PuzzleState> *astarsearch, PuzzleState *parent_node );
|
|
||||||
float GetCost( PuzzleState &successor );
|
|
||||||
bool IsSameState( PuzzleState &rhs );
|
|
||||||
|
|
||||||
void PrintNodeInfo();
|
|
||||||
|
|
||||||
private:
|
|
||||||
// User stuff - Just add what you need to help you write the above functions...
|
|
||||||
|
|
||||||
void GetSpacePosition( PuzzleState *pn, int *rx, int *ry );
|
|
||||||
bool LegalMove( TILE *StartTiles, TILE *TargetTiles, int spx, int spy, int tx, int ty );
|
|
||||||
int GetMap( int x, int y, TILE *tiles );
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
// Goal state
|
|
||||||
PuzzleState::TILE PuzzleState::g_goal[] =
|
|
||||||
{
|
|
||||||
TL_1,
|
|
||||||
TL_2,
|
|
||||||
TL_3,
|
|
||||||
TL_8,
|
|
||||||
TL_SPACE,
|
|
||||||
TL_4,
|
|
||||||
TL_7,
|
|
||||||
TL_6,
|
|
||||||
TL_5,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Some nice Start states
|
|
||||||
PuzzleState::TILE PuzzleState::g_start[] =
|
|
||||||
{
|
|
||||||
|
|
||||||
// Three example start states from Bratko's Prolog Programming for Artificial Intelligence
|
|
||||||
|
|
||||||
#if 1
|
|
||||||
// ex a - 4 steps
|
|
||||||
TL_1 ,
|
|
||||||
TL_3 ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_2 ,
|
|
||||||
TL_7 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// ex b - 5 steps
|
|
||||||
TL_2 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_3 ,
|
|
||||||
TL_1 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_7 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// ex c - 18 steps
|
|
||||||
TL_2 ,
|
|
||||||
TL_1 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_7 ,
|
|
||||||
TL_5 ,
|
|
||||||
TL_3 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// nasty one - doesn't solve
|
|
||||||
TL_6 ,
|
|
||||||
TL_3 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_5 ,
|
|
||||||
TL_7 ,
|
|
||||||
TL_2 ,
|
|
||||||
TL_1 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// sent by email - does work though
|
|
||||||
|
|
||||||
TL_1 , TL_2 , TL_3 ,
|
|
||||||
TL_4 , TL_5 , TL_6 ,
|
|
||||||
TL_8 , TL_7 , TL_SPACE ,
|
|
||||||
|
|
||||||
|
|
||||||
// from http://www.cs.utexas.edu/users/novak/asg-8p.html
|
|
||||||
|
|
||||||
//Goal: Easy: Medium: Hard: Worst:
|
|
||||||
|
|
||||||
//1 2 3 1 3 4 2 8 1 2 8 1 5 6 7
|
|
||||||
//8 4 8 6 2 4 3 4 6 3 4 8
|
|
||||||
//7 6 5 7 5 7 6 5 7 5 3 2 1
|
|
||||||
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// easy 5
|
|
||||||
TL_1 ,
|
|
||||||
TL_3 ,
|
|
||||||
TL_4 ,
|
|
||||||
|
|
||||||
TL_8 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_2 ,
|
|
||||||
|
|
||||||
TL_7 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// medium 9
|
|
||||||
TL_2 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_1 ,
|
|
||||||
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_3 ,
|
|
||||||
|
|
||||||
TL_7 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// hard 12
|
|
||||||
TL_2 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_1 ,
|
|
||||||
|
|
||||||
TL_4 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_3 ,
|
|
||||||
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_7 ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// worst 30
|
|
||||||
TL_5 ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_7 ,
|
|
||||||
|
|
||||||
TL_4 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_8 ,
|
|
||||||
|
|
||||||
TL_3 ,
|
|
||||||
TL_2 ,
|
|
||||||
TL_1 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
|
|
||||||
// 123
|
|
||||||
// 784
|
|
||||||
// 65
|
|
||||||
|
|
||||||
// two move simple board
|
|
||||||
TL_1 ,
|
|
||||||
TL_2 ,
|
|
||||||
TL_3 ,
|
|
||||||
|
|
||||||
TL_7 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_4 ,
|
|
||||||
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_6 ,
|
|
||||||
TL_5 ,
|
|
||||||
|
|
||||||
#elif 0
|
|
||||||
// a1 b2 c3 d4 e5 f6 g7 h8
|
|
||||||
//C3,Blank,H8,A1,G8,F6,E5,D4,B2
|
|
||||||
|
|
||||||
TL_3 ,
|
|
||||||
TL_SPACE ,
|
|
||||||
TL_8 ,
|
|
||||||
|
|
||||||
TL_1 ,
|
|
||||||
TL_8 ,
|
|
||||||
TL_6 ,
|
|
||||||
|
|
||||||
TL_5 ,
|
|
||||||
TL_4 ,
|
|
||||||
TL_2 ,
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
bool PuzzleState::IsSameState( PuzzleState &rhs )
|
|
||||||
{
|
|
||||||
|
|
||||||
for( int i=0; i<(BOARD_HEIGHT*BOARD_WIDTH); i++ )
|
|
||||||
{
|
|
||||||
if( tiles[i] != rhs.tiles[i] )
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void PuzzleState::PrintNodeInfo()
|
|
||||||
{
|
|
||||||
char str[100];
|
|
||||||
sprintf( str, "%c %c %c\n%c %c %c\n%c %c %c\n",
|
|
||||||
tiles[0] + '0',
|
|
||||||
tiles[1] + '0',
|
|
||||||
tiles[2] + '0',
|
|
||||||
tiles[3] + '0',
|
|
||||||
tiles[4] + '0',
|
|
||||||
tiles[5] + '0',
|
|
||||||
tiles[6] + '0',
|
|
||||||
tiles[7] + '0',
|
|
||||||
tiles[8] + '0'
|
|
||||||
);
|
|
||||||
|
|
||||||
cout << str;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here's the heuristic function that estimates the distance from a PuzzleState
|
|
||||||
// to the Goal.
|
|
||||||
|
|
||||||
float PuzzleState::GoalDistanceEstimate( PuzzleState &nodeGoal )
|
|
||||||
{
|
|
||||||
|
|
||||||
// Nilsson's sequence score
|
|
||||||
|
|
||||||
int i, cx, cy, ax, ay, h = 0, s, t;
|
|
||||||
|
|
||||||
// given a tile this returns the tile that should be clockwise
|
|
||||||
TILE correct_follower_to[ BOARD_WIDTH * BOARD_HEIGHT ] =
|
|
||||||
{
|
|
||||||
TL_SPACE, // always wrong
|
|
||||||
TL_2,
|
|
||||||
TL_3,
|
|
||||||
TL_4,
|
|
||||||
TL_5,
|
|
||||||
TL_6,
|
|
||||||
TL_7,
|
|
||||||
TL_8,
|
|
||||||
TL_1,
|
|
||||||
};
|
|
||||||
|
|
||||||
// given a table index returns the index of the tile that is clockwise to it 3*3 only
|
|
||||||
int clockwise_tile_of[ BOARD_WIDTH * BOARD_HEIGHT ] =
|
|
||||||
{
|
|
||||||
1,
|
|
||||||
2, // 012
|
|
||||||
5, // 345
|
|
||||||
0, // 678
|
|
||||||
-1, // never called with center square
|
|
||||||
8,
|
|
||||||
3,
|
|
||||||
6,
|
|
||||||
7
|
|
||||||
};
|
|
||||||
|
|
||||||
int tile_x[ BOARD_WIDTH * BOARD_HEIGHT ] =
|
|
||||||
{
|
|
||||||
/* TL_SPACE */ 1,
|
|
||||||
/* TL_1 */ 0,
|
|
||||||
/* TL_2 */ 1,
|
|
||||||
/* TL_3 */ 2,
|
|
||||||
/* TL_4 */ 2,
|
|
||||||
/* TL_5 */ 2,
|
|
||||||
/* TL_6 */ 1,
|
|
||||||
/* TL_7 */ 0,
|
|
||||||
/* TL_8 */ 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
int tile_y[ BOARD_WIDTH * BOARD_HEIGHT ] =
|
|
||||||
{
|
|
||||||
/* TL_SPACE */ 1,
|
|
||||||
/* TL_1 */ 0,
|
|
||||||
/* TL_2 */ 0,
|
|
||||||
/* TL_3 */ 0,
|
|
||||||
/* TL_4 */ 1,
|
|
||||||
/* TL_5 */ 2,
|
|
||||||
/* TL_6 */ 2,
|
|
||||||
/* TL_7 */ 2,
|
|
||||||
/* TL_8 */ 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
s=0;
|
|
||||||
|
|
||||||
// score 1 point if centre is not correct
|
|
||||||
if( tiles[(BOARD_HEIGHT*BOARD_WIDTH)/2] != nodeGoal.tiles[(BOARD_HEIGHT*BOARD_WIDTH)/2] )
|
|
||||||
{
|
|
||||||
s = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
for( i=0; i<(BOARD_HEIGHT*BOARD_WIDTH); i++ )
|
|
||||||
{
|
|
||||||
// this loop adds up the totaldist element in h and
|
|
||||||
// the sequence score in s
|
|
||||||
|
|
||||||
// the space does not count
|
|
||||||
if( tiles[i] == TL_SPACE )
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get correct x and y of this tile
|
|
||||||
cx = tile_x[tiles[i]];
|
|
||||||
cy = tile_y[tiles[i]];
|
|
||||||
|
|
||||||
// get actual
|
|
||||||
ax = i % BOARD_WIDTH;
|
|
||||||
ay = i / BOARD_WIDTH;
|
|
||||||
|
|
||||||
// add manhatten distance to h
|
|
||||||
h += abs( cx-ax );
|
|
||||||
h += abs( cy-ay );
|
|
||||||
|
|
||||||
// no s score for center tile
|
|
||||||
if( (ax == (BOARD_WIDTH/2)) && (ay == (BOARD_HEIGHT/2)) )
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// score 2 points if not followed by successor
|
|
||||||
if( correct_follower_to[ tiles[i] ] != tiles[ clockwise_tile_of[ i ] ] )
|
|
||||||
{
|
|
||||||
s += 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// mult by 3 and add to h
|
|
||||||
t = h + (3*s);
|
|
||||||
|
|
||||||
return (float) t;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
bool PuzzleState::IsGoal( PuzzleState &nodeGoal )
|
|
||||||
{
|
|
||||||
return IsSameState( nodeGoal );
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper
|
|
||||||
// Return the x and y position of the space tile
|
|
||||||
void PuzzleState::GetSpacePosition( PuzzleState *pn, int *rx, int *ry )
|
|
||||||
{
|
|
||||||
int x,y;
|
|
||||||
|
|
||||||
for( y=0; y<BOARD_HEIGHT; y++ )
|
|
||||||
{
|
|
||||||
for( x=0; x<BOARD_WIDTH; x++ )
|
|
||||||
{
|
|
||||||
if( pn->tiles[(y*BOARD_WIDTH)+x] == TL_SPACE )
|
|
||||||
{
|
|
||||||
*rx = x;
|
|
||||||
*ry = y;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert( false && "Something went wrong. There's no space on the board" );
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
int PuzzleState::GetMap( int x, int y, TILE *tiles )
|
|
||||||
{
|
|
||||||
|
|
||||||
if( x < 0 ||
|
|
||||||
x >= BOARD_WIDTH ||
|
|
||||||
y < 0 ||
|
|
||||||
y >= BOARD_HEIGHT
|
|
||||||
)
|
|
||||||
return GM_OFF_BOARD;
|
|
||||||
|
|
||||||
if( tiles[(y*BOARD_WIDTH)+x] == TL_SPACE )
|
|
||||||
{
|
|
||||||
return GM_SPACE;
|
|
||||||
}
|
|
||||||
|
|
||||||
return GM_TILE;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a node set of tiles and a set of tiles to move them into, do the move as if it was on a tile board
|
|
||||||
// note : returns false if the board wasn't changed, and simply returns the tiles as they were in the target
|
|
||||||
// spx and spy is the space position while tx and ty is the target move from position
|
|
||||||
|
|
||||||
bool PuzzleState::LegalMove( TILE *StartTiles, TILE *TargetTiles, int spx, int spy, int tx, int ty )
|
|
||||||
{
|
|
||||||
|
|
||||||
int t;
|
|
||||||
|
|
||||||
if( GetMap( spx, spy, StartTiles ) == GM_SPACE )
|
|
||||||
{
|
|
||||||
if( GetMap( tx, ty, StartTiles ) == GM_TILE )
|
|
||||||
{
|
|
||||||
|
|
||||||
// copy tiles
|
|
||||||
for( t=0; t<(BOARD_HEIGHT*BOARD_WIDTH); t++ )
|
|
||||||
{
|
|
||||||
TargetTiles[t] = StartTiles[t];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TargetTiles[ (ty*BOARD_WIDTH)+tx ] = StartTiles[ (spy*BOARD_WIDTH)+spx ];
|
|
||||||
TargetTiles[ (spy*BOARD_WIDTH)+spx ] = StartTiles[ (ty*BOARD_WIDTH)+tx ];
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return false;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// This generates the successors to the given PuzzleState. It uses a helper function called
|
|
||||||
// AddSuccessor to give the successors to the AStar class. The A* specific initialisation
|
|
||||||
// is done for each node internally, so here you just set the state information that
|
|
||||||
// is specific to the application
|
|
||||||
bool PuzzleState::GetSuccessors( AStarSearch<PuzzleState> *astarsearch, PuzzleState *parent_node )
|
|
||||||
{
|
|
||||||
PuzzleState NewNode;
|
|
||||||
|
|
||||||
int sp_x,sp_y;
|
|
||||||
|
|
||||||
GetSpacePosition( this, &sp_x, &sp_y );
|
|
||||||
|
|
||||||
bool ret;
|
|
||||||
|
|
||||||
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x, sp_y-1 ) == true )
|
|
||||||
{
|
|
||||||
ret = astarsearch->AddSuccessor( NewNode );
|
|
||||||
|
|
||||||
if( !ret ) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x, sp_y+1 ) == true )
|
|
||||||
{
|
|
||||||
ret = astarsearch->AddSuccessor( NewNode );
|
|
||||||
|
|
||||||
if( !ret ) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x-1, sp_y ) == true )
|
|
||||||
{
|
|
||||||
ret = astarsearch->AddSuccessor( NewNode );
|
|
||||||
|
|
||||||
if( !ret ) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if( LegalMove( tiles, NewNode.tiles, sp_x, sp_y, sp_x+1, sp_y ) == true )
|
|
||||||
{
|
|
||||||
ret = astarsearch->AddSuccessor( NewNode );
|
|
||||||
|
|
||||||
if( !ret ) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// given this node, what does it cost to move to successor. In the case
|
|
||||||
// of our map the answer is the map terrain value at this node since that is
|
|
||||||
// conceptually where we're moving
|
|
||||||
|
|
||||||
float PuzzleState::GetCost( PuzzleState &successor )
|
|
||||||
{
|
|
||||||
return 1.0f; // I love it when life is simple
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// Main
|
|
||||||
|
|
||||||
int puzzle( int argc, char *argv[] )
|
|
||||||
{
|
|
||||||
|
|
||||||
cout << "STL A* 8-puzzle solver implementation\n(C)2001 Justin Heyes-Jones\n";
|
|
||||||
|
|
||||||
if( argc > 1 )
|
|
||||||
{
|
|
||||||
int i = 0;
|
|
||||||
int c;
|
|
||||||
|
|
||||||
while( (c = argv[1][i]) )
|
|
||||||
{
|
|
||||||
if( isdigit( c ) )
|
|
||||||
{
|
|
||||||
int num = (c - '0');
|
|
||||||
|
|
||||||
PuzzleState::g_start[i] = static_cast<PuzzleState::TILE>(num);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create an instance of the search class...
|
|
||||||
|
|
||||||
AStarSearch<PuzzleState> astarsearch;
|
|
||||||
|
|
||||||
int NumTimesToSearch = NUM_TIMES_TO_RUN_SEARCH;
|
|
||||||
|
|
||||||
while( NumTimesToSearch-- )
|
|
||||||
{
|
|
||||||
|
|
||||||
// Create a start state
|
|
||||||
PuzzleState nodeStart( PuzzleState::g_start );
|
|
||||||
|
|
||||||
// Define the goal state
|
|
||||||
PuzzleState nodeEnd( PuzzleState::g_goal );
|
|
||||||
|
|
||||||
// Set Start and goal states
|
|
||||||
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
|
|
||||||
|
|
||||||
unsigned int SearchState;
|
|
||||||
|
|
||||||
unsigned int SearchSteps = 0;
|
|
||||||
|
|
||||||
do
|
|
||||||
{
|
|
||||||
SearchState = astarsearch.SearchStep();
|
|
||||||
|
|
||||||
#if DEBUG_LISTS
|
|
||||||
|
|
||||||
float f,g,h;
|
|
||||||
|
|
||||||
cout << "Search step " << SearchSteps << endl;
|
|
||||||
|
|
||||||
cout << "Open:\n";
|
|
||||||
PuzzleState *p = astarsearch.GetOpenListStart( f,g,h );
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
((PuzzleState *)p)->PrintNodeInfo();
|
|
||||||
cout << "f: " << f << " g: " << g << " h: " << h << "\n\n";
|
|
||||||
|
|
||||||
p = astarsearch.GetOpenListNext( f,g,h );
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "Closed:\n";
|
|
||||||
p = astarsearch.GetClosedListStart( f,g,h );
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
p->PrintNodeInfo();
|
|
||||||
cout << "f: " << f << " g: " << g << " h: " << h << "\n\n";
|
|
||||||
|
|
||||||
p = astarsearch.GetClosedListNext( f,g,h );
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Test cancel search
|
|
||||||
#if 0
|
|
||||||
int StepCount = astarsearch.GetStepCount();
|
|
||||||
if( StepCount == 10 )
|
|
||||||
{
|
|
||||||
astarsearch.CancelSearch();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
SearchSteps++;
|
|
||||||
}
|
|
||||||
while( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_SEARCHING );
|
|
||||||
|
|
||||||
if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_SUCCEEDED )
|
|
||||||
{
|
|
||||||
#if DISPLAY_SOLUTION_FORWARDS
|
|
||||||
cout << "Search found goal state\n";
|
|
||||||
#endif
|
|
||||||
PuzzleState *node = astarsearch.GetSolutionStart();
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_FORWARDS
|
|
||||||
cout << "Displaying solution\n";
|
|
||||||
#endif
|
|
||||||
int steps = 0;
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_FORWARDS
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
cout << endl;
|
|
||||||
#endif
|
|
||||||
for( ;; )
|
|
||||||
{
|
|
||||||
node = astarsearch.GetSolutionNext();
|
|
||||||
|
|
||||||
if( !node )
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_FORWARDS
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
cout << endl;
|
|
||||||
#endif
|
|
||||||
steps ++;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_FORWARDS
|
|
||||||
// todo move step count into main algorithm
|
|
||||||
cout << "Solution steps " << steps << endl;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
////////////
|
|
||||||
|
|
||||||
node = astarsearch.GetSolutionEnd();
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_BACKWARDS
|
|
||||||
cout << "Displaying reverse solution\n";
|
|
||||||
#endif
|
|
||||||
steps = 0;
|
|
||||||
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
cout << endl;
|
|
||||||
for( ;; )
|
|
||||||
{
|
|
||||||
node = astarsearch.GetSolutionPrev();
|
|
||||||
|
|
||||||
if( !node )
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
#if DISPLAY_SOLUTION_BACKWARDS
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
cout << endl;
|
|
||||||
#endif
|
|
||||||
steps ++;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION_BACKWARDS
|
|
||||||
cout << "Solution steps " << steps << endl;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//////////////
|
|
||||||
|
|
||||||
// Once you're done with the solution you can free the nodes up
|
|
||||||
astarsearch.FreeSolutionNodes();
|
|
||||||
|
|
||||||
}
|
|
||||||
else if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_FAILED )
|
|
||||||
{
|
|
||||||
#if DISPLAY_SOLUTION_INFO
|
|
||||||
cout << "Search terminated. Did not find goal state\n";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
else if( SearchState == AStarSearch<PuzzleState>::SEARCH_STATE_OUT_OF_MEMORY )
|
|
||||||
{
|
|
||||||
#if DISPLAY_SOLUTION_INFO
|
|
||||||
cout << "Search terminated. Out of memory\n";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Display the number of loops the search went through
|
|
||||||
#if DISPLAY_SOLUTION_INFO
|
|
||||||
cout << "SearchSteps : " << astarsearch.GetStepCount() << endl;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,343 +0,0 @@
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// STL A* Search implementation
|
|
||||||
// (C)2001 Justin Heyes-Jones
|
|
||||||
//
|
|
||||||
// Finding a path on a simple grid maze
|
|
||||||
// This shows how to do shortest path finding using A*
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#include "stlastar.h" // See header for copyright and usage information
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <math.h>
|
|
||||||
|
|
||||||
#define DEBUG_LISTS 0
|
|
||||||
#define DEBUG_LIST_LENGTHS_ONLY 0
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
// Global data
|
|
||||||
|
|
||||||
// The world map
|
|
||||||
|
|
||||||
const int MAP_WIDTH = 20;
|
|
||||||
const int MAP_HEIGHT = 20;
|
|
||||||
|
|
||||||
int world_map[ MAP_WIDTH * MAP_HEIGHT ] =
|
|
||||||
{
|
|
||||||
|
|
||||||
// 0001020304050607080910111213141516171819
|
|
||||||
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // 00
|
|
||||||
1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,1, // 01
|
|
||||||
1,9,9,1,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 02
|
|
||||||
1,9,9,1,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 03
|
|
||||||
1,9,1,1,1,1,9,9,1,9,1,9,1,1,1,1,9,9,1,1, // 04
|
|
||||||
1,9,1,1,9,1,1,1,1,9,1,1,1,1,9,1,1,1,1,1, // 05
|
|
||||||
1,9,9,9,9,1,1,1,1,1,1,9,9,9,9,1,1,1,1,1, // 06
|
|
||||||
1,9,9,9,9,9,9,9,9,1,1,1,9,9,9,9,9,9,9,1, // 07
|
|
||||||
1,9,1,1,1,1,1,1,1,1,1,9,1,1,1,1,1,1,1,1, // 08
|
|
||||||
1,9,1,9,9,9,9,9,9,9,1,1,9,9,9,9,9,9,9,1, // 09
|
|
||||||
1,9,1,1,1,1,9,1,1,9,1,1,1,1,1,1,1,1,1,1, // 10
|
|
||||||
1,9,9,9,9,9,1,9,1,9,1,9,9,9,9,9,1,1,1,1, // 11
|
|
||||||
1,9,1,9,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 12
|
|
||||||
1,9,1,9,1,9,9,9,1,9,1,9,1,9,1,9,9,9,1,1, // 13
|
|
||||||
1,9,1,1,1,1,9,9,1,9,1,9,1,1,1,1,9,9,1,1, // 14
|
|
||||||
1,9,1,1,9,1,1,1,1,9,1,1,1,1,9,1,1,1,1,1, // 15
|
|
||||||
1,9,9,9,9,1,1,1,1,1,1,9,9,9,9,1,1,1,1,1, // 16
|
|
||||||
1,1,9,9,9,9,9,9,9,1,1,1,9,9,9,1,9,9,9,9, // 17
|
|
||||||
1,9,1,1,1,1,1,1,1,1,1,9,1,1,1,1,1,1,1,1, // 18
|
|
||||||
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // 19
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
// map helper functions
|
|
||||||
|
|
||||||
int GetMap( int x, int y )
|
|
||||||
{
|
|
||||||
if( x < 0 ||
|
|
||||||
x >= MAP_WIDTH ||
|
|
||||||
y < 0 ||
|
|
||||||
y >= MAP_HEIGHT
|
|
||||||
)
|
|
||||||
{
|
|
||||||
return 9;
|
|
||||||
}
|
|
||||||
|
|
||||||
return world_map[(y*MAP_WIDTH)+x];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Definitions
|
|
||||||
|
|
||||||
class MapSearchNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
int x; // the (x,y) positions of the node
|
|
||||||
int y;
|
|
||||||
|
|
||||||
MapSearchNode() { x = y = 0; }
|
|
||||||
MapSearchNode( int px, int py ) { x=px; y=py; }
|
|
||||||
|
|
||||||
float GoalDistanceEstimate( MapSearchNode &nodeGoal );
|
|
||||||
bool IsGoal( MapSearchNode &nodeGoal );
|
|
||||||
bool GetSuccessors( AStarSearch<MapSearchNode> *astarsearch, MapSearchNode *parent_node );
|
|
||||||
float GetCost( MapSearchNode &successor );
|
|
||||||
bool IsSameState( MapSearchNode &rhs );
|
|
||||||
|
|
||||||
void PrintNodeInfo();
|
|
||||||
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
bool MapSearchNode::IsSameState( MapSearchNode &rhs )
|
|
||||||
{
|
|
||||||
|
|
||||||
// same state in a maze search is simply when (x,y) are the same
|
|
||||||
if( (x == rhs.x) &&
|
|
||||||
(y == rhs.y) )
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void MapSearchNode::PrintNodeInfo()
|
|
||||||
{
|
|
||||||
char str[100];
|
|
||||||
sprintf( str, "Node position : (%d,%d)\n", x,y );
|
|
||||||
|
|
||||||
cout << str;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here's the heuristic function that estimates the distance from a Node
|
|
||||||
// to the Goal.
|
|
||||||
|
|
||||||
float MapSearchNode::GoalDistanceEstimate( MapSearchNode &nodeGoal )
|
|
||||||
{
|
|
||||||
return abs(x - nodeGoal.x) + abs(y - nodeGoal.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MapSearchNode::IsGoal( MapSearchNode &nodeGoal )
|
|
||||||
{
|
|
||||||
|
|
||||||
if( (x == nodeGoal.x) &&
|
|
||||||
(y == nodeGoal.y) )
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This generates the successors to the given Node. It uses a helper function called
|
|
||||||
// AddSuccessor to give the successors to the AStar class. The A* specific initialisation
|
|
||||||
// is done for each node internally, so here you just set the state information that
|
|
||||||
// is specific to the application
|
|
||||||
bool MapSearchNode::GetSuccessors( AStarSearch<MapSearchNode> *astarsearch, MapSearchNode *parent_node )
|
|
||||||
{
|
|
||||||
|
|
||||||
int parent_x = -1;
|
|
||||||
int parent_y = -1;
|
|
||||||
|
|
||||||
if( parent_node )
|
|
||||||
{
|
|
||||||
parent_x = parent_node->x;
|
|
||||||
parent_y = parent_node->y;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MapSearchNode NewNode;
|
|
||||||
|
|
||||||
// push each possible move except allowing the search to go backwards
|
|
||||||
|
|
||||||
if( (GetMap( x-1, y ) < 9)
|
|
||||||
&& !((parent_x == x-1) && (parent_y == y))
|
|
||||||
)
|
|
||||||
{
|
|
||||||
NewNode = MapSearchNode( x-1, y );
|
|
||||||
astarsearch->AddSuccessor( NewNode );
|
|
||||||
}
|
|
||||||
|
|
||||||
if( (GetMap( x, y-1 ) < 9)
|
|
||||||
&& !((parent_x == x) && (parent_y == y-1))
|
|
||||||
)
|
|
||||||
{
|
|
||||||
NewNode = MapSearchNode( x, y-1 );
|
|
||||||
astarsearch->AddSuccessor( NewNode );
|
|
||||||
}
|
|
||||||
|
|
||||||
if( (GetMap( x+1, y ) < 9)
|
|
||||||
&& !((parent_x == x+1) && (parent_y == y))
|
|
||||||
)
|
|
||||||
{
|
|
||||||
NewNode = MapSearchNode( x+1, y );
|
|
||||||
astarsearch->AddSuccessor( NewNode );
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if( (GetMap( x, y+1 ) < 9)
|
|
||||||
&& !((parent_x == x) && (parent_y == y+1))
|
|
||||||
)
|
|
||||||
{
|
|
||||||
NewNode = MapSearchNode( x, y+1 );
|
|
||||||
astarsearch->AddSuccessor( NewNode );
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// given this node, what does it cost to move to successor. In the case
|
|
||||||
// of our map the answer is the map terrain value at this node since that is
|
|
||||||
// conceptually where we're moving
|
|
||||||
|
|
||||||
float MapSearchNode::GetCost( MapSearchNode &successor )
|
|
||||||
{
|
|
||||||
return (float) GetMap( x, y );
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// Main
|
|
||||||
|
|
||||||
int findpath( int argc, char *argv[] )
|
|
||||||
{
|
|
||||||
|
|
||||||
cout << "STL A* Search implementation\n(C)2001 Justin Heyes-Jones\n";
|
|
||||||
|
|
||||||
// Our sample problem defines the world as a 2d array representing a terrain
|
|
||||||
// Each element contains an integer from 0 to 5 which indicates the cost
|
|
||||||
// of travel across the terrain. Zero means the least possible difficulty
|
|
||||||
// in travelling (think ice rink if you can skate) whilst 5 represents the
|
|
||||||
// most difficult. 9 indicates that we cannot pass.
|
|
||||||
|
|
||||||
// Create an instance of the search class...
|
|
||||||
|
|
||||||
AStarSearch<MapSearchNode> astarsearch;
|
|
||||||
|
|
||||||
unsigned int SearchCount = 0;
|
|
||||||
|
|
||||||
const unsigned int NumSearches = 1;
|
|
||||||
|
|
||||||
while(SearchCount < NumSearches)
|
|
||||||
{
|
|
||||||
|
|
||||||
// Create a start state
|
|
||||||
MapSearchNode nodeStart;
|
|
||||||
nodeStart.x = rand()%MAP_WIDTH;
|
|
||||||
nodeStart.y = rand()%MAP_HEIGHT;
|
|
||||||
|
|
||||||
// Define the goal state
|
|
||||||
MapSearchNode nodeEnd;
|
|
||||||
nodeEnd.x = rand()%MAP_WIDTH;
|
|
||||||
nodeEnd.y = rand()%MAP_HEIGHT;
|
|
||||||
|
|
||||||
// Set Start and goal states
|
|
||||||
|
|
||||||
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
|
|
||||||
|
|
||||||
unsigned int SearchState;
|
|
||||||
unsigned int SearchSteps = 0;
|
|
||||||
|
|
||||||
do
|
|
||||||
{
|
|
||||||
SearchState = astarsearch.SearchStep();
|
|
||||||
|
|
||||||
SearchSteps++;
|
|
||||||
|
|
||||||
#if DEBUG_LISTS
|
|
||||||
|
|
||||||
cout << "Steps:" << SearchSteps << "\n";
|
|
||||||
|
|
||||||
int len = 0;
|
|
||||||
|
|
||||||
cout << "Open:\n";
|
|
||||||
MapSearchNode *p = astarsearch.GetOpenListStart();
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
len++;
|
|
||||||
#if !DEBUG_LIST_LENGTHS_ONLY
|
|
||||||
((MapSearchNode *)p)->PrintNodeInfo();
|
|
||||||
#endif
|
|
||||||
p = astarsearch.GetOpenListNext();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "Open list has " << len << " nodes\n";
|
|
||||||
|
|
||||||
len = 0;
|
|
||||||
|
|
||||||
cout << "Closed:\n";
|
|
||||||
p = astarsearch.GetClosedListStart();
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
len++;
|
|
||||||
#if !DEBUG_LIST_LENGTHS_ONLY
|
|
||||||
p->PrintNodeInfo();
|
|
||||||
#endif
|
|
||||||
p = astarsearch.GetClosedListNext();
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "Closed list has " << len << " nodes\n";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}
|
|
||||||
while( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_SEARCHING );
|
|
||||||
|
|
||||||
if( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_SUCCEEDED )
|
|
||||||
{
|
|
||||||
cout << "Search found goal state\n";
|
|
||||||
|
|
||||||
MapSearchNode *node = astarsearch.GetSolutionStart();
|
|
||||||
|
|
||||||
#if DISPLAY_SOLUTION
|
|
||||||
cout << "Displaying solution\n";
|
|
||||||
#endif
|
|
||||||
int steps = 0;
|
|
||||||
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
for( ;; )
|
|
||||||
{
|
|
||||||
node = astarsearch.GetSolutionNext();
|
|
||||||
|
|
||||||
if( !node )
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
steps ++;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
cout << "Solution steps " << steps << endl;
|
|
||||||
|
|
||||||
// Once you're done with the solution you can free the nodes up
|
|
||||||
astarsearch.FreeSolutionNodes();
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
else if( SearchState == AStarSearch<MapSearchNode>::SEARCH_STATE_FAILED )
|
|
||||||
{
|
|
||||||
cout << "Search terminated. Did not find goal state\n";
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Display the number of loops the search went through
|
|
||||||
cout << "SearchSteps : " << SearchSteps << "\n";
|
|
||||||
|
|
||||||
SearchCount ++;
|
|
||||||
|
|
||||||
astarsearch.EnsureMemoryFreed();
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@ -1,252 +0,0 @@
|
|||||||
/*
|
|
||||||
|
|
||||||
A* Algorithm Implementation using STL is
|
|
||||||
Copyright (C)2001-2005 Justin Heyes-Jones
|
|
||||||
|
|
||||||
Permission is given by the author to freely redistribute and
|
|
||||||
include this code in any program as long as this credit is
|
|
||||||
given where due.
|
|
||||||
|
|
||||||
COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
|
|
||||||
INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED CODE
|
|
||||||
IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE
|
|
||||||
OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND
|
|
||||||
PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED
|
|
||||||
CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL
|
|
||||||
DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY
|
|
||||||
NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF
|
|
||||||
WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE
|
|
||||||
OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER
|
|
||||||
THIS DISCLAIMER.
|
|
||||||
|
|
||||||
Use at your own risk!
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FixedSizeAllocator class
|
|
||||||
Copyright 2001 Justin Heyes-Jones
|
|
||||||
|
|
||||||
This class is a constant time O(1) memory manager for objects of
|
|
||||||
a specified type. The type is specified using a template class.
|
|
||||||
|
|
||||||
Memory is allocated from a fixed size buffer which you can specify in the
|
|
||||||
class constructor or use the default.
|
|
||||||
|
|
||||||
Using GetFirst and GetNext it is possible to iterate through the elements
|
|
||||||
one by one, and this would be the most common use for the class.
|
|
||||||
|
|
||||||
I would suggest using this class when you want O(1) add and delete
|
|
||||||
and you don't do much searching, which would be O(n). Structures such as binary
|
|
||||||
trees can be used instead to get O(logn) access time.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef FSA_H
|
|
||||||
#define FSA_H
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
template <class USER_TYPE> class FixedSizeAllocator
|
|
||||||
{
|
|
||||||
|
|
||||||
public:
|
|
||||||
// Constants
|
|
||||||
enum
|
|
||||||
{
|
|
||||||
FSA_DEFAULT_SIZE = 100
|
|
||||||
};
|
|
||||||
|
|
||||||
// This class enables us to transparently manage the extra data
|
|
||||||
// needed to enable the user class to form part of the double-linked
|
|
||||||
// list class
|
|
||||||
struct FSA_ELEMENT
|
|
||||||
{
|
|
||||||
USER_TYPE UserType;
|
|
||||||
|
|
||||||
FSA_ELEMENT *pPrev;
|
|
||||||
FSA_ELEMENT *pNext;
|
|
||||||
};
|
|
||||||
|
|
||||||
public: // methods
|
|
||||||
FixedSizeAllocator( unsigned int MaxElements = FSA_DEFAULT_SIZE ) :
|
|
||||||
m_pFirstUsed( NULL ),
|
|
||||||
m_MaxElements( MaxElements )
|
|
||||||
{
|
|
||||||
// Allocate enough memory for the maximum number of elements
|
|
||||||
|
|
||||||
char *pMem = new char[ m_MaxElements * sizeof(FSA_ELEMENT) ];
|
|
||||||
|
|
||||||
m_pMemory = (FSA_ELEMENT *) pMem;
|
|
||||||
|
|
||||||
// Set the free list first pointer
|
|
||||||
m_pFirstFree = m_pMemory;
|
|
||||||
|
|
||||||
// Clear the memory
|
|
||||||
memset( m_pMemory, 0, sizeof( FSA_ELEMENT ) * m_MaxElements );
|
|
||||||
|
|
||||||
// Point at first element
|
|
||||||
FSA_ELEMENT *pElement = m_pFirstFree;
|
|
||||||
|
|
||||||
// Set the double linked free list
|
|
||||||
for( unsigned int i=0; i<m_MaxElements; i++ )
|
|
||||||
{
|
|
||||||
pElement->pPrev = pElement-1;
|
|
||||||
pElement->pNext = pElement+1;
|
|
||||||
|
|
||||||
pElement++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// first element should have a null prev
|
|
||||||
m_pFirstFree->pPrev = NULL;
|
|
||||||
// last element should have a null next
|
|
||||||
(pElement-1)->pNext = NULL;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
~FixedSizeAllocator()
|
|
||||||
{
|
|
||||||
// Free up the memory
|
|
||||||
delete [] (char *) m_pMemory;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate a new USER_TYPE and return a pointer to it
|
|
||||||
USER_TYPE *alloc()
|
|
||||||
{
|
|
||||||
|
|
||||||
FSA_ELEMENT *pNewNode = NULL;
|
|
||||||
|
|
||||||
if( !m_pFirstFree )
|
|
||||||
{
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
pNewNode = m_pFirstFree;
|
|
||||||
m_pFirstFree = pNewNode->pNext;
|
|
||||||
|
|
||||||
// if the new node points to another free node then
|
|
||||||
// change that nodes prev free pointer...
|
|
||||||
if( pNewNode->pNext )
|
|
||||||
{
|
|
||||||
pNewNode->pNext->pPrev = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
// node is now on the used list
|
|
||||||
|
|
||||||
pNewNode->pPrev = NULL; // the allocated node is always first in the list
|
|
||||||
|
|
||||||
if( m_pFirstUsed == NULL )
|
|
||||||
{
|
|
||||||
pNewNode->pNext = NULL; // no other nodes
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
m_pFirstUsed->pPrev = pNewNode; // insert this at the head of the used list
|
|
||||||
pNewNode->pNext = m_pFirstUsed;
|
|
||||||
}
|
|
||||||
|
|
||||||
m_pFirstUsed = pNewNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
return reinterpret_cast<USER_TYPE*>(pNewNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free the given user type
|
|
||||||
// For efficiency I don't check whether the user_data is a valid
|
|
||||||
// pointer that was allocated. I may add some debug only checking
|
|
||||||
// (To add the debug check you'd need to make sure the pointer is in
|
|
||||||
// the m_pMemory area and is pointing at the start of a node)
|
|
||||||
void free( USER_TYPE *user_data )
|
|
||||||
{
|
|
||||||
FSA_ELEMENT *pNode = reinterpret_cast<FSA_ELEMENT*>(user_data);
|
|
||||||
|
|
||||||
// manage used list, remove this node from it
|
|
||||||
if( pNode->pPrev )
|
|
||||||
{
|
|
||||||
pNode->pPrev->pNext = pNode->pNext;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// this handles the case that we delete the first node in the used list
|
|
||||||
m_pFirstUsed = pNode->pNext;
|
|
||||||
}
|
|
||||||
|
|
||||||
if( pNode->pNext )
|
|
||||||
{
|
|
||||||
pNode->pNext->pPrev = pNode->pPrev;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add to free list
|
|
||||||
if( m_pFirstFree == NULL )
|
|
||||||
{
|
|
||||||
// free list was empty
|
|
||||||
m_pFirstFree = pNode;
|
|
||||||
pNode->pPrev = NULL;
|
|
||||||
pNode->pNext = NULL;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Add this node at the start of the free list
|
|
||||||
m_pFirstFree->pPrev = pNode;
|
|
||||||
pNode->pNext = m_pFirstFree;
|
|
||||||
m_pFirstFree = pNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// For debugging this displays both lists (using the prev/next list pointers)
|
|
||||||
void Debug()
|
|
||||||
{
|
|
||||||
printf( "free list " );
|
|
||||||
|
|
||||||
FSA_ELEMENT *p = m_pFirstFree;
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
printf( "%x!%x ", p->pPrev, p->pNext );
|
|
||||||
p = p->pNext;
|
|
||||||
}
|
|
||||||
printf( "\n" );
|
|
||||||
|
|
||||||
printf( "used list " );
|
|
||||||
|
|
||||||
p = m_pFirstUsed;
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
printf( "%x!%x ", p->pPrev, p->pNext );
|
|
||||||
p = p->pNext;
|
|
||||||
}
|
|
||||||
printf( "\n" );
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterators
|
|
||||||
|
|
||||||
USER_TYPE *GetFirst()
|
|
||||||
{
|
|
||||||
return reinterpret_cast<USER_TYPE *>(m_pFirstUsed);
|
|
||||||
}
|
|
||||||
|
|
||||||
USER_TYPE *GetNext( USER_TYPE *node )
|
|
||||||
{
|
|
||||||
return reinterpret_cast<USER_TYPE *>
|
|
||||||
(
|
|
||||||
(reinterpret_cast<FSA_ELEMENT *>(node))->pNext
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
public: // data
|
|
||||||
|
|
||||||
private: // methods
|
|
||||||
|
|
||||||
private: // data
|
|
||||||
|
|
||||||
FSA_ELEMENT *m_pFirstFree;
|
|
||||||
FSA_ELEMENT *m_pFirstUsed;
|
|
||||||
unsigned int m_MaxElements;
|
|
||||||
FSA_ELEMENT *m_pMemory;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // defined FSA_H
|
|
@ -1,315 +0,0 @@
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// This example code illustrate how to use STL A* Search implementation to find the minimum path between two
|
|
||||||
// cities given a map. The example is taken from the book AI: A Modern Approach, 3rd Ed., by Russel, where a map
|
|
||||||
// of Romania is given. The target node is Bucharest, and the user can specify the initial city from which the
|
|
||||||
// search algorithm will start looking for the minimum path to Bucharest.
|
|
||||||
//
|
|
||||||
// Usage: min_path_to_Bucharest <start city name>
|
|
||||||
// Example:
|
|
||||||
// min_path_to_Bucharest Arad
|
|
||||||
//
|
|
||||||
// Thanks to Rasoul Mojtahedzadeh for this contribution
|
|
||||||
// Mojtahedzadeh _atsign_ gmail com
|
|
||||||
//
|
|
||||||
// Please note that this exercise is academic in nature and that the distances between the cities may not be
|
|
||||||
// correct compared to the latitude and longnitude differences in real life. Thanks to parthi2929 for noticing
|
|
||||||
// this issue. In a real application you would use some kind of exact x,y position of each city in order to get
|
|
||||||
// accurate heuristics. In fact, that is part of the point of this example, in the book you will see Norvig
|
|
||||||
// mention that the algorithm does some backtracking because the heuristic is not accurate (yet still admissable).
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
#include "stlastar.h"
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#define DEBUG_LISTS 0
|
|
||||||
#define DEBUG_LIST_LENGTHS_ONLY 0
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
const int MAX_CITIES = 20;
|
|
||||||
|
|
||||||
enum ENUM_CITIES{Arad=0, Bucharest, Craiova, Drobeta, Eforie, Fagaras, Giurgiu, Hirsova, Iasi, Lugoj, Mehadia, Neamt, Oradea, Pitesti, RimnicuVilcea, Sibiu, Timisoara, Urziceni, Vaslui, Zerind};
|
|
||||||
vector<string> CityNames(MAX_CITIES);
|
|
||||||
float RomaniaMap[MAX_CITIES][MAX_CITIES];
|
|
||||||
|
|
||||||
class PathSearchNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
ENUM_CITIES city;
|
|
||||||
|
|
||||||
PathSearchNode() { city = Arad; }
|
|
||||||
PathSearchNode( ENUM_CITIES in ) { city = in; }
|
|
||||||
|
|
||||||
float GoalDistanceEstimate( PathSearchNode &nodeGoal );
|
|
||||||
bool IsGoal( PathSearchNode &nodeGoal );
|
|
||||||
bool GetSuccessors( AStarSearch<PathSearchNode> *astarsearch, PathSearchNode *parent_node );
|
|
||||||
float GetCost( PathSearchNode &successor );
|
|
||||||
bool IsSameState( PathSearchNode &rhs );
|
|
||||||
|
|
||||||
void PrintNodeInfo();
|
|
||||||
};
|
|
||||||
|
|
||||||
// check if "this" node is the same as "rhs" node
|
|
||||||
bool PathSearchNode::IsSameState( PathSearchNode &rhs )
|
|
||||||
{
|
|
||||||
if(city == rhs.city) return(true);
|
|
||||||
return(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Euclidean distance between "this" node city and Bucharest
|
|
||||||
// Note: Numbers are taken from the book
|
|
||||||
float PathSearchNode::GoalDistanceEstimate( PathSearchNode &nodeGoal )
|
|
||||||
{
|
|
||||||
// goal is always Bucharest!
|
|
||||||
switch(city)
|
|
||||||
{
|
|
||||||
case Arad: return 366; break;
|
|
||||||
case Bucharest: return 0; break;
|
|
||||||
case Craiova: return 160; break;
|
|
||||||
case Drobeta: return 242; break;
|
|
||||||
case Eforie: return 161; break;
|
|
||||||
case Fagaras: return 176; break;
|
|
||||||
case Giurgiu: return 77; break;
|
|
||||||
case Hirsova: return 151; break;
|
|
||||||
case Iasi: return 226; break;
|
|
||||||
case Lugoj: return 244; break;
|
|
||||||
case Mehadia: return 241; break;
|
|
||||||
case Neamt: return 234; break;
|
|
||||||
case Oradea: return 380; break;
|
|
||||||
case Pitesti: return 100; break;
|
|
||||||
case RimnicuVilcea: return 193; break;
|
|
||||||
case Sibiu: return 253; break;
|
|
||||||
case Timisoara: return 329; break;
|
|
||||||
case Urziceni: return 80; break;
|
|
||||||
case Vaslui: return 199; break;
|
|
||||||
case Zerind: return 374; break;
|
|
||||||
}
|
|
||||||
cerr << "ASSERT: city = " << CityNames[city] << endl;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if "this" node is the goal node
|
|
||||||
bool PathSearchNode::IsGoal( PathSearchNode &nodeGoal )
|
|
||||||
{
|
|
||||||
if( city == Bucharest ) return(true);
|
|
||||||
return(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// generates the successor nodes of "this" node
|
|
||||||
bool PathSearchNode::GetSuccessors( AStarSearch<PathSearchNode> *astarsearch, PathSearchNode *parent_node )
|
|
||||||
{
|
|
||||||
PathSearchNode NewNode;
|
|
||||||
for(int c=0; c<MAX_CITIES; c++)
|
|
||||||
{
|
|
||||||
if(RomaniaMap[city][c] < 0) continue;
|
|
||||||
NewNode = PathSearchNode((ENUM_CITIES)c);
|
|
||||||
astarsearch->AddSuccessor( NewNode );
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// the cost of going from "this" node to the "successor" node
|
|
||||||
float PathSearchNode::GetCost( PathSearchNode &successor )
|
|
||||||
{
|
|
||||||
return RomaniaMap[city][successor.city];
|
|
||||||
}
|
|
||||||
|
|
||||||
// prints out information about the node
|
|
||||||
void PathSearchNode::PrintNodeInfo()
|
|
||||||
{
|
|
||||||
cout << " " << CityNames[city] << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Main
|
|
||||||
int min_path_to_Bucharest( int argc, char *argv[] )
|
|
||||||
{
|
|
||||||
// creating map of Romania
|
|
||||||
for(int i=0; i<MAX_CITIES; i++)
|
|
||||||
for(int j=0; j<MAX_CITIES; j++)
|
|
||||||
RomaniaMap[i][j]=-1.0;
|
|
||||||
|
|
||||||
RomaniaMap[Arad][Sibiu]=140;
|
|
||||||
RomaniaMap[Arad][Zerind]=75;
|
|
||||||
RomaniaMap[Arad][Timisoara]=118;
|
|
||||||
RomaniaMap[Bucharest][Giurgiu]=90;
|
|
||||||
RomaniaMap[Bucharest][Urziceni]=85;
|
|
||||||
RomaniaMap[Bucharest][Fagaras]=211;
|
|
||||||
RomaniaMap[Bucharest][Pitesti]=101;
|
|
||||||
RomaniaMap[Craiova][Drobeta]=120;
|
|
||||||
RomaniaMap[Craiova][RimnicuVilcea]=146;
|
|
||||||
RomaniaMap[Craiova][Pitesti]=138;
|
|
||||||
RomaniaMap[Drobeta][Craiova]=120;
|
|
||||||
RomaniaMap[Drobeta][Mehadia]=75;
|
|
||||||
RomaniaMap[Eforie][Hirsova]=75;
|
|
||||||
RomaniaMap[Fagaras][Bucharest]=211;
|
|
||||||
RomaniaMap[Fagaras][Sibiu]=99;
|
|
||||||
RomaniaMap[Giurgiu][Bucharest]=90;
|
|
||||||
RomaniaMap[Hirsova][Eforie]=86;
|
|
||||||
RomaniaMap[Hirsova][Urziceni]=98;
|
|
||||||
RomaniaMap[Iasi][Vaslui]=92;
|
|
||||||
RomaniaMap[Iasi][Neamt]=87;
|
|
||||||
RomaniaMap[Lugoj][Timisoara]=111;
|
|
||||||
RomaniaMap[Lugoj][Mehadia]=70;
|
|
||||||
RomaniaMap[Mehadia][Lugoj]=70;
|
|
||||||
RomaniaMap[Mehadia][Drobeta]=75;
|
|
||||||
RomaniaMap[Neamt][Iasi]=87;
|
|
||||||
RomaniaMap[Oradea][Zerind]=71;
|
|
||||||
RomaniaMap[Oradea][Sibiu]=151;
|
|
||||||
RomaniaMap[Pitesti][Bucharest]=101;
|
|
||||||
RomaniaMap[Pitesti][RimnicuVilcea]=97;
|
|
||||||
RomaniaMap[Pitesti][Craiova]=138;
|
|
||||||
RomaniaMap[RimnicuVilcea][Pitesti]=97;
|
|
||||||
RomaniaMap[RimnicuVilcea][Craiova]=146;
|
|
||||||
RomaniaMap[RimnicuVilcea][Sibiu]=80;
|
|
||||||
RomaniaMap[Sibiu][RimnicuVilcea]=80;
|
|
||||||
RomaniaMap[Sibiu][Fagaras]=99;
|
|
||||||
RomaniaMap[Sibiu][Oradea]=151;
|
|
||||||
RomaniaMap[Sibiu][Arad]=140;
|
|
||||||
RomaniaMap[Timisoara][Arad]=118;
|
|
||||||
RomaniaMap[Timisoara][Lugoj]=111;
|
|
||||||
RomaniaMap[Urziceni][Bucharest]=85;
|
|
||||||
RomaniaMap[Urziceni][Hirsova]=98;
|
|
||||||
RomaniaMap[Urziceni][Vaslui]=142;
|
|
||||||
RomaniaMap[Vaslui][Urziceni]=142;
|
|
||||||
RomaniaMap[Vaslui][Iasi]=92;
|
|
||||||
RomaniaMap[Zerind][Arad]=75;
|
|
||||||
RomaniaMap[Zerind][Oradea]=71;
|
|
||||||
|
|
||||||
// City names
|
|
||||||
CityNames[Arad].assign("Arad");
|
|
||||||
CityNames[Bucharest].assign("Bucharest");
|
|
||||||
CityNames[Craiova].assign("Craiova");
|
|
||||||
CityNames[Drobeta].assign("Drobeta");
|
|
||||||
CityNames[Eforie].assign("Eforie");
|
|
||||||
CityNames[Fagaras].assign("Fagaras");
|
|
||||||
CityNames[Giurgiu].assign("Giurgiu");
|
|
||||||
CityNames[Hirsova].assign("Hirsova");
|
|
||||||
CityNames[Iasi].assign("Iasi");
|
|
||||||
CityNames[Lugoj].assign("Lugoj");
|
|
||||||
CityNames[Mehadia].assign("Mehadia");
|
|
||||||
CityNames[Neamt].assign("Neamt");
|
|
||||||
CityNames[Oradea].assign("Oradea");
|
|
||||||
CityNames[Pitesti].assign("Pitesti");
|
|
||||||
CityNames[RimnicuVilcea].assign("RimnicuVilcea");
|
|
||||||
CityNames[Sibiu].assign("Sibiu");
|
|
||||||
CityNames[Timisoara].assign("Timisoara");
|
|
||||||
CityNames[Urziceni].assign("Urziceni");
|
|
||||||
CityNames[Vaslui].assign("Vaslui");
|
|
||||||
CityNames[Zerind].assign("Zerind");
|
|
||||||
|
|
||||||
ENUM_CITIES initCity = Arad;
|
|
||||||
if(argc == 2)
|
|
||||||
{
|
|
||||||
bool found = false;
|
|
||||||
for(size_t i=0; i<CityNames.size(); i++)
|
|
||||||
if(CityNames[i].compare(argv[1])==0)
|
|
||||||
{
|
|
||||||
initCity = (ENUM_CITIES)i;
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if(not found)
|
|
||||||
{
|
|
||||||
cout << "There is no city named "<<argv[1]<<" in the map!\n";
|
|
||||||
return(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// An instance of A* search class
|
|
||||||
AStarSearch<PathSearchNode> astarsearch;
|
|
||||||
|
|
||||||
unsigned int SearchCount = 0;
|
|
||||||
const unsigned int NumSearches = 1;
|
|
||||||
|
|
||||||
while(SearchCount < NumSearches)
|
|
||||||
{
|
|
||||||
// Create a start state
|
|
||||||
PathSearchNode nodeStart;
|
|
||||||
nodeStart.city = initCity;
|
|
||||||
|
|
||||||
// Define the goal state, always Bucharest!
|
|
||||||
PathSearchNode nodeEnd;
|
|
||||||
nodeEnd.city = Bucharest;
|
|
||||||
|
|
||||||
// Set Start and goal states
|
|
||||||
astarsearch.SetStartAndGoalStates( nodeStart, nodeEnd );
|
|
||||||
|
|
||||||
unsigned int SearchState;
|
|
||||||
unsigned int SearchSteps = 0;
|
|
||||||
|
|
||||||
do
|
|
||||||
{
|
|
||||||
SearchState = astarsearch.SearchStep();
|
|
||||||
SearchSteps++;
|
|
||||||
|
|
||||||
#if DEBUG_LISTS
|
|
||||||
|
|
||||||
cout << "Steps:" << SearchSteps << "\n";
|
|
||||||
|
|
||||||
int len = 0;
|
|
||||||
|
|
||||||
cout << "Open:\n";
|
|
||||||
PathSearchNode *p = astarsearch.GetOpenListStart();
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
len++;
|
|
||||||
#if !DEBUG_LIST_LENGTHS_ONLY
|
|
||||||
((PathSearchNode *)p)->PrintNodeInfo();
|
|
||||||
#endif
|
|
||||||
p = astarsearch.GetOpenListNext();
|
|
||||||
|
|
||||||
}
|
|
||||||
cout << "Open list has " << len << " nodes\n";
|
|
||||||
|
|
||||||
len = 0;
|
|
||||||
|
|
||||||
cout << "Closed:\n";
|
|
||||||
p = astarsearch.GetClosedListStart();
|
|
||||||
while( p )
|
|
||||||
{
|
|
||||||
len++;
|
|
||||||
#if !DEBUG_LIST_LENGTHS_ONLY
|
|
||||||
p->PrintNodeInfo();
|
|
||||||
#endif
|
|
||||||
p = astarsearch.GetClosedListNext();
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "Closed list has " << len << " nodes\n";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}
|
|
||||||
while( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_SEARCHING );
|
|
||||||
|
|
||||||
if( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_SUCCEEDED )
|
|
||||||
{
|
|
||||||
cout << "Search found the goal state\n";
|
|
||||||
PathSearchNode *node = astarsearch.GetSolutionStart();
|
|
||||||
cout << "Displaying solution\n";
|
|
||||||
int steps = 0;
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
for( ;; )
|
|
||||||
{
|
|
||||||
node = astarsearch.GetSolutionNext();
|
|
||||||
if( !node ) break;
|
|
||||||
node->PrintNodeInfo();
|
|
||||||
steps ++;
|
|
||||||
};
|
|
||||||
cout << "Solution steps " << steps << endl;
|
|
||||||
// Once you're done with the solution you can free the nodes up
|
|
||||||
astarsearch.FreeSolutionNodes();
|
|
||||||
}
|
|
||||||
else if( SearchState == AStarSearch<PathSearchNode>::SEARCH_STATE_FAILED )
|
|
||||||
{
|
|
||||||
cout << "Search terminated. Did not find goal state\n";
|
|
||||||
}
|
|
||||||
// Display the number of loops the search went through
|
|
||||||
cout << "SearchSteps : " << SearchSteps << "\n";
|
|
||||||
SearchCount ++;
|
|
||||||
astarsearch.EnsureMemoryFreed();
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,833 +0,0 @@
|
|||||||
/*
|
|
||||||
A* Algorithm Implementation using STL is
|
|
||||||
Copyright (C)2001-2005 Justin Heyes-Jones
|
|
||||||
|
|
||||||
Permission is given by the author to freely redistribute and
|
|
||||||
include this code in any program as long as this credit is
|
|
||||||
given where due.
|
|
||||||
|
|
||||||
COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
|
|
||||||
INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED CODE
|
|
||||||
IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE
|
|
||||||
OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND
|
|
||||||
PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED
|
|
||||||
CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL
|
|
||||||
DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY
|
|
||||||
NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF
|
|
||||||
WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE
|
|
||||||
OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER
|
|
||||||
THIS DISCLAIMER.
|
|
||||||
|
|
||||||
Use at your own risk!
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef STLASTAR_H
|
|
||||||
#define STLASTAR_H
|
|
||||||
// used for text debugging
|
|
||||||
#include <iostream>
|
|
||||||
#include <stdio.h>
|
|
||||||
//#include <conio.h>
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
// stl includes
|
|
||||||
#include <algorithm>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <cfloat>
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
// fast fixed size memory allocator, used for fast node memory management
|
|
||||||
#include "fsa.h"
|
|
||||||
|
|
||||||
// Fixed size memory allocator can be disabled to compare performance
|
|
||||||
// Uses std new and delete instead if you turn it off
|
|
||||||
#define USE_FSA_MEMORY 1
|
|
||||||
|
|
||||||
// disable warning that debugging information has lines that are truncated
|
|
||||||
// occurs in stl headers
|
|
||||||
#if defined(WIN32) && defined(_WINDOWS)
|
|
||||||
#pragma warning( disable : 4786 )
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <class T> class AStarState;
|
|
||||||
|
|
||||||
// The AStar search class. UserState is the users state space type
|
|
||||||
template <class UserState> class AStarSearch
|
|
||||||
{
|
|
||||||
|
|
||||||
public: // data
|
|
||||||
|
|
||||||
enum
|
|
||||||
{
|
|
||||||
SEARCH_STATE_NOT_INITIALISED,
|
|
||||||
SEARCH_STATE_SEARCHING,
|
|
||||||
SEARCH_STATE_SUCCEEDED,
|
|
||||||
SEARCH_STATE_FAILED,
|
|
||||||
SEARCH_STATE_OUT_OF_MEMORY,
|
|
||||||
SEARCH_STATE_INVALID
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// A node represents a possible state in the search
|
|
||||||
// The user provided state type is included inside this type
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
class Node
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
Node *parent; // used during the search to record the parent of successor nodes
|
|
||||||
Node *child; // used after the search for the application to view the search in reverse
|
|
||||||
|
|
||||||
float g; // cost of this node + it's predecessors
|
|
||||||
float h; // heuristic estimate of distance to goal
|
|
||||||
float f; // sum of cumulative cost of predecessors and self and heuristic
|
|
||||||
|
|
||||||
Node() :
|
|
||||||
parent( 0 ),
|
|
||||||
child( 0 ),
|
|
||||||
g( 0.0f ),
|
|
||||||
h( 0.0f ),
|
|
||||||
f( 0.0f )
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState m_UserState;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// For sorting the heap the STL needs compare function that lets us compare
|
|
||||||
// the f value of two nodes
|
|
||||||
|
|
||||||
class HeapCompare_f
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
bool operator() ( const Node *x, const Node *y ) const
|
|
||||||
{
|
|
||||||
return x->f > y->f;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
public: // methods
|
|
||||||
|
|
||||||
|
|
||||||
// constructor just initialises private data
|
|
||||||
AStarSearch() :
|
|
||||||
m_State( SEARCH_STATE_NOT_INITIALISED ),
|
|
||||||
m_CurrentSolutionNode( NULL ),
|
|
||||||
#if USE_FSA_MEMORY
|
|
||||||
m_FixedSizeAllocator( 1000 ),
|
|
||||||
#endif
|
|
||||||
m_AllocateNodeCount(0),
|
|
||||||
m_CancelRequest( false )
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
AStarSearch( int MaxNodes ) :
|
|
||||||
m_State( SEARCH_STATE_NOT_INITIALISED ),
|
|
||||||
m_CurrentSolutionNode( NULL ),
|
|
||||||
#if USE_FSA_MEMORY
|
|
||||||
m_FixedSizeAllocator( MaxNodes ),
|
|
||||||
#endif
|
|
||||||
m_AllocateNodeCount(0),
|
|
||||||
m_CancelRequest( false )
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// call at any time to cancel the search and free up all the memory
|
|
||||||
void CancelSearch()
|
|
||||||
{
|
|
||||||
m_CancelRequest = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set Start and goal states
|
|
||||||
void SetStartAndGoalStates( UserState &Start, UserState &Goal )
|
|
||||||
{
|
|
||||||
m_CancelRequest = false;
|
|
||||||
|
|
||||||
m_Start = AllocateNode();
|
|
||||||
m_Goal = AllocateNode();
|
|
||||||
|
|
||||||
assert((m_Start != NULL && m_Goal != NULL));
|
|
||||||
|
|
||||||
m_Start->m_UserState = Start;
|
|
||||||
m_Goal->m_UserState = Goal;
|
|
||||||
|
|
||||||
m_State = SEARCH_STATE_SEARCHING;
|
|
||||||
|
|
||||||
// Initialise the AStar specific parts of the Start Node
|
|
||||||
// The user only needs fill out the state information
|
|
||||||
|
|
||||||
m_Start->g = 0;
|
|
||||||
m_Start->h = m_Start->m_UserState.GoalDistanceEstimate( m_Goal->m_UserState );
|
|
||||||
m_Start->f = m_Start->g + m_Start->h;
|
|
||||||
m_Start->parent = 0;
|
|
||||||
|
|
||||||
// Push the start node on the Open list
|
|
||||||
|
|
||||||
m_OpenList.push_back( m_Start ); // heap now unsorted
|
|
||||||
|
|
||||||
// Sort back element into heap
|
|
||||||
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
|
|
||||||
|
|
||||||
// Initialise counter for search steps
|
|
||||||
m_Steps = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advances search one step
|
|
||||||
unsigned int SearchStep()
|
|
||||||
{
|
|
||||||
// Firstly break if the user has not initialised the search
|
|
||||||
assert( (m_State > SEARCH_STATE_NOT_INITIALISED) &&
|
|
||||||
(m_State < SEARCH_STATE_INVALID) );
|
|
||||||
|
|
||||||
// Next I want it to be safe to do a searchstep once the search has succeeded...
|
|
||||||
if( (m_State == SEARCH_STATE_SUCCEEDED) ||
|
|
||||||
(m_State == SEARCH_STATE_FAILED)
|
|
||||||
)
|
|
||||||
{
|
|
||||||
return m_State;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Failure is defined as emptying the open list as there is nothing left to
|
|
||||||
// search...
|
|
||||||
// New: Allow user abort
|
|
||||||
if( m_OpenList.empty() || m_CancelRequest )
|
|
||||||
{
|
|
||||||
FreeAllNodes();
|
|
||||||
m_State = SEARCH_STATE_FAILED;
|
|
||||||
return m_State;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Incremement step count
|
|
||||||
m_Steps ++;
|
|
||||||
|
|
||||||
// Pop the best node (the one with the lowest f)
|
|
||||||
Node *n = m_OpenList.front(); // get pointer to the node
|
|
||||||
pop_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
|
|
||||||
m_OpenList.pop_back();
|
|
||||||
|
|
||||||
// Check for the goal, once we pop that we're done
|
|
||||||
if( n->m_UserState.IsGoal( m_Goal->m_UserState ) )
|
|
||||||
{
|
|
||||||
// The user is going to use the Goal Node he passed in
|
|
||||||
// so copy the parent pointer of n
|
|
||||||
m_Goal->parent = n->parent;
|
|
||||||
m_Goal->g = n->g;
|
|
||||||
|
|
||||||
// A special case is that the goal was passed in as the start state
|
|
||||||
// so handle that here
|
|
||||||
if( false == n->m_UserState.IsSameState( m_Start->m_UserState ) )
|
|
||||||
{
|
|
||||||
FreeNode( n );
|
|
||||||
|
|
||||||
// set the child pointers in each node (except Goal which has no child)
|
|
||||||
Node *nodeChild = m_Goal;
|
|
||||||
Node *nodeParent = m_Goal->parent;
|
|
||||||
|
|
||||||
do
|
|
||||||
{
|
|
||||||
nodeParent->child = nodeChild;
|
|
||||||
|
|
||||||
nodeChild = nodeParent;
|
|
||||||
nodeParent = nodeParent->parent;
|
|
||||||
|
|
||||||
}
|
|
||||||
while( nodeChild != m_Start ); // Start is always the first node by definition
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// delete nodes that aren't needed for the solution
|
|
||||||
FreeUnusedNodes();
|
|
||||||
|
|
||||||
m_State = SEARCH_STATE_SUCCEEDED;
|
|
||||||
|
|
||||||
return m_State;
|
|
||||||
}
|
|
||||||
else // not goal
|
|
||||||
{
|
|
||||||
|
|
||||||
// We now need to generate the successors of this node
|
|
||||||
// The user helps us to do this, and we keep the new nodes in
|
|
||||||
// m_Successors ...
|
|
||||||
|
|
||||||
m_Successors.clear(); // empty vector of successor nodes to n
|
|
||||||
|
|
||||||
// User provides this functions and uses AddSuccessor to add each successor of
|
|
||||||
// node 'n' to m_Successors
|
|
||||||
bool ret = n->m_UserState.GetSuccessors( this, n->parent ? &n->parent->m_UserState : NULL );
|
|
||||||
|
|
||||||
if( !ret )
|
|
||||||
{
|
|
||||||
|
|
||||||
typename vector< Node * >::iterator successor;
|
|
||||||
|
|
||||||
// free the nodes that may previously have been added
|
|
||||||
for( successor = m_Successors.begin(); successor != m_Successors.end(); successor ++ )
|
|
||||||
{
|
|
||||||
FreeNode( (*successor) );
|
|
||||||
}
|
|
||||||
|
|
||||||
m_Successors.clear(); // empty vector of successor nodes to n
|
|
||||||
|
|
||||||
// free up everything else we allocated
|
|
||||||
FreeNode( (n) );
|
|
||||||
FreeAllNodes();
|
|
||||||
|
|
||||||
m_State = SEARCH_STATE_OUT_OF_MEMORY;
|
|
||||||
return m_State;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now handle each successor to the current node ...
|
|
||||||
for( typename vector< Node * >::iterator successor = m_Successors.begin(); successor != m_Successors.end(); successor ++ )
|
|
||||||
{
|
|
||||||
|
|
||||||
// The g value for this successor ...
|
|
||||||
float newg = n->g + n->m_UserState.GetCost( (*successor)->m_UserState );
|
|
||||||
|
|
||||||
// Now we need to find whether the node is on the open or closed lists
|
|
||||||
// If it is but the node that is already on them is better (lower g)
|
|
||||||
// then we can forget about this successor
|
|
||||||
|
|
||||||
// First linear search of open list to find node
|
|
||||||
|
|
||||||
typename vector< Node * >::iterator openlist_result;
|
|
||||||
|
|
||||||
for( openlist_result = m_OpenList.begin(); openlist_result != m_OpenList.end(); openlist_result ++ )
|
|
||||||
{
|
|
||||||
if( (*openlist_result)->m_UserState.IsSameState( (*successor)->m_UserState ) )
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if( openlist_result != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
|
|
||||||
// we found this state on open
|
|
||||||
|
|
||||||
if( (*openlist_result)->g <= newg )
|
|
||||||
{
|
|
||||||
FreeNode( (*successor) );
|
|
||||||
|
|
||||||
// the one on Open is cheaper than this one
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
typename vector< Node * >::iterator closedlist_result;
|
|
||||||
|
|
||||||
for( closedlist_result = m_ClosedList.begin(); closedlist_result != m_ClosedList.end(); closedlist_result ++ )
|
|
||||||
{
|
|
||||||
if( (*closedlist_result)->m_UserState.IsSameState( (*successor)->m_UserState ) )
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if( closedlist_result != m_ClosedList.end() )
|
|
||||||
{
|
|
||||||
|
|
||||||
// we found this state on closed
|
|
||||||
|
|
||||||
if( (*closedlist_result)->g <= newg )
|
|
||||||
{
|
|
||||||
// the one on Closed is cheaper than this one
|
|
||||||
FreeNode( (*successor) );
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This node is the best node so far with this particular state
|
|
||||||
// so lets keep it and set up its AStar specific data ...
|
|
||||||
|
|
||||||
(*successor)->parent = n;
|
|
||||||
(*successor)->g = newg;
|
|
||||||
(*successor)->h = (*successor)->m_UserState.GoalDistanceEstimate( m_Goal->m_UserState );
|
|
||||||
(*successor)->f = (*successor)->g + (*successor)->h;
|
|
||||||
|
|
||||||
// Successor in closed list
|
|
||||||
// 1 - Update old version of this node in closed list
|
|
||||||
// 2 - Move it from closed to open list
|
|
||||||
// 3 - Sort heap again in open list
|
|
||||||
|
|
||||||
if( closedlist_result != m_ClosedList.end() )
|
|
||||||
{
|
|
||||||
// Update closed node with successor node AStar data
|
|
||||||
//*(*closedlist_result) = *(*successor);
|
|
||||||
(*closedlist_result)->parent = (*successor)->parent;
|
|
||||||
(*closedlist_result)->g = (*successor)->g;
|
|
||||||
(*closedlist_result)->h = (*successor)->h;
|
|
||||||
(*closedlist_result)->f = (*successor)->f;
|
|
||||||
|
|
||||||
// Free successor node
|
|
||||||
FreeNode( (*successor) );
|
|
||||||
|
|
||||||
// Push closed node into open list
|
|
||||||
m_OpenList.push_back( (*closedlist_result) );
|
|
||||||
|
|
||||||
// Remove closed node from closed list
|
|
||||||
m_ClosedList.erase( closedlist_result );
|
|
||||||
|
|
||||||
// Sort back element into heap
|
|
||||||
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
|
|
||||||
|
|
||||||
// Fix thanks to ...
|
|
||||||
// Greg Douglas <gregdouglasmail@gmail.com>
|
|
||||||
// who noticed that this code path was incorrect
|
|
||||||
// Here we have found a new state which is already CLOSED
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Successor in open list
|
|
||||||
// 1 - Update old version of this node in open list
|
|
||||||
// 2 - sort heap again in open list
|
|
||||||
|
|
||||||
else if( openlist_result != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
// Update open node with successor node AStar data
|
|
||||||
//*(*openlist_result) = *(*successor);
|
|
||||||
(*openlist_result)->parent = (*successor)->parent;
|
|
||||||
(*openlist_result)->g = (*successor)->g;
|
|
||||||
(*openlist_result)->h = (*successor)->h;
|
|
||||||
(*openlist_result)->f = (*successor)->f;
|
|
||||||
|
|
||||||
// Free successor node
|
|
||||||
FreeNode( (*successor) );
|
|
||||||
|
|
||||||
// re-make the heap
|
|
||||||
// make_heap rather than sort_heap is an essential bug fix
|
|
||||||
// thanks to Mike Ryynanen for pointing this out and then explaining
|
|
||||||
// it in detail. sort_heap called on an invalid heap does not work
|
|
||||||
make_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
|
|
||||||
}
|
|
||||||
|
|
||||||
// New successor
|
|
||||||
// 1 - Move it from successors to open list
|
|
||||||
// 2 - sort heap again in open list
|
|
||||||
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Push successor node into open list
|
|
||||||
m_OpenList.push_back( (*successor) );
|
|
||||||
|
|
||||||
// Sort back element into heap
|
|
||||||
push_heap( m_OpenList.begin(), m_OpenList.end(), HeapCompare_f() );
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// push n onto Closed, as we have expanded it now
|
|
||||||
|
|
||||||
m_ClosedList.push_back( n );
|
|
||||||
|
|
||||||
} // end else (not goal so expand)
|
|
||||||
|
|
||||||
return m_State; // Succeeded bool is false at this point.
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// User calls this to add a successor to a list of successors
|
|
||||||
// when expanding the search frontier
|
|
||||||
bool AddSuccessor( UserState &State )
|
|
||||||
{
|
|
||||||
Node *node = AllocateNode();
|
|
||||||
|
|
||||||
if( node )
|
|
||||||
{
|
|
||||||
node->m_UserState = State;
|
|
||||||
|
|
||||||
m_Successors.push_back( node );
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free the solution nodes
|
|
||||||
// This is done to clean up all used Node memory when you are done with the
|
|
||||||
// search
|
|
||||||
void FreeSolutionNodes()
|
|
||||||
{
|
|
||||||
Node *n = m_Start;
|
|
||||||
|
|
||||||
if( m_Start->child )
|
|
||||||
{
|
|
||||||
do
|
|
||||||
{
|
|
||||||
Node *del = n;
|
|
||||||
n = n->child;
|
|
||||||
FreeNode( del );
|
|
||||||
|
|
||||||
del = NULL;
|
|
||||||
|
|
||||||
} while( n != m_Goal );
|
|
||||||
|
|
||||||
FreeNode( n ); // Delete the goal
|
|
||||||
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// if the start node is the solution we need to just delete the start and goal
|
|
||||||
// nodes
|
|
||||||
FreeNode( m_Start );
|
|
||||||
FreeNode( m_Goal );
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Functions for traversing the solution
|
|
||||||
|
|
||||||
// Get start node
|
|
||||||
UserState *GetSolutionStart()
|
|
||||||
{
|
|
||||||
m_CurrentSolutionNode = m_Start;
|
|
||||||
if( m_Start )
|
|
||||||
{
|
|
||||||
return &m_Start->m_UserState;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get next node
|
|
||||||
UserState *GetSolutionNext()
|
|
||||||
{
|
|
||||||
if( m_CurrentSolutionNode )
|
|
||||||
{
|
|
||||||
if( m_CurrentSolutionNode->child )
|
|
||||||
{
|
|
||||||
|
|
||||||
Node *child = m_CurrentSolutionNode->child;
|
|
||||||
|
|
||||||
m_CurrentSolutionNode = m_CurrentSolutionNode->child;
|
|
||||||
|
|
||||||
return &child->m_UserState;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get end node
|
|
||||||
UserState *GetSolutionEnd()
|
|
||||||
{
|
|
||||||
m_CurrentSolutionNode = m_Goal;
|
|
||||||
if( m_Goal )
|
|
||||||
{
|
|
||||||
return &m_Goal->m_UserState;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step solution iterator backwards
|
|
||||||
UserState *GetSolutionPrev()
|
|
||||||
{
|
|
||||||
if( m_CurrentSolutionNode )
|
|
||||||
{
|
|
||||||
if( m_CurrentSolutionNode->parent )
|
|
||||||
{
|
|
||||||
|
|
||||||
Node *parent = m_CurrentSolutionNode->parent;
|
|
||||||
|
|
||||||
m_CurrentSolutionNode = m_CurrentSolutionNode->parent;
|
|
||||||
|
|
||||||
return &parent->m_UserState;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get final cost of solution
|
|
||||||
// Returns FLT_MAX if goal is not defined or there is no solution
|
|
||||||
float GetSolutionCost()
|
|
||||||
{
|
|
||||||
if( m_Goal && m_State == SEARCH_STATE_SUCCEEDED )
|
|
||||||
{
|
|
||||||
return m_Goal->g;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return FLT_MAX;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For educational use and debugging it is useful to be able to view
|
|
||||||
// the open and closed list at each step, here are two functions to allow that.
|
|
||||||
|
|
||||||
UserState *GetOpenListStart()
|
|
||||||
{
|
|
||||||
float f,g,h;
|
|
||||||
return GetOpenListStart( f,g,h );
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetOpenListStart( float &f, float &g, float &h )
|
|
||||||
{
|
|
||||||
iterDbgOpen = m_OpenList.begin();
|
|
||||||
if( iterDbgOpen != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
f = (*iterDbgOpen)->f;
|
|
||||||
g = (*iterDbgOpen)->g;
|
|
||||||
h = (*iterDbgOpen)->h;
|
|
||||||
return &(*iterDbgOpen)->m_UserState;
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetOpenListNext()
|
|
||||||
{
|
|
||||||
float f,g,h;
|
|
||||||
return GetOpenListNext( f,g,h );
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetOpenListNext( float &f, float &g, float &h )
|
|
||||||
{
|
|
||||||
iterDbgOpen++;
|
|
||||||
if( iterDbgOpen != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
f = (*iterDbgOpen)->f;
|
|
||||||
g = (*iterDbgOpen)->g;
|
|
||||||
h = (*iterDbgOpen)->h;
|
|
||||||
return &(*iterDbgOpen)->m_UserState;
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetClosedListStart()
|
|
||||||
{
|
|
||||||
float f,g,h;
|
|
||||||
return GetClosedListStart( f,g,h );
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetClosedListStart( float &f, float &g, float &h )
|
|
||||||
{
|
|
||||||
iterDbgClosed = m_ClosedList.begin();
|
|
||||||
if( iterDbgClosed != m_ClosedList.end() )
|
|
||||||
{
|
|
||||||
f = (*iterDbgClosed)->f;
|
|
||||||
g = (*iterDbgClosed)->g;
|
|
||||||
h = (*iterDbgClosed)->h;
|
|
||||||
|
|
||||||
return &(*iterDbgClosed)->m_UserState;
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetClosedListNext()
|
|
||||||
{
|
|
||||||
float f,g,h;
|
|
||||||
return GetClosedListNext( f,g,h );
|
|
||||||
}
|
|
||||||
|
|
||||||
UserState *GetClosedListNext( float &f, float &g, float &h )
|
|
||||||
{
|
|
||||||
iterDbgClosed++;
|
|
||||||
if( iterDbgClosed != m_ClosedList.end() )
|
|
||||||
{
|
|
||||||
f = (*iterDbgClosed)->f;
|
|
||||||
g = (*iterDbgClosed)->g;
|
|
||||||
h = (*iterDbgClosed)->h;
|
|
||||||
|
|
||||||
return &(*iterDbgClosed)->m_UserState;
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the number of steps
|
|
||||||
|
|
||||||
int GetStepCount() { return m_Steps; }
|
|
||||||
|
|
||||||
void EnsureMemoryFreed()
|
|
||||||
{
|
|
||||||
#if USE_FSA_MEMORY
|
|
||||||
assert(m_AllocateNodeCount == 0);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private: // methods
|
|
||||||
|
|
||||||
// This is called when a search fails or is cancelled to free all used
|
|
||||||
// memory
|
|
||||||
void FreeAllNodes()
|
|
||||||
{
|
|
||||||
// iterate open list and delete all nodes
|
|
||||||
typename vector< Node * >::iterator iterOpen = m_OpenList.begin();
|
|
||||||
|
|
||||||
while( iterOpen != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
Node *n = (*iterOpen);
|
|
||||||
FreeNode( n );
|
|
||||||
|
|
||||||
iterOpen ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
m_OpenList.clear();
|
|
||||||
|
|
||||||
// iterate closed list and delete unused nodes
|
|
||||||
typename vector< Node * >::iterator iterClosed;
|
|
||||||
|
|
||||||
for( iterClosed = m_ClosedList.begin(); iterClosed != m_ClosedList.end(); iterClosed ++ )
|
|
||||||
{
|
|
||||||
Node *n = (*iterClosed);
|
|
||||||
FreeNode( n );
|
|
||||||
}
|
|
||||||
|
|
||||||
m_ClosedList.clear();
|
|
||||||
|
|
||||||
// delete the goal
|
|
||||||
|
|
||||||
FreeNode(m_Goal);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// This call is made by the search class when the search ends. A lot of nodes may be
|
|
||||||
// created that are still present when the search ends. They will be deleted by this
|
|
||||||
// routine once the search ends
|
|
||||||
void FreeUnusedNodes()
|
|
||||||
{
|
|
||||||
// iterate open list and delete unused nodes
|
|
||||||
typename vector< Node * >::iterator iterOpen = m_OpenList.begin();
|
|
||||||
|
|
||||||
while( iterOpen != m_OpenList.end() )
|
|
||||||
{
|
|
||||||
Node *n = (*iterOpen);
|
|
||||||
|
|
||||||
if( !n->child )
|
|
||||||
{
|
|
||||||
FreeNode( n );
|
|
||||||
|
|
||||||
n = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
iterOpen ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
m_OpenList.clear();
|
|
||||||
|
|
||||||
// iterate closed list and delete unused nodes
|
|
||||||
typename vector< Node * >::iterator iterClosed;
|
|
||||||
|
|
||||||
for( iterClosed = m_ClosedList.begin(); iterClosed != m_ClosedList.end(); iterClosed ++ )
|
|
||||||
{
|
|
||||||
Node *n = (*iterClosed);
|
|
||||||
|
|
||||||
if( !n->child )
|
|
||||||
{
|
|
||||||
FreeNode( n );
|
|
||||||
n = NULL;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m_ClosedList.clear();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Node memory management
|
|
||||||
Node *AllocateNode()
|
|
||||||
{
|
|
||||||
|
|
||||||
#if !USE_FSA_MEMORY
|
|
||||||
m_AllocateNodeCount ++;
|
|
||||||
Node *p = new Node;
|
|
||||||
return p;
|
|
||||||
#else
|
|
||||||
Node *address = m_FixedSizeAllocator.alloc();
|
|
||||||
|
|
||||||
if( !address )
|
|
||||||
{
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
m_AllocateNodeCount ++;
|
|
||||||
Node *p = new (address) Node;
|
|
||||||
return p;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void FreeNode( Node *node )
|
|
||||||
{
|
|
||||||
|
|
||||||
m_AllocateNodeCount --;
|
|
||||||
|
|
||||||
#if !USE_FSA_MEMORY
|
|
||||||
delete node;
|
|
||||||
#else
|
|
||||||
node->~Node();
|
|
||||||
m_FixedSizeAllocator.free( node );
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
private: // data
|
|
||||||
|
|
||||||
// Heap (simple vector but used as a heap, cf. Steve Rabin's game gems article)
|
|
||||||
vector< Node *> m_OpenList;
|
|
||||||
|
|
||||||
// Closed list is a vector.
|
|
||||||
vector< Node * > m_ClosedList;
|
|
||||||
|
|
||||||
// Successors is a vector filled out by the user each type successors to a node
|
|
||||||
// are generated
|
|
||||||
vector< Node * > m_Successors;
|
|
||||||
|
|
||||||
// State
|
|
||||||
unsigned int m_State;
|
|
||||||
|
|
||||||
// Counts steps
|
|
||||||
int m_Steps;
|
|
||||||
|
|
||||||
// Start and goal state pointers
|
|
||||||
Node *m_Start;
|
|
||||||
Node *m_Goal;
|
|
||||||
|
|
||||||
Node *m_CurrentSolutionNode;
|
|
||||||
|
|
||||||
#if USE_FSA_MEMORY
|
|
||||||
// Memory
|
|
||||||
FixedSizeAllocator<Node> m_FixedSizeAllocator;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//Debug : need to keep these two iterators around
|
|
||||||
// for the user Dbg functions
|
|
||||||
typename vector< Node * >::iterator iterDbgOpen;
|
|
||||||
typename vector< Node * >::iterator iterDbgClosed;
|
|
||||||
|
|
||||||
// debugging : count memory allocation and free's
|
|
||||||
int m_AllocateNodeCount;
|
|
||||||
|
|
||||||
bool m_CancelRequest;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class T> class AStarState
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~AStarState() {}
|
|
||||||
virtual float GoalDistanceEstimate( T &nodeGoal ) = 0; // Heuristic function which computes the estimated cost to the goal node
|
|
||||||
virtual bool IsGoal( T &nodeGoal ) = 0; // Returns true if this node is the goal node
|
|
||||||
virtual bool GetSuccessors( AStarSearch<T> *astarsearch, T *parent_node ) = 0; // Retrieves all successors to this node and adds them via astarsearch.addSuccessor()
|
|
||||||
virtual float GetCost( T &successor ) = 0; // Computes the cost of travelling from this node to the successor node
|
|
||||||
virtual bool IsSameState( T &rhs ) = 0; // Returns true if this node is the same as the rhs node
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
@ -1,674 +0,0 @@
|
|||||||
GNU GENERAL PUBLIC LICENSE
|
|
||||||
Version 3, 29 June 2007
|
|
||||||
|
|
||||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
|
||||||
Everyone is permitted to copy and distribute verbatim copies
|
|
||||||
of this license document, but changing it is not allowed.
|
|
||||||
|
|
||||||
Preamble
|
|
||||||
|
|
||||||
The GNU General Public License is a free, copyleft license for
|
|
||||||
software and other kinds of works.
|
|
||||||
|
|
||||||
The licenses for most software and other practical works are designed
|
|
||||||
to take away your freedom to share and change the works. By contrast,
|
|
||||||
the GNU General Public License is intended to guarantee your freedom to
|
|
||||||
share and change all versions of a program--to make sure it remains free
|
|
||||||
software for all its users. We, the Free Software Foundation, use the
|
|
||||||
GNU General Public License for most of our software; it applies also to
|
|
||||||
any other work released this way by its authors. You can apply it to
|
|
||||||
your programs, too.
|
|
||||||
|
|
||||||
When we speak of free software, we are referring to freedom, not
|
|
||||||
price. Our General Public Licenses are designed to make sure that you
|
|
||||||
have the freedom to distribute copies of free software (and charge for
|
|
||||||
them if you wish), that you receive source code or can get it if you
|
|
||||||
want it, that you can change the software or use pieces of it in new
|
|
||||||
free programs, and that you know you can do these things.
|
|
||||||
|
|
||||||
To protect your rights, we need to prevent others from denying you
|
|
||||||
these rights or asking you to surrender the rights. Therefore, you have
|
|
||||||
certain responsibilities if you distribute copies of the software, or if
|
|
||||||
you modify it: responsibilities to respect the freedom of others.
|
|
||||||
|
|
||||||
For example, if you distribute copies of such a program, whether
|
|
||||||
gratis or for a fee, you must pass on to the recipients the same
|
|
||||||
freedoms that you received. You must make sure that they, too, receive
|
|
||||||
or can get the source code. And you must show them these terms so they
|
|
||||||
know their rights.
|
|
||||||
|
|
||||||
Developers that use the GNU GPL protect your rights with two steps:
|
|
||||||
(1) assert copyright on the software, and (2) offer you this License
|
|
||||||
giving you legal permission to copy, distribute and/or modify it.
|
|
||||||
|
|
||||||
For the developers' and authors' protection, the GPL clearly explains
|
|
||||||
that there is no warranty for this free software. For both users' and
|
|
||||||
authors' sake, the GPL requires that modified versions be marked as
|
|
||||||
changed, so that their problems will not be attributed erroneously to
|
|
||||||
authors of previous versions.
|
|
||||||
|
|
||||||
Some devices are designed to deny users access to install or run
|
|
||||||
modified versions of the software inside them, although the manufacturer
|
|
||||||
can do so. This is fundamentally incompatible with the aim of
|
|
||||||
protecting users' freedom to change the software. The systematic
|
|
||||||
pattern of such abuse occurs in the area of products for individuals to
|
|
||||||
use, which is precisely where it is most unacceptable. Therefore, we
|
|
||||||
have designed this version of the GPL to prohibit the practice for those
|
|
||||||
products. If such problems arise substantially in other domains, we
|
|
||||||
stand ready to extend this provision to those domains in future versions
|
|
||||||
of the GPL, as needed to protect the freedom of users.
|
|
||||||
|
|
||||||
Finally, every program is threatened constantly by software patents.
|
|
||||||
States should not allow patents to restrict development and use of
|
|
||||||
software on general-purpose computers, but in those that do, we wish to
|
|
||||||
avoid the special danger that patents applied to a free program could
|
|
||||||
make it effectively proprietary. To prevent this, the GPL assures that
|
|
||||||
patents cannot be used to render the program non-free.
|
|
||||||
|
|
||||||
The precise terms and conditions for copying, distribution and
|
|
||||||
modification follow.
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
0. Definitions.
|
|
||||||
|
|
||||||
"This License" refers to version 3 of the GNU General Public License.
|
|
||||||
|
|
||||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
|
||||||
works, such as semiconductor masks.
|
|
||||||
|
|
||||||
"The Program" refers to any copyrightable work licensed under this
|
|
||||||
License. Each licensee is addressed as "you". "Licensees" and
|
|
||||||
"recipients" may be individuals or organizations.
|
|
||||||
|
|
||||||
To "modify" a work means to copy from or adapt all or part of the work
|
|
||||||
in a fashion requiring copyright permission, other than the making of an
|
|
||||||
exact copy. The resulting work is called a "modified version" of the
|
|
||||||
earlier work or a work "based on" the earlier work.
|
|
||||||
|
|
||||||
A "covered work" means either the unmodified Program or a work based
|
|
||||||
on the Program.
|
|
||||||
|
|
||||||
To "propagate" a work means to do anything with it that, without
|
|
||||||
permission, would make you directly or secondarily liable for
|
|
||||||
infringement under applicable copyright law, except executing it on a
|
|
||||||
computer or modifying a private copy. Propagation includes copying,
|
|
||||||
distribution (with or without modification), making available to the
|
|
||||||
public, and in some countries other activities as well.
|
|
||||||
|
|
||||||
To "convey" a work means any kind of propagation that enables other
|
|
||||||
parties to make or receive copies. Mere interaction with a user through
|
|
||||||
a computer network, with no transfer of a copy, is not conveying.
|
|
||||||
|
|
||||||
An interactive user interface displays "Appropriate Legal Notices"
|
|
||||||
to the extent that it includes a convenient and prominently visible
|
|
||||||
feature that (1) displays an appropriate copyright notice, and (2)
|
|
||||||
tells the user that there is no warranty for the work (except to the
|
|
||||||
extent that warranties are provided), that licensees may convey the
|
|
||||||
work under this License, and how to view a copy of this License. If
|
|
||||||
the interface presents a list of user commands or options, such as a
|
|
||||||
menu, a prominent item in the list meets this criterion.
|
|
||||||
|
|
||||||
1. Source Code.
|
|
||||||
|
|
||||||
The "source code" for a work means the preferred form of the work
|
|
||||||
for making modifications to it. "Object code" means any non-source
|
|
||||||
form of a work.
|
|
||||||
|
|
||||||
A "Standard Interface" means an interface that either is an official
|
|
||||||
standard defined by a recognized standards body, or, in the case of
|
|
||||||
interfaces specified for a particular programming language, one that
|
|
||||||
is widely used among developers working in that language.
|
|
||||||
|
|
||||||
The "System Libraries" of an executable work include anything, other
|
|
||||||
than the work as a whole, that (a) is included in the normal form of
|
|
||||||
packaging a Major Component, but which is not part of that Major
|
|
||||||
Component, and (b) serves only to enable use of the work with that
|
|
||||||
Major Component, or to implement a Standard Interface for which an
|
|
||||||
implementation is available to the public in source code form. A
|
|
||||||
"Major Component", in this context, means a major essential component
|
|
||||||
(kernel, window system, and so on) of the specific operating system
|
|
||||||
(if any) on which the executable work runs, or a compiler used to
|
|
||||||
produce the work, or an object code interpreter used to run it.
|
|
||||||
|
|
||||||
The "Corresponding Source" for a work in object code form means all
|
|
||||||
the source code needed to generate, install, and (for an executable
|
|
||||||
work) run the object code and to modify the work, including scripts to
|
|
||||||
control those activities. However, it does not include the work's
|
|
||||||
System Libraries, or general-purpose tools or generally available free
|
|
||||||
programs which are used unmodified in performing those activities but
|
|
||||||
which are not part of the work. For example, Corresponding Source
|
|
||||||
includes interface definition files associated with source files for
|
|
||||||
the work, and the source code for shared libraries and dynamically
|
|
||||||
linked subprograms that the work is specifically designed to require,
|
|
||||||
such as by intimate data communication or control flow between those
|
|
||||||
subprograms and other parts of the work.
|
|
||||||
|
|
||||||
The Corresponding Source need not include anything that users
|
|
||||||
can regenerate automatically from other parts of the Corresponding
|
|
||||||
Source.
|
|
||||||
|
|
||||||
The Corresponding Source for a work in source code form is that
|
|
||||||
same work.
|
|
||||||
|
|
||||||
2. Basic Permissions.
|
|
||||||
|
|
||||||
All rights granted under this License are granted for the term of
|
|
||||||
copyright on the Program, and are irrevocable provided the stated
|
|
||||||
conditions are met. This License explicitly affirms your unlimited
|
|
||||||
permission to run the unmodified Program. The output from running a
|
|
||||||
covered work is covered by this License only if the output, given its
|
|
||||||
content, constitutes a covered work. This License acknowledges your
|
|
||||||
rights of fair use or other equivalent, as provided by copyright law.
|
|
||||||
|
|
||||||
You may make, run and propagate covered works that you do not
|
|
||||||
convey, without conditions so long as your license otherwise remains
|
|
||||||
in force. You may convey covered works to others for the sole purpose
|
|
||||||
of having them make modifications exclusively for you, or provide you
|
|
||||||
with facilities for running those works, provided that you comply with
|
|
||||||
the terms of this License in conveying all material for which you do
|
|
||||||
not control copyright. Those thus making or running the covered works
|
|
||||||
for you must do so exclusively on your behalf, under your direction
|
|
||||||
and control, on terms that prohibit them from making any copies of
|
|
||||||
your copyrighted material outside their relationship with you.
|
|
||||||
|
|
||||||
Conveying under any other circumstances is permitted solely under
|
|
||||||
the conditions stated below. Sublicensing is not allowed; section 10
|
|
||||||
makes it unnecessary.
|
|
||||||
|
|
||||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
|
||||||
|
|
||||||
No covered work shall be deemed part of an effective technological
|
|
||||||
measure under any applicable law fulfilling obligations under article
|
|
||||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
|
||||||
similar laws prohibiting or restricting circumvention of such
|
|
||||||
measures.
|
|
||||||
|
|
||||||
When you convey a covered work, you waive any legal power to forbid
|
|
||||||
circumvention of technological measures to the extent such circumvention
|
|
||||||
is effected by exercising rights under this License with respect to
|
|
||||||
the covered work, and you disclaim any intention to limit operation or
|
|
||||||
modification of the work as a means of enforcing, against the work's
|
|
||||||
users, your or third parties' legal rights to forbid circumvention of
|
|
||||||
technological measures.
|
|
||||||
|
|
||||||
4. Conveying Verbatim Copies.
|
|
||||||
|
|
||||||
You may convey verbatim copies of the Program's source code as you
|
|
||||||
receive it, in any medium, provided that you conspicuously and
|
|
||||||
appropriately publish on each copy an appropriate copyright notice;
|
|
||||||
keep intact all notices stating that this License and any
|
|
||||||
non-permissive terms added in accord with section 7 apply to the code;
|
|
||||||
keep intact all notices of the absence of any warranty; and give all
|
|
||||||
recipients a copy of this License along with the Program.
|
|
||||||
|
|
||||||
You may charge any price or no price for each copy that you convey,
|
|
||||||
and you may offer support or warranty protection for a fee.
|
|
||||||
|
|
||||||
5. Conveying Modified Source Versions.
|
|
||||||
|
|
||||||
You may convey a work based on the Program, or the modifications to
|
|
||||||
produce it from the Program, in the form of source code under the
|
|
||||||
terms of section 4, provided that you also meet all of these conditions:
|
|
||||||
|
|
||||||
a) The work must carry prominent notices stating that you modified
|
|
||||||
it, and giving a relevant date.
|
|
||||||
|
|
||||||
b) The work must carry prominent notices stating that it is
|
|
||||||
released under this License and any conditions added under section
|
|
||||||
7. This requirement modifies the requirement in section 4 to
|
|
||||||
"keep intact all notices".
|
|
||||||
|
|
||||||
c) You must license the entire work, as a whole, under this
|
|
||||||
License to anyone who comes into possession of a copy. This
|
|
||||||
License will therefore apply, along with any applicable section 7
|
|
||||||
additional terms, to the whole of the work, and all its parts,
|
|
||||||
regardless of how they are packaged. This License gives no
|
|
||||||
permission to license the work in any other way, but it does not
|
|
||||||
invalidate such permission if you have separately received it.
|
|
||||||
|
|
||||||
d) If the work has interactive user interfaces, each must display
|
|
||||||
Appropriate Legal Notices; however, if the Program has interactive
|
|
||||||
interfaces that do not display Appropriate Legal Notices, your
|
|
||||||
work need not make them do so.
|
|
||||||
|
|
||||||
A compilation of a covered work with other separate and independent
|
|
||||||
works, which are not by their nature extensions of the covered work,
|
|
||||||
and which are not combined with it such as to form a larger program,
|
|
||||||
in or on a volume of a storage or distribution medium, is called an
|
|
||||||
"aggregate" if the compilation and its resulting copyright are not
|
|
||||||
used to limit the access or legal rights of the compilation's users
|
|
||||||
beyond what the individual works permit. Inclusion of a covered work
|
|
||||||
in an aggregate does not cause this License to apply to the other
|
|
||||||
parts of the aggregate.
|
|
||||||
|
|
||||||
6. Conveying Non-Source Forms.
|
|
||||||
|
|
||||||
You may convey a covered work in object code form under the terms
|
|
||||||
of sections 4 and 5, provided that you also convey the
|
|
||||||
machine-readable Corresponding Source under the terms of this License,
|
|
||||||
in one of these ways:
|
|
||||||
|
|
||||||
a) Convey the object code in, or embodied in, a physical product
|
|
||||||
(including a physical distribution medium), accompanied by the
|
|
||||||
Corresponding Source fixed on a durable physical medium
|
|
||||||
customarily used for software interchange.
|
|
||||||
|
|
||||||
b) Convey the object code in, or embodied in, a physical product
|
|
||||||
(including a physical distribution medium), accompanied by a
|
|
||||||
written offer, valid for at least three years and valid for as
|
|
||||||
long as you offer spare parts or customer support for that product
|
|
||||||
model, to give anyone who possesses the object code either (1) a
|
|
||||||
copy of the Corresponding Source for all the software in the
|
|
||||||
product that is covered by this License, on a durable physical
|
|
||||||
medium customarily used for software interchange, for a price no
|
|
||||||
more than your reasonable cost of physically performing this
|
|
||||||
conveying of source, or (2) access to copy the
|
|
||||||
Corresponding Source from a network server at no charge.
|
|
||||||
|
|
||||||
c) Convey individual copies of the object code with a copy of the
|
|
||||||
written offer to provide the Corresponding Source. This
|
|
||||||
alternative is allowed only occasionally and noncommercially, and
|
|
||||||
only if you received the object code with such an offer, in accord
|
|
||||||
with subsection 6b.
|
|
||||||
|
|
||||||
d) Convey the object code by offering access from a designated
|
|
||||||
place (gratis or for a charge), and offer equivalent access to the
|
|
||||||
Corresponding Source in the same way through the same place at no
|
|
||||||
further charge. You need not require recipients to copy the
|
|
||||||
Corresponding Source along with the object code. If the place to
|
|
||||||
copy the object code is a network server, the Corresponding Source
|
|
||||||
may be on a different server (operated by you or a third party)
|
|
||||||
that supports equivalent copying facilities, provided you maintain
|
|
||||||
clear directions next to the object code saying where to find the
|
|
||||||
Corresponding Source. Regardless of what server hosts the
|
|
||||||
Corresponding Source, you remain obligated to ensure that it is
|
|
||||||
available for as long as needed to satisfy these requirements.
|
|
||||||
|
|
||||||
e) Convey the object code using peer-to-peer transmission, provided
|
|
||||||
you inform other peers where the object code and Corresponding
|
|
||||||
Source of the work are being offered to the general public at no
|
|
||||||
charge under subsection 6d.
|
|
||||||
|
|
||||||
A separable portion of the object code, whose source code is excluded
|
|
||||||
from the Corresponding Source as a System Library, need not be
|
|
||||||
included in conveying the object code work.
|
|
||||||
|
|
||||||
A "User Product" is either (1) a "consumer product", which means any
|
|
||||||
tangible personal property which is normally used for personal, family,
|
|
||||||
or household purposes, or (2) anything designed or sold for incorporation
|
|
||||||
into a dwelling. In determining whether a product is a consumer product,
|
|
||||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
|
||||||
product received by a particular user, "normally used" refers to a
|
|
||||||
typical or common use of that class of product, regardless of the status
|
|
||||||
of the particular user or of the way in which the particular user
|
|
||||||
actually uses, or expects or is expected to use, the product. A product
|
|
||||||
is a consumer product regardless of whether the product has substantial
|
|
||||||
commercial, industrial or non-consumer uses, unless such uses represent
|
|
||||||
the only significant mode of use of the product.
|
|
||||||
|
|
||||||
"Installation Information" for a User Product means any methods,
|
|
||||||
procedures, authorization keys, or other information required to install
|
|
||||||
and execute modified versions of a covered work in that User Product from
|
|
||||||
a modified version of its Corresponding Source. The information must
|
|
||||||
suffice to ensure that the continued functioning of the modified object
|
|
||||||
code is in no case prevented or interfered with solely because
|
|
||||||
modification has been made.
|
|
||||||
|
|
||||||
If you convey an object code work under this section in, or with, or
|
|
||||||
specifically for use in, a User Product, and the conveying occurs as
|
|
||||||
part of a transaction in which the right of possession and use of the
|
|
||||||
User Product is transferred to the recipient in perpetuity or for a
|
|
||||||
fixed term (regardless of how the transaction is characterized), the
|
|
||||||
Corresponding Source conveyed under this section must be accompanied
|
|
||||||
by the Installation Information. But this requirement does not apply
|
|
||||||
if neither you nor any third party retains the ability to install
|
|
||||||
modified object code on the User Product (for example, the work has
|
|
||||||
been installed in ROM).
|
|
||||||
|
|
||||||
The requirement to provide Installation Information does not include a
|
|
||||||
requirement to continue to provide support service, warranty, or updates
|
|
||||||
for a work that has been modified or installed by the recipient, or for
|
|
||||||
the User Product in which it has been modified or installed. Access to a
|
|
||||||
network may be denied when the modification itself materially and
|
|
||||||
adversely affects the operation of the network or violates the rules and
|
|
||||||
protocols for communication across the network.
|
|
||||||
|
|
||||||
Corresponding Source conveyed, and Installation Information provided,
|
|
||||||
in accord with this section must be in a format that is publicly
|
|
||||||
documented (and with an implementation available to the public in
|
|
||||||
source code form), and must require no special password or key for
|
|
||||||
unpacking, reading or copying.
|
|
||||||
|
|
||||||
7. Additional Terms.
|
|
||||||
|
|
||||||
"Additional permissions" are terms that supplement the terms of this
|
|
||||||
License by making exceptions from one or more of its conditions.
|
|
||||||
Additional permissions that are applicable to the entire Program shall
|
|
||||||
be treated as though they were included in this License, to the extent
|
|
||||||
that they are valid under applicable law. If additional permissions
|
|
||||||
apply only to part of the Program, that part may be used separately
|
|
||||||
under those permissions, but the entire Program remains governed by
|
|
||||||
this License without regard to the additional permissions.
|
|
||||||
|
|
||||||
When you convey a copy of a covered work, you may at your option
|
|
||||||
remove any additional permissions from that copy, or from any part of
|
|
||||||
it. (Additional permissions may be written to require their own
|
|
||||||
removal in certain cases when you modify the work.) You may place
|
|
||||||
additional permissions on material, added by you to a covered work,
|
|
||||||
for which you have or can give appropriate copyright permission.
|
|
||||||
|
|
||||||
Notwithstanding any other provision of this License, for material you
|
|
||||||
add to a covered work, you may (if authorized by the copyright holders of
|
|
||||||
that material) supplement the terms of this License with terms:
|
|
||||||
|
|
||||||
a) Disclaiming warranty or limiting liability differently from the
|
|
||||||
terms of sections 15 and 16 of this License; or
|
|
||||||
|
|
||||||
b) Requiring preservation of specified reasonable legal notices or
|
|
||||||
author attributions in that material or in the Appropriate Legal
|
|
||||||
Notices displayed by works containing it; or
|
|
||||||
|
|
||||||
c) Prohibiting misrepresentation of the origin of that material, or
|
|
||||||
requiring that modified versions of such material be marked in
|
|
||||||
reasonable ways as different from the original version; or
|
|
||||||
|
|
||||||
d) Limiting the use for publicity purposes of names of licensors or
|
|
||||||
authors of the material; or
|
|
||||||
|
|
||||||
e) Declining to grant rights under trademark law for use of some
|
|
||||||
trade names, trademarks, or service marks; or
|
|
||||||
|
|
||||||
f) Requiring indemnification of licensors and authors of that
|
|
||||||
material by anyone who conveys the material (or modified versions of
|
|
||||||
it) with contractual assumptions of liability to the recipient, for
|
|
||||||
any liability that these contractual assumptions directly impose on
|
|
||||||
those licensors and authors.
|
|
||||||
|
|
||||||
All other non-permissive additional terms are considered "further
|
|
||||||
restrictions" within the meaning of section 10. If the Program as you
|
|
||||||
received it, or any part of it, contains a notice stating that it is
|
|
||||||
governed by this License along with a term that is a further
|
|
||||||
restriction, you may remove that term. If a license document contains
|
|
||||||
a further restriction but permits relicensing or conveying under this
|
|
||||||
License, you may add to a covered work material governed by the terms
|
|
||||||
of that license document, provided that the further restriction does
|
|
||||||
not survive such relicensing or conveying.
|
|
||||||
|
|
||||||
If you add terms to a covered work in accord with this section, you
|
|
||||||
must place, in the relevant source files, a statement of the
|
|
||||||
additional terms that apply to those files, or a notice indicating
|
|
||||||
where to find the applicable terms.
|
|
||||||
|
|
||||||
Additional terms, permissive or non-permissive, may be stated in the
|
|
||||||
form of a separately written license, or stated as exceptions;
|
|
||||||
the above requirements apply either way.
|
|
||||||
|
|
||||||
8. Termination.
|
|
||||||
|
|
||||||
You may not propagate or modify a covered work except as expressly
|
|
||||||
provided under this License. Any attempt otherwise to propagate or
|
|
||||||
modify it is void, and will automatically terminate your rights under
|
|
||||||
this License (including any patent licenses granted under the third
|
|
||||||
paragraph of section 11).
|
|
||||||
|
|
||||||
However, if you cease all violation of this License, then your
|
|
||||||
license from a particular copyright holder is reinstated (a)
|
|
||||||
provisionally, unless and until the copyright holder explicitly and
|
|
||||||
finally terminates your license, and (b) permanently, if the copyright
|
|
||||||
holder fails to notify you of the violation by some reasonable means
|
|
||||||
prior to 60 days after the cessation.
|
|
||||||
|
|
||||||
Moreover, your license from a particular copyright holder is
|
|
||||||
reinstated permanently if the copyright holder notifies you of the
|
|
||||||
violation by some reasonable means, this is the first time you have
|
|
||||||
received notice of violation of this License (for any work) from that
|
|
||||||
copyright holder, and you cure the violation prior to 30 days after
|
|
||||||
your receipt of the notice.
|
|
||||||
|
|
||||||
Termination of your rights under this section does not terminate the
|
|
||||||
licenses of parties who have received copies or rights from you under
|
|
||||||
this License. If your rights have been terminated and not permanently
|
|
||||||
reinstated, you do not qualify to receive new licenses for the same
|
|
||||||
material under section 10.
|
|
||||||
|
|
||||||
9. Acceptance Not Required for Having Copies.
|
|
||||||
|
|
||||||
You are not required to accept this License in order to receive or
|
|
||||||
run a copy of the Program. Ancillary propagation of a covered work
|
|
||||||
occurring solely as a consequence of using peer-to-peer transmission
|
|
||||||
to receive a copy likewise does not require acceptance. However,
|
|
||||||
nothing other than this License grants you permission to propagate or
|
|
||||||
modify any covered work. These actions infringe copyright if you do
|
|
||||||
not accept this License. Therefore, by modifying or propagating a
|
|
||||||
covered work, you indicate your acceptance of this License to do so.
|
|
||||||
|
|
||||||
10. Automatic Licensing of Downstream Recipients.
|
|
||||||
|
|
||||||
Each time you convey a covered work, the recipient automatically
|
|
||||||
receives a license from the original licensors, to run, modify and
|
|
||||||
propagate that work, subject to this License. You are not responsible
|
|
||||||
for enforcing compliance by third parties with this License.
|
|
||||||
|
|
||||||
An "entity transaction" is a transaction transferring control of an
|
|
||||||
organization, or substantially all assets of one, or subdividing an
|
|
||||||
organization, or merging organizations. If propagation of a covered
|
|
||||||
work results from an entity transaction, each party to that
|
|
||||||
transaction who receives a copy of the work also receives whatever
|
|
||||||
licenses to the work the party's predecessor in interest had or could
|
|
||||||
give under the previous paragraph, plus a right to possession of the
|
|
||||||
Corresponding Source of the work from the predecessor in interest, if
|
|
||||||
the predecessor has it or can get it with reasonable efforts.
|
|
||||||
|
|
||||||
You may not impose any further restrictions on the exercise of the
|
|
||||||
rights granted or affirmed under this License. For example, you may
|
|
||||||
not impose a license fee, royalty, or other charge for exercise of
|
|
||||||
rights granted under this License, and you may not initiate litigation
|
|
||||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
|
||||||
any patent claim is infringed by making, using, selling, offering for
|
|
||||||
sale, or importing the Program or any portion of it.
|
|
||||||
|
|
||||||
11. Patents.
|
|
||||||
|
|
||||||
A "contributor" is a copyright holder who authorizes use under this
|
|
||||||
License of the Program or a work on which the Program is based. The
|
|
||||||
work thus licensed is called the contributor's "contributor version".
|
|
||||||
|
|
||||||
A contributor's "essential patent claims" are all patent claims
|
|
||||||
owned or controlled by the contributor, whether already acquired or
|
|
||||||
hereafter acquired, that would be infringed by some manner, permitted
|
|
||||||
by this License, of making, using, or selling its contributor version,
|
|
||||||
but do not include claims that would be infringed only as a
|
|
||||||
consequence of further modification of the contributor version. For
|
|
||||||
purposes of this definition, "control" includes the right to grant
|
|
||||||
patent sublicenses in a manner consistent with the requirements of
|
|
||||||
this License.
|
|
||||||
|
|
||||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
|
||||||
patent license under the contributor's essential patent claims, to
|
|
||||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
|
||||||
propagate the contents of its contributor version.
|
|
||||||
|
|
||||||
In the following three paragraphs, a "patent license" is any express
|
|
||||||
agreement or commitment, however denominated, not to enforce a patent
|
|
||||||
(such as an express permission to practice a patent or covenant not to
|
|
||||||
sue for patent infringement). To "grant" such a patent license to a
|
|
||||||
party means to make such an agreement or commitment not to enforce a
|
|
||||||
patent against the party.
|
|
||||||
|
|
||||||
If you convey a covered work, knowingly relying on a patent license,
|
|
||||||
and the Corresponding Source of the work is not available for anyone
|
|
||||||
to copy, free of charge and under the terms of this License, through a
|
|
||||||
publicly available network server or other readily accessible means,
|
|
||||||
then you must either (1) cause the Corresponding Source to be so
|
|
||||||
available, or (2) arrange to deprive yourself of the benefit of the
|
|
||||||
patent license for this particular work, or (3) arrange, in a manner
|
|
||||||
consistent with the requirements of this License, to extend the patent
|
|
||||||
license to downstream recipients. "Knowingly relying" means you have
|
|
||||||
actual knowledge that, but for the patent license, your conveying the
|
|
||||||
covered work in a country, or your recipient's use of the covered work
|
|
||||||
in a country, would infringe one or more identifiable patents in that
|
|
||||||
country that you have reason to believe are valid.
|
|
||||||
|
|
||||||
If, pursuant to or in connection with a single transaction or
|
|
||||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
|
||||||
covered work, and grant a patent license to some of the parties
|
|
||||||
receiving the covered work authorizing them to use, propagate, modify
|
|
||||||
or convey a specific copy of the covered work, then the patent license
|
|
||||||
you grant is automatically extended to all recipients of the covered
|
|
||||||
work and works based on it.
|
|
||||||
|
|
||||||
A patent license is "discriminatory" if it does not include within
|
|
||||||
the scope of its coverage, prohibits the exercise of, or is
|
|
||||||
conditioned on the non-exercise of one or more of the rights that are
|
|
||||||
specifically granted under this License. You may not convey a covered
|
|
||||||
work if you are a party to an arrangement with a third party that is
|
|
||||||
in the business of distributing software, under which you make payment
|
|
||||||
to the third party based on the extent of your activity of conveying
|
|
||||||
the work, and under which the third party grants, to any of the
|
|
||||||
parties who would receive the covered work from you, a discriminatory
|
|
||||||
patent license (a) in connection with copies of the covered work
|
|
||||||
conveyed by you (or copies made from those copies), or (b) primarily
|
|
||||||
for and in connection with specific products or compilations that
|
|
||||||
contain the covered work, unless you entered into that arrangement,
|
|
||||||
or that patent license was granted, prior to 28 March 2007.
|
|
||||||
|
|
||||||
Nothing in this License shall be construed as excluding or limiting
|
|
||||||
any implied license or other defenses to infringement that may
|
|
||||||
otherwise be available to you under applicable patent law.
|
|
||||||
|
|
||||||
12. No Surrender of Others' Freedom.
|
|
||||||
|
|
||||||
If conditions are imposed on you (whether by court order, agreement or
|
|
||||||
otherwise) that contradict the conditions of this License, they do not
|
|
||||||
excuse you from the conditions of this License. If you cannot convey a
|
|
||||||
covered work so as to satisfy simultaneously your obligations under this
|
|
||||||
License and any other pertinent obligations, then as a consequence you may
|
|
||||||
not convey it at all. For example, if you agree to terms that obligate you
|
|
||||||
to collect a royalty for further conveying from those to whom you convey
|
|
||||||
the Program, the only way you could satisfy both those terms and this
|
|
||||||
License would be to refrain entirely from conveying the Program.
|
|
||||||
|
|
||||||
13. Use with the GNU Affero General Public License.
|
|
||||||
|
|
||||||
Notwithstanding any other provision of this License, you have
|
|
||||||
permission to link or combine any covered work with a work licensed
|
|
||||||
under version 3 of the GNU Affero General Public License into a single
|
|
||||||
combined work, and to convey the resulting work. The terms of this
|
|
||||||
License will continue to apply to the part which is the covered work,
|
|
||||||
but the special requirements of the GNU Affero General Public License,
|
|
||||||
section 13, concerning interaction through a network will apply to the
|
|
||||||
combination as such.
|
|
||||||
|
|
||||||
14. Revised Versions of this License.
|
|
||||||
|
|
||||||
The Free Software Foundation may publish revised and/or new versions of
|
|
||||||
the GNU General Public License from time to time. Such new versions will
|
|
||||||
be similar in spirit to the present version, but may differ in detail to
|
|
||||||
address new problems or concerns.
|
|
||||||
|
|
||||||
Each version is given a distinguishing version number. If the
|
|
||||||
Program specifies that a certain numbered version of the GNU General
|
|
||||||
Public License "or any later version" applies to it, you have the
|
|
||||||
option of following the terms and conditions either of that numbered
|
|
||||||
version or of any later version published by the Free Software
|
|
||||||
Foundation. If the Program does not specify a version number of the
|
|
||||||
GNU General Public License, you may choose any version ever published
|
|
||||||
by the Free Software Foundation.
|
|
||||||
|
|
||||||
If the Program specifies that a proxy can decide which future
|
|
||||||
versions of the GNU General Public License can be used, that proxy's
|
|
||||||
public statement of acceptance of a version permanently authorizes you
|
|
||||||
to choose that version for the Program.
|
|
||||||
|
|
||||||
Later license versions may give you additional or different
|
|
||||||
permissions. However, no additional obligations are imposed on any
|
|
||||||
author or copyright holder as a result of your choosing to follow a
|
|
||||||
later version.
|
|
||||||
|
|
||||||
15. Disclaimer of Warranty.
|
|
||||||
|
|
||||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
|
||||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
|
||||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
|
||||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
|
||||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
||||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
|
||||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
|
||||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
|
||||||
|
|
||||||
16. Limitation of Liability.
|
|
||||||
|
|
||||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
|
||||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
|
||||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
|
||||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
|
||||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
|
||||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
|
||||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
|
||||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
|
||||||
SUCH DAMAGES.
|
|
||||||
|
|
||||||
17. Interpretation of Sections 15 and 16.
|
|
||||||
|
|
||||||
If the disclaimer of warranty and limitation of liability provided
|
|
||||||
above cannot be given local legal effect according to their terms,
|
|
||||||
reviewing courts shall apply local law that most closely approximates
|
|
||||||
an absolute waiver of all civil liability in connection with the
|
|
||||||
Program, unless a warranty or assumption of liability accompanies a
|
|
||||||
copy of the Program in return for a fee.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
How to Apply These Terms to Your New Programs
|
|
||||||
|
|
||||||
If you develop a new program, and you want it to be of the greatest
|
|
||||||
possible use to the public, the best way to achieve this is to make it
|
|
||||||
free software which everyone can redistribute and change under these terms.
|
|
||||||
|
|
||||||
To do so, attach the following notices to the program. It is safest
|
|
||||||
to attach them to the start of each source file to most effectively
|
|
||||||
state the exclusion of warranty; and each file should have at least
|
|
||||||
the "copyright" line and a pointer to where the full notice is found.
|
|
||||||
|
|
||||||
<one line to give the program's name and a brief idea of what it does.>
|
|
||||||
Copyright (C) <year> <name of author>
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
Also add information on how to contact you by electronic and paper mail.
|
|
||||||
|
|
||||||
If the program does terminal interaction, make it output a short
|
|
||||||
notice like this when it starts in an interactive mode:
|
|
||||||
|
|
||||||
<program> Copyright (C) <year> <name of author>
|
|
||||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
|
||||||
This is free software, and you are welcome to redistribute it
|
|
||||||
under certain conditions; type `show c' for details.
|
|
||||||
|
|
||||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
|
||||||
parts of the General Public License. Of course, your program's commands
|
|
||||||
might be different; for a GUI interface, you would use an "about box".
|
|
||||||
|
|
||||||
You should also get your employer (if you work as a programmer) or school,
|
|
||||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
|
||||||
For more information on this, and how to apply and follow the GNU GPL, see
|
|
||||||
<https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
The GNU General Public License does not permit incorporating your program
|
|
||||||
into proprietary programs. If your program is a subroutine library, you
|
|
||||||
may consider it more useful to permit linking proprietary applications with
|
|
||||||
the library. If this is what you want to do, use the GNU Lesser General
|
|
||||||
Public License instead of this License. But first, please read
|
|
||||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
@ -1,105 +0,0 @@
|
|||||||
# MT-YOLOv6 [About Naming YOLOv6](./docs/About_naming_yolov6.md)
|
|
||||||
|
|
||||||
## Introduction
|
|
||||||
|
|
||||||
YOLOv6 is a single-stage object detection framework dedicated to industrial applications, with hardware-friendly efficient design and high performance.
|
|
||||||
|
|
||||||
<img src="assets/picture.png" width="800">
|
|
||||||
|
|
||||||
YOLOv6-nano achieves 35.0 mAP on COCO val2017 dataset with 1242 FPS on T4 using TensorRT FP16 for bs32 inference, and YOLOv6-s achieves 43.1 mAP on COCO val2017 dataset with 520 FPS on T4 using TensorRT FP16 for bs32 inference.
|
|
||||||
|
|
||||||
YOLOv6 is composed of the following methods:
|
|
||||||
|
|
||||||
- Hardware-friendly Design for Backbone and Neck
|
|
||||||
- Efficient Decoupled Head with SIoU Loss
|
|
||||||
|
|
||||||
|
|
||||||
## Coming soon
|
|
||||||
|
|
||||||
- [ ] YOLOv6 m/l/x model.
|
|
||||||
- [ ] Deployment for MNN/TNN/NCNN/CoreML...
|
|
||||||
- [ ] Quantization tools
|
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Install
|
|
||||||
|
|
||||||
```shell
|
|
||||||
git clone https://github.com/meituan/YOLOv6
|
|
||||||
cd YOLOv6
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
### Inference
|
|
||||||
|
|
||||||
First, download a pretrained model from the YOLOv6 [release](https://github.com/meituan/YOLOv6/releases/tag/0.1.0)
|
|
||||||
|
|
||||||
Second, run inference with `tools/infer.py`
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/infer.py --weights yolov6s.pt --source img.jpg / imgdir
|
|
||||||
yolov6n.pt
|
|
||||||
```
|
|
||||||
|
|
||||||
### Training
|
|
||||||
|
|
||||||
Single GPU
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/train.py --batch 32 --conf configs/yolov6s.py --data data/coco.yaml --device 0
|
|
||||||
configs/yolov6n.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Multi GPUs (DDP mode recommended)
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python -m torch.distributed.launch --nproc_per_node 8 tools/train.py --batch 256 --conf configs/yolov6s.py --data data/coco.yaml --device 0,1,2,3,4,5,6,7
|
|
||||||
configs/yolov6n.py
|
|
||||||
```
|
|
||||||
|
|
||||||
- conf: select config file to specify network/optimizer/hyperparameters
|
|
||||||
- data: prepare [COCO](http://cocodataset.org) dataset and specify dataset paths in data.yaml
|
|
||||||
|
|
||||||
|
|
||||||
### Evaluation
|
|
||||||
|
|
||||||
Reproduce mAP on COCO val2017 dataset
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6s.pt --task val
|
|
||||||
yolov6n.pt
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Deployment
|
|
||||||
|
|
||||||
* [ONNX](./deploy/ONNX)
|
|
||||||
* [OpenVINO](./deploy/OpenVINO)
|
|
||||||
|
|
||||||
### Tutorials
|
|
||||||
|
|
||||||
* [Train custom data](./docs/Train_custom_data.md)
|
|
||||||
* [Test speed](./docs/Test_speed.md)
|
|
||||||
|
|
||||||
|
|
||||||
## Benchmark
|
|
||||||
|
|
||||||
|
|
||||||
| Model | Size | mAP<sup>val<br/>0.5:0.95 | Speed<sup>V100<br/>fp16 b32 <br/>(ms) | Speed<sup>V100<br/>fp32 b32 <br/>(ms) | Speed<sup>T4<br/>trt fp16 b1 <br/>(fps) | Speed<sup>T4<br/>trt fp16 b32 <br/>(fps) | Params<br/><sup> (M) | Flops<br/><sup> (G) |
|
|
||||||
| :-------------- | ----------- | :----------------------- | :------------------------------------ | :------------------------------------ | ---------------------------------------- | ----------------------------------------- | --------------- | -------------- |
|
|
||||||
| [**YOLOv6-n**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.pt) | 416<br/>640 | 30.8<br/>35.0 | 0.3<br/>0.5 | 0.4<br/>0.7 | 1100<br/>788 | 2716<br/>1242 | 4.3<br/>4.3 | 4.7<br/>11.1 |
|
|
||||||
| [**YOLOv6-tiny**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6t.pt) | 640 | 41.3 | 0.9 | 1.5 | 425 | 602 | 15.0 | 36.7 |
|
|
||||||
| [**YOLOv6-s**](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.pt) | 640 | 43.1 | 1.0 | 1.7 | 373 | 520 | 17.2 | 44.2 |
|
|
||||||
|
|
||||||
|
|
||||||
- Comparisons of the mAP and speed of different object detectors are tested on [COCO val2017](https://cocodataset.org/#download) dataset.
|
|
||||||
- Refer to [Test speed](./docs/Test_speed.md) tutorial to reproduce the speed results of YOLOv6.
|
|
||||||
- Params and Flops of YOLOv6 are estimated on deployed model.
|
|
||||||
- Speed results of other methods are tested in our environment using official codebase and model if not found from the corresponding official release.
|
|
||||||
|
|
||||||
## Third-party resources
|
|
||||||
* YOLOv6 NCNN Android app demo: [ncnn-android-yolov6](https://github.com/FeiGeChuanShu/ncnn-android-yolov6) from [FeiGeChuanShu](https://github.com/FeiGeChuanShu)
|
|
||||||
* YOLOv6 ONNXRuntime/MNN/TNN C++: [YOLOv6-ORT](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/yolov6.cpp), [YOLOv6-MNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_yolov6.cpp) and [YOLOv6-TNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_yolov6.cpp) from [DefTruth](https://github.com/DefTruth)
|
|
||||||
* YOLOv6 TensorRT Python: [yolov6-tensorrt-python](https://github.com/Linaom1214/tensorrt-python/blob/main/yolov6/trt.py) from [Linaom1214](https://github.com/Linaom1214)
|
|
||||||
* YOLOv6 TensorRT Windows C++: [yolort](https://github.com/zhiqwang/yolov5-rt-stack/tree/main/deployment/tensorrt-yolov6) from [Wei Zeng](https://github.com/Wulingtian)
|
|
Before Width: | Height: | Size: 517 KiB |
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6t model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6t',
|
|
||||||
pretrained=None,
|
|
||||||
depth_multiple=0.25,
|
|
||||||
width_multiple=0.50,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='ciou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.01,
|
|
||||||
lrf=0.01,
|
|
||||||
momentum=0.937,
|
|
||||||
weight_decay=0.0005,
|
|
||||||
warmup_epochs=3.0,
|
|
||||||
warmup_momentum=0.8,
|
|
||||||
warmup_bias_lr=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.015,
|
|
||||||
hsv_s=0.7,
|
|
||||||
hsv_v=0.4,
|
|
||||||
degrees=0.0,
|
|
||||||
translate=0.1,
|
|
||||||
scale=0.5,
|
|
||||||
shear=0.0,
|
|
||||||
flipud=0.0,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.0,
|
|
||||||
)
|
|
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6t model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6t',
|
|
||||||
pretrained='./weights/yolov6t.pt',
|
|
||||||
depth_multiple=0.25,
|
|
||||||
width_multiple=0.50,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='ciou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.0032,
|
|
||||||
lrf=0.12,
|
|
||||||
momentum=0.843,
|
|
||||||
weight_decay=0.00036,
|
|
||||||
warmup_epochs=2.0,
|
|
||||||
warmup_momentum=0.5,
|
|
||||||
warmup_bias_lr=0.05
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.0138,
|
|
||||||
hsv_s=0.664,
|
|
||||||
hsv_v=0.464,
|
|
||||||
degrees=0.373,
|
|
||||||
translate=0.245,
|
|
||||||
scale=0.898,
|
|
||||||
shear=0.602,
|
|
||||||
flipud=0.00856,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.243,
|
|
||||||
)
|
|
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6n model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6n',
|
|
||||||
pretrained=None,
|
|
||||||
depth_multiple=0.33,
|
|
||||||
width_multiple=0.25,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='ciou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.01,
|
|
||||||
lrf=0.01,
|
|
||||||
momentum=0.937,
|
|
||||||
weight_decay=0.0005,
|
|
||||||
warmup_epochs=3.0,
|
|
||||||
warmup_momentum=0.8,
|
|
||||||
warmup_bias_lr=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.015,
|
|
||||||
hsv_s=0.7,
|
|
||||||
hsv_v=0.4,
|
|
||||||
degrees=0.0,
|
|
||||||
translate=0.1,
|
|
||||||
scale=0.5,
|
|
||||||
shear=0.0,
|
|
||||||
flipud=0.0,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.0,
|
|
||||||
)
|
|
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6n model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6n',
|
|
||||||
pretrained='./weights/yolov6n.pt',
|
|
||||||
depth_multiple=0.33,
|
|
||||||
width_multiple=0.25,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='ciou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.0032,
|
|
||||||
lrf=0.12,
|
|
||||||
momentum=0.843,
|
|
||||||
weight_decay=0.00036,
|
|
||||||
warmup_epochs=2.0,
|
|
||||||
warmup_momentum=0.5,
|
|
||||||
warmup_bias_lr=0.05
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.0138,
|
|
||||||
hsv_s=0.664,
|
|
||||||
hsv_v=0.464,
|
|
||||||
degrees=0.373,
|
|
||||||
translate=0.245,
|
|
||||||
scale=0.898,
|
|
||||||
shear=0.602,
|
|
||||||
flipud=0.00856,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.243
|
|
||||||
)
|
|
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6s model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6s',
|
|
||||||
pretrained=None,
|
|
||||||
depth_multiple=0.33,
|
|
||||||
width_multiple=0.50,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='siou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.01,
|
|
||||||
lrf=0.01,
|
|
||||||
momentum=0.937,
|
|
||||||
weight_decay=0.0005,
|
|
||||||
warmup_epochs=3.0,
|
|
||||||
warmup_momentum=0.8,
|
|
||||||
warmup_bias_lr=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.015,
|
|
||||||
hsv_s=0.7,
|
|
||||||
hsv_v=0.4,
|
|
||||||
degrees=0.0,
|
|
||||||
translate=0.1,
|
|
||||||
scale=0.5,
|
|
||||||
shear=0.0,
|
|
||||||
flipud=0.0,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.0,
|
|
||||||
)
|
|
@ -1,53 +0,0 @@
|
|||||||
# YOLOv6s model
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6s',
|
|
||||||
pretrained='./weights/yolov6s.pt',
|
|
||||||
depth_multiple=0.33,
|
|
||||||
width_multiple=0.50,
|
|
||||||
backbone=dict(
|
|
||||||
type='EfficientRep',
|
|
||||||
num_repeats=[1, 6, 12, 18, 6],
|
|
||||||
out_channels=[64, 128, 256, 512, 1024],
|
|
||||||
),
|
|
||||||
neck=dict(
|
|
||||||
type='RepPAN',
|
|
||||||
num_repeats=[12, 12, 12, 12],
|
|
||||||
out_channels=[256, 128, 128, 256, 256, 512],
|
|
||||||
),
|
|
||||||
head=dict(
|
|
||||||
type='EffiDeHead',
|
|
||||||
in_channels=[128, 256, 512],
|
|
||||||
num_layers=3,
|
|
||||||
begin_indices=24,
|
|
||||||
anchors=1,
|
|
||||||
out_indices=[17, 20, 23],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
iou_type='siou'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
solver = dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
lr0=0.0032,
|
|
||||||
lrf=0.12,
|
|
||||||
momentum=0.843,
|
|
||||||
weight_decay=0.00036,
|
|
||||||
warmup_epochs=2.0,
|
|
||||||
warmup_momentum=0.5,
|
|
||||||
warmup_bias_lr=0.05
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.0138,
|
|
||||||
hsv_s=0.664,
|
|
||||||
hsv_v=0.464,
|
|
||||||
degrees=0.373,
|
|
||||||
translate=0.245,
|
|
||||||
scale=0.898,
|
|
||||||
shear=0.602,
|
|
||||||
flipud=0.00856,
|
|
||||||
fliplr=0.5,
|
|
||||||
mosaic=1.0,
|
|
||||||
mixup=0.243,
|
|
||||||
)
|
|
@ -1,20 +0,0 @@
|
|||||||
# COCO 2017 dataset http://cocodataset.org
|
|
||||||
train: ../coco/images/train2017 # 118287 images
|
|
||||||
val: ../coco/images/val2017 # 5000 images
|
|
||||||
test: ../coco/images/test2017
|
|
||||||
anno_path: ../coco/annotations/instances_val2017.json
|
|
||||||
# number of classes
|
|
||||||
nc: 80
|
|
||||||
# whether it is coco dataset, only coco dataset should be set to True.
|
|
||||||
is_coco: True
|
|
||||||
|
|
||||||
# class names
|
|
||||||
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
|
||||||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
|
||||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
|
||||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
|
||||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
||||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
|
||||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
||||||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
|
||||||
'hair drier', 'toothbrush' ]
|
|
@ -1,11 +0,0 @@
|
|||||||
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
|
|
||||||
train: ../custom_dataset/images/train # train images
|
|
||||||
val: ../custom_dataset/images/val # val images
|
|
||||||
test: ../custom_dataset/images/test # test images (optional)
|
|
||||||
|
|
||||||
# whether it is coco dataset, only coco dataset should be set to True.
|
|
||||||
is_coco: False
|
|
||||||
# Classes
|
|
||||||
nc: 20 # number of classes
|
|
||||||
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
|
||||||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
|
|
Before Width: | Height: | Size: 79 KiB |
Before Width: | Height: | Size: 140 KiB |
Before Width: | Height: | Size: 115 KiB |
@ -1,98 +0,0 @@
|
|||||||
# Export ONNX Model
|
|
||||||
|
|
||||||
## Check requirements
|
|
||||||
```shell
|
|
||||||
pip install onnx>=1.10.0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Export script
|
|
||||||
```shell
|
|
||||||
python ./deploy/ONNX/export_onnx.py \
|
|
||||||
--weights yolov6s.pt \
|
|
||||||
--img 640 \
|
|
||||||
--batch 1
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### Description of all arguments
|
|
||||||
|
|
||||||
- `--weights` : The path of yolov6 model weights.
|
|
||||||
- `--img` : Image size of model inputs.
|
|
||||||
- `--batch` : Batch size of model inputs.
|
|
||||||
- `--half` : Whether to export half-precision model.
|
|
||||||
- `--inplace` : Whether to set Detect() inplace.
|
|
||||||
- `--simplify` : Whether to simplify onnx. Not support in end to end export.
|
|
||||||
- `--end2end` : Whether to export end to end onnx model. Only support onnxruntime and TensorRT >= 8.0.0 .
|
|
||||||
- `--max-wh` : Default is None for TensorRT backend. Set int for onnxruntime backend.
|
|
||||||
- `--topk-all` : Topk objects for every image.
|
|
||||||
- `--iou-thres` : IoU threshold for NMS algorithm.
|
|
||||||
- `--conf-thres` : Confidence threshold for NMS algorithm.
|
|
||||||
- `--device` : Export device. Cuda device : 0 or 0,1,2,3 ... , CPU : cpu .
|
|
||||||
|
|
||||||
## Download
|
|
||||||
|
|
||||||
* [YOLOv6-nano](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.onnx)
|
|
||||||
* [YOLOv6-tiny](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6t.onnx)
|
|
||||||
* [YOLOv6-s](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.onnx)
|
|
||||||
|
|
||||||
## End2End export
|
|
||||||
|
|
||||||
Now YOLOv6 supports end to end detect for onnxruntime and TensorRT !
|
|
||||||
|
|
||||||
If you want to deploy in TensorRT, make sure you have installed TensorRT >= 8.0.0 !
|
|
||||||
|
|
||||||
### onnxruntime backend
|
|
||||||
#### Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python ./deploy/ONNX/export_onnx.py \
|
|
||||||
--weights yolov6s.pt \
|
|
||||||
--img 640 \
|
|
||||||
--batch 1 \
|
|
||||||
--end2end \
|
|
||||||
--max-wh 7680
|
|
||||||
```
|
|
||||||
|
|
||||||
You will get an onnx with **NonMaxSuppression** operater .
|
|
||||||
|
|
||||||
The onnx outputs shape is ```nums x 7```.
|
|
||||||
|
|
||||||
```nums``` means the number of all objects which were detected.
|
|
||||||
|
|
||||||
```7``` means [`batch_index`,`x0`,`y0`,`x1`,` y1`,`classid`,`score`]
|
|
||||||
|
|
||||||
### TensorRT backend (TensorRT version>= 8.0.0)
|
|
||||||
|
|
||||||
#### Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python ./deploy/ONNX/export_onnx.py \
|
|
||||||
--weights yolov6s.pt \
|
|
||||||
--img 640 \
|
|
||||||
--batch 1 \
|
|
||||||
--end2end
|
|
||||||
```
|
|
||||||
|
|
||||||
You will get an onnx with **[EfficientNMS_TRT](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin)** plugin .
|
|
||||||
The onnx outputs are as shown :
|
|
||||||
|
|
||||||
<img src="https://user-images.githubusercontent.com/92794867/176650971-a4fa3d65-10d4-4b65-b8ef-00a2ff13406c.png" height="300px" />
|
|
||||||
|
|
||||||
```num_dets``` means the number of object in every image in its batch .
|
|
||||||
|
|
||||||
```det_boxes``` means topk(100) object's location about [`x0`,`y0`,`x1`,` y1`] .
|
|
||||||
|
|
||||||
```det_scores``` means the confidence score of every topk(100) objects .
|
|
||||||
|
|
||||||
```det_classes``` means the category of every topk(100) objects .
|
|
||||||
|
|
||||||
|
|
||||||
You can export TensorRT engine use [trtexec](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#trtexec-ovr) tools.
|
|
||||||
#### Usage
|
|
||||||
``` shell
|
|
||||||
/path/to/trtexec \
|
|
||||||
--onnx=yolov6s.onnx \
|
|
||||||
--saveEngine=yolov6s.engine \
|
|
||||||
--fp16 # if export TensorRT fp16 model
|
|
||||||
```
|
|
@ -1,112 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import onnx
|
|
||||||
|
|
||||||
ROOT = os.getcwd()
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT))
|
|
||||||
|
|
||||||
from yolov6.models.yolo import *
|
|
||||||
from yolov6.models.effidehead import Detect
|
|
||||||
from yolov6.layers.common import *
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.utils.checkpoint import load_checkpoint
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--weights', type=str, default='./yolov6s.pt', help='weights path')
|
|
||||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
|
|
||||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
|
||||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
|
||||||
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
|
|
||||||
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
|
|
||||||
parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
|
|
||||||
parser.add_argument('--max-wh', type=int, default=None, help='None for trt int for ort')
|
|
||||||
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
|
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
|
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
|
|
||||||
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
|
|
||||||
print(args)
|
|
||||||
t = time.time()
|
|
||||||
|
|
||||||
# Check device
|
|
||||||
cuda = args.device != 'cpu' and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'
|
|
||||||
# Load PyTorch model
|
|
||||||
model = load_checkpoint(args.weights, map_location=device, inplace=True, fuse=True) # load FP32 model
|
|
||||||
for layer in model.modules():
|
|
||||||
if isinstance(layer, RepVGGBlock):
|
|
||||||
layer.switch_to_deploy()
|
|
||||||
|
|
||||||
# Input
|
|
||||||
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device) # image size(1,3,320,192) iDetection
|
|
||||||
|
|
||||||
# Update model
|
|
||||||
if args.half:
|
|
||||||
img, model = img.half(), model.half() # to FP16
|
|
||||||
model.eval()
|
|
||||||
for k, m in model.named_modules():
|
|
||||||
if isinstance(m, Conv): # assign export-friendly activations
|
|
||||||
if isinstance(m.act, nn.SiLU):
|
|
||||||
m.act = SiLU()
|
|
||||||
elif isinstance(m, Detect):
|
|
||||||
m.inplace = args.inplace
|
|
||||||
if args.end2end:
|
|
||||||
from yolov6.models.end2end import End2End
|
|
||||||
model = End2End(model, max_obj=args.topk_all, iou_thres=args.iou_thres,
|
|
||||||
score_thres=args.conf_thres, max_wh=args.max_wh, device=device)
|
|
||||||
|
|
||||||
y = model(img) # dry run
|
|
||||||
|
|
||||||
# ONNX export
|
|
||||||
try:
|
|
||||||
LOGGER.info('\nStarting to export ONNX...')
|
|
||||||
export_file = args.weights.replace('.pt', '.onnx') # filename
|
|
||||||
with BytesIO() as f:
|
|
||||||
torch.onnx.export(model, img, f, verbose=False, opset_version=12,
|
|
||||||
training=torch.onnx.TrainingMode.EVAL,
|
|
||||||
do_constant_folding=True,
|
|
||||||
input_names=['image_arrays'],
|
|
||||||
output_names=['num_dets', 'det_boxes', 'det_scores', 'det_classes']
|
|
||||||
if args.end2end and args.max_wh is None else ['outputs'],)
|
|
||||||
f.seek(0)
|
|
||||||
# Checks
|
|
||||||
onnx_model = onnx.load(f) # load onnx model
|
|
||||||
onnx.checker.check_model(onnx_model) # check onnx model
|
|
||||||
# Fix output shape
|
|
||||||
if args.end2end and args.max_wh is None:
|
|
||||||
shapes = [args.batch_size, 1, args.batch_size, args.topk_all, 4,
|
|
||||||
args.batch_size, args.topk_all, args.batch_size, args.topk_all]
|
|
||||||
for i in onnx_model.graph.output:
|
|
||||||
for j in i.type.tensor_type.shape.dim:
|
|
||||||
j.dim_param = str(shapes.pop(0))
|
|
||||||
if args.simplify:
|
|
||||||
try:
|
|
||||||
import onnxsim
|
|
||||||
LOGGER.info('\nStarting to simplify ONNX...')
|
|
||||||
onnx_model, check = onnxsim.simplify(onnx_model)
|
|
||||||
assert check, 'assert check failed'
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'Simplifier failure: {e}')
|
|
||||||
onnx.save(onnx_model, export_file)
|
|
||||||
LOGGER.info(f'ONNX export success, saved as {export_file}')
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'ONNX export failure: {e}')
|
|
||||||
|
|
||||||
# Finish
|
|
||||||
LOGGER.info('\nExport complete (%.2fs)' % (time.time() - t))
|
|
||||||
if args.end2end:
|
|
||||||
if args.max_wh is None:
|
|
||||||
LOGGER.info('\nYou can export tensorrt engine use trtexec tools.\nCommand is:')
|
|
||||||
LOGGER.info(f'trtexec --onnx={export_file} --saveEngine={export_file.replace(".onnx",".engine")}')
|
|
@ -1,24 +0,0 @@
|
|||||||
## Export OpenVINO Model
|
|
||||||
|
|
||||||
### Check requirements
|
|
||||||
```shell
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install openvino-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### Export script
|
|
||||||
```shell
|
|
||||||
python deploy/OpenVINO/export_openvino.py --weights yolov6s.pt --img 640 --batch 1
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### Download
|
|
||||||
* [YOLOv6-nano](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
|
|
||||||
* [YOLOv6-tiny](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
|
|
||||||
* [YOLOv6-s](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n_openvino.tar.gz)
|
|
||||||
|
|
||||||
### Speed test
|
|
||||||
```shell
|
|
||||||
benchmark_app -m yolov6s_openvino/yolov6s.xml -i data/images/image1.jpg -d CPU -niter 100 -progress
|
|
||||||
|
|
||||||
```
|
|
@ -1,92 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import onnx
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
ROOT = os.getcwd()
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT))
|
|
||||||
|
|
||||||
from yolov6.models.yolo import *
|
|
||||||
from yolov6.models.effidehead import Detect
|
|
||||||
from yolov6.layers.common import *
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.utils.checkpoint import load_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--weights', type=str, default='./yolov6s.pt', help='weights path')
|
|
||||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
|
|
||||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
|
||||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
|
||||||
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
|
|
||||||
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
|
|
||||||
print(args)
|
|
||||||
t = time.time()
|
|
||||||
|
|
||||||
# Check device
|
|
||||||
cuda = args.device != 'cpu' and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'
|
|
||||||
# Load PyTorch model
|
|
||||||
model = load_checkpoint(args.weights, map_location=device, inplace=True, fuse=True) # load FP32 model
|
|
||||||
for layer in model.modules():
|
|
||||||
if isinstance(layer, RepVGGBlock):
|
|
||||||
layer.switch_to_deploy()
|
|
||||||
|
|
||||||
# Input
|
|
||||||
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device) # image size(1,3,320,192) iDetection
|
|
||||||
|
|
||||||
# Update model
|
|
||||||
if args.half:
|
|
||||||
img, model = img.half(), model.half() # to FP16
|
|
||||||
model.eval()
|
|
||||||
for k, m in model.named_modules():
|
|
||||||
if isinstance(m, Conv): # assign export-friendly activations
|
|
||||||
if isinstance(m.act, nn.SiLU):
|
|
||||||
m.act = SiLU()
|
|
||||||
elif isinstance(m, Detect):
|
|
||||||
m.inplace = args.inplace
|
|
||||||
|
|
||||||
y = model(img) # dry run
|
|
||||||
|
|
||||||
# ONNX export
|
|
||||||
try:
|
|
||||||
LOGGER.info('\nStarting to export ONNX...')
|
|
||||||
export_file = args.weights.replace('.pt', '.onnx') # filename
|
|
||||||
torch.onnx.export(model, img, export_file, verbose=False, opset_version=12,
|
|
||||||
training=torch.onnx.TrainingMode.EVAL,
|
|
||||||
do_constant_folding=True,
|
|
||||||
input_names=['image_arrays'],
|
|
||||||
output_names=['outputs'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Checks
|
|
||||||
onnx_model = onnx.load(export_file) # load onnx model
|
|
||||||
onnx.checker.check_model(onnx_model) # check onnx model
|
|
||||||
LOGGER.info(f'ONNX export success, saved as {export_file}')
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'ONNX export failure: {e}')
|
|
||||||
|
|
||||||
# OpenVINO export
|
|
||||||
try:
|
|
||||||
LOGGER.info('\nStarting to export OpenVINO...')
|
|
||||||
import_file = args.weights.replace('.pt', '.onnx')
|
|
||||||
export_dir = str(import_file).replace('.onnx', '_openvino')
|
|
||||||
cmd = f"mo --input_model {import_file} --output_dir {export_dir} --data_type {'FP16' if args.half else 'FP32'}"
|
|
||||||
subprocess.check_output(cmd.split())
|
|
||||||
LOGGER.info(f'OpenVINO export success, saved as {export_dir}')
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'OpenVINO export failure: {e}')
|
|
||||||
|
|
||||||
# Finish
|
|
||||||
LOGGER.info('\nExport complete (%.2fs)' % (time.time() - t))
|
|
@ -1,41 +0,0 @@
|
|||||||
# Test speed
|
|
||||||
|
|
||||||
This guidence explains how to reproduce speed results of YOLOv6. For fair comparison, the speed results do not contain the time cost of data pre-processing and NMS post-processing.
|
|
||||||
|
|
||||||
## 0. Prepare model
|
|
||||||
|
|
||||||
Download the models you want to test from the latest release.
|
|
||||||
|
|
||||||
## 1. Prepare testing environment
|
|
||||||
|
|
||||||
Refer to README, install packages corresponding to CUDA, CUDNN and TensorRT version.
|
|
||||||
|
|
||||||
Here, we use Torch1.8.0 inference on V100 and TensorRT 7.2 on T4.
|
|
||||||
|
|
||||||
## 2. Reproduce speed
|
|
||||||
|
|
||||||
#### 2.1 Torch Inference on V100
|
|
||||||
|
|
||||||
To get inference speed without TensorRT on V100, you can run the following command:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6n.pt --task speed [--half]
|
|
||||||
```
|
|
||||||
|
|
||||||
- Speed results with batchsize = 1 are unstable in multiple runs, thus we do not provide the bs1 speed results.
|
|
||||||
|
|
||||||
#### 2.2 TensorRT Inference on T4
|
|
||||||
|
|
||||||
To get inference speed with TensorRT in FP16 mode on T4, you can follow the steps below:
|
|
||||||
|
|
||||||
First, export pytorch model as onnx format using the following command:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python deploy/ONNX/export_onnx.py --weights yolov6n.pt --device 0 --batch [1 or 32]
|
|
||||||
```
|
|
||||||
|
|
||||||
Second, generate an inference trt engine and test speed using `trtexec`:
|
|
||||||
|
|
||||||
```
|
|
||||||
trtexec --onnx=yolov6n.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw
|
|
||||||
```
|
|
@ -1,143 +0,0 @@
|
|||||||
# Train Custom Data
|
|
||||||
|
|
||||||
This guidence explains how to train your own custom data with YOLOv6 (take fine-tuning YOLOv6-s model for example).
|
|
||||||
|
|
||||||
## 0. Before you start
|
|
||||||
|
|
||||||
Clone this repo and follow README.md to install requirements in a Python3.8 environment.
|
|
||||||
|
|
||||||
|
|
||||||
## 1. Prepare your own dataset
|
|
||||||
|
|
||||||
**Step 1** Prepare your own dataset with images. For labeling images, you can use tools like [Labelme](https://github.com/wkentaro/labelme).
|
|
||||||
|
|
||||||
**Step 2** Generate label files in YOLO format.
|
|
||||||
|
|
||||||
One image corresponds to one label file, and the label format example is presented as below.
|
|
||||||
|
|
||||||
```json
|
|
||||||
# class_id center_x center_y bbox_width bbox_height
|
|
||||||
0 0.300926 0.617063 0.601852 0.765873
|
|
||||||
1 0.575 0.319531 0.4 0.551562
|
|
||||||
```
|
|
||||||
|
|
||||||
- Each row represents one object.
|
|
||||||
- Class id starts from `0`.
|
|
||||||
- Boundingbox coordinates must be in normalized `xywh` format (from 0 - 1). If your boxes are in pixels, divide `center_x` and `bbox_width` by image width, and `center_y` and `bbox_height` by image height.
|
|
||||||
|
|
||||||
**Step 3** Organize directories.
|
|
||||||
|
|
||||||
Organize your directory of custom dataset as follows:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
custom_dataset
|
|
||||||
├── images
|
|
||||||
│ ├── train
|
|
||||||
│ │ ├── train0.jpg
|
|
||||||
│ │ └── train1.jpg
|
|
||||||
│ ├── val
|
|
||||||
│ │ ├── val0.jpg
|
|
||||||
│ │ └── val1.jpg
|
|
||||||
│ └── test
|
|
||||||
│ ├── test0.jpg
|
|
||||||
│ └── test1.jpg
|
|
||||||
└── labels
|
|
||||||
├── train
|
|
||||||
│ ├── train0.txt
|
|
||||||
│ └── train1.txt
|
|
||||||
├── val
|
|
||||||
│ ├── val0.txt
|
|
||||||
│ └── val1.txt
|
|
||||||
└── test
|
|
||||||
├── test0.txt
|
|
||||||
└── test1.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4** Create `dataset.yaml` in `$YOLOv6_DIR/data`.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
|
|
||||||
train: ../custom_dataset/images/train # train images
|
|
||||||
val: ../custom_dataset/images/val # val images
|
|
||||||
test: ../custom_dataset/images/test # test images (optional)
|
|
||||||
|
|
||||||
# whether it is coco dataset, only coco dataset should be set to True.
|
|
||||||
is_coco: False
|
|
||||||
|
|
||||||
# Classes
|
|
||||||
nc: 20 # number of classes
|
|
||||||
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
|
||||||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
|
|
||||||
```
|
|
||||||
|
|
||||||
## 2. Create a config file
|
|
||||||
|
|
||||||
We use a config file to specify the network structure and training setting, including optimizer and data augmentation hyperparameters.
|
|
||||||
|
|
||||||
If you create a new config file, please put it under the configs directory.
|
|
||||||
Or just use the provided config file in `$YOLOV6_HOME/configs/*_finetune.py`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
## YOLOv6s Model config file
|
|
||||||
model = dict(
|
|
||||||
type='YOLOv6s',
|
|
||||||
pretrained='./weights/yolov6s.pt', # download pretrain model from YOLOv6 github if use pretrained model
|
|
||||||
depth_multiple = 0.33,
|
|
||||||
width_multiple = 0.50,
|
|
||||||
...
|
|
||||||
)
|
|
||||||
solver=dict(
|
|
||||||
optim='SGD',
|
|
||||||
lr_scheduler='Cosine',
|
|
||||||
...
|
|
||||||
)
|
|
||||||
|
|
||||||
data_aug = dict(
|
|
||||||
hsv_h=0.015,
|
|
||||||
hsv_s=0.7,
|
|
||||||
hsv_v=0.4,
|
|
||||||
...
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 3. Train
|
|
||||||
|
|
||||||
Single GPU
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/train.py --batch 256 --conf configs/yolov6s_finetune.py --data data/data.yaml --device 0
|
|
||||||
```
|
|
||||||
|
|
||||||
Multi GPUs (DDP mode recommended)
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python -m torch.distributed.launch --nproc_per_node 4 tools/train.py --batch 256 --conf configs/yolov6s_finetune.py --data data/data.yaml --device 0,1,2,3
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 4. Evaluation
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/eval.py --data data/data.yaml --weights output_dir/name/weights/best_ckpt.pt --device 0
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 5. Inference
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/infer.py --weights output_dir/name/weights/best_ckpt.pt --source img.jpg --device 0
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 6. Deployment
|
|
||||||
|
|
||||||
Export as ONNX Format
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python deploy/ONNX/export_onnx.py --weights output_dir/name/weights/best_ckpt.pt --device 0
|
|
||||||
```
|
|
Before Width: | Height: | Size: 114 KiB |
Before Width: | Height: | Size: 132 KiB |
@ -1,16 +0,0 @@
|
|||||||
# pip install -r requirements.txt
|
|
||||||
# python3.8 environment
|
|
||||||
|
|
||||||
torch>=1.8.0
|
|
||||||
torchvision>=0.9.0
|
|
||||||
numpy>=1.18.5
|
|
||||||
opencv-python>=4.1.2
|
|
||||||
PyYAML>=5.3.1
|
|
||||||
scipy>=1.4.1
|
|
||||||
tqdm>=4.41.0
|
|
||||||
addict>=2.4.0
|
|
||||||
tensorboard>=2.7.0
|
|
||||||
pycocotools>=2.0
|
|
||||||
onnx>=1.10.0 # ONNX export
|
|
||||||
onnx-simplifier>=0.3.6 # ONNX simplifier
|
|
||||||
thop # FLOPs computation
|
|
Before Width: | Height: | Size: 266 KiB |
Before Width: | Height: | Size: 420 KiB |
@ -1,93 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
|
|
||||||
ROOT = os.getcwd()
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT))
|
|
||||||
|
|
||||||
from yolov6.core.evaler import Evaler
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.utils.general import increment_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser(add_help=True):
|
|
||||||
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Evalating', add_help=add_help)
|
|
||||||
parser.add_argument('--data', type=str, default='./data/coco.yaml', help='dataset.yaml path')
|
|
||||||
parser.add_argument('--weights', type=str, default='./weights/yolov6s.pt', help='model.pt path(s)')
|
|
||||||
parser.add_argument('--batch-size', type=int, default=32, help='batch size')
|
|
||||||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
|
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
|
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.65, help='NMS IoU threshold')
|
|
||||||
parser.add_argument('--task', default='val', help='val, or speed')
|
|
||||||
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
||||||
parser.add_argument('--half', default=False, action='store_true', help='whether to use fp16 infer')
|
|
||||||
parser.add_argument('--save_dir', type=str, default='runs/val/', help='evaluation save dir')
|
|
||||||
parser.add_argument('--name', type=str, default='exp', help='save evaluation results to save_dir/name')
|
|
||||||
args = parser.parse_args()
|
|
||||||
LOGGER.info(args)
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run(data,
|
|
||||||
weights=None,
|
|
||||||
batch_size=32,
|
|
||||||
img_size=640,
|
|
||||||
conf_thres=0.001,
|
|
||||||
iou_thres=0.65,
|
|
||||||
task='val',
|
|
||||||
device='',
|
|
||||||
half=False,
|
|
||||||
model=None,
|
|
||||||
dataloader=None,
|
|
||||||
save_dir='',
|
|
||||||
name = ''
|
|
||||||
):
|
|
||||||
""" Run the evaluation process
|
|
||||||
|
|
||||||
This function is the main process of evaluataion, supporting image file and dir containing images.
|
|
||||||
It has tasks of 'val', 'train' and 'speed'. Task 'train' processes the evaluation during training phase.
|
|
||||||
Task 'val' processes the evaluation purely and return the mAP of model.pt. Task 'speed' precesses the
|
|
||||||
evaluation of inference speed of model.pt.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# task
|
|
||||||
Evaler.check_task(task)
|
|
||||||
if task == 'train':
|
|
||||||
save_dir = save_dir
|
|
||||||
else:
|
|
||||||
save_dir = str(increment_name(osp.join(save_dir, name)))
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# reload thres/device/half/data according task
|
|
||||||
conf_thres, iou_thres = Evaler.reload_thres(conf_thres, iou_thres, task)
|
|
||||||
device = Evaler.reload_device(device, model, task)
|
|
||||||
half = device.type != 'cpu' and half
|
|
||||||
data = Evaler.reload_dataset(data) if isinstance(data, str) else data
|
|
||||||
|
|
||||||
# init
|
|
||||||
val = Evaler(data, batch_size, img_size, conf_thres, \
|
|
||||||
iou_thres, device, half, save_dir)
|
|
||||||
model = val.init_model(model, weights, task)
|
|
||||||
dataloader = val.init_data(dataloader, task)
|
|
||||||
|
|
||||||
# eval
|
|
||||||
model.eval()
|
|
||||||
pred_result = val.predict_model(model, dataloader, task)
|
|
||||||
eval_result = val.eval_model(pred_result, model, dataloader, task)
|
|
||||||
return eval_result
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
run(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = get_args_parser()
|
|
||||||
main(args)
|
|
@ -1,108 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
ROOT = os.getcwd()
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT))
|
|
||||||
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.core.inferer import Inferer
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser(add_help=True):
|
|
||||||
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Inference.', add_help=add_help)
|
|
||||||
parser.add_argument('--weights', type=str, default='weights/yolov6s.pt', help='model path(s) for inference.')
|
|
||||||
parser.add_argument('--source', type=str, default='data/images', help='the source path, e.g. image-file/dir.')
|
|
||||||
parser.add_argument('--yaml', type=str, default='data/coco.yaml', help='data yaml file.')
|
|
||||||
parser.add_argument('--img-size', type=int, default=640, help='the image-size(h,w) in inference size.')
|
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold for inference.')
|
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold for inference.')
|
|
||||||
parser.add_argument('--max-det', type=int, default=1000, help='maximal inferences per image.')
|
|
||||||
parser.add_argument('--device', default='0', help='device to run our model i.e. 0 or 0,1,2,3 or cpu.')
|
|
||||||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt.')
|
|
||||||
parser.add_argument('--save-img', action='store_false', help='save visuallized inference results.')
|
|
||||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by classes, e.g. --classes 0, or --classes 0 2 3.')
|
|
||||||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS.')
|
|
||||||
parser.add_argument('--project', default='runs/inference', help='save inference results to project/name.')
|
|
||||||
parser.add_argument('--name', default='exp', help='save inference results to project/name.')
|
|
||||||
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels.')
|
|
||||||
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences.')
|
|
||||||
parser.add_argument('--half', action='store_true', help='whether to use FP16 half-precision inference.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
LOGGER.info(args)
|
|
||||||
return args
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run(weights=osp.join(ROOT, 'yolov6s.pt'),
|
|
||||||
source=osp.join(ROOT, 'data/images'),
|
|
||||||
yaml=None,
|
|
||||||
img_size=640,
|
|
||||||
conf_thres=0.25,
|
|
||||||
iou_thres=0.45,
|
|
||||||
max_det=1000,
|
|
||||||
device='',
|
|
||||||
save_txt=False,
|
|
||||||
save_img=True,
|
|
||||||
classes=None,
|
|
||||||
agnostic_nms=False,
|
|
||||||
project=osp.join(ROOT, 'runs/inference'),
|
|
||||||
name='exp',
|
|
||||||
hide_labels=False,
|
|
||||||
hide_conf=False,
|
|
||||||
half=False,
|
|
||||||
):
|
|
||||||
""" Inference process
|
|
||||||
|
|
||||||
This function is the main process of inference, supporting image files or dirs containing images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weights: The path of model.pt, e.g. yolov6s.pt
|
|
||||||
source: Source path, supporting image files or dirs containing images.
|
|
||||||
yaml: Data yaml file, .
|
|
||||||
img_size: Inference image-size, e.g. 640
|
|
||||||
conf_thres: Confidence threshold in inference, e.g. 0.25
|
|
||||||
iou_thres: NMS IOU threshold in inference, e.g. 0.45
|
|
||||||
max_det: Maximal detections per image, e.g. 1000
|
|
||||||
device: Cuda device, e.e. 0, or 0,1,2,3 or cpu
|
|
||||||
save_txt: Save results to *.txt
|
|
||||||
save_img: Save visualized inference results
|
|
||||||
classes: Filter by class: --class 0, or --class 0 2 3
|
|
||||||
agnostic_nms: Class-agnostic NMS
|
|
||||||
project: Save results to project/name
|
|
||||||
name: Save results to project/name, e.g. 'exp'
|
|
||||||
line_thickness: Bounding box thickness (pixels), e.g. 3
|
|
||||||
hide_labels: Hide labels, e.g. False
|
|
||||||
hide_conf: Hide confidences
|
|
||||||
half: Use FP16 half-precision inference, e.g. False
|
|
||||||
"""
|
|
||||||
# create save dir
|
|
||||||
save_dir = osp.join(project, name)
|
|
||||||
if (save_img or save_txt) and not osp.exists(save_dir):
|
|
||||||
os.makedirs(save_dir)
|
|
||||||
else:
|
|
||||||
LOGGER.warning('Save directory already existed')
|
|
||||||
if save_txt:
|
|
||||||
os.mkdir(osp.join(save_dir, 'labels'))
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
inferer = Inferer(source, weights, device, yaml, img_size, half)
|
|
||||||
inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf)
|
|
||||||
|
|
||||||
if save_txt or save_img:
|
|
||||||
LOGGER.info(f"Results saved to {save_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
run(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = get_args_parser()
|
|
||||||
main(args)
|
|
@ -1 +0,0 @@
|
|||||||
# Coming soon
|
|
@ -1,210 +0,0 @@
|
|||||||
#
|
|
||||||
# Modified by Meituan
|
|
||||||
# 2022.6.24
|
|
||||||
#
|
|
||||||
|
|
||||||
# Copyright 2019 NVIDIA Corporation
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import glob
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import tensorrt as trt
|
|
||||||
import pycuda.driver as cuda
|
|
||||||
import pycuda.autoinit
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def preprocess_yolov6(image, channels=3, height=224, width=224):
|
|
||||||
"""Pre-processing for YOLOv6-based Object Detection Models
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
image: PIL.Image
|
|
||||||
The image resulting from PIL.Image.open(filename) to preprocess
|
|
||||||
channels: int
|
|
||||||
The number of channels the image has (Usually 1 or 3)
|
|
||||||
height: int
|
|
||||||
The desired height of the image (usually 640)
|
|
||||||
width: int
|
|
||||||
The desired width of the image (usually 640)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img_data: numpy array
|
|
||||||
The preprocessed image data in the form of a numpy array
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Get the image in CHW format
|
|
||||||
resized_image = image.resize((width, height), Image.BILINEAR)
|
|
||||||
img_data = np.asarray(resized_image).astype(np.float32)
|
|
||||||
|
|
||||||
if len(img_data.shape) == 2:
|
|
||||||
# For images without a channel dimension, we stack
|
|
||||||
img_data = np.stack([img_data] * 3)
|
|
||||||
logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape))
|
|
||||||
else:
|
|
||||||
img_data = img_data.transpose([2, 0, 1])
|
|
||||||
|
|
||||||
mean_vec = np.array([0.0, 0.0, 0.0])
|
|
||||||
stddev_vec = np.array([1.0, 1.0, 1.0])
|
|
||||||
assert img_data.shape[0] == channels
|
|
||||||
|
|
||||||
for i in range(img_data.shape[0]):
|
|
||||||
# Scale each pixel to [0, 1] and normalize per channel.
|
|
||||||
img_data[i, :, :] = (img_data[i, :, :] / 255.0 - mean_vec[i]) / stddev_vec[i]
|
|
||||||
|
|
||||||
return img_data
|
|
||||||
|
|
||||||
def get_int8_calibrator(calib_cache, calib_data, max_calib_size, calib_batch_size):
|
|
||||||
# Use calibration cache if it exists
|
|
||||||
if os.path.exists(calib_cache):
|
|
||||||
logger.info("Skipping calibration files, using calibration cache: {:}".format(calib_cache))
|
|
||||||
calib_files = []
|
|
||||||
# Use calibration files from validation dataset if no cache exists
|
|
||||||
else:
|
|
||||||
if not calib_data:
|
|
||||||
raise ValueError("ERROR: Int8 mode requested, but no calibration data provided. Please provide --calibration-data /path/to/calibration/files")
|
|
||||||
|
|
||||||
calib_files = get_calibration_files(calib_data, max_calib_size)
|
|
||||||
|
|
||||||
# Choose pre-processing function for INT8 calibration
|
|
||||||
preprocess_func = preprocess_yolov6
|
|
||||||
|
|
||||||
int8_calibrator = ImageCalibrator(calibration_files=calib_files,
|
|
||||||
batch_size=calib_batch_size,
|
|
||||||
cache_file=calib_cache)
|
|
||||||
return int8_calibrator
|
|
||||||
|
|
||||||
|
|
||||||
def get_calibration_files(calibration_data, max_calibration_size=None, allowed_extensions=(".jpeg", ".jpg", ".png")):
|
|
||||||
"""Returns a list of all filenames ending with `allowed_extensions` found in the `calibration_data` directory.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
calibration_data: str
|
|
||||||
Path to directory containing desired files.
|
|
||||||
max_calibration_size: int
|
|
||||||
Max number of files to use for calibration. If calibration_data contains more than this number,
|
|
||||||
a random sample of size max_calibration_size will be returned instead. If None, all samples will be used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
calibration_files: List[str]
|
|
||||||
List of filenames contained in the `calibration_data` directory ending with `allowed_extensions`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info("Collecting calibration files from: {:}".format(calibration_data))
|
|
||||||
calibration_files = [path for path in glob.iglob(os.path.join(calibration_data, "**"), recursive=True)
|
|
||||||
if os.path.isfile(path) and path.lower().endswith(allowed_extensions)]
|
|
||||||
logger.info("Number of Calibration Files found: {:}".format(len(calibration_files)))
|
|
||||||
|
|
||||||
if len(calibration_files) == 0:
|
|
||||||
raise Exception("ERROR: Calibration data path [{:}] contains no files!".format(calibration_data))
|
|
||||||
|
|
||||||
if max_calibration_size:
|
|
||||||
if len(calibration_files) > max_calibration_size:
|
|
||||||
logger.warning("Capping number of calibration images to max_calibration_size: {:}".format(max_calibration_size))
|
|
||||||
random.seed(42) # Set seed for reproducibility
|
|
||||||
calibration_files = random.sample(calibration_files, max_calibration_size)
|
|
||||||
|
|
||||||
return calibration_files
|
|
||||||
|
|
||||||
|
|
||||||
# https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Int8/EntropyCalibrator2.html
|
|
||||||
class ImageCalibrator(trt.IInt8EntropyCalibrator2):
|
|
||||||
"""INT8 Calibrator Class for Imagenet-based Image Classification Models.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
calibration_files: List[str]
|
|
||||||
List of image filenames to use for INT8 Calibration
|
|
||||||
batch_size: int
|
|
||||||
Number of images to pass through in one batch during calibration
|
|
||||||
input_shape: Tuple[int]
|
|
||||||
Tuple of integers defining the shape of input to the model (Default: (3, 224, 224))
|
|
||||||
cache_file: str
|
|
||||||
Name of file to read/write calibration cache from/to.
|
|
||||||
preprocess_func: function -> numpy.ndarray
|
|
||||||
Pre-processing function to run on calibration data. This should match the pre-processing
|
|
||||||
done at inference time. In general, this function should return a numpy array of
|
|
||||||
shape `input_shape`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, calibration_files=[], batch_size=32, input_shape=(3, 224, 224),
|
|
||||||
cache_file="calibration.cache", use_cv2=False):
|
|
||||||
super().__init__()
|
|
||||||
self.input_shape = input_shape
|
|
||||||
self.cache_file = cache_file
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.batch = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
|
|
||||||
self.device_input = cuda.mem_alloc(self.batch.nbytes)
|
|
||||||
|
|
||||||
self.files = calibration_files
|
|
||||||
self.use_cv2 = use_cv2
|
|
||||||
# Pad the list so it is a multiple of batch_size
|
|
||||||
if len(self.files) % self.batch_size != 0:
|
|
||||||
logger.info("Padding # calibration files to be a multiple of batch_size {:}".format(self.batch_size))
|
|
||||||
self.files += calibration_files[(len(calibration_files) % self.batch_size):self.batch_size]
|
|
||||||
|
|
||||||
self.batches = self.load_batches()
|
|
||||||
self.preprocess_func = preprocess_yolov6
|
|
||||||
|
|
||||||
def load_batches(self):
|
|
||||||
# Populates a persistent self.batch buffer with images.
|
|
||||||
for index in range(0, len(self.files), self.batch_size):
|
|
||||||
for offset in range(self.batch_size):
|
|
||||||
if self.use_cv2:
|
|
||||||
image = cv2.imread(self.files[index + offset])
|
|
||||||
else:
|
|
||||||
image = Image.open(self.files[index + offset])
|
|
||||||
self.batch[offset] = self.preprocess_func(image, *self.input_shape)
|
|
||||||
logger.info("Calibration images pre-processed: {:}/{:}".format(index+self.batch_size, len(self.files)))
|
|
||||||
yield self.batch
|
|
||||||
|
|
||||||
def get_batch_size(self):
|
|
||||||
return self.batch_size
|
|
||||||
|
|
||||||
def get_batch(self, names):
|
|
||||||
try:
|
|
||||||
# Assume self.batches is a generator that provides batch data.
|
|
||||||
batch = next(self.batches)
|
|
||||||
# Assume that self.device_input is a device buffer allocated by the constructor.
|
|
||||||
cuda.memcpy_htod(self.device_input, batch)
|
|
||||||
return [int(self.device_input)]
|
|
||||||
except StopIteration:
|
|
||||||
# When we're out of batches, we return either [] or None.
|
|
||||||
# This signals to TensorRT that there is no calibration data remaining.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def read_calibration_cache(self):
|
|
||||||
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
|
|
||||||
if os.path.exists(self.cache_file):
|
|
||||||
with open(self.cache_file, "rb") as f:
|
|
||||||
logger.info("Using calibration cache to save time: {:}".format(self.cache_file))
|
|
||||||
return f.read()
|
|
||||||
|
|
||||||
def write_calibration_cache(self, cache):
|
|
||||||
with open(self.cache_file, "wb") as f:
|
|
||||||
logger.info("Caching calibration data for future use: {:}".format(self.cache_file))
|
|
||||||
f.write(cache)
|
|
@ -1,191 +0,0 @@
|
|||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
Copyright 2020 NVIDIA Corporation
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@ -1,83 +0,0 @@
|
|||||||
# ONNX -> TensorRT INT8
|
|
||||||
These scripts were last tested using the
|
|
||||||
[NGC TensorRT Container Version 20.06-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorrt).
|
|
||||||
You can see the corresponding framework versions for this container [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-container-release-notes/rel_20.06.html#rel_20.06).
|
|
||||||
|
|
||||||
## Quickstart
|
|
||||||
|
|
||||||
> **NOTE**: This INT8 example is only valid for **fixed-shape** ONNX models at the moment.
|
|
||||||
>
|
|
||||||
INT8 Calibration on **dynamic-shape** models is now supported, however this example has not been updated
|
|
||||||
to reflect that yet. For more details on INT8 Calibration for **dynamic-shape** models, please
|
|
||||||
see the [documentation](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#int8-calib-dynamic-shapes).
|
|
||||||
|
|
||||||
### 1. Convert ONNX model to TensorRT INT8
|
|
||||||
|
|
||||||
See `./onnx_to_tensorrt.py -h` for full list of command line arguments.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./onnx_to_tensorrt.py --explicit-batch \
|
|
||||||
--onnx resnet50/model.onnx \
|
|
||||||
--fp16 \
|
|
||||||
--int8 \
|
|
||||||
--calibration-cache="caches/yolov6.cache" \
|
|
||||||
-o resnet50.int8.engine
|
|
||||||
```
|
|
||||||
|
|
||||||
See the [INT8 Calibration](#int8-calibration) section below for details on calibration
|
|
||||||
using your own model or different data, where you don't have an existing calibration cache
|
|
||||||
or want to create a new one.
|
|
||||||
|
|
||||||
## INT8 Calibration
|
|
||||||
|
|
||||||
See [ImagenetCalibrator.py](ImagenetCalibrator.py) for a reference implementation
|
|
||||||
of TensorRT's [IInt8EntropyCalibrator2](https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Int8/EntropyCalibrator2.html).
|
|
||||||
|
|
||||||
This class can be tweaked to work for other kinds of models, inputs, etc.
|
|
||||||
|
|
||||||
In the [Quickstart](#quickstart) section above, we made use of a pre-existing cache,
|
|
||||||
[caches/yolov6.cache](caches/yolov6.cache), to save time for the sake of an example.
|
|
||||||
|
|
||||||
However, to calibrate using different data or a different model, you can do so with the `--calibration-data` argument.
|
|
||||||
|
|
||||||
* This requires that you've mounted a dataset, such as Imagenet, to use for calibration.
|
|
||||||
* Add something like `-v /imagenet:/imagenet` to your Docker command in Step (1)
|
|
||||||
to mount a dataset found locally at `/imagenet`.
|
|
||||||
* You can specify your own `preprocess_func` by defining it inside of `ImageCalibrator.py`
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Path to dataset to use for calibration.
|
|
||||||
# **Not necessary if you already have a calibration cache from a previous run.
|
|
||||||
CALIBRATION_DATA="/imagenet"
|
|
||||||
|
|
||||||
# Truncate calibration images to a random sample of this amount if more are found.
|
|
||||||
# **Not necessary if you already have a calibration cache from a previous run.
|
|
||||||
MAX_CALIBRATION_SIZE=512
|
|
||||||
|
|
||||||
# Calibration cache to be used instead of calibration data if it already exists,
|
|
||||||
# or the cache will be created from the calibration data if it doesn't exist.
|
|
||||||
CACHE_FILENAME="caches/yolov6.cache"
|
|
||||||
|
|
||||||
# Path to ONNX model
|
|
||||||
ONNX_MODEL="model/yolov6.onnx"
|
|
||||||
|
|
||||||
# Path to write TensorRT engine to
|
|
||||||
OUTPUT="yolov6.int8.engine"
|
|
||||||
|
|
||||||
# Creates an int8 engine from your ONNX model, creating ${CACHE_FILENAME} based
|
|
||||||
# on your ${CALIBRATION_DATA}, unless ${CACHE_FILENAME} already exists, then
|
|
||||||
# it will use simply use that instead.
|
|
||||||
python3 onnx_to_tensorrt.py --fp16 --int8 -v \
|
|
||||||
--max_calibration_size=${MAX_CALIBRATION_SIZE} \
|
|
||||||
--calibration-data=${CALIBRATION_DATA} \
|
|
||||||
--calibration-cache=${CACHE_FILENAME} \
|
|
||||||
--preprocess_func=${PREPROCESS_FUNC} \
|
|
||||||
--explicit-batch \
|
|
||||||
--onnx ${ONNX_MODEL} -o ${OUTPUT}
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pre-processing
|
|
||||||
|
|
||||||
In order to calibrate your model correctly, you should `pre-process` your data the same way
|
|
||||||
that you would during inference.
|
|
@ -1,220 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
#
|
|
||||||
# Modified by Meituan
|
|
||||||
# 2022.6.24
|
|
||||||
#
|
|
||||||
|
|
||||||
# Copyright 2019 NVIDIA Corporation
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import glob
|
|
||||||
import math
|
|
||||||
import logging
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import tensorrt as trt
|
|
||||||
#sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
|
|
||||||
|
|
||||||
TRT_LOGGER = trt.Logger()
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def add_profiles(config, inputs, opt_profiles):
|
|
||||||
logger.debug("=== Optimization Profiles ===")
|
|
||||||
for i, profile in enumerate(opt_profiles):
|
|
||||||
for inp in inputs:
|
|
||||||
_min, _opt, _max = profile.get_shape(inp.name)
|
|
||||||
logger.debug("{} - OptProfile {} - Min {} Opt {} Max {}".format(inp.name, i, _min, _opt, _max))
|
|
||||||
config.add_optimization_profile(profile)
|
|
||||||
|
|
||||||
|
|
||||||
def mark_outputs(network):
|
|
||||||
# Mark last layer's outputs if not already marked
|
|
||||||
# NOTE: This may not be correct in all cases
|
|
||||||
last_layer = network.get_layer(network.num_layers-1)
|
|
||||||
if not last_layer.num_outputs:
|
|
||||||
logger.error("Last layer contains no outputs.")
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(last_layer.num_outputs):
|
|
||||||
network.mark_output(last_layer.get_output(i))
|
|
||||||
|
|
||||||
|
|
||||||
def check_network(network):
|
|
||||||
if not network.num_outputs:
|
|
||||||
logger.warning("No output nodes found, marking last layer's outputs as network outputs. Correct this if wrong.")
|
|
||||||
mark_outputs(network)
|
|
||||||
|
|
||||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
|
||||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
|
||||||
max_len = max([len(inp.name) for inp in inputs] + [len(out.name) for out in outputs])
|
|
||||||
|
|
||||||
logger.debug("=== Network Description ===")
|
|
||||||
for i, inp in enumerate(inputs):
|
|
||||||
logger.debug("Input {0} | Name: {1:{2}} | Shape: {3}".format(i, inp.name, max_len, inp.shape))
|
|
||||||
for i, out in enumerate(outputs):
|
|
||||||
logger.debug("Output {0} | Name: {1:{2}} | Shape: {3}".format(i, out.name, max_len, out.shape))
|
|
||||||
|
|
||||||
|
|
||||||
def get_batch_sizes(max_batch_size):
|
|
||||||
# Returns powers of 2, up to and including max_batch_size
|
|
||||||
max_exponent = math.log2(max_batch_size)
|
|
||||||
for i in range(int(max_exponent)+1):
|
|
||||||
batch_size = 2**i
|
|
||||||
yield batch_size
|
|
||||||
|
|
||||||
if max_batch_size != batch_size:
|
|
||||||
yield max_batch_size
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This only covers dynamic shape for batch size, not dynamic shape for other dimensions
|
|
||||||
def create_optimization_profiles(builder, inputs, batch_sizes=[1,8,16,32,64]):
|
|
||||||
# Check if all inputs are fixed explicit batch to create a single profile and avoid duplicates
|
|
||||||
if all([inp.shape[0] > -1 for inp in inputs]):
|
|
||||||
profile = builder.create_optimization_profile()
|
|
||||||
for inp in inputs:
|
|
||||||
fbs, shape = inp.shape[0], inp.shape[1:]
|
|
||||||
profile.set_shape(inp.name, min=(fbs, *shape), opt=(fbs, *shape), max=(fbs, *shape))
|
|
||||||
return [profile]
|
|
||||||
|
|
||||||
# Otherwise for mixed fixed+dynamic explicit batch inputs, create several profiles
|
|
||||||
profiles = {}
|
|
||||||
for bs in batch_sizes:
|
|
||||||
if not profiles.get(bs):
|
|
||||||
profiles[bs] = builder.create_optimization_profile()
|
|
||||||
|
|
||||||
for inp in inputs:
|
|
||||||
shape = inp.shape[1:]
|
|
||||||
# Check if fixed explicit batch
|
|
||||||
if inp.shape[0] > -1:
|
|
||||||
bs = inp.shape[0]
|
|
||||||
|
|
||||||
profiles[bs].set_shape(inp.name, min=(bs, *shape), opt=(bs, *shape), max=(bs, *shape))
|
|
||||||
|
|
||||||
return list(profiles.values())
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Creates a TensorRT engine from the provided ONNX file.\n")
|
|
||||||
parser.add_argument("--onnx", required=True, help="The ONNX model file to convert to TensorRT")
|
|
||||||
parser.add_argument("-o", "--output", type=str, default="model.engine", help="The path at which to write the engine")
|
|
||||||
parser.add_argument("-b", "--max-batch-size", type=int, help="The max batch size for the TensorRT engine input")
|
|
||||||
parser.add_argument("-v", "--verbosity", action="count", help="Verbosity for logging. (None) for ERROR, (-v) for INFO/WARNING/ERROR, (-vv) for VERBOSE.")
|
|
||||||
parser.add_argument("--explicit-batch", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH.")
|
|
||||||
parser.add_argument("--explicit-precision", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION.")
|
|
||||||
parser.add_argument("--gpu-fallback", action='store_true', help="Set trt.BuilderFlag.GPU_FALLBACK.")
|
|
||||||
parser.add_argument("--refittable", action='store_true', help="Set trt.BuilderFlag.REFIT.")
|
|
||||||
parser.add_argument("--debug", action='store_true', help="Set trt.BuilderFlag.DEBUG.")
|
|
||||||
parser.add_argument("--strict-types", action='store_true', help="Set trt.BuilderFlag.STRICT_TYPES.")
|
|
||||||
parser.add_argument("--fp16", action="store_true", help="Attempt to use FP16 kernels when possible.")
|
|
||||||
parser.add_argument("--int8", action="store_true", help="Attempt to use INT8 kernels when possible. This should generally be used in addition to the --fp16 flag. \
|
|
||||||
ONLY SUPPORTS RESNET-LIKE MODELS SUCH AS RESNET50/VGG16/INCEPTION/etc.")
|
|
||||||
parser.add_argument("--calibration-cache", help="(INT8 ONLY) The path to read/write from calibration cache.", default="calibration.cache")
|
|
||||||
parser.add_argument("--calibration-data", help="(INT8 ONLY) The directory containing {*.jpg, *.jpeg, *.png} files to use for calibration. (ex: Imagenet Validation Set)", default=None)
|
|
||||||
parser.add_argument("--calibration-batch-size", help="(INT8 ONLY) The batch size to use during calibration.", type=int, default=128)
|
|
||||||
parser.add_argument("--max-calibration-size", help="(INT8 ONLY) The max number of data to calibrate on from --calibration-data.", type=int, default=2048)
|
|
||||||
parser.add_argument("-s", "--simple", action="store_true", help="Use SimpleCalibrator with random data instead of ImagenetCalibrator for INT8 calibration.")
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
print(args)
|
|
||||||
|
|
||||||
# Adjust logging verbosity
|
|
||||||
if args.verbosity is None:
|
|
||||||
TRT_LOGGER.min_severity = trt.Logger.Severity.ERROR
|
|
||||||
# -v
|
|
||||||
elif args.verbosity == 1:
|
|
||||||
TRT_LOGGER.min_severity = trt.Logger.Severity.INFO
|
|
||||||
# -vv
|
|
||||||
else:
|
|
||||||
TRT_LOGGER.min_severity = trt.Logger.Severity.VERBOSE
|
|
||||||
logger.info("TRT_LOGGER Verbosity: {:}".format(TRT_LOGGER.min_severity))
|
|
||||||
|
|
||||||
# Network flags
|
|
||||||
network_flags = 0
|
|
||||||
if args.explicit_batch:
|
|
||||||
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
||||||
if args.explicit_precision:
|
|
||||||
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
|
|
||||||
|
|
||||||
builder_flag_map = {
|
|
||||||
'gpu_fallback': trt.BuilderFlag.GPU_FALLBACK,
|
|
||||||
'refittable': trt.BuilderFlag.REFIT,
|
|
||||||
'debug': trt.BuilderFlag.DEBUG,
|
|
||||||
'strict_types': trt.BuilderFlag.STRICT_TYPES,
|
|
||||||
'fp16': trt.BuilderFlag.FP16,
|
|
||||||
'int8': trt.BuilderFlag.INT8,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Building engine
|
|
||||||
with trt.Builder(TRT_LOGGER) as builder, \
|
|
||||||
builder.create_network(network_flags) as network, \
|
|
||||||
builder.create_builder_config() as config, \
|
|
||||||
trt.OnnxParser(network, TRT_LOGGER) as parser:
|
|
||||||
|
|
||||||
config.max_workspace_size = 2**30 # 1GiB
|
|
||||||
|
|
||||||
# Set Builder Config Flags
|
|
||||||
for flag in builder_flag_map:
|
|
||||||
if getattr(args, flag):
|
|
||||||
logger.info("Setting {}".format(builder_flag_map[flag]))
|
|
||||||
config.set_flag(builder_flag_map[flag])
|
|
||||||
|
|
||||||
# Fill network atrributes with information by parsing model
|
|
||||||
with open(args.onnx, "rb") as f:
|
|
||||||
if not parser.parse(f.read()):
|
|
||||||
print('ERROR: Failed to parse the ONNX file: {}'.format(args.onnx))
|
|
||||||
for error in range(parser.num_errors):
|
|
||||||
print(parser.get_error(error))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Display network info and check certain properties
|
|
||||||
check_network(network)
|
|
||||||
|
|
||||||
if args.explicit_batch:
|
|
||||||
# Add optimization profiles
|
|
||||||
batch_sizes = [1, 8, 16, 32, 64]
|
|
||||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
|
||||||
opt_profiles = create_optimization_profiles(builder, inputs, batch_sizes)
|
|
||||||
add_profiles(config, inputs, opt_profiles)
|
|
||||||
# Implicit Batch Network
|
|
||||||
else:
|
|
||||||
builder.max_batch_size = args.max_batch_size
|
|
||||||
opt_profiles = []
|
|
||||||
|
|
||||||
# Precision flags
|
|
||||||
if args.fp16 and not builder.platform_has_fast_fp16:
|
|
||||||
logger.warning("FP16 not supported on this platform.")
|
|
||||||
|
|
||||||
if args.int8 and not builder.platform_has_fast_int8:
|
|
||||||
logger.warning("INT8 not supported on this platform.")
|
|
||||||
|
|
||||||
if args.int8:
|
|
||||||
from Calibrator import ImageCalibrator, get_int8_calibrator # local module
|
|
||||||
config.int8_calibrator = get_int8_calibrator(args.calibration_cache,
|
|
||||||
args.calibration_data,
|
|
||||||
args.max_calibration_size,
|
|
||||||
args.calibration_batch_size)
|
|
||||||
|
|
||||||
logger.info("Building Engine...")
|
|
||||||
with builder.build_engine(network, config) as engine, open(args.output, "wb") as f:
|
|
||||||
logger.info("Serializing engine to file: {:}".format(args.output))
|
|
||||||
f.write(engine.serialize())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,23 +0,0 @@
|
|||||||
# Path to ONNX model
|
|
||||||
# ex: ../yolov6.onnx
|
|
||||||
ONNX_MODEL=$1
|
|
||||||
|
|
||||||
# Path to dataset to use for calibration.
|
|
||||||
# **Not necessary if you already have a calibration cache from a previous run.
|
|
||||||
CALIBRATION_DATA=$2
|
|
||||||
|
|
||||||
# Path to Cache file to Serving
|
|
||||||
# ex: ./caches/demo.cache
|
|
||||||
CACHE_FILENAME=$3
|
|
||||||
|
|
||||||
# Path to write TensorRT engine to
|
|
||||||
OUTPUT=$4
|
|
||||||
|
|
||||||
# Creates an int8 engine from your ONNX model, creating ${CACHE_FILENAME} based
|
|
||||||
# on your ${CALIBRATION_DATA}, unless ${CACHE_FILENAME} already exists, then
|
|
||||||
# it will use simply use that instead.
|
|
||||||
python3 onnx_to_tensorrt.py --fp16 --int8 -v \
|
|
||||||
--calibration-data=${CALIBRATION_DATA} \
|
|
||||||
--calibration-cache=${CACHE_FILENAME} \
|
|
||||||
--explicit-batch \
|
|
||||||
--onnx ${ONNX_MODEL} -o ${OUTPUT}
|
|
@ -1,7 +0,0 @@
|
|||||||
# pip install -r requirements.txt
|
|
||||||
# python3.8 environment
|
|
||||||
|
|
||||||
tensorrt # TensorRT 8.0+
|
|
||||||
pycuda==2020.1 # CUDA 11.0
|
|
||||||
nvidia-pyindex
|
|
||||||
pytorch-quantization
|
|
@ -1,39 +0,0 @@
|
|||||||
#
|
|
||||||
# QAT_quantizer.py
|
|
||||||
# YOLOv6
|
|
||||||
#
|
|
||||||
# Created by Meituan on 2022/06/24.
|
|
||||||
# Copyright © 2022
|
|
||||||
#
|
|
||||||
|
|
||||||
from absl import logging
|
|
||||||
from pytorch_quantization import nn as quant_nn
|
|
||||||
from pytorch_quantization import quant_modules
|
|
||||||
|
|
||||||
# Call this function before defining the model
|
|
||||||
def tensorrt_official_qat():
|
|
||||||
# Quantization Aware Training is based on Straight Through Estimator (STE) derivative approximation.
|
|
||||||
# It is some time known as “quantization aware training”.
|
|
||||||
|
|
||||||
# PyTorch-Quantization is a toolkit for training and evaluating PyTorch models with simulated quantization.
|
|
||||||
# Quantization can be added to the model automatically, or manually, allowing the model to be tuned for accuracy and performance.
|
|
||||||
# Quantization is compatible with NVIDIAs high performance integer kernels which leverage integer Tensor Cores.
|
|
||||||
# The quantized model can be exported to ONNX and imported by TensorRT 8.0 and later.
|
|
||||||
# https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/finetune_quant_resnet50.ipynb
|
|
||||||
|
|
||||||
# The example to export the
|
|
||||||
# model.eval()
|
|
||||||
# quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX
|
|
||||||
# opset_version = 13
|
|
||||||
|
|
||||||
# Export ONNX for multiple batch sizes
|
|
||||||
# print("Creating ONNX file: " + onnx_filename)
|
|
||||||
# dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model
|
|
||||||
# torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True)
|
|
||||||
try:
|
|
||||||
quant_modules.initialize()
|
|
||||||
except NameError:
|
|
||||||
logging.info("initialzation error for quant_modules")
|
|
||||||
|
|
||||||
# def QAT_quantizer():
|
|
||||||
# coming soon
|
|
@ -1,94 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import sys
|
|
||||||
|
|
||||||
ROOT = os.getcwd()
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT))
|
|
||||||
|
|
||||||
from yolov6.core.engine import Trainer
|
|
||||||
from yolov6.utils.config import Config
|
|
||||||
from yolov6.utils.events import LOGGER, save_yaml
|
|
||||||
from yolov6.utils.envs import get_envs, select_device, set_random_seed
|
|
||||||
from yolov6.utils.general import increment_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser(add_help=True):
|
|
||||||
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help)
|
|
||||||
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='path of dataset')
|
|
||||||
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiments description file')
|
|
||||||
parser.add_argument('--img-size', default=640, type=int, help='train, val image size (pixels)')
|
|
||||||
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs')
|
|
||||||
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run')
|
|
||||||
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)')
|
|
||||||
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
||||||
parser.add_argument('--eval-interval', default=20, type=int, help='evaluate at every interval epochs')
|
|
||||||
parser.add_argument('--eval-final-only', action='store_true', help='only evaluate at the final epoch')
|
|
||||||
parser.add_argument('--heavy-eval-range', default=50, type=int,
|
|
||||||
help='evaluating every epoch for last such epochs (can be jointly used with --eval-interval)')
|
|
||||||
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets')
|
|
||||||
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets')
|
|
||||||
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs')
|
|
||||||
parser.add_argument('--name', default='exp', type=str, help='experiment name, saved to output_dir/name')
|
|
||||||
parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
|
|
||||||
parser.add_argument('--gpu_count', type=int, default=0)
|
|
||||||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')
|
|
||||||
parser.add_argument('--resume', type=str, default=None, help='resume the corresponding ckpt')
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def check_and_init(args):
|
|
||||||
'''check config files and device, and initialize '''
|
|
||||||
|
|
||||||
# check files
|
|
||||||
master_process = args.rank == 0 if args.world_size > 1 else args.rank == -1
|
|
||||||
args.save_dir = str(increment_name(osp.join(args.output_dir, args.name), master_process))
|
|
||||||
cfg = Config.fromfile(args.conf_file)
|
|
||||||
|
|
||||||
# check device
|
|
||||||
device = select_device(args.device)
|
|
||||||
|
|
||||||
# set random seed
|
|
||||||
set_random_seed(1+args.rank, deterministic=(args.rank == -1))
|
|
||||||
|
|
||||||
# save args
|
|
||||||
if master_process:
|
|
||||||
os.makedirs(args.save_dir)
|
|
||||||
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))
|
|
||||||
|
|
||||||
return cfg, device
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
'''main function of training'''
|
|
||||||
# Setup
|
|
||||||
args.rank, args.local_rank, args.world_size = get_envs()
|
|
||||||
LOGGER.info(f'training args are: {args}\n')
|
|
||||||
cfg, device = check_and_init(args)
|
|
||||||
|
|
||||||
if args.local_rank != -1: # if DDP mode
|
|
||||||
torch.cuda.set_device(args.local_rank)
|
|
||||||
device = torch.device('cuda', args.local_rank)
|
|
||||||
LOGGER.info('Initializing process group... ')
|
|
||||||
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", \
|
|
||||||
init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size)
|
|
||||||
|
|
||||||
# Start
|
|
||||||
trainer = Trainer(args, cfg, device)
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# End
|
|
||||||
if args.world_size > 1 and args.rank == 0:
|
|
||||||
LOGGER.info('Destroying process group... ')
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = get_args_parser().parse_args()
|
|
||||||
main(args)
|
|
@ -1,276 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from copy import deepcopy
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.cuda import amp
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
import tools.eval as eval
|
|
||||||
from yolov6.data.data_load import create_dataloader
|
|
||||||
from yolov6.models.yolo import build_model
|
|
||||||
from yolov6.models.loss import ComputeLoss
|
|
||||||
from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog
|
|
||||||
from yolov6.utils.ema import ModelEMA, de_parallel
|
|
||||||
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
|
|
||||||
from yolov6.solver.build import build_optimizer, build_lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(self, args, cfg, device):
|
|
||||||
self.args = args
|
|
||||||
self.cfg = cfg
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.rank = args.rank
|
|
||||||
self.local_rank = args.local_rank
|
|
||||||
self.world_size = args.world_size
|
|
||||||
self.main_process = self.rank in [-1, 0]
|
|
||||||
self.save_dir = args.save_dir
|
|
||||||
# get data loader
|
|
||||||
self.data_dict = load_yaml(args.data_path)
|
|
||||||
self.num_classes = self.data_dict['nc']
|
|
||||||
self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
|
|
||||||
# get model and optimizer
|
|
||||||
model = self.get_model(args, cfg, self.num_classes, device)
|
|
||||||
self.optimizer = self.get_optimizer(args, cfg, model)
|
|
||||||
self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
|
|
||||||
self.ema = ModelEMA(model) if self.main_process else None
|
|
||||||
self.model = self.parallel_model(args, model, device)
|
|
||||||
self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']
|
|
||||||
# tensorboard
|
|
||||||
self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None
|
|
||||||
|
|
||||||
self.start_epoch = 0
|
|
||||||
|
|
||||||
# resume ckpt from user-defined path
|
|
||||||
if args.resume:
|
|
||||||
assert os.path.isfile(args.resume), 'ERROR: --resume checkpoint does not exists'
|
|
||||||
self.ckpt = torch.load(args.resume, map_location='cpu')
|
|
||||||
self.start_epoch = self.ckpt['epoch'] + 1
|
|
||||||
|
|
||||||
self.max_epoch = args.epochs
|
|
||||||
self.max_stepnum = len(self.train_loader)
|
|
||||||
self.batch_size = args.batch_size
|
|
||||||
self.img_size = args.img_size
|
|
||||||
|
|
||||||
# Training Process
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
try:
|
|
||||||
self.train_before_loop()
|
|
||||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
|
||||||
self.train_in_loop()
|
|
||||||
|
|
||||||
except Exception as _:
|
|
||||||
LOGGER.error('ERROR in training loop or eval/save model.')
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.train_after_loop()
|
|
||||||
|
|
||||||
# Training loop for each epoch
|
|
||||||
def train_in_loop(self):
|
|
||||||
try:
|
|
||||||
self.prepare_for_steps()
|
|
||||||
for self.step, self.batch_data in self.pbar:
|
|
||||||
self.train_in_steps()
|
|
||||||
self.print_details()
|
|
||||||
except Exception as _:
|
|
||||||
LOGGER.error('ERROR in training steps.')
|
|
||||||
raise
|
|
||||||
try:
|
|
||||||
self.eval_and_save()
|
|
||||||
except Exception as _:
|
|
||||||
LOGGER.error('ERROR in evaluate and save model.')
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Training loop for batchdata
|
|
||||||
def train_in_steps(self):
|
|
||||||
images, targets = self.prepro_data(self.batch_data, self.device)
|
|
||||||
# forward
|
|
||||||
with amp.autocast(enabled=self.device != 'cpu'):
|
|
||||||
preds = self.model(images)
|
|
||||||
total_loss, loss_items = self.compute_loss(preds, targets)
|
|
||||||
if self.rank != -1:
|
|
||||||
total_loss *= self.world_size
|
|
||||||
# backward
|
|
||||||
self.scaler.scale(total_loss).backward()
|
|
||||||
self.loss_items = loss_items
|
|
||||||
self.update_optimizer()
|
|
||||||
|
|
||||||
def eval_and_save(self):
|
|
||||||
remaining_epochs = self.max_epoch - self.epoch
|
|
||||||
eval_interval = self.args.eval_interval if remaining_epochs > self.args.heavy_eval_range else 1
|
|
||||||
is_val_epoch = (not self.args.eval_final_only or (remaining_epochs == 1)) and (self.epoch % eval_interval == 0)
|
|
||||||
if self.main_process:
|
|
||||||
self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model
|
|
||||||
if is_val_epoch:
|
|
||||||
self.eval_model()
|
|
||||||
self.ap = self.evaluate_results[0] * 0.1 + self.evaluate_results[1] * 0.9
|
|
||||||
self.best_ap = max(self.ap, self.best_ap)
|
|
||||||
# save ckpt
|
|
||||||
ckpt = {
|
|
||||||
'model': deepcopy(de_parallel(self.model)).half(),
|
|
||||||
'ema': deepcopy(self.ema.ema).half(),
|
|
||||||
'updates': self.ema.updates,
|
|
||||||
'optimizer': self.optimizer.state_dict(),
|
|
||||||
'epoch': self.epoch,
|
|
||||||
}
|
|
||||||
|
|
||||||
save_ckpt_dir = osp.join(self.save_dir, 'weights')
|
|
||||||
save_checkpoint(ckpt, (is_val_epoch) and (self.ap == self.best_ap), save_ckpt_dir, model_name='last_ckpt')
|
|
||||||
del ckpt
|
|
||||||
# log for tensorboard
|
|
||||||
write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)
|
|
||||||
|
|
||||||
def eval_model(self):
|
|
||||||
results = eval.run(self.data_dict,
|
|
||||||
batch_size=self.batch_size // self.world_size * 2,
|
|
||||||
img_size=self.img_size,
|
|
||||||
model=self.ema.ema,
|
|
||||||
dataloader=self.val_loader,
|
|
||||||
save_dir=self.save_dir,
|
|
||||||
task='train')
|
|
||||||
|
|
||||||
LOGGER.info(f"Epoch: {self.epoch} | mAP@0.5: {results[0]} | mAP@0.50:0.95: {results[1]}")
|
|
||||||
self.evaluate_results = results[:2]
|
|
||||||
|
|
||||||
def train_before_loop(self):
|
|
||||||
LOGGER.info('Training start...')
|
|
||||||
self.start_time = time.time()
|
|
||||||
self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000)
|
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1
|
|
||||||
self.last_opt_step = -1
|
|
||||||
self.scaler = amp.GradScaler(enabled=self.device != 'cpu')
|
|
||||||
|
|
||||||
self.best_ap, self.ap = 0.0, 0.0
|
|
||||||
self.evaluate_results = (0, 0) # AP50, AP50_95
|
|
||||||
self.compute_loss = ComputeLoss(iou_type=self.cfg.model.head.iou_type)
|
|
||||||
|
|
||||||
if hasattr(self, "ckpt"):
|
|
||||||
resume_state_dict = self.ckpt['model'].float().state_dict() # checkpoint's state_dict as FP32
|
|
||||||
self.model.load_state_dict(resume_state_dict, strict=True) # load model state dict
|
|
||||||
self.optimizer.load_state_dict(self.ckpt['optimizer']) # load optimizer
|
|
||||||
self.start_epoch = self.ckpt['epoch'] + 1
|
|
||||||
self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict()) # load ema state dict
|
|
||||||
self.ema.updates = self.ckpt['updates']
|
|
||||||
|
|
||||||
def prepare_for_steps(self):
|
|
||||||
if self.epoch > self.start_epoch:
|
|
||||||
self.scheduler.step()
|
|
||||||
self.model.train()
|
|
||||||
if self.rank != -1:
|
|
||||||
self.train_loader.sampler.set_epoch(self.epoch)
|
|
||||||
self.mean_loss = torch.zeros(4, device=self.device)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
LOGGER.info(('\n' + '%10s' * 5) % ('Epoch', 'iou_loss', 'l1_loss', 'obj_loss', 'cls_loss'))
|
|
||||||
self.pbar = enumerate(self.train_loader)
|
|
||||||
if self.main_process:
|
|
||||||
self.pbar = tqdm(self.pbar, total=self.max_stepnum, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
|
||||||
|
|
||||||
# Print loss after each steps
|
|
||||||
def print_details(self):
|
|
||||||
if self.main_process:
|
|
||||||
self.mean_loss = (self.mean_loss * self.step + self.loss_items) / (self.step + 1)
|
|
||||||
self.pbar.set_description(('%10s' + '%10.4g' * 4) % (f'{self.epoch}/{self.max_epoch - 1}', \
|
|
||||||
*(self.mean_loss)))
|
|
||||||
|
|
||||||
# Empty cache if training finished
|
|
||||||
def train_after_loop(self):
|
|
||||||
if self.main_process:
|
|
||||||
LOGGER.info(f'\nTraining completed in {(time.time() - self.start_time) / 3600:.3f} hours.')
|
|
||||||
save_ckpt_dir = osp.join(self.save_dir, 'weights')
|
|
||||||
strip_optimizer(save_ckpt_dir, self.epoch) # strip optimizers for saved pt model
|
|
||||||
if self.device != 'cpu':
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def update_optimizer(self):
|
|
||||||
curr_step = self.step + self.max_stepnum * self.epoch
|
|
||||||
self.accumulate = max(1, round(64 / self.batch_size))
|
|
||||||
if curr_step <= self.warmup_stepnum:
|
|
||||||
self.accumulate = max(1, np.interp(curr_step, [0, self.warmup_stepnum], [1, 64 / self.batch_size]).round())
|
|
||||||
for k, param in enumerate(self.optimizer.param_groups):
|
|
||||||
warmup_bias_lr = self.cfg.solver.warmup_bias_lr if k == 2 else 0.0
|
|
||||||
param['lr'] = np.interp(curr_step, [0, self.warmup_stepnum], [warmup_bias_lr, param['initial_lr'] * self.lf(self.epoch)])
|
|
||||||
if 'momentum' in param:
|
|
||||||
param['momentum'] = np.interp(curr_step, [0, self.warmup_stepnum], [self.cfg.solver.warmup_momentum, self.cfg.solver.momentum])
|
|
||||||
if curr_step - self.last_opt_step >= self.accumulate:
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
if self.ema:
|
|
||||||
self.ema.update(self.model)
|
|
||||||
self.last_opt_step = curr_step
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_data_loader(args, cfg, data_dict):
|
|
||||||
train_path, val_path = data_dict['train'], data_dict['val']
|
|
||||||
# check data
|
|
||||||
nc = int(data_dict['nc'])
|
|
||||||
class_names = data_dict['names']
|
|
||||||
assert len(class_names) == nc, f'the length of class names does not match the number of classes defined'
|
|
||||||
grid_size = max(int(max(cfg.model.head.strides)), 32)
|
|
||||||
# create train dataloader
|
|
||||||
train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,
|
|
||||||
hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,
|
|
||||||
workers=args.workers, shuffle=True, check_images=args.check_images,
|
|
||||||
check_labels=args.check_labels, data_dict=data_dict, task='train')[0]
|
|
||||||
# create val dataloader
|
|
||||||
val_loader = None
|
|
||||||
if args.rank in [-1, 0]:
|
|
||||||
val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,
|
|
||||||
hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
|
|
||||||
workers=args.workers, check_images=args.check_images,
|
|
||||||
check_labels=args.check_labels, data_dict=data_dict, task='val')[0]
|
|
||||||
|
|
||||||
return train_loader, val_loader
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepro_data(batch_data, device):
|
|
||||||
images = batch_data[0].to(device, non_blocking=True).float() / 255
|
|
||||||
targets = batch_data[1].to(device)
|
|
||||||
return images, targets
|
|
||||||
|
|
||||||
def get_model(self, args, cfg, nc, device):
|
|
||||||
model = build_model(cfg, nc, device)
|
|
||||||
weights = cfg.model.pretrained
|
|
||||||
if weights: # finetune if pretrained model is set
|
|
||||||
LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
|
|
||||||
model = load_state_dict(weights, model, map_location=device)
|
|
||||||
LOGGER.info('Model: {}'.format(model))
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parallel_model(args, model, device):
|
|
||||||
# If DP mode
|
|
||||||
dp_mode = device.type != 'cpu' and args.rank == -1
|
|
||||||
if dp_mode and torch.cuda.device_count() > 1:
|
|
||||||
LOGGER.warning('WARNING: DP not recommended, use DDP instead.\n')
|
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
|
|
||||||
# If DDP mode
|
|
||||||
ddp_mode = device.type != 'cpu' and args.rank != -1
|
|
||||||
if ddp_mode:
|
|
||||||
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_optimizer(self, args, cfg, model):
|
|
||||||
accumulate = max(1, round(64 / args.batch_size))
|
|
||||||
cfg.solver.weight_decay *= args.batch_size * accumulate / 64
|
|
||||||
optimizer = build_optimizer(cfg, model)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_lr_scheduler(args, cfg, optimizer):
|
|
||||||
epochs = args.epochs
|
|
||||||
lr_scheduler, lf = build_lr_scheduler(cfg, optimizer, epochs)
|
|
||||||
return lr_scheduler, lf
|
|
@ -1,256 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pycocotools.coco import COCO
|
|
||||||
from pycocotools.cocoeval import COCOeval
|
|
||||||
|
|
||||||
from yolov6.data.data_load import create_dataloader
|
|
||||||
from yolov6.utils.events import LOGGER, NCOLS
|
|
||||||
from yolov6.utils.nms import non_max_suppression
|
|
||||||
from yolov6.utils.checkpoint import load_checkpoint
|
|
||||||
from yolov6.utils.torch_utils import time_sync, get_model_info
|
|
||||||
|
|
||||||
'''
|
|
||||||
python tools/eval.py --task 'train'/'val'/'speed'
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
class Evaler:
|
|
||||||
def __init__(self,
|
|
||||||
data,
|
|
||||||
batch_size=32,
|
|
||||||
img_size=640,
|
|
||||||
conf_thres=0.001,
|
|
||||||
iou_thres=0.65,
|
|
||||||
device='',
|
|
||||||
half=True,
|
|
||||||
save_dir=''):
|
|
||||||
self.data = data
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.img_size = img_size
|
|
||||||
self.conf_thres = conf_thres
|
|
||||||
self.iou_thres = iou_thres
|
|
||||||
self.device = device
|
|
||||||
self.half = half
|
|
||||||
self.save_dir = save_dir
|
|
||||||
|
|
||||||
def init_model(self, model, weights, task):
|
|
||||||
if task != 'train':
|
|
||||||
model = load_checkpoint(weights, map_location=self.device)
|
|
||||||
self.stride = int(model.stride.max())
|
|
||||||
if self.device.type != 'cpu':
|
|
||||||
model(torch.zeros(1, 3, self.img_size, self.img_size).to(self.device).type_as(next(model.parameters())))
|
|
||||||
# switch to deploy
|
|
||||||
from yolov6.layers.common import RepVGGBlock
|
|
||||||
for layer in model.modules():
|
|
||||||
if isinstance(layer, RepVGGBlock):
|
|
||||||
layer.switch_to_deploy()
|
|
||||||
LOGGER.info("Switch model to deploy modality.")
|
|
||||||
LOGGER.info("Model Summary: {}".format(get_model_info(model, self.img_size)))
|
|
||||||
model.half() if self.half else model.float()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def init_data(self, dataloader, task):
|
|
||||||
'''Initialize dataloader.
|
|
||||||
Returns a dataloader for task val or speed.
|
|
||||||
'''
|
|
||||||
self.is_coco = self.data.get("is_coco", False)
|
|
||||||
self.ids = self.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
|
||||||
if task != 'train':
|
|
||||||
pad = 0.0 if task == 'speed' else 0.5
|
|
||||||
dataloader = create_dataloader(self.data[task if task in ('train', 'val', 'test') else 'val'],
|
|
||||||
self.img_size, self.batch_size, self.stride, check_labels=True, pad=pad, rect=True,
|
|
||||||
data_dict=self.data, task=task)[0]
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
def predict_model(self, model, dataloader, task):
|
|
||||||
'''Model prediction
|
|
||||||
Predicts the whole dataset and gets the prediced results and inference time.
|
|
||||||
'''
|
|
||||||
self.speed_result = torch.zeros(4, device=self.device)
|
|
||||||
pred_results = []
|
|
||||||
pbar = tqdm(dataloader, desc="Inferencing model in val datasets.", ncols=NCOLS)
|
|
||||||
for imgs, targets, paths, shapes in pbar:
|
|
||||||
# pre-process
|
|
||||||
t1 = time_sync()
|
|
||||||
imgs = imgs.to(self.device, non_blocking=True)
|
|
||||||
imgs = imgs.half() if self.half else imgs.float()
|
|
||||||
imgs /= 255
|
|
||||||
self.speed_result[1] += time_sync() - t1 # pre-process time
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
t2 = time_sync()
|
|
||||||
outputs = model(imgs)
|
|
||||||
self.speed_result[2] += time_sync() - t2 # inference time
|
|
||||||
|
|
||||||
# post-process
|
|
||||||
t3 = time_sync()
|
|
||||||
outputs = non_max_suppression(outputs, self.conf_thres, self.iou_thres, multi_label=True)
|
|
||||||
self.speed_result[3] += time_sync() - t3 # post-process time
|
|
||||||
self.speed_result[0] += len(outputs)
|
|
||||||
|
|
||||||
# save result
|
|
||||||
pred_results.extend(self.convert_to_coco_format(outputs, imgs, paths, shapes, self.ids))
|
|
||||||
return pred_results
|
|
||||||
|
|
||||||
def eval_model(self, pred_results, model, dataloader, task):
|
|
||||||
'''Evaluate models
|
|
||||||
For task speed, this function only evaluates the speed of model and outputs inference time.
|
|
||||||
For task val, this function evaluates the speed and mAP by pycocotools, and returns
|
|
||||||
inference time and mAP value.
|
|
||||||
'''
|
|
||||||
LOGGER.info(f'\nEvaluating speed.')
|
|
||||||
self.eval_speed(task)
|
|
||||||
|
|
||||||
LOGGER.info(f'\nEvaluating mAP by pycocotools.')
|
|
||||||
if task != 'speed' and len(pred_results):
|
|
||||||
if 'anno_path' in self.data:
|
|
||||||
anno_json = self.data['anno_path']
|
|
||||||
else:
|
|
||||||
# generated coco format labels in dataset initialization
|
|
||||||
dataset_root = os.path.dirname(os.path.dirname(self.data['val']))
|
|
||||||
base_name = os.path.basename(self.data['val'])
|
|
||||||
anno_json = os.path.join(dataset_root, 'annotations', f'instances_{base_name}.json')
|
|
||||||
pred_json = os.path.join(self.save_dir, "predictions.json")
|
|
||||||
LOGGER.info(f'Saving {pred_json}...')
|
|
||||||
with open(pred_json, 'w') as f:
|
|
||||||
json.dump(pred_results, f)
|
|
||||||
|
|
||||||
anno = COCO(anno_json)
|
|
||||||
pred = anno.loadRes(pred_json)
|
|
||||||
cocoEval = COCOeval(anno, pred, 'bbox')
|
|
||||||
if self.is_coco:
|
|
||||||
imgIds = [int(os.path.basename(x).split(".")[0])
|
|
||||||
for x in dataloader.dataset.img_paths]
|
|
||||||
cocoEval.params.imgIds = imgIds
|
|
||||||
cocoEval.evaluate()
|
|
||||||
cocoEval.accumulate()
|
|
||||||
cocoEval.summarize()
|
|
||||||
map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
|
|
||||||
# Return results
|
|
||||||
model.float() # for training
|
|
||||||
if task != 'train':
|
|
||||||
LOGGER.info(f"Results saved to {self.save_dir}")
|
|
||||||
return (map50, map)
|
|
||||||
return (0.0, 0.0)
|
|
||||||
|
|
||||||
def eval_speed(self, task):
|
|
||||||
'''Evaluate model inference speed.'''
|
|
||||||
if task != 'train':
|
|
||||||
n_samples = self.speed_result[0].item()
|
|
||||||
pre_time, inf_time, nms_time = 1000 * self.speed_result[1:].cpu().numpy() / n_samples
|
|
||||||
for n, v in zip(["pre-process", "inference", "NMS"],[pre_time, inf_time, nms_time]):
|
|
||||||
LOGGER.info("Average {} time: {:.2f} ms".format(n, v))
|
|
||||||
|
|
||||||
def box_convert(self, x):
|
|
||||||
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
||||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
|
||||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
|
||||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
|
||||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
|
||||||
return y
|
|
||||||
|
|
||||||
def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None):
|
|
||||||
# Rescale coords (xyxy) from img1_shape to img0_shape
|
|
||||||
if ratio_pad is None: # calculate from img0_shape
|
|
||||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
|
||||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
|
||||||
else:
|
|
||||||
gain = ratio_pad[0][0]
|
|
||||||
pad = ratio_pad[1]
|
|
||||||
|
|
||||||
coords[:, [0, 2]] -= pad[0] # x padding
|
|
||||||
coords[:, [1, 3]] -= pad[1] # y padding
|
|
||||||
coords[:, :4] /= gain
|
|
||||||
if isinstance(coords, torch.Tensor): # faster individually
|
|
||||||
coords[:, 0].clamp_(0, img0_shape[1]) # x1
|
|
||||||
coords[:, 1].clamp_(0, img0_shape[0]) # y1
|
|
||||||
coords[:, 2].clamp_(0, img0_shape[1]) # x2
|
|
||||||
coords[:, 3].clamp_(0, img0_shape[0]) # y2
|
|
||||||
else: # np.array (faster grouped)
|
|
||||||
coords[:, [0, 2]] = coords[:, [0, 2]].clip(0, img0_shape[1]) # x1, x2
|
|
||||||
coords[:, [1, 3]] = coords[:, [1, 3]].clip(0, img0_shape[0]) # y1, y2
|
|
||||||
return coords
|
|
||||||
|
|
||||||
def convert_to_coco_format(self, outputs, imgs, paths, shapes, ids):
|
|
||||||
pred_results = []
|
|
||||||
for i, pred in enumerate(outputs):
|
|
||||||
if len(pred) == 0:
|
|
||||||
continue
|
|
||||||
path, shape = Path(paths[i]), shapes[i][0]
|
|
||||||
self.scale_coords(imgs[i].shape[1:], pred[:, :4], shape, shapes[i][1])
|
|
||||||
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
|
|
||||||
bboxes = self.box_convert(pred[:, 0:4])
|
|
||||||
bboxes[:, :2] -= bboxes[:, 2:] / 2
|
|
||||||
cls = pred[:, 5]
|
|
||||||
scores = pred[:, 4]
|
|
||||||
for ind in range(pred.shape[0]):
|
|
||||||
category_id = ids[int(cls[ind])]
|
|
||||||
bbox = [round(x, 3) for x in bboxes[ind].tolist()]
|
|
||||||
score = round(scores[ind].item(), 5)
|
|
||||||
pred_data = {
|
|
||||||
"image_id": image_id,
|
|
||||||
"category_id": category_id,
|
|
||||||
"bbox": bbox,
|
|
||||||
"score": score
|
|
||||||
}
|
|
||||||
pred_results.append(pred_data)
|
|
||||||
return pred_results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_task(task):
|
|
||||||
if task not in ['train','val','speed']:
|
|
||||||
raise Exception("task argument error: only support 'train' / 'val' / 'speed' task.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reload_thres(conf_thres, iou_thres, task):
|
|
||||||
'''Sets conf and iou threshold for task val/speed'''
|
|
||||||
if task != 'train':
|
|
||||||
if task == 'val':
|
|
||||||
conf_thres = 0.001
|
|
||||||
if task == 'speed':
|
|
||||||
conf_thres = 0.25
|
|
||||||
iou_thres = 0.45
|
|
||||||
return conf_thres, iou_thres
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reload_device(device, model, task):
|
|
||||||
# device = 'cpu' or '0' or '0,1,2,3'
|
|
||||||
if task == 'train':
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
else:
|
|
||||||
if device == 'cpu':
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
||||||
elif device:
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
|
||||||
assert torch.cuda.is_available()
|
|
||||||
cuda = device != 'cpu' and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
return device
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reload_dataset(data):
|
|
||||||
with open(data, errors='ignore') as yaml_file:
|
|
||||||
data = yaml.safe_load(yaml_file)
|
|
||||||
val = data.get('val')
|
|
||||||
if not os.path.exists(val):
|
|
||||||
raise Exception('Dataset not found.')
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
|
||||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
|
||||||
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
|
|
||||||
21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
|
||||||
41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
|
|
||||||
59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79,
|
|
||||||
80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
|
||||||
return x
|
|
@ -1,193 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import math
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
from PIL import ImageFont
|
|
||||||
|
|
||||||
from yolov6.utils.events import LOGGER, load_yaml
|
|
||||||
from yolov6.layers.common import DetectBackend
|
|
||||||
from yolov6.data.data_augment import letterbox
|
|
||||||
from yolov6.utils.nms import non_max_suppression
|
|
||||||
|
|
||||||
|
|
||||||
class Inferer:
|
|
||||||
def __init__(self, source, weights, device, yaml, img_size, half):
|
|
||||||
import glob
|
|
||||||
from yolov6.data.datasets import IMG_FORMATS
|
|
||||||
|
|
||||||
self.__dict__.update(locals())
|
|
||||||
|
|
||||||
# Init model
|
|
||||||
self.device = device
|
|
||||||
self.img_size = img_size
|
|
||||||
cuda = self.device != 'cpu' and torch.cuda.is_available()
|
|
||||||
self.device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
self.model = DetectBackend(weights, device=self.device)
|
|
||||||
self.stride = self.model.stride
|
|
||||||
self.class_names = load_yaml(yaml)['names']
|
|
||||||
self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
|
|
||||||
|
|
||||||
# Half precision
|
|
||||||
if half & (self.device.type != 'cpu'):
|
|
||||||
self.model.model.half()
|
|
||||||
else:
|
|
||||||
self.model.model.float()
|
|
||||||
half = False
|
|
||||||
|
|
||||||
if self.device.type != 'cpu':
|
|
||||||
self.model(torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
if os.path.isdir(source):
|
|
||||||
img_paths = sorted(glob.glob(os.path.join(source, '*.*'))) # dir
|
|
||||||
elif os.path.isfile(source):
|
|
||||||
img_paths = [source] # files
|
|
||||||
else:
|
|
||||||
raise Exception(f'Invalid path: {source}')
|
|
||||||
self.img_paths = [img_path for img_path in img_paths if img_path.split('.')[-1].lower() in IMG_FORMATS]
|
|
||||||
|
|
||||||
def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf):
|
|
||||||
''' Model Inference and results visualization '''
|
|
||||||
|
|
||||||
for img_path in tqdm(self.img_paths):
|
|
||||||
img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half)
|
|
||||||
img = img.to(self.device)
|
|
||||||
if len(img.shape) == 3:
|
|
||||||
img = img[None]
|
|
||||||
# expand for batch dim
|
|
||||||
pred_results = self.model(img)
|
|
||||||
det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
|
|
||||||
|
|
||||||
save_path = osp.join(save_dir, osp.basename(img_path)) # im.jpg
|
|
||||||
txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])
|
|
||||||
|
|
||||||
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
|
||||||
img_ori = img_src
|
|
||||||
|
|
||||||
# check image and font
|
|
||||||
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).'
|
|
||||||
self.font_check()
|
|
||||||
|
|
||||||
if len(det):
|
|
||||||
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
|
|
||||||
|
|
||||||
for *xyxy, conf, cls in reversed(det):
|
|
||||||
if save_txt: # Write to file
|
|
||||||
xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
|
||||||
line = (cls, *xywh, conf)
|
|
||||||
with open(txt_path + '.txt', 'a') as f:
|
|
||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
|
||||||
|
|
||||||
if save_img:
|
|
||||||
class_num = int(cls) # integer class
|
|
||||||
label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
|
|
||||||
|
|
||||||
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
|
|
||||||
|
|
||||||
img_src = np.asarray(img_ori)
|
|
||||||
|
|
||||||
# Save results (image with detections)
|
|
||||||
if save_img:
|
|
||||||
cv2.imwrite(save_path, img_src)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def precess_image(path, img_size, stride, half):
|
|
||||||
'''Process image before image inference.'''
|
|
||||||
try:
|
|
||||||
img_src = cv2.imread(path)
|
|
||||||
assert img_src is not None, f'Invalid image: {path}'
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.Warning(e)
|
|
||||||
image = letterbox(img_src, img_size, stride=stride)[0]
|
|
||||||
|
|
||||||
# Convert
|
|
||||||
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
image = torch.from_numpy(np.ascontiguousarray(image))
|
|
||||||
image = image.half() if half else image.float() # uint8 to fp16/32
|
|
||||||
image /= 255 # 0 - 255 to 0.0 - 1.0
|
|
||||||
|
|
||||||
return image, img_src
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rescale(ori_shape, boxes, target_shape):
|
|
||||||
'''Rescale the output to the original image shape'''
|
|
||||||
ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
|
|
||||||
padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
|
|
||||||
|
|
||||||
boxes[:, [0, 2]] -= padding[0]
|
|
||||||
boxes[:, [1, 3]] -= padding[1]
|
|
||||||
boxes[:, :4] /= ratio
|
|
||||||
|
|
||||||
boxes[:, 0].clamp_(0, target_shape[1]) # x1
|
|
||||||
boxes[:, 1].clamp_(0, target_shape[0]) # y1
|
|
||||||
boxes[:, 2].clamp_(0, target_shape[1]) # x2
|
|
||||||
boxes[:, 3].clamp_(0, target_shape[0]) # y2
|
|
||||||
|
|
||||||
return boxes
|
|
||||||
|
|
||||||
def check_img_size(self, img_size, s=32, floor=0):
|
|
||||||
"""Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
|
|
||||||
if isinstance(img_size, int): # integer i.e. img_size=640
|
|
||||||
new_size = max(self.make_divisible(img_size, int(s)), floor)
|
|
||||||
elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
|
|
||||||
new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unsupported type of img_size: {type(img_size)}")
|
|
||||||
|
|
||||||
if new_size != img_size:
|
|
||||||
print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}')
|
|
||||||
return new_size if isinstance(img_size,list) else [new_size]*2
|
|
||||||
|
|
||||||
def make_divisible(self, x, divisor):
|
|
||||||
# Upward revision the value x to make it evenly divisible by the divisor.
|
|
||||||
return math.ceil(x / divisor) * divisor
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def plot_box_and_label(image, lw, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
|
||||||
# Add one xyxy box to image with label
|
|
||||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
|
||||||
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
|
|
||||||
if label:
|
|
||||||
tf = max(lw - 1, 1) # font thickness
|
|
||||||
w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
|
|
||||||
outside = p1[1] - h - 3 >= 0 # label fits outside box
|
|
||||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
|
||||||
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
|
|
||||||
cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, lw / 3, txt_color,
|
|
||||||
thickness=tf, lineType=cv2.LINE_AA)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def font_check(font='./yolov6/utils/Arial.ttf', size=10):
|
|
||||||
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
|
||||||
assert osp.exists(font), f'font path not exists: {font}'
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
|
||||||
except Exception as e: # download if missing
|
|
||||||
return ImageFont.truetype(str(font), size)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def box_convert(x):
|
|
||||||
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
||||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
|
||||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
|
||||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
|
||||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
|
||||||
return y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_colors(i, bgr=False):
|
|
||||||
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
|
||||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
|
||||||
palette = []
|
|
||||||
for iter in hex:
|
|
||||||
h = '#' + iter
|
|
||||||
palette.append(tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)))
|
|
||||||
num = len(palette)
|
|
||||||
color = palette[int(i) % num]
|
|
||||||
return (color[2], color[1], color[0]) if bgr else color
|
|
@ -1,193 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# This code is based on
|
|
||||||
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
|
|
||||||
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
|
||||||
# HSV color-space augmentation
|
|
||||||
if hgain or sgain or vgain:
|
|
||||||
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
|
||||||
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
|
|
||||||
dtype = im.dtype # uint8
|
|
||||||
|
|
||||||
x = np.arange(0, 256, dtype=r.dtype)
|
|
||||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
|
||||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
|
||||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
|
||||||
|
|
||||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
|
||||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
|
|
||||||
|
|
||||||
|
|
||||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
|
||||||
# Resize and pad image while meeting stride-multiple constraints
|
|
||||||
shape = im.shape[:2] # current shape [height, width]
|
|
||||||
if isinstance(new_shape, int):
|
|
||||||
new_shape = (new_shape, new_shape)
|
|
||||||
|
|
||||||
# Scale ratio (new / old)
|
|
||||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
||||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
|
||||||
r = min(r, 1.0)
|
|
||||||
|
|
||||||
# Compute padding
|
|
||||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
||||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
|
||||||
|
|
||||||
if auto: # minimum rectangle
|
|
||||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
|
||||||
|
|
||||||
dw /= 2 # divide padding into 2 sides
|
|
||||||
dh /= 2
|
|
||||||
|
|
||||||
if shape[::-1] != new_unpad: # resize
|
|
||||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
||||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
||||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
||||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
|
||||||
return im, r, (dw, dh)
|
|
||||||
|
|
||||||
|
|
||||||
def mixup(im, labels, im2, labels2):
|
|
||||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
|
||||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
|
||||||
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
|
||||||
labels = np.concatenate((labels, labels2), 0)
|
|
||||||
return im, labels
|
|
||||||
|
|
||||||
|
|
||||||
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
|
||||||
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
|
||||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
|
||||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
|
||||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
|
||||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
|
||||||
|
|
||||||
|
|
||||||
def random_affine(img, labels=(), degrees=10, translate=.1, scale=.1, shear=10,
|
|
||||||
new_shape=(640, 640)):
|
|
||||||
|
|
||||||
n = len(labels)
|
|
||||||
height, width = new_shape
|
|
||||||
|
|
||||||
M, s = get_transform_matrix(img.shape[:2], (height, width), degrees, scale, shear, translate)
|
|
||||||
if (M != np.eye(3)).any(): # image changed
|
|
||||||
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
|
||||||
|
|
||||||
# Transform label coordinates
|
|
||||||
if n:
|
|
||||||
new = np.zeros((n, 4))
|
|
||||||
|
|
||||||
xy = np.ones((n * 4, 3))
|
|
||||||
xy[:, :2] = labels[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
|
||||||
xy = xy @ M.T # transform
|
|
||||||
xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine
|
|
||||||
|
|
||||||
# create new boxes
|
|
||||||
x = xy[:, [0, 2, 4, 6]]
|
|
||||||
y = xy[:, [1, 3, 5, 7]]
|
|
||||||
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
|
||||||
|
|
||||||
# clip
|
|
||||||
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
|
||||||
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
|
||||||
|
|
||||||
# filter candidates
|
|
||||||
i = box_candidates(box1=labels[:, 1:5].T * s, box2=new.T, area_thr=0.1)
|
|
||||||
labels = labels[i]
|
|
||||||
labels[:, 1:5] = new[i]
|
|
||||||
|
|
||||||
return img, labels
|
|
||||||
|
|
||||||
|
|
||||||
def get_transform_matrix(img_shape, new_shape, degrees, scale, shear, translate):
|
|
||||||
new_height, new_width = new_shape
|
|
||||||
# Center
|
|
||||||
C = np.eye(3)
|
|
||||||
C[0, 2] = -img_shape[1] / 2 # x translation (pixels)
|
|
||||||
C[1, 2] = -img_shape[0] / 2 # y translation (pixels)
|
|
||||||
|
|
||||||
# Rotation and Scale
|
|
||||||
R = np.eye(3)
|
|
||||||
a = random.uniform(-degrees, degrees)
|
|
||||||
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
|
||||||
s = random.uniform(1 - scale, 1 + scale)
|
|
||||||
# s = 2 ** random.uniform(-scale, scale)
|
|
||||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
|
||||||
|
|
||||||
# Shear
|
|
||||||
S = np.eye(3)
|
|
||||||
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
|
||||||
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
|
||||||
|
|
||||||
# Translation
|
|
||||||
T = np.eye(3)
|
|
||||||
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_width # x translation (pixels)
|
|
||||||
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_height # y transla ion (pixels)
|
|
||||||
|
|
||||||
# Combined rotation matrix
|
|
||||||
M = T @ S @ R @ C # order of operations (right to left) is IMPORTANT
|
|
||||||
return M, s
|
|
||||||
|
|
||||||
|
|
||||||
def mosaic_augmentation(img_size, imgs, hs, ws, labels, hyp):
|
|
||||||
|
|
||||||
assert len(imgs) == 4, "Mosaic augmentation of current version only supports 4 images."
|
|
||||||
|
|
||||||
labels4 = []
|
|
||||||
s = img_size
|
|
||||||
yc, xc = (int(random.uniform(s//2, 3*s//2)) for _ in range(2)) # mosaic center x, y
|
|
||||||
for i in range(len(imgs)):
|
|
||||||
# Load image
|
|
||||||
img, h, w = imgs[i], hs[i], ws[i]
|
|
||||||
# place img in img4
|
|
||||||
if i == 0: # top left
|
|
||||||
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
|
||||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
|
||||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
|
||||||
elif i == 1: # top right
|
|
||||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
|
||||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
|
||||||
elif i == 2: # bottom left
|
|
||||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
|
||||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
|
||||||
elif i == 3: # bottom right
|
|
||||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
|
||||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
|
||||||
|
|
||||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
|
||||||
padw = x1a - x1b
|
|
||||||
padh = y1a - y1b
|
|
||||||
|
|
||||||
# Labels
|
|
||||||
labels_per_img = labels[i].copy()
|
|
||||||
if labels_per_img.size:
|
|
||||||
boxes = np.copy(labels_per_img[:, 1:])
|
|
||||||
boxes[:, 0] = w * (labels_per_img[:, 1] - labels_per_img[:, 3] / 2) + padw # top left x
|
|
||||||
boxes[:, 1] = h * (labels_per_img[:, 2] - labels_per_img[:, 4] / 2) + padh # top left y
|
|
||||||
boxes[:, 2] = w * (labels_per_img[:, 1] + labels_per_img[:, 3] / 2) + padw # bottom right x
|
|
||||||
boxes[:, 3] = h * (labels_per_img[:, 2] + labels_per_img[:, 4] / 2) + padh # bottom right y
|
|
||||||
labels_per_img[:, 1:] = boxes
|
|
||||||
|
|
||||||
labels4.append(labels_per_img)
|
|
||||||
|
|
||||||
# Concat/clip labels
|
|
||||||
labels4 = np.concatenate(labels4, 0)
|
|
||||||
for x in (labels4[:, 1:]):
|
|
||||||
np.clip(x, 0, 2 * s, out=x)
|
|
||||||
|
|
||||||
# Augment
|
|
||||||
img4, labels4 = random_affine(img4, labels4,
|
|
||||||
degrees=hyp['degrees'],
|
|
||||||
translate=hyp['translate'],
|
|
||||||
scale=hyp['scale'],
|
|
||||||
shear=hyp['shear'])
|
|
||||||
|
|
||||||
return img4, labels4
|
|
@ -1,113 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# This code is based on
|
|
||||||
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
|
|
||||||
|
|
||||||
import os
|
|
||||||
from torch.utils.data import dataloader, distributed
|
|
||||||
|
|
||||||
from .datasets import TrainValDataset
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.utils.torch_utils import torch_distributed_zero_first
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(
|
|
||||||
path,
|
|
||||||
img_size,
|
|
||||||
batch_size,
|
|
||||||
stride,
|
|
||||||
hyp=None,
|
|
||||||
augment=False,
|
|
||||||
check_images=False,
|
|
||||||
check_labels=False,
|
|
||||||
pad=0.0,
|
|
||||||
rect=False,
|
|
||||||
rank=-1,
|
|
||||||
workers=8,
|
|
||||||
shuffle=False,
|
|
||||||
data_dict=None,
|
|
||||||
task="Train",
|
|
||||||
):
|
|
||||||
"""Create general dataloader.
|
|
||||||
|
|
||||||
Returns dataloader and dataset
|
|
||||||
"""
|
|
||||||
if rect and shuffle:
|
|
||||||
LOGGER.warning(
|
|
||||||
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
|
|
||||||
)
|
|
||||||
shuffle = False
|
|
||||||
with torch_distributed_zero_first(rank):
|
|
||||||
dataset = TrainValDataset(
|
|
||||||
path,
|
|
||||||
img_size,
|
|
||||||
batch_size,
|
|
||||||
augment=augment,
|
|
||||||
hyp=hyp,
|
|
||||||
rect=rect,
|
|
||||||
check_images=check_images,
|
|
||||||
check_labels=check_labels,
|
|
||||||
stride=int(stride),
|
|
||||||
pad=pad,
|
|
||||||
rank=rank,
|
|
||||||
data_dict=data_dict,
|
|
||||||
task=task,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = min(batch_size, len(dataset))
|
|
||||||
workers = min(
|
|
||||||
[
|
|
||||||
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
|
|
||||||
batch_size if batch_size > 1 else 0,
|
|
||||||
workers,
|
|
||||||
]
|
|
||||||
) # number of workers
|
|
||||||
sampler = (
|
|
||||||
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
TrainValDataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle and sampler is None,
|
|
||||||
num_workers=workers,
|
|
||||||
sampler=sampler,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=TrainValDataset.collate_fn,
|
|
||||||
),
|
|
||||||
dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainValDataLoader(dataloader.DataLoader):
|
|
||||||
"""Dataloader that reuses workers
|
|
||||||
|
|
||||||
Uses same syntax as vanilla DataLoader
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
|
||||||
self.iterator = super().__iter__()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.batch_sampler.sampler)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for i in range(len(self)):
|
|
||||||
yield next(self.iterator)
|
|
||||||
|
|
||||||
|
|
||||||
class _RepeatSampler:
|
|
||||||
"""Sampler that repeats forever
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sampler (Sampler)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sampler):
|
|
||||||
self.sampler = sampler
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
while True:
|
|
||||||
yield from iter(self.sampler)
|
|
@ -1,550 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from multiprocessing.pool import Pool
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import ExifTags, Image, ImageOps
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .data_augment import (
|
|
||||||
augment_hsv,
|
|
||||||
letterbox,
|
|
||||||
mixup,
|
|
||||||
random_affine,
|
|
||||||
mosaic_augmentation,
|
|
||||||
)
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
|
|
||||||
# Get orientation exif tag
|
|
||||||
for k, v in ExifTags.TAGS.items():
|
|
||||||
if v == "Orientation":
|
|
||||||
ORIENTATION = k
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class TrainValDataset(Dataset):
|
|
||||||
# YOLOv6 train_loader/val_loader, loads images and labels for training and validation
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_dir,
|
|
||||||
img_size=640,
|
|
||||||
batch_size=16,
|
|
||||||
augment=False,
|
|
||||||
hyp=None,
|
|
||||||
rect=False,
|
|
||||||
check_images=False,
|
|
||||||
check_labels=False,
|
|
||||||
stride=32,
|
|
||||||
pad=0.0,
|
|
||||||
rank=-1,
|
|
||||||
data_dict=None,
|
|
||||||
task="train",
|
|
||||||
):
|
|
||||||
assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}"
|
|
||||||
t1 = time.time()
|
|
||||||
self.__dict__.update(locals())
|
|
||||||
self.main_process = self.rank in (-1, 0)
|
|
||||||
self.task = self.task.capitalize()
|
|
||||||
self.class_names = data_dict["names"]
|
|
||||||
self.img_paths, self.labels = self.get_imgs_labels(self.img_dir)
|
|
||||||
if self.rect:
|
|
||||||
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
|
|
||||||
self.shapes = np.array(shapes, dtype=np.float64)
|
|
||||||
self.batch_indices = np.floor(
|
|
||||||
np.arange(len(shapes)) / self.batch_size
|
|
||||||
).astype(
|
|
||||||
np.int
|
|
||||||
) # batch indices of each image
|
|
||||||
self.sort_files_shapes()
|
|
||||||
t2 = time.time()
|
|
||||||
if self.main_process:
|
|
||||||
LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""Get the length of dataset"""
|
|
||||||
return len(self.img_paths)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
"""Fetching a data sample for a given key.
|
|
||||||
This function applies mosaic and mixup augments during training.
|
|
||||||
During validation, letterbox augment is applied.
|
|
||||||
"""
|
|
||||||
# Mosaic Augmentation
|
|
||||||
if self.augment and random.random() < self.hyp["mosaic"]:
|
|
||||||
img, labels = self.get_mosaic(index)
|
|
||||||
shapes = None
|
|
||||||
|
|
||||||
# MixUp augmentation
|
|
||||||
if random.random() < self.hyp["mixup"]:
|
|
||||||
img_other, labels_other = self.get_mosaic(
|
|
||||||
random.randint(0, len(self.img_paths) - 1)
|
|
||||||
)
|
|
||||||
img, labels = mixup(img, labels, img_other, labels_other)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Load image
|
|
||||||
img, (h0, w0), (h, w) = self.load_image(index)
|
|
||||||
|
|
||||||
# Letterbox
|
|
||||||
shape = (
|
|
||||||
self.batch_shapes[self.batch_indices[index]]
|
|
||||||
if self.rect
|
|
||||||
else self.img_size
|
|
||||||
) # final letterboxed shape
|
|
||||||
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
|
||||||
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
|
||||||
|
|
||||||
labels = self.labels[index].copy()
|
|
||||||
if labels.size:
|
|
||||||
w *= ratio
|
|
||||||
h *= ratio
|
|
||||||
# new boxes
|
|
||||||
boxes = np.copy(labels[:, 1:])
|
|
||||||
boxes[:, 0] = (
|
|
||||||
w * (labels[:, 1] - labels[:, 3] / 2) + pad[0]
|
|
||||||
) # top left x
|
|
||||||
boxes[:, 1] = (
|
|
||||||
h * (labels[:, 2] - labels[:, 4] / 2) + pad[1]
|
|
||||||
) # top left y
|
|
||||||
boxes[:, 2] = (
|
|
||||||
w * (labels[:, 1] + labels[:, 3] / 2) + pad[0]
|
|
||||||
) # bottom right x
|
|
||||||
boxes[:, 3] = (
|
|
||||||
h * (labels[:, 2] + labels[:, 4] / 2) + pad[1]
|
|
||||||
) # bottom right y
|
|
||||||
labels[:, 1:] = boxes
|
|
||||||
|
|
||||||
if self.augment:
|
|
||||||
img, labels = random_affine(
|
|
||||||
img,
|
|
||||||
labels,
|
|
||||||
degrees=self.hyp["degrees"],
|
|
||||||
translate=self.hyp["translate"],
|
|
||||||
scale=self.hyp["scale"],
|
|
||||||
shear=self.hyp["shear"],
|
|
||||||
new_shape=(self.img_size, self.img_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(labels):
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
labels[:, [1, 3]] = labels[:, [1, 3]].clip(0, w - 1e-3) # x1, x2
|
|
||||||
labels[:, [2, 4]] = labels[:, [2, 4]].clip(0, h - 1e-3) # y1, y2
|
|
||||||
|
|
||||||
boxes = np.copy(labels[:, 1:])
|
|
||||||
boxes[:, 0] = ((labels[:, 1] + labels[:, 3]) / 2) / w # x center
|
|
||||||
boxes[:, 1] = ((labels[:, 2] + labels[:, 4]) / 2) / h # y center
|
|
||||||
boxes[:, 2] = (labels[:, 3] - labels[:, 1]) / w # width
|
|
||||||
boxes[:, 3] = (labels[:, 4] - labels[:, 2]) / h # height
|
|
||||||
labels[:, 1:] = boxes
|
|
||||||
|
|
||||||
if self.augment:
|
|
||||||
img, labels = self.general_augment(img, labels)
|
|
||||||
|
|
||||||
labels_out = torch.zeros((len(labels), 6))
|
|
||||||
if len(labels):
|
|
||||||
labels_out[:, 1:] = torch.from_numpy(labels)
|
|
||||||
|
|
||||||
# Convert
|
|
||||||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
img = np.ascontiguousarray(img)
|
|
||||||
|
|
||||||
return torch.from_numpy(img), labels_out, self.img_paths[index], shapes
|
|
||||||
|
|
||||||
def load_image(self, index):
|
|
||||||
"""Load image.
|
|
||||||
This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image, original shape of image, resized image shape
|
|
||||||
"""
|
|
||||||
path = self.img_paths[index]
|
|
||||||
im = cv2.imread(path)
|
|
||||||
assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}"
|
|
||||||
|
|
||||||
h0, w0 = im.shape[:2] # origin shape
|
|
||||||
r = self.img_size / max(h0, w0)
|
|
||||||
if r != 1:
|
|
||||||
im = cv2.resize(
|
|
||||||
im,
|
|
||||||
(int(w0 * r), int(h0 * r)),
|
|
||||||
interpolation=cv2.INTER_AREA
|
|
||||||
if r < 1 and not self.augment
|
|
||||||
else cv2.INTER_LINEAR,
|
|
||||||
)
|
|
||||||
return im, (h0, w0), im.shape[:2]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def collate_fn(batch):
|
|
||||||
"""Merges a list of samples to form a mini-batch of Tensor(s)"""
|
|
||||||
img, label, path, shapes = zip(*batch)
|
|
||||||
for i, l in enumerate(label):
|
|
||||||
l[:, 0] = i # add target image index for build_targets()
|
|
||||||
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
|
||||||
|
|
||||||
def get_imgs_labels(self, img_dir):
|
|
||||||
|
|
||||||
assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!"
|
|
||||||
valid_img_record = osp.join(
|
|
||||||
osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
|
|
||||||
)
|
|
||||||
NUM_THREADS = min(8, os.cpu_count())
|
|
||||||
|
|
||||||
img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
|
|
||||||
img_paths = sorted(
|
|
||||||
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
|
|
||||||
)
|
|
||||||
assert img_paths, f"No images found in {img_dir}."
|
|
||||||
|
|
||||||
img_hash = self.get_hash(img_paths)
|
|
||||||
if osp.exists(valid_img_record):
|
|
||||||
with open(valid_img_record, "r") as f:
|
|
||||||
cache_info = json.load(f)
|
|
||||||
if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
|
|
||||||
img_info = cache_info["information"]
|
|
||||||
else:
|
|
||||||
self.check_images = True
|
|
||||||
else:
|
|
||||||
self.check_images = True
|
|
||||||
|
|
||||||
# check images
|
|
||||||
if self.check_images and self.main_process:
|
|
||||||
img_info = {}
|
|
||||||
nc, msgs = 0, [] # number corrupt, messages
|
|
||||||
LOGGER.info(
|
|
||||||
f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
|
|
||||||
)
|
|
||||||
with Pool(NUM_THREADS) as pool:
|
|
||||||
pbar = tqdm(
|
|
||||||
pool.imap(TrainValDataset.check_image, img_paths),
|
|
||||||
total=len(img_paths),
|
|
||||||
)
|
|
||||||
for img_path, shape_per_img, nc_per_img, msg in pbar:
|
|
||||||
if nc_per_img == 0: # not corrupted
|
|
||||||
img_info[img_path] = {"shape": shape_per_img}
|
|
||||||
nc += nc_per_img
|
|
||||||
if msg:
|
|
||||||
msgs.append(msg)
|
|
||||||
pbar.desc = f"{nc} image(s) corrupted"
|
|
||||||
pbar.close()
|
|
||||||
if msgs:
|
|
||||||
LOGGER.info("\n".join(msgs))
|
|
||||||
|
|
||||||
cache_info = {"information": img_info, "image_hash": img_hash}
|
|
||||||
# save valid image paths.
|
|
||||||
with open(valid_img_record, "w") as f:
|
|
||||||
json.dump(cache_info, f)
|
|
||||||
|
|
||||||
# check and load anns
|
|
||||||
label_dir = osp.join(
|
|
||||||
osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
|
|
||||||
)
|
|
||||||
assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
|
|
||||||
|
|
||||||
img_paths = list(img_info.keys())
|
|
||||||
label_paths = sorted(
|
|
||||||
osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt")
|
|
||||||
for p in img_paths
|
|
||||||
)
|
|
||||||
label_hash = self.get_hash(label_paths)
|
|
||||||
if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
|
|
||||||
self.check_labels = True
|
|
||||||
|
|
||||||
if self.check_labels:
|
|
||||||
cache_info["label_hash"] = label_hash
|
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages
|
|
||||||
LOGGER.info(
|
|
||||||
f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
|
|
||||||
)
|
|
||||||
with Pool(NUM_THREADS) as pool:
|
|
||||||
pbar = pool.imap(
|
|
||||||
TrainValDataset.check_label_files, zip(img_paths, label_paths)
|
|
||||||
)
|
|
||||||
pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
|
|
||||||
for (
|
|
||||||
img_path,
|
|
||||||
labels_per_file,
|
|
||||||
nc_per_file,
|
|
||||||
nm_per_file,
|
|
||||||
nf_per_file,
|
|
||||||
ne_per_file,
|
|
||||||
msg,
|
|
||||||
) in pbar:
|
|
||||||
if nc_per_file == 0:
|
|
||||||
img_info[img_path]["labels"] = labels_per_file
|
|
||||||
else:
|
|
||||||
img_info.pop(img_path)
|
|
||||||
nc += nc_per_file
|
|
||||||
nm += nm_per_file
|
|
||||||
nf += nf_per_file
|
|
||||||
ne += ne_per_file
|
|
||||||
if msg:
|
|
||||||
msgs.append(msg)
|
|
||||||
if self.main_process:
|
|
||||||
pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
|
|
||||||
if self.main_process:
|
|
||||||
pbar.close()
|
|
||||||
with open(valid_img_record, "w") as f:
|
|
||||||
json.dump(cache_info, f)
|
|
||||||
if msgs:
|
|
||||||
LOGGER.info("\n".join(msgs))
|
|
||||||
if nf == 0:
|
|
||||||
LOGGER.warning(
|
|
||||||
f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.task.lower() == "val":
|
|
||||||
if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
|
|
||||||
assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
self.class_names
|
|
||||||
), "Class names is required when converting labels to coco format for evaluating."
|
|
||||||
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
|
|
||||||
if not osp.exists(save_dir):
|
|
||||||
os.mkdir(save_dir)
|
|
||||||
save_path = osp.join(
|
|
||||||
save_dir, "instances_" + osp.basename(img_dir) + ".json"
|
|
||||||
)
|
|
||||||
TrainValDataset.generate_coco_format_labels(
|
|
||||||
img_info, self.class_names, save_path
|
|
||||||
)
|
|
||||||
|
|
||||||
img_paths, labels = list(
|
|
||||||
zip(
|
|
||||||
*[
|
|
||||||
(
|
|
||||||
img_path,
|
|
||||||
np.array(info["labels"], dtype=np.float32)
|
|
||||||
if info["labels"]
|
|
||||||
else np.zeros((0, 5), dtype=np.float32),
|
|
||||||
)
|
|
||||||
for img_path, info in img_info.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.img_info = img_info
|
|
||||||
LOGGER.info(
|
|
||||||
f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
|
|
||||||
)
|
|
||||||
return img_paths, labels
|
|
||||||
|
|
||||||
def get_mosaic(self, index):
|
|
||||||
"""Gets images and labels after mosaic augments"""
|
|
||||||
indices = [index] + random.choices(
|
|
||||||
range(0, len(self.img_paths)), k=3
|
|
||||||
) # 3 additional image indices
|
|
||||||
random.shuffle(indices)
|
|
||||||
imgs, hs, ws, labels = [], [], [], []
|
|
||||||
for index in indices:
|
|
||||||
img, _, (h, w) = self.load_image(index)
|
|
||||||
labels_per_img = self.labels[index]
|
|
||||||
imgs.append(img)
|
|
||||||
hs.append(h)
|
|
||||||
ws.append(w)
|
|
||||||
labels.append(labels_per_img)
|
|
||||||
img, labels = mosaic_augmentation(self.img_size, imgs, hs, ws, labels, self.hyp)
|
|
||||||
return img, labels
|
|
||||||
|
|
||||||
def general_augment(self, img, labels):
|
|
||||||
"""Gets images and labels after general augment
|
|
||||||
This function applies hsv, random ud-flip and random lr-flips augments.
|
|
||||||
"""
|
|
||||||
nl = len(labels)
|
|
||||||
|
|
||||||
# HSV color-space
|
|
||||||
augment_hsv(
|
|
||||||
img,
|
|
||||||
hgain=self.hyp["hsv_h"],
|
|
||||||
sgain=self.hyp["hsv_s"],
|
|
||||||
vgain=self.hyp["hsv_v"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip up-down
|
|
||||||
if random.random() < self.hyp["flipud"]:
|
|
||||||
img = np.flipud(img)
|
|
||||||
if nl:
|
|
||||||
labels[:, 2] = 1 - labels[:, 2]
|
|
||||||
|
|
||||||
# Flip left-right
|
|
||||||
if random.random() < self.hyp["fliplr"]:
|
|
||||||
img = np.fliplr(img)
|
|
||||||
if nl:
|
|
||||||
labels[:, 1] = 1 - labels[:, 1]
|
|
||||||
|
|
||||||
return img, labels
|
|
||||||
|
|
||||||
def sort_files_shapes(self):
|
|
||||||
# Sort by aspect ratio
|
|
||||||
batch_num = self.batch_indices[-1] + 1
|
|
||||||
s = self.shapes # wh
|
|
||||||
ar = s[:, 1] / s[:, 0] # aspect ratio
|
|
||||||
irect = ar.argsort()
|
|
||||||
self.img_paths = [self.img_paths[i] for i in irect]
|
|
||||||
self.labels = [self.labels[i] for i in irect]
|
|
||||||
self.shapes = s[irect] # wh
|
|
||||||
ar = ar[irect]
|
|
||||||
|
|
||||||
# Set training image shapes
|
|
||||||
shapes = [[1, 1]] * batch_num
|
|
||||||
for i in range(batch_num):
|
|
||||||
ari = ar[self.batch_indices == i]
|
|
||||||
mini, maxi = ari.min(), ari.max()
|
|
||||||
if maxi < 1:
|
|
||||||
shapes[i] = [maxi, 1]
|
|
||||||
elif mini > 1:
|
|
||||||
shapes[i] = [1, 1 / mini]
|
|
||||||
self.batch_shapes = (
|
|
||||||
np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(
|
|
||||||
np.int
|
|
||||||
)
|
|
||||||
* self.stride
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_image(im_file):
|
|
||||||
# verify an image.
|
|
||||||
nc, msg = 0, ""
|
|
||||||
try:
|
|
||||||
im = Image.open(im_file)
|
|
||||||
im.verify() # PIL verify
|
|
||||||
shape = im.size # (width, height)
|
|
||||||
im_exif = im._getexif()
|
|
||||||
if im_exif and ORIENTATION in im_exif:
|
|
||||||
rotation = im_exif[ORIENTATION]
|
|
||||||
if rotation in (6, 8):
|
|
||||||
shape = (shape[1], shape[0])
|
|
||||||
|
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
|
||||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
|
||||||
if im.format.lower() in ("jpg", "jpeg"):
|
|
||||||
with open(im_file, "rb") as f:
|
|
||||||
f.seek(-2, 2)
|
|
||||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
|
||||||
ImageOps.exif_transpose(Image.open(im_file)).save(
|
|
||||||
im_file, "JPEG", subsampling=0, quality=100
|
|
||||||
)
|
|
||||||
msg += f"WARNING: {im_file}: corrupt JPEG restored and saved"
|
|
||||||
return im_file, shape, nc, msg
|
|
||||||
except Exception as e:
|
|
||||||
nc = 1
|
|
||||||
msg = f"WARNING: {im_file}: ignoring corrupt image: {e}"
|
|
||||||
return im_file, None, nc, msg
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_label_files(args):
|
|
||||||
img_path, lb_path = args
|
|
||||||
nm, nf, ne, nc, msg = 0, 0, 0, 0, "" # number (missing, found, empty, message
|
|
||||||
try:
|
|
||||||
if osp.exists(lb_path):
|
|
||||||
nf = 1 # label found
|
|
||||||
with open(lb_path, "r") as f:
|
|
||||||
labels = [
|
|
||||||
x.split() for x in f.read().strip().splitlines() if len(x)
|
|
||||||
]
|
|
||||||
labels = np.array(labels, dtype=np.float32)
|
|
||||||
if len(labels):
|
|
||||||
assert all(
|
|
||||||
len(l) == 5 for l in labels
|
|
||||||
), f"{lb_path}: wrong label format."
|
|
||||||
assert (
|
|
||||||
labels >= 0
|
|
||||||
).all(), f"{lb_path}: Label values error: all values in label file must > 0"
|
|
||||||
assert (
|
|
||||||
labels[:, 1:] <= 1
|
|
||||||
).all(), f"{lb_path}: Label values error: all coordinates must be normalized"
|
|
||||||
|
|
||||||
_, indices = np.unique(labels, axis=0, return_index=True)
|
|
||||||
if len(indices) < len(labels): # duplicate row check
|
|
||||||
labels = labels[indices] # remove duplicates
|
|
||||||
msg += f"WARNING: {lb_path}: {len(labels) - len(indices)} duplicate labels removed"
|
|
||||||
labels = labels.tolist()
|
|
||||||
else:
|
|
||||||
ne = 1 # label empty
|
|
||||||
labels = []
|
|
||||||
else:
|
|
||||||
nm = 1 # label missing
|
|
||||||
labels = []
|
|
||||||
|
|
||||||
return img_path, labels, nc, nm, nf, ne, msg
|
|
||||||
except Exception as e:
|
|
||||||
nc = 1
|
|
||||||
msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}"
|
|
||||||
return img_path, None, nc, nm, nf, ne, msg
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_coco_format_labels(img_info, class_names, save_path):
|
|
||||||
# for evaluation with pycocotools
|
|
||||||
dataset = {"categories": [], "annotations": [], "images": []}
|
|
||||||
for i, class_name in enumerate(class_names):
|
|
||||||
dataset["categories"].append(
|
|
||||||
{"id": i, "name": class_name, "supercategory": ""}
|
|
||||||
)
|
|
||||||
|
|
||||||
ann_id = 0
|
|
||||||
LOGGER.info(f"Convert to COCO format")
|
|
||||||
for i, (img_path, info) in enumerate(tqdm(img_info.items())):
|
|
||||||
labels = info["labels"] if info["labels"] else []
|
|
||||||
img_id = osp.splitext(osp.basename(img_path))[0]
|
|
||||||
img_id = int(img_id) if img_id.isnumeric() else img_id
|
|
||||||
img_w, img_h = info["shape"]
|
|
||||||
dataset["images"].append(
|
|
||||||
{
|
|
||||||
"file_name": os.path.basename(img_path),
|
|
||||||
"id": img_id,
|
|
||||||
"width": img_w,
|
|
||||||
"height": img_h,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if labels:
|
|
||||||
for label in labels:
|
|
||||||
c, x, y, w, h = label[:5]
|
|
||||||
# convert x,y,w,h to x1,y1,x2,y2
|
|
||||||
x1 = (x - w / 2) * img_w
|
|
||||||
y1 = (y - h / 2) * img_h
|
|
||||||
x2 = (x + w / 2) * img_w
|
|
||||||
y2 = (y + h / 2) * img_h
|
|
||||||
# cls_id starts from 0
|
|
||||||
cls_id = int(c)
|
|
||||||
w = max(0, x2 - x1)
|
|
||||||
h = max(0, y2 - y1)
|
|
||||||
dataset["annotations"].append(
|
|
||||||
{
|
|
||||||
"area": h * w,
|
|
||||||
"bbox": [x1, y1, w, h],
|
|
||||||
"category_id": cls_id,
|
|
||||||
"id": ann_id,
|
|
||||||
"image_id": img_id,
|
|
||||||
"iscrowd": 0,
|
|
||||||
# mask
|
|
||||||
"segmentation": [],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ann_id += 1
|
|
||||||
|
|
||||||
with open(save_path, "w") as f:
|
|
||||||
json.dump(dataset, f)
|
|
||||||
LOGGER.info(
|
|
||||||
f"Convert to COCO format finished. Resutls saved in {save_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_hash(paths):
|
|
||||||
"""Get the hash value of paths"""
|
|
||||||
assert isinstance(paths, list), "Only support list currently."
|
|
||||||
h = hashlib.md5("".join(paths).encode())
|
|
||||||
return h.hexdigest()
|
|
@ -1,501 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from yolov6.layers.dbb_transforms import *
|
|
||||||
|
|
||||||
|
|
||||||
class SiLU(nn.Module):
|
|
||||||
'''Activation of SiLU'''
|
|
||||||
@staticmethod
|
|
||||||
def forward(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv(nn.Module):
|
|
||||||
'''Normal Conv with SiLU activation'''
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
padding = kernel_size // 2
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.bn = nn.BatchNorm2d(out_channels)
|
|
||||||
self.act = nn.SiLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.bn(self.conv(x)))
|
|
||||||
|
|
||||||
def forward_fuse(self, x):
|
|
||||||
return self.act(self.conv(x))
|
|
||||||
|
|
||||||
|
|
||||||
class SimConv(nn.Module):
|
|
||||||
'''Normal Conv with ReLU activation'''
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
padding = kernel_size // 2
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.bn = nn.BatchNorm2d(out_channels)
|
|
||||||
self.act = nn.ReLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.bn(self.conv(x)))
|
|
||||||
|
|
||||||
def forward_fuse(self, x):
|
|
||||||
return self.act(self.conv(x))
|
|
||||||
|
|
||||||
|
|
||||||
class SimSPPF(nn.Module):
|
|
||||||
'''Simplified SPPF with ReLU activation'''
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=5):
|
|
||||||
super().__init__()
|
|
||||||
c_ = in_channels // 2 # hidden channels
|
|
||||||
self.cv1 = SimConv(in_channels, c_, 1, 1)
|
|
||||||
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
|
|
||||||
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.cv1(x)
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter('ignore')
|
|
||||||
y1 = self.m(x)
|
|
||||||
y2 = self.m(y1)
|
|
||||||
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
|
|
||||||
|
|
||||||
|
|
||||||
class Transpose(nn.Module):
|
|
||||||
'''Normal Transpose, default for upsampling'''
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
|
|
||||||
super().__init__()
|
|
||||||
self.upsample_transpose = torch.nn.ConvTranspose2d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.upsample_transpose(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Concat(nn.Module):
|
|
||||||
def __init__(self, dimension=1):
|
|
||||||
super().__init__()
|
|
||||||
self.d = dimension
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.cat(x, self.d)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
|
|
||||||
'''Basic cell for rep-style block, including conv and bn'''
|
|
||||||
result = nn.Sequential()
|
|
||||||
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
|
|
||||||
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class RepBlock(nn.Module):
|
|
||||||
'''
|
|
||||||
RepBlock is a stage block with rep-style basic block
|
|
||||||
'''
|
|
||||||
def __init__(self, in_channels, out_channels, n=1):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = RepVGGBlock(in_channels, out_channels)
|
|
||||||
self.block = nn.Sequential(*(RepVGGBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
if self.block is not None:
|
|
||||||
x = self.block(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class RepVGGBlock(nn.Module):
|
|
||||||
'''RepVGGBlock is a basic rep-style block, including training and deploy status
|
|
||||||
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
|
||||||
'''
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
||||||
stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
|
|
||||||
super(RepVGGBlock, self).__init__()
|
|
||||||
""" Initialization of the class.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Number of channels in the input image
|
|
||||||
out_channels (int): Number of channels produced by the convolution
|
|
||||||
kernel_size (int or tuple): Size of the convolving kernel
|
|
||||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
||||||
padding (int or tuple, optional): Zero-padding added to both sides of
|
|
||||||
the input. Default: 1
|
|
||||||
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
||||||
groups (int, optional): Number of blocked connections from input
|
|
||||||
channels to output channels. Default: 1
|
|
||||||
padding_mode (string, optional): Default: 'zeros'
|
|
||||||
deploy: Whether to be deploy status or training status. Default: False
|
|
||||||
use_se: Whether to use se. Default: False
|
|
||||||
"""
|
|
||||||
self.deploy = deploy
|
|
||||||
self.groups = groups
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
assert kernel_size == 3
|
|
||||||
assert padding == 1
|
|
||||||
|
|
||||||
padding_11 = padding - kernel_size // 2
|
|
||||||
|
|
||||||
self.nonlinearity = nn.ReLU()
|
|
||||||
|
|
||||||
if use_se:
|
|
||||||
raise NotImplementedError("se block not supported yet")
|
|
||||||
else:
|
|
||||||
self.se = nn.Identity()
|
|
||||||
|
|
||||||
if deploy:
|
|
||||||
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
|
||||||
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
|
|
||||||
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
'''Forward process'''
|
|
||||||
if hasattr(self, 'rbr_reparam'):
|
|
||||||
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
|
||||||
|
|
||||||
if self.rbr_identity is None:
|
|
||||||
id_out = 0
|
|
||||||
else:
|
|
||||||
id_out = self.rbr_identity(inputs)
|
|
||||||
|
|
||||||
return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
|
|
||||||
|
|
||||||
def get_equivalent_kernel_bias(self):
|
|
||||||
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
|
||||||
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
|
||||||
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
|
||||||
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
|
||||||
|
|
||||||
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
|
||||||
if kernel1x1 is None:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
|
||||||
|
|
||||||
def _fuse_bn_tensor(self, branch):
|
|
||||||
if branch is None:
|
|
||||||
return 0, 0
|
|
||||||
if isinstance(branch, nn.Sequential):
|
|
||||||
kernel = branch.conv.weight
|
|
||||||
running_mean = branch.bn.running_mean
|
|
||||||
running_var = branch.bn.running_var
|
|
||||||
gamma = branch.bn.weight
|
|
||||||
beta = branch.bn.bias
|
|
||||||
eps = branch.bn.eps
|
|
||||||
else:
|
|
||||||
assert isinstance(branch, nn.BatchNorm2d)
|
|
||||||
if not hasattr(self, 'id_tensor'):
|
|
||||||
input_dim = self.in_channels // self.groups
|
|
||||||
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
|
||||||
for i in range(self.in_channels):
|
|
||||||
kernel_value[i, i % input_dim, 1, 1] = 1
|
|
||||||
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
|
||||||
kernel = self.id_tensor
|
|
||||||
running_mean = branch.running_mean
|
|
||||||
running_var = branch.running_var
|
|
||||||
gamma = branch.weight
|
|
||||||
beta = branch.bias
|
|
||||||
eps = branch.eps
|
|
||||||
std = (running_var + eps).sqrt()
|
|
||||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
||||||
return kernel * t, beta - running_mean * gamma / std
|
|
||||||
|
|
||||||
def switch_to_deploy(self):
|
|
||||||
if hasattr(self, 'rbr_reparam'):
|
|
||||||
return
|
|
||||||
kernel, bias = self.get_equivalent_kernel_bias()
|
|
||||||
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
|
|
||||||
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
|
|
||||||
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
|
|
||||||
self.rbr_reparam.weight.data = kernel
|
|
||||||
self.rbr_reparam.bias.data = bias
|
|
||||||
for para in self.parameters():
|
|
||||||
para.detach_()
|
|
||||||
self.__delattr__('rbr_dense')
|
|
||||||
self.__delattr__('rbr_1x1')
|
|
||||||
if hasattr(self, 'rbr_identity'):
|
|
||||||
self.__delattr__('rbr_identity')
|
|
||||||
if hasattr(self, 'id_tensor'):
|
|
||||||
self.__delattr__('id_tensor')
|
|
||||||
self.deploy = True
|
|
||||||
|
|
||||||
|
|
||||||
def conv_bn_v2(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
|
|
||||||
padding_mode='zeros'):
|
|
||||||
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
|
||||||
stride=stride, padding=padding, dilation=dilation, groups=groups,
|
|
||||||
bias=False, padding_mode=padding_mode)
|
|
||||||
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
|
|
||||||
se = nn.Sequential()
|
|
||||||
se.add_module('conv', conv_layer)
|
|
||||||
se.add_module('bn', bn_layer)
|
|
||||||
return se
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityBasedConv1x1(nn.Conv2d):
|
|
||||||
|
|
||||||
def __init__(self, channels, groups=1):
|
|
||||||
super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False)
|
|
||||||
|
|
||||||
assert channels % groups == 0
|
|
||||||
input_dim = channels // groups
|
|
||||||
id_value = np.zeros((channels, input_dim, 1, 1))
|
|
||||||
for i in range(channels):
|
|
||||||
id_value[i, i % input_dim, 0, 0] = 1
|
|
||||||
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
|
|
||||||
nn.init.zeros_(self.weight)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
kernel = self.weight + self.id_tensor.to(self.weight.device)
|
|
||||||
result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_actual_kernel(self):
|
|
||||||
return self.weight + self.id_tensor.to(self.weight.device)
|
|
||||||
|
|
||||||
|
|
||||||
class BNAndPadLayer(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
pad_pixels,
|
|
||||||
num_features,
|
|
||||||
eps=1e-5,
|
|
||||||
momentum=0.1,
|
|
||||||
affine=True,
|
|
||||||
track_running_stats=True):
|
|
||||||
super(BNAndPadLayer, self).__init__()
|
|
||||||
self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
|
||||||
self.pad_pixels = pad_pixels
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = self.bn(input)
|
|
||||||
if self.pad_pixels > 0:
|
|
||||||
if self.bn.affine:
|
|
||||||
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps)
|
|
||||||
else:
|
|
||||||
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
|
|
||||||
output = F.pad(output, [self.pad_pixels] * 4)
|
|
||||||
pad_values = pad_values.view(1, -1, 1, 1)
|
|
||||||
output[:, :, 0:self.pad_pixels, :] = pad_values
|
|
||||||
output[:, :, -self.pad_pixels:, :] = pad_values
|
|
||||||
output[:, :, :, 0:self.pad_pixels] = pad_values
|
|
||||||
output[:, :, :, -self.pad_pixels:] = pad_values
|
|
||||||
return output
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bn_weight(self):
|
|
||||||
return self.bn.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bn_bias(self):
|
|
||||||
return self.bn.bias
|
|
||||||
|
|
||||||
@property
|
|
||||||
def running_mean(self):
|
|
||||||
return self.bn.running_mean
|
|
||||||
|
|
||||||
@property
|
|
||||||
def running_var(self):
|
|
||||||
return self.bn.running_var
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eps(self):
|
|
||||||
return self.bn.eps
|
|
||||||
|
|
||||||
|
|
||||||
class DBBBlock(nn.Module):
|
|
||||||
'''
|
|
||||||
RepBlock is a stage block with rep-style basic block
|
|
||||||
'''
|
|
||||||
def __init__(self, in_channels, out_channels, n=1):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = DiverseBranchBlock(in_channels, out_channels)
|
|
||||||
self.block = nn.Sequential(*(DiverseBranchBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
if self.block is not None:
|
|
||||||
x = self.block(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DiverseBranchBlock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
||||||
stride=1, padding=1, dilation=1, groups=1,
|
|
||||||
internal_channels_1x1_3x3=None,
|
|
||||||
deploy=False, nonlinear=nn.ReLU(), single_init=False):
|
|
||||||
super(DiverseBranchBlock, self).__init__()
|
|
||||||
self.deploy = deploy
|
|
||||||
|
|
||||||
if nonlinear is None:
|
|
||||||
self.nonlinear = nn.Identity()
|
|
||||||
else:
|
|
||||||
self.nonlinear = nonlinear
|
|
||||||
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.groups = groups
|
|
||||||
assert padding == kernel_size // 2
|
|
||||||
|
|
||||||
if deploy:
|
|
||||||
self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, groups=groups, bias=True)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
self.dbb_origin = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
|
||||||
|
|
||||||
self.dbb_avg = nn.Sequential()
|
|
||||||
if groups < out_channels:
|
|
||||||
self.dbb_avg.add_module('conv',
|
|
||||||
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
|
|
||||||
stride=1, padding=0, groups=groups, bias=False))
|
|
||||||
self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels))
|
|
||||||
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
|
|
||||||
self.dbb_1x1 = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
|
|
||||||
padding=0, groups=groups)
|
|
||||||
else:
|
|
||||||
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding))
|
|
||||||
|
|
||||||
self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels))
|
|
||||||
|
|
||||||
if internal_channels_1x1_3x3 is None:
|
|
||||||
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
|
|
||||||
|
|
||||||
self.dbb_1x1_kxk = nn.Sequential()
|
|
||||||
if internal_channels_1x1_3x3 == in_channels:
|
|
||||||
self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups))
|
|
||||||
else:
|
|
||||||
self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3,
|
|
||||||
kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
|
|
||||||
self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
|
|
||||||
self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
|
|
||||||
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
|
|
||||||
|
|
||||||
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
|
|
||||||
if single_init:
|
|
||||||
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
|
|
||||||
self.single_init()
|
|
||||||
|
|
||||||
def get_equivalent_kernel_bias(self):
|
|
||||||
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn)
|
|
||||||
|
|
||||||
if hasattr(self, 'dbb_1x1'):
|
|
||||||
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
|
|
||||||
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
|
|
||||||
else:
|
|
||||||
k_1x1, b_1x1 = 0, 0
|
|
||||||
|
|
||||||
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
|
|
||||||
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
|
|
||||||
else:
|
|
||||||
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
|
|
||||||
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
|
|
||||||
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
|
|
||||||
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups)
|
|
||||||
|
|
||||||
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
|
|
||||||
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn)
|
|
||||||
if hasattr(self.dbb_avg, 'conv'):
|
|
||||||
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn)
|
|
||||||
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups)
|
|
||||||
else:
|
|
||||||
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
|
|
||||||
|
|
||||||
return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
|
|
||||||
|
|
||||||
def switch_to_deploy(self):
|
|
||||||
if hasattr(self, 'dbb_reparam'):
|
|
||||||
return
|
|
||||||
kernel, bias = self.get_equivalent_kernel_bias()
|
|
||||||
self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels,
|
|
||||||
kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride,
|
|
||||||
padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True)
|
|
||||||
self.dbb_reparam.weight.data = kernel
|
|
||||||
self.dbb_reparam.bias.data = bias
|
|
||||||
for para in self.parameters():
|
|
||||||
para.detach_()
|
|
||||||
self.__delattr__('dbb_origin')
|
|
||||||
self.__delattr__('dbb_avg')
|
|
||||||
if hasattr(self, 'dbb_1x1'):
|
|
||||||
self.__delattr__('dbb_1x1')
|
|
||||||
self.__delattr__('dbb_1x1_kxk')
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
|
|
||||||
if hasattr(self, 'dbb_reparam'):
|
|
||||||
return self.nonlinear(self.dbb_reparam(inputs))
|
|
||||||
|
|
||||||
out = self.dbb_origin(inputs)
|
|
||||||
if hasattr(self, 'dbb_1x1'):
|
|
||||||
out += self.dbb_1x1(inputs)
|
|
||||||
out += self.dbb_avg(inputs)
|
|
||||||
out += self.dbb_1x1_kxk(inputs)
|
|
||||||
return self.nonlinear(out)
|
|
||||||
|
|
||||||
def init_gamma(self, gamma_value):
|
|
||||||
if hasattr(self, "dbb_origin"):
|
|
||||||
torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
|
|
||||||
if hasattr(self, "dbb_1x1"):
|
|
||||||
torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
|
|
||||||
if hasattr(self, "dbb_avg"):
|
|
||||||
torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
|
|
||||||
if hasattr(self, "dbb_1x1_kxk"):
|
|
||||||
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
|
|
||||||
|
|
||||||
def single_init(self):
|
|
||||||
self.init_gamma(0.0)
|
|
||||||
if hasattr(self, "dbb_origin"):
|
|
||||||
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectBackend(nn.Module):
|
|
||||||
def __init__(self, weights='yolov6s.pt', device=None, dnn=True):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(weights, str) and Path(weights).suffix == '.pt', f'{Path(weights).suffix} format is not supported.'
|
|
||||||
from yolov6.utils.checkpoint import load_checkpoint
|
|
||||||
model = load_checkpoint(weights, map_location=device)
|
|
||||||
stride = int(model.stride.max())
|
|
||||||
self.__dict__.update(locals()) # assign all variables to self
|
|
||||||
|
|
||||||
def forward(self, im, val=False):
|
|
||||||
y = self.model(im)
|
|
||||||
if isinstance(y, np.ndarray):
|
|
||||||
y = torch.tensor(y, device=self.device)
|
|
||||||
return y
|
|
@ -1,50 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def transI_fusebn(kernel, bn):
|
|
||||||
gamma = bn.weight
|
|
||||||
std = (bn.running_var + bn.eps).sqrt()
|
|
||||||
return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
|
|
||||||
|
|
||||||
|
|
||||||
def transII_addbranch(kernels, biases):
|
|
||||||
return sum(kernels), sum(biases)
|
|
||||||
|
|
||||||
|
|
||||||
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
|
|
||||||
if groups == 1:
|
|
||||||
k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
|
|
||||||
b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
|
|
||||||
else:
|
|
||||||
k_slices = []
|
|
||||||
b_slices = []
|
|
||||||
k1_T = k1.permute(1, 0, 2, 3)
|
|
||||||
k1_group_width = k1.size(0) // groups
|
|
||||||
k2_group_width = k2.size(0) // groups
|
|
||||||
for g in range(groups):
|
|
||||||
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
|
|
||||||
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
|
|
||||||
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
|
|
||||||
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
|
|
||||||
k, b_hat = transIV_depthconcat(k_slices, b_slices)
|
|
||||||
return k, b_hat + b2
|
|
||||||
|
|
||||||
|
|
||||||
def transIV_depthconcat(kernels, biases):
|
|
||||||
return torch.cat(kernels, dim=0), torch.cat(biases)
|
|
||||||
|
|
||||||
|
|
||||||
def transV_avg(channels, kernel_size, groups):
|
|
||||||
input_dim = channels // groups
|
|
||||||
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
|
||||||
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
|
|
||||||
def transVI_multiscale(kernel, target_kernel_size):
|
|
||||||
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
|
|
||||||
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
|
|
||||||
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
|
|
@ -1,102 +0,0 @@
|
|||||||
from torch import nn
|
|
||||||
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
|
|
||||||
|
|
||||||
|
|
||||||
class EfficientRep(nn.Module):
|
|
||||||
'''EfficientRep Backbone
|
|
||||||
EfficientRep is handcrafted by hardware-aware neural network design.
|
|
||||||
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=3,
|
|
||||||
channels_list=None,
|
|
||||||
num_repeats=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert channels_list is not None
|
|
||||||
assert num_repeats is not None
|
|
||||||
|
|
||||||
self.stem = RepVGGBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=channels_list[0],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ERBlock_2 = nn.Sequential(
|
|
||||||
RepVGGBlock(
|
|
||||||
in_channels=channels_list[0],
|
|
||||||
out_channels=channels_list[1],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
),
|
|
||||||
RepBlock(
|
|
||||||
in_channels=channels_list[1],
|
|
||||||
out_channels=channels_list[1],
|
|
||||||
n=num_repeats[1]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ERBlock_3 = nn.Sequential(
|
|
||||||
RepVGGBlock(
|
|
||||||
in_channels=channels_list[1],
|
|
||||||
out_channels=channels_list[2],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
),
|
|
||||||
RepBlock(
|
|
||||||
in_channels=channels_list[2],
|
|
||||||
out_channels=channels_list[2],
|
|
||||||
n=num_repeats[2]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ERBlock_4 = nn.Sequential(
|
|
||||||
RepVGGBlock(
|
|
||||||
in_channels=channels_list[2],
|
|
||||||
out_channels=channels_list[3],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
),
|
|
||||||
RepBlock(
|
|
||||||
in_channels=channels_list[3],
|
|
||||||
out_channels=channels_list[3],
|
|
||||||
n=num_repeats[3]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ERBlock_5 = nn.Sequential(
|
|
||||||
RepVGGBlock(
|
|
||||||
in_channels=channels_list[3],
|
|
||||||
out_channels=channels_list[4],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
RepBlock(
|
|
||||||
in_channels=channels_list[4],
|
|
||||||
out_channels=channels_list[4],
|
|
||||||
n=num_repeats[4]
|
|
||||||
),
|
|
||||||
SimSPPF(
|
|
||||||
in_channels=channels_list[4],
|
|
||||||
out_channels=channels_list[4],
|
|
||||||
kernel_size=5
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
x = self.stem(x)
|
|
||||||
x = self.ERBlock_2(x)
|
|
||||||
x = self.ERBlock_3(x)
|
|
||||||
outputs.append(x)
|
|
||||||
x = self.ERBlock_4(x)
|
|
||||||
outputs.append(x)
|
|
||||||
x = self.ERBlock_5(x)
|
|
||||||
outputs.append(x)
|
|
||||||
|
|
||||||
return tuple(outputs)
|
|
@ -1,211 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import math
|
|
||||||
from yolov6.layers.common import *
|
|
||||||
|
|
||||||
|
|
||||||
class Detect(nn.Module):
|
|
||||||
'''Efficient Decoupled Head
|
|
||||||
With hardware-aware degisn, the decoupled head is optimized with
|
|
||||||
hybridchannels methods.
|
|
||||||
'''
|
|
||||||
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
|
|
||||||
super().__init__()
|
|
||||||
assert head_layers is not None
|
|
||||||
self.nc = num_classes # number of classes
|
|
||||||
self.no = num_classes + 5 # number of outputs per anchor
|
|
||||||
self.nl = num_layers # number of detection layers
|
|
||||||
if isinstance(anchors, (list, tuple)):
|
|
||||||
self.na = len(anchors[0]) // 2
|
|
||||||
else:
|
|
||||||
self.na = anchors
|
|
||||||
self.anchors = anchors
|
|
||||||
self.grid = [torch.zeros(1)] * num_layers
|
|
||||||
self.prior_prob = 1e-2
|
|
||||||
self.inplace = inplace
|
|
||||||
stride = [8, 16, 32] # strides computed during build
|
|
||||||
self.stride = torch.tensor(stride)
|
|
||||||
|
|
||||||
# Init decouple head
|
|
||||||
self.cls_convs = nn.ModuleList()
|
|
||||||
self.reg_convs = nn.ModuleList()
|
|
||||||
self.cls_preds = nn.ModuleList()
|
|
||||||
self.reg_preds = nn.ModuleList()
|
|
||||||
self.obj_preds = nn.ModuleList()
|
|
||||||
self.stems = nn.ModuleList()
|
|
||||||
|
|
||||||
# Efficient decoupled head layers
|
|
||||||
for i in range(num_layers):
|
|
||||||
idx = i*6
|
|
||||||
self.stems.append(head_layers[idx])
|
|
||||||
self.cls_convs.append(head_layers[idx+1])
|
|
||||||
self.reg_convs.append(head_layers[idx+2])
|
|
||||||
self.cls_preds.append(head_layers[idx+3])
|
|
||||||
self.reg_preds.append(head_layers[idx+4])
|
|
||||||
self.obj_preds.append(head_layers[idx+5])
|
|
||||||
|
|
||||||
def initialize_biases(self):
|
|
||||||
for conv in self.cls_preds:
|
|
||||||
b = conv.bias.view(self.na, -1)
|
|
||||||
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
|
||||||
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
||||||
for conv in self.obj_preds:
|
|
||||||
b = conv.bias.view(self.na, -1)
|
|
||||||
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
|
||||||
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
z = []
|
|
||||||
for i in range(self.nl):
|
|
||||||
x[i] = self.stems[i](x[i])
|
|
||||||
cls_x = x[i]
|
|
||||||
reg_x = x[i]
|
|
||||||
cls_feat = self.cls_convs[i](cls_x)
|
|
||||||
cls_output = self.cls_preds[i](cls_feat)
|
|
||||||
reg_feat = self.reg_convs[i](reg_x)
|
|
||||||
reg_output = self.reg_preds[i](reg_feat)
|
|
||||||
obj_output = self.obj_preds[i](reg_feat)
|
|
||||||
if self.training:
|
|
||||||
x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
|
|
||||||
bs, _, ny, nx = x[i].shape
|
|
||||||
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
||||||
else:
|
|
||||||
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
|
|
||||||
bs, _, ny, nx = y.shape
|
|
||||||
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
||||||
if self.grid[i].shape[2:4] != y.shape[2:4]:
|
|
||||||
d = self.stride.device
|
|
||||||
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
|
|
||||||
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
|
|
||||||
if self.inplace:
|
|
||||||
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
|
||||||
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
|
||||||
else:
|
|
||||||
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
|
||||||
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
|
||||||
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
|
||||||
z.append(y.view(bs, -1, self.no))
|
|
||||||
return x if self.training else torch.cat(z, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def build_effidehead_layer(channels_list, num_anchors, num_classes):
|
|
||||||
head_layers = nn.Sequential(
|
|
||||||
# stem0
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=channels_list[6],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_conv0
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=channels_list[6],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# reg_conv0
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=channels_list[6],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_pred0
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=num_classes * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# reg_pred0
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=4 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# obj_pred0
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=1 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# stem1
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=channels_list[8],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_conv1
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=channels_list[8],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# reg_conv1
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=channels_list[8],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_pred1
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=num_classes * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# reg_pred1
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=4 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# obj_pred1
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=1 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# stem2
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=channels_list[10],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_conv2
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=channels_list[10],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# reg_conv2
|
|
||||||
Conv(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=channels_list[10],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1
|
|
||||||
),
|
|
||||||
# cls_pred2
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=num_classes * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# reg_pred2
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=4 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
),
|
|
||||||
# obj_pred2
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=channels_list[10],
|
|
||||||
out_channels=1 * num_anchors,
|
|
||||||
kernel_size=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return head_layers
|
|
@ -1,151 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import random
|
|
||||||
|
|
||||||
class ORT_NMS(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx,
|
|
||||||
boxes,
|
|
||||||
scores,
|
|
||||||
max_output_boxes_per_class=torch.tensor([100]),
|
|
||||||
iou_threshold=torch.tensor([0.45]),
|
|
||||||
score_threshold=torch.tensor([0.25])):
|
|
||||||
device = boxes.device
|
|
||||||
batch = scores.shape[0]
|
|
||||||
num_det = random.randint(0, 100)
|
|
||||||
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
|
|
||||||
idxs = torch.arange(100, 100 + num_det).to(device)
|
|
||||||
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
|
|
||||||
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
|
|
||||||
selected_indices = selected_indices.to(torch.int64)
|
|
||||||
return selected_indices
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
|
|
||||||
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
|
|
||||||
|
|
||||||
class TRT_NMS(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
boxes,
|
|
||||||
scores,
|
|
||||||
background_class=-1,
|
|
||||||
box_coding=0,
|
|
||||||
iou_threshold=0.45,
|
|
||||||
max_output_boxes=100,
|
|
||||||
plugin_version="1",
|
|
||||||
score_activation=0,
|
|
||||||
score_threshold=0.25,
|
|
||||||
):
|
|
||||||
batch_size, num_boxes, num_classes = scores.shape
|
|
||||||
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
|
||||||
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
|
||||||
det_scores = torch.randn(batch_size, max_output_boxes)
|
|
||||||
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
|
||||||
|
|
||||||
return num_det, det_boxes, det_scores, det_classes
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(g,
|
|
||||||
boxes,
|
|
||||||
scores,
|
|
||||||
background_class=-1,
|
|
||||||
box_coding=0,
|
|
||||||
iou_threshold=0.45,
|
|
||||||
max_output_boxes=100,
|
|
||||||
plugin_version="1",
|
|
||||||
score_activation=0,
|
|
||||||
score_threshold=0.25):
|
|
||||||
out = g.op("TRT::EfficientNMS_TRT",
|
|
||||||
boxes,
|
|
||||||
scores,
|
|
||||||
background_class_i=background_class,
|
|
||||||
box_coding_i=box_coding,
|
|
||||||
iou_threshold_f=iou_threshold,
|
|
||||||
max_output_boxes_i=max_output_boxes,
|
|
||||||
plugin_version_s=plugin_version,
|
|
||||||
score_activation_i=score_activation,
|
|
||||||
score_threshold_f=score_threshold,
|
|
||||||
outputs=4)
|
|
||||||
nums, boxes, scores, classes = out
|
|
||||||
return nums,boxes,scores,classes
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ONNX_ORT(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device if device else torch.device("cpu")
|
|
||||||
self.max_obj = torch.tensor([max_obj]).to(device)
|
|
||||||
self.iou_threshold = torch.tensor([iou_thres]).to(device)
|
|
||||||
self.score_threshold = torch.tensor([score_thres]).to(device)
|
|
||||||
self.max_wh = max_wh
|
|
||||||
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
box = x[:, :, :4]
|
|
||||||
conf = x[:, :, 4:5]
|
|
||||||
score = x[:, :, 5:]
|
|
||||||
score *= conf
|
|
||||||
box @= self.convert_matrix
|
|
||||||
objScore, objCls = score.max(2, keepdim=True)
|
|
||||||
dis = objCls.float() * self.max_wh
|
|
||||||
nmsbox = box + dis
|
|
||||||
objScore1 = objScore.transpose(1, 2).contiguous()
|
|
||||||
selected_indices = ORT_NMS.apply(nmsbox, objScore1, self.max_obj, self.iou_threshold, self.score_threshold)
|
|
||||||
X, Y = selected_indices[:, 0], selected_indices[:, 2]
|
|
||||||
resBoxes = box[X, Y, :]
|
|
||||||
resClasses = objCls[X, Y, :].float()
|
|
||||||
resScores = objScore[X, Y, :]
|
|
||||||
X = X.unsqueeze(1).float()
|
|
||||||
return torch.concat([X, resBoxes, resClasses, resScores], 1)
|
|
||||||
|
|
||||||
class ONNX_TRT(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
|
|
||||||
super().__init__()
|
|
||||||
assert max_wh is None
|
|
||||||
self.device = device if device else torch.device('cpu')
|
|
||||||
self.background_class = -1,
|
|
||||||
self.box_coding = 0,
|
|
||||||
self.iou_threshold = iou_thres
|
|
||||||
self.max_obj = max_obj
|
|
||||||
self.plugin_version = '1'
|
|
||||||
self.score_activation = 0
|
|
||||||
self.score_threshold = score_thres
|
|
||||||
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
box = x[:, :, :4]
|
|
||||||
conf = x[:, :, 4:5]
|
|
||||||
score = x[:, :, 5:]
|
|
||||||
score *= conf
|
|
||||||
box @= self.convert_matrix
|
|
||||||
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(box, score, self.background_class, self.box_coding,
|
|
||||||
self.iou_threshold, self.max_obj,
|
|
||||||
self.plugin_version, self.score_activation,
|
|
||||||
self.score_threshold)
|
|
||||||
return num_det, det_boxes, det_scores, det_classes
|
|
||||||
|
|
||||||
|
|
||||||
class End2End(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
|
|
||||||
super().__init__()
|
|
||||||
device = device if device else torch.device('cpu')
|
|
||||||
self.model = model.to(device)
|
|
||||||
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
|
|
||||||
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
|
|
||||||
self.end2end.eval()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.model(x)
|
|
||||||
x = self.end2end(x)
|
|
||||||
return x
|
|
@ -1,411 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
# The code is based on
|
|
||||||
# https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
|
|
||||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from yolov6.utils.figure_iou import IOUloss, pairwise_bbox_iou
|
|
||||||
|
|
||||||
|
|
||||||
class ComputeLoss:
|
|
||||||
'''Loss computation func.
|
|
||||||
This func contains SimOTA and siou loss.
|
|
||||||
'''
|
|
||||||
def __init__(self,
|
|
||||||
reg_weight=5.0,
|
|
||||||
iou_weight=3.0,
|
|
||||||
cls_weight=1.0,
|
|
||||||
center_radius=2.5,
|
|
||||||
eps=1e-7,
|
|
||||||
in_channels=[256, 512, 1024],
|
|
||||||
strides=[8, 16, 32],
|
|
||||||
n_anchors=1,
|
|
||||||
iou_type='ciou'
|
|
||||||
):
|
|
||||||
|
|
||||||
self.reg_weight = reg_weight
|
|
||||||
self.iou_weight = iou_weight
|
|
||||||
self.cls_weight = cls_weight
|
|
||||||
|
|
||||||
self.center_radius = center_radius
|
|
||||||
self.eps = eps
|
|
||||||
self.n_anchors = n_anchors
|
|
||||||
self.strides = strides
|
|
||||||
self.grids = [torch.zeros(1)] * len(in_channels)
|
|
||||||
|
|
||||||
# Define criteria
|
|
||||||
self.l1_loss = nn.L1Loss(reduction="none")
|
|
||||||
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
|
||||||
self.iou_loss = IOUloss(iou_type=iou_type, reduction="none")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
outputs,
|
|
||||||
targets
|
|
||||||
):
|
|
||||||
dtype = outputs[0].type()
|
|
||||||
device = targets.device
|
|
||||||
loss_cls, loss_obj, loss_iou, loss_l1 = torch.zeros(1, device=device), torch.zeros(1, device=device), \
|
|
||||||
torch.zeros(1, device=device), torch.zeros(1, device=device)
|
|
||||||
num_classes = outputs[0].shape[-1] - 5
|
|
||||||
|
|
||||||
outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(
|
|
||||||
outputs, self.strides, dtype, device)
|
|
||||||
|
|
||||||
total_num_anchors = outputs.shape[1]
|
|
||||||
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
|
|
||||||
bbox_preds_org = outputs_origin[:, :, :4] # [batch, n_anchors_all, 4]
|
|
||||||
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
|
|
||||||
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
|
|
||||||
|
|
||||||
# targets
|
|
||||||
batch_size = bbox_preds.shape[0]
|
|
||||||
targets_list = np.zeros((batch_size, 1, 5)).tolist()
|
|
||||||
for i, item in enumerate(targets.cpu().numpy().tolist()):
|
|
||||||
targets_list[int(item[0])].append(item[1:])
|
|
||||||
max_len = max((len(l) for l in targets_list))
|
|
||||||
|
|
||||||
targets = torch.from_numpy(np.array(list(map(lambda l:l + [[-1,0,0,0,0]]*(max_len - len(l)), targets_list)))[:,1:,:]).to(targets.device)
|
|
||||||
num_targets_list = (targets.sum(dim=2) > 0).sum(dim=1) # number of objects
|
|
||||||
|
|
||||||
num_fg, num_gts = 0, 0
|
|
||||||
cls_targets, reg_targets, l1_targets, obj_targets, fg_masks = [], [], [], [], []
|
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
num_gt = int(num_targets_list[batch_idx])
|
|
||||||
num_gts += num_gt
|
|
||||||
if num_gt == 0:
|
|
||||||
cls_target = outputs.new_zeros((0, num_classes))
|
|
||||||
reg_target = outputs.new_zeros((0, 4))
|
|
||||||
l1_target = outputs.new_zeros((0, 4))
|
|
||||||
obj_target = outputs.new_zeros((total_num_anchors, 1))
|
|
||||||
fg_mask = outputs.new_zeros(total_num_anchors).bool()
|
|
||||||
else:
|
|
||||||
|
|
||||||
gt_bboxes_per_image = targets[batch_idx, :num_gt, 1:5].mul_(gt_bboxes_scale)
|
|
||||||
gt_classes = targets[batch_idx, :num_gt, 0]
|
|
||||||
bboxes_preds_per_image = bbox_preds[batch_idx]
|
|
||||||
cls_preds_per_image = cls_preds[batch_idx]
|
|
||||||
obj_preds_per_image = obj_preds[batch_idx]
|
|
||||||
|
|
||||||
try:
|
|
||||||
(
|
|
||||||
gt_matched_classes,
|
|
||||||
fg_mask,
|
|
||||||
pred_ious_this_matching,
|
|
||||||
matched_gt_inds,
|
|
||||||
num_fg_img,
|
|
||||||
) = self.get_assignments(
|
|
||||||
batch_idx,
|
|
||||||
num_gt,
|
|
||||||
total_num_anchors,
|
|
||||||
gt_bboxes_per_image,
|
|
||||||
gt_classes,
|
|
||||||
bboxes_preds_per_image,
|
|
||||||
cls_preds_per_image,
|
|
||||||
obj_preds_per_image,
|
|
||||||
expanded_strides,
|
|
||||||
xy_shifts,
|
|
||||||
num_classes
|
|
||||||
)
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
print(
|
|
||||||
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
|
|
||||||
CPU mode is applied in this batch. If you want to avoid this issue, \
|
|
||||||
try to reduce the batch size or image size."
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print("------------CPU Mode for This Batch-------------")
|
|
||||||
|
|
||||||
_gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
|
|
||||||
_gt_classes = gt_classes.cpu().float()
|
|
||||||
_bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
|
|
||||||
_cls_preds_per_image = cls_preds_per_image.cpu().float()
|
|
||||||
_obj_preds_per_image = obj_preds_per_image.cpu().float()
|
|
||||||
|
|
||||||
_expanded_strides = expanded_strides.cpu().float()
|
|
||||||
_xy_shifts = xy_shifts.cpu()
|
|
||||||
|
|
||||||
(
|
|
||||||
gt_matched_classes,
|
|
||||||
fg_mask,
|
|
||||||
pred_ious_this_matching,
|
|
||||||
matched_gt_inds,
|
|
||||||
num_fg_img,
|
|
||||||
) = self.get_assignments(
|
|
||||||
batch_idx,
|
|
||||||
num_gt,
|
|
||||||
total_num_anchors,
|
|
||||||
_gt_bboxes_per_image,
|
|
||||||
_gt_classes,
|
|
||||||
_bboxes_preds_per_image,
|
|
||||||
_cls_preds_per_image,
|
|
||||||
_obj_preds_per_image,
|
|
||||||
_expanded_strides,
|
|
||||||
_xy_shifts,
|
|
||||||
num_classes
|
|
||||||
)
|
|
||||||
|
|
||||||
gt_matched_classes = gt_matched_classes.cuda()
|
|
||||||
fg_mask = fg_mask.cuda()
|
|
||||||
pred_ious_this_matching = pred_ious_this_matching.cuda()
|
|
||||||
matched_gt_inds = matched_gt_inds.cuda()
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
num_fg += num_fg_img
|
|
||||||
if num_fg_img > 0:
|
|
||||||
cls_target = F.one_hot(
|
|
||||||
gt_matched_classes.to(torch.int64), num_classes
|
|
||||||
) * pred_ious_this_matching.unsqueeze(-1)
|
|
||||||
obj_target = fg_mask.unsqueeze(-1)
|
|
||||||
reg_target = gt_bboxes_per_image[matched_gt_inds]
|
|
||||||
|
|
||||||
l1_target = self.get_l1_target(
|
|
||||||
outputs.new_zeros((num_fg_img, 4)),
|
|
||||||
gt_bboxes_per_image[matched_gt_inds],
|
|
||||||
expanded_strides[0][fg_mask],
|
|
||||||
xy_shifts=xy_shifts[0][fg_mask],
|
|
||||||
)
|
|
||||||
|
|
||||||
cls_targets.append(cls_target)
|
|
||||||
reg_targets.append(reg_target)
|
|
||||||
obj_targets.append(obj_target)
|
|
||||||
l1_targets.append(l1_target)
|
|
||||||
fg_masks.append(fg_mask)
|
|
||||||
|
|
||||||
cls_targets = torch.cat(cls_targets, 0)
|
|
||||||
reg_targets = torch.cat(reg_targets, 0)
|
|
||||||
obj_targets = torch.cat(obj_targets, 0)
|
|
||||||
l1_targets = torch.cat(l1_targets, 0)
|
|
||||||
fg_masks = torch.cat(fg_masks, 0)
|
|
||||||
|
|
||||||
num_fg = max(num_fg, 1)
|
|
||||||
# loss
|
|
||||||
loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
|
|
||||||
loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
|
|
||||||
|
|
||||||
loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
|
|
||||||
loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
|
|
||||||
|
|
||||||
total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls
|
|
||||||
return total_losses, torch.cat((self.reg_weight * loss_iou, loss_l1, loss_obj, loss_cls)).detach()
|
|
||||||
|
|
||||||
def decode_output(self, output, k, stride, dtype, device):
|
|
||||||
grid = self.grids[k].to(device)
|
|
||||||
batch_size = output.shape[0]
|
|
||||||
hsize, wsize = output.shape[2:4]
|
|
||||||
if grid.shape[2:4] != output.shape[2:4]:
|
|
||||||
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
|
||||||
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)
|
|
||||||
self.grids[k] = grid
|
|
||||||
|
|
||||||
output = output.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
|
|
||||||
output_origin = output.clone()
|
|
||||||
grid = grid.view(1, -1, 2)
|
|
||||||
|
|
||||||
output[..., :2] = (output[..., :2] + grid) * stride
|
|
||||||
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
|
||||||
|
|
||||||
return output, output_origin, grid, hsize, wsize
|
|
||||||
|
|
||||||
def get_outputs_and_grids(self, outputs, strides, dtype, device):
|
|
||||||
xy_shifts = []
|
|
||||||
expanded_strides = []
|
|
||||||
outputs_new = []
|
|
||||||
outputs_origin = []
|
|
||||||
|
|
||||||
for k, output in enumerate(outputs):
|
|
||||||
output, output_origin, grid, feat_h, feat_w = self.decode_output(
|
|
||||||
output, k, strides[k], dtype, device)
|
|
||||||
|
|
||||||
xy_shift = grid
|
|
||||||
expanded_stride = torch.full((1, grid.shape[1], 1), strides[k], dtype=grid.dtype, device=grid.device)
|
|
||||||
|
|
||||||
xy_shifts.append(xy_shift)
|
|
||||||
expanded_strides.append(expanded_stride)
|
|
||||||
outputs_new.append(output)
|
|
||||||
outputs_origin.append(output_origin)
|
|
||||||
|
|
||||||
xy_shifts = torch.cat(xy_shifts, 1) # [1, n_anchors_all, 2]
|
|
||||||
expanded_strides = torch.cat(expanded_strides, 1) # [1, n_anchors_all, 1]
|
|
||||||
outputs_origin = torch.cat(outputs_origin, 1)
|
|
||||||
outputs = torch.cat(outputs_new, 1)
|
|
||||||
|
|
||||||
feat_h *= strides[-1]
|
|
||||||
feat_w *= strides[-1]
|
|
||||||
gt_bboxes_scale = torch.Tensor([[feat_w, feat_h, feat_w, feat_h]]).type_as(outputs)
|
|
||||||
|
|
||||||
return outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides
|
|
||||||
|
|
||||||
def get_l1_target(self, l1_target, gt, stride, xy_shifts, eps=1e-8):
|
|
||||||
|
|
||||||
l1_target[:, 0:2] = gt[:, 0:2] / stride - xy_shifts
|
|
||||||
l1_target[:, 2:4] = torch.log(gt[:, 2:4] / stride + eps)
|
|
||||||
return l1_target
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_assignments(
|
|
||||||
self,
|
|
||||||
batch_idx,
|
|
||||||
num_gt,
|
|
||||||
total_num_anchors,
|
|
||||||
gt_bboxes_per_image,
|
|
||||||
gt_classes,
|
|
||||||
bboxes_preds_per_image,
|
|
||||||
cls_preds_per_image,
|
|
||||||
obj_preds_per_image,
|
|
||||||
expanded_strides,
|
|
||||||
xy_shifts,
|
|
||||||
num_classes
|
|
||||||
):
|
|
||||||
|
|
||||||
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
|
|
||||||
gt_bboxes_per_image,
|
|
||||||
expanded_strides,
|
|
||||||
xy_shifts,
|
|
||||||
total_num_anchors,
|
|
||||||
num_gt,
|
|
||||||
)
|
|
||||||
|
|
||||||
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
|
||||||
cls_preds_ = cls_preds_per_image[fg_mask]
|
|
||||||
obj_preds_ = obj_preds_per_image[fg_mask]
|
|
||||||
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
|
|
||||||
|
|
||||||
# cost
|
|
||||||
pair_wise_ious = pairwise_bbox_iou(gt_bboxes_per_image, bboxes_preds_per_image, box_format='xywh')
|
|
||||||
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
|
||||||
|
|
||||||
gt_cls_per_image = (
|
|
||||||
F.one_hot(gt_classes.to(torch.int64), num_classes)
|
|
||||||
.float()
|
|
||||||
.unsqueeze(1)
|
|
||||||
.repeat(1, num_in_boxes_anchor, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
cls_preds_ = (
|
|
||||||
cls_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
|
|
||||||
* obj_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
|
|
||||||
)
|
|
||||||
pair_wise_cls_loss = F.binary_cross_entropy(
|
|
||||||
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
|
|
||||||
).sum(-1)
|
|
||||||
del cls_preds_, obj_preds_
|
|
||||||
|
|
||||||
cost = (
|
|
||||||
self.cls_weight * pair_wise_cls_loss
|
|
||||||
+ self.iou_weight * pair_wise_ious_loss
|
|
||||||
+ 100000.0 * (~is_in_boxes_and_center)
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
num_fg,
|
|
||||||
gt_matched_classes,
|
|
||||||
pred_ious_this_matching,
|
|
||||||
matched_gt_inds,
|
|
||||||
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
|
||||||
|
|
||||||
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
|
||||||
|
|
||||||
return (
|
|
||||||
gt_matched_classes,
|
|
||||||
fg_mask,
|
|
||||||
pred_ious_this_matching,
|
|
||||||
matched_gt_inds,
|
|
||||||
num_fg,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_in_boxes_info(
|
|
||||||
self,
|
|
||||||
gt_bboxes_per_image,
|
|
||||||
expanded_strides,
|
|
||||||
xy_shifts,
|
|
||||||
total_num_anchors,
|
|
||||||
num_gt,
|
|
||||||
):
|
|
||||||
expanded_strides_per_image = expanded_strides[0]
|
|
||||||
xy_shifts_per_image = xy_shifts[0] * expanded_strides_per_image
|
|
||||||
xy_centers_per_image = (
|
|
||||||
(xy_shifts_per_image + 0.5 * expanded_strides_per_image)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.repeat(num_gt, 1, 1)
|
|
||||||
) # [n_anchor, 2] -> [n_gt, n_anchor, 2]
|
|
||||||
|
|
||||||
gt_bboxes_per_image_lt = (
|
|
||||||
(gt_bboxes_per_image[:, 0:2] - 0.5 * gt_bboxes_per_image[:, 2:4])
|
|
||||||
.unsqueeze(1)
|
|
||||||
.repeat(1, total_num_anchors, 1)
|
|
||||||
)
|
|
||||||
gt_bboxes_per_image_rb = (
|
|
||||||
(gt_bboxes_per_image[:, 0:2] + 0.5 * gt_bboxes_per_image[:, 2:4])
|
|
||||||
.unsqueeze(1)
|
|
||||||
.repeat(1, total_num_anchors, 1)
|
|
||||||
) # [n_gt, 2] -> [n_gt, n_anchor, 2]
|
|
||||||
|
|
||||||
b_lt = xy_centers_per_image - gt_bboxes_per_image_lt
|
|
||||||
b_rb = gt_bboxes_per_image_rb - xy_centers_per_image
|
|
||||||
bbox_deltas = torch.cat([b_lt, b_rb], 2)
|
|
||||||
|
|
||||||
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
|
||||||
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
|
||||||
|
|
||||||
# in fixed center
|
|
||||||
gt_bboxes_per_image_lt = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
|
|
||||||
1, total_num_anchors, 1
|
|
||||||
) - self.center_radius * expanded_strides_per_image.unsqueeze(0)
|
|
||||||
gt_bboxes_per_image_rb = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
|
|
||||||
1, total_num_anchors, 1
|
|
||||||
) + self.center_radius * expanded_strides_per_image.unsqueeze(0)
|
|
||||||
|
|
||||||
c_lt = xy_centers_per_image - gt_bboxes_per_image_lt
|
|
||||||
c_rb = gt_bboxes_per_image_rb - xy_centers_per_image
|
|
||||||
center_deltas = torch.cat([c_lt, c_rb], 2)
|
|
||||||
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
|
||||||
is_in_centers_all = is_in_centers.sum(dim=0) > 0
|
|
||||||
|
|
||||||
# in boxes and in centers
|
|
||||||
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
|
||||||
|
|
||||||
is_in_boxes_and_center = (
|
|
||||||
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
|
||||||
)
|
|
||||||
return is_in_boxes_anchor, is_in_boxes_and_center
|
|
||||||
|
|
||||||
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
|
||||||
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
|
|
||||||
ious_in_boxes_matrix = pair_wise_ious
|
|
||||||
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
|
|
||||||
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
|
||||||
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
|
||||||
dynamic_ks = dynamic_ks.tolist()
|
|
||||||
|
|
||||||
for gt_idx in range(num_gt):
|
|
||||||
_, pos_idx = torch.topk(
|
|
||||||
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
|
|
||||||
)
|
|
||||||
matching_matrix[gt_idx][pos_idx] = 1
|
|
||||||
del topk_ious, dynamic_ks, pos_idx
|
|
||||||
|
|
||||||
anchor_matching_gt = matching_matrix.sum(0)
|
|
||||||
if (anchor_matching_gt > 1).sum() > 0:
|
|
||||||
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
|
||||||
matching_matrix[:, anchor_matching_gt > 1] *= 0
|
|
||||||
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
|
|
||||||
fg_mask_inboxes = matching_matrix.sum(0) > 0
|
|
||||||
num_fg = fg_mask_inboxes.sum().item()
|
|
||||||
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
|
||||||
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
|
||||||
gt_matched_classes = gt_classes[matched_gt_inds]
|
|
||||||
|
|
||||||
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
|
|
||||||
fg_mask_inboxes
|
|
||||||
]
|
|
||||||
|
|
||||||
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
|
@ -1,108 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from yolov6.layers.common import RepBlock, SimConv, Transpose
|
|
||||||
|
|
||||||
|
|
||||||
class RepPANNeck(nn.Module):
|
|
||||||
"""RepPANNeck Module
|
|
||||||
EfficientRep is the default backbone of this model.
|
|
||||||
RepPANNeck has the balance of feature fusion ability and hardware efficiency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels_list=None,
|
|
||||||
num_repeats=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert channels_list is not None
|
|
||||||
assert num_repeats is not None
|
|
||||||
|
|
||||||
self.Rep_p4 = RepBlock(
|
|
||||||
in_channels=channels_list[3] + channels_list[5],
|
|
||||||
out_channels=channels_list[5],
|
|
||||||
n=num_repeats[5],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Rep_p3 = RepBlock(
|
|
||||||
in_channels=channels_list[2] + channels_list[6],
|
|
||||||
out_channels=channels_list[6],
|
|
||||||
n=num_repeats[6]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Rep_n3 = RepBlock(
|
|
||||||
in_channels=channels_list[6] + channels_list[7],
|
|
||||||
out_channels=channels_list[8],
|
|
||||||
n=num_repeats[7],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Rep_n4 = RepBlock(
|
|
||||||
in_channels=channels_list[5] + channels_list[9],
|
|
||||||
out_channels=channels_list[10],
|
|
||||||
n=num_repeats[8]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reduce_layer0 = SimConv(
|
|
||||||
in_channels=channels_list[4],
|
|
||||||
out_channels=channels_list[5],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.upsample0 = Transpose(
|
|
||||||
in_channels=channels_list[5],
|
|
||||||
out_channels=channels_list[5],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reduce_layer1 = SimConv(
|
|
||||||
in_channels=channels_list[5],
|
|
||||||
out_channels=channels_list[6],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.upsample1 = Transpose(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=channels_list[6]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.downsample2 = SimConv(
|
|
||||||
in_channels=channels_list[6],
|
|
||||||
out_channels=channels_list[7],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
self.downsample1 = SimConv(
|
|
||||||
in_channels=channels_list[8],
|
|
||||||
out_channels=channels_list[9],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
(x2, x1, x0) = input
|
|
||||||
|
|
||||||
fpn_out0 = self.reduce_layer0(x0)
|
|
||||||
upsample_feat0 = self.upsample0(fpn_out0)
|
|
||||||
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1)
|
|
||||||
f_out0 = self.Rep_p4(f_concat_layer0)
|
|
||||||
|
|
||||||
fpn_out1 = self.reduce_layer1(f_out0)
|
|
||||||
upsample_feat1 = self.upsample1(fpn_out1)
|
|
||||||
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1)
|
|
||||||
pan_out2 = self.Rep_p3(f_concat_layer1)
|
|
||||||
|
|
||||||
down_feat1 = self.downsample2(pan_out2)
|
|
||||||
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1)
|
|
||||||
pan_out1 = self.Rep_n3(p_concat_layer1)
|
|
||||||
|
|
||||||
down_feat0 = self.downsample1(pan_out1)
|
|
||||||
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1)
|
|
||||||
pan_out0 = self.Rep_n4(p_concat_layer2)
|
|
||||||
|
|
||||||
outputs = [pan_out2, pan_out1, pan_out0]
|
|
||||||
|
|
||||||
return outputs
|
|
@ -1,83 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import math
|
|
||||||
import torch.nn as nn
|
|
||||||
from yolov6.layers.common import *
|
|
||||||
from yolov6.utils.torch_utils import initialize_weights
|
|
||||||
from yolov6.models.efficientrep import EfficientRep
|
|
||||||
from yolov6.models.reppan import RepPANNeck
|
|
||||||
from yolov6.models.effidehead import Detect, build_effidehead_layer
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
|
||||||
'''YOLOv6 model with backbone, neck and head.
|
|
||||||
The default parts are EfficientRep Backbone, Rep-PAN and
|
|
||||||
Efficient Decoupled Head.
|
|
||||||
'''
|
|
||||||
def __init__(self, config, channels=3, num_classes=None, anchors=None): # model, input channels, number of classes
|
|
||||||
super().__init__()
|
|
||||||
# Build network
|
|
||||||
num_layers = config.model.head.num_layers
|
|
||||||
self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers)
|
|
||||||
|
|
||||||
# Init Detect head
|
|
||||||
begin_indices = config.model.head.begin_indices
|
|
||||||
out_indices_head = config.model.head.out_indices
|
|
||||||
self.stride = self.detect.stride
|
|
||||||
self.detect.i = begin_indices
|
|
||||||
self.detect.f = out_indices_head
|
|
||||||
self.detect.initialize_biases()
|
|
||||||
|
|
||||||
# Init weights
|
|
||||||
initialize_weights(self)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.backbone(x)
|
|
||||||
x = self.neck(x)
|
|
||||||
x = self.detect(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _apply(self, fn):
|
|
||||||
self = super()._apply(fn)
|
|
||||||
self.detect.stride = fn(self.detect.stride)
|
|
||||||
self.detect.grid = list(map(fn, self.detect.grid))
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(x, divisor):
|
|
||||||
# Upward revision the value x to make it evenly divisible by the divisor.
|
|
||||||
return math.ceil(x / divisor) * divisor
|
|
||||||
|
|
||||||
|
|
||||||
def build_network(config, channels, num_classes, anchors, num_layers):
|
|
||||||
depth_mul = config.model.depth_multiple
|
|
||||||
width_mul = config.model.width_multiple
|
|
||||||
num_repeat_backbone = config.model.backbone.num_repeats
|
|
||||||
channels_list_backbone = config.model.backbone.out_channels
|
|
||||||
num_repeat_neck = config.model.neck.num_repeats
|
|
||||||
channels_list_neck = config.model.neck.out_channels
|
|
||||||
num_anchors = config.model.head.anchors
|
|
||||||
num_repeat = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]
|
|
||||||
channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_backbone + channels_list_neck)]
|
|
||||||
|
|
||||||
backbone = EfficientRep(
|
|
||||||
in_channels=channels,
|
|
||||||
channels_list=channels_list,
|
|
||||||
num_repeats=num_repeat
|
|
||||||
)
|
|
||||||
|
|
||||||
neck = RepPANNeck(
|
|
||||||
channels_list=channels_list,
|
|
||||||
num_repeats=num_repeat
|
|
||||||
)
|
|
||||||
|
|
||||||
head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes)
|
|
||||||
|
|
||||||
head = Detect(num_classes, anchors, num_layers, head_layers=head_layers)
|
|
||||||
|
|
||||||
return backbone, neck, head
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(cfg, num_classes, device):
|
|
||||||
model = Model(cfg, channels=3, num_classes=num_classes, anchors=cfg.model.head.anchors).to(device)
|
|
||||||
return model
|
|
@ -1,42 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(cfg, model):
|
|
||||||
""" Build optimizer from cfg file."""
|
|
||||||
g_bnw, g_w, g_b = [], [], []
|
|
||||||
for v in model.modules():
|
|
||||||
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
|
|
||||||
g_b.append(v.bias)
|
|
||||||
if isinstance(v, nn.BatchNorm2d):
|
|
||||||
g_bnw.append(v.weight)
|
|
||||||
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
|
|
||||||
g_w.append(v.weight)
|
|
||||||
|
|
||||||
assert cfg.solver.optim == 'SGD' or 'Adam', 'ERROR: unknown optimizer, use SGD defaulted'
|
|
||||||
if cfg.solver.optim == 'SGD':
|
|
||||||
optimizer = torch.optim.SGD(g_bnw, lr=cfg.solver.lr0, momentum=cfg.solver.momentum, nesterov=True)
|
|
||||||
elif cfg.solver.optim == 'Adam':
|
|
||||||
optimizer = torch.optim.Adam(g_bnw, lr=cfg.solver.lr0, betas=(cfg.solver.momentum, 0.999))
|
|
||||||
|
|
||||||
optimizer.add_param_group({'params': g_w, 'weight_decay': cfg.solver.weight_decay})
|
|
||||||
optimizer.add_param_group({'params': g_b})
|
|
||||||
|
|
||||||
del g_bnw, g_w, g_b
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(cfg, optimizer, epochs):
|
|
||||||
"""Build learning rate scheduler from cfg file."""
|
|
||||||
if cfg.solver.lr_scheduler == 'Cosine':
|
|
||||||
lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.solver.lrf - 1) + 1
|
|
||||||
else:
|
|
||||||
LOGGER.error('unknown lr scheduler, use Cosine defaulted')
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
|
||||||
return scheduler, lf
|
|
@ -1,60 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import torch
|
|
||||||
import os.path as osp
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
from yolov6.utils.torch_utils import fuse_model
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(weights, model, map_location=None):
|
|
||||||
"""Load weights from checkpoint file, only assign weights those layers' name and shape are match."""
|
|
||||||
ckpt = torch.load(weights, map_location=map_location)
|
|
||||||
state_dict = ckpt['model'].float().state_dict()
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
del ckpt, state_dict, model_state_dict
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(weights, map_location=None, inplace=True, fuse=True):
|
|
||||||
"""Load model from checkpoint file."""
|
|
||||||
LOGGER.info("Loading checkpoint from {}".format(weights))
|
|
||||||
ckpt = torch.load(weights, map_location=map_location) # load
|
|
||||||
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
|
|
||||||
if fuse:
|
|
||||||
LOGGER.info("\nFusing model...")
|
|
||||||
model = fuse_model(model).eval()
|
|
||||||
else:
|
|
||||||
model = model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(ckpt, is_best, save_dir, model_name=""):
|
|
||||||
""" Save checkpoint to the disk."""
|
|
||||||
if not osp.exists(save_dir):
|
|
||||||
os.makedirs(save_dir)
|
|
||||||
filename = osp.join(save_dir, model_name + '.pt')
|
|
||||||
torch.save(ckpt, filename)
|
|
||||||
if is_best:
|
|
||||||
best_filename = osp.join(save_dir, 'best_ckpt.pt')
|
|
||||||
shutil.copyfile(filename, best_filename)
|
|
||||||
|
|
||||||
|
|
||||||
def strip_optimizer(ckpt_dir, epoch):
|
|
||||||
for s in ['best', 'last']:
|
|
||||||
ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s))
|
|
||||||
if not osp.exists(ckpt_path):
|
|
||||||
continue
|
|
||||||
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
|
||||||
if ckpt.get('ema'):
|
|
||||||
ckpt['model'] = ckpt['ema'] # replace model with ema
|
|
||||||
for k in ['optimizer', 'ema', 'updates']: # keys
|
|
||||||
ckpt[k] = None
|
|
||||||
ckpt['epoch'] = epoch
|
|
||||||
ckpt['model'].half() # to FP16
|
|
||||||
for p in ckpt['model'].parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
torch.save(ckpt, ckpt_path)
|
|
@ -1,101 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# The code is based on
|
|
||||||
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
|
||||||
# Copyright (c) OpenMMLab.
|
|
||||||
|
|
||||||
import os.path as osp
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
from importlib import import_module
|
|
||||||
from addict import Dict
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigDict(Dict):
|
|
||||||
|
|
||||||
def __missing__(self, name):
|
|
||||||
raise KeyError(name)
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
try:
|
|
||||||
value = super(ConfigDict, self).__getattr__(name)
|
|
||||||
except KeyError:
|
|
||||||
ex = AttributeError("'{}' object has no attribute '{}'".format(
|
|
||||||
self.__class__.__name__, name))
|
|
||||||
except Exception as e:
|
|
||||||
ex = e
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _file2dict(filename):
|
|
||||||
filename = str(filename)
|
|
||||||
if filename.endswith('.py'):
|
|
||||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
|
||||||
shutil.copyfile(filename,
|
|
||||||
osp.join(temp_config_dir, '_tempconfig.py'))
|
|
||||||
sys.path.insert(0, temp_config_dir)
|
|
||||||
mod = import_module('_tempconfig')
|
|
||||||
sys.path.pop(0)
|
|
||||||
cfg_dict = {
|
|
||||||
name: value
|
|
||||||
for name, value in mod.__dict__.items()
|
|
||||||
if not name.startswith('__')
|
|
||||||
}
|
|
||||||
# delete imported module
|
|
||||||
del sys.modules['_tempconfig']
|
|
||||||
else:
|
|
||||||
raise IOError('Only .py type are supported now!')
|
|
||||||
cfg_text = filename + '\n'
|
|
||||||
with open(filename, 'r') as f:
|
|
||||||
cfg_text += f.read()
|
|
||||||
|
|
||||||
return cfg_dict, cfg_text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fromfile(filename):
|
|
||||||
cfg_dict, cfg_text = Config._file2dict(filename)
|
|
||||||
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
|
||||||
|
|
||||||
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
|
||||||
if cfg_dict is None:
|
|
||||||
cfg_dict = dict()
|
|
||||||
elif not isinstance(cfg_dict, dict):
|
|
||||||
raise TypeError('cfg_dict must be a dict, but got {}'.format(
|
|
||||||
type(cfg_dict)))
|
|
||||||
|
|
||||||
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
|
||||||
super(Config, self).__setattr__('_filename', filename)
|
|
||||||
if cfg_text:
|
|
||||||
text = cfg_text
|
|
||||||
elif filename:
|
|
||||||
with open(filename, 'r') as f:
|
|
||||||
text = f.read()
|
|
||||||
else:
|
|
||||||
text = ''
|
|
||||||
super(Config, self).__setattr__('_text', text)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def filename(self):
|
|
||||||
return self._filename
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self):
|
|
||||||
return self._text
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'Config (path: {}): {}'.format(self.filename,
|
|
||||||
self._cfg_dict.__repr__())
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._cfg_dict, name)
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
|
||||||
if isinstance(value, dict):
|
|
||||||
value = ConfigDict(value)
|
|
||||||
self._cfg_dict.__setattr__(name, value)
|
|
@ -1,59 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# The code is based on
|
|
||||||
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
|
|
||||||
import math
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
|
||||||
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
|
||||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
||||||
This is intended to allow functionality like
|
|
||||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
||||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
||||||
This class is sensitive where it is initialized in the sequence of model init,
|
|
||||||
GPU assignment and distributed training wrappers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, updates=0):
|
|
||||||
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
|
||||||
self.updates = updates
|
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
|
||||||
for param in self.ema.parameters():
|
|
||||||
param.requires_grad_(False)
|
|
||||||
|
|
||||||
def update(self, model):
|
|
||||||
with torch.no_grad():
|
|
||||||
self.updates += 1
|
|
||||||
decay = self.decay(self.updates)
|
|
||||||
|
|
||||||
state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
|
||||||
for k, item in self.ema.state_dict().items():
|
|
||||||
if item.dtype.is_floating_point:
|
|
||||||
item *= decay
|
|
||||||
item += (1 - decay) * state_dict[k].detach()
|
|
||||||
|
|
||||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
|
||||||
copy_attr(self.ema, model, include, exclude)
|
|
||||||
|
|
||||||
|
|
||||||
def copy_attr(a, b, include=(), exclude=()):
|
|
||||||
"""Copy attributes from one instance and set them to another instance."""
|
|
||||||
for k, item in b.__dict__.items():
|
|
||||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
setattr(a, k, item)
|
|
||||||
|
|
||||||
|
|
||||||
def is_parallel(model):
|
|
||||||
# Return True if model's type is DP or DDP, else False.
|
|
||||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
|
||||||
|
|
||||||
|
|
||||||
def de_parallel(model):
|
|
||||||
# De-parallelize a model. Return single-GPU model if model's type is DP or DDP.
|
|
||||||
return model.module if is_parallel(model) else model
|
|
@ -1,54 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
|
|
||||||
|
|
||||||
def get_envs():
|
|
||||||
"""Get PyTorch needed environments from system envirionments."""
|
|
||||||
local_rank = int(os.getenv('LOCAL_RANK', -1))
|
|
||||||
rank = int(os.getenv('RANK', -1))
|
|
||||||
world_size = int(os.getenv('WORLD_SIZE', 1))
|
|
||||||
return local_rank, rank, world_size
|
|
||||||
|
|
||||||
|
|
||||||
def select_device(device):
|
|
||||||
"""Set devices' information to the program.
|
|
||||||
Args:
|
|
||||||
device: a string, like 'cpu' or '1,2,3,4'
|
|
||||||
Returns:
|
|
||||||
torch.device
|
|
||||||
"""
|
|
||||||
if device == 'cpu':
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
||||||
LOGGER.info('Using CPU for training... ')
|
|
||||||
elif device:
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
|
||||||
assert torch.cuda.is_available()
|
|
||||||
nd = len(device.strip().split(','))
|
|
||||||
LOGGER.info(f'Using {nd} GPU for training... ')
|
|
||||||
cuda = device != 'cpu' and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed, deterministic=False):
|
|
||||||
""" Set random state to random libray, numpy, torch and cudnn.
|
|
||||||
Args:
|
|
||||||
seed: int value.
|
|
||||||
deterministic: bool value.
|
|
||||||
"""
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if deterministic:
|
|
||||||
cudnn.deterministic = True
|
|
||||||
cudnn.benchmark = False
|
|
||||||
else:
|
|
||||||
cudnn.deterministic = False
|
|
||||||
cudnn.benchmark = True
|
|
@ -1,41 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import os
|
|
||||||
import yaml
|
|
||||||
import logging
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
def set_logging(name=None):
|
|
||||||
rank = int(os.getenv('RANK', -1))
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging.INFO if (rank in (-1, 0)) else logging.WARNING)
|
|
||||||
return logging.getLogger(name)
|
|
||||||
|
|
||||||
|
|
||||||
LOGGER = set_logging(__name__)
|
|
||||||
NCOLS = shutil.get_terminal_size().columns
|
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(file_path):
|
|
||||||
"""Load data from yaml file."""
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
with open(file_path, errors='ignore') as f:
|
|
||||||
data_dict = yaml.safe_load(f)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def save_yaml(data_dict, save_path):
|
|
||||||
"""Save data to yaml file"""
|
|
||||||
with open(save_path, 'w') as f:
|
|
||||||
yaml.safe_dump(data_dict, f, sort_keys=False)
|
|
||||||
|
|
||||||
|
|
||||||
def write_tblog(tblogger, epoch, results, losses):
|
|
||||||
"""Display mAP and loss information to log."""
|
|
||||||
tblogger.add_scalar("val/mAP@0.5", results[0], epoch + 1)
|
|
||||||
tblogger.add_scalar("val/mAP@0.50:0.95", results[1], epoch + 1)
|
|
||||||
|
|
||||||
tblogger.add_scalar("train/iou_loss", losses[0], epoch + 1)
|
|
||||||
tblogger.add_scalar("train/l1_loss", losses[1], epoch + 1)
|
|
||||||
tblogger.add_scalar("train/obj_loss", losses[2], epoch + 1)
|
|
||||||
tblogger.add_scalar("train/cls_loss", losses[3], epoch + 1)
|
|
@ -1,114 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class IOUloss:
|
|
||||||
""" Calculate IoU loss.
|
|
||||||
"""
|
|
||||||
def __init__(self, box_format='xywh', iou_type='ciou', reduction='none', eps=1e-7):
|
|
||||||
""" Setting of the class.
|
|
||||||
Args:
|
|
||||||
box_format: (string), must be one of 'xywh' or 'xyxy'.
|
|
||||||
iou_type: (string), can be one of 'ciou', 'diou', 'giou' or 'siou'
|
|
||||||
reduction: (string), specifies the reduction to apply to the output, must be one of 'none', 'mean','sum'.
|
|
||||||
eps: (float), a value to avoid divide by zero error.
|
|
||||||
"""
|
|
||||||
self.box_format = box_format
|
|
||||||
self.iou_type = iou_type.lower()
|
|
||||||
self.reduction = reduction
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def __call__(self, box1, box2):
|
|
||||||
""" calculate iou. box1 and box2 are torch tensor with shape [M, 4] and [Nm 4].
|
|
||||||
"""
|
|
||||||
box2 = box2.T
|
|
||||||
if self.box_format == 'xyxy':
|
|
||||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
|
||||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
|
||||||
elif self.box_format == 'xywh':
|
|
||||||
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
|
||||||
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
|
||||||
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
|
||||||
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
|
||||||
|
|
||||||
# Intersection area
|
|
||||||
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
|
||||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
|
||||||
|
|
||||||
# Union Area
|
|
||||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
|
|
||||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps
|
|
||||||
union = w1 * h1 + w2 * h2 - inter + self.eps
|
|
||||||
iou = inter / union
|
|
||||||
|
|
||||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex width
|
|
||||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
|
||||||
if self.iou_type == 'giou':
|
|
||||||
c_area = cw * ch + self.eps # convex area
|
|
||||||
iou = iou - (c_area - union) / c_area
|
|
||||||
elif self.iou_type in ['diou', 'ciou']:
|
|
||||||
c2 = cw ** 2 + ch ** 2 + self.eps # convex diagonal squared
|
|
||||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
|
||||||
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
|
|
||||||
if self.iou_type == 'diou':
|
|
||||||
iou = iou - rho2 / c2
|
|
||||||
elif self.iou_type == 'ciou':
|
|
||||||
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
|
||||||
with torch.no_grad():
|
|
||||||
alpha = v / (v - iou + (1 + self.eps))
|
|
||||||
iou = iou - (rho2 / c2 + v * alpha)
|
|
||||||
elif self.iou_type == 'siou':
|
|
||||||
# SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
|
|
||||||
s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
|
|
||||||
s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
|
|
||||||
sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
|
|
||||||
sin_alpha_1 = torch.abs(s_cw) / sigma
|
|
||||||
sin_alpha_2 = torch.abs(s_ch) / sigma
|
|
||||||
threshold = pow(2, 0.5) / 2
|
|
||||||
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
|
|
||||||
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
|
|
||||||
rho_x = (s_cw / cw) ** 2
|
|
||||||
rho_y = (s_ch / ch) ** 2
|
|
||||||
gamma = angle_cost - 2
|
|
||||||
distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
|
|
||||||
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
|
|
||||||
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
|
|
||||||
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
|
|
||||||
iou = iou - 0.5 * (distance_cost + shape_cost)
|
|
||||||
loss = 1.0 - iou
|
|
||||||
|
|
||||||
if self.reduction == 'sum':
|
|
||||||
loss = loss.sum()
|
|
||||||
elif self.reduction == 'mean':
|
|
||||||
loss = loss.mean()
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def pairwise_bbox_iou(box1, box2, box_format='xywh'):
|
|
||||||
"""Calculate iou.
|
|
||||||
This code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/boxes.py
|
|
||||||
"""
|
|
||||||
if box_format == 'xyxy':
|
|
||||||
lt = torch.max(box1[:, None, :2], box2[:, :2])
|
|
||||||
rb = torch.min(box1[:, None, 2:], box2[:, 2:])
|
|
||||||
area_1 = torch.prod(box1[:, 2:] - box1[:, :2], 1)
|
|
||||||
area_2 = torch.prod(box2[:, 2:] - box2[:, :2], 1)
|
|
||||||
|
|
||||||
elif box_format == 'xywh':
|
|
||||||
lt = torch.max(
|
|
||||||
(box1[:, None, :2] - box1[:, None, 2:] / 2),
|
|
||||||
(box2[:, :2] - box2[:, 2:] / 2),
|
|
||||||
)
|
|
||||||
rb = torch.min(
|
|
||||||
(box1[:, None, :2] + box1[:, None, 2:] / 2),
|
|
||||||
(box2[:, :2] + box2[:, 2:] / 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
area_1 = torch.prod(box1[:, 2:], 1)
|
|
||||||
area_2 = torch.prod(box2[:, 2:], 1)
|
|
||||||
valid = (lt < rb).type(lt.type()).prod(dim=2)
|
|
||||||
inter = torch.prod(rb - lt, 2) * valid
|
|
||||||
return inter / (area_1[:, None] + area_2 - inter)
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def increment_name(path, master_process):
|
|
||||||
"increase save directory's id"
|
|
||||||
path = Path(path)
|
|
||||||
sep = ''
|
|
||||||
if path.exists() and master_process:
|
|
||||||
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
|
||||||
for n in range(1, 9999):
|
|
||||||
p = f'{path}{sep}{n}{suffix}'
|
|
||||||
if not os.path.exists(p):
|
|
||||||
break
|
|
||||||
path = Path(p)
|
|
||||||
return path
|
|
@ -1,106 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# The code is based on
|
|
||||||
# https://github.com/ultralytics/yolov5/blob/master/utils/general.py
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
|
|
||||||
# Settings
|
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
|
||||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
|
||||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
|
||||||
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
|
||||||
|
|
||||||
|
|
||||||
def xywh2xyxy(x):
|
|
||||||
# Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right
|
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
||||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
|
||||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
|
||||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
|
||||||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
|
|
||||||
"""Runs Non-Maximum Suppression (NMS) on inference results.
|
|
||||||
This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775
|
|
||||||
Args:
|
|
||||||
prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes.
|
|
||||||
conf_thres: (float) confidence threshold.
|
|
||||||
iou_thres: (float) iou threshold.
|
|
||||||
classes: (None or list[int]), if a list is provided, nms only keep the classes you provide.
|
|
||||||
agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively.
|
|
||||||
multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label.
|
|
||||||
max_det:(int), max number of output bboxes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_classes = prediction.shape[2] - 5 # number of classes
|
|
||||||
pred_candidates = prediction[..., 4] > conf_thres # candidates
|
|
||||||
|
|
||||||
# Check the parameters.
|
|
||||||
assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.'
|
|
||||||
assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.'
|
|
||||||
|
|
||||||
# Function settings.
|
|
||||||
max_wh = 4096 # maximum box width and height
|
|
||||||
max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms()
|
|
||||||
time_limit = 10.0 # quit the function when nms cost time exceed the limit time.
|
|
||||||
multi_label &= num_classes > 1 # multiple labels per box
|
|
||||||
|
|
||||||
tik = time.time()
|
|
||||||
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
|
||||||
for img_idx, x in enumerate(prediction): # image index, image inference
|
|
||||||
x = x[pred_candidates[img_idx]] # confidence
|
|
||||||
|
|
||||||
# If no box remains, skip the next process.
|
|
||||||
if not x.shape[0]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# confidence multiply the objectness
|
|
||||||
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
|
||||||
|
|
||||||
# (center x, center y, width, height) to (x1, y1, x2, y2)
|
|
||||||
box = xywh2xyxy(x[:, :4])
|
|
||||||
|
|
||||||
# Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls)
|
|
||||||
if multi_label:
|
|
||||||
box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
|
||||||
x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1)
|
|
||||||
else: # Only keep the class with highest scores.
|
|
||||||
conf, class_idx = x[:, 5:].max(1, keepdim=True)
|
|
||||||
x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres]
|
|
||||||
|
|
||||||
# Filter by class, only keep boxes whose category is in classes.
|
|
||||||
if classes is not None:
|
|
||||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
|
||||||
|
|
||||||
# Check shape
|
|
||||||
num_box = x.shape[0] # number of boxes
|
|
||||||
if not num_box: # no boxes kept.
|
|
||||||
continue
|
|
||||||
elif num_box > max_nms: # excess max boxes' number.
|
|
||||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
|
||||||
|
|
||||||
# Batched NMS
|
|
||||||
class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
|
||||||
boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores
|
|
||||||
keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
|
||||||
if keep_box_idx.shape[0] > max_det: # limit detections
|
|
||||||
keep_box_idx = keep_box_idx[:max_det]
|
|
||||||
|
|
||||||
output[img_idx] = x[keep_box_idx]
|
|
||||||
if (time.time() - tik) > time_limit:
|
|
||||||
print(f'WARNING: NMS cost time exceed the limited {time_limit}s.')
|
|
||||||
break # time limit exceeded
|
|
||||||
|
|
||||||
return output
|
|
@ -1,109 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
import time
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from yolov6.utils.events import LOGGER
|
|
||||||
|
|
||||||
try:
|
|
||||||
import thop # for FLOPs computation
|
|
||||||
except ImportError:
|
|
||||||
thop = None
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
|
||||||
"""
|
|
||||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
|
||||||
"""
|
|
||||||
if local_rank not in [-1, 0]:
|
|
||||||
dist.barrier(device_ids=[local_rank])
|
|
||||||
yield
|
|
||||||
if local_rank == 0:
|
|
||||||
dist.barrier(device_ids=[0])
|
|
||||||
|
|
||||||
|
|
||||||
def time_sync():
|
|
||||||
# Waits for all kernels in all streams on a CUDA device to complete if cuda is available.
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(model):
|
|
||||||
for m in model.modules():
|
|
||||||
t = type(m)
|
|
||||||
if t is nn.Conv2d:
|
|
||||||
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
||||||
elif t is nn.BatchNorm2d:
|
|
||||||
m.eps = 1e-3
|
|
||||||
m.momentum = 0.03
|
|
||||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
|
||||||
m.inplace = True
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
|
||||||
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
|
||||||
fusedconv = (
|
|
||||||
nn.Conv2d(
|
|
||||||
conv.in_channels,
|
|
||||||
conv.out_channels,
|
|
||||||
kernel_size=conv.kernel_size,
|
|
||||||
stride=conv.stride,
|
|
||||||
padding=conv.padding,
|
|
||||||
groups=conv.groups,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
.requires_grad_(False)
|
|
||||||
.to(conv.weight.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare filters
|
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
|
||||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
|
||||||
|
|
||||||
# prepare spatial bias
|
|
||||||
b_conv = (
|
|
||||||
torch.zeros(conv.weight.size(0), device=conv.weight.device)
|
|
||||||
if conv.bias is None
|
|
||||||
else conv.bias
|
|
||||||
)
|
|
||||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
|
|
||||||
torch.sqrt(bn.running_var + bn.eps)
|
|
||||||
)
|
|
||||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
|
||||||
|
|
||||||
return fusedconv
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_model(model):
|
|
||||||
from yolov6.layers.common import Conv
|
|
||||||
|
|
||||||
for m in model.modules():
|
|
||||||
if type(m) is Conv and hasattr(m, "bn"):
|
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
||||||
delattr(m, "bn") # remove batchnorm
|
|
||||||
m.forward = m.forward_fuse # update forward
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_info(model, img_size=640):
|
|
||||||
"""Get model Params and GFlops.
|
|
||||||
Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py
|
|
||||||
"""
|
|
||||||
from thop import profile
|
|
||||||
stride = 32
|
|
||||||
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
|
|
||||||
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
|
|
||||||
params /= 1e6
|
|
||||||
flops /= 1e9
|
|
||||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size]
|
|
||||||
flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops
|
|
||||||
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
|
|
||||||
return info
|
|