本文逐步分析 Gumbel-Softmax 的背景以及它所解决的问题。
1. 背景与问题描述
图片首先描述了一个分类问题,这是在机器学习,尤其是深度学习中非常常见的场景。分类问题的目标是预测某个输入样本属于哪一类。
假设模型的输出是一个概率分布,例如 [0.2, 0.4, 0.1, 0.2, 0.1],表示该样本属于5个类别的概率分别是20%、40%、10%、20%和10%。在分类任务中,通常我们会通过 argmax 选择概率最大的类别作为预测结果。在这个例子中,第二个类别的概率最大,所以 argmax 会选择类别2作为最终预测。
但是,这种方法仅适用于预测,即我们最终只关心哪个类别概率最大,并不涉及采样问题。
2. 采样问题
然而,当我们不只是需要预测某个具体类别,而是需要从这个概率分布中采样出一个类别时,问题就变得复杂了。采样的意思是我们要根据这些概率来随机选取一个类别,而不是总是选择概率最大的那个类别。
3. 采样的数学公式
最常见的采样方法使用了 onehot 和 max 函数,公式如下:
[ \mathbf{z} = \text{onehot}\left(\max\left{i \mid \pi_1 + \pi_2 + \cdots + \pi_{i-1} \leq u\right}\right) ]
公式解释:
- ( i ) 是类别的索引,取值范围为 ( 1, 2, \ldots, x )。
- ( u ) 是一个从均匀分布 ( U(0,1) ) 中随机采样的值,也就是说, ( u ) 是在0到1之间的一个随机数。
- ( \pi_1, \pi_2, \ldots ) 是类别的累计概率,代表从第一个类别开始,依次累加前面所有类别的概率。
步骤解释:
- 累积概率的范围不断增大,直到超过随机值 ( u )。
- 然后我们选择当前 ( i ) 所在的类别。
- 最终用
onehot编码方式将这个类别表示为一个向量,其中对应 ( i ) 类别的位置为1,其他位置为0。
4. 为什么 max 函数不可导?
采样过程中的这个 max 函数有一个致命的问题:它不可导(不可微分)。在机器学习模型,特别是在神经网络中,我们需要对模型的损失函数进行微分(求导)以进行反向传播并更新模型参数。然而,max 函数是一个离散的操作,它只能输出一个具体的索引,而无法对输入进行连续的微分,这就使得我们无法计算梯度,进而不能对模型进行优化。
5. Gumbel-Softmax 的作用
Gumbel-Softmax 提供了一种方法,能够在保留概率采样的同时,使得采样过程是可导的(可微的)。它通过加入Gumbel分布的噪声并结合Softmax操作,使得整个采样过程可以用连续的方式近似,从而可以进行梯度计算。这解决了采样过程中不可导的问题,使得我们能够在训练过程中对离散的类别变量进行有效的优化。
总结
- 采样问题:我们希望从一个概率分布中随机选取一个类别,而不是总是选择概率最大的那个。
- 传统方法的问题:通过
max和onehot操作进行采样虽然简单,但max操作不可导,无法用于梯度计算。 - Gumbel-Softmax 的贡献:它允许我们在保持采样性质的同时,使整个过程是可微分的,从而可以进行反向传播和参数优化。