You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2.0 KiB
2.0 KiB
4.3 动手实现k近邻
k近邻算法非常简单,可以很方便的使用python
实现,实现代码如下:
#encoding=utf8
import numpy as np
def knn_clf(k,train_feature,train_label,test_feature):
'''
input:
k(int):最近邻样本个数
train_feature(ndarray):训练样本特征
train_label(ndarray):训练样本标签
test_feature(ndarray):测试样本特征
output:
predict(ndarray):测试样本预测标签
'''
#初始化预测结果
predict = np.zeros(test_feature.shape[0],).astype('int')
#对测试集每一个样本进行遍历
for i in range(test_feature.shape[0]):
#测试集第i个样本到训练集每一个样本的距离
distance = np.sqrt(np.power(np.tile(test_feature[i],(train_feature.shape[0],1))-train_feature,2).sum(axis=1))
#最近的k个样本的距离
distance_k = np.sort(distance)[:k]
#最近的k个样本的索引
nearest = np.argsort(distance)[:k]
#最近的k个样本的标签
topK = [train_label[i] for i in nearest]
#初始化进行投票的字典,字典的键为标签,值为投票分数
votes = {}
#初始化最大票数
max_count = 0
#进行投票
for j,label in enumerate(topK):
#如果标签在字典的键中则投票计分
if label in votes.keys():
votes[label] += 1/(distance_k[j]+1e-10)#防止分母为0
#如果评分最高则将预测值更新为对应标签
if votes[label] > max_count:
max_count = votes[label]
predict[i] = label
#如果标签不在字典中则将标签加入字典的键,同时计入相应的分数
else:
votes[label] = 1/(distance_k[j]+1e-10)
if votes[label] > max_count:
max_count = votes[label]
predict[i] = label
return predict