畳み込み層と全結合層の関係とLoRAの畳み込みへの拡張について

 深層学習の学習資料ではだいたい全結合型のニューラルネットワークの話から、畳み込みニューラルネットワークへと移っていくのが主流ですが、よく考えてみると畳み込みニューラルネットワークは本当にニューラルネットワークなの?という疑問を抱くようなものがほとんどだと思います。そのような話を分かりやすく解説している記事がありました。

 私の記事読むより上の記事を読んだ方が良いと思いますが、読まない人向けにまとめると、畳み込みニューラルネットワークはニューロン同士のつながりに制約をつけたものになります。具体的にはフィルターが届く範囲のニューロンしかつながらないネットワークです。(追記:重み共有もあります)
 今回は畳み込み層が全結合層に制約をつけたものなら、畳み込み層は全結合層で再現できる!ということを確認していきます。

畳み込み層の重みを全結合層の重みに変換する

 畳み込み層(2D)の入出力はチャンネル、縦、横の3次元ですが、全結合層はそれらを全部フラットにして1次元ベクトルとして扱います。
 畳み込み層の重みを全結合層の重みに変換するのは割と簡単に考えられて、全結合層(biasなし)に単位行列を入力すると、全結合層の重みがそのまま出力されるので、畳み込み層に単位行列を入力するだけです。biasはまあbiasなので簡単です。

 数式的には、$${Conv(I) = WI = W}$$です。ただし単位行列は列ごとに画像の形にする必要があります。

 もうちょっというと行列とは、標準基底の像なので、何らかの線形写像に標準基底を入力すればその写像に対応する行列(表現行列というらしい)が得られるという自然な発想です。

実装は多分以下のような感じです。

from torch.nn.parameter import Parameter
import torch
import numpy as np

def conv2linear(conv, in_channels, in_height, in_width):
    eye_matrix = torch.eye(in_channels * in_height * in_width).reshape(-1, in_channels, in_height, in_width)
    conv_output = torch.nn.functional.conv2d(eye_matrix, conv.weight, bias=None, stride=conv.stride, padding=conv.padding)
    _, out_channels, out_height, out_width = conv_output.shape
    fc_weights = conv_output.reshape(-1, out_channels * out_height * out_width).t()

    linear = torch.nn.Linear(in_channels * in_height * in_width, out_channels * out_height * out_width)
    linear.weight = Parameter(fc_weights)

    bias_repeated = torch.tensor(np.repeat(conv.bias.detach().numpy(), out_height * out_width))
    linear.bias = Parameter(bias_repeated)
    return linear

実際試してみると、うまくいってました。

import torch
in_channels = 3
in_height = in_width = 6

x = torch.randn(1,in_channels,in_height,in_width)

conv = torch.nn.Conv2d(3,1,3,1,1)
conv_out = conv(x)
print(conv_out)

linear = conv2linear(conv, in_channels, in_height, in_width)
print(linear(x.reshape(1,-1)).reshape(conv_out.shape))

VGG16でやってみ・・・無理やん!!!

 最初は事前学習済みVGG16の畳み込み層を全結合層にしてみようとしてみました。しかしGPT4に定義させてロードしようとしたらメモリエラーになってしまいます。そこでパラメータ数いくつあんねん!とGPT4に質問したら、計算式を返してくれました。

def compute_params(input_sizes, output_sizes):
    params = [in_size * out_size for in_size, out_size in zip(input_sizes, output_sizes)]
    return sum(params)

# 線形層の入力サイズ
input_sizes = [3 * 224 * 224, 64 * 224 * 224, 128 * 112 * 112, 128 * 112 * 112, 256 * 56 * 56, 256 * 56 * 56, 512 * 28 * 28, 512 * 28 * 28, 512 * 14 * 14, 512 * 14 * 14]

# 線形層の出力サイズ
output_sizes = [64 * 224 * 224, 64 * 224 * 224, 128 * 112 * 112, 128 * 112 * 112, 256 * 56 * 56, 256 * 56 * 56, 512 * 28 * 28, 512 * 28 * 28, 512 * 14 * 14, 512 * 14 * 14]

# 各線形層のパラメータ数を計算
params_features = compute_params(input_sizes, output_sizes)

# 分類器のパラメータ数を計算
params_classifier = (512 * 7 * 7) * 4096 + 4096 * 4096 + 4096 * 1000

# 合計パラメータ数
total_params = params_features + params_classifier

 biasがないけどほとんどあってそうですね。計算してみると、17兆5832億5837万48でした・・・
ChatGPT(GPT-3.x)が1750億パラメータなのでちょうどその100倍かあ。VGG16自体は1億3,800万パラメータらしいです。10万倍以上に膨れ上がるわけですね。畳み込み層の制限というのはそれほど大きいものなんですねえ。

LoCon(LoRA-C3liar)  

 LoRAは線形層を2つに分解してパラメータ数を減らす方法です。学習時はモデルを凍結してLoRAだけ学習し、生成時にマージします(しなくてもいいけど)。

import torch.nn as nn
in_features = 10
out_features = 10
rank = 2

linear = nn.Linear(in_features, out_features, bias = False)

lora_down = nn.Linear(in_features, rank, bias = False)
lora_up = nn.Linear(rank, out_features, bias = False)

print(linear.weight.numel()) # 100
print(lora_down.weight.numel(),lora_up.weight.numel()) # 20,20

#推論
x = torch.randn(1,10)
y = linear(x) + lora_up(lora_down(x))
print(y)

#元の重みへのマージ
merge_weight = lora_up.weight @ lora_down.weight
linear.weight = Parameter(linear.weight + merge_weight)
print(linear(x))

 LoRAを畳み込み層にも拡張したものがLoConになります。LoRAの強みとして、元の重みへマージできるというものがあるので、畳み込み層に拡張する場合もその性質を維持していないとLoRAの拡張とは言えないでしょう。
 最初に述べた通り畳み込み層はフィルターが届く範囲にしかニューロン同士がつながらない(追記:重みが共有されるという条件もある)という制約がついたネットワークです。そのためLoRAのように二層に分けるとき、その制約を守らないと元の重みにマージできなくなります(逆に守っていればマージできる)。Stable Diffsionの畳み込み層のフィルターは1×1か3×3です。ここで1×1の畳み込み層は実質各ピクセルに全結合層を適用しているだけなので、元のLoRAと同じように実装できます(Kohya氏によるLoRAは最初からそうしていました)。そのため3×3フィルターについて考えていきましょう。
 まず、down層とup層を両方3×3フィルターの畳み込み層にすることはできないというのがすぐわかると思います。3×3フィルターを二層適用すると、ニューロンがつながる範囲が5×5までに拡大されてしまうからです。

cloneofsimo氏による畳み込み層への拡張

 畳み込み層への拡張を最初に試したのはLoRAをStable Diffusionに適用したcloneofsimo氏です。

 cloneofsimo氏はdown層のフィルターを3×3、up層のフィルターを1×1にしました。この二層を適用しても、ニューロンがつながる範囲は3×3で変わらないため、制約を守れていますね。以下が実装例です。マージ部分は後程紹介するLyCORISの実装を参考にしています。

import torch.nn as nn
in_features = 10
out_features = 10
rank = 2

padding = 1
stride = 1

conv = nn.Conv2d(in_features, out_features, bias = False, kernel_size=3, padding = padding, stride = stride)

lora_down = nn.Conv2d(in_features, rank, bias = False, kernel_size=3, padding = padding, stride = stride)
lora_up = nn.Conv2d(rank, out_features, bias = False, kernel_size=1, padding = 0)

print(conv.weight.numel()) # 900
print(lora_down.weight.numel(),lora_up.weight.numel()) # 180 20

#推論
x = torch.randn(1,10,12,12)
y = conv(x) + lora_up(lora_down(x))
print(y[0][0])

#元の重みへのマージ
# (in, rank) @ (rank, out * kernel_size * kernel_size) -> (in, out * kernel_size * kernel_size)
merge_weight = lora_up.weight.squeeze((-1,-2)) @ lora_down.weight.reshape(rank, -1)
merge_weight = merge_weight.reshape(conv.weight.shape)
conv.weight = Parameter(conv.weight + merge_weight)
print(conv(x)[0][0])

 マージの部分で何してるか分かりづらいですね。簡単にいっちゃうとフィルターに対してup層を適用しています。フィルターを通してからup層を適用するのと、フィルターにup層を適用したものを通すのが同じであるという発想です。

LyCORISのCP分解による実装

 LyCORISではCP分解なるものを使った畳み込みへの拡張が存在します。3層に分けます。

 1×1Convでチャンネルをrankに圧縮した後、チャンネル数を保って3×3Convを適用し、1×1Convで出力チャンネルに拡張する方法です。これも制約を保っていますね。以下が実装例です。

in_features = 10
out_features = 10
rank = 2

padding = 1
stride = 1

conv = nn.Conv2d(in_features, out_features, bias = False, kernel_size=3, padding = padding, stride = stride)

lora_down = nn.Conv2d(in_features, rank, bias = False, kernel_size=1)
lora_mid = nn.Conv2d(rank, rank, bias = False, kernel_size=3, padding = padding, stride = stride)
lora_up = nn.Conv2d(rank, out_features, bias = False, kernel_size=1)

print(conv.weight.numel()) # 900
print(lora_down.weight.numel(),lora_mid.weight.numel(),lora_up.weight.numel()) # 20 36 20

#推論
x = torch.randn(1,10,12,12)
y = conv(x) + lora_up(lora_mid(lora_down(x)))
print(y[0][0])

#元の重みへのマージ
# (in, rank) @ (rank, rank * kernel_size * kernel_size) -> (in, rank * kernel_size * kernel_size)
merge_weight = lora_down.weight.squeeze((-1,-2)).t() @ lora_mid.weight.permute(1,0,2,3).reshape(rank, -1)
merge_weight = merge_weight.reshape(in_features,rank,3,3).permute(1,0,2,3).reshape(rank,in_features,3,3)

# (out, rank) @ (rank, in * kernel_size * kernel_size) -> (out, in * kernel_size * kernel_size)
merge_weight = lora_up.weight.squeeze((-1,-2)) @ merge_weight.reshape(rank, -1)
merge_weight = merge_weight.reshape(conv.weight.shape)
conv.weight = Parameter(conv.weight + merge_weight)
print(conv(x)[0][0])

 マージはさっきみたいなのを二回やるだけですね。3×3Convのチャンネル数が減っている分パラメータ数が低いです。

二層の2×2Convによる実装?

 2×2フィルター二層でも、制約を満たしているのでは、と考えたので実装してみます。

in_features = 10
out_features = 10
rank = 2
padding = 2
stride = 1
conv = nn.Conv2d(in_features, out_features, bias = False, kernel_size=3, padding = padding,stride=stride)

lora_down = nn.Conv2d(in_features, rank, bias = False, kernel_size=2, padding=padding)
lora_up = nn.Conv2d(rank, out_features, bias = False, kernel_size=2, stride=stride)

print(conv.weight.numel()) # 900
print(lora_down.weight.numel(),lora_up.weight.numel()) # 80 80

#推論
x = torch.randn(1,10,12,12)
y = conv(x) + lora_up(lora_down(x))
print(y[0][0])

#元の重みへのマージ
merge_weight = torch.nn.functional.conv_transpose2d(lora_up.weight,lora_down.weight,padding=0)
conv.weight = Parameter(conv.weight + merge_weight)
print(conv(x)[0][0])

 なんかできそうですね。マージはup層のフィルターをdown層のフィルターで転置畳み込みをかけて3×3に拡張することによって実装できます。paddingをdown層で行うのは通常の方法と同じですが、strideはup層で行います。
 偶数サイズのフィルターは真ん中がないから避けられるといった話を聞きますが、マージできることからわかる通り二層あれば真ん中ありますし、面白そうではありますね。

各実装の性能

 3×3の範囲のニューロンしかつながらないという制約さえ守れば様々な実装があることが分かりました。実装しませんでしたが、down層が1×1でup層で3×3にするという方法も素朴に考えられますよね。ではどれが一番性能がいいんでしょうか?rankが十分低ければパラメータ数が最小なのはCP分解を使ったやり方なんでしょうね。性能はなにか実験結果とかあるんですかね?

まとめ

  • 畳み込みニューラルネットワークは全結合ニューラルネットワークで表すことができる。

  • LoRAの畳み込み層への拡張はいろいろ考えられそうだが、どれがいいのかよくわかんない

  • 畳み込みとかいう図がないとイメージしづらい話なのに図が一切ないのはよくない