k-NN法をフルスクラッチ実装

今日はk-NN法(k-NearestNeighbor)を実装する.

k-NN法の概念

k-NN法は教師データから未知データを予測する教師あり学習の一つで, また, 線形回帰などのようにパラメータを最適化するような手法を取らない.
すなわち, 「教師データから識別境界のようなものを学習してから未知データが境界のどちらにあるか」という手法をとらずに, 「それぞれの未知データに対して教師データとの関係を見てどちらのクラスかを予測する」という方法をとる. これを怠惰学習というらしい.
k-NN法に関しては, 名前の通り, 未知データに対して最も近い教師データk個のうち多い方のクラスラベルを予測クラスとする.
まず下図のようなデータがあるとする. 赤色がクラス1, 緑色がクラス0 とする. 星がクラスラベルのわかっていない未知データである.
この未知データがどちらのクラスに入れた方が良いのかを予測したい.

f:id:umashika5555:20181223013754p:plain
データ


今k=3すると未知データから最も近い3つの教師データを抽出できる.
f:id:umashika5555:20181223020457p:plain
3-NN
赤色(クラス1)のデータが1つで緑色(クラス0)のデータが2つである.
よって多数決をとって, 緑(クラス0)データの方が多いので, 未知データは緑(クラス0)と予想する.
式で表すとこのようになるが, わざわざ式で表さなくてもよい.
f:id:umashika5555:20181223015047g:plain

データの生成

k-NN法を実装するために二次元のデータを生成する.
k-NN法は分布を仮定した方法ではないため, 分布から生成する必要は無いと思うが, ラベル付きのデータを生成したかったので取り敢えずガウス分布を使ってデータを生成した.

f:id:umashika5555:20181223015313p:plain
データ

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 乱数固定
np.random.seed(0)
# 平均ベクトル
Mu1, Mu2 = np.array([1, 10]), np.array([9,2])
# 分散共分散行列
sigma11, sigma12, sigma21, sigma22 = 4, 5,  6, 3
SIGMA1, SIGMA2 = np.array([[sigma11**2, np.sqrt(sigma11*sigma12)],[np.sqrt(sigma11*sigma12), sigma12**2]]), np.array([[sigma21**2, np.sqrt(sigma21*sigma22)],[np.sqrt(sigma21*sigma22), sigma22**2]])
# データ生成
values1 = np.random.multivariate_normal(Mu1, SIGMA1, 200)
values2 = np.random.multivariate_normal(Mu2, SIGMA2, 200)
# データ保存
np.save("./data/values1.npy", values1)
np.save("./data/values2.npy", values2)
# 散乱図の表示
plt.scatter(values1[:,0], values1[:,1], color="r")
plt.scatter(values2[:,0], values2[:,1], color="g")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-10, 30)
plt.ylim(-10, 25)
plt.savefig("./img/data.png")
plt.show()

k-NN法による未知データの分類

格子点を未知データとして考え領域分割を行う.
格子点から各データ点までの距離を計算し, 最も近いデータ点k個のラベルから多数決を行う.
閾値が0.5のとき, kが奇数の場合は多数決が成立するのだが, kが偶数の場合で各ラベルが同じ個数あるとき多数決が成立しない.
よってkは奇数にするほうが良いことが分かる.
またkが大きくなると, 汎用性が高くなることが分かる.

f:id:umashika5555:20181223015636p:plain
k=1
f:id:umashika5555:20181223015651p:plain
k=2
f:id:umashika5555:20181223015705p:plain
k=3
f:id:umashika5555:20181223015720p:plain
k=4
f:id:umashika5555:20181223015818p:plain
k=5

f:id:umashika5555:20181223015836p:plain
k=7

f:id:umashika5555:20181223015854p:plain
k=9

f:id:umashika5555:20181223151116p:plain
k=51

import numpy as np
import matplotlib.pyplot as plt

# k-NN法のパラメータkを設定
k = 200
# データのロード
values1 = np.load("./data/values1.npy")
values2 = np.load("./data/values2.npy")

# データを
# 統合, その際に教師情報もつける
data = np.vstack((values1, values2))
labels = np.hstack((np.ones(len(values1)), np.zeros(len(values2))))

# 未知のデータとして格子点を作成
# 2つのデータから最小, 最大のx, yを探す
x_min, x_max = min(np.min(values1[:,0]), np.min(values2[:,0])), max(np.max(values1[:,0]), np.max(values2[:,0]))
y_min, y_max = min(np.min(values1[:,1]), np.min(values2[:,1])), max(np.max(values1[:,1]), np.max(values2[:,1]))
# 格子点を作成
xx = np.linspace(x_min-1, x_max+1, 300)
yy = np.linspace(y_min-1, y_max+1, 300)
xxx, yyy = np.meshgrid(xx, yy)# 格子点(xxx, yyy)が生成された. これを未知データとみなしてk-NN法を適用する.

class0 = []
class1 = []
thresholds = []

# k-NN法の適用
for i, (xx, yy) in enumerate(zip(xxx, yyy)):
    for j, (x, y) in enumerate(zip(xx, yy)):
        # 未知データ(x,y)から既知データdata各点への距離を計算する
        # 距離はユークリッド距離の二乗, すなわちl2ノルムの二乗を計算する
        tmp = data - (x, y)
        distances = np.linalg.norm(tmp, axis=1)
        # 最も小さいスコア上位k点のラベルを参照
        supervisers_index = np.argsort(distances)[:k]
        # print(supervisers_index)
        supervisers = labels[supervisers_index]
        # print(supervisers)
        res = np.sum(supervisers)/ k 
        # if res > 0.5 then x_res in class1
        if res > 0.5:
            class1.append((x, y))
        elif res == 0.5:
            thresholds.append((x, y))
        else:
            class0.append((x, y))
        
class0, class1, thresholds = np.array(class0), np.array(class1), np.array(thresholds)
plt.scatter(values1[:,0], values1[:,1], color="r", marker="x", s=20)
plt.scatter(values2[:,0], values2[:,1], color="g", marker="x", s=20)
plt.scatter(class1[:,0], class1[:,1], color="r", marker="o", s=1, alpha=0.1)
plt.scatter(class0[:,0], class0[:,1], color="g", marker="o", s=1, alpha=0.1)
try:# 奇数の場合, 必ず多数決が成立するので閾値(0.5)とイコールになるものがないためエラー処理しておく
    plt.scatter(thresholds[:,0], thresholds[:,1], color="b", s=1, alpha=0.1)
except:
    pass

plt.title("./img/{}-NN methods".format(k))
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-9, 23)
plt.ylim(-7, 23)
plt.savefig("./img/{}-NN.png".format(k))
plt.show()