You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/draw/draw_md_cluster_with_data_p...

74 lines
3.5 KiB

# # 将数据点和MD一起聚类
# import os
# import numpy as np
# import pandas as pd
# from matplotlib import pyplot as plt
#
# from draw_md_cluster import DBSCAN
# from ml_er.ml_entity_resolver import build_col_pairs_sim_tensor_dict
#
#
# def plot(md_keys_, md_data_, pre_match_points_, pre_mismatch_points_, labels_, output_path_):
# clusterNum = len(set(labels_))
# fig = plt.figure()
# scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
# ax = fig.add_subplot(111, projection='3d')
# for i in range(-1, clusterNum):
# colorStyle = scatterColors[i % len(scatterColors)]
# subCluster = md_data_[np.where(labels_ == i)]
# ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12)
# ax.scatter(pre_match_points_[:, 0], pre_match_points_[:, 1], pre_match_points_[:, 2], c='#66CCFF', s=12, marker='x')
# if pre_mismatch_points_.shape[0] > 0:
# ax.scatter(pre_mismatch_points_[:, 0], pre_mismatch_points_[:, 1], pre_mismatch_points_[:, 2], c='#006666', s=12, marker='x')
# ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度
# ax.set_ylabel(md_keys_[1], rotation=-45)
# ax.set_zlabel(md_keys_[2], rotation=0)
# plt.title(output_path_.split('\\')[-1].split('.')[0])
# plt.savefig(output_path_, dpi=500)
# plt.show()
#
#
# if __name__ == '__main__':
# outcome_path = r'E:\Data\Research\Outcome'
# config_dir = r'\Magellan+Smac+roberta-large-nli-stsb-mean-tokens+inter-0.5'
# dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()]
# for dataset_name in dataset_name_list:
# absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt' # MD路径
# predictions = outcome_path + rf'\{dataset_name}' + config_dir + r'\predictions.csv' # prediction路径
# pred = pd.read_csv(predictions)
# pred = pred.astype(str)
# # pred = pred[pred['predicted'] == str(1)]
# sim_tensor_dict = build_col_pairs_sim_tensor_dict(pred)
# # 选取的三个字段
# md_keys = []
# with open(absolute_path, 'r') as f:
# # 读取每一行的md加入该文件的md列表
# md_data = []
# for line in f.readlines():
# md_metadata = line.strip().split('\t')
# md_tuple = eval(md_metadata[1])
# md_keys = list(md_tuple[0].keys())[1:4]
# md_values = list(md_tuple[0].values())
# md_data.append(md_values[1:4])
# if len(md_data) == 10000:
# break
#
# pre_match_points = []
# pre_mismatch_points = []
# for _ in pred.itertuples():
# data_point_value = []
# for key in md_keys:
# sim_tensor = sim_tensor_dict[key]
# data_point_value.append(round(float(sim_tensor[_[0]]), 4))
# if getattr(_, 'predicted') == str(1):
# pre_match_points.append(data_point_value)
# elif getattr(_, 'predicted') == str(0):
# pre_mismatch_points.append(data_point_value)
#
# md_data = np.array(md_data, dtype=np.float32)
# pre_match_points = np.array(pre_match_points, dtype=np.float32)
# pre_mismatch_points = np.array(pre_mismatch_points, dtype=np.float32)
# labels = DBSCAN(md_data, 0.5, 30)
# output_path = outcome_path + rf'\{dataset_name}_MD&data.png'
# plot(md_keys, md_data, pre_match_points, pre_mismatch_points, labels, output_path)