parent
6111a17f48
commit
129fe13a27
@ -0,0 +1,31 @@
|
||||
zh_cuan 川
|
||||
zh_gan1 甘
|
||||
zh_hei 黑
|
||||
zh_jin 津
|
||||
zh_liao 辽
|
||||
zh_min 闽
|
||||
zh_qiong 琼
|
||||
zh_sx 晋
|
||||
zh_xin 新
|
||||
zh_yue 粤
|
||||
zh_zhe 浙
|
||||
zh_e 鄂
|
||||
zh_gui 贵
|
||||
zh_hu 沪
|
||||
zh_jing 京
|
||||
zh_lu 鲁
|
||||
zh_ning 宁
|
||||
zh_shan 陕
|
||||
zh_wan 皖
|
||||
zh_yu 豫
|
||||
zh_yun 云
|
||||
zh_gan 赣
|
||||
zh_gui1 桂
|
||||
zh_ji 冀
|
||||
zh_jl 吉
|
||||
zh_meng 蒙
|
||||
zh_qing 青
|
||||
zh_su 苏
|
||||
zh_xiang 湘
|
||||
zh_yu1 渝
|
||||
zh_zang 藏
|
@ -0,0 +1,196 @@
|
||||
#include "easypr/train/svm_train.h"
|
||||
#include "easypr/util/util.h"
|
||||
#include "easypr/config.h"
|
||||
|
||||
#ifdef OS_WINDOWS
|
||||
#include <ctime>
|
||||
#endif
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
|
||||
|
||||
// 原版C++语言 训练代码
|
||||
namespace easypr {
|
||||
|
||||
SvmTrain::SvmTrain(const char* plates_folder, const char* xml): plates_folder_(plates_folder), svm_xml_(xml) {
|
||||
assert(plates_folder);
|
||||
assert(xml);
|
||||
extractFeature = getHistomPlusColoFeatures;
|
||||
}
|
||||
|
||||
void SvmTrain::train() {
|
||||
svm_ = cv::ml::SVM::create();
|
||||
svm_->setType(cv::ml::SVM::C_SVC);
|
||||
svm_->setKernel(cv::ml::SVM::RBF);
|
||||
svm_->setDegree(0.1);
|
||||
// 1.4 bug fix: old 1.4 ver gamma is 1
|
||||
svm_->setGamma(0.1);
|
||||
svm_->setCoef0(0.1);
|
||||
svm_->setC(1);
|
||||
svm_->setNu(0.1);
|
||||
svm_->setP(0.1);
|
||||
svm_->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 20000, 0.0001));
|
||||
|
||||
this->prepare();
|
||||
|
||||
if (train_file_list_.size() == 0) {
|
||||
fprintf(stdout, "No file found in the train folder!\n");
|
||||
fprintf(stdout, "You should create a folder named \"tmp\" in EasyPR main folder.\n");
|
||||
fprintf(stdout, "Copy train data folder(like \"SVM\") under \"tmp\". \n");
|
||||
return;
|
||||
}
|
||||
auto train_data = tdata();
|
||||
|
||||
fprintf(stdout, ">> Training SVM model, please wait...\n");
|
||||
long start = utils::getTimestamp();
|
||||
svm_->trainAuto(train_data, 10, SVM::getDefaultGrid(SVM::C),
|
||||
SVM::getDefaultGrid(SVM::GAMMA), SVM::getDefaultGrid(SVM::P),
|
||||
SVM::getDefaultGrid(SVM::NU), SVM::getDefaultGrid(SVM::COEF),
|
||||
SVM::getDefaultGrid(SVM::DEGREE), true);
|
||||
//svm_->train(train_data);
|
||||
|
||||
long end = utils::getTimestamp();
|
||||
fprintf(stdout, ">> Training done. Time elapse: %ldms\n", end - start);
|
||||
fprintf(stdout, ">> Saving model file...\n");
|
||||
svm_->save(svm_xml_);
|
||||
|
||||
fprintf(stdout, ">> Your SVM Model was saved to %s\n", svm_xml_);
|
||||
fprintf(stdout, ">> Testing...\n");
|
||||
|
||||
this->test();
|
||||
|
||||
}
|
||||
|
||||
void SvmTrain::test() {
|
||||
// 1.4 bug fix: old 1.4 ver there is no null judge
|
||||
// if (NULL == svm_)
|
||||
LOAD_SVM_MODEL(svm_, svm_xml_);
|
||||
|
||||
if (test_file_list_.empty()) {
|
||||
this->prepare();
|
||||
}
|
||||
|
||||
double count_all = test_file_list_.size();
|
||||
double ptrue_rtrue = 0;
|
||||
double ptrue_rfalse = 0;
|
||||
double pfalse_rtrue = 0;
|
||||
double pfalse_rfalse = 0;
|
||||
|
||||
for (auto item : test_file_list_) {
|
||||
auto image = cv::imread(item.file);
|
||||
if (!image.data) {
|
||||
std::cout << "no" << std::endl;
|
||||
continue;
|
||||
}
|
||||
cv::Mat feature;
|
||||
extractFeature(image, feature);
|
||||
|
||||
auto predict = int(svm_->predict(feature));
|
||||
//std::cout << "predict: " << predict << std::endl;
|
||||
|
||||
auto real = item.label;
|
||||
if (predict == kForward && real == kForward) ptrue_rtrue++;
|
||||
if (predict == kForward && real == kInverse) ptrue_rfalse++;
|
||||
if (predict == kInverse && real == kForward) pfalse_rtrue++;
|
||||
if (predict == kInverse && real == kInverse) pfalse_rfalse++;
|
||||
}
|
||||
|
||||
std::cout << "count_all: " << count_all << std::endl;
|
||||
std::cout << "ptrue_rtrue: " << ptrue_rtrue << std::endl;
|
||||
std::cout << "ptrue_rfalse: " << ptrue_rfalse << std::endl;
|
||||
std::cout << "pfalse_rtrue: " << pfalse_rtrue << std::endl;
|
||||
std::cout << "pfalse_rfalse: " << pfalse_rfalse << std::endl;
|
||||
|
||||
double precise = 0;
|
||||
if (ptrue_rtrue + ptrue_rfalse != 0) {
|
||||
precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse);
|
||||
std::cout << "precise: " << precise << std::endl;
|
||||
} else {
|
||||
std::cout << "precise: "
|
||||
<< "NA" << std::endl;
|
||||
}
|
||||
|
||||
double recall = 0;
|
||||
if (ptrue_rtrue + pfalse_rtrue != 0) {
|
||||
recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue);
|
||||
std::cout << "recall: " << recall << std::endl;
|
||||
} else {
|
||||
std::cout << "recall: "
|
||||
<< "NA" << std::endl;
|
||||
}
|
||||
|
||||
double Fsocre = 0;
|
||||
if (precise + recall != 0) {
|
||||
Fsocre = 2 * (precise * recall) / (precise + recall);
|
||||
std::cout << "Fsocre: " << Fsocre << std::endl;
|
||||
} else {
|
||||
std::cout << "Fsocre: "
|
||||
<< "NA" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void SvmTrain::prepare() {
|
||||
srand(unsigned(time(NULL)));
|
||||
|
||||
char buffer[260] = {0};
|
||||
|
||||
sprintf(buffer, "%s/has/train", plates_folder_);
|
||||
auto has_file_train_list = utils::getFiles(buffer);
|
||||
std::random_shuffle(has_file_train_list.begin(), has_file_train_list.end());
|
||||
|
||||
sprintf(buffer, "%s/has/test", plates_folder_);
|
||||
auto has_file_test_list = utils::getFiles(buffer);
|
||||
std::random_shuffle(has_file_test_list.begin(), has_file_test_list.end());
|
||||
|
||||
sprintf(buffer, "%s/no/train", plates_folder_);
|
||||
auto no_file_train_list = utils::getFiles(buffer);
|
||||
std::random_shuffle(no_file_train_list.begin(), no_file_train_list.end());
|
||||
|
||||
sprintf(buffer, "%s/no/test", plates_folder_);
|
||||
auto no_file_test_list = utils::getFiles(buffer);
|
||||
std::random_shuffle(no_file_test_list.begin(), no_file_test_list.end());
|
||||
|
||||
fprintf(stdout, ">> Collecting train data...\n");
|
||||
|
||||
for (auto file : has_file_train_list)
|
||||
train_file_list_.push_back({ file, kForward });
|
||||
|
||||
for (auto file : no_file_train_list)
|
||||
train_file_list_.push_back({ file, kInverse });
|
||||
|
||||
fprintf(stdout, ">> Collecting test data...\n");
|
||||
|
||||
for (auto file : has_file_test_list)
|
||||
test_file_list_.push_back({ file, kForward });
|
||||
|
||||
for (auto file : no_file_test_list)
|
||||
test_file_list_.push_back({ file, kInverse });
|
||||
}
|
||||
|
||||
cv::Ptr<cv::ml::TrainData> SvmTrain::tdata() {
|
||||
cv::Mat samples;
|
||||
std::vector<int> responses;
|
||||
|
||||
for (auto f : train_file_list_) {
|
||||
auto image = cv::imread(f.file);
|
||||
if (!image.data) {
|
||||
fprintf(stdout, ">> Invalid image: %s ignore.\n", f.file.c_str());
|
||||
continue;
|
||||
}
|
||||
cv::Mat feature;
|
||||
extractFeature(image, feature);
|
||||
feature = feature.reshape(1, 1);
|
||||
|
||||
samples.push_back(feature);
|
||||
responses.push_back(int(f.label));
|
||||
}
|
||||
|
||||
cv::Mat samples_, responses_;
|
||||
samples.convertTo(samples_, CV_32FC1);
|
||||
cv::Mat(responses).copyTo(responses_);
|
||||
|
||||
return cv::ml::TrainData::create(samples_, cv::ml::SampleTypes::ROW_SAMPLE, responses_);
|
||||
}
|
||||
|
||||
} // namespace easypr
|
Loading…
Reference in new issue