PyTorchとJAXに対応したKeras3でMNISTを試す
バックボーンのフレームワークを、従来のTensorFlowから、デファクトスタンダードになりつつあるPyTorchと、実行効率に優れたJAXも選べるようになったKeras3.0が公開されていたので、さっそくバックボーンをPyTorchやJAXに設定して、手書きアルファベット画像のクラス分け課題のMNISTを試してみました。
23.11.29追記
公式の紹介ページも公開されていました。
https://keras.io/keras_3/
Keras3のインストール、インポート
今回はGoogle Colabで試してみます。Keras3は現時点ではPyPI上では、プレビューリリースとしてkeras-coreの名前でインストールできます。
!pip install keras-core
バックエンドの設定(torch, jax, tensorflow)
import os
os.environ["KERAS_BACKEND"] = "torch"
keras_coreライブラリのインポート
from keras_core import datasets, layers, models
ライブラリのインポート時に選択されているバックエンドが表示されます。
JAXをバックエンドにする場合
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras_core import datasets, layers, models
MNISTの学習
学習データの読み込み
# MNISTデータセットのダウンロードと準備
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
# ピクセルの値を 0~1 の間に正規化
train_images, test_images = train_images / 255.0, test_images / 255.0
モデル定義(CNN)
教科書的なCNNモデルを定義してみます
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
モデルサマリーを表示。見やすい表形式でレイヤー毎の種類、Shape、パラメタ数が確認できます。
モデルの学習
オプティマイザーほかを設定して、学習を走らせます。
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
モデルの評価
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_acc)
フレームワーク毎の実行時間の比較
Google Colab(T4)での実行時間は以下の通り。バックエンドをJAXにするとTensorFlowと比べてざっくり2倍程度、学習および推論が高速化しています。PyTorchを使った場合の最適化は今後の課題でしょうか。
実行時間の比較
JAX: 28秒
TensorFlow: 57秒
PyTorch: 88秒
まとめ
バックエンドの設定だけ行えば、あとは従来のKeras2の感覚で普通に使えるようです。
ディープラーニングの内部動作を学ぶとか、研究目的には素のPyTorchのほうが良い気がしますが、さくっと試すのにはやっぱりKerasがラクだとおもいます。ドキュメントの充実など、今後の展開に期待したいです。ここから巻き返して流行ってほしいなぁ。
この記事が気に入ったらサポートをしてみませんか?