no commit message

devA
yuxue 5 years ago
parent 17c56d0ecb
commit f5774cc35e

@ -9,9 +9,6 @@ import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Rect;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_ml.SVM;
import org.opencv.core.CvType;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
/**
@ -53,15 +50,19 @@ public class PlateJudge {
float ret = svm.predict(features);
return (int) ret;
/*opencv_imgproc.cvtColor(inMat, inMat, Imgproc.COLOR_BGR2GRAY);
Mat features = new Mat();
opencv_imgproc.Canny(inMat, features, 130, 250);
Mat p = features.reshape(1, 1);
p.convertTo(p, opencv_core.CV_32FC1);
float ret = svm.predict(p);
/*// 使用com.yuxue.test.PlateDetectTrainTest 生成的训练库文件
// 在使用的过程中,传入的样本切图要跟训练的时候处理切图的方法一致
Mat grayImage = new Mat();
opencv_imgproc.cvtColor(inMat, grayImage, opencv_imgproc.CV_RGB2GRAY);
Mat dst = new Mat();
opencv_imgproc.Canny(grayImage, dst, 130, 250);
Mat samples = dst.reshape(1, 1);
samples.convertTo(samples, opencv_core.CV_32FC1);
// 如果训练时使用这个标识那么符合的图像会返回9.0
float ret = svm.predict(samples);
return (int) ret;*/
}
/**

@ -0,0 +1,239 @@
package com.yuxue.test;
import static org.bytedeco.javacpp.opencv_core.CV_32F;
import static org.bytedeco.javacpp.opencv_core.CV_32FC1;
import static org.bytedeco.javacpp.opencv_core.CV_32SC1;
import static org.bytedeco.javacpp.opencv_core.getTickCount;
import static org.bytedeco.javacpp.opencv_imgproc.resize;
import java.util.Vector;
import org.bytedeco.javacpp.opencv_core.CvMemStorage;
import org.bytedeco.javacpp.opencv_core.FileStorage;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Scalar;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_ml.ANN_MLP;
import com.yuxue.easypr.core.CoreFunc;
import com.yuxue.enumtype.Direction;
import com.yuxue.util.Convert;
import com.yuxue.util.FileUtil;
import ch.qos.logback.classic.pattern.Util;
/*
*
*/
public class ANNTrain {
/*private ANN_MLP ann=ANN_MLP.create();
// 中国车牌
private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F', 'G', 'H', I
'J', 'K', 'L', 'M', 'N', O 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' };
private final int numCharacter = 34; I0,1024
// 以下都是我训练时用到的中文字符数据,并不全面,有些省份没有训练数据所以没有字符
// 有些后面加数字2的表示在训练时常看到字符的一种变形也作为训练数据存储
private final String strChinese[] = { "zh_cuan" , "zh_e" , "zh_gan" , "zh_hei" ,
"zh_hu" , "zh_ji" , "zh_jl" , "zh_jin" , "zh_jing" , "zh_shan" ,
"zh_liao" , "zh_lu" , "zh_min" , "zh_ning" , "zh_su" , "zh_sx" ,
"zh_wan" , "zh_yu" , "zh_yue" , "zh_zhe" };
private final int numAll = 54; 34+20=54
public Mat features(Mat in, int sizeData) {
// Histogram features
float[] vhist = CoreFunc.projectedHistogram(in, Direction.VERTICAL);
float[] hhist = CoreFunc.projectedHistogram(in, Direction.HORIZONTAL);
// Low data feature
Mat lowData = new Mat();
resize(in, lowData, new Size(sizeData, sizeData));
// Last 10 is the number of moments components
int numCols = vhist.length + hhist.length + lowData.cols() * lowData.cols();
Mat out = Mat.zeros(1, numCols, CV_32F).asMat();
// Asign values to feature,ANN的样本特征为水平、垂直直方图和低分辨率图像所组成的矢量
int j = 0;
for (int i = 0; i < vhist.length; i++, ++j) {
out.ptr(j).put(Convert.getBytes(vhist[i]));
}
for (int i = 0; i < hhist.length; i++, ++j) {
out.ptr(j).put(Convert.getBytes(hhist[i]));
}
for (int x = 0; x < lowData.cols(); x++) {
for (int y = 0; y < lowData.rows(); y++, ++j) {
float val = lowData.ptr(x, y).get() & 0xFF;
out.ptr(j).put(Convert.getBytes(val));
}
}
// if(DEBUG)
// cout << out << "\n===========================================\n";
return out;
}
public void annTrain(Mat TrainData, Mat classes, int nNeruns) {
ann.clear();
Mat layers = new Mat(1, 3, CV_32SC1);
layers.ptr(0).put(Convert.getBytes(TrainData.cols()));
layers.ptr(1).put(Convert.getBytes(nNeruns));
layers.ptr(2).put(Convert.getBytes(numAll));
ann.create(layers, ANN_MLP.SIGMOID_SYM, 1, 1);
// Prepare trainClases
// Create a mat with n trained data by m classes
Mat trainClasses = new Mat();
trainClasses.create(TrainData.rows(), numAll, CV_32FC1);
for (int i = 0; i < trainClasses.rows(); i++) {
for (int k = 0; k < trainClasses.cols(); k++) {
// If class of data i is same than a k class
if (k == Convert.toInt(classes.ptr(i)))
trainClasses.ptr(i, k).put(Convert.getBytes(1f));
else
trainClasses.ptr(i, k).put(Convert.getBytes(0f));
}
}
Mat weights = new Mat(1, TrainData.rows(), CV_32FC1, Scalar.all(1));
// Learn classifier
ann.train(TrainData, trainClasses, weights);
}
public int saveTrainData() {
System.out.println("Begin saveTrainData");
Mat classes = new Mat();
Mat trainingDataf5 = new Mat();
Mat trainingDataf10 = new Mat();
Mat trainingDataf15 = new Mat();
Mat trainingDataf20 = new Mat();
Vector<Integer> trainingLabels = new Vector<Integer>();
String path = "res/train/data/chars_recognise_ann/chars2/chars2";
for (int i = 0; i < numCharacter; i++) {
System.out.println("Character: " + strCharacters[i]);
String str = path + '/' + strCharacters[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
int size = (int) files.size();
for (int j = 0; j < size; j++) {
System.out.println(files.get(j));
Mat img = imread(files.get(j), 0);
Mat f5 = features(img, 5);
Mat f10 = features(img, 10);
Mat f15 = features(img, 15);
Mat f20 = features(img, 20);
trainingDataf5.push_back(f5);
trainingDataf10.push_back(f10);
trainingDataf15.push_back(f15);
trainingDataf20.push_back(f20);
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
}
path = "res/train/data/chars_recognise_ann/charsChinese/charsChinese";
for (int i = 0; i < strChinese.length; i++) {
System.out.println("Character: " + strChinese[i]);
String str = path + '/' + strChinese[i];
Vector<String> files = new Vector<String>();
Util.getFiles(str, files);
int size = (int) files.size();
for (int j = 0; j < size; j++) {
System.out.println(files.get(j));
Mat img = imread(files.get(j), 0);
Mat f5 = features(img, 5);
Mat f10 = features(img, 10);
Mat f15 = features(img, 15);
Mat f20 = features(img, 20);
trainingDataf5.push_back(f5);
trainingDataf10.push_back(f10);
trainingDataf15.push_back(f15);
trainingDataf20.push_back(f20);
trainingLabels.add(i + numCharacter);
}
}
trainingDataf5.convertTo(trainingDataf5, CV_32FC1);
trainingDataf10.convertTo(trainingDataf10, CV_32FC1);
trainingDataf15.convertTo(trainingDataf15, CV_32FC1);
trainingDataf20.convertTo(trainingDataf20, CV_32FC1);
int[] labels = new int[trainingLabels.size()];
for (int i = 0; i < labels.length; ++i)
labels[i] = trainingLabels.get(i).intValue();
new Mat(labels).copyTo(classes);
FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.WRITE);
fs.writeObj("TrainingDataF5", trainingDataf5.data());
fs.writeObj("TrainingDataF10", trainingDataf10.data());
fs.writeObj("TrainingDataF15", trainingDataf15.data());
fs.writeObj("TrainingDataF20", trainingDataf20.data());
fs.writeObj("classes", classes.data());
fs.release();
System.out.println("End saveTrainData");
return 0;
}
public void saveModel(int _predictsize, int _neurons) {
FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.READ);
String training = "TrainingDataF" + _predictsize;
Mat TrainingData = new Mat(fs.get(training).readObj());
Mat Classes = new Mat(fs.get("classes"));
// train the Ann
System.out.println("Begin to saveModelChar predictSize:" + Integer.valueOf(_predictsize).toString());
System.out.println(" neurons:" + Integer.valueOf(_neurons).toString());
long start = getTickCount();
annTrain(TrainingData, Classes, _neurons);
long end = getTickCount();
System.out.println("GetTickCount:" + Long.valueOf((end - start) / 1000).toString());
System.out.println("End the saveModelChar");
String model_name = "res/train/ann.xml";
// if(1)
// {
// String str =
// String.format("ann_prd:%d\tneu:%d",_predictsize,_neurons);
// model_name = str;
// }
CvFileStorage fsto = CvFileStorage.open(model_name, CvMemStorage.create(), CV_STORAGE_WRITE);
ann.write(fsto, "ann");
}
public int annMain() {
System.out.println("To be begin.");
saveTrainData();
// 可根据需要训练不同的predictSize或者neurons的ANN模型
// for (int i = 2; i <= 2; i ++)
// {
// int size = i * 5;
// for (int j = 5; j <= 10; j++)
// {
// int neurons = j * 10;
// saveModel(size, neurons);
// }
// }
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型。
// 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。
saveModel(10, 40);
System.out.println("To be end.");
return 0;
}*/
}

@ -39,7 +39,6 @@ public class PlateDetectTrainTest {
public static void main(String[] arg) {
// 正样本 // 136 × 36 像素 训练的源图像文件要相同大小
List<File> imgList1 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/HasPlate"), Constant.DEFAULT_TYPE, false);
@ -84,14 +83,14 @@ public class PlateDetectTrainTest {
// 失败案例:这里我试图用 get(row,col,data)方法获取数组,但是结果和这个结果不一样,原因未知。
float[] arr = new float[dst.rows() * dst.cols()];
int l = 0;
for (int j = 0; j < dst.rows(); j++) {
for (int k = 0; k < dst.cols(); k++) {
for (int j = 0; j < dst.rows(); j++) { // 遍历行
for (int k = 0; k < dst.cols(); k++) { // 遍历列
double[] a = dst.get(j, k);
arr[l] = (float) a[0];
l++;
}
}
trainingDataMat.put(i, 0, arr);
trainingDataMat.put(i, 0, arr); // 多张图合并到一张
}
String module = DEFAULT_PATH + "svm.xml";

@ -1,5 +1,6 @@
package com.yuxue.test;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
@ -51,20 +52,22 @@ public class PlatePridectTest {
Mat dst = new Mat();
Imgproc.Canny(src, dst, 130, 250);
Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32FC1);
// 转换 src 图像的 cvtype
// 失败案例:我试图用 dst.convertTo(src, CvType.CV_32FC1); 转换,但是失败了,原因未知。猜测: 内部的数据类型没有转换?
float[] dataArr = new float[dst.cols() * dst.rows()];
for (int i = 0, f = 0; i < dst.rows(); i++) {
for (int j = 0; j < dst.cols(); j++) {
double pixel = dst.get(i, j)[0];
dataArr[f] = (float) pixel;
f++;
Mat samples = dst.reshape(1, 1);
samples.convertTo(samples, CvType.CV_32FC1);
// 等价于上面两行代码
/*Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32FC1);
float[] arr = new float[dst.cols() * dst.rows()];
int l = 0;
for (int j = 0; j < dst.rows(); j++) { // 遍历行
for (int k = 0; k < dst.cols(); k++) { // 遍历列
double[] a = dst.get(j, k);
arr[l] = (float) a[0];
l++;
}
}
samples.put(0, 0, arr);*/
Imgcodecs.imwrite(DEFAULT_PATH + "test_1.jpg", samples);
samples.put(0, 0, dataArr);
// 如果训练时使用这个标识那么符合的图像会返回9.0
float flag = svm.predict(samples);

@ -0,0 +1,382 @@
package com.yuxue.test;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_highgui.*;
import static org.bytedeco.javacpp.opencv_ml.*;
import java.util.*;
/*
* Created by fanwenjie
* @version 1.1
*/
public class SVMTrain {
/*private SVMCallback callback = new Features();
private static final String hasPlate = "HasPlate";
private static final String noPlate = "NoPlate";
public SVMTrain(SVMCallback callback){
this.callback = callback;
}
public SVMTrain(){}
private void learn2Plate(float bound, final String name) {
final String filePath = "res/train/data/plate_detect_svm/learn/" + name;
Vector<String> files = new Vector<String>();
////获取该路径下的所有文件
Util.getFiles(filePath, files);
int size = files.size();
if (0 == size) {
System.out.println("File not found in " + filePath);
return;
}
Collections.shuffle(files, new Random(new Date().getTime()));
////随机选取70%作为训练数据30%作为测试数据
int boundry = (int) (bound * size);
Util.recreateDir("res/train/data/plate_detect_svm/train/" + name);
Util.recreateDir("res/train/data/plate_detect_svm/test/" + name);
System.out.println("Save " + name + " train!");
for (int i = 0; i < boundry; i++) {
System.out.println(files.get(i));
Mat img = imread(files.get(i));
String str = "res/train/data/plate_detect_svm/train/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg";
imwrite(str, img);
}
System.out.println("Save " + name + " test!");
for (int i = boundry; i < size; i++) {
System.out.println(files.get(i));
Mat img = imread(files.get(i));
String str = "res/train/data/plate_detect_svm/test/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg";
imwrite(str, img);
}
}
private void getPlateTrain(Mat trainingImages, Vector<Integer> trainingLabels, final String name) {
int label = 1;
final String filePath = "res/train/data/plate_detect_svm/train/" + name;
Vector<String> files = new Vector<String>();
////获取该路径下的所有文件
Util.getFiles(filePath, files);
int size = files.size();
if (0 == size) {
System.out.println("File not found in " + filePath);
return;
}
System.out.println("get " + name + " train!");
for (int i = 0; i < size; i++) {
//System.out.println(files[i].c_str()).toString());
Mat img = imread(files.get(i));
//调用回调函数决定特征
Mat features = this.callback.getHisteqFeatures(img);
features = features.reshape(1, 1);
trainingImages.push_back(features);
trainingLabels.add(label);
}
}
private void getPlateTest(MatVector testingImages,Vector<Integer> testingLabels,final String name){
int label = 1;
final String filePath = "res/train/data/plate_detect_svm/test/"+name;
Vector<String> files = new Vector<String>();
Util.getFiles(filePath, files);
int size = files.size();
if (0 == size) {
System.out.println("File not found in " + filePath);
return;
}
System.out.println("get "+name+" test!");
for (int i = 0; i < size; i++)
{
Mat img = imread(files.get(i));
testingImages.put(img);
testingLabels.add(label);
}
}
public void learn2HasPlate() {
learn2HasPlate(0.7f);
}
public void learn2HasPlate(float bound) {
learn2Plate(bound, hasPlate);
}
public void learn2NoPlate() {
learn2NoPlate(0.7f);
}
public void learn2NoPlate(float bound) {
learn2Plate(bound, noPlate);
}
public void getNoPlateTrain(Mat trainingImages, Vector<Integer> trainingLabels) {
getPlateTrain(trainingImages, trainingLabels, noPlate);
}
public void getHasPlateTrain(Mat trainingImages, Vector<Integer> trainingLabels) {
getPlateTrain(trainingImages, trainingLabels, hasPlate);
}
public void getHasPlateTest(MatVector testingImages,Vector<Integer> testingLabels)
{
getPlateTest(testingImages,testingLabels,hasPlate);
}
public void getNoPlateTest(MatVector testingImages,Vector<Integer> testingLabels)
{
getPlateTest(testingImages,testingLabels,noPlate);
}
//! 测试SVM的准确率回归率以及FScore
public void getAccuracy(Mat testingclasses_preditc, Mat testingclasses_real)
{
int channels = testingclasses_preditc.channels();
System.out.println("channels: "+Integer.valueOf(channels).toString());
int nRows = testingclasses_preditc.rows();
System.out.println("nRows: "+Integer.valueOf(nRows).toString());
int nCols = testingclasses_preditc.cols() * channels;
System.out.println("nCols: "+Integer.valueOf(nCols).toString());
int channels_real = testingclasses_real.channels();
System.out.println("channels_real: "+Integer.valueOf(channels_real).toString());
int nRows_real = testingclasses_real.rows();
System.out.println("nRows_real: " + Integer.valueOf(nRows_real).toString());
int nCols_real = testingclasses_real.cols() * channels;
System.out.println("nCols_real: "+Integer.valueOf(nCols_real).toString());
double count_all = 0;
double ptrue_rtrue = 0;
double ptrue_rfalse = 0;
double pfalse_rtrue = 0;
double pfalse_rfalse = 0;
for (int i = 0; i < nRows; i++)
{
final float predict = Convert.toFloat(testingclasses_preditc.ptr(i));
final float real = Convert.toFloat(testingclasses_real.ptr(i));
count_all ++;
//System.out.println("predict:" << predict).toString());
//System.out.println("real:" << real).toString());
if (predict == 1.0 && real == 1.0)
ptrue_rtrue ++;
if (predict == 1.0 && real == 0)
ptrue_rfalse ++;
if (predict == 0 && real == 1.0)
pfalse_rtrue ++;
if (predict == 0 && real == 0)
pfalse_rfalse ++;
}
System.out.println("count_all: "+Double.valueOf(count_all).toString());
System.out.println("ptrue_rtrue: "+Double.valueOf(ptrue_rtrue).toString());
System.out.println("ptrue_rfalse: "+Double.valueOf(ptrue_rfalse).toString());
System.out.println("pfalse_rtrue: "+Double.valueOf(pfalse_rtrue).toString());
System.out.println("pfalse_rfalse: "+Double.valueOf(pfalse_rfalse).toString());
double precise = 0;
if (ptrue_rtrue + ptrue_rfalse != 0)
{
precise = ptrue_rtrue/(ptrue_rtrue + ptrue_rfalse);
System.out.println("precise: "+Double.valueOf(precise).toString());
}
else
{
System.out.println("precise: NA");
}
double recall = 0;
if (ptrue_rtrue + pfalse_rtrue != 0)
{
recall = ptrue_rtrue/(ptrue_rtrue + pfalse_rtrue);
System.out.println("recall: "+Double.valueOf(recall).toString());
}
else
{
System.out.println("recall: NA");
}
if (precise + recall != 0)
{
double F = (precise * recall)/(precise + recall);
System.out.println("F: "+Double.valueOf(F).toString());
}
else
{
System.out.println("F: NA");
}
}
public int svmTrain(boolean dividePrepared, boolean trainPrepared)
{
Mat classes = new Mat();
Mat trainingData = new Mat();
Mat trainingImages = new Mat();
Vector<Integer> trainingLabels = new Vector<Integer>();
if (!dividePrepared)
{
//分割learn里的数据到train和test里
System.out.println("Divide learn to train and test");
learn2HasPlate();
learn2NoPlate();
}
//将训练数据加载入内存
if (!trainPrepared)
{
System.out.print("Begin to get train data to memory");
getHasPlateTrain(trainingImages, trainingLabels);
getNoPlateTrain(trainingImages, trainingLabels);
trainingImages.copyTo(trainingData);
trainingData.convertTo(trainingData, CV_32FC1);
int []labels = new int[trainingLabels.size()];
for(int i=0;i<trainingLabels.size();++i)
labels[i] = trainingLabels.get(i).intValue();
new Mat(labels).copyTo(classes);
}
//Test SVM
MatVector testingImages = new MatVector();
Vector<Integer> testingLabels_real = new Vector<Integer>();
//将测试数据加载入内存
System.out.println("Begin to get test data to memory");
getHasPlateTest(testingImages, testingLabels_real);
getNoPlateTest(testingImages, testingLabels_real);
CvSVM svm = new CvSVM();
if (!trainPrepared && !classes.empty() && !trainingData.empty())
{
CvSVMParams SVM_params = new CvSVMParams(CvSVM.C_SVC,CvSVM.RBF,0.1,1,0.1,1,0.1,0.1,
new CvMat(),new CvTermCriteria().type(CV_TERMCRIT_ITER).max_iter(100000).epsilon(0.0001));
//Train SVM
System.out.println("Begin to generate svm");
try {
//CvSVM svm(trainingData, classes, Mat(), Mat(), SVM_params);
svm.train_auto(trainingData, classes, new Mat(), new Mat(), SVM_params, 10,
CvSVM.get_default_grid(CvSVM.C),
CvSVM.get_default_grid(CvSVM.GAMMA),
CvSVM.get_default_grid(CvSVM.P),
CvSVM.get_default_grid(CvSVM.NU),
CvSVM.get_default_grid(CvSVM.COEF),
CvSVM.get_default_grid(CvSVM.DEGREE),
true);
} catch (Exception err) {
System.out.println(err.getMessage());
}
System.out.println("Svm generate done!");
CvFileStorage fsTo = CvFileStorage.open("res/rain/svm.xml", CvMemStorage.create(),CV_STORAGE_WRITE);
svm.write(fsTo, "svm");
}
else
{
try {
String path = "res/train/svm.xml";
svm.load(path, "svm");
} catch (Exception err) {
System.out.println(err.getMessage());
return 0; //next predict requires svm
}
}
System.out.println("Begin to predict");
double count_all = 0;
double ptrue_rtrue = 0;
double ptrue_rfalse = 0;
double pfalse_rtrue = 0;
double pfalse_rfalse = 0;
int size = (int)testingImages.size();
for (int i = 0; i < size; i++)
{
//System.out.println(files[i].c_str());
Mat p = testingImages.get(i);
//调用回调函数决定特征
Mat features = callback.getHistogramFeatures(p);
features = features.reshape(1, 1);
features.convertTo(features, CV_32FC1);
int predict = (int)svm.predict(features);
int real = testingLabels_real.get(i);
if (predict == 1 && real == 1)
ptrue_rtrue ++;
if (predict == 1 && real == 0)
ptrue_rfalse ++;
if (predict == 0 && real == 1)
pfalse_rtrue ++;
if (predict == 0 && real == 0)
pfalse_rfalse ++;
}
count_all = size;
System.out.println("Get the Accuracy!");
System.out.println("count_all: "+Double.valueOf(count_all).toString());
System.out.println("ptrue_rtrue: "+Double.valueOf(ptrue_rtrue).toString());
System.out.println("ptrue_rfalse: "+Double.valueOf(ptrue_rfalse).toString());
System.out.println("pfalse_rtrue: "+Double.valueOf(pfalse_rtrue).toString());
System.out.println("pfalse_rfalse: "+Double.valueOf(pfalse_rfalse).toString());
double precise = 0;
if (ptrue_rtrue + ptrue_rfalse != 0)
{
precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse);
System.out.println("precise: "+Double.valueOf(precise).toString());
}
else
System.out.println("precise: NA");
double recall = 0;
if (ptrue_rtrue + pfalse_rtrue != 0)
{
recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue);
System.out.println("recall: "+Double.valueOf(recall).toString());
}
else
System.out.println("recall: NA");
double Fsocre = 0;
if (precise + recall != 0)
{
Fsocre = 2 * (precise * recall) / (precise + recall);
System.out.println("Fsocre: "+Double.valueOf(Fsocre).toString());
}
else
System.out.println("Fsocre: NA");
return 0;
}*/
}

@ -37,13 +37,13 @@ public class trainsvm {
openFile(1, DEFAULT_PATH + "/learn/HasPlate");
openFile(0, DEFAULT_PATH + "/learn/NoPlate");
Mat srcImgs = new Mat();
Mat flags = new Mat(trainingLabels.size(), 1, CvType.CV_32SC1);
Mat labelsMat = new Mat(trainingLabels.size(), 1, CvType.CV_32SC1);
Core.vconcat(trainingImages, srcImgs); // 样本数量不能太大trainingImages.size有限制
for (int i = 0; i < trainingLabels.size(); i++) {
int[] val = { trainingLabels.get(i) };
flags.put(i, 0, val);
labelsMat.put(i, 0, val);
}
SVM svm = SVM.create();
svm.setKernel(SVM.LINEAR);
@ -54,7 +54,7 @@ public class trainsvm {
svm.setNu(0);
svm.setP(0);
svm.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 20000, 0.0001));
TrainData trainData = TrainData.create(srcImgs, Ml.ROW_SAMPLE, flags);
TrainData trainData = TrainData.create(srcImgs, Ml.ROW_SAMPLE, labelsMat);
boolean success = svm.train(trainData);
System.out.println(success);
svm.save( DEFAULT_PATH + "svm.xml");

Loading…
Cancel
Save