Kerasでカスタムレイヤーを実装する

この記事では、Kerasでカスタムレイヤーを実装してみます。
下記のコードサンプルを参考にしています。

環境はGoogle Corabです。

インストール

pip install --upgrade keras

Google CorabにはKerasがプリインストールされていますが、バージョンが3以上である必要があるためアップグレードします。

ライブラリのインポート

import keras
from keras import layers
from keras import ops

Antirectifierレイヤーの実装

カスタムレイヤーの実装を行います。このサンプルではカスタムレイヤーとして、Antirectifierレイヤーを実装しています。

ReLUとは異なり、Antirectifierレイヤーは入力の正の部分と負の部分を分離し、負の部分を絶対値に変換します。その後、両方の絶対値を連結して返します。これにより、情報をより多く保持することで、モデルの学習能力の向上を目的としたレイヤーです。

class Antirectifier(layers.Layer):
    def __init__(self, initializer="he_normal", **kwargs):
        super().__init__(**kwargs)
        self.initializer = keras.initializers.get(initializer)

    def build(self, input_shape):
        output_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(output_dim * 2, output_dim),
            initializer=self.initializer,
            name="kernel",
            trainable=True,
        )

    def call(self, inputs):
        inputs -= ops.mean(inputs, axis=-1, keepdims=True)
        pos = ops.relu(inputs)
        neg = ops.relu(-inputs)
        concatenated = ops.concatenate([pos, neg], axis=-1)
        mixed = ops.matmul(concatenated, self.kernel)
        return mixed

    def get_config(self):
        # Implement get_config to enable serialization. This is optional.
        base_config = super().get_config()
        config = {"initializer": keras.initializers.serialize(self.initializer)}
        return dict(list(base_config.items()) + list(config.items()))

カスタムレイヤーの実装には下記が必要です。

  • 状態変数を__init__またはbuild()メソッド内でadd_weight()を使って作成します。

  • call()メソッドを実装し、レイヤーの入力テンソルを受け取り、出力テンソルを返します。

  • 必須でないですが、get_config()メソッドを実装することで、カスタムレイヤーのシリアライズが可能になります。

データセットを用意

batch_size = 128
num_classes = 10
epochs = 20

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255

print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

コードサンプルでは、KerasにビルトインされているMNISTデータセットを利用しています。

60000 train samples
10000 test samples

モデルの構築

model = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        layers.Dense(256),
        Antirectifier(),
        layers.Dense(256),
        Antirectifier(),
        layers.Dropout(0.5),
        layers.Dense(10),
    ]
)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)

model.evaluate(x_test, y_test)

モデルの構築を行っています。
先ほど実装したAntirectifier()レイヤーが追加されています。