diff --git a/src/test/java/com/yuxue/test/ANNTrain.java b/src/main/java/com/yuxue/train/ANNTrain.java similarity index 66% rename from src/test/java/com/yuxue/test/ANNTrain.java rename to src/main/java/com/yuxue/train/ANNTrain.java index a91a0447..d10a9502 100644 --- a/src/test/java/com/yuxue/test/ANNTrain.java +++ b/src/main/java/com/yuxue/train/ANNTrain.java @@ -1,49 +1,50 @@ -package com.yuxue.test; +package com.yuxue.train; import static org.bytedeco.javacpp.opencv_core.CV_32F; import static org.bytedeco.javacpp.opencv_core.CV_32FC1; import static org.bytedeco.javacpp.opencv_core.CV_32SC1; import static org.bytedeco.javacpp.opencv_core.getTickCount; import static org.bytedeco.javacpp.opencv_imgproc.resize; +import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE; import java.util.Vector; -import org.bytedeco.javacpp.opencv_core.CvMemStorage; import org.bytedeco.javacpp.opencv_core.FileStorage; import org.bytedeco.javacpp.opencv_core.Mat; import org.bytedeco.javacpp.opencv_core.Scalar; import org.bytedeco.javacpp.opencv_core.Size; +import org.bytedeco.javacpp.opencv_core.TermCriteria; +import org.bytedeco.javacpp.opencv_imgcodecs; import org.bytedeco.javacpp.opencv_ml.ANN_MLP; +import org.bytedeco.javacpp.opencv_ml.TrainData; import com.yuxue.easypr.core.CoreFunc; import com.yuxue.enumtype.Direction; import com.yuxue.util.Convert; import com.yuxue.util.FileUtil; -import ch.qos.logback.classic.pattern.Util; - -/* +/** * + * @author yuxue + * @date 2020-05-14 22:16 */ public class ANNTrain { - - /*private ANN_MLP ann=ANN_MLP.create(); + + private ANN_MLP ann = ANN_MLP.create(); // 中国车牌 - private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', - 'F', 'G', 'H', 没有I - 'J', 'K', 'L', 'M', 'N', 没有O 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' }; - private final int numCharacter = 34; 没有I和0,10个数字与24个英文字符之和 + private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', /* 没有I */ + 'J', 'K', 'L', 'M', 'N', /* 没有O */'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' }; + private final int numCharacter = 34; /* 没有I和0,10个数字与24个英文字符之和 */ // 以下都是我训练时用到的中文字符数据,并不全面,有些省份没有训练数据所以没有字符 // 有些后面加数字2的表示在训练时常看到字符的一种变形,也作为训练数据存储 - private final String strChinese[] = { "zh_cuan" 川 , "zh_e" 鄂 , "zh_gan" 赣 , "zh_hei" 黑 , - "zh_hu" 沪 , "zh_ji" 冀 , "zh_jl" 吉 , "zh_jin" 津 , "zh_jing" 京 , "zh_shan" 陕 , - "zh_liao" 辽 , "zh_lu" 鲁 , "zh_min" 闽 , "zh_ning" 宁 , "zh_su" 苏 , "zh_sx" 晋 , - "zh_wan" 皖 , "zh_yu" 豫 , "zh_yue" 粤 , "zh_zhe" 浙 }; + private final String strChinese[] = { "zh_cuan" /* 川 */, "zh_e" /* 鄂 */, "zh_gan" /* 赣 */, "zh_hei" /* 黑 */, "zh_hu" /* 沪 */, "zh_ji" /* 冀 */, "zh_jl" /* 吉 */, "zh_jin" /* 津 */, "zh_jing" /* 京 */, + "zh_shan" /* 陕 */, "zh_liao" /* 辽 */, "zh_lu" /* 鲁 */, "zh_min" /* 闽 */, "zh_ning" /* 宁 */, "zh_su" /* 苏 */, "zh_sx" /* 晋 */, "zh_wan" /* 皖 */, "zh_yu" /* 豫 */, "zh_yue" /* 粤 */, + "zh_zhe" /* 浙 */ }; - private final int numAll = 54; 34+20=54 + private final int numAll = 54; /* 34+20=54 */ public Mat features(Mat in, int sizeData) { // Histogram features @@ -72,23 +73,27 @@ public class ANNTrain { out.ptr(j).put(Convert.getBytes(val)); } } - // if(DEBUG) - // cout << out << "\n===========================================\n"; return out; } - public void annTrain(Mat TrainData, Mat classes, int nNeruns) { + + /** + * + * @param TrainData 训练样本数据 + * @param classes 数据对应的标签 + * @param nNeruns + */ + public void annTrain(Mat trainingData, Mat classes, int nNeruns) { ann.clear(); Mat layers = new Mat(1, 3, CV_32SC1); - layers.ptr(0).put(Convert.getBytes(TrainData.cols())); + layers.ptr(0).put(Convert.getBytes(trainingData.cols())); layers.ptr(1).put(Convert.getBytes(nNeruns)); layers.ptr(2).put(Convert.getBytes(numAll)); - ann.create(layers, ANN_MLP.SIGMOID_SYM, 1, 1); // Prepare trainClases // Create a mat with n trained data by m classes Mat trainClasses = new Mat(); - trainClasses.create(TrainData.rows(), numAll, CV_32FC1); + trainClasses.create(trainingData.rows(), numAll, CV_32FC1); for (int i = 0; i < trainClasses.rows(); i++) { for (int k = 0; k < trainClasses.cols(); k++) { // If class of data i is same than a k class @@ -98,12 +103,29 @@ public class ANNTrain { trainClasses.ptr(i, k).put(Convert.getBytes(0f)); } } - Mat weights = new Mat(1, TrainData.rows(), CV_32FC1, Scalar.all(1)); + Mat weights = new Mat(1, trainingData.rows(), CV_32FC1, Scalar.all(1)); // Learn classifier - ann.train(TrainData, trainClasses, weights); + + TrainData train_data = TrainData.create(trainingData, ROW_SAMPLE, trainClasses); + + /* + ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP); + ann_->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 30000, 0.0001)); + ann_->setBackpropWeightScale(0.1); + ann_->setBackpropMomentumScale(0.1);*/ + + ann.setLayerSizes(layers); + ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1); + ann.setTrainMethod(ANN_MLP.BACKPROP); + TermCriteria criteria = new TermCriteria(TermCriteria.MAX_ITER, 30000, 0.0001); + ann.setTermCriteria(criteria); + ann.setBackpropWeightScale(0.1); + ann.setBackpropMomentumScale(0.1); + ann.train(train_data); } public int saveTrainData() { + System.out.println("Begin saveTrainData"); Mat classes = new Mat(); Mat trainingDataf5 = new Mat(); @@ -122,8 +144,8 @@ public class ANNTrain { int size = (int) files.size(); for (int j = 0; j < size; j++) { - System.out.println(files.get(j)); - Mat img = imread(files.get(j), 0); + // System.out.println(files.get(j)); + Mat img = opencv_imgcodecs.imread(files.get(j), 0); Mat f5 = features(img, 5); Mat f10 = features(img, 10); Mat f15 = features(img, 15); @@ -143,12 +165,12 @@ public class ANNTrain { System.out.println("Character: " + strChinese[i]); String str = path + '/' + strChinese[i]; Vector files = new Vector(); - Util.getFiles(str, files); + FileUtil.getFiles(str, files); int size = (int) files.size(); for (int j = 0; j < size; j++) { - System.out.println(files.get(j)); - Mat img = imread(files.get(j), 0); + // System.out.println(files.get(j)); + Mat img = opencv_imgcodecs.imread(files.get(j), 0); Mat f5 = features(img, 5); Mat f10 = features(img, 10); Mat f15 = features(img, 15); @@ -172,11 +194,11 @@ public class ANNTrain { new Mat(labels).copyTo(classes); FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.WRITE); - fs.writeObj("TrainingDataF5", trainingDataf5.data()); - fs.writeObj("TrainingDataF10", trainingDataf10.data()); - fs.writeObj("TrainingDataF15", trainingDataf15.data()); - fs.writeObj("TrainingDataF20", trainingDataf20.data()); - fs.writeObj("classes", classes.data()); + fs.write("TrainingDataF5", trainingDataf5); + fs.write("TrainingDataF10", trainingDataf10); + fs.write("TrainingDataF15", trainingDataf15); + fs.write("TrainingDataF20", trainingDataf20); + fs.write("classes", classes); fs.release(); System.out.println("End saveTrainData"); @@ -184,32 +206,21 @@ public class ANNTrain { } public void saveModel(int _predictsize, int _neurons) { + + // 样本文件数据已经保存到xml FileStorage fs = new FileStorage("res/train/ann_data.xml", FileStorage.READ); String training = "TrainingDataF" + _predictsize; - Mat TrainingData = new Mat(fs.get(training).readObj()); + Mat TrainingData = new Mat(fs.get(training)); Mat Classes = new Mat(fs.get("classes")); - // train the Ann - System.out.println("Begin to saveModelChar predictSize:" + Integer.valueOf(_predictsize).toString()); - System.out.println(" neurons:" + Integer.valueOf(_neurons).toString()); - long start = getTickCount(); annTrain(TrainingData, Classes, _neurons); long end = getTickCount(); - System.out.println("GetTickCount:" + Long.valueOf((end - start) / 1000).toString()); - - System.out.println("End the saveModelChar"); - + System.out.println("完成耗时: " + Long.valueOf((end - start) / 1000).toString()); + String model_name = "res/train/ann.xml"; - - // if(1) - // { - // String str = - // String.format("ann_prd:%d\tneu:%d",_predictsize,_neurons); - // model_name = str; - // } - - CvFileStorage fsto = CvFileStorage.open(model_name, CvMemStorage.create(), CV_STORAGE_WRITE); + + FileStorage fsto = new FileStorage(model_name, FileStorage.WRITE); ann.write(fsto, "ann"); } @@ -219,15 +230,13 @@ public class ANNTrain { saveTrainData(); // 可根据需要训练不同的predictSize或者neurons的ANN模型 - // for (int i = 2; i <= 2; i ++) - // { - // int size = i * 5; - // for (int j = 5; j <= 10; j++) - // { - // int neurons = j * 10; - // saveModel(size, neurons); - // } - // } + /*for (int i = 2; i <= 2; i++) { + int size = i * 5; + for (int j = 5; j <= 10; j++) { + int neurons = j * 10; + saveModel(size, neurons); + } + }*/ // 这里演示只训练model文件夹下的ann.xml,此模型是一个predictSize=10,neurons=40的ANN模型。 // 根据机器的不同,训练时间不一样,但一般需要10分钟左右,所以慢慢等一会吧。 @@ -235,5 +244,5 @@ public class ANNTrain { System.out.println("To be end."); return 0; - }*/ -} + } +} \ No newline at end of file diff --git a/src/main/java/com/yuxue/train/SVMTrain.java b/src/main/java/com/yuxue/train/SVMTrain.java index d746de80..5c6ac92d 100644 --- a/src/main/java/com/yuxue/train/SVMTrain.java +++ b/src/main/java/com/yuxue/train/SVMTrain.java @@ -16,7 +16,7 @@ import com.yuxue.util.FileUtil; /** * * @author yuxue - * @date 2020-05-14 11:37 + * @date 2020-05-14 22:16 */ public class SVMTrain {