构建卷积神经网络 (CNN) 识别手写数字

接上一篇利用 Softmax 回归识别手写数字, 本文一步步构建一个卷积神经网络, 同样用于识别手写数字, 将取得惊人效果, 准确率高达 99.31%.

完整代码

# encoding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.examples.tutorials.mnist import input_data


def define_cnn(learning_rate=0.001):
    g = tf.Graph()
    with g.as_default():
        x = tf.placeholder(tf.float32, shape=(None, 784), name='image')
        y = tf.placeholder(tf.float32, shape=(None, 10), name='label')

        with tf.name_scope('reshape'):
            x_image = tf.reshape(x, [-1, 28, 28, 1])

        with tf.name_scope('conv1'):
            w_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))
            b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))
            h_conv1 = tf.nn.relu(tf.nn.conv2d(x_image, w_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1)

        with tf.name_scope('pool1'):
            h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        with tf.name_scope('conv2'):
            w_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1))
            b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))
            h_conv2 = tf.nn.relu(tf.nn.conv2d(h_pool1, w_conv2, strides=[1, 1, 1, 1], padding='SAME') + b_conv2)

        with tf.name_scope('pool2'):
            h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        with tf.name_scope('fc1'):
            w_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))
            b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
            h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
            h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

        with tf.name_scope('dropout'):
            keep_prob = tf.placeholder(tf.float32)
            h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

        with tf.name_scope('fc2'):
            w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
            b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
            y_ = tf.matmul(h_fc1_drop, w_fc2) + b_fc2

        with tf.name_scope('loss'):
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_)
            loss = tf.reduce_mean(cross_entropy)

        with tf.name_scope('train'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            train = optimizer.minimize(loss)

        with tf.name_scope('predict'):
            predict = tf.argmax(tf.nn.softmax(y_), 1)

        with tf.name_scope('accuracy'):
            label = tf.argmax(y, 1)
            correct = tf.equal(predict, label)
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    return [g, x, y, keep_prob, loss, train, predict, accuracy]


def main(argv):
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    [g, x, y, keep_prob, loss, train, predict, accuracy] = define_cnn(FLAGS.learning_rate)
    test_feed = {x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0}
    template = 'Batch: {}, loss: {:.5f}, accuracy: {:.2f}%'
    with tf.Session(graph=g) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        for i in range(1, FLAGS.train_step + 1):
            xs, ys = mnist.train.next_batch(FLAGS.batch_size)
            sess.run(train, feed_dict={x: xs, y: ys, keep_prob: 0.5})
            if i % FLAGS.display_step == 0:
                [curr_loss, curr_accuracy] = sess.run([loss, accuracy], feed_dict=test_feed)
                print(template.format(i, curr_loss, 100 * curr_accuracy))
        curr_test = np.argmax(mnist.test.labels, 1)
        curr_predict = sess.run(predict, feed_dict=test_feed)
        print(classification_report(curr_test, curr_predict, labels=range(10)))
        print(confusion_matrix(curr_test, curr_predict, labels=range(10)))


if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.flags.DEFINE_string('data_dir', 'MNIST_data', 'input data directory')
    tf.flags.DEFINE_integer('batch_size', 50, 'batch size')
    tf.flags.DEFINE_integer('train_step', 20000, 'training steps')
    tf.flags.DEFINE_integer('display_step', 100, 'display step')
    tf.flags.DEFINE_float('learning_rate', 1e-4, 'learning rate')
    FLAGS = tf.flags.FLAGS
    tf.app.run(main)

计算图

计算图

执行程序输出

Batch: 100, loss: 0.52098, accuracy: 83.99%
Batch: 200, loss: 0.32583, accuracy: 90.25%
Batch: 300, loss: 0.25579, accuracy: 92.39%
Batch: 400, loss: 0.21144, accuracy: 93.84%
Batch: 500, loss: 0.18473, accuracy: 94.77%
......
Batch: 19600, loss: 0.02193, accuracy: 99.27%
Batch: 19700, loss: 0.02318, accuracy: 99.22%
Batch: 19800, loss: 0.02183, accuracy: 99.25%
Batch: 19900, loss: 0.02433, accuracy: 99.24%
Batch: 20000, loss: 0.02342, accuracy: 99.31%
             precision    recall  f1-score   support

          0       0.99      1.00      0.99       980
          1       1.00      1.00      1.00      1135
          2       1.00      1.00      1.00      1032
          3       0.99      1.00      0.99      1010
          4       0.99      1.00      0.99       982
          5       0.99      0.99      0.99       892
          6       1.00      0.99      0.99       958
          7       0.99      0.99      0.99      1028
          8       0.99      0.99      0.99       974
          9       0.99      0.99      0.99      1009

avg / total       0.99      0.99      0.99     10000

[[ 977    0    0    0    0    0    1    1    1    0]
 [   0 1130    1    2    0    0    0    1    1    0]
 [   1    0 1029    0    0    0    0    2    0    0]
 [   0    0    0 1006    0    2    0    2    0    0]
 [   0    0    0    0  978    0    0    0    0    4]
 [   1    0    0    7    0  881    1    0    1    1]
 [   3    2    0    0    1    1  950    0    1    0]
 [   0    0    2    1    0    0    0 1021    1    3]
 [   2    0    2    2    0    1    0    2  962    3]
 [   1    0    0    1    5    3    0    1    1  997]]

参考文献