8.5. Автоматическое дифференцирование назад#
Автоматическое дифференцирование назад (обратное, backward mode, reverse accumulation) это вид автоматического дифференцирования, в котором вычисления производных распространяются от результата функции к её аргументам. С историей открытия и авторства этого метода дифференцирования можно ознакомиться в работе [Gri12].
Мы приведём общий случай позже, а сейчас продемонстрируем алгоритм на примере функции из Раздела 8.3.
В отличие от дифференцирования вперёд, поставим задачу вычисления не одной, а сразу всех частных производных функции \(f\) по аргументам (т.е. найдём градиент)
Будем считать функцию \(f\) сложной (композитной) функцией промежуточных значений \(x_j\) (\(j = 1, \dots, 7\)) и воспользуемся правилом дифференцирования сложной функции. Чтобы найти градиент \(\nabla f\), нам понадобится найти все частные производные вида
при \(j = 1, 2, 3\) искомые производные составляют градиент.
Шаг 0. Совершим обход графа вычислений (Рисунок 8.8) для нахождения (только) промежуточных значений
На Рисунке 8.9 показан результат этого обхода при вычислении \(f(1, 2, 3)\).
Запомним промежуточные значения \(x_j\) и запомним граф вычислений, он понадобится далее. На этом первый этап вычислений завершается.
Предупреждение
В дальнейшем мы будем «переопределять» функцию \(f\) для упрощения нотации. Более корректным было бы вводить новую функцию (со своим именем) на каждом шаге. Например, при вычислении \(\partial f / \partial x_3\) на шаге 5 было бы корректней определить функцию \(g\)
производная которой совпадает с производной \(\partial f / \partial x_3\). Вместо записи выше мы используем запись
Шаг 1. Мы будем искать производные «справа-налево». Произведём инициализацию алгоритма и вычислим одну (тривиальную) производную
Шаг 2. Найдём производную \(\partial f / \partial x_6\). Значение \(x_6\) явно влияет только на \(x_7 = x_6 / x_3\), поэтому (для вычисления производной) посчитаем, что \(f\) это сложная функция вида
тогда
Здесь производная \(\partial f / \partial x_7\) нам известна с шага 1, а производная \(\partial x_7 / \partial x_6\) определяется из явной связи \(x_7 = x_6 / x_3\). Отметим, что эта явная связь нам известна из графа вычислений, сохранённом на шаге 0.
Шаг 3. Найдём производную \(\partial f / \partial x_5\). Значение \(x_5\) явно влияет только на \(x_6\), поэтому представим \(f\) в виде
тогда
Значение производной \(\partial f / \partial x_6\) нам известно с шага 2. В свою очередь, связь \(x_6 = x_4 + x_5\) хранится в графе вычислений, что позволяет вычислить производную \(\partial x_6 / \partial x_5 = 1\).
Шаг 4. Найдём производную \(\partial f / \partial x_4\). Значение \(x_4\) явно влияет только на \(x_6 = x_4 + x_5\), и этот шаг аналогичен шагу 3.
Шаг 5. Найдём производную \(\partial f / \partial x_3\). Значение \(x_3\) явно влияет на \(x_5\) и \(x_7\), поэтому для вычисления производной представим \(f\) в виде
тогда производная имеет вид
Здесь мы воспользовались правилом дифференцирования сложной функции. Значения производных \(\partial f / \partial x_5\) и \(\partial f / \partial x_7\) известны с шагов 3 и 1, соответственно. Для нахождения производных \(\partial x_5 / \partial x_3\) и \(\partial x_7 / \partial x_3\) мы пользуемся графом вычислений, в котором хранятся явные связи \(x_5 = \sin{x_3}\) и \(x_7 = x_6 / x_3\). Значение \(x_6\) известно с шага 0.
Шаг 6. Найдём производную \(\partial f / \partial x_2\). Значение \(x_2\) явно влияет только на \(x_4\), поэтому для нахождения производной представим \(f\) в виде
тогда производная имеет вид
Значение производной \(\partial f / \partial x_4\) известно с шага 4, а связь \(x_4 = x_1 x_2\) хранится в графе вычислений, что позволяет посчитать \(\partial x_4 / \partial x_2 = x_1\).
Шаг 7. Последний шаг, вычислим производную \(\partial f / \partial x_1\). Значение \(x_1\) влияет только на \(x_4\) и этот шаг аналогичен шагу 6.
Итак, задача вычисления градиента в некоторой (одной) точке решена. Отметим, что она решалась в два прохода по графу вычислений: прямому и обратному. При обратном проходе мы интенсивно пользовались только одним правилом — правилом дифференцирования сложной функции.
8.5.1. Общий случай#
В общем случае для вычисления градиента функции вида \(f: \real^n \to \real\) ставится вычислить производные вида
где \(x_j\) это промежуточные значения (вершины графа вычислений). Первые \(n\) штук \(x_j\) совпадают с аргументами функции \(f\), а последний \(x_V \equiv f\).
Для вычисления одной производной \(\partial f / \partial x_j\) функция \(f\) представляется в виде
где \(\{ x_k(x_j) \}_{k \in K}\) это вершины, зависящие явно от \(x_j\) (а \(K\) множество индексов таких вершин). На графе вычислений \(x_k\) соответствуют тем вершинам, у которых есть сток (ребро) из \(x_j\) в \(x_k\) (см. Рисунок 8.10).
В этом представлении производная \(\partial f / \partial x_j\) вычисляется следующим образом
Производные \(\partial f / \partial x_k\) известны с предыдущих шагов обратного прохода по графу. В свою очередь, производные \(\partial x_k / \partial x_j\) вычисляются на шаге по явным связям \(x_k(x_j)\), хранящимся в графе вычислений.
8.5.2. Быстродействие и применимость#
Рассмотрим два крайних случая.
В задаче вычисления градиента функции вида \(f: \real^n \to \real\) (8.8) автоматическому дифференцированию назад требуется два прохода. По этой причине дифференцирование назад предпочтительней дифференцирования вперёд для вычисления градиента.
Напротив, в задаче вычисления производных вида \(\partial f_i / \partial x\) (строки матрицы Якоби) для функции \(f: \real \to \real^m\) автоматическому дифференцированию назад требуется уже \(m + 1\) проходов. Один прямой проход и по проходу на каждую производную.
В общем случае функции вида \(f: \real^n \to \real^m\) автоматическое дифференцирование назад эффективно для вычисления матрицы Якоби при \(n \gg m\).
Кроме того, дифференцированию назад требуется хранить граф вычислений, поэтому возникает требование к памяти. Остро это требование проявляется в задачах с большими графами вычислений, например, при минимизации функции ошибки в машинном обучении.
Примечание
Воспользоваться автоматическим дифференцированием назад в Julia можно, например, с помощью пакетов Zygote.jl [Inn18] или ReverseDiff.jl.