6

numpy.argsort docs state

Returns:
index_array : ndarray, int Array of indices that sort a along the specified axis. If a is one-dimensional, a[index_array] yields a sorted a.

How can I apply the result of numpy.argsort for a multidimensional array to get back a sorted array? (NOT just a 1-D or 2-D array; it could be an N-dimensional array where N is known only at runtime)

>>> import numpy as np
>>> np.random.seed(123)
>>> A = np.random.randn(3,2)
>>> A
array([[-1.0856306 ,  0.99734545],
       [ 0.2829785 , -1.50629471],
       [-0.57860025,  1.65143654]])
>>> i=np.argsort(A,axis=-1)
>>> A[i]
array([[[-1.0856306 ,  0.99734545],
        [ 0.2829785 , -1.50629471]],

       [[ 0.2829785 , -1.50629471],
        [-1.0856306 ,  0.99734545]],

       [[-1.0856306 ,  0.99734545],
        [ 0.2829785 , -1.50629471]]])

For me it's not just a matter of using sort() instead; I have another array B and I want to order B using the results of np.argsort(A) along the appropriate axis. Consider the following example:

>>> A = np.array([[3,2,1],[4,0,6]])
>>> B = np.array([[3,1,4],[1,5,9]])
>>> i = np.argsort(A,axis=-1)
>>> BsortA = ???             
# should result in [[4,1,3],[5,1,9]]
# so that corresponding elements of B and sort(A) stay together

It looks like this functionality is already an enhancement request in numpy.

Jason S
  • 184,598
  • 164
  • 608
  • 970
  • https://stackoverflow.com/questions/37878946/indexing-one-array-by-another-in-numpy very closely related but that answer does not help me answer this question for the general N-dimensional case – Jason S Oct 31 '17 at 21:45
  • 1
    `take_along_axis` is now part of `numpy` – hpaulj Jan 25 '19 at 13:12

3 Answers3

5

The numpy issue #8708 has a sample implementation of take_along_axis that does what I need; I'm not sure if it's efficient for large arrays but it seems to work.

def take_along_axis(arr, ind, axis):
    """
    ... here means a "pack" of dimensions, possibly empty

    arr: array_like of shape (A..., M, B...)
        source array
    ind: array_like of shape (A..., K..., B...)
        indices to take along each 1d slice of `arr`
    axis: int
        index of the axis with dimension M

    out: array_like of shape (A..., K..., B...)
        out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]
    """
    if axis < 0:
       if axis >= -arr.ndim:
           axis += arr.ndim
       else:
           raise IndexError('axis out of range')
    ind_shape = (1,) * ind.ndim
    ins_ndim = ind.ndim - (arr.ndim - 1)   #inserted dimensions

    dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, ind.ndim))

    # could also call np.ix_ here with some dummy arguments, then throw those results away
    inds = []
    for dim, n in zip(dest_dims, arr.shape):
        if dim is None:
            inds.append(ind)
        else:
            ind_shape_dim = ind_shape[:dim] + (-1,) + ind_shape[dim+1:]
            inds.append(np.arange(n).reshape(ind_shape_dim))

    return arr[tuple(inds)]

which yields

>>> A = np.array([[3,2,1],[4,0,6]])
>>> B = np.array([[3,1,4],[1,5,9]])
>>> i = A.argsort(axis=-1)
>>> take_along_axis(A,i,axis=-1)
array([[1, 2, 3],
       [0, 4, 6]])
>>> take_along_axis(B,i,axis=-1)
array([[4, 1, 3],
       [5, 1, 9]])
Jason S
  • 184,598
  • 164
  • 608
  • 970
  • I guess one more error check for a large positive axis number could be added. – Divakar Nov 01 '17 at 05:43
  • @Divakar: Or better yet, use `axis = np.core.multiarray.normalize_axis_index(axis, arr.ndim)`, which does all of that for you. [numpy/numpy#8714](https://github.com/numpy/numpy/pull/8714) has a more complete implementation than the one taken from issue 8708 – Eric Nov 01 '17 at 06:46
  • @Eric That seems useful. Thanks. Added. – Divakar Nov 01 '17 at 07:02
2

This argsort produces a (3,2) array

In [453]: idx=np.argsort(A,axis=-1)
In [454]: idx
Out[454]: 
array([[0, 1],
       [1, 0],
       [0, 1]], dtype=int32)

As you note applying this to A to get the equivalent of np.sort(A, axis=-1) isn't obvious. The iterative solution is sort each row (a 1d case) with:

In [459]: np.array([x[i] for i,x in zip(idx,A)])
Out[459]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

While probably not the fastest, it is probably the clearest solution, and a good starting point for conceptualizing a better solution.

The tuple(inds) from the take solution is:

(array([[0],
        [1],
        [2]]), 
 array([[0, 1],
        [1, 0],
        [0, 1]], dtype=int32))
In [470]: A[_]
Out[470]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

In other words:

In [472]: A[np.arange(3)[:,None], idx]
Out[472]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

The first part is what np.ix_ would construct, but it does not 'like' the 2d idx.


Looks like I explored this topic a couple of years ago

argsort for a multidimensional ndarray

a[np.arange(np.shape(a)[0])[:,np.newaxis], np.argsort(a)]

I tried to explain what is going on. The take function does the same sort of thing, but constructs the indexing tuple for a more general case (dimensions and axis). Generalizing to more dimensions, but still with axis=-1 should be easy.

For the first axis, A[np.argsort(A,axis=0),np.arange(2)] works.

hpaulj
  • 221,503
  • 14
  • 230
  • 353
1

We just need to use advanced-indexing to index along all axes with those indices array. We can use np.ogrid to create open grids of range arrays along all axes and then replace only for the input axis with the input indices. Finally, index into data array with those indices for the desired output. Thus, essentially, we would have -

# Inputs : arr, ind, axis
idx = np.ogrid[tuple(map(slice, ind.shape))]
idx[axis] = ind
out = arr[tuple(idx)]

Just to make it functional and do error checks, let's create two functions - One to get those indices and second one to feed in the data array and simply index. The idea with the first function is to get the indices that could be re-used for indexing into any arbitrary array which would support the necessary number of dimensions and lengths along each axis.

Hence, the implementations would be -

def advindex_allaxes(ind, axis):
    axis = np.core.multiarray.normalize_axis_index(axis,ind.ndim)
    idx = np.ogrid[tuple(map(slice, ind.shape))]
    idx[axis] = ind
    return tuple(idx)

def take_along_axis(arr, ind, axis):
    return arr[advindex_allaxes(ind, axis)]

Sample runs -

In [161]: A = np.array([[3,2,1],[4,0,6]])

In [162]: B = np.array([[3,1,4],[1,5,9]])

In [163]: i = A.argsort(axis=-1)

In [164]: take_along_axis(A,i,axis=-1)
Out[164]: 
array([[1, 2, 3],
       [0, 4, 6]])

In [165]: take_along_axis(B,i,axis=-1)
Out[165]: 
array([[4, 1, 3],
       [5, 1, 9]])

Relevant one.

Divakar
  • 218,885
  • 19
  • 262
  • 358
  • Your implementation of `take_along_axis` is not equivalent to the one in the other answer in cases when `ind.ndim != arr.ndim` – Eric Nov 01 '17 at 07:09
  • @Eric Yeah, this one indexes along all axes, hence that function name. If that's what you were referring to? – Divakar Nov 01 '17 at 07:10
  • No, I mean that `idx[axis] = ind` should be `idx[axis:axis+ins_ndim] = [ind]`. Your code is only correct when `ins_ndim == 1` (the case the question asks for). That's probably fine, but you should add an `assert arr.ndim == ind.ndim` to avoid unexpected behaviour – Eric Nov 01 '17 at 07:11
  • @Eric Not really sure why we need that. Works fine for me. Take a look here - https://ideone.com/SK22fa Also, I am indexing along only one axis i.e. `axis` accepts a scalar only, maybe that's confusing you? – Divakar Nov 01 '17 at 07:54
  • All your tests have `ind.ndim == arr.ndim`, so of course they pass. My point is simply that this implementation of `take_along_axis` does not behave in the same way as the implementation in the numpy PR in cases where that condition does not hold. – Eric Nov 01 '17 at 08:15
  • @Eric And that's why I said this one indexes along all axes, hence that function name : `advindex_allaxes` to get indexing tuple :) I guess that PR is more generic as that's intended to cover reduced indices with `argmin, argmax`, not this one. – Divakar Nov 01 '17 at 08:17
  • @Eric And since I am using `ind` to index along all axes of `arr`, I am assuming `ind.ndim == arr.ndim`. – Divakar Nov 01 '17 at 08:25
  • 1
    Nothing like encoding your assumptions in `assert`s to ensure they're not violated :) – Eric Nov 01 '17 at 08:27