修复线程池bug

HuangJintao
HuangJintao 1 year ago
parent 59dc97d2e2
commit 34b2c72646

1
.gitignore vendored

@ -1 +1,2 @@
/deprecated/ /deprecated/
/datasets/

@ -23,7 +23,7 @@ if __name__ == '__main__':
# todo # todo
# 距离度量用户可设置? # 距离度量用户可设置?
# 使用drop删除特征向量中的列(如删除id相关特征) # 使用drop删除特征向量中的列(如删除id相关特征)
run(3) # 迭代3轮 run(1) # 迭代3轮
# ml_er(1) # ml_er(1)
# todo 将优化结果与参数输出到文件中 # todo 将优化结果与参数输出到文件中
# 通过ml_entity_resolver.ml_er()输出,同时输出参数配置信息 # 通过ml_entity_resolver.ml_er()输出,同时输出参数配置信息

@ -116,7 +116,8 @@ def inference_from_record_pairs(path, threshold, target_col):
lock = manager.Lock() lock = manager.Lock()
if len(minimal_vio) == 0: if len(minimal_vio) == 0:
return md_list, [] return md_list, []
pool = multiprocessing.Pool(len(minimal_vio)) pool_size = len(minimal_vio) if len(minimal_vio) < 61 else 60
pool = multiprocessing.Pool(pool_size)
# tmp = copy.deepcopy(minimal_vio) # tmp = copy.deepcopy(minimal_vio)
with manager: with manager:
proxy_minimal_vio = manager.list(minimal_vio) proxy_minimal_vio = manager.list(minimal_vio)
@ -145,7 +146,8 @@ def get_mds_metadata(md_list, dataset_path, target_col):
manager = multiprocessing.Manager() manager = multiprocessing.Manager()
if len(md_list) == 0: if len(md_list) == 0:
return [] return []
pool = multiprocessing.Pool(len(md_list)) pool_size = len(md_list) if len(md_list) < 61 else 60
pool = multiprocessing.Pool(pool_size)
result = [] result = []
with manager: with manager:
for _ in md_list: for _ in md_list:

@ -198,7 +198,7 @@ def ml_er(iter_round: int, config: Configuration = None, ):
elif config["ml_blocker"] == "attr_equiv": elif config["ml_blocker"] == "attr_equiv":
blocker = em.AttrEquivalenceBlocker() blocker = em.AttrEquivalenceBlocker()
candidate = blocker.block_tables(selected_ltable, selected_rtable, config["block_attr"], config["block_attr"], candidate = blocker.block_tables(selected_ltable, selected_rtable, config["block_attr"], config["block_attr"],
l_output_attrs=selected_attrs, r_output_attrs=selected_attrs, n_jobs=-1) l_output_attrs=selected_attrs, r_output_attrs=selected_attrs)
else: else:
matcher = em.RFMatcher(name='RF', random_state=0) matcher = em.RFMatcher(name='RF', random_state=0)
blocker = em.OverlapBlocker() blocker = em.OverlapBlocker()

@ -1,6 +1,6 @@
ltable_path = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amazon.csv' ltable_path = 'datasets\\Amazon.csv'
rtable_path = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/GoogleProducts.csv' rtable_path = 'datasets\\GoogleProducts.csv'
mapping_path = '/home/w/PycharmProjects/py_entitymatching/py_entitymatching/datasets/end-to-end/Amazon-GoogleProducts/Amzon_GoogleProducts_perfectMapping.csv' mapping_path = 'datasets\\Amzon_GoogleProducts_perfectMapping.csv'
mapping_lid = 'idAmazon' # mapping表中左表id名 mapping_lid = 'idAmazon' # mapping表中左表id名
mapping_rid = 'idGoogleBase' # mapping表中右表id名 mapping_rid = 'idGoogleBase' # mapping表中右表id名
ltable_id = 'id' # 左表id字段名称 ltable_id = 'id' # 左表id字段名称

Loading…
Cancel
Save