大型语言模型(LLMs)的成功,某种程度上要归功于Transformer架构在自然语言处理任务上的突破。该架构最初是为了克服循环模型的sequential training问题而提出的。这些年来,Transformer已经成为LLMs普遍采用的架构。
然而,Transformer的训练并行性是以低效推理为代价的:每一步的复杂度为O(N)且键值缓存受内存限制,让Transformer不适合部署。不断增长的序列长度会增加GPU内存消耗和延迟,并降低推理速度。
研究者们一直在努力开发下一代架构,希望保留训练并行性和Transformer的性能,同时实现高效的O(1)推理。针对这个问题,此前的方案都没能同时实现这几点,至少与Transformer相比没有显示出绝对的优势。
现在,微软研究院和清华大学的研究者已经在这个问题上取得了重大突破。论文链接:https://arxiv.org/pdf/2307.08621.pdf
在这项工作中,研究者提出了retentive网络(RetNet),同时实现了低成本推理、高效长序列建模、媲美Transformer的性能和并行模型训练,打破了「不可能三角」。具体来说,RetNet引入了一种多尺度retention机制来替代多头注意力,它有三种计算范式:并行、循环和分块循环表征。
首先,并行表征使训练并行化,以充分利用GPU设备。其次,循环表征法在内存和计算方面实现了高效的O(1)推理。部署成本和延迟可以显著降低,同时无需键值缓存技巧,大大简化了实现过程。此外,分块循环表征法能够执行高效的长序列建模。研究者对每个局部块进行并行编码以提高计算速度,同时对全局块进行循环编码以节省GPU内存。
论文进行了大量实验来对比RetNet和Transformer及其变体。实验结果表明,RetNet在scaling曲线和上下文学习方面始终具有竞争力。此外,RetNet的推理成本与长度无关。对于7B模型和8k序列长度,RetNet的解码速度是带键值缓存的Transformers的8.4倍,内存节省70%。
在训练过程中,RetNet也能够比标准Transformer节省25-50%的内存,实现7倍的加速,并在高度优化的FlashAttention方面具有优势。此外,RetNet的推理延迟对批大小不敏感,从而实现了巨大的吞吐量。