添加ann训练

devA
yuxue 5 years ago
parent 591dc5e9a1
commit 2e840a3663

@ -1,20 +1,20 @@
package com.yuxue.train;
import static org.bytedeco.javacpp.opencv_core.CV_32F;
import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE;
import java.util.Vector;
import org.bytedeco.javacpp.opencv_core.FileStorage;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.TermCriteria;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_ml.ANN_MLP;
import org.bytedeco.javacpp.opencv_ml.TrainData;
import com.yuxue.easypr.core.CoreFunc;
import com.yuxue.util.Convert;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.core.TermCriteria;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
import org.opencv.ml.TrainData;
import com.yuxue.enumtype.Direction;
import com.yuxue.util.FileUtil;
@ -36,6 +36,9 @@ public class ANNTrain {
private ANN_MLP ann = ANN_MLP.create();
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
// 中国车牌; 34个字符; 没有 字母I、字母O
private final char strCharacters[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
@ -85,30 +88,119 @@ public class ANNTrain {
private static final String DEFAULT_PATH = "D:/PlateDetect/train/chars_recognise_ann/";
// 训练模型文件保存位置
private static final String DATA_PATH = DEFAULT_PATH + "ann_data.xml";
// private static final String DATA_PATH = DEFAULT_PATH + "ann_data.xml";
private static final String MODEL_PATH = DEFAULT_PATH + "ann.xml";
public static float[] projectedHistogram(final Mat img, Direction direction) {
int sz = 0;
switch (direction) {
case HORIZONTAL:
sz = img.rows();
break;
case VERTICAL:
sz = img.cols();
break;
default:
break;
}
// 统计这一行或一列中非零元素的个数并保存到nonZeroMat中
float[] nonZeroMat = new float[sz];
Core.extractChannel(img, img, 0);
for (int j = 0; j < sz; j++) {
Mat data = (direction == Direction.HORIZONTAL) ? img.row(j) : img.col(j);
int count = Core.countNonZero(data);
nonZeroMat[j] = count;
}
// Normalize histogram
float max = 0;
for (int j = 0; j < nonZeroMat.length; ++j) {
max = Math.max(max, nonZeroMat[j]);
}
if (max > 0) {
for (int j = 0; j < nonZeroMat.length; ++j) {
nonZeroMat[j] /= max;
}
}
return nonZeroMat;
}
public int saveTrainData(int _predictsize) {
public Mat features(Mat in, int sizeData) {
float[] vhist = projectedHistogram(in, Direction.VERTICAL);
float[] hhist = projectedHistogram(in, Direction.HORIZONTAL);
Mat lowData = new Mat();
if (sizeData > 0) {
Imgproc.resize(in, lowData, new Size(sizeData, sizeData));
}
Mat classes = new Mat();
Mat trainingDataf = new Mat();
int numCols = vhist.length + hhist.length + lowData.cols() * lowData.rows();
Mat out = new Mat(1, numCols, CvType.CV_32F);
Vector<Integer> trainingLabels = new Vector<Integer>();
int j = 0;
for (int i = 0; i < vhist.length; ++i, ++j) {
out.put(0, j, vhist[i]);
}
for (int i = 0; i < hhist.length; ++i, ++j) {
out.put(0, j, hhist[i]);
}
for (int x = 0; x < lowData.cols(); x++) {
for (int y = 0; y < lowData.rows(); y++, ++j) {
// float val = lowData.ptr(x, y).get(0) & 0xFF;
double[] val = lowData.get(x, y);
out.put(0, j, val[0]);
}
}
return out;
}
public void train(int _predictsize, int _neurons) {
// 读取样本文件数据
/*FileStorage fs = new FileStorage(DATA_PATH, FileStorage.READ);
Mat samples = new Mat(fs.get("TrainingDataF" + _predictsize));
Mat classes = new Mat(fs.get("classes"));
Mat trainClasses = new Mat(samples.rows(), numAll, CV_32F);
for (int i = 0; i < trainClasses.rows(); i++) {
for (int k = 0; k < trainClasses.cols(); k++) {
// If class of data i is same than a k class
if (k == Convert.toInt(classes.ptr(i))) {
trainClasses.ptr(i, k).put(Convert.getBytes(1f));
} else {
trainClasses.ptr(i, k).put(Convert.getBytes(0f));
}
}
}
samples.convertTo(samples, CV_32F);
System.out.println(samples.type());*/
Mat samples = new Mat(); // 使用push_back行数列数不能赋初始值
Vector<Integer> trainingLabels = new Vector<Integer>();
// 加载数字及字母字符
for (int i = 0; i < numCharacter; i++) {
String str = DEFAULT_PATH + strCharacters[i];
System.err.println(str);
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
int size = (int) files.size();
for (int j = 0; j < size; j++) {
Mat img = opencv_imgcodecs.imread(files.get(j), 0);
Mat img = Imgcodecs.imread(files.get(j), 0);
// System.err.println(files.get(j)); // 文件名不能包含中文
Mat f = CoreFunc.features(img, _predictsize);
trainingDataf.push_back(f);
Mat f = features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
}
@ -121,84 +213,67 @@ public class ANNTrain {
int size = (int) files.size();
for (int j = 0; j < size; j++) {
Mat img = opencv_imgcodecs.imread(files.get(j), 0);
Mat img = Imgcodecs.imread(files.get(j), 0);
// System.err.println(files.get(j)); // 文件名不能包含中文
Mat f = CoreFunc.features(img, _predictsize);
trainingDataf.push_back(f);
Mat f = features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i + numCharacter);
}
}
// CV_32FC1 CV_32SC1 CV_32F
trainingDataf.convertTo(trainingDataf, opencv_core.CV_32F);
// samples.convertTo(samples, CvType.CV_32F);
int[] labels = new int[trainingLabels.size()];
float[] labels = new float[trainingLabels.size()];
for (int i = 0; i < labels.length; ++i) {
labels[i] = trainingLabels.get(i).intValue();
}
new Mat(labels).copyTo(classes);
FileStorage fs = new FileStorage(DATA_PATH, FileStorage.WRITE);
fs.write("TrainingDataF" + _predictsize, trainingDataf);
fs.write("classes", classes);
fs.release();
Mat classes = new Mat(labels.length, 440, CvType.CV_32F);
classes.put(0, 0, labels);
System.out.println("End saveTrainData");
return 0;
}
public void train(int _predictsize, int _neurons) {
// 读取样本文件数据
FileStorage fs = new FileStorage(DATA_PATH, FileStorage.READ);
Mat samples = new Mat(fs.get("TrainingDataF" + _predictsize));
Mat classes = new Mat(fs.get("classes"));
Mat trainClasses = new Mat(samples.rows(), numAll, CV_32F);
for (int i = 0; i < trainClasses.rows(); i++) {
for (int k = 0; k < trainClasses.cols(); k++) {
// If class of data i is same than a k class
if (k == Convert.toInt(classes.ptr(i))) {
trainClasses.ptr(i, k).put(Convert.getBytes(1f));
} else {
trainClasses.ptr(i, k).put(Convert.getBytes(0f));
}
}
}
System.out.println(samples.rows());
System.out.println(samples.cols());
System.out.println(samples.type());
samples.convertTo(samples, CV_32F);
System.out.println(classes.rows());
System.out.println(classes.cols());
System.out.println(classes.type());
System.out.println(samples.type());
// samples.type() == CV_32F || samples.type() == CV_32S
TrainData train_data = TrainData.create(samples, ROW_SAMPLE, trainClasses);
TrainData train_data = TrainData.create(samples, Ml.ROW_SAMPLE, classes);
// //l_count为相量_layer_sizes的维数即MLP的层数L
// l_count = _layer_sizes->rows + _layer_sizes->cols - 1;
ann.clear();
Mat layers = new Mat(1, 3, CV_32F);
layers.ptr(0).put(Convert.getBytes(samples.cols()));
layers.ptr(1).put(Convert.getBytes(_neurons));
layers.ptr(2).put(Convert.getBytes(numAll));
Mat layers = new Mat(1, 3, CvType.CV_32F);
layers.put(0, 0, samples.cols());
layers.put(0, 1, _neurons);
layers.put(0, 2, classes.cols());
/*layers.ptr(0,0).put(Convert.getBytes(samples.cols())); //440 vhist.length + hhist.length + lowData.cols() * lowData.rows();
layers.ptr(0,1).put(Convert.getBytes(_predictsize));
layers.ptr(0,2).put(Convert.getBytes(numAll));*/
ann.setLayerSizes(layers);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
ann.setTrainMethod(ANN_MLP.BACKPROP);
TermCriteria criteria = new TermCriteria(TermCriteria.MAX_ITER, 30000, 0.0001);
TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 30000, 0.0001);
ann.setTermCriteria(criteria);
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.train(train_data);
System.err.println("完成 ");
FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE);
ann.write(fsto, "ann");
// FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE);
// ann.write(fsto, "ann");
ann.save(MODEL_PATH);
}
public static void main(String[] args) {
ANNTrain annT = new ANNTrain();
// 可根据需要训练不同的predictSize或者neurons的ANN模型
int _predictsize = 10;
int _predictsize = 20;
int _neurons = 40;
// annT.saveTrainData(_predictsize);

Loading…
Cancel
Save