numpy 中 argmax 函数的使用

先来看函数定义:

1
2
3
4
argmax(a, axis=None, out=None)
# a 表示 array
# axis 表示指定的轴,默认是 None,表示把 array 平铺,
# out 默认为 None,如果指定,那么返回的结果会插入其中

axis 默认为 None

二维 array 情况

1
2
3
4
5
6
import numpy as np

a = np.array([[2,5,6],[7,6,1]])
print(np.argmax(a))

3

输出结果为 3,因为 a 里面 7 是最大的,如果没有指定 axis,默认就是 None,相当于把 array 平铺为:[2,5,6,7,6,1],那么结果就是 3,因为索引 3 对应的值最大

三维 array 的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

b = np.random.randint(20, size=[3, 2, 2])
print(b)
print(np.argmax(b))

[[[ 8 5]
[18 9]]

[[18 8]
[18 15]]

[[10 14]
[14 8]]]
2

这个里面,18 最大,把它平铺,18 对应的索引就是 2,那么 np.argmax(b)就是 2.

给定 axis 的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np

check = np.asarray(np.random.randint(1, 10, size=[2, 3, 4]))
print(check)
print(check.shape)

print(np.argmax(check, axis=0))
print(np.argmax(check, axis=1))
print(np.argmax(check, axis=2))

[[[8 5 6 5]
[4 8 7 5]
[5 9 5 5]]

[[5 3 8 7]
[1 6 2 7]
[9 5 6 1]]]

(2, 3, 4)
[[0 0 1 1]
[0 0 0 1]
[1 0 1 0]]
[[0 2 1 0]
[2 1 0 0]]
[[0 1 1]
[2 3 0]]

axis=0

先说 axis=0 的情况,相当于将原来的矩阵根据 1 维展开,由于第一维的维度是 2(check.shape=(2, 3, 4)),所以可以得到两组元素,将两个元素平铺在一起进行比较。

1
2
3
4
5
[[[8 5 6 5][4 8 7 5][5 9 5 5]]      [[8 5 6 5][4 8 7 5][5 9 5 5]]
->
[[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]]
↓ ↓ ↓
[[0 0 1 1][0 0 0 1][1 0 1 0]]

也就是说去掉了第一层的[],然后将里面的两个元素拿出来进行比较。最后将比较的结果还原回矩阵。

axis=1

再说 axis=1 的情况,相当于将原来的矩阵根据 2 维展开,由于第二维的维度是 3(check.shape=(2, 3, 4)),所以可以得到三组元素,将三个元素平铺在一起进行比较。

1
2
3
4
5
6
7
8
9
10
11
12
                                                                        [0 2 1 0]    ->       ↓

[8 5 6 5]
[[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] [4 8 7 5]
[5 9 5 5] [[0 2 1 0]
[2 1 0 0]]
-> ->
[5 3 8 7]
[[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]] [1 6 2 7]
[9 5 6 1]

[2 1 0 0] -> ↑

也就是说去掉了第一层和第二层的[],然后将拿到的元素按照原来第二层的的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。

axis=2

再说 axis=2 的情况,相当于将原来的矩阵根据 3 维展开,由于第三维的维度是 4(check.shape=(2, 3, 4)),所以可以得到四组元素,将四组元素平铺在一起进行比较。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
                                                                                           [0 1 1]    ->   ↓

[8 4 5]
[8 5 6 5] [5 8 9]
[[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] [4 8 7 5] [6 7 5]
[5 9 5 5] [5 5 5]
[[0 1 1]
-> -> -> [2 3 0]]
[5 3 8 7] [5 1 9]
[[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]] [1 6 2 7] [3 6 5]
[9 5 6 1] [8 2 6]
[7 7 1]

[2 3 0] -> ↑

也就是说去掉了所有的[],然后将拿到的元素按照第三层的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。
总结为一句话,a…i…k…n,设 axis=k,那么沿着第 k 个下标的位置进行操作。

参考链接:

https://blog.x-fei.me/posts/argmax-function-and-its-usage/

本文章首发于个人博客 LLLibra146’s blog
本文作者:LLLibra146
版权声明:本博客所有文章除特别声明外,均采用 © BY-NC-ND 许可协议。非商用转载请注明出处!严禁商业转载!
本文链接https://blog.d77.xyz/archives/bd46bf7.html