2.1 PyTorch 中的自动微分

Author

jshn9515

Published

April 5, 2026

在 1.3 节里,我们把计算图当成一条“责任链”:损失函数为什么是这个值,沿着链条往回追,就能追到每个参数到底“负了多少责任”。这一节我们换一个更工程的视角:框架是怎么把这条责任链自动搭起来,并且在需要的时候把梯度算出来的?

先把问题说得更直白一点:训练时我们要的是梯度,但我们手里只有一堆代码:加法、乘法、卷积、激活函数…。这些操作在前向传播里一行行执行,最后吐出一个 loss。那么,梯度从哪来?难道框架真的去推导一个巨大的符号表达式吗?

当然不是。深度学习框架做的事情更像是:

理解这套机制很关键。它不仅解释了“梯度是怎么来的”,还会直接影响我们后面遇到的许多现象:比如梯度为什么会累积?为什么中间变量默认没有 .grad 属性?为什么有些操作会切断梯度链条?以及显存与计算之间为什么总要做权衡。

import torch
import torch.autograd.functional as AF

print('PyTorch version:', torch.__version__)
PyTorch version: 2.11.0+xpu

2.1.1 计算图不是画出来的,是跑出来的

理解 PyTorch 的自动微分,最好的方式不是先背概念,而是先观察一件事:你只是在做前向计算,但计算图会在运行过程中自动搭建出来。

假设我们有这样一个简单的函数:

\[ z = \sin(x \cdot y) \]

我们可以把它拆解成几个基本的运算步骤:

  1. 计算向量内积:\(q = x \cdot y\)
  2. 计算正弦函数:\(z = \sin(q)\)

然后,我们告诉 PyTorch,在接下来的计算中,我们希望得到 z 关于 xy 的梯度。

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)

这里的 requires_grad=True 可以理解成一种声明:这些变量需要被“追责”。之后只要某个结果是由它们参与计算得到的,它就会自动带上可导属性,并在背后记录“我是谁算出来的,依赖了谁”。

现在做两步普通的前向计算:先算点积,再取正弦。

q = x.dot(y)
z = q.sin()
print('z.requires_grad:', z.requires_grad)
z.requires_grad: True

到这里你看到的依然只是数值计算,但 PyTorch 已经做了两件事:

  1. z 会自动变成需要梯度的结果(因为它依赖了需要梯度的 xy)。
  2. qz 的产生过程会被记录下来:zsin 得到,qdot 得到,而 q 又依赖 xy

先别急着管计算图长什么样。我们先看一个更直观的现象:在你调用反向传播之前,梯度并不会凭空出现。

print('x.grad:', x.grad)
print('y.grad:', y.grad)
x.grad: None
y.grad: None

这里是 None,而不是 0。原因也很简单:梯度是一种反向回溯的产物,只有当你明确发起回溯(比如调用 backward())时,PyTorch 才会沿着刚才记录的依赖关系,把梯度算出来并写回到叶子节点上。如果不调用,PyTorch 就不会去算梯度,自然也不会给你填上数值。

接下来我们就做这件事:从 z 开始反向传播,看看 .grad 是如何出现的,以及它和我们手算的结果是否一致。

2.1.2 backward 到底做了什么:从输出往回查账

上一节我们只做了前向计算,但 PyTorch 已经把依赖关系悄悄记录好了。现在我们真正关心的是:当你调用 backward() 时,框架究竟做了什么?算出来的梯度又是否可信?

还是沿用同一个例子:

\[ q = x \cdot y, \quad z = \sin(q) \]

如果我们手算梯度,我们就会得到:

\[ \frac{\partial z}{\partial x} = \frac{\partial z}{\partial q} \cdot \frac{\partial q}{\partial x} = \cos(q) \cdot y \] \[ \frac{\partial z}{\partial y} = \frac{\partial z}{\partial q} \cdot \frac{\partial q}{\partial y} = \cos(q) \cdot x \]

好的,现在让 PyTorch 来算。我们直接从输出 z 发起回溯:

z.backward()
print('x.grad:', x.grad)
print('y.grad:', y.grad)
x.grad: tensor([3.1666, 3.7999, 4.4332, 5.0666])
y.grad: tensor([0.6333, 1.2666, 1.9000, 2.5333])

此时 .grad 不再是 None,梯度已经被写回到了 xy 这两个叶子节点上。直觉上你可以这样理解 backward()

  1. z 为起点,默认认为 \(\frac{\partial z}{\partial z} = 1\)
  2. 然后沿着前向传播时记下来的依赖链往回走;
  3. 每走过一个算子节点,就用这个算子的局部求导规则把梯度继续往上游传递。

我们可以把它和手算结果对齐。比如:

assert torch.allclose(x.grad, y * x.dot(y).cos())
assert torch.allclose(y.grad, x * x.dot(y).cos())

到这里,自动微分的核心逻辑其实已经很清楚了。深度学习框架并不需要推导一个巨大的全局导数公式,它只需要知道每一步怎么求导,然后把这些局部规则按计算图的结构串起来。

如果再深入一点,其实 PyTorch 也把这条回溯链暴露了一部分给我们。比如:

print('z.grad_fn:', z.grad_fn.name())
print('q.grad_fn:', q.grad_fn.name())
print('x.grad_fn:', x.grad_fn)
print('y.grad_fn:', y.grad_fn)
z.grad_fn: SinBackward0
q.grad_fn: DotBackward0
x.grad_fn: None
y.grad_fn: None

我们通常会看到类似 SinBackward0 这样带有 Backward 的名字。它的含义可以粗略理解为:

  • z 不是凭空来的,它是某个算子(这里是 sin)产生的结果;
  • grad_fn 就是这个算子在反向传播时对应的梯度函数对象。

在计算反向传播时,PyTorch 从根节点开始,依次调用每个节点的导数算子,计算出各个输入变量的梯度,直到到达输入节点为止。例如,当我们调用 z.backward() 时,PyTorch 会首先调用 z 节点的导数算子 SinBackward0,计算出 \(\frac{\partial z}{\partial q}\),然后将该值传递给 q 节点的导数算子 DotBackward0,计算出 \(\frac{\partial q}{\partial x}\)\(\frac{\partial q}{\partial y}\),最终得到 \(\frac{\partial z}{\partial x}\)\(\frac{\partial z}{\partial y}\)。叶子节点(如 xy)没有导数算子,因为它们是计算图的起点,不需要进一步计算梯度。

更关键的是,grad_fn.next_functions 会指向它的上游依赖:

node_q = z.grad_fn.next_functions[0][0]
node_x = node_q.next_functions[0][0]
node_y = node_q.next_functions[1][0]
print('grad_fn of z.child -> q:', node_q.name())
print('grad_fn of q.child -> x:', node_x.name())
print('grad_fn of q.child -> y:', node_y.name())
grad_fn of z.child -> q: DotBackward0
grad_fn of q.child -> x: struct torch::autograd::AccumulateGrad
grad_fn of q.child -> y: struct torch::autograd::AccumulateGrad

它们描述的是,为了计算 z 的梯度,反向传播接下来应该去找谁、沿着哪些输入回溯。例如,在 SinBackward0 节点中,next_functions 会指向 DotBackward0 节点,因为 SinBackward0 的输入是 q,而 q 是通过 DotBackward0 计算得到的。同样地,在 DotBackward0 节点中,next_functions 会指向输入节点 xyAccumulateGrad 是一个特殊的节点类型,每个需要梯度的叶子节点前都会有一个对应的 AccumulateGrad 节点,负责把得到的梯度累加到叶子节点的 .grad 属性中。这也就是为什么 x.grady.grad 最终会在调用 backward() 后出现。

2.1.3 为什么非标量不能直接 backward?

上面的例子里,z 是一个标量,所以我们可以理直气壮地写 z.backward()。相信很多人第一次换成输出是向量或者矩阵时,会立刻撞到 PyTorch 的一条看起来很不讲理的限制:

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
try:
    Z.backward()  # This will raise an error because z is not a scalar
except RuntimeError as err:
    print('RuntimeError:', err)
RuntimeError: grad can be implicitly created only for scalar outputs

这不是 PyTorch 小气,而是反向传播的起点在非标量情况下不再唯一。

对标量 z,我们通常关心的是 \(\frac{\partial z}{\partial x}\)\(\frac{\partial z}{\partial y}\)。反向传播从输出出发,第一步就是设定 \(\frac{\partial z}{\partial z} = 1\)。这一步之所以合理,是因为标量输出的单位梯度没有歧义:我们就是要沿着 z 这个方向往回传。

但是,如果输出是向量或者矩阵 Z 呢?我们到底想要什么?

  • 是想要 Z 的每一个元素对 xy 的梯度吗?那会是一个更高阶的张量。
  • 还是想要某个标量函数,比如 Z 的和、均值、某个加权和,对 xy 的梯度?

也就是说,对非标量输出,反向传播必须先回答一句话:我们打算从哪个“方向”把梯度回传?

在数学上,这个“方向”就是一个与输出同形状的张量 v,表示从上游传下来的梯度:

\[ v = \frac{\partial L}{\partial Z} \]

然后 PyTorch 实际计算的是向量-雅可比积(VJP):

\[ \frac{\partial L}{\partial x} = v^\top \left(\frac{\partial Z}{\partial x}\right) \]

对于标量输出,v 自动为 1(等价于调用 Z.backward(),即把 \(L\) 取为 \(Z\));对于非标量输出,v 需要我们自己提供。

这里就有两种写法。

一种写法是,我们显式传入 gradient,表示我们想要从哪个方向回传梯度:

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
Z.backward(gradient=torch.ones_like(Z))
print('x.grad:', x.grad)
print('y.grad:', y.grad)
x.grad: tensor([26., 26., 26., 26.])
y.grad: tensor([10., 10., 10., 10.])

这里 torch.ones_like(Z) 就是告诉 PyTorch,我想让 \(L = \sum_{i,j} Z_{i,j}\),因为

\[ \frac{\partial L}{\partial Z_{i,j}} = 1 \]

所以传一个全 1 的梯度,就等价于“对所有元素求和后再 backward”。

还有另外一种写法,就是先把 Z 变成一个标量,再对这个标量调用 backward()

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
Z = Z.sum()  # Now Z is a scalar
Z.backward()
print('x.grad:', x.grad)
print('y.grad:', y.grad)
x.grad: tensor([26., 26., 26., 26.])
y.grad: tensor([10., 10., 10., 10.])

这两种写法在很多情况下是等价的。要么我们显式告诉 PyTorch 从哪个方向回传梯度,要么我们先把输出变成一个标量(比如求和),让它自己默认从这个标量的方向回传梯度。

2.1.4 高阶导数:让求导过程也变成计算的一部分

到目前为止,我们做的都是一阶梯度:给定一个标量输出(或者可以转换成标量输出)\(L\),求 \(\nabla_x L\)\(\nabla_y L\)。但有时候我们会需要更高阶的信息,比如二阶导数(Hessian 的某些方向)、曲率、或者用在一些正则项里。

那么这件事的关键点在于:如果你想对“梯度”再求导,那么“求梯度这件事”本身也必须是可微的。这就是 create_graph=True 的含义。在计算一阶导数时,不仅算出数值,还要把“算出这个导数的过程”记录成新的计算图。

可能这时候很多人就会有疑惑,为什么不用 backward() 呢?因为 backward() 的设计目标是训练模型:我们把梯度累积进叶子张量的 .grad 属性中,并且默认释放图来节省内存。但是,在做高阶导时,我们更希望:

  • 梯度作为一个张量返回(方便继续算)
  • 必要时保留 / 构建计算图(方便再求导)

因此更常用的是 torch.autograd.grad

我们还是用上面的例子:\(z = \sin(x \cdot y)\)。我们先求一阶导数 \(\frac{dz}{dx}\)\(\frac{dz}{dy}\),然后再对这个结果求导,看看二阶导数 \(\frac{d^2 z}{dx^2}\)\(\frac{d^2 z}{dy^2}\) 是什么样的。

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(4.0, requires_grad=True)
z = torch.sin(x * y)

dzdx, dzdy = torch.autograd.grad(z, (x, y), create_graph=True)
print('dz/dx:', dzdx)
print('dz/dy:', dzdy)
dz/dx: tensor(-0.5820, grad_fn=<MulBackward0>)
dz/dy: tensor(-0.2910, grad_fn=<MulBackward0>)

这里最重要的一行是 create_graph=True。如果没有它,dz/dxdz/dy 会被当成纯数值结果,不再保留它是怎么得到的,那我们就没法再对它求导。dz/dxdz/dy 的输出都包含了一个 grad_fn,说明他们允许自身被求导。

在计算高阶导数时,我们有时候希望在同一个计算图中前后对不同变量分别求导。但是,PyTorch 在调用一次 backward() 后默认会释放计算图来节省内存,这就导致我们无法在同一个图里连续求导。如果我们确实需要在同一次前向结果上做多次回溯,可以通过设置 retain_graph=True 来保留图:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(4.0, requires_grad=True)
z = torch.sin(x * y)

dzdx, dzdy = torch.autograd.grad(z, (x, y), create_graph=True)
print('dz/dx:', dzdx)
print('dz/dy:', dzdy)

(d2zdx2,) = torch.autograd.grad(dzdx, x, retain_graph=True)
(d2zdy2,) = torch.autograd.grad(dzdy, y)
print('d2z/dx2:', d2zdx2)
print('d2z/dy2:', d2zdy2)
dz/dx: tensor(-0.5820, grad_fn=<MulBackward0>)
dz/dy: tensor(-0.2910, grad_fn=<MulBackward0>)
d2z/dx2: tensor(-15.8297)
d2z/dy2: tensor(-3.9574)

不过更常见的做法是,重新执行一次前向传播来得到一张新的计算图。retain_graph=True 通常是当我们确实要在同一个计算图上做多次梯度计算时才用,比如高阶导数实验或者某些正则项的计算。

2.1.5 VJP 和 JVP:反向模式与正向模式到底在算什么?

到目前为止我们一直在说“求梯度”。但严格来说,深度学习里绝大多数函数并不是从标量到标量,而是:

\[ f: \mathbb{R}^n \to \mathbb{R}^m \]

它的导数是一个雅可比矩阵(Jacobian):

\[ J = \frac{\partial f}{\partial x} \in \mathbb{R}^{m \times n} \]

真正的问题是,当 \(m,n\) 都很大时,我们几乎从来不会显式构造 \(J\)。我们真正想要的,框架实际计算的是 Jacobian 的乘积,要么乘在左边,要么乘在右边。

2.1.5.1 VJP:向量-雅可比积(反向模式)

给定上游梯度向量 \(v \in \mathbb{R}^m\)(可以理解为 \(\frac{\partial L}{\partial f}\)),反向模式计算的是:

\[ v^\top J \in \mathbb{R}^n \]

这就是 VJP(vector-Jacobian product)

把它翻译成训练时的语言就更熟悉了:

  • 我们有一个标量 loss\(L = \mathcal{L}(f(x))\)
  • 一个上游梯度:\(v = \frac{\partial L}{\partial f}\)
  • 进行反向传播:\(\frac{\partial L}{\partial x} = v^\top \frac{\partial f}{\partial x}\)

所以,平时我们调用 backward(),实际上就是在计算一个特殊的 VJP。

def vjp_func(x: torch.Tensor, y: torch.Tensor):
    return x.dot(y).sin()


x = torch.arange(1.0, 5.0)
y = torch.arange(5.0, 9.0)
out = AF.vjp(vjp_func, (x, y))
print('func(x,y):', out[0])
print('VJP output:', out[1])
func(x,y): tensor(0.7739)
VJP output: (tensor([3.1666, 3.7999, 4.4332, 5.0666]), tensor([0.6333, 1.2666, 1.9000, 2.5333]))

2.1.5.2 JVP:雅可比-向量积(正向模式)

正向模式则相反:给定一个输入方向 \(u \in \mathbb{R}^n\),计算:

\[ Ju \in \mathbb{R}^m \]

这就是 JVP(Jacobian-vector product)。从直觉上,它回答的问题是:如果我们在输入空间里沿某个方向 \(u\) 做一个微小的扰动,输出会沿着哪个方向变化?这在做敏感性分析、隐式层、某些二阶方法、以及一些物理/科学计算中非常常见。

def jvp_func(a: torch.Tensor, b: torch.Tensor):
    return a.dot(b).sin()


x = torch.arange(1.0, 5.0)
y = torch.arange(5.0, 9.0)
v_x = torch.full_like(x, 0.1)
v_y = torch.full_like(y, 0.2)
out = AF.jvp(jvp_func, (x, y), (v_x, v_y))
print('func(x,y):', out[0])
print('JVP output:', out[1])
func(x,y): tensor(0.7739)
JVP output: tensor(2.9133)

2.1.5.3 为什么深度学习里更常见的是 VJP

这个问题不是“谁更高级”,而是”规模匹配”。

  • 在深度学习训练中,通常 \(n\) 是参数维度(百万/亿级),\(m\) 是输出维度(通常是一个标量)
  • 我们真正想要的是 \(\nabla L \in \mathbb{R}^n\)

VJP 的复杂度大致和一次反向传播同量级,适合 \(n\) 很大但输出是标量/低维的场景。JVP 更适合输入维度相对小,但我们关心输出方向变化的场景。所以,我们会看到一个很经典的判断:如果输出是标量或低维向量,而且输入维度很大,那么反向模式(VJP)更合适;如果输入维度相对较小,输出维度很大,那么正向模式(JVP)可能更合适。

2.1.6 反向传播中的常见错误

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)

1. 重复调用 backward()

在同一个计算图上多次调用 backward() 会导致错误。PyTorch 在第一次反向传播结束后,会把这张图里只为反向传播服务的中间变量释放掉,以节省显存。所以当我们第二次再沿着同一张图回溯,就会发现“路标”已经被清理了。如果需要多次计算梯度,可以在第一次调用时设置 retain_graph=True

z = x.dot(y).sin()
z.backward()
try:
    z.backward()  # This will raise an error because gradients are already computed
except RuntimeError as err:
    print('RuntimeError:', err)
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
z = x.dot(y).sin()
z.backward(retain_graph=True)
z.backward()  # This works because we retained the graph

2. 尝试访问中间节点的梯度

只有叶子节点(即最初创建的变量)会存储梯度信息。中间节点的梯度不会被存储,因为如果每个中间变量都存梯度,显存会直接爆炸,而且训练真正需要的是参数梯度,而不是所有中间量的梯度。因此尝试访问它们的 .grad 属性会返回 None,并引发 UserWarning。如果需要保留中间节点的梯度,可以在创建这些节点时设置 q.retain_grad()

import warnings

q = x.dot(y)
z = q.sin()
z.backward()

with warnings.catch_warnings(record=True) as warns:
    print('q.grad:', q.grad)
    if len(warns) > 0:
        for warn in warns:
            print('UserWarning:', warn.message)
q.grad: None
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\build\aten\src\ATen/core/TensorBody.h:499.)
q = x.dot(y)
q.retain_grad()
z = q.sin()
z.backward()
print('q.grad after `retain_grad`:', q.grad)  # Now q.grad is available
q.grad after `retain_grad`: tensor(0.6333)

3. 使用原地操作

PyTorch 里像 x.add_(1)x.relu_() 这种带下划线的操作,表示原地修改张量。不创建新张量,而是直接改 x 自己的内存。这在直觉上很省事,但在反向传播往往需要用到前向传播时的某些中间值。如果这些值在前向之后被我们就地改掉,那反向传播就可能失去计算梯度所需的信息。因此,在反向传播过程中,尽量避免使用原地操作,或者确保它们不会修改反向传播需要的中间变量。

z = x.dot(y)
try:
    x.relu_()
except RuntimeError as err:
    print('RuntimeError:', err)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
z = x.dot(y)
x = x.relu()
z.backward()