見出し画像

TensorFlow.js 入門 / 物体検出

「TensorFlow.js」を使って、ブラウザで「物体検出」を行います。Chromeで動作確認しています。

1. 物体検出

「TensorFlow.js」による物体検出のコードは、次のとおり。

<!-- TensorFlow.jsの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

<!-- coco-ssdモデルの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"> </script>

<!-- テストに使用する画像 -->
<img id="img" src="cat.jpg"/>

<script>
  // 画像の取得
  const img = document.getElementById('img');

  // モデルの読み込み
  cocoSsd.load().then(model => {
    // 物体検出
    model.detect(img).then(predictions => {
      console.log('Predictions: ', predictions);
    });
  });
</script>

用意した画像(cat.jpg)に応じて、JavaScriptコンソールに次のような結果が出力されます。

画像2

2. パッケージのインポート

<script>でパッケージをインポートします。

<!-- TensorFlow.jsの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

<!-- coco-ssdモデルの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"> </script>

3. 物体検出モデルの読み込み

物体検出モデル(coco-ssd)の読み込みを行うには、cocoSsd.load()を使います。

export interface ModelConfig {
  base?: ObjectDetectionBaseModel;
  modelUrl?: string;
}

cocoSsd.load(config: ModelConfig = {});

base : ベースCNNモデルの指定。 (mobilenet_v1 / mobilenet_v2 / lite_mobilenet_v2)
modelUrl : モデルのカスタムURLを指定。

4. 物体検出の実行

物体検出を実行するには、model.detect()を使います。

model.detect(
  img: tf.Tensor3D | ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement, 
  maxDetectionSize: number
)

img:テンソル、または画像要素
maxNumBoxes:検出するバウンディングボックスの最大数。 デフォルトは20。

結果は次のように出力されます。

[{
  bbox: [x, y, width, height],
  class: "person",
  score: 0.8380282521247864
 }, {
  bbox: [x, y, width, height],
  class: "kite",
  score: 0.74644153267145157
}]

5. Webカメラを使った物体検出

Webカメラを使った物体検出の例は、次のとおり。

<html>
  <head>
    <!-- TensorFlow.jsの読み込み -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

    <!-- coco-ssdモデルの読み込み -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>

    <script>
      // テンソルをキャンバスに描画
      const renderToCanvas = async (ctx, a) => {
        const [height, width] = a.shape
        const imageData = new ImageData(width, height)
        const data = await a.data()
        for (let i = 0; i < height * width; ++i) {
          const j = i * 4
          const k = i * 3
          imageData.data[j + 0] = data[k + 0]
          imageData.data[j + 1] = data[k + 1]
          imageData.data[j + 2] = data[k + 2]
          imageData.data[j + 3] = 255
        }
        ctx.putImageData(imageData, 0, 0)
      }

      // バウンディングボックスの描画
      const drawBBox = (ctx, bbox, name) => {
        // 枠の描画
        ctx.strokeStyle = 'red'
        ctx.fillStyle = 'red'
        ctx.strokeRect(bbox[0], bbox[1], bbox[2], bbox[3])
        ctx.fillRect(bbox[0], bbox[1]-20, bbox[2], 20)

        // 名前の描画
        ctx.fillStyle = "white"
        ctx.font = 'bold 20px sans-serif'
        ctx.textAlign = "left"
        ctx.textBaseline = "top"
        ctx.fillText(name, bbox[0]+8, bbox[1]-20, bbox[2])
      }

      // 物体検出の開始
      const startDetect = () => {
        cocoSsd.load()
          .then(model => {
            const webcamElement = document.getElementById('webcam')
            window.requestAnimationFrame(onFrame.bind(null, model, webcamElement))
          })
      }

      // フレーム毎に呼ばれる
      const onFrame = async (model, webcamElement) => {
        // 画像分類
        const tensor = tf.browser.fromPixels(webcamElement)
        const predictions = await model.detect(tensor)

        // キャンバスの準備
        const canvas = document.getElementById('canvas')
        const [height, width] = tensor.shape
        canvas.width = width
        canvas.height = height

        // キャンバスの描画
        const ctx = canvas.getContext('2d')
        await renderToCanvas(ctx, tensor)
        for (let i = 0; i < predictions.length; i++) {
          drawBBox(ctx, predictions[i].bbox, predictions[i].class)
        }

        // 次フレーム
        setTimeout(() => {
          window.requestAnimationFrame(onFrame.bind(null, model, webcamElement))
        }, 1000)
      }

      // Webカメラの開始
      const constraints = {
        audio: false,
        video: true
      }
      navigator.mediaDevices.getUserMedia(constraints)
        // 成功時に呼ばれる
        .then((stream) => {
            const video = document.querySelector('video')
            video.srcObject = stream

            // 物体検出の開始
            startDetect()
        })
        // エラー時に呼ばれる
        .catch((error) => {
            const errorMsg = document.querySelector('#errorMsg')
            errorMsg.innerHTML += `<p>${error.name}</p>`
        })
    </script>
  </head>
  <body>
    <video id="webcam" width="320" height="240" autoplay playsinline></video>
    <canvas id="canvas"></canvas>
    <div id="errorMsg"></div>
  </body>
</html>


この記事が気に入ったらサポートをしてみませんか?