diff --git a/src/main/java/com/yuxue/train/ANNTrain.java b/src/main/java/com/yuxue/train/ANNTrain.java index 315c3f43..9a847d0c 100644 --- a/src/main/java/com/yuxue/train/ANNTrain.java +++ b/src/main/java/com/yuxue/train/ANNTrain.java @@ -1,20 +1,20 @@ package com.yuxue.train; -import static org.bytedeco.javacpp.opencv_core.CV_32F; -import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE; - import java.util.Vector; -import org.bytedeco.javacpp.opencv_core.FileStorage; -import org.bytedeco.javacpp.opencv_core.Mat; -import org.bytedeco.javacpp.opencv_core.TermCriteria; -import org.bytedeco.javacpp.opencv_core; -import org.bytedeco.javacpp.opencv_imgcodecs; -import org.bytedeco.javacpp.opencv_ml.ANN_MLP; -import org.bytedeco.javacpp.opencv_ml.TrainData; -import com.yuxue.easypr.core.CoreFunc; -import com.yuxue.util.Convert; +import org.opencv.core.Core; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.Size; +import org.opencv.core.TermCriteria; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.imgproc.Imgproc; +import org.opencv.ml.ANN_MLP; +import org.opencv.ml.Ml; +import org.opencv.ml.TrainData; + +import com.yuxue.enumtype.Direction; import com.yuxue.util.FileUtil; @@ -36,6 +36,9 @@ public class ANNTrain { private ANN_MLP ann = ANN_MLP.create(); + static { + System.loadLibrary(Core.NATIVE_LIBRARY_NAME); + } // 中国车牌; 34个字符; 没有 字母I、字母O private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', @@ -85,30 +88,119 @@ public class ANNTrain { private static final String DEFAULT_PATH = "D:/PlateDetect/train/chars_recognise_ann/"; // 训练模型文件保存位置 - private static final String DATA_PATH = DEFAULT_PATH + "ann_data.xml"; + // private static final String DATA_PATH = DEFAULT_PATH + "ann_data.xml"; private static final String MODEL_PATH = DEFAULT_PATH + "ann.xml"; + + public static float[] projectedHistogram(final Mat img, Direction direction) { + int sz = 0; + switch (direction) { + case HORIZONTAL: + sz = img.rows(); + break; + + case VERTICAL: + sz = img.cols(); + break; + + default: + break; + } + + // 统计这一行或一列中,非零元素的个数,并保存到nonZeroMat中 + float[] nonZeroMat = new float[sz]; + Core.extractChannel(img, img, 0); + for (int j = 0; j < sz; j++) { + Mat data = (direction == Direction.HORIZONTAL) ? img.row(j) : img.col(j); + int count = Core.countNonZero(data); + nonZeroMat[j] = count; + } + + // Normalize histogram + float max = 0; + for (int j = 0; j < nonZeroMat.length; ++j) { + max = Math.max(max, nonZeroMat[j]); + } + + if (max > 0) { + for (int j = 0; j < nonZeroMat.length; ++j) { + nonZeroMat[j] /= max; + } + } + + return nonZeroMat; + } + - public int saveTrainData(int _predictsize) { + public Mat features(Mat in, int sizeData) { + + float[] vhist = projectedHistogram(in, Direction.VERTICAL); + float[] hhist = projectedHistogram(in, Direction.HORIZONTAL); + + Mat lowData = new Mat(); + if (sizeData > 0) { + Imgproc.resize(in, lowData, new Size(sizeData, sizeData)); + } - Mat classes = new Mat(); - Mat trainingDataf = new Mat(); + int numCols = vhist.length + hhist.length + lowData.cols() * lowData.rows(); + Mat out = new Mat(1, numCols, CvType.CV_32F); - Vector trainingLabels = new Vector(); + int j = 0; + for (int i = 0; i < vhist.length; ++i, ++j) { + out.put(0, j, vhist[i]); + } + for (int i = 0; i < hhist.length; ++i, ++j) { + out.put(0, j, hhist[i]); + } + + for (int x = 0; x < lowData.cols(); x++) { + for (int y = 0; y < lowData.rows(); y++, ++j) { + // float val = lowData.ptr(x, y).get(0) & 0xFF; + double[] val = lowData.get(x, y); + out.put(0, j, val[0]); + } + } + return out; + } + + public void train(int _predictsize, int _neurons) { + // 读取样本文件数据 + /*FileStorage fs = new FileStorage(DATA_PATH, FileStorage.READ); + Mat samples = new Mat(fs.get("TrainingDataF" + _predictsize)); + Mat classes = new Mat(fs.get("classes")); + + Mat trainClasses = new Mat(samples.rows(), numAll, CV_32F); + for (int i = 0; i < trainClasses.rows(); i++) { + for (int k = 0; k < trainClasses.cols(); k++) { + // If class of data i is same than a k class + if (k == Convert.toInt(classes.ptr(i))) { + trainClasses.ptr(i, k).put(Convert.getBytes(1f)); + + } else { + trainClasses.ptr(i, k).put(Convert.getBytes(0f)); + } + } + } + samples.convertTo(samples, CV_32F); + System.out.println(samples.type());*/ + + + Mat samples = new Mat(); // 使用push_back,行数列数不能赋初始值 + + Vector trainingLabels = new Vector(); // 加载数字及字母字符 for (int i = 0; i < numCharacter; i++) { String str = DEFAULT_PATH + strCharacters[i]; - System.err.println(str); Vector files = new Vector(); FileUtil.getFiles(str, files); int size = (int) files.size(); for (int j = 0; j < size; j++) { - Mat img = opencv_imgcodecs.imread(files.get(j), 0); + Mat img = Imgcodecs.imread(files.get(j), 0); // System.err.println(files.get(j)); // 文件名不能包含中文 - Mat f = CoreFunc.features(img, _predictsize); - trainingDataf.push_back(f); + Mat f = features(img, _predictsize); + samples.push_back(f); trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标 } } @@ -121,84 +213,67 @@ public class ANNTrain { int size = (int) files.size(); for (int j = 0; j < size; j++) { - Mat img = opencv_imgcodecs.imread(files.get(j), 0); + Mat img = Imgcodecs.imread(files.get(j), 0); // System.err.println(files.get(j)); // 文件名不能包含中文 - Mat f = CoreFunc.features(img, _predictsize); - trainingDataf.push_back(f); + Mat f = features(img, _predictsize); + samples.push_back(f); trainingLabels.add(i + numCharacter); } } // CV_32FC1 CV_32SC1 CV_32F - trainingDataf.convertTo(trainingDataf, opencv_core.CV_32F); + // samples.convertTo(samples, CvType.CV_32F); - int[] labels = new int[trainingLabels.size()]; + float[] labels = new float[trainingLabels.size()]; for (int i = 0; i < labels.length; ++i) { labels[i] = trainingLabels.get(i).intValue(); } - new Mat(labels).copyTo(classes); - - FileStorage fs = new FileStorage(DATA_PATH, FileStorage.WRITE); - fs.write("TrainingDataF" + _predictsize, trainingDataf); - fs.write("classes", classes); - fs.release(); + Mat classes = new Mat(labels.length, 440, CvType.CV_32F); + classes.put(0, 0, labels); - System.out.println("End saveTrainData"); - return 0; - } - - public void train(int _predictsize, int _neurons) { - - // 读取样本文件数据 - FileStorage fs = new FileStorage(DATA_PATH, FileStorage.READ); - Mat samples = new Mat(fs.get("TrainingDataF" + _predictsize)); - Mat classes = new Mat(fs.get("classes")); - - Mat trainClasses = new Mat(samples.rows(), numAll, CV_32F); - for (int i = 0; i < trainClasses.rows(); i++) { - for (int k = 0; k < trainClasses.cols(); k++) { - // If class of data i is same than a k class - if (k == Convert.toInt(classes.ptr(i))) { - trainClasses.ptr(i, k).put(Convert.getBytes(1f)); - - } else { - trainClasses.ptr(i, k).put(Convert.getBytes(0f)); - } - } - } + System.out.println(samples.rows()); + System.out.println(samples.cols()); + System.out.println(samples.type()); - samples.convertTo(samples, CV_32F); + System.out.println(classes.rows()); + System.out.println(classes.cols()); + System.out.println(classes.type()); - System.out.println(samples.type()); - // samples.type() == CV_32F || samples.type() == CV_32S - TrainData train_data = TrainData.create(samples, ROW_SAMPLE, trainClasses); + TrainData train_data = TrainData.create(samples, Ml.ROW_SAMPLE, classes); + // //l_count为相量_layer_sizes的维数,即MLP的层数L + // l_count = _layer_sizes->rows + _layer_sizes->cols - 1; ann.clear(); - Mat layers = new Mat(1, 3, CV_32F); - layers.ptr(0).put(Convert.getBytes(samples.cols())); - layers.ptr(1).put(Convert.getBytes(_neurons)); - layers.ptr(2).put(Convert.getBytes(numAll)); + Mat layers = new Mat(1, 3, CvType.CV_32F); + layers.put(0, 0, samples.cols()); + layers.put(0, 1, _neurons); + layers.put(0, 2, classes.cols()); + + /*layers.ptr(0,0).put(Convert.getBytes(samples.cols())); //440 vhist.length + hhist.length + lowData.cols() * lowData.rows(); + layers.ptr(0,1).put(Convert.getBytes(_predictsize)); + layers.ptr(0,2).put(Convert.getBytes(numAll));*/ ann.setLayerSizes(layers); ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1); ann.setTrainMethod(ANN_MLP.BACKPROP); - TermCriteria criteria = new TermCriteria(TermCriteria.MAX_ITER, 30000, 0.0001); + TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 30000, 0.0001); ann.setTermCriteria(criteria); ann.setBackpropWeightScale(0.1); ann.setBackpropMomentumScale(0.1); ann.train(train_data); System.err.println("完成 "); - FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE); - ann.write(fsto, "ann"); + // FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE); + // ann.write(fsto, "ann"); + ann.save(MODEL_PATH); } public static void main(String[] args) { ANNTrain annT = new ANNTrain(); // 可根据需要训练不同的predictSize或者neurons的ANN模型 - int _predictsize = 10; + int _predictsize = 20; int _neurons = 40; // annT.saveTrainData(_predictsize);