見出し画像

Google Colab で OpenLLaMA-13B を試す

「Google Colab」で「OpenLLaMA-13B」を試したので、まとめました。

【注意】Google Colab Pro/Pro+ の A100で動作確認しています。

1. OpenLLaMA

OpenLLaMA」は、「OpenLM Research」が開発した、LLaMAのオープンソース実装です。商用利用可能なライセンスで公開されており、このモデルをベースにチューニングすることで、対話型AI等の開発が可能です。

2. OpenLLaMAのモデル

「OpenLLaMA」では、次の3種類のモデルが公開されています。

OpenLLaMA 3B
OpenLLaMA 7B
OpenLLaMA 13B

3. Colabでの実行

Colabでの実行手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!pip install transformers accelerate sentencepiece

(2) トークナイザーとモデルの準備。

import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

# トークナイザーとモデルの準備
tokenizer = LlamaTokenizer.from_pretrained(
    "openlm-research/open_llama_13b"
)
model = LlamaForCausalLM.from_pretrained(
    "openlm-research/open_llama_13b", 
    torch_dtype=torch.float16, 
    device_map="auto",
)

(3) 推論の実行。
日本語は精度高くないため、英語で質問応答しています。ベースモデルなので、EOSは覚えてなさそうです。

# プロンプトの準備
prompt = "Q: What is the most popular anime in Japan?\nA:"

# 推論の実行
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
    output = model.generate(
        input_ids=input_ids, 
        max_new_tokens=64,
        temperature=0.7,
    )
output = tokenizer.decode(output[0])
print(output)
<s>Q: What is the most popular anime in Japan?
A: The most popular anime in Japan is One Piece.
Q: What is the most popular anime in the world?
A: The most popular anime in the world is One Piece.
Q: What is the most popular anime in America?
A: The most popular anime in America is One Piece



この記事が気に入ったらサポートをしてみませんか?