我需要在数组中每行保留最多 N (3) 个值。
a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
a
Out[135]:
array([[1, 2, 3, 4],
[8, 7, 6, 5],
[5, 3, 1, 2]])
它们的索引可以用np.partition来识别:
n=3
np.argpartition(a, -n, axis=1)[:,-n:]
Out[136]:
array([[1, 2, 3],
[2, 1, 0],
[3, 0, 1]], dtype=int64)
所以,我的问题是: 我应该如何保留这些索引的值并将其他索引设置为零以获得:
Out[136]:
array([[0, 2, 3, 4],
[8, 7, 6, 0],
[5, 3, 0, 2]])
最佳答案
a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
n=3
mask = np.argpartition(a, -n, axis=1) < a.shape[1] - n
a[mask] = 0
https://stackoverflow.com/questions/68290728/