numpy 中 argmax 函数的使用
先来看函数定义:
1 | argmax(a, axis=None, out=None) |
axis 默认为 None
二维 array 情况
1 | import numpy as np |
输出结果为 3,因为 a 里面 7 是最大的,如果没有指定 axis,默认就是 None,相当于把 array 平铺为:[2,5,6,7,6,1],那么结果就是 3,因为索引 3 对应的值最大
三维 array 的情况
1 | import numpy as np |
这个里面,18 最大,把它平铺,18 对应的索引就是 2,那么 np.argmax(b)就是 2.
给定 axis 的情况
1 | import numpy as np |
axis=0
先说 axis=0 的情况,相当于将原来的矩阵根据 1 维展开,由于第一维的维度是 2(check.shape=(2, 3, 4)),所以可以得到两组元素,将两个元素平铺在一起进行比较。
1 | [[[8 5 6 5][4 8 7 5][5 9 5 5]] [[8 5 6 5][4 8 7 5][5 9 5 5]] |
也就是说去掉了第一层的[]
,然后将里面的两个元素拿出来进行比较。最后将比较的结果还原回矩阵。
axis=1
再说 axis=1 的情况,相当于将原来的矩阵根据 2 维展开,由于第二维的维度是 3(check.shape=(2, 3, 4)),所以可以得到三组元素,将三个元素平铺在一起进行比较。
1 | [0 2 1 0] -> ↓ |
也就是说去掉了第一层和第二层的[]
,然后将拿到的元素按照原来第二层的的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。
axis=2
再说 axis=2 的情况,相当于将原来的矩阵根据 3 维展开,由于第三维的维度是 4(check.shape=(2, 3, 4)),所以可以得到四组元素,将四组元素平铺在一起进行比较。
1 | [0 1 1] -> ↓ |
也就是说去掉了所有的[]
,然后将拿到的元素按照第三层的分组进行分组,比较大小,获得索引,然后将比较的结果还原回矩阵。
总结为一句话,a…i…k…n,设 axis=k,那么沿着第 k 个下标的位置进行操作。
参考链接:
https://blog.x-fei.me/posts/argmax-function-and-its-usage/
本文章首发于个人博客 LLLibra146’s blog
本文作者:LLLibra146
版权声明:本博客所有文章除特别声明外,均采用 © BY-NC-ND 许可协议。非商用转载请注明出处!严禁商业转载!
本文链接:https://blog.d77.xyz/archives/bd46bf7.html