見出し画像

埋め込み層の描画

fastaiの表データを扱うtabularモデルは強力な予測手法だ.特にカテゴリーデータを埋め込み層で処理することが重要だが,結果をもとに洞察も得られる.

fast.aiの講義では協調フィルタリングの結果を主成分分析して2次元に落とし込んで可視化していたが,ここでは表データに対して適用してみる.

例として使用するのは以前紹介したadultデータだ.学習器のsummaryを見てみると先頭は埋め込み層になっていることが確認できる.

Path: /tmp/.fastai/data/adult_sample, model=TabularModel(
(embeds): ModuleList(
(0): Embedding(10, 6)
(1): Embedding(17, 8)
(2): Embedding(8, 5)
(3): Embedding(16, 8)
(4): Embedding(7, 5)
(5): Embedding(6, 4)
(6): Embedding(3, 3)
)

data.cat_namesでカテゴリーデータの名前を確認する.

['workclass',
'education',
'marital-status',
'occupation',
'relationship',
'race',
'education-num_na']

educationデータにNaNが入っているので,自動的に最後の行(educationがNaNならTrueの列)が追加されていることが確認できる.

各カテゴリーに含まれているクラス数を
     [(len(data.classes[n]),n) for n in data.cat_names]
で確認しておく.

[(10, 'workclass'),
(17, 'education'),
(8, 'marital-status'),
(16, 'occupation'),
(7, 'relationship'),
(6, 'race'),
(3, 'education-num_na')]

例としてeducationを可視化してみる.まずは主成分分析(pca)を行う.

X,Y = learn.model.embeds[1].weight.pca(2).t()

次に,education列に入っているクラス名を所得する.

text = data.classes["education"]

最後に描画する.

plt.figure(figsize=(15,15))
for i, x, y in zip(text, X, Y):
       plt.text(x,y,i, fontsize=15)
plt.show()

とまあこれだけだ.冒頭に示したような図ができる.どうやら博士と教授だけが仲間はずれのようだ:-) 


この記事が気に入ったらサポートをしてみませんか?