numpy.take_along_axis

numpy.take_along_axis(arr, indices, axis)

Функция take_along_axis() сопоставляет одномерные массивы индексов с соответствующими полными срезами исходного массива вдоль указанной оси и возвращает найденные элементы. Доступно в NumPy начиная с версии 1.15.0.

В общем, на деле все гораздо проще, так что если ничего не понятно, то лучше сразу перейдите к примерам.

Параметры:
arr - массив NumPy или массивоподобный объект.
Исходный массив.
indices - массив NumPy (необязательный параметр).
Массив индексов, который должен быть либо транслируемым по массиву arr либо содержать столькоже одномерных массивов сколько их в индексируемом массиве вдоль указанной в параметре axis оси.
axis - целое число (необязательный параметр).
Определяет ось вдоль которой извлекаются элементы с указанными в одномерных массивах индексами. По умолчанию axis = None, что соответствует извелечению элементов из сжатого до одной оси представления массива a.
Возвращает:
ndarray - массив NumPy
массив элементов исходного массива выбранных в соответствии с индексами одномерных массивов из полного среза вдоль указанной оси исходного массива.
Смотрите так же:
take, put_along_axis, compress


Примеры

Что бы разобраться, давайте создадим небольшой квадратный массив и поэкспериментируем с ним:

>>> import numpy as np
>>> 
>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

Теперь придумаем два массива индексов той же размерности но разной формы:

>>> ind_r = np.array([[3, 0, 2, 1]])
>>> ind_r
array([[3, 0, 2, 1]])
>>> 
>>> ind_c = np.array([[3], [0], [2], [1]])
>>> ind_c
array([[3],
       [0],
       [2],
       [1]])

А теперь посмотрим на то, что нам выдаст функция take_along_axis(), для каждой оси исходного массива:

>>> np.take_along_axis(a, ind_r, axis = 0)
array([[12,  1, 10,  7]])

Все довольно просто - из каждого столбца вытащен элемент с соответствующим индексом из ind_r. Посмотрим что будет с тем же массивом индексов, но для другой оси:

>>> np.take_along_axis(a, ind_r, axis = 1)
array([[ 3,  0,  2,  1],
       [ 7,  4,  6,  5],
       [11,  8, 10,  9],
       [15, 12, 14, 13]])

Теперь мы получили массив той же формы что и исходный, но с столбцами переставленными в соответствии с индексами из ind_r.

Теперь сделаем все тоже самое для ind_c:

>>> np.take_along_axis(a, ind_c, axis = 0)    #  переставленные строки
array([[12, 13, 14, 15],
       [ 0,  1,  2,  3],
       [ 8,  9, 10, 11],
       [ 4,  5,  6,  7]])
>>> 
>>> np.take_along_axis(a, ind_c, axis = 1)    #  соответствующий элемент из каждой строки
array([[ 3],
       [ 4],
       [10],
       [13]])

Теперь может возникнуть вопрос: "На кой ляд все это надо?". Отвечаю: "Сделать все тоже самое простыми и очевидными способами не получится." К тому же это очень удобно, когда дело касается отображений элементов одного массива на другой. Например у нас есть два массива. Один массив будет имитировать какие-то случайные данные:

>>> a = np.random.randint(0, 20, size = (4, 4))
>>> a
array([[ 0,  1,  5,  0],
       [19,  8, 14, 12],
       [ 4, 15,  2, 12],
       [10,  7, 11, 14]])

Второй массив, будет трехмерным и в нем каждому элементу из массива a будет соответствовать какой-то набор значений:

>>> b = np.arange(3*4*4).reshape(3, 4, 4)
>>> b
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]],

       [[16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]],

       [[32, 33, 34, 35],
        [36, 37, 38, 39],
        [40, 41, 42, 43],
        [44, 45, 46, 47]]])

Так элементу a[0, 0] будут соответствовать элементы из b[:, 0, 0] (т.е. 0, 16, 32). Теперь, допустим, мы хотим найти максимальный элемент в каждой строке массива a:

>>> a
array([[ 0,  1,  5,  0],
       [19,  8, 14, 12],
       [ 4, 15,  2, 12],
       [10,  7, 11, 14]])
>>> 
>>> max_a_r = np.argmax(a, axis = 1)
>>> max_a_r
array([2, 0, 1, 3], dtype=int32)
>>> 
>>> a[np.r_[:4], max_a_r].reshape(4, 1)
array([[ 5],
       [19],
       [15],
       [14]])

Обратите внимание на строчку a[np.r_[:4], max_a_r].reshape(4, 1), которую мы могли бы заменить на:

>>> np.take_along_axis(a, np.expand_dims(max_a_r, axis = 1), axis = 1)
array([[ 5],
       [19],
       [15],
       [14]])

Команда np.expand_dims(max_a_r, axis = 1) просто добавляет новое измерение справа:

>>> max_a_r
array([2, 0, 1, 3], dtype=int32)
>>> 
>>> np.expand_dims(max_a_r, axis = 1)
array([[2],
       [0],
       [1],
       [3]], dtype=int32)

Ну и наконец-то мы можем заняться нашим отображением

>>> c = np.take_along_axis(b, max_a_r.reshape(1, 4, 1), axis = 2)
>>> c
array([[[ 2],
        [ 4],
        [ 9],
        [15]],

       [[18],
        [20],
        [25],
        [31]],

       [[34],
        [36],
        [41],
        [47]]])

Что бы было легче разобраться с правильностью примера, удалим лишнюю ось и еще раз выведем максимальные элементы в a и массив b:

>>> a
array([[ 0,  1,  5,  0],
       [19,  8, 14, 12],
       [ 4, 15,  2, 12],
       [10,  7, 11, 14]])
>>> 
>>> a[np.r_[:4], max_a_r]
array([ 5, 19, 15, 14])
>>> 
>>> np.squeeze(c)
array([[ 2,  4,  9, 15],
       [18, 20, 25, 31],
       [34, 36, 41, 47]])
>>> 
>>> b
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]],

       [[16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]],

       [[32, 33, 34, 35],
        [36, 37, 38, 39],
        [40, 41, 42, 43],
        [44, 45, 46, 47]]])

В общем, я немного сомневаюсь, в том что пример оказался простым и наглядным, однако, при некоторой практике, данная функция может оказаться весьма удобной и полезной. Тем более что она замечательно подходит для использования функций возвращающих индексы элементов, например таких как sort(), argsort(), argpartition и т.д.

Что ж попробую привести еще оди пример, допустим мы хотим вытащить из каждой строки массива a минимальный и максимальный элементы:

>>> a_min_r = np.expand_dims(np.argmin(a, axis=1), axis=1)
>>> a_max_r = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> 
>>> ind = np.hstack((a_min_r, a_max_r))
>>> ind
array([[0, 2],
       [1, 0],
       [2, 1],
       [1, 3]], dtype=int32)
>>> 
>>> 
>>> np.take_along_axis(a, ind, axis=1)
array([[ 0,  5],
       [ 8, 19],
       [ 2, 15],
       [ 7, 14]])
>>> 
>>> a
array([[ 0,  1,  5,  0],
       [19,  8, 14, 12],
       [ 4, 15,  2, 12],
       [10,  7, 11, 14]])

Как видим, все работает, и, довольно неплохо.