Trainerを使ってみよう¶
Trainer
を使うと学習ループを陽に書く必要がなくなります。またいろいろな便利なExtentionを使うことで可視化やログの保存などが楽になります。
[29]:
# Install Chainer and CuPy!
!curl https://colab.chainer.org/install | sh -
Reading package lists... Done
Building dependency tree
Reading state information... Done
libcusparse8.0 is already the newest version (8.0.61-1).
libnvrtc8.0 is already the newest version (8.0.61-1).
libnvtoolsext1 is already the newest version (8.0.61-1).
0 upgraded, 0 newly installed, 0 to remove and 1 not upgraded.
Requirement already satisfied: cupy-cuda80==4.0.0b3 from https://github.com/kmaehashi/chainer-colab/releases/download/2018-02-06/cupy_cuda80-4.0.0b3-cp36-cp36m-linux_x86_64.whl in /usr/local/lib/python3.6/dist-packages
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: fastrlock>=0.3 in /usr/local/lib/python3.6/dist-packages (from cupy-cuda80==4.0.0b3)
Requirement already satisfied: chainer==4.0.0b3 in /usr/local/lib/python3.6/dist-packages
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: protobuf>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from chainer==4.0.0b3)
Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from protobuf>=3.0.0->chainer==4.0.0b3)
1. データセットの準備¶
[ ]:
from chainer.datasets import mnist
train, test = mnist.get_mnist()
2. Iteratorの準備¶
[ ]:
from chainer import iterators
batchsize = 128
train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, False, False)
3. Modelの準備¶
ここでは、先程と同じモデルを再度用います。
[ ]:
import chainer
import chainer.links as L
import chainer.functions as F
class MLP(chainer.Chain):
def __init__(self, n_mid_units=100, n_out=10):
super(MLP, self).__init__()
with self.init_scope():
self.l1=L.Linear(None, n_mid_units)
self.l2=L.Linear(None, n_mid_units)
self.l3=L.Linear(None, n_out)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
gpu_id = 0 # Set to -1 if you don't have a GPU
model = MLP()
if gpu_id >= 0:
model.to_gpu(gpu_id)
4. Updaterの準備¶
Trainerは学習に必要な全てのものをひとまとめにするクラスです。それは主に以下のようなものを保持します。
- Updater
- Iterator
- Dataset
- Optimizer
- Model
- Iterator
Trainer
オブジェクトを作成するときに渡すのは基本的にUpdater
だけですが、Updater
は中にIterator
とOptimizer
を持っています。Iterator
からはデータセットにアクセスすることができ、Optimizer
は中でモデルへの参照を保持しているので、モデルのパラメータを更新することができます。つまり、Updater
が内部で
- データセットからデータを取り出し(Iterator)
- モデルに渡してロスを計算し(Model = Optimizer.target)
- Optimizerを使ってモデルのパラメータを更新する(Optimizer)
という一連の学習の主要部分を行うことができるということです。では、Updater
オブジェクトを作成してみます。
[ ]:
from chainer import optimizers
from chainer import training
max_epoch = 10
# モデルをClassifierで包んで、ロスの計算などをモデルに含める
model = L.Classifier(model)
if gpu_id >= 0:
model.to_gpu(gpu_id)
# 最適化手法の選択
optimizer = optimizers.SGD()
optimizer.setup(model)
# UpdaterにIteratorとOptimizerを渡す
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)
NOTE¶
ここで、上で定義したモデルのオブジェクトをL.Classifier
に渡して、新しいChain
にしています。L.Classifier
はChain
を継承したクラスで、渡されたChain
をpredictor
というプロパティに保存します。()
アクセサでデータとラベルを渡すと、中で__call__
が実行され、まず渡されたデータの方をpredictor
に通し、その出力y
と同じく渡されていたラベルを、コンストラクタのlossfun
引数で指定されたロス関数に渡して、その出力Variable
を返します。lossfun
はデフォルトでsoftmax_cross_entropy
に指定されています。
StandardUpdater
は前述のUpdaterの行う処理を遂行する最もシンプルなクラスです。この他にも複数のGPUを用いるためのParallelUpdater
などが用意されています。
5. Trainerの設定¶
最後に、Trainer
の設定を行います。Trainer
のオブジェクトを作成する際に必須となるのは、先程作成したUpdater
オブジェクトだけですが、二番目の引数stop_trigger
に学習をどのタイミングで終了するかを表す(長さ, 単位)
という形のタプルを与えると、指定したタイミングで学習を自動的に終了することができます。長さには任意の整数、単位には'epoch'
か'iteration'
のいずれかの文字列を指定できます。stop_trigger
を指定しない場合、学習は自動的には止まりません。
[ ]:
# TrainerにUpdaterを渡す
trainer = training.Trainer(updater, (max_epoch, 'epoch'),
out='mnist_result')
out
引数では、この次に説明するExtension
を使って、ログファイルやロスの変化の過程を描画したグラフの画像ファイルなどを保存するディレクトリを指定しています。
6. TrainerにExtensionを追加する¶
Trainer
を使う利点として、
- ログを自動的にファイルに保存(
LogReport
) - ターミナルに定期的にロスなどの情報を表示(
PrintReport
) - ロスを定期的にグラフで可視化して画像として保存(
PlotReport
) - 定期的にモデルやOptimizerの状態を自動シリアライズ(
snapshot
/snapshot_object
) - 学習の進捗を示すプログレスバーを表示(
ProgressBar
) - モデルの構造をGraphvizのdot形式で保存(
dump_graph
)
などなどの様々な便利な機能を簡単に利用することができる点があります。これらの機能を利用するには、Trainer
オブジェクトに対してextend
メソッドを使って追加したいExtension
のオブジェクトを渡してやるだけです。では実際に幾つかのExtension
を追加してみましょう。
[ ]:
from chainer.training import extensions
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id))
trainer.extend(extensions.dump_graph('main/loss'))
LogReport
¶
epoch
やiteration
ごとのloss
, accuracy
などを自動的に集計し、Trainer
のout
引数で指定した出力ディレクトリにlog
というファイル名で保存します。
PrintReport
¶
Reporter
によって集計された値を標準出力に出力します。このときどの値を出力するかを、リストの形で与えます。
PlotReport
¶
引数のリストで指定された値の変遷をmatplotlib
ライブラリを使ってグラフに描画し、出力ディレクトリにfile_name
引数で指定されたファイル名で画像として保存します。
snapshot
¶
Trainer
のout
引数で指定した出力ディレクトリにTrainer
オブジェクトを指定されたタイミング(デフォルトでは1エポックごと)に保存します。Trainer
オブジェクトは上述のようにUpdater
を持っており、この中にOptimizer
とモデルが保持されているため、このExtension
でスナップショットをとっておけば、学習の復帰や学習済みモデルを使った推論などが学習終了後にも可能になります。
snapshot_object
¶
しかし、Trainer
ごと保存した場合、しばしば中身のモデルだけ取り出すのが面倒な場合があります。そこで、snapshot_object
を使って指定したオブジェクト(ここではClassifier
で包まれたモデル)だけを、Trainer
とは別に保存するようにします。Classifier
は第1引数に渡されたChain
オブジェクトを自身のpredictor
というプロパティとして保持してロスの計算を行うChain
であり、Classifier
はそもそもモデル以外にパラメータを持たないので、ここでは後々学習済みモデルを推論に使うことを見越してmodel.predictor
を保存対象として指定しています。
Evaluator
¶
評価用のデータセットのIterator
と、学習に使うモデルのオブジェクトを渡しておくことで、学習中のモデルを指定されたタイミングで評価用データセットを用いて評価します。
dump_graph
¶
指定されたVariable
オブジェクトから辿れる計算グラフをGraphvizのdot形式で保存します。保存先はTrainer
のout
引数で指定した出力ディレクトリです。
これらのExtension
は、ここで紹介した以外にも、例えばtrigger
によって個別に作動するタイミングを指定できるなどのいくつかのオプションを持っており、より柔軟に組み合わせることができます。詳しくは公式のドキュメントを見てください:Trainer extensions。
7. 学習を開始する¶
学習を開始するには、Trainer
オブジェクトのメソッドrun
を呼ぶだけです。
[36]:
trainer.run()
epoch main/loss main/accuracy validation/main/loss validation/main/accuracy elapsed_time
1 1.55444 0.617104 0.793979 0.818335 2.89638
2 0.6137 0.843384 0.469892 0.873517 6.23138
3 0.438097 0.88088 0.377211 0.895767 9.54007
4 0.373241 0.896234 0.336282 0.904866 13.0217
5 0.338238 0.904568 0.307831 0.912085 16.3388
6 0.314541 0.910048 0.288398 0.918216 19.7134
7 0.297155 0.915528 0.275047 0.921974 23.0603
8 0.282538 0.920022 0.262142 0.924644 26.3866
9 0.270557 0.922625 0.252849 0.927809 29.772
10 0.259736 0.92544 0.244242 0.928501 33.1618
保存されているロスのグラフを確認してみましょう。
[37]:
from IPython.display import Image
Image(filename='mnist_result/loss.png')
[37]:

精度のグラフも見てみましょう。
[38]:
Image(filename='mnist_result/accuracy.png')
[38]:

ついでに、dump_graph
というExtension
が出力した計算グラフを、Graphviz
を使って画像化して見てみましょう。
[39]:
!apt-get install graphviz -y
!dot -Tpng mnist_result/cg.dot -o mnist_result/cg.png
Reading package lists... Done
Building dependency tree
Reading state information... Done
graphviz is already the newest version (2.38.0-16ubuntu2).
0 upgraded, 0 newly installed, 0 to remove and 1 not upgraded.
[40]:
Image(filename='mnist_result/cg.png')
[40]:

上から下へ向かって、データやパラメータがどのようなFunction
に渡されて計算が行われ、ロスが出力されたかが分かります。
8. 学習済みモデルで推論する¶
[41]:
import numpy as np
from chainer import serializers
from chainer.cuda import to_gpu
from chainer.cuda import to_cpu
model = MLP()
serializers.load_npz('mnist_result/model_epoch-10', model)
%matplotlib inline
import matplotlib.pyplot as plt
x, t = test[0]
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()
print('label:', t)
if gpu_id >= 0:
model.to_gpu(gpu_id)
x = to_gpu(x[None, ...])
y = model(x)
y = to_cpu(y.data)
else:
x = x[None, ...]
y = model(x)
y = y.data
print('predicted_label:', y.argmax(axis=1)[0])

label: 7
predicted_label: 7
無事正解できました。