見出し画像

Gradioの`ChatInterface`こと始め その8: mlx_lmのstream_generateを使うスクリプト

MLXやmlx-lmがどんどんバージョンアップしてきて、mlx_lmのなかにgenerate_stepや、stream_generateが組み込まれてきました。
参考リンク  

それに合わせて、stream_generateを採用したスクリプトへと変更してみました(こちらのリンクも参考)。ただmlxフォーマットのモデルでないとスムーズに動かない印象です。
そのあたりは各自確認してください。

import gradio as gr
from mlx_lm import load, stream_generate

# mlx形式のモデルでないと、streamingがうまくいかない様子

model_1 = "mlx-community/calm3-22b-chat-4bit"
model_2 = "mlx-community/Llama-3.1-70B-Japanese-Instruct-2407-4bit"


selected_model = model_2 #あるいはモデルをローカルに収納している場所へのparh 
model, tokenizer = load(selected_model)


def predict(message, history, system_message, tokens, temp):
    conversation = []


    for human, assistant in history:
        conversation.append({'role': 'user', 'content': human})
        conversation.append({'role': 'assistant', 'content': assistant})
        
    conversation.append({'role': 'user', 'content': message})
    conversation.insert(0, {'role': 'system', 'content': system_message})
     
    prompt = tokenizer.apply_chat_template(conversation,
                                           tokenize=False,
                                           add_generation_prompt=True)


    #print  (f"入力される最終プロンプトはこんな感じ:\n{prompt}") 
    
    generation_args = {
        "max_tokens": tokens,
        "temp": temp,
        "repetition_penalty": 1.2,
        "repetition_context_size": 20,
        "top_p": 0.95,
    }

    
    partial_message = ""
    for text in stream_generate(model, tokenizer, prompt, **generation_args):
        if text is not None:
            partial_message += text
            yield partial_message


demo = gr.ChatInterface(predict, 
    title=selected_model,
    description="MLX Chat",
    additional_inputs=[
        gr.Textbox("あなたは誠実で優秀な日本人のアシスタントです。特に指示がない限り日本語で応答します。", lines=5, max_lines=50, label="System Prompt"), 
        gr.Slider(100, 3000, value=1200, label="Tokens to generate"),
        gr.Slider(0, 1, value=0.8, label="Temperture")
                        ]
                       )

if __name__ == "__main__":
    demo.launch()

とりあえず、pip install -U mlx と pip install -U mlx_lm を忘れないでください。

スクリプト内にあるように、下記の2つのモデルが動くのは確かめました。
"mlx-community/calm3-22b-chat-4bit"
"mlx-community/Llama-3.1-70B-Japanese-Instruct-2407-4bit"

上二つのモデルをmlxフォーマットでアップロードしてもらえていて非常に助かってます。

Calm3-22B-Chatの日本語生成なかなか綺麗な印象でした。


#AI #AIとやってみた #やってみた #mlx #MacbookPro #ローカルLLM #Huggingface

この記事を最後までご覧いただき、ありがとうございます!もしも私の活動を応援していただけるなら、大変嬉しく思います。