webエンジニアの日常

RubyやPython, JSなど、IT関連の記事を書いています

初めてのTensorFlow入門~手書き文字を綺麗なフォントに変換する~

こんにちは、エンジニアのさもです。

今回は、前回使った畳み込みニューラルネットワークを改造して遊んでみたいと思います。

最終的にこんな結果になります。

f:id:s-uotani-zetakansu:20171006175023p:plain

スポンサーリンク

やること

やりたいことは、手書き文字を綺麗なフォント文字に変換することです。

アイデアとしては、入力が手書き文字で、教師データがフォント文字(を数値化して配列にしたもの)とし、2乗誤差を損失関数として学習させます。

グラフはこんな感じです。

f:id:s-uotani-zetakansu:20171006173410p:plain

学習には、mnistの学習データと、教師データにはHandy Georgeというフォントを使わせていただきました。

f:id:s-uotani-zetakansu:20171006174241p:plain

これをひとつひとつペイントで28×28サイズの画像に分けて使います。

実装

特に変わったところはないので、さらっといきます

ライブラリのインポート

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline

np.random.seed(20171001)
tf.set_random_seed(20171001)
mnist = input_data.read_data_sets("./data/", one_hot=True)

畳み込み・プーリングの定義

cfilter_size = 16
input_size = 14*14*cfilter_size

x = tf.placeholder(tf.float32, [None, 784])
x_image = tf.reshape(x, [-1,28,28,1])

cfilter = tf.Variable(tf.truncated_normal([5,5,1,cfilter_size],stddev=0.1))
cfiltered = tf.nn.conv2d(x_image, cfilter, strides=[1,1,1,1], padding='SAME')

pfiltered = tf.nn.max_pool(cfiltered, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pfiltered_flat = tf.reshape(pfiltered, [-1, input_size])

全結合、出力層の定義

hidden_size = 1024

w1 = tf.Variable(tf.truncated_normal([input_size, hidden_size]))
b1 = tf.Variable(tf.zeros([hidden_size]))
h =  tf.nn.relu(tf.matmul(pfiltered_flat,w1) + b1)

w0 = tf.Variable(tf.zeros([hidden_size, 784]))
b0 = tf.Variable(tf.zeros([784]))
y = tf.matmul(h, w0) + b0

t = tf.placeholder(tf.float32, [None, 784])

損失関数、学習の定義

loss = tf.reduce_sum(tf.square(t-y))
train_step = tf.train.AdamOptimizer().minimize(loss)

教師データの準備

from PIL import Image
import os
# Handy George
filenames = os.listdir('./teacher2')
imgs = []
for name in filenames:
    img = Image.open('./teacher2/' + name).convert('L')
    img.thumbnail((28, 28))
    img = np.array(img, dtype=np.float32)
    img = 1-np.array(img / 255)
    img = img.reshape(1, 784)
    imgs.append(img)
imgs = np.array(imgs)

def label2im(labels):
    limgs = []
    for label in labels:
        limgs.append(imgs[np.argmax(label)])
    return np.array(limgs).reshape((-1, 784))

label2imは、mnistのラベルデータを受け取って、対応するフォント画像(を数値化し、配列にしたも)を返します。

word2vecみたいにおしゃれな名前を付けてみました。

学習

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
i = 0
imgs = np.array(imgs).reshape([10, 784])
sess.run(pfiltered_flat, feed_dict={x: imgs})
for _ in range(20000):
    i += 1
    batch_xs, batch_ts = mnist.train.next_batch(100)
    batch_ts = label2im(batch_ts)
    sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts})
    if i % 1000 == 0:
        loss_val = sess.run(loss, feed_dict = {x: batch_xs, t: batch_ts})
        print("step: %d, loss: %f" % (i, loss_val))

そこそこ時間がかかります。

mnist.test.imagesで試してみる

mnistのテストデータを対象に、変換してみたいと思います。

from PIL import Image
c = 1
fig = plt.figure(figsize=(6, 8))
for img in mnist.test.images[0:10]:
    res = sess.run(y, feed_dict={x: [img]})[0]
    res = res - res[np.argmin(res)]
    res = res / res[np.argmax(res)]

    subplot = fig.add_subplot(5, 4, 2 * c - 1)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('before:%d' % (c-1))
    subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    
    subplot = fig.add_subplot(5, 4, 2*c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('after:%d' % (c-1))
    subplot.imshow(res.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1

実行結果がこちらです

f:id:s-uotani-zetakansu:20171006175023p:plain

見て分かりますが、左がmnistのデータで、右がそれを変換したものです。

自分の手書き文字で試してみる

from PIL import Image
import os
filenames = os.listdir('./sample/test_samples')
c = 1
fig = plt.figure(figsize=(6, 8))
labels = []
imgs = []
for name in filenames:
    img = Image.open("./sample/test_samples/" + name).convert('L')
    img.thumbnail((28, 28))
    img = np.array(img, dtype=np.float32)
    img = 1-np.array(img / 255)
    img = img.reshape(1, 784)
    res = sess.run(y, feed_dict={x: img})[0]
    res = res - res[np.argmin(res)]
    res = res / res[np.argmax(res)]

    subplot = fig.add_subplot(5, 4, 2 * c - 1)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('before:%d' % (c-1))
    subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    
    subplot = fig.add_subplot(5, 4, 2*c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('after:%d' % (c-1))
    subplot.imshow(res.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1

結果はこちらです。

f:id:s-uotani-zetakansu:20171006175201p:plain

9が苦手なのは一緒なんですね。

最後に

画像が分類できれば、それに対応するフォントを出すだけでいいのですが、画像の変換みたいなことが出来ないかなと思いやってみました。

mnistの画像に対しては上手くいくのですが、自前の画像だとなかなか上手くいかないときがありますね。

mnistが持つ何らかの特徴にオーバーフィッティングしているということでしょうか。

以前どこかで見た、絵画を有名画家風に変換するみたいなこともやってみたいですね。

以上、勉強の合間にちょっと遊んでみました。

読者登録をしていただけると、ブログを続ける励みになりますので、よろしくお願いします。