2. pip install chainerする
3. 以下を実行
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
from chainer import functions as F | |
from chainer import links as L | |
from chainer import training | |
from chainer.training import extensions | |
BATCH_SIZE = 50 | |
EPOCH = 10 | |
class MLP(chainer.Chain): | |
def __init__(self): | |
super(MLP, self).__init__( | |
l1 = L.Linear(784, 1000), | |
l2 = L.Linear(1000, 1000), | |
l3 = L.Linear(1000, 10) | |
) | |
def __call__(self, x): | |
h = F.dropout(F.relu(self.l1(x)), ratio=0.3) | |
h = F.dropout(F.relu(self.l2(h)), ratio=0.3) | |
return self.l3(h) | |
model = L.Classifier(MLP()) | |
optimizer = chainer.optimizers.Adam() | |
optimizer.setup(model) | |
train, test = chainer.datasets.get_mnist() | |
train_iter = chainer.iterators.SerialIterator(train, batch_size=BATCH_SIZE) | |
test_iter = chainer.iterators.SerialIterator(test, batch_size=BATCH_SIZE, repeat=False, shuffle=False) | |
updater = training.StandardUpdater(train_iter, optimizer) | |
trainer = training.Trainer(updater, (EPOCH, 'epoch'), out='result') | |
trainer.extend(extensions.Evaluator(test_iter, model)) | |
trainer.extend(extensions.LogReport()) | |
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy'])) | |
trainer.extend(extensions.ProgressBar()) | |
trainer.run() |
以前Chainer使ったときは学習の設定とかなんだかんだ面倒くさいところがあったけど、
MNISTはデータを直接取りに行ける + trainer使うと学習の部分を簡単にかけるので、こういうMNISTとかはかなり簡単にかけるようになった印象。
30行もあればMNISTの学習ができてしまうというのは、デモンストレーションとかでは非常に良いのでは?
0 件のコメント:
コメントを投稿