Google Colab で SAM 2 を試す
「Google Colab」で「SAM 2」を試したのでまとめました。
1. SAM 2
「SAM 2」(Segment Anything Model 2) は、画像や動画のセグメンテーションを行うためのAIモデルです。目的のオブジェクトを示す情報 (XY座標など) が与えられた場合に、オブジェクトマスクを予測します。
具体的に何ができるかは、以下のデモページが参考になります。
2. セットアップ
セットアップの手順は、次のとおりです。
(1) パッケージのインストール。
# パッケージのインストール
!git clone https://github.com/facebookresearch/segment-anything-2.git
%cd ./segment-anything-2
!pip install -e .
(2) チェックポイントのダウンロード。
# チェックポイントのダウンロード
%cd checkpoints
!./download_ckpts.sh
%cd ..
3. 画像のセグメンテーション
画像のセグメンテーションの手順は、次のとおりです。
(1) 左端のフォルダアイコンで、画像を「segment-anything-2」フォルダにアップロード。
・sample.jpg
(2) 画像の準備。
from PIL import Image
import numpy as np
# 画像の準備
image = Image.open("sample.jpg")
image = np.array(image.convert("RGB"))
(3) ユーティリティ関数の準備。
画像上にに各種情報を表示する関数になります。
# マスクの表示
def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
# 点群の表示
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# ボックスの表示
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
# マスク群の表示
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
(4) Predictorの準備。
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Predictorの準備
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
(5) Predictorに画像を指定。
SAM2ImagePredictor.set_image()で画像を指定します。
# Predictorに画像を指定
predictor.set_image(image)
(6) 目的のオブジェクトの指定。
目的のオブジェクト上の位置 (x, y) と、ラベル (1: 前景点、0:背景点) を選択します。
# 目的のオブジェクトの指定
input_point = np.array([[500, 300]]) # 位置
input_label = np.array([1]) # ラベル (1:前景点、0:背景点)
(7) 目的のオブジェクトの確認。
☆で位置が示されます。
import matplotlib.pyplot as plt
# 目的のオブジェクトの確認
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
(8) Predictorで予測。
# Predictorで予測
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind] # マスク
scores = scores[sorted_ind] # スコア
logits = logits[sorted_ind] # ロジット
multimask_output=True (デフォルト) の場合、3つのマスクを出力します。スコアはマスクの品質の推定値です。単一のポイントなど、あいまいなプロンプトの場合、単一のマスクのみが必要な場合でも、multimask_output=True を使用することが推奨されます。
(9) マスクの確認。
# マスクの確認
show_masks(image, masks, scores,
point_coords=input_point,
input_labels=input_label,
borders=True
)
4. 動画のセグメンテーション
動画のセグメンテーションの手順は、次のとおりです。
(1) 左端のフォルダアイコンで、動画を「segment-anything-2」フォルダにアップロード。
・sample.mp4
(2) 動画をJPEGフレームのリストに変換。
「videos」フォルダに00000.jpg〜00089.jpgが生成されます。
# 動画をJPEGフレームのリストに変換
!mkdir videos
!ffmpeg -i sample.mp4 -q:v 2 -start_number 0 ./videos/'%05d.jpg'
(3) JPEGフレーム名の準備
import os
# JPEGフレーム名の準備
video_dir = "./videos"
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
(4) ユーティリティ関数の準備。
画像上にに各種情報を表示する関数になります。
# マスクの表示
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
# 点群の表示
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# ボックスの表示
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
(5) Predictorの準備。
from sam2.build_sam import build_sam2_video_predictor
# Predictorの準備
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
(6) 動画の推論状態の初期化。
「SAM 2」では、インタラクティブな動画セグメンテーションにステートフル推論が必要なため、動画の推論状態を初期化する必要があります。
# 動画の推論状態の初期化
inference_state = predictor.init_state(video_path=video_dir)
(7) 動画の推論状態のリセット。
このinference_stateを使用してトラッキングを実行した場合、reset_state でリセットする必要があります。
# 動画の推論状態のリセット
predictor.reset_state(inference_state)
(8) 目的のオブジェクトの指定。
import numpy as np
# 目的のオブジェクトの指定
ann_frame_idx = 0
ann_obj_id = 1
input_point = np.array([[800, 400]], dtype=np.float32) # 位置
input_label = np.array([1], np.int32) # ラベル (1:前景点、0:背景点)
(9) 指定フレームをPredictorで予測。
import cv2
import matplotlib.pyplot as plt
# 指定フレームをPredictorで予測
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=input_point,
labels=input_label,
)
# 確認
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
image = cv2.imread(video_dir + "/" + frame_names[ann_frame_idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
(10) 全フレームをPredictorで予測。
# 全フレームをPredictorで予測
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
(11) 結果画像フレームの保存。
「output」フォルダに00000.jpg〜00089.jpgが保存されます。
!mkdir output
# 結果画像フレームの保存
plt.close("all")
for out_frame_idx in range(len(frame_names)):
plt.figure(figsize=(6, 4))
plt.title(f"frame {out_frame_idx}")
plt.axis("off")
plt.tight_layout(pad=0)
# 元画像の描画
image = cv2.imread(video_dir + "/" + frame_names[out_frame_idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
# マスクの描画
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
# 結果画像フレームの保存
basename = os.path.basename(frame_names[out_frame_idx])
output_frame = os.path.join("output", basename)
plt.savefig(output_frame)
plt.close()
(12) 結果画像フレームを動画に変換。
# 結果画像フレームを動画に変換
!ffmpeg -framerate 24 -i ./output/%05d.jpg -c:v libx264 -r 30 -pix_fmt yuv420p ./output.mp4
この記事が気に入ったらサポートをしてみませんか?