———— 相逢意气为君饮,系马高楼垂柳边。

概率图模型在机器学习中应用非常广泛,近年来随着深度学习的发展更加流行,本文着重介绍两种有向图模型,分别是变分自编码器 (variational auto-encoder) 和生成对抗网络 (generative adversarial network)。

变分自编码器 (Variational Auto-Encoder)

本文从应用的角度讲起,变分自编码器 (VAE) 和生成对抗网络 (GAN) 常用来生成数据。首先得说一说自编码器 (auto-encoder) 模型,自编码器有编码器和解码器两部分组成。给定一批原始数据,比如图片或文本序列等,我们将原始数据通过编码器映射成一个隐藏空间的潜在变量,然后再通过一个解码器将潜在变量映射回原空间试图恢复数据,也可以看做是生成样本。

自编码器是基于原始数据直接采样,这种方式生成的样本和原始数据并无太大差异,针对这一缺陷,变分自编码器直接对原始数据的分布进行采样。既然有变分,那么一定少不了 KL 散度的计算,因为我们的目的是对数据分布进行采样,所以就需要知道这个数据分布,在已知一系列样本 [公式] 的前提下,去推出变量 [公式] 的分布,即 [公式] ,根据贝叶斯公式,有

[公式]

直接计算这个后验概率是很困难的,根据变分法,我们假设一个 [公式] 的分布 [公式] 去逼近真实分布 [公式] ,用 KL 散度度量这个两个分布的距离,然后最小化即可。

[公式]

上式第二步到第三步根据归一化条件 [公式] 所得,最终所得的公式第一项 [公式]  [公式] 无关,第二项可以写成 KL 散度的形式 [公式] ,第三项是似然概率的负对数 [公式]  [公式] 服从分布 [公式] 时的数学期望 [公式] 。这个结果和前面贝叶斯神经网络的结论是类似的,最小化 [公式] ,相当于分别最小化 [公式]  [公式] ,前者是为了让我们假设的概率分布 [公式] 和先验分布 [公式] 尽可能接近,后者可以通过蒙特卡罗采样实现,在分布 [公式] 上多次采样,然后使被重构的样本中重构 [公式] 的几率最大,这就是变分自编码器的全过程。

生成对抗网络 (Generative Adversarial Network)

生成对抗网络 (GAN) 是一种深度学习框架,包含生成模型和判别模型两部分,和变分自编码器一样,它也是用来生成数据的,只不过是通过相互对抗的方式训练模型的,就像是左右互搏一样。首先生成模型生成与样本分布一致的数据,目标是欺骗判别模型,让判别模型认为生成的数据是真实的;判别模型则需要分辨出生成模型生成的数据与真实样本。不断重复训练,最终使得生成模型能够生成非常逼近真实样本的数据,令判别模型无法区分。生成对抗网络在图像,文本,语音等数据上的应用都很广泛,是近几年最热门的模型之一。

生成模型的输入是低维度的随机噪声,输出是表示数据的高维度张量。判别模型的输入是高维度的数据张量;输出是判别向量。在训练阶段,生成模型输出的数据样本会输入给判别模型,由判别模型判断生成的数据是否已经足够逼近真实的样本数据。模型训练完成之后,在测试阶段就可以给生成模型输入低维度的随机噪声,让生成模型输出高维度的样本数据。

训练过程中生成模型 (G) 和判别模型 (D) 相互对抗,用 L(G,D) 代表损失函数,其中判别模型试图最小化误差,生成模型试图最大化误差。最终的误差函数如下:

[公式]

其中 [公式] 是样本数据, [公式] 是随机噪声。可知生成模型和判别模型对误差都有影响,任何一个变动都会导致误差变动,所以,GAN 是采用交替训练的方法来训练的,即固定一个模型,训练另外一个模型。在训练过程中,先固定生成模型的参数,然后优化判别模型的参数,首先让生成模型生成一批样本数据,将它们标记为生成样本,然后与真实的样本数据 [公式] 一起输入判别模型。判别模型的任务是将二者区分开,我们可以看做一个二分类问题,使用卷积神经网络等模型很容易完成。经过训练,判别模型就能够将真实样本与生成的数据区分开,可得到一个判别模型 D1;然后固定判别模型 D1 的参数,优化生成模型的参数。生成模型的优化目标是降低判别模型的准确率,所以应根据判别模型的辨别结果调整生成模型的参数,直到生成模型能够产生让判别模型 D1 无法区分的生成数据,至此可得到一个生成模型 G1。然后循环执行固定一个模型参数优化另一个模型的操作,交替训练升级生成模型和判别模型,每经过一轮训练,就会提高一些模型的准确率,最终得到生成模型 G2、G3、G4、⋯、Gn 和对应的判别模型 D2、D3、D4、⋯、Dn。经过以上 n 轮训练,不管是生成模型还是判别模型的性能都会得到极大的提升,判别模型能够区分稍有瑕疵的生成数据。为了能够欺骗判别模型,生成模型必须能够生成能够以假乱真的数据,最终 GAN 具备了生成足够逼真的高维度数据的能力。

参考资料:读懂生成对抗神经网络 GAN,看这文就够了