![見出し画像](https://assets.st-note.com/production/uploads/images/23170920/rectangle_large_type_2_bdc4f737857d07f1d40dc37779733efb.png?width=800)
TenorFlow Lite 入門 / Androidによるテキスト分類
1. Androidによる姿勢推定
「TensorFlow Lite」を使ってAndroidでテキスト分類を行います。入力した映画の感想(英語)がネガティブとポジティブのどちらであるかを判定します。
2. バージョン
・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7
3. 依存関係の追加
「build.gradle(Mudule:app)」に、「TensorFlow Lite」のプロジェクトの依存関係を追加します。
android {
<<省略>>
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
}
dependencies {
<<省略>>
// TensorFlow Lite
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}
4. マニフェストファイルの設定
「CAMERA」のパーミッションを追加します。
<uses-permission android:name="android.permission.CAMERA" />
5. アセットの準備
プロジェクトの「app/src/main/assets」に、「TensorFlow Lite text classification sample」のページからダウンロードしたモデルとラベルと単語辞書を追加します。
・text_classification.tflite
・text_classification_labels.txt
・text_classification_vocab.txt
6. レイアウトの設定
「activity_main.xml」に「ScrollView」と「TextView」と「EditText」と「Button」を追加します。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<RelativeLayout android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#FFFFFF">
<TextureView
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<ImageView
android:id="@+id/image_view"
android:layout_width="640px"
android:layout_height="640px"
android:layout_margin="16dp"
android:visibility="visible"
app:srcCompat="@mipmap/ic_launcher" />
</RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
7. UIの作成
画像分類を行うUIを作成します。
以下の処理を行なっています。
・パーミッション
・カメラのプレビューと解析
・TextClassificationInterpriterにテキストを渡して推論(後ほど説明)
◎ MainActivity.java
package net.npaka.textclassificationex;
import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.ScrollView;
import android.widget.TextView;
import java.util.List;
public class MainActivity extends AppCompatActivity {
// UI
private TextView resultTextView;
private EditText inputEditText;
private ScrollView scrollView;
// 推論
private TextClassificationInterpreter interpreter;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// 推論
this.interpreter = new TextClassificationInterpreter(this);
// UI
this.resultTextView = findViewById(R.id.result_text_view);
this.inputEditText = findViewById(R.id.input_edit_text);
this.scrollView = findViewById(R.id.scroll_view);
Button button = findViewById(R.id.button);
button.setOnClickListener((View v) -> {
// 推論
String inputText = inputEditText.getText().toString();
List<Result> results = interpreter.predict(inputText);
// 結果表示
String textToShow = "Input: " + inputText + "\nOutput:\n";
for (int i = 0; i < results.size(); i++) {
Result result = results.get(i);
textToShow += String.format(" %s: %s\n", result.title, result.confidence);
}
textToShow += "---------\n";
resultTextView.append(textToShow);
inputEditText.getText().clear();
scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
});
}
}
8. テキスト分類
テキストを受け取り、テキスト分類の結果を返します。
◎ TextClassificationInterpreter.java
package net.npaka.textclassificationex;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
// テキスト分類インタプリタ
public class TextClassificationInterpreter {
// 定数
private static final int SENTENCE_LEN = 256; // 入力文の最大長
private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\,|\\.|\\!|\\?|\n"; // 単語分割する区切り文字
// ImdbDataSet dicの予約値
// dic["<PAD>"] = 0 パディング
// dic["<START>"] = 1 文頭
// dic["<UNKNOWN>"] = 2 不明 (OOV)
private static final String START = "<START>";
private static final String PAD = "<PAD>";
private static final String UNKNOWN = "<UNKNOWN>";
// 推論
private Context context;
private Map<String, Integer> dic = new HashMap<>();
private List<String> labels = new ArrayList<>();
private Interpreter interpreter;
// コンストラクタ
public TextClassificationInterpreter(Context context) {
this.context = context;
loadModel("text_classification.tflite");
loadDictionary("text_classification_vocab.txt");
loadLabels("text_classification_labels.txt");
}
// モデルの読み込み
private synchronized void loadModel(String path) {
try {
AssetFileDescriptor fd = this.context.getAssets().openFd(path);
FileInputStream in = new FileInputStream(fd.getFileDescriptor());
FileChannel fc = in.getChannel();
long startOffset = fd.getStartOffset();
long declaredLength = fd.getDeclaredLength();
ByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
this.interpreter = new Interpreter(buffer);
} catch (Exception e) {
e.printStackTrace();
}
}
// 辞書の読み込み
private synchronized void loadDictionary(String path) {
try {
InputStream ins = this.context.getAssets().open(path);
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
while (reader.ready()) {
List<String> line = Arrays.asList(reader.readLine().split(" "));
if (line.size() < 2) {
continue;
}
dic.put(line.get(0), Integer.parseInt(line.get(1)));
}
} catch (Exception e) {
e.printStackTrace();
}
}
// ラベルの読み込み
private synchronized void loadLabels(String path) {
try {
InputStream ins = this.context.getAssets().open(path);
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
while (reader.ready()) {
labels.add(reader.readLine());
}
} catch (Exception e) {
e.printStackTrace();
}
}
// 推論
public synchronized List<Result> predict(String text) {
// 推論
float[][] input = makeInput(text);
float[][] output = new float[1][labels.size()];
this.interpreter.run(input, output);
// 結果の取得
PriorityQueue<Result> pq = new PriorityQueue<>(
3, (lhs, rhs) -> Float.compare(rhs.confidence, lhs.confidence));
for (int i = 0; i < labels.size(); i++) {
pq.add(new Result("" + i, labels.get(i), output[0][i]));
}
final ArrayList<Result> results = new ArrayList<>();
while (!pq.isEmpty()) {
results.add(pq.poll());
}
return results;
}
// 入力の生成
private float[][] makeInput(String text) {
float[] tmp = new float[SENTENCE_LEN];
int index = 0;
// 文字列分割
List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));
// START
if (dic.containsKey(START)) {
tmp[index++] = dic.get(START);
}
// 本文
for (String word : array) {
if (index >= SENTENCE_LEN) {
break;
}
tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
}
// PAD
Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
float[][] ans = {tmp};
return ans;
}
}
◎ Result
package net.npaka.textclassificationex;
public class Result {
public final String id;
public final String title;
public final Float confidence;
// コンストラクタ
public Result(final String id, final String title, final Float confidence) {
this.id = id;
this.title = title;
this.confidence = confidence;
}
}
この記事が気に入ったらサポートをしてみませんか?