機械学習図鑑をよんでる

(理解できていないことが多いので、いくつか思い違いをしていたことなど随時追記修正します。すみません。)

ようやくPythonのサンプルコードをすこし扱えるようになってきたので、メモがてら足跡として残しておこうかと思った。

多項式回帰と正則化の話で。サンプルコードを理解するためにコメントを加えながら勉強中。

import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt


# トレーニングデータとテストデータをつくる。
train_size = 20
test_size = 12
train_X = np.random.uniform(low=0, high=1.2, size=train_size)
test_X = np.random.uniform(low=0.1, high=1.3, size=test_size)
train_y = np.sin(train_X * 2 * np.pi) + np.random.normal(0, 0.2, train_size)
test_y = np.sin(test_X * 2 * np.pi) + np.random.normal(0, 0.2, test_size)

# 多項式回帰モデルのインスタンスを作成。
poly = PolynomialFeatures(6) # 第一引数は次数

# reshapeで、1行train_X列からtrain_X行1列に。
# fit_transformは。イメージとしては学習した結果のパラメータをかえしてくる。
# PolynomialFeaturesだとxに何がしか与えたときの、1, x, x^2, ...が戻り値になる。
# この時点では、Ridge回帰で利用する損失関数計算用の特徴量を計算しているだけ。
train_poly_X = poly.fit_transform(train_X.reshape(train_size, 1))
test_poly_X = poly.fit_transform(test_X.reshape(test_size, 1))

# ここでRidge回帰を使って学習。(正則化つき)
model = Ridge(alpha=1.0)
model.fit(train_poly_X, train_y)

# モデルに沿って計算
train_pred_y = model.predict(train_poly_X)
test_pred_y = model.predict(test_poly_X)

# 誤差を計算。
print(mean_squared_error(train_pred_y, train_y))
print(mean_squared_error(test_pred_y, test_y))

# 学習データのプロット
plt.scatter(train_X, train_y)

# 多項式回帰の概形
graph_x = np.linspace(0, 1.0)
graph_poly_x = poly.fit_transform(graph_x.reshape(graph_x.size, 1))
graph_y = model.predict(graph_poly_x)
plt.plot(graph_x, graph_y)
plt.show()

fitとか、fit_transformとか、何をしているのかいまいちわからなかったけど以下のページに書いてあることを元にこうかなぁ、と思うことをつらつら書いてみる。

fit関数は、与えられたデータにそって多項式を作る。

で、transformは与えられたデータをfitしたパラメータに沿って変換する。少なくともPolynomialFeaturesクラスのtransformにはいわゆる学習の過程はないように思うが。。

fit_transformの挙動は、モデルによって違うんだろうな、とは思いつつ。PolynomialFeaturesでは、以下の関数のx^0, x^1, x^2...を1行として、元データの数だけ行を返す。

この系列に、Ridge回帰を適用することで、予測用のモデルを作る。

イメージとしては、フィルタリングのような、行列の演算だけを抜き出して、そこだけ計算するような感じだと思うが、何せまだわかっていないことが多い。

とりあえず今日はここまで。


この記事が気に入ったらサポートをしてみませんか?