From 48ad5116a878e1cdd84dfe54f39163680275a6f9 Mon Sep 17 00:00:00 2001 From: sztu202200202053 Date: Mon, 24 Jun 2024 15:14:57 +0800 Subject: [PATCH] ADD file via upload --- model_core.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 model_core.py diff --git a/model_core.py b/model_core.py new file mode 100644 index 0000000..acf6e16 --- /dev/null +++ b/model_core.py @@ -0,0 +1,66 @@ +import tensorflow as tf +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D, Flatten, Dense, Dropout +from tensorflow.keras.optimizers import Adam +import numpy as np +from PIL import Image + + +class LeNet5Custom: + def __init__(self, dropout_rate): + self.dropout_rate = dropout_rate + self.model = self._create_model() + + + def _create_model(self): + model = Sequential([ + Input(shape=(28, 28, 1)), # 明确指定输入形状 + Conv2D(32, (3, 3), activation='relu'), + BatchNormalization(), + MaxPooling2D((2, 2)), + + Conv2D(64, (3, 3), activation='relu'), + BatchNormalization(), + MaxPooling2D((2, 2)), + + Flatten(), + + Dense(128, activation='relu'), + BatchNormalization(), + Dropout(self.dropout_rate), + + Dense(10, activation='softmax') + ]) + return model + + + def compile_model(self): + self.model.compile(optimizer=Adam(), + loss='categorical_crossentropy', + metrics=['accuracy']) + return self.model + + + def load_model(self, model_path): + self.model = tf.keras.models.load_model(model_path) + for layer in self.model.layers: + if isinstance(layer, Dropout): + layer.rate = self.dropout_rate + return self.model + + + def preprocess_image(self, image_path): + img = Image.open(image_path).convert('L') + img = img.resize((28, 28)) + img_array = np.array(img) + img_array[img_array != 255] = 0 # 将像素值非255的全部设为0 + img_array = img_array.astype("float32") / 255 + img_array = np.expand_dims(img_array, axis=-1) + img_array = np.expand_dims(img_array, axis=0) + return img_array + + + def predict(self, image_path): + preprocessed_image = self.preprocess_image(image_path) + predictions = self.model.predict(preprocessed_image) + return predictions