[OPS, IMP] New batched_nms implementation#3426
Conversation
fmassa
left a comment
There was a problem hiding this comment.
Thanks!
ONNX failures are related, and I think the easiest workaround is to simplify the checks be independent on the device
torchvision/ops/boxes.py
Outdated
| iou_threshold: float, | ||
| ) -> Tensor: | ||
| # Based on Detectron2 implementation | ||
| result_mask = scores.new_zeros(scores.size(), dtype=torch.bool) |
There was a problem hiding this comment.
should this be torch.zeros_like(scores, dtype = torch.bool)?
|
Looks like a nice boost. :) Does it make sense to check that this fix does not have any negative side-effect on the validation metrics of all pre-trained models? I think we should confirm by re-estimating all the validation stats and by retraining some of the models just before merge. |
|
The two implementations are equivalent in terms of results (they both rely on |
|
Looks like |
…ng to return vanilla now
Codecov Report
@@ Coverage Diff @@
## master #3426 +/- ##
==========================================
+ Coverage 78.70% 78.75% +0.04%
==========================================
Files 105 105
Lines 9735 9748 +13
Branches 1563 1565 +2
==========================================
+ Hits 7662 7677 +15
+ Misses 1582 1581 -1
+ Partials 491 490 -1
Continue to review full report at Codecov.
|
| boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) | ||
| assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 | ||
| assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 |
There was a problem hiding this comment.
nit: it might be preferable to construct boxes which are always well-formed.
So something like
boxes = torch.rand(num_boxes, 4)
boxes[:, 2:] += boxes[:, :2]is generally better when constructing bounding boxes. It might be good to move this to a helper function btw
Summary: * new batched_nms implem * flake8 * hopefully fix torchscipt tests * Use where instead of nonzero * Use same threshold (4k) for CPU and GPU * Remove use of argsort * use views again * remove print * trying stuff, I don't know what's going on * previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now * add tracing decorators * cleanup * wip * ignore new path with ONNX * use vanilla if tracing...???? * Remove script_if_tracing decorator as it was conflicting with _is_tracing * flake8 * Improve coverage Reviewed By: NicolasHug, cpuhrsch Differential Revision: D26945728 fbshipit-source-id: 118a41e03da2939a726e5bd18f5f77b7c0ce6339 Co-authored-by: Francisco Massa <[email protected]>

Closes #1311
This PR introduces a new implementation of
batched_nms, which is faster than the current one in some cases (refer to benchmarks in issue)