diff --git a/src/main/java/com/yuxue/train/ANNTrain1.java b/src/main/java/com/yuxue/train/ANNTrain1.java index 88bf3abb..0a0d8ce8 100644 --- a/src/main/java/com/yuxue/train/ANNTrain1.java +++ b/src/main/java/com/yuxue/train/ANNTrain1.java @@ -5,13 +5,11 @@ import java.util.Vector; import static org.bytedeco.javacpp.opencv_core.*; import static org.bytedeco.javacpp.opencv_ml.*; -import org.bytedeco.javacpp.opencv_core; import org.bytedeco.javacpp.opencv_imgcodecs; import org.bytedeco.javacpp.opencv_core.Mat; import com.yuxue.constant.Constant; import com.yuxue.easypr.core.CoreFunc; -import com.yuxue.util.Convert; import com.yuxue.util.FileUtil; /** @@ -150,7 +148,7 @@ public class ANNTrain1 { // 这里演示只训练model文件夹下的ann.xml,此模型是一个predictSize=10,neurons=40的ANN模型 // 可根据需要训练不同的predictSize或者neurons的ANN模型 // 根据机器的不同,训练时间不一样,但一般需要10分钟左右,所以慢慢等一会吧。 - // annT.train(Constant.predictSize, Constant.neurons); + annT.train(Constant.predictSize, Constant.neurons); annT.predict(); diff --git a/src/main/java/com/yuxue/train/PlateRecoTrain.java b/src/main/java/com/yuxue/train/PlateRecoTrain.java deleted file mode 100644 index 969aa4a4..00000000 --- a/src/main/java/com/yuxue/train/PlateRecoTrain.java +++ /dev/null @@ -1,212 +0,0 @@ -package com.yuxue.train; - -import java.io.File; -import java.util.List; - -import org.opencv.core.Core; -import org.opencv.core.CvType; -import org.opencv.core.Mat; -import org.opencv.core.TermCriteria; -import org.opencv.imgcodecs.Imgcodecs; -import org.opencv.imgproc.Imgproc; -import org.opencv.ml.Ml; -import org.opencv.ml.SVM; -import org.opencv.ml.TrainData; - -import com.yuxue.constant.Constant; -import com.yuxue.util.FileUtil; - -/** - * 基于org.opencv官方包实现的训练 - * - * - * windows下环境配置: - * 1、官网下载对应版本的openvp:https://opencv.org/releases/page/2/ 当前使用4.0.1版本 - * 2、双击exe文件安装,将 安装目录下\build\java\x64\opencv_java401.dll 拷贝到\build\x64\vc14\bin\目录下 - * 3、eclipse添加User Libraries - * 4、项目右键build path,添加步骤三新增的lib - * - * 图片识别车牌训练 - * 训练出来的库文件,用于判断切图是否包含车牌 - * - * 训练的svm.xml应用: - * 1、替换res/model/svm.xml文件 - * 2、修改com.yuxue.easypr.core.PlateJudge.plateJudge(Mat) 方法 - * 将样本处理方法切换一下,即将对应被注释掉的模块代码取消注释 - * @author yuxue - * @date 2020-05-13 10:10 - */ -public class PlateRecoTrain { - - // 默认的训练操作的根目录 - private static final String DEFAULT_PATH = "D:/PlateDetect/train/plate_detect_svm/"; - - // 训练模型文件保存位置 - private static final String MODEL_PATH = DEFAULT_PATH + "svm.xml"; - - static { - System.loadLibrary(Core.NATIVE_LIBRARY_NAME); - } - - public static void main(String[] arg) { - // 训练, 生成svm.xml库文件 - // train(); - - // 识别,判断样本文件是否是车牌 - pridect(); - } - - - public static void train() { - - // 正样本 // 136 × 36 像素 训练的源图像文件要相同大小 - List imgList0 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/HasPlate"), Constant.DEFAULT_TYPE, false); - - // 负样本 // 136 × 36 像素 训练的源图像文件要相同大小 - List imgList1 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/NoPlate"), Constant.DEFAULT_TYPE, false); - - // 标记:正样本用 0 表示,负样本用 1 表示。 - int labels[] = createLabelArray(imgList0.size(), imgList1.size()); - int sample_num = labels.length; // 图片数量 - - // 用于存放所有样本的矩阵 - Mat trainingDataMat = null; - - // 存放标记的Mat,每个图片都要给一个标记 - Mat labelsMat = new Mat(sample_num, 1, CvType.CV_32SC1); - labelsMat.put(0, 0, labels); - - for (int i = 0; i < sample_num; i++) { // 遍历所有的正负样本,处理样本用于生成训练的库文件 - String path = ""; - if(i < imgList0.size()) { - path = imgList0.get(i).getAbsolutePath(); - } else { - path = imgList1.get(i - imgList0.size()).getAbsolutePath(); - } - - Mat inMat = Imgcodecs.imread(path); // 读取样本文件 - - // 创建一个行数为sample_num, 列数为 rows*cols 的矩阵; 用于存放样本 - if (trainingDataMat == null) { - trainingDataMat = new Mat(sample_num, inMat.rows() * inMat.cols(), CvType.CV_32F); - } - - // 样本文件处理,这里是为了过滤不需要的特征,减少训练时间 // 根据实际情况需要进行处理 - Mat greyMat = new Mat(); - Imgproc.cvtColor(inMat, greyMat, Imgproc.COLOR_BGR2GRAY); // 转成灰度图 - - Mat dst = new Mat(inMat.rows(), inMat.cols(), inMat.type()); - Imgproc.Canny(greyMat, dst, 130, 250); // 边缘检测 - - // 将样本矩阵转换成只有一行的矩阵,保存为float数组 - float[] arr = new float[dst.rows() * dst.cols()]; - int l = 0; - for (int j = 0; j < dst.rows(); j++) { // 遍历行 - for (int k = 0; k < dst.cols(); k++) { // 遍历列 - double[] a = dst.get(j, k); - arr[l] = (float) a[0]; - l++; - } - } - - trainingDataMat.put(i, 0, arr); // 多张图合并到一张 - } - - // Imgcodecs.imwrite(DEFAULT_PATH + "trainingDataMat.jpg", trainingDataMat); - - // 配置SVM训练器参数 - TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 20000, 0.0001); - SVM svm = SVM.create(); - svm.setTermCriteria(criteria); // 指定 - svm.setKernel(SVM.RBF); // 使用预先定义的内核初始化 - svm.setType(SVM.C_SVC); // SVM的类型,默认是:SVM.C_SVC - svm.setGamma(0.1); // 核函数的参数 - svm.setNu(0.1); // SVM优化问题参数 - svm.setC(1); // SVM优化问题的参数C - svm.setP(0.1); - svm.setDegree(0.1); - svm.setCoef0(0.1); - - TrainData td = TrainData.create(trainingDataMat, Ml.ROW_SAMPLE, labelsMat);// 类封装的训练数据 - boolean success = svm.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());// 训练统计模型 - System.out.println("svm training result: " + success); - svm.save(MODEL_PATH);// 保存模型 - } - - - - public static void pridect() { - // 加载训练得到的 xml 模型文件 - SVM svm = SVM.load(MODEL_PATH); - - // 136 × 36 像素 需要跟训练的源图像文件保持相同大小 - doPridect(svm, DEFAULT_PATH + "test/A01_NMV802_0.jpg"); - doPridect(svm, DEFAULT_PATH + "test/debug_resize_1.jpg"); - doPridect(svm, DEFAULT_PATH + "test/debug_resize_2.jpg"); - doPridect(svm, DEFAULT_PATH + "test/debug_resize_3.jpg"); - doPridect(svm, DEFAULT_PATH + "test/S22_KG2187_3.jpg"); - doPridect(svm, DEFAULT_PATH + "test/S22_KG2187_5.jpg"); - doPridect(svm, DEFAULT_PATH + "test/result_0.png"); - doPridect(svm, DEFAULT_PATH + "test/result_1.png"); - doPridect(svm, DEFAULT_PATH + "test/result_2.png"); - doPridect(svm, DEFAULT_PATH + "test/result_3.png"); - doPridect(svm, DEFAULT_PATH + "test/result_4.png"); - doPridect(svm, DEFAULT_PATH + "test/result_5.png"); - doPridect(svm, DEFAULT_PATH + "test/result_6.png"); - doPridect(svm, DEFAULT_PATH + "test/result_7.png"); - doPridect(svm, DEFAULT_PATH + "test/result_8.png"); - - } - - public static void doPridect(SVM svm, String imgPath) { - - Mat src = Imgcodecs.imread(imgPath);// 图片大小要和样本一致 - Imgproc.cvtColor(src, src, Imgproc.COLOR_BGR2GRAY); - Mat dst = new Mat(); - Imgproc.Canny(src, dst, 130, 250); - - Mat samples = dst.reshape(1, 1); - samples.convertTo(samples, CvType.CV_32F); - - // 等价于上面两行代码 - /*Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32F); - float[] arr = new float[dst.cols() * dst.rows()]; - int l = 0; - for (int j = 0; j < dst.rows(); j++) { // 遍历行 - for (int k = 0; k < dst.cols(); k++) { // 遍历列 - double[] a = dst.get(j, k); - arr[l] = (float) a[0]; - l++; - } - } - samples.put(0, 0, arr);*/ - - // Imgcodecs.imwrite(DEFAULT_PATH + "test_1.jpg", samples); - - // 如果训练时使用这个标识,那么符合的图像会返回9.0 - float flag = svm.predict(samples); - - System.err.println(flag); - - if (flag == 0) { - System.err.println(imgPath + ": 目标符合"); - } - if (flag == 1) { - System.out.println(imgPath + ": 目标不符合"); - } - } - - public static int[] createLabelArray(Integer i1, Integer i2) { - int labels[] = new int[i1 + i2]; - - for (int i = 0; i < labels.length; i++) { - if(i < i1) { - labels[i] = 0; - } else { - labels[i] = 1; - } - } - return labels; - } - -} diff --git a/src/main/java/com/yuxue/train/SVMTrain.java b/src/main/java/com/yuxue/train/SVMTrain.java index 8f6b5e17..abdc66e0 100644 --- a/src/main/java/com/yuxue/train/SVMTrain.java +++ b/src/main/java/com/yuxue/train/SVMTrain.java @@ -1,382 +1,212 @@ -package com.yuxue.train; - -import java.util.*; - -import static org.bytedeco.javacpp.opencv_core.*; -import static org.bytedeco.javacpp.opencv_ml.*; - -import org.bytedeco.javacpp.opencv_core; -import org.bytedeco.javacpp.opencv_imgcodecs; - -import com.yuxue.easypr.core.Features; -import com.yuxue.easypr.core.SVMCallback; -import com.yuxue.util.Convert; -import com.yuxue.util.FileUtil; - -/** - * 基于org.bytedeco.javacpp包实现的训练 - * JavaCPP 是一个开源库,它提供了在 Java 中高效访问本地 C++的方法 - * - * 图片识别车牌训练 - * 训练出来的库文件,用于判断切图是否包含车牌 - * - * 训练的svm.xml应用: - * 1、替换res/model/svm.xml文件 - * 2、修改com.yuxue.easypr.core.PlateJudge.plateJudge(Mat) 方法 - * 将样本处理方法切换一下,即将对应被注释掉的模块代码取消注释 - * @author yuxue - * @date 2020-05-14 22:16 - */ -public class SVMTrain { - - private SVMCallback callback = new Features(); - - // 默认的训练操作的根目录 - private static final String DEFAULT_PATH = "D:/PlateDetect/train/plate_detect_svm/"; - - // 训练模型文件保存位置 - private static final String MODEL_PATH = DEFAULT_PATH + "svm.xml"; - - private static final String hasPlate = "HasPlate"; - private static final String noPlate = "NoPlate"; - - public SVMTrain() { - } - - public SVMTrain(SVMCallback callback) { - this.callback = callback; - } - - /** - * 将learn文件夹下的图片,转存到tain test文件夹下,区分hasPalte noPlate - * 随机选取bound%作为训练数据,30%作为测试数据 - * @param bound - * @param name - */ - private void learn2Plate(float bound, final String name) { - final String filePath = DEFAULT_PATH + "learn/" + name; - Vector files = new Vector(); - - //// 获取该路径下的所有文件 - FileUtil.getFiles(filePath, files); - int size = files.size(); - if (0 == size) { - System.err.println("当前目录下没有文件: " + filePath); - return; - } - Collections.shuffle(files, new Random(new Date().getTime())); - - //// 随机选取70%作为训练数据,30%作为测试数据 - int boundry = (int) (bound * size); - - // 重新创建目录 - FileUtil.recreateDir(DEFAULT_PATH + "train/" + name); - FileUtil.recreateDir(DEFAULT_PATH + "test/" + name); - - for (int i = 0; i < boundry; i++) { - Mat img = opencv_imgcodecs.imread(files.get(i)); - String str = DEFAULT_PATH + "train/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; - opencv_imgcodecs.imwrite(str, img); - } - - for (int i = boundry; i < size; i++) { - Mat img = opencv_imgcodecs.imread(files.get(i)); - String str = DEFAULT_PATH + "test/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; - opencv_imgcodecs.imwrite(str, img); - } - } - - /** - * 获取训练图片 - * @param trainingImages - * @param trainingLabels - * @param name - */ - private void getPlateTrain(Mat trainingImages, Vector trainingLabels, final String name, int label) { - // int label = 1; - final String filePath = DEFAULT_PATH + "train/" + name; - Vector files = new Vector(); - - // 获取该路径下的所有文件 - FileUtil.getFiles(filePath, files); - - int size = files.size(); - if (null == files || size <= 0) { - System.out.println("File not found in " + filePath); - return; - } - for (int i = 0; i < size; i++) { - // System.out.println(files.get(i)); - Mat inMat = opencv_imgcodecs.imread(files.get(i)); - // 调用回调函数决定特征 - // Mat features = this.callback.getHisteqFeatures(inMat); - Mat features = this.callback.getHistogramFeatures(inMat); - // 通过直方图均衡化后的彩色图进行预测 - Mat p = features.reshape(1, 1); - p.convertTo(p, opencv_core.CV_32F); - - // 136 36 14688 1 变换尺寸 - // System.err.println(inMat.cols() + "\t" + inMat.rows() + "\t" + p.cols() + "\t" + p.rows()); - - trainingImages.push_back(p); // 合并成一张图片 - trainingLabels.add(label); - } - } - - private void getPlateTest(MatVector testingImages, Vector testingLabels, final String name, int label) { - // int label = 1; - final String filePath = DEFAULT_PATH + "test/" + name; - Vector files = new Vector(); - FileUtil.getFiles(filePath, files); - - int size = files.size(); - if (0 == size) { - System.out.println("File not found in " + filePath); - return; - } - System.out.println("get " + name + " test!"); - for (int i = 0; i < size; i++) { - Mat inMat = opencv_imgcodecs.imread(files.get(i)); - testingImages.push_back(inMat); - testingLabels.add(label); - } - } - - // ! 测试SVM的准确率,回归率以及FScore - public void getAccuracy(Mat testingclasses_preditc, Mat testingclasses_real) { - int channels = testingclasses_preditc.channels(); - System.out.println("channels: " + Integer.valueOf(channels).toString()); - int nRows = testingclasses_preditc.rows(); - System.out.println("nRows: " + Integer.valueOf(nRows).toString()); - int nCols = testingclasses_preditc.cols() * channels; - System.out.println("nCols: " + Integer.valueOf(nCols).toString()); - int channels_real = testingclasses_real.channels(); - System.out.println("channels_real: " + Integer.valueOf(channels_real).toString()); - int nRows_real = testingclasses_real.rows(); - System.out.println("nRows_real: " + Integer.valueOf(nRows_real).toString()); - int nCols_real = testingclasses_real.cols() * channels; - System.out.println("nCols_real: " + Integer.valueOf(nCols_real).toString()); - - double count_all = 0; - double ptrue_rtrue = 0; - double ptrue_rfalse = 0; - double pfalse_rtrue = 0; - double pfalse_rfalse = 0; - - for (int i = 0; i < nRows; i++) { - - final float predict = Convert.toFloat(testingclasses_preditc.ptr(i)); - final float real = Convert.toFloat(testingclasses_real.ptr(i)); - - count_all++; - - // System.out.println("predict:" << predict).toString()); - // System.out.println("real:" << real).toString()); - - if (predict == 1.0 && real == 1.0) - ptrue_rtrue++; - if (predict == 1.0 && real == 0) - ptrue_rfalse++; - if (predict == 0 && real == 1.0) - pfalse_rtrue++; - if (predict == 0 && real == 0) - pfalse_rfalse++; - } - - System.out.println("count_all: " + Double.valueOf(count_all).toString()); - System.out.println("ptrue_rtrue: " + Double.valueOf(ptrue_rtrue).toString()); - System.out.println("ptrue_rfalse: " + Double.valueOf(ptrue_rfalse).toString()); - System.out.println("pfalse_rtrue: " + Double.valueOf(pfalse_rtrue).toString()); - System.out.println("pfalse_rfalse: " + Double.valueOf(pfalse_rfalse).toString()); - - double precise = 0; - if (ptrue_rtrue + ptrue_rfalse != 0) { - precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse); - System.out.println("precise: " + Double.valueOf(precise).toString()); - } else { - System.out.println("precise: NA"); - } - - double recall = 0; - if (ptrue_rtrue + pfalse_rtrue != 0) { - recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue); - System.out.println("recall: " + Double.valueOf(recall).toString()); - } else { - System.out.println("recall: NA"); - } - - if (precise + recall != 0) { - double F = (precise * recall) / (precise + recall); - System.out.println("F: " + Double.valueOf(F).toString()); - } else { - System.out.println("F: NA"); - } - } - - /** - * 训练 - * @param dividePrepared - * @return - */ - public int svmTrain(boolean dividePrepared) { - - Mat classes = new Mat(); - Mat trainingData = new Mat(); - Mat trainingImages = new Mat(); - Vector trainingLabels = new Vector(); - - // 分割learn里的数据到train和test里 // 从库里面选取训练样本 - if (!dividePrepared) { - learn2Plate(0.1f, hasPlate); // 性能不好的机器,最好不要挑选太多的样本,这个方案太消耗资源了。 - learn2Plate(0.1f, noPlate); - } - - // System.err.println("Begin to get train data to memory"); - - getPlateTrain(trainingImages, trainingLabels, hasPlate, 0); - getPlateTrain(trainingImages, trainingLabels, noPlate, 1); - - // System.err.println(trainingImages.cols()); - - trainingImages.copyTo(trainingData); - trainingData.convertTo(trainingData, CV_32F); - - int[] labels = new int[trainingLabels.size()]; - for (int i = 0; i < trainingLabels.size(); ++i) { - labels[i] = trainingLabels.get(i).intValue(); - } - new Mat(labels).copyTo(classes); - - TrainData train_data = TrainData.create(trainingData, ROW_SAMPLE, classes); - - SVM svm = SVM.create(); - - try { - TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 20000, 0.0001); - svm.setTermCriteria(criteria); // 指定 - svm.setKernel(SVM.RBF); // 使用预先定义的内核初始化 - svm.setType(SVM.C_SVC); // SVM的类型,默认是:SVM.C_SVC - svm.setGamma(0.1); // 核函数的参数 - svm.setNu(0.1); // SVM优化问题参数 - svm.setC(1); // SVM优化问题的参数C - svm.setP(0.1); - svm.setDegree(0.1); - svm.setCoef0(0.1); - - svm.trainAuto(train_data, 10, - SVM.getDefaultGrid(SVM.C), - SVM.getDefaultGrid(SVM.GAMMA), - SVM.getDefaultGrid(SVM.P), - SVM.getDefaultGrid(SVM.NU), - SVM.getDefaultGrid(SVM.COEF), - SVM.getDefaultGrid(SVM.DEGREE), - true); - - } catch (Exception err) { - System.out.println(err.getMessage()); - } - - System.out.println("Svm generate done!"); - - /*FileStorage fsTo = new FileStorage(MODEL_PATH, FileStorage.WRITE); - svm.write(fsTo, "svm");*/ - svm.save(MODEL_PATH); - return 0; - } - - // 测试 - public int svmPredict() { - SVM svm = SVM.create(); - try { - svm.clear(); - // svm = SVM.loadSVM(MODEL_PATH, "svm"); - svm = SVM.load(MODEL_PATH); - } catch (Exception err) { - System.err.println(err.getMessage()); - return 0; // next predict requires svm - } - - System.out.println("Begin to predict"); - // Test SVM - MatVector testingImages = new MatVector(); - Vector testingLabels_real = new Vector(); - - // 将测试数据加载入内存 - getPlateTest(testingImages, testingLabels_real, hasPlate, 0); - getPlateTest(testingImages, testingLabels_real, noPlate, 1); - - double count_all = 0; - double ptrue_rtrue = 0; - double ptrue_rfalse = 0; - double pfalse_rtrue = 0; - double pfalse_rfalse = 0; - - long size = testingImages.size(); - System.err.println(size); - - for (int i = 0; i < size; i++) { - Mat inMat = testingImages.get(i); - - // Mat features = callback.getHisteqFeatures(inMat); - Mat features = callback.getHistogramFeatures(inMat); - Mat p = features.reshape(1, 1); - p.convertTo(p, opencv_core.CV_32F); - - // System.out.println(p.cols() + "\t" + p.rows() + "\t" + p.type()); - - // samples.cols == var_count && samples.type() == CV_32F - // var_count 的值会在svm.xml库文件中有体现 - float predoct = svm.predict(features); - - int predict = (int) predoct; // 预期值 - int real = testingLabels_real.get(i); // 实际值 - - if (predict == 1 && real == 1) - ptrue_rtrue++; - if (predict == 1 && real == 0) - ptrue_rfalse++; - if (predict == 0 && real == 1) - pfalse_rtrue++; - if (predict == 0 && real == 0) - pfalse_rfalse++; - } - - count_all = size; - System.out.println("Get the Accuracy!"); - - System.out.println("count_all: " + Double.valueOf(count_all).toString()); - System.out.println("ptrue_rtrue: " + Double.valueOf(ptrue_rtrue).toString()); - System.out.println("ptrue_rfalse: " + Double.valueOf(ptrue_rfalse).toString()); - System.out.println("pfalse_rtrue: " + Double.valueOf(pfalse_rtrue).toString()); - System.out.println("pfalse_rfalse: " + Double.valueOf(pfalse_rfalse).toString()); - - double precise = 0; - if (ptrue_rtrue + ptrue_rfalse != 0) { - precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse); - System.out.println("precise: " + Double.valueOf(precise).toString()); - } else - System.out.println("precise: NA"); - - double recall = 0; - if (ptrue_rtrue + pfalse_rtrue != 0) { - recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue); - System.out.println("recall: " + Double.valueOf(recall).toString()); - } else - System.out.println("recall: NA"); - - double Fsocre = 0; - if (precise + recall != 0) { - Fsocre = 2 * (precise * recall) / (precise + recall); - System.out.println("Fsocre: " + Double.valueOf(Fsocre).toString()); - } else - System.out.println("Fsocre: NA"); - return 0; - } - - public static void main(String[] args) { - SVMTrain s = new SVMTrain(); - s.svmTrain(true); - s.svmPredict(); - } - -} +package com.yuxue.train; + +import java.io.File; +import java.util.List; + +import org.opencv.core.Core; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.TermCriteria; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.imgproc.Imgproc; +import org.opencv.ml.Ml; +import org.opencv.ml.SVM; +import org.opencv.ml.TrainData; + +import com.yuxue.constant.Constant; +import com.yuxue.util.FileUtil; + +/** + * 基于org.opencv官方包实现的训练 + * + * + * windows下环境配置: + * 1、官网下载对应版本的openvp:https://opencv.org/releases/page/2/ 当前使用4.0.1版本 + * 2、双击exe文件安装,将 安装目录下\build\java\x64\opencv_java401.dll 拷贝到\build\x64\vc14\bin\目录下 + * 3、eclipse添加User Libraries + * 4、项目右键build path,添加步骤三新增的lib + * + * 图片识别车牌训练 + * 训练出来的库文件,用于判断切图是否包含车牌 + * + * 训练的svm.xml应用: + * 1、替换res/model/svm.xml文件 + * 2、修改com.yuxue.easypr.core.PlateJudge.plateJudge(Mat) 方法 + * 将样本处理方法切换一下,即将对应被注释掉的模块代码取消注释 + * @author yuxue + * @date 2020-05-13 10:10 + */ +public class SVMTrain { + + // 默认的训练操作的根目录 + private static final String DEFAULT_PATH = "D:/PlateDetect/train/plate_detect_svm/"; + + // 训练模型文件保存位置 + private static final String MODEL_PATH = DEFAULT_PATH + "svm.xml"; + + static { + System.loadLibrary(Core.NATIVE_LIBRARY_NAME); + } + + public static void main(String[] arg) { + // 训练, 生成svm.xml库文件 + train(); + + // 识别,判断样本文件是否是车牌 + pridect(); + } + + + public static void train() { + + // 正样本 // 136 × 36 像素 训练的源图像文件要相同大小 + List imgList0 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/HasPlate"), Constant.DEFAULT_TYPE, false); + + // 负样本 // 136 × 36 像素 训练的源图像文件要相同大小 + List imgList1 = FileUtil.listFile(new File(DEFAULT_PATH + "/learn/NoPlate"), Constant.DEFAULT_TYPE, false); + + // 标记:正样本用 0 表示,负样本用 1 表示。 + int labels[] = createLabelArray(imgList0.size(), imgList1.size()); + int sample_num = labels.length; // 图片数量 + + // 用于存放所有样本的矩阵 + Mat trainingDataMat = null; + + // 存放标记的Mat,每个图片都要给一个标记 + Mat labelsMat = new Mat(sample_num, 1, CvType.CV_32SC1); + labelsMat.put(0, 0, labels); + + for (int i = 0; i < sample_num; i++) { // 遍历所有的正负样本,处理样本用于生成训练的库文件 + String path = ""; + if(i < imgList0.size()) { + path = imgList0.get(i).getAbsolutePath(); + } else { + path = imgList1.get(i - imgList0.size()).getAbsolutePath(); + } + + Mat inMat = Imgcodecs.imread(path); // 读取样本文件 + + // 创建一个行数为sample_num, 列数为 rows*cols 的矩阵; 用于存放样本 + if (trainingDataMat == null) { + trainingDataMat = new Mat(sample_num, inMat.rows() * inMat.cols(), CvType.CV_32F); + } + + // 样本文件处理,这里是为了过滤不需要的特征,减少训练时间 // 根据实际情况需要进行处理 + Mat greyMat = new Mat(); + Imgproc.cvtColor(inMat, greyMat, Imgproc.COLOR_BGR2GRAY); // 转成灰度图 + + Mat dst = new Mat(inMat.rows(), inMat.cols(), inMat.type()); + Imgproc.Canny(greyMat, dst, 130, 250); // 边缘检测 + + // 将样本矩阵转换成只有一行的矩阵,保存为float数组 + float[] arr = new float[dst.rows() * dst.cols()]; + int l = 0; + for (int j = 0; j < dst.rows(); j++) { // 遍历行 + for (int k = 0; k < dst.cols(); k++) { // 遍历列 + double[] a = dst.get(j, k); + arr[l] = (float) a[0]; + l++; + } + } + + trainingDataMat.put(i, 0, arr); // 多张图合并到一张 + } + + // Imgcodecs.imwrite(DEFAULT_PATH + "trainingDataMat.jpg", trainingDataMat); + + // 配置SVM训练器参数 + TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 20000, 0.0001); + SVM svm = SVM.create(); + svm.setTermCriteria(criteria); // 指定 + svm.setKernel(SVM.RBF); // 使用预先定义的内核初始化 + svm.setType(SVM.C_SVC); // SVM的类型,默认是:SVM.C_SVC + svm.setGamma(0.1); // 核函数的参数 + svm.setNu(0.1); // SVM优化问题参数 + svm.setC(1); // SVM优化问题的参数C + svm.setP(0.1); + svm.setDegree(0.1); + svm.setCoef0(0.1); + + TrainData td = TrainData.create(trainingDataMat, Ml.ROW_SAMPLE, labelsMat);// 类封装的训练数据 + boolean success = svm.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());// 训练统计模型 + System.out.println("svm training result: " + success); + svm.save(MODEL_PATH);// 保存模型 + } + + + + public static void pridect() { + // 加载训练得到的 xml 模型文件 + SVM svm = SVM.load(MODEL_PATH); + + // 136 × 36 像素 需要跟训练的源图像文件保持相同大小 + doPridect(svm, DEFAULT_PATH + "test/A01_NMV802_0.jpg"); + doPridect(svm, DEFAULT_PATH + "test/debug_resize_1.jpg"); + doPridect(svm, DEFAULT_PATH + "test/debug_resize_2.jpg"); + doPridect(svm, DEFAULT_PATH + "test/debug_resize_3.jpg"); + doPridect(svm, DEFAULT_PATH + "test/S22_KG2187_3.jpg"); + doPridect(svm, DEFAULT_PATH + "test/S22_KG2187_5.jpg"); + doPridect(svm, DEFAULT_PATH + "test/result_0.png"); + doPridect(svm, DEFAULT_PATH + "test/result_1.png"); + doPridect(svm, DEFAULT_PATH + "test/result_2.png"); + doPridect(svm, DEFAULT_PATH + "test/result_3.png"); + doPridect(svm, DEFAULT_PATH + "test/result_4.png"); + doPridect(svm, DEFAULT_PATH + "test/result_5.png"); + doPridect(svm, DEFAULT_PATH + "test/result_6.png"); + doPridect(svm, DEFAULT_PATH + "test/result_7.png"); + doPridect(svm, DEFAULT_PATH + "test/result_8.png"); + + } + + public static void doPridect(SVM svm, String imgPath) { + + Mat src = Imgcodecs.imread(imgPath);// 图片大小要和样本一致 + Imgproc.cvtColor(src, src, Imgproc.COLOR_BGR2GRAY); + Mat dst = new Mat(); + Imgproc.Canny(src, dst, 130, 250); + + Mat samples = dst.reshape(1, 1); + samples.convertTo(samples, CvType.CV_32F); + + // 等价于上面两行代码 + /*Mat samples = new Mat(1, dst.cols() * dst.rows(), CvType.CV_32F); + float[] arr = new float[dst.cols() * dst.rows()]; + int l = 0; + for (int j = 0; j < dst.rows(); j++) { // 遍历行 + for (int k = 0; k < dst.cols(); k++) { // 遍历列 + double[] a = dst.get(j, k); + arr[l] = (float) a[0]; + l++; + } + } + samples.put(0, 0, arr);*/ + + // Imgcodecs.imwrite(DEFAULT_PATH + "test_1.jpg", samples); + + // 如果训练时使用这个标识,那么符合的图像会返回9.0 + float flag = svm.predict(samples); + + System.err.println(flag); + + if (flag == 0) { + System.err.println(imgPath + ": 目标符合"); + } + if (flag == 1) { + System.out.println(imgPath + ": 目标不符合"); + } + } + + public static int[] createLabelArray(Integer i1, Integer i2) { + int labels[] = new int[i1 + i2]; + + for (int i = 0; i < labels.length; i++) { + if(i < i1) { + labels[i] = 0; + } else { + labels[i] = 1; + } + } + return labels; + } + +} diff --git a/src/main/java/com/yuxue/train/SVMTrain1.java b/src/main/java/com/yuxue/train/SVMTrain1.java new file mode 100644 index 00000000..23d653c2 --- /dev/null +++ b/src/main/java/com/yuxue/train/SVMTrain1.java @@ -0,0 +1,382 @@ +package com.yuxue.train; + +import java.util.*; + +import static org.bytedeco.javacpp.opencv_core.*; +import static org.bytedeco.javacpp.opencv_ml.*; + +import org.bytedeco.javacpp.opencv_core; +import org.bytedeco.javacpp.opencv_imgcodecs; + +import com.yuxue.easypr.core.Features; +import com.yuxue.easypr.core.SVMCallback; +import com.yuxue.util.Convert; +import com.yuxue.util.FileUtil; + +/** + * 基于org.bytedeco.javacpp包实现的训练 + * JavaCPP 是一个开源库,它提供了在 Java 中高效访问本地 C++的方法 + * + * 图片识别车牌训练 + * 训练出来的库文件,用于判断切图是否包含车牌 + * + * 训练的svm.xml应用: + * 1、替换res/model/svm.xml文件 + * 2、修改com.yuxue.easypr.core.PlateJudge.plateJudge(Mat) 方法 + * 将样本处理方法切换一下,即将对应被注释掉的模块代码取消注释 + * @author yuxue + * @date 2020-05-14 22:16 + */ +public class SVMTrain1 { + + private SVMCallback callback = new Features(); + + // 默认的训练操作的根目录 + private static final String DEFAULT_PATH = "D:/PlateDetect/train/plate_detect_svm/"; + + // 训练模型文件保存位置 + private static final String MODEL_PATH = DEFAULT_PATH + "svm.xml"; + + private static final String hasPlate = "HasPlate"; + private static final String noPlate = "NoPlate"; + + public SVMTrain() { + } + + public SVMTrain(SVMCallback callback) { + this.callback = callback; + } + + /** + * 将learn文件夹下的图片,转存到tain test文件夹下,区分hasPalte noPlate + * 随机选取bound%作为训练数据,30%作为测试数据 + * @param bound + * @param name + */ + private void learn2Plate(float bound, final String name) { + final String filePath = DEFAULT_PATH + "learn/" + name; + Vector files = new Vector(); + + //// 获取该路径下的所有文件 + FileUtil.getFiles(filePath, files); + int size = files.size(); + if (0 == size) { + System.err.println("当前目录下没有文件: " + filePath); + return; + } + Collections.shuffle(files, new Random(new Date().getTime())); + + //// 随机选取70%作为训练数据,30%作为测试数据 + int boundry = (int) (bound * size); + + // 重新创建目录 + FileUtil.recreateDir(DEFAULT_PATH + "train/" + name); + FileUtil.recreateDir(DEFAULT_PATH + "test/" + name); + + for (int i = 0; i < boundry; i++) { + Mat img = opencv_imgcodecs.imread(files.get(i)); + String str = DEFAULT_PATH + "train/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; + opencv_imgcodecs.imwrite(str, img); + } + + for (int i = boundry; i < size; i++) { + Mat img = opencv_imgcodecs.imread(files.get(i)); + String str = DEFAULT_PATH + "test/" + name + "/" + name + "_" + Integer.valueOf(i).toString() + ".jpg"; + opencv_imgcodecs.imwrite(str, img); + } + } + + /** + * 获取训练图片 + * @param trainingImages + * @param trainingLabels + * @param name + */ + private void getPlateTrain(Mat trainingImages, Vector trainingLabels, final String name, int label) { + // int label = 1; + final String filePath = DEFAULT_PATH + "train/" + name; + Vector files = new Vector(); + + // 获取该路径下的所有文件 + FileUtil.getFiles(filePath, files); + + int size = files.size(); + if (null == files || size <= 0) { + System.out.println("File not found in " + filePath); + return; + } + for (int i = 0; i < size; i++) { + // System.out.println(files.get(i)); + Mat inMat = opencv_imgcodecs.imread(files.get(i)); + // 调用回调函数决定特征 + // Mat features = this.callback.getHisteqFeatures(inMat); + Mat features = this.callback.getHistogramFeatures(inMat); + // 通过直方图均衡化后的彩色图进行预测 + Mat p = features.reshape(1, 1); + p.convertTo(p, opencv_core.CV_32F); + + // 136 36 14688 1 变换尺寸 + // System.err.println(inMat.cols() + "\t" + inMat.rows() + "\t" + p.cols() + "\t" + p.rows()); + + trainingImages.push_back(p); // 合并成一张图片 + trainingLabels.add(label); + } + } + + private void getPlateTest(MatVector testingImages, Vector testingLabels, final String name, int label) { + // int label = 1; + final String filePath = DEFAULT_PATH + "test/" + name; + Vector files = new Vector(); + FileUtil.getFiles(filePath, files); + + int size = files.size(); + if (0 == size) { + System.out.println("File not found in " + filePath); + return; + } + System.out.println("get " + name + " test!"); + for (int i = 0; i < size; i++) { + Mat inMat = opencv_imgcodecs.imread(files.get(i)); + testingImages.push_back(inMat); + testingLabels.add(label); + } + } + + // ! 测试SVM的准确率,回归率以及FScore + public void getAccuracy(Mat testingclasses_preditc, Mat testingclasses_real) { + int channels = testingclasses_preditc.channels(); + System.out.println("channels: " + Integer.valueOf(channels).toString()); + int nRows = testingclasses_preditc.rows(); + System.out.println("nRows: " + Integer.valueOf(nRows).toString()); + int nCols = testingclasses_preditc.cols() * channels; + System.out.println("nCols: " + Integer.valueOf(nCols).toString()); + int channels_real = testingclasses_real.channels(); + System.out.println("channels_real: " + Integer.valueOf(channels_real).toString()); + int nRows_real = testingclasses_real.rows(); + System.out.println("nRows_real: " + Integer.valueOf(nRows_real).toString()); + int nCols_real = testingclasses_real.cols() * channels; + System.out.println("nCols_real: " + Integer.valueOf(nCols_real).toString()); + + double count_all = 0; + double ptrue_rtrue = 0; + double ptrue_rfalse = 0; + double pfalse_rtrue = 0; + double pfalse_rfalse = 0; + + for (int i = 0; i < nRows; i++) { + + final float predict = Convert.toFloat(testingclasses_preditc.ptr(i)); + final float real = Convert.toFloat(testingclasses_real.ptr(i)); + + count_all++; + + // System.out.println("predict:" << predict).toString()); + // System.out.println("real:" << real).toString()); + + if (predict == 1.0 && real == 1.0) + ptrue_rtrue++; + if (predict == 1.0 && real == 0) + ptrue_rfalse++; + if (predict == 0 && real == 1.0) + pfalse_rtrue++; + if (predict == 0 && real == 0) + pfalse_rfalse++; + } + + System.out.println("count_all: " + Double.valueOf(count_all).toString()); + System.out.println("ptrue_rtrue: " + Double.valueOf(ptrue_rtrue).toString()); + System.out.println("ptrue_rfalse: " + Double.valueOf(ptrue_rfalse).toString()); + System.out.println("pfalse_rtrue: " + Double.valueOf(pfalse_rtrue).toString()); + System.out.println("pfalse_rfalse: " + Double.valueOf(pfalse_rfalse).toString()); + + double precise = 0; + if (ptrue_rtrue + ptrue_rfalse != 0) { + precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse); + System.out.println("precise: " + Double.valueOf(precise).toString()); + } else { + System.out.println("precise: NA"); + } + + double recall = 0; + if (ptrue_rtrue + pfalse_rtrue != 0) { + recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue); + System.out.println("recall: " + Double.valueOf(recall).toString()); + } else { + System.out.println("recall: NA"); + } + + if (precise + recall != 0) { + double F = (precise * recall) / (precise + recall); + System.out.println("F: " + Double.valueOf(F).toString()); + } else { + System.out.println("F: NA"); + } + } + + /** + * 训练 + * @param dividePrepared + * @return + */ + public int svmTrain(boolean dividePrepared) { + + Mat classes = new Mat(); + Mat trainingData = new Mat(); + Mat trainingImages = new Mat(); + Vector trainingLabels = new Vector(); + + // 分割learn里的数据到train和test里 // 从库里面选取训练样本 + if (!dividePrepared) { + learn2Plate(0.1f, hasPlate); // 性能不好的机器,最好不要挑选太多的样本,这个方案太消耗资源了。 + learn2Plate(0.1f, noPlate); + } + + // System.err.println("Begin to get train data to memory"); + + getPlateTrain(trainingImages, trainingLabels, hasPlate, 0); + getPlateTrain(trainingImages, trainingLabels, noPlate, 1); + + // System.err.println(trainingImages.cols()); + + trainingImages.copyTo(trainingData); + trainingData.convertTo(trainingData, CV_32F); + + int[] labels = new int[trainingLabels.size()]; + for (int i = 0; i < trainingLabels.size(); ++i) { + labels[i] = trainingLabels.get(i).intValue(); + } + new Mat(labels).copyTo(classes); + + TrainData train_data = TrainData.create(trainingData, ROW_SAMPLE, classes); + + SVM svm = SVM.create(); + + try { + TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 20000, 0.0001); + svm.setTermCriteria(criteria); // 指定 + svm.setKernel(SVM.RBF); // 使用预先定义的内核初始化 + svm.setType(SVM.C_SVC); // SVM的类型,默认是:SVM.C_SVC + svm.setGamma(0.1); // 核函数的参数 + svm.setNu(0.1); // SVM优化问题参数 + svm.setC(1); // SVM优化问题的参数C + svm.setP(0.1); + svm.setDegree(0.1); + svm.setCoef0(0.1); + + svm.trainAuto(train_data, 10, + SVM.getDefaultGrid(SVM.C), + SVM.getDefaultGrid(SVM.GAMMA), + SVM.getDefaultGrid(SVM.P), + SVM.getDefaultGrid(SVM.NU), + SVM.getDefaultGrid(SVM.COEF), + SVM.getDefaultGrid(SVM.DEGREE), + true); + + } catch (Exception err) { + System.out.println(err.getMessage()); + } + + System.out.println("Svm generate done!"); + + /*FileStorage fsTo = new FileStorage(MODEL_PATH, FileStorage.WRITE); + svm.write(fsTo, "svm");*/ + svm.save(MODEL_PATH); + return 0; + } + + // 测试 + public int svmPredict() { + SVM svm = SVM.create(); + try { + svm.clear(); + // svm = SVM.loadSVM(MODEL_PATH, "svm"); + svm = SVM.load(MODEL_PATH); + } catch (Exception err) { + System.err.println(err.getMessage()); + return 0; // next predict requires svm + } + + System.out.println("Begin to predict"); + // Test SVM + MatVector testingImages = new MatVector(); + Vector testingLabels_real = new Vector(); + + // 将测试数据加载入内存 + getPlateTest(testingImages, testingLabels_real, hasPlate, 0); + getPlateTest(testingImages, testingLabels_real, noPlate, 1); + + double count_all = 0; + double ptrue_rtrue = 0; + double ptrue_rfalse = 0; + double pfalse_rtrue = 0; + double pfalse_rfalse = 0; + + long size = testingImages.size(); + System.err.println(size); + + for (int i = 0; i < size; i++) { + Mat inMat = testingImages.get(i); + + // Mat features = callback.getHisteqFeatures(inMat); + Mat features = callback.getHistogramFeatures(inMat); + Mat p = features.reshape(1, 1); + p.convertTo(p, opencv_core.CV_32F); + + // System.out.println(p.cols() + "\t" + p.rows() + "\t" + p.type()); + + // samples.cols == var_count && samples.type() == CV_32F + // var_count 的值会在svm.xml库文件中有体现 + float predoct = svm.predict(features); + + int predict = (int) predoct; // 预期值 + int real = testingLabels_real.get(i); // 实际值 + + if (predict == 1 && real == 1) + ptrue_rtrue++; + if (predict == 1 && real == 0) + ptrue_rfalse++; + if (predict == 0 && real == 1) + pfalse_rtrue++; + if (predict == 0 && real == 0) + pfalse_rfalse++; + } + + count_all = size; + System.out.println("Get the Accuracy!"); + + System.out.println("count_all: " + Double.valueOf(count_all).toString()); + System.out.println("ptrue_rtrue: " + Double.valueOf(ptrue_rtrue).toString()); + System.out.println("ptrue_rfalse: " + Double.valueOf(ptrue_rfalse).toString()); + System.out.println("pfalse_rtrue: " + Double.valueOf(pfalse_rtrue).toString()); + System.out.println("pfalse_rfalse: " + Double.valueOf(pfalse_rfalse).toString()); + + double precise = 0; + if (ptrue_rtrue + ptrue_rfalse != 0) { + precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse); + System.out.println("precise: " + Double.valueOf(precise).toString()); + } else + System.out.println("precise: NA"); + + double recall = 0; + if (ptrue_rtrue + pfalse_rtrue != 0) { + recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue); + System.out.println("recall: " + Double.valueOf(recall).toString()); + } else + System.out.println("recall: NA"); + + double Fsocre = 0; + if (precise + recall != 0) { + Fsocre = 2 * (precise * recall) / (precise + recall); + System.out.println("Fsocre: " + Double.valueOf(Fsocre).toString()); + } else + System.out.println("Fsocre: NA"); + return 0; + } + + public static void main(String[] args) { + SVMTrain s = new SVMTrain(); + s.svmTrain(true); + s.svmPredict(); + } + +}