見出し画像

PyTorchとJAXに対応したKeras3でMNISTを試す

バックボーンのフレームワークを、従来のTensorFlowから、デファクトスタンダードになりつつあるPyTorchと、実行効率に優れたJAXも選べるようになったKeras3.0が公開されていたので、さっそくバックボーンをPyTorchやJAXに設定して、手書きアルファベット画像のクラス分け課題のMNISTを試してみました。


23.11.29追記
公式の紹介ページも公開されていました。
https://keras.io/keras_3/

Keras3のインストール、インポート

今回はGoogle Colabで試してみます。Keras3は現時点ではPyPI上では、プレビューリリースとしてkeras-coreの名前でインストールできます。

!pip install keras-core

バックエンドの設定(torch, jax, tensorflow)

import os
os.environ["KERAS_BACKEND"] = "torch"

keras_coreライブラリのインポート

from keras_core import datasets, layers, models

ライブラリのインポート時に選択されているバックエンドが表示されます。

Using PyTorch backend.

JAXをバックエンドにする場合

import os
os.environ["KERAS_BACKEND"] = "jax"

from keras_core import datasets, layers, models

Using JAX backend.

MNISTの学習

学習データの読み込み

# MNISTデータセットのダウンロードと準備
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

# ピクセルの値を 0~1 の間に正規化
train_images, test_images = train_images / 255.0, test_images / 255.0

モデル定義(CNN)

教科書的なCNNモデルを定義してみます

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()

モデルサマリーを表示。見やすい表形式でレイヤー毎の種類、Shape、パラメタ数が確認できます。

モデルの学習

オプティマイザーほかを設定して、学習を走らせます。

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5)

モデルの評価

test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_acc)

0.991100013256073

フレームワーク毎の実行時間の比較

Google Colab(T4)での実行時間は以下の通り。バックエンドをJAXにするとTensorFlowと比べてざっくり2倍程度、学習および推論が高速化しています。PyTorchを使った場合の最適化は今後の課題でしょうか。

実行時間の比較

  • JAX: 28秒

  • TensorFlow: 57秒

  • PyTorch: 88秒

まとめ

バックエンドの設定だけ行えば、あとは従来のKeras2の感覚で普通に使えるようです。
ディープラーニングの内部動作を学ぶとか、研究目的には素のPyTorchのほうが良い気がしますが、さくっと試すのにはやっぱりKerasがラクだとおもいます。ドキュメントの充実など、今後の展開に期待したいです。ここから巻き返して流行ってほしいなぁ。

この記事が気に入ったらサポートをしてみませんか?