Gemma 2Bを3種類のタスクで埋め込みモデルとして訓練しマージして比較する

埋め込みモデルのマージを実験してみました。
まず、Gemma-2Bをベースモデルとして、検索、NLI、分類の3タスクを想定した学習データで学習し、それぞれのタスクに特化した3つの埋め込みモデルを作成しました。
これらをマージすることで、タスクに特化したモデルよりも評価が高くなることが確認でき、埋め込みモデルでもマージが有効であることが分かりました。
ただし、特に検索性能は既存のモデルと比べて競える精度とはなっていません。もう少し改善出来たら公開したいです。

背景、モチベーション

取り組みの背景として、関連する最近の話題をいくつか紹介します。

1. LLMを使った埋め込みモデルとモデルマージ

これまでの埋め込みモデルは、BERTのような1B未満のパラメータの小規模なモデルが使われることが多かったのですが、最近ではLLMを使った埋め込みモデルも出てきています。
日本語に対応した小規模な埋め込みモデルとしては、intfloat/multilingual-e5-largeBAAI/bge-m3などが有力なモデルですが、これらのパラメタ数は
0.5B程度でLLMに比べると非常に小さいものです。

LLMを使った埋め込みモデルは、高い性能を持つものとしてはmistral-7bをベースとしたintfloat/e5-mistral-7b-instructが2023/12あたりに公開されており、他にも論文・モデルが公開されてきています。

  • GritLM/GritLM-8x7B テキスト生成と埋め込みを同時に行えるように訓練したモデル。Mistral 7BとMixtral 8x7Bをベースにしたモデルを公開

  • Salesforce/SFR-Embedding-Mistral e5-mistrall-instructをベースに多数のデータセットで追加学習したモデル。ブログ投稿ではマルチタスク、ハードネガティブの選び方、バッチサイズなどの要素の影響が解説されている

  • BMRetriever 医療文書ドメイン向けに4サイズのLLMを埋め込みモデルとして訓練したモデル群。410M(Pythia)、1B(Pythia)、2B(Gemma)、7B(BioMistral)の4種類のモデルを公開。

  • nvidia/NV-Embed-v1 Mistra 7Bをベースに、last token poolingの代わりにlatent attentionを取り入れMTEBで暫定1位のモデル

LLMと言えば、ここ数か月日本語LLMの界隈でもモデルマージが流行していますが、埋め込みモデルについてはあまりマージの話を聞きません。

個人的に以前、e5-mistral-instructとMistral 7Bの日本語追加学習モデルであるstabilityai/japanese-stablelm-base-gamma-7bをマージしてみたところ日本語評価タスクでの性能が向上することを確認しており、埋め込みモデルのマージも効果がありそうだと思っていました。

2. Sentence-Transformersのアップデート

sentence-transformersは埋め込みモデルの学習、推論を行うデファクトのライブラリです。
2022年のv2.2以降アップデートが止まっていましたが、2024年の1月にv2.3が出て以来かなりのスピードで開発が進んでいます(嬉しい)。
v2.3ではCachedMultipleNegativesRankingLossが追加されましたが、これは限られたVRAMでも大きなバッチサイズで訓練することを可能にするGradCache が実装されたものです。
埋め込みモデルの学習ではバッチ内のテキストを負例として共有するバッチ単位の対象学習をすることが多く、大きなバッチサイズを使用することで、より多くの負例が得られ、より良く学習できると考えられています。
このアプデにより、小規模な学習資源でも大きなバッチサイズでLLMの埋め込み学習を行うことが非常に簡単になりました

3. JMTEBのリリース

6タスク・16データセットで構成される日本語テキスト埋め込みベンチマークJMTEBが公開されました。
MTEBの日本語サブセットや、公開されている日本語データが整理され、データセットとして公開されています。

今回の実験でもsbintuitions/JMTEBで公開いただいているデータセットをいくつか使わせていただきました。

4. 中規模サイズの日本語対応LLMのリリース

Mistral 7Bをベースにした埋め込みモデルは精度は良くても、大きすぎてユースケースが限られます(例えば、リアルタイムのRAGで使いたい場合は生成用のLLMを同時に使うため、VRAMの取り合いになる)。
一方で、BERTをベースとした埋め込みモデルは精度に限界がある、かもしれません。特に、LLMを使った埋め込みモデルではinstructionを使って埋め込む方法(※)が良く使われますが、そのような指示の理解能力はなんとなくLLMと大きく差が出そうな気がします。
※ テキストに Classify Amazon reviews into positive or negative sentiment などの指示的なprefixを付けて訓練することで、多様なタスクに同時に対応できるようにする方法

そこで、精度とモデルサイズのバランスを取る選択肢としてt、1B~3B程度の大きさで精度の高い埋め込みモデルがあると、ユースケース次第では嬉しいかもと考えています。

今回は、1B~3B程度で日本語対応しているモデルとしてGemma 2Bを選択しました。
つい数日前にQwen2 1.5Bが公開されており、こちらも気になっています。

実験

今回は、Gemma 2Bをベースモデルとして、検索、NLI、分類の3タスクを想定してそれぞれのタスクに特化したモデルを学習し、後でマージすることで良いとこ取りができるか、を試してみました。

データセット

anchorに対し、positive1件とhard negativeを1件紐づけたtripletデータセットとして作成しました。

  • NLIhppRC/simple-simcse-jaのデータセットを使用) 33k件

    • 学習:shunk031/jsnli、cl-nagoya/nu-mnli、cl-nagoya/nu-snli、hpprc/janli

      • hard negative:entailmentのテキストをpositive、neutralまたはcontradictionの1件をnegativeとして使用

    • 評価:JSTS

  • 検索 65k件

    • 学習:cl-nagoya/auto-wiki-qaのサブセットと、個人的に作成した長めのクエリを持つWikipedia質問データセット

    • 評価:Miracl

  • 分類(JMTEBのデータセットを使用) 22k件

    • 学習:amazon_review_classification、massive_intent_classification、massive_scenario_classification

      • hard negative:異なるカテゴリを持つデータからランダムに1件使用

    • 評価:livedoor_news

学習

sentence-transformers v2.7を使用し、1xRTX3090の環境で、CachedMultipleNegativesRankingLossを使うことで大きなバッチサイズ(1024)での対照学習を試行しました。
また、instructionは使わずにme5-largeと同じスタイルで"query:"と"passage: "のprefixを付けて学習しました。

パラメタは以下の通りです。(記載のないものはライブラリのデフォルト)

  • CachedMultipleNegativesRankingLossを使用

  • LoRA(r=16, alpha=32)

  • batch_size=1024

  • max_length=128(NLI), 256(分類), 512(検索)

検索タスクはある程度長いpassageを学習したいのでmax_lengthを長くする必要がありますが、max_length=512ではかなり学習に時間がかかり、検索タスク65k件の学習には180時間程度を要しました。
(おおむねmax_lengthに線形に遅くなるようで、max_length=128の4倍時間がかかる)

学習結果

単独タスクの学習結果

単独タスクで学習した3つのモデルと、e5-mistral(ベースライン)との比較を行います。

NLIタスクのみで学習したモデル(青)は、Miracl(検索)評価が0.5未満と低い値ですが、JSTSの評価は最も高くなっています


検索タスクのみで学習したモデル(オレンジ)はすべての評価タスクで高い数値
となりました。ただ、Miraclでの評価はe5-mistralよりかなり低く、既存モデルと比べて競えるレベルには達しておらずこの点は不満です。
知る限り、ベースモデルから日本語のデータで対照学習をしてe5シリーズに近い検索性能を達成している公開済みの埋め込みモデルはなく、検索性能を高めるのはなかなか難しいのかもしれません。
とはいえ、この図のように学習終盤でもまだ評価値は上昇していたので、データを増やせば解決するかもしれません。

分類タスクで学習したモデル(緑)はすべてのタスクで低い数値となりました。学習データの大半をamazon_review_classificationが占めるのですが、これはレビューの点数を分類するタスクで、他のタスクとはかなり趣旨が異なります。
例えば「この傘は使いやすかった、デザインも好き」(5点)というテキストに対して、「この傘、見た目は良かったけどすぐ壊れて使えなくなった」(1点)は遠ざけて、「うちの観葉植物に使ったらすくすく育ちました!」(5点)を近づける必要があります。
このようなタスクと、検索タスクやトピック分類のようなタスクを同時にこなすには、instructionをつけてタスクの区別をモデルに教えてあげないと難しそうな気がします(今回は検索タスク以外は"query: "を付けて学習したので、モデルにタスクを区別する手掛かりがなかった)

マージした結果

3つのモデルをマージして評価します。
mergekitを使用し、マージ手法はTIESを使いました。ここはあまり探索しておらず、より良い手法やパラメタの探索の余地があります。

単独タスクの結果の良かったNLIと検索モデルのマージ(赤)は、すべてのタスクでマージ元のモデルを上回りました
NLIモデルはMiraclの評価結果がかなり低いので、検索モデルにマージすることで邪魔してしまうと思ったのですが、逆に上昇したのは意外でした。

3モデルをマージしたモデル(紫)もかなり良好な結果です。
分類タスクで学習した(緑)は数値が悪いですが、マージしてもあまり邪魔しないようです。livedoor_newsの評価結果は2モデルのマージよりわずかに向上しています。

余談ですが、埋め込みモデルとして訓練したGemmaは、GemmaForCausalLMではなくGemmaModelクラスとして保存されています。
このままではmergekitで扱えないため、マージ前にGemmaForCausalLMとして読み込んで保存しなおす必要がありました。

複数タスクを同時に学習させた結果

複数タスクを同時に学習させる方法と、単一タスクを学習したモデルのマージのどちらが良くなるか比較したかったのですが、力尽きました。
max_lengthを短くして高速に学習したいNLIと、max_lengthを長くしたい検索を同時に学習させると、NLIの学習に数倍の時間がかかり効率が悪くなります。
効率を上げるためにタスクを順番に学習させたり、なにか工夫して一つのバッチに近い長さのものを集めると良いかもしれませんが…

少し趣旨は違いますが、SFT-Embeddingの記事では同じタスクのデータを使ってバッチを作ることでin-batch negativeの難しさを上げて学習効果を高める工夫(Task-Homogenous Batching)をしているようです。

今後やってみたいこと

  • 今回失敗したamazon_review_classificationでの学習など、多様なタスクに対応するためにinstructionを使って学習する

  • Qwen2 1.5Bをベースモデルとした学習

  • sentence-transformersのv3で導入されたマルチGPU学習

  • 今回届かなかった、e5シリーズを超える検索性能を達成するための工夫(特にアイデアなし)

  • JMTEBの全タスクでの評価と考察


この記事が気に入ったらサポートをしてみませんか?