埋め込み層の描画
fastaiの表データを扱うtabularモデルは強力な予測手法だ.特にカテゴリーデータを埋め込み層で処理することが重要だが,結果をもとに洞察も得られる.
fast.aiの講義では協調フィルタリングの結果を主成分分析して2次元に落とし込んで可視化していたが,ここでは表データに対して適用してみる.
例として使用するのは以前紹介したadultデータだ.学習器のsummaryを見てみると先頭は埋め込み層になっていることが確認できる.
Path: /tmp/.fastai/data/adult_sample, model=TabularModel(
(embeds): ModuleList(
(0): Embedding(10, 6)
(1): Embedding(17, 8)
(2): Embedding(8, 5)
(3): Embedding(16, 8)
(4): Embedding(7, 5)
(5): Embedding(6, 4)
(6): Embedding(3, 3)
)
data.cat_namesでカテゴリーデータの名前を確認する.
['workclass',
'education',
'marital-status',
'occupation',
'relationship',
'race',
'education-num_na']
educationデータにNaNが入っているので,自動的に最後の行(educationがNaNならTrueの列)が追加されていることが確認できる.
各カテゴリーに含まれているクラス数を
[(len(data.classes[n]),n) for n in data.cat_names]
で確認しておく.
[(10, 'workclass'),
(17, 'education'),
(8, 'marital-status'),
(16, 'occupation'),
(7, 'relationship'),
(6, 'race'),
(3, 'education-num_na')]
例としてeducationを可視化してみる.まずは主成分分析(pca)を行う.
X,Y = learn.model.embeds[1].weight.pca(2).t()
次に,education列に入っているクラス名を所得する.
text = data.classes["education"]
最後に描画する.
plt.figure(figsize=(15,15))
for i, x, y in zip(text, X, Y):
plt.text(x,y,i, fontsize=15)
plt.show()
とまあこれだけだ.冒頭に示したような図ができる.どうやら博士と教授だけが仲間はずれのようだ:-)
この記事が気に入ったらサポートをしてみませんか?