parent
14297bdaaf
commit
c9032d071f
@ -0,0 +1,3 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (tf2.0)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ML.iml" filepath="$PROJECT_DIR$/.idea/ML.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow import keras
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
print(tf.version.VERSION)
|
||||||
|
|
||||||
|
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
|
||||||
|
|
||||||
|
plt.imshow(train_images[0])
|
||||||
|
|
||||||
|
train_labels = train_labels[:1000]
|
||||||
|
test_labels = test_labels[:1000]
|
||||||
|
|
||||||
|
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
|
||||||
|
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
|
||||||
|
|
||||||
|
# 定义一个简单的序列模型
|
||||||
|
def create_model():
|
||||||
|
model = tf.keras.models.Sequential([
|
||||||
|
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
|
||||||
|
keras.layers.Dropout(0.2),
|
||||||
|
keras.layers.Dense(10)
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 创建一个基本的模型实例
|
||||||
|
model = create_model()
|
||||||
|
|
||||||
|
# 显示模型的结构
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
checkpoint_path = "training_1/cp.ckpt"
|
||||||
|
checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||||
|
|
||||||
|
# 创建一个保存模型权重的回调
|
||||||
|
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
|
||||||
|
save_weights_only=True,
|
||||||
|
verbose=1)
|
||||||
|
|
||||||
|
# 使用新的回调训练模型
|
||||||
|
model.fit(train_images,
|
||||||
|
train_labels,
|
||||||
|
epochs=10,
|
||||||
|
batch_size=8,
|
||||||
|
validation_data=(test_images,test_labels),
|
||||||
|
callbacks=[cp_callback]) # 通过回调训练
|
||||||
|
|
||||||
|
# 这可能会生成与保存优化程序状态相关的警告。
|
||||||
|
# 这些警告(以及整个笔记本中的类似警告)
|
||||||
|
# 是防止过时使用,可以忽略。
|
||||||
|
|
||||||
|
results = model.evaluate(test_images, test_labels, verbose=2)
|
||||||
|
|
||||||
|
print(results)
|
@ -0,0 +1,2 @@
|
|||||||
|
model_checkpoint_path: "cp.ckpt"
|
||||||
|
all_model_checkpoint_paths: "cp.ckpt"
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue