見出し画像

画像分類(MNIST) - Haiku / Flax

最近、NumPyro 関係のネタをいろいろと探しているうちに、Haiku や Flax のコードを読む機会がありました。

JAX を使った深層学習系のライブラリとしては、Haiku / Flax / Stax / Trax など幾つかあるかと思うのですが、その中でも特にメジャーなものは Haiku や Flax あたりではないかと思います。

GitHub の星の数でいうと Haiku は 1.8k で、Flax は 2.7k なので、単純に考えれば Flax の方が圧勝なのですが、Haiku を開発した DeepMind 社のサイトを見ていると、「JAX で攻めるぞ!」といった DeepMind 社の気合が感じられて、星の数だけで「Flax が正義!」としてしまうのもどうかな…という気がしています。

特に、強化学習に関しては DeepMind 社が Haiku と同じタイミングで RLax と呼ばれる強化学習のライブラリを発表したことを考えると、Haiku の方もなかなか捨てがたい魅力があるのではないかと思っています。

一方で、 画像分野では Pretrained Model の充実度は Flax の方が上のような気がしますし、Hugging Face の Transformers が Flax をサポートし始めたことを考えると、やっぱり Flax なのかな…と思ったりもします。

そんな訳で、「Haiku と Flax だったら、どっちを学ぶべき?」と悩んでいた訳なのですが、両方のコードをいろいろいじっているうちに「これらはそれほど違いがないのでは?」という気がしてきました。

初見では Haiku と Flax のコードは結構違うように見えるのですが、結局のところ同じニューラルネットのライブラリですし、最適化には同じ Optax を使うことができますので、その部分は共通化できそうです。

そんな訳で、今回何をやってみたか?というと、画像分類(MNIST)のコードを Haiku と Flax でなるべく寄せて書いてみたらどうなるか?…ということをやってみました。

Haiku も Flax も開発者が違うので、例題などもそれぞれの開発者が好きなようにコードを書いているかと思うのですが、それを第三者の視点でなるべく近づけるようにコードを書いてみたら、どうなるか?というのを試してみています。

で、結論から言うと、結構近づけることができました。少なくとも簡単な画像分類みたいなコードであればネットワークの定義の部分以外はほぼ共通化できるみたいな感じです。

ネットワークの定義のやり方は、TensorFlow や PyTorch に慣れている方なら、ほぼ同じような感覚で扱えますので、違和感も少ないかと思いますし、Optax での最適化のやり方も非常にすっきりしていてわかり易いです。

つまり、両方覚えちゃうのもそれほど難しくはなさそうだ!というのがこの記事の主張です。

コードの方はこちらにありますので、もしよければ参考にしてみて下さい。

ちなみに、Flax の本家サイトでは TrainState というクラスを使ったやり方が紹介されているのですが、今回ご紹介するコードではこのクラスを使わずにコードを書いています。

ニューラルネットを普段からバリバリ使っている方にとってはこうしたパラメータの管理の仕方の方が利点が多いのかもしれないですが、初めて Haiku / Flax を使う方にとっては、むしろ TrainState クラスを使わない方がわかり易いのではないか?と考えて、今回こういうコードの書き方をしてみました。

少しでも皆様の参考になる部分がありましたら幸いです。


この記事が気に入ったらサポートをしてみませんか?