見出し画像

Modal 上でComfyUIで FLUX.1 schnellモデルをAPIから画像生成


概要

前回(↑)を発展させて、 FLUX.1 schnell モデルの画像生成を、複数のプロンプトを順次自動で行うようにします。
Modal 上で ComfyUI を動作させ、それをローカルPCからAPI経由で行います。
これによって同じWorkflow で異なるプロンプトを使って画像を複数生成し、色々試すのが楽にできることを意図しています。
※筆者の環境は Ubuntu22.04.04 LTSで、その環境でのみ確認をしています。Python 環境があり、Modalをターミナルから起動できれば問題ないと思います。

使用準備

  1. Modal 上でComfyUI を立ち上げられることを確認します。

  1. ComfyUI のFLUXのページの中程にある Flux Schnell の画像をローカルにダウンロードします。保存場所はどこでも大丈夫です。

  2. 立ち上げているComfyUIの画面に、保存した画像をドラッグ&ドロップします。

  3. 画面右側のメニューボタンが集まっているところの右上に歯車の形をしたボタンがありますので、そこをクリックします。

  4. そこに Enable dev option という欄がありますので、チェックをいれます。すると、メニューボタンに Save (API format) というボタンが増えます。

  5. Save (API format) というボタンをクリックし、ファイルを保存します。ファイルは flux1_schnell_workflow_api.json としましょう。このWorkflow で画像を生成します。自分自身で改良したWorkflowを使用することもできます。

  6. このファイルを開き内容を確認します。Workflow のJSONファイルを開くと番号が振られていますが、この番号がWorkflow の一つ一つのノードに対応しています。"class_type" が"CLIPTextEncode", "SaveImage" となっている番号を控えます。

  7. ブラウザの画面も閉じ、ComfyUI を立ち上げたターミナルもControl-C で終了します。

Modal で動作させるサーバーアプリケーション

Modal の ComfyUI Example を参考にします。
Image は前回同様に FLUX.1 のモデル、エンコーダ、VAEをダウンロードするように変更します。
ただし、APIからComfyUI を動作させて場合には、comfy_cli (1.0.36) ですと、画像生成に30秒以上かかる場合に、Timeoutして停止してしまいます。

Executing workflow: /root/53c1a16b708243e688eb9c450255a0b3.json
━━━                                        8% 0:00:30

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /usr/local/lib/python3.11/site-packages/websocket/[*socket.py:118](http://socket.py:118/) in recv     │
│                                                                              │
│   115 │   │   if sock.gettimeout() == 0:                                     │
│   116 │   │   │   bytes* = sock.recv(bufsize)                                │
│   117 │   │   else:                                                          │
│ ❱ 118 │   │   │   bytes_ = _recv()                                           │
│   119except TimeoutError:                                               │
│   120 │   │   raise WebSocketTimeoutException("Connection timed out")        │
│   121except socket.timeout as e:                                        │
│                                                                              │
│ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
│ │   _recv = <function recv.<locals>._recv at 0x7ea9fce011c0>               │ │
│ │ bufsize = 2                                                              │ │
│ │    sock = <socket.socket fd=4, family=2, type=1, proto=6,                │ │
│ │           laddr=('127.0.0.1', 57641), raddr=('127.0.0.1', 8188)>         │ │
│ ╰──────────────────────────────────────────────────────────────────────────╯ │
│                                                                              │
│ /usr/local/lib/python3.11/site-packages/websocket/_socket.py:97 in _recv     │
│                                                                              │
│    94 │                                                                      │
│    95def _recv():                                                       │
│    96 │   │   try:                                                           │
│ ❱  97 │   │   │   return sock.recv(bufsize)                                  │
│    98 │   │   except SSLWantReadError:                                       │
│    99 │   │   │   pass                                                       │
│   100 │   │   except socket.error as exc:                                    │
│                                                                              │
│ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
│ │    bufsize = 2                                                           │ │
│ │ error_code = None                                                        │ │
│ │       sock = <socket.socket fd=4, family=2, type=1, proto=6,             │ │
│ │              laddr=('127.0.0.1', 57641), raddr=('127.0.0.1', 8188)>      │ │
│ ╰──────────────────────────────────────────────────────────────────────────╯ │
╰──────────────────────────────────────────────────────────────────────────────╯
TimeoutError: timed out

APIではなく、ブラウザからでは問題がありませんので、これらの違いによるものだと推測しますが、現状ではcomfy_cli に修正が必要な様子です。

ご自身のcomfy_cli/command/run.py のexecute() の内部に存在する、以下のtimeout の部分を数字で上書きして修正します。筆者は 300 (=つまり300秒) を入力してTimeout は発生しなくなりました。

    execution = WorkflowExecution(
        workflow, host, port, verbose, progress, local_paths, timeout
    )

さて Image を完成させるために、変更したローカルの run.py をコピーします。 筆者の環境では、このファイルの一つ上の階層にVirtualenv のフォルダがありますのでそこの内部保存されており、 "../modal_comfy/lib/python3.10/site-pakcages" があり、その内部の run.py を変更して、それをコピーしています。
(もっと賢いやり方があれば、知りたい・・)

image = (
  ....
  .copy_local_file( ## comfy_cli/command/run.py::execute() must be modified with timeout of ~300s.
    local_path="../modal_comfy/lib/python3.10/site-packages/comfy_cli/command/run.py",
    remote_path=" /usr/local/lib/python3.11/site-packages/comfy_cli/command",
  )
)

その他、Appのname には適当な名前をつけます。こちらでつけた名前にAPI でアクセスするURLを対応させる必要があります。
@app.function と@modal.webserver のデコレータがある ui() はWebサービスを開始させています。

app = modal.App(name="comfy-api", image=image)

@app.function(
  gpu="T4",
  concurrency_limit=1,
  allow_concurrent_inputs=10,
  container_idle_timeout=600,
  timeout=3600,
)
@modal.web_server(8000, startup_timeout=60)
def ui():
  subprocess.Popen("comfy launch -- --listen 0.0.0.0 --port 8000", shell=True)

使用準備で保存したWorkflow のJSONファイルを同じフォルダに配置し、ファイル名を workflow_filename で指定し、それを mount します。
使用準備で控えた"class_type" が"CLIPTextEncode"となっている番号を
・workflow_data["6"]["inputs"]["text"] = item["prompt"] の 6 のところに、
・workflow_data["9"]["inputs"]["filename_prefix"] = client_idの9 のところに
入力します。
(もっと賢いやり方はある・・けどサボりました)

workflow_filename = "flux1_schnell_workflow_api.json"
### Background container for API
import json
import uuid
from pathlib import Path
from typing import Dict
@app.cls(
  concurrency_limit=1,
  allow_concurrent_inputs=10,
  container_idle_timeout=600,
  gpu="T4",
  mounts=[
    modal.Mount.from_local_file(
      Path(__file__).parent / workflow_filename,
      "/root/" + workflow_filename
    ),
  ],
)
class ComfyUI:
  @modal.enter()
  def launch_comfy_background(self):
    print(f"Comfy process is going to run in background.")
    cmd = "comfy launch --background"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Comfy process is running.")
  
  @modal.method()
  def infer(self, workflow_path: str):
    print(f"Workflow is going to run.")
    cmd = f"comfy run --workflow {workflow_path} --wait"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Workflow is done.")
    
    output_dir = "/root/comfy/ComfyUI/output"
    workflow = json.loads(Path(workflow_path).read_text())
    file_prefix = [
        node.get("inputs")
        for node in workflow.values()
        if node.get("class_type") == "SaveImage"
    ][0]["filename_prefix"]
    print(f"Output is saved in {output_dir} .")

    for f in Path(output_dir).iterdir():
      if f.name.startswith(file_prefix):
        return f.read_bytes()
  
  @modal.web_endpoint(method="POST")
  def api(self, item: Dict):
    from fastapi import Response
    print("API hooked.")
    workflow_data = json.loads(
      (Path("/root") / workflow_filename).read_text()
    )
    print(f"Prompt received: {item['prompt']}")
    workflow_data["6"]["inputs"]["text"] = item["prompt"]
    client_id = uuid.uuid4().hex
    workflow_data["9"]["inputs"]["filename_prefix"] = client_id
    new_workflow_file = f"{client_id}.json"
    print(f"Workflow to run is saved: {new_workflow_file}")
    json.dump(workflow_data, Path(new_workflow_file).open("w"))
    
    img_bytes=self.infer.local(new_workflow_file)
    
    return Response(img_bytes, media_type="image/jpeg")

ローカルで動作させるクライアントアプリケーション

サーバーアプリケーションのURL とプロンプトを渡し、HTTPリクエストを送って成功(200) レスポンスがあれば生成された画像データを受け取って、同じフォルダに Outputフォルダに保存します。

import argparse
from pathlib import Path
import time
import requests

OUTPUT_DIR = Path(__file__).parent / "output"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

def main(args: argparse.Namespace):
  url = args.modalurl
  if not url.endswith('/'):
    url = url + '/'

  print(f"URL: {url}")
  data = {
    "prompt": args.prompt,
  } 
  print(f"Client for FLUX.1 Schnell Workflow running at server-side.")
  print(f"Sending request to {url}")
  # print(f"Prompt: {data['prompt']}")
  print(f"Parameter: {data}")
  print("Waiting for response.")
  
  start_time = time.time()
  res = requests.post(url, json=data)
  if res.status_code == 200:
    end_time = time.time()
    print(f"Successfully finished in {end_time - start_time} s.")
    filename = OUTPUT_DIR / f"{slugify(args.prompt)}.png"
    filename.write_bytes(res.content) 
    print(f"Saved to '{filename} .")
  else:
    if res.status_code == 404:
      print(f"Workflow API not found at {url}")
    res.raise_for_status()
  
def parse_args(arglist: list[str]) -> argparse.Namespace:
  parser = argparse.ArgumentParser()
  
  parser.add_argument(
    "--modalurl",
    type=str,
    required=True,
    help="URL for modal ComfyUI app with the defined FLUX.1 Schnell image generation workflow.",
  )
  parser.add_argument(
    "--prompt",
    type=str,
    required=True,
    help="What to draw the by generative AI model."
  )
  return parser.parse_args(arglist[1:])

def slugify(s: str) -> str:
  return s.lower().replace(" ","-").replace(".","-").replace("/","-").replace(",","-")[:64]

import sys

if __name__ == "__main__":
  args = parse_args(sys.argv)
  main(args)

バッチ処理

バッチ処理を python で書くとこのような感じです。

ここから先は

915字 / 2画像

¥ 500

この記事が気に入ったらサポートをしてみませんか?