Kerasでカスタムレイヤーを実装する
この記事では、Kerasでカスタムレイヤーを実装してみます。
下記のコードサンプルを参考にしています。
環境はGoogle Corabです。
インストール
pip install --upgrade keras
Google CorabにはKerasがプリインストールされていますが、バージョンが3以上である必要があるためアップグレードします。
ライブラリのインポート
import keras
from keras import layers
from keras import ops
Antirectifierレイヤーの実装
カスタムレイヤーの実装を行います。このサンプルではカスタムレイヤーとして、Antirectifierレイヤーを実装しています。
ReLUとは異なり、Antirectifierレイヤーは入力の正の部分と負の部分を分離し、負の部分を絶対値に変換します。その後、両方の絶対値を連結して返します。これにより、情報をより多く保持することで、モデルの学習能力の向上を目的としたレイヤーです。
class Antirectifier(layers.Layer):
def __init__(self, initializer="he_normal", **kwargs):
super().__init__(**kwargs)
self.initializer = keras.initializers.get(initializer)
def build(self, input_shape):
output_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(output_dim * 2, output_dim),
initializer=self.initializer,
name="kernel",
trainable=True,
)
def call(self, inputs):
inputs -= ops.mean(inputs, axis=-1, keepdims=True)
pos = ops.relu(inputs)
neg = ops.relu(-inputs)
concatenated = ops.concatenate([pos, neg], axis=-1)
mixed = ops.matmul(concatenated, self.kernel)
return mixed
def get_config(self):
# Implement get_config to enable serialization. This is optional.
base_config = super().get_config()
config = {"initializer": keras.initializers.serialize(self.initializer)}
return dict(list(base_config.items()) + list(config.items()))
カスタムレイヤーの実装には下記が必要です。
状態変数を__init__またはbuild()メソッド内でadd_weight()を使って作成します。
call()メソッドを実装し、レイヤーの入力テンソルを受け取り、出力テンソルを返します。
必須でないですが、get_config()メソッドを実装することで、カスタムレイヤーのシリアライズが可能になります。
データセットを用意
batch_size = 128
num_classes = 10
epochs = 20
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
コードサンプルでは、KerasにビルトインされているMNISTデータセットを利用しています。
60000 train samples
10000 test samples
モデルの構築
model = keras.Sequential(
[
keras.Input(shape=(784,)),
layers.Dense(256),
Antirectifier(),
layers.Dense(256),
Antirectifier(),
layers.Dropout(0.5),
layers.Dense(10),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)
model.evaluate(x_test, y_test)
モデルの構築を行っています。
先ほど実装したAntirectifier()レイヤーが追加されています。
![](https://assets.st-note.com/img/1716476844469-BqP9xvJ5Qm.png?width=800)