Skip to content

Commit 4b7e942

Browse files
committed
Improve coverage
1 parent ff22a3a commit 4b7e942

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

test/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ def test_batched_nms_implementations(self):
479479
err_msg = "The vanilla and the trick implementation yield different nms outputs."
480480
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
481481

482+
# Also make sure an empty tensor is returned if boxes is empty
483+
empty = torch.empty((0,), dtype=torch.int64)
484+
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
485+
482486

483487
class DeformConvTester(OpTester, unittest.TestCase):
484488
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):

torchvision/ops/boxes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def _batched_nms_vanilla(
9898
iou_threshold: float,
9999
) -> Tensor:
100100
# Based on Detectron2 implementation, just manually call nms() on each class independently
101-
if boxes.numel() == 0:
102-
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
103-
104101
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
105102
for class_id in torch.unique(idxs):
106103
curr_indices = torch.where(idxs == class_id)[0]

0 commit comments

Comments
 (0)