2016年9月28日水曜日

ChainerでMNISTするまで

1. なるべく最新のPythonを用意する
2. pip install chainerする
3. 以下を実行

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import training
from chainer.training import extensions
BATCH_SIZE = 50
EPOCH = 10
class MLP(chainer.Chain):
def __init__(self):
super(MLP, self).__init__(
l1 = L.Linear(784, 1000),
l2 = L.Linear(1000, 1000),
l3 = L.Linear(1000, 10)
)
def __call__(self, x):
h = F.dropout(F.relu(self.l1(x)), ratio=0.3)
h = F.dropout(F.relu(self.l2(h)), ratio=0.3)
return self.l3(h)
model = L.Classifier(MLP())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, batch_size=BATCH_SIZE)
test_iter = chainer.iterators.SerialIterator(test, batch_size=BATCH_SIZE, repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (EPOCH, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
view raw short-mnist.py hosted with ❤ by GitHub


以前Chainer使ったときは学習の設定とかなんだかんだ面倒くさいところがあったけど、
MNISTはデータを直接取りに行ける + trainer使うと学習の部分を簡単にかけるので、こういうMNISTとかはかなり簡単にかけるようになった印象。

30行もあればMNISTの学習ができてしまうというのは、デモンストレーションとかでは非常に良いのでは?