调整聚类作图结果

MD-metrics-HPO
HuangJintao 11 months ago
parent cada3863bf
commit a2b297bf01

@ -1,3 +1,4 @@
# https://github.com/GISerWang/Spatio-temporal-Clustering.git
import os import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -43,7 +44,7 @@ def DBSCAN(data_, eps, minPts):
return labels return labels
def plotFeature(data_, labels_, output_path_): def plotFeature(md_keys_, data_, labels_, output_path_):
clusterNum = len(set(labels_)) clusterNum = len(set(labels_))
fig = plt.figure() fig = plt.figure()
scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown'] scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
@ -52,6 +53,10 @@ def plotFeature(data_, labels_, output_path_):
colorStyle = scatterColors[i % len(scatterColors)] colorStyle = scatterColors[i % len(scatterColors)]
subCluster = data_[np.where(labels_ == i)] subCluster = data_[np.where(labels_ == i)]
ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12) ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12)
ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度
ax.set_ylabel(md_keys_[1], rotation=-45)
ax.set_zlabel(md_keys_[2], rotation=0)
plt.title(output_path_.split('\\')[-1].split('.')[0])
plt.savefig(output_path_, dpi=500) plt.savefig(output_path_, dpi=500)
plt.show() plt.show()
@ -62,12 +67,14 @@ if __name__ == '__main__':
dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()] dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()]
for dataset_name in dataset_name_list: for dataset_name in dataset_name_list:
absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt' absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt'
md_keys = []
with open(absolute_path, 'r') as f: with open(absolute_path, 'r') as f:
# 读取每一行的md加入该文件的md列表 # 读取每一行的md加入该文件的md列表
data = [] data = []
for line in f.readlines(): for line in f.readlines():
md_metadata = line.strip().split('\t') md_metadata = line.strip().split('\t')
md_tuple = eval(md_metadata[1]) md_tuple = eval(md_metadata[1])
md_keys = list(md_tuple[0].keys())[1:4]
md_values = list(md_tuple[0].values()) md_values = list(md_tuple[0].values())
data.append(md_values[1:4]) data.append(md_values[1:4])
if len(data) == 10000: if len(data) == 10000:
@ -76,4 +83,4 @@ if __name__ == '__main__':
data = np.array(data, dtype=np.float32) data = np.array(data, dtype=np.float32)
labels = DBSCAN(data, 0.5, 30) labels = DBSCAN(data, 0.5, 30)
output_path = outcome_path + rf'\{dataset_name}.png' output_path = outcome_path + rf'\{dataset_name}.png'
plotFeature(data, labels, output_path) plotFeature(md_keys, data, labels, output_path)

Loading…
Cancel
Save