From aae605be809455fd9391bd89de75c5c37a238004 Mon Sep 17 00:00:00 2001 From: yuxue Date: Thu, 2 Jul 2020 17:25:50 +0800 Subject: [PATCH] no commit message --- src/main/java/com/yuxue/train/ANNTrain.java | 271 ++++++-------------- 1 file changed, 72 insertions(+), 199 deletions(-) diff --git a/src/main/java/com/yuxue/train/ANNTrain.java b/src/main/java/com/yuxue/train/ANNTrain.java index a9aa0800..b7d004be 100644 --- a/src/main/java/com/yuxue/train/ANNTrain.java +++ b/src/main/java/com/yuxue/train/ANNTrain.java @@ -6,11 +6,8 @@ import java.util.Vector; import org.opencv.core.Core; import org.opencv.core.CvType; import org.opencv.core.Mat; -import org.opencv.core.Point; -import org.opencv.core.Size; import org.opencv.core.TermCriteria; import org.opencv.imgcodecs.Imgcodecs; -import org.opencv.imgproc.Imgproc; import org.opencv.ml.ANN_MLP; import org.opencv.ml.Ml; import org.opencv.ml.TrainData; @@ -49,96 +46,6 @@ public class ANNTrain { // 训练模型文件保存位置 private static final String MODEL_PATH = DEFAULT_PATH + "ann.xml"; - - /** - * 进行膨胀操作 - * @param inMat - * @return - */ - public Mat dilate(Mat inMat) { - Mat result = inMat.clone(); - Mat element = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(2, 2)); - Imgproc.dilate(inMat, result, element); - return result; - } - - /** - * 进行腐蚀操作 - * @param inMat - * @return - */ - public Mat erode(Mat inMat) { - Mat result = inMat.clone(); - Mat element = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(2, 2)); - Imgproc.erode(inMat, result, element); - return result; - } - - - /** - * 随机数平移 - * @param inMat - * @return - */ - public Mat randTranslate(Mat inMat) { - Random rand = new Random(); - Mat result = inMat.clone(); - int ran_x = rand.nextInt(10000) % 5 - 2; // 控制在-2~3个像素范围内 - int ran_y = rand.nextInt(10000) % 5 - 2; - return translateImg(result, ran_x, ran_y); - } - - - /** - * 随机数旋转 - * @param inMat - * @return - */ - public Mat randRotate(Mat inMat) { - Random rand = new Random(); - Mat result = inMat.clone(); - float angle = (float) (rand.nextInt(10000) % 15 - 7); // 旋转角度控制在-7~8°范围内 - return rotateImg(result, angle); - } - - - - /** - * 平移 - * @param img - * @param offsetx - * @param offsety - * @return - */ - public Mat translateImg(Mat img, int offsetx, int offsety){ - Mat dst = new Mat(); - //定义平移矩阵 - Mat trans_mat = Mat.zeros(2, 3, CvType.CV_32FC1); - trans_mat.put(0, 0, 1); - trans_mat.put(0, 2, offsetx); - trans_mat.put(1, 1, 1); - trans_mat.put(1, 2, offsety); - Imgproc.warpAffine(img, dst, trans_mat, img.size()); // 仿射变换 - return dst; - } - - - /** - * 旋转角度 - * @param source - * @param angle - * @return - */ - public Mat rotateImg(Mat source, float angle){ - Point src_center = new Point(source.cols() / 2.0F, source.rows() / 2.0F); - Mat rot_mat = Imgproc.getRotationMatrix2D(src_center, angle, 1); - Mat dst = new Mat(); - // 仿射变换 可以考虑使用投影变换; 这里使用放射变换进行旋转,对于实际效果来说感觉意义不大,反而会干扰结果预测 - Imgproc.warpAffine(source, dst, rot_mat, source.size()); - return dst; - } - - public void train(int _predictsize, int _neurons) { Mat samples = new Mat(); // 使用push_back,行数列数不能赋初始值 Vector trainingLabels = new Vector(); @@ -149,68 +56,46 @@ public class ANNTrain { Vector files = new Vector(); FileUtil.getFiles(str, files); // 文件名不能包含中文 - int count = 200; // 控制从训练样本中,抽取指定数量的样本 + // int count = 200; // 控制从训练样本中,抽取指定数量的样本 + int count = files.size(); // 控制从训练样本中,抽取指定数量的样本 for (int j = 0; j < count; j++) { - Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0); + String filename = ""; + if(j < files.size()) { + filename = files.get(j); + } else { + filename = files.get(rand.nextInt(files.size() - 1)); // 样本不足,随机重复提取已有的样本 + } + + Mat img = Imgcodecs.imread(filename, 0); + Mat f = PlateUtil.features(img, _predictsize); samples.push_back(f); trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标 // 增加随机平移样本 - samples.push_back(PlateUtil.features(randTranslate(img), _predictsize)); + samples.push_back(PlateUtil.features(PlateUtil.randTranslate(img), _predictsize)); trainingLabels.add(i); // 增加随机旋转样本 - samples.push_back(PlateUtil.features(randRotate(img), _predictsize)); + samples.push_back(PlateUtil.features(PlateUtil.randRotate(img), _predictsize)); trainingLabels.add(i); // 增加膨胀样本 - /*samples.push_back(PlateUtil.features(dilate(img), _predictsize)); - trainingLabels.add(i);*/ + samples.push_back(PlateUtil.features(PlateUtil.dilate(img), _predictsize)); + trainingLabels.add(i); // 增加腐蚀样本 - /*samples.push_back(PlateUtil.features(erode(img), _predictsize)); + /*samples.push_back(PlateUtil.features(PlateUtil.erode(img), _predictsize)); trainingLabels.add(i); */ } } - // 加载汉字字符 - for (int i = 0; i < Constant.strChinese.length; i++) { - String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i]; - Vector files = new Vector(); - FileUtil.getFiles(str, files); - - int count = 200; // 控制从训练样本中,抽取指定数量的样本 - for (int j = 0; j < count; j++) { - Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0); - Mat f = PlateUtil.features(img, _predictsize); - samples.push_back(f); - trainingLabels.add(i + Constant.numCharacter); - - // 增加随机平移样本 - samples.push_back(PlateUtil.features(randTranslate(img), _predictsize)); - trainingLabels.add(i + Constant.numCharacter); - - // 增加随机旋转样本 - samples.push_back(PlateUtil.features(randRotate(img), _predictsize)); - trainingLabels.add(i + Constant.numCharacter); - - // 增加膨胀样本 - /*samples.push_back(PlateUtil.features(dilate(img), _predictsize)); - trainingLabels.add(i + Constant.numCharacter);*/ - - // 增加腐蚀样本 - samples.push_back(PlateUtil.features(erode(img), _predictsize)); - trainingLabels.add(i + Constant.numCharacter); - } - } - samples.convertTo(samples, CvType.CV_32F); //440 vhist.length + hhist.length + lowData.cols() * lowData.rows(); // CV_32FC1 CV_32SC1 CV_32F - Mat classes = Mat.zeros(trainingLabels.size(), Constant.numAll, CvType.CV_32F); + Mat classes = Mat.zeros(trainingLabels.size(), Constant.strCharacters.length, CvType.CV_32F); float[] labels = new float[trainingLabels.size()]; for (int i = 0; i < labels.length; ++i) { @@ -244,78 +129,66 @@ public class ANNTrain { public void predict() { ann.clear(); ann = ANN_MLP.load(MODEL_PATH); - Vector files = new Vector(); - FileUtil.getFiles(DEFAULT_PATH + "test/", files); // 获取测试文件 - - String plate = ""; - for (String string : files) { - Mat img = Imgcodecs.imread(string, 0); - Mat f = PlateUtil.features(img, Constant.predictSize); - - int index = 0; - double maxVal = -2; - Mat output = new Mat(1, Constant.numAll, CvType.CV_32F); - ann.predict(f, output); // 预测结果 - for (int j = 0; j < Constant.numAll; j++) { - double val = output.get(0, j)[0]; - if (val > maxVal) { - maxVal = val; - index = j; - } - } - - // 随机平移 - /*f = PlateUtil.features(randTranslate(img), Constant.predictSize); - ann.predict(f, output); // 预测结果 - for (int j = 0; j < Constant.numAll; j++) { - double val = output.get(0, j)[0]; - if (val > maxVal) { - maxVal = val; - index = j; - } - }*/ - - // 随机旋转 - /*f = PlateUtil.features(randRotate(img), Constant.predictSize); - ann.predict(f, output); // 预测结果 - for (int j = 0; j < Constant.numAll; j++) { - double val = output.get(0, j)[0]; - if (val > maxVal) { - maxVal = val; - index = j; - } - }*/ - - // 膨胀 - /*f = PlateUtil.features(dilate(img), Constant.predictSize); - ann.predict(f, output); // 预测结果 - for (int j = 0; j < Constant.numAll; j++) { - double val = output.get(0, j)[0]; - if (val > maxVal) { - maxVal = val; - index = j; + + int total = 0; + int correct = 0; + + // 遍历测试样本下的所有文件,计算预测准确率 + for (int i = 0; i < Constant.strCharacters.length; i++) { + + char c = Constant.strCharacters[i]; + String path = DEFAULT_PATH + "learn/" + c; + + Vector files = new Vector(); + FileUtil.getFiles(path, files); + + for (String filePath : files) { + + Mat img = Imgcodecs.imread(filePath, 0); + Mat f = PlateUtil.features(img, Constant.predictSize); + + int index = 0; + double maxVal = -2; + Mat output = new Mat(1, Constant.strCharacters.length, CvType.CV_32F); + ann.predict(f, output); // 预测结果 + for (int j = 0; j < Constant.strCharacters.length; j++) { + double val = output.get(0, j)[0]; + if (val > maxVal) { + maxVal = val; + index = j; + } } - }*/ - - // 腐蚀 -- 识别中文字符效果会好一点,识别数字及字母效果会更差 - /*f = PlateUtil.features(erode(img), Constant.predictSize); - ann.predict(f, output); // 预测结果 - for (int j = 0; j < Constant.numAll; j++) { - double val = output.get(0, j)[0]; - if (val > maxVal) { - maxVal = val; - index = j; + + // 膨胀 + f = PlateUtil.features(PlateUtil.dilate(img), Constant.predictSize); + ann.predict(f, output); // 预测结果 + for (int j = 0; j < Constant.strCharacters.length; j++) { + double val = output.get(0, j)[0]; + if (val > maxVal) { + maxVal = val; + index = j; + } } - }*/ - if (index < Constant.numCharacter) { - plate += String.valueOf(Constant.strCharacters[index]); - } else { - String s = Constant.strChinese[index - Constant.numCharacter]; - plate += Constant.KEY_CHINESE_MAP.get(s); + String result = String.valueOf(Constant.strCharacters[index]); + if(result.equals(String.valueOf(c))) { + correct++; + } else { + System.err.print(filePath); + System.err.println("\t预测结果:" + result); + } + total++; } + } - System.err.println("===>" + plate); + + System.out.print("total:" + total); + System.out.print("\tcorrect:" + correct); + System.out.print("\terror:" + (total - correct)); + System.out.println("\t计算准确率为:" + correct / (total * 1.0)); + + //牛逼,我操 total:13178 correct:13139 error:39 计算准确率为:0.9970405220822584 + return; }