Compare commits

...

2 Commits

8
.idea/.gitignore vendored

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="renderExternalDocumentation" value="true" />
</component>
</module>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (hand)" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/hand.iml" filepath="$PROJECT_DIR$/.idea/hand.iml" />
</modules>
</component>
</project>

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_SCI_VIEW" value="true" />
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

BIN
1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

@ -1,2 +0,0 @@
# hand_gesturer_ecognition

Binary file not shown.

@ -1,26 +1,27 @@
import warnings
import threading #导入多线程模块
import threading
import cv2
import mediapipe as mp
import numpy as np
from tensorflow.keras.models import load_model
from tkinter import Tk, Canvas, Button, Label, LEFT, RIGHT, NW
from PIL import Image, ImageTk
# 禁用特定警告
warnings.filterwarnings("ignore", category=UserWarning, message='SymbolDatabase.GetPrototype() is deprecated')
# 初始化 MediaPipe 和 OpenCV
hands = None
mp_draw = mp.solutions.drawing_utils
cap = None
keep_running = False
paused = False
popup_open = False # 用于标记当前是否有弹窗打开
# 加载手势识别模型
model_path = 'D:/hand/hand_gesture_model.h5'
model = load_model(model_path)
# 定义手势类别
gesture_classes = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09']
def start_recognition(callback=None):
global keep_running, cap, hands
if cap is None or not cap.isOpened():
@ -32,9 +33,9 @@ def start_recognition(callback=None):
keep_running = True
threading.Thread(target=run_recognition, args=(callback,)).start()
def run_recognition(callback=None):
global keep_running
last_gesture = None
global keep_running, paused
while keep_running and cap.isOpened():
ret, img = cap.read()
@ -43,25 +44,30 @@ def run_recognition(callback=None):
img = cv2.flip(img, 1)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = hands.process(img_rgb)
total_raised_fingers = 0
if not paused:
results = hands.process(img_rgb)
total_raised_fingers = 0
if results.multi_hand_landmarks:
for handLms in results.multi_hand_landmarks:
mp_draw.draw_landmarks(img, handLms, mp.solutions.hands.HAND_CONNECTIONS)
gesture, raised_fingers = detect_gesture_and_fingers(handLms)
total_raised_fingers += raised_fingers
if gesture == "OK":
handle_ok_gesture()
if results.multi_hand_landmarks:
for handLms in results.multi_hand_landmarks:
mp_draw.draw_landmarks(img, handLms, mp.solutions.hands.HAND_CONNECTIONS)
_, raised_fingers = detect_gesture_and_fingers(handLms)
total_raised_fingers += raised_fingers
if total_raised_fingers > 0:
handle_finger_detection(total_raised_fingers)
cv2.putText(img, f'Total Raised Fingers: {total_raised_fingers}', (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA,)
cv2.putText(img, f'Total Raised Fingers: {total_raised_fingers}', (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA, )
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if callback:
callback(img)
stop_recognition()
#停止识别
def stop_recognition():
global keep_running, cap
keep_running = False
@ -71,30 +77,30 @@ def stop_recognition():
cv2.destroyAllWindows()
#释放摄像头资源
def release_camera():
global cap
if cap is not None and cap.isOpened():
cap.release()
cap = None
def detect_gesture_and_fingers(hand_landmarks):
# 手势识别
gesture_image = get_hand_image(hand_landmarks)
gesture = predict_gesture(gesture_image)
# 手指竖起数量检测
raised_fingers = count_raised_fingers(hand_landmarks)
if is_ok_gesture(hand_landmarks):
gesture = "OK"
return gesture, raised_fingers
def get_hand_image(hand_landmarks):
# 提取手部区域图像
# 示例实现,请根据你的需要进行调整
img = np.zeros((150, 150, 3), dtype=np.uint8) # 示例图像
img = np.zeros((150, 150, 3), dtype=np.uint8)
return img
def predict_gesture(img):
img = cv2.resize(img, (150, 150))
img_array = np.expand_dims(img, axis=0) / 255.0
@ -102,21 +108,19 @@ def predict_gesture(img):
predicted_class = gesture_classes[np.argmax(predictions)]
return predicted_class
def count_raised_fingers(hand_landmarks):
fingers_status = [0, 0, 0, 0, 0]
# 拇指
thumb_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.THUMB_TIP]
thumb_ip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.THUMB_IP]
thumb_mcp = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.THUMB_MCP]
thumb_cmc = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.THUMB_CMC]
# 计算拇指的角度
angle_thumb = calculate_angle(thumb_cmc, thumb_mcp, thumb_tip)
if angle_thumb > 160: # 如果拇指的角度大于160度认为拇指竖起
if angle_thumb > 160:
fingers_status[0] = 1
# 其他手指
for i, finger_tip_id in enumerate([mp.solutions.hands.HandLandmark.INDEX_FINGER_TIP,
mp.solutions.hands.HandLandmark.MIDDLE_FINGER_TIP,
mp.solutions.hands.HandLandmark.RING_FINGER_TIP,
@ -125,17 +129,203 @@ def count_raised_fingers(hand_landmarks):
finger_pip = hand_landmarks.landmark[finger_tip_id - 2]
finger_mcp = hand_landmarks.landmark[finger_tip_id - 3]
# 计算手指的角度
angle_finger = calculate_angle(finger_mcp, finger_pip, finger_tip)
if angle_finger > 160: # 如果手指的角度大于160度认为手指竖起
if angle_finger > 160:
fingers_status[i + 1] = 1
return sum(fingers_status)
def calculate_angle(point1, point2, point3):
# 计算三个点之间的角度
angle = np.arctan2(point3.y - point2.y, point3.x - point2.x) - np.arctan2(point1.y - point2.y, point1.x - point2.x)
angle = np.abs(angle)
if angle > np.pi:
angle = 2 * np.pi - angle
return angle * 180 / np.pi
def is_ok_gesture(hand_landmarks):
thumb_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.THUMB_TIP]
index_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.INDEX_FINGER_TIP]
distance = np.linalg.norm(np.array([thumb_tip.x, thumb_tip.y]) - np.array([index_tip.x, index_tip.y]))
# 检查其他手指是否弯曲
middle_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.MIDDLE_FINGER_TIP]
ring_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.RING_FINGER_TIP]
pinky_tip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.PINKY_TIP]
middle_pip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.MIDDLE_FINGER_PIP]
ring_pip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.RING_FINGER_PIP]
pinky_pip = hand_landmarks.landmark[mp.solutions.hands.HandLandmark.PINKY_FINGER_PIP]
middle_finger_bent = middle_tip.y > middle_pip.y
ring_finger_bent = ring_tip.y > ring_pip.y
pinky_finger_bent = pinky_tip.y > pinky_pip.y
return distance < 0.05 and middle_finger_bent and ring_finger_bent and pinky_finger_bent # 根据实际情况调整这个阈值
def handle_ok_gesture():
global paused, popup_open
if not popup_open:
paused = True
popup_open = True
show_ok_window()
def show_ok_window():
def on_continue():
global paused, popup_open
paused = False
popup_open = False # 关闭弹窗后将标志设置为False
ok_window.destroy()
start_recognition(show_frame)
ok_window = Tk()
ok_window.title("手势检测")
label = Label(ok_window, text="检测到OK手势", font=('Helvetica', 24, 'bold'))
label.pack(pady=20)
continue_button = Button(ok_window, text="继续识别", command=on_continue)
continue_button.pack(pady=10)
ok_window.protocol("WM_DELETE_WINDOW", on_continue)
ok_window.mainloop()
def handle_finger_detection(finger_count):
global paused, popup_open
if not popup_open: # 只有在没有弹窗打开的情况下才处理手指检测并显示弹窗
if finger_count == 1:
paused = True
popup_open = True
show_finger_window("您竖起了一根手指")
elif finger_count == 2:
paused = True
popup_open = True
show_finger_window("您竖起了两根手指")
elif finger_count == 3:
paused = True
popup_open = True
show_finger_window("您竖起了三根手指")
elif finger_count == 4:
paused = True
popup_open = True
show_finger_window("您竖起了四根手指")
elif finger_count == 5:
paused = True
popup_open = True
show_stop_recognition_window()
elif finger_count == 6:
paused = True
popup_open = True
show_finger_window("您竖起了六根手指")
elif finger_count == 7:
paused = True
popup_open = True
show_finger_window("您竖起了七根手指")
elif finger_count == 8:
paused = True
popup_open = True
show_finger_window("您竖起了八根手指")
elif finger_count == 9:
paused = True
popup_open = True
show_finger_window("您竖起了九根手指")
elif finger_count == 10:
paused = True
popup_open = True
show_finger_window("您竖起了十根手指")
def show_finger_window(message):
def on_continue():
global paused, popup_open
paused = False
popup_open = False # 关闭弹窗后将标志设置为False
finger_window.destroy()
start_recognition(show_frame)
finger_window = Tk()
finger_window.title("手指检测")
label = Label(finger_window, text=message, font=('Helvetica', 24, 'bold'))
label.pack(pady=20)
continue_button = Button(finger_window, text="继续识别", command=on_continue)
continue_button.pack(pady=10)
finger_window.protocol("WM_DELETE_WINDOW", on_continue)
finger_window.mainloop()
def show_stop_recognition_window():
def on_continue():
global paused, popup_open
paused = False
popup_open = False # 关闭弹窗后将标志设置为False
stop_window.destroy()
start_recognition(show_frame)
def on_stop():
global popup_open
stop_recognition()
popup_open = False # 关闭弹窗后将标志设置为False
stop_window.destroy()
stop_window = Tk()
stop_window.title("停止识别")
label = Label(stop_window, text="您竖起了五根手指,是否停止识别?", font=('Helvetica', 24, 'bold'))
label.pack(pady=20)
continue_button = Button(stop_window, text="继续识别", command=on_continue)
continue_button.pack(side=LEFT, padx=10, pady=10)
stop_button = Button(stop_window, text="停止识别", command=on_stop)
stop_button.pack(side=RIGHT, padx=10, pady=10)
stop_window.protocol("WM_DELETE_WINDOW", on_continue)
stop_window.mainloop()
def show_frame(img=None):
global paused
if keep_running and cap.isOpened():
if img is not None:
frame = img
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
ret, frame = cap.read()
if not ret:
return
frame = cv2.flip(frame, 1)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame_rgb)
imgtk = ImageTk.PhotoImage(image=img)
canvas.create_image(0, 0, anchor=NW, image=imgtk)
canvas.image = imgtk
if not paused:
root.after(10, show_frame)
else:
root.update_idletasks()
root.update()
if __name__ == "__main__":
root = Tk()
root.title("手势识别")
canvas = Canvas(root, width=640, height=480)
canvas.pack()
start_button = Button(root, text="开始识别", command=lambda: start_recognition(show_frame))
start_button.pack(side=LEFT, padx=10, pady=10)
stop_button = Button(root, text="停止识别", command=stop_recognition)
stop_button.pack(side=RIGHT, padx=10, pady=10)
root.mainloop()

@ -4,7 +4,6 @@ from tkinter import ttk
from PIL import Image, ImageTk
from gesture_recognition import start_recognition, stop_recognition, release_camera, keep_running
# 设置窗口宽高
WINDOW_WIDTH = 800
WINDOW_HEIGHT = 705
@ -23,7 +22,6 @@ def show_welcome_screen():
welcome.title("欢迎使用")
set_window_position(welcome, WINDOW_WIDTH, WINDOW_HEIGHT)
# 加载背景图片
bg_image = Image.open("1.png")
bg_image = bg_image.resize((WINDOW_WIDTH, WINDOW_HEIGHT), Image.Resampling.LANCZOS)
bg_image = ImageTk.PhotoImage(bg_image)
@ -41,7 +39,7 @@ def show_welcome_screen():
welcome.mainloop()
def show_main_screen():
global window
global window, canvas
window = Tk()
window.title("手势识别")
set_window_position(window, WINDOW_WIDTH, WINDOW_HEIGHT)
@ -69,7 +67,6 @@ def show_main_screen():
btn_exit = ttk.Button(frame_controls, text="退出", command=lambda: exit_program())
btn_exit.pack(side=LEFT, padx=10)
# 创建一个样式对象并设置按钮的样式
style = ttk.Style(window)
style.configure('TButton', font=('Helvetica', 14))

Binary file not shown.

@ -0,0 +1,14 @@
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 设置数据目录
data_dir = 'D:/hand/archive'
train_dir = os.path.join(data_dir, 'leapGestRecog')
validation_dir = os.path.join(data_dir, 'validation')
# 简单的标准化处理
datagen = ImageDataGenerator(rescale=1./255)
train_generator = datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
validation_generator = datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')

@ -0,0 +1,188 @@
import random
from math import sin, cos, pi, log
from tkinter import *
CANVAS_WIDTH = 500 # 画布的宽
CANVAS_HEIGHT = 400 # 画布的高
CANVAS_CENTER_X = CANVAS_WIDTH / 2 # 画布中心的X轴坐标
CANVAS_CENTER_Y = CANVAS_HEIGHT / 2 # 画布中心的Y轴坐标
IMAGE_ENLARGE = 11 # 放大比例
HEART_COLOR = "#ffb6c1"
def heart_function(t, shrink_ratio: float = IMAGE_ENLARGE):
"""
爱心函数生成器
:param shrink_ratio: 放大比例
:param t: 参数
:return: 坐标
"""
# 基础函数
x = 16 * (sin(t) ** 3)
y = -(13 * cos(t) - 5 * cos(2 * t) - 2 * cos(3 * t) - cos(4 * t))
# 放大
x *= shrink_ratio
y *= shrink_ratio
# 移到画布中央
x += CANVAS_CENTER_X
y += CANVAS_CENTER_Y
return int(x), int(y)
def scatter_inside(x, y, beta=0.15):
"""
随机内部扩散
:param x: 原x
:param y: 原y
:param beta: 强度
:return: 新坐标
"""
ratio_x = -beta * log(random.random())
ratio_y = -beta * log(random.random())
dx = ratio_x * (x - CANVAS_CENTER_X)
dy = ratio_y * (y - CANVAS_CENTER_Y)
return x - dx, y - dy
def shrink(x, y, ratio):
"""
抖动
:param x: 原x
:param y: 原y
:param ratio: 比例
:return: 新坐标
"""
force = -1 / (
((x - CANVAS_CENTER_X) ** 2 + (y - CANVAS_CENTER_Y) ** 2) ** 0.6
) # 这个参数...
dx = ratio * force * (x - CANVAS_CENTER_X)
dy = ratio * force * (y - CANVAS_CENTER_Y)
return x - dx, y - dy
def curve(p):
"""
自定义曲线函数调整跳动周期
:param p: 参数
:return: 正弦
"""
# 可以尝试换其他的动态函数,达到更有力量的效果(贝塞尔?)
return 2 * (2 * sin(4 * p)) / (2 * pi)
class Heart:
"""
爱心类
"""
def __init__(self, generate_frame=20):
self._points = set() # 原始爱心坐标集合
self._edge_diffusion_points = set() # 边缘扩散效果点坐标集合
self._center_diffusion_points = set() # 中心扩散效果点坐标集合
self.all_points = {} # 每帧动态点坐标
self.build(2000)
self.random_halo = 1000
self.generate_frame = generate_frame
for frame in range(generate_frame):
self.calc(frame)
def build(self, number):
# 爱心
for _ in range(number):
t = random.uniform(0, 2 * pi) # 随机不到的地方造成爱心有缺口
x, y = heart_function(t)
self._points.add((x, y))
# 爱心内扩散
for _x, _y in list(self._points):
for _ in range(3):
x, y = scatter_inside(_x, _y, 0.05)
self._edge_diffusion_points.add((x, y))
# 爱心内再次扩散
point_list = list(self._points)
for _ in range(4000):
x, y = random.choice(point_list)
x, y = scatter_inside(x, y, 0.17)
self._center_diffusion_points.add((x, y))
@staticmethod
def calc_position(x, y, ratio):
# 调整缩放比例
force = 1 / (
((x - CANVAS_CENTER_X) ** 2 + (y - CANVAS_CENTER_Y) ** 2) ** 0.520
) # 魔法参数
dx = ratio * force * (x - CANVAS_CENTER_X) + random.randint(-1, 1)
dy = ratio * force * (y - CANVAS_CENTER_Y) + random.randint(-1, 1)
return x - dx, y - dy
def calc(self, generate_frame):
ratio = 10 * curve(generate_frame / 10 * pi) # 圆滑的周期的缩放比例
halo_radius = int(4 + 6 * (1 + curve(generate_frame / 10 * pi)))
halo_number = int(3000 + 4000 * abs(curve(generate_frame / 10 * pi) ** 2))
all_points = []
# 光环
heart_halo_point = set() # 光环的点坐标集合
for _ in range(halo_number):
t = random.uniform(0, 2 * pi) # 随机不到的地方造成爱心有缺口
x, y = heart_function(t, shrink_ratio=11.6) # 魔法参数
x, y = shrink(x, y, halo_radius)
if (x, y) not in heart_halo_point:
# 处理新的点
heart_halo_point.add((x, y))
x += random.randint(-14, 14)
y += random.randint(-14, 14)
size = random.choice((1, 2, 2))
all_points.append((x, y, size))
# 轮廓
for x, y in self._points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 3)
all_points.append((x, y, size))
# 内容
for x, y in self._edge_diffusion_points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 2)
all_points.append((x, y, size))
for x, y in self._center_diffusion_points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 2)
all_points.append((x, y, size))
self.all_points[generate_frame] = all_points
def render(self, render_canvas, render_frame):
for x, y, size in self.all_points[render_frame % self.generate_frame]:
render_canvas.create_rectangle(
x, y, x + size, y + size, width=0, fill=HEART_COLOR
)
def draw(main: Tk, render_canvas: Canvas, render_heart: Heart, render_frame=0):
render_canvas.delete("all")
render_heart.render(render_canvas, render_frame)
main.after(160, draw, main, render_canvas, render_heart, render_frame + 1)
if __name__ == "__main__":
root = Tk() # 一个Tk
canvas = Canvas(root, bg="black", height=CANVAS_HEIGHT, width=CANVAS_WIDTH)
canvas.pack()
heart = Heart() # 心
draw(root, canvas, heart) # 开始画画~
root.mainloop()

@ -0,0 +1,42 @@
import os
import numpy as np
import matplotlib.pyplot as plt #绘画
from tensorflow.keras.preprocessing import image #图片预处理
from tensorflow.keras.models import load_model #加载模型
#生成图像数据
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 设置数据目录
test_dir = 'D:/hand/archive/leapGestRecog'
# 加载模型
model = load_model('hand_gesture_model.h5')
print("模型已从 hand_gesture_model.h5 加载")
# 使用 ImageDataGenerator 加载测试数据
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
# 评估模型在测试数据集上的性能
test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
# 加载单个图像并进行预测
img_path = 'D:/hand/archive/leapGestRecog/00/01_palm/frame_00_01_0001.png' # 修改为实际图像路径
img = image.load_img(img_path, target_size=(150, 150))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) / 255.0 # 标准化图像数据
# 进行预测
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions, axis=1)
# 输出预测结果
class_labels = {v: k for k, v in test_generator.class_indices.items()}
print(f"预测类别: {class_labels[predicted_class[0]]}")
# 显示图像和预测结果
plt.imshow(img)
plt.title(f"预测类别: {class_labels[predicted_class[0]]}")
plt.axis('off')
plt.show()

@ -0,0 +1,110 @@
import os
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import tensorflow as tf
# 设置字体
font_path = 'C:/Windows/Fonts/msyh.ttc'
prop = fm.FontProperties(fname=font_path)
plt.rcParams['font.family'] = prop.get_name()
# 设置数据目录
data_dir = 'D:/hand/archive'
train_dir = os.path.join(data_dir, 'leapGestRecog')
validation_dir = os.path.join(data_dir, 'validation')
# 检查目录路径是否存在
print("训练数据目录存在:", os.path.exists(train_dir))
print("验证数据目录存在:", os.path.exists(validation_dir))
if os.path.exists(train_dir) and os.path.exists(validation_dir):
# 简单的标准化处理
datagen = ImageDataGenerator(rescale=1./255)
train_generator = datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
validation_generator = datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
train_dataset = tf.data.Dataset.from_generator(
lambda: train_generator,
output_signature=(
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
)
).repeat()
validation_dataset = tf.data.Dataset.from_generator(
lambda: validation_generator,
output_signature=(
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
)
).repeat()
# 检查类别名称和数量
train_classes = train_generator.class_indices
validation_classes = validation_generator.class_indices
print("训练数据集类别:", train_classes)
print("验证数据集类别:", validation_classes)
# 构建模型
model = models.Sequential([
layers.Input(shape=(150, 150, 3)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dense(train_generator.num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = validation_generator.samples // validation_generator.batch_size
history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch,
epochs=30, validation_data=validation_dataset,
validation_steps=validation_steps)
validation_loss, validation_accuracy = model.evaluate(validation_dataset, steps=validation_steps)
print(f"Validation Accuracy: {validation_accuracy * 100:.2f}%")
# 保存模型
model.save('hand_gesture_model.h5')
print("模型已保存到 hand_gesture_model.h5")
# 从文件加载模型
loaded_model = load_model('hand_gesture_model.h5')
print("模型已从 hand_gesture_model.h5 加载")
# 可视化训练过程
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'r', label='训练准确度')
plt.plot(epochs, val_acc, 'b', label='验证准确度')
plt.title('训练和验证准确度')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'r', label='训练损失')
plt.plot(epochs, val_loss, 'b', label='验证损失')
plt.title('训练和验证损失')
plt.legend()
plt.show()
else:
print(f"路径错误,请检查以下路径是否正确:\n训练数据目录: {train_dir}\n验证数据目录: {validation_dir}")
Loading…
Cancel
Save