You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.yuxue.train;
import java.util.Vector;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_ml.*;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_core.Mat;
import com.yuxue.constant.Constant;
import com.yuxue.easypr.core.CoreFunc;
import com.yuxue.util.FileUtil;
/**
* 基于org.bytedeco.javacpp包实现的训练
*
* 图片文字识别训练
* 训练出来的库文件,用于识别图片中的文字
*
* 训练的ann.xml应用
* 1、替换res/model/ann.xml文件
* 2、修改com.yuxue.easypr.core.CharsIdentify.charsIdentify(Mat, Boolean, Boolean)方法
*
* @author yuxue
* @date 2020-05-14 22:16
*/
public class ANNTrain1 {
private ANN_MLP ann = ANN_MLP.create();
// 默认的训练操作的根目录
private static final String DEFAULT_PATH = "D:/PlateDetect/train/chars_recognise_ann/";
// 训练模型文件保存位置
private static final String MODEL_PATH = "res/model/ann.xml";
public void train(int _predictsize, int _neurons) {
Mat samples = new Mat(); // 使用push_back行数列数不能赋初始值
Vector<Integer> trainingLabels = new Vector<Integer>();
// 加载数字及字母字符
for (int i = 0; i < Constant.numCharacter; i++) {
String str = DEFAULT_PATH + "learn/" + Constant.strCharacters[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
int size = (int) files.size();
for (int j = 0; j < size; j++) {
Mat img = opencv_imgcodecs.imread(files.get(j), 0);
// System.err.println(files.get(j)); // 文件名不能包含中文
Mat f = CoreFunc.features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
}
// 加载汉字字符
for (int i = 0; i < Constant.strChinese.length; i++) {
String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
int size = (int) files.size();
for (int j = 0; j < size; j++) {
Mat img = opencv_imgcodecs.imread(files.get(j), 0);
// System.err.println(files.get(j)); // 文件名不能包含中文
Mat f = CoreFunc.features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i + Constant.numCharacter);
}
}
//440 vhist.length + hhist.length + lowData.cols() * lowData.rows();
// CV_32FC1 CV_32SC1 CV_32F
Mat classes = new Mat(trainingLabels.size(), Constant.numAll, CV_32F);
float[] labels = new float[trainingLabels.size()];
for (int i = 0; i < labels.length; ++i) {
classes.ptr(i, trainingLabels.get(i)).putFloat(1.f);
}
// samples.type() == CV_32F || samples.type() == CV_32S
TrainData train_data = TrainData.create(samples, ROW_SAMPLE, classes);
ann.clear();
Mat layers = new Mat(1, 3, CV_32SC1);
layers.ptr(0, 0).putInt(samples.cols());
layers.ptr(0, 1).putInt(_neurons);
layers.ptr(0, 2).putInt(classes.cols());
System.out.println(layers);
ann.setLayerSizes(layers);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
ann.setTrainMethod(ANN_MLP.BACKPROP);
TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 30000, 0.0001);
ann.setTermCriteria(criteria);
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.train(train_data);
//FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE);
//ann.write(fsto, "ann");
ann.save(MODEL_PATH);
}
public void predict() {
ann.clear();
ann = ANN_MLP.load(MODEL_PATH);
//ann = ANN_MLP.loadANN_MLP(MODEL_PATH, "ann");
Vector<String> files = new Vector<String>();
FileUtil.getFiles(DEFAULT_PATH + "test/", files);
for (String string : files) {
Mat img = opencv_imgcodecs.imread(string);
Mat f = CoreFunc.features(img, Constant.predictSize);
// 140 predictSize = 10; vhist.length + hhist.length + lowData.cols() * lowData.rows();
// 440 predictSize = 20;
Mat output = new Mat(1, 140, CV_32F);
//ann.predict(f, output, 0); // 预测结果
// System.err.println(string + "===>" + (int) ann.predict(f, output, 0));
int index = (int) ann.predict(f, output, 0);
String result = "";
if (index < Constant.numCharacter) {
result = String.valueOf(Constant.strCharacters[index]);
} else {
String s = Constant.strChinese[index - Constant.numCharacter];
result = Constant.KEY_CHINESE_MAP.get(s); // 编码转中文
}
System.err.println(string + "===>" + result);
// ann.predict(f, output, 0);
// System.err.println(string + "===>" + output.get(0, 0)[0]);
}
}
public static void main(String[] args) {
ANNTrain1 annT = new ANNTrain1();
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型
// 可根据需要训练不同的predictSize或者neurons的ANN模型
// 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。
annT.train(Constant.predictSize, Constant.neurons);
annT.predict();
System.out.println("The end.");
}
}