見出し画像

ニューラルネットワークの連鎖律を解く(偏微分は横に置いておく)

これはなに?

誤差逆伝播法のベースになっている連鎖律を解くと、ニューラルネットワークの理解が捗ります。そこで、連鎖律を解くための教材として「ネットワーク図」と「連鎖律の図」、そしていい感じの「お題」を用意しました。まずはお題から。

【お題】重み$${w_{32}^{(2)}}$$を更新したい。どうやるの?

このお題はシンプルですがニューラルネットワークの連鎖律を解く上で必要にして十分な要素が詰まっています。このお題では、ネットワーク全体の中にあるたった1個の重みの更新について考えますが、$${w_{ij}^{(l)}}$$のように英字の添字を使って考えるとピンと来ないので添字を展開して具体的に考えることにします。具体的な重み$${w_{32}^{(2)}}$$の位置についてはネットワーク図に示します。

この後、偏微分記号$${(∂)}$$のついた数式が登場しますが、一切の計算を行わずにパズルアドベンチャーの感覚で連鎖律を解いていきます。


表記体系

とりあえずは、3つの添字$${i,\,\,j,\,\,l}$$を押さえておくと良いです。

テンソルの表記

$$
\def\arraystretch{1.5}
\begin{array}{l:ll}
テンソル&表記方法&表記例 \\[1pt]
\hline \\[-15pt]
スカラー&小文字 & w_{ij}^{(l)}\,\,\,x_{i}^{(l)}\,\,\,δ_{j}^{(l)} \\
ベクトル&\overrightarrow{小文字} &\vec{b}^{(l)}\,\,\,\vec{x}^{(l)}\,\,\,\vec{δ}^{(l)} \\
行列&大文字 & W^{(l)}\,\,X^{(l)}\,\,Δ^{(l)}
 \end{array}                                                                                                             \\
$$

記号の定義

$$
\def\arraystretch{1.3}
\begin{array}{c:ll}
記号 & 読み方 & 意味 &\\[1pt]
\hline \\[-12pt]
η & イータ & 学習率 \\
δ & デルタ & 誤差(スカラー) \\
Δ & デルタ & 誤差(行列) \\
%∇ & ナブラ & 勾配(ベクトル) \\
{grad} & グラディエント & 勾配(ベクトル、行列) \\[2pt]
\hdashline \\[-12pt]
Σ & シグマ & 総和 \\
∂ & ラウンドディー & 偏微分 \\
⋅ & ドット & ドット積(内積)\\
⊙ & 丸ドット & アダマール積(外積)\\
\mathrm{T} & トランスポーズ & 転置 \\
{:=} & コロンイコール & 代入
 \end{array}                                                            \\
$$

添字の定義

$$
\def\arraystretch{1.5}
\begin{array}{c:ll}
添字&意味&表記例 \\
\hline \\[-15pt]
i & 入力番号 & w_{ij}^{(l)}\,\,\,x_{i}^{(l)} \\
j & ノード番号 & w_{ij}^{(l)}\,\,\,b_{j}^{(l)}\,\,\,u_{j}^{(l)}\,\,\,y_{j}^{(l)}\,\,\,t_{j}^{(l)}\,\,\,δ_{j}^{(l)} \\
l & レイヤ番号 & w_{ij}^{(l)} \\
r & 行列の行番号 & \sum_{r=1}^{m} \\[2pt]
\hdashline \\[-15pt]
q & 入力数 & (\,i=1,2,\ldots,\,q\,) \\
p & ノード数 & (\,j=1,2,\ldots,\,p\,) \\
k & レイヤ数 & l=k,\,\,l\lt k \\
m & バッチサイズ & \sum_{r=1}^{m} \\
\end{array}                                                                                \\
$$

表記例

$$
\def\arraystretch{1.6}
\begin{array}{l:cc}
&スカラー&ベクトル&行列 \\[2pt]
\hline \\[-17pt]
特徴量 & {x_{i}} & {\vec{x}} & {X} \\
予測値 & {y_{j}} & {\vec{y}} & {Y} \\
正解値 & {t_{j}} & {\vec{t}} & {T} \\[2pt]
\hdashline \\[-16pt]
入力値 & {x_{i}^{(l)}} & {\vec{x}^{(l)}} & {X^{(l)}} \\
入力の勾配 & {{∂L}/{x_{i}^{(l)}}} & ― & {X_{grad}^{(l)}} \\
重み & {w_{ij}^{(l)}} & {\vec{w}_{j}^{(l)}} & {W^{(l)}} \\
重みの勾配 & {{∂L}/{∂w_{ij}^{(l)}}} & ― & {W_{grad}^{(l)}} \\
バイアス & {b_{j}^{(l)}} & {\vec{b}^{(l)}} & ― \\
バイアスの勾配 & {{∂L}/{∂b_{j}^{(l)}}} & {\vec{b}_{grad}^{(l)}} & ― \\
出力値\,\,u & {u_{j}^{(l)}} & {\vec{u}^{(l)}} & {U^{(l)}} \\
出力値\,\,y & {y_{j}^{(l)}} & {\vec{y}^{(l)}} & {Y^{(l)}} \\
誤差 & {δ_{j}^{(l)}} & {\vec{δ}^{(l)}} & {Δ^{(l)}} \\
\end{array}                                                                                \\
$$

※ 勾配を含めて、$${1×1}$$の形状をしているものはすべてスカラーに分類しています。

図を用意する

ネットワーク図

入力層、Layer 1、Layer 2、Layer 3のそれぞれのノード数を 5, 4, 3, 2 としています。お題の重み$${w_{32}^{(2)}}$$は図の真ん中辺に🅦で示してあります。

ネットワークの構成要素を参照しやすくするために行列とベクトルで表現しておきます。

連鎖律の図

付録「データフローダイアグラム」の順伝播部分を利用して連鎖律の図を作成しました。

連鎖律を解く

お題を再確認する

準備ができたのでもう一度、お題を確認しておきましょう。

【お題】重み$${w_{32}^{(2)}}$$を更新したい。どうやるの?

まずは連鎖律を使って重みの勾配$${∂L/∂w_{32}^{(2)}}$$を求めます。勾配さえ求めれば、あとは勾配降下法の式を使って重みを更新できます。

偏微分項を作る

連鎖律の図にあるように8つの偏微分項(🅻 🅼 🅰 🅱 🅲 🅳 🅴 🅵)を作るところから始めます。例として偏微分項$${∂u_{2}^{(2)}/∂w_{32}^{(2)}}$$(🅻)を作ってみます。作業に当たってはゴールである重みの勾配$${∂L/∂w_{32}^{(2)}}$$(🆆)を意識しておくと良いです。

偏微分項を作るには連鎖律の図とネットワーク図を突き合わせます。先に連鎖律の図を使ってザックリと偏微分項$${{∂u}/{∂w}}$$を作ります。これを作るには図の 🅻 の上にあるアンダーブレス($${\underbrace{   }}$$)の左側を分母に、そして右側を分子に見立てます。図からは分母として$${{∂x},\,\,{∂w},\,\,{∂b}}$$の3つのうちのいずれかを選べることが分かります。今回はこの中から$${∂w}$$を選んだわけです。また、図からは分子にできるのは$${∂u}$$だけであることも分かります。

続いて作りたての偏微分項$${{∂u}/{∂w}}$$をネットワーク図の$${w_{32}^{(2)}}$$(🅦)と突き合わせます。そうすると偏微分項$${∂u_{2}^{(2)}/∂w_{32}^{(2)}}$$(🅻)を得ることができます。他の7つの偏微分項についてもネットワーク図の逆伝播を遡りながら同じ要領で作っていきます。

今回作った8つの偏微分項を$${式(1.1)〜式(1.8)}$$に示します。これとは別に機能単位で作った偏微分項を$${式(1.9)〜式(1.13)}$$に示します。 後者についてもアンダーブレス($${\underbrace{   }}$$)を頼りに作ることができます。また、連鎖律の図には載せていませんがバイアスの更新に関連する偏微分項を$${式(1.14)、式(1.15)}$$に示します。バイアスの更新についてはしばらく忘れて、まずは重みの更新に意識を集中します。
【連鎖律の図】🅻 🅼 🅰 🅱 🅲 🅳 🅴 🅵,🅶 🅷 🅹 🅺 🆆

$$
\def\arraystretch{3.2}
\begin{array}{lr}
\cfrac{∂u_{2}^{(2)}}{∂w_{32}^{(2)}} &(1.1)\,🅻 \\
\cfrac{∂y_{2}^{(2)}}{∂u_{2}^{(2)}} &(1.2)\,🅼 \\
\cfrac{∂u_{1}^{(3)}}{∂x_{2}^{(3)}} &(1.3)\,🅰 \\
\cfrac{∂u_{2}^{(3)}}{∂x_{2}^{(3)}} &(1.4)\,🅱 \\
\cfrac{∂y_{1}^{(3)}}{∂u_{1}^{(3)}} &(1.5)\,🅲 \\
\end{array}   
\begin{array}{lr}
\cfrac{∂y_{2}^{(3)}}{∂u_{2}^{(3)}} &(1.6)\,🅳 \\
\cfrac{∂L}{∂y_{1}^{(3)}} &(1.7)\,🅴 \\
\cfrac{∂L}{∂y_{2}^{(3)}} &(1.8)\,🅵 \\
\cfrac{∂L}{∂u_{1}^{(3)}} &(1.9)\,🅶 \\
\cfrac{∂L}{∂u_{2}^{(3)}} &(1.10)\,🅷 \\
\end{array}   
\begin{array}{lr}
\cfrac{∂L}{∂x_{2}^{(3)}} &(1.11)\,🅹 \\
\cfrac{∂L}{∂u_{2}^{(2)}} &(1.12)\,🅺 \\
\cfrac{∂L}{∂w_{32}^{(2)}} &(1.13)\,🆆 \\
\cfrac{∂u_{2}^{(2)}}{∂b_{2}^{(2)}} &(1.14)\, \\
\cfrac{∂L}{∂b_{2}^{(2)}} &(1.15)\, \\
\end{array}
$$

連鎖律を使って式を立てる

連鎖律の図には載せていませんが、重みの勾配$${∂L/∂w_{32}^{(2)}}$$を連鎖律を使って表現するには$${式(1.13)、式(1.1)〜式(1.8)}$$を使います。
🆆 =(🅴🅲🅰 + 🅵🅳🅱)🅼🅻

$$
\tag{2.1}\begin{aligned}
{\cfrac{∂L}{∂w_{32}^{(2)}}}&=(\,
{\cfrac{∂L}{∂y_{1}^{(3)}}}\,
{\cfrac{∂y_{1}^{(3)}}{∂u_{1}^{(3)}}}\,
{\cfrac{∂u_{1}^{(3)}}{∂x_{2}^{(3)}}}+
{\cfrac{∂L}{∂y_{2}^{(3)}}}\,
{\cfrac{∂y_{2}^{(3)}}{∂u_{2}^{(3)}}}\,
{\cfrac{∂u_{2}^{(3)}}{∂x_{2}^{(3)}}}\,)\,
{\cfrac{∂y_{2}^{(2)}}{∂u_{2}^{(2)}}}\,
{\cfrac{∂u_{2}^{(2)}}{∂w_{32}^{(2)}}} \\[1pt]
\end{aligned}
$$

この式の乗算部分は連鎖律(chain rule)によるもの、そして括弧内の加算部分は多変数連鎖律(multi-variable chain rule)によるものです。この式は最終的には次のように簡略化されます。
【連鎖律の図】🆆=🅺🅻

$$
\tag{2.2}\cfrac{∂L}{∂w_{32}^{(2)}}={x_{3}^{(2)}δ_{2}^{(2)}}
$$

今回はお題へのチャレンジとは別に、実装に使える数式も作ってみたいと思います。そのため、ここから先は式$${(2.1)}$$から離れて連鎖律の図にあるように機能単位で連鎖律を解いていきます。機能単位は次の4つです。

$$
\def\arraystretch{1.5}
\begin{array}{c:lll}
\# & 機能単位 & & 位置 \\
\hline \\[-15pt]
1. & 出力層の誤差 &{δ_{1}^{(3)},\,\,δ_{2}^{(3)}} & \text{Layer 3} \\
2. & 入力の勾配 &{∂L/∂x_{2}^{(3)}} & \text{Layer 3} \\
3. & 中間層の誤差 &{δ_{2}^{(2)}} & \text{Layer 2, Node 2} \\
4. & 重みの勾配 &{∂L/∂w_{32}^{(2)}} & \text{Layer 2, Node 2}
\end{array}                                                                                \\
$$

偏微分のスニペットを定義する

連鎖律を解くにあたって偏微分項の展開も一緒くたにしてしまうとこれがノイズとなって焦点がぼやけてしまいます。そこで偏微分項の展開については再利用可能なスニペットを定義して利用することにします。

$$
\begin{array}{lr}
\begin{aligned}
\cfrac{∂u_{j}^{(l)}}{∂x_{i}^{(l)}}&={w_{ij}^{(l)}} \\[22pt]
\cfrac{∂u_{j}^{(l)}}{∂w_{ij}^{(l)}}&={x_{i}^{(l)}} \\[22pt]
\cfrac{∂u_{j}^{(l)}}{∂b_{j}^{(l)}}&=1 \\[22pt]
\cfrac{∂y_{j}^{(l)}}{∂u_{j}^{(l)}}&=f'^{(l)}(u_{j}^{(l)})
\end{aligned}&
\begin{aligned}
  (3.1) \\[32pt]
  (3.2) \\[32pt]
  (3.3) \\[32pt]
  (3.4)
\end{aligned}
\end{array}
$$

#1. 出力層の誤差(Layer 3)

ネットワーク図を見るとLayer 3からLayer 2, Node 2に向かっている信号は2本あります(🅙)。この信号を上流まで遡ったところにある誤差$${δ_{1}^{(3)}}$$(🅖)と誤差$${δ_{2}^{(3)}}$$(🅗)を求めるところから始めます。

出力層の誤差$${δ_{1}^{(3)}}$$を連鎖律を使って表現するには$${式(1.9)、式(1.7)、式(1.5)}$$を使います。
【連鎖律の図】🅶 = 🅴🅲

$$
\tag{4.1}
δ_{1}^{(3)}=\cfrac{∂L}{∂u_{1}^{(3)}}=\cfrac{∂L}{∂y_{1}^{(3)}}\cfrac{∂y_{1}^{(3)}}{∂u_{1}^{(3)}}
$$

出力層の誤差$${δ_{2}^{(3)}}$$を連鎖律を使って表現するには$${式(1.10)、式(1.8)、式(1.6)}$$を使います。
【連鎖律の図】🅷 = 🅵🅳

$$
\tag{4.2}
δ_{2}^{(3)}=\cfrac{∂L}{∂u_{2}^{(3)}}=\cfrac{∂L}{∂y_{2}^{(3)}}\cfrac{∂y_{2}^{(3)}}{∂u_{2}^{(3)}}
$$

$${式(4.1)}$$と$${式(4.2)}$$は、損失関数の偏微分と活性化関数の偏微分を掛けた形になっています。いきなり厄介なことになりましたが、そこはうまいこと出来ていて、少なくとも次の条件のいずれかを満足していれば$${式(4.3)}$$と$${式(4.4)}$$を得ることが出来ます。

$$
\def\arraystretch{1.5}
\begin{array}{l:ll}
&出力層の活性化関数&損失関数 \\
\hline \\[-15pt]
回帰の場合 & 恒等関数 & 二乗和誤差関数 \\
分類の場合 & ソフトマックス関数 & 交差エントロピー誤差関数 \\
\end{array}                    \\
$$

$$
\tag{4.3}δ_{1}^{(3)}=\cfrac{∂L}{∂u_{1}^{(3)}}=y_{1}-t_{1}
$$

$$
\tag{4.4}δ_{2}^{(3)}=\cfrac{∂L}{∂u_{2}^{(3)}}=y_{2}-t_{2}
$$

#2. 入力の勾配(Layer 3)

ネットワーク図を見るとLayer 3からLayer 2, Node 2に向かっている信号はLayer 2, Node 2に入る前に合流しています(🅙)。図では2本しか合流していませんが、もしLayer 3のノード数が10個であれば10本の信号が合流します。入力の勾配を求めるには、これらの信号を多変数連鎖律に従ってすべて足し合わせます。

入力の勾配$${{∂L}/{∂x_{2}^{(3)}}}$$を連鎖律(と多変数連鎖律)を使って表現するには$${式(1.11)、式(1.9)、式(1.3)、式(1.10)、式(1.4)}$$を使います。
【連鎖律の図】🅹 = 🅶🅰 + 🅷🅱

$$
\tag{4.5}
\cfrac{∂L}{∂x_{2}^{(3)}}
=\cfrac{∂L}{∂u_{1}^{(3)}}\cfrac{∂u_{1}^{(3)}}{∂x_{2}^{(3)}}
+\cfrac{∂L}{∂u_{2}^{(3)}}\cfrac{∂u_{2}^{(3)}}{∂x_{2}^{(3)}}
$$

上の2つの偏微分項$${{∂u_{1}^{(3)}}/{∂x_{2}^{(3)}}}$$と$${{∂u_{2}^{(3)}}/{∂x_{2}^{(3)}}}$$を$${式(3.1)}$$に当てはめます。
【連鎖律の図】🅰,🅱

$$
\tag{4.6}
\cfrac{∂L}{∂x_{2}^{(3)}}
=\cfrac{∂L}{∂u_{1}^{(3)}}\,\,w_{21}^{(3)}
+\cfrac{∂L}{∂u_{2}^{(3)}}\,\,w_{22}^{(3)}
$$

上の2つの偏微分項$${{∂L}/{∂u_{1}^{(3)}}}$$と$${{∂L}/{∂u_{2}^{(3)}}}$$のそれぞれに「#1. 出力層の誤差」で求めた$${式(4.3)}$$と$${式(4.4)}$$を当てはめます。
【連鎖律の図】🅹 = 🅶🅰 + 🅷🅱

$$
\tag{4.7}
\cfrac{∂L}{∂x_{2}^{(3)}} 
=δ_{1}^{(3)}w_{21}^{(3)}+δ_{2}^{(3)}w_{22}^{(3)}
=\displaystyle\sum_{j=1}^{2}δ_{j}^{(3)}w_{2j}^{(3)}
$$

以上で、入力の勾配$${{∂L}/{∂x_{2}^{(3)}}}$$を求めることができました。

#3. 中間層の誤差(Layer 2, Node 2)

中間層の誤差$${δ_{2}^{(2)}}$$を連鎖律を使って表現するには$${式(1.12)、式(1.11)、式(1.2)}$$を使います。
【連鎖律の図】🅺 = 🅹🅼

$$
\tag{4.8}
δ_{2}^{(2)}=\cfrac{∂L}{∂u_{2}^{(2)}}=\cfrac{∂L}{∂x_{2}^{(3)}}\cfrac{∂y_{2}^{(2)}}{∂u_{2}^{(2)}}
$$

上の偏微分項$${{∂y_{2}^{(2)}}/{∂u_{2}^{(2)}}}$$を$${式(3.4)}$$に当てはめます。
【連鎖律の図】🅼

$$
\tag{4.9}
δ_{2}^{(2)}=\cfrac{∂L}{∂u_{2}^{(2)}}=\cfrac{∂L}{∂x_{2}^{(3)}}\,\,f'^{(2)}(u_{2}^{(2)})
$$

上の偏微分項$${{∂L}/{∂x_{2}^{(3)}}}$$に「#2. 入力の勾配」で求めた$${式(4.7)}$$を当てはめます。
【連鎖律の図】🅺 = 🅹🅼

$$
\tag{4.10}
δ_{2}^{(2)}=\cfrac{∂L}{∂u_{2}^{(2)}}=\displaystyle\sum_{j=1}^{2}δ_{j}^{(3)}w_{2j}^{(3)}\,f'^{(2)}(u_{2}^{(2)})
$$

以上で、中間層の誤差$${δ_{2}^{(2)}}$$を求めることができました。

#4. 重みの勾配(Layer 2, Node 2)

重みの勾配$${{∂L}/{∂w_{32}^{(2)}}}$$を連鎖律を使って表現するには$${式(1.13)、式(1.12)、式(1.1)}$$を使います。
【連鎖律の図】🆆=🅺🅻

$$
\tag{4.11}
\cfrac{∂L}{∂w_{32}^{(2)}}
=\cfrac{∂L}{∂u_{2}^{(2)}}\cfrac{∂u_{2}^{(2)}}{∂w_{32}^{(2)}}
$$

上の偏微分項$${{∂u_{2}^{(2)}}/{∂w_{32}^{(2)}}}$$を$${式(3.2)}$$に当てはめます。
【連鎖律の図】🅻

$$
\tag{4.12}
\cfrac{∂L}{∂w_{32}^{(2)}}
=\cfrac{∂L}{∂u_{2}^{(2)}}\,\,{x_{3}^{(2)}}
$$

上の偏微分項$${{∂L}/{∂u_{2}^{(2)}}}$$に「#3. 中間層の誤差」で求めた$${式(4.10)}$$を当てはめます。
【連鎖律の図】🆆=🅺🅻

$$
\tag{4.13}
\cfrac{∂L}{∂w_{32}^{(2)}}=x_{3}^{(2)}δ_{2}^{(2)}
$$

以上で、重みの勾配$${{∂L}/{∂w_{32}^{(2)}}}$$を求めることができました。

重みを更新する

重みを更新する場合の勾配降下法の式は次のとおりです。

$$
\tag{4.14}
w_{ij}^{(l)}:=w_{ij}^{(l)}-η\cfrac{∂L}{∂w_{ij}^{(l)}}
$$

この式に重み$${w_{32}^{(2)}}$$と$${式(4.13)}$$の重みの勾配$${{∂L}/{∂w_{32}^{(2)}}}$$を当てはめます。

$$
\def\arraystretch{2}
\tag{4.15}
\begin{aligned}
w_{32}^{(2)}&:=w_{32}^{(2)}-η\cfrac{∂L}{∂w_{32}^{(2)}} \\
w_{32}^{(2)}&:=w_{32}^{(2)}-η\,\,x_{3}^{(2)}δ_{2}^{(2)}
\end{aligned}
$$

完成です!これで重み$${w_{32}^{(2)}}$$を更新できます。

バイアスを更新する

どうせなら1つだけ残っているバイアスの更新も片付けて連鎖律をコンプリートしましょう。

【お題】バイアス$${b_{2}^{(2)}}$$を更新したい。どうやるの?

まずは連鎖律を使ってバイアスの勾配$${∂L/b_{2}^{(2)}}$$を求めます。勾配さえ求めれば、あとは勾配降下法の式を使ってバイアスを更新できます。

さっそく、重みと同じ要領で偏微分項を作るところからはじめます(「偏微分項を作る」を参照)。$${式(1.14)}$$を再掲します。

$$
\tag{4.16}
\cfrac{∂u_{2}^{(2)}}{∂b_{2}^{(2)}}
$$

バイアスの勾配$${{∂L}/{∂b_{2}^{(2)}}}$$を連鎖律を使って表現するには$${式(1.15)、式(1.12)、式(1.14)}$$を使います。

$$
\tag{4.17}
\cfrac{∂L}{∂b_{2}^{(2)}}
=\cfrac{∂L}{∂u_{2}^{(2)}}\cfrac{∂u_{2}^{(2)}}{∂b_{2}^{(2)}}
$$

上の偏微分項$${{∂u_{2}^{(2)}}/{∂b_{2}^{(2)}}}$$を$${式(3.3)}$$に当てはめます。

$$
\tag{4.18}
\cfrac{∂L}{∂b_{2}^{(2)}}=\cfrac{∂L}{∂u_{2}^{(2)}}×1
$$

上の偏微分項$${{∂L}/{∂u_{2}^{(2)}}}$$に「#3. 中間層の誤差」で求めた$${式(4.10)}$$を当てはめます。

$$
\tag{4.19}
\cfrac{∂L}{∂b_{2}^{(2)}}=δ_{2}^{(2)}×1=δ_{2}^{(2)}
$$

バイアスを更新する場合の勾配降下法の式は次のとおりです。

$$
\tag{4.20}
b_{j}^{(l)}:=b_{j}^{(l)}-η\cfrac{∂L}{∂b_{j}^{(l)}}
$$

この式にバイアス$${b_{2}^{(2)}}$$と$${式(4.19)}$$のバイアスの勾配$${{∂L}/{∂b_{2}^{(2)}}}$$を当てはめます。

$$
\def\arraystretch{2}
\tag{4.21}
\begin{aligned}
b_{2}^{(2)}&:=b_{2}^{(2)}-η\cfrac{∂L}{∂b_{2}^{(2)}} \\
b_{2}^{(2)}&:=b_{2}^{(2)}-η\,δ_{2}^{(2)}
\end{aligned}
$$

これでバイアス$${b_{2}^{(2)}}$$を更新できます。

数式を一般化する

ここまでに5つの機能単位につて連鎖律を解きました(下図の3〜7)。また、勾配降下法の式を使ってパラメータを更新しました(下図の8と9)。これらは逆伝播についての式ですが、これに順伝播の式を2つ加えれば(下図の1と2)ニューラルネットワークの骨組みとなる数式が揃います。

$$
\def\arraystretch{1.5}
\begin{array}{c:lll}
\#&骨組みとなる数式&&式番号\\
\hline \\[-15pt]
1. &出力値\,\,u & &―\\
2. &出力値\,\,y &  &―\\
3. &出力層の誤差 &{δ_{1}^{(3)},\,\,δ_{2}^{(3)}} &(4.3),\,(4.4) \\
4. &中間層の誤差 &{δ_{2}^{(2)}} &(4.10) \\
5. &入力の勾配 &{∂L/∂x_{2}^{(3)}} &(4.7) \\
6. &重みの勾配 &{∂L/∂w_{32}^{(2)}} &(4.13) \\
7. &バイアスの勾配 &{∂L}/{∂b_{2}^{(2)}} &(4.19) \\
8. &重みの更新 & &(4.15) \\
9. &バイアスの更新 & &(4.21) \\
\end{array}                                                                                \\
$$

ここからは骨組みとなる数式を一般化して使い勝手を良くします。

出力値の計算

連鎖律には含まれませんが、順伝播の式も載せておきます。

出力値$${u}$$について、次の通りです。

$$
\def\arraystretch{1.5}
\tag{5.1}
\begin{aligned}
u_{j}^{(l)}&=\displaystyle\sum_{i=1}^{q}x_{i}^{(l)}w_{ij}^{(l)}+b_{j}^{(l)} \\
&=\vec{x}^{(l)}\cdot\vec{w}_{j}^{(l)}+b_{j}^{(l)}
\end{aligned}
$$

出力値$${y}$$について、次の通りです。

$$
\tag{5.2}
y_{j}^{(l)}=f^{(l)}(u_{j}^{(l)})
$$

誤差の計算

出力層の誤差について、$${式(4.3)}$$と$${式(4.4)}$$を一般化します。

$$
\tag{5.3}
δ_{j}^{(l)}=y_{j}-t_{j}
$$

$$
(l=k)
$$

中間層の誤差について、$${式(4.10)}$$を一般化します。

$$
\tag{5.4}
δ_{j}^{(l)}=
\Big[\displaystyle\sum_{n=1}^{p}δ_{n}w_{jn}\Big]^{(l+1)}
\Big[f'(u_{j})\Big]^{(l)}
$$

$$
(l \lt k)
$$

$${式(5.4)}$$では2つのレイヤの式が合流しますが、これによる添字の衝突などを解消するために2つの工夫をしています。添字$${p}$$はレイヤ$${^{(l+1)}}$$のノード数です。

  1. レイヤ$${^{(l+1)}}$$の入力番号は$${i}$$の代わりにレイヤ$${^{(l)}}$$のノード番号$${j}$$を使用しています(両者は必ず同じくなる)。

  2. レイヤ$${^{(l+1)}}$$のノード番号には$${j}$$の代わりに$${n}$$を使用しています。

勾配の計算

入力の勾配について、$${式(4.7)}$$を一般化します。

$$
 \tag{5.5}
\cfrac{∂L}{∂x_{i}^{(l)}}=\displaystyle\sum_{j=1}^{p}δ_{j}^{(l)}w_{ij}^{(l)}
$$

重みの勾配について、$${式(4.13)}$$を一般化します。

$$
 \tag{5.6}
\cfrac{∂L}{∂w_{ij}^{(l)}}=x_{i}^{(l)}δ_{j}^{(l)}
$$

バイアスの勾配について、$${式(4.19)}$$を一般化します。

$$
 \tag{5.7}
\cfrac{∂L}{∂b_{j}^{(l)}}=δ_{j}^{(l)}
$$

パラメータの更新

重みの更新について、$${式(4.15)}$$を一般化します。

$$
\tag{5.8}
w_{ij}^{(l)}:=w_{ij}^{(l)}-η\,\,x_{i}^{(l)}δ_{j}^{(l)}
$$

バイアスの更新について、$${式(4.21)}$$を一般化します。

$$
\tag{5.9}
b_{j}^{(l)}:=b_{j}^{(l)}-η\,\,δ_{j}^{(l)}
$$

上の記事では$${式(5.1)〜式(5.9)}$$をベースにしてニューラルネットワーク(ニューロン版)を実装しています。ただし、スカラーの数式を忠実に実装すると無駄にforループを使う羽目になるので、そういう無駄を避けるためにベクトルで実装しています。

行列版の数式を考える

ここまでは「ネットワーク全体の中にあるたった1個の重みの更新について考える」ところから数式を作りあげました。これについては前述の通りニューロン版として実装しましたが、これは実用性よりも探究心を満足させるためのものでした。ここでは、より実用的な実装を可能にするために行列を使った数式を考えることにします。

重みを更新する

まずは、お馴染みの重み$${w_{32}^{(2)}}$$を更新することを考えます。ただし、今度は行列の要素として更新するので、重み$${w_{32}^{(2)}}$$を所有している行列$${W^{(2)}}$$を俎上に載せます。

$$
\tag{6.1}
\def\arraystretch{1.6}
W^{(2)}=\begin{pmatrix}
 w_{11}^{(2)} & w_{12}^{(2)} & w_{13}^{(2)}  \\
 w_{21}^{(2)} & w_{22}^{(2)} & w_{23}^{(2)}  \\
 w_{31}^{(2)} & \boxed{w_{32}^{(2)}} & w_{33}^{(2)}  \\
 w_{41}^{(2)} & w_{42}^{(2)} & w_{43}^{(2)} 
\end{pmatrix}
$$

次は、重み$${w_{32}^{(2)}}$$を更新するのに必要な重みの勾配です。重みの勾配は$${式(4.13)}$$にある通り$${{∂L}/{∂w_{32}^{(2)}}=x_{3}^{(2)}δ_{2}^{(2)}}$$なので、$${x_{3}^{(2)}}$$と$${δ_{2}^{(2)}}$$を所有しているベクトルをそれぞれ俎上に載せます。

$$
\tag{6.2}
\vec{x}^{(2)}=(x_{1}^{(2)},\,x_{2}^{(2)},\,\boxed{x_{3}^{(2)}},\,x_{4}^{(2)})
$$

$$
\tag{6.3}
\vec{δ}^{(2)}=(δ_{1}^{(2)},\,\boxed{δ_{2}^{(2)}},\,δ_{3}^{(2)})
$$

ここでベクトル$${\vec{x}^{(2)}}$$とベクトル$${\vec{δ}^{(2)}}$$のそれぞれを 4×1 と 1×3 の行列に見立ててドット積(内積)を計算します(ベクトル$${\vec{x}^{(2)}}$$を転置することに注意)。こうすることにより$${W^{(2)}}$$と同じ形状(4×3)の行列のどこかに$${x_{3}^{(2)}δ_{2}^{(2)}}$$が現れるはずです。 

$$
\def\arraystretch{1.6}
\tag{6.4}
\begin{aligned}
W_{grad}^{(2)}&=\vec{x}^{(2)\mathrm{T}}\cdot\vec{δ}^{(2)} \\[5pt]
&=\begin{pmatrix}
x_{1}^{(2)}\\x_{2}^{(2)}\\\boxed{x_{3}^{(2)}}\\x_{4}^{(2)}
\end{pmatrix}\cdot (δ_{1}^{(2)},\,\boxed{δ_{2}^{(2)}},\,δ_{3}^{(2)}) \\[40pt]
&=\begin{pmatrix}
 {x_{1}^{(2)}}{δ_{1}^{(2)}} & {x_{1}^{(2)}}{δ_{2}^{(2)}} & {x_{1}^{(2)}}{δ_{3}^{(2)}}  \\
 {x_{2}^{(2)}}{δ_{1}^{(2)}} & {x_{2}^{(2)}}{δ_{2}^{(2)}} & {x_{2}^{(2)}}{δ_{3}^{(2)}}  \\
 {x_{3}^{(2)}}{δ_{1}^{(2)}} & \boxed{{x_{3}^{(2)}}{δ_{2}^{(2)}}} & {x_{3}^{(2)}}{δ_{3}^{(2)}}  \\
 {x_{4}^{(2)}}{δ_{1}^{(2)}} & {x_{4}^{(2)}}{δ_{2}^{(2)}} & {x_{4}^{(2)}}{δ_{3}^{(2)}} 
\end{pmatrix}
\end{aligned}
$$

念のため Python(SymPy)でも確認しておきます。

何と都合の良いことでしょう。$${x_{3}^{(2)}δ_{2}^{(2)}}$$(Pythonでは$${d_2x_3}$$)が重み$${w_{32}^{(2)}}$$と同じ位置$${(3, 2)}$$に現れました。これは何の細工もせずに勾配降下法の式に当てはめるだけで重みを更新できるということです。

$${式(6.4)}$$は実質的には行列による演算なので記号をベクトルから行列に置き換えます。

$$
\tag{6.5}
W_{grad}^{(2)}=X^{(2)\mathrm{T}}\cdot Δ^{(2)}
$$

勾配降下法の式に当てはめます。

$$
\tag{6.6}
W^{(2)}:=W^{(2)}-η\,W_{grad}^{(2)}
$$

これで行列$${W^{(2)}}$$の要素としての重み$${w_{32}^{(2)}}$$も更新できます。

バイアスを更新する

重みの場合は行列が更新対象でしたが、バイアスの場合はベクトルが更新対象になります。それでは、再びバイアス$${b_{2}^{(2)}}$$を更新することを考えることにしましょう。まずは、バイアス$${b_{2}^{(2)}}$$を所有しているベクトル$${\vec{b}^{(2)}}$$を俎上に載せます。

$$
\tag{6.7}
\vec{b}^{(2)}=(b_{1}^{(2)},\,\boxed{b_{2}^{(2)}},\,b_{3}^{(2)})
$$

次は、バイアス$${b_{2}^{(2)}}$$を更新するのに必要なバイアスの勾配です。バイアスの勾配は$${式(4.19)}$$にある通り$${{∂L}/{∂b_{2}^{(2)}}=δ_{2}^{(2)}}$$なので、$${δ_{2}^{(2)}}$$を所有しているベクトル$${\vec{δ}^{(2)}}$$を俎上に載せます。これは$${式(6.3)}$$と同じです。

$$
\tag{6.8}
\vec{δ}^{(2)}=(δ_{1}^{(2)},\,\boxed{δ_{2}^{(2)}},\,δ_{3}^{(2)})
$$

ベクトルの形状が同じなので一見すると簡単に更新できそうです。しかし、行列版では更新対象のバイアス$${\vec{b}^{(2)}}$$が常にベクトルであるのに対して誤差は常に行列なのでした。危うく落とし穴にハマるところでした。例として$${式(6.8)}$$をバッチサイズ$${=4}$$の行列に書き直します。

$$
\def\arraystretch{1.6}
\tag{6.9}
Δ^{(2)}=\begin{pmatrix}
 δ_{11}^{(2)} & \boxed{δ_{12}^{(2)}} & δ_{13}^{(2)}  \\
 δ_{21}^{(2)} & \boxed{δ_{22}^{(2)}} & δ_{23}^{(2)}  \\
 δ_{31}^{(2)} & \boxed{δ_{32}^{(2)}} & δ_{33}^{(2)}  \\
 δ_{41}^{(2)} & \boxed{δ_{42}^{(2)}} & δ_{43}^{(2)} 
\end{pmatrix}
$$

見ての通り、単純に誤差$${Δ^{(2)}}$$の縦軸を足し合わせると更新対象の$${式(6.7)}$$と同じ形状になることが分かります。従ってバイアスの勾配$${\vec{b}_{grad}^{(2)}}$$は次のようになります。

$$
\tag{6.10}
\vec{b}_{grad}^{(2)}=\displaystyle\sum_{r=1}^{m}Δ^{(2)}[r,:]
$$

勾配降下法の式に当てはめます。

$$
\tag{6.11}
\vec{b}^{(2)}:=\vec{b}^{(2)}-η\,\,\vec{b}_{grad}^{(2)}
$$

これでベクトル$${\vec{b}^{(2)}}$$の要素としてのバイアス$${b_{2}^{(2)}}$$も更新できます。

行列版の数式

ここまでの成果を仕上げます。

出力値の計算

出力値$${u}$$について、$${式(5.1)}$$を行列に対応させます。

$$
\tag{7.1}
U^{(l)}=X^{(l)}\cdot W^{(l)}+\vec{b}^{(l)}
$$

出力値$${y}$$について、$${式(5.2)}$$を行列に対応させます。

$$
\tag{7.2}
Y^{(l)}=f^{(l)}(U^{(l)})
$$

誤差の計算

出力層の誤差について、$${式(5.3)}$$を行列に対応させます。

$$
\tag{7.3}
Δ^{(l)}=Y-T
$$

$$
(l=k)
$$

中間層の誤差について、$${式(5.4)}$$を行列に対応させます。

$$
\tag{7.4}
Δ^{(l)}=X_{grad}^{(l+1)}\odot f'^{(l)}(U^{(l)})
$$

$$
(l\lt k)
$$

勾配の計算

入力の勾配について、$${式(5.5)}$$を行列に対応させます。

$$
\tag{7.5}
X_{grad}^{(l)}=Δ^{(l)}\cdot W^{(l)\mathrm{T}}
$$

重みの勾配について、$${式(6.5)}$$を一般化します。

$$
\tag{7.6}
W_{grad}^{(l)}=X^{(l)\mathrm{T}}\cdot Δ^{(l)}
$$

バイアスの勾配について、$${式(6.10)}$$を一般化します。

$$
\tag{7.7}
\vec{b}_{grad}^{(l)}=\displaystyle\sum_{r=1}^{m}Δ^{(l)}[r,:]
$$

パラメータの更新

重みの更新について、$${式(6.6)}$$を一般化します。

$$
\tag{7.8}
W^{(l)}:=W^{(l)}-η\,W_{grad}^{(l)}
$$

バイアスの更新について、$${式(6.11)}$$を一般化します。

$$
\tag{7.9}
\vec{b}^{(l)}:= \vec{b}^{(l)}-η\,\vec{b}_{grad}^{(l)}
$$

上の記事では$${式(7.1)〜式(7.9)}$$を使ってニューラルネットワーク(行列版)を実装しています。この実装は数式と完全に対応しています。

おわりに

連鎖律を解くところから始めて、最終的には骨組みとなる数式を手に入れることができました。その過程で、普通であれば連鎖律とセットで語られる合成関数については一切触れませんでした。

ニューラルネットワークに価値のある仕事をさせるには、骨組みとなる数式とは別に、初期化関数、活性化関数、活性化関数の導関数、損失関数、評価関数などの数式も必要です。これまでに実装したニューロン版と行列版ではこれらの関数も実装した上で、よく知られたデータセットを使って演習を行っています。

付録

データフローダイアグラム(DFD)

ニューロン版の実装で使った図を再掲します。

チートシート

行列版の実装で使った図を再掲します。

この記事が気に入ったらサポートをしてみませんか?