knn演算法的原理與實現

knn演算法的原理與實現

來自專欄機器學習知識點16 人贊了文章

機器學習基礎演算法python代碼實現可參考:zlxy9892/ml_code

1 原理

knn 是機器學習領域非常基礎的一種演算法,可解決分類或者回歸問題,如果是剛開始入門學習機器學習,knn是一個非常好的入門選擇,它有著便於理解,實現簡單的特點,那麼下面就開始介紹其演算法的原理。

首先,knn演算法的基本法則是:相同類別的樣本之間在特徵空間中應當聚集在一起。

如下圖所示,假設我們現在紅、綠、藍三種顏色的點,分布在二維空間中,這就對應了分類任務中的訓練樣點包含了三個類別,且特徵數量為2。如果現在我們希望推測圖中空心圓的那個點是屬於那個類別,那麼knn演算法將會計算該待推測點與所有訓練樣點之間的距離,並且挑選出距離最小的k個樣點(此處設定k=4),則圖中與連接的4個點將被視為推測空心點(待推測點)類別的參考依據。顯然,由於這4個點均為紅色類別,則該待推測點即被推測為紅色類別。

knn圖解1

再看另一種情況,如果待推測點在中間的某個位置(如下圖所示),則同樣也計算出與其最鄰近的4個樣本點,而此時這4個樣本點包含了3個類別(1紅、1藍、2綠),針對這樣的情況,knn演算法通常採用投票法來進行類別推測,即找出k個樣本點中類別出現次數最多的那個類別,因此該待推測點的類型值即被推測為綠色類別。

knn圖解2

knn的原理就是這麼樸素而簡單,那麼有人會說,這麼簡單的演算法,相比於現在流行的其他較為複雜的機器學習演算法,比如神經網路、隨機森林等,有什麼存在的價值呢?

回答這個問題,我們不妨回想一下機器學習中著名的沒有免費午餐定理(no free lunch theorem),對於所有的任意一個問題(學習任務)來說,並不存在最好的模型,反之恰恰提醒了我們,對於特定的學習任務,我們需要去考慮最適合該問題學習器,也就是具體問題具體分析的哲學道理。

那麼,knn也必然有其存在的價值,比如說,我們現在拿到一個學習任務,需要去選擇一個學習器去解決該問題,而且也沒有任何對該問題研究的前車之鑒,那麼,從何下手?通常,我們不需要上來就用一個神經網路模型或者強大的集成學習模型去做,而是可以先用一用簡單模型做一下「試探」,比如knn就是一個很好的選擇,這樣的「試探」的好處在哪裡呢?我們知道,knn本質上屬於懶惰學習的代表,也就是說它根本就沒有用訓練數據去擬合一個什麼模型,而是直接用top-k個近鄰的樣本點做了個投票就完成了分類任務,那麼如果這樣一個懶惰模型在當前的問題上就已經能夠得到一個較高的精度,則我們可以認為當前的學習任務是比較簡單的,不同類別的樣本點在特徵空間中的分布較為清晰,無需採用複雜模型,反之,若knn得到的精度很低,則傳達給我們的信息是:該學習任務有點複雜,往往伴隨著的消息就是,當前問題中不同類別樣本點在特徵空間中分布不是清晰,通常是非線性可分的,需要我們去調用更強大的學習器。從而,一個簡單的knn機器學習演算法恰恰可以幫助建模者對問題的複雜度有一個大概的判斷,協助我們接下來如何展開進一步的工作:繼續挖掘特徵工程、或者是更換複雜模型等。

2 演算法實現

下面使用python自己親手實現一個knn分類器。

代碼鏈接 : github.com/zlxy9892/ml_

首先創建一個名為knn.py的文件,用於構建knn實現的類,代碼如下。

# -*- coding: utf-8 -*-import numpy as npimport operatorclass KNN(object): def __init__(self, k=3): self.k = k def fit(self, x, y): self.x = x self.y = y def _square_distance(self, v1, v2): return np.sum(np.square(v1-v2)) def _vote(self, ys): ys_unique = np.unique(ys) vote_dict = {} for y in ys: if y not in vote_dict.keys(): vote_dict[y] = 1 else: vote_dict[y] += 1 sorted_vote_dict = sorted(vote_dict.items(), key=operator.itemgetter(1), reverse=True) return sorted_vote_dict[0][0] def predict(self, x): y_pred = [] for i in range(len(x)): dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))] sorted_index = np.argsort(dist_arr) top_k_index = sorted_index[:self.k] y_pred.append(self._vote(ys=self.y[top_k_index])) return np.array(y_pred) def score(self, y_true=None, y_pred=None): if y_true is None and y_pred is None: y_pred = self.predict(self.x) y_true = self.y score = 0.0 for i in range(len(y_true)): if y_true[i] == y_pred[i]: score += 1 score /= len(y_true) return score

接下來創建train.py文件,用於生成隨機的訓練樣本數據,並調用knn類完成分類任務,並計算推測精度。

# -*- coding: utf-8 -*-import numpy as npimport matplotlib.pyplot as pltfrom knn import *# data generationnp.random.seed(314)data_size_1 = 300x1_1 = np.random.normal(loc=5.0, scale=1.0, size=data_size_1)x2_1 = np.random.normal(loc=4.0, scale=1.0, size=data_size_1)y_1 = [0 for _ in range(data_size_1)]data_size_2 = 400x1_2 = np.random.normal(loc=10.0, scale=2.0, size=data_size_2)x2_2 = np.random.normal(loc=8.0, scale=2.0, size=data_size_2)y_2 = [1 for _ in range(data_size_2)]x1 = np.concatenate((x1_1, x1_2), axis=0)x2 = np.concatenate((x2_1, x2_2), axis=0)x = np.hstack((x1.reshape(-1,1), x2.reshape(-1,1)))y = np.concatenate((y_1, y_2), axis=0)data_size_all = data_size_1+data_size_2shuffled_index = np.random.permutation(data_size_all)x = x[shuffled_index]y = y[shuffled_index]split_index = int(data_size_all*0.7)x_train = x[:split_index]y_train = y[:split_index]x_test = x[split_index:]y_test = y[split_index:]# visualize dataplt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker=.)plt.show()plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker=.)plt.show()# data preprocessingx_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))x_test = (x_test - np.min(x_test, axis=0)) / (np.max(x_test, axis=0) - np.min(x_test, axis=0))# knn classifierclf = KNN(k=3)clf.fit(x_train, y_train)print(train accuracy: {:.3}.format(clf.score()))y_test_pred = clf.predict(x_test)print(test accuracy: {:.3}.format(clf.score(y_test, y_test_pred)))

至此,knn演算法原理闡述與代碼實現介紹完畢。

推薦閱讀:

Excel x Python的奇妙反應
黃哥Python:回答知乎上二個有關for 循環問題
入門:用Python抓取網頁上的免費賬號(六)
入門:用Python抓取網頁上的免費賬號(五)
TMDb電影數據可視化分析

TAG:機器學習 | Python |