|
|
|
@ -354,3 +354,22 @@ cv::Ptr<cv::ml::TrainData> AnnTrain::tdata() {
|
|
|
|
|
train_classes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
/*这段代码是一个开源项目EasyPR中的一个类AnnTrain的实现。
|
|
|
|
|
|
|
|
|
|
AnnTrain类用于训练一个基于人工神经网络(ANN)的字符识别模型。它接受一个包含字符图像的文件夹路径和一个输出的XML文件路径作为参数。
|
|
|
|
|
|
|
|
|
|
在构造函数中,它创建了一个ANN_MLP类对象ann_。type变量用于指定训练类型,0表示所有字符,1表示只有中文字符。kv_是一个Kv类的对象,用于加载一个映射文件。
|
|
|
|
|
|
|
|
|
|
train()函数是训练模型的核心函数。它根据type的值,确定输入层、隐藏层和输出层的节点数量,以及神经网络的层数和结构。然后设置神经网络的各种参数,如激活函数、训练方法、停止条件等。接着,它加载训练数据并开始训练模型。训练完成后,将模型保存到指定的XML文件中,并调用test()函数评估模型的准确率。
|
|
|
|
|
|
|
|
|
|
identifyChinese()函数和identify()函数分别用于识别中文字符和非中文字符。它们接受一个字符图像作为输入,提取特征并使用训练好的模型进行预测,返回一个包含字符和所属省份信息的pair。
|
|
|
|
|
|
|
|
|
|
test()函数用于评估模型的准确率。它遍历每个字符类别,读取所有字符图像文件,使用训练好的模型进行识别,并统计预测正确的数量。最后计算整体的准确率。
|
|
|
|
|
|
|
|
|
|
getSyntheticImage()函数用于生成合成图像,通过随机平移和旋转原字符图像来增加训练数据的多样性。
|
|
|
|
|
|
|
|
|
|
sdata()函数用于生成训练数据,读取字符图像文件夹中的所有图像,对于每个字符,根据指定数量生成一些合成图像,并提取特征。
|
|
|
|
|
|
|
|
|
|
tdata()函数和sdata()函数类似,用于生成训练数据,只是不生成合成图像。
|
|
|
|
|
|
|
|
|
|
该类还依赖其他一些辅助函数和常量定义,没有在这段代码中给出。*/
|