try ai
科普
编辑
分享
反馈
  • 反向模式微分

反向模式微分

SciencePedia玻尔百科
核心要点
  • 反向模式微分通过一次反向传播,高效地计算单个输出对大量输入的导数。
  • 这种高计算效率的代价是需要大量内存来存储中间值,这一权衡可以通过检查点(checkpointing)等技术进行管理。
  • 许多学科都使用了相同的核心原理,在机器学习中称为反向传播,在物理学和工程学中称为伴随状态法。
  • 其工作原理是构建一个计算图,并利用链式法则将敏感度(或称“伴随值”)从输出反向传播至输入。

引言

想象一个复杂的系统,其中一个最终结果取决于数百万个输入参数。你如何能有效地确定每个参数对该结果的影响?这是从机器学习到气候科学等领域的一个根本性挑战。反向模式微分为此问题提供了一个优雅且惊人高效的解决方案。它是一种计算策略,如同神谕一般,在一次反向扫描中揭示输出对所有输入的敏感度。本文旨在揭开这项强大技术的神秘面纱。在接下来的章节中,你将首先探索反向模式微分背后的“原理与机制”,了解计算图和链式法则是如何被用来反向传播影响的。然后,你将踏上“应用与跨学科联系”的旅程,发现这同一个思想,在反向传播和伴随状态法等不同名称下,如何成为推动科学和工程革命的引擎。

原理与机制

想象一下,你刚按照一个复杂的配方烤好了一个蛋糕。你尝了一口,觉得有点太甜了。现在难点来了:每种配料——糖、面粉、香草——对最终的甜度贡献了多少?你可以尝试每次只改变一种配料来烤一个新蛋糕,但这极其浪费。有没有一种方法,仅凭这一次品尝,就能推断出最终味道对你投入的每一种原料的敏感度?

这就是反向模式微分的核心魔力。它是一种计算策略,使我们能够以惊人的效率找到一个最终输出相对于大量输入的导数。让我们层层剥茧,看看这个非凡的过程是如何运作的。

计算图:数字的“配方”

从本质上讲,计算机上计算的任何函数,无论多么复杂,都只是一系列简单的基本运算:加法、乘法、正弦、余弦等等。我们可以将这个序列可视化为一个有向图,我们称之为​​计算图​​或​​记录带​​(tape)。可以把它想象成一个详细的、分步的配方,其中每条指令都利用先前指令的结果来产生一个新的中间结果。

让我们考虑一个函数,如 f(x,y)=xsin⁡(y)+yexp⁡(x)f(x,y) = x\sin(y) + y\exp(x)f(x,y)=xsin(y)+yexp(x)。计算机不会一次性理解这个函数,而是将其分解:

  1. 令 v1=xv_1 = xv1​=x 且 v2=yv_2 = yv2​=y。
  2. 计算 v3=sin⁡(v2)v_3 = \sin(v_2)v3​=sin(v2​)。
  3. 计算 v4=exp⁡(v1)v_4 = \exp(v_1)v4​=exp(v1​)。
  4. 计算 v5=v1⋅v3v_5 = v_1 \cdot v_3v5​=v1​⋅v3​。
  5. 计算 v6=v2⋅v4v_6 = v_2 \cdot v_4v6​=v2​⋅v4​。
  6. 最后,计算结果 f=v7=v5+v6f = v_7 = v_5 + v_6f=v7​=v5​+v6​。

这个操作序列就是我们的记录带。第一个阶段,称为​​前向传播​​(forward pass),很简单:我们选定输入值,比如 x=2x=2x=2 和 y=0.5y=0.5y=0.5,然后只需遵循配方,沿途计算并记录每个中间变量 viv_ivi​ 的值。这就像是烤蛋糕的过程。

反向传播:追溯影响之流

天才之处就在这里。我们想知道,如果我们稍微调整初始输入 xxx 和 yyy,最终输出 fff 会改变多少。这正是导数 ∂f∂x\frac{\partial f}{\partial x}∂x∂f​ 和 ∂f∂y\frac{\partial f}{\partial y}∂y∂f​ 告诉我们的。在反向模式中,我们从末端开始,沿着记录带向后计算这些导数。

我们为图中的每个变量 viv_ivi​ 引入一个新量,称为其​​伴随值​​(adjoint),记作 vˉi\bar{v}_ivˉi​。伴随值被定义为最终输出 fff 对该变量的导数:

vˉi=∂f∂vi\bar{v}_i = \frac{\partial f}{\partial v_i}vˉi​=∂vi​∂f​

伴随值 vˉi\bar{v}_ivˉi​ 衡量了最终结果 fff 对中间值 viv_ivi​ 微小变化的敏感度。它告诉我们 viv_ivi​ 对最终结果有多大的“影响力”。我们的目标是找到 xˉ\bar{x}xˉ 和 yˉ\bar{y}yˉ​,这正是我们所寻求的梯度。

这个过程始于将最终输出的伴随值设为1,因为 fˉ=∂f∂f=1\bar{f} = \frac{\partial f}{\partial f} = 1fˉ​=∂f∂f​=1。输出对自身的影响,自然是一对一的。现在,我们沿着配方往回走。对于每个运算,我们局部地使用链式法则,将这种影响传播回该步骤用作输入的变量。

考虑我们例子中的最后一步:f=v7=v5+v6f = v_7 = v_5 + v_6f=v7​=v5​+v6​。v7v_7v7​ 的影响(即 vˉ7=1\bar{v}_7=1vˉ7​=1)是如何分配给 v5v_5v5​ 和 v6v_6v6​ 的?链式法则告诉我们:

vˉ5=∂f∂v5=∂f∂v7∂v7∂v5=vˉ7⋅1=1\bar{v}_5 = \frac{\partial f}{\partial v_5} = \frac{\partial f}{\partial v_7} \frac{\partial v_7}{\partial v_5} = \bar{v}_7 \cdot 1 = 1vˉ5​=∂v5​∂f​=∂v7​∂f​∂v5​∂v7​​=vˉ7​⋅1=1
vˉ6=∂f∂v6=∂f∂v7∂v7∂v6=vˉ7⋅1=1\bar{v}_6 = \frac{\partial f}{\partial v_6} = \frac{\partial f}{\partial v_7} \frac{\partial v_7}{\partial v_6} = \bar{v}_7 \cdot 1 = 1vˉ6​=∂v6​∂f​=∂v7​∂f​∂v6​∂v7​​=vˉ7​⋅1=1

所以,影响通过加法直接传递了过去。现在让我们再往后退一步,回到运算 v5=v1⋅v3v_5 = v_1 \cdot v_3v5​=v1​⋅v3​。我们已经知道 v5v_5v5​ 的影响是 vˉ5=1\bar{v}_5 = 1vˉ5​=1。这如何传播到 v1v_1v1​ 和 v3v_3v3​?同样,链式法则是我们的指南:

从此路径对 v1 的影响=∂f∂v5∂v5∂v1=vˉ5⋅v3=1⋅v3\text{从此路径对 } v_1 \text{ 的影响} = \frac{\partial f}{\partial v_5}\frac{\partial v_5}{\partial v_1} = \bar{v}_5 \cdot v_3 = 1 \cdot v_3从此路径对 v1​ 的影响=∂v5​∂f​∂v1​∂v5​​=vˉ5​⋅v3​=1⋅v3​
从此路径对 v3 的影响=∂f∂v5∂v5∂v3=vˉ5⋅v1=1⋅v1\text{从此路径对 } v_3 \text{ 的影响} = \frac{\partial f}{\partial v_5}\frac{\partial v_5}{\partial v_3} = \bar{v}_5 \cdot v_1 = 1 \cdot v_1从此路径对 v3​ 的影响=∂v5​∂f​∂v3​∂v5​​=vˉ5​⋅v1​=1⋅v1​

请注意一个关键点:为了在反向传播中计算伴随值,我们需要在前向传播期间计算出的变量的实际值(v1,v3v_1, v_3v1​,v3​ 等)。这就是为什么我们必须记录整个记录带。

一个变量可能通过多条路径影响输出。例如,在函数 L=(wx+b)(wx)L = (wx+b)(wx)L=(wx+b)(wx) 中,中间变量 v1=wxv_1 = wxv1​=wx 被使用了两次。当我们反向传播伴随值时,v1v_1v1​ 将从它被使用的两个地方都接收到影响的贡献。其总影响,即最终的伴随值,就是来自所有路径影响的总和。这种累积是根本性的。这个过程继续进行,一步步地反向传播和累积伴随值,直到我们到达输入节点。xˉ\bar{x}xˉ 和 yˉ\bar{y}yˉ​ 的最终值就是我们所寻求的梯度。这整个反向传播过程,为我们提供了相对于所有输入的梯度,被称为​​反向传播​​(backward pass)。

巨大的权衡:计算与内存

你可能会问:“为什么要这么麻烦?为什么不逐一扰动每个输入,看看会发生什么?” 那就相当于为每种配料都烤一个新蛋糕。对于一个有 nnn 个输入和 mmm 个输出的函数,那种“暴力”方法(类似于​​前向模式微分​​)大约需要原始函数评估 nnn 倍的工作量才能找到所有导数。

反向模式彻底改变了这一点。它只需要一次前向传播(建立记录带)和一次反向传播,就能得到一个输出相对于所有输入的导数。总成本与输出数量 mmm 成正比。

让我们具体化这一点:

  • ​​获取完整梯度的前向模式成本:​​ 与 nnn(输入数量)成正比。
  • ​​获取完整梯度的反向模式成本:​​ 与 mmm(输出数量)成正比。

这导出了一个简单而深刻的经验法则:

  • 如果你有许多输出和少量输入(m≫nm \gg nm≫n),使用前向模式。这就像一枚火箭,其轨道(mmm 个输出)取决于几个初始设置(nnn 个输入)。
  • 如果你有许多输入和少量输出(n≫mn \gg mn≫m),使用反向模式。这正是现代机器学习中的情况,其中单个损失函数(m=1m=1m=1)可能取决于数百万甚至数十亿的模型参数(nnn)。反向模式可以一次性计算出相对于所有这数百万个参数的梯度!正是这种不可思议的效率推动了深度学习革命。

但这种能力是有代价的:​​内存​​。正如我们所见,反向传播需要前向传播的中间值。对于一个有十亿步的计算,你需要存储十亿个值。这可能是一个巨大的内存负担。

聪明的工程师们用一种叫做​​检查点技术​​(checkpointing)的方法解决了这个问题。你不是保存配方的每一步,而是只保存几个关键的“检查点”。然后,在反向传播期间,当你需要两个检查点之间的中间步骤时,你只需从最后一个保存的检查点快速重新计算它们。这是一个经典的​​空间-时间权衡​​:你多做一点计算来节省大量内存。

现实世界是复杂的:分支与拐点

当我们的计算配方不是一条直线时会发生什么?如果它有条件分支,比如一个 if-then-else 语句呢?假设我们的程序说,“if v1<exp⁡(x)v_1 < \exp(x)v1​<exp(x), then do this, otherwise, do that”。反向模式的原理优雅地处理了这种情况。在前向传播期间,一条特定的路径被采纳。在反向传播期间,影响之流只沿着实际执行的路径回溯。未被采纳的路径所贡献的导数为零。

那么对于不完全平滑的函数呢?许多重要函数,比如在神经网络中无处不在的​​ReLU​​函数(ReLU⁡(x)=max⁡(0,x)\operatorname{ReLU}(x) = \max(0, x)ReLU(x)=max(0,x)),都有一个“拐点”,在该点导数没有严格定义。即使在这里,框架也可以被扩展。在这些点上,我们可以使用一个叫做​​次梯度​​(subgradient)的概念,它是梯度的有效推广。我们可以简单地在该点为导数定义一个合理的值(例如,对于ReLU,在x=0x=0x=0处,我们可以选择111、000或0.50.50.5),反向模式的机制就能照常工作。

万法归一:雅可比之舞

虽然分步记录带的比喻对于建立直觉很有帮助,但其背后存在着更深层次的数学统一性。反向传播中伴随值的传播可以用线性代数优美地描述。我们计算图中的每个基本步骤都有一个相关的局部​​雅可比矩阵​​。伴随向量 vˉ\bar{v}vˉ 跨越一个步骤的反向传播在数学上等同于将其乘以该步骤雅可比矩阵的转置。

因此,整个反向传播过程就是一系列​​雅可比转置向量积​​,从输出开始,向后移动到输入:

∇f(x)=Jg1(x)⊤⋯Jgk−1(yk−2)⊤Jgk(yk−1)⊤yˉk\nabla f(x) = J_{g_1}(x)^{\top} \cdots J_{g_{k-1}}(y_{k-2})^{\top} J_{g_k}(y_{k-1})^{\top} \bar{y}_k∇f(x)=Jg1​​(x)⊤⋯Jgk−1​​(yk−2​)⊤Jgk​​(yk−1​)⊤yˉ​k​

这就是为什么反向模式也被称为​​伴随模式​​(adjoint mode)。它揭示了值的正向传播与敏感度的反向传播之间美妙的对偶性。正是这种优雅而强大的机制,使我们能够有效地驾驭现代计算问题的高维景观,将理解影响这一棘手的任务,转变为一次简单的反向旅程。

应用与跨学科联系

你是否曾想过,是否存在一种神奇的神谕?一个能够对任何复杂系统——无论是活细胞、地球气候、金融市场还是人工大脑——精确地告诉你应该转动哪些旋钮、转动多少,才能让系统表现得更好的神谕?如果你想设计一种更有效的药物,你应该改变分子的哪些部分?如果你想让机器人走得更平稳,你应该调整哪些电机指令?事实证明,这样的神谕,或者至少是它的一个强大近似,确实存在。它不是魔法,而是反向模式微分的原理,它是我们这个时代最重要的计算思想之一。

在上一章中,我们剖析了这个非凡工具的力学原理,看到它如何巧妙地反向应用链式法则,以求得单个输出对大量输入的梯度。现在,让我们踏上一段旅程,看看这个原理在实践中的应用。我们将发现,这一个单一、优雅的思想,构成了一条统一的线索,贯穿于现代科学和工程领域惊人多样化的图景中,常常以不同的名称出现,但总是执行着相同的基本任务:为改进提供路线图。

现代机器学习的核心

反向模式微分最引人注目和广为人知的应用,可能是在机器学习领域,它在那里以​​反向传播​​(backpropagation)闻名。其核心是,机器学习模型的“学习”不过是一个系统性的纠错过程。想象一个简单的线性模型试图预测房价。它获取房屋的一些特征(面积、位置),将它们乘以一些权重,加上一个偏置,然后做出一个猜测。最初的猜测几乎肯定是错误的。猜测值与实际价格之间的差异就是误差。

关键问题是:这个误差应该归咎于谁?是“面积”的权重太高了?还是偏置项太低了?反向传播通过将误差视为一条信息来回答这个问题。它接收这个单一的误差值,并将其通过产生猜测的同一计算序列反向传播。在每一步,它都使用局部导数来确定计算中的每个参数应承担多少“责任”。对最终输出有较大影响的参数将承担较大份额的责任。这个“责任”正是误差对该参数的梯度。一旦每个参数知道了自己应承担的责任份额,它就可以朝着减小误差的方向稍微调整自己。用数千栋房屋重复这个过程数百万次,模型就学会了做出准确的预测。

这种方法的强大之处在于其可扩展性。无论我们的模型有两个参数,还是像现代大型语言模型那样有数万亿个参数,同样的逻辑都适用。一次反向传播责任的计算成本与一次前向传播做出预测的成本惊人地相似。这种效率是推动整个深度学习革命的引擎。这个思想远远超出了简单的回归;它被用来训练像高斯混合模型这样的复杂统计模型,以发现数据中的隐藏结构,其中需要优化的函数——对数似然——是许多数学运算的复杂组合。反向模式自动微分驯服了这种复杂性,将一项艰巨的分析任务变成了一项自动化的计算任务。

物理学家的秘密:伴随状态法

远在计算机科学家创造“反向传播”这个术语之前,物理学家、气象学家和工程师们就已经发现了完全相同的原理,并给它起了自己的名字:​​伴随状态法​​(adjoint-state method)。这揭示了一种美妙的思想趋同,不同的领域在处理类似的大规模优化问题时,独立地得出了同样优雅的解决方案。

考虑一下分子动力学的世界。一位化学家想要模拟一个蛋白质的折叠过程,这个过程由数千个原子之间的势能所支配。总势能是一个单一的标量值,但它取决于所有原子的 3N3N3N 个坐标。作用于每个原子上的力——其运动的真正驱动力——正是这个总能量相对于该原子坐标的负梯度。如何才能有效地计算这个梯度呢?这和之前是同样的问题:一个输出(能量),许多输入(坐标)。伴随方法,即我们伪装起来的反向模式微分,解决了这个问题。它在一次反向传播中计算出所有原子上的所有力,其计算成本与仅计算一次总能量相当。

在像地球物理反演这样的问题中,这种方法的力量变得真正令人惊叹。地球科学家想要创建一幅地球地下的地图——以寻找石油储量或理解断层线。他们通过测量来自地震或地表受控爆炸的地震波来实现这一目标。他们的“模型”是波动方程的计算机模拟,其参数是巨大三维网格中每一点的岩石地震波速,可能涉及数百万个参数。他们定义一个“失配”函数,一个衡量模拟地震波与真实世界测量值差异有多大的标量值。为了改进他们的地球地图,他们需要这个失配函数相对于所有那一百万个岩石波速参数的梯度。用任何其他方法计算这个梯度都是行不通的。

但使用伴随状态法,这就变得可行。计算过程非常优美:接收器处的失配值在模拟的地球中随时间向后传播。这个“伴随波”会重新聚焦到最可能导致失配的地下模型区域。这就好像一个人通过模拟的物理过程向后发送一个搜索查询,问:“对地球结构做出什么改变,才能让我的模拟更好地匹配现实?”

对隐式函数和求解器求导

到目前为止,我们的系统都是显式的:一系列直接从输入到输出的计算。但是当关系是隐式的,由一个必须被求解的方程定义时,会发生什么?例如,在一个电网中,每个节点的电压不是由一个简单的公式给出;它们是一个大型线性方程组 Gv=iGv = iGv=i 的解,该方程组平衡了整个网络中的电流流动。

假设我们想知道一个关键位置(如医院)的电压如何受到一个参数变化的影响,比如数英里外一条输电线电阻的变化。这是一个敏感度问题。我们可以用一个稍微扰动的电阻再次求解整个系统,但有一种更优雅的方法,它再次依赖于伴随思想。

核心思想是直接对求解器本身进行微分。在数学上,我们可以找到一个最终标量(如一个节点的电压)相对于线性系统矩阵中一个参数的敏感度,即 A(θ)x(θ)=bA(\theta)x(\theta) = bA(θ)x(θ)=b。反向模式方法并不计算电网中所有电压的变化。相反,它求解一个相关的线性系统,称为*伴随系统*。这个伴随系统的解,一个“伴随电压”向量,直接告诉我们目标量对系统中任何地方变化的敏感度。这是一种外科手术般精确的方法,让我们能够针对一个输出提出一个具体问题,并一次性获得所有输入的完整敏感度图。这一原理对于设计稳健的工程系统(从电网到飞机机翼)以及在金融模型中进行风险分析至关重要,在这些模型中,投资组合的价值是由复杂的、相互关联的模拟决定的。

新前沿:可微编程

反向模式自动微分是一种通用的算法微分工具,这一认识激发了一种令人兴奋的新范式:​​可微编程​​(differentiable programming)。其愿景是构建复杂的程序,而不仅仅是数学函数,并能够对它们进行端到端的微分。

一个惊人的例子是​​神经微分方程(Neural ODE)​​。建模动态系统(如系统生物学中的蛋白质相互作用)的科学家通常使用微分方程。神经微分方程用一个神经网络取代了微分方程内部手工设计的函数。它直接从数据中学习运动定律。你怎么可能训练这样的东西?你需要通过常微分方程(ODE)的解进行反向传播。如果天真地通过数值ODE求解器的所有微小步骤进行反向传播,将需要存储每一步的状态,导致巨大的内存成本。

解决方案再次是伴随敏感度方法。它通过求解第二个随时间反向的伴随ODE来计算梯度。令人震惊的结果是,内存成本是恒定的——它不依赖于求解器所采取的步数!这一突破使得在复杂和长时间运行的动态现象上训练优雅的连续时间模型成为可能,无缝地融合了深度学习和经典科学的世界。

这仅仅是个开始。可微编程的目标是使任何算法——从物理模拟器和图形渲染器到优化求解器——都成为一个更大的、可学习系统中的构建块。

一门实用的艺术

尽管这个“神谕”功能强大,但实现它是一门实用的艺术,而非黑魔法。在一个真实的科学代码中,比如一个大规模的有限元程序,需要进行权衡取舍。开发者可以​​手动编写​​伴随求解器。这通常能产生最佳性能和最低的内存占用,但在软件工程上是一项巨大的工程量——难以编写,更难调试,而且是维护的噩梦,因为对原始代码的任何更改都需要对伴随代码进行相应的更改。

或者,可以使用​​符号微分​​系统从高层数学方程自动生成导数代码。这极大地降低了维护负担,但可能会遭受“表达式膨胀”之苦,导致代码庞大、编译缓慢,并且无法处理程序中非符号表达的部分,比如对外部库的调用。

​​算法微分工具​​提供了第三种方式。它们直接对源代码进行操作,有望自动化整个过程。虽然它们相对于参数数量实现了同样低的成本,但它们通常带有自身的开销。“记录带”用于记录计算过程,可能会消耗大量内存,而减轻这种情况的策略,如检查点技术,则是以额外的计算来换取内存。

没有唯一的最佳答案。选择取决于问题、现有软件和开发者的资源。但可以肯定的是,其基本原理——反向的链式法则——是相同的。它证明了一个简单数学思想的力量。从一年级微积分课上学到的一个法则,涌现出一个足以撬动世界、教导机器、揭示自然规律并设计工程未来的计算杠杆。这是发现行为与计算行为之间深刻而美妙的联系。