見出し画像

Unity Barracuda 入門 / 物体検出

「Unity Barracuda」で物体検出を行う方法をまとめました。

・Unity 2019.3.0f1
・Barracuda 0.6.3

1. プロジェクトの作成

(1) Unityの3Dプロジェクトを作成。
(2) メニュー「Window → Package Manager」で「Package Manager」を開き、「Preview Package」を有効にし、「Barracuda」(0.6.3)をインストール。

2. モデルとラベルの準備

「Assets/Resources」に「TFClassify-Unity-Barracuda」からダウンロードした「モデル」と「ラベル」を追加します。

・tiny_yolo.onnx
・tiny_yolo_labels.txt

3. 物体検出の実装

物体検出の実装を行います。

(1) Hierarchyウィンドウに、「RawImage」を生成し、そこにスクリプト「WebCam」を追加。
1024x768の画面に400x400のRawImageを配置しました。

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;

using System;
using System.IO;
using Barracuda;
using UnityEngine;
using System.Linq;
using UnityEngine.UI;
using System.Collections;
using System.Threading.Tasks;
using System.Collections.Generic;

// Webカメラ
public class WebCam : MonoBehaviour
{
    // カメラ
    RawImage rawImage; // RawImage
    WebCamTexture webCamTexture; //Webカメラテクスチャ

    // 描画
    Texture2D lineTexture; // ラインテクスチャ
    GUIStyle guiStyle; // GUIスタイル

    // 情報
    IList<BoundingBox> boxes; // 検出したバウンディングボックス
    float shiftX = 512f-200f; // 描画先のX座標
    float shiftY = 384f-200f; // 描画先のY座標
    float scaleFactor = 400f/(float)Detector.IMAGE_SIZE; // 描画先のスケール

    // 推論
    public Detector detector; // 物体検出
    bool isWorking = false; // 処理中

    // スタート時に呼ばれる
    void Start ()
    {
        // Webカメラの開始
        this.rawImage = GetComponent<RawImage>();
        this.webCamTexture = new WebCamTexture();
        this.webCamTexture = new WebCamTexture(
            Detector.IMAGE_SIZE, Detector.IMAGE_SIZE, 30);
        this.rawImage.texture = this.webCamTexture;
        this.webCamTexture.Play();

        // ラインテクスチャ
        this.lineTexture = new Texture2D(1, 1);
        this.lineTexture.SetPixel(0, 0, Color.red);
        this.lineTexture.Apply();

        // GUIスタイル
        this.guiStyle = new GUIStyle();
        this.guiStyle.fontSize = 50;
        this.guiStyle.normal.textColor = Color.red;
    }

    // フレーム毎に呼ばれる
    private void Update()
    {
        // 物体検出
        TFDetect();
    }

    // 物体検出
    private void TFDetect()
    {
        if (this.isWorking)
        {
            return;
        }

        this.isWorking = true;

        // 画像の前処理
        StartCoroutine(ProcessImage(result =>
        {
            // 推論の実行
            StartCoroutine(this.detector.Predict(result, boxes =>
            {
                if (boxes.Count == 0)
                {
                    this.isWorking = false;
                    return;
                }
                this.boxes = boxes;

                // 未使用のアセットをアンロード
                Resources.UnloadUnusedAssets();
                this.isWorking = false;
            }));
        }));
    }

    // 画像の前処理
    private IEnumerator ProcessImage(System.Action<Color32[]> callback)
    {
        // 画像のクロップ(WebCamTexture → Texture2D)
        yield return StartCoroutine(CropSquare(webCamTexture, texture =>
            {
                // 画像のスケール(Texture2D → Texture2D)
                var scaled = Scaled(texture,
                    Detector.IMAGE_SIZE,
                    Detector.IMAGE_SIZE);

                // コールバックを返す
                callback(scaled.GetPixels32());
            }));
    }

    // 画像のクロップ(WebCamTexture → Texture2D)
    public static IEnumerator CropSquare(WebCamTexture texture, System.Action<Texture2D> callback)
    {
        // Texture2Dの準備
        var smallest = texture.width < texture.height ? texture.width : texture.height;
        var rect = new Rect(0, 0, smallest, smallest);
        Texture2D result = new Texture2D((int)rect.width, (int)rect.height);

        // 画像のクロップ
        if (rect.width != 0 && rect.height != 0)
        {
            result.SetPixels(texture.GetPixels(
                Mathf.FloorToInt((texture.width - rect.width) / 2),
                Mathf.FloorToInt((texture.height - rect.height) / 2),
                Mathf.FloorToInt(rect.width),
                Mathf.FloorToInt(rect.height)));
            yield return null;
            result.Apply();
        }

        yield return null;
        callback(result);
    }

    // 画像のスケール(Texture2D → Texture2D)
    public static Texture2D Scaled(Texture2D texture, int width, int height)
    {
        // リサイズ後のRenderTextureの生成
        var rt = RenderTexture.GetTemporary(width, height);
        Graphics.Blit(texture, rt);

        // リサイズ後のTexture2Dの生成
        var preRT = RenderTexture.active;
        RenderTexture.active = rt;
        var ret = new Texture2D(width, height);
        ret.ReadPixels(new Rect(0, 0, width, height), 0, 0);
        ret.Apply();
        RenderTexture.active = preRT;
        RenderTexture.ReleaseTemporary(rt);
        return ret;
    }

    // GUIの表示
    public void OnGUI()
    {
        if (this.boxes != null)
        {
            foreach (var box in this.boxes)
            {
                DrawBoundingBox(box, scaleFactor, shiftX, shiftY);
            }
        }
    }

    // バウンディングボックスの描画
    void DrawBoundingBox(BoundingBox box, float scaleFactor, float shiftX, float shiftY)
    {
        var x = box.Rect.x * scaleFactor + shiftX;
        var width = box.Rect.width * scaleFactor;
        var y = box.Rect.y * scaleFactor + shiftY;
        var height = box.Rect.height * scaleFactor;
        DrawRectangle(new Rect(x, y, width, height), 4, Color.red);
        DrawLabel(new Rect(x + 10, y + 10, 200, 20), $"{box.Label}: {(int)(box.Confidence * 100)}%");
    }

    // ラベルの描画
    void DrawLabel(Rect pos, string text)
    {
        GUI.Label(pos, text, this.guiStyle);
    }

    // 矩形の描画
    void DrawRectangle(Rect area, int frameWidth, Color color)
    {
        Rect lineArea = area;
        lineArea.height = frameWidth;
        GUI.DrawTexture(lineArea, lineTexture);
        lineArea.y = area.yMax - frameWidth;
        GUI.DrawTexture(lineArea, lineTexture);
        lineArea = area;
        lineArea.width = frameWidth;
        GUI.DrawTexture(lineArea, lineTexture);
        lineArea.x = area.xMax - frameWidth;
        GUI.DrawTexture(lineArea, lineTexture);
    }
}

◎ Webカメラの開始
Start()ではWebカメラの開始を行なっています。

◎ 画像分類の前処理
ProcessImage()では画像の前処理を行なっています。
Webカメラから取得した画像を、CropSquare()で短辺x短辺の正方形にした後、Scaled()で416x416の画像に変換しています。

◎ 推論の実行
「Classifier」のPredict()を呼んで推論の実行を行います。

◎ バウンディックスの描画
OnGUI()で推論結果として取得したバウンディックスボックスを描画します。

(2) Hierarchyウィンドウに、空のゲームオブジェクト「Detector」を生成し、そこにスクリプト「Detector」を追加。

using System;
using Barracuda;
using System.Linq;
using UnityEngine;
using System.Collections;
using System.Collections.Generic;
using System.Text.RegularExpressions;

// 物体検出
public class Detector : MonoBehaviour
{
    // アンカー
    private float[] Anchors = new float[]
    {
        1.08F, 1.19F, 3.42F, 4.41F, 6.63F, 11.38F, 9.42F, 5.11F, 16.62F, 10.52F
    };

    // リソース
    public NNModel modelFile; // モデル
    public TextAsset labelsFile; // ラベル

    // パラメータ
    public const int IMAGE_SIZE = 416; // 画像サイズ
    private const int IMAGE_MEAN = 0;
    private const float IMAGE_STD = 1f;
    private const string INPUT_NAME = "image";
    private const string OUTPUT_NAME = "grid";

    // 出力のパース
    private const int ROW_COUNT = 13; // 行
    private const int COL_COUNT = 13; // 列
    private const int BOXES_PER_CELL = 5; // セル毎のボックス数
    private const int BOX_INFO_FEATURE_COUNT = 5; // ボックス情報の特徴数
    private const int CLASS_COUNT = 20; // クラス数
    private const float CELL_WIDTH = 32; // セル幅
    private const float CELL_HEIGHT = 32; // セル高さ

    // 出力のフィルタリング
    private const float MINIMUM_CONFIDENCE = 0.3f; // 最小検出信頼度

    // 推論
    private IWorker worker; // ワーカー
    private string[] labels; // ラベル

    // スタート時に呼ばれる
    void Start()
    {
        // ラベルとモデルの読み込み
        this.labels = Regex.Split(this.labelsFile.text, "\n|\r|\r\n")
            .Where(s => !String.IsNullOrEmpty(s)).ToArray();
        var model = ModelLoader.Load(this.modelFile);

        // ワーカーの生成
        this.worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
    }

    // 推論
    public IEnumerator Predict(Color32[] picture, System.Action<IList<BoundingBox>> callback)
    {
        // 入力テンソルの生成
        using (var tensor = TransformInput(picture, IMAGE_SIZE, IMAGE_SIZE))
        {
            // 入力の生成
            var inputs = new Dictionary<string, Tensor>();
            inputs.Add(INPUT_NAME, tensor);

            // 推論の実行
            yield return StartCoroutine(worker.ExecuteAsync(inputs));

            // 出力の生成
            var output = worker.PeekOutput(OUTPUT_NAME);
            var results = ParseOutputs(output);
            var boxes = FilterBoundingBoxes(results, 5, MINIMUM_CONFIDENCE);

            // 結果を返す
            callback(boxes);
        }
    }

    // 入力テンソルの生成
    public static Tensor TransformInput(Color32[] pic, int width, int height)
    {
        float[] floatValues = new float[width * height * 3];
        for (int i = 0; i < pic.Length; ++i)
        {
            var color = pic[i];
            floatValues[i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD;
        }
        return new Tensor(1, height, width, 3, floatValues);
    }

    // 出力のパース
    private IList<BoundingBox> ParseOutputs(Tensor output, float threshold = .3F)
    {
        var boxes = new List<BoundingBox>();
        for (int cy = 0; cy < COL_COUNT; cy++)
        {
            for (int cx = 0; cx < ROW_COUNT; cx++)
            {
                for (int box = 0; box < BOXES_PER_CELL; box++)
                {
                    var channel = (box * (CLASS_COUNT + BOX_INFO_FEATURE_COUNT));

                    // バウンディングボックスの寸法と信頼度の取得
                    var dimensions = GetBoundingBoxDimensions(output, cx, cy, channel);
                    float confidence = GetConfidence(output, cx, cy, channel);
                    if (confidence < threshold)
                    {
                        continue;
                    }

                    // スコアが最大のクラスのINDEXとスコアの取得
                    float[] predictedClasses = GetPredictedClasses(output, cx, cy, channel);
                    var (topResultIndex, topResultScore) = GetTopResult(predictedClasses);
                    var topScore = topResultScore * confidence;
                    if (topScore < threshold)
                    {
                        continue;
                    }

                    // バウンディングボックスをセルにマッピング
                    var mappedBoundingBox = MapBoundingBoxToCell(cx, cy, box, dimensions);
 
                    // バウンディングボックスの追加
                    var boundingBox = new BoundingBox();
                    boundingBox.Rect = new Rect(
                        (mappedBoundingBox.x - mappedBoundingBox.width / 2),
                        (mappedBoundingBox.y - mappedBoundingBox.height / 2),
                        mappedBoundingBox.width,
                        mappedBoundingBox.height);
                    boundingBox.Confidence = topScore;
                    boundingBox.Label = labels[topResultIndex];
                    boxes.Add(boundingBox);
                }
            }
        }
        return boxes;
    }

    // バウンディングボックスの抽出
    private Rect GetBoundingBoxDimensions(Tensor output, int x, int y, int channel)
    {
        return new Rect(
            output[0, x, y, channel],
            output[0, x, y, channel + 1],
            output[0, x, y, channel + 2],
            output[0, x, y, channel + 3]);
    }

    // 信頼度の抽出
    private float GetConfidence(Tensor output, int x, int y, int channel)
    {
        return Sigmoid(output[0, x, y, channel + 4]);
    }

    // 予測クラスの抽出
    private float[] GetPredictedClasses(Tensor output, int x, int y, int channel)
    {
        float[] predictedClasses = new float[CLASS_COUNT];
        int predictedClassOffset = channel + BOX_INFO_FEATURE_COUNT;
        for (int predictedClass = 0; predictedClass < CLASS_COUNT; predictedClass++)
        {
            predictedClasses[predictedClass] = output[0, x, y, predictedClass + predictedClassOffset];
        }
        return Softmax(predictedClasses);
    }

    // スコアが最大のクラスのINDEXとスコアの取得
    private ValueTuple<int, float> GetTopResult(float[] predictedClasses)
    {
        return predictedClasses
            .Select((predictedClass, index) => (Index: index, Value: predictedClass))
            .OrderByDescending(result => result.Value)
            .First();
    }

    // バウンディングボックスをセルにマッピング
    private Rect MapBoundingBoxToCell(int x, int y, int box, Rect dimensions)
    {
        return new Rect(
            ((float)y + Sigmoid(dimensions.x)) * CELL_WIDTH,
            ((float)x + Sigmoid(dimensions.y)) * CELL_HEIGHT,
            (float)Math.Exp(dimensions.width) * CELL_WIDTH * Anchors[box * 2],
            (float)Math.Exp(dimensions.height) * CELL_HEIGHT * Anchors[box * 2 + 1]);
    }

    // バウンディングボックスのフィルタリング
    private IList<BoundingBox> FilterBoundingBoxes(IList<BoundingBox> boxes, int limit, float threshold)
    {
        var activeCount = boxes.Count;
        var isActiveBoxes = new bool[boxes.Count];
        for (int i = 0; i < isActiveBoxes.Length; i++)
        {
            isActiveBoxes[i] = true;
        }
        var sortedBoxes = boxes.Select((b, i) => new { Box = b, Index = i })
            .OrderByDescending(b => b.Box.Confidence)
            .ToList();
        var results = new List<BoundingBox>();
        for (int i = 0; i < boxes.Count; i++)
        {
            if (isActiveBoxes[i])
            {
                var boxA = sortedBoxes[i].Box;
                results.Add(boxA);
                if (results.Count >= limit)
                {
                    break;
                }
                for (var j = i + 1; j < boxes.Count; j++)
                {
                    if (isActiveBoxes[j])
                    {
                        var boxB = sortedBoxes[j].Box;
                        if (IntersectionOverUnion(boxA.Rect, boxB.Rect) > threshold)
                        {
                            isActiveBoxes[j] = false;
                            activeCount--;
                            if (activeCount <= 0)
                            {
                                break;
                            }
                        }
                    }
                }
                if (activeCount <= 0)
                {
                    break;
                }
            }
        }
        return results;
    }

    // IoU(評価指標)の計算
    private float IntersectionOverUnion(Rect boundingBoxA, Rect boundingBoxB)
    {
        var areaA = boundingBoxA.width * boundingBoxA.height;
        if (areaA <= 0)
        {
            return 0;
        }
        var areaB = boundingBoxB.width * boundingBoxB.height;
        if (areaB <= 0)
        {
            return 0;
        }
        var minX = Math.Max(boundingBoxA.xMin, boundingBoxB.xMin);
        var minY = Math.Max(boundingBoxA.yMin, boundingBoxB.yMin);
        var maxX = Math.Min(boundingBoxA.xMax, boundingBoxB.xMax);
        var maxY = Math.Min(boundingBoxA.yMax, boundingBoxB.yMax);
        var intersectionArea = Math.Max(maxY - minY, 0) * Math.Max(maxX - minX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    // シグモイド
    private float Sigmoid(float value)
    {
        var k = (float)Math.Exp(value);
        return k / (1.0f + k);
    }

    // ソフトマックス
    private float[] Softmax(float[] values)
    {
        var maxVal = values.Max();
        var exp = values.Select(v => Math.Exp(v - maxVal));
        var sumExp = exp.Sum();
        return exp.Select(v => (float)(v / sumExp)).ToArray();
    }
}

// バウンディングボックス
public class BoundingBox
{
    public string Label; // ラベル
    public float Confidence; // 信頼度
    public Rect Rect; //矩形
}

◎ モデルとラベルの読み込み
Start()でモデルとラベルの読み込みを行います。

◎ 推論の実行
Predict()で推論の実行を行います。
結果は「BoundingBox」のリストに格納されます。

(4) 「Detector」をWebCamの「Detector」にドラッグ&ドロップ。

(5) Assetsのモデルとラベルを「Detector」の「Model File」と「Labels File」にドラッグ&ドロップ。

(6) Webカメラのあるパソコン(またはスマートフォン)で、「Unity Editor」のPlayボタンを押して実行。
推論結果が画面に表示されます。

画像1


この記事が気に入ったらサポートをしてみませんか?