LDA(線形判別分析)
LDA(線形判別分析)とは?
LDA(線形判別分析)は、統計学および機械学習において使用される手法で、クラス間の分離を最大化するような方法でデータを低次元に射影することによって、クラス識別や次元削減を行うものです。
この記事ではscikit-learnライブラリに実装されているLinearDiscriminantAnalysisクラスを利用して線形判別分析を行ってみたいと思います。
データセットの読み込み
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
今回はscikit-learnのデータセットにあるiris(アヤメ)データセットを利用します。特徴量としては4個、データポイントは150個のデータです。また目的変数として3種の花の種類があります。
データセットの分割
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
訓練用とテスト用にデータを分割します。
LDAの学習
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(n_components=2)
lda.fit(X_train, y_train)
LinearDiscriminantAnalysisクラスを出力する次元数を2に設定して初期化しています。LDAモデルの学習を行います。
学習したLDAモデルでデータを射影
X_lda = lda.transform(X)
transformメソッドは、データをLDAによって見つけた最適な射影軸に射影します。
射影した結果を視覚化
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
marker_shapes = 'so^'
colors = ['red', 'green', 'blue']
labels = iris.target_names
for i, label in enumerate(labels):
plt.scatter(X_lda[y == i, 0], X_lda[y == i, 1], alpha=0.8, c=colors[i], marker=marker_shapes[i], label=label)
plt.xlabel('LD1')
plt.ylabel('LD2')
plt.title('LDA: Iris dataset')
plt.legend(loc='best')
plt.show()