Skip to content

Conversation

@alykhantejani
Copy link
Contributor

In response to issue #1939 where this failed, due to broadcasting because o and t have different shapes

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

sigmoid = nn.Sigmoid()

t = np.round(np.random.rand(64))
o = np.random.rand(64,1) - 0.5

t = Variable(torch.Tensor(t))
o = Variable(torch.Tensor(o))

print(nn.BCEWithLogitsLoss()(o, t))
print(nn.BCELoss()(sigmoid(o), t)) # Different numbers

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good solution. It's better to have a input/target shape spec

@alykhantejani
Copy link
Contributor Author

@apaszke you mean assert that target.is_same_size(input)?

@apaszke
Copy link
Contributor

apaszke commented Jun 29, 2017

Yeah I guess that if it's an element-wise loss then they should match exactly.

@alykhantejani
Copy link
Contributor Author

This was my first thought too (as the docstring also says they should be the same size - the same for BCELoss).

cc @martinarjovsky who had some thoughts on this.

@alykhantejani
Copy link
Contributor Author

(btw I think there's still an issue in my weight sizing in the test - I'll fix that once we reach a conclusion on how we should fix this issue generally)

@alykhantejani
Copy link
Contributor Author

alykhantejani commented Jun 30, 2017

@apaszke if we add a check in binary_cross_entropy_with_logits to raise if target.size() != input.size(), then we should also add this to binary_cross_entropy as their interfaces should match.

However, this will break backwards compatibility as I think many people use it in the following way:

target = Variable(torch.rand(64))
output = Variable(torch.rand(64,1) - 0.5)
nn.BCELoss()(sigmoid(output), target)

@alykhantejani
Copy link
Contributor Author

Looking at the interface of BCELoss a bit more, it is also quite inconsistent i.e.

target = Variable(torch.rand(5))
output = Variable(torch.rand(5, 1) - 0.5)
nn.BCELoss()(nn.Sigmoid()(output), target) # this works fine
weight = torch.ones(1)
nn.BCELoss(weight)(nn.Sigmoid()(output), target) #this fails as target.size() != input.size()
nn.BCELoss(weight)(nn.Sigmoid()(output), target.unsqueeze(1)) #this works

I'm in favour of asserting target.size() == input.size() as it leads to simpler code, happy to ammend this PR if you agree @apaszke (but it will break backwards compatibility of BCELoss with single dim targets)

@apaszke
Copy link
Contributor

apaszke commented Jun 30, 2017

How about adding strict checks to with_logits version and deprecation warnings to the ones that we had earlier?

@alykhantejani
Copy link
Contributor Author

Sounds sensible - I'll update this PR with that then

@alykhantejani
Copy link
Contributor Author

@apaszke I've made the changes and added the warning in functional.binary_cross_entropy and in torch/nn/_functions/thnn/loss.BCELoss. Should I just add the warning in BCECriterion.c instead?

@martinarjovsky
Copy link
Contributor

martinarjovsky commented Jul 1, 2017

Hello! Adding checks for sizes would be good. The thing I think it's really necessary is that behaviour in the 'with_logits' version is exactly the same as the one on probabilities [for BCE and non-binary CE]. The main thing that was worrisome to me was that a shape mismatch affected one and not the other. Cheers :)

@soumith soumith merged commit 4575870 into pytorch:master Jul 2, 2017
@soumith
Copy link
Contributor

soumith commented Jul 2, 2017

thanks Alykhan!

for each minibatch.
"""
if not target.is_same_size(input):
warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. "

This comment was marked as off-topic.

This comment was marked as off-topic.

@alykhantejani alykhantejani deleted the fix_broadcasting_issues_in_bceloss branch July 4, 2017 21:14
houseroad added a commit to houseroad/pytorch that referenced this pull request Apr 19, 2019
…09c7db (pytorch#19454)

Summary:
Pull Request resolved: pytorch#19454

Previous import was ad7313470a9119d7e1afda7edf1d654497ee80ab

Included changes:
- **[83dd6265](onnx/onnx@83dd6265)**: Add NonMaxSuppression operator (pytorch#1703) <Hector Li>
- **[31ca5d6f](onnx/onnx@31ca5d6f)**: add node tests for quantized ops (pytorch#1944) <Ashwini Khade>
- **[e6076c1d](onnx/onnx@e6076c1d)**: Fix test stat coverage script (pytorch#1948) <Raymond Yang>
- **[ad036405](onnx/onnx@ad036405)**: Add IsInf to detect infinity values (pytorch#1884) <Wei-Sheng Chin>

Differential Revision: D15010015

fbshipit-source-id: 9778757752785fe3169ad2ac606b37299aa69da6
facebook-github-bot pushed a commit that referenced this pull request Apr 22, 2019
…09c7db (#19454)

Summary:
Pull Request resolved: #19454

Previous import was ad7313470a9119d7e1afda7edf1d654497ee80ab

Included changes:
- **[83dd6265](onnx/onnx@83dd6265)**: Add NonMaxSuppression operator (#1703) <Hector Li>
- **[31ca5d6f](onnx/onnx@31ca5d6f)**: add node tests for quantized ops (#1944) <Ashwini Khade>
- **[e6076c1d](onnx/onnx@e6076c1d)**: Fix test stat coverage script (#1948) <Raymond Yang>
- **[ad036405](onnx/onnx@ad036405)**: Add IsInf to detect infinity values (#1884) <Wei-Sheng Chin>

Reviewed By: benoitsteiner

Differential Revision: D15010015

fbshipit-source-id: 4b29de21de60f8e6a2db75309809a4e619c92532
zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
…09c7db (pytorch#19454)

Summary:
Pull Request resolved: pytorch#19454

Previous import was ad7313470a9119d7e1afda7edf1d654497ee80ab

Included changes:
- **[83dd6265](onnx/onnx@83dd6265)**: Add NonMaxSuppression operator (pytorch#1703) <Hector Li>
- **[31ca5d6f](onnx/onnx@31ca5d6f)**: add node tests for quantized ops (pytorch#1944) <Ashwini Khade>
- **[e6076c1d](onnx/onnx@e6076c1d)**: Fix test stat coverage script (pytorch#1948) <Raymond Yang>
- **[ad036405](onnx/onnx@ad036405)**: Add IsInf to detect infinity values (pytorch#1884) <Wei-Sheng Chin>

Reviewed By: benoitsteiner

Differential Revision: D15010015

fbshipit-source-id: 4b29de21de60f8e6a2db75309809a4e619c92532
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants