【机器学习】knn算法应用(代码)

152
0
2020年12月14日 09时17分

一、什么是knn算法

 

本章着重对算法部分进行讲解,原理部分不过多叙述,有兴趣的小伙伴可以自行查阅其他文献/文章

 

(一)、介绍

 

  • 邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

 

  • kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。

 

  • 由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

 

20191231142238515

 

右图中,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

 

(二)、算法流程

1.准备数据,对数据进行预处理

 

2.选用合适的数据结构存储训练数据和测试元组

 

3.设定参数,如k

 

4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列

 

5.遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax

 

6.进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。

 

7.遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。

 

8.测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。

 

 

二、代码

例子——sklearn鸢尾花

 

由Fisher在1936年整理,包含4个特征(Sepal.Length(花萼长度)、Sepal.Width(花萼宽度)、Petal.Length(花瓣长度)、Petal.Width(花瓣宽度)),特征值都为正浮点数,单位为厘米。目标值为鸢尾花的分类(Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾))。

 

sklearn例子

 

sklearn提供了一个非常简单快捷的knn调用方式我们直接调用相关函数就可以实现knn预测

 

from sklearn import neighbors

from sklearn import datasets

knn = neighbors.KNeighborsClassifier() # 创建分类器对象

iris = datasets.load_iris() # 加载数据集

print(iris) # 打印数据集

knn.fit(iris.data, iris.target) # 分类

predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])  # 预测

print(predictedLabel) # 打印预测结果

 

 

这个例子的运行效果就不展示出来了,大家可以复制代码自行查看

 

具体代码

 

如果你安装的是Anaconda,csv文件可在 安装目录\Anaconda3\pkgs\scikit-learn-0.19.0-py36h294a771_2\Lib\site-packages\sklearn\datasets\data里面找到

 


# 3.KNN实现Implementation:

import csv

import random

import math

import operator

trainingSet = []

testSet = []

split = 0.67 # 划分训练集和测试机的一个参数


def loadDataset(filename):
	# 加载数据集,输入参数为数据集所在路径
    with open(filename, 'r') as csvfile:
        lines = csv.reader(csvfile)
        dataSet = list(lines)
        for x in range(1, len(dataSet) - 1):
            for y in range(4):
                dataSet[x][y] = float(dataSet[x][y])
            if random.random() < split:
                trainingSet.append(dataSet[x])
            else:
                testSet.append(dataSet[x])


def euclideanDistance(instance1, instance2, length):
    distance = 0

    for x in range(length):
        distance += pow((instance1[x] - instance2[x]), 2)

    return math.sqrt(distance)


def getNeighbors(trainingSet, testInstance, k):
    distances = []

    length = len(testInstance) - 1

    for x in range(len(trainingSet)):
        dist = euclideanDistance(testInstance, trainingSet[x], length)

        distances.append((trainingSet[x], dist))

    distances.sort(key=operator.itemgetter(1))

    neighbors = []

    for x in range(k):
        neighbors.append(distances[x][0])

    return neighbors


def getResponse(neighbors):
    classVotes = {}

    for x in range(len(neighbors)):

        response = neighbors[x][-1]

        if response in classVotes:

            classVotes[response] += 1

        else:

            classVotes[response] = 1

    sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)

    return sortedVotes[0][0]


def getAccuracy(testSet, predictions):
    correct = 0

    for x in range(len(testSet)):

        if testSet[x][-1] == predictions[x]:
            correct += 1

    return (correct / float(len(testSet))) * 100.0


def main():
    # 预处理

    loadDataset(r'iris.csv')

    print('Train set: ' + repr(len(trainingSet)))

    print('Test set: ' + repr(len(testSet)))

    # 生成预测模型

    predictions = []

    k = 3

    for x in range(len(testSet)):
        neighbors = getNeighbors(trainingSet, testSet[x], k)

        result = getResponse(neighbors)

        predictions.append(result)

        print('> predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))
	
	# 计算准确率
    accuracy = getAccuracy(testSet, predictions)

    print('Accuracy: ' + repr(accuracy) + '%')


main()


 

结果:

 

Train set: 96
Test set: 53
> predicted='0', actual='0'
> predicted='0', actual='0'
> predicted='0', actual='0'
> predicted='0', actual='0'
> predicted='0', actual='0'
> predicted='0', actual='0'
> predicted='0', actual='0'
.
.
.
.
Accuracy: 98.11320754716981%

 

发表评论

后才能评论