You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/inference_from_record_pairs.py

152 lines
5.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import pandas as pd
import time
import Levenshtein
import copy
def my_Levenshtein_ratio(str1, str2):
return 1 - Levenshtein.distance(str1, str2) / max(len(str1), len(str2))
def if_minimal(md, md_list, target_col):
# 假设这个md是minimal
minimal = True
for _ in md_list:
if _ != md:
# 假设列表中每一个md都使当前md不minimal
exist = True
# 如果左边任何一个大于,则假设不成立
for col in list(set(_.keys()) - set([target_col])):
if _[col] > md[col]:
exist = False
# 如果右边小于,假设也不成立
if _[target_col] < md[target_col]:
exist = False
# 任何一次假设成立当前md不minimal
if exist:
minimal = False
break
return minimal
def satisfy_confidence(md, df, conf_thresh, target_col):
support = 0
support_plus = 0
for row1 in df.itertuples():
i = row1[0]
df_slice = df[i + 1:]
for row2 in df_slice.itertuples():
left_satisfy = True
both_satisfy = True
for col in df.columns.values.tolist():
sim = my_Levenshtein_ratio(getattr(row1, col), getattr(row2, col))
if col == target_col:
if sim < 1:
both_satisfy = False
else:
if sim < md[col]:
left_satisfy = False
both_satisfy = False
if left_satisfy:
support += 1
if both_satisfy:
support_plus += 1
confidence = support_plus / support
return confidence >= conf_thresh
def inference_from_record_pairs(path, threshold, target_col):
data = pd.read_csv(path, low_memory=False, encoding='ISO-8859-1')
data = data.astype(str)
columns = data.columns.values.tolist()
md_list = []
minimal_vio = []
init_md = {}
for col in columns:
init_md[col] = 1 if col == target_col else 0
md_list.append(init_md)
for row1 in data.itertuples():
# 获取当前行的索引,从后一行开始切片
i = row1[0]
data1 = data[i + 1:]
for row2 in data1.itertuples():
violated_mds = []
# sims是两行的相似度
sims = {}
for col in columns:
similarity = my_Levenshtein_ratio(getattr(row1, col), getattr(row2, col))
sims[col] = similarity
# 寻找violated md,从md列表中删除并加入vio列表
for md in md_list:
lhs_satis = True
rhs_satis = True
for col in list(set(columns) - set([target_col])):
if sims[col] < md[col]:
lhs_satis = False
if sims[target_col] < md[target_col]:
rhs_satis = False
if lhs_satis == True and rhs_satis == False:
md_list.remove(md)
violated_mds.append(md)
minimal_vio.extend(violated_mds)
for vio_md in violated_mds:
# 特殊化右侧,我们需要右侧百分百相似,其实不需要降低右侧阈值
# if sims[target_col] >= threshold:
# new_rhs = sims[target_col]
# spec_r_md = copy.deepcopy(vio_md)
# spec_r_md[target_col] = new_rhs
# if if_minimal(spec_r_md, md_list, target_col):
# md_list.append(spec_r_md)
# 特殊化左侧
for col in list(set(columns) - set([target_col])):
if sims[col] + 0.001 <= 1:
spec_l_md = copy.deepcopy(vio_md)
spec_l_md[col] = threshold if sims[col] < threshold else sims[col] + 0.001
if if_minimal(spec_l_md, md_list, target_col):
md_list.append(spec_l_md)
for vio in minimal_vio:
if not if_minimal(vio, md_list, target_col):
minimal_vio.remove(vio)
for _ in minimal_vio:
if not satisfy_confidence(_, data, 0.8, target_col):
minimal_vio.remove(_)
list1 = copy.deepcopy(minimal_vio)
for _ in list1:
if not if_minimal(_, minimal_vio, target_col):
minimal_vio.remove(_)
return md_list, minimal_vio
if __name__ == '__main__':
# 目前可以仿照这个main函数写
path = "/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/output/8.14/TP_single_tuple.csv"
start = time.time()
# 输入csv文件路径md左侧相似度阈值md右侧目标字段
# 输出2个md列表列表1中md无violation,列表2中md有violation但confidence满足阈值(0.8)
# 例如此处输入参数要求md左侧相似度字段至少为0.7,右侧指向'id'字段
mds, mds_vio = inference_from_record_pairs(path, 0.7, 'id')
# 将列表1写入本地路径需自己修改
md_path = '/home/w/A-New Folder/8.14/Goods Dataset/TP_md_list.txt'
with open(md_path, 'w') as f:
for _ in mds:
f.write(str(_)+'\n')
# 将列表2写入本地路径需自己修改
vio_path = '/home/w/A-New Folder/8.14/Goods Dataset/TP_vio_list.txt'
with open(vio_path, 'w') as f:
for _ in mds_vio:
f.write(str(_)+'\n')
print(time.time() - start)