8.5. Автоматическое дифференцирование назад#

Автоматическое дифференцирование назад (обратное, backward mode, reverse accumulation) это вид автоматического дифференцирования, в котором вычисления производных распространяются от результата функции к её аргументам. С историей открытия и авторства этого метода дифференцирования можно ознакомиться в работе [Gri12].

Мы приведём общий случай позже, а сейчас продемонстрируем алгоритм на примере функции из Раздела 8.3.

\[f(x_1, x_2, x_3) = \frac{x_1 x_2 + \sin{x_3}}{x_3}\]
../_images/graph-example.svg

Рис. 8.8 Граф вычислений функции \(f(x_1, x_2, x_3) = (x_1 x_2 + \sin{x_3})/x_3\).#

В отличие от дифференцирования вперёд, поставим задачу вычисления не одной, а сразу всех частных производных функции \(f\) по аргументам (т.е. найдём градиент)

\[\begin{split}\nabla f (x_1, x_2, x_3) = \begin{bmatrix} \partial f / \partial x_1 \\ \partial f / \partial x_2 \\ \partial f / \partial x_3 \end{bmatrix}.\end{split}\]

Будем считать функцию \(f\) сложной (композитной) функцией промежуточных значений \(x_j\) (\(j = 1, \dots, 7\)) и воспользуемся правилом дифференцирования сложной функции. Чтобы найти градиент \(\nabla f\), нам понадобится найти все частные производные вида

\[\frac{\partial f}{\partial x_j}, \quad j = 1, \dots, 7,\]

при \(j = 1, 2, 3\) искомые производные составляют градиент.

Шаг 0. Совершим обход графа вычислений (Рисунок 8.8) для нахождения (только) промежуточных значений

\[x_1, \dots, x_7.\]

На Рисунке 8.9 показан результат этого обхода при вычислении \(f(1, 2, 3)\).

../_images/graph-backward-first-traverse.svg

Рис. 8.9 Прямой обход графа вычислений \(f(1, 2, 3)\) для нахождения промежуточных значений (подписаны рядом с вершинами).#

Запомним промежуточные значения \(x_j\) и запомним граф вычислений, он понадобится далее. На этом первый этап вычислений завершается.

Предупреждение

В дальнейшем мы будем «переопределять» функцию \(f\) для упрощения нотации. Более корректным было бы вводить новую функцию (со своим именем) на каждом шаге. Например, при вычислении \(\partial f / \partial x_3\) на шаге 5 было бы корректней определить функцию \(g\)

\[g = g(x_5(x_3), x_7(x_3)) = g(x_3),\]

производная которой совпадает с производной \(\partial f / \partial x_3\). Вместо записи выше мы используем запись

\[f = f(x_5(x_3), x_7(x_3)).\]

Шаг 1. Мы будем искать производные «справа-налево». Произведём инициализацию алгоритма и вычислим одну (тривиальную) производную

\[\frac{\partial f}{\partial x_7} = 1. \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_7}(1, 2, 3) = 1.\]

Шаг 2. Найдём производную \(\partial f / \partial x_6\). Значение \(x_6\) явно влияет только на \(x_7 = x_6 / x_3\), поэтому (для вычисления производной) посчитаем, что \(f\) это сложная функция вида

\[f = f(x_7(x_6)),\]

тогда

\[\frac{\partial f}{\partial x_6} = \frac{\partial f}{\partial x_7} \frac{\partial x_7}{\partial x_6} = 1 \times \frac{1}{x_3} = \frac{1}{x_3} . \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_6}(1, 2, 3) = \frac{1}{3}.\]

Здесь производная \(\partial f / \partial x_7\) нам известна с шага 1, а производная \(\partial x_7 / \partial x_6\) определяется из явной связи \(x_7 = x_6 / x_3\). Отметим, что эта явная связь нам известна из графа вычислений, сохранённом на шаге 0.

Шаг 3. Найдём производную \(\partial f / \partial x_5\). Значение \(x_5\) явно влияет только на \(x_6\), поэтому представим \(f\) в виде

\[f = f(x_6(x_5)),\]

тогда

\[\frac{\partial f}{\partial x_5} = \frac{\partial f}{\partial x_6} \frac{\partial x_6}{\partial x_5} = \frac{1}{x_3} \times 1 = \frac{1}{x_3}. \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_5}(1, 2, 3) = \frac{1}{3}.\]

Значение производной \(\partial f / \partial x_6\) нам известно с шага 2. В свою очередь, связь \(x_6 = x_4 + x_5\) хранится в графе вычислений, что позволяет вычислить производную \(\partial x_6 / \partial x_5 = 1\).

Шаг 4. Найдём производную \(\partial f / \partial x_4\). Значение \(x_4\) явно влияет только на \(x_6 = x_4 + x_5\), и этот шаг аналогичен шагу 3.

\[\begin{split}\begin{align} f &= f(x_6(x_4)) \\ \frac{\partial f}{\partial x_4} &= \frac{\partial f}{\partial x_6} \frac{\partial x_6}{\partial x_4} = \frac{1}{x_3} \times 1 = \frac{1}{x_3}. \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_4}(1, 2, 3) = \frac{1}{3}. \end{align}\end{split}\]

Шаг 5. Найдём производную \(\partial f / \partial x_3\). Значение \(x_3\) явно влияет на \(x_5\) и \(x_7\), поэтому для вычисления производной представим \(f\) в виде

\[f = f(x_5(x_3), x_7(x_3)),\]

тогда производная имеет вид

\[\begin{split}\frac{\partial f}{\partial x_3} = \frac{\partial f}{\partial x_5} \frac{\partial x_5}{\partial x_3} + \frac{\partial f}{\partial x_7} \frac{\partial x_7}{\partial x_3} = \frac{1}{x_3} \times \cos{x_3} + 1 \times \bigg[ - \frac{x_6}{x^2_3} \bigg] = \frac{\cos{x_3}}{x_3} - \frac{x_6}{x^2_3}. \\ \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_3}(1, 2, 3) = \frac{\cos{3}}{3} - \frac{2 + \sin{3}}{9}.\end{split}\]

Здесь мы воспользовались правилом дифференцирования сложной функции. Значения производных \(\partial f / \partial x_5\) и \(\partial f / \partial x_7\) известны с шагов 3 и 1, соответственно. Для нахождения производных \(\partial x_5 / \partial x_3\) и \(\partial x_7 / \partial x_3\) мы пользуемся графом вычислений, в котором хранятся явные связи \(x_5 = \sin{x_3}\) и \(x_7 = x_6 / x_3\). Значение \(x_6\) известно с шага 0.

Шаг 6. Найдём производную \(\partial f / \partial x_2\). Значение \(x_2\) явно влияет только на \(x_4\), поэтому для нахождения производной представим \(f\) в виде

\[f = f(x_4(x_2)),\]

тогда производная имеет вид

\[\frac{\partial f}{\partial x_2} = \frac{\partial f}{\partial x_4} \frac{\partial x_4}{\partial x_2} = \frac{1}{x_3} \times x_1 = \frac{x_1}{x_3}. \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_2}(1, 2, 3) = \frac{1}{3}\]

Значение производной \(\partial f / \partial x_4\) известно с шага 4, а связь \(x_4 = x_1 x_2\) хранится в графе вычислений, что позволяет посчитать \(\partial x_4 / \partial x_2 = x_1\).

Шаг 7. Последний шаг, вычислим производную \(\partial f / \partial x_1\). Значение \(x_1\) влияет только на \(x_4\) и этот шаг аналогичен шагу 6.

\[\begin{split}\begin{align} f &= f(x_4(x_1)) \\ \frac{\partial f}{\partial x_1} &= \frac{\partial f}{\partial x_4} \frac{\partial x_4}{\partial x_1} = \frac{1}{x_3} \times x_2 = \frac{x_2}{x_3}. \quad \color{gray} \bigg\rvert \ \frac{\partial f}{\partial x_1}(1, 2, 3) = \frac{2}{3} \end{align}\end{split}\]

Итак, задача вычисления градиента в некоторой (одной) точке решена. Отметим, что она решалась в два прохода по графу вычислений: прямому и обратному. При обратном проходе мы интенсивно пользовались только одним правилом — правилом дифференцирования сложной функции.

8.5.1. Общий случай#

В общем случае для вычисления градиента функции вида \(f: \real^n \to \real\) ставится вычислить производные вида

(8.8)#\[\frac{\partial f}{\partial x_j}, \quad j = 1, \dots, V,\ V \ge n,\]

где \(x_j\) это промежуточные значения (вершины графа вычислений). Первые \(n\) штук \(x_j\) совпадают с аргументами функции \(f\), а последний \(x_V \equiv f\).

Для вычисления одной производной \(\partial f / \partial x_j\) функция \(f\) представляется в виде

\[f = f(\{ x_k(x_j) \}_{k \in K}),\]

где \(\{ x_k(x_j) \}_{k \in K}\) это вершины, зависящие явно от \(x_j\)\(K\) множество индексов таких вершин). На графе вычислений \(x_k\) соответствуют тем вершинам, у которых есть сток (ребро) из \(x_j\) в \(x_k\) (см. Рисунок 8.10).

../_images/graph-backward-general.svg

Рис. 8.10 Общая схема графа вычислений производной \(\partial f / \partial x_j\) в автоматическом дифференцировании назад. Явные связи между вершинами (т.е. дуги графа) показаны прямыми стрелками, а неявные — волнистыми.#

В этом представлении производная \(\partial f / \partial x_j\) вычисляется следующим образом

\[\frac{\partial f}{\partial x_j} = \sum_{k \in K} \frac{\partial f}{\partial x_k} \frac{\partial x_k}{\partial x_j}.\]

Производные \(\partial f / \partial x_k\) известны с предыдущих шагов обратного прохода по графу. В свою очередь, производные \(\partial x_k / \partial x_j\) вычисляются на шаге по явным связям \(x_k(x_j)\), хранящимся в графе вычислений.

8.5.2. Быстродействие и применимость#

Рассмотрим два крайних случая.

В задаче вычисления градиента функции вида \(f: \real^n \to \real\) (8.8) автоматическому дифференцированию назад требуется два прохода. По этой причине дифференцирование назад предпочтительней дифференцирования вперёд для вычисления градиента.

Напротив, в задаче вычисления производных вида \(\partial f_i / \partial x\) (строки матрицы Якоби) для функции \(f: \real \to \real^m\) автоматическому дифференцированию назад требуется уже \(m + 1\) проходов. Один прямой проход и по проходу на каждую производную.

В общем случае функции вида \(f: \real^n \to \real^m\) автоматическое дифференцирование назад эффективно для вычисления матрицы Якоби при \(n \gg m\).

Кроме того, дифференцированию назад требуется хранить граф вычислений, поэтому возникает требование к памяти. Остро это требование проявляется в задачах с большими графами вычислений, например, при минимизации функции ошибки в машинном обучении.

Примечание

Воспользоваться автоматическим дифференцированием назад в Julia можно, например, с помощью пакетов Zygote.jl [Inn18] или ReverseDiff.jl.