You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
5.7 KiB

5 years ago
#include "easypr/train/svm_train.h"
#include "easypr/util/util.h"
#include "easypr/config.h"
#ifdef OS_WINDOWS
#include <ctime>
#endif
using namespace cv;
using namespace cv::ml;
// 原版C++语言 训练代码
namespace easypr {
SvmTrain::SvmTrain(const char* plates_folder, const char* xml): plates_folder_(plates_folder), svm_xml_(xml) {
assert(plates_folder);
assert(xml);
extractFeature = getHistomPlusColoFeatures;
}
void SvmTrain::train() {
svm_ = cv::ml::SVM::create();
svm_->setType(cv::ml::SVM::C_SVC);
svm_->setKernel(cv::ml::SVM::RBF);
svm_->setDegree(0.1);
// 1.4 bug fix: old 1.4 ver gamma is 1
svm_->setGamma(0.1);
svm_->setCoef0(0.1);
svm_->setC(1);
svm_->setNu(0.1);
svm_->setP(0.1);
svm_->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 20000, 0.0001));
this->prepare();
if (train_file_list_.size() == 0) {
fprintf(stdout, "No file found in the train folder!\n");
fprintf(stdout, "You should create a folder named \"tmp\" in EasyPR main folder.\n");
fprintf(stdout, "Copy train data folder(like \"SVM\") under \"tmp\". \n");
return;
}
auto train_data = tdata();
fprintf(stdout, ">> Training SVM model, please wait...\n");
long start = utils::getTimestamp();
svm_->trainAuto(train_data, 10, SVM::getDefaultGrid(SVM::C),
SVM::getDefaultGrid(SVM::GAMMA), SVM::getDefaultGrid(SVM::P),
SVM::getDefaultGrid(SVM::NU), SVM::getDefaultGrid(SVM::COEF),
SVM::getDefaultGrid(SVM::DEGREE), true);
//svm_->train(train_data);
long end = utils::getTimestamp();
fprintf(stdout, ">> Training done. Time elapse: %ldms\n", end - start);
fprintf(stdout, ">> Saving model file...\n");
svm_->save(svm_xml_);
fprintf(stdout, ">> Your SVM Model was saved to %s\n", svm_xml_);
fprintf(stdout, ">> Testing...\n");
this->test();
}
void SvmTrain::test() {
// 1.4 bug fix: old 1.4 ver there is no null judge
// if (NULL == svm_)
LOAD_SVM_MODEL(svm_, svm_xml_);
if (test_file_list_.empty()) {
this->prepare();
}
double count_all = test_file_list_.size();
double ptrue_rtrue = 0;
double ptrue_rfalse = 0;
double pfalse_rtrue = 0;
double pfalse_rfalse = 0;
for (auto item : test_file_list_) {
auto image = cv::imread(item.file);
if (!image.data) {
std::cout << "no" << std::endl;
continue;
}
cv::Mat feature;
extractFeature(image, feature);
auto predict = int(svm_->predict(feature));
//std::cout << "predict: " << predict << std::endl;
auto real = item.label;
if (predict == kForward && real == kForward) ptrue_rtrue++;
if (predict == kForward && real == kInverse) ptrue_rfalse++;
if (predict == kInverse && real == kForward) pfalse_rtrue++;
if (predict == kInverse && real == kInverse) pfalse_rfalse++;
}
std::cout << "count_all: " << count_all << std::endl;
std::cout << "ptrue_rtrue: " << ptrue_rtrue << std::endl;
std::cout << "ptrue_rfalse: " << ptrue_rfalse << std::endl;
std::cout << "pfalse_rtrue: " << pfalse_rtrue << std::endl;
std::cout << "pfalse_rfalse: " << pfalse_rfalse << std::endl;
double precise = 0;
if (ptrue_rtrue + ptrue_rfalse != 0) {
precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse);
std::cout << "precise: " << precise << std::endl;
} else {
std::cout << "precise: "
<< "NA" << std::endl;
}
double recall = 0;
if (ptrue_rtrue + pfalse_rtrue != 0) {
recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue);
std::cout << "recall: " << recall << std::endl;
} else {
std::cout << "recall: "
<< "NA" << std::endl;
}
double Fsocre = 0;
if (precise + recall != 0) {
Fsocre = 2 * (precise * recall) / (precise + recall);
std::cout << "Fsocre: " << Fsocre << std::endl;
} else {
std::cout << "Fsocre: "
<< "NA" << std::endl;
}
}
void SvmTrain::prepare() {
srand(unsigned(time(NULL)));
char buffer[260] = {0};
sprintf(buffer, "%s/has/train", plates_folder_);
auto has_file_train_list = utils::getFiles(buffer);
std::random_shuffle(has_file_train_list.begin(), has_file_train_list.end());
sprintf(buffer, "%s/has/test", plates_folder_);
auto has_file_test_list = utils::getFiles(buffer);
std::random_shuffle(has_file_test_list.begin(), has_file_test_list.end());
sprintf(buffer, "%s/no/train", plates_folder_);
auto no_file_train_list = utils::getFiles(buffer);
std::random_shuffle(no_file_train_list.begin(), no_file_train_list.end());
sprintf(buffer, "%s/no/test", plates_folder_);
auto no_file_test_list = utils::getFiles(buffer);
std::random_shuffle(no_file_test_list.begin(), no_file_test_list.end());
fprintf(stdout, ">> Collecting train data...\n");
for (auto file : has_file_train_list)
train_file_list_.push_back({ file, kForward });
for (auto file : no_file_train_list)
train_file_list_.push_back({ file, kInverse });
fprintf(stdout, ">> Collecting test data...\n");
for (auto file : has_file_test_list)
test_file_list_.push_back({ file, kForward });
for (auto file : no_file_test_list)
test_file_list_.push_back({ file, kInverse });
}
cv::Ptr<cv::ml::TrainData> SvmTrain::tdata() {
cv::Mat samples;
std::vector<int> responses;
for (auto f : train_file_list_) {
auto image = cv::imread(f.file);
if (!image.data) {
fprintf(stdout, ">> Invalid image: %s ignore.\n", f.file.c_str());
continue;
}
cv::Mat feature;
extractFeature(image, feature);
feature = feature.reshape(1, 1);
samples.push_back(feature);
responses.push_back(int(f.label));
}
cv::Mat samples_, responses_;
samples.convertTo(samples_, CV_32FC1);
cv::Mat(responses).copyTo(responses_);
return cv::ml::TrainData::create(samples_, cv::ml::SampleTypes::ROW_SAMPLE, responses_);
}
} // namespace easypr