自从谷歌的BERT预训练模型横空出世,预训练-下游任务微调的方式便成了自然语言处理任务的灵丹妙药。然而,复杂度高、显存消耗大等问题一直困扰着BERT等预训练模型的优化;由于BERT中Transformer(多层自注意力)关于输入文本长度L有的O()的时间空间复杂度,长文本消耗显存陡然增加。
想象一下,一位工程师兴致勃勃地将数据在设计好的下游任务上微调,满怀期待地盼望着结果的提升,却因为其中的一些长文本使得显存溢出或超过位置嵌入(position embedding)最大长度,该是一件多么沮丧的事情。
解决这个问题最直接的方法是滑动窗口(sliding window)对每个512(通常BERT位置嵌入的最大长度)字符的窗口分别预测,最终合并不同窗口的结果的方式随着具体下游任务的不同略有差异,例如阅读理解问答可以输出各段中总评分最高的小段(span)作为答案。然而,如果问题需要长程注意力,也就是两个关键的句子分布在段落中相距较远位置的时候,这种方法的效果就会大打折扣,下图就是一个例子。
解决这个问题的另一种思路是优化Transformer结构,这一条思路的工作有很多,例如Longformer [1]、BlockBert、最近的BigBird等…… 但是这些工作通常只是将文本长度从512扩展几倍(基于现有的硬件条件),让BERT一次“看到”更多的文本;然而,人类并不需要如此强的瞬时阅读能力——实际上人类同时在工作记忆里存储的元素通常只有5-7个——也能阅读并理解长文本,那么人类是如何做到的呢?
认知中的工作记忆和调度“工作记忆的核心是一个中央处理机制,它协调来自于多种来源的信息”, 并且“它发挥一个有限容量的注意力系统的作用,这个系统能选择和操作控制过程和策略”, 这是工作记忆的提出者Baddeley [2] 在他1992年《Science》著作中的论断。事实上,人脑正是通过回忆和注意力,协调长期记忆和短期记忆(工作记忆)的使用策略来完成对长文本的理解。
CogLTX的工作流程受到人的认知过程启发,我们用同样的方法来处理长文本。如果将BERT的512输入字符限制比作人的工作记忆,那么既然人思考问题时能够找到关键的少量信息,并在工作记忆中推理出结果,BERT的512也应该远远足够,关键是对于特定的问题,我们要最终用的真正关键的那部分信息。
MemRecall关键信息抽取对于关键信息的认识本身也是智能的重要部分,这并非易事。最直观的想法是通过信息检索的办法(例如BM25)来抽取关键句,但是仔细一想就会发现这其实是不可行的,因为下游任务的不确定性,无法建模成信息检索的形式。例如,文本分类任务如果用BM25去检索,则无法定义查询(query)是什么。因此抽取的模型也要与任务息息相关。
在模型训练时,我们考虑两种情况:第一种是阅读理解问答这样的任务,由于信息句可以从答案所在句推断出来,因此是监督学习。此时评分机和推理机的训练(finetuning)都比较简单,只需将真正的关键句和一些负样本信息句组合,然后像正常BERT那样训练即可;第二种是文本分类这种,数据集中往往不会提供关键句的标注,这就需要我们自己推断。
关键句的一个特性是,如果缺少关键句将不能推断到正确答案,因此我们先用词向量等方法初始化关键句标签后,再训练中调整关键句标签,如果某个句子剔出后损失函数骤然增加那么就必然是关键句,如果可有可无则不是,根据这个方法在调整关键句标签后可重新进行下一轮训练,具体算法如下:
文章在NewsQA、HotpotQA问答数据集,20NewsGroup文本分类和Alibaba淘外文本多标签分类等几个任务上进行试验,结果均超过或类似于目前最好的模型效果,具体数据在论文中列举。同时,CogLTX牺牲了部分推理的时间,换取了与文本长度无关的训练空间开销。
对于BERT处理长文本时遇到的困境,通常的做法都会考虑轻量化Transformer的思路,然而如果能从人类处理信息的方式得到启发,另辟蹊径从下游任务微调的流程上考虑,更直接地解决这个问题。