TensorFlow.js 入門 / 物体検出
「TensorFlow.js」を使って、ブラウザで「物体検出」を行います。Chromeで動作確認しています。
1. 物体検出
「TensorFlow.js」による物体検出のコードは、次のとおり。
<!-- TensorFlow.jsの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- coco-ssdモデルの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"> </script>
<!-- テストに使用する画像 -->
<img id="img" src="cat.jpg"/>
<script>
// 画像の取得
const img = document.getElementById('img');
// モデルの読み込み
cocoSsd.load().then(model => {
// 物体検出
model.detect(img).then(predictions => {
console.log('Predictions: ', predictions);
});
});
</script>
用意した画像(cat.jpg)に応じて、JavaScriptコンソールに次のような結果が出力されます。
2. パッケージのインポート
<script>でパッケージをインポートします。
<!-- TensorFlow.jsの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- coco-ssdモデルの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"> </script>
3. 物体検出モデルの読み込み
物体検出モデル(coco-ssd)の読み込みを行うには、cocoSsd.load()を使います。
export interface ModelConfig {
base?: ObjectDetectionBaseModel;
modelUrl?: string;
}
cocoSsd.load(config: ModelConfig = {});
・ base : ベースCNNモデルの指定。 (mobilenet_v1 / mobilenet_v2 / lite_mobilenet_v2)
・ modelUrl : モデルのカスタムURLを指定。
4. 物体検出の実行
物体検出を実行するには、model.detect()を使います。
model.detect(
img: tf.Tensor3D | ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement,
maxDetectionSize: number
)
・img:テンソル、または画像要素
・maxNumBoxes:検出するバウンディングボックスの最大数。 デフォルトは20。
結果は次のように出力されます。
[{
bbox: [x, y, width, height],
class: "person",
score: 0.8380282521247864
}, {
bbox: [x, y, width, height],
class: "kite",
score: 0.74644153267145157
}]
5. Webカメラを使った物体検出
Webカメラを使った物体検出の例は、次のとおり。
<html>
<head>
<!-- TensorFlow.jsの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- coco-ssdモデルの読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
<script>
// テンソルをキャンバスに描画
const renderToCanvas = async (ctx, a) => {
const [height, width] = a.shape
const imageData = new ImageData(width, height)
const data = await a.data()
for (let i = 0; i < height * width; ++i) {
const j = i * 4
const k = i * 3
imageData.data[j + 0] = data[k + 0]
imageData.data[j + 1] = data[k + 1]
imageData.data[j + 2] = data[k + 2]
imageData.data[j + 3] = 255
}
ctx.putImageData(imageData, 0, 0)
}
// バウンディングボックスの描画
const drawBBox = (ctx, bbox, name) => {
// 枠の描画
ctx.strokeStyle = 'red'
ctx.fillStyle = 'red'
ctx.strokeRect(bbox[0], bbox[1], bbox[2], bbox[3])
ctx.fillRect(bbox[0], bbox[1]-20, bbox[2], 20)
// 名前の描画
ctx.fillStyle = "white"
ctx.font = 'bold 20px sans-serif'
ctx.textAlign = "left"
ctx.textBaseline = "top"
ctx.fillText(name, bbox[0]+8, bbox[1]-20, bbox[2])
}
// 物体検出の開始
const startDetect = () => {
cocoSsd.load()
.then(model => {
const webcamElement = document.getElementById('webcam')
window.requestAnimationFrame(onFrame.bind(null, model, webcamElement))
})
}
// フレーム毎に呼ばれる
const onFrame = async (model, webcamElement) => {
// 画像分類
const tensor = tf.browser.fromPixels(webcamElement)
const predictions = await model.detect(tensor)
// キャンバスの準備
const canvas = document.getElementById('canvas')
const [height, width] = tensor.shape
canvas.width = width
canvas.height = height
// キャンバスの描画
const ctx = canvas.getContext('2d')
await renderToCanvas(ctx, tensor)
for (let i = 0; i < predictions.length; i++) {
drawBBox(ctx, predictions[i].bbox, predictions[i].class)
}
// 次フレーム
setTimeout(() => {
window.requestAnimationFrame(onFrame.bind(null, model, webcamElement))
}, 1000)
}
// Webカメラの開始
const constraints = {
audio: false,
video: true
}
navigator.mediaDevices.getUserMedia(constraints)
// 成功時に呼ばれる
.then((stream) => {
const video = document.querySelector('video')
video.srcObject = stream
// 物体検出の開始
startDetect()
})
// エラー時に呼ばれる
.catch((error) => {
const errorMsg = document.querySelector('#errorMsg')
errorMsg.innerHTML += `<p>${error.name}</p>`
})
</script>
</head>
<body>
<video id="webcam" width="320" height="240" autoplay playsinline></video>
<canvas id="canvas"></canvas>
<div id="errorMsg"></div>
</body>
</html>
この記事が気に入ったらサポートをしてみませんか?