package com.yuxue.train;

import java.util.Vector;

import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_ml.*;

import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_core.Mat;

import com.yuxue.constant.Constant;
import com.yuxue.easypr.core.CoreFunc;
import com.yuxue.util.FileUtil;

/**
 * 基于org.bytedeco.javacpp包实现的训练
 * 
 * 图片文字识别训练
 * 训练出来的库文件,用于识别图片中的文字
 * 
 * 训练的ann.xml应用:
 * 1、替换res/model/ann.xml文件
 * 2、修改com.yuxue.easypr.core.CharsIdentify.charsIdentify(Mat, Boolean, Boolean)方法
 * 
 * @author yuxue
 * @date 2020-05-14 22:16
 */
public class ANNTrain1 {

    private ANN_MLP ann = ANN_MLP.create();

    // 默认的训练操作的根目录
    private static final String DEFAULT_PATH = "D:/PlateDetect/train/chars_recognise_ann/";

    // 训练模型文件保存位置
    private static final String MODEL_PATH = "res/model/ann.xml";
    
    
    public void train(int _predictsize, int _neurons) {
        Mat samples = new Mat(); // 使用push_back,行数列数不能赋初始值
        Vector<Integer> trainingLabels = new Vector<Integer>();
        // 加载数字及字母字符
        for (int i = 0; i < Constant.numCharacter; i++) {
            String str = DEFAULT_PATH + "learn/" + Constant.strCharacters[i];
            Vector<String> files = new Vector<String>();
            FileUtil.getFiles(str, files);

            int size = (int) files.size();
            for (int j = 0; j < size; j++) {
                Mat img = opencv_imgcodecs.imread(files.get(j), 0);
                // System.err.println(files.get(j)); // 文件名不能包含中文
                Mat f = CoreFunc.features(img, _predictsize);
                samples.push_back(f);
                trainingLabels.add(i); // 每一幅字符图片所对应的字符类别索引下标
            }
        }

        // 加载汉字字符
        for (int i = 0; i < Constant.strChinese.length; i++) {
            String str = DEFAULT_PATH + "learn/" + Constant.strChinese[i];
            Vector<String> files = new Vector<String>();
            FileUtil.getFiles(str, files);

            int size = (int) files.size();
            for (int j = 0; j < size; j++) {
                Mat img = opencv_imgcodecs.imread(files.get(j), 0);
                // System.err.println(files.get(j));   // 文件名不能包含中文
                Mat f = CoreFunc.features(img, _predictsize);
                samples.push_back(f);
                trainingLabels.add(i + Constant.numCharacter);
            }
        }


        //440   vhist.length + hhist.length + lowData.cols() * lowData.rows();
        // CV_32FC1 CV_32SC1 CV_32F
        Mat classes = new Mat(trainingLabels.size(), Constant.numAll, CV_32F);
        
        float[] labels = new float[trainingLabels.size()];
        for (int i = 0; i < labels.length; ++i) {
            classes.ptr(i, trainingLabels.get(i)).putFloat(1.f);
            
        }

        // samples.type() == CV_32F || samples.type() == CV_32S 
        TrainData train_data = TrainData.create(samples, ROW_SAMPLE, classes);

        ann.clear();
        Mat layers = new Mat(1, 3, CV_32SC1);
        layers.ptr(0, 0).putInt(samples.cols());
        layers.ptr(0, 1).putInt(_neurons);
        layers.ptr(0, 2).putInt(classes.cols());

        System.out.println(layers);
        
        ann.setLayerSizes(layers);
        ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
        ann.setTrainMethod(ANN_MLP.BACKPROP);
        TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 30000, 0.0001);
        ann.setTermCriteria(criteria);
        ann.setBackpropWeightScale(0.1);
        ann.setBackpropMomentumScale(0.1);
        ann.train(train_data);

        //FileStorage fsto = new FileStorage(MODEL_PATH, FileStorage.WRITE);
        //ann.write(fsto, "ann");
        ann.save(MODEL_PATH);
    }
    
    
    public void predict() {
        ann.clear();
        ann = ANN_MLP.load(MODEL_PATH);
        //ann = ANN_MLP.loadANN_MLP(MODEL_PATH, "ann");
        Vector<String> files = new Vector<String>();
        FileUtil.getFiles(DEFAULT_PATH + "test/", files);
        
        for (String string : files) {
            Mat img = opencv_imgcodecs.imread(string);
            Mat f = CoreFunc.features(img, Constant.predictSize);
            
            // 140 predictSize = 10; vhist.length + hhist.length + lowData.cols() * lowData.rows();
            // 440 predictSize = 20;
            Mat output = new Mat(1, 140, CV_32F);
            //ann.predict(f, output, 0);  // 预测结果
            // System.err.println(string + "===>" + (int) ann.predict(f, output, 0));

            int index = (int) ann.predict(f, output, 0);
            
            String result = "";
            if (index < Constant.numCharacter) {
                result = String.valueOf(Constant.strCharacters[index]);
            } else {
                String s = Constant.strChinese[index - Constant.numCharacter];
                result = Constant.KEY_CHINESE_MAP.get(s);   // 编码转中文
            }
            System.err.println(string + "===>" + result);
            
            // ann.predict(f, output, 0);
            // System.err.println(string + "===>" + output.get(0, 0)[0]);
            
        }
    }

    public static void main(String[] args) {
        
        ANNTrain1 annT = new ANNTrain1();
        // 这里演示只训练model文件夹下的ann.xml,此模型是一个predictSize=10,neurons=40的ANN模型
        // 可根据需要训练不同的predictSize或者neurons的ANN模型
        // 根据机器的不同,训练时间不一样,但一般需要10分钟左右,所以慢慢等一会吧。
        annT.train(Constant.predictSize, Constant.neurons);

        annT.predict();
        
        System.out.println("The end.");
    }


}