Skip to content

(おまけ) Hamiltonian Descent Methods の Implicit Method について

先日こんな記事を書いたが

Hamiltonian Descent Methods の Second Explicit Method を Chainer で実装して検証した

この中のアルゴリズムである Implicit Method について少し。

元論文は以下の DeepMind が開発した最適化手法である。

arxiv.org

最適化したい関数 \(f\) の変数 \(x(=\theta)\)、モーメンタムを \(p\)、運動エネルギー \(k\) としたとき、時間発展を以下のように記述できた。

$$\begin{align}
\dot{x_t} &= \nabla k(p_t) \tag{1} \\
\dot{p_t} &= – \nabla f(x_t) – \gamma p_t \tag{2}
\end{align}$$

\(\gamma\) は勢いの減衰係数であり、ドットは時間微分を表す。詳細は前のブログを参照のこと。ここで時間微分を次のように近似するのが Implicit Method である。

$$\displaystyle
\begin{align}
\dot{x_t} &\sim \frac{x_{t+1} – x_t}{\Delta t} = \nabla k(p_{t+1}) \tag{3} \\
\dot{p_t} &\sim \frac{p_{t+1} – p_t}{\Delta t} = – \nabla f(x_{t+1}) – \gamma p_{t+1} \tag{4}
\end{align}$$

以降、タイムステップ \(\Delta t = \epsilon, 1/(1+\gamma \epsilon) = \delta\) と置く。ここで論文中では全部すっ飛ばして以下のように \(x_{t+1}, p_{t+1}\) を求めればいいよと書いてある。

$$\displaystyle
\begin{align}
x_{t+1} &= \mathop{\rm argmin}\limits_x \left\{ \epsilon k^{*} \left( \frac{x – x_t}{\epsilon} \right) + \epsilon \delta f(x) – \delta \langle p_t, x \rangle \right\} \tag{5} \\
p_{t+1} &= \delta p_t – \epsilon \delta \nabla f(x_{t+1}) \tag{6}
\end{align}$$

あまりにも唐突すぎて ?????? だったので少し考えてみた。もっといい証明があれば是非教えて下さい!

式(6)は簡単だ。まず式(4)の両辺に \(\epsilon\) をかける。

$$\displaystyle
\begin{align}
p_{t+1} – p_{t} &= – \epsilon \nabla f(x_{t+1}) – \gamma \epsilon p_{t+1} \\
(1 + \gamma \epsilon) p_{t+1} &= p_{t} – \epsilon \nabla f(x_{t+1}) \\
p_{t+1} &= \delta p_t – \epsilon \delta \nabla f(x_{t+1}) \tag{7} \\
\end{align}$$

つまり、更新後のパラメータ \(x_{x+1}\) が決まれば、モーメンタム \(p_{t+1}\) も得られる。しかし当然 \(\theta’ = x_{t+1}\) を求めたいのだ。

ここで簡単にルジャンドル変換を導入しておく。ルジャンドル変換した関数を \(f^{*}\) で表すと定義は \(f^{*}(p) = – \min{ \left[ f(x) – \langle x, p \rangle \right] }\) である。意味的にはざっくり座標の空間と傾き(微分)の空間の交換であり、関数上の点の集合って、その位置での接線の集合と1対1対応してると考えれば分かりやすい。今回のように関数とモーメンタムの対応関係に似ており、何かうまくいきそうな予感がする(感覚)。事実、ハミルトニアンの導出にルジャンドル変換が登場し、微分 \(\dot{x}\) を速度 \(p\) に変換していい感じの系にしている。

本当に唐突だが、 \(\nabla \left( k^{*} \right) (x) = \left( \nabla k \right)^{-1} (x)\) という定理がある。これを使うと、 \(\nabla k^{*} \left( \nabla k(p) \right) = p\) となってしまう。何がしたいかというと、式(3) の両辺を \(\nabla k^{*}\) に突っ込んでやる。そうすると以下が得られる。

$$\displaystyle
\nabla k^{*} \left( \frac{x_{t+1} – x_t}{\epsilon} \right) = \nabla k^{*} \left( \nabla k(p_{t+1}) \right) = p_{t+1} \tag{8}$$

式(8) の \(p_{t+1}\) に式(7) を代入してやろう。

$$\displaystyle
\begin{align}
\nabla k^{*} \left( \frac{x_{t+1} – x_t}{\epsilon} \right) = \delta p_{t} – \epsilon \delta \nabla f(x_{t+1}) \\
\nabla k^{*} \left( \frac{x_{t+1} – x_t}{\epsilon} \right) + \epsilon \delta \nabla f(x_{t+1}) – \delta p_{t} = 0 \tag{9}
\end{align}$$

こうすることで \(p_{t+1}\) を消去することができ、 \(x_t, p_t\) から \(x_{t+1}\) を求めればいいことになった。…が、まだ終わらない…。まあ、つまり… \(F'(x) = \nabla k^{*} \left( \frac{x – x_t}{\epsilon} \right) + \epsilon \delta \nabla f(x) – \delta p_{t}\) という導関数が \(F'(x_{t+1}) = 0\) を満たすことが \(x_{t+1}\) の条件であり、そのような場合凸関数 \(F(x)\) が最小になる \(x\) を求めてやればいいことになる。それが式(5)の argmin の意味である。( \(f,k\) が簡単な関数であれば、こんなまどろっこしいことをしなくていいのだが…)。

ゴールが見えてきたので、答えから先に言うと \(F(x)\) は以下のような関数である。

$$\displaystyle
F(x) = \epsilon k^{*} \left( \frac{x – x_t}{\epsilon} \right) + \epsilon \delta f(x) – \delta \langle p_t, x \rangle \tag{10}$$

\(\langle p_t, x \rangle = \sum_n p_t^{(n)} x^{(n)}\) という内積である。(すっげールジャンドル変換みがあるので、もっと簡単に証明できるんじゃねーのかなーと思うのだがよく分からず)
では、微分して式(9)と同値であるか検証してみよう。

まず \(y = (x – x_t) / \epsilon\) とすると \(\mathrm{d} y / \mathrm{d} x = 1/\epsilon\) であり、 \(\nabla \langle p_t, x \rangle = \nabla \sum_n p_t^{(n)} x^{(n)} = p_t\) である。すると、

$$\displaystyle
\begin{align}
\nabla F(x) &= \epsilon \frac{\mathrm{d} y}{\mathrm{d} x} \nabla k^{*} (y) + \epsilon \delta \nabla f(x) – \delta \nabla \langle p_t, x \rangle \\
&= \epsilon \nabla k^{*} \left( \frac{x – x_t}{\epsilon} \right) + \epsilon \delta \nabla f(x) – \delta p_t \tag{11}
\end{align}$$

こうすることで \(x_{t+1}\) を算出できるようになる。

$$\displaystyle
\begin{align}
x_{t+1} &= \mathop{\rm argmin}\limits_x F(x) \\
&= \mathop{\rm argmin}\limits_x \left\{ \epsilon k^{*} \left( \frac{x – x_t}{\epsilon} \right) + \epsilon \delta f(x) – \delta \langle p_t, x \rangle \right\} \tag{12}
\end{align}$$

式(5)である。これで、 \(x_t, p_t\) から \(x_{t+1}, p_{t+1}\) が求めることができようになった (まあ、argmin があるので簡単に求まらないのだが…)。

そう、何が言いたいかというと、僕がこれだけのプロセスを経て理解した1つの式が、論文中ではさも自明かのように展開するのである。機械学習でご飯を食べることを諦めた瞬間である。…まあ世界最強の DeepMind だし………… (現実逃避)

諦めて 寝なさい

諦めて 寝なさい

Be First to Comment

    コメントを残す

    メールアドレスが公開されることはありません。 * が付いている欄は必須項目です