You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
2.6 KiB
57 lines
2.6 KiB
import pandas as pd
|
|
from sortedcontainers import SortedSet
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
from keras.layers import Dense, Embedding, Input, Flatten
|
|
from keras.layers import LSTM, GRU, Dropout
|
|
from keras.models import Model
|
|
import keras
|
|
from keras.utils import plot_model
|
|
import utils
|
|
import time
|
|
|
|
|
|
def build_model(want_answer_size, infact_answer_size):
|
|
inputs_want_answer = Input(shape=(want_answer_size, ), name='want_answer_input')
|
|
inputs_infact_answer = Input(shape=(infact_answer_size, ), name='infact_answer_input')
|
|
x_1 = Embedding(want_answer_size, 128, name='want_answer_embedding', embeddings_initializer='he_normal', embeddings_regularizer=keras.regularizers.l2(0.01))(inputs_want_answer)
|
|
x_2 = Embedding(infact_answer_size, 128, name='infact_answer_embedding', embeddings_initializer='he_normal', embeddings_regularizer=keras.regularizers.l2(0.01))(inputs_infact_answer)
|
|
x_1 = GRU(128, dropout=0.2, return_sequences=True, recurrent_initializer='he_normal', recurrent_regularizer=keras.regularizers.l2(0.01))(x_1)
|
|
x_2 = GRU(128, dropout=0.2, return_sequences=True, recurrent_initializer='he_normal', recurrent_regularizer=keras.regularizers.l2(0.01))(x_2)
|
|
x = keras.layers.concatenate([x_1, x_2])
|
|
x = Flatten()(x)
|
|
x = Dropout(0.4)(x)
|
|
x = Dense(64, activation='relu')(x)
|
|
predictions = Dense(2, activation='softmax')(x)
|
|
model = Model(inputs=[inputs_want_answer, inputs_infact_answer], outputs=predictions)
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
df = pd.read_excel('./预期输出与实际输出数据表.xlsx')
|
|
want_answer_corpus, infact_answer_corpus = utils.build_corpus(df)
|
|
onehot = utils.label2onehot(df['是否正确'])
|
|
x_train_1, x_test_1, y_train, y_test = train_test_split(want_answer_corpus, onehot, random_state=2333)
|
|
x_train_2, x_test_2, _, _ = train_test_split(infact_answer_corpus, onehot, random_state=2333)
|
|
|
|
want_answer_corpus_size = len(want_answer_corpus[0])
|
|
infact_answer_corpus_size = len(infact_answer_corpus[0])
|
|
|
|
model = build_model(want_answer_corpus_size, infact_answer_corpus_size)
|
|
|
|
# plot_model(model, to_file='model.png')
|
|
|
|
model.compile(loss='categorical_crossentropy',
|
|
optimizer=keras.optimizers.Adam(lr=1e-4),
|
|
metrics=['accuracy'])
|
|
|
|
print('Train...')
|
|
model.fit([x_train_1, x_train_2], y_train,
|
|
batch_size=16,
|
|
epochs=50)
|
|
|
|
score, acc = model.evaluate([x_test_1, x_test_2], y_test,
|
|
batch_size=8, verbose=0)
|
|
|
|
print('Test score:', score)
|
|
print('Test accuracy:', acc) |