シュッといい感じの評価してくれるエージェントの Adala を触ってみた

ちゃっす(/・ω・)/



Label Studio さんのブログで紹介されてた Adala というのが気になったので触ってみたぞい☆


GitHub はこちら(/・ω・)/


アーキ図


まぁ実際に動かした結果を見た方が早いので Let's do it !!


Google Colab で動かしたよ(/・ω・)/


まずはいんすとーる

!pip install adala


今回は二値分類のタスクを解かせてみるぞい☆
ということで教師データを用意するでござんす

import pandas as pd

df = pd.DataFrame([
    ["寿司", "Traditional"],
    ["ラーメン", "Modern"],
    ["天ぷら", "Traditional"],
    ["ハンバーガー", "Modern"],
    ["刺身", "Traditional"],
    ["カレーライス", "Modern"],
    ["おでん", "Traditional"],
    ["パスタ", "Modern"],
    ["餃子", "Modern"],
    ["納豆", "Traditional"],
    ["うなぎの蒲焼", "Traditional"],
], columns=["text", "ground_truth"])

df

んで、dataset としてぶち込む

from adala.datasets import DataFrameDataset

dataset = DataFrameDataset(df=df)


そしてエージェントの設定でござる

from adala.agents import Agent
from adala.environments import BasicEnvironment
from adala.skills import ClassificationSkill
from adala.runtimes import OpenAIRuntime
from rich import print
import os

os.environ["OPENAI_API_KEY"] = "sk-"

agent = Agent(
    # define the agent's labeling skill that should classify text onto 2 categories
    skills=ClassificationSkill(
        name='traditional_or_modern_detection',
        description='Understanding traditional dish and modern dish statements from text.',
        instructions='Classify a dish name as either expressing "Traditional" or "Modern" statements.',
        labels=['Traditional', 'Modern'],
        input_data_field='text'
    ),
    
    # basic environment extracts ground truth signal from the input records
    environment=BasicEnvironment(
        ground_truth_dataset=dataset,
        ground_truth_column='ground_truth'
    ),
    
    runtimes = {
        # You can specify your OPENAI API KEY here via `OpenAIRuntime(..., api_key='your-api-key')`
        'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
        'openai-gpt3': OpenAIRuntime(model='gpt-3.5-turbo'),
        'openai-gpt4': OpenAIRuntime(model='gpt-4'),
    },
    default_runtime='openai',
    
    # NOTE! If you don't have an access to gpt4 - replace it with "openai-gpt3"
    default_teacher_runtime='openai-gpt4'
)

print(agent)

print 結果はこちらよん

Agent Instance

Environment: BasicEnvironment
Skills: traditional_or_modern_detection
Runtimes: openai, openai-gpt3, openai-gpt4
Default Runtime: openai
Default Teacher Runtime: openai-gpt4

では学びなさい(/・ω・)/

learning_experience = agent.learn(learning_iterations=3, accuracy_threshold=0.95)


で、ここからがミソであるが Option として指定した Interation だけ評価、分析、改善を繰り返すのである(/・ω・)/


評価(五件しか表示されないけど11件の教師データを見てるぞ☆)

100%|██████████| 11/11 [00:02<00:00,  3.93it/s]

=> Iteration #0: Comparing to ground truth, analyzing and improving ...
Comparing predictions to ground truth data ...
                                                                                                                   
  text           ground_truth   traditional_or_modern_de…   score                       ground_truth__x__traditi…  
 ───────────────────────────────────────────────────────────────────────────────────────────────────────────────── 
  寿司           Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.14485012, 'Modern':                                 
                                                            -2.003607}                                             
  ラーメン       Modern         Traditional                 {'Traditional':             False                      
                                                            -0.12320776999999997,                                  
                                                            'Modern': -2.1548545}                                  
  天ぷら         Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.02719422399999996,                                  
                                                            'Modern': -3.6183197}                                  
  ハンバーガー   Modern         Modern                      {'Traditional':             True                       
                                                            -4.058615, 'Modern':                                   
                                                            -0.017423895999999977}                                 
  刺身           Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.4935383199999999,                                   
                                                            'Modern': -0.9427952}                                  
                                                                                       

分析

Analyze evaluation experience ...
100%|██████████| 3/3 [00:00<00:00, 157.31it/s]
100%|██████████| 3/3 [00:15<00:00,  5.02s/it]
Number of errors: 3
Accuracy = 72.73%
Improve "traditional_or_modern_detection" skill based on analysis ...
Updated instructions for skill "traditional_or_modern_detection":

Based on the dish name provided, determine whether it represents a "Traditional" or "Modern" culinary style.

Examples:

Input: ラーメン
Output: Modern

Input: 餃子
Output: Modern

Input: カレーライス
Output: Modern

改善結果を適用

Re-apply traditional_or_modern_detection skill to dataset ...
100%|██████████| 11/11 [00:03<00:00,  3.66it/s]


Take2

=> Iteration #1: Comparing to ground truth, analyzing and improving ...
Comparing predictions to ground truth data ...
                                                                                                                   
  text           ground_truth   traditional_or_modern_de…   score                       ground_truth__x__traditi…  
 ───────────────────────────────────────────────────────────────────────────────────────────────────────────────── 
  寿司           Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.00504399839999998,                                  
                                                            'Modern': -5.292077}                                   
  ラーメン       Modern         Modern                      {'Traditional':             True                       
                                                            -4.4393187, 'Modern':                                  
                                                            -0.011874314000000028}                                 
  天ぷら         Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.009848873999999952,                                 
                                                            'Modern': -4.62532}                                    
  ハンバーガー   Modern         Modern                      {'Traditional':             True                       
                                                            -5.9589763, 'Modern':                                  
                                                            -0.0025858853000000357}                                
  刺身           Traditional    Traditional                 {'Traditional':             True                       
                                                            -0.0013666658000000445,                                
                                                            'Modern': -6.5960474}                                  
                                                                                                                   
Analyze evaluation experience ...
Number of errors: 0
Accuracy = 100.00%
Accuracy threshold reached (1.0 >= 0.95)
Train is done!


ここで Accuracy が 100 になったので終了!!



ちなみに、推論は Runtimes として Option に設定したものは使用されるけれど分析、改善は Teacher Runtime が実施するぞ☆
デフォルトは GPT4 だぞ☆



でまぁ結果も見れるざんす(/・ω・)/

learning_experience.predictions


推論結果よ


ではテストしマッスルか

test_df = pd.DataFrame([
    "カリフォルニアロール",
    "うどん",
    "アイスクリーム",
    "おにぎり"
], columns=['text'])
test_df


いざ(/・ω・)/

result = agent.apply_skills(test_df)
result.predictions


結果

それっぽい



んでまぁ今回はクイックスタートをちょちょっといじってシュッと試しただけだけなので教師データ(Ground Truth) は最初に設定したけれど、必要に応じて request_feedback を実装すれば別のデータソース(例えば Label Studio)の結果を取り込んで実行できると思いますわ~(/・ω・)/

    @abstractmethod
    def request_feedback(self, skill: BaseSkill, experience: ShortTermMemory):
        """Request user feedback using predictions and update internal ground truth set."""



ちなみに今使えるスキルはこんな感じみたい(/・ω・)/

👉 Available skills



まだできたばっかりなので今後に期待(/・ω・)/



というわけでおしまい。

この記事が気に入ったらサポートをしてみませんか?