From 28ed2a5565648488fee6f15463698aacbffb778c Mon Sep 17 00:00:00 2001 From: yuxue Date: Tue, 2 Jun 2020 16:54:57 +0800 Subject: [PATCH] no commit message --- src/main/java/com/yuxue/train/ANNTrain.java | 56 +++++++++++--------- src/main/java/com/yuxue/train/ANNTrain1.java | 2 +- src/main/java/com/yuxue/util/ImageUtil.java | 2 +- src/main/java/com/yuxue/util/PlateUtil.java | 28 ++++++++-- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/main/java/com/yuxue/train/ANNTrain.java b/src/main/java/com/yuxue/train/ANNTrain.java index eaff97e0..c228ceb0 100644 --- a/src/main/java/com/yuxue/train/ANNTrain.java +++ b/src/main/java/com/yuxue/train/ANNTrain.java @@ -45,7 +45,7 @@ public class ANNTrain { // 训练模型文件保存位置 private static final String MODEL_PATH = DEFAULT_PATH + "ann.xml"; - + public static float[] projectedHistogram(final Mat img, Direction direction) { int sz = 0; switch (direction) { @@ -121,12 +121,10 @@ public class ANNTrain { for (int i = 0; i < Constant.numCharacter; i++) { String str = DEFAULT_PATH + "learn/" + Constant.strCharacters[i]; Vector files = new Vector(); - FileUtil.getFiles(str, files); - - int size = (int) files.size(); - for (int j = 0; j < size; j++) { - Mat img = Imgcodecs.imread(files.get(j), 0); - // System.err.println(files.get(j)); // 文件名不能包含中文 + FileUtil.getFiles(str, files); // 文件名不能包含中文 + + for (String filePath : files) { + Mat img = Imgcodecs.imread(filePath); Mat f = features(img, _predictsize); samples.push_back(f); trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标 @@ -138,22 +136,18 @@ public class ANNTrain { String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i]; Vector files = new Vector(); FileUtil.getFiles(str, files); - - int size = (int) files.size(); - for (int j = 0; j < size; j++) { - Mat img = Imgcodecs.imread(files.get(j), 0); - // System.err.println(files.get(j)); // 文件名不能包含中文 + for (String filePath : files) { + Mat img = Imgcodecs.imread(filePath); Mat f = features(img, _predictsize); samples.push_back(f); - trainingLabels.add(i + Constant.numCharacter); + trainingLabels.add(i + Constant.numCharacter); // 每一幅字符图片所对应的字符类别索引下标 } } - //440 vhist.length + hhist.length + lowData.cols() * lowData.rows(); // CV_32FC1 CV_32SC1 CV_32F Mat classes = new Mat(trainingLabels.size(), Constant.numAll, CvType.CV_32F); - + float[] labels = new float[trainingLabels.size()]; for (int i = 0; i < labels.length; ++i) { classes.put(i, trainingLabels.get(i), 1.f); @@ -167,8 +161,8 @@ public class ANNTrain { layers.put(0, 0, samples.cols()); layers.put(0, 1, _neurons); layers.put(0, 2, classes.cols()); - - System.out.println(layers); + + // System.out.println(layers); ann.setLayerSizes(layers); ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1); @@ -183,32 +177,42 @@ public class ANNTrain { // ann.write(fsto, "ann"); ann.save(MODEL_PATH); } - - + + public void predict() { ann.clear(); ann = ANN_MLP.load(MODEL_PATH); Vector files = new Vector(); FileUtil.getFiles(DEFAULT_PATH + "test/", files); - + + String plate = ""; for (String string : files) { Mat img = Imgcodecs.imread(string, 0); Mat f = features(img, Constant.predictSize); - + // 140 predictSize = 10; vhist.length + hhist.length + lowData.cols() * lowData.rows(); // 440 predictSize = 20; Mat output = new Mat(1, 140, CvType.CV_32F); //ann.predict(f, output, 0); // 预测结果 - System.err.println(string + "===>" + (int) ann.predict(f, output, 0)); - + // ann.predict(f, output, 0); // System.err.println(string + "===>" + output.get(0, 0)[0]); - + + int index = (int) ann.predict(f, output, 0); + System.err.println(string + "===>" + index); + + if (index < Constant.numCharacter) { + plate += String.valueOf(Constant.strCharacters[index]); + } else { + String s = Constant.strChinese[index - Constant.numCharacter]; + plate += Constant.KEY_CHINESE_MAP.get(s); + } } + System.err.println("===>" + plate); } public static void main(String[] args) { - + ANNTrain annT = new ANNTrain(); // 这里演示只训练model文件夹下的ann.xml,此模型是一个predictSize=10,neurons=40的ANN模型 // 可根据需要训练不同的predictSize或者neurons的ANN模型 @@ -216,7 +220,7 @@ public class ANNTrain { // annT.train(Constant.predictSize, Constant.neurons); annT.predict(); - + System.out.println("The end."); } diff --git a/src/main/java/com/yuxue/train/ANNTrain1.java b/src/main/java/com/yuxue/train/ANNTrain1.java index 0a0d8ce8..a177ac24 100644 --- a/src/main/java/com/yuxue/train/ANNTrain1.java +++ b/src/main/java/com/yuxue/train/ANNTrain1.java @@ -128,7 +128,7 @@ public class ANNTrain1 { int index = (int) ann.predict(f, output, 0); String result = ""; - if (index <= Constant.numCharacter) { + if (index < Constant.numCharacter) { result = String.valueOf(Constant.strCharacters[index]); } else { String s = Constant.strChinese[index - Constant.numCharacter]; diff --git a/src/main/java/com/yuxue/util/ImageUtil.java b/src/main/java/com/yuxue/util/ImageUtil.java index 31a519e3..65b29d09 100644 --- a/src/main/java/com/yuxue/util/ImageUtil.java +++ b/src/main/java/com/yuxue/util/ImageUtil.java @@ -71,7 +71,7 @@ public class ImageUtil { Instant start = Instant.now(); String tempPath = DEFAULT_BASE_TEST_PATH + "test/"; String filename = tempPath + "/100_yuantu.jpg"; - filename = tempPath + "/100_yuantu2.jpg"; + filename = tempPath + "/100_yuantu1.jpg"; // filename = tempPath + "/109_crop_0.png"; Mat src = Imgcodecs.imread(filename); diff --git a/src/main/java/com/yuxue/util/PlateUtil.java b/src/main/java/com/yuxue/util/PlateUtil.java index 5bf74109..faaf90f2 100644 --- a/src/main/java/com/yuxue/util/PlateUtil.java +++ b/src/main/java/com/yuxue/util/PlateUtil.java @@ -24,6 +24,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.yuxue.constant.Constant; import com.yuxue.enumtype.PlateColor; +import com.yuxue.train.ANNTrain; import com.yuxue.train.SVMTrain; @@ -58,7 +59,8 @@ public class PlateUtil { entry.setValue(index); index ++; } - + + // 这个位置加载模型文件会报错,暂时没时间定位啥问题报错 /*loadSvmModel("D:/PlateDetect/train/plate_detect_svm/svm2.xml"); loadAnnModel("D:/PlateDetect/train/chars_recognise_ann/ann.xml");*/ } @@ -288,16 +290,34 @@ public class PlateUtil { Vector sorted = new Vector(); sortRect(rt, sorted); + String plate = ""; Vector dst = new Vector(); + + ANNTrain annT = new ANNTrain(); for (int i = 0; i < sorted.size(); i++) { Mat img_crop = new Mat(threshold, sorted.get(i)); img_crop = preprocessChar(img_crop); dst.add(img_crop); - Imgcodecs.imwrite(tempPath + debugMap.get("plateCrop") + "_plateCrop_" + i + ".jpg", img_crop); + if(debug) { + Imgcodecs.imwrite(tempPath + debugMap.get("plateCrop") + "_plateCrop_" + i + ".jpg", img_crop); + } + + + Mat f = annT.features(img_crop, Constant.predictSize); + + // 字符预测 + Mat output = new Mat(1, 140, CvType.CV_32F); + int index = (int) ann.predict(f, output, 0); + + if (index < Constant.numCharacter) { + plate += String.valueOf(Constant.strCharacters[index]); + } else { + String s = Constant.strChinese[index - Constant.numCharacter]; + plate += Constant.KEY_CHINESE_MAP.get(s); + } } - + System.err.println("===>" + plate); - return; }