Trainerを使ってみよう

Trainerを使うと学習ループを陽に書く必要がなくなります。またいろいろな便利なExtentionを使うことで可視化やログの保存などが楽になります。

[29]:
# Install Chainer and CuPy!

!curl https://colab.chainer.org/install | sh -
Reading package lists... Done
Building dependency tree
Reading state information... Done
libcusparse8.0 is already the newest version (8.0.61-1).
libnvrtc8.0 is already the newest version (8.0.61-1).
libnvtoolsext1 is already the newest version (8.0.61-1).
0 upgraded, 0 newly installed, 0 to remove and 1 not upgraded.
Requirement already satisfied: cupy-cuda80==4.0.0b3 from https://github.com/kmaehashi/chainer-colab/releases/download/2018-02-06/cupy_cuda80-4.0.0b3-cp36-cp36m-linux_x86_64.whl in /usr/local/lib/python3.6/dist-packages
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: fastrlock>=0.3 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: chainer==4.0.0b3 in /usr/local/lib/python3.6/dist-packages
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: protobuf>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from protobuf>=3.0.0->chainer==4.0.0b3)

1. データセットの準備

[ ]:
from chainer.datasets import mnist

train, test = mnist.get_mnist()

2. Iteratorの準備

[ ]:
from chainer import iterators

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, False, False)

3. Modelの準備

ここでは、先程と同じモデルを再度用います。

[ ]:
import chainer
import chainer.links as L
import chainer.functions as F

class MLP(chainer.Chain):

    def __init__(self, n_mid_units=100, n_out=10):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1=L.Linear(None, n_mid_units)
            self.l2=L.Linear(None, n_mid_units)
            self.l3=L.Linear(None, n_out)


    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

gpu_id = 0  # Set to -1 if you don't have a GPU

model = MLP()
if gpu_id >= 0:
    model.to_gpu(gpu_id)

4. Updaterの準備

Trainerは学習に必要な全てのものをひとまとめにするクラスです。それは主に以下のようなものを保持します。

  • Updater
    • Iterator
      • Dataset
    • Optimizer
      • Model

Trainerオブジェクトを作成するときに渡すのは基本的にUpdaterだけですが、Updaterは中にIteratorOptimizerを持っています。Iteratorからはデータセットにアクセスすることができ、Optimizerは中でモデルへの参照を保持しているので、モデルのパラメータを更新することができます。つまり、Updaterが内部で

  1. データセットからデータを取り出し(Iterator)
  2. モデルに渡してロスを計算し(Model = Optimizer.target)
  3. Optimizerを使ってモデルのパラメータを更新する(Optimizer)

という一連の学習の主要部分を行うことができるということです。では、Updaterオブジェクトを作成してみます。

[ ]:
from chainer import optimizers
from chainer import training

max_epoch = 10

# モデルをClassifierで包んで、ロスの計算などをモデルに含める
model = L.Classifier(model)

if gpu_id >= 0:
    model.to_gpu(gpu_id)

# 最適化手法の選択
optimizer = optimizers.SGD()
optimizer.setup(model)

# UpdaterにIteratorとOptimizerを渡す
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

NOTE

ここで、上で定義したモデルのオブジェクトをL.Classifierに渡して、新しいChainにしています。L.ClassifierChainを継承したクラスで、渡されたChainpredictorというプロパティに保存します。()アクセサでデータとラベルを渡すと、中で__call__が実行され、まず渡されたデータの方をpredictorに通し、その出力yと同じく渡されていたラベルを、コンストラクタのlossfun引数で指定されたロス関数に渡して、その出力Variableを返します。lossfunはデフォルトでsoftmax_cross_entropyに指定されています。

StandardUpdaterは前述のUpdaterの行う処理を遂行する最もシンプルなクラスです。この他にも複数のGPUを用いるためのParallelUpdaterなどが用意されています。

5. Trainerの設定

最後に、Trainerの設定を行います。Trainerのオブジェクトを作成する際に必須となるのは、先程作成したUpdaterオブジェクトだけですが、二番目の引数stop_triggerに学習をどのタイミングで終了するかを表す(長さ, 単位)という形のタプルを与えると、指定したタイミングで学習を自動的に終了することができます。長さには任意の整数、単位には'epoch''iteration'のいずれかの文字列を指定できます。stop_triggerを指定しない場合、学習は自動的には止まりません。

[ ]:
# TrainerにUpdaterを渡す
trainer = training.Trainer(updater, (max_epoch, 'epoch'),
                           out='mnist_result')

out引数では、この次に説明するExtensionを使って、ログファイルやロスの変化の過程を描画したグラフの画像ファイルなどを保存するディレクトリを指定しています。

6. TrainerにExtensionを追加する

Trainerを使う利点として、

  • ログを自動的にファイルに保存(LogReport)
  • ターミナルに定期的にロスなどの情報を表示(PrintReport
  • ロスを定期的にグラフで可視化して画像として保存(PlotReport)
  • 定期的にモデルやOptimizerの状態を自動シリアライズ(snapshot/snapshot_object
  • 学習の進捗を示すプログレスバーを表示(ProgressBar
  • モデルの構造をGraphvizのdot形式で保存(dump_graph

などなどの様々な便利な機能を簡単に利用することができる点があります。これらの機能を利用するには、Trainerオブジェクトに対してextendメソッドを使って追加したいExtensionのオブジェクトを渡してやるだけです。では実際に幾つかのExtensionを追加してみましょう。

[ ]:
from chainer.training import extensions

trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id))
trainer.extend(extensions.dump_graph('main/loss'))

LogReport

epochiterationごとのloss, accuracyなどを自動的に集計し、Trainerout引数で指定した出力ディレクトリにlogというファイル名で保存します。

PrintReport

Reporterによって集計された値を標準出力に出力します。このときどの値を出力するかを、リストの形で与えます。

PlotReport

引数のリストで指定された値の変遷をmatplotlibライブラリを使ってグラフに描画し、出力ディレクトリにfile_name引数で指定されたファイル名で画像として保存します。

snapshot

Trainerout引数で指定した出力ディレクトリにTrainerオブジェクトを指定されたタイミング(デフォルトでは1エポックごと)に保存します。Trainerオブジェクトは上述のようにUpdaterを持っており、この中にOptimizerとモデルが保持されているため、このExtensionでスナップショットをとっておけば、学習の復帰や学習済みモデルを使った推論などが学習終了後にも可能になります。

snapshot_object

しかし、Trainerごと保存した場合、しばしば中身のモデルだけ取り出すのが面倒な場合があります。そこで、snapshot_objectを使って指定したオブジェクト(ここではClassifierで包まれたモデル)だけを、Trainerとは別に保存するようにします。Classifierは第1引数に渡されたChainオブジェクトを自身のpredictorというプロパティとして保持してロスの計算を行うChainであり、Classifierはそもそもモデル以外にパラメータを持たないので、ここでは後々学習済みモデルを推論に使うことを見越してmodel.predictorを保存対象として指定しています。

Evaluator

評価用のデータセットのIteratorと、学習に使うモデルのオブジェクトを渡しておくことで、学習中のモデルを指定されたタイミングで評価用データセットを用いて評価します。

dump_graph

指定されたVariableオブジェクトから辿れる計算グラフをGraphvizのdot形式で保存します。保存先はTrainerout引数で指定した出力ディレクトリです。


これらのExtensionは、ここで紹介した以外にも、例えばtriggerによって個別に作動するタイミングを指定できるなどのいくつかのオプションを持っており、より柔軟に組み合わせることができます。詳しくは公式のドキュメントを見てください:Trainer extensions

7. 学習を開始する

学習を開始するには、Trainerオブジェクトのメソッドrunを呼ぶだけです。

[36]:
trainer.run()
epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
1           1.55444     0.617104       0.793979              0.818335                  2.89638
2           0.6137      0.843384       0.469892              0.873517                  6.23138
3           0.438097    0.88088        0.377211              0.895767                  9.54007
4           0.373241    0.896234       0.336282              0.904866                  13.0217
5           0.338238    0.904568       0.307831              0.912085                  16.3388
6           0.314541    0.910048       0.288398              0.918216                  19.7134
7           0.297155    0.915528       0.275047              0.921974                  23.0603
8           0.282538    0.920022       0.262142              0.924644                  26.3866
9           0.270557    0.922625       0.252849              0.927809                  29.772
10          0.259736    0.92544        0.244242              0.928501                  33.1618

保存されているロスのグラフを確認してみましょう。

[37]:
from IPython.display import Image
Image(filename='mnist_result/loss.png')
[37]:
../../../../_images/notebook_hands_on_chainer_begginers_hands_on_12_Try_Trainer_class_20_0.png

精度のグラフも見てみましょう。

[38]:
Image(filename='mnist_result/accuracy.png')
[38]:
../../../../_images/notebook_hands_on_chainer_begginers_hands_on_12_Try_Trainer_class_22_0.png

ついでに、dump_graphというExtensionが出力した計算グラフを、Graphvizを使って画像化して見てみましょう。

[39]:
!apt-get install graphviz -y
!dot -Tpng mnist_result/cg.dot -o mnist_result/cg.png
Reading package lists... Done
Building dependency tree
Reading state information... Done
graphviz is already the newest version (2.38.0-16ubuntu2).
0 upgraded, 0 newly installed, 0 to remove and 1 not upgraded.
[40]:
Image(filename='mnist_result/cg.png')
[40]:
../../../../_images/notebook_hands_on_chainer_begginers_hands_on_12_Try_Trainer_class_25_0.png

上から下へ向かって、データやパラメータがどのようなFunctionに渡されて計算が行われ、ロスが出力されたかが分かります。

8. 学習済みモデルで推論する

[41]:
import numpy as np
from chainer import serializers
from chainer.cuda import to_gpu
from chainer.cuda import to_cpu

model = MLP()
serializers.load_npz('mnist_result/model_epoch-10', model)

%matplotlib inline
import matplotlib.pyplot as plt

x, t = test[0]
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()
print('label:', t)

if gpu_id >= 0:
    model.to_gpu(gpu_id)
    x = to_gpu(x[None, ...])
    y = model(x)
    y = to_cpu(y.data)
else:
    x = x[None, ...]
    y = model(x)
    y = y.data

print('predicted_label:', y.argmax(axis=1)[0])
../../../../_images/notebook_hands_on_chainer_begginers_hands_on_12_Try_Trainer_class_28_0.png
label: 7
predicted_label: 7

無事正解できました。