近8年后,谷歌Transformer继任者「Titans」来了,上下文记忆瓶颈被打破

近8年后,谷歌Transformer继任者「Titans」来了,上下文记忆瓶颈被打破
2025年01月15日 17:08 机器之心Pro

正如论文一作所说,「新架构 Titans 既比 Transformer 和现代线性 RNN 更有效,也比 GPT-4 等超大型模型性能更强。」

终于,在 2017 年推出影响 AI 行业长达 8 年的 Transformer 架构之后,谷歌带来了全新的架构 Titans。这次,谷歌的重点是将推理领域非常重要的测试时(test-time)计算用在了记忆(memory)层面。

在谈到推出 Titans 的初衷时,论文一作 Ali Behrouz 表示,「注意力机制一直是大多数 LLM 进展的重要组成部分,不过它无法扩展到长上下文。因此,Titans 应运而出,它成为了一种同时具备注意力机制和元上下文记忆的结构,可以在测试时学习记忆。该架构可以将上下文窗口扩展到 200 万 tokens。」

图源:https://x.com/behrouz_ali/status/1878859086227255347

这意味着,谷歌 Transformer 迎来了它的「继任者」。

图源:https://x.com/mark_k/status/1878896628654022993

多年来,研究人员一直在广泛探究如何有效地利用循环模型和注意力机制,其中循环模型旨在将数据压缩到固定大小的记忆(称为隐状态)中,而注意力机制允许处理整个上下文窗口,捕捉所有 token 的直接依赖。不过,更准确的依赖建模往往伴随着二次成本,导致模型只能处理固定长度的上下文。

因此,谷歌提出了一种新的长期神经记忆模块(neural memory module),它能够学习记忆历史上下文,并帮助注意力机制在利用过去已久信息的同时处理当前上下文。结果表明,这种神经记忆具有快速并行化训练的优势,同时还能保持快速推理。

从记忆的角度来看,谷歌认为注意力机制虽然受限于上下文但可以更准确地建模依赖关系,因此可以起到短期记忆的作用;而神经记忆能够对数据进行记忆,起到了长期、更持久的记忆作用。基于这两个模块,谷歌引入了一个全新的系列架构 —— Titans,通过三种变体有效地将记忆融合到该系统架构中,它们分别是记忆作为上下文(Memory as a Context,MAC)、记忆作为门(Memory as a Gate,MAG)和记忆作为层(Memory as a Layer,MAL)

在语言建模、常识推理、基因组学和时序预测任务上的实验结果表明,Titans 架构比 Transformer 和近年来的现代线性循环模型更有效。另外,在大海捞针(needle-in-haystack)中,Titans 架构能够有效地扩展到超过 200 万 tokens 的上下文窗口,并且比基准模型实现了更高的准确性。

  • 论文标题:Titans: Learning to Memorize at Test Time

  • 论文地址:https://arxiv.org/pdf/2501.00663v1

另外,论文作者之一 Peilin Zhong 为谷歌 NYC 算法与优化团队的研究科学家,2021 年加入谷歌。他本科毕业于清华姚班,博士毕业于哥伦比亚大学。

目前,已经有人搞出了有关 Titans 架构的非官方实现,感兴趣的读者可以去看一下。

GitHub 地址:https://github.com/lucidrains/titans-pytorch

学习测试时记忆

谷歌详细介绍了长期神经记忆模块,它成为了一种可以在测试时学习记忆的元模型。

长期记忆

为了设计一个长期神经记忆模块,我们需要模型能够将过去历史的抽象编码到其参数中。因此,一个简单的思路是训练神经网络并期望它能够记住自己的训练数据,然而记忆几乎一直是神经网络中令人头疼的现象,它限制了模型的泛化能力,还引发隐私问题,因此导致测试时性能不佳。

基于此,谷歌认为需要一个在线元模型来学习如何在测试时记忆或忘记数据。在这种设置下,模型学习一个能够记忆的函数,但不会过拟合训练数据,从而在测试时实现更好的泛化性能。

学习过程和意外指标(Learning Process and Surprise Metric)。训练长期记忆的关键思路是将训练视为在线学习问题,其中将过去信息 x_1, …, x_t-1 压缩到长期神经记忆模块中。人类往往能够记住背离预期(令人惊讶)的事件,受此启发,模型意外可以简单定义为它相对于输入的梯度。梯度越大,输入数据与过去数据的偏差就越大。因此,使用这个意外分数,可以将记忆更新如下:

这一意外指标可以导致在重大意外时刻之后出现重要信息缺失。从人类记忆的角度来看,即使一个事件令人难忘,但它可能不会在长时间内持续让我们感到惊讶。为了改进这一现象,谷歌将意外指标分解为了(1)过去意外,它衡量最近过去的意外程度;(2)瞬时意外,它衡量传入数据的意外。

这些意外指标基于一个损失函数

,它就是我们的记忆在测试时学习充当的目标。也就是说,记忆模块是一个元模型,它基于损失函数来学习一个函数。

在本文中,谷歌则专注于联想记忆,目的是将过去的数据存储为键(keys)和值(values)对。类似于 Transformer,在给定 x_t 的情况下,谷歌使用两个线性层将 x_t 投影到键和值中:

接下来,谷歌希望记忆模块可以学习键和值之间的关联,为此将损失定义如下:

遗忘机制(Forgetting Mechanism)。在处理非常大的序列(比如百万 tokens)时,管理哪些过去信息应该被遗忘非常重要,即使使用深度或者非常大的矩阵值记忆时也是如此。因此,谷歌使用了一种自适应遗忘机制,允许记忆忘记不再需要的信息,从而更好地管理有限的记忆容量。也就是说,给定下一个 token x_t,谷歌将更新规则做如下修改:

记忆架构(Memory Architecture)。谷歌重点将具有 L_M≥1 层的简单 MLP 作为长期记忆架构,选择它们的原因在于希望能够更好地激励长期记忆设计以及将其融入架构的方法。谷歌表示,本文的架构开辟了一个新的研究方向,有助于设计更有效且高效记忆数据的神经架构。

检索记忆(Retrieving a Memory)。在探讨如何设计和训练一个可以在测试时学习记忆的长期记忆模块之后,剩下的关键问题便是如何从记忆中检索信息?谷歌仅仅使用了没有更新权重的前向传递(即推理)来检索与查询相对应的记忆。在形式上,给定一个输入 x_t,谷歌使用线性层 W_Q 来投影输入,即 q_t = x_tW_Q,并通过以下公式从记忆 y_t 中检索相应(或有用)的信息。

并行化长期记忆训练

理论上,长期记忆模块的训练需要

FLOPS,其中 N 为序列长度。不过在实践中,我们需要并行化训练过程并充分利用 TPU、GPU 等硬件加速器,同时需要张量化该过程并使用更多矩阵乘法(matmuls)。

接下来,谷歌表示,使用小批量梯度下降、数据学习率和权重衰减来计算内循环权重的方式可以重新来表示,以便它只使用矩阵乘法和求和(sums)。这里将序列拆分为大小为 b ≥ 1 的块,并将小批量梯度下降写做:

此外,谷歌解释了 M_t = W_t 为线性时的情况。对于 N_p ≥ 2 的 MLP,过程类似。使用本文的损失函数可以得到如下:

最后,谷歌扩展这一表示,以便可以合并动量项。在具有动量的块式梯度下降中,如果观察动量项则可以得到如下:

作为函数块的参数(Parameters as the Function of Chunks)。谷歌没有让参数 a_t、θ_t 和 η_t 依赖于输入,而是让它们成为函数块。尽管失去了表达能力,但可以帮助更快地训练。在这种情况下,谷歌在每个块中对每一个 a、θ 和 η 都使用了相同的值。在实验中,谷歌将这些参数作为了 token 的函数,并表示,这种简化(即作为块函数)可能是未来工作感兴趣的地方,以便以更高效的方式训练更大的模型。

下图 1 展示了如何并行并在使用矩阵乘法时完成神经记忆训练。

如何融合记忆?

接下来需要解决的一个重要问题是:如何有效且高效地将神经记忆融合到深度学习架构中

从记忆的角度来看,Transformer 中的 K 和 V 矩阵对可以解释为联想记忆块。由于它们对依赖关系的精确建模以及有限的上下文窗口,它们可以被用作短期记忆模块,以处理当前上下文窗口大小。另一方面,神经记忆能够不断从数据中学习并存储在其权重中,因而可以发挥长期记忆的作用。谷歌通过三个不同的 Titans 变体来回答以上问题。

记忆作为上下文(Memory as a Context,MAC)

Titans 的第一个变体 MAC 的架构设计如下图 2 所示,将记忆作为当前信息的上下文。

也就是说,给定一个长序列

,首先将序列分成固定大小的片段 S^(𝑖),其中 𝑖 = 1,...,𝑁/𝐶。给定传入片段 S^(𝑡),谷歌将它视为当前上下文,将其过去的片段视为历史信息。因此,谷歌让 M_𝑡-1 成为片段 S^(𝑡) 之前的长期记忆状态,使用输入上下文作为对记忆 M^𝑡-1 的查询,以从长期记忆中检索相应的信息。谷歌如下所示检索与 S^(𝑡) 相对应的过去信息:

接下来将这些历史信息与持久记忆参数一起作为注意力模块的输入序列:

该架构具有两个关键优势:一是注意力模块同时具有历史和当前上下文,能够根据当前数据决定是否需要长期记忆信息,二是注意力模块帮助长期记忆只存储来自当前上下文的有用信息。这意味着,并非每个片段中的所有 token 都是有用的,记忆所有 token 可能会导致内存溢出。因此,注意力模块帮助记忆了解哪些信息是有用的,从而更好地管理内存容量。

另外,在测试时,(i)持久记忆参数是固定的,它们编码了有关任务的知识,不应改变;(ii)注意力模块权重是上下文学习器;(iii)长期记忆模块在测试时仍然学习(记忆)信息。也就是说,即使在测试时,神经记忆的权重也会更新,这是因为权重对过去已久的抽象进行了编码。

记忆作为门(Memory as a Gate,MAG)

Titans 第二个变体 MAG 的架构设计如下图 4 所示:

在其中一个分支中,谷歌直接使用输入数据来更新长期记忆;在第二个分支中,谷歌使用了滑动窗口注意力(SWA):

该架构的整体注意力掩码如下图 3b 所示,其中滑动窗口注意力(SWA)充当精确的短期记忆,而神经记忆模块充当模型的衰减记忆。该设计也可以看作是多头架构,其中各头的结构不同。

记忆作为层(Memory as a Layer,MAL)

Titans 的第三个变体 MAL 使用了深度神经网络,这种架构设计在文献中更为常见,其中混合模型堆叠具有完整或滑动窗口注意力的循环模型。

给定输入 x,可以得到以下:

其中 SW-Attn 是滑动窗口注意力。

无注意力记忆(Memory Without Attention)。从记忆的角度来看,谷歌期望记忆系统的每个组件都能独立工作,即使其他组件受到了干扰。因此,即使没有短期记忆(即注意力),长期记忆模块仍然应该是一个强大的模型。谷歌在实验中将这种变体称为 Titans (LMM)。

架构细节

在所有块中,谷歌使用了残差连接;在实现中,谷歌使用 SiLU (.) 激活函数作为计算查询、键和值的非线性激活,并使用

对查询和键进行归一化。

卷积(Convolution)。遵循最近的现代线性循环模型,谷歌在每个查询、键和值投影后都融合了一个 1D 深度可分离卷积层。这些 1D 卷积可以提升性能,并且计算高效。

门控(Gating)。谷歌还在最终输出投影之前利用线性层进行归一化和门控。

实验结果

谷歌在实验部分关注上述三种 Titans 变体,分别是 MAC、MAG 和 MAL,以及单独的神经记忆模块。对于每个模型,谷歌使用了四种尺寸的模型,参数分别是 (i) 170M、(ii) 340M、(iii) 400M 和 (iv) 760M。

语言建模

谷歌首先关注模型在语言建模和常识推理任务中的困惑度。下表 1 报告了 Titans 变体和三种不同大小(340M、400M 和 760M)基线的结果。在包括 Transformer++ 在内的非混合模型中,神经记忆模块在困惑度和准确度测量方面均取得了最佳性能。

谷歌还发现,Titans 的三种变体(MAC, MAG 和 MAL)都优于 Samba (Mamba + 注意力)和 Gated DeltaNet-H2(Gated DeltaNet + 注意力)。

大海捞针 

下表 2 结果显示,与基线相比,神经记忆模块均取得了最佳结果。

谷歌将这种卓越的表现归因于 Titans 与现有序列模型的三个关键差异:(1)与 TTT 相比,神经记忆能够通过使用动量和遗忘机制(即权重衰减)更好地处理记忆容量。因此,随着序列长度的增加,神经记忆的性能不会下降,呈现出一致的趋势;(2)与具有门控(遗忘)机制的 Mamba2 相比,Titans 具有深度非线性记忆,从而实现了更好的记忆管理。此外,与神经记忆和 DeltaNet 不同,Mamba2 无法移除记忆,因此在增加序列长度时,其性能会出现显著下降;(3)与 DeltaNet 相比,尽管它能够使用增量规则移除记忆,但无法擦除记忆,缺乏遗忘机制。

最终,正如预期的那样,使用 Titans 变体时能看到相当或更好的结果,其中最佳结果来自 MAC。

BABILong 基准

在微调设置中,谷歌将小型微调版本的 Titans (MAC) 与其他模型进行了比较。

Titans 和基线的结果如下图 6b 所示。Titans 的表现优于所有模型,甚至比 GPT4 这样的超大型模型还要好。此外,与基于 Transformer 的 RMT 等记忆模型相比,Titans 表现出更好的性能,这主要归功于其强大的记忆。

深度记忆的影响

接下来的实验评估了深度记忆对 wall-clock 训练时间和模型性能的影响。

下图 7 中报告了 Titans(LMM)和基线的困惑度与序列长度的关系。有趣的是,随着记忆深度的增加,该模型可以在所有序列长度上实现更好的困惑度。此外,当模型的参数量较少时,更深的记忆模块对序列长度的鲁棒性更强。随着参数量的增加,所有模型在较长的序列上都表现出更好的性能。

时序预测

为了展示记忆模块在更广泛任务中的有效性,谷歌评估了 Titans 在时序预测任务中的表现。结果如下表 3 所示,谷歌的神经记忆模块优于所有基线,包括基于 Mamba、线性和 Transformer 的架构。

DNA 建模

谷歌还进一步评估了神经记忆模块在 DNA 建模任务上的表现,结果如下 4 所示,相较于当前的 SOTA 架构,Titans(LMM)在不同的下游基因组任务中仍具有竞争力。

效率

谷歌还对 Titans 与当前 SOTA 序列模型的效率进行了比较,下图 9 显示了不同序列长度 x 批大小的模型的训练吞吐量。可以看到,谷歌神经记忆模块比 Mamba2 和 Gated DeltaNet 稍慢,不过 Titans (MAL) 比基线和神经记忆模块都要快。

更多技术细节和实验结果请参阅原论文。

新浪科技公众号
新浪科技公众号

“掌”握科技鲜闻 (微信搜索techsina或扫描左侧二维码关注)

创事记

科学探索

科学大家

苹果汇

众测

专题

官方微博

新浪科技 新浪数码 新浪手机 科学探索 苹果汇 新浪众测

公众号

新浪科技

新浪科技为你带来最新鲜的科技资讯

苹果汇

苹果汇为你带来最新鲜的苹果产品新闻

新浪众测

新酷产品第一时间免费试玩

新浪探索

提供最新的科学家新闻,精彩的震撼图片