import sys from py_entitymatching.debugmatcher.debug_gui_utils import _get_metric sys.path.append('/home/w/PycharmProjects/py_entitymatching/py_entitymatching') import py_entitymatching as em import py_entitymatching.catalog.catalog_manager as cm import pandas as pd import time import six def load_data(left_path: str, right_path: str, mapping_path: str): left = pd.read_csv(left_path, encoding='ISO-8859-1') cm.set_key(left, left.columns.values.tolist()[0]) left.fillna("", inplace=True) left = left.astype(str) right = pd.read_csv(right_path, encoding='ISO-8859-1') cm.set_key(right, right.columns.values.tolist()[0]) right.fillna("", inplace=True) right = right.astype(str) mapping = pd.read_csv(mapping_path) mapping = mapping.astype(str) return left, right, mapping if __name__ == '__main__': # 读入公开数据,注册并填充空值 path_Amazon = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amazon.csv' path_Google = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/GoogleProducts.csv' path_Mappings = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amzon_GoogleProducts_perfectMapping.csv' Amazon = pd.read_csv(path_Amazon, encoding='ISO-8859-1') cm.set_key(Amazon, 'id') Amazon.fillna("", inplace=True) Google = pd.read_csv(path_Google, encoding='ISO-8859-1') cm.set_key(Google, 'id') Google.fillna("", inplace=True) Mappings = pd.read_csv(path_Mappings) # 仅保留两表中出现在映射表中的行,增大正样本比例 l_id_list = [] r_id_list = [] # 全部转为字符串 Amazon = Amazon.astype(str) Google = Google.astype(str) Mappings = Mappings.astype(str) for index, row in Mappings.iterrows(): l_id_list.append(row["idAmazon"]) r_id_list.append(row["idGoogleBase"]) selected_Amazon = Amazon[Amazon['id'].isin(l_id_list)] selected_Amazon = selected_Amazon.rename(columns={'title': 'name'}) selected_Google = Google[Google['id'].isin(r_id_list)] cm.set_key(selected_Amazon, 'id') cm.set_key(selected_Google, 'id') ######################################################################### # False-retain True-remove def match_last_name(ltuple, rtuple): l_last_name = ltuple['name'] r_last_name = rtuple['name'] if l_last_name != r_last_name: return True else: return False bb = em.BlackBoxBlocker() bb.set_black_box_function(match_last_name) Candidate = bb.block_tables(selected_Amazon, selected_Google, l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'], r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price']) ######################################################################### # block 并将gold标记为0 blocker = em.OverlapBlocker() candidate = blocker.block_tables(selected_Amazon, selected_Google, 'name', 'name', l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'], r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'], overlap_size=0, show_progress=False) candidate['gold'] = 0 start = time.time() candidate_match_rows = [] for index, row in candidate.iterrows(): l_id = row["ltable_id"] map_row = Mappings[Mappings['idAmazon'] == l_id] if map_row is not None: r_id = map_row["idGoogleBase"] for value in r_id: if value == row["rtable_id"]: candidate_match_rows.append(row["_id"]) else: continue for row in candidate_match_rows: candidate.loc[row, 'gold'] = 1 # 裁剪负样本,保持正负样本数量一致 candidate_mismatch = candidate[candidate['gold'] == 0] candidate_match = candidate[candidate['gold'] == 1] candidate_mismatch = candidate_mismatch.sample(n=len(candidate_match)) # 拼接正负样本 candidate_for_train_test = pd.concat([candidate_mismatch, candidate_match]) cm.set_key(candidate_for_train_test, '_id') cm.set_fk_ltable(candidate_for_train_test, 'ltable_id') cm.set_fk_rtable(candidate_for_train_test, 'rtable_id') cm.set_ltable(candidate_for_train_test, selected_Amazon) cm.set_rtable(candidate_for_train_test, selected_Google) # 分为训练测试集 sets = em.split_train_test(candidate_for_train_test, train_proportion=0.7, random_state=0) train_set = sets['train'] test_set = sets['test'] dt = em.DTMatcher(name='DecisionTree', random_state=0) svm = em.SVMMatcher(name='SVM', random_state=0) rf = em.RFMatcher(name='RF', random_state=0) lg = em.LogRegMatcher(name='LogReg', random_state=0) ln = em.LinRegMatcher(name='LinReg') nb = em.NBMatcher(name='NaiveBayes') feature_table = em.get_features_for_matching(selected_Amazon, selected_Google, validate_inferred_attr_types=False) train_feature_vecs = em.extract_feature_vecs(train_set, feature_table=feature_table, attrs_after='gold', show_progress=False) result = em.select_matcher([dt, rf, svm, ln, lg, nb], table=train_feature_vecs, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'], k=5, target_attr='gold', metric_to_select_matcher='f1', random_state=0) test_feature_vecs = em.extract_feature_vecs(test_set, feature_table=feature_table, attrs_after=['ltable_name', 'ltable_description', 'ltable_manufacturer', 'ltable_price', 'rtable_name', 'rtable_description', 'rtable_manufacturer', 'rtable_price', 'gold'], show_progress=False) rf.fit(table=train_feature_vecs, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'], target_attr='gold') predictions = rf.predict(table=test_feature_vecs, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'ltable_name', 'ltable_description', 'ltable_manufacturer', 'ltable_price', 'rtable_name', 'rtable_description', 'rtable_manufacturer', 'rtable_price', 'gold'], append=True, target_attr='predicted', inplace=False) eval_result = em.eval_matches(predictions, 'gold', 'predicted') em.print_eval_summary(eval_result) output_path = "output/eval_result" + str(time.time()) + ".txt" with open(output_path, 'w') as f: for key, value in six.iteritems(_get_metric(eval_result)): f.write(key + " : " + value) f.write('\n')