こんにちは、エンジニアのさもです。
今回は、前回使った畳み込みニューラルネットワークを改造して遊んでみたいと思います。
最終的にこんな結果になります。
スポンサーリンク
やること
やりたいことは、手書き文字を綺麗なフォント文字に変換することです。
アイデアとしては、入力が手書き文字で、教師データがフォント文字(を数値化して配列にしたもの)とし、2乗誤差を損失関数として学習させます。
グラフはこんな感じです。
学習には、mnistの学習データと、教師データにはHandy Georgeというフォントを使わせていただきました。
これをひとつひとつペイントで28×28サイズの画像に分けて使います。
実装
特に変わったところはないので、さらっといきます
ライブラリのインポート
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data %matplotlib inline np.random.seed(20171001) tf.set_random_seed(20171001) mnist = input_data.read_data_sets("./data/", one_hot=True)
畳み込み・プーリングの定義
cfilter_size = 16 input_size = 14*14*cfilter_size x = tf.placeholder(tf.float32, [None, 784]) x_image = tf.reshape(x, [-1,28,28,1]) cfilter = tf.Variable(tf.truncated_normal([5,5,1,cfilter_size],stddev=0.1)) cfiltered = tf.nn.conv2d(x_image, cfilter, strides=[1,1,1,1], padding='SAME') pfiltered = tf.nn.max_pool(cfiltered, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') pfiltered_flat = tf.reshape(pfiltered, [-1, input_size])
全結合、出力層の定義
hidden_size = 1024 w1 = tf.Variable(tf.truncated_normal([input_size, hidden_size])) b1 = tf.Variable(tf.zeros([hidden_size])) h = tf.nn.relu(tf.matmul(pfiltered_flat,w1) + b1) w0 = tf.Variable(tf.zeros([hidden_size, 784])) b0 = tf.Variable(tf.zeros([784])) y = tf.matmul(h, w0) + b0 t = tf.placeholder(tf.float32, [None, 784])
損失関数、学習の定義
loss = tf.reduce_sum(tf.square(t-y)) train_step = tf.train.AdamOptimizer().minimize(loss)
教師データの準備
from PIL import Image import os # Handy George filenames = os.listdir('./teacher2') imgs = [] for name in filenames: img = Image.open('./teacher2/' + name).convert('L') img.thumbnail((28, 28)) img = np.array(img, dtype=np.float32) img = 1-np.array(img / 255) img = img.reshape(1, 784) imgs.append(img) imgs = np.array(imgs) def label2im(labels): limgs = [] for label in labels: limgs.append(imgs[np.argmax(label)]) return np.array(limgs).reshape((-1, 784))
label2imは、mnistのラベルデータを受け取って、対応するフォント画像(を数値化し、配列にしたも)を返します。
word2vecみたいにおしゃれな名前を付けてみました。
学習
sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) i = 0 imgs = np.array(imgs).reshape([10, 784]) sess.run(pfiltered_flat, feed_dict={x: imgs}) for _ in range(20000): i += 1 batch_xs, batch_ts = mnist.train.next_batch(100) batch_ts = label2im(batch_ts) sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts}) if i % 1000 == 0: loss_val = sess.run(loss, feed_dict = {x: batch_xs, t: batch_ts}) print("step: %d, loss: %f" % (i, loss_val))
そこそこ時間がかかります。
mnist.test.imagesで試してみる
mnistのテストデータを対象に、変換してみたいと思います。
from PIL import Image c = 1 fig = plt.figure(figsize=(6, 8)) for img in mnist.test.images[0:10]: res = sess.run(y, feed_dict={x: [img]})[0] res = res - res[np.argmin(res)] res = res / res[np.argmax(res)] subplot = fig.add_subplot(5, 4, 2 * c - 1) subplot.set_xticks([]) subplot.set_yticks([]) subplot.set_title('before:%d' % (c-1)) subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest") subplot = fig.add_subplot(5, 4, 2*c) subplot.set_xticks([]) subplot.set_yticks([]) subplot.set_title('after:%d' % (c-1)) subplot.imshow(res.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest") c += 1
実行結果がこちらです
見て分かりますが、左がmnistのデータで、右がそれを変換したものです。
自分の手書き文字で試してみる
from PIL import Image import os filenames = os.listdir('./sample/test_samples') c = 1 fig = plt.figure(figsize=(6, 8)) labels = [] imgs = [] for name in filenames: img = Image.open("./sample/test_samples/" + name).convert('L') img.thumbnail((28, 28)) img = np.array(img, dtype=np.float32) img = 1-np.array(img / 255) img = img.reshape(1, 784) res = sess.run(y, feed_dict={x: img})[0] res = res - res[np.argmin(res)] res = res / res[np.argmax(res)] subplot = fig.add_subplot(5, 4, 2 * c - 1) subplot.set_xticks([]) subplot.set_yticks([]) subplot.set_title('before:%d' % (c-1)) subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest") subplot = fig.add_subplot(5, 4, 2*c) subplot.set_xticks([]) subplot.set_yticks([]) subplot.set_title('after:%d' % (c-1)) subplot.imshow(res.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest") c += 1
結果はこちらです。
9が苦手なのは一緒なんですね。
最後に
画像が分類できれば、それに対応するフォントを出すだけでいいのですが、画像の変換みたいなことが出来ないかなと思いやってみました。
mnistの画像に対しては上手くいくのですが、自前の画像だとなかなか上手くいかないときがありますね。
mnistが持つ何らかの特徴にオーバーフィッティングしているということでしょうか。
以前どこかで見た、絵画を有名画家風に変換するみたいなこともやってみたいですね。
以上、勉強の合間にちょっと遊んでみました。
読者登録をしていただけると、ブログを続ける励みになりますので、よろしくお願いします。