webエンジニアの日常

RubyやPython, JSなど、IT関連の記事を書いています

初めてのTensorFlow入門~畳み込みニューラルネットワーク~

こんにちは、エンジニアのさもです。

前回は隠れ層を一層追加した、単層ニューラルネットワークを構築しました。

今回はいよいよ畳み込み層とプーリング層を追加して、畳み込みニューラルネットワークを構築していきたいと思います。

コードはこちらの書籍をお手本にしています。

TensorFlowで学ぶディープラーニング入門 ~畳み込みニューラルネットワーク徹底解説~

TensorFlowで学ぶディープラーニング入門 ~畳み込みニューラルネットワーク徹底解説~

スポンサーリンク

畳み込みニューラルネットワークの概要

なぜ畳み込みとかするのか

前回までの実装では、画像(画素値)を一列に並べた配列につて学習を行ってきました。

ですが、これだと画像の横の関連は学習できても、縦の関連は学習できません。

そこで、画像の形(二次配列)で情報を学習させる必要があります。

そこで編み出されたのが、畳み込みニューラルネットワークです。

畳み込みニューラルネットワークでは、画像そのものを学習させるのではなく、画像から抽出した特徴を使って学習を行います。

画像から特徴を抽出する処理が、畳み込み層の役割です。

畳み込み

畳み込みは、一言で言えば、フィルターです。

画像処理でぼかしフィルタや鮮鋭化フィルタがありますよね。あれのことです。

フィルタに関してはいくつか過去に記事を書いていたので、そちらを参照してみてください。

www.uosansatox.biz

画像処理関連の記事の目次はこちらに作りました。

www.uosansatox.biz

特徴を抽出させるだけなら、私たちでフィルターを考えて、特徴抽出した画像を前回使った単層ニューラルネットワークで学習させればよいです。

ですが、対象の画像を判別するのに最適なフィルタを毎回考えられるわけではありません。

そこで、フィルタも学習パラメータにしてしまおうというのが、畳み込みニューラルネットワークのアイデアです。

畳み込みの処理を言葉で説明すると以下のような感じになります。

* 画像より少し小さいサイズの二次配列を用意する<-これがフィルタ
# 画像の全ての画素について以下を行う
* 画素とその周辺(フィルタと同じサイズの小領域)とフィルタを掛け算する(同じ場所にある値同士)
* 掛け算した結果の合計を取る
* 新しい二次配列を用意し、対象としている画素とおんなじ場所に合計値を入れる

イメージとしてはこんな感じです。

f:id:s-uotani-zetakansu:20171005165930p:plain

畳み込みで使うフィルタ(これはカーネルという)は普通複数枚用意します。今回は16枚です。

TensorFlowでの流れはこんな感じになります。

f:id:s-uotani-zetakansu:20171005164644p:plain

プーリング

畳み込み層でめでたく特徴抽出することが出来ました。特徴抽出した結果の画像を特徴マップと言います。

プーリング層では、特徴が少しずれた位置にあっても正しく判定できるように、特徴の要約を行います。

具体的には、特徴マップを同じサイズの小領域に分けて、それぞれの領域の最大の値を集めた、新しい画像を作ります

f:id:s-uotani-zetakansu:20171005173442p:plain

プーリング層ではパラメータは無いので、処理のみになります。

f:id:s-uotani-zetakansu:20171005173531p:plain

実装

プーリング層の出力はまだ二次配列なので、これを1次配列へ変換します。

変換後の配列が、前回までの実装の入力になるわけです。

実装としては、前回の実装でいうと入力 Xの定義の前に、畳み込み層とプーリング層を追加することが主な変更点です。

ライブラリのインポート

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline

np.random.seed(20171001)
tf.set_random_seed(20171001)
mnist = input_data.read_data_sets("./data/", one_hot=True)

畳み込み層・プーリング層の定義

cfilter_size = 16
input_size = 14*14*cfilter_size

with tf.name_scope('Input'):
    x = tf.placeholder(tf.float32, [None, 784])
    x_image = tf.reshape(x, [-1,28,28,1])
with tf.name_scope('Conv'):
    cfilter = tf.Variable(tf.truncated_normal([5,5,1,cfilter_size],stddev=0.1))
    cfiltered = tf.nn.conv2d(x_image, cfilter, strides=[1,1,1,1], padding='SAME')
with tf.name_scope('Pooling'):
    pfiltered = tf.nn.max_pool(cfiltered, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    pfiltered_flat = tf.reshape(pfiltered, [-1, input_size])

cfilter_sizeはフィルタの枚数です。

フィルタの大きさは、5×5にしています。

TensorFlowでは畳み込み処理が関数として用意されている(tf.nn.conv2d)ので、それを使います。

また、プーリング処理も、tf.nn.max_poolとして用意されています。最大値のほかに、小領域の平均を取る、平均値プーリングを使うことも出来ます。

プーリング層の最後で、二次配列から一次配列へ変形しています。

隠れ層以降の変数の定義

hidden_size = 1024

with tf.name_scope('Hidden'):
    w1 = tf.Variable(tf.truncated_normal([input_size, hidden_size]))
    b1 = tf.Variable(tf.zeros([hidden_size]))
    z = tf.nn.relu(tf.matmul(pfiltered_flat, w1) + b1)
with tf.name_scope('Output'):
    w0 = tf.Variable(tf.zeros([hidden_size, 10]))
    b0 = tf.Variable(tf.zeros([10]))
    p = tf.nn.softmax(tf.matmul(z, w0) + b0)

with tf.name_scope('T'):
    t = tf.placeholder(tf.float32, [None, 10])

損失関数・学習方法・正解率の定義

with tf.name_scope('Loss'):
    loss = -tf.reduce_sum(t * tf.log(p))
with tf.name_scope('Train'):
    train_step = tf.train.AdamOptimizer().minimize(loss)
with tf.name_scope('Acc'):
    correct_prediction =  tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

TensorBoardで表示するサマリの定義

sess = tf.InteractiveSession()
with tf.name_scope('summary'):
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('loss', loss)
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter('./logs', sess.graph)

学習

sess.run(tf.global_variables_initializer())
i = 0
for _ in range(2000):
    batch_xs, batch_ts = mnist.train.next_batch(100)
    __, summary = sess.run([train_step, merged], feed_dict={x: batch_xs, t: batch_ts})
    
    if i % 100 == 0:
        loss_val, acc_val = sess.run([loss, accuracy], feed_dict = {x: batch_xs, t: batch_ts})
        writer.add_summary(summary, _)
        print("step: %d, loss: %f, acc: %f, test: %f" % (i, loss_val, acc_val, test_acc))
    i += 1

学習を実行すると、最終的に学習データに対して、正解率100%, テストデータに対して98%という結果になりました。

自分で書いた文字の判別

from PIL import Image
import os
filenames = os.listdir('./sample/test_samples')
c = 1
fig = plt.figure(figsize=(6, 5))
labels = []
imgs = []
for name in filenames:
    img = Image.open("./sample/test_samples/" + name).convert('L')
    img.thumbnail((28, 28))
    img = np.array(img, dtype=np.float32)
    img = 1-np.array(img / 255)
    img = img.reshape(1, 784)
    imgs.append(img)
    label = np.array([0,0,0,0,0,0,0,0,0,0])
    label[c-1] = 1
    labels.append(label)
    test_p = sess.run(p, feed_dict={x: img})
    subplot = fig.add_subplot(2, 5, c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % (np.argmax(test_p)))
    subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1
print(sess.run(accuracy, feed_dict={x: np.array(imgs).reshape((10, 784)), t: labels}))

以下実行結果です。

f:id:s-uotani-zetakansu:20171005175714p:plain

正解率がとうとう9割になりました。やっぱり9が苦手みたいですね。

ちなみに、こちらの9なら上手く判定してくれました。

f:id:s-uotani-zetakansu:20171005175803p:plain

畳み込みフィルタの可視化

畳み込み層で学習したフィルタも可視化してみます

fil = sess.run(cfilter).reshape([5,5,16]).transpose([2,0,1])
c = 1
fig = plt.figure(figsize=(8, 3))
for f in fil:
    flatten_f = f.reshape([25, -1])
    flatten_f = flatten_f - flatten_f[np.argmin(f)]
    flatten_f = flatten_f / flatten_f[np.argmax(flatten_f)]
    test_p = sess.run(p, feed_dict={x: img})
    subplot = fig.add_subplot(2, 8, c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % c)
    subplot.imshow(flatten_f.reshape((5, 5)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1

f:id:s-uotani-zetakansu:20171005180029p:plain

抽出された特徴の可視化

続いて、上記のフィルタをかけた画像がどうなるのか、見てみたいと思います

number = 3
filtered_img = sess.run(cfiltered, feed_dict={x: np.array([imgs[number]]).reshape((1, 784))}).reshape([28,28,16]).transpose([2,0,1])
c = 1
fig = plt.figure(figsize=(8, 3))
for f in filtered_img:
    flatten_f = f.reshape([784, -1])
    flatten_f = flatten_f - flatten_f[np.argmin(f)]
    flatten_f = flatten_f / flatten_f[np.argmax(flatten_f)]
    test_p = sess.run(p, feed_dict={x: img})
    subplot = fig.add_subplot(2, 8, c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % c)
    subplot.imshow(flatten_f.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1

f:id:s-uotani-zetakansu:20171005180242p:plain

9番、14番、16番あたりのフィルタが顕著ですが、全体的にエッジ抽出みたいなことをしてますね。

斜め方向のエッジを抽出しているのが多いように見えます。

最後に

とうとう畳み込みニューラルネットワークを実装してしまいました。

自分で書いた手書き文字をもう少し多くして実験してみたいですね。

学習時間は思ったほどかかりませんでした。

巷ではよく機械学習にはGPCが必須なんてことを耳にします(私はこれを聞いて二の足を踏んでました。)が、層が浅いうちはノートパソコンのスペックでも十分実験できました。

次回は、畳み込み層とプーリング層、全結合層をもう一層ずつ追加して、実験してみたいと思います。

まだ予定ですが、次々回以降、kaggleに挑戦するか、Kerasで実装してみるか、はたまたRNNを実装してみるか、どれかの方向に進みます。

読者登録をしていただけると、ブログを続ける励みになりますので、よろしくお願いします。