You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

270 lines
7.7 KiB

#include <algorithm>
#include <cmath>
#include <iostream>
#include <unordered_map>
#include <vector>
#include <fstream>
#include <sstream>
#include <filesystem>
// 定义图像类
class Image {
private:
std::vector<std::vector<int>> pixels; // 像素数组
int width; // 图像宽度
int height; // 图像高度
public:
Image(std::vector<std::vector<int>> pixels, int width, int height) : pixels(pixels), width(width), height(height) {}
// 获取像素值
int getPixelValue(int x, int y) const {
return pixels[x][y];
}
// 获取图像宽度
int getWidth() const {
return width;
}
// 获取图像高度
int getHeight() const {
return height;
}
};
class ImageProcessor {
public:
// 加载图像并进行预处理
Image preprocessImage(const std::string& imagePath) {
return loadTXT(imagePath);
}
private:
// 加载TXT图像
Image loadTXT(const std::string& imagePath) {
std::ifstream file(imagePath);
if (!file.is_open()) {
throw std::runtime_error("Failed to open TXT image file: " + imagePath);
}
std::vector<std::vector<int>> pixels;
std::string line;
int width = 0;
while (std::getline(file, line)) {
std::istringstream iss(line);
std::vector<int> row;
int value;
while (iss >> value) {
row.push_back(value);
}
pixels.push_back(row);
if (width == 0) {
width = row.size();
}
else if (row.size() != width) {
throw std::runtime_error("Invalid TXT image format: inconsistent row lengths");
}
}
// Create and return Image object
return Image(pixels, width, pixels.size());
}
};
// 特征类
class Feature {
private:
std::vector<double> values; // 特征值
public:
Feature(std::vector<double> values) : values(values) {}
// 获取特征值
double getValue(int index) const {
return values[index];
}
// 获取特征值数量
size_t getSize() const {
return values.size();
}
};
// 特征提取器
class FeatureExtractor {
public:
// 从图像中提取特征
Feature extractFeature(const Image& image) {
std::vector<double> features;
for (int y = 0; y < image.getHeight(); ++y) {
for (int x = 0; x < image.getWidth(); ++x) {
features.push_back(static_cast<double>(image.getPixelValue(x, y)));
}
}
return Feature(features);
}
};
// 距离度量类
class DistanceMetric {
public:
// 计算两个特征之间的距离(欧氏距离)
double calculateDistance(const Feature& feature1, const Feature& feature2) const {
if (feature1.getSize() != feature2.getSize()) {
throw std::runtime_error("Feature vectors must be of the same size");
}
double distance = 0.0;
for (size_t i = 0; i < feature1.getSize(); ++i) {
distance += pow(feature1.getValue(i) - feature2.getValue(i), 2);
}
return sqrt(distance);
}
};
// KNN 模型
class KNNModel {
private:
std::vector<Feature> features; // 特征向量
std::vector<int> labels; // 标签
DistanceMetric distanceMetric; // 距离度量
public:
// 训练 KNN 模型
void train(const std::vector<Feature>& features, const std::vector<int>& labels) {
this->features = features;
this->labels = labels;
}
// 使用 KNN 模型进行预测
int predict(const Feature& feature, int k) {
// 计算待预测特征与训练集中所有特征的距离,并记录索引
std::vector<std::pair<double, int>> distances;
for (size_t i = 0; i < features.size(); ++i) {
double distance = distanceMetric.calculateDistance(features[i], feature);
distances.push_back(std::make_pair(distance, i));
}
// 根据距离排序,选择最近的 K 个样本
std::sort(distances.begin(), distances.end());
// 统计最近的 K 个样本中各类别出现的次数
std::unordered_map<int, int> classCount;
for (int i = 0; i < k; ++i) {
int index = distances[i].second;
int label = labels[index];
classCount[label]++;
}
// 找出最频繁出现的类别
int maxCount = 0;
int predictedLabel = -1;
for (auto& pair : classCount) {
if (pair.second > maxCount) {
maxCount = pair.second;
predictedLabel = pair.first;
}
}
return predictedLabel;
}
};
// 数字识别器
class DigitRecognizer {
private:
ImageProcessor imageProcessor;
FeatureExtractor featureExtractor;
KNNModel knnModel;
public:
// 构造函数
DigitRecognizer(const std::vector<Feature>& trainingFeatures, const std::vector<int>& trainingLabels) {
knnModel.train(trainingFeatures, trainingLabels);
}
// 执行数字识别
int recognizeDigit(const std::string& imagePath) {
// 加载图像并进行预处理
Image image = imageProcessor.preprocessImage(imagePath);
// 提取图像特征
Feature feature = featureExtractor.extractFeature(image);
// 使用 KNN 算法进行预测
int predictedDigit = knnModel.predict(feature, 5 /* K 值 */);
return predictedDigit;
}
};
// 加载训练数据
void loadTrainingData(const std::string& baseDirTrain, std::vector<Feature>& trainingFeatures, std::vector<int>& trainingLabels) {
ImageProcessor imageProcessor;
FeatureExtractor featureExtractor;
for (int label = 0; label <= 9; ++label) {
std::string dirPath = baseDirTrain + "\\" + std::to_string(label);
for (int i = 1; i <= 10; ++i) {
std::stringstream ss;
ss << dirPath << "\\img" << label << i << ".txt";
std::string filePath = ss.str();
Image image = imageProcessor.preprocessImage(filePath);
Feature feature = featureExtractor.extractFeature(image);
trainingFeatures.push_back(feature);
trainingLabels.push_back(label);
}
}
}
// 加载测试数据
void loadTestData(const std::string& baseDirTest, std::vector<std::string>& testFilePaths) {
std::string dirPath = baseDirTest;
// 遍历指定文件夹中的所有文件
for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
if (entry.path().extension() == ".txt") {
testFilePaths.push_back(entry.path().string());
}
}
}
int main() {
// 创建训练数据
std::vector<Feature> trainingFeatures;
std::vector<int> trainingLabels;
// 训练数据路径
std::string baseDirTrain = "C:\\Users\\DELL\\Desktop\\期末三期\\10_10";
// 加载训练数据
loadTrainingData(baseDirTrain, trainingFeatures, trainingLabels);
// 创建数字识别器对象
DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels);
// 训练数据路径
std::string baseDirTest = "C:\\Users\\DELL\\Desktop\\期末三期\\txt\\2";
// 加载测试数据
std::vector<std::string> testFilePaths;
loadTestData(baseDirTest, testFilePaths);
// 执行数字识别并计算正确率
int correctCount = 0;
for (size_t i = 0; i < testFilePaths.size(); ++i) {
int recognizedDigit = digitRecognizer.recognizeDigit(testFilePaths[i]);
if (recognizedDigit == 2) {
correctCount++;
}
}
double accuracy = static_cast<double>(correctCount) / testFilePaths.size();
// 输出识别结果
std::cout << "2 Recognition Accuracy: " << accuracy * 100 << "%" << std::endl;
return 0;
}