python - 如何在 Python 中实现 Softmax 函数

来自 Udacity's deep learning class , y_i 的 softmax 就是简单的指数除以整个 Y 向量的指数之和:

其中S(y_i)y_i的softmax函数,e是指数,j是没有。输入向量 Y 中的列数。

我尝试了以下方法:

import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

scores = [3.0, 1.0, 0.2]
print(softmax(scores))

返回:

[ 0.8360188   0.11314284  0.05083836]

但建议的解决方案是:

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

它产生与第一个实现相同的输出,即使第一个实现显式地获取每列的差值和最大值,然后除以总和。

有人能用数学方法说明原因吗?一个正确一个错误?

实现在代码和时间复杂度方面是否相似?哪个更高效?

最佳答案

它们都是正确的,但从数值稳定性的角度来看,你的更好。

你开始

e ^ (x - max(x)) / sum(e^(x - max(x))

通过使用 a^(b - c) = (a^b)/(a^c) 我们有这个事实

= e ^ x / (e ^ max(x) * sum(e ^ x / e ^ max(x)))

= e ^ x / sum(e ^ x)

这是另一个答案所说的。您可以将 max(x) 替换为任何变量,它会取消。

https://stackoverflow.com/questions/34968722/

相关文章:

linux - Linux 内核中可能/不太可能的宏是如何工作的,它们有什么好处?

python - 如何在 Django 中按日期范围过滤查询对象?

linux - 如何从 Linux shell 运行具有与当前不同工作目录的程序?

linux - 如何在 Linux 上按名称而不是 PID 杀死进程?

python - 使用多处理 Pool.map() 时无法 pickle

python - 在 Python 中使用多处理时我应该如何记录?

python - 超出相对导入中的顶级包错误

linux - 在一行中执行组合多个 Linux 命令

c - 什么是 LD_PRELOAD 技巧?

python - 为什么 PEP-8 指定最大行长度为 79 个字符?