PyTorch Moduleに紐づく定数のtensorを定義する

環境

python 3.6.6
IPython 7.1.1
torch==1.7.0

目的

PyTorchで`torch.nn.Module`に紐づく定数を定義したい場合があるとします.`torch.nn.parameter.Parameter`で定義せず`torch.tensor`に`require_grad=False`にすれば,`optimizer`のstepメソッドでは学習はされませんが,toメソッドでGPUに乗せたり,data typeを変えることができません.

解決方法

`torch.nn.Module.register_buffer`メソッドを用いれば良い.

例えば,指数関数e^xがあり,底の部分 (e) は定数として扱い,指数の部分 (x) のみを変数として扱いたい場合は,以下のように書くことができます.

import torch
import numpy as np

class MyModule(torch.nn.Module):
   def __init__(self):
       super().__init__()
       self.exponent = torch.nn.parameter.Parameter(torch.Tensor([1]))
       self.e = torch.tensor([np.e])
       self.register_buffer('e_const', self.e)
       
   def forward(self, x):
       x = x * (self.e_const).pow(self.exponent)
       return x

​実際にこのモデルをインスタンス化させて,`parameters`メソッドを読んでも,`e_const`が呼ばれないが,`to`メソッドで`e_const`のdtypeが変更されることを見てみます(localでGPU環境がないのでdeviceの移動は確認できませんでした).

parameters()

m = MyModule()
print(list(m.parameters()))

結果は,

[Parameter containing:
tensor([1.], requires_grad=True)]

確かに`self.exponent`のみです.

to()

次に`to`メソッドで`e_const`が変更されるか見てみます.

print(f"e_const.dtype is {m.e_const.dtype}")
print(f"e.dtype is {m.e.dtype}")
print(f"exponent.dtype is {m.exponent.dtype}")

出力は,

e_const.dtype is torch.float32
e.dtype is torch.float32
exponent.dtype is torch.float32

最初は,`torch.float32`になっています.

ここで,`to`メソッドで,`torch.float64`にしてみます.

m.to(dtype=torch.float64)

もう一度,各パラメタのdtypeを出力してみると,

e_const.dtype is torch.float64
e.dtype is torch.float32
exponent.dtype is torch.float64

`parameters()`には入っていないが,`buffer_register`した`e_const`はちゃんと`torch.float64`に変更されています.

この記事が気に入ったらサポートをしてみませんか?