見出し画像

【初心者向け】Databricksで始めるLLM - LLMを動かしてみるところからベクターDBやRAGまで※サンプルコードあり

三菱UFJフィナンシャル・グループ(以下MUFG)の戦略子会社であるJapan Digital Design(以下JDD)でMUFG AI Studio(以下M-AIS)に所属する浦田です。JDDに入社してから初めてのエントリになります。

はじめに

今回は、Databricks環境を使ってLLM(大規模言語モデル)を動かしてみる初心者向けガイドをお届けします。LLMの導入が進む中、実際に手を動かして試すことが難しいと感じる方も多いのではないでしょうか。この記事では、日本で開発された「Swallow」モデルを使い、簡単にLLMを動かす方法から、ベクトルDBやRAG(Retrieval-Augmented Generation)といった先進的な技術の基本を紹介します。

Databricks上で実行可能なサンプルコードも用意しているので、ぜひお手元の環境で試してみてください。


1. DatabricksでLLMを動かす

まずは、Databricks環境でLLMを動かすための基本設定から始めます。今回使用するのは、東工大と産総研の研究チームが開発した「Swallow」というモデルです。このモデルは、英語の言語理解や対話に強いMeta社のLlama 2をベースに、日本語対応能力を大幅に向上させた大規模言語モデルです。

モデル: tokyotech-llm/Swallow-7b-instruct-hf

使用するインスタンス: g5.xlarge [A10G] (1GPU, 24GiB(GPU), 4vCPU, 16GiB)
Runtime: 14.3 LTS ML (Apache Spark 3.5.0, GPU, Scala 2.12)

まずはこの設定を使い、Swallowモデルをロードして実際に簡単な対話を行ってみましょう。

クラスター環境の設定

DatabricksでSwallowを動かすためのクラスター構成は以下の通りです。GPUを使って効率的に大規模モデルを動かせるよう、g5.xlargeインスタンスを選択しています。

クラスター構成例:

  • インスタンスタイプ: g5.xlarge(GPU 1つ、GPUメモリ 24GiB、CPU 4vCPU、メモリ 16GiB)

  • ランタイム: 14.3 LTS ML(Apache Spark 3.5.0、Scala 2.12)

この設定でSwallowモデルを動かすと、対話ベースの質問に応答させることが可能です。

サンプルコード

早速動かしてみましょう。以下のコードですぐにLLMが動いてしまいます。

"""
tokyotech-llm/Swallow-7b-instruct-hf
"""

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "tokyotech-llm/Swallow-7b-instruct-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")


PROMPT_DICT = {
    "prompt_input": (
        "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"

    ),
    "prompt_no_input": (
        "以下に、あるタスクを説明する指示があります。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 応答:"
    ),
}

def create_prompt(instruction, input=None):
    """
    Generates a prompt based on the given instruction and an optional input.
    If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
    If no input is provided, it uses the 'prompt_no_input' template.

    Args:
        instruction (str): The instruction describing the task.
        input (str, optional): Additional input providing context for the task. Default is None.

    Returns:
        str: The generated prompt.
    """
    if input:
        # Use the 'prompt_input' template when additional input is provided
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input)
    else:
        # Use the 'prompt_no_input' template when no additional input is provided
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

# Example usage
instruction_example = "以下のトピックに関する詳細な情報を提供してください。"
input_example = "東京工業大学の主なキャンパスについて教えてください"
prompt = create_prompt(instruction_example, input_example)

input_ids = tokenizer.encode(
    prompt,
    add_special_tokens=False,
    return_tensors="pt"
)

tokens = model.generate(
    input_ids.to(device=model.device),
    max_new_tokens=128,
    temperature=0.99,
    top_p=0.95,
    do_sample=True,
)

out = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(out)

出力は以下になります。しっかり回答できていますね!

※回答の文章が途中で途切れてしまっていますが、理由は、model.generate関数で指定した生成トークン数 (max_new_tokens=128) によるものです。このパラメータが、生成される新しいトークン数の上限を設定しているため、128トークンが生成されると文章が途中でも強制的に停止します。途切れないようにするための方法としては、max_new_tokens を増やす、stop_token を使用する、eos_token_id の活用などを検討すると良いと思います。

以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。リクエストを適切に完了するための回答を記述してください。

### 指示:
以下のトピックに関する詳細な情報を提供してください。

### 入力:
東京工業大学の主なキャンパスについて教えてください

### 応答:日本で最も有名な大学である東京工業大学は、本拠地を構える東京工業大学のキャンパスが主に3つあります。
1.大岡山キャンパス
2.すずかけ台キャンパス
3.田町キャンパス

本キャンパスである大岡山キャンパスは、東京都目黒区に位置し、1991年に開設されました。すずかけ台キャンパスは神奈川県横浜市のすずかけ台に位置し、1993年に開設されました。最後の田町キャンパスは東京都港区に位置し、2005

2. テキストをベクトル化し、ベクターDBを触ってみよう

次に、テキストをベクトル化し、ベクターDBに保存してベクトル検索を行う方法を見ていきましょう。LLM界隈では、テキストを数値ベクトルに変換する「Embedding」が広く使われています。今回は多言語対応の埋め込みモデルとして「Multilingual-E5」を使用します。またベクトルを効率的に扱うためのデータベースとして「ベクターDB」が注目されており、その中でも人気の高い「Chroma」を使用します。

モデル: intfloat/multilingual-e5-base
ベクターDB: Chroma

Chromaとは?

Chromaは、オープンソースのベクターDBとして知られ、特にAIや機械学習アプリケーションでの使用を想定して開発されています。テキストや画像などの非構造データを数値ベクトル(高次元ベクトル)に変換して保存し、そのベクトルに基づいて類似のデータを効率的に検索できるのがChromaの強みです。

例えば、LLMを使ったアプリケーションで過去の会話履歴やドキュメント検索を行う際、関連性の高いテキストを迅速に取り出すためにChromaが利用されます。従来のSQLデータベースでは難しい「類似検索」も、Chromaのベクトル検索なら簡単に実現できます。

このサンプルでは、任意のテキストをEmbeddingモデルでベクトル化し、そのベクトルをChromaベクターDBに保存して、類似のテキストを効率的に検索する方法を学びます。

サンプルコード

では実際に動かすためのコードを確認していきましょう。
まずは必要なパッケージをインストールします。

%pip install chromadb==0.4.24 \
    langchain-core==0.1.48 \
    langchain==0.1.17 \
    langsmith==0.1.52 \
    langchain-chroma==0.1.0 \
    langchain_community==0.0.36 \
    fugashi==1.3.2 \
    unidic-lite==1.0.8 \
    ipadic==1.0.0

dbutils.library.restartPython()

テキストデータの準備

今回はこちらのMUFGの株価について書かれたニュースのテキストデータを使っていきます。

document = """
本記事では三菱UFJフィナンシャル・グループ <8306> の株価について解説します。

【株価チャート】1年間と1ヶ月間の株価推移をチェックする

2024年5月2日時点における過去1年間の株価の推移や最高値など、今後の投資判断や株価分析に役立つ情報をお届けするので、参考にしてください。

※編集部注:外部配信先ではハイパーリンクや図表などの画像を全部閲覧できない場合があります。その際はLIMO内でご確認ください。

2024年5月2日の三菱UFJフィナンシャル・グループの株価は1554円
2024年5月2日時点、三菱UFJフィナンシャル・グループの株価(調整後終値)は1554円となっています。

また、1年前(2023年5月2日)の株価は862.6円です。

仮に、2023年5月2日時点で三菱UFJフィナンシャル・グループの株を取得していた場合、リターンは+80.15%となります。

2023年5月2日時点で100万円を投資した場合のリターンは+80.15万円です。

※リターン計算において配当、株主優待は考慮していません

三菱UFJフィナンシャル・グループの1年間の株価推移【2024年5月2日時点】
三菱UFJフィナンシャル・グループの株価は過去1年間で以下の通りに推移しています。

※株価チャートは配信先によっては表示されていない可能性があるため、ご覧になりたい場合はLIMOにて記事をご参照ください。

過去1年間における株価の最高価格と最低価格は次のようになりました。

 ・最高価格:1632円
 ・最低価格:848.3円
仮に最低価格で取得し、最高価格で売却できた場合、過去1年間で実現できた最高リターンは+92.38%です。

100万円が192.38万円となった計算です。

過去1年間における三菱UFJフィナンシャル・グループの株価の最大上昇日と最大下落日は?
過去1年間、三菱UFJフィナンシャル・グループの株価が対前日比で最も上昇した(もしくは最も下落率が低かった)のは2023年7月28日でした。

変化率は対前日比+5.28%です。

また、最も下落した(もしくは最も上昇率が低かった)のは2023年10月4日でした。

変化率は対前日比▲5.37%です。

株を買うには証券口座が必要!おすすめの証券会社は?
株を購入するには証券口座を開設する必要があります。

証券口座は銀行や証券会社などで開設できますが、なかでもおすすめなのはネット証券です。

ネット証券はオンライン専門の証券会社です。

口座開設から実際の購入まですべてオンラインで完結するので、店舗や人件費のコストがかからない分安く取引できる傾向にあります。
"""

テキストデータの分割

ニューステキストを均等なチャンクサイズに分割します。langchainに分割する便利関数が用意されているのでこれを使うだけです。

rom langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=30)
splits = text_splitter.split_text(document)

埋め込み用モデルのダウンロード

多言語対応の埋め込みモデル「Multilingual-E5」をダウンロードしてきます

from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings

# Embeddingの定義
embedding_model_name = "intfloat/multilingual-e5-base"
embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name)

テキストデータの埋め込みとベクターDBへの保存

# Embed
vectorstore = Chroma.from_texts(
    texts=splits, 
    embedding=embedding_model
)

これだけで既にベクターDBへ保存されているので、テキトーな質問文に対して一番近い文章を取得してみましょう。

query = "2024年5月2日時点での三菱UFJフィナンシャル・グループ <8306> の株価はいくらですか?"
vector_top = vectorstore.similarity_search(query=query, k=3)
vector_top
[Document(page_content='2024年5月2日の三菱UFJフィナンシャル・グループの株価は1554円\n2024年5月2日時点、三菱UFJフィナンシャル・グループの株価(調整後終値)は1554円となっています。'),
 Document(page_content='本記事では三菱UFJフィナンシャル・グループ <8306> の株価について解説します。\n\n【株価チャート】1年間と1ヶ月間の株価推移をチェックする'),
 Document(page_content='また、1年前(2023年5月2日)の株価は862.6円です。\n\n仮に、2023年5月2日時点で三菱UFJフィナンシャル・グループの株を取得していた場合、リターンは+80.15%となります。')]

見事、株価に関する記述を引っ張ってくることに成功していますね。これが次のRAGの原型になります。


3. RAGを体験してみよう

最後に、RAG(Retrieval-Augmented Generation)を使って、LLMの回答精度を向上させるテクニックを紹介します。RAGは、LLMに与える指示文に関連する情報を事前に取得し、その情報を埋め込んで応答の精度を上げる手法です。

今回のサンプルノートブックでは、以下のステップでRAGの効果を実感していただけます。

  1. LLMが知らない質問を投げてみる(例えば、最新のニュースに関する質問など)

  2. その質問に対して、LLMが自信満々に間違った回答を生成する様子を確認

  3. 最新のニュース記事を参照させることで、正しい回答に改善されることを体験

サンプルコード

%pip install chromadb==0.4.24 \
    langchain-core==0.1.48 \
    langchain==0.1.17 \
    langsmith==0.1.52 \
    langchain-chroma==0.1.0 \
    langchain_community==0.0.36 \
    fugashi==1.3.2 \
    unidic-lite==1.0.8 \
    ipadic==1.0.0

dbutils.library.restartPython()

LLMが知り得ないであろう質問を尋ねてみる

まずはLLMが知り得ないであろう質問を尋ねてみて、LLMがテキトーな回答をしてくることを確認しましょう。 LLMが知り得ないであろう質問としては、「2024年5月2日時点での三菱UFJフィナンシャル・グループの株価はいくらですか?」という質問をしてみます。

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "tokyotech-llm/Swallow-7b-instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")


PROMPT_DICT = {
    "prompt_input": (
        "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"

    ),
    "prompt_no_input": (
        "以下に、あるタスクを説明する指示があります。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 応答:"
    ),
}

def create_prompt(instruction, input=None):
    """
    Generates a prompt based on the given instruction and an optional input.
    If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
    If no input is provided, it uses the 'prompt_no_input' template.

    Args:
        instruction (str): The instruction describing the task.
        input (str, optional): Additional input providing context for the task. Default is None.

    Returns:
        str: The generated prompt.
    """
    if input:
        # Use the 'prompt_input' template when additional input is provided
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input)
    else:
        # Use the 'prompt_no_input' template when no additional input is provided
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

def ask(query: str):
    instruction="以下のトピックに関する詳細な情報を提供してください。"
    prompt = create_prompt(instruction, query)

    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=False,
        return_tensors="pt"
    )

    tokens = model.generate(
        input_ids.to(device=model.device),
        max_new_tokens=128,
        temperature=0.99,
        top_p=0.95,
        do_sample=True,
    )

    out = tokenizer.decode(tokens[0], skip_special_tokens=True)
    answer = out.split("応答:")[-1]
    return answer


query = "2024年5月2日時点での三菱UFJフィナンシャル・グループの株価はいくらですか?"  
answer = ask(query)
print(f"answer = {answer}")
answer = 2024年5月2日現在、三菱UFJフィナンシャル・グループの株価は1株あたり858円です。

回答はおそらく間違えた回答が返って来ていると思います(実行のたびに回答が変わる)。 ちなみに正解は1554円です。 テキトーな答えを生成しまっていることが確認できます。

RAGを使って回答精度を上げる

ここからは、RAG(Retrieval-Augmented Generation)を使って回答の精度を向上させます。先ほど確認した質問に対する回答精度を、RAGを利用してさらに改善していきます。ここで参照するデータとしては、ベクターDBの章で使用したものをそのまま活用します。

テキストデータの埋め込みとベクターDBへの保存

まず、テキストデータをEmbeddingモデルでベクトル化し、そのベクトルをベクターDBに保存します。この準備によって、関連する情報を迅速に検索できる状態が整います。

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings

# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=30)
splits = text_splitter.split_text(document)

# Embeddingの定義
embedding_model_name = "intfloat/multilingual-e5-base"
embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name)

# Embed
vectorstore = Chroma.from_texts(
    texts=splits, 
    embedding=embedding_model
)

Retieverの定義

次に、as_retriever()メソッドを使って、ベクターDBでの検索機能を定義します。Retrieverは、ユーザーからのクエリに関連する文書や情報をデータベースから引き出す役割を担います。

retriever = vectorstore.as_retriever()

RAGのパイプライン構築

Langchainライブラリには、RAGのパイプラインを構築するためのモジュールが揃っているため、それらを組み合わせていくだけで簡単に実装が可能です。プロンプトのテンプレートには、モデルごとに「お作法」があり、公式ドキュメントに記載されているテンプレートを使用することを推奨します。独自のプロンプトを使用すると、期待通りの回答が得られない可能性が高いため注意が必要です。

from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline

template = """\
以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。
リクエストを適切に完了するための回答を記述してください。\n\n
### 指示:\n以下のトピックに関する詳細な情報を提供してください。その際に次に続く文章を参考にしてください\n{context}\n\n### 入力:\n{question}\n\n### 応答::
"""

prompt = ChatPromptTemplate.from_template(template)

llm_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    max_length=1024,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(
    pipeline = llm_pipeline, 
    model_kwargs = {"temperature": 0}
)

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

RAG実行

では実際に先ほどの質問を参照文章を付与した上で行ってみます。

query = "2024年5月2日時点での三菱UFJフィナンシャル・グループの株価はいくらですか?"
answer = rag_chain.invoke(query)
answer
'Human: 以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。\nリクエストを適切に完了するための回答を記述してください。\n\n\n### 指示:\n以下のトピックに関する詳細な情報を提供してください。その際に次に続く文章を参考にしてください\n2024年5月2日の三菱UFJフィナンシャル・グループの株価は1554円\n2024年5月2日時点、三菱UFJフィナンシャル・グループの株価(調整後終値)は1554円となっています。\n\n三菱UFJフィナンシャル・グループの1年間の株価推移【2024年5月2日時点】\n三菱UFJフィナンシャル・グループの株価は過去1年間で以下の通りに推移しています。\n\nまた、1年前(2023年5月2日)の株価は862.6円です。\n\n仮に、2023年5月2日時点で三菱UFJフィナンシャル・グループの株を取得していた場合、リターンは+80.15%となります。\n\n過去1年間における三菱UFJフィナンシャル・グループの株価の最大上昇日と最大下落日は?\n\n### 入力:\n2024年5月2日時点での三菱UFJフィナンシャル・グループの株価はいくらですか?\n\n### 応答::\n2024年5月2日時点、三菱UFJフィナンシャル・グループの株価は1554円となっています。'

このように、RAGを使うことで、LLMが知識として持っていない最新情報を補完し、より正確な回答を得ることができるようになりました。コンテキストにはRetrieverが取得した関連情報が含まれていることも確認できます。

以上で、RAGの体験は終了です。


最後に

この記事では、Databricks上でSwallowモデルを使ったLLMの動作確認から、テキストベクトル化、ベクターDB、RAGの活用方法まで紹介しました。今後、さらに高度なLLMアプリケーションを構築するための基本的な知識を学べたのではないでしょうか。

LLMの応用はこれからますます広がっていく分野ですので、ぜひ今回の内容をベースに、より実践的なプロジェクトに取り組んでみてください。

以上、「Databricksで始めるLLM - LLMを動かしてみるところからベクターDBやRAGまで」でした。
最後までご覧いただきありがとうございました。


Japan Digital Design株式会社では、一緒に働いてくださる仲間を募集中です。カジュアル面談も実施しておりますので下記リンク先からお気軽にお問合せください。

この記事に関するお問い合わせはこちら


M-AIS
Takuma Urata