From d03d300f8c65de5332cad821a63a423805f287d9 Mon Sep 17 00:00:00 2001 From: HuangJintao <1447537163@qq.com> Date: Sat, 6 Jan 2024 15:49:49 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=E4=B8=8D=E5=8C=B9=E9=85=8D=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=82=B9=E4=B9=9F=E5=8A=A0=E4=B8=8A=E8=81=9A?= =?UTF-8?q?=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- draw/draw_md_cluster_with_data_point.py | 37 +++++++++++++++---------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/draw/draw_md_cluster_with_data_point.py b/draw/draw_md_cluster_with_data_point.py index 5b53606..f373c88 100644 --- a/draw/draw_md_cluster_with_data_point.py +++ b/draw/draw_md_cluster_with_data_point.py @@ -8,16 +8,18 @@ from draw_md_cluster import DBSCAN from ml_er.ml_entity_resolver import build_col_pairs_sim_tensor_dict -def plot(md_keys_, data_, data_points_, labels_, output_path_): +def plot(md_keys_, md_data_, pre_match_points_, pre_mismatch_points_, labels_, output_path_): clusterNum = len(set(labels_)) fig = plt.figure() scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown'] ax = fig.add_subplot(111, projection='3d') for i in range(-1, clusterNum): colorStyle = scatterColors[i % len(scatterColors)] - subCluster = data_[np.where(labels_ == i)] + subCluster = md_data_[np.where(labels_ == i)] ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12) - ax.scatter(data_points_[:, 0], data_points_[:, 1], data_points_[:, 2], c='#66CCFF', s=12, marker='x') + ax.scatter(pre_match_points_[:, 0], pre_match_points_[:, 1], pre_match_points_[:, 2], c='#66CCFF', s=12, marker='x') + if pre_mismatch_points_.shape[0] > 0: + ax.scatter(pre_mismatch_points_[:, 0], pre_mismatch_points_[:, 1], pre_mismatch_points_[:, 2], c='#006666', s=12, marker='x') ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度 ax.set_ylabel(md_keys_[1], rotation=-45) ax.set_zlabel(md_keys_[2], rotation=0) @@ -35,32 +37,37 @@ if __name__ == '__main__': predictions = outcome_path + rf'\{dataset_name}' + config_dir + r'\predictions.csv' # prediction路径 pred = pd.read_csv(predictions) pred = pred.astype(str) - pred = pred[pred['predicted'] == str(1)] + # pred = pred[pred['predicted'] == str(1)] sim_tensor_dict = build_col_pairs_sim_tensor_dict(pred) # 选取的三个字段 md_keys = [] with open(absolute_path, 'r') as f: # 读取每一行的md,加入该文件的md列表 - data = [] + md_data = [] for line in f.readlines(): md_metadata = line.strip().split('\t') md_tuple = eval(md_metadata[1]) md_keys = list(md_tuple[0].keys())[1:4] md_values = list(md_tuple[0].values()) - data.append(md_values[1:4]) - if len(data) == 10000: + md_data.append(md_values[1:4]) + if len(md_data) == 10000: break - data_points = [] - for _ in range(len(pred)): + pre_match_points = [] + pre_mismatch_points = [] + for _ in pred.itertuples(): data_point_value = [] for key in md_keys: sim_tensor = sim_tensor_dict[key] - data_point_value.append(round(float(sim_tensor[_]), 4)) - data_points.append(data_point_value) + data_point_value.append(round(float(sim_tensor[_[0]]), 4)) + if getattr(_, 'predicted') == str(1): + pre_match_points.append(data_point_value) + elif getattr(_, 'predicted') == str(0): + pre_mismatch_points.append(data_point_value) - data = np.array(data, dtype=np.float32) - data_points = np.array(data_points, dtype=np.float32) - labels = DBSCAN(data, 0.5, 30) + md_data = np.array(md_data, dtype=np.float32) + pre_match_points = np.array(pre_match_points, dtype=np.float32) + pre_mismatch_points = np.array(pre_mismatch_points, dtype=np.float32) + labels = DBSCAN(md_data, 0.5, 30) output_path = outcome_path + rf'\{dataset_name}_MD&data.png' - plot(md_keys, data, data_points, labels, output_path) + plot(md_keys, md_data, pre_match_points, pre_mismatch_points, labels, output_path)