見出し画像

OpenAI CLIPの使い方

「OpenAI CLIP」の使い方をまとめました。

1. OpenAI CLIP

OpenAI CLIP」は、OpenAIが開発した、画像とテキストの関連性をランク付けするニューラルネットワークです。従来の「教師あり学習」の画像分類では決められたラベルのみで分類するのに対し、「OpenAI CLIP」では推論時に自由にラベルを指定して画像分類することができます。

「GTP-2」や「GTP-3」で使われている「Zero-shot Learning」の技術を応用することによって、この機能を実現しています。画像分類の精度も、ImageNetのラベル付きデータを使用せずに、ResNet50と同等を実現しています。

2. OpenAI CLIPのインストール

「Google Colab」に「OpenAI CLIP」をインストールする手順は、次のとおりです。

(1) Google Colabのメニュー「編集 → ノートブックの設定」で「GPU」を選択。
(2) CUDAのバージョン取得。

# CUDAのバージョン取得
import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)
if CUDA_version == "10.0":
   torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
   torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
   torch_version_suffix = ""
else:
   torch_version_suffix = "+cu110"

(3) PyTorchと「OpenAI CLIP」のインストール。

# PyTorchとCLIPのインストール
!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex
!pip install git+https://github.com/openai/CLIP.git
!pip install ftfy regex tqdm

3. 推論

(1) 画像のアップロード。

・test.png

画像1

(2) 推論の実行。

import torch
import clip
from PIL import Image

# モデルの読み込み
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 画像とテキストの準備
image = preprocess(Image.open("test.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a human", "a dog", "a cat"]).to(device)

with torch.no_grad():
    # 画像とテキストのエンコード
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    # 推論
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# 類似率の出力
print("Label probs:", probs)
Label probs: [[0.0957  0.01057 0.8936 ]]

「a human」(人間)や「a dog」(犬)より「a cat」(猫)の関連性が高いことがわかります。

4. 追加の実験

【実験1】 テキストを「cute」「cool」「scary」に変更して推論。

text = clip.tokenize(["cute", "cool", "scary"]).to(device)
Label probs: [[0.9546  0.02176 0.0239 ]]

「cool」(おとなしい)や「scary」(おっかない)より「cute」(かわいい)の関連性が高いことがわかります。

【実験2】 テキストを「black」「white」「red」に変更して推論。

text = clip.tokenize(["black", "white", "red"]).to(device)
Label probs: [[0.05106 0.9336  0.01533]]

「white」(白)が関連性が高く、「black」(黒)の関連性が少しだけで、「red」(赤)の関連性がないことがわかります。

5. clipモジュールのAPI

clipモジュールは、次のメソッドをサポートします。

・clip.available_models() : 使用可能なCLIPモデル名の取得。
・clip.load(name, device=...) : モデルの読み込み。
・clip.tokenize(text: Union[str, List[str]], context_length=77) : テキストのトークン化。

clip.load()によって返されるモデルは、次のメソッドをサポートします。

・model.encode_image(image: Tensor) : 画像のエンコード。
・model.encode_text(text: Tensor) : テキストのエンコード。
・model(image: Tensor, text: Tensor) : 類似度を推論。



この記事が気に入ったらサポートをしてみませんか?