pytorch - Pytorch 中软标签的交叉熵

我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的一个 float 。

Pytorch 中的 torch.nn.CrossEntropy 不支持软标签,所以我想自己写一个交叉熵函数。

我的函数是这样的

def cross_entropy(self, pred, target):
    loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
    return loss

def step(self, batch: Any):
    x, y = batch
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = logits
    # torch.argmax(logits, dim=1)
    return loss, preds, y

然而它根本不起作用。

谁能给我一个建议,我的损失函数有没有错误?

最佳答案

好像BCELoss和健壮的版本 BCEWithLogitsLoss正在“开箱即用”地处理模糊目标。他们不希望 target 是二进制的“0 到 1 之间的任何数字都可以。
请阅读文档。

https://stackoverflow.com/questions/70429846/

相关文章:

c++ - 无法构建 Boost Spirit 示例 conjure2

javascript - Typescript reducer 的 switch case type

reactjs - 如何在 React 中使用 map 输出嵌套对象的内容?

android - 我们如何保存和恢复 Android StateFlow 的状态?

c# - 在 Visual Studio 中替换后删除空行

ios - FaSTLane Match 环境变量未被 build_app 获取

typescript - 数组中泛型的联合

python - 是否可以在 discord.py 中对不同的前缀使用不同的命令?

java - ThreadPoolExecutor.shutdownNow() 没有在 Thread

ruby-on-rails - 无法使用 octokit 连接到 Github 帐户