Quiet-STaR¶
约 7651 个字 预计阅读时间 38 分钟
Info
原论文出处:Quiet-STaR
引入背景¶
当人们写作和交谈时,常常会停顿思考。尽管以推理为中心的作品通常将推理框架视为回答问题或完成代理任务的方法,但推理在几乎所有书面文本中都是隐含的。例如,这适用于证明中未在行间明示的步骤,或对话中潜在的心智理论。文本中有很多含义是隐藏在行间的,理解这些隐含信息对深入理解文本至关重要。研究表明,语言模型(LM)在处理各种任务时(如常识推理、定理证明和编程)也表现出这种现象。
对于目前允许LM从其推理中学习的方法(例如,STaR , Zelikman等,2022)则专注于解决单个任务或预定义的任务集合(例如,Wei等,2021b)。虽然关于文本的推理以预测后续文本的一致性被证明能够提高LM在多种任务上的表现,但这些研究都是专注依赖于精心策划的数据集,以提供特定的推理任务或在某些情况下提供推理本身。
在STAR中,有用的思维通过从少量示例中的问答(推断理由)的过程和导致正确答案的示例中学习。这是一个高度受限的设置。而在理想情况下,语言模型可以在任意文本中学习推断未表述的理由。
所以本文提出了Quiet-STaR,这是STaR的一种推广,在这种方法中,语言模型学习在每个token生成时生成推理步骤(rationales),以解释未来文本,从而改善其预测。
Quiet-STaR
Quiet-STaR,它可以被理解为“安静地”应用STaR,即训练模型在说话之前思考。总体而言,Quiet-STaR通过在生成每个token后生成推理(rationals)来解释未来文本(思考),将带有和不带有推理的未来文本预测结合(交谈),然后利用REINFORCE学习生成更好的推理(学习)。
解决的关键挑战
- 生成后续文本的计算成本
- 语言模型最初不知道如何生成或使用内部思维
- 需要预测的不仅仅是单个下一个词汇(即要把后续所有可能生成的token都要考虑在内)
作出的主要贡献
- 将STaR推广为从多样化的非结构化文本数据中学习推理。这是第一部明确训练LM从文本中进行一般推理的工作,而不是在策划的推理任务或推理任务集合上进行训练。
- 提出并实现了一种并行采样算法(parallel sampling algorithm),使我们的训练过程可扩展,从给定字符串的所有token位置生成推理。
- 在每个思考的开始和结束引入自定义meta-token,以便让LM学习何时应该生成推理以及何时应该根据该推理做出预测。
- 使用一个混合头来回溯性地确定通过在当前token推理产生的对下一个token预测的概率应当多大程度的被结合到基于基础模型预测下一个token的概率。(我的理解就是做了对推理的和原来的logits通过MLP做了一个加权和)
- 提出并证明了一种非短视的损失函数(即在语言建模中包含多个未来词汇的预测)能够提升思维效果。
- 在多个任务上,展示了思考使得LM在预测困难token时表现得比仅在相同网页文本上训练的模型更好,且思考的越多效果越好。
相关工作¶
语言模型中的推理
这部分主要讨论了如何通过推理训练语言模型以解决复杂任务,重点介绍了一些关键研究成果。以下是要点总结:
-
推理训练的优势:通过先训练语言模型进行推理,再回答问题,可以提升模型的性能。Rajani等(2019)的研究表明,微调后的模型在输出人类推理过程后,能够在常识推理问题上优于直接训练回答的模型。
-
辅助框架和推理链:Shwartz等(2020)展示了在提供某些框架时,语言模型能够生成有用的推理链,而不需要额外监督。Nye等(2021)进一步发现,随着模型能力增强,所需的框架(scaffolding)减少。Kojima等(2022)则证明了该行为可以在零样本(zero-shot)情况下实现。
-
强制推理与信心机制:Wang和Zhou(2024)展示了,在常识问答中,可以通过防止模型生成不确定的答案,强制模型利用推理链。然而,这种方法主要适用于问答数据集,且依赖启发式方法来判断模型是否输出了答案。
-
基于目标文本的质量估计:类似于TRICE方法(Phan等,2023),作者通过目标文本对推理过程的对数似然改进来评估推理质量,但不引入复杂的控制变量。
训练语言模型进行推理
这段主要讨论了训练语言模型进行推理的两种主要方法及其优缺点。以下是要点总结:
-
基于已挖掘的推理轨迹:研究者们通过训练语言模型使用挖掘出的推理轨迹或类似数据来提高推理能力。这种方法虽然有效,但存在一些缺点,如需要手动注释,依赖注释者的能力,且训练数据的分布与语言模型自然生成的文本不同。此外,这种方法成本高,难以扩展,无法解决比注释者能解决的问题更难的任务。
-
基于模型自生成的推理:另一种方法依赖于语言模型自己生成的推理,类似于自我对弈的概念。例如,自学推理器(Self-Taught Reasoner, STaR)通过迭代训练模型使用其生成的推理来解决越来越复杂的问题。后续研究还通过引入更多假设或信息(如多数票算法)来改进这些方法,尽管性能有所下降。其他研究进一步扩展了自学推理器的结果,如基于“过程监督”过滤不正确的推理轨迹,以及通过训练验证器来引导生成过程来提高性能。
-
数学推理中的约束学习:相关工作还探索了在数学推理的受限环境中学习中间推理的过程,确保推理中的陈述是有效的数学陈述。
元令牌
这部分讨论了自定义元令牌(meta-tokens)在神经网络中的应用和效果,特别是针对语言模型的优化。以下是核心要点的总结:
- 自定义令牌的功能:近期研究表明,自定义令牌可以在神经网络中执行特定功能,因此被称为“功能向量”(function vectors)。其中,prompt-tuning(Lester等,2021)和prefix-tuning(Li & Liang,2021)通过优化提示中的嵌入向量来更好地完成任务。
- 元令牌的应用:一些研究将元令牌用于压缩长提示,以提高效率。例如,Mu等(2024)优化了一个令牌,使得后续令牌无法访问之前的令牌(即上下文压缩令牌),但依然能为未来的令牌提供足够的信息。
- 影响注意力和控制复杂行为的令牌:尽管该研究不专注于压缩,但也面临类似的问题,即学习一个令牌来影响注意力并控制复杂的下游行为。相关研究(Goyal等,2023)表明,通过学习一个“暂停”令牌(将每个令牌视为两个令牌),可以提升语言模型性能。然而,这个暂停令牌并不是初始化推理过程的,而是表示整个思维过程的完成。
- 推理在语言中的重要性:研究发现,推理在语言模型中的作用更为显著,能够有效提升模型的表现
引入推理步骤(rationales)的动机¶
在这篇文章中是通过引入中间推理步骤(rationales)来帮助模型更好地预测未来文本,尤其提高模型在推理任务中的表现。
主要做法是在每对观察到的序列token之间插入一个辅助的推理步骤变量(rationale variable)。这个推理变量代表了中间的思维过程或推理链。然后,语言模型的目标是通过优化参数\(\theta\)使得模型能够生成这些中间推理过程(rationales)。
以上是优化参数\(\theta\)的目标函数,意思是模型不仅要根据之前的输入token \(x_{0:i}\) 来预测后续token \(x_{i:n}\),还要生成一个中间的推理过程\(rationale_\theta\),用以提升对后续序列的预测能力。
事实上引入rationale并不会在比如生成任务上对于已经训练好可以正确建立语言分布的最优语言模型更有优势。但在推理任务中,模型却会从这些中间的推理链中受益。
引入动机主要在于,作者认为推理的作用在于帮助模型将复杂的计算分解为更小的步骤。通过这种方式,模型学会了哪些分解和规划步骤对预测未来文本有效。而事实上,之前的研究(如Nye等人,2021;Zelikman等人,2022)表明推理任务中的推理链确实可以提升模型的表现。
本文引入非短视损失(non-myopic)事实上也是出于这么一个考量,每一步的推理去预测更长的目标会更有效地引导模型学习如何分解复杂问题。
Quiet-STaR 模型详解¶
总体介绍¶
Quiet-STaR模型有三个主要步骤:
并行推理生成(think):对于输入序列中的𝑛个token \(x_i\),并行生成长度为\(t\)的\(r\)个推理\(c_i = (c_{i1}, \dots, c_{it})\),最终生成\(n×r\)个推理候选。插入学习到的<|startofthought|> 和<|endofthought|> token,标记每个推理的开始和结束。
混合后推理与基础预测(talk):通过每个推理生成后的隐藏状态,训练一个“混合头”(shallow MLP),产生权重,决定在后推理预测的下一个词的logits与基础语言模型预测的logits之间如何融合。这一方法在微调早期通过引入推理来缓解分布偏移问题。
优化推理生成(learn):通过优化推理生成的参数(例如起始/结束token及语言模型权重),来增加推理使得未来文本更可能出现的概率。使用REINFORCE算法,根据推理对未来词预测的影响,提供学习信号。为了减少方差,采用教师强制技巧(teacher-forcing trick),将损失中不仅包括思维后的词预测,还包括后续词预测的可能性。
并行推理生成(think)¶
因为本文基本思想是在每个token后都去生成rationales,所以如果按顺序对每个token进行生成不可避免地会产生多次前向传播,这对于长序列来说在时间上的花费是让人难以接受的。
所以很自然的联想到对于每个token并行同步生成rationales将会极大减少计算所需的时间。
但事实上,对于每个token生成的rationales中的后续token来说,它并不能直接与原始序列相连,每个词应该还是按照token自回归的顺序生成适合其前缀token的上下文。简单举个文中提到的例子,对于序列 “
所以为了高效的实现并行推理,且可以很好的将每一次生成的"反事实"延续集合的元素分离开,本文的方法就是先会让模型缓存每次的前向传递结果,并将对角注意力掩码(diagonal attention mask)连接到之前的注意力掩码中。
如上图所示,此图很清楚的画出本文实现并行推理生成的原理。先解释下这个图(我个人认为论文中这个图直观理解可能会陷入误区),左图是呈现了Quiet-STaR生成token的可视化过程,这个很好理解,主要就是体现了并行化生成rationales的过程,不多加说明。看右图,这里需要结合图的描述与左图去进行理解,其中灰色部分其实是每个token可以允许注意到的部分,白色区域才是掩码部分。
具体理解是要纵向看这张图,假设输入的Base Text是abcd,然后在并行生成Thought Token1时,a允许注意到a本身,b则是a,b,以此类推,去生成a',b',c',d'。在生成Thought Token2时a'允许注意到a和a',去生成a'',b'允许注意到a,b,b'去生成b''。依次类推再去生成c''和d''。注意到,生成Thought Token2的过程其实就是在继承上次的对角掩码的同时,构建一个对角注意力掩码(diagonal attention mask),只让比如c'注意到a,b,c的同时还只能关注到"反事实"延续集合中的它本身。这样通过这种巧妙的掩码机制就完成了并行生成和分离"反事实"延续集合元素这么两大要求。对于后续的Thought tokenk,其实就不断连接对角注意力掩码即可,分析同Thought token2。
Faster Parallel Sampling
这一部分是文中在附录补充,可以用以参考的加速并行推理算法。总体可以分为三个关键点:
1.通过把注意力机制的成对计算(pairwise dot-products),简化为逐元素(elementwise)计算。只需要对每个元素计算它自己在对角线上的注意力,而不是与整个序列的每个元素进行比较。对比到图中就是将第一步的三角掩码改为对角掩码。
2.可以通过自己设定一个index操作去返回一个掩码矩阵,实现选择性地生成某些令牌的延续,从而避免不必要的计算。
3.既然实现了选择性地对某些token进行生成推理链,那么就可能存在生成推理链时相对位置编码进行改进的情况。
其实总体而言,文中实现并行推理还是基于之前描述,这里提出的是文中略带一笔的想法,并未详细地对怎么设计index,什么情况下需要对相对位置编码修改这样的问题提出实现方案。
混合后推理与基础预测(talk)¶
在使用预训练模型时,由于预训练模型并没有被设计来处理“思考”(thoughts)这种额外的信息,刚开始生成的推理会与模型的原始数据分布不一致(out of distribution)。这会对语言建模的性能产生负面影响。简单来说,预训练模型并不习惯处理“思维过程”的信息,因此需要一个平滑的过渡来适应这种新信息。
为了平滑引入推理的过渡,作者提出了一种插值方法,在语言模型的预测中引入“推理”与“非推理”的预测混合。也就是说,模型需要根据推理信息来调整其预测的权重,平衡“有推理”和“没有推理”的结果。
插值的实现方式:
-
每个token生成后,推理过程结束时会有一个特殊的标记(end-of-thought token)。对于每个词,模型会考虑这个“推理结束”标记的隐藏状态(hidden state),以及该词本身的隐藏状态。
-
混合头(mixing head)的作用是基于这两个隐藏状态,输出一个权重(weight),该权重决定后续的语言模型预测(带推理的)结果与原始语言模型预测(不带推理的)结果之间的比例。换句话说,权重越大,模型越倾向于使用带有推理的预测结果。
混合头使用一个浅层的多层感知机(MLP)来生成每个token的权重。这个 MLP 的输出是一个标量(scalar),表示每个token的混合权重。
执行细节
混合头的输入是一个向量,这个向量的大小是语言模型隐藏状态(hidden state)大小的两倍。这是因为输入向量是通过将两种预测(带有推理的预测和不带推理的预测)拼接(concatenate)在一起得到的。
混合头使用了一个三层的 MLP,作为计算权重的核心部分。MLP 的每一层使用了 ReLU 激活函数(Rectified Linear Unit),最终输出一个标量用于对两类logits的加权。
优化推理生成(learn)¶
优化推理开始和结束token¶
优化这些token的表示非常重要,尤其是< startofthought >token,因为它会影响模型如何生成推理链。然而,这个优化过程具有挑战性,原因是这些推理token是离散的,难以通过简单的方式嵌入到连续的模型训练中。
为了解决这个问题,作者选择将推理开始和结束令牌的嵌入初始化为类似于文本中的破折号 ——,这种符号在文本数据中通常表示一个停顿或思考的开始,借此利用语言模型已经具备的先验知识。
为了更快地优化这些令牌的嵌入,作者给这些令牌的嵌入增加了一个超参数权重(hyperparameter weight)。简单来说,模型在每次训练中调整这些令牌的嵌入时,会给它们的梯度施加一个额外的权重,从而加速它们的学习。
例子理解
假设我们在训练模型生成文本,而
非短视评分和教师强制¶
本文作者并不希望思维过程(thoughts)在每个词的预测中都发挥作用,也就是说,希望模型的奖励机制应该更多地依赖后续的语义内容,而不是紧紧依赖于思维之后生成的下一个具体token。
然而这样的话,这里就会存在一个问题,就是基于之前对模型思维并行生成的描述来看,其实在推理过程中只有与当前预测(下一个token)直接相关的推理链才会收到梯度更新。(下面是我的理解,这里的收到梯度更新应该是只有比如输入a,b,c序列,计算损失时只有与可以预测出b的a的相关推理链可以用于进行梯度更新)这样的话就会有一点问题,对于其它生成的推理链,可能存在只是对当前预测下一个token作用不大,但对于后续一系列推理发挥有效作用的推理链就会被废弃掉。
然后有一个解决方案就是我让模型不要只是去关注比如可以预测出b的a的相关推理链,而是多关注几个后续的token,比如b,c,d,让只要后续的token有预测正确的推理链都加入梯度更新。这样的话可以更充分的考虑多个推理链。但是这种方式也会带来新的问题:
- 更高的熵:模型会更随机地生成文本,导致语言模型的熵增加,生成的文本质量下降。
- 忽视前面的词:如果过分关注未来的词,模型可能会忽略之前已经生成的上下文信息,从而导致文本连贯性降低。
最终本文选择的解决方案是在利用并行注意力掩码机制的同时,引入教师强制的方法。具体看下图:
像图中一样,比如对于f的推理链更新,会先设置在更新参数的损失函数中包含几个提前预测的未来的token的超参。图中设定的超参是3。教师强制的作用就体现在虚线部分,是强制给模型插入正确的下一个token或者开始思考的提示符。这样的话就保证了每一次预测下一个token时,模型前面传入的token都是正确的。不会出现语言模型熵太高导致文本质量下降的问题。
然后在每个预测token的阶段,比如说f预测\(\hat{g}\),这里就是用混合头机制,将对于模型生成的分布与插入推理后生成的分布进行加权。值得一提的是,因为这里是为了优化f的推理链,所以比如预测未来第二个token h的时候基于推理链预测的\(\hat{h}_T\),是基于f推理链在插入正确答案g的情况下生成的。后续就是利用\(\hat{g}、\hat{h}、\hat{i}\)的概率分布去计算损失函数以及更新参数。
目标函数¶
先说结论,我认为论文中的伪代码和这一部分的描述都存在问题,应该是写错了。但思想本身在前面已经体现,后续我会按照我的理解对其公式改进的进行介绍。先说一下为什么我认为原文伪代码部分和这一部分的写作有错误的地方。
原文和伪代码
从原文的描述结合之前的理解我们可以知道,\(r_j\)应该是代表了第j个推理链(rationale)的奖励得分。但从伪代码我们可以看见j迭代的是序列长度,也就是说在伪代码中它代表的不是某一token的第j条推理链,而是第j个token。这样的话伪代码和原文其实就矛盾了,因为这样的话\(p^{talk}_{j:j+n_{true}}\)就无法代表像文中说的a particular rationale。有很明显的指代不明。
所以为了延续文中的思想,很明显的我们需要引入一个新的迭代变量k,它用来表示第j个token的第k条rationale \(T_{j,k}\)。新的公式如下:
很明显此时\(\log \bar{p}^{\text{talk}}_{j:j+n_{\text{true}}} \left( X_{j+1:j+n_{\text{true}}+1} \right)\)就是在对k个rationale预测的第\(j+n_{true}\)位置token的概率求均值。
文中还提到为了进一步在基础LM上优化对下一步的预测结果,又引入了一个损失函数\(\mathcal{L}_{j,k}^{NLL}\),根据伪代码它的计算公式如下:
文中最终使用的是去除掉r小于0的奖励更迭,因为发现训练会更稳定,尽管会引入些许偏差。
所以最终对于伪代码的修改只要增加一个k的循环,次数是\(n_{thoughts}\)次。
需要优化的参数: 模型参数,以及开始和结束token的embeddings。
实验¶
这一部分就是简单记录,详情看文章。
实验的研究目标:
-
探讨Quiet-STaR是否对那些确实需要推理的 token 的预测有帮助。
-
评估两个方面:
- Quiet-STaR 是否提高了模型在需要推理的数据集上直接预测答案的能力;
- 额外思考对各 token 的影响分布。
-
实验数据与结果:
- 主要在 OpenWebMath 数据集上进行实验,该数据集重点抓取了技术性网页,作者认为这类网页中有更多需要推理的 token。
- 还在 C4 数据集上评估了 Quiet-STaR,该数据集文本更加多样化,实验结果表明效果有显著但较小的提升。
-
-
评估Quiet-STaR 算法在不同下游任务中的表现
- CommonsenseQA 和 GSM8K 的实验结果:
- CommonsenseQA:与基础模型相比,Quiet-STaR 将性能提升了 10.9%。随着模型中推理链长度的增加,性能也随之提升,表明更多的推理步骤有助于提高直接问答的效果。
- GSM8K:Quiet-STaR 将性能提升了 5.0%,同样地,随着推理链长度增加,性能得到进一步提升。
- 在 C4 数据集上的实验结果:
- 在 C4 数据集上进行的 Quiet-STaR 训练,GSM8K 的性能从 5.9% 提升到 8.1%,CommonsenseQA 的性能从 36.3% 提升到 42.6%,但提升幅度较小。
- 这些实验基于训练 Mistral 7B 模型,使用 16 个推理链 token 和 4 个真实 token 进行推理。
- 与 Pause Tokens 的对比:
- Pause Tokens(暂停 token)是 Goyal et al. (2023) 提出的一种方法,与 Quiet-STaR 的区别在于每个 token 由两个 token 表示,其中一个作为“暂停”进行推理。
- 结果显示,在 CommonsenseQA 上,Pause Tokens 的微调提升了从 26.9% 到 28.8%,但在 GSM8K 上表现有所下降。
- 总体上,Pause Tokens 的额外推理对性能并不总是有益,相反,Quiet-STaR 通过多 token 推理链获得了更好的推理效果。
- 与 Pause Tokens 的微调不同,Quiet-STaR 的实验并未在下游任务上进行微调。
- 总体下游表现
- 训练语言模型预测一般文本中隐藏的“次文本”,可以显著提升模型的推理能力,甚至在那些没有明确训练的数据集上也能提升表现。
- 推理链越长,性能提升越明显,而且 Quiet-STaR 在推理效果上优于受限的暂停 token 方法,这表明 Quiet-STaR 成功地教会了模型利用自身生成的推理链更深入地推理输入内容。
- CommonsenseQA 和 GSM8K 的实验结果:
-
Quiet-STaR 与 Chain-of-Thought (CoT)
- 方法之间的关系:
- Chain-of-Thought 提示是一种让模型“自言自语”进行推理的方式,用户会主动提示模型生成逐步的推理,使用普通的生成分布来实现 (Kojima et al., 2022)。
- Quiet-STaR 则是让模型在每个 token 上“安静地思考”,并通过训练使这些推理分布变得有用。
- 结合使用 Quiet-STaR 和 CoT:
- 在实验中,作者探索了使用 Quiet-STaR 的内部推理链,同时生成显式的 Chain-of-Thought 推理。
- 由于作者的目标是通用推理,因此使用了零样本提示(“Let’s think step by step.”)而不包含上下文示例。
- 结果表明,内部推理链有助于生成更结构化和连贯的 Chain-of-Thought 推理。实验显示,GSM8K 测试集上的 8 个样本的多数投票准确率 (cot-maj@8) 从 40.6% 提升到 47.7%。
- 思考链的贡献:
- Quiet-STaR 中生成的思考链有助于引导模型进行更深入和更相关的推理,特别是当需要理解复杂的上下文或解释时。例如,化学反应和数学命题的推理链都帮助模型更好地理解了推理过程,并生成了相关内容。
- 有时候思考链会生成一些推理,但并不一定直接用于预测。即便如此,这些思考过程依然对模型理解输入文本的语义有所帮助。
- 方法之间的关系:
有待改进¶
- 当前模型是在就是已经训练好的LM基础上构建的推理学习框架,但我们还不能很确定这些技术是否在从头开始训练模型时同样有效。
- 作者目前只将 Quiet-STaR 应用于一个拥有 70 亿参数的模型,尽管它是一个强大的模型。预计如果应用于更好的模型,可能会带来更显著的改进(这在推理任务中经常被观察到)。但却未在多个其他模型上试验过。
- Quiet-STaR 需要生成许多 token 才能生成每个额外的 token,这导致了相当大的计算开销。但这种计算开销也可以看作是一种优势:通常语言模型只能基于当前上下文生成下一个 token,虽然有方法提高采样质量,但没有通用的方法能够利用额外的计算资源来提升下一个 token 的预测能力。而 Quiet-STaR 提供了这种可能性。
- 在当前实现中,模型还不支持动态地预测何时生成或结束推理链。
- 作者提到一种改进方法:如果混合头部(mixing head)来自于基础语言模型的预测,而不是思考之后的结果,那么可以在生成推理链之前应用一个阈值,以防止生成那些不会被纳入的推理链。这种方法能够提高模型的计算效率。
- 但作者也认为,这样的改进会更加困难,因为在生成推理链之后评估其有用性要比在生成之前进行预测更容易。