diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/ML.iml b/.idea/ML.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/ML.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..f0e9c8a --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..f7d5bab --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..5a41e01 --- /dev/null +++ b/main.py @@ -0,0 +1,61 @@ +import os + +import tensorflow as tf +from tensorflow import keras +import matplotlib.pyplot as plt + +print(tf.version.VERSION) + +(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() + +plt.imshow(train_images[0]) + +train_labels = train_labels[:1000] +test_labels = test_labels[:1000] + +train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 +test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 + +# 定义一个简单的序列模型 +def create_model(): + model = tf.keras.models.Sequential([ + keras.layers.Dense(512, activation='relu', input_shape=(784,)), + keras.layers.Dropout(0.2), + keras.layers.Dense(10) + ]) + + model.compile(optimizer='adam', + loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + + return model + +# 创建一个基本的模型实例 +model = create_model() + +# 显示模型的结构 +model.summary() + +checkpoint_path = "training_1/cp.ckpt" +checkpoint_dir = os.path.dirname(checkpoint_path) + +# 创建一个保存模型权重的回调 +cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, + save_weights_only=True, + verbose=1) + +# 使用新的回调训练模型 +model.fit(train_images, + train_labels, + epochs=10, + batch_size=8, + validation_data=(test_images,test_labels), + callbacks=[cp_callback]) # 通过回调训练 + +# 这可能会生成与保存优化程序状态相关的警告。 +# 这些警告(以及整个笔记本中的类似警告) +# 是防止过时使用,可以忽略。 + +results = model.evaluate(test_images, test_labels, verbose=2) + +print(results) \ No newline at end of file diff --git a/training_1/checkpoint b/training_1/checkpoint new file mode 100644 index 0000000..c4fa61d --- /dev/null +++ b/training_1/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "cp.ckpt" +all_model_checkpoint_paths: "cp.ckpt" diff --git a/training_1/cp.ckpt.data-00000-of-00001 b/training_1/cp.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..5b130f8 Binary files /dev/null and b/training_1/cp.ckpt.data-00000-of-00001 differ diff --git a/training_1/cp.ckpt.index b/training_1/cp.ckpt.index new file mode 100644 index 0000000..c26620c Binary files /dev/null and b/training_1/cp.ckpt.index differ