LightGBMで二値分類モデルを構築する

LightGBMとは?

LightGBM(Light Gradient Boosting Machine)は、Microsoft社が開発した勾配ブースティング(Gradient Boosting)アルゴリズムをベースとした機械学習ライブラリです。

この記事では、LightGBMで二値分類モデルを構築してみます。

データセットの用意

今回、KaggleにあるTitanicのデータセットを利用したいと思います。このデータセットは、1912年に沈没したタイタニック号の乗客の情報を含んでおり、生存者を予測するためのモデルを構築することが目的となっています。

import pandas as pd
import numpy as np

df = pd.read_csv("./titanic.csv")

特徴量エンジニアリング

学習を行っていくために、数値特徴量やカテゴリカル特徴量に対して前処理を実施しています。

from sklearn.preprocessing import StandardScaler

def extract_cabin_number(cabin):
    if isinstance(cabin, str):
        cabin_number = "".join(filter(str.isdigit, cabin))
        if cabin_number:
            return int(cabin_number)
    return np.nan


df["Cabin"] = df["Cabin"].apply(extract_cabin_number)
df["Cabin"] = df["Cabin"].fillna(df["Cabin"].mean())
df["Age"] = df["Age"].fillna(df["Age"].mean())
df["Sex"] = df["Sex"].replace({"male": 1, "female": 0})
df["Embarked"] = df["Embarked"].replace({"S": 0, "C": 1, "Q": 2})
df["Embarked"] = df["Embarked"].fillna(df["Embarked"].mean())

col = ["Age", "Fare", "Cabin"]
scaler = StandardScaler()
df[col] = scaler.fit_transform(df[col])

df.drop(columns=["Name", "PassengerId", "Ticket"], axis=1, inplace=True)

目的変数設定

X = df.drop(columns=['Survived'])
y = df['Survived']

LightGBMモデルの構築

学習用にデータを分割し、LightGBM用データセットの作成した後にモデルの学習を行います。

lgb_paramsはLightGBMのハイパーパラメータを設定しています。

  • "objective": "binary": 二値分類を目的として学習します。

  • "metric": "binary_logloss": 評価指標の指定

  • "verbosity": -1: ログの出力を抑制します。

import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

lgb_params = {"objective": "binary", "metric": "binary_logloss", "verbosity": -1}
lgb_train = lgb.Dataset(X_train, y_train)

lgb = lgb.train(lgb_params, lgb_train)
y_pred = lgb.predict(X_test)

評価

予測結果を二値分類として評価するために、閾値を用いて生存したかどうかのクラスに変換し、正解率を算出します。

y_pred = [1 if pred > 0.5 else 0 for pred in y_pred]
accuracy_score(y_test, y_pred)

結果は下記の通りになりました。

0.8100558659217877