Skip to content

“Optimizing Millions of Hyperparameters by Implicit Differentiation” ハイパーパラメータ最適化の概要とPyTorchで実装練習

面白い論文を読んだので、久しぶりにちょっとブログを書いてみる。Chainer が残念ながら開発中止となり最近 PyTorch に移行したので、その練習を兼ねて実装もしてみた。

ニューラルネットを学習する時、ハイパーパラメータも同時に最適化したい… という場合がよくある。しかし、ただでさえ多いパラメータにハイパーパラメータも最適化するとなると途方に暮れることが多い。この論文は、ちょっとウィットな手法で効率的に最適化できるよというもの。

“Optimizing Millions of Hyperparameters by Implicit Differentiation” (Lorraine et al. 2019)
https://arxiv.org/abs/1911.02590

問題設定

ニューラルネットのパラメータを \(\boldsymbol w\)、今回最適化したいハイパーパラメータを \(\boldsymbol \lambda\) として、損失関数は学習データについて \(\mathcal{L}_{\mathrm{T}}(\boldsymbol \lambda, \boldsymbol{w}) \)、検証データについて \(\mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w})\) とする。

最終的に求めたいのは、\(\boldsymbol{w}^* = \mathop{\rm argmin}\limits_{\boldsymbol{w}} \mathcal{L}_{\mathrm{T}}(\boldsymbol \lambda, \boldsymbol{w})\) における \(\boldsymbol \lambda^* = \mathop{\rm argmin}\limits_{\boldsymbol \lambda} \mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w}^*) \) である。

なので、勾配降下法に使える \(\displaystyle \frac{\partial \mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w}^*)}{\partial \boldsymbol \lambda} \) ( hypergradient : ハイパー勾配?) が欲しいことになるが、ある \(\boldsymbol \lambda\) から学習データで \(\boldsymbol{w}^*\) を最適化し、検証データで \(\boldsymbol \lambda^*\) に最適化するという二重の最適化になり超複雑ということになる。

左が学習データの損失関数、右が検証データの損失関数の空間。往々にして異なる分布である。 (Lorraine et al. 2019)

ハイパー勾配を求めたい

ハイパー勾配は、\(\boldsymbol \lambda\) に依存した直接の変位と、\(\boldsymbol{w}^* (\boldsymbol \lambda)\) による間接の変位の和で記述される。分かりやすく全微分で表現すると、

$$\begin{align}
\frac{\mathrm{d} \mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w}^*)}{\mathrm{d} \boldsymbol \lambda} &= \frac{\partial \mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w}^*)}{\partial \boldsymbol \lambda} + \frac{\partial \mathcal{L}_{\mathrm{V}}(\boldsymbol \lambda, \boldsymbol{w}^*)}{\partial \boldsymbol{w}^*} \frac{\partial \boldsymbol{w}^* (\boldsymbol \lambda)}{\partial \boldsymbol \lambda}
\end{align}$$

問題は \(\displaystyle \frac{\partial \boldsymbol{w}^* (\boldsymbol \lambda)}{\partial \boldsymbol \lambda}\) (\(\boldsymbol \lambda\) を変えた時、最適な \(\boldsymbol{w}\) はどう変わるか) を測定するのがきつい。ので、2つの定理を使う。

  • コーシーの陰関数定理
    \(\displaystyle \frac{\partial y}{\partial x} (x) = -\frac{J_{f,x}}{J_{f,y}} = -\left[ \frac{\partial f}{\partial y} (x, y(x)) \right]^{-1} \frac{\partial f}{\partial x} (x, y(x))\)
  • ノイマン級数展開 \(\displaystyle \left[ I – A \right]^{-1} = \sum_{n=0}^{\infty} A^n\)

ノイマン級数は、等比数列の和 \(1/(1-x) = 1+x+x^2 \ldots \) の行列への拡張であり、左から \(\left[ I – A \right]\) を作用させると一致することがわかる。(ただし、\(A\) の無限大乗が小さいことが要求されるが)

コーシーの陰関数定理が示しているものを円の方程式から類推する。コーシーの陰関数定理は \(f = x^2 + y^2 – 1 (=0)\) という関数に張り付いた局所的な \(x\) と \(y\) の関係を示している。(局所的に) 円の上半分では \(\displaystyle y = \sqrt{-x^2+1}\) であり \(\displaystyle \frac{\partial y}{\partial x} = -\frac{x}{y}\), \(\displaystyle \frac{\partial f}{\partial x} = 2x\), \(\displaystyle \frac{\partial f}{\partial y} = 2y\) なので、定理通りである。

すると \(\displaystyle f = \frac{\partial \mathcal{L}_T}{\partial \boldsymbol{w}} (=0)\) という関数に張り付いた \(x = \boldsymbol \lambda\), \(y(x) = \boldsymbol{w}^* (\boldsymbol \lambda)\) に拡張できそうな気がしてくる。ハイパー勾配は次のようになる。(ただし、\(f\) が微分でさらに微分するとなると巨大な Hessian がここで登場する)

$$\displaystyle \begin{align} \frac{\partial y}{\partial x} &=~- \left[ \frac{\partial f}{\partial y} (x, y(x)) \right]^{-1} \frac{\partial f}{\partial x} (x, y(x)) \ \frac{\partial \boldsymbol{w}^*}{\partial \boldsymbol \lambda} \\ &=~- \left. \left[ \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol{w}^T} \right]^{-1} \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol \lambda^T} \right|_{\boldsymbol{w}^*(\boldsymbol \lambda)}
\end{align}$$

Hessian の逆行列 (\(\mathcal{O} (w^3)\)) をどう求めればいいか世界中の研究者の頭を悩ませているが、ノイマン級数 (\(\left[ I – A \right]^{-1} = \sum_{n=0}^{\infty} A^n\)) で展開するのがこの手法のミソ。

$$\displaystyle \begin{align}
\frac{\partial \boldsymbol{w}^*}{\partial \boldsymbol \lambda} &=~- \left. \left[ \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol{w}^T} \right]^{-1} \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol \lambda^T} \right|_{\boldsymbol{w}^*(\boldsymbol \lambda)} \\
&=~- \left. \lim_{i \rightarrow \infty} \sum^{i}{j=0} \left[ I – \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol{w}^T} \right]^j \left[ \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol \lambda^T} \right] \right|_{\boldsymbol{w}^*(\boldsymbol \lambda)} \\ &\sim~- \left. \sum_{j < i} \left( I – \alpha \frac{\partial^2 \mathcal{L}T }{\partial \boldsymbol{w} \partial \boldsymbol{w}^T} \right)^j \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial \boldsymbol \lambda^T} \right|_{\boldsymbol{w}^*(\boldsymbol \lambda)}
\end{align}$$

ノイマン級数のおかげで、少ないメモリと計算量で結構よさげな近似になっていることが分かる。ただし、直接微分できない離散的なハイパーパラメータに適用はできないので工夫が必要になる手法である。\(\alpha\) は学習率であり収束性を保証するために入れるらしい(射影の方向が合っていればいいというお気持ちだろうか)。

実験

まとめると、ハイパーパラメータは次のように更新をすればいい。

$$\begin{align}
\lambda \leftarrow \lambda – \alpha \left[ \frac{\partial \mathcal{L}_\mathrm{V}}{\partial \boldsymbol \lambda} + \left. \frac{\partial \mathcal{L}_V}{\partial \boldsymbol{w}} \sum^{i}_{j=0} \left( I – \beta \frac{\partial^2 \mathcal{L}_T} {\partial \boldsymbol{w} \partial \boldsymbol{w}^T} \right)^j \frac{\partial^2 \mathcal{L}_T}{\partial \boldsymbol{w} \partial {\boldsymbol \lambda}^T} \right|_{\boldsymbol{w}^*} \right]
\end{align}$$

Hessian さえちゃんと計算できれば実にシンプル!

ということで実験してみた。Iris データセットに対して学習データ60、検証データ60、テストデータ30として、クロスエントロピーロスの変化を見る。分類モデルは中間ユニット 1024 の3層ニューラルネット。ハイパーパラメータは論文同様にそれぞれの重みに対する weight decay (17280個)。全バッチ学習で100 epoch パラメータを最適化しハイパー勾配を計算するという iteration を100回繰り返す。

実験は、ハイパーパラメータ最適化をせず (no HO) 一律の weight decay (WD) を適用したもの。ハイパーパラメータ最適化をして、ハイパー勾配の学習率(\(\alpha\))、ノイマン級数の学習率(\(\beta\))、ノイマン級数の次数 (\(i\)) を色々振ってみた。

学習データの損失関数の変化
検証データに対する損失関数の変化
テストデータに対する損失関数の変化

面白いことに、ハイパーパラメータ最適化無しの場合過学習気味だが、ちょうどいいぐらいの近似(緑)をすることで、テストデータに対する汎化性能が非常に高くなっている。ちなみにだいたい最適化された WD は平均 0.00003 ぐらいになった。論文の他の実験もめちゃめちゃ面白いので読んでみてください。

とはいえ、weight decay がハイパーパラメータだと重みのシュリンクに圧をかけるのであんまりフェアじゃないなぁという印象もある。あと、学習率とかもハイパーパラメータなのだが、損失関数の空間に作用するわけではなく、これこそ最適化されてほしい…。しかし、結構簡単で計算コストも軽いので明日から使ってみたい手法である。

今回実験したコードは以下に置きました。
https://github.com/mocchi-tam/pytorch-HO-implicit-diff

…なんか面白い職場ないかなぁ(小声)

One Comment

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です