Сердце языка Julia — множественная диспетчеризация

Ву-ху. Первый пост про язык Julia, и сразу в сердце... Рассказываю про core design фичу — multiple dispatch. Примеры, что это такое и как работает. По полочкам! А в послесловии дружеский подзатыльник Пайтону.

Система типов в Julia

Чтобы понять, как устроена диспетчеризация в Julia, надо познакомиться с системой типов.

В Julia типы организованы в иерархию типа дерево.
Any
└─ Number
   ├─ Complex
   │  ├─ Complex{Int64}
   │  ├─ Complex{Float64}
   │  └─ ...
   └─ Real
      ├─ AbstractFloat
      │  ├─ BigFloat
      │  ├─ Float16
      │  ├─ Float32
      │  └─ Float64
      ├─ ...
      ├─ Integer
      │  ├─ Bool
      │  ├─ Signed
      │  │  ├─ BigInt
      │  │  ├─ Int128
      │  │  ├─ Int16
      │  │  ├─ Int32
      │  │  ├─ Int64
      │  │  └─ Int8
      │  └─ Unsigned
      │     ├─ UInt128
      │     ├─ UInt16
      │     ├─ UInt32
      │     ├─ UInt64
      │     └─ UInt8
      └─ Rational
         ├─ Rational{Int32}
         ├─ Rational{Int64}
         └─ ...

Это часть дерева типов, она содержит (почти) все типы для чисел «из коробки». Например, цепочка от Int64 до Any такая.

julia> Int64 <: Signed <: Integer <: Real <: Number <: Any
true

julia> supertypes(Int64)
(Int64, Signed, Integer, Real, Number, Any)

Корень дерева это тип Any. А все типы в дереве делятся на два вида: абстрактный (abstract) и конкретный (concrete).

Про конкретные типы компилятору известно всё, включая их устройство в памяти. С ними можно считать в рантайме (runtime; по дефолту Julia компилируется just-in-time). И они являются листьями в дереве типов.

Абстрактные типы используются для упорядочивания типов в дерево. Отношение между типами выстраивается как «является подтипом такого-то типа», но Julia не объектно-ориентированный язык, здесь нет привычного для ООП наследования (inheritance). Например, нельзя со 100% уверенностью написать функцию для чисел Number и быть уверенным, что она будет работать для каждого подтипа.

Однако, создавать функции для абстрактных типов данных можно и полезно. На этом строится поддержка обобщённой (generic) парадигмы. Когда компилятор встречает вызов такой функции, он проверяет, все ли есть, чтобы её выполнить для конкретных аргументов. Если чего-то не хватает, увы, ошибка. А если всё есть, то компилятор создаёт настолько оптимизированный код, насколько может.

Диспетчеризация

Я начну с примера. В языке Julia можно так.

julia> f(x, y) = "default";

julia> f(x::T, y::T) where {T} = "default when x and y have same type";

julia> f(x::Int, y::Int) = "x is Int, y is Int";

julia> f(x::Int, y::Float64) = "x is Int, y is Float64";

julia> f(x, y, z) = "oh my, there are x, y and even z!";

julia> f("a", 1)
"default"

julia> f("a", "b")
"default when x and y have same type"

julia> f(1, 2)
"x is Int, y is Int"

julia> f(1, 2.0)
"x is Int, y is Float64"

julia> f(1, 2, 3)
"oh my, there are x, y and even z!"

Выше определена одна функция f.

julia> f
f (generic function with 5 methods)

И пять методов для неё.

julia> methods(f)
 [1] f(x::Int64, y::Float64)
 [2] f(x::Int64, y::Int64)
 [3] f(x, y, z)
 [4] f(x::T, y::T) where T
 [5] f(x, y)
Julia не объектно-ориентированный язык: здесь методы принадлежат не объектам, а функциям.

Когда функция вызывается, то диспетчер просматривает, какие есть методы у функции и выбирает тот, который лучшим образом подходит. «Выбрать лучшим образом» значит выбрать метод, типы аргументов которого наиболее «близки» к типам передаваемых аргументов. В случае вызова f(1, 2) подходят сразу три метода:

  • f(x, y),
  • f(x::T, y::T) where {T},
  • f(x::Int64, y::Int64).

Но вызывается последний, потому что типы аргументов (числа 1 и 2) наиболее близки к паре (Int64, Int64) (вообще, они совпадают с ними). К тому, что значит «близость» я вернусь позднее.

Итак, диспетчеризация это процесс выбора метода для конкретного вызова функции.

А что значит «множественная диспетчеризация» (multiple dispatch)?

Множественная диспетчеризация это вид диспетчеризации, который учитывает типы нескольких аргументов.

В случае Julia учитываются типы всех позиционных аргументов (positional arguments), а вот диспетчеризацию для аргументов по ключу (keyword arguments) не завезли.

Диспатч в Python

В Python тоже есть диспетчеризация, но она одинарная (single dispatch). В Python класс это пространство имён, а когда происходит вызов, например, x + y, то под капотом интерпретатор делает примерно следующее.

x + y
x.__add__(y)
type(x).__add__(x, y)

Где type(x) превращается в int, list или чем там x является в рантайме. Это и есть диспетчеризация, но по типу только первого аргумента. (Если вы раньше не знали, зачем писать в Python методах self, то теперь видите? 😏)

Что ещё можно почерпнуть из примеров выше?

Декларация не нужна

Не обязательно декларировать типы аргументов функции.

julia> f(x, y) = "default";

На самом деле, декларация здесь есть, но неявная. Этот пример эквивалентен такому.

julia> f(x::Any, y::Any) = "default";

Диспатч на абстрактных типах

Можно диспетчеризовываться не только на конкретных, но и на абстрактных типах.

julia> f(x, y) = "default";  # f(x::Any, y::Any)

julia> f(x::Int, y::Int) = "x is Int, y is Int";

julia> f(x::Real, y::Real) = "x and y are real numbers";
Чуть-чуть про компиляцию

Когда происходит вызов функции и метод выбран, компилятор создаёт машинный код для него. Повторная компиляция для вызова функции от тех же типов больше не требуется, можно сразу использовать машинный код.

Если при этом тип возвращаемого значения не зависит от значений аргументов (а определяется только их типами), то машинный код будет эффективным. Функции, которые написаны таким образом, называются стабильными по типу (type stable). Если ситуация обратная, то быстродействие падает примерно до уровня Python, потому что значения боксятся, и их тип приходится проверять в рантайме.

Типичный пример — квадратный корень sqrt(x). Для неотрицательных действительных чисел он всегда возвращает float-число. А вот отрицательные нужно обернуть в комплексное число. Будь иначе, стабильность бы была потеряна, ведь тогда sqrt(1) возвращал бы float, а sqrt(-1) — complex, в то время как аргумент в обоих случаях имеет тип Int.

Диспатч на количестве аргументов

Методы могут иметь разное количество аргументов, а диспетчеризация это учитывает.

julia> f(x, y) = "default";

julia> f(x, y, z) = "oh my, there are x, y and even z!";

Generic programming included

Julia поддерживает обобщённую (generic) парадигму.

julia> f(x::T, y::T) where {T} = "default when x and y have same type";

Здесь в одной строчке определяется семейство методов, у которых два аргумента, имеющих одинаковый тип (и он назван T). То есть метод подходит для (Int, Int), (Float64, Float64), (String, String) и так далее.

Кстати, тип аргументов доступен как в рантайме так и во время компиляции.

julia> atruntime(x) = typeof(x);

julia> atcompiletime(x::T) where {T} = T;

julia> atruntime(π)
Irrational{:π}

julia> atcompiletime(π)
Irrational{:π}

Знать типы во время компиляции бывает необходимо, чтобы дешёво извлечь информацию об аргументах. Например, так можно узнать размерность массива ndims(x) или тип его элементов eltype(x), поэтому что информация об этом зашита в типе.

Как выбирается метод?

Перейдём к тому, как выбирается метод среди нескольких.

Пример попроще

Пусть у нас есть функция с одним аргументом и несколькими методами.

  • g(x)
  • g(x::Number)
  • g(x::Float64)

Если вызвать функцию от целого числа g(1), то применится метод g(::Number). Логика следующая.

  • Подходят два метода: g(::Any) и g(::Number).
  • Тип Int64 в дереве ближе к Number, чем к Any.
  • Поэтому берём его.

Пример посложнее

А что если аргументов несколько?

julia> f(x, y) = "default";

julia> f(x::Number, y::Number) = "Number & Number";

julia> f(x::Int, y::Number) = "Int & Number";

julia> f(1, "a")
"default"

julia> f(1, 2)
"Int & Number"

julia> f(1, 1.5)
"Int & Number"

julia> f(1.5, 1.5)
"Number & Number"

Здесь у функции три метода.

С первым вызовом f(1, "a") всё понятно, второй аргумент это строка, и только один метод подходит из трёх.

Для оставшихся случаев давайте обратимся к дереву типов. Я намеренно оставил только те типы, которые участвуют в оставшихся случаях: Any, Number, Int64 (он же Int) и Float64.

Any
└─ Number
   └─ Real
      ├─ AbstractFloat
      │  └─ Float64
      └─ Integer
         └─ Signed
            └─ Int64 (это Int на моей машине)

Для вызова f(1, 2) подходят все три метода. Давайте посмотрим насколько тип аргумента при вызове x = 1::Int «удалён» от декларируемого типа первого аргумента в каждом методе.

  • Для метода f(x, y) декларируемый тип x-а это тип Any. Расстояние от Int64 до Any в дереве типов равняется пяти.
  • Для метода f(x::Number, y::Number) подсчёт даёт четыре (расстояние между Int64 и Number).
  • Для метода f(x::Int, y::Number) подсчёт даёт ноль.

Повторим то же самое для второго аргумента y. Получим расстояния 5, 4 и 4, соответственно.

Сложим теперь эти расстояния.

  • У f(x, y) расстояние \(10 = 5 + 5\).
  • У f(x::Number, y::Number) расстояние \(8 = 4 + 4\).
  • У f(x::Int, y::Number) расстояние \(4 = 0 + 4\).

В итоге получаем, что расстояние от третьего метода f(x::Int, y::Number) до аргументов (1::Int, 2::Int) самое маленькое. Поэтому он и выбирается. В этом смысле (точнее, метрике) он самый близкий.

Аналогично получается для вызова f(1, 1.5), но теперь тип второго аргумента это Float64.

  • У f(x, y) расстояние \(9 = 5 + 4\).
  • У f(x::Number, y::Number) расстояние \(7 = 4 + 3\).
  • У f(x::Int, y::Number) расстояние \(3 = 0 + 3\). Он и выбирается.

Для последнего вызова f(1.5, 1.5) оба аргумента имеют тип Float64, и подходят только два метода.

  • У f(x, y) расстояние \(8 = 4 + 4\).
  • У f(x::Number, y::Number) расстояние \(6 = 3 + 3\).

Поэтому выбирается метод f(x::Number, y::Number).

Общий случай (почти)

Если эти примеры обобщить, то мы имеем дело с пространством, состоящем из типов. Расстояние между двумя типами измеряется по дереву с учётом того, что можно двигаться только от листьев к корню. Если достичь одного типа из другого так нельзя, то расстояние бесконечное. (Такая вот топология получается.)

Я выражу это языком... Julia.

julia> function ρ(x, y)
           indx = findfirst(==(x), supertypes(y))
           !isnothing(indx) && return indx - 1

           indy = findfirst(==(y), supertypes(x))
           !isnothing(indy) && return indy - 1

           return -1
       end
ρ (generic function with 1 method)

julia> ρ(Float64, Any)
4

julia> ρ(Float64, Number)
3

Здесь сначала пробуем найти «где тип x среди супертипов игрека». Если нашли, значит x находится в ветви от y до Any, тогда и возвращаем позицию x минус 1, это и будет расстоянием от y до x. А если не нашли, пробуем искать y в ветви от x до Any. Если и так не нашли, возвращаем -1, как признак бесконечности. Примеры я привёл те, что считал вручную для f(1.5, 1.5).

С одним аргументом разобрались, разберёмся с несколькими. Несколько аргументов я буду подавать в виде кортежей (tuple).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
julia> function ρ(x::Tuple, y::Tuple)
           length(x) != length(y) && return -1

           dists = ρ.(x, y)

           -1 in dists && return -1

           return sum(dists)
       end
ρ (generic function with 2 methods)

julia> ρ((Any, Any), (Int, Int))
10

julia> ρ((Number, Number), (Int, Int))
8

julia> ρ((Int, Number), (Int, Int))
4
  • Строка 2 обрабатывает случай вообще разных пространств. Это когда в методе и вызове не совпадает число аргументов.
  • Строка 4 считает расстояние по отдельности. Тут используется broadcast.
  • В строке 6 происходит «если расстояние между какими-то типами бесконечное, то и всё расстояние бесконечное».
  • В строке 8 формула метрики для небесконечного случая. В итоге она похожа на 1-норму \(\sum |x_{i} - y_{i}|\).
  • В строках 12, 15 и 18 то, что мы считали вручную для f(1, 2).

Вот так вот и определяется «близость» типов. А точнее, близость сигнатуры вызова к сигнатуре метода.

Подытожим

Выбор метода при вызове функции осуществляется через просмотр дерева типов. Диспетчер проверяет число аргументов, и если оно правильное, то ищет подходящие методы. Если таких методов несколько, то выбирается тот, чья сигнатура ближе к сигнатуре вызова. Близость определяется расстоянием между типами в дереве.

За кадром остаётся много технических деталей. Я не разработчик языка Julia и не знаю, как они имплементированы. На моей практике те примеры и эвристики, что я привёл, покрывают 80% работы при написании кода на Julia. В оставшихся 20% то, с чем я не сталкивался, то, что становится интуитивно понятным с опытом, и то, что приходится гуглить (обычно на форуме, коммунити супер).

Если понравилось, присоединяйтесь к каналу в телеге. А ещё можете финансово поддержать выпуск новых материалов по Julia или вообще.

Напоследок я оставлю пример, который люблю.


P.S. Пример обобщённого программирования в Julia

В Julia очень много маленьких функций, которые помогают писать generic код. Например, самописный generic сумматор выглядит так.

julia> function mysum(x)
           acc = zero(eltype(x))
           for xi in x
               acc += xi
           end
           return acc
       end;

julia> mysum([1, 2, 3])
6

julia> mysum(1:10)
55

Клёво, да? Первый пример для массива из трёх элементов. А во втором 1:10 это арифметическая прогрессия от 1 до 10 с единичным шагом.

Так, стоп. А зачем нам суммировать все элементы прогрессии, если можно найти сумму за \(O(1)\) так

\begin{equation*} \frac{(x_{1} + x_{n}) \times n}{2} \end{equation*}

Исправляем!

julia> mysum(x::AbstractRange{<:Number}) = (first(x) + last(x)) * length(x) / 2;

julia> mysum(1:10)
55.0

Почти готово, но есть косяк — сумма прогрессии целых чисел это всегда целое число, не дробное. Исправим, накинув ещё один метод!

julia> mysum(x::AbstractRange{<:Integer}) = (first(x) + last(x)) * length(x) ÷ 2;

julia> mysum(1:10)
55

Вот теперь хорошо. В дальнейшем можно накинуть низкоуровневых оптимизаций: накидать потоков и SIMD инструкций (если компилятор не сделает это за нас), но щас не об этом.

Итого, сумма элементов массива, как ей и положено, считается за \(O(n)\), а сумма арифметической прогрессии за \(O(1)\).

julia> using BenchmarkTools

julia> @btime mysum(1:10^6);
  0.791 ns (0 allocations: 0 bytes)

julia> @btime mysum(1:10^8);
  0.791 ns (0 allocations: 0 bytes)

А Python так может?

К сожалению, нет. В Python отсутствует magick метод __sum__.

In [1]: %timeit sum(range(1, 10**6 + 1))
9.52 ms ± 36.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [2]: %timeit sum(range(1, 10**8 + 1))
951 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Плак-плак. Но оно может Пайтону и не надо.

На этом послесловие всё.