Helper functions to get largest, smallest n-indices, elements along an axis
Here's a helper function to select top n-largest
indices along a generic axis from a generic ndarray making use of np.argpartition
and np.take_along_axis
-
def take_largest_indices_along_axis(ar, n, axis):
s = ar.ndim*[slice(None,None,None)]
s[axis] = slice(-n,None,None)
idx = np.argpartition(ar, kth=-n, axis=axis)[tuple(s)]
sidx = np.take_along_axis(ar,idx, axis=axis).argsort(axis=axis)
return np.flip(np.take_along_axis(idx, sidx, axis=axis),axis=axis)
Extending this to get n-smallest indices -
def take_smallest_indices_along_axis(ar, n, axis):
s = ar.ndim*[slice(None,None,None)]
s[axis] = slice(None,n,None)
idx = np.argpartition(ar, kth=n, axis=axis)[tuple(s)]
sidx = np.take_along_axis(ar,idx, axis=axis).argsort(axis=axis)
return np.take_along_axis(idx, sidx, axis=axis)
And extending these to select the largest or smallest n
elements themselves, it would be with a simple usage of np.take_along_axis
as listed next -
def take_largest_along_axis(ar, n, axis):
idx = take_largest_indices_along_axis(ar, n, axis)
return np.take_along_axis(ar, idx, axis=axis)
def take_smallest_along_axis(ar, n, axis):
idx = take_smallest_indices_along_axis(ar, n, axis)
return np.take_along_axis(ar, idx, axis=axis)
Sample runs
# Sample setup
In [200]: np.random.seed(0)
...: ar = np.random.randint(0,99,(5,5))
In [201]: ar
Out[201]:
array([[44, 47, 64, 67, 67],
[ 9, 83, 21, 36, 87],
[70, 88, 88, 12, 58],
[65, 39, 87, 46, 88],
[81, 37, 25, 77, 72]])
Take largest n
indices, elements along axis -
In [202]: take_largest_indices_along_axis(ar, n=2, axis=0)
Out[202]:
array([[4, 2, 2, 4, 3],
[2, 1, 3, 0, 1]])
In [203]: take_largest_indices_along_axis(ar, n=2, axis=1)
Out[203]:
array([[4, 3],
[4, 1],
[2, 1],
[4, 2],
[0, 3]])
In [251]: take_largest_along_axis(ar, n=2, axis=0)
Out[251]:
array([[81, 88, 88, 77, 88],
[70, 83, 87, 67, 87]])
In [252]: take_largest_along_axis(ar, n=2, axis=1)
Out[252]:
array([[67, 67],
[87, 83],
[88, 88],
[88, 87],
[81, 77]])
Take smallest n
indices, elements along axis -
In [232]: take_smallest_indices_along_axis(ar, n=2, axis=0)
Out[232]:
array([[1, 4, 1, 2, 2],
[0, 3, 4, 1, 0]])
In [233]: take_smallest_indices_along_axis(ar, n=2, axis=1)
Out[233]:
array([[0, 1],
[0, 2],
[3, 4],
[1, 3],
[2, 1]])
In [253]: take_smallest_along_axis(ar, n=2, axis=0)
Out[253]:
array([[ 9, 37, 21, 12, 58],
[44, 39, 25, 36, 67]])
In [254]: take_smallest_along_axis(ar, n=2, axis=1)
Out[254]:
array([[44, 47],
[ 9, 21],
[12, 58],
[39, 46],
[25, 37]])
Solving our case here
For our case, let's assume the input is similarities
and is of shape (1000,128)
representing 1000 data points and 128 features and that we want to look for largest say n=10
features for each of those data points, then it would be -
take_largest_indices_along_axis(similarities, n=10, axis=1) # indices
take_largest_along_axis(similarities, n=10, axis=1) # elements
The final indices/values array would be of shape (1000, n)
.
Sample run with the given dataset shape -
In [257]: np.random.seed(0)
...: similarities = np.random.randint(0,99,(1000,128))
In [263]: take_largest_indices_along_axis(similarities, n=10, axis=1).shape
Out[263]: (1000, 10)
In [264]: take_largest_along_axis(similarities, n=10, axis=1).shape
Out[264]: (1000, 10)
If instead you were looking to get n
largest data-points for each of those features, that is the final indices/values array would be of shape (n, 128)
, then use axis=0
.