見出し画像

魚ロボットがパソコンの前に座り論文を書く

Sakana AI、研究開発プロセスそのものを自動化する「AIサイエンティスト」を開発したので、仕組みを調べ解説する。

generate_idea.py アイデアを出す


インポートと環境変数の設定:

import json
import os
import os.path as osp
import time
from typing import List, Dict, Union
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers
import requests
import backoff
S2_API_KEY = os.getenv("S2_API_KEY")

2. アイデア生成のためのプロンプト:
```python
idea_first_prompt = """...
Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
..."""

idea_reflection_prompt = """...
In the next attempt, try and refine and improve your idea.
..."""

これらは、LLMにアイデアを生成させるためのプロンプトテンプレートです。最初のアイデア生成と、その後の改善のためのプロンプトが定義されています。
「次のインパクトのある創造的なアイデアを考え出す」

アイデア生成関数:

def generate_ideas(
    base_dir,
    client,
    model,
    skip_generation=False,
    max_num_generations=20,
    num_reflections=5,
):
    # アイデアの生成ロジック
    ...

この関数は以下の手順でアイデアを生成します:

  • 既存のアイデアをロードするか、新しいアイデアを生成します。

  • LLMを使用してアイデアを生成し、複数回の反復で改善します。

  • 生成されたアイデアをJSONファイルに保存します。

オープンエンドなアイデア生成:

def generate_next_idea(
    base_dir,
    client,
    model,
    prev_idea_archive=[],
    num_reflections=5,
    max_attempts=10,
):
    # 次のアイデアの生成ロジック
    ...

この関数は、以前のアイデアアーカイブに基づいて次のアイデアを生成します。初回実行時にはシードアイデアを使用します。

論文検索関数:

@backoff.on_exception(
    backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
)
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
    # Semantic Scholar APIを使用した論文検索ロジック
    ...

この関数は、Semantic Scholar APIを使用して与えられたクエリに基づいて関連論文を検索します。エラー時には指数バックオフを使用して再試行します。

新規性チェックのためのプロンプト:

novelty_system_msg = """...
You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
..."""

novelty_prompt = """...
Respond in the following format:
..."""

これらは、LLMにアイデアの新規性を評価させるためのプロンプトテンプレートです。

新規性チェック関数:

def check_idea_novelty(
    ideas,
    base_dir,
    client,
    model,
    max_num_iterations=10,
):
    # アイデアの新規性チェックロジック
    ...

この関数は、生成された各アイデアの新規性をチェックします:

  • LLMを使用して検索クエリを生成します。

  • 論文を検索し、結果を分析します。

  • アイデアが新規かどうかを決定します

メイン実行部分:

if __name__ == "__main__":
    # コマンドライン引数の解析
    parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
    ...

    # クライアントの作成
    if args.model == "claude-3-5-sonnet-20240620":
        import anthropic
        ...
    elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
        import openai
        ...
    ...

    # アイデア生成と新規性チェックの実行
    ideas = generate_ideas(...)
    if args.check_novelty:
        ideas = check_idea_novelty(...)

この部分では:

  • コマンドライン引数を解析して、使用するモデルやタスクを設定します。

  • 適切なAI/LLMクライアント(Anthropic、OpenAI、Deepseek、LLaMAなど)を初期化します。

  • アイデア生成プロセスを実行し、オプションで新規性チェックを行います。

このスクリプトは、AI研究者が新しい実験アイデアを自動的に生成し、それらの新規性を評価するためのツールとして機能します。LLMを活用して創造的なアイデアを生成し、既存の研究文献との比較を行うことで、潜在的に価値のある新しい研究方向を特定することを目的としています。

バックオフ関数:

def on_backoff(details):
    print(
        f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
        f"calling function {details['target'].__name__} at {time.strftime('%X')}"
    )

この関数は、APIリクエストが失敗した際のバックオフ(一時的な待機)をログに記録します。これにより、リクエストの再試行時間や試行回数を追跡できます。

アイデアの構造:


生成されるアイデアは以下のような構造を持つJSONオブジェクトです:

{
    "Name": "アイデアの短い識別子",
    "Title": "アイデアのタイトル",
    "Experiment": "実装の概要",
    "Interestingness": "1から10の評価",
    "Feasibility": "1から10の評価",
    "Novelty": "1から10の評価"
}
  1. LLMとの対話:

text, msg_history = get_response_from_llm(
    idea_first_prompt.format(...),
    client=client,
    model=model,
    system_message=idea_system_prompt,
    msg_history=msg_history,
)

この部分では、フォーマットされたプロンプトをLLMに送信し、応答を取得しています。`msg_history`を使用することで、会話の文脈を維持しています。

  1. JSON抽出:

json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"

LLMの応答からJSONデータを抽出し、抽出が成功したことを確認しています。

新規性チェックのロジック:

if "decision made: novel" in text.lower():
    print("Decision made: novel after round", j)
    novel = True
    break
if "decision made: not novel" in text.lower():
    print("Decision made: not novel after round", j)
    break

LLMの応答を分析し、アイデアが新規であるかどうかを判断しています。

論文検索結果の処理:

paper_strings = []
for i, paper in enumerate(papers):
    paper_strings.append(
        """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
            i=i,
            title=paper["title"],
            authors=paper["authors"],
            venue=paper["venue"],
            year=paper["year"],
            cites=paper["citationCount"],
            abstract=paper["abstract"],
        )
    )
papers_str = "\n\n".join(paper_strings)

検索された論文の情報を整形し、LLMに提供可能な形式に変換しています。

  1. 結果の保存:

results_file = osp.join(base_dir, "ideas.json")
with open(results_file, "w") as f:
    json.dump(ideas, f, indent=4)

生成されたアイデアと新規性チェックの結果をJSONファイルに保存しています。

  1. コマンドライン引数の処理:

parser.add_argument(
    "--experiment",
    type=str,
    default="nanoGPT",
    help="Experiment to run AI Scientist on.",
)
parser.add_argument(
    "--model",
    type=str,
    default="gpt-4o-2024-05-13",
    choices=[
        "claude-3-5-sonnet-20240620",
        "gpt-4o-2024-05-13",
        "deepseek-coder-v2-0724",
        "llama3.1-405b",
    ],
    help="Model to use for AI Scientist.",
)

これらの引数により、ユーザーは実行時に特定の実験タイプやAIモデルを指定できます。

モデル選択

クライアント初期化:

if args.model == "claude-3-5-sonnet-20240620":
    import anthropic
    print(f"Using Anthropic API with model {args.model}.")
    client_model = "claude-3-5-sonnet-20240620"
    client = anthropic.Anthropic()
elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
    import openai
    print(f"Using OpenAI API with model {args.model}.")
    client_model = "gpt-4o-2024-05-13"
    client = openai.OpenAI()
# ... 他のモデルのケース ...

このセクションでは、コマンドライン引数で指定されたモデルに基づいて、適切なAPIクライアントを初期化しています。各モデルに対応する適切なライブラリ(anthropic, openai等)をインポートし、クライアントオブジェクトを作成しています。

  1. ディレクトリ設定:

base_dir = osp.join("templates", args.experiment)
results_dir = osp.join("results", args.experiment)

実験テンプレートと結果を保存するディレクトリを設定しています。

アイデア生成の実行:

ideas = generate_ideas(
    base_dir,
    client=client,
    model=client_model,
    skip_generation=args.skip_idea_generation,
    max_num_generations=MAX_NUM_GENERATIONS,
    num_reflections=NUM_REFLECTIONS,
)

`generate_ideas`関数を呼び出してアイデアを生成しています。引数には、ベースディレクトリ、APIクライアント、モデル名、生成をスキップするかどうか、最大生成回数、各アイデアの改善回数が含まれます。

  1. 新規性チェックの実行(オプション):

if args.check_novelty:
    ideas = check_idea_novelty(
        ideas,
        base_dir=base_dir,
        client=client,
        model=client_model,
    )

コマンドライン引数で指定された場合、生成されたアイデアの新規性をチェックします。

  1. エラーハンドリング:
    スクリプト全体を通して、try-except文を使用してエラーを捕捉し、処理を続行できるようにしています。例:

try:
    # アイデア生成や新規性チェックのロジック
except Exception as e:
    print(f"Error: {e}")
    continue
  1. LLMとのインタラクション:
    LLMとのやり取りは複数回行われ、各ラウンドで以下のようなプロセスが繰り返されます:

  • プロンプトの生成

  • LLMへの送信

  • 応答の解析

  • 必要に応じた追加情報(論文検索結果など)の提供

  1. アイデアの反復的改善:

for j in range(num_reflections - 1):
    print(f"Iteration {j + 2}/{num_reflections}")
    text, msg_history = get_response_from_llm(
        idea_reflection_prompt.format(...),
        client=client,
        model=model,
        system_message=idea_system_prompt,
        msg_history=msg_history,
    )

この部分では、生成されたアイデアを複数回にわたって改善しています。LLMに前回の結果を提供し、改善を求めています。

  1. 早期終了条件:

if "I am done" in text:
    print(f"Idea generation converged after {j + 2} iterations.")
    break

LLMが改善の必要がないと判断した場合、反復プロセスを早期に終了します。

  1. 論文検索のバックオフ機能:

@backoff.on_exception(
    backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
)
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
    # ...

この`@backoff.on_exception`デコレータは、HTTPエラーが発生した場合に指数関数的なバックオフ戦略を適用します。これにより、一時的なネットワーク問題や API の制限による失敗を自動的に処理し、再試行します。

Semantic Scholar APIの使用:

rsp = requests.get(
    "https://api.semanticscholar.org/graph/v1/paper/search",
    headers={"X-API-KEY": S2_API_KEY},
    params={
        "query": query,
        "limit": result_limit,
        "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
    },
)

この部分では、Semantic Scholar APIを使用して論文を検索しています。クエリ、結果の制限、必要なフィールドを指定しています。

  1. 新規性チェックのロジック詳細:

for j in range(max_num_iterations):
    try:
        # LLMとのやり取り
        # ...
        if "decision made: novel" in text.lower():
            novel = True
            break
        if "decision made: not novel" in text.lower():
            break
        
        # 論文検索
        query = json_output["Query"]
        papers = search_for_papers(query, result_limit=10)
        # ...
    except Exception as e:
        print(f"Error: {e}")
        continue

このループでは、LLMとの対話と論文検索を繰り返し行い、アイデアの新規性を判断しています。LLMが明確な決定を下すか、最大反復回数に達するまで続けます。

  1. アイデアの保存と更新:

idea["novel"] = novel
# ...
with open(results_file, "w") as f:
    json.dump(ideas, f, indent=4)

新規性チェックの結果をアイデアオブジェクトに追加し、更新されたアイデアリストをJSONファイルに保存しています。

  1. プロンプトのカスタマイズ:

with open(osp.join(base_dir, "prompt.json"), "r") as f:
    prompt = json.load(f)
idea_system_prompt = prompt["system"]

システムプロンプトを外部JSONファイルから読み込んでいます。これにより、異なる実験タイプに対して異なるプロンプトを簡単に設定できます。

  1. コード注入:

with open(osp.join(base_dir, "experiment.py"), "r") as f:
    code = f.read()

実験用のPythonコードを読み込み、アイデア生成のプロンプトに注入しています。これにより、LLMは具体的な実験セットアップに基づいてアイデアを生成できます。

  1. 並行処理の可能性:
    現在のスクリプトは逐次的に実行されていますが、`generate_ideas`や`check_idea_novelty`関数を並列化することで、処理速度を向上させる可能性があります。

  2. セキュリティ上の考慮事項:
    APIキーは環境変数から読み込まれており、コード内にハードコードされていません。これはセキュリティのベストプラクティスに従っています。

  3. 拡張性:
    このスクリプトは、新しいAIモデルや実験タイプを簡単に追加できるように設計されています。新しいモデルを追加するには、適切なAPIクライアントの初期化を`if name == "main":`ブロックに追加するだけです。

  4. ログ記録:
    スクリプトは重要なステップや決定をコンソールに出力しています。より詳細なログ記録を実装することで、プロセスの追跡と分析がさらに容易になるでしょう。

このスクリプトは、AI研究の自動化における先進的なアプローチを示しています。LLMの創造的能力と既存の科学文献データベースを組み合わせることで、新しい研究アイデアの生成とその新規性の評価を自動化しています。これにより、研究者は潜在的に価値のある新しい研究方向を効率的に特定し、探求することができます。将来的には、このようなツールが研究プロセスの初期段階を大幅に加速し、より革新的で影響力のある研究成果につながる可能性があります。

この記事が気に入ったらサポートをしてみませんか?