From a2b297bf016bdd3f5879045ca4bba86628a33bc5 Mon Sep 17 00:00:00 2001 From: HuangJintao <1447537163@qq.com> Date: Sun, 31 Dec 2023 15:36:50 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E8=81=9A=E7=B1=BB=E4=BD=9C?= =?UTF-8?q?=E5=9B=BE=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- draw/draw_md_cluster.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/draw/draw_md_cluster.py b/draw/draw_md_cluster.py index 7880a16..d106a81 100644 --- a/draw/draw_md_cluster.py +++ b/draw/draw_md_cluster.py @@ -1,3 +1,4 @@ +# https://github.com/GISerWang/Spatio-temporal-Clustering.git import os import numpy as np import matplotlib.pyplot as plt @@ -43,7 +44,7 @@ def DBSCAN(data_, eps, minPts): return labels -def plotFeature(data_, labels_, output_path_): +def plotFeature(md_keys_, data_, labels_, output_path_): clusterNum = len(set(labels_)) fig = plt.figure() scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown'] @@ -52,6 +53,10 @@ def plotFeature(data_, labels_, output_path_): colorStyle = scatterColors[i % len(scatterColors)] subCluster = data_[np.where(labels_ == i)] ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12) + ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度 + ax.set_ylabel(md_keys_[1], rotation=-45) + ax.set_zlabel(md_keys_[2], rotation=0) + plt.title(output_path_.split('\\')[-1].split('.')[0]) plt.savefig(output_path_, dpi=500) plt.show() @@ -62,12 +67,14 @@ if __name__ == '__main__': dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()] for dataset_name in dataset_name_list: absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt' + md_keys = [] with open(absolute_path, 'r') as f: # 读取每一行的md,加入该文件的md列表 data = [] for line in f.readlines(): md_metadata = line.strip().split('\t') md_tuple = eval(md_metadata[1]) + md_keys = list(md_tuple[0].keys())[1:4] md_values = list(md_tuple[0].values()) data.append(md_values[1:4]) if len(data) == 10000: @@ -76,4 +83,4 @@ if __name__ == '__main__': data = np.array(data, dtype=np.float32) labels = DBSCAN(data, 0.5, 30) output_path = outcome_path + rf'\{dataset_name}.png' - plotFeature(data, labels, output_path) + plotFeature(md_keys, data, labels, output_path)