PyTorchが計算グラフを作成する仕組み
小さな小さなPyTorchクローンを作ろうと思ったので、PyTorchの計算グラフの仕組みを調べました。
公式ブログに分かりやすいGIFアニメがあったので、メモに残します。
1.requires_gradなテンソルを関数に入力
![](https://assets.st-note.com/img/1680961539327-jHAiYJjnhN.png?width=800)
2.grad_fn(勾配関数)ノードが作成される
![](https://assets.st-note.com/img/1680961539468-YfIKvuZ6F3.png?width=800)
3.「collect_next_edges」が入力関数からgrad_fnにリンクするエッジを作成
![](https://assets.st-note.com/img/1680961539647-r7pVeZs3Sw.png?width=800)
![](https://assets.st-note.com/img/1680961539794-7O7CKQhDHn.png?width=800)
![](https://assets.st-note.com/img/1680961539956-BD82ks1nT3.png?width=800)
6.同様にして関数Bを作成
![](https://assets.st-note.com/img/1680961540125-E19vKy3fJA.png?width=800)
![](https://assets.st-note.com/img/1680961540255-xOV40gh8FI.png?width=800)
![](https://assets.st-note.com/img/1680961540451-FfVIPs0Nuu.png?width=800)
![](https://assets.st-note.com/img/1680961540551-MkFgdRhw8y.png?width=800)
7.3つのテンソルを関数Cへ入力
![](https://assets.st-note.com/img/1680961540710-iDKol9pEiL.png?width=800)
![](https://assets.st-note.com/img/1680961540872-ktepkYxuty.png?width=800)
9.「collect_next_edges」が入力の関数を調べて、grad_fnへのエッジを作成する
![](https://assets.st-note.com/img/1680961541185-FJStu1DiQa.png?width=800)
![](https://assets.st-note.com/img/1680961541328-EvcYkPlJZ9.png?width=800)
![](https://assets.st-note.com/img/1680961541481-w7RamqTRHk.png?width=800)
10.終了。
![](https://assets.st-note.com/img/1680961541632-TqNbEK8VdZ.png?width=800)
この記事が気に入ったらサポートをしてみませんか?