見出し画像

CLIPSeg によるZero-Shot画像セグメンテーションを試す

CLIPSeg によるZero-Shot画像セグメンテーションを試したのでまとめました。

1. CLIPSeg

「CLIPSeg」は、学習なしに、ほぼすべての種類のオブジェクトを「画像セグメンテーション」できるAIモデルです。

画像セグメンテーションは、画像内に何があるか(分類)、オブジェクトが画像内のどこにあるか(検出)だけでなく、それらのオブジェクトの輪郭を知ることができます。

ロボットが物体の形を知り、オブジェクトを正しくつかむためや、画像インペイントと組み合わせて、画像の特定の部分を置き換えるために利用できます。

2. Colabでの実行

Google Colabでの実行手順は、次のとおりです。

(1) メニュー「編集→ノートブックの設定」で、「ハードウェアアクセラレータ」に「GPU」を選択。

(2)  パッケージのインストール。

# パッケージのインポート
!pip install -q transformers

(3) プロセッサとモデルの準備。

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

# プロセッサとモデルの準備
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

(4) 画像のアップロード。
左端のフォルダアイコンでファイル一覧を表示し、画像をアップロードします。

アップロードするのは、次の画像になります。

・image.jpg

(5) 画像の読み込み。

from PIL import Image

# 画像の読み込み
image = Image.open("image.jpg")
image

(6) プロンプトの準備。
ハンバーガー(hamburger)、フライドポテト(fries)、ドリンク(drink)の3つをセグメンテーションします。

# プロンプトの準備
prompts = ["hamburger", "fries", "drink"]

(7) 推論の実行。

import torch

# 推論の実行
inputs = processor(
    text=prompts, 
    images=[image] * len(prompts), 
    padding="max_length", 
    return_tensors="pt")
with torch.no_grad():
  outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)

(8) 結果の表示。

import matplotlib.pyplot as plt

# 結果の表示
_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];

画像のハンバーガー、フライドポテト、ドリンクの場所が認識できていることがわかります。

次回



この記事が気に入ったらサポートをしてみませんか?