見出し画像

新聞のパズルをプログラムで解く(数独編)

デジタル版の新聞紙面からパズルの画像をキャプチャして、プログラムで扱いやすいように変換して、(ついでに)解いてしまおうというおはなしです.今回は数独(ナンプレ)です.
今どき文字認識なんてAPI一発で...ということかも知れませんが、趣味なので手作り感のある方法でやっています.文字認識AI入ってます(笑)

前提ソフトウェア

Python 3, Tensorflow, OpenCV2

前提とするパズル画像

上下左右に若干の白い余白があるようにスクリーンショットなどでパズル部分を切り取った画像を用意します.

画像を用意(例としてhttps://noukatu.com/sudoku-printable/1198 から借用しました)
だいたいこのような感じで画素数が500x500以上くらいになるようにキャプチャしてください.

プログラムの流れ

  1. 画像から各マスの数字を読み取り9x9の行列に変換する

    • 画像を読み込んで下処理をする(トリミング、白黒化など)

    • 各マスごとに画像を切り出し、そこに書かれている数字が何かを識別する判定機を作る(Tensorflow使います)

    • 判定機を使って画像からパズル(9x9の行列)を作る

  2. 数独ソルバでパズルを解く

画像の読み込みと下処理

画像を読み込んで白黒の2値画像に変換する

今回はOpenCV2を使います.

import cv2

def read_file(file):
    """ 数独の画像を読んで白黒2値に変換する """
    image = cv2.imread(file, 0)
    threshold = 150
    _, imgage = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY)
    return image

余白をトリミングする

画像全体を9x9に分割して各マスを切り出したいので余白サイズに左右されないように、まず最初に余白をトリミングします.あまり良い方法が思い浮かばず、競プロの時のようにアドホックに書いてしまいました.

from itertools import groupby

def remove_margin(image):
    """ 画像の外側のマージン部分をトリミングする """
    def margins(l):
        ret = []
        for v, ll in groupby(l):
            if v == -1:
                ret.append(len(list(ll)))
        assert len(ret) == 2
        return ret

    H, W = image.shape
    rows = []
    for h in range(H):
        r = image[h:h+1, :]
        if r.min() > 200:
            rows.append(-1)
        else:
            rows.append(1)
    top, bottom = margins(rows)

    cols = []
    for w in range(W):
        c = image[:, w:w+1]
        if c.min() > 200:
            cols.append(-1)
        else:
            cols.append(1)
    left, right = margins(cols)

    image = image[top:H-bottom, left:W-right]
    return image

9x9のマスの画像を切り出す

左右上下の余白がトリミングされた画像を9x9の等分に分割して、各マスの画像を切り出します.

def get_cells(image):
    """ 9x9のマスを切り取る """
    h, w = image.shape
    w, h = w//9, h//9
    ret = []
    for r in range(9):
        for c in range(9):
            cell = image[r*h:r*h+h,c*w:c*w+w]
            ret.append(cell)
    return ret

マスの画像から少しづつずらしたクロッピング画像を作る

1枚のマスの画像から少しずつずらしたクロッピング画像を作り、学習用データの「水増し」をします.境界線の含まれ方の影響を受けにくくする効果を期待しています.

PX = 28                       # モデルで使う画像サイズ

def crop(image, center=False):
    """ マス画像の外周部をトリミングした画像をいくつか作る """
    h, w = image.shape
    cy, cx = h//2, w//2
    hh, wh = int(h/2 * 0.8), int(w/2 * 0.8)
    if center:
        img = image[cy - hh:cy+hh, cx-wh:cx+wh]
        img = cv2.resize(img, (PX, PX))
        return img

    ret = []
    for dy in range(-4, 5):
        for dx in range(-4, 5):
            y1, y2 = max(0, (cy+dy) - hh), min(h, (cy+dy) + hh)
            x1, x2 = max(0, (cx+dx) - wh), min(w, (cx+dx) + wh)
            img = image[y1:y2, x1:x2]
            img = cv2.resize(img, (PX, PX))
            ret.append(img)
    return ret

教師データを用意する

1〜9の数字が必ず出現するようなパズルを学習用に用意し(最低1枚でもなんとかなると思いますが)、人力でラベルデータを作ります.パズル1枚分のデータは下のような感じになります.空白部分は0としてください.

label = [0,0,0,8,0,0,0,5,0,0,6,0,0,7,0,3,0,0,4,0,0,0,0,6,0,0,0,
         8,0,0,5,0,0,0,1,0,0,0,7,0,4,0,9,0,0,0,4,0,0,0,7,0,0,6,
         0,0,0,1,0,0,0,0,8,0,0,5,0,2,0,0,7,0,0,9,0,0,0,3,0,0,0]

数字画像の分類器を作る

ここまでで前処理などの準備が出来ましたので、画像を入力に0〜9のどれかを推論する分類器を作成します.

モデルを定義する

今回はTensorflow(Keras)を使います.画像サイズとクラス数がMNISTと同じなのでチュートリアルなどがそのまま使えるはずです.CNNを使いましたが、この程度の大きさであればGPU無くてもそれほど時間はかかりません.

import tensorflow as tf
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

model = tf.keras.models.Sequential()
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=(PX, PX, 1)))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

学習データの作成

今まで準備した関数を使って学習データを作成します.この例では'sudoku.png'という名前のファイルを1枚だけ使っています.
xdataに学習データ、ydataに教師データが格納されます.

def one_hot_encoding(l):
    return np.identity(10)[l]

image = read_file(myDrive+'sudoku.png')
image = remove_margin(image)
cells = get_cells(image)

xdata = []
ydata = []
for r in range(9):
    for c in range(9):
        idx = r*9 + c
        img = cells[idx]
        for x in crop(img):
            xdata.append(np.array(x))
            ydata.append(one_hot_encoding(label[idx]))
xdata = np.array(xdata)
ydata = np.array(ydata)

学習

これで数字画像の分類器ができるはずです.

model.fit(xdata, ydata, epochs=10)

学習したモデルを使って新しいパズル画像を変換する

学習で使った画像とは異なるパズルの画像を正しく変換できるか確認します.学習で使ったときと同様の前処理をして、9x9のマスの画像をまとめてバッチにしてmodelに推論してもらいます.

image = read_file(sudoku1.png')
image = remove_margin(image)
cells = get_cells(image)

size = xdata[0].shape[0]
print(cells[0].shape)
input = []
for r in range(9):
    for c in range(9):
        img = cells[r*9 + c]
        img = crop(img, center=True)
        input.append(np.array(img))
input = np.array(input)
pred = model.predict(input)

確率分布の形で出力されるのでargmaxをとって0〜9の値にします.2次元のリストに推論結果を保存して、数独ソルバで解けるようにします.

puzzle = []
for r in range(9):
    row = []
    for c in range(9):
        v = np.argmax(pred[r*9 + c])
        row.append(v)
    puzzle.append(row)

for r in puzzle:
    print(*r)

数独ソルバ

バックトラッキングで全探索します(力技です・・).

import random

def is_legal(pzl, r, c, n):
    """ pzl[r][c]にnを書き込めるか? """
    for i in range(9):
        if pzl[i][c] == n:
            return False
    for j in range(9):
        if pzl[r][j] == n:
            return False

    r, c = (r//3)*3, (c//3)*3
    for i in range(r, r+3):
        for j in range(c, c+3):
            if pzl[i][j] == n:
                return False
    return True

def solver(pzl):
    """ 数独ソルバ """
    for r in range(9):
        for c in range(9):
            if pzl[r][c] == 0:
                nums = list(range(1, 10))
                random.shuffle(nums)
                for n in nums:
                    if is_legal(pzl, r, c, n):
                        pzl[r][c] = n
                        if solver(pzl):
                            return True
                pzl[r][c] = 0
                return False
    return True


if solver(puzzle):
    print('...解けました!')
else:
    print('解けません.パズルのデータがおかしい??')
for r in puzzle:
    print(*r)


おわりに

とても少ないサンプル数で確認しただけなので、入力画像のちょっとした違いなどで認識ミスをする可能性があります.そんな時は数年前のAI開発の現場の雰囲気(チューニングを頑張るとか)が味わえると笑って許してください.

コード全体をまとめておきます.

from itertools import groupby
import numpy as np
import cv2


def read_file(file):
    """ 数独の画像を読んで白黒2値に変換する """
    image = cv2.imread(file, 0)
    threshold = 150
    _, imgage = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY)
    return image


def remove_margin(image):
    """ 画像の外側のマージン部分をトリミングする """
    def margins(l):
        ret = []
        for v, ll in groupby(l):
            if v == -1:
                ret.append(len(list(ll)))
        assert len(ret) == 2
        return ret

    H, W = image.shape
    rows = []
    for h in range(H):
        r = image[h:h+1, :]
        if r.min() > 200:
            rows.append(-1)
        else:
            rows.append(1)
    top, bottom = margins(rows)

    cols = []
    for w in range(W):
        c = image[:, w:w+1]
        if c.min() > 200:
            cols.append(-1)
        else:
            cols.append(1)
    left, right = margins(cols)

    image = image[top:H-bottom, left:W-right]
    return image


def get_cells(image):
    """ 9x9のマスを切り取る """
    h, w = image.shape
    w, h = w//9, h//9
    ret = []
    for r in range(9):
        for c in range(9):
            cell = image[r*h:r*h+h,c*w:c*w+w]
            ret.append(cell)
    return ret


PX = 28    # モデルで使う画像サイズ

def crop(image, center=False):
    """ マス画像の外周部をトリミングした画像をいくつか作る """
    h, w = image.shape
    cy, cx = h//2, w//2
    hh, wh = int(h/2 * 0.8), int(w/2 * 0.8)
    if center:
        img = image[cy - hh:cy+hh, cx-wh:cx+wh]
        img = cv2.resize(img, (PX, PX))
        return img

    ret = []
    for dy in range(-4, 5):
        for dx in range(-4, 5):
            y1, y2 = max(0, (cy+dy) - hh), min(h, (cy+dy) + hh)
            x1, x2 = max(0, (cx+dx) - wh), min(w, (cx+dx) + wh)
            img = image[y1:y2, x1:x2]
            img = cv2.resize(img, (PX, PX))
            ret.append(img)
    return ret

# モデルの定義
import tensorflow as tf
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

model = tf.keras.models.Sequential()
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=(PX, PX, 1)))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 学習データの準備
# 教師データのラベル情報(ここは書き換えてください)
label = [0,0,0,8,0,0,0,5,0,0,6,0,0,7,0,3,0,0,4,0,0,0,0,6,0,0,0,
         8,0,0,5,0,0,0,1,0,0,0,7,0,4,0,9,0,0,0,4,0,0,0,7,0,0,6,
         0,0,0,1,0,0,0,0,8,0,0,5,0,2,0,0,7,0,0,9,0,0,0,3,0,0,0]

def one_hot_encoding(l):
    return np.identity(10)[l]

image = read_file('sudoku.png')
image = remove_margin(image)
cells = get_cells(image)

xdata = []
ydata = []
for r in range(9):
    for c in range(9):
        idx = r*9 + c
        img = cells[idx]
        for x in crop(img):
            xdata.append(np.array(x))
            ydata.append(one_hot_encoding(label[idx]))
xdata = np.array(xdata)
ydata = np.array(ydata)

# 学習
model.fit(xdata, ydata, epochs=10)

# 学習済モデルで新しいパズルを変換する
image = read_file('sudoku1.png')
image = remove_margin(image)
cells = get_cells(image)

size = xdata[0].shape[0]
print(cells[0].shape)
input = []
for r in range(9):
    for c in range(9):
        img = cells[r*9 + c]
        img = crop(img, center=True)
        input.append(np.array(img))
input = np.array(input)
pred = model.predict(input)

puzzle = []
for r in range(9):
    row = []
    for c in range(9):
        v = np.argmax(pred[r*9 + c])
        row.append(v)
    puzzle.append(row)

# 変換したパズルを解く
import random

def is_legal(pzl, r, c, n):
    """ pzl[r][c]にnを書き込めるか? """
    for i in range(9):
        if pzl[i][c] == n:
            return False
    for j in range(9):
        if pzl[r][j] == n:
            return False

    r, c = (r//3)*3, (c//3)*3
    for i in range(r, r+3):
        for j in range(c, c+3):
            if pzl[i][j] == n:
                return False
    return True

def solver(pzl):
    """ 数独ソルバ """
    for r in range(9):
        for c in range(9):
            if pzl[r][c] == 0:
                nums = list(range(1, 10))
                random.shuffle(nums)
                for n in nums:
                    if is_legal(pzl, r, c, n):
                        pzl[r][c] = n
                        if solver(pzl):
                            return True
                pzl[r][c] = 0
                return False
    return True

# 数独ソルバを呼び出す
if solver(puzzle):
    print('...解けました!')
else:
    print('解けません.パズルのデータがおかしいようです..')
for r in puzzle:
    print(*r)




この記事が気に入ったらサポートをしてみませんか?