見出し画像

ローカル環境でRAGを用いた文書生成とCTranslate2を用いた高速化

研究開発本部 海老原樹、後藤裕也

はじめに

シャープでは現在、LLM(Large Language Model)の開発、特にエッジデバイスで動作するLLM(以下、ローカルLLM)に関する様々な取り組みを行っております。本シリーズ記事ではそれらの取り組みの中から一部をご紹介させて頂きたいと考えております。(本シリーズのオープニング記事はこちら

本シリーズ記事の第一回は、「RAG(Retrieval Augmented Generation)」についてです。LLMを外部のデータベースと連携させる事により、学習されていない事柄についても正しく答えられるようになります。

今回はローカル環境でRAGによる文書生成を行うシステムを実装し、LLMにシャープの情報について正しく答えてもらおうと思います。さらに後半ではLLMの回答の高速化の実験を行います。
LLMに興味のある皆様の参考になれば幸いです。

要約

  • オープンソースの日本語言語モデルの「ELYZA-japanese-Llama-2-7b-instruct」とLangChainを用いてRAGによる文章生成をするプログラムを実装。

  • ELYZAをCTranslate2変換したモデルをLangChainに組み込み、文章生成の速度向上を確認。変換前のモデルに比べ、処理速度が1.5~1.8倍に向上。


RAGとは?


RAGによる文書生成の概要

RAG(Retrieval Augmented Generation)は、LLMに外部のデータベースや情報源を用いてテキストを生成させる技術です。
通常、LLMを用いてテキストを生成する際にはプロンプトと呼ばれる指示文を入力します。RAGにおいては通常のプロンプトに加え、プロンプトとの関連性が高い文章をデータベースから取得し、それをプロンプトに付け加えてテキスト生成を行います。インプットの部分を工夫してやる事で、モデル自体を一から学習し直さずとも所望の答えが得られるという訳ですね。
 
例えばChatGPTで簡単にRAGを試してみます。まずはシャープの本社の場所を聞いてみましょう。

ChatGPTにシャープの本社の場所を聞いてみた

シャープの本社の本社は大阪府堺市にあるので、この回答は誤りです。

続いて、プロンプトにシャープのWikipediaの概要にある文章を追加してみましょう。

プロンプトに追加情報を入れて聞いてみた

今度は正しい回答を得られました。
今回は質問に関する情報を検索してプロンプトに与える部分は手動で行いましたが、実際のRAGではこの部分もプログラムの中で行います。

RAGによる文書生成高速化

今回は、「ローカル環境で、高速で動く」ことを目標に、RAGによる文書生成を行うLLMのシステムを実装し、高速化の実験を行いました。

LLMにはオープンソースの日本語言語モデルの「ELYZA-japanese-Llama-2-7b-instruct」を使っています。

以降では、実験のコードと実行手順をご説明します。前半ではLangChainを使用して、RAGによる文書生成を行い、後半ではCtranslate2を使った文書生成の高速化を行います。

コードは主要な部分のみをご紹介し、詳しいコードは補足に掲載していますので、必要に応じてご参照ください。

実行環境について

  • OS: Ubuntu 20.04

  • RAM: 125 GB

  • GPU: NVIDIA GeForce RTX 3090(VRAM 24 GB)

  • CUDA 11.8

  • Nvida Driver 530.30.02

  • Python 3.10.3

前半: RAG による文書生成

手順

まずはLangChainを使用して、RAGによる文書生成を試してみます。LangchainはLLMを使ったアプリケーションの開発に便利なライブラリで、RAGのための機能も備わっています。

 RAGによる文書生成の手順は、大きく以下の3つのステップになります。

  1. ベクトルデータベースの準備

  2. 質問と関連する文章検索

  3. 追加情報を活用した文章生成

それでは、手を動かしていきましょう!

ベクトルデータベースの準備

最初のステップでは、LLMに与える追加情報のソースとなる外部のデータベースを準備します。今回は外部のデータベースとして、シャープのWikipediaのページから抽出したテキストを使いました。

HTMLのダウンロード、プレーンテキストの抽出にはtrafilaturaというライブラリを使いました。

以下が抽出したテキストの一部です。

(一部抜粋)
|シャープ株式会社
|特記事項:
シャープ株式会社(英: SHARP CORPORATION)は、大阪府堺市に本社を置く日本の電気機器メ
ーカー。台湾の鴻海精密工業(フォックスコングループ)の子会社。日経平均株価の構成銘柄
の一つ[3]。
1912年、早川徳次が東京市本所区松井町(現・東京都江東区新大橋)に金属加工業を設立す
る。関東大震災により工場を消失後、1924年に大阪府東成郡田辺町(現・大阪府大阪市阿倍野
区)に早川金属工業研究所を設立する。1935年に改組し、株式会社早川金属工業研究所を設立
し、1936年に早川金属工業株式会社、1942年に早川電機工業株式会社、1970年にブランドと
して使われていたシャープ株式会社に社名変更する。2016年に大阪府堺市堺区匠町に本社移
転。
歴史[編集]
- 1912年 - 早川徳次が東京で創業した。徳尾錠というベルトのバックルの発明が始まりであ
る。
- 1915年 - 金属製繰出鉛筆(早川式繰出鉛筆)を発明。販売開始後、商品名をエバー・レディ・
シャープ・ペンシルに変えた。アメリカで爆発的にヒット。現在の社名はこれに由来する。
- 1923年 - 関東大震災によりシャープペンシル工場を焼失。早川は家族もすべて失い、大阪
へ移り再起を図った。明治時代、明治政府から藩債処分の影響で大打撃を与えられた大阪市で
あったが当時は経済的にも回復していた。
- 1925年 - 鉱石ラジオをシャープの名前で発売。戦前の主力商品となる。
- 戦後、総合家電では松下電器産業やソニーが台頭し、営業・販売力においてこの2社に圧倒 
 的な差を付けられていた上、シャープ製のテレビ(ブラウン管はアメリカ等海外製)が突然
 発火して大火事になった事件などもあり、低迷の時代が続く。
- 1962年 - 日本の家電企業で初めて家庭向け量産型の電子レンジを発売(当初は業務用)。
 1966年には世界初のターンテーブル方式の電子レンジを開発する。

上記のテキストデータをEmbedding(埋め込み)モデルを用いてベクトルに変換します。

Embeddingモデルとは、テキストや画像のようなデータをその特徴を表す数値ベクトルに変換するモデルです。質問も同様にベクトル化することで、テキストの類似性に基づいて検索することができるようになります(「ベクトル検索」)。

Embeddingモデルにはintfloat/multilingual-e5-largeを使用しました。(そのほかの候補としてはOpenAIのAPIを介して使用できるtext-embedding-ada-002がありますが、ローカル環境で実行できるものを作るという目的上、今回はintfloat/multilingual-e5-largeを選択しています。)

また、ベクトル検索のライブラリはMeta製のFaissを使用しました。FaissはGPUを活用することができ、高速で動作することが特徴です。

前置きが長くなりましたが、コードの説明です。まずはテキストデータ(sharp_wiki.txt)をある程度の長さで分割します。この分割された文章(チャンク)が検索時に扱う文章の単位になります。

from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = TextLoader("sharp_wiki.txt", encoding="utf-8")
document = loader.load()

# 全文章を決まった長さの文章(チャンク)に分割して、文章データベースを作成
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=300,
    chunk_overlap=20,
)

splitted_texts = text_splitter.split_documents(document)
print(f"チャンクの総数:{len(splitted_texts)}")
print(f"チャンクされた文章の確認(参考に7番目にチャンクされたデータを確認):\n{splitted_texts[6]}")
チャンクの総数:226
チャンクされた文章の確認(参考に7番目にチャンクされたデータを確認):
page_content='1912年、早川徳次が東京市本所区松井町(現・東京都江東区新大橋)に金属
加工業を設立する。関東大震災により工場を消失後、1924年に大阪府東成郡田辺町(現・大阪
府大阪市阿倍野区)に早川金属工業研究所を設立する。1935年に改組し、株式会社早川金属工
業研究所を設立し、1936年に早川金属工業株式会社、1942年に早川'
 metadata={'source': 'sharp_wiki.txt'}

226個のチャンクが生成されました。
それぞれのチャンクをEmbeddingモデルを使ってベクトル化し、ベクトルデータベースを作成します。

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

# 文章からベクトルに変換するためのモデルを用意
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# 文章データベースからベクトルデータベースを作成。チャンク単位で文章からベクトルに変換。
db = FAISS.from_documents(splitted_texts, embeddings)

質問と関連する文章検索

ベクトルデータベースの準備ができたので、早速ベクトル検索をしてみます。質問は「シャープの本社はどこにありますか。」です。
一番類似する文章を3つ出力し、ついでに処理時間を計測しました。

import time

question = "シャープの本社はどこにありますか。"

start = time.time()
# 質問に対して、データベース中の類似度上位3件を抽出。質問の文章はこの関数でベクトル化され利用される
docs = db.similarity_search(question, k=3)
elapsed_time = time.time() - start
print(f"処理時間[s]: {elapsed_time:.2f}")
for i in range(len(docs)):
    print(docs[i])
処理時間[s]: 0.01
コンテンツ: シャープ
|
大阪府堺市の本社
|種類
|株式会社
|機関設計
|監査等委員会設置会社[1]
|市場情報
|本社所在地
|
日本
〒590-8522
大阪府堺市堺区匠町1番地
|設立
|1935年(昭和10年)51日(株式会社早川金属工業研究所)
|業種
|電気機器
|法人番号
|6120001005484
|事業内容
|エレクトロニクス、電子部品
|代表者
|

コンテンツ: 高橋興三(七代目社長)
戴正呉(八代目社長)
浅田篤(元副社長)
|外部リンク
|シャープ株式会社
|特記事項:
シャープ株式会社(英: SHARP CORPORATION)は、大阪府堺市に本社を置く日本の電気機器メーカー。台湾の鴻海精密工業(フォックスコングループ)の子会社。日経平均株価の構成銘柄の一つ[3]。

コンテンツ: 出典[編集]
- ^ 組織図 - シャープ株式会社
- ^ a b c d e f “第123期有価証券報告書” (PDF). シャープ株式会社 (2017年6月21日). 2017年10月2日閲覧。
- ^ 構成銘柄一覧:日経平均株価 Nikkei Inc. 2021年10月8日閲覧。
- ^ シャープ外資傘下へ 社員・ゆかりの地、揺れる思い [要検証]
- ^ シャープ本社 長年親しまれた大阪・阿倍野から堺へ移転 THE PAGE 2016年7月1日、2022年8月7日閲覧。

処理速度は0.01秒、速いですね。文章は類似度が高い順番に出力されています。
確かに、全ての文章に現在の本社の場所である「大阪府堺市」または「堺」が見つかります。1つ目の文書はより詳しい住所まで載っていますね。

文書生成

LLMには「ELYZA-japanese-Llama-2-7b-instruct」を使用します。オープンソースの日本語言語モデルの中でもトップクラスの性能を誇り、それでいて少々高スペックなPCであればローカル環境での動作が可能です。

モデルのロード(半精度浮動小数点(float16)でロード)、RAGのためのLangChainのインタフェースの用意については補足のコードをご参照ください。

ELYZA用のプロンプトは以下になります。下記の{context}の部分に外部データベースの関連する文章が、{question}には質問が入ります。

<s>[INST] <<SYS>>
参考情報を元に、ユーザーからの質問に簡潔に正確に答えてください。
<</SYS>>

{context}
ユーザからの質問は次のとおりです。{question} [/INST]

まずはRAGを使わずに(外部データの情報なしに)文書生成をしてみましょう。質問は「シャープの本社はどこにありますか。」で、期待する回答は「大阪府堺市(堺区匠町1番地)」です。

RAGなし
処理時間[s]: 0.85
 シャープの本社は、大阪府大阪市北区堂島1-1-23にあります。

惜しい、けど不正解ですね。
ではRAGを使って外部情報を参考に文書生成をしてみましょう。プロンプトにベクトル検索した結果の3つの文章と質問を入れ、LLMに回答してもらいます。

RAGあり
処理時間[s]: 0.82
 シャープの本社は大阪府堺市にあります。
トークン数: 28

正しい住所を教えてくれました!
一緒に出力されている「トークン数」の「トークン」とはLLMがテキストデータを処理する際の単語の単位であり、出力する文章が長い=トークン数が多いほど、LLMが全ての文章を出力するまでの時間が長くなります。今回の結果では、1トークン当たり0.030秒の処理速度です。

また、プログラムの実行中、GPUのメモリは最大で約17.3GBが使用されていました(このうち約3.8GBはベクトルデータベースに関連する機能で占めています)。GPUを活用しているということもあり、短い文章の出力であれば上記の実行速度でも十分に高速だと感じます。
しかし、出力する文章が長くなるケースだとより速い処理速度が求められることも考えられます。また、入力文章や出力文章が長いほど、文章生成時にLLMが占めるメモリは大きくなるため、省メモリ化も望まれます。

そこで、後半ではCTranslate2を使ってLLMの高速処理・省メモリ化を行う実験をします。

後半: CTranslate2を用いた高速化

手順

CTranslate2は、Transformerモデルのメモリ削減および高速な推論のためのC++とPythonのライブラリです。

以降はRAGによる文章生成を高速化する実験を、以下の3つのステップで進めます。

  1. モデルのCTranslate2変換

  2.  LangChainで動かすための準備

  3.   文書生成

モデルのCTranslate2変換

HuggingfaceのレポジトリからELYZA-japanese-Llama-2-7b-instructモデルをダウンロードし、Ctranslate2で変換します。モデルのダウンロードと変換には少々時間がかかります。

ct2-transformers-converter --model elyza/ELYZA-japanese-Llama-2-7b-instruct --quantization int8 --output_dir ct2_model

変換されたモデルが「ct2_model」ディレクトリに保存されています。元々のモデルのファイルサイズは約13.5GB(半精度浮動小数点)でしたが、保存されたモデルのファイルサイズ(model.bin)は約6.6GBになっています。
以降はこの変換されたモデルを便宜上「ct2モデル」と呼ぶことにします。

LangChainで動かすための準備

Langchainにはct2モデルを動かすためのクラス(langchain.llms.CTranslate2)が用意されていますが、ct2モデルの文書生成時の挙動を制御するためにlangchain.llms.CTranslate2を継承した自作のクラスを作成します。
(LLMの出力にプロンプトを含まないように設定しています。)

from typing import Optional, List, Any

from langchain.llms import CTranslate2
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import Generation, LLMResult

## ELYZA LLama2 + Ctranslate2 (7B)
class ElyzaCT2LLM(CTranslate2):
    generator_params = {
        "max_length": 256,
        "sampling_topk": 20,
        "sampling_temperature": 0.7,
        "include_prompt_in_result": False,
    }

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        encoded_prompts = self.tokenizer(prompts, add_special_tokens=False)["input_ids"]
        tokenized_prompts = [
            self.tokenizer.convert_ids_to_tokens(encoded_prompt)
            for encoded_prompt in encoded_prompts
        ]

        # 指定したパラメータで文書生成を制御
        results = self.client.generate_batch(tokenized_prompts, **self.generator_params)

        sequences = [result.sequences_ids[0] for result in results]
        decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences]

        generations = []
       for text in decoded_sequences:
            generations.append([Generation(text=text)])

        return LLMResult(generations=generations)

自作したクラスを使ってモデルをロードします。

model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
llm = ElyzaCT2LLM(
    model_path="ct2_model",
    tokenizer_name=model_name,
    device="cuda",
    device_index=[0],
    compute_type="int8",
) 

ここで用意したモデルを使い、前半と同様にRAGによる文書生成を行います。

文書生成

ct2モデルに対して、前半と同様にシャープの本社の場所を聞いてみましょう。RAGを使って外部情報を参考に文書生成した結果がこちらです。

RAGあり
処理時間[s]: 0.70
シャープの本社は、大阪府堺市堺区匠町1番地にあります。
トークン数: 42

おみごと、今度も正解です。ct2モデルも正しい答えを返してくれました。
文章の生成速度は1トークン当たり約0.017秒です。

これはCTranslate2変換前の(半精度浮動小数点の)モデルと比べ、2倍近く速い処理速度になります。
プログラムの実行中、GPUのメモリは最大で約10.8GBが使用されていました。こちらも、CTranslate2を行っていない時とモデルと比べ低く抑えられており、CTranslate2による変換の量子化の効果が表れています。
 
では、他の質問も試してみましょう。少々難易度高めです。
質問:「今年は2023年ですが、シャープは創業何周年ですか。」
期待する答えはもちろん、111周年です!
RAGによる文書生成の回答を見てみましょう。まずはCTranslate2変換なしのモデルの回答。

RAGあり
処理時間[s]: 0.85
 2023年はシャープの創業88周年にあたります。
トークン数: 26

次にct2モデル。

RAGあり
処理時間[s]: 0.54
シャープは2023年で創業88年を迎えます。
トークン数: 25

処理速度はct2モデルの方が約1.5倍速いです。文章の長さ等によって処理速度にばらつきがありますが、ct2モデルを使った場合はCTranslate2モデル変換前のモデルより約1.5~1.8倍速いという結果になりました。

両者とも答えは同じなのですが、期待したものとは違います。なぜなのかと思い、ベクトル検索の結果を見てみると、検索された以下の情報をもとに文章を生成しているようでした。

シャープ
|
大阪府堺市の本社
|種類
|株式会社
|機関設計
|監査等委員会設置会社[1]
|市場情報
|本社所在地
|
日本\n〒590-8522
大阪府堺市堺区匠町1番地
|設立
|1935年(昭和10年)51日(株式会社早川金属工業研究所)
|業種
|電気機器
|法人番号
|6120001005484
|事業内容
|エレクトロニクス、電子部品
|代表者

設立は1935年という記載があります。こちらは株式会社としての設立なので、参照してほしい情報とは違います。
ですが2023年-1935年=88年という正しい計算をしているELYZAはさすがですね!!(正直びっくり)
創業に関してはWikipediaにも記載がありますが(「1912年 - 早川徳次が東京で創業した。徳尾錠というベルトのバックルの発明が始まりである。」)、創業時は「シャープ」という名前ではなかったためか、「今年は2023年ですが、シャープは創業何周年ですか。」という質問ではベクトル検索には引っ掛かりませんでした。
ベクトルデータベースを構築する際はベクトル検索のことを意識し、ベクトル検索がしやすい構造にするためにテキストの前処理を行うと効果的でありそうです。

まとめ

今回はRAGによる文書生成をLangChainを使って実装し、さらにLLMをCTranslate2で変換することで処理速度の向上させる実験を行いました。
RAGによる文書生成ではローカルでの動作を確認し、シャープの住所について正確な返答を得ることができました。また、CTranslate2変換によって約1.5倍~1.8倍の処理速度向上を確認しました。
RAGはLLMを直接学習させることなく、新たな知識を付与する手法でしたが、ファインチューニングと組み合わせることによって、LLMを目的に合わせてより自由にカスタマイズすることができます。
次回はそのLLMのファインチューニングに関する記事を公開予定です!

参考文献

補足

使用したコードの詳細です。

事前準備

ライブラリのインストール

pip install ctranslate2 faiss-gpu langchain sentence-transformers tiktoken trafilatura torch transformers==4.30.2

使用したライブラリのバージョンは以下の通りです。

ctranslate2==3.20.0
faiss-gpu==1.7.2
langchain==0.0.312
sentence-transformers==2.2.2
tiktoken==0.5.1
trafilatura==1.6.2
torch==2.1.0
transformers==4.30.2

LLMのCTranslate2変換

ct2-transformers-converter --model elyza/ELYZA-japanese-Llama-2-7b-instruct --quantization int8 --output_dir ct2_model

メモリを多く使うので、必要に応じて--low_cpu_mem_usageオプションをつけてください。

RAGによる文書生成高速化

###################################################
# 前半: RAGによる文書生成
###################################################
# テキストデータベースの準備

from trafilatura import fetch_url, extract

url = "https://ja.wikipedia.org/wiki/シャープ"
filename = "sharp_wiki.txt"

# ウェブページをダウンロード
http_response = fetch_url(url)

# HTML から本文を抽出
html_content = extract(http_response)

# どのようなデータが取得出来ているか確認するために文章の先頭 300 文字を表示
print(f"ダウンロードしたHTMLの文章先頭部分:\n{html_content[:300]}")

# HTMLの本文を保存
with open(filename, "w", encoding="utf-8") as f:
    f.write(html_content)

###################################################
# ベクトルデータベースの準備

from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = TextLoader(filename, encoding="utf-8")
document = loader.load()

# 全文章を決まった長さの文章(チャンク)に分割して、文章データベースを作成
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=300,
    chunk_overlap=20,
)

splitted_texts = text_splitter.split_documents(document)
print(f"チャンクの総数:{len(splitted_texts)}")
print(f"チャンクされた文章の確認(参考に7番目にチャンクされたデータを確認):\n{splitted_texts[6]}")

# ベクトルデータベースの作成
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

# 文章からベクトルに変換するためのモデルを用意
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# 文章データベースからベクトルデータベースを作成。チャンク単位で文章からベクトルに変換。
db = FAISS.from_documents(splitted_texts, embeddings)

###################################################
# ベクトル検索

import time

question = "シャープの本社はどこにありますか。"

start = time.time()
# 質問に対して、データベース中の類似度上位3件を抽出。質問の文章はこの関数でベクトル化され利用される
docs = db.similarity_search(question, k=3)
elapsed_time = time.time() - start
print(f"処理時間[s]: {elapsed_time:.2f}")
for i in range(len(docs)):
    print(docs[i])

###################################################
# RAG の準備

# 生成モデルの準備
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda:0"
model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to(device)

# RAG のためのLangChainのインタフェース準備
from transformers import pipeline
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    top_k=20,
    temperature=0.7,
    device=device,
)
llm = HuggingFacePipeline(pipeline=pipe)

# プロンプトの準備(ELYZA 用)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = "参考情報を元に、ユーザーからの質問に簡潔に正確に答えてください。"
text = "{context}\nユーザからの質問は次のとおりです。{question}"
template = "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
    bos_token=tokenizer.bos_token,
    b_inst=B_INST,
    system=f"{B_SYS}{DEFAULT_SYSTEM_PROMPT}{E_SYS}",
    prompt=text,
    e_inst=E_INST,
)
rag_prompt_custom = PromptTemplate(
    template=template, input_variables=["context", "question"]
)

# チェーンの準備
chain = load_qa_chain(llm, chain_type="stuff", prompt=rag_prompt_custom)

###################################################
# 生成

# RAG ありの場合
start = time.time()
# ベクトル検索結果の上位3件と質問内容を入力として、elyzaで文章生成
inputs = {"input_documents": docs, "question": question}
output = chain.run(inputs)
elapsed_time = time.time() - start
print("RAGあり")
print(f"処理時間[s]: {elapsed_time:.2f}")
print(f"出力内容:\n{output}")
print(f"トークン数: {llm.get_num_tokens(output)}")

# RAG なしの場合
# 質問内容のみを入力として、elyzaで文章生成
inputs = template.format(context="", question=question)
start = time.time()
output = llm(inputs)
elapsed_time = time.time() - start
print("RAGなし")
print(f"処理時間[s]: {elapsed_time:.2f}")
print(f"出力内容:\n{output}")
print(f"トークン数: {llm.get_num_tokens(output)}")

###################################################
# メモリの解放

del model, tokenizer, pipe, llm, chain
torch.cuda.empty_cache()

###################################################
# 後半: CTranslate2を用いた高速化編
###################################################
# RAG の準備

# 生成モデルの準備
from typing import Optional, List, Any

from langchain.llms import CTranslate2
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import Generation, LLMResult

## ELYZA LLama2 + Ctranslate2 (7B)
class ElyzaCT2LLM(CTranslate2):
    generator_params = {
        "max_length": 256,
        "sampling_topk": 20,
        "sampling_temperature": 0.7,
        "include_prompt_in_result": False,
}

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        encoded_prompts = self.tokenizer(prompts, add_special_tokens=False)["input_ids"]
        tokenized_prompts = [
            self.tokenizer.convert_ids_to_tokens(encoded_prompt)
            for encoded_prompt in encoded_prompts
        ]

        # 指定したパラメータで文書生成を制御
        results = self.client.generate_batch(tokenized_prompts, **self.generator_params)

        sequences = [result.sequences_ids[0] for result in results]
        decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences]

        generations = []
        for text in decoded_sequences:
            generations.append([Generation(text=text)])

        return LLMResult(generations=generations)

model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
llm_ct2 = ElyzaCT2LLM(
    model_path="ct2_model",
    tokenizer_name=model_name,
    device="cuda",
    device_index=[0],
    compute_type="int8",
)

# RAG のためのLangChainのインタフェース準備
chain = load_qa_chain(llm_ct2, chain_type="stuff", prompt=rag_prompt_custom)

###################################################
# 生成

# RAG ありの場合
start = time.time()
inputs = {"input_documents": docs, "question": question}
output = chain.run(inputs)
elapsed_time = time.time() - start
print("RAGあり")
print(f"処理時間[s]: {elapsed_time:.2f}")
print(f"出力内容:\n{output}")
print(f"トークン数: {llm_ct2.get_num_tokens(output)}")

# RAG なしの場合
inputs = template.format(context="", question=question)
start = time.time()
output = llm_ct2(inputs)
elapsed_time = time.time() - start
print("RAGなし")
print(f"処理時間[s]: {elapsed_time:.2f}")
print(f"出力内容:\n{output}")

この記事が気に入ったらサポートをしてみませんか?