見出し画像

Turi Createによるアニメのレコメンド

「Turi Create」によるアニメのレコメンドの実装方法をまとめました。

・Turi Create 6.4.1
・Python 3.7

1. Turi Create

「Turi Create」は、Appleが提供する機械学習フレームワークです。

以下のタスクおよびアルゴリズムの、学習および推論を行うことができます。

◎ タスクベースの機能

・画像分類
・類似画像検出
・物体検出
・画風変換
・活動分類
・テキスト分類
・レコメンド
・ワンショット物体検出
・手書き分類
・音声分類

◎ アルゴリズムベースの機能

・分類
・回帰
・クラスタリング
・グラフ分析
・テキスト分析

2. Turi Createのインストール

「Google Colab」で以下のコマンドを実行します。

!pip install turicreate

3. データセットの準備

今回は、Kaggleの「Anime Recommendations Database」を使います。Kaggleのサイトにログインして、「rating.csv」と「anime.csv」を取得し、「Google Colab」にアップロードしてください。

CSVの内容は、次のとおりです。

◎ ratings.csv

・user_id :  ユーザーID
・anime_id :  アニメのユーニークID
・rating : アニメのレーティング (-1 or 1〜10)

◎ anime.csv

・anime_id : アニメのユニークID
・name : アニメのタイトル
・genre : アニメのカテゴリ
・type : アニメのメディアタイプ (TV, OVA, Movie)
・episodes : アニメのエピソード数
・rating : アニメの平均レーティング (-1 or 1〜10)
・members : アニメのグループに参加するユーザー数

4. データセットの読み込み

データセットを読み込むコードは、次のとおりです。

(1) 「rating.csv」を読み込む。

import turicreate as tc

# rating.csvの読み込み
actions = tc.SFrame.read_csv('./rating.csv')
print(actions)
+---------+----------+--------+
| user_id | anime_id | rating |
+---------+----------+--------+
|    1    |    20    |   -1   |
|    1    |    24    |   -1   |
|    1    |    79    |   -1   |
|    1    |   226    |   -1   |
|    1    |   241    |   -1   |
|    1    |   355    |   -1   |
|    1    |   356    |   -1   |
|    1    |   442    |   -1   |
|    1    |   487    |   -1   |
|    1    |   846    |   -1   |
+---------+----------+--------+
[7813737 rows x 3 columns]

(2) 「anime.csv」を読み込む。

# anime.csvの読み込み
items = tc.SFrame.read_csv('./anime.csv')
print(items)
+----------+-------------------------------+--------------------------------+
| anime_id |              name             |             genre              |
+----------+-------------------------------+--------------------------------+
|  32281   |         Kimi no Na wa.        | Drama, Romance, School, Su...  |
|   5114   | Fullmetal Alchemist: Broth... | Action, Adventure, Drama, ...  |
|  28977   |            Gintama°           | Action, Comedy, Historical...  |
|   9253   |          Steins;Gate          |        Sci-Fi, Thriller        |
|   9969   |         Gintama'         | Action, Comedy, Historical...  |
|  32935   | Haikyuu!!: Karasuno Koukou... | Comedy, Drama, School, Sho...  |
|  11061   |     Hunter x Hunter (2011)    | Action, Adventure, Shounen...  |
|   820    |      Ginga Eiyuu Densetsu     | Drama, Military, Sci-Fi, Space |
|  15335   | Gintama Movie: Kanketsu-he... | Action, Comedy, Historical...  |
|  15417   |    Gintama': Enchousen   | Action, Comedy, Historical...  |
+----------+-------------------------------+--------------------------------+
+-------+----------+--------+---------+
|  type | episodes | rating | members |
+-------+----------+--------+---------+
| Movie |    1     |  9.37  |  200630 |
|   TV  |    64    |  9.26  |  793665 |
|   TV  |    51    |  9.25  |  114262 |
|   TV  |    24    |  9.17  |  673572 |
|   TV  |    51    |  9.16  |  151266 |
|   TV  |    10    |  9.15  |  93351  |
|   TV  |   148    |  9.13  |  425855 |
|  OVA  |   110    |  9.11  |  80679  |
| Movie |    1     |  9.1   |  72534  |
|   TV  |    13    |  9.11  |  81109  |
+-------+----------+--------+---------+
[12294 rows x 7 columns]

5. レコメンドモデルの作成と実行

レコメンドモデルの作成と実行を行うコードは、次のとおりです。

# 訓練データと検証データに分割
train_data, valid_data = tc.recommender.util.random_split_by_user(actions, 'user_id', 'anime_id')
   
# レコメンドモデルの作成
model = tc.recommender.create(train_data, 'user_id', 'anime_id')

# レコメンドモデルの実行
results = model.recommend([1])
print(results)
+---------+----------+---------------------+------+
| user_id | anime_id |        score        | rank |
+---------+----------+---------------------+------+
|    1    |  15583   |  0.1314304383074651 |  1   |
|    1    |  10110   | 0.11426670238619945 |  2   |
|    1    |  15809   | 0.11381574536933274 |  3   |
|    1    |  15315   | 0.10413592760680151 |  4   |
|    1    |  17247   | 0.09412194275465167 |  5   |
|    1    |   9181   | 0.09371922397222675 |  6   |
|    1    |  13161   |  0.0894805950219514 |  7   |
|    1    |  17895   | 0.08931980709560582 |  8   |
|    1    |   8769   |  0.0871320752823939 |  9   |
|    1    |  12293   | 0.08446551297531753 |  10  |
+---------+----------+---------------------+------+
[10 rows x 4 columns]

引数[user_id]model.recommend()を呼ぶことで、ユーザー毎のレコメンド(推奨事項)を取得できます。引数を指定しない場合は、全ユーザーのレコメンドが返ります。ただし、ユーザーがレーティングしたアイテムは除外されます。

model.recommend()のパラメータは、次のとおりです。

・users : SArray, SFrame, list - ユーザー群
・k : int - レコメンド数
・exclude : SFrame - 除外
・items : SArray, SFrame, or list - アイテム群
・new_observation_data : SFrame - 新規観察データ
・new_user_data : SFrame - 新規ユーザーデータ
・new_item_data : SFrame - 新規アイテムデータ
・exclude_known : bool - 既知アイテムの除外
・diversity : non-negative float - 多様性 (1〜3)
・random_seed : int - 乱数シード
・verbose : bool - 状況出力

わかりやすさのため、アニメIDをアニメ名に変換してみます。

# アニメIDをアニメ名に変換
results['anime_id'] = results['anime_id'].apply(
    lambda anime_id: items[items['anime_id'] == anime_id]['name'][0])
print(results)
+---------+-------------------------------+---------------------+------+
| user_id |            anime_id           |        score        | rank |
+---------+-------------------------------+---------------------+------+
|    1    |          Date A Live          |  0.1314304383074651 |  1   |
|    1    |          Mayo Chiki!          | 0.11426670238619945 |  2   |
|    1    |      Hataraku Maou-sama!      | 0.11381574536933274 |  3   |
|    1    | Mondaiji-tachi ga Isekai k... | 0.10413592760680151 |  4   |
|    1    |  Machine-Doll wa Kizutsukanai | 0.09412194275465167 |  5   |
|    1    |        Motto To LOVE-Ru       | 0.09371922397222675 |  6   |
|    1    |  Hagure Yuusha no Aesthetica  |  0.0894805950219514 |  7   |
|    1    |          Golden Time          | 0.08931980709560582 |  8   |
|    1    | Ore no Imouto ga Konnani K... |  0.0871320752823939 |  9   |
|    1    | Campione!: Matsurowanu Kam... | 0.08446551297531753 |  10  |
+---------+-------------------------------+---------------------+------+

ユーザID「1」へのレコメンドは、以下のアニメであることがわかります。

1. デート・ア・ライブ
2. まよチキ!
3. はたらく魔王さま!
4. 問題児たちが異世界から来るそうですよ?
5. 機巧少女は傷つかない
6. もっとTo LOVEる−とらぶる−
7. はぐれ勇者の鬼蓄美学
8. ゴールデンタイム
9. 俺の妹がこんなに可愛いわけがない
10. カンピオーネ! ~まつろわぬ神々と神殺しの魔王

6. 新規ユーザー向けのレコメンド

model.recommend()の引数に未レーティングなユーザーIDを使うと、デフォルトで人気のあるアイテムをレコメンドします。

# 新規ユーザー向けのレコメンドの作成
results = model.recommend([999999]) # 未レーティングなID
results['anime_id'] = results['anime_id'].apply(
    lambda anime_id: items[items['anime_id'] == anime_id]['name'][0])
print(results)
+---------+------------------------+---------------------+------+
| user_id |        anime_id        |        score        | rank |
+---------+------------------------+---------------------+------+
|  999999 |      Angel Beats!      |  0.3071678948402405 |  1   |
|  999999 | Highschool of the Dead | 0.28175392031669616 |  2   |
|  999999 |    No Game No Life     |  0.2742947745323181 |  3   |
|  999999 |      Guilty Crown      |  0.271802704334259  |  4   |
|  999999 |    Mirai Nikki (TV)    |  0.2692363142967224 |  5   |
|  999999 |   Shingeki no Kyojin   |  0.2642093980312347 |  6   |
|  999999 |        Noragami        |  0.2622223353385925 |  7   |
|  999999 |  Hataraku Maou-sama!   | 0.26127488374710084 |  8   |
|  999999 |    Sword Art Online    |  0.2531563758850098 |  9   |
|  999999 |      Steins;Gate       | 0.24980076789855957 |  10  |
+---------+------------------------+---------------------+------+
[10 rows x 4 columns]

新規ユーザへのレコメンドは、以下のアニメであることがわかります。

1. Angel Beats!
2. 学園黙示録 HIGHSCHOOL OF THE DEAD
3. ノーゲーム・ノーライフ
4. ギルティクラウン
5. 未来日記
6. 進撃の巨人
7. ノラガミ
8. はたらく魔王さま!
9. ソードアート・オンライン
10. STEINS;GATE​

7. 類似アイテムの検索

類似アイテムの検索行うコードは、次のとおりです。

# 類似アイテムの検索similar_items = model.get_similar_items([16498, 22535]) # [進撃の巨人, 寄生獣]
similar_items['similar'] = similar_items['similar'].apply(
    lambda anime_id: items[items['anime_id'] == anime_id]['name'][0])
print(similar_items)
+----------+-------------------------------+---------------------+------+
| anime_id |            similar            |        score        | rank |
+----------+-------------------------------+---------------------+------+
|  16498   |        Sword Art Online       |  0.5374723076820374 |  1   |
|  16498   |           Death Note          | 0.47845208644866943 |  2   |
|  16498   |        Mirai Nikki (TV)       | 0.42451924085617065 |  3   |
|  16498   |          Angel Beats!         | 0.40856748819351196 |  4   |
|  16498   | Fullmetal Alchemist: Broth... | 0.38907521963119507 |  5   |
|  16498   |        No Game No Life        | 0.38764047622680664 |  6   |
|  16498   |          Tokyo Ghoul          |  0.3853837847709656 |  7   |
|  16498   |         Ao no Exorcist        |  0.3807450532913208 |  8   |
|  16498   | Code Geass: Hangyaku no Le... | 0.37176352739334106 |  9   |
|  16498   |           Elfen Lied          |  0.3689039349555969 |  10  |
+----------+-------------------------------+---------------------+------+

「進撃の巨人」「寄生獣」に類似するのは、以下のアニメであることがわかります。

1. ソードアート・オンライン
2. DEATH NOTE
3. 未来日記
4. Angel Beats!
5. 鋼の錬金術師 FULLMETAL ALCHEMIS
6. ノーゲーム・ノーライフ
7. 東京喰種トーキョーグール
8. 青の祓魔師
9. コードギアス 反逆のルルーシュ
10. エルフェンリート

8. モデルの保存と読み込み

モデルの保存と読み込みを行うコードは、次のとおりです。

# モデルの保存
model.save("recommendations.model")
# モデルの読み込みmodel = tc.load_model("recommendations.model")

【おまけ】 英語のアニメタイトルを日本語に変換するスクリプト

英語のアニメタイトルを日本語に変換するスクリプトは次のとおりです。「Wikipedia <英語のアニメタイトル>」のGoogle検索結果で判定しています。

import requests as web
import bs4

# 日本語タイトルに変換
def get_jp_title(en_title):  
    # キーワードで検索
    keyward_list = ['wikipedia',en_title]
    url = 'https://www.google.co.jp/search?num=1&q=' + ' '.join(keyward_list)
    resp = web.get(url)
    resp.raise_for_status()
    #print ('url:', url)

    # HTMLのパース
    soup = bs4.BeautifulSoup(resp.text, "html.parser")
    link_elem = soup.select('.kCrYT>a') # Googleの仕様変更により変更する必要あり
    leng = len(link_elem)    
    if leng == 0:
        print('エラー: リンクが見つからない - ', title)
        return ''

    # タイトルの取得
    title = link_elem[0].get_text() # タイトル
   
    # Wikipediaの削除
    if title.find(' - ウィキペディア') >= 0:
        title = title.split(' - ウィキペディア')[0]
    elif title.find(' - Wikipedia') >= 0:
        title = title.split(' - Wikipedia')[0]
    else:
        print('エラー:検索失敗 - ', en_title)
        return en_title
       
    # カッコの削除
    if title.find("("):
        title = title.split('(')[0]
    if title.find(" ("):
        title = title.split(' (')[0]        
    return title

import time
output = []
with open('popular100.csv') as f:
    lines = f.readlines()
    for line in lines:
        strs = line.strip().split(',')
       
        if len(strs) >= 2:
            jp_title = get_jp_title(strs[1])
            if jp_title != '':
                print(strs[0]+','+strs[1]+','+jp_title)
                output.append(strs[0]+','+strs[1]+','+jp_title+'\n')
               
            # サーバーに負荷をかけないためのスリープ
            time.sleep(5)
        else:
            print(strs)
       
# CSVに出力
with open('popular100_jp.csv','w') as f:
    f.writelines(output)



この記事が気に入ったらサポートをしてみませんか?