From f5774cc35e6dd5bbea7c38797be7b382a32ec4c0 Mon Sep 17 00:00:00 2001 From: yuxue Date: Wed, 13 May 2020 22:17:15 +0800 Subject: [PATCH] no commit message --- .../com/yuxue/easypr/core/PlateJudge.java | 25 +- src/test/java/com/yuxue/test/ANNTrain.java | 239 +++++++++++ .../com/yuxue/test/PlateDetectTrainTest.java | 7 +- .../java/com/yuxue/test/PlatePridectTest.java | 27 +- src/test/java/com/yuxue/test/SVMTrain.java | 382 ++++++++++++++++++ src/test/java/com/yuxue/test/trainsvm.java | 6 +- 6 files changed, 655 insertions(+), 31 deletions(-) create mode 100644 src/test/java/com/yuxue/test/ANNTrain.java create mode 100644 src/test/java/com/yuxue/test/SVMTrain.java diff --git a/src/main/java/com/yuxue/easypr/core/PlateJudge.java b/src/main/java/com/yuxue/easypr/core/PlateJudge.java index 17483796..60ca6e38 100644 --- a/src/main/java/com/yuxue/easypr/core/PlateJudge.java +++ b/src/main/java/com/yuxue/easypr/core/PlateJudge.java @@ -9,9 +9,6 @@ import org.bytedeco.javacpp.opencv_core.Mat; import org.bytedeco.javacpp.opencv_core.Rect; import org.bytedeco.javacpp.opencv_core.Size; import org.bytedeco.javacpp.opencv_ml.SVM; -import org.opencv.core.CvType; -import org.opencv.imgcodecs.Imgcodecs; -import org.opencv.imgproc.Imgproc; /** @@ -52,16 +49,20 @@ public class PlateJudge { p.convertTo(p, opencv_core.CV_32FC1); float ret = svm.predict(features); return (int) ret; - - /*opencv_imgproc.cvtColor(inMat, inMat, Imgproc.COLOR_BGR2GRAY); - Mat features = new Mat(); - opencv_imgproc.Canny(inMat, features, 130, 250); - - Mat p = features.reshape(1, 1); - p.convertTo(p, opencv_core.CV_32FC1); - - float ret = svm.predict(p); + + /*// 使用com.yuxue.test.PlateDetectTrainTest 生成的训练库文件 + // 在使用的过程中,传入的样本切图要跟训练的时候处理切图的方法一致 + Mat grayImage = new Mat(); + opencv_imgproc.cvtColor(inMat, grayImage, opencv_imgproc.CV_RGB2GRAY); + Mat dst = new Mat(); + opencv_imgproc.Canny(grayImage, dst, 130, 250); + Mat samples = dst.reshape(1, 1); + samples.convertTo(samples, opencv_core.CV_32FC1); + + // 如果训练时使用这个标识,那么符合的图像会返回9.0 + float ret = svm.predict(samples); return (int) ret;*/ + } /** diff --git a/src/test/java/com/yuxue/test/ANNTrain.java b/src/test/java/com/yuxue/test/ANNTrain.java new file mode 100644 index 00000000..a91a0447 --- /dev/null +++ b/src/test/java/com/yuxue/test/ANNTrain.java @@ -0,0 +1,239 @@ +package com.yuxue.test; + +import static org.bytedeco.javacpp.opencv_core.CV_32F; +import static org.bytedeco.javacpp.opencv_core.CV_32FC1; +import static org.bytedeco.javacpp.opencv_core.CV_32SC1; +import static org.bytedeco.javacpp.opencv_core.getTickCount; +import static org.bytedeco.javacpp.opencv_imgproc.resize; + +import java.util.Vector; + +import org.bytedeco.javacpp.opencv_core.CvMemStorage; +import org.bytedeco.javacpp.opencv_core.FileStorage; +import org.bytedeco.javacpp.opencv_core.Mat; +import org.bytedeco.javacpp.opencv_core.Scalar; +import org.bytedeco.javacpp.opencv_core.Size; +import org.bytedeco.javacpp.opencv_ml.ANN_MLP; + +import com.yuxue.easypr.core.CoreFunc; +import com.yuxue.enumtype.Direction; +import com.yuxue.util.Convert; +import com.yuxue.util.FileUtil; + +import ch.qos.logback.classic.pattern.Util; + + +/* + * + */ +public class ANNTrain { + + /*private ANN_MLP ann=ANN_MLP.create(); + + // 中国车牌 + private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', + 'F', 'G', 'H', 没有I + 'J', 'K', 'L', 'M', 'N', 没有O 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' }; + private final int numCharacter = 34; 没有I和0,10个数字与24个英文字符之和 + + // 以下都是我训练时用到的中文字符数据,并不全面,有些省份没有训练数据所以没有字符 + // 有些后面加数字2的表示在训练时常看到字符的一种变形,也作为训练数据存储 + private final String strChinese[] = { "zh_cuan" 川 , "zh_e" 鄂 , "zh_gan" 赣 , "zh_hei" 黑 , + "zh_hu" 沪 , "zh_ji" 冀 , "zh_jl" 吉 , "zh_jin" 津 , "zh_jing" 京 , "zh_shan" 陕 , + "zh_liao" 辽 , "zh_lu" 鲁 , "zh_min" 闽 , "zh_ning" 宁 , "zh_su" 苏 , "zh_sx" 晋 , + "zh_wan" 皖 , "zh_yu" 豫 , "zh_yue" 粤 , "zh_zhe" 浙 }; + + private final int numAll = 54; 34+20=54 + + public Mat features(Mat in, int sizeData) { + // Histogram features + float[] vhist = CoreFunc.projectedHistogram(in, Direction.VERTICAL); + float[] hhist = CoreFunc.projectedHistogram(in, Direction.HORIZONTAL); + + // Low data feature + Mat lowData = new Mat(); + resize(in, lowData, new Size(sizeData, sizeData)); + + // Last 10 is the number of moments components + int numCols = vhist.length + hhist.length + lowData.cols() * lowData.cols(); + + Mat out = Mat.zeros(1, numCols, CV_32F).asMat(); + // Asign values to feature,ANN的样本特征为水平、垂直直方图和低分辨率图像所组成的矢量 + int j = 0; + for (int i = 0; i < vhist.length; i++, ++j) { + out.ptr(j).put(Convert.getBytes(vhist[i])); + } + for (int i = 0; i < hhist.length; i++, ++j) { + out.ptr(j).put(Convert.getBytes(hhist[i])); + } + for (int x = 0; x < lowData.cols(); x++) { + for (int y = 0; y < lowData.rows(); y++, ++j) { + float val = lowData.ptr(x, y).get() & 0xFF; + out.ptr(j).put(Convert.getBytes(val)); + } + } + // if(DEBUG) + // cout << out << "\n===========================================\n"; + return out; + } + + public void annTrain(Mat TrainData, Mat classes, int nNeruns) { + ann.clear(); + Mat layers = new Mat(1, 3, CV_32SC1); + layers.ptr(0).put(Convert.getBytes(TrainData.cols())); + layers.ptr(1).put(Convert.getBytes(nNeruns)); + layers.ptr(2).put(Convert.getBytes(numAll)); + ann.create(layers, ANN_MLP.SIGMOID_SYM, 1, 1); + + // Prepare trainClases + // Create a mat with n trained data by m classes + Mat trainClasses = new Mat(); + trainClasses.create(TrainData.rows(), numAll, CV_32FC1); + for (int i = 0; i < trainClasses.rows(); i++) { + for (int k = 0; k < trainClasses.cols(); k++) { + // If class of data i is same than a k class + if (k == Convert.toInt(classes.ptr(i))) + trainClasses.ptr(i, k).put(Convert.getBytes(1f)); + else + trainClasses.ptr(i, k).put(Convert.getBytes(0f)); + } + } + Mat weights = new Mat(1, TrainData.rows(), CV_32FC1, Scalar.all(1)); + // Learn classifier + ann.train(TrainData, trainClasses, weights); + } + + public int saveTrainData() { + System.out.println("Begin saveTrainData"); + Mat classes = new Mat(); + Mat trainingDataf5 = new Mat(); + Mat trainingDataf10 = new Mat(); + Mat trainingDataf15 = new Mat(); + Mat trainingDataf20 = new Mat(); + + Vector trainingLabels = new Vector(); + String path = "res/train/data/chars_recognise_ann/chars2/chars2"; + + for (int i = 0; i < numCharacter; i++) { + System.out.println("Character: " + strCharacters[i]); + String str = path + '/' + strCharacters[i]; + Vector files = new Vector(); + FileUtil.getFiles(str, files); + + int size = (int) files.size(); + for (int j = 0; j < size; j++) { + System.out.println(files.get(j)); + Mat img = imread(files.get(j), 0); + Mat f5 = features(img, 5); + Mat f10 = features(img, 10); + Mat f15 = features(img, 15); + Mat f20 = features(img, 20); + + trainingDataf5.push_back(f5); + trainingDataf10.push_back(f10); + trainingDataf15.push_back(f15); + trainingDataf20.push_back(f20); + trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标 + } + } + + path = "res/train/data/chars_recognise_ann/charsChinese/charsChinese"; + + for (int i = 0; i < strChinese.length; i++) { + System.out.println("Character: " + strChinese[i]); + String str = path + '/' + strChinese[i]; + Vector files = new Vector(); + Util.getFiles(str, files); + + int size = (int) files.size(); + for (int j = 0; j < size; j++) { + System.out.println(files.get(j)); + Mat img = imread(files.get(j), 0); + Mat f5 = features(img, 5); + Mat f10 = features(img, 10); + Mat f15 = features(img, 15); + Mat f20 = features(img, 20); + + trainingDataf5.push_back(f5); + trainingDataf10.push_back(f10); + trainingDataf15.push_back(f15); + trainingDataf20.push_back(f20); + trainingLabels.add(i + numCharacter); + } + } + + trainingDataf5.convertTo(trainingDataf5, CV_32FC1); + trainingDataf10.convertTo(trainingDataf10, CV_32FC1); + trainingDataf15.convertTo(trainingDataf15, CV_32FC1); + trainingDataf20.convertTo(trainingDataf20, CV_32FC1); + int[] labels = new int[trainingLabels.size()]; + for (int i = 0; i < labels.length; ++i) + labels[i] = trainingLabels.get(i).intValue(); + new Mat(labels).copyTo(classes); + + FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.WRITE); + fs.writeObj("TrainingDataF5", trainingDataf5.data()); + fs.writeObj("TrainingDataF10", trainingDataf10.data()); + fs.writeObj("TrainingDataF15", trainingDataf15.data()); + fs.writeObj("TrainingDataF20", trainingDataf20.data()); + fs.writeObj("classes", classes.data()); + fs.release(); + + System.out.println("End saveTrainData"); + return 0; + } + + public void saveModel(int _predictsize, int _neurons) { + FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.READ); + String training = "TrainingDataF" + _predictsize; + Mat TrainingData = new Mat(fs.get(training).readObj()); + Mat Classes = new Mat(fs.get("classes")); + + // train the Ann + System.out.println("Begin to saveModelChar predictSize:" + Integer.valueOf(_predictsize).toString()); + System.out.println(" neurons:" + Integer.valueOf(_neurons).toString()); + + long start = getTickCount(); + annTrain(TrainingData, Classes, _neurons); + long end = getTickCount(); + System.out.println("GetTickCount:" + Long.valueOf((end - start) / 1000).toString()); + + System.out.println("End the saveModelChar"); + + String model_name = "res/train/ann.xml"; + + // if(1) + // { + // String str = + // String.format("ann_prd:%d\tneu:%d",_predictsize,_neurons); + // model_name = str; + // } + + CvFileStorage fsto = CvFileStorage.open(model_name, CvMemStorage.create(), CV_STORAGE_WRITE); + ann.write(fsto, "ann"); + } + + public int annMain() { + System.out.println("To be begin."); + + saveTrainData(); + + // 可根据需要训练不同的predictSize或者neurons的ANN模型 + // for (int i = 2; i <= 2; i ++) + // { + // int size = i * 5; + // for (int j = 5; j <= 10; j++) + // { + // int neurons = j * 10; + // saveModel(size, neurons); + // } + // } + + // 这里演示只训练model文件夹下的ann.xml,此模型是一个predictSize=10,neurons=40的ANN模型。 + // 根据机器的不同,训练时间不一样,但一般需要10分钟左右,所以慢慢等一会吧。 + saveModel(10, 40); + + System.out.println("To be end."); + return 0; + }*/ +} diff --git a/src/test/java/com/yuxue/test/PlateDetectTrainTest.java b/src/test/java/com/yuxue/test/PlateDetectTrainTest.java index cc9a5b68..0866ce15 100644 --- a/src/test/java/com/yuxue/test/PlateDetectTrainTest.java +++ b/src/test/java/com/yuxue/test/PlateDetectTrainTest.java @@ -39,7 +39,6 @@ public class PlateDetectTrainTest { public static void main(String[] arg) { - // 正样本 // 136 × 36 像素 训练的源图像文件要相同大小 List imgList1 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/HasPlate"), Constant.DEFAULT_TYPE, false); @@ -84,14 +83,14 @@ public class PlateDetectTrainTest { // 失败案例:这里我试图用 get(row,col,data)方法获取数组,但是结果和这个结果不一样,原因未知。 float[] arr = new float[dst.rows() * dst.cols()]; int l = 0; - for (int j = 0; j < dst.rows(); j++) { - for (int k = 0; k < dst.cols(); k++) { + for (int j = 0; j < dst.rows(); j++) { // 遍历行 + for (int k = 0; k < dst.cols(); k++) { // 遍历列 double[] a = dst.get(j, k); arr[l] = (float) a[0]; l++; } } - trainingDataMat.put(i, 0, arr); + trainingDataMat.put(i, 0, arr); // 多张图合并到一张 } String module = DEFAULT_PATH + "svm.xml"; diff --git a/src/test/java/com/yuxue/test/PlatePridectTest.java b/src/test/java/com/yuxue/test/PlatePridectTest.java index 5e5bfd0f..d75632ae 100644 --- a/src/test/java/com/yuxue/test/PlatePridectTest.java +++ b/src/test/java/com/yuxue/test/PlatePridectTest.java @@ -1,5 +1,6 @@ package com.yuxue.test; + import org.opencv.core.Core; import org.opencv.core.CvType; import org.opencv.core.Mat; @@ -50,21 +51,23 @@ public class PlatePridectTest { Imgproc.cvtColor(src, src, Imgproc.COLOR_BGR2GRAY); Mat dst = new Mat(); Imgproc.Canny(src, dst, 130, 250); - - Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32FC1); - - // 转换 src 图像的 cvtype - // 失败案例:我试图用 dst.convertTo(src, CvType.CV_32FC1); 转换,但是失败了,原因未知。猜测: 内部的数据类型没有转换? - float[] dataArr = new float[dst.cols() * dst.rows()]; - for (int i = 0, f = 0; i < dst.rows(); i++) { - for (int j = 0; j < dst.cols(); j++) { - double pixel = dst.get(i, j)[0]; - dataArr[f] = (float) pixel; - f++; + + Mat samples = dst.reshape(1, 1); + samples.convertTo(samples, CvType.CV_32FC1); + // 等价于上面两行代码 + /*Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32FC1); + float[] arr = new float[dst.cols() * dst.rows()]; + int l = 0; + for (int j = 0; j < dst.rows(); j++) { // 遍历行 + for (int k = 0; k < dst.cols(); k++) { // 遍历列 + double[] a = dst.get(j, k); + arr[l] = (float) a[0]; + l++; } } + samples.put(0, 0, arr);*/ + Imgcodecs.imwrite(DEFAULT_PATH + "test_1.jpg", samples); - samples.put(0, 0, dataArr); // 如果训练时使用这个标识,那么符合的图像会返回9.0 float flag = svm.predict(samples); diff --git a/src/test/java/com/yuxue/test/SVMTrain.java b/src/test/java/com/yuxue/test/SVMTrain.java new file mode 100644 index 00000000..03e22f47 --- /dev/null +++ b/src/test/java/com/yuxue/test/SVMTrain.java @@ -0,0 +1,382 @@ +package com.yuxue.test; + + +import static org.bytedeco.javacpp.opencv_core.*; +import static org.bytedeco.javacpp.opencv_highgui.*; +import static org.bytedeco.javacpp.opencv_ml.*; + + +import java.util.*; + +/* + * Created by fanwenjie + * @version 1.1 + */ +public class SVMTrain { + + /*private SVMCallback callback = new Features(); + private static final String hasPlate = "HasPlate"; + private static final String noPlate = "NoPlate"; + + public SVMTrain(SVMCallback callback){ + this.callback = callback; + } + + public SVMTrain(){} + + + private void learn2Plate(float bound, final String name) { + final String filePath = "res/train/data/plate_detect_svm/learn/" + name; + Vector files = new Vector(); + ////获取该路径下的所有文件 + Util.getFiles(filePath, files); + int size = files.size(); + if (0 == size) { + System.out.println("File not found in " + filePath); + return; + } + Collections.shuffle(files, new Random(new Date().getTime())); + ////随机选取70%作为训练数据,30%作为测试数据 + int boundry = (int) (bound * size); + + Util.recreateDir("res/train/data/plate_detect_svm/train/" + name); + Util.recreateDir("res/train/data/plate_detect_svm/test/" + name); + + System.out.println("Save " + name + " train!"); + for (int i = 0; i < boundry; i++) { + System.out.println(files.get(i)); + Mat img = imread(files.get(i)); + String str = "res/train/data/plate_detect_svm/train/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; + imwrite(str, img); + } + + System.out.println("Save " + name + " test!"); + for (int i = boundry; i < size; i++) { + System.out.println(files.get(i)); + Mat img = imread(files.get(i)); + String str = "res/train/data/plate_detect_svm/test/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; + imwrite(str, img); + } + } + + private void getPlateTrain(Mat trainingImages, Vector trainingLabels, final String name) { + int label = 1; + final String filePath = "res/train/data/plate_detect_svm/train/" + name; + Vector files = new Vector(); + + ////获取该路径下的所有文件 + Util.getFiles(filePath, files); + + int size = files.size(); + if (0 == size) { + System.out.println("File not found in " + filePath); + return; + } + System.out.println("get " + name + " train!"); + for (int i = 0; i < size; i++) { + //System.out.println(files[i].c_str()).toString()); + Mat img = imread(files.get(i)); + + //调用回调函数决定特征 + Mat features = this.callback.getHisteqFeatures(img); + features = features.reshape(1, 1); + trainingImages.push_back(features); + trainingLabels.add(label); + } + } + + private void getPlateTest(MatVector testingImages,Vector testingLabels,final String name){ + int label = 1; + final String filePath = "res/train/data/plate_detect_svm/test/"+name; + Vector files = new Vector(); + Util.getFiles(filePath, files); + + int size = files.size(); + if (0 == size) { + System.out.println("File not found in " + filePath); + return; + } + System.out.println("get "+name+" test!"); + for (int i = 0; i < size; i++) + { + Mat img = imread(files.get(i)); + testingImages.put(img); + testingLabels.add(label); + } + } + + public void learn2HasPlate() { + learn2HasPlate(0.7f); + } + + public void learn2HasPlate(float bound) { + learn2Plate(bound, hasPlate); + } + + public void learn2NoPlate() { + learn2NoPlate(0.7f); + } + + public void learn2NoPlate(float bound) { + learn2Plate(bound, noPlate); + } + + + public void getNoPlateTrain(Mat trainingImages, Vector trainingLabels) { + getPlateTrain(trainingImages, trainingLabels, noPlate); + } + + public void getHasPlateTrain(Mat trainingImages, Vector trainingLabels) { + getPlateTrain(trainingImages, trainingLabels, hasPlate); + } + + + public void getHasPlateTest(MatVector testingImages,Vector testingLabels) + { + getPlateTest(testingImages,testingLabels,hasPlate); + } + + public void getNoPlateTest(MatVector testingImages,Vector testingLabels) + { + getPlateTest(testingImages,testingLabels,noPlate); + } + + + + //! 测试SVM的准确率,回归率以及FScore + public void getAccuracy(Mat testingclasses_preditc, Mat testingclasses_real) + { + int channels = testingclasses_preditc.channels(); + System.out.println("channels: "+Integer.valueOf(channels).toString()); + int nRows = testingclasses_preditc.rows(); + System.out.println("nRows: "+Integer.valueOf(nRows).toString()); + int nCols = testingclasses_preditc.cols() * channels; + System.out.println("nCols: "+Integer.valueOf(nCols).toString()); + int channels_real = testingclasses_real.channels(); + System.out.println("channels_real: "+Integer.valueOf(channels_real).toString()); + int nRows_real = testingclasses_real.rows(); + System.out.println("nRows_real: " + Integer.valueOf(nRows_real).toString()); + int nCols_real = testingclasses_real.cols() * channels; + System.out.println("nCols_real: "+Integer.valueOf(nCols_real).toString()); + + double count_all = 0; + double ptrue_rtrue = 0; + double ptrue_rfalse = 0; + double pfalse_rtrue = 0; + double pfalse_rfalse = 0; + + for (int i = 0; i < nRows; i++) + { + + final float predict = Convert.toFloat(testingclasses_preditc.ptr(i)); + final float real = Convert.toFloat(testingclasses_real.ptr(i)); + + count_all ++; + + //System.out.println("predict:" << predict).toString()); + //System.out.println("real:" << real).toString()); + + if (predict == 1.0 && real == 1.0) + ptrue_rtrue ++; + if (predict == 1.0 && real == 0) + ptrue_rfalse ++; + if (predict == 0 && real == 1.0) + pfalse_rtrue ++; + if (predict == 0 && real == 0) + pfalse_rfalse ++; + } + + System.out.println("count_all: "+Double.valueOf(count_all).toString()); + System.out.println("ptrue_rtrue: "+Double.valueOf(ptrue_rtrue).toString()); + System.out.println("ptrue_rfalse: "+Double.valueOf(ptrue_rfalse).toString()); + System.out.println("pfalse_rtrue: "+Double.valueOf(pfalse_rtrue).toString()); + System.out.println("pfalse_rfalse: "+Double.valueOf(pfalse_rfalse).toString()); + + double precise = 0; + if (ptrue_rtrue + ptrue_rfalse != 0) + { + precise = ptrue_rtrue/(ptrue_rtrue + ptrue_rfalse); + System.out.println("precise: "+Double.valueOf(precise).toString()); + } + else + { + System.out.println("precise: NA"); + } + + double recall = 0; + if (ptrue_rtrue + pfalse_rtrue != 0) + { + recall = ptrue_rtrue/(ptrue_rtrue + pfalse_rtrue); + System.out.println("recall: "+Double.valueOf(recall).toString()); + } + else + { + System.out.println("recall: NA"); + } + + if (precise + recall != 0) + { + double F = (precise * recall)/(precise + recall); + System.out.println("F: "+Double.valueOf(F).toString()); + } + else + { + System.out.println("F: NA"); + } + } + + + public int svmTrain(boolean dividePrepared, boolean trainPrepared) + { + + Mat classes = new Mat(); + Mat trainingData = new Mat(); + + Mat trainingImages = new Mat(); + Vector trainingLabels = new Vector(); + + + if (!dividePrepared) + { + //分割learn里的数据到train和test里 + System.out.println("Divide learn to train and test"); + learn2HasPlate(); + learn2NoPlate(); + } + + //将训练数据加载入内存 + if (!trainPrepared) + { + System.out.print("Begin to get train data to memory"); + getHasPlateTrain(trainingImages, trainingLabels); + getNoPlateTrain(trainingImages, trainingLabels); + + + trainingImages.copyTo(trainingData); + trainingData.convertTo(trainingData, CV_32FC1); + + int []labels = new int[trainingLabels.size()]; + for(int i=0;i testingLabels_real = new Vector(); + + //将测试数据加载入内存 + System.out.println("Begin to get test data to memory"); + getHasPlateTest(testingImages, testingLabels_real); + getNoPlateTest(testingImages, testingLabels_real); + + CvSVM svm = new CvSVM(); + if (!trainPrepared && !classes.empty() && !trainingData.empty()) + { + CvSVMParams SVM_params = new CvSVMParams(CvSVM.C_SVC,CvSVM.RBF,0.1,1,0.1,1,0.1,0.1, + new CvMat(),new CvTermCriteria().type(CV_TERMCRIT_ITER).max_iter(100000).epsilon(0.0001)); + + //Train SVM + System.out.println("Begin to generate svm"); + + try { + //CvSVM svm(trainingData, classes, Mat(), Mat(), SVM_params); + svm.train_auto(trainingData, classes, new Mat(), new Mat(), SVM_params, 10, + CvSVM.get_default_grid(CvSVM.C), + CvSVM.get_default_grid(CvSVM.GAMMA), + CvSVM.get_default_grid(CvSVM.P), + CvSVM.get_default_grid(CvSVM.NU), + CvSVM.get_default_grid(CvSVM.COEF), + CvSVM.get_default_grid(CvSVM.DEGREE), + true); + } catch (Exception err) { + System.out.println(err.getMessage()); + } + + System.out.println("Svm generate done!"); + + CvFileStorage fsTo = CvFileStorage.open("res/rain/svm.xml", CvMemStorage.create(),CV_STORAGE_WRITE); + svm.write(fsTo, "svm"); + } + else + { + try { + String path = "res/train/svm.xml"; + svm.load(path, "svm"); + } catch (Exception err) { + System.out.println(err.getMessage()); + return 0; //next predict requires svm + } + } + + System.out.println("Begin to predict"); + + double count_all = 0; + double ptrue_rtrue = 0; + double ptrue_rfalse = 0; + double pfalse_rtrue = 0; + double pfalse_rfalse = 0; + + int size = (int)testingImages.size(); + for (int i = 0; i < size; i++) + { + //System.out.println(files[i].c_str()); + Mat p = testingImages.get(i); + + //调用回调函数决定特征 + Mat features = callback.getHistogramFeatures(p); + features = features.reshape(1, 1); + features.convertTo(features, CV_32FC1); + + int predict = (int)svm.predict(features); + int real = testingLabels_real.get(i); + + if (predict == 1 && real == 1) + ptrue_rtrue ++; + if (predict == 1 && real == 0) + ptrue_rfalse ++; + if (predict == 0 && real == 1) + pfalse_rtrue ++; + if (predict == 0 && real == 0) + pfalse_rfalse ++; + } + + count_all = size; + + System.out.println("Get the Accuracy!"); + + System.out.println("count_all: "+Double.valueOf(count_all).toString()); + System.out.println("ptrue_rtrue: "+Double.valueOf(ptrue_rtrue).toString()); + System.out.println("ptrue_rfalse: "+Double.valueOf(ptrue_rfalse).toString()); + System.out.println("pfalse_rtrue: "+Double.valueOf(pfalse_rtrue).toString()); + System.out.println("pfalse_rfalse: "+Double.valueOf(pfalse_rfalse).toString()); + + double precise = 0; + if (ptrue_rtrue + ptrue_rfalse != 0) + { + precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse); + System.out.println("precise: "+Double.valueOf(precise).toString()); + } + else + System.out.println("precise: NA"); + + double recall = 0; + if (ptrue_rtrue + pfalse_rtrue != 0) + { + recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue); + System.out.println("recall: "+Double.valueOf(recall).toString()); + } + else + System.out.println("recall: NA"); + + double Fsocre = 0; + if (precise + recall != 0) + { + Fsocre = 2 * (precise * recall) / (precise + recall); + System.out.println("Fsocre: "+Double.valueOf(Fsocre).toString()); + } + else + System.out.println("Fsocre: NA"); + return 0; + }*/ +} diff --git a/src/test/java/com/yuxue/test/trainsvm.java b/src/test/java/com/yuxue/test/trainsvm.java index 873d8083..91352816 100644 --- a/src/test/java/com/yuxue/test/trainsvm.java +++ b/src/test/java/com/yuxue/test/trainsvm.java @@ -37,13 +37,13 @@ public class trainsvm { openFile(1, DEFAULT_PATH + "/learn/HasPlate"); openFile(0, DEFAULT_PATH + "/learn/NoPlate"); Mat srcImgs = new Mat(); - Mat flags = new Mat(trainingLabels.size(), 1, CvType.CV_32SC1); + Mat labelsMat = new Mat(trainingLabels.size(), 1, CvType.CV_32SC1); Core.vconcat(trainingImages, srcImgs); // 样本数量不能太大,trainingImages.size有限制 for (int i = 0; i < trainingLabels.size(); i++) { int[] val = { trainingLabels.get(i) }; - flags.put(i, 0, val); + labelsMat.put(i, 0, val); } SVM svm = SVM.create(); svm.setKernel(SVM.LINEAR); @@ -54,7 +54,7 @@ public class trainsvm { svm.setNu(0); svm.setP(0); svm.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 20000, 0.0001)); - TrainData trainData = TrainData.create(srcImgs, Ml.ROW_SAMPLE, flags); + TrainData trainData = TrainData.create(srcImgs, Ml.ROW_SAMPLE, labelsMat); boolean success = svm.train(trainData); System.out.println(success); svm.save( DEFAULT_PATH + "svm.xml");