見出し画像

TenorFlow Lite 入門 / Androidによるテキスト分類

1. Androidによる姿勢推定

「TensorFlow Lite」を使ってAndroidでテキスト分類を行います。入力した映画の感想(英語)がネガティブとポジティブのどちらであるかを判定します。

2. バージョン

・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7

3. 依存関係の追加

「build.gradle(Mudule:app)」に、「TensorFlow Lite」のプロジェクトの依存関係を追加します。

android {
    <<省略>>

    aaptOptions {
        noCompress "tflite"
    }
    compileOptions {
        sourceCompatibility = '1.8'
        targetCompatibility = '1.8'
    }
}
dependencies {
    <<省略>>

    // TensorFlow Lite
    implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
    implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
    implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}

4. マニフェストファイルの設定

「CAMERA」のパーミッションを追加します。

<uses-permission android:name="android.permission.CAMERA" />

5. アセットの準備

プロジェクトの「app/src/main/assets」に、「TensorFlow Lite text classification sample」のページからダウンロードしたモデルとラベルと単語辞書を追加します。

・text_classification.tflite
・text_classification_labels.txt
・text_classification_vocab.txt

6. レイアウトの設定

「activity_main.xml」に「ScrollView」と「TextView」と「EditText」と「Button」を追加します。

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <RelativeLayout android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:background="#FFFFFF">
        <TextureView
            android:id="@+id/texture_view"
            android:layout_width="match_parent"
            android:layout_height="match_parent"/>

        <ImageView
            android:id="@+id/image_view"
            android:layout_width="640px"
            android:layout_height="640px"
            android:layout_margin="16dp"
            android:visibility="visible"
            app:srcCompat="@mipmap/ic_launcher" />
    </RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

7. UIの作成

画像分類を行うUIを作成します。
以下の処理を行なっています。

・パーミッション
・カメラのプレビューと解析
・TextClassificationInterpriterにテキストを渡して推論(後ほど説明)

◎ MainActivity.java

package net.npaka.textclassificationex;

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.ScrollView;
import android.widget.TextView;

import java.util.List;

public class MainActivity extends AppCompatActivity {
    // UI
    private TextView resultTextView;
    private EditText inputEditText;
    private ScrollView scrollView;

    // 推論
    private TextClassificationInterpreter interpreter;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 推論
        this.interpreter = new TextClassificationInterpreter(this);

        // UI
        this.resultTextView = findViewById(R.id.result_text_view);
        this.inputEditText = findViewById(R.id.input_edit_text);
        this.scrollView = findViewById(R.id.scroll_view);
        Button button = findViewById(R.id.button);
        button.setOnClickListener((View v) -> {
            // 推論
            String inputText = inputEditText.getText().toString();
            List<Result> results = interpreter.predict(inputText);

            // 結果表示
            String textToShow = "Input: " + inputText + "\nOutput:\n";
            for (int i = 0; i < results.size(); i++) {
                Result result = results.get(i);
                textToShow += String.format("    %s: %s\n", result.title, result.confidence);
            }
            textToShow += "---------\n";
            resultTextView.append(textToShow);
            inputEditText.getText().clear();
            scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
        });
    }
}

8. テキスト分類

テキストを受け取り、テキスト分類の結果を返します。

◎ TextClassificationInterpreter.java

package net.npaka.textclassificationex;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;

// テキスト分類インタプリタ
public class TextClassificationInterpreter {
    // 定数
    private static final int SENTENCE_LEN = 256;  // 入力文の最大長
    private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\,|\\.|\\!|\\?|\n"; // 単語分割する区切り文字

    // ImdbDataSet dicの予約値
    // dic["<PAD>"] = 0      パディング
    // dic["<START>"] = 1    文頭
    // dic["<UNKNOWN>"] = 2  不明 (OOV)
    private static final String START = "<START>";
    private static final String PAD = "<PAD>";
    private static final String UNKNOWN = "<UNKNOWN>";

    // 推論
    private Context context;
    private Map<String, Integer> dic = new HashMap<>();
    private List<String> labels = new ArrayList<>();
    private Interpreter interpreter;

    // コンストラクタ
    public TextClassificationInterpreter(Context context) {
        this.context = context;
        loadModel("text_classification.tflite");
        loadDictionary("text_classification_vocab.txt");
        loadLabels("text_classification_labels.txt");
    }

    // モデルの読み込み
    private synchronized void loadModel(String path) {
        try {
            AssetFileDescriptor fd = this.context.getAssets().openFd(path);
            FileInputStream in = new FileInputStream(fd.getFileDescriptor());
            FileChannel fc = in.getChannel();
            long startOffset = fd.getStartOffset();
            long declaredLength = fd.getDeclaredLength();
            ByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
            this.interpreter = new Interpreter(buffer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // 辞書の読み込み
    private synchronized void loadDictionary(String path) {
        try {
            InputStream ins = this.context.getAssets().open(path);
            BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
            while (reader.ready()) {
                List<String> line = Arrays.asList(reader.readLine().split(" "));
                if (line.size() < 2) {
                    continue;
                }
                dic.put(line.get(0), Integer.parseInt(line.get(1)));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // ラベルの読み込み
    private synchronized void loadLabels(String path) {
        try {
            InputStream ins = this.context.getAssets().open(path);
            BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
            while (reader.ready()) {
                labels.add(reader.readLine());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // 推論
    public synchronized List<Result> predict(String text) {
        // 推論
        float[][] input = makeInput(text);
        float[][] output = new float[1][labels.size()];
        this.interpreter.run(input, output);

        // 結果の取得
        PriorityQueue<Result> pq = new PriorityQueue<>(
            3, (lhs, rhs) -> Float.compare(rhs.confidence, lhs.confidence));
        for (int i = 0; i < labels.size(); i++) {
            pq.add(new Result("" + i, labels.get(i), output[0][i]));
        }
        final ArrayList<Result> results = new ArrayList<>();
        while (!pq.isEmpty()) {
            results.add(pq.poll());
        }
        return results;
    }

    // 入力の生成
    private float[][] makeInput(String text) {
        float[] tmp = new float[SENTENCE_LEN];
        int index = 0;

        // 文字列分割
        List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));

        // START
        if (dic.containsKey(START)) {
            tmp[index++] = dic.get(START);
        }

        // 本文
        for (String word : array) {
            if (index >= SENTENCE_LEN) {
                break;
            }
            tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
        }

        // PAD
        Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
        float[][] ans = {tmp};
        return ans;
    }
}

◎ Result

package net.npaka.textclassificationex;

public class Result {
    public final String id;
    public final String title;
    public final Float confidence;

    // コンストラクタ
    public Result(final String id, final String title, final Float confidence) {
        this.id = id;
        this.title = title;
        this.confidence = confidence;
    }
}


この記事が気に入ったらサポートをしてみませんか?