测试时训练层Test-Time Training(TTT),一种新的序列建模层

自注意力机制在长上下文中表现良好,但其复杂度为二次方。现有的RNN层具有线性复杂度,但其在长上下文中的性能受限于其隐藏状态的表达能力。论文Learning to (Learn at Test Time): RNNs with Expressive Hidden States提出了一类新的序列建模层,具有线性复杂度和强表达能力的隐藏状态。关键思想是使隐藏状态本身成为一个机器学习模型,并将更新规则视为自监督学习的一步。由于隐藏状态在测试序列上也会通过训练进行更新,这些层被称为测试时训练(Test-Time Training, TTT)层。论文考虑了两种实现方式:TTT-Linear和TTT-MLP,分别以线性模型和两层MLP作为隐藏状态。论文在125M到1.3B参数的规模上评估了这些实现,并与一个强大的Transformer和现代RNN(Mamba)进行了比较。TTT-Linear和TTT-MLP都能匹敌或超越基线。类似于Transformer,它们可以通过考虑更多的tokens来不断降低困惑度,而Mamba在16k上下文之后无法继续降低。在初步的系统优化下,TTT-Linear在8k上下文中已经比Transformer更快,并在实际运行时间上与Mamba相当。TTT-MLP仍然面临内存I/O的挑战,但在长上下文中表现出更大的潜力,指向未来研究的一个有希望的方向。

论文作者为Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin。作者来自Stanford University, UC San Diego, UC Berkeley和Meta AI。

论文概要内容如下:

一、主要贡献

  1. TTT层的提出:提出了一种新的序列建模层,其隐藏状态是一个模型,更新规则是自监督学习。该视角将层的前向传播视为一个训练循环,为未来研究开辟了新方向。
  2. TTT-Linear的性能:在评估中,TTT-Linear在125M到1.3B参数范围内的表现超过了Transformer和Mamba。
  3. 硬件效率改进:通过mini-batch TTT和对偶形式的设计,提高了TTT层的硬件效率,使得TTT-Linear成为LLM的一个实用构件。

二、方法

  1. TTT层概念:所有序列建模层都可以视为存储历史上下文到隐藏状态中,更新规则则是自监督学习的一步。TTT层的前向传播过程中,隐藏状态会根据输入序列不断更新。
  2. 自监督任务:TTT层的核心在于自监督任务的设计,该任务决定了隐藏状态从测试序列中学习的特征。通过优化自监督任务以实现最终的下一个词预测目标。
  3. mini-batch TTT:提出了mini-batch梯度下降,通过在小批量内并行计算多个梯度,提高了计算效率。
  4. 对偶形式:通过对偶形式计算Wb和输出序列,进一步提高了计算效率。

三、实验结果

  1. 短上下文评估(2k和8k):在2k上下文中,TTT-Linear、Mamba和Transformer性能相当;在8k上下文中,TTT-Linear和TTT-MLP显著优于Mamba。
  2. 长上下文评估(1k到32k):在长上下文评估中,TTT-Linear和TTT-MLP在32k上下文时的表现优于Mamba。TTT-MLP在更长上下文和更大规模的模型中表现出更大的潜力。

四、结论

TTT层通过在测试时进行训练,能够在保持线性复杂度的同时,利用更长的上下文,提高序列建模的性能。TTT-Linear和TTT-MLP的表现表明,这种方法在现有序列建模技术上有显著的改进,并为未来的研究方向提供了新的思路。

五、未来研究方向

  1. 进一步优化自监督任务:探索更复杂的自监督任务,以提高TTT层的性能。
  2. 探索不同的模型和优化器:研究更复杂的模型(如深度神经网络)和优化器(如Adam)在TTT层中的应用。

这篇论文通过提出新的序列建模方法,展示了在保持线性复杂度的同时,如何利用更长的上下文来提高模型性能,具有重要的研究意义和应用潜力。


P.S., 论文研究成果的相关代码实现:
1. Jax for training and testing: https://github.com/test-time-training/ttt-lm-jax
2. Pytorch inference code: https://github.com/test-time-training/ttt-lm-pytorch

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注