JAXで始めるディープラーニング

こんにちは。今回はGoogleの開発するディープラーニング向けライブラリ⭐️であるJAXの紹介をします。

https://github.com/google/jax

ディープラーニング向けライブラリとしてはTensorflowやPyTorch、最近開発終了が宣言されたChainerなどが有名かと思います。これらは多次元配列と自動微分をサポートした計算ライブラリをコアとしていて、それにニューラルネットの実装を容易にするラッパーなどが付属しています。

GoogleといえばTensorflowが有名ですが、JAXはTensorflowとは何が違うのでしょうか。

JAXを一言で表現すると、高速なautogradです。

もう少し詳しくいうと、多次元配列の計算ライブラリであるnumpyに自動微分とJITがくっついたものです。さらに、GPUやTPUといったアクセラレーター上でも動作します。Tensorflowのように複雑怪奇ではなく、自分で部品から組み立てていけるのが魅力です。

簡単に使い方を見ていきましょう。まずはpipでインストールします。

pip install -U jax jaxlib

JAXとJAX用のnumpyを読み込みます。ほとんど普通のnumpyと同じですが、乱数生成まわりの操作が異なります。

import jax
import jax.numpy as np

適当な函数をつくります。

def f(x):
   return np.tanh(x) ** 2 / x

f(1.0)

# >>> DeviceArray(0.5800257, dtype=float32)

f(np.array([1.0, 2.0]))

# >>> DeviceArray([0.5800257, 0.4646746], dtype=float32)

この函数をxについて微分します。

grad_f = jax.grad(f)
grad_f(1.0)

# >>> DeviceArray(0.05967432, dtype=float32)

さらに微分します。

jax.grad(grad_f)(1.0)

# >>> DeviceArray(-0.7409754, dtype=float32)

さらにさらに微分します。

jax.grad(jax.grad(grad_f))(1.0)

# >>> DeviceArray(1.5578352, dtype=float32)

このように高階微分も容易に求められます。この時、導函数は出力がスカラーであるように入力を考える必要があります。

grad_f(np.array([1.0, 2.0]))

# >>> TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).

エラーが出てきました。でも、JAXは以下のように並列化が可能なので心配は無用です。


jax.vmap(grad_f)(np.array([1.0, 2.0]))

# >>> DeviceArray([ 0.05967432, -0.16422796], dtype=float32)

vmapを使うとバッチの処理も簡単に実装できそうですね。

さらに、込み入った計算もJITを用いて高速化することができます。


jit_vmap_grad_f = jax.jit(jax.vmap(grad_f))
jit_vmap_grad_f(np.array([1.0, 2.0]))

# >>> DeviceArray([ 0.05967433, -0.16422796], dtype=float32)

手元のCPU環境でもオリジナルのjax.vmap(grad_f)(np.array([1.0, 2.0]))と比較して8倍ほど早く計算ができました。

さらに、ヘッシアンやヤコビアンなど、他のライブラリでは求めにくい値も簡単に求められます。

def f(x):
    x, y = x
    return x ** 2 + 2 * x * y + 2 * y ** 2



jax.hessian(f)(np.array([1.0, 2.0]))

# >>> DeviceArray([[2., 2.],                                                                                                 │
                   [2., 4.]], dtype=float32)

jax.jacobian(f)(np.array([1.0, 2.0]))
# >>> DeviceArray([ 6., 10.], dtype=float32)

以上のことを組み合わせることで任意のテンソル計算に対して自動微分が可能です。一つ一つ組み合わせていくことで、ディープラーニングができますね!次回は実際にニューラルネットワークを作って学習させてみましょう!

いかがでしたか?JAXの今後の発展に期待ですね!

この記事が気に入ったらサポートをしてみませんか?