TensorFlow でのデータ入力方法まとめ ②

あゆた Machine Learning 担当の佐々木です。

前回 に引き続き、 今日も Tensorflow のモデルにデータを食わせる手法についてのお話です。 前回の記事で基本的なやり方を一通りおさえたので、今回は feed_dictQueue を用いたハイブリッドなアプローチに ついて書いていこうと思います。

…という訳で以下、feed_dict を 使って Queue にデータを入れていく仕組みを実装していきます。 この記事も基本的にはこの記事 に書いてあることの要約 + 解説なので、英語が読める方は本家の方を御覧ください。


はじめに

前回の記事でも Queue を用いた実装を行いましたが、エンキューをシングルプロセスで行っていたので実はあまり速くない? ものになっていました。なので今回はその点を改善すべく、エンキューの処理をマルチスレッド並列化でやっていこうと思います。 また、エンキューのオペレーションには feed_dict を用います。


共通なところ

まずはいつもと変わらない、おなじみのところのコードを晒しておきます。
ここはほぼ前回の記事で作成したものと同じなので、そちらをご覧になられた方はスキップして頂いて大丈夫です。

def get_mnist_data():  
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    return mnist

class Model(object):  
    def __init__(self, custom_runner):
        self.custom_runner = custom_runner  # CustomRunner class's instance


    # --- Define utility functions ----
    # Weight and biases
    def weight_variable(self, shape):
        initial = tf.truncated_normal(shape, stddev= 0.1)
        return tf.Variable(initial)

    def bias_variable(self, shape):
        initial = tf.constant(0.1, shape= shape)
        return tf.Variable(initial)

    # Convolution and pooling
    def conv2d(self, input_, filter_):
        return tf.nn.conv2d(input_, filter_, strides=[1,1,1,1], padding= 'SAME')

    def max_pool_2x2(self, input_):
        return tf.nn.max_pool(input_, ksize= [1,2,2,1], strides= [1,2,2,1], padding= 'SAME')


    # ---- Get Tensorflow equations ----
    def model(self):
        # ---- Define feeding data operations ----
        -- !! ここはあとで書く !! --

        # ---- Get model ----
        model_y = self.cnn(x)

        # ---- Train and evaluate the model ----
        cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits= model_y, labels= y) )
        train_op      = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

        # ---- Run tensorflow ----
        config = tf.ConfigProto( intra_op_parallelism_threads= NUM_THREADS )

        sess = tf.Session( config= config )
        init = tf.global_variables_initializer()
        sess.run( init )

        return sess, train_op


    # ---- Convolutional Neural Network ----
    def cnn(self, x_image):

        # First convolutional layer
        w_conv1 = self.weight_variable([5, 5, 1, 32])
        b_conv1 = self.bias_variable([32])

        x_image = tf.reshape(x_image, [-1, 28, 28, 1])
        h_conv1 = tf.nn.relu( self.conv2d(x_image, w_conv1) + b_conv1 )
        h_pool1 = self.max_pool_2x2(h_conv1)

        # Second convolutional layer
        w_conv2 = self.weight_variable([5, 5, 32, 64])
        b_conv2 = self.bias_variable([64])

        h_conv2 = tf.nn.relu( self.conv2d(h_pool1, w_conv2) + b_conv2 )
        h_pool2 = self.max_pool_2x2(h_conv2)

        # Density connected layer
        w_fc1 = self.weight_variable([7 * 7 * 64, 1024])
        b_fc1 = self.bias_variable([1024])

        h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
        h_fc1        = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

        # Readout layer
        w_fc2 = self.weight_variable([1024, 10])
        b_fc2 = self.bias_variable([10])

        model_y = tf.matmul(h_fc1, w_fc2) + b_fc2

        return model_y

データ入力に関する部分

さて、いつもと変わらない部分を見ていただいたうえで、今度はデータ入力に関わる部分を実装していきます。 本記事のキモになる部分ですね!
以下、一連のコードは CustomRunner クラスの中に記述していきます。

''' Custom Queue Runners  
'''  
class CustomRunner(object):  
    ''' This class manages the background threads needed to fill
        a queue full of data.
    '''
    def __init__(self, mnist):
        # Queue で一度に取り出すデータ件数の単位。batch_size 個づつ取り出していく。
        self.batch_size   = 128

        # MNIST data
        self.train_images = mnist.train.images
        self.train_labels = mnist.train.labels
        self.test_images  = mnist.test.images
        self.test_labels  = mnist.test.labels

        # Define placeholder
        self.dataX = tf.placeholder( dtype= tf.float32, shape= [None, 784] )
        self.dataY = tf.placeholder( dtype= tf.int64,   shape= [None,  10] )

        # Queue の定義
        self.queue = tf.RandomShuffleQueue( shapes= [ [784], [10] ],
                                            dtypes= [ tf.float32, tf.int64 ],
                                            capacity= 20000,
                                            min_after_dequeue= 1000 )

        # EnQueue のオペレーション
        self.enqueue_op = self.queue.enqueue_many( [self.dataX, self.dataY] )
tf.RandomShuffleQueue

・shapes
Queue に入れるデータの構造を定義します。ここでは MNIST の画像データと、10クラス分類のラベルを保持するので、[ [784], [10] ] になっています。

・dtypes
shapes で指定した各データの型を格納します。

・capacity
Queue に入るデータ件数の上限

・min_after_dequeue
DeQueue した後に残る最小のデータ件数

queue.enqueue_many()

これは tf.RandomShuffleQueue オブジェクトのメソッドです。
enqueue_many() を使うと、リストで指定した各データを複数個一度にエンキューできます。 複数個一度にとは、len(self.dataX) が != 1 でも動くよってことです。


入力用データのジェネレータを作成

enqueue_many() へ入れるデータを生成するため、呼び出す度にエンキュー 1回分のデータを生成するジェネレータを作ります。

def data_iterator(self):  
    ''' A simple data iterator '''
    batch_idx = 0
    while True:
        # shuffle labels and features
        idxs = np.arange(0, len(self.train_images))
        np.random.shuffle(idxs)
        shuf_features = self.train_images[ idxs ]
        shuf_labels   = self.train_labels[ idxs ]

        for batch_idx in range(0, len(self.train_images), self.batch_size):
            images_batch = shuf_features[ batch_idx : batch_idx + self.batch_size ] / 255
            images_batch = images_batch.astype('float32')
            labels_batch = shuf_labels[ batch_idx : batch_idx + self.batch_size ]
            yield images_batch, labels_batch


マルチスレッドでデータを突っ込む

マルチスレッドでエンキューする部分を作っていきます。
まずは各スレッドで動作させるメソッドを作ります。

def thread_main(self, sess):  
    ''' Function run on alternate thread. Basically, keep adding data to the queue.
    '''
    for dataX, dataY in self.data_iterator():
        sess.run(self.enqueue_op, feed_dict= { self.dataX:dataX, self.dataY:dataY })

data_iterator() が無限ループでデータを生成し続けるので、生成されたデータを随時 enqueue_op していくだけです。

def start_threads(self, sess, n_threads= 4):  
    ''' Start background threads to read queue
    '''
    threads = []
    for n in range(n_threads):
        t = threading.Thread( target= self.thread_main, args= (sess,) )
        t.daemon = True  # thread will close when parent quits
        t.start()
        threads.append(t)
    return threads

python のマルチスレッドは上記のように書けばいけます。

標準で用意されている threading モジュールを import して使います。
threading.Thread() の第一引数には起動したスレッドで実行するメソッドを、第二引数にはそのメソッドに渡す引数をタプルで 渡します。daemon = True とすると、親プロセスが終了した時に自動でスレッドを閉じてくれます。

キューからの取り出し
def get_inputs(self):  
    ''' Return's tensors containing batch of images and labels
    '''
    images_batch, labels_batch = self.queue.dequeue_many( self.batch_size )
    return images_batch, labels_batch

キューからデータを取り出す際は、上記のようにすればおkです。
デキューのメソッドは Tensorflow の tf.RandomShuffleQueue クラスで用意されているので、それを利用するだけです。 dequeue_many() を用いると引数で入れたバッチサイズ個のデータを一度に取得してきます。


モデルの変更

def model(self):  
    # ---- Define feeding data operations ----
    with tf.device('/cpu:0'):
        x, y = self.custom_runner.get_inputs()

    # ---- Get model ----
    model_y = self.cnn(x)

    -- 以下略 --

ここまででデータの入出力を行う一連のメソッドは用意できたので、あとはこれを呼び出して利用する部分を作成していきます。 …といっても、ただデキューのメソッドを呼び出して、データを取り出すだけです。GPU のリソースを計算に集中させるために、 with tf.device('/cpu:0') で Queue の操作に関する処理は CPU で行うよう指定しておくのがポイントです。


動かしてみる

おなじみのところ
# MNIST data
mnist = get_mnist_data()

# Define tendorflow formula
custom_runner  = CustomRunner( mnist )  
model          = Model( custom_runner )  
sess, train_op = model.model()  

MNIST データの取得から Tensorflow の式を取得するまで。


学習を実施
# start QueueRunner related methods
tf.train.start_queue_runners(sess= sess)  
custom_runner.start_threads(sess, n_threads= 8)

# Training
for i in tqdm( range(20000) ):  
    sess.run( train_op )

tf.train.start_queue_runners() でキューを実行し、それからマルチスレッドでエンキューするメソッドを起動します。 ここでのポイントは tf.train.start_queue_runners() を先に起動しておくことです。キューが動いていない状態で エンキューするとエラーが出るので、ここでは順番をきちんと守りましょう。

for 文のところで使っている tqdm はプログレスバーを表示するライブラリで、1秒あたりのループ数も表示されるので学習がどの程度の速さで実行されているのか把握するのに役立ちます。使用する際は pip install tqdm し、from tqdm import tqdm すればおkです。

tf.train.start_queue_runners() を起動しておけば、あとは学習のオペレーション(train_op) を sess.run() するだけです。


参考文献

[1] https://indico.io/blog/tensorflow-data-input-part2-extensions/
[2] https://intheweb.io/tensorflow-denodetaru-li-fang-fa-matome-2/