parent
5e4283231c
commit
9f192e0a02
@ -0,0 +1,242 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <filesystem>
|
||||
|
||||
// 定义图像类
|
||||
class Image {
|
||||
private:
|
||||
std::vector<std::vector<int>> pixels; // 像素数组
|
||||
int width; // 图像宽度
|
||||
int height; // 图像高度
|
||||
|
||||
public:
|
||||
Image(std::vector<std::vector<int>> pixels, int width, int height) : pixels(pixels), width(width), height(height) {}
|
||||
|
||||
// 获取像素值
|
||||
int getPixelValue(int x, int y) const {
|
||||
return pixels[x][y];
|
||||
}
|
||||
|
||||
// 获取图像宽度
|
||||
int getWidth() const {
|
||||
return width;
|
||||
}
|
||||
|
||||
// 获取图像高度
|
||||
int getHeight() const {
|
||||
return height;
|
||||
}
|
||||
};
|
||||
|
||||
class ImageProcessor {
|
||||
public:
|
||||
// 加载图像并进行预处理
|
||||
Image preprocessImage(const std::string& imagePath) {
|
||||
return loadTXT(imagePath);
|
||||
}
|
||||
|
||||
private:
|
||||
// 加载TXT图像
|
||||
Image loadTXT(const std::string& imagePath) {
|
||||
std::ifstream file(imagePath);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("Failed to open TXT image file: " + imagePath);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> pixels;
|
||||
std::string line;
|
||||
int width = 0;
|
||||
while (std::getline(file, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::vector<int> row;
|
||||
int value;
|
||||
while (iss >> value) {
|
||||
row.push_back(value);
|
||||
}
|
||||
pixels.push_back(row);
|
||||
if (width == 0) {
|
||||
width = row.size();
|
||||
}
|
||||
else if (row.size() != width) {
|
||||
throw std::runtime_error("Invalid TXT image format: inconsistent row lengths");
|
||||
}
|
||||
}
|
||||
|
||||
// Create and return Image object
|
||||
return Image(pixels, width, pixels.size());
|
||||
}
|
||||
};
|
||||
|
||||
// 特征类
|
||||
class Feature {
|
||||
private:
|
||||
std::vector<double> values; // 特征值
|
||||
|
||||
public:
|
||||
Feature(std::vector<double> values) : values(values) {}
|
||||
|
||||
// 获取特征值
|
||||
double getValue(int index) const {
|
||||
return values[index];
|
||||
}
|
||||
|
||||
// 获取特征值数量
|
||||
size_t getSize() const {
|
||||
return values.size();
|
||||
}
|
||||
};
|
||||
|
||||
// 特征提取器
|
||||
class FeatureExtractor {
|
||||
public:
|
||||
// 从图像中提取特征
|
||||
Feature extractFeature(const Image& image) {
|
||||
std::vector<double> features;
|
||||
for (int y = 0; y < image.getHeight(); ++y) {
|
||||
for (int x = 0; x < image.getWidth(); ++x) {
|
||||
features.push_back(static_cast<double>(image.getPixelValue(x, y)));
|
||||
}
|
||||
}
|
||||
return Feature(features);
|
||||
}
|
||||
};
|
||||
|
||||
// 距离度量类
|
||||
class DistanceMetric {
|
||||
public:
|
||||
// 计算两个特征之间的距离(欧氏距离)
|
||||
double calculateDistance(const Feature& feature1, const Feature& feature2) const {
|
||||
if (feature1.getSize() != feature2.getSize()) {
|
||||
throw std::runtime_error("Feature vectors must be of the same size");
|
||||
}
|
||||
double distance = 0.0;
|
||||
for (size_t i = 0; i < feature1.getSize(); ++i) {
|
||||
distance += pow(feature1.getValue(i) - feature2.getValue(i), 2);
|
||||
}
|
||||
return sqrt(distance);
|
||||
}
|
||||
};
|
||||
|
||||
// KNN 模型
|
||||
class KNNModel {
|
||||
private:
|
||||
std::vector<Feature> features; // 特征向量
|
||||
std::vector<int> labels; // 标签
|
||||
DistanceMetric distanceMetric; // 距离度量
|
||||
|
||||
public:
|
||||
// 训练 KNN 模型
|
||||
void train(const std::vector<Feature>& features, const std::vector<int>& labels) {
|
||||
this->features = features;
|
||||
this->labels = labels;
|
||||
}
|
||||
|
||||
// 使用 KNN 模型进行预测
|
||||
int predict(const Feature& feature, int k) {
|
||||
// 计算待预测特征与训练集中所有特征的距离,并记录索引
|
||||
std::vector<std::pair<double, int>> distances;
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
double distance = distanceMetric.calculateDistance(features[i], feature);
|
||||
distances.push_back(std::make_pair(distance, i));
|
||||
}
|
||||
|
||||
// 根据距离排序,选择最近的 K 个样本
|
||||
std::sort(distances.begin(), distances.end());
|
||||
|
||||
// 统计最近的 K 个样本中各类别出现的次数
|
||||
std::unordered_map<int, int> classCount;
|
||||
for (int i = 0; i < k; ++i) {
|
||||
int index = distances[i].second;
|
||||
int label = labels[index];
|
||||
classCount[label]++;
|
||||
}
|
||||
|
||||
// 找出最频繁出现的类别
|
||||
int maxCount = 0;
|
||||
int predictedLabel = -1;
|
||||
for (auto& pair : classCount) {
|
||||
if (pair.second > maxCount) {
|
||||
maxCount = pair.second;
|
||||
predictedLabel = pair.first;
|
||||
}
|
||||
}
|
||||
|
||||
return predictedLabel;
|
||||
}
|
||||
};
|
||||
|
||||
// 数字识别器
|
||||
class DigitRecognizer {
|
||||
private:
|
||||
ImageProcessor imageProcessor;
|
||||
FeatureExtractor featureExtractor;
|
||||
KNNModel knnModel;
|
||||
|
||||
public:
|
||||
// 构造函数
|
||||
DigitRecognizer(const std::vector<Feature>& trainingFeatures, const std::vector<int>& trainingLabels) {
|
||||
knnModel.train(trainingFeatures, trainingLabels);
|
||||
}
|
||||
|
||||
// 执行数字识别
|
||||
int recognizeDigit(const std::string& imagePath) {
|
||||
// 加载图像并进行预处理
|
||||
Image image = imageProcessor.preprocessImage(imagePath);
|
||||
|
||||
// 提取图像特征
|
||||
Feature feature = featureExtractor.extractFeature(image);
|
||||
|
||||
// 使用 KNN 算法进行预测
|
||||
int predictedDigit = knnModel.predict(feature, 3 /* K 值 */);
|
||||
|
||||
return predictedDigit;
|
||||
}
|
||||
};
|
||||
|
||||
// 加载训练数据
|
||||
void loadTrainingData(const std::string& baseDir, std::vector<Feature>& trainingFeatures, std::vector<int>& trainingLabels) {
|
||||
ImageProcessor imageProcessor;
|
||||
FeatureExtractor featureExtractor;
|
||||
|
||||
for (int label = 0; label <= 9; ++label) {
|
||||
std::string dirPath = baseDir + "\\" + std::to_string(label);
|
||||
for (int i = 1; i <= 10; ++i) {
|
||||
std::stringstream ss;
|
||||
ss << dirPath << "\\img" << label << i << ".txt";
|
||||
std::string filePath = ss.str();
|
||||
Image image = imageProcessor.preprocessImage(filePath);
|
||||
Feature feature = featureExtractor.extractFeature(image);
|
||||
trainingFeatures.push_back(feature);
|
||||
trainingLabels.push_back(label);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// 创建训练数据
|
||||
std::vector<Feature> trainingFeatures;
|
||||
std::vector<int> trainingLabels;
|
||||
|
||||
// 训练数据路径
|
||||
std::string baseDir = "C:\\Users\\DELL\\Desktop\\期末三期\\10_10";
|
||||
|
||||
// 加载训练数据
|
||||
loadTrainingData(baseDir, trainingFeatures, trainingLabels);
|
||||
|
||||
// 创建数字识别器对象
|
||||
DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels);
|
||||
|
||||
// 执行数字识别
|
||||
std::string imagePath = "C:\\Users\\DELL\\Desktop\\期末三期\\txt\\9\\img9286.txt"; // 替换为实际的图像路径
|
||||
int recognizedDigit = digitRecognizer.recognizeDigit(imagePath);
|
||||
|
||||
// 输出识别结果
|
||||
std::cout << "Recognized Digit: " << recognizedDigit << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in new issue