我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的一个 float 。
Pytorch 中的 torch.nn.CrossEntropy 不支持软标签,所以我想自己写一个交叉熵函数。
我的函数是这样的
def cross_entropy(self, pred, target):
loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
return loss
def step(self, batch: Any):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = logits
# torch.argmax(logits, dim=1)
return loss, preds, y
然而它根本不起作用。
谁能给我一个建议,我的损失函数有没有错误?
最佳答案
好像BCELoss
和健壮的版本 BCEWithLogitsLoss
正在“开箱即用”地处理模糊目标。他们不希望 target
是二进制的“0 到 1 之间的任何数字都可以。
请阅读文档。
https://stackoverflow.com/questions/70429846/
相关文章:
c++ - 无法构建 Boost Spirit 示例 conjure2
javascript - Typescript reducer 的 switch case type
reactjs - 如何在 React 中使用 map 输出嵌套对象的内容?
android - 我们如何保存和恢复 Android StateFlow 的状态?
ios - FaSTLane Match 环境变量未被 build_app 获取
python - 是否可以在 discord.py 中对不同的前缀使用不同的命令?