优化ann训练算法

devA
yuxue 5 years ago
parent fc4c79b419
commit b2d8e97b7a

@ -217,7 +217,7 @@ public class ANNTrain {
Vector<String> files = new Vector<String>(); Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files); // 文件名不能包含中文 FileUtil.getFiles(str, files); // 文件名不能包含中文
int count = 100; // 控制从训练样本中,抽取指定数量的样本 int count = 200; // 控制从训练样本中,抽取指定数量的样本
for (int j = 0; j < count; j++) { for (int j = 0; j < count; j++) {
Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0); Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0);
@ -235,11 +235,11 @@ public class ANNTrain {
// 增加膨胀样本 // 增加膨胀样本
/*samples.push_back(features(dilate(img), _predictsize)); /*samples.push_back(features(dilate(img), _predictsize));
trainingLabels.add(i); */ trainingLabels.add(i);*/
// 增加腐蚀样本 // 增加腐蚀样本
samples.push_back(features(erode(img), _predictsize)); /*samples.push_back(features(erode(img), _predictsize));
trainingLabels.add(i); trainingLabels.add(i); */
} }
} }
@ -249,7 +249,7 @@ public class ANNTrain {
Vector<String> files = new Vector<String>(); Vector<String> files = new Vector<String>();
FileUtil.getFiles(str, files); FileUtil.getFiles(str, files);
int count = 100; // 控制从训练样本中,抽取指定数量的样本 int count = 200; // 控制从训练样本中,抽取指定数量的样本
for (int j = 0; j < count; j++) { for (int j = 0; j < count; j++) {
Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0); Mat img = Imgcodecs.imread(files.get(rand.nextInt(files.size() - 1)), 0);
Mat f = features(img, _predictsize); Mat f = features(img, _predictsize);
@ -290,8 +290,8 @@ public class ANNTrain {
ann.clear(); ann.clear();
Mat layers = new Mat(1, 3, CvType.CV_32F); Mat layers = new Mat(1, 3, CvType.CV_32F);
layers.put(0, 0, samples.cols()); // 样本特征数 layers.put(0, 0, samples.cols()); // 样本特征数 140 10*10 + 20+20
layers.put(0, 1, _neurons); // layers.put(0, 1, _neurons); // 神经元个数
layers.put(0, 2, classes.cols()); // 字符数 layers.put(0, 2, classes.cols()); // 字符数
ann.setLayerSizes(layers); ann.setLayerSizes(layers);
@ -393,7 +393,7 @@ public class ANNTrain {
// 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型 // 这里演示只训练model文件夹下的ann.xml此模型是一个predictSize=10,neurons=40的ANN模型
// 可根据需要训练不同的predictSize或者neurons的ANN模型 // 可根据需要训练不同的predictSize或者neurons的ANN模型
// 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。 // 根据机器的不同训练时间不一样但一般需要10分钟左右所以慢慢等一会吧。
// annT.train(Constant.predictSize, Constant.neurons); annT.train(Constant.predictSize, Constant.neurons);
annT.predict(); annT.predict();

Loading…
Cancel
Save