You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.6 KiB

6 years ago
import pandas as pd
from sortedcontainers import SortedSet
import numpy as np
from sklearn.model_selection import train_test_split
from keras.layers import Dense, Embedding, Input, Flatten
from keras.layers import LSTM, GRU, Dropout
from keras.models import Model
import keras
from keras.utils import plot_model
import utils
import time
def build_model(want_answer_size, infact_answer_size):
inputs_want_answer = Input(shape=(want_answer_size, ), name='want_answer_input')
inputs_infact_answer = Input(shape=(infact_answer_size, ), name='infact_answer_input')
x_1 = Embedding(want_answer_size, 128, name='want_answer_embedding', embeddings_initializer='he_normal', embeddings_regularizer=keras.regularizers.l2(0.01))(inputs_want_answer)
x_2 = Embedding(infact_answer_size, 128, name='infact_answer_embedding', embeddings_initializer='he_normal', embeddings_regularizer=keras.regularizers.l2(0.01))(inputs_infact_answer)
6 years ago
x_1 = GRU(128, dropout=0.2, return_sequences=True, recurrent_initializer='he_normal', recurrent_regularizer=keras.regularizers.l2(0.01))(x_1)
x_2 = GRU(128, dropout=0.2, return_sequences=True, recurrent_initializer='he_normal', recurrent_regularizer=keras.regularizers.l2(0.01))(x_2)
6 years ago
x = keras.layers.concatenate([x_1, x_2])
x = Flatten()(x)
6 years ago
x = Dropout(0.4)(x)
6 years ago
x = Dense(64, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs=[inputs_want_answer, inputs_infact_answer], outputs=predictions)
return model
if __name__ == '__main__':
df = pd.read_excel('./预期输出与实际输出数据表.xlsx')
want_answer_corpus, infact_answer_corpus = utils.build_corpus(df)
onehot = utils.label2onehot(df['是否正确'])
x_train_1, x_test_1, y_train, y_test = train_test_split(want_answer_corpus, onehot, random_state=2333)
x_train_2, x_test_2, _, _ = train_test_split(infact_answer_corpus, onehot, random_state=2333)
want_answer_corpus_size = len(want_answer_corpus[0])
infact_answer_corpus_size = len(infact_answer_corpus[0])
model = build_model(want_answer_corpus_size, infact_answer_corpus_size)
# plot_model(model, to_file='model.png')
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Adam(lr=1e-4),
metrics=['accuracy'])
print('Train...')
model.fit([x_train_1, x_train_2], y_train,
batch_size=16,
epochs=50)
score, acc = model.evaluate([x_test_1, x_test_2], y_test,
batch_size=8, verbose=0)
print('Test score:', score)
print('Test accuracy:', acc)