PyTorchが計算グラフを作成する仕組み

小さな小さなPyTorchクローンを作ろうと思ったので、PyTorchの計算グラフの仕組みを調べました。

公式ブログに分かりやすいGIFアニメがあったので、メモに残します。


1.requires_gradなテンソルを関数に入力

2.grad_fn(勾配関数)ノードが作成される

3.「collect_next_edges」が入力関数からgrad_fnにリンクするエッジを作成

6.同様にして関数Bを作成

7.3つのテンソルを関数Cへ入力

9.「collect_next_edges」が入力の関数を調べて、grad_fnへのエッジを作成する

10.終了。

この記事が気に入ったらサポートをしてみませんか?