Getting Max Element Index From The Array Along A Particular Axis
NumPy is often used to handle or work with arrays (multidimensional, masked) and matrices. It has a collection of functions and methods to operate on arrays like statistical operations, mathematical and logical operations, shape manipulation, linear algebra, and much more.
Argmax function
numpy.argmax() is one of the functions provided by NumPy that is used to return the indices of the maximum element along an axis from the specified array.
Syntax
numpy.argmax(a, axis=None, out=None)
Parameters:
a
- The input array we will work onaxis
- It is optional. We can specify an axis like 1 or 0 to find the maximum value index horizontally or vertically.out
- By default, it is None. It provides a feature to insert the output to the out array, but the array should be of appropriate shape and dtype.
Return value
The array of integers is returned with the indices of max values from the array
a
with the same shape asa.shape
with the dimension along the axis removed.
Finding the index of the max element
Let’s see the basic example to find the index of the max element in the array.
Working with a 1D array without specifying the axis
Output
MAX ELEMENT INDEX: 2
Working with a 2D array without specifying the axis
When we work with 2D arrays in numpy and try to find the index of the max element without defining the axis, the array we are working on has the element index the same as the 1D array or flattened array.
Output
INPUT ARRAY:
[[ 5 5 4 12]
[12 15 13 0]
[11 13 2 6]
[ 6 8 8 9]]MAX ELEMENT INDEX: 5
Finding the index of the max element along the axis
Things will change when we specify the axis and try to find the index of the max element along it.
When the axis is 0
When we set the axis=0
, the function will find the index of the max element vertically in the multidimensional array that the user selected. Let's understand it by an illustration below.
In the above illustration, the argmax() function returned the max element index from the 1st column, which is 1, and then returned the max element index from the 2nd column, which is again 1, and the same goes for the 3rd and the 4th column.
Code example
Output
INPUT ARRAY:
[[ 8 6 10 3]
[ 4 5 9 1]
[ 6 15 13 13]
[ 4 14 15 13]]MAX ELEMENT INDEX: [0 2 3 2]
When the axis is 1
When we specify the axis=1
, the function will find the index of the max element horizontally in the multidimensional array that the user selected. Let's understand it by an illustration below.
In the above illustration, the argmax() function returned the max element index from the 1st row, which is 2, and then returned the max element index from the 2nd row, which is 1, and the same goes for the 3rd and the 4th row.
Code example
Output
INPUT ARRAY:
[[ 7 8 0]
[ 3 0 11]
[ 7 6 0]
[10 8 1]]MAX ELEMENT INDEX: [1 2 0 0]
Multiple occurrences of the highest value
Sometimes, we can come across multidimensional arrays with multiple occurrences of the highest values along the particular axis, then what will happen?
The function will return the index of the highest value that occurs first in a particular axis.
Illustration showing multiple occurrences of the highest values along axis 0
Illustration showing multiple occurrences of the highest values along axis 1
Code example
Output
MAX ELEMENT INDEX: [2 0 1 1 2]MAX ELEMENT INDEX: [1 2 0]The array is flattened into 1D array: [ 2 14 9 4 5 7 14 53 10 4 91 2 41 6 91]MAX ELEMENT INDEX: 10
Explanation
In the above code, when we try to find the indices of the max elements along the axis 0, we got an array with values [2 0 1 1 2]
; if we look at the 2nd column, 14 is the highest value at the 0th and the 1st index, we got 0
because the value 14 at the 0th index occurred first when finding the highest value.
The same goes for the array we obtained in the second output when we provided the axis 1, in the 3rd row, 91 is the highest value at the 0th and the 4th index, and the value 91 at the 0th index occurred first when finding the highest value hence we got output 0
.
Using the out parameter
The out
parameter in numpy.argmax() function is optional, by default, it is None.
The out
parameter stores the output array(containing indices of the max elements in a particular axis) in a numpy array. The array specified in the out parameter should be of shape and dtype, the same as the input array.
Code Example
Output
ARRAY w/ ZEROES: [0 0 0 0]
INPUT ARRAY:
[[ 4 2 14 15]
[ 6 15 2 1]
[13 6 13 3]
[ 5 1 13 9]]AXIS 1: [3 1 0 2]AXIS 0: [2 1 0 0]
Explanation
We created an array filled with zeroes named out_array
, and we defined the shape and dtype same as the input array and then used the numpy.argmax()
function to get the indices of the max elements along the axes 1 and 0 and store them in the out_array we defined earlier.
The numpy.zeros()
by default has dtype float
, that’s why we defined the dtype in the above code because our input array has the dtype=int
.
If we didn’t specify the dtype in the above code, it would throw an error.
Output
ARRAY w/ ZEROES: [0. 0. 0. 0.]
INPUT ARRAY:
[[14 9 3 4]
[ 9 2 4 8]
[ 5 1 9 1]
[ 6 0 10 7]]TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Conclusion
That was the insight of the argmax() function in NumPy. Let’s review what we’ve learned:
numpy.argmax()
function returns the index of the highest value in an array. If the maximum value occurs more than once in an array(multidimensional or flattened array), then the function will return the index of the highest value which occurred first.- We can specify the axis parameter when working with a multidimensional array to get the result along a particular axis. If we set axis=0, we’ll get the index of the highest values vertically in the multidimensional array, and for axis=1, we’ll get the result horizontally in the multidimensional array.
- We can store the output in another array specified in the out parameter; however, the array should be compatible with the input array.
That’s all for now
Keep Coding✌✌
Originally published at https://geekpython.in.