diff --git a/newapp.py b/newapp.py new file mode 100644 index 0000000..526279c --- /dev/null +++ b/newapp.py @@ -0,0 +1,115 @@ +import sys +import streamlit as st +from PIL import Image, ImageOps +import tensorflow +import numpy as np +import base64 +from io import BytesIO +import joblib +import os +from streamlit_drawable_canvas import st_canvas + + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + +# 页面布局 +st.set_page_config(page_title="图像分类平台", page_icon="🔬", layout="wide") + +# 页面布局 +st.title('图像分类平台') +st.header('手写数字识别') +# 创建画布 +canvas = st_canvas( + fill_color="#FFFFFF", # 画布背景色 + stroke_color="#000000", # 笔触颜色 + height=300, # 画布高度 + width=300, # 画布宽度 + drawing_mode="freedraw", # 绘制模式 + key='canvas' +) + +# 添加提交按钮 +user_drew = st.button("提交并预测数字") + +# 加载模型 +model_path = os.path.join(BASE_DIR, "model/number_model.h5") +if os.path.isfile(model_path): + try: + num_model = tensorflow.keras.models.load_model(model_path, compile=True) + except Exception as e: + st.error(f"加载模型时发生错误: {e}") + num_model = None +else: + st.error(f"模型文件不存在: {model_path}") + +num_flag = 1 + +# 执行预测 +if user_drew: + if canvas is not None and canvas.image_data is not None: + try: + # 将 NumPy 数组转换为 PIL 图像 + image = canvas.image_data + # 检查 canvas.image_data 是否是有效的图像数据 + print("Canvas image data shape:", canvas.image_data.shape) + print("Canvas image data dtype:", canvas.image_data.dtype) + + if canvas.image_data.shape[-1] == 4: + my_image = canvas.image_data[..., 3:] + else: + my_image = canvas.image_data + + # 显示用户绘制的图像 + st.image(my_image, caption='您绘制的数字') + + print("my_image shape:", my_image.shape) + + # 创建一个新的 PIL 图像,模式设置为 'L'(灰度) + pil_image = Image.new('L', (my_image.shape[1], my_image.shape[0])) + + # 将 my_image 的数据复制到 PIL 图像中 + pil_image.putdata(my_image.reshape(-1)) + + image = pil_image.resize((28, 28), Image.Resampling.LANCZOS) # 调整大小 + + st.image(image, caption='调整大小后') + + # 归一化图像数据 + image_array = np.array(image) / 255.0 + image_array = np.expand_dims(image_array, axis=-1) # 添加通道维度 + image_array = np.expand_dims(image_array, axis=0) # 添加批次维度 + + # 显示用户绘制的图像 + st.image(image_array[0, :, :, 0], caption='处理后的图像') + + # 打印图像数组的形状和数据类型 + print("Image array shape:", image_array.shape) + print("Image array dtype:", image_array.dtype) + + # 打印最小和最大像素值 + print("Min pixel value:", image_array.min()) + print("Max pixel value:", image_array.max()) + + # 使用模型进行预测 + num_class_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + predictions = num_model.predict(image_array)[0] + st.write(f"predictions:{predictions}") + predicted_class_index = np.argmax(predictions) + st.write(f"predicted_class_index:{np.argmax(predictions)}") + predicted_class = num_class_labels[predicted_class_index] + st.write(f"predicted_class:{num_class_labels[predicted_class_index]}") + + # 获取预测的概率值 + predicted_probabilities = predictions * 100 + + st.write(f"对于您绘制的数字的预测结果是:") + st.write(f"类别:'{predicted_class}' 概率:{predicted_probabilities[predicted_class_index]:.2f}") + + except Exception as e: + # 显示错误信息 + st.error("图像处理出错") + st.exception(e) + num_flag = 0 + else: + st.warning("没有检测到图像数据。请在画布上绘制数字。") + num_flag = 0 \ No newline at end of file