知識不要で機械学習!?Pycaretでモデルを作る

Pycaretとは

以下公式からの抜粋

PyCaret is an open source, low-code machine learning library in Python that allows you to go from preparing your data to deploying your model within minutes in your choice of notebook environment.​

数行コードを書くだけで、前処理、学習、可視化などを行ってくれるツールになります。

この記事で行う内容

前処理・学習・推論

実行環境

Google Colaboratory

Pycaret 2.0

Pycaretのインストール

Google Colaboratory上で以下のコードを実行します。

今回はver2.0をインストールするのでversionを指定しインストールしています。versionを指定しなかった場合は最新がインストールされます。

!pip install pycaret==2.0

ライブラリのimport

Pycaretをimportします

今回は分類問題を扱うため、「from pycaret.classification import *」をimportします。

回帰問題を扱う場合は「from pycaret.regression import *」をimportしてください。

# 今回は2値分類を行うので、classificationをimport
from pycaret.classification import *
from pycaret.datasets import get_data
# Google Colabを用いる場合、以下のコードでインタラクティブな表示をする
from pycaret.utils import enable_colab
enable_colab()

データの準備

Pycaretで用意されている「クレジット」のデータを使用します。

このデータはバイナリ分類(2値分類)用のデータとなります。

多クラス分類だとワインのデータなどが準備されています。

インストールできるデータセットは公式ページを参照ください!

※インターネットに繋がっていないと、インストールできません。

dataset = get_data('credit')

実行するとデータの中身が表示されます

目的変数は「default」というカラムで0 or 1が入っています。

スクリーンショット 2020-09-19 10.04.08

前処理の実行

1行書くだけで前処理の実行ができます。

 exp1 = setup(dataset, target='default')

前処理では、trainデータ・testデータの分割、欠損値の補完など行ってくれます。詳しくは公式ページをご確認ください。

実行した結果をみて見ます!

trainデータ、testデータ説明変数に関する情報が入っているようです。

(       LIMIT_BAL   AGE  BILL_AMT1  ...  PAY_6_6  PAY_6_7  PAY_6_8
0        20000.0  24.0     3913.0  ...      0.0      0.0      0.0
1        90000.0  34.0    29239.0  ...      0.0      0.0      0.0
2        50000.0  37.0    46990.0  ...      0.0      0.0      0.0
3        50000.0  57.0     8617.0  ...      0.0      0.0      0.0
4        50000.0  37.0    64400.0  ...      0.0      0.0      0.0
...          ...   ...        ...  ...      ...      ...      ...
23995    80000.0  34.0    72557.0  ...      0.0      0.0      0.0
23996   150000.0  43.0     1683.0  ...      0.0      0.0      0.0
23997    30000.0  37.0     3565.0  ...      0.0      0.0      0.0
23998    80000.0  41.0    -1645.0  ...      0.0      0.0      0.0
23999    50000.0  46.0    47929.0  ...      0.0      0.0      0.0

[24000 rows x 91 columns], 0        1
1        0
2        0
3        0
4        0
        ..
23995    1
23996    0
23997    1
23998    1
23999    1
Name: default, Length: 24000, dtype: int64,        LIMIT_BAL   AGE  BILL_AMT1  ...  PAY_6_6  PAY_6_7  PAY_6_8
1776     50000.0  41.0    21910.0  ...      0.0      0.0      0.0
22839    10000.0  33.0     7486.0  ...      0.0      0.0      0.0
4827    500000.0  50.0   265803.0  ...      0.0      0.0      0.0
17937    70000.0  48.0     3838.0  ...      0.0      0.0      0.0
2289     50000.0  27.0    43333.0  ...      0.0      0.0      0.0
...          ...   ...        ...  ...      ...      ...      ...
20299    20000.0  28.0     6064.0  ...      0.0      0.0      0.0
7309    110000.0  23.0    34486.0  ...      0.0      0.0      0.0
14968    50000.0  25.0     9587.0  ...      0.0      0.0      0.0
4922     70000.0  27.0    76062.0  ...      0.0      0.0      0.0
405      30000.0  53.0        0.0  ...      0.0      0.0      0.0

[16799 rows x 91 columns],        LIMIT_BAL   AGE  BILL_AMT1  ...  PAY_6_6  PAY_6_7  PAY_6_8
21592   100000.0  32.0    99573.0  ...      0.0      0.0      0.0
11487    20000.0  25.0     3556.0  ...      0.0      0.0      0.0
15178   330000.0  28.0   221793.0  ...      0.0      0.0      0.0
15583   360000.0  27.0        0.0  ...      0.0      0.0      0.0
17144    20000.0  27.0     5888.0  ...      0.0      0.0      0.0
...          ...   ...        ...  ...      ...      ...      ...
4403     80000.0  31.0    81358.0  ...      0.0      0.0      0.0
18870   150000.0  44.0   168179.0  ...      0.0      0.0      0.0
14294   240000.0  25.0    24640.0  ...      0.0      0.0      0.0
19562   250000.0  45.0    98015.0  ...      0.0      0.0      0.0
19278    30000.0  28.0    29234.0  ...      0.0      0.0      0.0

[7201 rows x 91 columns], 1776     0
22839    0
4827     0
17937    0
2289     1
        ..
20299    0
7309     0
14968    1
4922     1
405      0
Name: default, Length: 16799, dtype: int64, 21592    0
11487    0
15178    0
15583    0
17144    0
        ..
4403     0
18870    1
14294    0
19562    0
19278    0
Name: default, Length: 7201, dtype: int64, 6695, Pipeline(memory=None,
         steps=[('dtypes',
                 DataTypes_Auto_infer(categorical_features=[],
                                      display_types=True, features_todrop=[],
                                      ml_usecase='classification',
                                      numerical_features=[], target='default',
                                      time_features=[])),
                ('imputer',
                 Simple_Imputer(categorical_strategy='not_available',
                                numeric_strategy='mean',
                                target_variable=None)),
                ('new_levels1',
                 New_Catagorical_Le...
                ('group', Empty()), ('nonliner', Empty()), ('scaling', Empty()),
                ('P_transform', Empty()), ('pt_target', Empty()),
                ('binn', Empty()), ('rem_outliers', Empty()),
                ('cluster_all', Empty()), ('dummy', Dummify(target='default')),
                ('fix_perfect', Empty()), ('clean_names', Clean_Colum_Names()),
                ('feature_select', Empty()), ('fix_multi', Empty()),
                ('dfs', Empty()), ('pca', Empty())],
         verbose=False), [('Classification Setup Config',
                        Description        Value
  0                      session_id         6695
  1                     Target Type       Binary
  2                   Label Encoded         None
  3                   Original Data  (24000, 24)
  4                 Missing Values         False
  5               Numeric Features            14
  6           Categorical Features             9
  7               Ordinal Features         False
  8      High Cardinality Features         False
  9        High Cardinality Method          None
  10                   Sampled Data  (24000, 24)
  11          Transformed Train Set  (16799, 91)
  12           Transformed Test Set   (7201, 91)
  13               Numeric Imputer          mean
  14           Categorical Imputer      constant
  15                     Normalize         False
  16              Normalize Method          None
  17                Transformation         False
  18         Transformation Method          None
  19                           PCA         False
  20                    PCA Method          None
  21                PCA Components          None
  22           Ignore Low Variance         False
  23           Combine Rare Levels         False
  24          Rare Level Threshold          None
  25               Numeric Binning         False
  26               Remove Outliers         False
  27            Outliers Threshold          None
  28      Remove Multicollinearity         False
  29   Multicollinearity Threshold          None
  30                    Clustering         False
  31          Clustering Iteration          None
  32           Polynomial Features         False
  33             Polynomial Degree          None
  34          Trignometry Features         False
  35          Polynomial Threshold          None
  36                Group Features         False
  37             Feature Selection         False
  38  Features Selection Threshold          None
  39           Feature Interaction         False
  40                 Feature Ratio         False
  41         Interaction Threshold          None
  42                  Fix Imbalance        False
  43           Fix Imbalance Method        SMOTE),
 ('X_training Set',
         LIMIT_BAL   AGE  BILL_AMT1  ...  PAY_6_6  PAY_6_7  PAY_6_8
  1776     50000.0  41.0    21910.0  ...      0.0      0.0      0.0
  22839    10000.0  33.0     7486.0  ...      0.0      0.0      0.0
  4827    500000.0  50.0   265803.0  ...      0.0      0.0      0.0
  17937    70000.0  48.0     3838.0  ...      0.0      0.0      0.0
  2289     50000.0  27.0    43333.0  ...      0.0      0.0      0.0
  ...          ...   ...        ...  ...      ...      ...      ...
  20299    20000.0  28.0     6064.0  ...      0.0      0.0      0.0
  7309    110000.0  23.0    34486.0  ...      0.0      0.0      0.0
  14968    50000.0  25.0     9587.0  ...      0.0      0.0      0.0
  4922     70000.0  27.0    76062.0  ...      0.0      0.0      0.0
  405      30000.0  53.0        0.0  ...      0.0      0.0      0.0
  
  [16799 rows x 91 columns]),
 ('y_training Set', 1776     0
  22839    0
  4827     0
  17937    0
  2289     1
          ..
  20299    0
  7309     0
  14968    1
  4922     1
  405      0
  Name: default, Length: 16799, dtype: int64),
 ('X_test Set',
         LIMIT_BAL   AGE  BILL_AMT1  ...  PAY_6_6  PAY_6_7  PAY_6_8
  21592   100000.0  32.0    99573.0  ...      0.0      0.0      0.0
  11487    20000.0  25.0     3556.0  ...      0.0      0.0      0.0
  15178   330000.0  28.0   221793.0  ...      0.0      0.0      0.0
  15583   360000.0  27.0        0.0  ...      0.0      0.0      0.0
  17144    20000.0  27.0     5888.0  ...      0.0      0.0      0.0
  ...          ...   ...        ...  ...      ...      ...      ...
  4403     80000.0  31.0    81358.0  ...      0.0      0.0      0.0
  18870   150000.0  44.0   168179.0  ...      0.0      0.0      0.0
  14294   240000.0  25.0    24640.0  ...      0.0      0.0      0.0
  19562   250000.0  45.0    98015.0  ...      0.0      0.0      0.0
  19278    30000.0  28.0    29234.0  ...      0.0      0.0      0.0
  
  [7201 rows x 91 columns]),
 ('y_test Set', 21592    0
  11487    0
  15178    0
  15583    0
  17144    0
          ..
  4403     0
  18870    1
  14294    0
  19562    0
  19278    0
  Name: default, Length: 7201, dtype: int64),
 ('Transformation Pipeline', Pipeline(memory=None,
           steps=[('dtypes',
                   DataTypes_Auto_infer(categorical_features=[],
                                        display_types=True, features_todrop=[],
                                        ml_usecase='classification',
                                        numerical_features=[], target='default',
                                        time_features=[])),
                  ('imputer',
                   Simple_Imputer(categorical_strategy='not_available',
                                  numeric_strategy='mean',
                                  target_variable=None)),
                  ('new_levels1',
                   New_Catagorical_Le...
                  ('group', Empty()), ('nonliner', Empty()), ('scaling', Empty()),
                  ('P_transform', Empty()), ('pt_target', Empty()),
                  ('binn', Empty()), ('rem_outliers', Empty()),
                  ('cluster_all', Empty()), ('dummy', Dummify(target='default')),
                  ('fix_perfect', Empty()), ('clean_names', Clean_Colum_Names()),
                  ('feature_select', Empty()), ('fix_multi', Empty()),
                  ('dfs', Empty()), ('pca', Empty())],
           verbose=False))], False, -1, True, [], [], [], 'no_logging', False, False, '670a', False, None, <Logger logs (DEBUG)>)

学習

事前準備が終わったので学習を行います。今回はlightgbmを指定して学習を行います。

model = create_model("lightgbm")

こんな感じで学習の結果が表示されます。

スクリーンショット 2020-09-19 10.52.33

学習で使用できるロジックは以下になります。(pycaretのソースコードより引用)

ID          Name      
   --------    ----------     
   'lr'        Logistic Regression             
   'knn'       K Nearest Neighbour            
   'nb'        Naive Bayes             
   'dt'        Decision Tree Classifier                   
   'svm'       SVM - Linear Kernel	            
   'rbfsvm'    SVM - Radial Kernel               
   'gpc'       Gaussian Process Classifier                  
   'mlp'       Multi Level Perceptron                  
   'ridge'     Ridge Classifier                
   'rf'        Random Forest Classifier                   
   'qda'       Quadratic Discriminant Analysis                  
   'ada'       Ada Boost Classifier                 
   'gbc'       Gradient Boosting Classifier                  
   'lda'       Linear Discriminant Analysis                  
   'et'        Extra Trees Classifier                   
   'xgboost'   Extreme Gradient Boosting              
   'lightgbm'  Light Gradient Boosting              
   'catboost'  CatBoost Classifier

推論

pycaretでは推論も簡単に行うことができます。

その際に出力されるAUCですがROC曲線におけるAUCになります。

インバランスデータだと正しく評価できないので注意が必要です。

predict_model(model)

上記のコードはscikit-learnの「roc_auc_score」を実行しており、処理で例外が発生した場合は無条件でAUC=0になります。

ちなみに、PR曲線でのAUCを算出したい場合は以下のコードで実行できます。

predict = model.predict(exp1[3])
from sklearn import metrics
precision, recall, thresholds = metrics.precision_recall_curve(exp1[5], predict)
auc = metrics.auc(recall, precision)
print(auc)

# ↑の実行結果(精度低い・・・)
0.5825109588252797


この記事が気に入ったらサポートをしてみませんか?