ann训练验证通过

devA
yuxue 5 years ago
parent e47ec20c68
commit 181e99a213

@ -1,7 +1,8 @@
package com.yuxue.train;
import java.util.Vector;
import static org.bytedeco.javacpp.opencv_core.CV_32FC1;
import java.util.Vector;
import org.opencv.core.Core;
import org.opencv.core.CvType;
@ -14,7 +15,9 @@ import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
import org.opencv.ml.TrainData;
import com.yuxue.constant.Constant;
import com.yuxue.enumtype.Direction;
import com.yuxue.util.Convert;
import com.yuxue.util.FileUtil;
@ -35,63 +38,19 @@ import com.yuxue.util.FileUtil;
public class ANNTrain {
private ANN_MLP ann = ANN_MLP.create();
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
// 中国车牌; 34个字符; 没有 字母I、字母O
private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', /* 没有I */ 'J', 'K', 'L', 'M', 'N', /* 没有O */'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' };
private final int numCharacter = strCharacters.length;
// 并不全面,有些省份没有训练数据所以没有字符
// 有些后面加数字2的表示在训练时常看到字符的一种变形也作为训练数据存储
private final String strChinese[] = {
"zh_cuan", /*川*/
"zh_e", /*鄂*/
"zh_gan", /*赣*/
"zh_gan1", /*甘*/
"zh_gui", /*贵*/
"zh_gui1", /*桂*/
"zh_hei", /*黑*/
"zh_hu", /*沪*/
"zh_ji", /*冀*/
"zh_jin", /*津*/
"zh_jing", /*京*/
"zh_jl", /*吉*/
"zh_liao", /*辽*/
"zh_lu", /*鲁*/
"zh_meng", /*蒙*/
"zh_min", /*闽*/
"zh_ning", /*宁*/
"zh_qing", /*青*/
"zh_qiong", /*琼*/
"zh_shan", /*陕*/
"zh_su", /*苏*/
"zh_sx", /*晋*/
"zh_wan", /*皖*/
"zh_xiang", /*湘*/
"zh_xin", /*新*/
"zh_yu", /*豫*/
"zh_yu1", /*渝*/
"zh_yue", /*粤*/
"zh_yun", /*云*/
"zh_zang", /*藏*/
"zh_zhe" /*浙*/
};
private final int numAll = strCharacters.length + strChinese.length;
private final int numCharacter = Constant.strCharacters.length;
// 默认的训练操作的根目录
private static final String DEFAULT_PATH = "D:/PlateDetect/train/chars_recognise_ann/";
// 训练模型文件保存位置
// private static final String DATA_PATH = DEFAULT_PATH + "ann_data.xml";
private static final String MODEL_PATH = DEFAULT_PATH + "ann.xml";
public static float[] projectedHistogram(final Mat img, Direction direction) {
int sz = 0;
switch (direction) {
@ -115,25 +74,22 @@ public class ANNTrain {
int count = Core.countNonZero(data);
nonZeroMat[j] = count;
}
// Normalize histogram
float max = 0;
for (int j = 0; j < nonZeroMat.length; ++j) {
max = Math.max(max, nonZeroMat[j]);
}
if (max > 0) {
for (int j = 0; j < nonZeroMat.length; ++j) {
nonZeroMat[j] /= max;
}
}
return nonZeroMat;
}
public Mat features(Mat in, int sizeData) {
float[] vhist = projectedHistogram(in, Direction.VERTICAL);
float[] hhist = projectedHistogram(in, Direction.HORIZONTAL);
@ -152,7 +108,7 @@ public class ANNTrain {
for (int i = 0; i < hhist.length; ++i, ++j) {
out.put(0, j, hhist[i]);
}
for (int x = 0; x < lowData.cols(); x++) {
for (int y = 0; y < lowData.rows(); y++, ++j) {
// float val = lowData.ptr(x, y).get(0) & 0xFF;
@ -164,34 +120,11 @@ public class ANNTrain {
}
public void train(int _predictsize, int _neurons) {
// 读取样本文件数据
/*FileStorage fs = new FileStorage(DATA_PATH, FileStorage.READ);
Mat samples = new Mat(fs.get("TrainingDataF" + _predictsize));
Mat classes = new Mat(fs.get("classes"));
Mat trainClasses = new Mat(samples.rows(), numAll, CV_32F);
for (int i = 0; i < trainClasses.rows(); i++) {
for (int k = 0; k < trainClasses.cols(); k++) {
// If class of data i is same than a k class
if (k == Convert.toInt(classes.ptr(i))) {
trainClasses.ptr(i, k).put(Convert.getBytes(1f));
} else {
trainClasses.ptr(i, k).put(Convert.getBytes(0f));
}
}
}
samples.convertTo(samples, CV_32F);
System.out.println(samples.type());*/
Mat samples = new Mat(); // 使用push_back行数列数不能赋初始值
Vector<Integer> trainingLabels = new Vector<Integer>();
// 加载数字及字母字符
for (int i = 0; i < numCharacter; i++) {
String str = DEFAULT_PATH + strCharacters[i];
String str = DEFAULT_PATH + "learn/" + Constant.strCharacters[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
@ -204,10 +137,10 @@ public class ANNTrain {
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
}
// 加载汉字字符
for (int i = 0; i < strChinese.length; i++) {
String str = DEFAULT_PATH + strChinese[i];
for (int i = 0; i < Constant.strChinese.length; i++) {
String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
@ -220,40 +153,29 @@ public class ANNTrain {
trainingLabels.add(i + numCharacter);
}
}
//440 vhist.length + hhist.length + lowData.cols() * lowData.rows();
// CV_32FC1 CV_32SC1 CV_32F
// samples.convertTo(samples, CvType.CV_32F);
Mat classes = new Mat(trainingLabels.size(), 65, CvType.CV_32F);
float[] labels = new float[trainingLabels.size()];
for (int i = 0; i < labels.length; ++i) {
labels[i] = trainingLabels.get(i).intValue();
// labels[i] = trainingLabels.get(i).intValue();
classes.put(i, trainingLabels.get(i), 1.f);
}
Mat classes = new Mat(labels.length, 440, CvType.CV_32F);
classes.put(0, 0, labels);
System.out.println(samples.rows());
System.out.println(samples.cols());
System.out.println(samples.type());
System.out.println(classes.rows());
System.out.println(classes.cols());
System.out.println(classes.type());
// classes.put(0, 0, labels);
// samples.type() == CV_32F || samples.type() == CV_32S
TrainData train_data = TrainData.create(samples, Ml.ROW_SAMPLE, classes);
// //l_count为相量_layer_sizes的维数即MLP的层数L
// l_count = _layer_sizes->rows + _layer_sizes->cols - 1;
ann.clear();
Mat layers = new Mat(1, 3, CvType.CV_32F);
layers.put(0, 0, samples.cols());
layers.put(0, 1, _neurons);
layers.put(0, 2, classes.cols());
/*layers.ptr(0,0).put(Convert.getBytes(samples.cols())); //440 vhist.length + hhist.length + lowData.cols() * lowData.rows();
layers.ptr(0,1).put(Convert.getBytes(_predictsize));
layers.ptr(0,2).put(Convert.getBytes(numAll));*/
ann.setLayerSizes(layers);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
ann.setTrainMethod(ANN_MLP.BACKPROP);
@ -262,28 +184,47 @@ public class ANNTrain {
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.train(train_data);
System.err.println("完成 ");
// FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE);
// ann.write(fsto, "ann");
ann.save(MODEL_PATH);
}
public void predict() {
ann.clear();
ann = ANN_MLP.load(MODEL_PATH);
Vector<String> files = new Vector<String>();
FileUtil.getFiles(DEFAULT_PATH + "test/", files);
int i = 0;
for (String string : files) {
Mat img = Imgcodecs.imread(string, 0);
Mat f = features(img, Constant.predictSize);
Mat output = new Mat(1, 140, CvType.CV_32F);
//ann.predict(f, output, 0); // 预测结果
System.err.println(string + "===>" + (int) ann.predict(f, output, 0));
// ann.predict(f, output, 0);
// System.err.println(string + "===>" + output.get(0, 0)[0]);
}
}
public static void main(String[] args) {
ANNTrain annT = new ANNTrain();
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型
// 可根据需要训练不同的predictSize或者neurons的ANN模型
int _predictsize = 20;
int _neurons = 40;
// annT.saveTrainData(_predictsize);
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型。
// 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。
annT.train(_predictsize, _neurons);
annT.train(Constant.predictSize, Constant.neurons);
System.out.println("To be end.");
}
annT.predict();
System.out.println("The end.");
}
}
Loading…
Cancel
Save