不幸的是,从 p_θ(x) 取出样本是很困难的,因此必须借助 MCMC 来使用梯度估计器。最早期的一些 EBM 就是使用这种方法训练的。
尽管这样的小进展已经积累了很长时间,但最近有些研究开始使用这种方法来在高维数据上训练大规模 EBM,而且使用了深度神经网络来对其进行参数化。近期的这些成功使用基于随机梯度 Langevin 动态(SGLD)的采样器,结果已经接近等式 (2) 的预期,其取出样本的方式为:
新提出的基于联合能量的模型(Joint Energy Based Model)
在现代机器学习中,有 K 个类别的分类问题通常是使用一个参数函数来解决,即 f_θ : R^D → R^K,其能将每个数据点 x ∈ R^D 映射成被称为 logit 的实数值。使用所谓的 softmax 迁移函数,可将这些 logit 用于对类别分布执行参数化:
研究者在本文中给出了一个关键性的观察,即也可以略微重新解读从 f_θ 获得的 logit 来定义 p(x,y) 和 p(x)。无需改变 f_θ,可通过下式复用这些 logit 来为数据点 x 和标签 y 的联合分布定义一个基于能量的模型:
通过将 y 边缘化,也可为 x 获得一个非归一化的密度模型:
注意,现在任意分类器的 logit 都可被重新用于定义数据点 x 处的能量函数:
由此,研究者就找到了每个标准的判别模型中隐藏的生成模型!因为这种方法提出将分类器重新解读为基于联合能量的模型(Joint Energy based Model),所以他们将该方法称为 JEM。
下图 1 给出了该框架的概况,其中分类器的 logit 会被重新解读,以定义数据点和标签的联合数据 密度以及数据点单独的数据密度。
图 1:新方法 JEM 的可视化,其可从分类器架构定义一个联合 EBM
优化
那么,这种对分类器架构的新解读方法能在保留模型强大判别能力的同时也获得生成模型的优势吗?
因为 p(y|x) 的模型参数化是相对 y 进行归一化的,因此最大化其似然是很简单的,就如同在标准的分类器训练中一样。又因为 p(x) 和 p(x, y) 的模型未归一化,因此最大化它们的似然并不容易。在这样的模型下,以最大化数据的似然为目标来训练 f_θ 的方法有很多。我们可以将等式 (2) 的梯度估计器应用于等式 (5) 的联合分布下的似然。使用等式 (6) 和 (4),可将该似然分解为:
鉴于这项研究的目标是将 EBM 训练整合进标准的分类设置中,所涉分布为 p(y|x)。因此,研究者提出使用等式 (8) 的因式分解来确保该分布的优化使用的目标是无偏差的。他们使用了标准的交叉熵来优化 p(y|x),使用了带 SGLD 的等式 (2) 来优化 log p(x),其中梯度是根据
得到的。
应用
为了展示 JEM 相比于标准分类器的优势,研究者进行了全面的实验研究。首先,新方法的表现在判别式建模和生成式建模上都与当前最佳方法媲美。更有意思的是,他们还观察到一些与判别式模型的实际应用相关的好处,包括不确定性量化的改善、对分布外数据的检测、对对抗样本的鲁棒性。人们很久以前就预期生成模型能够提供这些好处,但从来没有在这样的规模上展现这一点。
实验中使用的所有架构都基于 Wide Residual Networks,其中移除了批归一化以确保模型的输出是输入的确定性函数。这将 WRN-28-10 在 CIFAR-10 上的分类误差从 4.2% 提升到了 6.4%,将其在 SVHN 上的分类误差从 2.3% 提升到了 3.4%。
所有的模型都是用同样的方法训练的,它们的超参数也都一样,都是在 CIFAR-10 上调节得到的。有趣的是,这里找到的 SGLD 采样器参数可以在各种数据集和模型架构上实现很好的泛化。此外,所有模型都在单个 GPU 上训练完成,耗时大约 36 小时。
混合建模
首先,研究者表明给定的分类器架构可以作为 EBM 训练,而且能同时实现与分类器和生成模型都相媲美的表现。他们在 CIFAR-10、SVHN 和 CIFAR-100 上训练了 JEM,并与其它混合模型以及单独的生成模型和判别模型进行了比较。结果发现 JEM 能在两个任务上同时取得接近最佳表现的结果,优于其它混合模型(下表 1)。
表 1:CIFAR-10 混合建模的结果。
鉴于这种方法无法计算归一化的似然,所以研究者提出使用 inception 分数(IS)和 Frechet Inception Distance(FID)来表示结果的质量。结果发现,JEM 能在这些指标上与当前最佳的生成模型相媲美。新提出的模型在 SVHN 和 CIFAR-100 上分别实现了 96.7% 和 72.2% 的准确度。下图 2 和 3 展示了 JEM 的样本。
图 2:CIFAR-10 类-条件样本。
图 3:类-条件样本。
JEM 的训练目标是最大化等式 (8) 中的似然因式分解。这是为了确保不会把偏差加进 log p(y|x) 的估计中,这在新提出的设置中可以确切地计算出来。在控制变量研究中,为最大化这一目标而训练的 JEM 的判别性能有显著的下降(见表 1 第 4 行)。
校准
如果一个分类器的预测置信度 max_y p(y|x) 与其误分类率是一致的,那么就认为这个分类器是已校准的。因此,当一个经过校准的分类器以 0.9 的置信度预测标签 y 时,它应该有 90% 的几率是正确的。对于要在真实世界场景中部署的模型而言,这是一个非常重要的特性,因为在现实场景中,不正确的决策输出可能造成灾难性的后果。在实际应用时,经过良好校准但不够准确的分类器可能比更准确但校准差的模型更加有用。研究者发现 JEM 能在显著提升分类性能的同时维持较高的准确度。
研究者重点关注了在 CIFAR-100 上的表现,因为当前最佳的分类器的准确度大约为 80%。他们在这个数据集上训练了 JEM,并将其与没有 EBM 训练的同样架构的基准进行了比较。基准模型得到的准确度为 74.2%,JEM 得到的准确度为 72.2%(参考一下,ResNet-110 得到的准确度为 74.8%)。下图 4 给出了结果。
图 4:CIFAR-100 校准结果。ECE 是指预期校准误差。
检测分布外数据
通常而言,分布外(out-of-distribution,OOD)检测是二元分类问题,模型的目标是得到一个分布 s_θ(x) ∈R,其中 x 是查询,θ 是可学习参数的集合。有很多不同的 OOD 检测方法都可以使用 JEM。
输入密度
如下表 2 第 2 列所示,JEM 为分布内数据分配的似然总是比 OOD 数据高。JEM 相比于 IGEBM 进一步提升的一个可能解释是其有能力在训练过程中整合有标注的信息,同时还能推导 p(x) 的一个原理模型。
表 2:OOD 检测的直方图。所有模型都是在 CIFAR-10 上训练的。绿色对应于在分布内 CIFAR-10 数据上的分数,红色对应在 OOD 数据集上的分数。
预测分布
很多成功方法都为 OOD 检测使用了分类器的预测分布。JEM 是一种很有竞争力的分类器,实验发现其表现足以媲美优秀的基准分类器,并且显著优于其它生成模型。下表 3 给出了结果(中行 )。
表 3:OOD 检测结果。所测模型是在 CIFAR-10 训练的,结果是 AUROC 指标。
一种新分数:近似质量(Approximate Mass)
对于在经典数据集之外的高似然数据点,研究者预期其周围的密度会快速变化,因此其对数密度的梯度范数相比于经典数据集中的样本会很大(否则它会处于高质量的区域)。基于这一数量,他们提出了一种新的 OOD 分数:
对于 EBM(JEM 和 IGEBM),研究者发现这种预测器的表现显著优于我们自己的和其它的生成式模型的似然——见表 2 第 3 列。对于易处理的似然方法,他们发现这种预测器与模型的似然是反相关的(anti-correlated),它们对 OOD 检测而言都不可靠。结果见表 3(底行)。
鲁棒性
作者使用了一种基于梯度的优化流程来生成样本,从而激活特定的高层面网络激活,然后优化网络的权重以最小化所生成的样本对该激活的影响。围绕数据,对抗训练和网络激活的梯度的正则化之间的进一步关联已经被推导出来。
有了这些关联,人们可能会疑惑从 EBM 推导出来的分类器是否比标准模型能更稳健地处理对抗样本。类似地,作者发现 JEM 能在无损判别性能的前提下实现相当不错的稳健性。
通过 EBM 训练提升鲁棒性
在基于 CIFAR-10 训练的模型上,研究者执行了大量强力的对抗攻击。他们执行了一次白盒 PGD 攻击,通过采样流程向攻击者提供了对梯度的访问权。另外,研究者还执行了一些无梯度的黑盒 攻击、边界攻击和暴力式逐点攻击。下图 5 给出了 PGD 实验的结果。所有的攻击都是针对 L2 和 L∞ 范数进行的,他们测试了在输入中执行 0、1、10 步采样的 JEM。
实验表明,新模型的鲁棒性显著优于使用标准分类器训练得到的基准模型。在这两个范数上,JEM 的表现与当前最佳的对抗训练方法相当(但略差一些),也和 Salman et al. (2019) 提出的当前最佳的经过认证的鲁棒性方法(图 5 中的 RandAdvSmooth)相媲美。
图 5:使用 PGD 攻击的对抗稳健性结果。JEM 能带来相当可观的鲁棒性提升。
鲁棒性不强模型的另一种常见失败模式是它们往往会以高置信度分类无意义的输入。为了分析这一性质,研究者遵照 Schott et al. (2018) 的方法进行了测试。下图 6 给出了结果。基准方法会有信心地分类非结构化的噪声图像。JEM 不能有信心地分类无意义的图像,所以可以明显看到图中出现了汽车属性和自然图像属性。
图 6:远端对抗(Distal Adversarials)结果。
机器之心「SOTA模型」:22大领域、127个任务,机器学习 SOTA 研究一网打尽。
视频
直播
美图
博客
看点
政务
搞笑
八卦
情感
旅游
佛学
众测