見出し画像

SciKit-Learn,Keras,PyTorchの違いってなに?- Scikit-learn

Pythonを使って機械学習、ディープラーニングを行うときに使うものとして、SciKit-Learn,Keras,PyTorchがよく出てきます。

何が違うかわかりにくいのでちょっと整理してみます。

scikit-learnは、機械学習ライブラリ。サポートベクターマシン、ランダムフォレストなどのモデル(機械学習アルゴリズム)を使って実行します。scikit-learnのモデルについては以下参考になります。

それに対して、Keras,PyTorchはディープラーニングのフレームワークで、フレームワーク自身でモデルを最適化して実行することができるものです。

scikit-learnが人が決めたアルゴニズムで動くことに対してKeras,PyTorchはフレームワーク自身で試行錯誤して最適なパラメーターを構築して実行することができるものとなっています。

SciKit-Learn。公式サイトと参考サイトです。

以下をGoogleのColabで実行してみます。参考サイトのコードに追加で予測するコードを追加しています。

RandomForestClassifierを使っています。

from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split

# irisデータの読み込み
iris = datasets.load_iris()

# 特徴量とターゲットの取得
data       = iris['data']
target     = iris['target']

#学習データとテストデータを分割
train_data,test_data,train_target,test_target = train_test_split(data,target,test_size=0.5)

#モデル学習
model = RandomForestClassifier(n_estimators=100)
model.fit(train_data, train_target)

# 正解率
score = model.score(test_data, test_target)
print(score)

#予測
predict = model.predict(test_data)
print(predict)

学習した精度の結果が

score = model.score(test_data, test_target)
print(score)

で計算され、0.9733333333333334と出てきます。97%の精度ということです、

予測は

predict = model.predict(test_data)
print(predict)

で計算でき、

[2 2 0 2 1 2 2 2 2 1 0 2 2 1 2 1 2 2 0 2 0 1 2 1 1 0 2 1 1 1 1 0 0 1 0 0 2
1 2 0 0 1 0 2 1 2 0 0 0 2 1 2 2 0 0 1 0 0 0 0 2 2 2 2 0 0 2 2 2 1 1 2 0 1
2]

0,1,2ということで3つの種類の"iris"「あやめ」が数字で表されています。テスト結果が97%ぐらいの正解率ということでかなり良い結果ですね。

ここでのポイントは以下の部分で

# 特徴量とターゲットの取得
data = iris['data']
target = iris['target']

#学習データとテストデータを分割
train_data,test_data,train_target,test_target = train_test_split(data,target,test_size=0.5)

train_test_split()という命令で、トレーニングデータと、正解データ、問題のデータを分割して使えるように準備し、以下でモデルを作成、そしてモデルにデータを当てて学習させています。

#モデル学習
model = RandomForestClassifier(n_estimators=100)
model.fit(train_data, train_target)

あとはこの"model"を使って正答率と、あやめの種類を予測させています。

# 正解率
score = model.score(test_data, test_target)
print(score)

#予測
predict = model.predict(test_data)
print(predict)


この記事が気に入ったらサポートをしてみませんか?