优化ann训练算法

devA
yuxue 5 years ago
parent f0634d82f8
commit 9b272d5e94

@ -1,10 +1,12 @@
package com.yuxue.train;
import java.util.Random;
import java.util.Vector;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Size;
import org.opencv.core.TermCriteria;
import org.opencv.imgcodecs.Imgcodecs;
@ -85,6 +87,7 @@ public class ANNTrain {
public Mat features(Mat in, int sizeData) {
float[] vhist = projectedHistogram(in, Direction.VERTICAL);
float[] hhist = projectedHistogram(in, Direction.HORIZONTAL);
@ -106,7 +109,6 @@ public class ANNTrain {
for (int x = 0; x < lowData.cols(); x++) {
for (int y = 0; y < lowData.rows(); y++, ++j) {
// float val = lowData.ptr(x, y).get(0) & 0xFF;
double[] val = lowData.get(x, y);
out.put(0, j, val[0]);
}
@ -114,6 +116,64 @@ public class ANNTrain {
return out;
}
/**
*
* @param inMat
* @return
*/
public Mat getSyntheticImage(Mat inMat) {
Random rand = new Random();
int rand_type = rand.nextInt(10000);
Mat result = inMat.clone();
// if (rand_type % 2 == 0) {
int ran_x = rand.nextInt(10000) % 5 - 2; // 控制在0-3个像素范围内
int ran_y = rand.nextInt(10000) % 5 - 2;
result = translateImg(result, ran_x, ran_y); // 平移
/*} else if (rand_type % 2 != 0) {
float angle = (float) (rand.nextInt(10000) % 15 - 7); // 旋转角度控制在0-7°范围内
result = rotateImg(result, angle); // 旋转
}*/
return result;
}
/**
*
* @param img
* @param offsetx
* @param offsety
* @return
*/
public Mat translateImg(Mat img, int offsetx, int offsety){
Mat dst = new Mat();
//定义平移矩阵
Mat trans_mat = Mat.zeros(2, 3, CvType.CV_32FC1);
trans_mat.put(0, 0, 1);
trans_mat.put(0, 2, offsetx);
trans_mat.put(1, 1, 1);
trans_mat.put(1, 2, offsety);
Imgproc.warpAffine(img, dst, trans_mat, img.size()); // 仿射变换
return dst;
}
/**
*
* @param source
* @param angle
* @return
*/
public Mat rotateImg(Mat source, float angle){
Point src_center = new Point(source.cols() / 2.0F, source.rows() / 2.0F);
Mat rot_mat = Imgproc.getRotationMatrix2D(src_center, angle, 1);
Mat dst = new Mat();
// 仿射变换 可以考虑使用投影变换; 这里使用放射变换进行旋转,对于实际效果来说感觉意义不大,反而会干扰结果预测
Imgproc.warpAffine(source, dst, rot_mat, source.size());
return dst;
}
public void train(int _predictsize, int _neurons) {
Mat samples = new Mat(); // 使用push_back行数列数不能赋初始值
Vector<Integer> trainingLabels = new Vector<Integer>();
@ -122,12 +182,27 @@ public class ANNTrain {
String str = DEFAULT_PATH + "learn/" + Constant.strCharacters[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files); // 文件名不能包含中文
int count = 200; // 控制每个字符最多只允许有200个样本文件
int k = 0;
// System.out.println("数字+字母:\t" + files.size());
for (String filePath : files) {
Mat img = Imgcodecs.imread(filePath);
Mat img = Imgcodecs.imread(filePath, 0);
Mat f = features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
// 抽取1/3样本文件平移或者旋转变换后加入训练样本
if (k % 3 == 0) {
samples.push_back(features(getSyntheticImage(img), _predictsize));
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
k++;
if(count <= 0) {
break;
}
count--;
}
}
@ -136,17 +211,35 @@ public class ANNTrain {
String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i];
Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files);
int count = 50; // 控制每个字符最多只允许有100个样本文件
int k = 0;
// System.out.println("汉字:\t" + files.size());
for (String filePath : files) {
Mat img = Imgcodecs.imread(filePath);
Mat img = Imgcodecs.imread(filePath, 0);
Mat f = features(img, _predictsize);
samples.push_back(f);
trainingLabels.add(i + Constant.numCharacter); // 每一幅字符图片所对应的字符类别索引下标
// 抽取1/3样本文件平移或者旋转变换后加入训练样本
if (k % 3 == 0) {
samples.push_back(features(getSyntheticImage(img), _predictsize));
trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
}
k++;
if(count <= 0) {
break;
}
count--;
}
}
samples.convertTo(samples, CvType.CV_32F);
//440 vhist.length + hhist.length + lowData.cols() * lowData.rows();
// CV_32FC1 CV_32SC1 CV_32F
Mat classes = new Mat(trainingLabels.size(), Constant.numAll, CvType.CV_32F);
Mat classes = Mat.zeros(trainingLabels.size(), Constant.numAll, CvType.CV_32F);
float[] labels = new float[trainingLabels.size()];
for (int i = 0; i < labels.length; ++i) {
@ -158,11 +251,9 @@ public class ANNTrain {
ann.clear();
Mat layers = new Mat(1, 3, CvType.CV_32F);
layers.put(0, 0, samples.cols());
layers.put(0, 1, _neurons);
layers.put(0, 2, classes.cols());
// System.out.println(layers);
layers.put(0, 0, samples.cols()); // 样本数量
layers.put(0, 1, _neurons); //
layers.put(0, 2, classes.cols()); // 字符数
ann.setLayerSizes(layers);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
@ -190,7 +281,7 @@ public class ANNTrain {
Mat img = Imgcodecs.imread(string, 0);
Mat f = features(img, Constant.predictSize);
// 140 predictSize = 10; vhist.length + hhist.length + lowData.cols() * lowData.rows();
/*// 140 predictSize = 10; vhist.length + hhist.length + lowData.cols() * lowData.rows();
// 440 predictSize = 20;
Mat output = new Mat(1, 140, CvType.CV_32F);
//ann.predict(f, output, 0); // 预测结果
@ -201,6 +292,25 @@ public class ANNTrain {
int index = (int) ann.predict(f, output, 0);
System.err.println(string + "===>" + index);
if (index < Constant.numCharacter) {
plate += String.valueOf(Constant.strCharacters[index]);
} else {
String s = Constant.strChinese[index - Constant.numCharacter];
plate += Constant.KEY_CHINESE_MAP.get(s);
}*/
int index = 0;
double maxVal = -2;
Mat output = new Mat(1, Constant.numAll, CvType.CV_32F);
ann.predict(f, output); // 预测结果
for (int j = 0; j < Constant.numAll; j++) {
double val = output.get(0, j)[0];
if (val > maxVal) {
maxVal = val;
index = j;
}
}
System.err.println(index);
if (index < Constant.numCharacter) {
plate += String.valueOf(Constant.strCharacters[index]);
} else {
@ -209,6 +319,7 @@ public class ANNTrain {
}
}
System.err.println("===>" + plate);
return;
}
public static void main(String[] args) {
@ -217,11 +328,12 @@ public class ANNTrain {
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型
// 可根据需要训练不同的predictSize或者neurons的ANN模型
// 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。
// annT.train(Constant.predictSize, Constant.neurons);
annT.train(Constant.predictSize, Constant.neurons);
annT.predict();
System.out.println("The end.");
return;
}

Loading…
Cancel
Save