紙媒体で管理するとなくなりがちなのでブログで進捗などを管理することにしました
※殆どの記事は自分自身のためだけにかいています.他人に見せられるレベルには至っていません...

【scikit-learn】線形回帰

研究で識別/分類問題について実装することになるだろうので,scikit-learnの公式ドキュメント(http://scikit-learn.org/stable/index.html)のうち
classificationを勉強していこうと思う.
今回は線形回帰について(なぜ回帰?)

scikit-learnは

#modelの定義
regression = sklearn.linear_model.LinearRegression()

#学習
#定義したモデル.fit(トレーニングデータ, トレーニングデータのラベル)
regression.fit(X_train,y_train)

#予測
#予測ラベル = 定義したモデル.predict(テストデータ)
y_pred = regression.predict(X_test)

という感じでやるらしい.
線形回帰はMSEを最小化するという方法である.

糖尿病患者

diabetesデータセット
age, sex, bmi, map, tc, ldl, hdl, tch, ltg, gluという10個の属性が入っている.
目的変数は1年後の疾患の進行状況らしい.
これらを線形回帰によって当てはめてみる.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

# load the diabete dataset
diabetes = datasets.load_diabetes()
categories = {0:"age",1:"sex",2:"bmi",3:"map",4:"tc",5:"ldl",6:"hdl",7:"tch",8:"ltg",9:"glu"}

#全体の図を定義
plt.figure()

for key,val in categories.items():
    print(key)
  #データをある属性のみに絞る
    diabetes_X = diabetes.data[:, np.newaxis, key]
    #データをトレーニングとテストに分割
    diabetes_X_train = diabetes_X[:-20]
    diabetes_X_test  = diabetes_X[-20:]
    #目的変数をトレーニングとテストに分割
    diabetes_y_train = diabetes.target[:-20]
    diabetes_y_test  = diabetes.target[-20:]
    #線形回帰モデルの宣言
    regr = linear_model.LinearRegression()
    #学習
    regr.fit(diabetes_X_train,diabetes_y_train)
    #予測
    diabetes_y_pred = regr.predict(diabetes_X_test)
    #係数
    print("Cofficients:\n",regr.coef_)
    #MSE計算
    print("mean squared error:%.2f"%(mean_squared_error(diabetes_y_test,diabetes_y_pred)))
    #分散
    print("Variance score:%.2f"%r2_score(diabetes_y_test,diabetes_y_pred))
    #subplot
    fig_loc = "25" + str(key)
    plt.subplot(fig_loc)
    #plot outputs
    plt.scatter(diabetes_X_test,diabetes_y_test,color="black")
    plt.plot(diabetes_X_test,diabetes_y_pred,color="blue",linewidth=3)
    #plt.xticks(())
    #plt.yticks(())
    plt.title(categories[key])
plt.show()

結果はこのようになった.
sexのような2値に対しては当然ながら線形回帰は使えない.(そもそも使う意味がない.)
bmiやhdl, ltgあたりは割りと線形回帰で良い結果が出た方だと思う.
f:id:umashika5555:20170924230252p:plain