You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
7.1 KiB
154 lines
7.1 KiB
import sys
|
|
|
|
from py_entitymatching.debugmatcher.debug_gui_utils import _get_metric
|
|
|
|
sys.path.append('/home/w/PycharmProjects/py_entitymatching/py_entitymatching')
|
|
|
|
import py_entitymatching as em
|
|
import py_entitymatching.catalog.catalog_manager as cm
|
|
import pandas as pd
|
|
import time
|
|
import six
|
|
|
|
|
|
def load_data(left_path: str, right_path: str, mapping_path: str):
|
|
left = pd.read_csv(left_path, encoding='ISO-8859-1')
|
|
cm.set_key(left, left.columns.values.tolist()[0])
|
|
left.fillna("", inplace=True)
|
|
left = left.astype(str)
|
|
|
|
right = pd.read_csv(right_path, encoding='ISO-8859-1')
|
|
cm.set_key(right, right.columns.values.tolist()[0])
|
|
right.fillna("", inplace=True)
|
|
right = right.astype(str)
|
|
|
|
mapping = pd.read_csv(mapping_path)
|
|
mapping = mapping.astype(str)
|
|
return left, right, mapping
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 读入公开数据,注册并填充空值
|
|
path_Amazon = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amazon.csv'
|
|
path_Google = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/GoogleProducts.csv'
|
|
path_Mappings = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amzon_GoogleProducts_perfectMapping.csv'
|
|
Amazon = pd.read_csv(path_Amazon, encoding='ISO-8859-1')
|
|
cm.set_key(Amazon, 'id')
|
|
Amazon.fillna("", inplace=True)
|
|
Google = pd.read_csv(path_Google, encoding='ISO-8859-1')
|
|
cm.set_key(Google, 'id')
|
|
Google.fillna("", inplace=True)
|
|
Mappings = pd.read_csv(path_Mappings)
|
|
|
|
# 仅保留两表中出现在映射表中的行,增大正样本比例
|
|
l_id_list = []
|
|
r_id_list = []
|
|
# 全部转为字符串
|
|
Amazon = Amazon.astype(str)
|
|
Google = Google.astype(str)
|
|
Mappings = Mappings.astype(str)
|
|
for index, row in Mappings.iterrows():
|
|
l_id_list.append(row["idAmazon"])
|
|
r_id_list.append(row["idGoogleBase"])
|
|
selected_Amazon = Amazon[Amazon['id'].isin(l_id_list)]
|
|
selected_Amazon = selected_Amazon.rename(columns={'title': 'name'})
|
|
selected_Google = Google[Google['id'].isin(r_id_list)]
|
|
cm.set_key(selected_Amazon, 'id')
|
|
cm.set_key(selected_Google, 'id')
|
|
|
|
#########################################################################
|
|
# False-retain True-remove
|
|
def match_last_name(ltuple, rtuple):
|
|
l_last_name = ltuple['name']
|
|
r_last_name = rtuple['name']
|
|
if l_last_name != r_last_name:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
bb = em.BlackBoxBlocker()
|
|
bb.set_black_box_function(match_last_name)
|
|
|
|
Candidate = bb.block_tables(selected_Amazon, selected_Google, l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'], r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'])
|
|
#########################################################################
|
|
# block 并将gold标记为0
|
|
blocker = em.OverlapBlocker()
|
|
candidate = blocker.block_tables(selected_Amazon, selected_Google, 'name', 'name',
|
|
l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'],
|
|
r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'],
|
|
overlap_size=0, show_progress=False)
|
|
candidate['gold'] = 0
|
|
|
|
start = time.time()
|
|
candidate_match_rows = []
|
|
for index, row in candidate.iterrows():
|
|
l_id = row["ltable_id"]
|
|
map_row = Mappings[Mappings['idAmazon'] == l_id]
|
|
|
|
if map_row is not None:
|
|
r_id = map_row["idGoogleBase"]
|
|
for value in r_id:
|
|
if value == row["rtable_id"]:
|
|
candidate_match_rows.append(row["_id"])
|
|
else:
|
|
continue
|
|
for row in candidate_match_rows:
|
|
candidate.loc[row, 'gold'] = 1
|
|
|
|
# 裁剪负样本,保持正负样本数量一致
|
|
candidate_mismatch = candidate[candidate['gold'] == 0]
|
|
candidate_match = candidate[candidate['gold'] == 1]
|
|
candidate_mismatch = candidate_mismatch.sample(n=len(candidate_match))
|
|
# 拼接正负样本
|
|
candidate_for_train_test = pd.concat([candidate_mismatch, candidate_match])
|
|
cm.set_key(candidate_for_train_test, '_id')
|
|
cm.set_fk_ltable(candidate_for_train_test, 'ltable_id')
|
|
cm.set_fk_rtable(candidate_for_train_test, 'rtable_id')
|
|
cm.set_ltable(candidate_for_train_test, selected_Amazon)
|
|
cm.set_rtable(candidate_for_train_test, selected_Google)
|
|
|
|
# 分为训练测试集
|
|
sets = em.split_train_test(candidate_for_train_test, train_proportion=0.7, random_state=0)
|
|
train_set = sets['train']
|
|
test_set = sets['test']
|
|
|
|
dt = em.DTMatcher(name='DecisionTree', random_state=0)
|
|
svm = em.SVMMatcher(name='SVM', random_state=0)
|
|
rf = em.RFMatcher(name='RF', random_state=0)
|
|
lg = em.LogRegMatcher(name='LogReg', random_state=0)
|
|
ln = em.LinRegMatcher(name='LinReg')
|
|
nb = em.NBMatcher(name='NaiveBayes')
|
|
feature_table = em.get_features_for_matching(selected_Amazon, selected_Google, validate_inferred_attr_types=False)
|
|
|
|
train_feature_vecs = em.extract_feature_vecs(train_set,
|
|
feature_table=feature_table,
|
|
attrs_after='gold',
|
|
show_progress=False)
|
|
|
|
result = em.select_matcher([dt, rf, svm, ln, lg, nb], table=train_feature_vecs,
|
|
exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'],
|
|
k=5,
|
|
target_attr='gold', metric_to_select_matcher='f1', random_state=0)
|
|
|
|
test_feature_vecs = em.extract_feature_vecs(test_set, feature_table=feature_table,
|
|
attrs_after=['ltable_name', 'ltable_description', 'ltable_manufacturer',
|
|
'ltable_price', 'rtable_name', 'rtable_description',
|
|
'rtable_manufacturer', 'rtable_price', 'gold'], show_progress=False)
|
|
|
|
rf.fit(table=train_feature_vecs,
|
|
exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'],
|
|
target_attr='gold')
|
|
predictions = rf.predict(table=test_feature_vecs, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'ltable_name',
|
|
'ltable_description', 'ltable_manufacturer',
|
|
'ltable_price', 'rtable_name', 'rtable_description',
|
|
'rtable_manufacturer', 'rtable_price', 'gold'],
|
|
append=True, target_attr='predicted', inplace=False)
|
|
eval_result = em.eval_matches(predictions, 'gold', 'predicted')
|
|
em.print_eval_summary(eval_result)
|
|
|
|
output_path = "output/eval_result" + str(time.time()) + ".txt"
|
|
with open(output_path, 'w') as f:
|
|
for key, value in six.iteritems(_get_metric(eval_result)):
|
|
f.write(key + " : " + value)
|
|
f.write('\n')
|