第 3 部分 - 面向图像、视频和时间序列的 Mamba 状态空间模型

添加图片注释,不超过 140 字(可选)

米amba,这个被认为可以取代强大的 Transformer 的模型,从最初在深度学习中使用状态空间模型 (SSM) 的想法已经走了很长一段路。

Mamba 为状态空间模型添加了选择性,从而实现了与 Transformer 类似的性能,同时保持了 SSM 的亚二次工作复杂度。其高效的选择性扫描比标准实现快 40 倍,并且与 Transformer 相比,吞吐量可提高 5 倍。

与我一起深入研究 Mamba,我们将发现选择性如何解决以前 SSM 的局限性,Mamba 如何克服这些变化带来的新障碍,以及如何将 Mamba 融入现代深度学习架构。

我们以 S4 模型的介绍结束了第 2 部分。通过约束状态矩阵A具有某种结构,可以非常高效地计算卷积核,使 S4 的速度比以前的 SSM 快 30 倍,内存使用量减少 400 倍,因此使其成为 Transformers 的可行替代方案。 我们说过,由于 LTI 系统的循环和卷积表示的二元性,我们的结构化状态空间模型可以得到有效训练和有效推断。❗ 但问题是:LTI 系统平等对待所有输入,因此无法做出依赖数据的决策。

2. 为 SSM 引入选择性 Mamba 通过引入数据相关选择机制解决了这个问题,但正如我们将看到的,这是有代价的。稍后会详细介绍。让我们从一些直觉开始。 2.1 总体情况 序列建模的一个基本问题是决定上下文(即先前的输入)在多大程度上被压缩到模型的状态中。 请记住,RNN之所以能够实现快速推理,是因为它处于高度压缩的状态,无法有效地对长距离依赖关系进行建模。另一方面, Transformer 则完全不压缩上下文。每个 token 可以随时关注其他 token,从而可以非常有效地对长距离依赖关系进行建模,但代价是二次工作复杂度O(L²)。 结构化 SSM使用 HiPPO 矩阵来近似上下文,从而产生一些压缩,同时由于其卷积表示,它可以作为 RNN 更有效地训练。

图 1:RNN、SSM、Transformer 和 Mamba 模型的效率与性能

我们想要的是介于两者之间的一种方案,或者说两全其美。一种高性能且高效的模型。剧透:Mamba 可能是一个有希望的候选者。 2.2 什么是选择性? 我们不想像 Transformer 那样关注整个历史记录,也不想像 RNN 和 SSM 那样将整个历史记录压缩为单一状态,而是想有选择地只压缩相关数据,从而获得高性能和高效的模型。这里的高效不仅意味着推理速度,还与训练模型所需的其他资源(如内存和数据)有关。 让我们看一下这个预测我的名字的简单例子:

图 2:选择性地将焦点放在输入子集上。

到目前为止,所有矩阵A、B和C都与输入数据无关。它们不会随着信号的不同时间步长而改变,因此我们说我们有一个线性和时不变 (LTI) 系统,其中循环和卷积表示相等。

添加图片注释,不超过 140 字(可选)

Mamba 通过使矩阵B和C以及离散化参数 ∆ 依赖于时间来改变这一点。这意味着对于每个输入标记,我们最终都会得到一组不同的参数!回想一下,矩阵A是从 HiPPO 矩阵派生的结构化矩阵,它根据先前状态的历史来近似状态随时间的变化。Mamba 保持这种状态。 2.3 选择性如何影响张量维数 使 SSM 具有选择性会在矩阵中引入新维度。明确地说,我们现在需要考虑长度为L 的序列中的每个标记,并将其并行化为批处理大小B。

添加图片注释,不超过 140 字(可选)

这里有几点需要注意:

  1. 矩阵A的形状为DxN,并且不会因单个标记而改变。因此它仍然是时不变的。我们将在第 3 章中进一步讨论矩阵A。

  2. 矩阵A确实通过离散化步骤 Δ 执行选择,该步骤现在取决于输入。因此,离散化矩阵的形状为BxLxDxN。

  3. 在 SSM 内部,矩阵B和C仍然是分别应用于输入标记或产生输出标记的向量,只是它们现在被堆叠以构建张量。

  4. 离散化矩阵明显大于非离散化矩阵。更多内容请参见第 4 章。

  5. 这些是 SSM 方程中使用的矩阵的形状,而不是学习到的参数!它们隐藏在其他地方。我们将在第 5 章中寻找它们。

❓ 但这一切到底意味着什么?所有这些矩阵背后的直觉是什么?为什么只有一些是选择性的? 2.4 所有这些矩阵背后的直觉是什么? 矩阵 A:状态到状态矩阵决定从前一个状态传播到下一个状态的内容。 矩阵 B:输入到状态矩阵转换任何传入的标记。它必须有选择性地过滤掉不相关的数据或关注重要数据,以仅允许相关数据进入新状态。它类似于某种门控机制。 矩阵 C:状态到输出矩阵将状态转换为输出标记。必须有选择性地决定需要从状态中获取哪些信息来实现输出。因此,与矩阵B一样,矩阵C也与门控机制有一些联系。 矩阵 Δ:离散化参数通常是指将连续输入信号转换为离散序列的步长。虽然在经典 SSM 中,这可能是直观的,但在深度学习设置中却并非如此,尤其是在处理已离散化的数据或其他模态(如文本标记)时。 那么,让我们看看极端情况:如果Δ→∞ ,系统会更多地(或更长时间地)关注输入,因此在状态方程中,状态更新受输入贡献的影响较大,而受先前状态的影响较小。如果Δ→0,则情况相反。

3. 再来谈谈矩阵A! 再次,我们必须讨论状态矩阵A。当你查看矩阵A的维度时,你有没有注意到什么? 我们总是说矩阵A是NxN的方阵,但在 Mamba 中它的形状是DxN。发生了什么? 3.1 对角状态空间模型 好吧,从 S4 论文到 Mamba 发布,发生了很多事情。特别是,对角状态空间 (DSS)模型通过删除低秩项,进一步将 S4 的 DPLR(对角加低秩)公式简化为一个简单的对角矩阵。当我们在本系列第 2 部分讨论 S4 时,我们说,使用对角矩阵A来计算卷积核K实际上是理想情况,因为它可以大大减少计算量!

添加图片注释,不超过 140 字(可选)

在S4 作者的后续论文中,他们对其工作原理进行了彻底的分析,并进一步为真正的对角线情况提出了多种新的初始化方案,从而形成了 S4D 模型。 长话短说,Mamba 使用 S4D 论文的附注中提到的 S4D-Real 初始化,因为它具有与原始 S4-LegS HiPPO 矩阵相同的频谱。我们将在第 5 章中进一步讨论初始化。 3.2 Mamba 矩阵 A 但现在让我们回到 Mamba 的矩阵A。回想一下本系列第 2 部分,如果我们有一个NxN的对角矩阵,我们只需要存储N 个值,这意味着我们可以将其存储为一个向量。Mamba的DxN矩阵实际上只是表示对角矩阵的向量的D倍!

图 6:根据矩阵 A 创建 SSM 的对角矩阵。

这使我们能够观察到有关 Mamba 和以前的状态空间模型的另一个有趣的事实:D中的每个通道都是独立处理的,这可以解释为每个通道都有自己的 SSM!当我们分别研究 MambaMixer 和 Mamba-2 时,这一事实在第 6 部分和第 7 部分中变得重要。

4. 无需卷积表征的快速训练 不同时间步长具有不同的参数意味着我们的模型现在是随时间变化的。因此,卷积表示不再有效,我们失去了快速训练的好处。Mamba 论文的很大一部分是关于如何克服这个问题的。他们介绍了三种策略,从而形成了他们的选择性扫描算法:(1)并行关联扫描,(2)核融合和(3)梯度重新计算。 注意:这些想法并不完全是新的。S5层(序列建模的简化状态空间层)应用了并行扫描,而H3(Hungry Hungry Hippos)已经与核融合配合使用以加快训练速度。 4.1 并行联想扫描 回想一下,结构化状态空间模型是一个时不变系统 (LTI),因为更新状态所涉及的矩阵(A和B)对于每个新输入都是相同的。这允许使用递归表示或卷积表示。在 Mamba 中,这种情况发生了变化,因为离散化矩阵A和B现在依赖于输入。 这意味着我们不能使用卷积表示来加速训练,而只能使用循环表示,这次对于每个新输入和状态更新都有不同的矩阵:

图 7:选择性状态空间模型的递归表示。

循环执行的工作复杂度为O(L),这意味着将输入序列的长度L加倍需要两倍的工作来计算最终输出。循环的一个大问题是,在训练期间,我们不想按顺序计算所有状态,因为我们可以事先访问所有输入的标签。但 在推理过程中,尤其是对于自回归采样任务(例如下一个单词预测),这完全没问题。 ❓问题是,我们如何才能加快状态的顺序更新速度,其中每个状态都强烈依赖于前一个状态?我们不可能并行运行,对吧? 这种特殊的递归类型,即每个新状态都是前一个状态和输入的总和,称为扫描。这是科学和工程中一个常见的问题,甚至在我们有并行计算机之前,1986 年和 1990 年,并行化扫描的技术就已经公布了。直到 2005 年,这些算法才在真正的 GPU 上进行测试。

图 8:并行扫描算法。

并行扫描的工作复杂度为O(log L),小于O(L) 。输入序列L越长,与循环执行相比,加速比越高。 回想一下,原始变压器的工作复杂度为O(L²)。 4.2 内核融合 您可能知道,任何计算机都有不同类型的内存。一般来说,距离处理单元越远,内存越大,读取数据所需的时间越长(主要是由于 PCB 和其他电子元件的属性不允许更高的频率)。 GPU 中通常发生的情况是,变量从较大但较远的 HBM(高带宽内存)读取到较快的 SRAM 中以执行计算,然后将结果写回到 HBM。如果计算需要多个步骤,则重复读取、处理和写入的过程,直到完成整个计算。如果计算实际计算所需的时间少于读取和写入所需的时间,则称为内存受限,又称 I/O 受限。

图 9:GPU 利用率与内存 I/O。

I/O 密集型操作对计算硬件的使用效率非常低,因为它们大部分时间都在等待空闲状态的数据。基本上,您要避免过于频繁地读取大量数据。 ❓ 但我们能做什么呢?或者更确切地说:Mamba 的作者们做了什么? 全局内存中不存储离散化矩阵A和B,而只存储连续时间矩阵。矩阵A、B、C和Δ从 HBM 读入 SRAM ,然后在 SRAM 中执行离散化。 这里有一个例子,说明在读取连续时间矩阵或读取离散化矩阵时需要读取多少个值,考虑到前面描述的矩阵形状:

图 10:离散矩阵与连续矩阵需要读取的值的数量。

此外,还应用了内核融合来减少 HBM 的读写次数。具体来说,它们将离散化、选择性扫描和与矩阵C的乘法融合到单个内核中。因此,它们读取整个前向传递所有状态所需的所有数据,处理这些数据并写回形状为BxLxD的最终输出y 。形状为BxLxDxN的大状态没有实现,以避免内存需求和读取大量数据。

图 11:内核融合对 GPU 利用率的影响。

了解了核融合及其所涉及的不同内存后,我们终于可以看看 Mamba 论文中的标志性插图,该插图展示了不同矩阵之间的相互作用,以计算新状态和输出:

图 12:具有硬件感知状态扩展的选择性状态空间模型。

绿色表示这些值存储在 HBM 中,而橙色表示它们保存在 SRAM 中。在第 5 章的后面,我们将调整此图像以显示实际情况并找到训练过的参数。 4.3 梯度的重新计算 之前我们看到,通过减少 I/O 操作以及隐藏状态不保存在 HBM 中,我们可以通过核融合来加快模型推理速度。但是,要计算梯度并反向传播权重更新的损失,需要这些中间状态。 由于核融合非常有效,因此在反向传播过程中根据需要重新计算中间状态比从 HBM写入和读取形状为BxLxDxN的大型中间状态要高效得多。 4.4 它有多快? 在花费了如此多的精力来提高 SSM 的速度之后,仍然存在一个问题:与我们所知的最有效的注意力实现FlashAttention-2相比,它的速度有多快。

图 13:Mamba 的效率基准

5. Mamba 作为通用模型 我们现在有了选择性 SSM,这是一种高效、快速的状态空间模型,它随输入序列长度L线性扩展,并且由于每个输入都使用自己的一组参数进行处理,因此它也具有选择性。我们唯一要做的就是将其包装成一个块,并将多个块堆叠在一起以构建深度学习模型;就像我们在 Transformer-Decoder 中堆叠多头 Transformer 块一样。 5.1 曼巴区块 让我们看看使用 Mamba 的自回归下一个词预测模型是什么样的: 请注意,维度名称和模块名称与Mamba Block 的官方代码实现相同,以便更好地理解如何配置以及发生了什么。

图 14:Mamba-Block 的详细框图。在此示例中:L=4、d_model=3、expand=2、d_inner=6。图片由Sascha Kirch提供。

输入投影只是将每个L个token 投影到d_model- 维空间中。头部使用Mamba Blocks堆栈中的一个或多个输出 token 进行预测。 在Mamba Block中,我们首先将residuals上一层的添加到中,hidden_states 然后将它们的总和作为下一个输出residuals。我们使用 RMS Layer Normalization 进行归一化,然后从d_model维度线性投影到2*d_inner维度,其中d_inner定义为expand*d_model。代码内部expand默认为 2。然后将 的输出in_proj分成 2 个不同的路径,这两条路径现在都是维度[1, d_inner],并且都包含 SiLU 非线性。 选择性SSM 块是我们的状态空间模型,其中已知状态方程由矩阵A、B、C和Δ参数化。请注意,图 4 中的向量维度 D对应于d_inner。 在选择性 SSM 块之后,将两个路径组合起来,从 投影回d_inner以d_model形成最终输出。 请注意,两个线性投影都通过L个标记进行广播,这意味着它们对每个标记使用相同的参数! 5.2 可学习的参数在哪里? 虽然在 Mamba Block 内部,学习到的参数的位置非常清楚(即输入投影in_proj、输出投影out_proj和 1D 卷积conv1d),但学习到的参数在选择性 SSM Block 内隐藏的位置仍然是一个谜。 我认为,如果我们调整图 12,我们就能最好地观察发生的情况,以显示选择性 SSM(单个标记且无批处理轴)内部真正发生的情况:

图 15:具有硬件感知状态扩展的选择性状态空间模型,包括离散化矩阵和可学习的线性投影

同样,这个数字上有很多内容,所以我们从两个观察开始:

  1. 矩阵A没有下标t,表明它不具有时间相关的选择性,但离散化的矩阵A具有。

  2. 矩阵A和B存储在 GPU 的 HBM 中,而离散化矩阵仅在 GPU SRAM 内部实现,如第 4.2 节所述。

现在让我们看看发生了什么。请注意,时间相关矩阵B和C是时间tx_proj时输入x的线性投影的结果。同样,使用单个线性层,可能是出于性能原因。的输出进一步从投影到形成最终的离散化矩阵Δ。我们直接优化的只是矩阵A(实际上,我们在训练期间跟踪log(A)以确保数值稳定性)。x_projdt_rankd_inner 总而言之,在选择性 SSM 中,我们不训练状态方程中使用的矩阵,而是训练产生这些矩阵的线性投影的权重。我们直接优化的只是矩阵A。 那么,只剩下一个问题需要回答…… 5.3 我们如何初始化模型? 这不是一个简单的问题,因为不仅阅读论文由于主题复杂而变得复杂,而且代码也相当复杂。 首先要注意的是:有一个独立的 Mamba Block 的初始化,然后有一个重新初始化,这取决于将多个块堆叠在一起的深度学习模型的最终架构,即通过使用函数_init_weights()。 我们首先来看一下 Mamba Block 本身。 所有模块都使用 pytorch 相应模块类型的默认初始化器进行初始化(例如,Conv1D 和 Linear 的统一初始化),但有两个例外:矩阵A和dt_proj,用于构建Δ 的线性投影。 矩阵A使用前面提到的 S4D-Real 初始化来相对直接地进行初始化:

图 16:S4D-Real 状态矩阵 A 的初始化。

另一方面,时间表有点复杂。我建议你查看“如何训练你的河马”论文以了解更多详细信息。简短版本是: 偏差应通过在( Δmin=0.001,Δmax=0.1dt_proj )范围内进行对数均匀采样来初始化。该论文中有一个非常重要且容易忽视的旁注:如果将1/Δmin设置为大约在序列长度范围内,则可以提高性能。因此 0.001 大约为 10³ 个样本。 例如,查看本文中序列长度为L=16,384的基准数据集上的图:

图 17:时间尺度的选择对 Path-X 基准验证准确率的影响。

最后,无需过多介绍细节,下面是带有_init_weights()一些注释的函数,用于在作者提供的MixerModel和MambaLMHeadModel中使用已初始化的权重时覆盖它们: 这个函数实际上在这两个类中都会被调用,这意味着它们会覆盖初始初始化时覆盖的权重😵但玩笑归玩笑,如果在时间压力下工作并找到性能最高的解决方案,那么并不总是有足够的时间来编写干净的代码。还有另一个权衡……

图 18:注释 Mamba 存储库中的 _init_weights() 函数,该函数用于在初始化模型权重后覆盖它们。

6. 结束 Mamba 6.1 总结 借助 Mamba,我们现在得到了一个与 Transformer 一样强大、与 RNN 一样高效的序列模型。 我们发现,虽然通过添加选择性我们失去了循环和卷积表示的二元性,但我们能够通过实施选择性扫描算法再次加快模型速度。 Transformer 是一种通用架构,适用于您能想到的所有数据模式。到目前为止,在谈论 Mamba 时,我们只谈论序列类数据,例如文本标记或模拟信号。❓ 但是,我们可以将 Mamba 应用于非序列数据吗?比如说图像? 谢谢你! 🤝我想感谢你们, 我故事的读者,感谢你们一直以来的支持、为我的故事鼓掌以及关注我,让我不会错过任何一篇最新文章。 正是你们激励我继续写作,深入研究复杂的主题,并通过写作成为更好的工程师/研究人员。 你们太棒了! 6.2 继续阅读第 4 部分 Vision Mamba:类似 Vision Transformer 但更胜一筹 🔜我目前正在撰写该系列的第 4 部分。发布后我会在此处粘贴链接! 在第 4 部分中,我们将了解 Mamba 类模型如何应用于非序列数据类型(例如图像)。我们将重新审视 Vision Transformer 的某些设计选择(例如,块大小、图像分辨率和位置编码),并查看在使用 Mamba 而不是注意力机制时它们是否仍然有效。 Vision Mamba 最终比DeiT快 2.8 倍,并节省 86.8% 的 GPU 内存,我们将看到这是如何实现的。 敬请期待!

感谢关注雲闪世界。(Aws解决方案架构师vs开发人员&GCP解决方案架构师vs开发人员)

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐