普段 chainer を使っているんですが、最近 ver 2.0 に大幅アップデートがあり、自分のコードも2.0に対応するさせるために雛形を作ってみたのでお試し更新。
とりあえず MNIST をやるとして、trainerで抽象化させすぎると使いにくいので、custom loopをベースに自分がハンドリングしやすいようにアレンジしてみる。
https://github.com/chainer/chainer/tree/master/examples/mnist
変更点
以下を参照してみると
Upgrade Guide from v1 to v2 — Chainer 2.0.0 documentation
僕のソースコードだと chainer 2.0 のクリティカルなアップデートは、dropout などの train フラグと Variable の volatile の廃止。
とりあえずネットワークは dropout を噛ませた以下にしてみる。
class MTNNet(chainer.Chain): def __init__(self, n_mid, n_out): super(MTNNet, self).__init__() with self.init_scope(): self.lin1 = L.Linear(None, n_mid) self.lin2 = L.Linear(None, n_out) def __call__(self, x): h1 = self.lin1(x) h2 = F.relu(h1) h3 = F.dropout(h2) y = self.lin2(h3) return y def loss(self, x, t): y = self(x) loss = F.softmax_cross_entropy(y,t) self.accuracy = F.accuracy(y,t) return y, loss
config でフラグが管理されており、上述の2点について chainer.config.enable_backprop、chainer.config.train がデフォルトで True になっている。つまり Variable の volatile=False、train=True の状態なので、学習時は気にせず実行すればいいらしい。
test 時は with でフラグを False にしてやる(抜けると元に戻る)。
with chainer.using_config('train', False), chainer.no_backprop_mode(): y, loss = self.model.loss(x, t)
こうすれば、volatile=True, train=False で実行できる。
古いソースコードをコピペで使っていたので、custom loop でも iterators をうまく使えばすっきりするのかとサンプルを見ながら。とりあえず止まらなければいいやでserializers とかを try で投げてたりするんですが、お気になさらず。
上がソースコード。
これを雛形にして今後遊べそうだなと思ったところで、今回はこんな感じ。
- 作者: 岡谷貴之
- 出版社/メーカー: 講談社
- 発売日: 2015/04/08
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (13件) を見る