diff --git a/mnistset.py b/mnistset.py new file mode 100644 index 0000000..a656f65 --- /dev/null +++ b/mnistset.py @@ -0,0 +1,302 @@ +########手写数字数据集########## +###########保存模型############ +########1层隐含层(全连接层)########## +#60000条训练数据和10000条测试数据,28x28像素的灰度图像 +#隐含层激活函数:ReLU函数 +#输出层激活函数:softmax函数(实现多分类) +#损失函数:稀疏交叉熵损失函数 +#输入层有784个节点,隐含层有128个神经元,输出层有10个节点 +import tensorflow as tf +import matplotlib.pyplot as plt +import numpy as np +import tkinter as tk +from tkinter import filedialog +import cv2 +import utilsl +import numpy as np +import argparse +import imutils +from imutils import contours +import tensorflow as tf +import time +print('--------------') +nowtime = time.strftime('%Y-%m-%d %H:%M:%S') +print(nowtime) + +#指定GPU +#import os +#os.environ["CUDA_VISIBLE_DEVICES"] = "0" +#gpus = tf.config.experimental.list_physical_devices('GPU') +#tf.config.experimental.set_memory_growth(gpus[0],True) +#初始化 +plt.rcParams['font.sans-serif'] = ['SimHei'] + +#加载数据 +mnist = tf.keras.datasets.mnist +(train_x,train_y),(test_x,test_y) = mnist.load_data() +print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape)) + +#数据预处理 +#X_train = train_x.reshape((60000,28*28)) +#Y_train = train_y.reshape((60000,28*28)) #后面采用tf.keras.layers.Flatten()改变数组形状 +X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) #归一化 +y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16) + +#建立模型 +model = tf.keras.Sequential() +model.add(tf.keras.layers.Flatten(input_shape=(28,28))) #添加Flatten层说明输入数据的形状 +model.add(tf.keras.layers.Dense(128,activation='relu')) #添加隐含层,为全连接层,128个节点,relu激活函数 +model.add(tf.keras.layers.Dense(10,activation='softmax')) #添加输出层,为全连接层,10个节点,softmax激活函数 +print('\n',model.summary()) #查看网络结构和参数信息 + +#配置模型训练方法 +#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数 +model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy']) + +#训练模型 +#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据) +print('--------------') +nowtime = time.strftime('%Y-%m-%d %H:%M:%S') +print('训练前时刻:'+str(nowtime)) + +history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2) + +print('--------------') +nowtime = time.strftime('%Y-%m-%d %H:%M:%S') +print('训练后时刻:'+str(nowtime)) +#评估模型 +model.evaluate(X_test,y_test,verbose=2) #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力 + +#保存模型参数 +#model.save_weights('C:\\Users\\xuyansong\\Desktop\\深度学习\\python\\MNIST\\模型参数\\mnist_weights.h5') +#保存整个模型 +model.save('mnist_weights.h5') + + +#结果可视化 +print(history.history) +loss = history.history['loss'] #训练集损失 +val_loss = history.history['val_loss'] #测试集损失 +acc = history.history['sparse_categorical_accuracy'] #训练集准确率 +val_acc = history.history['val_sparse_categorical_accuracy'] #测试集准确率 + +plt.figure(figsize=(10,3)) + +plt.subplot(121) +plt.plot(loss,color='b',label='train') +plt.plot(val_loss,color='r',label='test') +plt.ylabel('loss') +plt.legend() + +plt.subplot(122) +plt.plot(acc,color='b',label='train') +plt.plot(val_acc,color='r',label='test') +plt.ylabel('Accuracy') +plt.legend() + +#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存 +#根据需要自行选择 +#plt.ion() #打开交互式操作模式 +#plt.show() +#plt.pause(5) +#plt.close() + +#使用模型 +#plt.figure() +#for i in range(10): +# num = np.random.randint(1,10000) + +# plt.subplot(2,5,i+1) +# plt.axis('off') +# plt.imshow(test_x[num],cmap='gray') + # demo = tf.reshape(X_test[num],(1,28,28)) +# y_pred = np.argmax(model.predict(demo)) +# plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred)) +#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1) +#print('X_test[0:5]: %s'%(X_test[0:5].shape)) +#print('y_pred: %s'%(y_pred)) + +#plt.ion() #打开交互式操作模式 +#plt.show() +#plt.pause(5) +#plt.close() + + +# 创建tkinter根窗口并立即隐藏 +root = tk.Tk() +root.withdraw() + +# 弹出文件选择对话框让用户选择模板文件 +#template_file_path = filedialog.askopenfilename(title="选择模板文件", filetypes=[("PNG files", "*.png"), ("JPEG files", "*.jpg"), ("All files", "*.*")]) +# 弹出文件选择对话框让用户选择信用卡图片 +image_file_path = filedialog.askopenfilename(title="选择信用卡图片", filetypes=[("PNG files", "*.png"), ("JPEG files", "*.jpg"), ("All files", "*.*")]) + +# 使用用户选择的路径读取模板文件和信用卡图片 +#img = cv2.imread(template_file_path) +image = cv2.imread(image_file_path) + +#指定信用卡类型 +FIRST_NUMBER={ + "3":"American Express", + "4":"Visa", + "5":"MasterCard", + "6":"Discover Card" +} + +#绘图展示 +def cv_show(name,img): + cv2.imshow(name,img) + cv2.waitKey(0) + cv2.destroyAllWindows() + +def preprocess_image(roi): + # 调整图像大小并归一化 + roi = cv2.resize(roi, (28, 28)) + roi = roi / 255.0 + roi = roi.reshape(1, 28, 28, 1) # 为模型输入调整形状 + return roi + +def predict_digit(roi, model): + roi = preprocess_image(roi) + prediction = model.predict(roi) + digit = np.argmax(prediction) + return str(digit) + +#读取一个模板文件 +#img=cv2.imread(template_file_path) +#cv_show("img",img) +#灰度图 +#ref=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) +#cv_show('ref',ref) +#二值图像 +#ref=cv2.threshold(ref,10,255,cv2.THRESH_BINARY_INV)[1] +#cv_show('ref',ref) + +#model = tf.keras.models.load_model('mnist_weights.h5') + +#计算轮廓 +#cv2.findContours()函数接受的参数为二值图,即黑白的(不是灰度图) +#cv2.RETR_EXTERNAL只检测外轮廓,cv2.CHAIN_APPROX_SIMPLE只保留终点坐标 +#返回的list中每个元素都是图像中的一个轮廓 + +#refCnts,hierarchy=cv2.findContours(ref.copy(),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) +#cv2.drawContours(img,refCnts,-1,(0,0,255),3) +#cv_show('img',img) +#refCnts=utilsl.sort_contours(refCnts,method="left-to-right")[0]#排序从左到右,从上到下 +#digits={} + +''' +第一个参数:img是原图 +第二个参数:(x,y)是矩阵的左上点坐标 +第三个参数:(x+w,y+h)是矩阵的右下点坐标 +第四个参数:(0,255,0)是画线对应的rgb颜色 +''' +#遍历每一个轮廓 +#for(i,c) in enumerate(refCnts): + #计算外接矩形并且resize成合适大小 +# (x,y,w,h)=cv2.boundingRect(c) + # roi=ref[y:y+h,x:x+w] + # roi=cv2.resize(roi,(57,58)) + + #每一个数字对应一个模板 + # digits[i]=roi + +#初始化卷积核 +rectKernel=cv2.getStructuringElement(cv2.MORPH_RECT,(9,3)) +sqKernel=cv2.getStructuringElement(cv2.MORPH_RECT,(5,5)) + +#读取输入图像,预处理 +image=cv2.imread(image_file_path) +cv_show('image',image) +image=utilsl.resize(image,width=300) +gray=cv2.cvtColor(image,cv2.COLOR_BGR2GRAY) +cv_show('gray',gray) + +#礼帽操作,突出更明亮的区域 +tophat=cv2.morphologyEx(gray,cv2.MORPH_TOPHAT,rectKernel) +cv_show('tophat',tophat) + +#计算 +gradX=cv2.Sobel(tophat,ddepth=cv2.CV_32F,dx=1,dy=0,ksize=1) +gradX=np.absolute(gradX) +(minVal,maxVal)=(np.min(gradX),np.max(gradX)) +gradX=(255*((gradX-minVal)/(maxVal-minVal))) +gradX=gradX.astype("uint8") + +print(np.array(gradX).shape) +cv_show('gradX',gradX) + +#通过闭操作,(先膨胀,在腐蚀)将数字连在一起 +gradX=cv2.morphologyEx(gradX,cv2.MORPH_CLOSE,rectKernel) +cv_show('gradX',gradX) +#THRESH_OTSU会自动寻找合适的阈值,适合双峰,需把阈值参数设置为0 +thresh=cv2.threshold(gradX,0,255,cv2.THRESH_BINARY|cv2.THRESH_OTSU)[1] +cv_show('thresh',thresh) + +#再来一个闭操作 +thresh=cv2.morphologyEx(thresh,cv2.MORPH_CLOSE,sqKernel) +cv_show('thresh',thresh) + +thresh=cv2.morphologyEx(thresh,cv2.MORPH_CLOSE,sqKernel) +cv_show('thresh',thresh) + +#计算轮廓 +threshCnts,hierarchy=cv2.findContours(thresh.copy(),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) +cnts=threshCnts +cur_img=image.copy() +cv2.drawContours(cur_img,cnts,-1,(0,0,255),3) +cv_show('img',cur_img) +locs=[] + +#遍历轮廓 +for (i,c) in enumerate(cnts): + #计算矩形 + (x,y,w,h)=cv2.boundingRect(c) + ar=w/float(h) + + #适合合适的区域,根据实际任务来,这里的基本是四个数字一组 + if ar>2.5 and ar <4: + if(w>40 and w<55) and (h>10 and h<20): + #符合的留下来 + locs.append((x,y,w,h)) +#将符合的轮廓从左到右排序 +locs=sorted(locs,key=lambda x:x[0]) +output=[] + +#遍历每一个轮廓中的数字 +for (i,(gX,gY,gW,gH)) in enumerate(locs): + #initialize the list of group digits + groupOutput=[] + + #根据坐标提取每一个组 + group=gray[gY-5:gY+gH+5,gX-5:gX+gW+5] + cv_show('group',group) + #预处理 + group=cv2.threshold(group,0,255,cv2.THRESH_BINARY|cv2.THRESH_OTSU)[1] + cv_show('group',group) + #计算每一个轮廓 + digitCnts,hierarchy=cv2.findContours(group.copy(),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) + digitCnts=contours.sort_contours(digitCnts,method="left-to-right")[0] + #计算每一组总的每一个数值 + for c in digitCnts: + (x, y, w, h) = cv2.boundingRect(c) + roi = group[y:y + h, x:x + w] + roi = cv2.resize(roi, (28, 28)) + cv_show('roi', roi) + + digit = predict_digit(roi, model) + groupOutput.append(digit) + + print("识别的数字:", groupOutput) + + #画出来 + cv2.rectangle(image,(gX-5,gY-5),(gX+gW+5,gY+gH+5),(0,0,255),1) + cv2.putText(image,"".join(groupOutput),(gX,gY-15),cv2.FONT_HERSHEY_SIMPLEX,0.65,(0,0,255),2) + #得到结果 + output.extend(groupOutput) +# 打印结果 +#print("Credit Card Type: {}".format(FIRST_NUMBER[output[0]])) +print("Credit Card #: {}".format("".join(output))) +cv2.imshow("Image", image) +cv2.waitKey() +cv2.destroyAllWindows() \ No newline at end of file