webエンジニアさもの挑戦

RubyやPython, JavaScriptなど勉強したことなど、IT関連の記事を書いています

ビットコイン取引高日本一の仮想通貨取引所 coincheck bitcoin

TensorFlow入門~MNIST for ML Beginners~

こんにちは、エンジニアのさもです。

TensorFlowで機械学習に入門してみた第4回目は、いよいよMNISTに挑戦します。

内容的には、TensorFlowチュートリアルの「MNIST For ML Beginners」に相当します。

また、mnistのテストデータでは92%ほどの正解率になりますが、本当に上手く分類できるのか、実際に私が書いた文字を判定してもらおうと思います。

スポンサーリンク

目次

MNIST・実装の概要

MNISTは手書き文字(0~9)を正しい数字に分類する、ディープラーニングでは有名なチュートリアル的な問題です。

今回はまず、隠れ層無しで実装してみます。

なので、実装的には前回紹介したロジスティック回帰とほぼ同じです。

前回と大きく異なる部分は、

  • 学習対象のデータは、tensorflowが用意する画像データ
  • 多分類なので、シグモイドではなく、ソフトマックス関数を使う
  • ミニバッチ法を使う

の3点です。

画像データをつかう

画像データを対象に学習を進めますが、実際は既に配列になっているものを利用するので、学習だけであれば特に意識すべきところはありません。

後で学習データの画像を見たりする場合には注意が必要です。

画像の表示にはpythonの画像ライブラリPILLOWを使います。PILLOWの基本的な使い方は、こちらを参考にしてみてください。

www.uosansatox.biz

分類に用いる関数

前回は1か0の2分類だったので、1になる確率をシグモイド関数を用いて求めました。

今回は、入力画像が10個の正解値にそれぞれどれくらいの確率で分類されるかを知るために、出力値を確率へ変換します。そのために、次のソフトマックス関数を使います。


f_{a}(Y) = \dfrac{e^{y_{a}}}{\sum e^{y_{i}}}

 Yは出力値の配列で、入力値の配列を X,パラメータの行列を W, バイアスを bとすると、

 Y = XW + b

となります。小文字の y_{i} Yの各成分で、 iには0~9が入ります。

 f_{3}(Y)は、画像が3だと判定された確率を表します。

誤差関数は、前回と同じクロスエントロピー関数ですが、少し見た目が異なります。


loss = -\sum_{t=0}^{9} labels(t) * \log(f_{t}(Y))

 labelsは正解値を表す、大きさ10のone-hotベクトルです。

したがって、もし正解値が3の場合は、


loss = -log(f_{3}(Y))

となります。 f_{3}(Y)の値が1の場合は、 lossは0となり、誤差が無いことを表します。

逆に確率が低くなると、 lossの値はどんどん大きくなります。これは、正解値だと判定される確率が低い場合に誤差が大きくなるという意味になります。

ミニバッチ法

画像の枚数が多いときは、一気に学習させずに、ミニバッチという単位で何枚かずつに分けて学習を進めていきます。

ミニバッチ法には、最小値でない「極小値」にはまってしまったときに、動けなくなってしまうことを防ぐ(可能性がある)というメリットがあります。

実装

今回もTensorBoardを使って可視化してみるので、使ったことが無いという方はこちらの記事を参考にしてみてください。

必要なライブラリのインポート

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline

np.random.seed(20171001)

mnistデータのロード

mnist = input_data.read_data_sets("./data/", one_hot=True)

ダウンロードしてくるフォルダなどはお使いの環境に合わせてください

変数の定義

with tf.name_scope('X'):
    x = tf.placeholder(tf.float32, [None, 784])
with tf.name_scope('W'):
    w = tf.Variable(tf.zeros([784, 10]))
with tf.name_scope('W0'):
    w0 = tf.Variable(tf.zeros([10]))
with tf.name_scope('P'):
    f = tf.matmul(x, w) + w0
    p = tf.nn.softmax(f)
with tf.name_scope('T'):
    t = tf.placeholder(tf.float32, [None, 10])
sess = tf.InteractiveSession()

誤差、学習方法、正解率の定義

with tf.name_scope('Loss'):
    loss = -tf.reduce_sum(t * tf.log(p))
with tf.name_scope('Train'):
    train_step = tf.train.AdamOptimizer().minimize(loss)
with tf.name_scope('Acc'):
    correct_prediction =  tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

TensorBoardで追跡する変数を定義

with tf.name_scope('summary'):
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('loss', loss)
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter('./logs', sess.graph)

学習

sess.run(tf.global_variables_initializer())

i = 0
for _ in range(20000):
    batch_xs, batch_ts = mnist.train.next_batch(100)
    __, summary = sess.run([train_step, merged], feed_dict={x: batch_xs, t: batch_ts})
    writer.add_summary(summary, _)
    if i % 1000 == 0:
        loss_val, acc_val = sess.run([loss, accuracy], feed_dict = {x: batch_xs, t: batch_ts})
        writer.add_summary(summary, _)
        print("step: %d, loss: %f, acc: %f" % (i, loss_val, acc_val))
    i += 1

mnist.train.next_batchで、ミニバッチ法を実現しています。next_batchは、いくつまで進んでいるか覚えていてくれるので、ループをまわすたびに新しいデータを返します。

学習を進めると、最終的に正解率は93%となりました。ですが、これは学習データへの正解率なので、テストデータに対しての正解率を見てみたいと思います。

テストデータで正解率を見る

sess.run(accuracy, feed_dict={x: mnist.test.images, t: mnist.test.labels})

コードはこれだけです。placeholderのありがたみがよく分かりますね!

結果は92.8%でした。

自分で書いた手書き文字を判定してみる

まず、準備として、paintなどで大きさ28×28の手書き文字を用意しておきます。

フォルダは、ソースコードからみて、./sample/test_samples/の場所に置きました。

from PIL import Image
import os

# フォルダ内のファイルの一覧を取得
filenames = os.listdir('./sample/test_samples')
c = 1
fig = plt.figure(figsize=(6, 5))
labels = []
imgs = []
for name in filenames:
    img = Image.open("./sample/test_samples/" + name) # 画像をオープンする
    img = img.convert('L')                            # グレースケールに変換
    img.thumbnail((28, 28))                           # 一応リサイズ 
    img = np.array(img, dtype=np.float32)             # 配列に変換
    img = 1-np.array(img / 255)                       # 0~1の値になるように変換
    img = img.reshape(1, 784)                         # 一列に並べる
    imgs.append(img)                                  # バッチに突っ込む
    label = np.array([0,0,0,0,0,0,0,0,0,0])           # 正解ラベルの準備
    label[c-1] = 1                                    # ファイル名をtest0とかにしてたのでラベル順に並んでる
    labels.append(label)                              # バッチに突っ込む
    test_p = sess.run(p, feed_dict={x: img})          # 確率を計算する
    subplot = fig.add_subplot(2, 5, c)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % (np.argmax(test_p)))     # タイトルに予測した数字を表示
    subplot.imshow(img.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1
# 正解率の表示
print(sess.run(accuracy, feed_dict={x: np.array(imgs).reshape((10, 784)), t: labels}))

これを実行すると、以下のような結果になりました。

f:id:s-uotani-zetakansu:20171003172245p:plain

正解率は6割でした。mnistの「クセ」みたいなのがあるのか、テストデータの92%までは程遠いですね。

実は、6割というのも結構良い方でして、何も考えずに画像を用意してみたところ、正解率は2割程度でした。

どこかコードがおかしいのかとはまってしまいました。

学習したパラメータを可視化してみる

最後に、学習したパラメータがどのような値になっているのか見てみたいと思います。

パラメータは、784行10列の行列ですが、各列はそれぞれ一つの数字(ラベル)の確率を計算するノードへつながっています。

そこで、各ノードへつながる784の要素を持つ配列を元に画像を生成しています。

w_val = sess.run(w)
w_val = w_val.reshape([10, 784])
fig = plt.figure(figsize=(10, 5))
c = 0
for w_raw in w_val:
    amax = np.argmax(w_raw)
    amin = np.argmin(w_raw)
    w_raw = w_raw - w_raw[amin]
    w_raw = w_raw / w_raw[amax]
    subplot = fig.add_subplot(2, 5, c+1)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % c)
    subplot.imshow(w_raw.reshape((28, 28)), vmin=0, vmax = 1, cmap=plt.cm.gray_r, interpolation="nearest")
    c += 1

以下が出力結果になります。

f:id:s-uotani-zetakansu:20171003172737p:plain

出してみたところでよく分からないですが、何らかのパターンを抽出しているように見えなくも無いです。(笑)

まとめ

少しずつ機械学習っぽくなってきましたね。

tensorflowの書き方については少し慣れてきたかなという感じです。

次回以降は、MNISTの精度をどんどん上げていきます。

以上、MNISTへ初めて挑戦してみた記事でした。

読者登録をしていただけると、ブログを続ける励みになりますので、よろしくお願いします。

このシリーズの記事一覧