python - 理解 NumPy 的 einsum

我正在努力理解究竟是如何 einsum 作品。我查看了文档和一些示例,但似乎并没有坚持。
这是我们在类里面看过的一个例子:

C = np.einsum("ij,jk->ki", A, B)
对于两个数组:AB .
我认为这需要 A^T * B ,但我不确定(它正在对其中一个进行转置,对吗?)。谁能告诉我这里到底发生了什么(通常在使用 einsum 时)?

最佳答案

(注意:此答案基于我不久前写的关于 einsum 的简短 blog post。)
什么einsum做?
假设我们有两个多维数组,AB .现在让我们假设我们想要...

  • AB以一种特殊的方式来创造新的产品系列;然后也许
  • 沿特定轴对这个新数组求和;然后也许
  • 按特定顺序转置新数组的轴。

  • 很有可能 einsummultiply 等 NumPy 函数的组合相比,它将帮助我们更快、更高效地完成此操作。 , sumtranspose会同意。
    怎么样einsum工作?
    这是一个简单(但并非完全微不足道)的示例。取以下两个数组:
    A = np.array([0, 1, 2])
    
    B = np.array([[ 0,  1,  2,  3],
                  [ 4,  5,  6,  7],
                  [ 8,  9, 10, 11]])
    
    我们将相乘 AB逐元素,然后沿新数组的行求和。在“正常”的 NumPy 中,我们会这样写:
    >>> (A[:, np.newaxis] * B).sum(axis=1)
    array([ 0, 22, 76])
    
    所以在这里,对 A 的索引操作排列两个数组的第一个轴,以便可以广播乘法。然后将产品数组的行相加以返回答案。
    现在,如果我们想使用 einsum相反,我们可以写:
    >>> np.einsum('i,ij->i', A, B)
    array([ 0, 22, 76])
    
    签名字符串'i,ij->i'是这里的关键,需要稍微解释一下。你可以把它想成两半。在左侧(-> 的左侧),我们标记了两个输入数组。在 -> 的右侧,我们已经标记了我们想要结束的数组。
    下面是接下来发生的事情:
  • A有一个轴;我们已经给它贴上了标签 i .和 B有两个轴;我们将轴 0 标记为 i和轴 1 为 j .
  • 来自 重复 标签i在两个输入数组中,我们告诉 einsum这两个轴应该是 乘以 一起。换句话说,我们乘以数组 A与数组的每一列 B ,就像 A[:, np.newaxis] * B做。
  • 请注意 j在我们想要的输出中没有作为标 checkout 现;我们刚刚用过 i (我们希望以一维数组结束)。来自 省略 标签,我们告诉 einsum总和 沿着这个轴。换句话说,我们对乘积的行求和,就像 .sum(axis=1)做。

  • 这基本上就是您使用 einsum 所需要知道的全部内容。 .稍微玩一下会有所帮助;如果我们在输出中保留两个标签,'i,ij->ij' ,我们得到一个二维的产品数组(与 A[:, np.newaxis] * B 相同)。如果我们说没有输出标签,'i,ij-> ,我们得到一个单一的数字(与做 (A[:, np.newaxis] * B).sum() 相同)。
    关于 einsum 的伟大之处然而,它并没有先构建一个临时的产品阵列;它只是对产品进行汇总。这可以大大节省内存使用。
    一个稍微大一点的例子
    为了解释点积,这里有两个新数组:
    A = array([[1, 1, 1],
               [2, 2, 2],
               [5, 5, 5]])
    
    B = array([[0, 1, 0],
               [1, 1, 0],
               [1, 1, 1]])
    
    我们将使用 np.einsum('ij,jk->ik', A, B) 计算点积.这是一张显示 A 标签的图片和 B以及我们从函数中得到的输出数组:

    你可以看到标签 j重复 - 这意味着我们将 A 的行相乘列 B .此外,标签 j不包括在输出中 - 我们正在对这些乘积求和。标签 ik保留用于输出,所以我们得到一个二维数组。
    将此结果与标签 j 所在的数组进行比较可能会更清楚。没有总结。下面,在左侧,您可以看到写入 np.einsum('ij,jk->ijk', A, B) 所产生的 3D 数组。 (即我们保留了标签 j ):

    求和轴 j给出预期的点积,如右图所示。
    一些练习
    获得更多感受 einsum ,使用下标表示法实现熟悉的 NumPy 数组操作会很有用。任何涉及乘法和求和轴组合的内容都可以使用 einsum 编写。 .
    设 A 和 B 是两个长度相同的一维数组。例如,A = np.arange(10)B = np.arange(5, 15) .
  • A的总和可以写成:
    np.einsum('i->', A)
    
  • 逐元素乘法,A * B ,可以写成:
    np.einsum('i,i->i', A, B)
    
  • 内积或点积,np.inner(A, B)np.dot(A, B) ,可以写成:
    np.einsum('i,i->', A, B) # or just use 'i,i'
    
  • 外积,np.outer(A, B) ,可以写成:
    np.einsum('i,j->ij', A, B)
    

  • 对于二维数组,CD ,假设轴是兼容的长度(两者长度相同或其中之一的长度为 1),以下是一些示例:
  • C的踪迹(主对角线的总和),np.trace(C) ,可以写成:
    np.einsum('ii', C)
    
  • C 的逐元素乘法和 D 的转置, C * D.T ,可以写成:
    np.einsum('ij,ji->ij', C, D)
    
  • C 的每个元素相乘通过数组 D (制作 4D 阵列),C[:, :, None, None] * D ,可以写成:
    np.einsum('ij,kl->ijkl', C, D)    
    
  • https://stackoverflow.com/questions/26089893/

    相关文章:

    python - 类型错误 : Missing 1 required positional argu

    python - functools partial 是如何做到的?

    linux - 我可以在已编译的二进制文件中更改 'rpath' 吗?

    python - 我如何捕捉一个像异常一样的 numpy 警告(不仅仅是为了测试)?

    linux - 使用 bash 历史记录获取先前的命令,复制它,然后 'run' 它但命令注释

    linux - Git 和硬链接(hard link)

    linux - 如何删除 CLOSE_WAIT 套接字连接

    c - 如何从 C 程序中获得 100% 的 CPU 使用率

    python - 为什么 apt-get 功能在 Mac OS X v10.9 (Mavericks

    linux - 在 printf 中使用颜色