You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

270 lines
7.7 KiB

#include <algorithm>
#include <cmath>
#include <iostream>
#include <unordered_map>
#include <vector>
#include <fstream>
#include <sstream>
#include <filesystem>
// <20><><EFBFBD><EFBFBD>ͼ<EFBFBD><CDBC><EFBFBD><EFBFBD>
class Image {
private:
std::vector<std::vector<int>> pixels; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
int width; // ͼ<><CDBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
int height; // ͼ<><CDBC><EFBFBD>߶<EFBFBD>
public:
Image(std::vector<std::vector<int>> pixels, int width, int height) : pixels(pixels), width(width), height(height) {}
// <20><>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD>ֵ
int getPixelValue(int x, int y) const {
return pixels[x][y];
}
// <20><>ȡͼ<C8A1><CDBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
int getWidth() const {
return width;
}
// <20><>ȡͼ<C8A1><CDBC><EFBFBD>߶<EFBFBD>
int getHeight() const {
return height;
}
};
class ImageProcessor {
public:
// <20><><EFBFBD><EFBFBD>ͼ<EFBFBD>񲢽<EFBFBD><F1B2A2BD><EFBFBD>Ԥ<EFBFBD><D4A4><EFBFBD><EFBFBD>
Image preprocessImage(const std::string& imagePath) {
return loadTXT(imagePath);
}
private:
// <20><><EFBFBD><EFBFBD>TXTͼ<54><CDBC>
Image loadTXT(const std::string& imagePath) {
std::ifstream file(imagePath);
if (!file.is_open()) {
throw std::runtime_error("Failed to open TXT image file: " + imagePath);
}
std::vector<std::vector<int>> pixels;
std::string line;
int width = 0;
while (std::getline(file, line)) {
std::istringstream iss(line);
std::vector<int> row;
int value;
while (iss >> value) {
row.push_back(value);
}
pixels.push_back(row);
if (width == 0) {
width = row.size();
}
else if (row.size() != width) {
throw std::runtime_error("Invalid TXT image format: inconsistent row lengths");
}
}
// Create and return Image object
return Image(pixels, width, pixels.size());
}
};
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
class Feature {
private:
std::vector<double> values; // <20><><EFBFBD><EFBFBD>ֵ
public:
Feature(std::vector<double> values) : values(values) {}
// <20><>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD>ֵ
double getValue(int index) const {
return values[index];
}
// <20><>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>
size_t getSize() const {
return values.size();
}
};
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȡ<EFBFBD><C8A1>
class FeatureExtractor {
public:
// <20><>ͼ<EFBFBD><CDBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD>
Feature extractFeature(const Image& image) {
std::vector<double> features;
for (int y = 0; y < image.getHeight(); ++y) {
for (int x = 0; x < image.getWidth(); ++x) {
features.push_back(static_cast<double>(image.getPixelValue(x, y)));
}
}
return Feature(features);
}
};
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
class DistanceMetric {
public:
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>֮<EFBFBD><D6AE><EFBFBD>ľ<EFBFBD><C4BE>루ŷ<EBA3A8>Ͼ<EFBFBD><CFBE>
double calculateDistance(const Feature& feature1, const Feature& feature2) const {
if (feature1.getSize() != feature2.getSize()) {
throw std::runtime_error("Feature vectors must be of the same size");
}
double distance = 0.0;
for (size_t i = 0; i < feature1.getSize(); ++i) {
distance += pow(feature1.getValue(i) - feature2.getValue(i), 2);
}
return sqrt(distance);
}
};
// KNN ģ<><C4A3>
class KNNModel {
private:
std::vector<Feature> features; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
std::vector<int> labels; // <20><>ǩ
DistanceMetric distanceMetric; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
public:
// ѵ<><D1B5> KNN ģ<><C4A3>
void train(const std::vector<Feature>& features, const std::vector<int>& labels) {
this->features = features;
this->labels = labels;
}
// ʹ<><CAB9> KNN ģ<>ͽ<EFBFBD><CDBD><EFBFBD>Ԥ<EFBFBD><D4A4>
int predict(const Feature& feature, int k) {
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ԥ<EFBFBD><D4A4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ľ<EFBFBD><C4BE><EFBFBD><EBA3AC><EFBFBD><EFBFBD>¼<EFBFBD><C2BC><EFBFBD><EFBFBD>
std::vector<std::pair<double, int>> distances;
for (size_t i = 0; i < features.size(); ++i) {
double distance = distanceMetric.calculateDistance(features[i], feature);
distances.push_back(std::make_pair(distance, i));
}
// <20><><EFBFBD>ݾ<EFBFBD><DDBE><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ѡ<EFBFBD><D1A1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> K <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
std::sort(distances.begin(), distances.end());
// ͳ<><CDB3><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> K <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>и<EFBFBD><D0B8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵĴ<D6B5><C4B4><EFBFBD>
std::unordered_map<int, int> classCount;
for (int i = 0; i < k; ++i) {
int index = distances[i].second;
int label = labels[index];
classCount[label]++;
}
// <20>ҳ<EFBFBD><D2B3><EFBFBD>Ƶ<EFBFBD><C6B5><EFBFBD><EFBFBD><EFBFBD>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD>
int maxCount = 0;
int predictedLabel = -1;
for (auto& pair : classCount) {
if (pair.second > maxCount) {
maxCount = pair.second;
predictedLabel = pair.first;
}
}
return predictedLabel;
}
};
// <20><><EFBFBD><EFBFBD>ʶ<EFBFBD><CAB6><EFBFBD><EFBFBD>
class DigitRecognizer {
private:
ImageProcessor imageProcessor;
FeatureExtractor featureExtractor;
KNNModel knnModel;
public:
// <20><><EFBFBD><EFBFBD><ECBAAF>
DigitRecognizer(const std::vector<Feature>& trainingFeatures, const std::vector<int>& trainingLabels) {
knnModel.train(trainingFeatures, trainingLabels);
}
// ִ<><D6B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʶ<EFBFBD><CAB6>
int recognizeDigit(const std::string& imagePath) {
// <20><><EFBFBD><EFBFBD>ͼ<EFBFBD>񲢽<EFBFBD><F1B2A2BD><EFBFBD>Ԥ<EFBFBD><D4A4><EFBFBD><EFBFBD>
Image image = imageProcessor.preprocessImage(imagePath);
// <20><>ȡͼ<C8A1><CDBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
Feature feature = featureExtractor.extractFeature(image);
// ʹ<><CAB9> KNN <20><EFBFBD><E3B7A8><EFBFBD><EFBFBD>Ԥ<EFBFBD><D4A4>
int predictedDigit = knnModel.predict(feature, 5 /* K ֵ */);
return predictedDigit;
}
};
// <20><><EFBFBD><EFBFBD>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
void loadTrainingData(const std::string& baseDirTrain, std::vector<Feature>& trainingFeatures, std::vector<int>& trainingLabels) {
ImageProcessor imageProcessor;
FeatureExtractor featureExtractor;
for (int label = 0; label <= 9; ++label) {
std::string dirPath = baseDirTrain + "\\" + std::to_string(label);
for (int i = 1; i <= 10; ++i) {
std::stringstream ss;
ss << dirPath << "\\img" << label << i << ".txt";
std::string filePath = ss.str();
Image image = imageProcessor.preprocessImage(filePath);
Feature feature = featureExtractor.extractFeature(image);
trainingFeatures.push_back(feature);
trainingLabels.push_back(label);
}
}
}
// <20><><EFBFBD>ز<EFBFBD><D8B2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
void loadTestData(const std::string& baseDirTest, std::vector<std::string>& testFilePaths) {
std::string dirPath = baseDirTest;
// <20><><EFBFBD><EFBFBD>ָ<EFBFBD><D6B8><EFBFBD>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD>е<EFBFBD><D0B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ļ<EFBFBD>
for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
if (entry.path().extension() == ".txt") {
testFilePaths.push_back(entry.path().string());
}
}
}
int main() {
// <20><><EFBFBD><EFBFBD>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
std::vector<Feature> trainingFeatures;
std::vector<int> trainingLabels;
// ѵ<><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>·<EFBFBD><C2B7>
std::string baseDirTrain = "C:\\Users\\DELL\\Desktop\\<EFBFBD><EFBFBD>ĩ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>\\10_10";
// <20><><EFBFBD><EFBFBD>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
loadTrainingData(baseDirTrain, trainingFeatures, trainingLabels);
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʶ<EFBFBD><CAB6><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
DigitRecognizer digitRecognizer(trainingFeatures, trainingLabels);
// ѵ<><D1B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>·<EFBFBD><C2B7>
std::string baseDirTest = "C:\\Users\\DELL\\Desktop\\<EFBFBD><EFBFBD>ĩ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>\\txt\\2";
// <20><><EFBFBD>ز<EFBFBD><D8B2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
std::vector<std::string> testFilePaths;
loadTestData(baseDirTest, testFilePaths);
// ִ<><D6B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʶ<EFBFBD>𲢼<EFBFBD><F0B2A2BC><EFBFBD><EFBFBD><EFBFBD>ȷ<EFBFBD><C8B7>
int correctCount = 0;
for (size_t i = 0; i < testFilePaths.size(); ++i) {
int recognizedDigit = digitRecognizer.recognizeDigit(testFilePaths[i]);
if (recognizedDigit == 2) {
correctCount++;
}
}
double accuracy = static_cast<double>(correctCount) / testFilePaths.size();
// <20><><EFBFBD><EFBFBD>ʶ<EFBFBD><CAB6><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
std::cout << "2 Recognition Accuracy: " << accuracy * 100 << "%" << std::endl;
return 0;
}