变分推断介绍

Author: Steven Date: May 16, 2018 Updated On: May 12, 2020
Categories: 概率图模型
2.8k words in total, 11 minutes required.

本文部分内容总结自引用[1]

变分推断,是一种在概率图模型中进行概率推断的近似方法。相比于基于采样的随机化方法,变分推断是一种确定性逼近方法。更多关于概率图推断的介绍,可以参见概率图模型总览

1. 理论知识

变分推断的思想的要点可以概括如下:

  • 使用已知的简单分布来逼近需推断的复杂分布;
  • 限制近似分布的类型;
  • 得到一种局部最优、但具有确定解的近似后验分布

简单地说,原始目标是根据已有数据推断需要的分布p;当p不容易表达、不能直接求解时,可以尝试用变分推断即,寻找容易表达和求解的分布q,当qp的差距很小的时候(技术上而言,是KL散度距离最小),q就可以作为p的近似分布,成为输出结果。

以下先结合文献[2] 15.5.2节中的一个例子来解答下变分推断的学习目标、及其在学习任务中具体的思想和用途。

假定隐变量z直接和N个可观测的变量x=x1,,xN相连,那么,所有可观测的变量的联合分布的概率密度函数可以表示为:

p(xΘ)=Ni=1zp(xi,zΘ)

对数似然可以写为:

lnp(xΘ)=Ni=1ln{zp(xi,zΘ)}

那么,上述例子中的推断和学习任务分别是在给定观测样本x的情况下计算出概率分布p(zx,Θ)和分布的参数Θ

在含有隐变量z时,上述问题的求解可以使用EM算法:

  • E步,根据t时刻参数Θtp(zx,Θt)进行推断,并对以上的联合似然函数p(x,zΘ)进行计算;
  • M步,基于E步计算的结果进行最大化寻优,即在z被当前参数和观测确定的情况下,对上述的对数似然求最大化

    Θt+1=argmaxΘQ(Θ;Θt)=argmaxΘzp(zx,Θt)lnp(x,zΘ)

    上式中,最大化的一项,实际上是对数联合似然函数lnp(x,zΘ)在分布p(zx,Θt)下的期望。

    当分布p(zx,Θt)和变量z的后验分布相等时,上式中最大化的期望值Q(Θ;Θt)可以近似于对数似然函数。

因此,通过E步和M步的迭代,最终可以获得稳定参数Θ,从而也可以获得z的分布。

但是,p(zx,Θt)未必一定是z的真实分布,而是一个近似值。若将近似分布表示为q(z),则可以有下列公式成立:

lnp(x)=L(q)+KL(qp)
L(q)=q(z)ln{p(x,z)q(z)}dz
KL(qp)=q(z)lnp(zx)q(z)dz

上述公式看起来很复杂,但是试着把后面两个公式带入到上面公式中,就可以发现其实是贝叶斯公式p(x)=p(x,z)p(zx)在符合q分布的变量z在积分上的一种表达形式。

接下来,考虑到z可能模型复杂而难以完成E步中p(zx,Θt)的推断,此时,就可以借助变分推断,假设z服从一个简单的分布:

q(z)=Mi=1qi(zi)

即假设复杂的多变量z可拆解为一系列相互独立的多变量zi。并且,还可以假设每个分布qi相对简单或有很好的结构。
考虑到上述对数似然的形式,我们假设每个分布符合指数族分布(易于积分求解),那么对于每一个独立的变量子集zj,其最优的分量分布qj应该满足:

lnqj(zj)=Eij[lnp(x,z)]+const
Eij[lnp(x,z)]=p(x,z)ijqidzi
const为一个常数。

对上述公式进行转换,可以得到一个最优分量分布(最接近真实情形)的表达式:

qj(zj)=exp(Eij[lnp(x,z)])/(exp(Eij[lnp(x,z)])dzj

通过上式可以看出,在对变量zj的最优分布qj估计时,融合了除zj外其他变量zij的信息,这是通过联合似然函数p(x,z)zj之外的隐变量求期望得到的,因此变分推断也被成为平均场(mean field)逼近方法。

实践中对于变分推断的使用:

  • 首先,对隐变量进行拆解,假设各个分量服从何种分布
  • 再利用上述最优分布求解,对隐变量的后验概率分布进行估计
  • 通过EM方法迭代求解,得到最终概率图模型的推断和参数估计

2. 通过实例进行理解

变分推断的思想,即采用简易的分布来近似复杂的隐变量分布,从而实现在观测变量之下,通过EM方法迭代对隐变量和观测变量的联合概率分布的参数估计进行求解。当然,变分推断的公式推导有些难以理解,虽然出现了KL散度的概念,但目前为止,如何实现对其的优化来完成分布的逼近依然没有得到解释,这一节,我们通过对于引用[1]中的回答进行review来进行理解。

2.1 简易理解

变分推断中分布逼近的示例变分推断中分布逼近的示例

上图中,为了对原始目标分布p进行求解,我们选择了两个高斯分布(简易好解释),来分别衡量它们和目标分布的相似性,并选择相似性高的分布来逼近p

2.2 求解思路

理解变分推断的步骤:

  1. 拥有两部分输入:数据x,模型p(z,x)
  2. 需要推断的是后验概率p(zx),但不能直接求
  3. 构造后验概率p(zx)的近似分布q(z;v)
  4. 不断缩小qp之间的距离直至收敛 - 使用EM算法

以下分别解释下上述4个步骤中重要的问题。

2.2.1 模型和输入确定

变分推断要解决的问题,简单来说,专家利用他们的知识,给出合理的模型假设p(z,x),其中包括隐含变量z和观察值变量x。隐含变量z在通常情况下不止一个,并且相互之间存在依赖关系,这也是问题难求解的原因之一。

为了理解隐含变量和观察值的关系,一个很重要的概念叫做“生成过程模型”。我们认为,观察值是从已知的隐含变量组成的层次结构中生成出来的。

以高斯混合模型问题举例。我们有5个相互独立的高斯分布,分别从中生成很多数据点,这些数据点混合在一起,组成了一个数据集。当我们转换角度,单从每一个数据点出发,考虑它是如何被生成的呢?生成过程分两步,第一步,从5个颜色类中选一个(比如粉红色),然后,再根据这个类对应的高斯分布,生成了这个点在空间中的位置。隐含变量有两个,第一个是5个高斯分布的参数u,第二个是每个点属于哪个高斯分布cuc共同组成隐含变量zuc之间也存在依赖关系。

2.2.2 后验概率求解

后验概率p(zx)即基于现有数据集合x,推断隐含变量的分布情况。

利用高斯混合模型的例子来说,就是求得每个高斯分布的参数u的概率和每个数据点的颜色的概率c。根据贝叶斯公式,p(zx)=p(z,x)/p(x)。根据专家提供的生成模型,可知p(z,x)部分(可以写出表达式并且方便优化),但是边缘概率p(x)是不能求得的,当z连续时边缘概率需要对所有可能的z求积分,不好求。当z离散时,计算复杂性随着x的增加而指数增长。

2.2.3 近似逼近后验概率

此时需要构造q(z;v),并且不断更新v,使得q(z;v)更接近p(zx)q(z;v)意思是z是变量,vz的概率分布q的参数。所以在构造q的时候也分两步,第一,概率分布的选择。第二,参数的选择。

第一步,我们在选择q的概率分布时,通常会直观选择p可能的概率分布,这样能够更好地保证qp的相似程度。

例如高斯混合模型中,原始假设p服从高斯分布,则构造的q依然服从高斯分布。

第二步,通过改变v,使得q不断逼近p

变分推断等价于对于KL散度的优化变分推断等价于对于KL散度的优化

2.2.4 优化问题的求解

优化目标很明确,减小KL散度的值即可。由于KL的表达式中依然有一部分不可求的后验概率,可以使用ELBO(Evidence Lower BOund)来进行替代,ELBO中只包括联合概率p(z,x)q(z;v),从而摆脱后验概率。

lnp(x)=L(q)+KL(qp)
L(q)=q(z)ln{p(x,z)q(z)}dz
KL(qp)=q(z)lnp(zx)q(z)dz

ELBO就是L(q)!!!给定数据集后(x),最小化KL等价于最大化ELBO,因此ELBO的最大化过程结束时,对应获得的q(z;v),就成为了最后输出。

对于L(q)的最大化,就是对z进行拆解、假设其各分量分布简单的情况下完成的。具体推导,可以参见引用[2] 14.5.2节中的内容。

引用


  1. 1.如何简单易懂地理解变分推断.
  2. 2.《机器学习》,周志华著,清华大学出版社.