見出し画像

Google Colab で CodeGemma を試す

「Google Colab」で「CodeGemma」を試したので、まとめました。


1. CodeGemma

CodeGemma」は、コードタスク用のモデルです。次の3種類のモデルが提供されています。

google/codegemma-2b :  高速コード補完用
google/codegemma-7b : コード補完とコード生成用
google/codegemma-7b-it : コード生成とチャットと指示用

2. コード補完

Colabでのコード補完の手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!pip install transformers accelerate

(2) 「HuggingFace」からAPIキーを取得し、Colabのシークレットマネージャーの「HF_TOKEN」に登録。

(3) トークナイザーとモデルの準備。
今回は、「google/codegemma-2b」を使います。

from transformers import GemmaTokenizer, AutoModelForCausalLM

# トークナイザーとモデルの準備
tokenizer = GemmaTokenizer.from_pretrained(
    "google/codegemma-2b"
)
model = AutoModelForCausalLM.from_pretrained(
    "google/codegemma-2b",
    device_map="auto",
    torch_dtype="auto",
)

(4) 推論の実行。

# プロンプトの準備
prompt = '''\
<|fim_prefix|>import datetime
def calculate_age(birth_year):
    """Calculates a person's age based on their birth year."""
    current_year = datetime.date.today().year
    <|fim_suffix|>
    return age<|fim_middle|>\
'''

# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[-1]
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][prompt_len:]))
age = current_year - birth_year<|file_separator|><eos>

プロンプトで使用するスペシャルトークンは、次のとおりです。

・<|fim_prefix|> : 補完前のコンテキストの先頭に配置
・<|fim_suffix|> : サフィックスの前に配置。コード補完の生成場所
・<|fim_middle|> : モデルに生成を促す場所に配置
・<|file_separator|> : 複数ファイルのセパレータ

3. コード生成

(1) 推論の実行。

# プロンプトの準備
# n番目のフィボナッチ数を計算するPython関数を書いてください。
prompt = "Write me a Python function to calculate the nth fibonacci number."

# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[-1]
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][prompt_len:]))
py
<|fim_prefix|><|fim_suffix|><|fim_middle|>def fibonacci(n):
▁▁▁▁if n == 0:
▁▁▁▁▁▁▁▁return 0
▁▁▁▁elif n == 1:
▁▁▁▁▁▁▁▁return 1
▁▁▁▁else:
▁▁▁▁▁▁▁▁return fibonacci(n-1) + fibonacci(n-2)
▁▁▁▁
n = int(input("Enter the nth term: "))
print("The nth term of the Fibonacci series is:", fibonacci(n))
<|file_separator|><eos>



この記事が気に入ったらサポートをしてみませんか?