You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/ml_er/Goods Dataset-8.14.py

154 lines
7.1 KiB

import sys
from py_entitymatching.debugmatcher.debug_gui_utils import _get_metric
sys.path.append('/home/w/PycharmProjects/py_entitymatching/py_entitymatching')
import py_entitymatching as em
import py_entitymatching.catalog.catalog_manager as cm
import pandas as pd
import time
import six
def load_data(left_path: str, right_path: str, mapping_path: str):
left = pd.read_csv(left_path, encoding='ISO-8859-1')
cm.set_key(left, left.columns.values.tolist()[0])
left.fillna("", inplace=True)
left = left.astype(str)
right = pd.read_csv(right_path, encoding='ISO-8859-1')
cm.set_key(right, right.columns.values.tolist()[0])
right.fillna("", inplace=True)
right = right.astype(str)
mapping = pd.read_csv(mapping_path)
mapping = mapping.astype(str)
return left, right, mapping
if __name__ == '__main__':
# 读入公开数据,注册并填充空值
path_Amazon = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amazon.csv'
path_Google = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/GoogleProducts.csv'
path_Mappings = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amzon_GoogleProducts_perfectMapping.csv'
Amazon = pd.read_csv(path_Amazon, encoding='ISO-8859-1')
cm.set_key(Amazon, 'id')
Amazon.fillna("", inplace=True)
Google = pd.read_csv(path_Google, encoding='ISO-8859-1')
cm.set_key(Google, 'id')
Google.fillna("", inplace=True)
Mappings = pd.read_csv(path_Mappings)
# 仅保留两表中出现在映射表中的行,增大正样本比例
l_id_list = []
r_id_list = []
# 全部转为字符串
Amazon = Amazon.astype(str)
Google = Google.astype(str)
Mappings = Mappings.astype(str)
for index, row in Mappings.iterrows():
l_id_list.append(row["idAmazon"])
r_id_list.append(row["idGoogleBase"])
selected_Amazon = Amazon[Amazon['id'].isin(l_id_list)]
selected_Amazon = selected_Amazon.rename(columns={'title': 'name'})
selected_Google = Google[Google['id'].isin(r_id_list)]
cm.set_key(selected_Amazon, 'id')
cm.set_key(selected_Google, 'id')
#########################################################################
# False-retain True-remove
def match_last_name(ltuple, rtuple):
l_last_name = ltuple['name']
r_last_name = rtuple['name']
if l_last_name != r_last_name:
return True
else:
return False
bb = em.BlackBoxBlocker()
bb.set_black_box_function(match_last_name)
Candidate = bb.block_tables(selected_Amazon, selected_Google, l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'], r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'])
#########################################################################
# block 并将gold标记为0
blocker = em.OverlapBlocker()
candidate = blocker.block_tables(selected_Amazon, selected_Google, 'name', 'name',
l_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'],
r_output_attrs=['id', 'name', 'description', 'manufacturer', 'price'],
overlap_size=0, show_progress=False)
candidate['gold'] = 0
start = time.time()
candidate_match_rows = []
for index, row in candidate.iterrows():
l_id = row["ltable_id"]
map_row = Mappings[Mappings['idAmazon'] == l_id]
if map_row is not None:
r_id = map_row["idGoogleBase"]
for value in r_id:
if value == row["rtable_id"]:
candidate_match_rows.append(row["_id"])
else:
continue
for row in candidate_match_rows:
candidate.loc[row, 'gold'] = 1
# 裁剪负样本,保持正负样本数量一致
candidate_mismatch = candidate[candidate['gold'] == 0]
candidate_match = candidate[candidate['gold'] == 1]
candidate_mismatch = candidate_mismatch.sample(n=len(candidate_match))
# 拼接正负样本
candidate_for_train_test = pd.concat([candidate_mismatch, candidate_match])
cm.set_key(candidate_for_train_test, '_id')
cm.set_fk_ltable(candidate_for_train_test, 'ltable_id')
cm.set_fk_rtable(candidate_for_train_test, 'rtable_id')
cm.set_ltable(candidate_for_train_test, selected_Amazon)
cm.set_rtable(candidate_for_train_test, selected_Google)
# 分为训练测试集
sets = em.split_train_test(candidate_for_train_test, train_proportion=0.7, random_state=0)
train_set = sets['train']
test_set = sets['test']
dt = em.DTMatcher(name='DecisionTree', random_state=0)
svm = em.SVMMatcher(name='SVM', random_state=0)
rf = em.RFMatcher(name='RF', random_state=0)
lg = em.LogRegMatcher(name='LogReg', random_state=0)
ln = em.LinRegMatcher(name='LinReg')
nb = em.NBMatcher(name='NaiveBayes')
feature_table = em.get_features_for_matching(selected_Amazon, selected_Google, validate_inferred_attr_types=False)
train_feature_vecs = em.extract_feature_vecs(train_set,
feature_table=feature_table,
attrs_after='gold',
show_progress=False)
result = em.select_matcher([dt, rf, svm, ln, lg, nb], table=train_feature_vecs,
exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'],
k=5,
target_attr='gold', metric_to_select_matcher='f1', random_state=0)
test_feature_vecs = em.extract_feature_vecs(test_set, feature_table=feature_table,
attrs_after=['ltable_name', 'ltable_description', 'ltable_manufacturer',
'ltable_price', 'rtable_name', 'rtable_description',
'rtable_manufacturer', 'rtable_price', 'gold'], show_progress=False)
rf.fit(table=train_feature_vecs,
exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'gold'],
target_attr='gold')
predictions = rf.predict(table=test_feature_vecs, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'ltable_name',
'ltable_description', 'ltable_manufacturer',
'ltable_price', 'rtable_name', 'rtable_description',
'rtable_manufacturer', 'rtable_price', 'gold'],
append=True, target_attr='predicted', inplace=False)
eval_result = em.eval_matches(predictions, 'gold', 'predicted')
em.print_eval_summary(eval_result)
output_path = "output/eval_result" + str(time.time()) + ".txt"
with open(output_path, 'w') as f:
for key, value in six.iteritems(_get_metric(eval_result)):
f.write(key + " : " + value)
f.write('\n')