Variational Inference

参考内容:Variational Inference: A Review for Statisticians、《深度学习》

背景

从贝叶斯派角度来看,对于一个问题可以分为贝叶斯推断和贝叶斯决策两个部分。由推断得到\(P(\theta|x)\),再据此进行决策/预测。从另一个角度来说,这也是一个Encoding的过程。
变分推断是一种近似推断的方法。在概率模型中,我们往往需要对难以计算或计算开销极大的分布进行近似。从贝叶斯派的角度来看,对未知量的推断实际上是对后验概率的计算,MCMC就是一种计算的方法。但相较于MCMC,变分推断在大数据下更加快速。
至于变分,我们把“自变量也是函数的函数”称为泛函,求泛函的极值问题称为变分问题。

基本思路

对于一个一般的问题,我们考虑一个观测变量\(\mathbf{x}=x_{1: n}\)和隐变量\(\mathbf{z}=z_{1: m}\)的联合概率分布 \[ p(\mathbf{z}, \mathbf{x})=p(\mathbf{z}) p(\mathbf{x} \mid \mathbf{z}) \] 从贝叶斯统计的角度而言,隐变量的分布能确定数据的分布。通过先验分布\(p(\mathbf{z})\)得到隐变量,并借助似然\(p(\mathbf{x} \mid \mathbf{z})\)得到观测变量

Inference in a Bayesian model amounts to conditioning on data and computing the posterior \(p(\mathbf{z} \mid \mathbf{x})\). In complex Bayesian models, this computation often requires approximate inference.

对于MCMC,我们首先基于\(z\)构造一个马尔可夫链 (其平稳分布就是\(p(\mathbf{z} \mid \mathbf{x})\)),而后从分布中采样,借此作为后验分布的近似。而变分推断的主要思想是将近似转化为一个优化问题,即找到一个分布(from a family of approximate densities)与后验分布的KL divergence尽可能小。

数学推导

变分推断的目标是在已知观测变量的情况下近似得到隐变量的条件密度。其核心思想是用优化的方法来解决这个问题。我们借助分布族,以“变分参数”参数化来得到隐变量的分布。而最优的变分参数则根据KL散度在分布族中寻找。最终将这一分布作为精确的条件密度的近似。

\(\mathbf{x}=x_{1: n}\)为观测变量,\(\mathbf{z}=z_{1: m}\)为隐变量,目标是计算给定观测变量下隐变量的条件概率分布\(p(\mathbf{z} \mid \mathbf{x})\)\[ p(\mathbf{z} \mid \mathbf{x})=\frac{p(\mathbf{z}, \mathbf{x})}{p(\mathbf{x})} \] 上式分母我们称之为evidence,可以借助联合概率密度进行计算 \[ p(\mathbf{x})=\int p(\mathbf{z}, \mathbf{x}) \mathrm{d} \mathbf{z} \] 在贝叶斯模型中,我们将模型所有的未知参数都表示为隐变量(latent variables),而上面evidence的积分在很多情况下需要指数时间,而不计算evidence又没法得到我们所需的条件概率,所以这也是变分推断的困难所在。

ELBO(The evidence lower bound)

在变分推断中,我们指定一个隐变量的分布族\(\mathscr{Q}\),其中每一个\(q(\mathbf{z}) \in \mathscr{Q}\)都是条件概率分布近似的候选项,而我们的目标就是找到一个最优,也就是KL散度最小的。因此我们的推断问题就对应于一个优化问题: \[ q^{*}(\mathbf{z})=\underset{q(\mathbf{z}) \in \mathscr{Q}}{\arg \min } \mathrm{KL}(q(\mathbf{z}) \| p(\mathbf{z} \mid \mathbf{x})) \] 同样,上式是无法直接计算的,因为 \[ \mathrm{KL}(q(\mathbf{z}) \| p(\mathbf{z} \mid \mathbf{x}))=\mathbb{E}[\log q(\mathbf{z})]-\mathbb{E}[\log p(\mathbf{z} \mid \mathbf{x})] \] \[ \mathrm{KL}(q(\mathbf{z}) \| p(\mathbf{z} \mid \mathbf{x}))=\mathbb{E}[\log q(\mathbf{z})]-\mathbb{E}[\log p(\mathbf{z}, \mathbf{x})]+\log p(\mathbf{x}) \]\(\log p(\mathbf{x})\)也就是evidence就是我们要绕过的。
因为我们没法直接计算KL散度,所以将优化目标进行转化 \[ \operatorname{ELBO}(q)=\mathbb{E}[\log p(\mathbf{z}, \mathbf{x})]-\mathbb{E}[\log q(\mathbf{z})] \] 上式就是evidence lower bound(ELBO),它等于负KL-divergence 加上\(\log p(\mathbf{x})\),最小化KL散度就相当于最大化ELBO,因为虽然\(\log p(\mathbf{x})\)无法求得,但在给定\(q(\mathbf{z})\)下它是一个常数,所以对\(q(\mathbf{z})\)的导数为0.

\[ \begin{aligned} \operatorname{ELBO}(q) &=\mathbb{E}[\log p(\mathbf{z})]+\mathbb{E}[\log p(\mathbf{x} \mid \mathbf{z})]-\mathbb{E}[\log q(\mathbf{z})] \\ &=\mathbb{E}[\log p(\mathbf{x} \mid \mathbf{z})]-\mathrm{KL}(q(\mathbf{z}) \| p(\mathbf{z})) \end{aligned} \] 重写ELBO后可以看到,第一项是似然的期望,它促使模型将它的隐藏变量集中于可以解释观察数据的配置上,第二项是隐变量变分分布与先验分布的KL divergence的相反数,它促使变分分布接近于先验分布,所以变分模型的目标函数是似然函数与先验分布的一种平衡。此外,上式的第一项正是EM算法中优化的expected complete log-likelihood。
为什么这式被称为evidence lower bound?
因为KL散度的非负性,所以对于任何\(q(\mathbf{z})\)都有\(\log p(\mathbf{x}) \geq \operatorname{ELBO}(q)\)

The mean-field variational family

将优化目标转化为最大化ELBO之后,我们要选择合适的变分族,而这一函数族的复杂度也决定了优化的复杂度。在这里我们引入平均场变分族(mean-field variational family),它假设隐变量之间都是相互独立的,这一假设是一个较强的假设,但我们可以将相互关联的隐变量归为一组,进而转化为相互独立的各组隐变量。引入这一假设之后就有: \[ q(\mathbf{z})=\prod_{j=1}^{m} q_{j}\left(z_{j}\right) \] 变分族的选择是与观测数据\(x\)无关的,在ELBO或者KL散度中才将数据与隐变量建立联系。
根据条件概率分布的链式法则我们有\(p\left(z_{1: m}, x_{1: n}\right)=p\left(x_{1: n}\right) \prod_{j=1}^{m} p\left(z_{j} \mid z_{1:(j-1)}, x_{1: n}\right)\)
变分分布的期望为\(E\left[\log q\left(z_{1: m}\right)\right]=\sum_{j=1}^{m} E_{j}\left[\log q\left(z_{j}\right)\right]\),代入ELBO的定义式就有 \[ E L B O=\log p\left(x_{1: n}\right)+\sum_{j=1}^{m} E\left[\log p\left(z_{j} \mid z_{1:(j-1)}, x_{1: n}\right)\right]-E_{j}\left[\log q\left(z_{j}\right)\right] \]

Coordinate ascent mean-field variational inference

CAVI是针对该优化问题的一个常用算法,它迭代优化每一项隐变量变分分布,保持另外的不变。具体而言,当我们将ELBO对\(z_{k}\)求导并令导数为0时,有: \[ \frac{d E L B O}{d q\left(z_{k}\right)}=E_{-k}\left[\log p\left(z_{k} \mid z_{-k}, x\right)\right]-\log q\left(z_{k}\right)-1=0 \] 也就是说 \[ q_{j}^{*}\left(z_{j}\right) \propto \exp \left\{\mathbb{E}_{-j}\left[\log p\left(z_{j} \mid \mathbf{z}_{-j}, \mathbf{x}\right)\right]\right\} \]

从ELBO导数到具体坐标上升法的更新法则需要较冗长的推导过程,在此省略了。事实上对于变分推断最出圈的变分自编码器中也不涉及梯度变分的步骤,因此这里也就不具体介绍。

再回首

最先接触变分推断是在大二的暑期,而现在再去回忆起变分推断时,首先想到的便是论文中intractable和estimation这两个词。在我看到这整个过程的核心点在于后验无法直接计算,进而利用另一个分布\(q\)进行近似,以及后续的优化问题和ELBO。无独有偶,EM算法中的核心推导部分也设计到了其中的ELBO项。这些算法背后都有着坚实而严谨的数学支撑,而我关注的另一个近似推断的算法SVGD也是由一个早已被数学家提出的stein discrepancy产生的想法。虽然在生成模型上GAN的风头早已盖过了VAE,但仅关注模型跑分的人又怎么会理解VAE中的奇妙的思想。周志华老师在深度森林的路上走出了自己独特的道路,相信未来AI的发展也不会仅限于深度神经网络这一条道路。

I do not accept rewards, but you can donate to the public welfare of China.
0%