From 8992f08a0a8cd02ee0ab8f1c4abc8e19504ac583 Mon Sep 17 00:00:00 2001 From: py5q8pfbc <2669752843@qq.com> Date: Mon, 17 Jun 2024 19:29:58 +0800 Subject: [PATCH] ADD file via upload --- FullTest(1).cpp | 269 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 FullTest(1).cpp diff --git a/FullTest(1).cpp b/FullTest(1).cpp new file mode 100644 index 0000000..6181b7a --- /dev/null +++ b/FullTest(1).cpp @@ -0,0 +1,269 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +// 定义图像类 +class Image { +private: + std::vector> pixels; // 像素数组 + int width; // 图像宽度 + int height; // 图像高度 + +public: + Image(std::vector> pixels, int width, int height) : pixels(pixels), width(width), height(height) {} + + // 获取像素值 + int getPixelValue(int x, int y) const { + return pixels[x][y]; + } + + // 获取图像宽度 + int getWidth() const { + return width; + } + + // 获取图像高度 + int getHeight() const { + return height; + } +}; + +class ImageProcessor { +public: + // 加载图像并进行预处理 + Image preprocessImage(const std::string& imagePath) { + return loadTXT(imagePath); + } + +private: + // 加载TXT图像 + Image loadTXT(const std::string& imagePath) { + std::ifstream file(imagePath); + if (!file.is_open()) { + throw std::runtime_error("Failed to open TXT image file: " + imagePath); + } + + std::vector> pixels; + std::string line; + int width = 0; + while (std::getline(file, line)) { + std::istringstream iss(line); + std::vector row; + int value; + while (iss >> value) { + row.push_back(value); + } + pixels.push_back(row); + if (width == 0) { + width = row.size(); + } + else if (row.size() != width) { + throw std::runtime_error("Invalid TXT image format: inconsistent row lengths"); + } + } + + // Create and return Image object + return Image(pixels, width, pixels.size()); + } +}; + +// 特征类 +class Feature { +private: + std::vector values; // 特征值 + +public: + Feature(std::vector values) : values(values) {} + + // 获取特征值 + double getValue(int index) const { + return values[index]; + } + + // 获取特征值数量 + size_t getSize() const { + return values.size(); + } +}; + +// 特征提取器 +class FeatureExtractor { +public: + // 从图像中提取特征 + Feature extractFeature(const Image& image) { + std::vector features; + for (int y = 0; y < image.getHeight(); ++y) { + for (int x = 0; x < image.getWidth(); ++x) { + features.push_back(static_cast(image.getPixelValue(x, y))); + } + } + return Feature(features); + } +}; + +// 距离度量类 +class DistanceMetric { +public: + // 计算两个特征之间的距离(欧氏距离) + double calculateDistance(const Feature& feature1, const Feature& feature2) const { + if (feature1.getSize() != feature2.getSize()) { + throw std::runtime_error("Feature vectors must be of the same size"); + } + double distance = 0.0; + for (size_t i = 0; i < feature1.getSize(); ++i) { + distance += pow(feature1.getValue(i) - feature2.getValue(i), 2); + } + return sqrt(distance); + } +}; + +// KNN 模型 +class KNNModel { +private: + std::vector features; // 特征向量 + std::vector labels; // 标签 + DistanceMetric distanceMetric; // 距离度量 + +public: + // 训练 KNN 模型 + void train(const std::vector& features, const std::vector& labels) { + this->features = features; + this->labels = labels; + } + + // 使用 KNN 模型进行预测 + int predict(const Feature& feature, int k) { + // 计算待预测特征与训练集中所有特征的距离,并记录索引 + std::vector> distances; + for (size_t i = 0; i < features.size(); ++i) { + double distance = distanceMetric.calculateDistance(features[i], feature); + distances.push_back(std::make_pair(distance, i)); + } + + // 根据距离排序,选择最近的 K 个样本 + std::sort(distances.begin(), distances.end()); + + // 统计最近的 K 个样本中各类别出现的次数 + std::unordered_map classCount; + for (int i = 0; i < k; ++i) { + int index = distances[i].second; + int label = labels[index]; + classCount[label]++; + } + + // 找出最频繁出现的类别 + int maxCount = 0; + int predictedLabel = -1; + for (auto& pair : classCount) { + if (pair.second > maxCount) { + maxCount = pair.second; + predictedLabel = pair.first; + } + } + + return predictedLabel; + } +}; + +// 数字识别器 +class DigitRecognizer { +private: + ImageProcessor imageProcessor; + FeatureExtractor featureExtractor; + KNNModel knnModel; + +public: + // 构造函数 + DigitRecognizer(const std::vector& trainingFeatures, const std::vector& trainingLabels) { + knnModel.train(trainingFeatures, trainingLabels); + } + + // 执行数字识别 + int recognizeDigit(const std::string& imagePath) { + // 加载图像并进行预处理 + Image image = imageProcessor.preprocessImage(imagePath); + + // 提取图像特征 + Feature feature = featureExtractor.extractFeature(image); + + // 使用 KNN 算法进行预测 + int predictedDigit = knnModel.predict(feature, 5 /* K 值 */); + + return predictedDigit; + } +}; + +// 加载训练数据 +void loadTrainingData(const std::string& baseDirTrain, std::vector& trainingFeatures, std::vector& trainingLabels) { + ImageProcessor imageProcessor; + FeatureExtractor featureExtractor; + + for (int label = 0; label <= 9; ++label) { + std::string dirPath = baseDirTrain + "\\" + std::to_string(label); + for (int i = 1; i <= 10; ++i) { + std::stringstream ss; + ss << dirPath << "\\img" << label << i << ".txt"; + std::string filePath = ss.str(); + Image image = imageProcessor.preprocessImage(filePath); + Feature feature = featureExtractor.extractFeature(image); + trainingFeatures.push_back(feature); + trainingLabels.push_back(label); + } + } +} + +// 加载测试数据 +void loadTestData(const std::string& baseDirTest, std::vector& testFilePaths) { + std::string dirPath = baseDirTest; + + // 遍历指定文件夹中的所有文件 + for (const auto& entry : std::filesystem::directory_iterator(dirPath)) { + if (entry.path().extension() == ".txt") { + testFilePaths.push_back(entry.path().string()); + + } + } +} + +int main() { + // 创建训练数据 + std::vector trainingFeatures; + std::vector trainingLabels; + + // 训练数据路径 + std::string baseDirTrain = "C:\\Users\\DELL\\Desktop\\期末三期\\10_10"; + + // 加载训练数据 + loadTrainingData(baseDirTrain, trainingFeatures, trainingLabels); + + // 创建数字识别器对象 + DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels); + + // 训练数据路径 + std::string baseDirTest = "C:\\Users\\DELL\\Desktop\\期末三期\\txt\\2"; + + // 加载测试数据 + std::vector testFilePaths; + loadTestData(baseDirTest, testFilePaths); + + // 执行数字识别并计算正确率 + int correctCount = 0; + for (size_t i = 0; i < testFilePaths.size(); ++i) { + int recognizedDigit = digitRecognizer.recognizeDigit(testFilePaths[i]); + if (recognizedDigit == 2) { + correctCount++; + } + } + + double accuracy = static_cast(correctCount) / testFilePaths.size(); + + // 输出识别结果 + std::cout << "2 Recognition Accuracy: " << accuracy * 100 << "%" << std::endl; + + return 0; +}