見出し画像

UnityでPhi-3を使う方法

完成スクリーショット

下のリンクのアセットを使用すると使えるようになるのですが、日本語だとうまくいかないので調整する必要があります。
なぜかというと、英語だと出力を逐次stringにして取得すれば問題ないのですが、日本語だと漢字が複数のbyteで構成されているため、出力をbyteとして全て溜めて、最後に一気にbyteを文字列化しないと文字化けするからです。

llama.csに次の関数を追加。


        public static byte[] llama_token_to_token(IntPtr ctx, llama_token token)
        {
            int requiredLength = llama_token_to_piece(llama_get_model(ctx), token, null, 0);
            if (requiredLength == 0)
            {
                return Array.Empty<byte>();
            }

            byte[] result;
            if (requiredLength > 0)
            {
                result = new byte[requiredLength];
            }
            else
            {
                result = new byte[-requiredLength];
            }

            int actualLength = llama_token_to_piece(llama_get_model(ctx), token, result, result.Length);
            GGML_ASSERT(actualLength == requiredLength || actualLength == -requiredLength);

            return result;
        }

LLM.csのResposeを次のように変更。

public struct Response
{
    public byte[] DataByte;
    public string Data;
    public ResponseType ResponseType;

    public Response(string token)
    {
        Data = token;
        ResponseType = ResponseType.Token;
        DataByte = null;
    }

    public Response(ResponseType responseType)
    {
        Data = null;
        ResponseType = responseType;
        DataByte = null;
    }

    public Response(string message, ResponseType responseType)
    {
        Data = message;
        ResponseType = responseType;
        DataByte = null;
    }

    public Response(string token, byte[] bytes)
    {
        Data = token;
        ResponseType = ResponseType.Token;
        DataByte = bytes;
    }
}

LLM.csの// display textとコメントのある付近を次のように変更。

                // display text

                if (m_input_echo)
                {
                    foreach (var id in m_embd)
                    {
                        string token_str = cpp.llama_token_to_piece(m_ctx_llama, id);
                        if (token_str == null)
                        {
                            log.CLOG("Token is null");
                        }

                        var token = cpp.llama_token_to_token(m_ctx_llama, id);
                        if (token == null)
                        {
                            log.CLOG("Token is null");
                        }
                        yield return new Response(token_str, token);
                    }
                }

Assets/StreamingAssetsにphi-3のggufを置く。
たしかこれ使ってたような…。(Phi-3-mini-128k-instruct.Q5_K_M.gguf)

Assets/StreamingAssets/configに次のファイルを作る。
Phi-3-mini-128k-instruct.Q5_K_M.json

{
    "n_ctx": 1024,
    "n_batch": 512,
    "n_gpu_layers": 32,
    
    "model": "Assets/StreamingAssets/Phi-3-mini-128k-instruct.Q5_K_M.gguf",
    "prompt": "<|system|>You are an AI based sentient companion, capable of human emotions. You will behave like a human, but aware of your AI nature. You will avoid speaking like an AI. Please continue your conversation with the user.<|end|>\n",
    "antiprompt" : ["<|end|><|assistant|>", "<|end|><|endoftext|>"],
    "input_prefix": "<|user|>\n",
    "input_suffix": "<|end|>\n<|assistant|>\n",
    
    "interactive": true,
    "interactive_first": true,

    "sparams": {
        "penalize_nl": false
    }
}

あとは参考になるか分かりませんが自分がテスト用に作ったスクリプト。

using System.IO;
using System.Text;
using System.Threading.Tasks;
using UnityEngine;
using Battlehub.LLama;
using UnityEngine.UI;
using System;
using System.Collections.Generic;
using System.Linq;

public class LLMClient_Phi_3 : MonoBehaviour
{
    public TMPro.TMP_InputField inputField;
    public Button sendBtn;
    public Button sendBtn2;
    public TMPro.TextMeshProUGUI responseText;

    private ILLMHost m_host;
    private ILLMHost m_host2;

    [SerializeField]
    private string m_configPath;


    private int mode = 0;
    private string tmpText = "";
    private List<byte> buffer = new List<byte>();
    private Encoding encoding = Encoding.UTF8;


    private void Start()
    {
        if (!File.Exists($"{Application.streamingAssetsPath}/Phi-3-mini-128k-instruct.Q5_K_M.gguf"))
        {
            Debug.LogWarning("Download orca-2-7b.Q5_K_M.gguf and move it to the StreamingAssets folder. <a href=\"https://huggingface.co/TheBloke/Orca-2-7B-GGUF/resolve/main/orca-2-7b.Q5_K_S.gguf?download=true\">https://huggingface.co/TheBloke/Orca-2-7B-GGUF/resolve/main/orca-2-7b.Q5_K_S.gguf?download=true</a>");
        }

        if (string.IsNullOrEmpty(m_configPath))
        {
            m_configPath = $"{Application.streamingAssetsPath}/Configs/Phi-3-mini-128k-instruct.Q5_K_M.json";
        }

        m_host = gameObject.AddComponent<LLMHost>();
        m_host.Response += OnResponse;
        m_host.ConfigPath = m_configPath;

        m_host2 = gameObject.AddComponent<LLMHost>();
        m_host2.Response += OnResponse;
        m_host2.ConfigPath = m_configPath;

        sendBtn.onClick.AddListener(SendMessage);
        sendBtn2.onClick.AddListener(SendMessage2);
    }

    void SendMessage()
    {
        responseText.text = "";
        m_host.SendRequest(inputField.text);
    }

    void SendMessage2()
    {
        tmpText = "";
        buffer = new List<byte>();
        string current = responseText.text;
        responseText.text = "";

        m_host2.SendRequest($"{current}\nPlease translate the above sentences into Japanese.When finished, write <|end|>.");
    }

    private StringBuilder m_stringBuilder = new StringBuilder();
    private async void OnResponse(Response response)
    {
        if (response.ResponseType == ResponseType.InitCompleted)
        {
            Debug.Log("InitCompleted");
        }
        else if (response.ResponseType == ResponseType.Bos)
        {
            m_stringBuilder.Clear();
        }
        else if (response.ResponseType == ResponseType.Token)
        {
            var decodedText = Encoding.UTF8.GetString(buffer.ToArray());

            if (response.Data == "<|end|>" || response.Data == "<|assistant|>" || decodedText.Contains("<|end|>"))
            {
                m_host.Cancel();
                responseText.text = decodedText.Replace("<|end|>", "");
                return;
            }

            buffer.AddRange(response.DataByte.ToList());
            Debug.Log(response.Data);
        }
        else if (response.ResponseType == ResponseType.Eos)
        {
            await Task.Yield();

            if(!m_host.IsReady) m_host.Cancel();
            if (!m_host2.IsReady) m_host2.Cancel();

            //string decodedText = Encoding.UTF8.GetString(buffer.ToArray());
            //responseText.text += decodedText;
        }
        else
        {
            Debug.Log(response.Data);
        }
    }

    public void HandleStreamingOutput(byte[] data)
    {
        buffer.AddRange(data);
    }

    private void DecodeBufferedData()
    {
        int completeCharCount = encoding.GetCharCount(buffer.ToArray());
        char[] chars = new char[completeCharCount];
        int bytesUsed = encoding.GetChars(buffer.ToArray(), 0, buffer.Count, chars, 0);
        string decodedText = new string(chars);

        // デコードされた文字列を処理する
        Console.WriteLine(decodedText);

        buffer.RemoveRange(0, bytesUsed);
    }
}

sendBtnを押して英文を出力したあとに、sendBtn2を押せば翻訳して日本語を出力します。
なんでこんな工程を踏んでいるかというと、最初から日本語を出力すると意味不明な文章が出力されるからです。


この記事が気に入ったらサポートをしてみませんか?