前々からすごいと思ってたハミルトニアン降下法 (Hamiltonian Descent Methods) だけど、最近さらに解説・実装も充実してきたので自分も検証してみようと思った次第。
[1809.05042] Hamiltonian Descent Methods
まえがき
例えばディープラーニングで推論モデルを構築する場合、関数が複雑すぎて最適なパラメータを得るのが非常に難しいという問題が存在する。そんな時、関数の勾配を使って最適な値を探索するのが一般的 (つまり二次関数なんかを考えれば簡単で、最小値を求めたければ傾いてる方向に転がっていけばいい)。ハミルトニアン降下法の面白いところは、この関数コロコロを運動方程式使えばええやんという「お、おぅ…」な手法という(ざっくり)。
例えば関数 \(f\) を最小化する最適パラメータ \(x_*\) とは \(x_* = \mathrm{argmin} \left[ f(x) \right]\) である。ディープラーニングでは、損失関数を定義して推論と正解の差を \(f\) にして \(\theta = x\) を最小化したりする。Maddison et al. の実験では、普通の勾配法と比べて、すさまじい勢いで最適値に降下していく様子が確認できる。
今回構築された手法は数学的にも非常によく考察されていて、そこを非常に分かりやすく解説している方がいらっしゃるので、自分は実装に必要な算数を紹介するに留めようと思う。(この記事は本当に素晴らしいので一読の価値アリ)
ざっくり理論の話
自分は(一応)物理出身なのでハミルトニアンを使った解析は馴染み深いのだけど、厳密さより分かりやすさ重視でいく。
ハミルトニアンとはその系の全エネルギーを記述したもの、つまりこの場合、運動エネルギー + 位置エネルギー(ポテンシャル)である。今回の場合、位置エネルギーが \(f(x)\) である。位置エネルギーの話を思い出すと、ゼロ点を決めないといけないのでこの場合 \(f(x_*)\) となる。
運動エネルギーは運動量(モーメンタム) \(p\) で記述される \(k(p)\) となる。 古典力学の場合、運動量は \(mv\) 、運動エネルギーは \(k(p) = 1/2mv^2 = p^2/2m\) である。\(f, k\) を使ってハミルトニアン \(H\) を書き下すと次のようになる。
$$\begin{equation}
H(x, p) = k(p) + f(x) – f(x_*) \tag{1}
\end{equation}$$
ハミルトニアンを使うと、運動方程式は次のように非常にわかりやすいフォームとなる。
$$\begin{align}
\dot{x} &= \frac{\partial H}{\partial p} \tag{2} \\
\dot{p} &= -\frac{\partial H}{\partial x} \tag{3}
\end{align}$$
ここでドットは時間微分を表す。古典力学で見ると、\(\dot{x} = p/m = v\)、例えば自由落下だと \(\dot{p} = -mg\) という馴染み深い式を意味していることが分かる。
そろそろ実装に向けた定式化を。タイムステップを刻みながらコロコロ転がすので時刻 \(t\) でのパラメータ \(x_t\) 、モーメンタム \(p_t\) とすると、ハミルトニアンの運動方程式を以下のように書くことができる。
$$\begin{align}
\dot{x_t} &= \nabla k(p_t) \tag{4} \\
\dot{p_t} &= -\nabla f(x_t) \tag{5}
\end{align}$$
運動エネルギーの微分(転がる勢い)の方向にパラメータを更新し、ポテンシャルの微分(関数の傾き)の方向にモーメンタムを変化させるという非常にシンプルな式である。
ところでこのままだと “エネルギー保存の法則” が働き、落ちない(笑)。そこで散逸項、つまり摩擦や空気抵抗みたいなやつを導入して、モーメンタムが強いほど速度を抑制する。その係数を \(\gamma\) とするとモーメンタムの更新式は以下となる。
$$\begin{equation}
\dot{p_t} = -\nabla f(x_t) – \gamma p_t \tag{6}
\end{equation}$$
こうすることでグルグル回るだけだったハミルトニアンが最小値に落ちていくようになる (Conformal Hamiltonian Field)。
アルゴリズムの話
普通ならこれを積分すれば終わりでいいのだが、関数が複雑すぎてある程度近似しながら離散的に評価していく必要がある。ここでタイムステップ \(\Delta t = \epsilon\) とすると更新式(傾き×刻み幅を足す積分)は、
$$\begin{align}
x_{t+1} &= x_t + \epsilon \nabla k(p_{t+1}) \tag{7} \\
p_{t+1} &= p_t + \epsilon \left[ -\nabla f(x_{t+1}) – \gamma p_{t+1} \tag{8} \right]
\end{align}$$
これが Implicit Method (IM) になる。あらかじめ転がる先( \(t+1\) )のポテンシャルを予見しながら更新するので、最も汎用性が高い。その代わりその予見に最適化が必要なので、計算が重くなることは想像に難くない。
そこで2つの Explicit Method はそれぞれ仮定を入れて計算を簡単にする。First Explicit Method (FEM) は、 \(x_t\) で一度止めて評価し、一方 Second Explicit Method (SEM) は \(p_t\) を止めて評価する。したがって、FEM は最適化する関数が急激に変化しないことを要請し、SEM はモーメンタムが急激に変化しないことを要請する。まとめるとこんな感じ
- IM : \((x_{t+1}, p_{t+1})\) で評価
- FEM : \((x_{t}, p_{t+1})\) で評価 ( \(\nabla f\) が緩やか)
- SEM : \((x_{t+1}, p_{t})\) で評価 ( \(\nabla k\) が緩やか)
簡単な式変形をしておく。式(8) を \(p_{t+1}\) について整理すると、
$$\begin{align}
(1+\gamma \epsilon) p_{t+1} &= p_t – \epsilon \nabla f(x_{t+1}) \\
p_{t+1} &= \delta p_t – \epsilon \delta \nabla f(x_{t+1}) \tag{9}
\end{align}$$
ここで、 \(\delta = 1/(1+\gamma \epsilon)\) である。
FEM は \((x_{t}, p_{t+1})\) で評価するので、 \(x\) の更新は式(7)そのままで、 \(p\) の更新式は式(9)から以下のようになる。
$$\begin{equation}
p_{t+1} = \delta p_t – \epsilon \delta \nabla f(x_{t}) \tag{10}
\end{equation}$$
逆に SEM は \((x_{t+1}, p_{t})\) で評価するので、 \(x\) の更新式は式(7)の右辺が \(p_t\) となり、 \(p\) の更新式は次のように求まる。
$$\begin{align}
p_{t+1} &= p_t + \epsilon \left[ -\nabla f(x_{t+1}) – \gamma p_{t} \right] \\
p_{t+1} &= (1-\epsilon \gamma) p_t – \epsilon \nabla f(x_{t+1}) \tag{11}
\end{align}$$
まとめると以下のようになる。
$${\displaystyle
\begin{eqnarray}
\left\{
\begin{array}{l}
x_{t+1; \mathrm{FEM}} &= x_t + \epsilon \nabla k(p_{t+1}) \\
p_{t+1; \mathrm{FEM}} &= \delta p_t – \epsilon \delta \nabla f(x_{t}) \tag{12}
\end{array}
\right.
\end{eqnarray}
}$$
$${\displaystyle
\begin{eqnarray}
\left\{
\begin{array}{l}
x_{t+1; \mathrm{SEM}} &= x_t + \epsilon \nabla k(p_{t}) \\
p_{t+1; \mathrm{SEM}} &= (1-\epsilon \gamma) p_t – \epsilon \nabla f(x_{t+1}) \tag{13}
\end{array}
\right.
\end{eqnarray}
}$$
この論文の本質は、実は非常にシンプルである。普通の勾配法であれば、関数の微分で直接パラメータが更新されるところを、モーメンタムを更新し、またその勾配でパラメータを更新するということが起こっている。
運動エネルギーについて
\(f\) は最適化する関数だが、モーメンタムを決める \(k\) には任意性がある。論文中では、 \(f\) を冪乗関数と仮定して最適な \(k\) を求めていた(詳しい解説については他に任せる)。関数が複雑な場合は、”相対論的 (relativistic)” な運動エネルギーを使うといい。
これは特殊相対性理論の運動エネルギー \(k(p) = \sqrt{\| p \|^2 c^2 + m^2 c^4} – m c^2\) (合ってるっけ?)から、
$$\displaystyle
\begin{align}
k(p) &= \sqrt{\| p \|_2^2 + 1} – 1 \tag{14} \\
\nabla k(p) &= \frac{p}{\sqrt{\| p \|_2^2 + 1}} \tag{15}
\end{align}$$
こうすることで、真ん中のようにいい感じに落ちていく。
実装の話
FEM については、Chainer で素晴らしい実験結果が出ている。今回の実装は、こちらの記事を参考にさせて頂きました。
しかし FEM は、謎の関数 \(f\) に対して \(\nabla f\) が緩やかであることを要請してしまうので、むしろハンドリングのしやすい関数系を仮定した \(\nabla k\) が緩やかであることを要請する SEM の方がナイーブなアプローチであるように思う (程度の問題かもしれないが。ちゃんと数学的に解析するのはまた)。
ここで実は、CuPy を使った GPU 活用について記述した前回の記事につながってくる。
Chainer チュートリアルも公開されたし GPU がもうちょっと楽しくなる CuPy カーネル入門
\(x,p\) の更新は独立で計算されるので、CuPy の ElementwiseKernel が活用でき、\(\nabla k(p)\) の分母は二乗和なので ReductionKernel が使える。
optimizer.UpdateRule を継承した HamiltonianExplicitRule において、SEM の update_core_gpu は次のように記述できる(はず)。
def update_core_gpu(self, param): grad = param.grad hp = self.hyperparam p = self.state['p'] if HamiltonianExplicitRule._kernel_x is None: HamiltonianExplicitRule._kernel_x = cp.ElementwiseKernel( 'T epsilon, T p, T denomp, T param', 'T x', 'x = param + epsilon * p / denomp', 'Hamiltonian_x') if HamiltonianExplicitRule._kernel_r is None: HamiltonianExplicitRule._kernel_r = cp.ReductionKernel( 'T p', 'T denomp', 'p * p', 'a + b', 'denomp = sqrt(a)', '1', 'relativistic') if HamiltonianExplicitRule._kernel_p is None: HamiltonianExplicitRule._kernel_p = cp.ElementwiseKernel( 'T delta, T epsilon, T grad, T p0', 'T p1', 'p1 = p0 * (2.0 - (1.0 / delta)) - epsilon * grad', 'Hamiltonian_p') else: # p p = HamiltonianExplicitRule._kernel_p(hp.delta, hp.epsilon, grad, p) # x denomp = HamiltonianExplicitRule._kernel_r(p) param.data = HamiltonianExplicitRule._kernel_x(hp.epsilon, p, denomp, param.data)
SEM は先にパラメータ更新をして勾配を求めて、モーメンタムの変化量を計算する。結局実装を考えてみたところ、勾配計算は次のミニバッチでもそれほど変わらないという仮定をして更新をすることにしたら FEM とあまり変わらない式に…。大きな違いは散逸項によるモーメンタムの減衰が更新前に基づいて行われるか、更新後に基づいて行われるかというぐらいだろう(あまり大きな違いも無い気がする…)。
MNIST で実験してみた
実際に、MNIST を使って実験してみる。実験条件は上述の FEM の実験に合わせて、タイムステップ \(\epsilon = 1\) , 散逸項の係数 \(\gamma = 0.6\) とし、中間ユニット500の3層ニューラルネット、100枚のミニバッチ学習で200 epoch 回す。
Optimizer は、Adam, SGD (モーメンタム=0.9, 学習率=0.01), FEM, SEM を比較する。
問題設定が簡単すぎたか普通に SGD がいい結果になっているが、ハミルトニアン降下法もちゃんと落ちている。予想通り FEM に比べて SEM がよく落ちているが、この結果はうまくいきすぎて懐疑的になってしまう。今回の実装では本質的に散逸項の違いぐらいしかないしかないので、\(\gamma\) の調整で変わってくるかもしれない。今回のパラメータではこんな感じだが、MNIST ぐらい簡単な問題だと SGD と FEM, SEM はパラメータの調整でなんとでもなって、あまり差は無さそう。
実行時間と性能についても比較する。
Adam | SGD | FEM | SEM | |
---|---|---|---|---|
実行時間 (sec) | 1834 | 1765 | 1884 | 1877 |
validation loss | 0.41 | 0.070 | 0.075 | 0.077 |
validation accuracy (%) | 98.25 | 98.1 | 98.20 | 98.25 |
ハミルトニアン降下法のパフォーマンスはその他の手法と見劣りしないが、train loss の落ち方に対して少し過学習気味になっているきらいがある。
ハミルトニアン降下法については、結構適用先は限られるのかと思っていたので意外とちゃんと使える技術であることに驚き。複雑な問題に対して、関数の空間が特徴的な ResNet とかでやると何か分かるかも。応用技術とかがこれからさらに発展していくと面白いかもしれない。

基礎からのベイズ統計学: ハミルトニアンモンテカルロ法による実践的入門
- 作者: 豊田秀樹
- 出版社/メーカー: 朝倉書店
- 発売日: 2015/06/25
- メディア: 単行本
- この商品を含むブログ (6件) を見る