2.2 PyTorch 中的梯度记录与控制

Author

jshn9515

Published

April 5, 2026

在 2.1 节里,我们回答了一个问题:梯度是怎么被算出来的?前向传播时记账,反向传播时查账,Autograd 把计算图搭起来,再沿着图把梯度传回去。

但是,当我们写代码时,很快会遇到另一个更现实的问题:这本账,到底要不要记?

训练时当然要记,因为我们需要反向传播。可是在验证、推理、只做特征提取、或者只是想跑一遍模型看看输出的时候,记账反而是在浪费。它会保存中间结果、构建计算图、占显存,还可能让你不小心把一段本来只想算数值的代码拖进反向传播里。

所以这一节我们换个视角:不再讨论怎么求导,而是讨论哪些计算会被 Autograd 记录,哪些会被忽略。PyTorch 给了我们几种很直接的开关:torch.no_grad()torch.enable_grad(),以及更偏推理优化的 torch.inference_mode()。它们不会改变你算出来的数值结果,但会改变这段计算有没有计算图、能不能反传、以及会花多少内存和开销。

这也代表了 PyTorch 的一个设计理念:计算怎么做是算子的事,要不要记账是 Autograd 的事。接下来我们就从最常见的 no_grad() 开始,了解这些梯度记录模式。

import torch
import torch.nn as nn
import torch.nn.functional as F

print('PyTorch version:', torch.__version__)
PyTorch version: 2.11.0+xpu

2.2.1 torch.no_grad():暂停记账

在默认情况下,只要张量的 requires_grad=True,并且我们对它进行了运算,PyTorch 就会自动构建计算图。也就是说,只要你在可求导的范围里做计算,Autograd 就会默默地帮你把账记下来。但是有时候,我们并不需要这本账。

比如,在验证模型性能时,我们通常不需要计算梯度,因为我们不会进行反向传播。或者,在推理阶段,我们只关心模型的输出结果,而不关心它是怎么算出来的。这时候,如果我们继续让 Autograd 记账,不仅浪费内存,还可能导致性能下降。如果还让 Autograd 构建计算图,那就是多此一举。

因此,PyTorch 提供了 torch.no_grad() 上下文管理器(也可以当函数的装饰器),让我们可以明确告诉 Autograd:在这个代码块里,我们不需要它记账。

我们来看一个最直观的对比。在默认模式下:

model = nn.Linear(6, 4)
x = torch.randn(10, 6)
y = torch.randn(10, 4)

y_pred = model(x)
print('`y_pred.requires_grad` before `no_grad()`:', y_pred.requires_grad)
`y_pred.requires_grad` before `no_grad()`: True

输出会是 True,因为模型参数默认 requires_grad=True,所以结果自动进入计算图。

现在我们把这段前向传播放进 no_grad() 里:

with torch.no_grad():
    y_pred = model(x)

print('`y_pred.requires_grad` inside `no_grad()`:', y_pred.requires_grad)
`y_pred.requires_grad` inside `no_grad()`: False

这一次的输出是 False

注意,在 no_grad() 模式里,所有前向传播正常执行,只是得到的结果不再被 Autograd 追踪。而一旦某个张量不再被追踪,后续所有基于它的计算也都不再被追踪。如果我们此时对一个不再被追踪的张量调用 backward(),就会报错,因为这个张量压根不在计算图里,自然也就没法执行反向传播。

loss = F.mse_loss(y_pred, y)
try:
    loss.backward()
except RuntimeError as err:
    print('RuntimeError:', err)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

这里的 loss 虽然不在 no_grad() 里,但它是基于 y_pred 计算出来的,而 y_pred 已经不被追踪了,并且另外一个输入 y 也并没有请求梯度。所以,loss 也就被排除在整个计算图里。如果我们对这个 loss 调用 backward(),就会报错。

可能有的人会以为,no_grad() 会把某些张量的 requires_grad 属性改成 False,但其实并不是。no_grad() 只是告诉 Autograd 不要追踪这个代码块里的计算,但它并不会修改张量本身的属性。你可以在 no_grad() 外创建一个新的张量,它的 requires_grad 仍然是 True,但是这并不妨碍它在 no_grad() 里被当成普通张量来使用,不会被追踪。

x = torch.randn(10, 6, requires_grad=True)

with torch.no_grad():
    print('`x.requires_grad` inside `no_grad()`:', x.requires_grad)
    y_pred = model(x)
    print('`y_pred.requires_grad` inside `no_grad()`:', y_pred.requires_grad)
`x.requires_grad` inside `no_grad()`: True
`y_pred.requires_grad` inside `no_grad()`: False

所以,no_grad() 并没有阻止张量本身的“可导资格”,它只是阻止在这个上下文里产生的新计算被记录。也就是说,张量的 requires_grad 属性是一种“能力声明”,表示“我有资格被追踪”,而 no_grad() 是一种“行为控制”,表示“在这个上下文里,不要追踪任何计算”,这两者是相互独立的。

此外,在 no_grad() 里创建的新张量,如果我们后续想让它重新加入自动微分系统,仍然可以通过调用 requires_grad_() 方法来实现。比如:

with torch.no_grad():
    x = torch.randn(10, 6)
    print('`x.requires_grad` inside `no_grad()`:', x.requires_grad)

x.requires_grad_()
print('`x.requires_grad` after `requires_grad_()`:', x.requires_grad)
`x.requires_grad` inside `no_grad()`: False
`x.requires_grad` after `requires_grad_()`: True

也就是说,no_grad() 是暂时关闭记录,而不是永久剥夺张量的“可导资格”。在内部,仍然维护了一系列计数器,来确保当我们后续需要重新开启梯度记录时,能够正确地恢复状态。但是,这仍然会引入一定的计算和显存开销。这一点和后面我们要讲的 inference_mode() 会形成一个非常重要的对比。在 inference_mode() 中,PyTorch 不仅不追踪,还会彻底关闭一些与 Autograd 相关的功能,使得我们无法再通过 requires_grad_() 来重新开启梯度记录。

从一个更底层的角度理解 no_grad() 就会发现,在 PyTorch 里,计算是数值层面的行为,记录是自动微分系统的行为。而 no_grad() 只影响后者。这也是为什么我们常在模型验证、推理部署、参数更新里用到它。

那么,接下来一个自然的问题是:如果梯度可以被关闭,那么它能不能在局部重新开启?如果我们在推理阶段的某个小步骤突然需要梯度怎么办?这就引出了下一节:torch.enable_grad()

2.2.2 torch.enable_grad():重新开始记账

在上一节里,我们看到 no_grad() 可以让 Autograd 暂停记录。那么一个自然的问题来了:如果我们已经在 no_grad() 里了,能不能只对其中一小段计算重新开启梯度?

答案是可以的。这就是 enable_grad() 的作用。

当然,也可以在外层用 enable_grad() 来开启梯度,然后在内层用 no_grad() 来关闭,这些都是可以嵌套使用的。但是,在默认模式下开启 enable_grad() 等同于啥也没干,所以也就懒得写了。

还是先来看一个简单的例子:

x = torch.randn(10, 6, requires_grad=True)

with torch.no_grad():
    y = x * 3  # Does not record computation graph
    print('`y.requires_grad` in `no_grad()`:', y.requires_grad)

    with torch.enable_grad():
        z = x * 4  # Enables gradient tracking
        print('`z.requires_grad` in `enable_grad()`:', z.requires_grad)

# Only z will have gradients tracked
z.backward(gradient=torch.ones_like(z))
`y.requires_grad` in `no_grad()`: False
`z.requires_grad` in `enable_grad()`: True

这里发生的事情非常关键:外层 no_grad() 关闭了自动微分记录,内层 enable_grad() 又在局部恢复了记录。而在退出内层的 enable_grad() 之后,外层的 no_grad() 仍然有效,所以后续的计算又回到了不被追踪的状态。这说明梯度模式是栈式管理的。进入一个上下文,就压入一种模式;退出这个上下文,就恢复之前的模式。

那么,这有什么意义呢?

很多时候,我们的代码路径是共用的。比如,推理阶段的大部分前向都不需要梯度,但某个中间步骤需要做敏感性分析;或者某些调试代码希望临时计算一个梯度。如果没有 enable_grad(),我们就不得不把整段代码拆开,或者频繁地在外层切换状态。但有了 enable_grad(),我们就可以在需要的地方局部开启,而不影响整体的推理流程。

当然,还有一个更通用的接口,就是 torch.set_grad_enabled(),它可以接受一个布尔值参数,来直接设置当前的梯度模式。no_grad()enable_grad() 其实就是这个接口的一个特例。

x = torch.randn(10, 6)
is_training = False

with torch.set_grad_enabled(is_training):
    y_pred = model(x)

is_training=True 时,等价于 enable_grad();当 is_training=False 时,等价于 no_grad()。这使得代码逻辑更加统一,写条件控制也更方便。

到目前为止,我们已经介绍了两种常用的梯度控制上下文:no_grad()enable_grad()。它们分别用于关闭和开启梯度记录,并且可以嵌套使用,形成一个灵活的栈式管理系统。接下来,我们还会介绍一个更专门针对推理优化的上下文:torch.inference_mode(),它在性能和内存效率上比 no_grad() 更进一步。

2.2.3 torch.inference_mode():干脆以后都别记账了

在前面两节中,我们已经有了一个相当灵活的机制:

  • no_grad() 可以关闭梯度记录;
  • enable_grad() 可以局部恢复梯度记录;
  • set_grad_enabled() 是一个更通用的接口,可以直接设置当前的梯度模式;
  • 梯度模式是可嵌套、可恢复的。

从表面上看,似乎已经足够了。那为什么 PyTorch 还要再提供一个 inference_mode()

答案在于一个更深层的问题:如果我们不仅知道当前不需要梯度,还知道这段计算以后永远不可能参与反向传播,那么框架是不是可以做得更激进一点?把所有和梯度有关的开销全部去掉?

这就是 inference_mode() 的设计动机1

no_grad() 模式里,PyTorch 还会维护版本计数器(version counter)和视图追踪(view tracking),以及一些用于确保梯度正确性的内部检查。这些机制在训练时是必要的。它们可以防止原地操作破坏图结构,或者共享内存导致梯度错误。但在纯推理阶段,它们其实是额外开销。既然这一段代码的结果永远不会参与梯度计算,框架就可以不再维护与梯度相关的版本检查与视图追踪,做更激进的内存优化。因此,inference_mode() 通常会比 no_grad() 更快、更省内存。

但是,它是不可逆的。

我们知道,在 no_grad() 里创建的张量,可以在之后重新启用梯度:

with torch.no_grad():
    x = torch.randn(10, 6)

x.requires_grad_()  # we can still enable gradients for x
print('`x.requires_grad` after `requires_grad_`:', x.requires_grad)
`x.requires_grad` after `requires_grad_`: True

但在 inference_mode() 中创建的张量,如果我们尝试设置 requires_grad=True,就会直接报错:

with torch.inference_mode():
    x = torch.randn(10, 6)

try:
    x.requires_grad_()
except RuntimeError as err:
    print('RuntimeError:', err)
RuntimeError: Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.

因为 inference_mode() 不是暂时关闭记录,而是创建了一种特殊的推理张量(inference tensor)。这类张量被标记为“永远不会进入自动微分系统”。即使你之后开启梯度模式,它们也不会被纳入计算图。所以,no_grad() 是暂时关闭,而 inference_mode() 是永久关闭。如果我们能够确定这段代码永远只用于推理,那就可以使用 inference_mode()

2.2.4 不同梯度模式下的行为对比

到这里,我们其实已经看到三种不同的梯度语义:默认模式、no_grad() 模式和 inference_mode() 模式。它们代表的是三种不同强度的语义承诺,也对应不同使用场景下的性能权衡。

在默认模式下,Autograd 必须假设当前的任何计算,都可能参与反向传播。因此它会:

  • 构建完整计算图
  • 保存反向传播所需的中间结果
  • 维护版本计数器和视图一致性检查

这种模式是默认模式,灵活,但代价最高。通常用于模型训练阶段的前向传播。

当我们进入 no_grad() 时,我们表达的是一种阶段性声明:这一段计算当前不参与反向传播。

于是,在这种模式下,Autograd 可以做出一些优化:

  • 不再构建计算图
  • 不再保存中间结果
  • 但仍然保留 Autograd 的内部一致性机制
  • 退出该上下文后可以恢复正常梯度模式

这种模式是暂时关闭。灵活性仍然存在,但性能已经有了明显提升。大多用于参数验证或模型评估阶段。

inference_mode() 则是更强的承诺:这一段计算永远不会参与梯度计算。基于这个前提,Autograd 可以做出更激进的优化:

  • 不构建计算图
  • 跳过与梯度相关的版本检查与视图追踪
  • 在该模式下创建的张量无法再重新加入自动微分系统

这是不可逆的关闭。这种模式的优化最激进,但限制也最大。适用于纯推理、模型评估和数据处理等场景。

Footnotes

  1. inference_mode() 是在 PyTorch 1.9 版本引入的,专门针对推理阶段的性能优化。关于它的具体实现,可以参考 RFC-0011-InferenceMode↩︎

Reuse