先来看函数定义:
1 2 3 4
| argmax(a, axis=None, out=None)
|
axis 默认为 None
二维 array 情况
1 2 3 4 5 6
| import numpy as np
a = np.array([[2,5,6],[7,6,1]]) print(np.argmax(a))
3
|
输出结果为 3,因为 a 里面 7 是最大的,如果没有指定 axis,默认就是 None,相当于把 array 平铺为:[2,5,6,7,6,1],那么结果就是 3,因为索引 3 对应的值最大
三维 array 的情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import numpy as np
b = np.random.randint(20, size=[3, 2, 2]) print(b) print(np.argmax(b))
[[[ 8 5] [18 9]]
[[18 8] [18 15]]
[[10 14] [14 8]]] 2
|
这个里面,18 最大,把它平铺,18 对应的索引就是 2,那么 np.argmax(b)就是 2.
给定 axis 的情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| import numpy as np
check = np.asarray(np.random.randint(1, 10, size=[2, 3, 4])) print(check) print(check.shape)
print(np.argmax(check, axis=0)) print(np.argmax(check, axis=1)) print(np.argmax(check, axis=2))
[[[8 5 6 5] [4 8 7 5] [5 9 5 5]]
[[5 3 8 7] [1 6 2 7] [9 5 6 1]]]
(2, 3, 4) [[0 0 1 1] [0 0 0 1] [1 0 1 0]] [[0 2 1 0] [2 1 0 0]] [[0 1 1] [2 3 0]]
|
axis=0
先说 axis=0 的情况,相当于将原来的矩阵根据 1 维展开,由于第一维的维度是 2(check.shape=(2, 3, 4)),所以可以得到两组元素,将两个元素平铺在一起进行比较。
1 2 3 4 5
| [[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] -> [[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]] ↓ ↓ ↓ [[0 0 1 1][0 0 0 1][1 0 1 0]]
|
也就是说去掉了第一层的[]
,然后将里面的两个元素拿出来进行比较。最后将比较的结果还原回矩阵。
axis=1
再说 axis=1 的情况,相当于将原来的矩阵根据 2 维展开,由于第二维的维度是 3(check.shape=(2, 3, 4)),所以可以得到三组元素,将三个元素平铺在一起进行比较。
1 2 3 4 5 6 7 8 9 10 11 12
| [0 2 1 0] -> ↓ ↑ [8 5 6 5] [[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] [4 8 7 5] [5 9 5 5] [[0 2 1 0] [2 1 0 0]] -> -> [5 3 8 7] [[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]] [1 6 2 7] [9 5 6 1] ↓ [2 1 0 0] -> ↑
|
也就是说去掉了第一层和第二层的[]
,然后将拿到的元素按照原来第二层的的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。
axis=2
再说 axis=2 的情况,相当于将原来的矩阵根据 3 维展开,由于第三维的维度是 4(check.shape=(2, 3, 4)),所以可以得到四组元素,将四组元素平铺在一起进行比较。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| [0 1 1] -> ↓ ↑ [8 4 5] [8 5 6 5] [5 8 9] [[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] [4 8 7 5] [6 7 5] [5 9 5 5] [5 5 5] [[0 1 1] -> -> -> [2 3 0]] [5 3 8 7] [5 1 9] [[5 3 8 7][1 6 2 7][9 5 6 1]]] [[5 3 8 7][1 6 2 7][9 5 6 1]] [1 6 2 7] [3 6 5] [9 5 6 1] [8 2 6] [7 7 1] ↓ [2 3 0] -> ↑
|
也就是说去掉了所有的[]
,然后将拿到的元素按照第三层的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。
总结为一句话,a…i…k…n,设 axis=k,那么沿着第 k 个下标的位置进行操作。
参考链接:
https://blog.x-fei.me/posts/argmax-function-and-its-usage/
本文章首发于个人博客 LLLibra146’s blog
本文作者:LLLibra146
更多文章请关注公众号 (LLLibra146):![LLLibra146]()
版权声明:本博客所有文章除特别声明外,均采用 © BY-NC-ND 许可协议。非商用转载请注明出处!严禁商业转载!
本文链接:
https://blog.d77.xyz/archives/bd46bf7.html