numpy.take

numpy.take(a, indices, axis=None, out=None, mode='raise')

Функция take() возвращает элементы массива с указанными индексами вдоль указанной оси.

Параметры:
a - массив NumPy или массивоподобный объект.
Исходный массив.
indices - массив NumPy, массивоподобный объект или целое число.
Индексы извлекаемых элементов.
axis - целое число (необязательный параметр).
Определяет ось вдоль которой извлекаются элементы с указанным индексом. По умолчанию axis = None, что соответствует извелечению элементов из сжатого до одной оси представления массива a.
out - массив NumPy (необязательный параметр).
Позволяет сразу напрямую поместить результат в указанный массив, при условии, что он имеет подходящую форму и тип данных.
mode - {'raise', 'wrap', 'clip'} (необязательный параметр).

Определяет метод обработки индексов, которые выходят за пределы формы исходного массива. Если указан один режим, то он применяется ко всем массивам в multi_index, но можно указать кортеж режимов аналогичной длинны, что бы обрабатывать каждый массив отдельно.

  • 'raise' - вызывать исключение (по умолчанию);
  • 'wrap' - обогнуть вокруг оси, т.е. циклически сместиться по ней;
  • 'clip' - обрезает до диапазона индексов исходного массива, причем отрицательные индексы обрезаются до 0.
Возвращает:
ndarray - массив NumPy
массив элементов исходного массива выбранных в соответствии с индексами вдоль указанной оси исходного массива.
Смотрите так же:
take_along_axis, compress

Замечание

Данная функция так же реализована в виде метода базового класса ndarray.take() с аналогичной сигнатурой (ndarray.take(indices, axis=None, out=None, mode='raise')) и принципом работы.



Примеры

Данная функция позволяет получать элементы массива по индексам которые расположены вдоль определенной оси. Рассмотрим простой пример на трехмерном массиве:

>>> import numpy as np
>>> 
>>> a = np.arange(3*4*6).reshape(3, 4, 6)
>>> a
array([[[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]],

       [[24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35],
        [36, 37, 38, 39, 40, 41],
        [42, 43, 44, 45, 46, 47]],

       [[48, 49, 50, 51, 52, 53],
        [54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65],
        [66, 67, 68, 69, 70, 71]]])

Допустим, теперь мы хотим вытащить из каждого подмассива первую и последнюю строчку:

>>> a[:, [0, 3], :]
array([[[ 0,  1,  2,  3,  4,  5],
        [18, 19, 20, 21, 22, 23]],

       [[24, 25, 26, 27, 28, 29],
        [42, 43, 44, 45, 46, 47]],

       [[48, 49, 50, 51, 52, 53],
        [66, 67, 68, 69, 70, 71]]])

Функция take(), позволяет сделать тоже самое, но в более удобной и читабельной форме (по крайней мере для чересчур многомерных массивов):

>>> np.take(a, [0, 3], axis = 1)
array([[[ 0,  1,  2,  3,  4,  5],
        [18, 19, 20, 21, 22, 23]],

       [[24, 25, 26, 27, 28, 29],
        [42, 43, 44, 45, 46, 47]],

       [[48, 49, 50, 51, 52, 53],
        [66, 67, 68, 69, 70, 71]]])

Например, мы можем с легкостью, используя один и тот же индекс вытащить первый и последний подмассив:

>>> np.take(a, [0, -1], axis = 0)
array([[[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]],

       [[48, 49, 50, 51, 52, 53],
        [54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65],
        [66, 67, 68, 69, 70, 71]]])

Или первый и последний столбец каждого подмассива:

>>> np.take(a, [0, -1], axis = 2)
array([[[ 0,  5],
        [ 6, 11],
        [12, 17],
        [18, 23]],

       [[24, 29],
        [30, 35],
        [36, 41],
        [42, 47]],

       [[48, 53],
        [54, 59],
        [60, 65],
        [66, 71]]])

Параметр indices может принимать другие массивы:

>>> b = np.array([[0, 2, 4], [1, 3, 5]])
>>> np.take(a[1], b, axis = 1)
array([[[24, 26, 28],
        [25, 27, 29]],

       [[30, 32, 34],
        [31, 33, 35]],

       [[36, 38, 40],
        [37, 39, 41]],

       [[42, 44, 46],
        [43, 45, 47]]])

Если параметр axis не указан, то извлекаются элементы из массива сжатого до одной оси:

>>> np.take(a, [0, -1])
array([ 0, 71])
>>> 
>>> np.take(a[1], [0, -1])
array([24, 47])

Если массив индексов содержит элементы чьи значения выходят за пределы исходного массива то вызывается исключение, но такое поведение желательно не всегда. Например эти значения могут просто обрезаться до максимальных размеров исходного массива:

>>> a = np.arange(7)
>>> a
array([0, 1, 2, 3, 4, 5, 6])
>>> 
>>> ind = [0, 1, 2, 10]
>>> 
>>> np.take(a, ind, mode = 'clip')
array([0, 1, 2, 6])

В примере выше, мы должны были увидеть сообщение об ошибке, так как последний элемент в ind явно превышает дину индексируемого массива. Но благодаря режиму 'clip' 10 было урезано до 6.

Режим 'wrap' в случае превышения размеров массива позволяет вести отсчет циклически смещаясь по оси:

>>> np.take(a, ind, mode = 'wrap')
array([0, 1, 2, 3])

Параметр out позволяет выгружать результат напрямую в указанный массив, минуя операцию присваивания, что позволяет несколько ускорить выполнение программы, особенно в циклах:

>>> b = np.random.randint(0, 10, size = (2, 3, 4))
>>> b
array([[[9, 1, 9, 6],
        [9, 7, 9, 4],
        [2, 3, 7, 9]],

       [[2, 0, 2, 6],
        [3, 5, 4, 6],
        [5, 7, 4, 0]]])
>>> 
>>> c = np.empty((2, 2, 1, 4), dtype = np.int)
>>> c
array([[[[777, 111, 111, 111]],

        [[777, 777, 111, 111]]],


       [[[  8, 777, 777, 111]],

        [[ 12,  13, 777, 777]]]])
>>> 
>>> np.take(b, [[-1], [1]], axis = 1, out = c)
array([[[[2, 3, 7, 9]],

        [[9, 7, 9, 4]]],


       [[[5, 7, 4, 0]],

        [[3, 5, 4, 6]]]])
>>> 
>>> c
array([[[[2, 3, 7, 9]],

        [[9, 7, 9, 4]]],


       [[[5, 7, 4, 0]],

        [[3, 5, 4, 6]]]])

Обратите внимание на то, что форма и тип данных как возвращаемого массива, так и массива в который помещается результат полностью совпадают.

Данная функция также реализована в виде метода базового класса ndarray.take():

>>> b.take([[-1], [1]], axis = 2)
array([[[[6],
         [1]],

        [[4],
         [7]],

        [[9],
         [3]]],


       [[[6],
         [0]],

        [[6],
         [5]],

        [[0],
         [7]]]])