見出し画像

18章 GNN:NetworkXがグラフを描画できない件

はじめに

シリーズ「Python機械学習プログラミング」の紹介

本シリーズは書籍「Python機械学習プログラミング PyTorch & scikit-learn編」(初版第1刷)に関する記事を取り扱います。
この書籍のよいところは、Pythonのコードを動かしたり、アルゴリズムの説明を読み、ときに数式を確認して、包括的に機械学習を学ぶことができることです。
Pythonで機械学習を学びたい方におすすめです!
この記事では、この書籍のことを「テキスト」と呼びます。

記事の内容

この記事は「第18章 グラフニューラルネットワーク-グラフ構造データでの依存性の捕捉」の「18.2 グラフ畳み込み」のグラフ作成と描画処理で発生するエラーの内容と対処策を紹介します。
今回はGPU非搭載のパソコンで実施しています。

18章のダイジェスト

18章では、グラフニューラルネットワーク(GNN)に挑戦します
次の図のように、グラフデータは丸で示した「ノード」(頂点)と線で示した「エッジ」(辺)によってノード間の関係性を表現するデータ構造です。
路線図、分子構造、PERT図などがグラフデータの例になります。

有向グラフの例

グラフ構造とグラフ畳み込みの概念を理解した後に、PyTorch Geometricを利用して、QM9データセットを学習して等方性分極率を予測します。
グラフ構造、グラフニューラルネットワークの概念やアウトプットのイメージがぼんやりしているので、実践を進めることが楽しみです。

QM9は、量子化学計算に基づいた機械学習用の大規模データセットです。
13万超の低分子化合物とラベルの情報が含まれています。
等方性分極率は、分子の電荷分布が外部電場によって歪められる度合い、だそうです。
次のサイトでデータセットを取得できます。


NetworkXさん、描いておくれよ・・・

Case 1:隣接行列Aの出力にて

エラーの内容
「18.2.2 基本的なグラフ畳み込みを実装する」の最初のコードでエラーが発生しました。

# グラフG、隣接行列Aの定義、Aの出力
import numpy as np
import networkx as nx
G = nx.Graph()
# グラフを描画する場合の色を表す16進数コード
blue, orange, green = "#1f77b4", "#ff7f0e", "#2ca02c"
G.add_nodes_from([(1, {"color": blue}),
                  (2, {"color": orange}),
                  (3, {"color": blue}),
                  (4, {"color": green})])
G.add_edges_from([(1,2), (2,3), (1,3), (3,4)])
A = np.asarray(nx.adjacency_matrix(G).todense())  <--- エラー発生源
print(A)

エラーの発生箇所は、隣接行列$${\boldsymbol{A}}$$に代入するこの部分です。

A = np.asarray(nx.adjacency_matrix(G).todense())
scipy.sparseのエラー

AttributeError: module 'scipy.sparse' has no attribute 'coo_array'
属性エラー:モジュール「scipy.sparse」には属性「coo_array」はありません。

SciPyは高度な科学技術計算を得意とするライブラリです。
でも、どうして今回のコードには関係のないSciPyがエラーを引き起こしているのでしょう?
次のサイトで答えを見つけました。Thank you very much!!!

エラーの原因
NetworkXの内部処理でSciPyを使っていました。
あるバージョンから両者の依存関係ができたようです。
・NetworkXのv2.7からscipy v1.8以上が必要になった。
・SciPyのv1.8からsparse.*_array関数が導入された。

テキストでは以下のバージョンを前提としています。
・原文版:NetworkX 2.6.2
・翻訳版:NetworkX 2.8.5、SciPy 1.8.0

利用環境のバージョンは次のようになっていました。
・NetworkX 2.7.1、SciPy 1.7.3

# NetworkXのインストールバージョン
print(networkx.__version__)

出力イメージ
2.7.1
# SciPyのインストールバージョン
import scipy
print(scipy.__version__)

出力イメージ
1.7.3

エラーの原因は、「NetworkX 2.7以上がインストールされているのでSciPy 1.8以上が必要だが、インストールされているSciPyは1.7.3なのでcoo_array関数が導入されていない。よってSciPyのcoo_array属性が見つからない」ことにありそうです。

対処策
SciPyのバージョンをアップデートすれば済みそうです。
と軽く考えて、次のアップデートコマンドを実行しましたが、どうしてもSciPyを1.7.3以降にバージョンアップすることができませんでした。
謎です。

# SciPyのバージョンアップ (コマンドプロンプトで実行)
conda install scipy=1.9.3

一大決心をしました。
パッケージ全部をアップデートします!
NetworkX、SciPyのバージョンが無事に変わりました。
- NetworkX 2.8.4
- SciPy 1.9.3
なお、pipでインストールできる最新版は、NetworkX3.0、SciPy1.10.0です。

# パッケージ全部のアップデート(Anaconda Promptで実行)
conda update --all

ついでに不要なパッケージやキャッシュを削除しました。

# 不要なパッケージ・キャッシュの削除(Anaconda Promptで実行)
conda clean --all

対処後の動作
隣接行列$${\boldsymbol{A}}$$の値を出力できました。

# 隣接行列Aの出力
print(A)

出力イメージ
[[0 1 1 0]
 [1 0 1 0]
 [1 1 0 1]
 [0 0 1 0]]

行列の値の意味は次のようになります。

  • 1行目:ノード1は、ノード2、ノード3と結合

  • 2行目:ノード2は、ノード1、ノード3と結合

  • 3行目:ノード3は、ノード1、ノード2、ノード4と結合

  • 4行目:ノード4は、ノード3と結合

Case 2:グラフGの描画にて

エラーの内容
「18.2.2 基本的なグラフ畳み込みを実装する」の中程のグラフを描画するコードでエラーが発生しました。

# グラフGの描画
color_map = nx.get_node_attributes(G, 'color').values()
nx.draw(G, with_labels=True, node_color=color_map)      <--- エラー発生源

エラーの発生箇所はグラフを描画するこの部分です。

nx.draw(G, with_labels=True, node_color=color_map)
_AxesStackを呼び出しできないエラー

TypeError: '_AxesStack' object is not callable
型エラー:'_AxesStack'オブジェクトを呼び出しできません。

networkxの処理の中で'_AxesStack'オブジェクトを呼び出しできない、とはどういうことを意味するのでしょう(謎すぎて意味不明です・・・)。
エラーメッセージの後段、特に112行目あたりから推測すると、「NetworkXのdraw関数の引数 ax が指定されず、かつ、_axstack()が無いときに、axにadd_axes((0, 0, 1, 1))を設定する」というような処理の途中で何か起きている感じです。
add_axes()は、描画パッケージであるmatplotlibのメソッドのようにも見えます

NetworkXのサイトでdrawのパラメータ仕様を確認してみます。

エラーの原因
引数axは、matplotlibのAxes(軸)オブジェクトでした。

draw関数のパラメータ(NetworkX公式サイトより)
draw関数のパラメータのGoogle翻訳(NetworkX公式サイトより)

どうやら、次のような問題が起きているような印象です。

  • NetworkXのdraw関数を実行時に、matplotlibのaxを設定しようとする。

  • 今回は、プログラムの内部で「_axstack()」がうまく作られないまま呼び出しされたのでエラーが発生した。

  • また、axにadd_axes((0, 0, 1, 1))を設定できていない。

これ以上の解明はスキル的に厳しそうです。。。

対処策&対処後の動作
ひとまず、draw関数の引数axにmatplotlobのAxesオブジェクトを与えてみます。

# グラフGの描画 :コード変更
import matplotlib.pyplot as plt                            # 追加
fig = plt.figure()                                         # 追加 
ax = fig.add_axes([0, 0, 1, 1])                            # 追加
color_map = nx.get_node_attributes(G, 'color').values()
nx.draw(G, ax=ax, with_labels=True, node_color=color_map)  # 変更 引数axを設定

一応、グラフを描画できました。

NetworkXのdraw関数によるグラフGの描画

でも、このコード変更は少しモヤモヤします。
matplotlibのコードを明示的に書くのは冗長な感じがします。
他の対処方法は無いのでしょうか・・・。

対処策&対処後の動作2
NetworkX公式サイトで、別の描画関数の記載を見つけました。

draw_networks関数(NetworkX公式サイトより)
draw_networks関数のGoogle翻訳(NetworkX公式サイトより)

テキストのコードのdrawの部分をdraw_networkxに置き換えて、実行しました。

# # グラフGの描画 :コード変更2
plt.axis('off')                                             # x・y軸を非表示にする
color_map = nx.get_node_attributes(G, 'color').values()
nx.draw_networkx(G, with_labels=True, node_color=color_map) # 変更 draw_networkx

無事に描画できました。

NetworkXのdraw_networkx関数によるグラフGの描画

こちらのコードの方がスマートです。
でも、draw関数のエラー解消に至ることができず、まだモヤモヤが残っています。。。
解決策をご存じの方、ぜひ教えてください!

追加:NetworkXのグラフ描画のコードサンプル

NetworkXのグラフ描画を理解したところで、「18.3 GNNをPyTorchで一から実装する」の節に登場する4つのグラフ($${G_1, G_2, G_3, G_4}$$)をNetworkXとmatplotlibを使って描画してみましょう。

グラフデータの作成コードはテキストのコードを引用しています。

# NodeNetworkモデル - グラフデータの作成
import networkx as nx
blue, orange, green = "#1f77b4", "#ff7f0e", "#2ca02c"

G1 = nx.Graph()
G1.add_nodes_from([(1,{"color": blue}),
                   (2,{"color": orange}),
                   (3,{"color": blue}),
                   (4,{"color": green})])
G1.add_edges_from([(1, 2), (2, 3), (1, 3), (3, 4)])

G2 = nx.Graph()
G2.add_nodes_from([(1,{"color": green}),
                   (2,{"color": green}),
                   (3,{"color": orange}),
                   (4,{"color": orange}),
                   (5,{"color": blue})])
G2.add_edges_from([(2, 3), (3, 4), (3, 1), (5, 1)])

G3 = nx.Graph()
G3.add_nodes_from([(1,{"color": orange}),
                   (2,{"color": orange}),
                   (3,{"color": green}),
                   (4,{"color": green}),
                   (5,{"color": blue}),
                   (6,{"color": orange})])
G3.add_edges_from([(2, 3), (3, 4), (3, 1), (5, 1), (2, 5), (6, 1)])

G4 = nx.Graph()
G4.add_nodes_from([(1,{"color": blue}),
                   (2,{"color": blue}),
                   (3,{"color": green})])
G4.add_edges_from([(1, 2), (2, 3)])

グラフを描画するサンプルコードは次のようになります。

# 4つのグラフの描画
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 8))
for i, g in enumerate([G1, G2, G3, G4]):
    color_map = nx.get_node_attributes(g, 'color').values()
    plt.subplot(2, 2, i+1, title=(f'G{i+1}: {g}'))
    plt.axis('off')
    nx.draw_networkx(g, with_labels=True, node_color=color_map)

4つのグラフデータは次のようなネットワーク構造をもっています。
グラフデータは、可視化することでグッと理解しやすくなりますね!

4つのグラフデータの描画サンプル

まとめ

今回は、NetworkXライブラリとSciPy/matplotlibとの連動がうまく動かないことに起因して発生するエラーの対処に取り組みました。
Pythonはさまざまなサードパーティライブラリの活躍によって、多くの機械学習タスクの遂行を実現しています。
サードパーティライブラリの間の整合性(たとえばバージョンの整合)を保つことが大切になります。
Python環境(Anaconda環境)の維持・保全にも力を入れなければならないことを体感しました。

# 今日の一句
print('統計検定準1級ではグラフィカルモデルを出題しています')

楽しくPython機械学習プログラミングを学びましょう!

ちなみに、次に利用するPyTorch Geometricライブラリをインストールできない件は、次回以降に取り上げます(とほほ)。

おまけ数式

noteでは数式記法を利用できます。
今回は上述のコードで実装したグラフ畳み込みの形式を紹介します。

$$
\boldsymbol{x}^{\prime}_i = \boldsymbol{x}_i \boldsymbol{W}_1 + \displaystyle \sum_{j \in N(i)} \boldsymbol{x}_j \boldsymbol{W}_2 + b
$$

$${\boldsymbol{x}^{\prime}_i}$$はノード$${i}$$の更新後の埋め込み、$${\boldsymbol{W}_1}$$、$${\boldsymbol{W}_2}$$は学習可能な重みからなる$${f_{in} \times f_{out}}$$行列、そして$${b}$$は長さ$${f_{out}}$$の学習可能なバイアスベクトルです。


おわりに

AI・機械学習の学習でおすすめの書籍を紹介いたします。
「AI・データサイエンスのための 図解でわかる数学プログラミング」

ビジネスの現場では今後、数学的知識の必要度が高くなると言われています。
この書籍は、図解によって数学的な考え方を直感的に説明し、Pythonのコードを動かしてみて計算を体感することを目的に書かれています。
カバーする領域は、確率統計、機械学習、数理最適化、数値シミュレーション、深層学習です。
7章では、NetworkXライブラリを利用してネットワークの成長を可視化することに取り組んでいます
ディープラーニング、Python、数学を一体として学習できるチャンスですね!

最後まで読んでくださり、ありがとうございました。

この記事が参加している募集

#この経験に学べ

54,275件

この記事が気に入ったらサポートをしてみませんか?