見出し画像

一枚の画像を読み込み、NSFW/SFWかどうか、版権キャラクターの可能性があるか判定するスクリプト

需要あるかなと思って書いたら結構反響あったのでメモ。
まずこれを動かす環境作り。
誰のPCでも動くようにCPU版にします

python -m venv venv
venv\Scripts\activate
pip install opencv-python-headless numpy pillow onnxruntime huggingface-hub

以下のスクリプトを作成
ContentSafetyAnalyzer.py

import csv
import os
from pathlib import Path

import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download

# 画像のサイズ設定
IMAGE_SIZE = 448

def preprocess_image(image):
    image = np.array(image)
    image = image[:, :, ::-1]  # BGRからRGBへ変換

    # 画像を正方形にするためのパディングを追加
    size = max(image.shape[0:2])
    pad_x = size - image.shape[1]
    pad_y = size - image.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2
    image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    # サイズに合わせた補間方法を選択
    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
    image = image.astype(np.float32)
    return image

def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags, thresh):
    try:
        image = Image.open(image_path)
        image = image.convert("RGB") if image.mode != "RGB" else image
        image = preprocess_image(image)
    except Exception as e:
        print(f"画像を読み込めません: {image_path}, エラー: {e}")
        return

    img = np.array([image])
    prob = ort_sess.run(None, {input_name: img})[0][0]  # ONNXモデルからの出力

    # NSFW/SFW判定
    tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
    max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
    max_sfw_score = tag_confidences.get("general", 0)

    if max_nsfw_score > max_sfw_score:
        print("NSFWの可能性が高いです")
    else:
        print("SFWの可能性が高いです")

    # 版権キャラクターの可能性を評価
    character_tags_with_probs = []
    for i, p in enumerate(prob[4:]):
        if p >= thresh and i >= len(general_tags):
            tag_index = i - len(general_tags)
            if tag_index < len(character_tags):
                tag_name = character_tags[tag_index]
                prob_percent = round(p * 100, 2)  # 確率をパーセンテージに変換
                character_tags_with_probs.append((tag_name, f"{prob_percent}%"))

    if character_tags_with_probs:
        print(f"版権キャラクター: {character_tags_with_probs}の可能性があります")
    else:
        print("版権キャラクターの可能性が低いと思われます")

def main(MODEL_ID, image_path, thresh):
    print("Hugging Faceからモデルをダウンロード中")
    onnx_path = hf_hub_download(MODEL_ID, "model.onnx")
    csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv")

    print("ONNXモデルを実行中")
    print(f"ONNXモデルのパス: {onnx_path}")

    ort_sess = ort.InferenceSession(onnx_path)

    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        rows = list(reader)
    assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}"

    rating_tags = [row[1] for row in rows if row[2] == "9"]
    character_tags = [row[1] for row in rows if row[2] == "4"]
    general_tags = [row[1] for row in rows[1:] if row[2] == "0"]

    process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags, thresh)

    print("処理完了!")

if __name__ == "__main__":
    MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"
    image_path = "E:/desktop/test.jpg"  # 画像のパス
    thresh = 0.35  # 閾値の設定
    main(MODEL_ID, image_path, thresh)

実行する

python ContentSafetyAnalyzer.py
まんまみーや!

この記事が気に入ったらサポートをしてみませんか?