れおなちずむ

素粒子物理、量子計算、機械学習、計算機科学とかの話をします

f-divergenceと汎関数微分

こんにちは🐬🐬🐬

情報理論の分野ではしばしばKullback-Leibler divergence(KL divergence)とよばれる量が登場します。

$$ D_{KL}(p||q) = \int p(x)\log \frac{p(x)}{q(x)}dx $$

このKL divergenceは、任意の確率分布$p,q$に対して常に非負の値をとることが知られています。
実際、常に$x-1 \geq \log x$が成り立つことに注目すると、

$$ D_{KL}= -\int_x p \log\frac{q}{p} \geq -\int_x p\left( \frac{q}{p} - 1 \right) = \int_x (p - q)= 0 $$

です*1。等号が成立するのは$x=1$、すなわち$p=q$のときであり、それに限られることもわかります。

距離$d$に必要な4公理

  • 非負性    $d(x,y) \geq 0$
  • 完備性    $d(x,y)=0 \Leftrightarrow x=y$
  • 対称性    $d(x,y) = d(y,x)$
  • 三角不等式  $d(x,y) + d(y,z) \geq d(x,z)$

さて、非負の値であることを見るに、KL divergenceという量が一見異なる確率分布$p,q$同士の距離(metric)を与えるかのように思えるのですが、実際には距離に要求される4つの公理のうち残りの2つを満たさないことがわかります。
そこで、距離の公理を弱めて、非負性と完備性を満たすことのみを要請したものを、一般にdivergenceと呼びます。
実のところ、KL divergenceはdivergenceとよばれる広範な対象のうちの1つに過ぎないのです。

f-divergence

定義から分かるとおり、divergenceはKL divergence以外にも無数にあります。
そのうちの一つが今回紹介するf-divergenceです。f-divergenceというのは、$\log$の部分を一般の関数$f$に置き換えたものです:

$$ D_{f}(p||q) = \int q(x)f\left(\frac{p(x)}{q(x)}\right)dx $$

この定義によれば、f-divergenceは$f(w)= w \log w$と選んだ場合に通常のKL divergenceと一致します。
一方$f(w) = -\log w$と選んだ場合、これはKL divergenceにおいて引数の$p,q$を逆にしたものと一致します。これを文字通りinverse KL divergenceと呼びます。
ほかにも、適当に$f$を選ぶことで、divergenceを作ることができるのですが、そのためには当然$f$の選び方として$D_{f}$がdivergenceの公理を満たすようなものを選ばなければいけません。
もし$f$に必要とされる条件が分かれば、わたしたちはdivergenceという量を$f$というかなり一般的な形式でもって、ある程度一貫性を備えながら捉えることができるはずです。
そこで、$D_f$がdivergenceになるために必要な$f$に関する条件について考えてみます。

fの必要条件

まず、完備性から、直ちに$f(w)=0 \Leftrightarrow w=1$がわかりますね。
それと、$p \approx q$のときの$D_{f}$の振る舞いを調べてみます:

$$ D_f(p||p+\delta p) = D_f(p||p) + \left. \int_x\delta q\frac{\delta}{\delta q}D_f(p||q)\right|_{q=p} + \frac{1}{2}\left. \left( \int_x \delta q\frac{\delta}{\delta q} \right)^2 D_f(p||q)\right|_{q=p} + \cdots $$

ここでやっていることは本質的には多変数関数をテイラー展開しているのと同じなのですが、今の場合、微小変化をとっているのは関数の値ではなく、関数そのものであることがポイントです。
なぜなら$D_f$の中にある積分の存在によって、$q$の形そのものの変化が$D_f$の値に直接影響を与えるからです。このようなものを汎関数(functional)と呼びます。汎関数の微小変化を調べるときは、単なる微分ではなく、汎関数微分というものを用います。
汎関数微分$\displaystyle \frac{\delta I}{\delta q}$は、汎関数$I(q)$があったときに、その1次の微小変化

$$ \delta I(q) = \int_x \delta q(x) \sigma(x) $$

に対する1次の係数$\sigma(x)$を対応させることで定義されます。すなわち、

$$ \frac{\delta}{\delta q(x)}I(q) = \sigma(x) $$

です。$D_f(p||q)$は$p,q$に関する2変数汎関数なので、本来$d\mathbf{x}\cdot\nabla$と書くところを、ここでは$\displaystyle \int_x\delta q\frac{\delta}{\delta q}$と書いたというわけです。

さて、定義に従って実際に計算をしていきますが、第1項は明らかに$D_f(p||p)=0$です。
第2項および第3項は

$$ \left. \frac{\delta}{\delta q(x)}D_f(p||q) \right|_{q=p} = \left[f\left(\frac{p(x)}{q(x)}\right) + q(x)\left(-\frac{p(x)}{q^2(x)}\right)f'\left(\frac{p(x)}{q(x)}\right)\right]_{q=p} = -f'(1) $$

$$ \left. \frac{\delta}{\delta q(x)} \frac{\delta}{\delta q(y)}D_f(p||q) \right|_{q=p} = \left[q(x)\left(-\frac{p(x)}{q^2(x)}\right)^2f''\left(\frac{p(x)}{q(x)}\right)\delta(x-y)\right]_{q=p} = \frac{\delta(x-y)}{p(x)}f''(1) $$

なので、

$$ D_f(p||p+\delta p) = -f'(1) \int_x \delta p(x) + \frac{1}{2} \int_x \int_y \delta p(x) \delta p(y) \frac{\delta(x-y)}{p(x)} f''(1) + O(\delta p^3) $$

となります。ここで、確率分布$q(x) = p(x) + \delta p(x)$に対して$\int_x q(x) = \int_x p(x) = 1$が成り立つので、$\int_x \delta p(x) = 0$となることに注意します。
結局$p \approx q$での$D_f$の振る舞いは

$$ D_f(p||p+\delta p) = \frac{1}{2} \int_x \int_y \delta p(x) \delta p(y) \frac{\delta(x-y)}{p(x)} f''(1) + O(\delta p^3) $$

で表されます。ここで、

$$ g_p(x,y) \equiv \frac{\delta(x-y)}{p(x)} f''(1) $$

と定義すると、$g_p$は点$p$における計量テンソルのようなものになっていることに気づきます。実際、$\delta$関数の性質から、$g_p(x,y)=g_p(y,x)$が成り立っていることが確かめられます。

$D_f(p||q)$が任意の$p,q$に対して非負の値をとるならば、明らかに$p \approx q$のときであっても$D_f \geq0$でなければなりません。このときの$D_f$の振る舞いは完全に2次の項

$$ \frac{1}{2} \int_x \int_y \delta p(x) \delta p(y) g_p(x,y) $$

によって決まるので、計量テンソル$g_p(x,y)$は正定値である必要があります。$\delta$関数の存在によって$g_p(x,y)$は対称行列なので、$g_p(x,y)$が正定値であることは、$f''(1)>0$と同値です。言い換えればこれは$f(w)$が$w=1$において下に凸であるということを意味します。

最後に、ある定数$a$に対して、

$$ D_{f}(p||q) = \int_x q(x)f\left(\frac{p(x)}{q(x)}\right) \geq 0 = a\int_x q(x)\left( \frac{p(x)}{q(x)} - 1 \right) $$

が成り立つ必要があるので、$f$に対し恒等的に

$$ f(w) \geq a(w-1) $$

が成り立たてば、$D_{f} \geq 0$が保証されることになります。


ここまでで$D_f$がdivergenceであるためには、少なくとも$f$が次の3条件を満たす必要があることがわかりました。

  • $f(w)=0 \Leftrightarrow w=1$
  • $f''(1)>0$
  • $\exists a,f(w) \geq a(w-1)$

下の2つの条件を満たす$f$として、最もよく用いられるのが凸関数です。凸関数は凸性と呼ばれる嬉しい性質を持っていて、統計力学や情報幾何学の分野では頻出します。
勘違いされがちですが、$f$が凸関数であることは、$D_f$がdivergenceであるための必要十分条件ではありません。凸関数でなくても$D_f$をdivergenceにするような$f$が存在するということは、上の議論からわかるはずです。
いずれにしろ「$w=1$でのみ$0$になる凸関数」がf-divergenceとして適当な$f(w)$であるということになります。

汎関数微分

$f$に必要な条件が分かったところで、汎関数汎関数微分について補足しておきます。

汎関数汎関数微分に馴染みのない人にとっては、なぜこんなものを引き合いに出す必要があるのか、よく理解できなかったかもしれません。
しかし、本来確率分布$p(x)$に要請されるのは$\displaystyle \int_x p(x) = 1$だとか$p(x) \geq 0$だとか、極めて少ない条件のみだったことを思い出すべきです。一般に$p(x)$に存在している自由パラメータというのは非加算無限個あるのです。
このように、確率分布や場などの連続的なものを丸ごと1つの解析的な対象として扱う場合には必然的に汎関数汎関数微分が現れます*2。 確率分布という対象を扱う以上、汎関数による表現は避けることができないのです。
わたしたちが普段汎関数にお目にかかることが無いのは、$p(x)$として離散分布を考えたり、あるいは分布関数を制御するパラメータとして有限個の変数$\{\theta_i\}$を定めるなどすることで、考察の対象とする$p(x)$の種類を敢えて制限しているからです。 指数型分布族とよばれるような確率分布の類はまさにその最たる例です。 したがって上の話は、普通$p(x)$に制限をかけた上でやるはずの操作を統一的に行ったということに過ぎません。

念のためパラメータが有限の場合との主な対応関係を書いておきます。$\delta_{ij}$はKroneckerのデルタです。

$$ \begin{array}{rcl} \displaystyle (x,y) & \longleftrightarrow &\displaystyle (i,j) \\ \displaystyle \int_x p(x) q(x) & \longleftrightarrow &\displaystyle \boldsymbol{p}\cdot \boldsymbol{q}=\sum_i p_i q_i \\ \displaystyle \delta(x-y) & \longleftrightarrow &\displaystyle \delta_{ij} \\ \displaystyle \frac{\delta}{\delta p(x)} & \longleftrightarrow &\displaystyle \frac{\partial}{\partial p_i} \end{array} $$

離散確率分布$\boldsymbol{p}=(p_i)$を考える場合は汎関数微分は$\displaystyle \frac{\partial}{\partial \boldsymbol{p}}$に対応するのですが、パラメトリックな分布関数$p(x|\boldsymbol{\theta})$を考えたいならばこれは$\displaystyle \frac{\partial}{\partial \boldsymbol{\theta}}$となります。もちろん、パラメータとして関数$\varphi(x)$をもつような確率分布$p(x|\varphi)$を仮定するならば、やはり汎関数微分$\displaystyle \frac{\delta}{\delta \varphi(x)}$が現れます。
おそらく機械学習界隈の人にとっては、$p(x|\boldsymbol{\theta})$を用いた表記が一番馴染みのある書き方なのではないでしょうか。

Fisher情報計量🐟

ところで、先ほど定義した計量テンソル$g_p(x,y)$をKL divergenceの場合について計算すると、$f''(w)=\frac{1}{w}$なので、

$$ \begin{array}{lll} g_p(x,y) &=\displaystyle p(x)\frac{\delta(x-y)}{p(x)p(y)} \\ &=\displaystyle \int_v p(v) \frac{\delta}{\delta p(x)}L(v|p) \frac{\delta}{\delta p(y)} L(v|p) \\ &=\displaystyle \mathbb{E}\left[ \frac{\delta}{\delta p(x)}L(v|p) \frac{\delta}{\delta p(y)} L(v|p) \right] \\ \end{array} $$

となります。ただし、

$$ L(x|p) \equiv \int_z \delta(x-z) \log p(z) = \log p(x) $$

と定義しました。$L(x|p)$は対数尤度関数というやつです。
この$g_p(x,y)$をFisher情報計量(Fisher information metric)と呼びます。Fisher情報計量は情報幾何学においてuniqueなLevi-Civita接続をもたらすRiemann計量です。
異なる$f$によって定義されたdivergenceに対しても同様の計量テンソルが計算できましたが、これらもFisher情報計量の定数倍に一致していました。$D_f(p||q)$の$p \approx q$での振る舞いは、どの$f$に対しても定数倍を除いて同じなのです。

divergenceと幾何

ここまでf-divergenceについて色々と考察してきたのですが、divergenceという量そのものにきちんとした説明を与えるには、情報幾何学の力を借りる必要があります。
1つのdivergenceを考えるということはある確率分布のクラスのなす多様体に1つの接続(connection)を導入することを意味します。接続を与えるということは、多様体上の異なる点(確率分布)同士の関係をどのような対称性・不変性に基づいて特徴付けるかということに直結するので、divergenceの選び方はそれによって測ろうとしている確率分布のクラスと密接に関係してくるのです。数多のdivergenceの中でも頻繁にKL divergenceが取り沙汰されるのは、これが指数型分布族とよばれる最も多用される確率分布のクラスに関連付けられるからです。
divergenceが距離にならないのは、つまるところこれによって誘導される多様体上の接続がFisher情報計量を保たないからに他なりません。上で議論した通りFisher情報計量は、あくまで確率分布$p,q$同士が非常に「近い」場合における関係を述べるに過ぎません。そもそもf-divergenceに対応する計量はどれもFisher情報計量と高々定数倍の違いしかないのでした。「遠い」確率分布同士の関係を述べるうえで、Fisher情報計量が保たれる必然性はないのです。
このように、情報幾何学ではFisher情報計量よりもむしろdivergenceが重要な役割を果たします。ある確率分布のクラスに対して適切なdivergenceを与えることで、そこに有意義な幾何学的構造が生まれるのです*3

*1:$\int_x$は$\int dx$の省略です。

*2:(連続)無限次元の数学的対象に対する数学的に厳密な基礎付けは未だに確立していません。
場の量子論ではこういう数学が本質的なためやむを得ず汎関数を使っていますが、やはり至る所で数学的にデリケートな問題が生じるので、定性的な見方を加えることでどうにかこれを回避しているという状況です。

*3:確率分布全体のなす最も一般的な空間を考えれば情報に関する最も一般的な幾何学を与えるはずですが、これは言うまでもなく連続無限次元の空間になります。
ところが残念なことに、甘利さんによれば、確率分布全体のなす集合はきちんとした多様体にはならないそうです。ここでもやはり連続性に由来する数学的な困難が生じているのだと思います。