从基础理解扩散模型的连续时间变分目标

让我逐步讲解扩散模型的基本概念,直到弄清楚“连续时间变分目标是交叉熵损失的加权积分”这一结论的含义。

1. 概率生成模型的基础

生成模型的目标

生成模型试图学习数据的概率分布 ,以便能够从这个分布中采样新的数据点。在语言模型的情况下,我们希望生成看起来像真实文本的新句子。

最大似然估计

训练生成模型的常见方法是最大似然估计,即最大化观测数据的概率:

  • 给定数据集
  • 我们希望找到模型参数 ,使得 最大化
  • 或等价地,最大化对数似然

2. 变分推断与证据下界(ELBO)

复杂模型的挑战

直接计算 在很多模型中是不可行的,特别是有潜变量的模型。

变分推断的基本思想

变分推断引入一个近似后验分布 ,用来近似真实后验分布

  1. 我们将对数似然分解为:
  2. 其中 ELBO (Evidence Lower BOund) 是:
  3. 由于 KL 散度始终非负,ELBO 提供了对数似然的下界:
  4. 最大化 ELBO 相当于:
    • 最大化重构项
    • 最小化 KL 散度

3. 扩散模型的基本原理

扩散模型的思想

扩散模型定义了一个逐渐将数据转换为噪声的前向过程,然后学习其逆过程:

  1. 前向过程(扩散过程)
    • 从数据 开始
    • 逐步添加噪声,得到
    • 最终达到近似纯噪声的状态
  2. 反向过程(生成过程)
    • 从噪声 开始
    • 逐步去噪,得到
    • 最终得到生成样本
离散时间扩散模型

在离散时间扩散中:

  • 定义 个时间步
  • 前向过程
  • 反向过程
离散扩散的 ELBO

对于离散时间扩散模型,ELBO 通常可以写作(一种常见形式):

(注:上面提供的原始公式可能略有不同,这里展示了一个更标准的离散ELBO分解形式。如果严格按照原文的 ∑ᵢ₌₂ᵀ,则公式为:)

4. 连续时间扩散模型

从离散到连续

当时间步 趋向无穷时,我们可以得到连续时间扩散过程:

  • 离散时间步 变为连续时间
  • 离散转移概率变为随机微分方程(SDE)或常微分方程(ODE)
连续时间 ELBO

在连续时间极限下,ELBO 变为积分形式(示意性,具体形式依赖于SDE/ODE选择):

(注:[...] 部分代表与得分匹配或类似项相关的期望。)

5. 掩码扩散模型的特点

掩码扩散的前向过程

在掩码扩散中,前向过程不是添加高斯噪声,而是将令牌(token)替换为特殊的掩码标记:

  • 数据 中的每个令牌以某个概率被替换为掩码
  • 概率随时间增加,到时间 时几乎所有令牌都被掩码
离散状态空间
  • 每个令牌有 个可能的状态(词汇表大小)
  • 加上掩码状态,共 个状态
  • 前向过程定义为特定的(通常是均匀的)转移矩阵

6. 论文中的连续时间变分目标

掩码扩散的 ELBO 简化

论文作者通过数学推导,将掩码扩散的连续时间 ELBO (记为 ) 简化为:

(注:这里使用了更标准的 LaTeX 格式,, , , , , , , 。)

这个公式的解读
  1. :指示函数,只有当 处的令牌被掩码(状态为 )时才为 1,否则为 0。这意味着损失只在被掩码的位置计算。
  2. :这是一个(多类别)交叉熵损失项。 是模型在时间 预测的原始令牌(在 中)的概率分布, (通常表示为 one-hot 向量) 是真实的原始令牌。该项衡量了模型预测与真实令牌之间的差距。
  3. :这是一个与时间相关的权重函数,由掩码调度 (通常表示未被掩码的概率)决定。它对不同时间点的交叉熵损失进行加权。
  4. 积分 :在整个连续时间轴 上累积(积分)加权的期望交叉熵损失。
简化的实际意义
  1. 直观理解:模型本质上是在学习一个去噪(或去掩码)任务:给定一个部分被掩码的句子 和时间 ,预测被掩码位置的原始令牌
  2. 计算效率:最终的损失函数形式简洁,训练时通常只需要对去噪模型 进行一次评估,然后根据 计算损失。
  3. 训练稳定性:这种形式避免了在原始 ELBO 中可能出现的高方差估计问题。

7. 交叉熵损失的加权积分

现在我们可以清晰地理解原句的含义:“连续时间变分目标是交叉熵损失的加权积分”:

  1. 变分目标 - 指的是 ,它是最大化对数似然 的(负)证据下界(或等价地,最小化这个目标)。
  2. 交叉熵损失 - 由项 代表,它衡量了模型在掩码位置预测的令牌分布与真实令牌之间的差异。
  3. 加权 - 通过时间依赖的系数 实现,该系数根据掩码调度调整不同时间点损失的重要性。
  4. 积分 - 通过 实现,将所有时间点上的加权交叉熵损失(在期望下)累加起来,形成最终的总变分目标。

这一简洁而优美的表达式是相关研究(如 D3PM、MaskGIT 等思想的连续化)的重要理论贡献,因为它将复杂的变分推断过程与直观的交叉熵损失联系起来,极大地简化了掩码式离散扩散模型的理论分析和训练实现,并促进了其性能的提升。