#include #include #include #include #include #include #include #include // 定义图像类 class Image { private: std::vector> pixels; // 像素数组 int width; // 图像宽度 int height; // 图像高度 public: Image(std::vector> pixels, int width, int height) : pixels(pixels), width(width), height(height) {} // 获取像素值 int getPixelValue(int x, int y) const { return pixels[x][y]; } // 获取图像宽度 int getWidth() const { return width; } // 获取图像高度 int getHeight() const { return height; } }; class ImageProcessor { public: // 加载图像并进行预处理 Image preprocessImage(const std::string& imagePath) { return loadTXT(imagePath); } private: // 加载TXT图像 Image loadTXT(const std::string& imagePath) { std::ifstream file(imagePath); if (!file.is_open()) { throw std::runtime_error("Failed to open TXT image file: " + imagePath); } std::vector> pixels; std::string line; int width = 0; while (std::getline(file, line)) { std::istringstream iss(line); std::vector row; int value; while (iss >> value) { row.push_back(value); } pixels.push_back(row); if (width == 0) { width = row.size(); } else if (row.size() != width) { throw std::runtime_error("Invalid TXT image format: inconsistent row lengths"); } } // Create and return Image object return Image(pixels, width, pixels.size()); } }; // 特征类 class Feature { private: std::vector values; // 特征值 public: Feature(std::vector values) : values(values) {} // 获取特征值 double getValue(int index) const { return values[index]; } // 获取特征值数量 size_t getSize() const { return values.size(); } }; // 特征提取器 class FeatureExtractor { public: // 从图像中提取特征 Feature extractFeature(const Image& image) { std::vector features; for (int y = 0; y < image.getHeight(); ++y) { for (int x = 0; x < image.getWidth(); ++x) { features.push_back(static_cast(image.getPixelValue(x, y))); } } return Feature(features); } }; // 距离度量类 class DistanceMetric { public: // 计算两个特征之间的距离(欧氏距离) double calculateDistance(const Feature& feature1, const Feature& feature2) const { if (feature1.getSize() != feature2.getSize()) { throw std::runtime_error("Feature vectors must be of the same size"); } double distance = 0.0; for (size_t i = 0; i < feature1.getSize(); ++i) { distance += pow(feature1.getValue(i) - feature2.getValue(i), 2); } return sqrt(distance); } }; // KNN 模型 class KNNModel { private: std::vector features; // 特征向量 std::vector labels; // 标签 DistanceMetric distanceMetric; // 距离度量 public: // 训练 KNN 模型 void train(const std::vector& features, const std::vector& labels) { this->features = features; this->labels = labels; } // 使用 KNN 模型进行预测 int predict(const Feature& feature, int k) { // 计算待预测特征与训练集中所有特征的距离,并记录索引 std::vector> distances; for (size_t i = 0; i < features.size(); ++i) { double distance = distanceMetric.calculateDistance(features[i], feature); distances.push_back(std::make_pair(distance, i)); } // 根据距离排序,选择最近的 K 个样本 std::sort(distances.begin(), distances.end()); // 统计最近的 K 个样本中各类别出现的次数 std::unordered_map classCount; for (int i = 0; i < k; ++i) { int index = distances[i].second; int label = labels[index]; classCount[label]++; } // 找出最频繁出现的类别 int maxCount = 0; int predictedLabel = -1; for (auto& pair : classCount) { if (pair.second > maxCount) { maxCount = pair.second; predictedLabel = pair.first; } } return predictedLabel; } }; // 数字识别器 class DigitRecognizer { private: ImageProcessor imageProcessor; FeatureExtractor featureExtractor; KNNModel knnModel; public: // 构造函数 DigitRecognizer(const std::vector& trainingFeatures, const std::vector& trainingLabels) { knnModel.train(trainingFeatures, trainingLabels); } // 执行数字识别 int recognizeDigit(const std::string& imagePath) { // 加载图像并进行预处理 Image image = imageProcessor.preprocessImage(imagePath); // 提取图像特征 Feature feature = featureExtractor.extractFeature(image); // 使用 KNN 算法进行预测 int predictedDigit = knnModel.predict(feature, 5 /* K 值 */); return predictedDigit; } }; // 加载训练数据 void loadTrainingData(const std::string& baseDirTrain, std::vector& trainingFeatures, std::vector& trainingLabels) { ImageProcessor imageProcessor; FeatureExtractor featureExtractor; for (int label = 0; label <= 9; ++label) { std::string dirPath = baseDirTrain + "\\" + std::to_string(label); for (int i = 1; i <= 10; ++i) { std::stringstream ss; ss << dirPath << "\\img" << label << i << ".txt"; std::string filePath = ss.str(); Image image = imageProcessor.preprocessImage(filePath); Feature feature = featureExtractor.extractFeature(image); trainingFeatures.push_back(feature); trainingLabels.push_back(label); } } } // 加载测试数据 void loadTestData(const std::string& baseDirTest, std::vector& testFilePaths) { std::string dirPath = baseDirTest; // 遍历指定文件夹中的所有文件 for (const auto& entry : std::filesystem::directory_iterator(dirPath)) { if (entry.path().extension() == ".txt") { testFilePaths.push_back(entry.path().string()); } } } int main() { // 创建训练数据 std::vector trainingFeatures; std::vector trainingLabels; // 训练数据路径 std::string baseDirTrain = "C:\\Users\\DELL\\Desktop\\期末三期\\10_10"; // 加载训练数据 loadTrainingData(baseDirTrain, trainingFeatures, trainingLabels); // 创建数字识别器对象 DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels); // 训练数据路径 std::string baseDirTest = "C:\\Users\\DELL\\Desktop\\期末三期\\txt\\2"; // 加载测试数据 std::vector testFilePaths; loadTestData(baseDirTest, testFilePaths); // 执行数字识别并计算正确率 int correctCount = 0; for (size_t i = 0; i < testFilePaths.size(); ++i) { int recognizedDigit = digitRecognizer.recognizeDigit(testFilePaths[i]); if (recognizedDigit == 2) { correctCount++; } } double accuracy = static_cast(correctCount) / testFilePaths.size(); // 输出识别结果 std::cout << "2 Recognition Accuracy: " << accuracy * 100 << "%" << std::endl; return 0; }