見出し画像

第35話 オリジナルAI失敗の原因を考察してみる

前回オリジナルAIを作って見事失敗に散り、学習に使用したデータセットに問題があるのでは?と原因を推測したところで終わっていました。
そこで今回はこの推測を深掘りしてみようと思います。

ちなみに前回が気になったらコチラからどうぞ。

お手本のデータセットはどうなっている?

オリジナルAIは、scikit-learnの手書き数字のデータセット(digits_dataset)に対して効果的な学習ができた畳み込みニューラルネットワーク(CNN)をお手本にして作成しました。ではこの手書き数字のデータセットの中身がどうなっているか見ていきます。

手書き数字データセットの中身
このデータセットには0~9の手書きの数字が入っているのですが、実際どんなデータが入っているか確認していくことにしました。そのためのコードはこんな感じです。

import matplotlib.pyplot as plt
import pandas as pd
from sklearn import datasets

# scikit-learnの手書き文字データセットの読み込み
digits_data = datasets.load_digits()

# 試しにデータセットの1番目を表示
plt.matshow(digits_data.images[0]) 
plt.show()

# 数字の分布を確認する
digits_data_target = digits_data.target
plt.hist(digits_data_target)
plt.show()

# データフレーム化して要約統計量を表示
df = pd.DataFrame(digits_data_target, columns=["num"])
print(df.describe())

これを実行すると3つの画像が表示されます。

1つ目の画像
データセットの1番目の数字が実際どんなものか表示させました。

スクリーンショット 2020-08-01 14.25.45

これは0なのですが・・・お世辞にもきれいな字とは言えないですね ^^;
これを学習させてるわけですが、AIに感情があったら「もっと綺麗な字にしてくれ」と文句を言われそうですw

2つ目の画像
0〜9の数字がデータセットの中にどのように分布しているかを確認しました。

スクリーンショット 2020-08-01 14.25.51

どうやらそれぞれの数字で175個前後のデータがあるようです。結構な数のデータ量ですね。

それにしてもなぜ175個ずつにしなかったのでしょうね。揃ってた方が綺麗でカッコイイというか、この微妙なバラツキは気持ち悪いというか・・・。
私としては「掃除機で掃除したのに部屋の角はかけてません」と言われているような感覚で、許せないですねぇ。

3つ目の画像
最後におまけとして要約統計量(データセットの特徴を表す代表的な統計学上の値)を表示させました。上から順に説明します。

スクリーンショット 2020-08-01 14.25.56

count:データ数
mean:平均値
std:標準偏差(データの広がり具合)
min:最小値
25%:第一四分位数(データを小さい順に並べて初めから数えて25%の位置にある数)
50%:第二四分位数(データを小さい順に並べて初めから数えて50%の位置にある数=中央値)
75%:第三四分位数(データを小さい順に並べて初めから数えて75%の位置にある数)
max:最大値

2つ目画像でデータセットの分布をみているのですが、定量的に示すとどうなるのかも確認してみました。

オリジナルAIのデータセットはどうなっている?

さていよいよオリジナルAIで自分が作成したデータセットがどうなっているか確認していきます。

オリジナルAIでは1枚の画像を分割して、分割された画像の中に何個細胞が入っているかを数えました。

画像のイメージ

名称未設定

分割後の画像 (一例)

画像6


そして数えた細胞をcsvファイルに次のように記録し、pythonで取り込んでいました。

細胞数を記録したcsvファイル (抜粋)

スクリーンショット 2020-08-01 15.10.40


以上の「分割した画像」+「細胞数を記したcsvファイル」がオリジナルAIのデータセットになります。

オリジナルAI用のデータセットの分布をみてみる
このデータセットの細胞数の分布がどうなっているか、見本と同じようにみていくことにしましょう。csvファイルをpythonで取り込んでヒストグラム表示させるのと、要約統計量を表示させます。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


# オリジナルAIのデータセットについて
# 正解データをcsvから取り込む
data_list = np.genfromtxt("dataset_csv.csv", delimiter=',')
correct = data_list[1:,2]
n_data = len(correct)
correct_data = np.reshape(correct, (n_data, 1))

# ヒストグラムを表示
plt.hist(correct_data)
plt.show()

# データフレーム化して要約統計量を表示
df = pd.DataFrame(correct_data, columns=["count"])
print(df.describe())

それでは実行結果です。じゃん!

スクリーンショット 2020-08-02 4.18.01

俯瞰してみると山型になっています。分布の中央付近(細胞数5,8,9,10個)のデータは多いですが、分布の両端(0,11個以上)は少ないですね。つまりデータに偏りがあります。
さきほどの手書き数字のデータセットと横並びで比較してみましょう。

名称未設定

データの分布が全く違いますね!
手書きの方は学習がうまくいって、オリジナルの方はうまくいかない。これはデータの偏りが原因かも、言い換えるとデータが満遍なくあれば学習がうまくいくかもしれません。

それを裏付けるかわかりませんが、オリジナルAIの予測精度をみるとこの主張が合っているかもと思わされます。

スクリーンショット 2020-08-02 4.57.27

この図はオリジナルAIの予測精度をプロットしたものです。横軸は正解値、縦軸はAIの予測結果で、斜めの直線にプロットが乗れば正しい予測ができていることを意味します。
プロットが斜めの直線からかなりばらついているので、AIの予測は全くうまくいっていません。しかしそれでもデータ数が多かった中央値(6前後)では直線に乗っているデータがいくつかあります。対してデータ数が少なかった端っこの値は、斜め直線にかすりもしません。
ということはつまり、端っこの値のデータ数が増えてくれば予測精度が高まる可能性があります!オリジナルAIが失敗して落ち込んでましたが、希望の光が見えてきたぞー!

まとめ

お手本とオリジナルのデータセットを比較することにより、データの偏りが学習に影響を与えている可能性があることがわかりました。
なので次なるステップとして、データセットをデータの偏りがないように作り直すことをやってみようと思います。

長い時間かけて勉強して実装できたAIが失敗したときにはガクっと落ち込みましたが、失敗から目を逸らさず向き合うことで新たな光が見えてきました。また失敗するかもしれませんが、いまは新たな光に向かって突き進んで行こうと思います。
こうやって人って成長していくんだなぁ。←若干326を意識

それではまた(^_^)ノシ

参考サイト

四分位数について
https://bellcurve.jp/statistics/course/19277.html

よろしければサポートお願いします!いただいたサポートは書籍代等に活用いたします!