击败 LLM 推理中的非确定性
可复现性是科学进步的基石。然而,要让大语言模型(LLM)给出可复现的结果却出奇地困难。
例如,你可能会发现多次向 ChatGPT 提同一个问题会得到不同的答案。这本身并不奇怪,因为从语言模型获取结果涉及“采样”:把模型输出转成概率分布,并按概率选取一个 token 的过程。
更令人吃惊的是,即便我们把温度调到 0(这意味着 LLM 总是选择概率最高的 token,即所谓的贪心采样,从而在理论上使采样成为确定性的),LLM 的 API 在实践中仍然不是确定性的(过往讨论见这里、这里或这里)。即使在你自己的硬件上、使用 vLLM 或 SGLang 等开源推理库运行推理,采样依然不是确定的(见这里或这里)。
但为什么 LLM 的推理引擎不是确定性的呢?一个常见假说是:浮点数的“非结合性”和并发执行的组合,会导致根据“哪个并发核心先完成”而出现非确定性。我们称之为 LLM 推理非确定性的“并发 + 浮点”假说。例如,一篇近期的 arXiv 预印本写道:
GPU 中的浮点运算表现出非结合性,意味着由于有限精度和舍入误差,。这一性质直接影响 Transformer 架构中注意力分数和 logits 的计算——当多个线程并行执行时,运算顺序不同会产生不同的结果。
你还会在其他地方看到“并发 + 浮点”假说被反复提及,比如这里(“存在速度权衡。为了让端点更快,会使用 GPU,它们进行[非确定性的]并行计算。任何现代 GPU 上的神经网络计算都可能受到这一点影响。”),或这里(“由于 GPU 高度并行化,加法或乘法的顺序在每次执行中可能不同,这会层层传递并造成输出上的微小差异。”)。
虽然这个假说并非完全错误,但它并不能展示全貌。比如,即便在 GPU 上,对同一数据重复运行同一个矩阵乘法,总是会得到按位相等的结果。我们确实在用浮点数,我们的 GPU 也确实高度并行。那为什么在这项测试中看不到非确定性呢?
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
    assert (torch.mm(A, B) - ref).abs().max().item() == 0
要理解 LLM 推理非确定性的真正原因,我们必须往更深处看。
不幸的是,连“LLM 推理是确定的”这句话的定义都不好下。也许让人困惑,但下面这些陈述可以同时为真:
- 有些 GPU 上的 kernel(内核)是非确定的。
- 然而,语言模型前向传播中用到的所有 kernel 都是确定的。
- 此外,一个 LLM 推理服务器(如 vLLM)的前向传播也可以被称作是确定的。
- 然而,从任何使用该推理服务器的人的视角看,结果又是非确定的。
在这篇文章中,我们将解释为何“并发 + 浮点”假说偏离了靶心,揭示 LLM 推理非确定性的真正罪魁祸首,并说明如何击败非确定性,获得真正可复现的 LLM 推理结果。
原罪:浮点数的非结合性
在谈非确定性之前,先解释为何会有数值差异。毕竟,我们通常把机器学习模型看作是遵循交换律、结合律等结构性规则的数学函数。难道不应该存在一个“数学上正确”的结果,让我们的机器学习库去给出吗?
罪魁祸首是浮点数的非结合性。也就是说,对浮点数:
(0.1 + 1e20) - 1e20
>>> 0
0.1 + (1e20 - 1e20)
>>> 0.1
具有讽刺意味的是,正是“破坏结合性”让浮点数变得有用。
浮点数之所以有用,是因为它们允许“动态”的精度级别。为了便于说明,我们使用十进制(而不是二进制),把浮点数写成 的形式,并假设尾数 3 位、指数 1 位。
例如,数值 3450 可以被精确表示为 。更小的值如 0.486 可以表示为 。通过这种方式,浮点数既能表示很小也能表示很大的值。在科学里,我们会说浮点数帮我们保持固定数量的“有效数字”。
如果把两个指数相同的浮点数相加,它看起来和整数加法很像。例如,123()+ 456()得到 579()。
但如果把指数不同的两个浮点数相加,比如 1230 和 23.4,会发生什么?精确结果是 1253.4。然而我们一次只能保留 3 位精度。因此浮点加法会“丢弃”最后两位,得到 (即 1250)。
1.23 × 10²
3.45 × 10¹
=
1.575 × 10²
精确值:1575
表示 1230 需要 3 位有效数字,表示 23.4 也需要 3 位。但是把这两个数加起来会得到一个需要 5 位有效数字的数(1253.4)。我们的浮点格式只能把末尾的 34 丢弃。从某种意义上说,我们等价于把原本的 23.4 先四舍五入成了 20.0 再去相加。
此时,信息已经丢失。注意,只要把不同“尺度”(指数不同)的两个浮点数相加,这种情况就可能发生。而加法中出现不同指数几乎是家常便饭。事实上,如果我们能保证永远不需要不同的指数,我们大可只用整数!
换句话说,只要改变浮点数相加的顺序,我们就可能得到完全不同的结果。一个极端例子是,下面这个数组根据求和顺序的不同,可能出现 102 种不同的结果。
import random
vals = [1e-10, 1e-5, 1e-2, 1]
vals = vals + [-v for v in vals]
results = []
random.seed(42)
for _ in range(10000):
    random.shuffle(vals)
    results.append(sum(vals))
results = sorted(set(results))
print(f"There are {len(results)} unique results: {results}")
# Output:
# There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ..., 8.326672684688674e-17]
虽然这是产生“结果不完全相同”的根本原因,但它并不能直接回答非确定性从何而来。它并没有帮我们理解:为何会以不同顺序相加、何时会发生、以及如何避免。
答案藏在 kernel 的实现方式里。
为什么 kernel 不总是按相同顺序相加?
如上所述,一个常见解释是“并发 + 浮点”假说。该假说认为:如果并发线程完成的顺序是非确定的,而累加的顺序又取决于线程完成的顺序(比如用原子加法 atomic add),那么累加顺序也会是非确定的。
让人困惑的是,尽管这确实会导致某些 kernel 非确定,但“并发(和原子加法)”最终与 LLM 推理的非确定性无关!为了说明真正的罪魁祸首,我们先解释为什么现代 GPU 内核很少需要原子加法。
什么时候需要原子加法?
通常,GPU 会在很多“核心”(SM)上并发地启动一个程序。由于核心之间没有固有的同步,如果核心之间需要通信就会遇到挑战。比如,如果所有核心都必须累加到同一个元素上,你可以使用“原子加法”(有时也称为“fetch-and-add”)。原子加法是“非确定的”——累加的顺序纯粹取决于哪个核心先完成。
具体而言,想象你用 100 个核心去约简一个 100 维向量(例如 torch.sum())。虽然可以并行加载这 100 个元素,但我们最终必须把它们约简到一个元素。一种方式是使用某种“原子加”原语,由硬件保证所有加法都会被处理,但不保证顺序。
原子加法可以保证每个核心的贡献都会反映在最终的和中。然而,它对这些贡献被相加的“顺序”不做任何保证。顺序完全取决于哪个核心先完成,这是一个非确定的属性。因此,多次执行同一个并行程序可能得到非确定的输出。
这通常就是人们所说的“非确定性”——把同一个 kernel 用完全相同的输入运行两次,却得到不同的输出。这被称为“运行到运行(run-to-run)非确定性”,即你用完全相同的依赖环境运行同一段 Python 脚本两次,但得到不同结果。
尽管并发的原子加法确实会让一个 kernel 变得非确定,但对绝大多数 kernel 来说,原子加法并非必需。事实上,在一个典型的 LLM 前向传播里,通常一个原子加法都用不到。
这也许让人意外,因为并行化一个约简确实可能受益于原子加法。之所以最终不需要,主要有两个原因:
- “批次”维度上通常有足够的并行性,我们不需要在约简维度上并行化。比如,不是约简一个 100 维向量,而是并行约简 500 个 100 维向量。这样我们就可以让每个核心各自独立地把一个向量完全约简掉,每个核心处理一个不同的向量。
- 随着时间推移,大多数神经网络库已经采用了各种策略,在不牺牲性能的情况下实现确定性。例如,我们可以做“分裂(或树形)约简”,先把 100 元素的约简拆成 5 个 20 元素的约简(从而获得 5 路并行)。然后,为了合并剩下的 5 个元素,可以单独做一个“清理”约简(不并行,但元素很少所以很便宜),或者使用信号量(semaphore)(确保并发线程块以确定的顺序累加)。关于信号量策略的描述见这里。
由于以上两点,避免使用原子加法对绝大多数神经网络运算来说几乎没有性能损失。
仍有少数常见操作在避免原子加法时会有显著性能代价。比如 PyTorch 中的 scatter_add(a[b] += c)。不过在 LLM 中常用、且需要原子加法的基本只有 FlashAttention 的反向传播。有趣的是:你知道广泛使用的 Triton 版 FlashAttention 反向传播在算法上其实和 Tri Dao 的 FlashAttention-2论文不同吗?标准的 Triton 实现会在反向传播里做额外的重计算,从而避免原子加法,但代价是多了约 40% 的 FLOPs!
然而,LLM 的前向传播里并没有需要原子加法的操作。因此,LLM 的前向传播事实上是“运行到运行确定”的。
模型确定性用户请求其他用户请求输出 从推理服务器的视角看,它是确定的。给定完全相同的用户请求,它总会给出相同的确定性输出。
维基百科写道:“确定性的算法是指,给定某个输入时,总是产生相同输出的算法。”在这里,给定完全相同的输入(即推理服务器正在处理的确切请求),前向传播总会产生完全相同的输出。
不过,前向传播自身“确定”,并不代表包含它的整个系统就是确定的。比如,如果我们的请求输出依赖于并行的其他用户请求(如 batch-norm 会跨样本),那对每个独立请求而言,整体的 LLM 推理依然是非确定的!
事实证明,我们的请求输出确实“依赖”并行的其他请求。但这并不是在跨 batch 泄漏信息,而是因为我们的前向传播缺乏“批次不变性(batch invariance)”,导致某个请求的输出会依赖于该次前向传播的批大小。
批次不变性与“确定性”
为了说明批次不变性,我们先简化系统,只看矩阵乘法(matmul)。你可以假设所有 matmul 实现都是“运行到运行确定”的(这不完全准确,但大多数常见的实现确实如此)。然而,它们并非“对批次不变”。换言之,当批大小变化时,批内的每个元素都可能得到不同的结果。
从数学视角看,这是个相当罕见的性质。矩阵乘法在批维度上应当是“独立”的——批里的其他元素,或批有多大,都不应该影响到批内某个特定元素的计算结果。
然而,经验上我们可以观察到,并非如此。
import torch
torch.set_default_device('cuda') 
B = 2048
D = 4096
a = torch.linspace(-1000, 1000, B*D).reshape(B, D)
b = torch.linspace(-1000, 1000, D*D).reshape(D, D)
# 取批中的第一个元素,做一次矩阵-向量乘
out1 = torch.mm(a[:1], b)
# 做一次矩阵-矩阵乘后,再取批中的第一个元素
out2 = torch.mm(a, b)[:1]
print((out1 - out2).abs().max()) # tensor(1669.2500, device='cuda:0')
注意,这里的确是“运行到运行确定”的:如果你多次运行脚本,它会确定性地返回相同的结果。它并非“硬件/软件版本不变”——不同的 GPU/PyTorch 版本可能返回不同的值,但对给定环境会确定地返回同一个值。
然而,当一个“对批次不不变”的 kernel 被用作更大推理系统的一部分时,系统可能整体变得非确定。当你向推理端点发请求时,服务器此刻的负载从用户视角看是“非确定的”。负载决定了 kernel 运行时的批大小,从而改变了每个独立请求的最终结果!
模型确定性非确定性用户请求其他用户请求输出 尽管从推理服务器本身可以声称“确定”,但对单个用户而言故事就不同了。对单个用户来说,并行的其他用户并不是系统的“输入”,而是系统的一个非确定属性。这使得从每个用户的视角看,LLM 推理是“非确定的”。
如果把 kernel 对某个属性(如批大小)不具不变性的事实,与该属性自身的非确定性(如服务器负载会变化)组合起来,你就得到一个非确定的系统。
换句话说,几乎所有 LLM 推理端点之所以非确定,首要原因是负载(也就是批大小)在非确定地变化! 这种非确定性并非 GPU 独有——部署在 CPU 或 TPU 上的 LLM 推理端点也会有同样来源的非确定性。
所以,如果我们希望在推理服务器中避免非确定性,就必须让 kernel 实现“对批次不变”。要理解如何做到这一点,先看看 kernel 为什么一开始就没有批次不变性。
如何让 kernel 具有批次不变性?
要让一个 Transformer 的实现具备批次不变性,我们必须让每个 kernel 都具备批次不变性。幸运的是,可以假设每个逐点(pointwise)操作都是批次不变的。尽管在 PyTorch 等库中这基本成立,但这并非天生如此。比如,有些 CPU 上的 kernel 会在数组的部分区域使用向量化指令、在其他区域使用非向量化指令,而这些指令的数值并不总是按位相同。因此,我们只需要担心 3 类涉及到约简的操作——RMSNorm、矩阵乘法和注意力(与并行相关的跨设备约简超出本文讨论范围,但原理相同。一个有用的事实是:在 Blackwell 以及装有 CUDA 12.8+ 的 Hopper 上,NVLink-Sharp 的交换机内约简是确定性的。像很多事情一样,这类信息可以在 NCCL 的github issues里找到)。
恰好这三者在“难度”上也大致从低到高排列。每一个都需要额外考虑,才能在保持合理性能的同时实现批次不变性。我们先从 RMSNorm 说起。
批次不变的 RMSNorm
数据并行的 RMSNorm 理想情况下,我们希望在并行化策略里尽量避免核心之间的通信。一种方式是把每个批元素分配给一个核心,这样能保证每个约简都完全发生在某一个核心里。这就是“数据并行”策略,因为我们只是沿着不需要通信的维度并行化。在这个示意中,我们有四行数据与四个核心,核心被充分占满。RMSNorm 可以实现为:
# x: [batch_size, hidden_dim]
# weight: [hidden_dim]
def rms_norm(x, weight):
    return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight
批次不变性的要求是:无论 kernel 的批大小如何,每个元素的约简顺序都必须固定。 注意,这并不意味着我们必须始终采用同一种约简策略。比如,如果改变了约简的元素数量,即便约简策略改变,也仍然可以保持批次不变性。The Quack 的博文有一些很好的例子,展示了各种约简策略的层次(如线程级约简、warp 级、block 级、cluster 级等)。
因此,只有当批大小会影响我们所用的约简策略时,才会破坏批次不变性。
我们来看 RMSNorm 的标准并行化策略。一般而言,并行算法受益于尽可能减少核心之间的通信。本文里,你可以把“核心”理解为 SM。更具体地说,这里重要的性质是:我们内核启动的线程块数要大于 SM 的数量。于是,一个起步策略就是像上图所示,把每个批元素分配给一个核心。
增大批大小不会影响约简策略;如果批大小为 200 就足以提供内核所需的并行性,那么批大小为 2000 肯定也足够。
面向更大批次的数据并行 RMSNorm 把数据并行策略扩展到更大的批次相当直接——不是让每个核心只处理一行,而是让每个核心按顺序处理不同的行。这能保持批次不变性,因为每个批元素的约简策略完全一致。
另一方面,减小批大小会带来挑战。由于我们把每个批元素分配给一个核心,批大小继续降低最终会导致核心数多于批元素数,从而有些核心空闲。
遇到这种情况,一个优秀的内核工程师会使用前一节提到的解决方案之一(原子加法或分裂约简),以维持良好的并行性与性能。不幸的是,这会改变约简策略,使该内核不再具有批次不变性。
分裂约简的 RMSNorm 如果批大小很小,数据并行策略可能不足以填满核心。在这种情况下,把一次约简“分裂”到多个核心上可能更高效,从而充分利用 GPU。但这会丧失批次不变性,因为每个元素的约简不再按同一顺序进行。
最简单的解决办法就是忽略这些情况。这并非完全离谱——小批大小意味着内核执行得本来就很快,因此一点点变慢未必是灾难性的。
如果我们确实不得不优化这种用例,一种做法是始终使用一种在极小批大小下也有足够并行性的约简策略。这样的策略在大批大小下会产生“过多的并行性”,但能在全区间内提供尚可(虽非峰值)的性能,同时保持批次不变性。
批次不变的矩阵乘法
数据并行的 Matmul 与 RMSNorm 类似,matmul 的标准并行化策略同样是“数据并行”,把整个约简保留在一个核心里。最直观的思路是把输出张量切分为 2D 小块(tile),每个小块分配给不同的核心。每个核心在自己的 tile 内计算对应的点积,整个约简依然发生在单个核心内。
与 RMSNorm 不同的是,由于算术强度与高效使用 tensor core 的额外约束,我们需要以 2D 小块的方式切分,而不是按单个输出元素切分,才能得到高效的 matmul 内核。
从本质上讲,矩阵乘法可以看作是“逐点运算 + 约简”。那么,只要把并行化方式改为按输出切分为多个小块,就能得到一个类似的数据并行策略,使每个约简都落在单个核心上。
同样地,我们的“批”维度(M 与 N)有可能太小,从而被迫沿着约简维(K)做切分。尽管有两个“批”维,matmul 为了高效利用 tensor core,每个核心还必须有足够多的“工作量”。例如,一个 [1024, K] × [K, 1024] 的 matmul,如果标准 2D tile 是 [128, 128],那么数据并行策略最多只能把它分到 64 个核心上,这不足以打满 GPU。
在 matmul 中沿约简维切分被称为 Split-K Matmul。就像 RMSNorm 一样,使用该策略会打破批次不变性。另一个有趣的 matmul 并行化策略是 stream-k。stream-k 更“缺乏不变性”:多数 matmul 库虽然不是批次不变,但至少是“批内位置不变”(也就是说改变元素在批内的位置不会影响数值)。而 stream-k 连批内位置不变也不满足!其核心思想是:针对不同的输出 tile,用不同方式沿 k 切分能获得更干净的负载均衡;但这样做也使 kernel 不再满足批内位置不变。
Split-K Matmul 如果批维很小,我们可能没有足够的并行性,就需要 Split-K。如下例,我们把每个约简拆分到两个核心上,分别累加后在末尾合并。但这样可以让我们依然利用 8 个核心。
matmul 还有一个额外复杂点——tensor core 指令。与仅需逐行操作的约简不同,高效的矩阵乘法内核需要以一个完整“tile”为单位操作。
每条 tensor core 指令(比如 wgmma.mma_async.sync.aligned.m64n128k16)内部的约简顺序可能不同。选择不同指令的一个原因是批大小很小。比如,如果我们用一条处理 256 长度 tile 的 PTX 指令,但批大小只有 32,那这几乎浪费了所有算力!在批大小为 1 时,最快的内核通常根本不会用 tensor core。
填充后的 Tensor Core 指令 如果批大小太小,以至于连一个 2D tile 都放不下,那么切换到更小的 tensor core 指令或干脆不用 tensor core 往往最有效!然而这两种选择都会让 kernel 失去批次不变性。
因此,确保 matmul 的批次不变性的最简单方式是:编译一个固定的内核配置,并对所有形状都用这一份。虽然会损失一些性能,但在 LLM 推理里通常不是灾难性的。特别是,Split-K 最“需要”的情形是 M 与 N 都很小;而在 LLM 里,N(模型维度)往往很大!
即便取得了批次不变性,我们的性能也只比 cuBLAS 慢大约 20%。注意,这里我们用的 Triton 内核也并未刻意优化(例如没有用 TMA)。不过,一些性能模式反而说明了我们为了批次不变性在哪些地方丢了性能。首先,在极小批大小处,由于指令规模过大且并行性不足,我们会损失较多性能。其次,随着批大小增大,会出现“拼图形”的锯齿模式,这是由量化效应(tile 量化与 wave 量化)造成的,而这类效应通常可以通过改变 tile 尺寸来缓解。关于这些量化效应的更多内容见这里。
批次不变的注意力
FlashAttention2 策略 我们沿 Q 维并行,同时沿 K/V 维做约简。这意味着整个约简都能保留在单个核心里,再次得到一个数据并行策略。
在确保 matmul 具备批次不变性之后,注意力带来了两个额外的“皱褶”——这也恰如其分,因为注意力里有两个 matmul。
- 不同于 RMSNorm 与 matmul 只沿特征维约简,注意力需要同时沿特征维与序列维约简。
- 由于上述原因,注意力必须处理多种影响序列处理方式的推理优化(如分块 prefill、前缀缓存等)。
因此,要在 LLM 推理中实现确定性,我们的数值必须同时对“同时处理多少请求”和“推理引擎如何切片每个请求”两者不敏感、保持不变。
我们先走一遍注意力的标准并行策略,它最初在 FlashAttention2 中提出。与 RMSNorm 和 Matmul 类似,默认策略是“数据并行”。由于我们要沿着 K/V 张量做约简,数据并行策略只能沿着查询张量(Q)并行。
例如,取决于推理引擎的选择,一个序列可能被分成若干部分来处理(如分块 prefill),也可能一次性处理(如果 prefill 没有被拆分)。为了实现“批次不变”,必须保证:某个查询 token 的约简顺序不取决于该序列当前与它一起被并行处理的 token 数量。如果你把 KV 缓存里的 K/V 与当前正在处理的 K/V 分开来算(比如 vLLM 的Triton 注意力内核),就无法实现这一点。比如,当处理一个序列的第 1000 个查询 token 时,无论 KV 缓存里有 0 个 token(prefill 阶段)还是 999 个 token(解码阶段),约简顺序都必须一致。
带 KV 缓存的 FlashAttention 之所以“单独处理 KV 缓存”和“当前 KV 值”会破坏批次不变性,微妙之处在于“边界条件”。想象块大小是 32,但我们目前的 KV 缓存里有 80 个元素。然后我们又计算了 48 个不进入缓存的新元素。在这种情况下,计算“P cache”需要三个块(两个满块 + 一个掩码块),计算“P”需要两个块(一个满块 + 一个掩码块)。总计 5 个块来完成约简,而实际只需要处理 4 个块(128 个元素),这样肯定会改变约简顺序。
比如,如果我们没有 KV 缓存、而是一次性处理 128 个元素,那么这两种情形的数值必须完全一致,注意力才能实现“批次不变”。
为了解决这一点,我们可以在进入注意力内核之前就更新 KV 缓存与页表,确保无论当前处理多少 token,我们的 keys 与 values 都采用一致的布局。
在这一附加改动(以及前一节里提到的那些,如固定 tile 尺寸)到位后,我们就能实现一个批次不变的注意力实现!
然而,这里有个大问题。不同于矩阵乘法,在 LLM 推理里我们经常会遇到需要“分裂约简”的注意力形状(常被称为 Split-KV 或 FlashDecoding)。因为如果不沿约简维并行,我们只能沿批维、头维以及“查询长度”维并行。在解码阶段,查询长度非常小;除非批很大,否则往往无法打满 GPU。
不幸的是,这种情况不像 RMSNorm/Matmul 那样可以“直接忽略”。比如,如果 KV 缓存非常长,即便只处理一个请求,注意力内核也可能需要很长时间。
固定分裂数量的 Split-KV(即 FlashDecode)策略 如果查询长度非常小(如解码时),我们可能会面临几乎没有并行性的内核。这时我们又需要沿约简维——这里是 KV 维——做切分。常见策略是先算出打满 GPU 所需的并行度,再把 KV 维平均分成这么多份。例如,KV 长度是 1000、需要 4 个分裂时,每个核心处理 250 个元素。
但这同样会破坏批次不变性,因为我们具体的约简策略取决于某次请求里我们正在同时处理多少个查询 token。
此外,注意力里常用的分裂约简策略在批次不变性上也有挑战。比如,FlashInfer 的“平衡调度算法”会选择仍能打满 GPU 的最大分裂尺寸,从而让约简策略不再“批次不变”。不过,与 RMSNorm/Matmul 不同,单纯“固定分裂数量”还不够。
为了实现批次不变性,我们必须采用“固定分裂尺寸”的策略。换言之,不固定“分裂数量”,而是固定每个分裂的“长度/大小”,最后得到一个“分裂数量可变”的实现。如此一来,无论一次性处理多少个 token,我们都能保证执行完全相同的约简顺序。为此需要对 FlexAttention 做一些内部改动,这些改动未包含在我们目前的代码发布中;我们会在不久之后上游合入!
固定尺寸的 Split-KV 策略 与前一种策略唯一的区别在于:我们的分裂现在是“固定尺寸”的。比如,KV 长度为 1000 时,不是均分成四段 250,而是分成三段固定长度 256、以及一段长度 232。
这使我们能够保持批次不变性,因为约简策略不再取决于一次性处理多少个查询 token!
实现
我们基于 vLLM,结合其 FlexAttention 后端与 torch.Library,演示了确定性的推理。通过 torch.Library,我们可以以非侵入的方式替换掉大部分相关的 PyTorch 运算符。你可以在 thinking-machines-lab/batch-invariant-ops 找到这套“批次不变”内核的库,以及如何以“确定性”模式运行 vLLM 的示例。
实验
完成式(completion)到底有多不确定?
我们使用 Qwen/Qwen3-235B-A22B-Instruct-2507,在温度 0 下,针对提示语 “Tell me about Richard Feynman”(非思维链模式)采样 1000 个完成式,每个生成 1000 个 token。令人惊讶的是,我们得到了多达 80 种不同的完成式,出现次数最多的一种出现了 78 次。
观察这些完成式的分歧位置,我们发现前 102 个 token 实际上都是一致的!第一次分歧发生在第 103 个 token。所有完成式都会生成 “Feynman was born on May 11, 1918, in” 这段文字。然而,其中 992 个继续生成 “Queens, New York”,而另外 8 个则生成 “New York City”。
另一方面,当我们启用“批次不变”的内核时,1000 个完成式全部一致。这正是我们从采样器的数学预期,但如果没有批次不变内核,我们无法实现这种确定性结果。
性能
在这里我们没有对批次不变内核做大量性能优化。不过,我们可以做一些实验来确认其性能仍然可用。
我们用一张 GPU 跑 Qwen-3-8B 搭建一个 API 服务器,发出 1000 个序列的请求,输出长度在 90 到 110 之间。
| Configuration | Time (seconds) | 
|---|---|
| vLLM default | 26 | 
| Unoptimized Deterministic vLLM | 55 | 
| + Improved Attention Kernel | 42 | 
绝大部分的变慢来自 vLLM 中 FlexAttention 集成目前还没有被深入优化。尽管如此,我们看到整体性能并非“灾难性”的。
真正的 on-policy 强化学习(RL)
正如研究者所指出的,训练与推理之间不同的数值行为,会把我们的 on-policy RL 隐式地变成 off-policy RL。
当然,如果连两次完全相同的推理请求都无法做到按位相同,那训练与推理更不可能按位相同。确定性推理让我们也可以修改训练栈,使采样与训练按位一致,从而得到真正的 on-policy RL。
我们在 Bigmath 上,以 RLVR 设置、最大 rollout 长度 4096,使用 Qwen 2.5-VL instruct 8B 初始化 RL 策略进行实验。
如果在训练时不做 off-policy 校正(如重要性采样加权),训练中期奖励会崩塌;而加入 off-policy 校正项则能让训练顺利推进。但如果我们让采样器与训练器按位一致,那么就是完全的 on-policy(即 KL 散度为 0),同样可以稳定训练。
我们也画出了采样器与训练器之间 logprob 的 KL 散度曲线,三种运行配置有显著不同表现:使用重要性加权时,KL 大致在 0.001 附近,并偶尔有尖峰;而不使用重要性加权时,最终会出现一次 KL 的尖峰,和奖励崩溃的时间大致一致;当然,“真正的 On-Policy RL” 时,KL 始终为 0,表明训练策略与采样策略没有任何偏离。
注意,那条“无重要性加权”的曲线在第 318 步附近出现了显著的 loss 尖峰,并伴随 logprob KL 的对应尖峰。与此同时,使用 off-policy 校正或“真正的 On-Policy”都能让 RL 平稳进行。蓝色那条“True On-Policy”的线不是绘图 bug——它就是一条在 0 上的平线。
结论
现代软件系统包含多层抽象。在机器学习中,当我们遇到非确定性与细微的数值差异时,往往会忍不住遮遮掩掩地“糊过去”。反正我们的系统本来就是“概率性的”,再多一点非确定性又有什么关系?把失败单测里的 atol/rtol 调大一点又有什么关系?训练器与采样器之间的 logprob 差异大概也不是真正的 bug,对吧?
我们拒绝这种失败主义。只要多做一点工作,我们可以理解非确定性的根源,乃至真正解决它们!我们希望这篇文章能为社区提供一个坚实的理解框架,帮助大家在推理系统中消除非确定性,并激励更多人对自己的系统做到“知其所以然”。
引用
请按如下方式引用本工作:
He, Horace and Thinking Machines Lab, "Defeating Nondeterminism in LLM Inference", 
Thinking Machines Lab: Connectionism, Sep 2025.
或使用 BibTeX:
@article{he2025nondeterminism,
  author = {Horace He and Thinking Machines Lab},
  title = {Defeating Nondeterminism in LLM Inference},
  journal = {Thinking Machines Lab: Connectionism},
  year = {2025},
  note = {https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/},
  doi = {10.64434/tml.20250910}
}