You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
270 lines
7.7 KiB
270 lines
7.7 KiB
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <filesystem>
|
|
|
|
// 定义图像类
|
|
class Image {
|
|
private:
|
|
std::vector<std::vector<int>> pixels; // 像素数组
|
|
int width; // 图像宽度
|
|
int height; // 图像高度
|
|
|
|
public:
|
|
Image(std::vector<std::vector<int>> pixels, int width, int height) : pixels(pixels), width(width), height(height) {}
|
|
|
|
// 获取像素值
|
|
int getPixelValue(int x, int y) const {
|
|
return pixels[x][y];
|
|
}
|
|
|
|
// 获取图像宽度
|
|
int getWidth() const {
|
|
return width;
|
|
}
|
|
|
|
// 获取图像高度
|
|
int getHeight() const {
|
|
return height;
|
|
}
|
|
};
|
|
|
|
class ImageProcessor {
|
|
public:
|
|
// 加载图像并进行预处理
|
|
Image preprocessImage(const std::string& imagePath) {
|
|
return loadTXT(imagePath);
|
|
}
|
|
|
|
private:
|
|
// 加载TXT图像
|
|
Image loadTXT(const std::string& imagePath) {
|
|
std::ifstream file(imagePath);
|
|
if (!file.is_open()) {
|
|
throw std::runtime_error("Failed to open TXT image file: " + imagePath);
|
|
}
|
|
|
|
std::vector<std::vector<int>> pixels;
|
|
std::string line;
|
|
int width = 0;
|
|
while (std::getline(file, line)) {
|
|
std::istringstream iss(line);
|
|
std::vector<int> row;
|
|
int value;
|
|
while (iss >> value) {
|
|
row.push_back(value);
|
|
}
|
|
pixels.push_back(row);
|
|
if (width == 0) {
|
|
width = row.size();
|
|
}
|
|
else if (row.size() != width) {
|
|
throw std::runtime_error("Invalid TXT image format: inconsistent row lengths");
|
|
}
|
|
}
|
|
|
|
// Create and return Image object
|
|
return Image(pixels, width, pixels.size());
|
|
}
|
|
};
|
|
|
|
// 特征类
|
|
class Feature {
|
|
private:
|
|
std::vector<double> values; // 特征值
|
|
|
|
public:
|
|
Feature(std::vector<double> values) : values(values) {}
|
|
|
|
// 获取特征值
|
|
double getValue(int index) const {
|
|
return values[index];
|
|
}
|
|
|
|
// 获取特征值数量
|
|
size_t getSize() const {
|
|
return values.size();
|
|
}
|
|
};
|
|
|
|
// 特征提取器
|
|
class FeatureExtractor {
|
|
public:
|
|
// 从图像中提取特征
|
|
Feature extractFeature(const Image& image) {
|
|
std::vector<double> features;
|
|
for (int y = 0; y < image.getHeight(); ++y) {
|
|
for (int x = 0; x < image.getWidth(); ++x) {
|
|
features.push_back(static_cast<double>(image.getPixelValue(x, y)));
|
|
}
|
|
}
|
|
return Feature(features);
|
|
}
|
|
};
|
|
|
|
// 距离度量类
|
|
class DistanceMetric {
|
|
public:
|
|
// 计算两个特征之间的距离(欧氏距离)
|
|
double calculateDistance(const Feature& feature1, const Feature& feature2) const {
|
|
if (feature1.getSize() != feature2.getSize()) {
|
|
throw std::runtime_error("Feature vectors must be of the same size");
|
|
}
|
|
double distance = 0.0;
|
|
for (size_t i = 0; i < feature1.getSize(); ++i) {
|
|
distance += pow(feature1.getValue(i) - feature2.getValue(i), 2);
|
|
}
|
|
return sqrt(distance);
|
|
}
|
|
};
|
|
|
|
// KNN 模型
|
|
class KNNModel {
|
|
private:
|
|
std::vector<Feature> features; // 特征向量
|
|
std::vector<int> labels; // 标签
|
|
DistanceMetric distanceMetric; // 距离度量
|
|
|
|
public:
|
|
// 训练 KNN 模型
|
|
void train(const std::vector<Feature>& features, const std::vector<int>& labels) {
|
|
this->features = features;
|
|
this->labels = labels;
|
|
}
|
|
|
|
// 使用 KNN 模型进行预测
|
|
int predict(const Feature& feature, int k) {
|
|
// 计算待预测特征与训练集中所有特征的距离,并记录索引
|
|
std::vector<std::pair<double, int>> distances;
|
|
for (size_t i = 0; i < features.size(); ++i) {
|
|
double distance = distanceMetric.calculateDistance(features[i], feature);
|
|
distances.push_back(std::make_pair(distance, i));
|
|
}
|
|
|
|
// 根据距离排序,选择最近的 K 个样本
|
|
std::sort(distances.begin(), distances.end());
|
|
|
|
// 统计最近的 K 个样本中各类别出现的次数
|
|
std::unordered_map<int, int> classCount;
|
|
for (int i = 0; i < k; ++i) {
|
|
int index = distances[i].second;
|
|
int label = labels[index];
|
|
classCount[label]++;
|
|
}
|
|
|
|
// 找出最频繁出现的类别
|
|
int maxCount = 0;
|
|
int predictedLabel = -1;
|
|
for (auto& pair : classCount) {
|
|
if (pair.second > maxCount) {
|
|
maxCount = pair.second;
|
|
predictedLabel = pair.first;
|
|
}
|
|
}
|
|
|
|
return predictedLabel;
|
|
}
|
|
};
|
|
|
|
// 数字识别器
|
|
class DigitRecognizer {
|
|
private:
|
|
ImageProcessor imageProcessor;
|
|
FeatureExtractor featureExtractor;
|
|
KNNModel knnModel;
|
|
|
|
public:
|
|
// 构造函数
|
|
DigitRecognizer(const std::vector<Feature>& trainingFeatures, const std::vector<int>& trainingLabels) {
|
|
knnModel.train(trainingFeatures, trainingLabels);
|
|
}
|
|
|
|
// 执行数字识别
|
|
int recognizeDigit(const std::string& imagePath) {
|
|
// 加载图像并进行预处理
|
|
Image image = imageProcessor.preprocessImage(imagePath);
|
|
|
|
// 提取图像特征
|
|
Feature feature = featureExtractor.extractFeature(image);
|
|
|
|
// 使用 KNN 算法进行预测
|
|
int predictedDigit = knnModel.predict(feature, 5 /* K 值 */);
|
|
|
|
return predictedDigit;
|
|
}
|
|
};
|
|
|
|
// 加载训练数据
|
|
void loadTrainingData(const std::string& baseDirTrain, std::vector<Feature>& trainingFeatures, std::vector<int>& trainingLabels) {
|
|
ImageProcessor imageProcessor;
|
|
FeatureExtractor featureExtractor;
|
|
|
|
for (int label = 0; label <= 9; ++label) {
|
|
std::string dirPath = baseDirTrain + "\\" + std::to_string(label);
|
|
for (int i = 1; i <= 10; ++i) {
|
|
std::stringstream ss;
|
|
ss << dirPath << "\\img" << label << i << ".txt";
|
|
std::string filePath = ss.str();
|
|
Image image = imageProcessor.preprocessImage(filePath);
|
|
Feature feature = featureExtractor.extractFeature(image);
|
|
trainingFeatures.push_back(feature);
|
|
trainingLabels.push_back(label);
|
|
}
|
|
}
|
|
}
|
|
|
|
// 加载测试数据
|
|
void loadTestData(const std::string& baseDirTest, std::vector<std::string>& testFilePaths) {
|
|
std::string dirPath = baseDirTest;
|
|
|
|
// 遍历指定文件夹中的所有文件
|
|
for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
|
|
if (entry.path().extension() == ".txt") {
|
|
testFilePaths.push_back(entry.path().string());
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
int main() {
|
|
// 创建训练数据
|
|
std::vector<Feature> trainingFeatures;
|
|
std::vector<int> trainingLabels;
|
|
|
|
// 训练数据路径
|
|
std::string baseDirTrain = "C:\\Users\\DELL\\Desktop\\期末三期\\10_10";
|
|
|
|
// 加载训练数据
|
|
loadTrainingData(baseDirTrain, trainingFeatures, trainingLabels);
|
|
|
|
// 创建数字识别器对象
|
|
DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels);
|
|
|
|
// 训练数据路径
|
|
std::string baseDirTest = "C:\\Users\\DELL\\Desktop\\期末三期\\txt\\2";
|
|
|
|
// 加载测试数据
|
|
std::vector<std::string> testFilePaths;
|
|
loadTestData(baseDirTest, testFilePaths);
|
|
|
|
// 执行数字识别并计算正确率
|
|
int correctCount = 0;
|
|
for (size_t i = 0; i < testFilePaths.size(); ++i) {
|
|
int recognizedDigit = digitRecognizer.recognizeDigit(testFilePaths[i]);
|
|
if (recognizedDigit == 2) {
|
|
correctCount++;
|
|
}
|
|
}
|
|
|
|
double accuracy = static_cast<double>(correctCount) / testFilePaths.size();
|
|
|
|
// 输出识别结果
|
|
std::cout << "2 Recognition Accuracy: " << accuracy * 100 << "%" << std::endl;
|
|
|
|
return 0;
|
|
}
|