LLMで同じレイヤーを繰り返したりする

 5月中に記事を書けとNote様に脅されたので、最近やったことを書く。

 LLM界隈では様々なモデルのレイヤーを重ね合わせることで性能を上げる試みをしているようです。なんと自分自身のレイヤーを重ねることで性能を上げる例もあるようです。よくわかっていませんが、ようするに同じレイヤーを何度も計算するだけだと思うので、forwardの内容を変えればVRAM使用量を変えずに同じことできんじゃねと思ったのが今回の発想です。

 たとえばLlama-70Bを120Bまで拡張したというモデルがあるらしいです。

command-r-v01の4bit版を使ってみます。

forward書き換え

 各レイヤーとその重みづけのタプルを処理する順番に並べたリストを作ります。以下の例はは0から39まで重み1という実質元のモデルと同じ計算です。

model.model.idx_weight_list = [(i, 1) for i in range(40)]

以下のような感じでforwardを上書きしちゃいます。

from transformers.models.cohere.modeling_cohere import *

def hook(self):
    def forward(
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for i, (idx, weight) in enumerate(self.idx_weight_list):
            decoder_layer = self.layers[idx]
            decoder_layer.self_attn.layer_idx = i
                    
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states += weight * (layer_outputs[0] - hidden_states)

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
    return forward

model.model.forward = hook(model.model)

 ほぼコピペですが、いくつか変更があります。
 まず当たり前ですが、for文を変えて指定した順番でレイヤーを計算します。

for i, (idx, weight) in enumerate(self.idx_weight_list):
    decoder_layer = self.layers[idx]

 次にAttention層のlayer_idxを上書きします。これはkv_cacheの保存や読み込みに使われるので、何番目に計算されたかを教えてあげます。ほかにもなんか落とし穴があるかもしれない・・・。

decoder_layer.self_attn.layer_idx = i

 最後に入力と出力の差分をとってweightをかけて足し合わせます。各レイヤーのforwardを書き換えれば差分取らずに済みますが、めんどくさーい。

hidden_states += weight * (layer_outputs[0] - hidden_states)

 今回の内容はこれで終わりです。以下は何も面白くない実験結果です。

実験

 今回は将棋観戦におけるAI評価表示の是非を聞いてみます。

messages = [{"role": "system", "content": "あなたは優秀なアシスタントです。"}, {"role": "user", "content": "プロ棋士による対局の観戦で、AIによる形勢判断や差し手予測を表示することは良いことだと思う?"}]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")

普通の設定

model.model.idx_weight_list = [(i, 1) for i in range(40)]
gen_tokens = model.generate(input_ids, max_new_tokens=512, do_sample=False, streamer=streamer)

プロ棋士による対局において、AIが形勢判断や差し手予測を表示することは、良いことでもあるし、悪いことでもあると考える人がいるでしょう。どちらの意見にも一理があります。以下に、それぞれについて説明します。

良い点:

・観戦者の理解を深める:AIの分析は、観戦している人たちにとって、対局の理解を深めるのに役立ちます。特に、アマチュア棋士や初心者の観戦者は、プロの棋士がどのような考えのもとに指し手を選んでいるのかを理解するのに役立ちます。
・棋戦の盛り上がり:AIの予測は、棋戦の盛り上がりや興奮を高めるのに貢献するかもしれません。AIが予測した手が、実際の棋士の指し手と一致した時など、観戦者にとって盛り上がる瞬間があるでしょう。
・研究に役立つ:プロ棋士にとっては、対局後の分析にAIを活用でき、自身の棋力向上に役立つでしょう。

悪い点:

・棋士の思考妨害:対局中のAIの活用は、棋士の集中を妨げる可能性があります。AIの分析が気になって、自身の考えに集中できなくなる恐れもあります。
・AIへの依存:AIの分析に頼るようになり、自身の判断力を鈍らせる恐れがあります。
・棋士の価値低下:AIが強すぎて、プロ棋士の腕前を正確に判断できなくなるかもしれません。AIの分析が正確すぎるがゆえに、棋士の判断や技術を軽視する風潮が生まれてしまうかもしれません。

最終的に、AIの活用は、対局者や主催者側が、観戦者に対する説明責任や、対局をより盛り上げたいという目的に合致していれば良いことだと言えるでしょう。AIの活用方法を慎重に考え、対局を妨げない範囲で導入するのが望ましいのではないでしょうか。

ふつーですね。

レイヤーを重ねてみる

0-10, 5-15, …みたいなことをしてみます。

idx_weight_list = []
for i in range(4):
    idx_weight_list += [(i, 1) for i in range(i*10, i*10+10)]
    if i < 3:
        idx_weight_list += [(i, 1) for i in range(i*10 + 5, i*10+10)]

プロ棋士による対局をより盛り上げ、観戦する人々に理解を深めてもらうという点では、AIによる形勢判断や差し手予測は有益なツールであると言えるでしょう。

観戦者が対局の展開をリアルタイムで理解するのを助けるだけでなく、AIの解説は、プロ棋士の思考を覗き見るような魅力もあり、対局の質や奥深さをより一層際立たせてくれるでしょう。

また、若手棋士やアマチュア棋士にとっては、自分の理解をテストしたり、学習の参考にすることもできます。

一方、気をつけなければならないのは、AIが絶対であるという印象を与えてはいけないことです。AIはあくまでもツールの一つであり、人間のような感性や直感は持ち合わせていません。AIに頼りすぎるあまり、人間ならではの創意や工夫が失われてしまうことも懸念されます。

理想的には、プロ棋士の解説とAIの解説を併用するなどして、双方の強みを生かした対局観戦が提供されると、観戦者にとってより豊かな経験になるのではないでしょうか。

 なんかびみょーになりました。

重み付けしてみる

 オーバーラップ分は重みを減らしてみます。

idx_weight_list = []
for i in range(4):
    idx_weight_list += [(i, 1) for i in range(i*10, i*10+10)]
    if i < 3:
        idx_weight_list += [(i, 0.5) for i in range(i*10 + 5, i*10+10)]

プロ棋士による対局において、AIによる形勢判断や差し手予測を表示することは、議論の的となる話題です。ある側面からは有益な面がある一方で、否定的にとらえる意見もあります。ここでは、複数の視点から考えてみましょう。

良い点:

1. 観戦体験の向上:一般の観戦者にとって、AIは対局の理解を手助けしてくれます。AIの形勢判断や最善手予想は、観戦者の理解を深め、より楽しめるようにしてくれます。

2. 棋士の成長に役立つ可能性:プロ棋士自身もAIを研究に利用し、自身の棋力向上に役立たせることができます。対局後の分析にAIを利用することで、より正確な反省が可能になり、今後の対局に生かせるでしょう。

3. 棋界の活性化:AIの導入は、新たな話題を提供し、将棋や囲碁など棋士の対局に注目を集める効果があるでしょう。新しい技術への関心は、棋界を盛り上げ、若手棋士の育成にもつながる可能性があります。

一方、悪い点も:

1. 棋士への依存と創意工夫の阻害:AIに頼りすぎると、棋士自身の思考が鈍るという懸念があります。AIの判断に頼るあまり、独自の思考や創意工夫が損なわれる恐れがあります。

2. 不公平さの可能性:AIの利用度合いは棋士によって異なるでしょう。全ての棋士が同じレベルのAIアクセスを持つわけではなく、格差が生じると公平な競争とはなりません。

3. AIの正確性への過度な依存:AIの判断が絶対であるかのような印象を与えてしまう恐れがあります。しかし、AIにも限界と不完全さはあり、その予測が常に正しいわけではありません。AIに依存しすぎると、実際の対局で判断力を失う可能性もあります。

結論として、AIの利用は慎重にバランスをとる必要があるでしょう。観戦者や棋士の成長のために役立つ面がある一方、棋士の創意工夫や対局中の判断力を妨げないよう注意が必要です。また、全ての関係者が公平にアクセスできるような環境整備も重要です。AIを補助的に利用することで、将棋や囲碁など棋士の対局をより一層盛り上げ、より深い戦略や思考が披露されるようになれば理想的ではないでしょうか。

 ながいですね。利用できるAIのレベルで棋士間に格差が生じるみたいな視点はなかなか鋭いですが、観戦について質問したので話題が反れているような気がします。

レイヤーをスキップしてみる

 なんとなく5の倍数だけスキップして計算してみます。

model.model.idx_weight_list = [(i, 1) for i in range(40) if i not in [j * 5 for j in range(1, 8)]]

AI技術は、プロ棋士の対局で使われるのは良いことだと思われますが、開発者は、人の知能を尊重し、AIの能力を促進するために、人間とほぼ同じ作業をさせるべきではないと感じているようです。

AIは、プロの試合の形勢判断や予測の提示など、対局に関わる一部の決定を実行することができます。しかし、AIは、人間と同じように、プロの試合を楽しむ能力を持つことができます。AIの分析能力は、プロの試合の理解を深化させ、より多くの選択肢やより良い決定を提案できるように設計されました。

AIの分析能力は、プロの試合の結果を予測する新しい方法を開拓するために使用できますが、開発者は通常、結果や決定を予測するために使用されるものではありません。単に結果を予測するだけでなく、AIは、人間と機械の対局で、より多くの選択肢やより良い決定を提案できるように、プロの試合の理解を深化させるように設計されています。

AIの分析能力は、プロの試合の結果を予測するために使用されるべきではなく、人間と機械の対局で、より良い決定を提案できるように、プロの試合の理解を深化させるように設計されていると考えています。 さらに、AIは、プロの試合の結果を予測する新しい方法を開拓するために使用され、結果や決定を予測するために使用されるべきではなく、開発者は通常、人間と機械の対局です。
要約すると、AIは、プロの試合の結果を予測するよりも、プロの試合の結果を深化させるように設計された有用なツールです。AIの分析能力は、プロの試合の結果を予測するために使用されるべきではなく、より良い決定を提案できるように、人間と機械の対局で、より良い決定を提案できるように設計されていると信じられているのです。

 日本語なのに何言ってるか全くわかりませんね。

Perplexityの計算をしてみる

以下の記事を丸パクリして計算してみました。

 各レイヤーをスキップしたときのPerplexityの変化を追ってみます。wikitextのtrainデータのほうで30000トークンだけ評価しました。

スキップなし:7.58889102935791
0:101913.187500
1:51271.070312
2:9.176620
3:8.538880
4:8.560865
5:8.325228
6:8.286266
7:8.338334
8:8.154686
9:8.185360
10:8.030475
11:8.142550
12:8.100570
13:8.040795
14:8.040329
15:8.059035
16:8.016946
17:7.972836
18:8.097392
19:8.095753
20:8.221308
21:8.105900
22:8.208970
23:8.292737
24:8.237318
25:8.289410
26:8.213966
27:8.243610
28:8.352770
29:8.238995
30:8.263823
31:8.250707
32:8.295854
33:8.350245
34:8.347494
35:8.523391
36:8.589978
37:10.522139
38:11.041298
39:17.561686

 入出力に近い層が特に劣化する感じですね。10~20層くらいが結構役に立っていないのかもしれませんが、どちらにせよスキップなしに比べたら劣化しているようです。ただperplexityじゃinstruct modelの会話能力をあまり測れなそうなので(だからみんな頑張って評価指標つくってる)参考にならないっすね。量子化誤差をとかを把握するのには役立ちますけど。

以下一人雑談

  command rのembedding層と出力層は同じ行列を使っているようです。そうなると、出力層では最後の隠れベクトルと各embeddingの内積を計算して、一番高かったトークンが予測結果になります。ということは中間のレイヤーの役目は現トークンのembeddingを次のトークンのembeddingに変換することになるんでしょうね。実際にレイヤーのリストを空にすると、embeddingが変換されないので前のトークンをそのまま繰り返すぼっとになるようです(normだとか位置埋め込みうんぬんとかはあるだろうけど)。

<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>あなたは優秀なアシスタントです。<|START_OF_TURN_TOKEN|><|USER_TOKEN|>プロ棋士による対局の観戦で、AIによる形勢判断や差し手予測を表示することは良いことだと思う?<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|><|CHATBOT_TOKEN|>…

  逆にレイヤーを増やしたら、次の次のトークンになっちゃいそうな気がしますが、二回繰り返してもそこそこうまくいってそうです。

 トークンごとに計算するレイヤーを変えてあげるなんてこともできそうです。たとえば予測確率分布のえんとろぴーが高かったらさらにレイヤーを通してみるとか・・・?めんどくさいからやりませんでしたけど。