You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.8 KiB
97 lines
3.8 KiB
import tensorflow as tf
|
|
from process_data import DataProcessor
|
|
|
|
batch_size = 128
|
|
learning_rate = 1e-4
|
|
aspect = 'province'
|
|
data_processor = DataProcessor(aspect)
|
|
dir = './'
|
|
model_name = aspect
|
|
|
|
input_x = tf.placeholder(dtype=tf.float32, shape=[None, 20, 20], name='input_x')
|
|
input_y = tf.placeholder(dtype=tf.float32, shape=[None, 34], name='input_y')
|
|
|
|
with tf.name_scope('conv1'):
|
|
W_C1 = tf.Variable(tf.truncated_normal(shape=[3, 3, 1, 32], stddev=0.1))
|
|
b_C1 = tf.Variable(tf.constant(0.1, tf.float32, shape=[32]))
|
|
X = tf.reshape(input_x, [-1, 20, 20, 1])
|
|
featureMap_C1 = tf.nn.relu(tf.nn.conv2d(X, W_C1, strides=[1, 1, 1, 1], padding='SAME') + b_C1)
|
|
|
|
with tf.name_scope('conv2'):
|
|
W_C2 = tf.Variable(tf.truncated_normal(shape=[3, 3, 32, 64], stddev=0.1))
|
|
b_C2 = tf.Variable(tf.constant(0.1, tf.float32, shape=[64]))
|
|
featureMap_C2 = tf.nn.relu(tf.nn.conv2d(featureMap_C1, W_C2, strides=[1, 1, 1, 1], padding='SAME') + b_C2)
|
|
|
|
with tf.name_scope('pooling2'):
|
|
featureMap_S2 = tf.nn.max_pool(featureMap_C2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
|
|
|
|
with tf.name_scope('conv3'):
|
|
W_C3 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 8], stddev=0.1))
|
|
b_C3 = tf.Variable(tf.constant(0.1, shape=[8], dtype=tf.float32))
|
|
featureMap_C3 = tf.nn.relu(tf.nn.conv2d(featureMap_S2, filter=W_C3, strides=[1, 1, 1, 1], padding='SAME') + b_C3)
|
|
|
|
with tf.name_scope('pooling3'):
|
|
featureMap_S3 = tf.nn.max_pool(featureMap_C3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
|
|
|
|
with tf.name_scope('fulnet'):
|
|
featureMap_flatten = tf.reshape(featureMap_S3, [-1, 5*5*8])
|
|
W_F4 = tf.Variable(tf.truncated_normal(shape=[5*5*8, 512], stddev=0.1))
|
|
b_F4 = tf.Variable(tf.constant(0.1, shape=[512], dtype=tf.float32))
|
|
out_F4 = tf.nn.relu(tf.matmul(featureMap_flatten, W_F4) + b_F4)
|
|
out_F4 = tf.nn.dropout(out_F4, keep_prob=0.5)
|
|
|
|
with tf.name_scope('output'):
|
|
W_OUTPUT = tf.Variable(tf.truncated_normal(shape=[512, 34], stddev=0.1))
|
|
b_OUTPUT = tf.Variable(tf.constant(0.1, shape=[34], dtype=tf.float32))
|
|
logits = tf.matmul(out_F4, W_OUTPUT)+b_OUTPUT
|
|
|
|
loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=input_y, logits=logits))
|
|
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
|
|
predictY = tf.nn.softmax(logits)
|
|
y_pred = tf.arg_max(predictY, 1)
|
|
bool_pred = tf.equal(tf.arg_max(input_y, 1), y_pred)
|
|
right_rate = tf.reduce_mean(tf.to_float(bool_pred))
|
|
|
|
saver = tf.train.Saver()
|
|
|
|
|
|
def load_model(sess, dir, model_name):
|
|
ckpt = tf.train.get_checkpoint_state(dir)
|
|
if ckpt and ckpt.model_checkpoint_path:
|
|
print('*'*30)
|
|
print('load latest model......')
|
|
saver.restore(sess, dir+'.\\'+model_name)
|
|
print('*'*30)
|
|
|
|
|
|
def save_model(sess, dir, model_name):
|
|
saver.save(sess, dir+model_name)
|
|
|
|
|
|
with tf.Session() as sess:
|
|
sess.run(tf.initialize_all_variables())
|
|
step = 1
|
|
display_interval = 200
|
|
max_epoch = 500
|
|
epoch = 1
|
|
acc = 0
|
|
load_model(sess, dir=dir, model_name=model_name)
|
|
while True:
|
|
if step % display_interval == 0:
|
|
image_batch, label_batch, epoch = data_processor.next_valid_batch(batch_size)
|
|
acc = sess.run(right_rate, feed_dict={input_x: image_batch, input_y: label_batch})
|
|
print({str(epoch)+':'+str(step): acc})
|
|
image_batch, label_batch, epoch = data_processor.next_train_batch(batch_size)
|
|
sess.run([loss, train_op], {input_x: image_batch, input_y: label_batch})
|
|
if epoch > max_epoch:
|
|
break
|
|
step += 1
|
|
while True:
|
|
test_img, test_lab, test_epoch = data_processor.next_test_batch(batch_size)
|
|
test_acc = sess.run(right_rate, {input_x: test_img, input_y: test_lab})
|
|
acc = test_acc * 0.8 + acc * 0.2
|
|
if test_epoch != epoch:
|
|
print({'Test Over..... acc:': acc})
|
|
break
|
|
save_model(sess, dir, model_name)
|