Skip to content

Commit 70c65f0

Browse files
committed
unpack input data
1 parent 3a27397 commit 70c65f0

File tree

1 file changed

+99
-63
lines changed

1 file changed

+99
-63
lines changed

python/tvm/topi/cuda/nms.py

Lines changed: 99 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm import te
2222

2323
from tvm.tir import if_then_else
24-
from .sort import argsort, argsort_thrust
24+
from .sort import argsort, argsort_thrust, is_thrust_available
2525

2626

2727
def cuda_atomic_add_rule(op):
@@ -338,7 +338,9 @@ def nms_ir(
338338
sorted_index,
339339
valid_count,
340340
indices,
341-
out,
341+
out_bboxes,
342+
out_scores,
343+
out_class_ids,
342344
box_indices,
343345
num_valid_boxes,
344346
max_output_size,
@@ -458,9 +460,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
458460
sorted_index = ib.buffer_ptr(sorted_index)
459461
valid_count = ib.buffer_ptr(valid_count)
460462
indices = ib.buffer_ptr(indices)
461-
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
462-
out = ib.buffer_ptr(out)
463+
464+
# outputs
465+
out_bboxes = ib.buffer_ptr(out_bboxes)
466+
out_scores = ib.buffer_ptr(out_scores)
467+
out_class_ids = ib.buffer_ptr(out_class_ids)
463468
box_indices = ib.buffer_ptr(box_indices)
469+
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
464470

465471
if isinstance(iou_threshold, float):
466472
iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
@@ -483,36 +489,55 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
483489
ib.scope_attr(tx, "thread_extent", nthread_tx)
484490
ib.scope_attr(bx, "thread_extent", nthread_bx)
485491
i = by
486-
base_idx = i * num_anchors * box_data_length
492+
base_src_idx = i * num_anchors * box_data_length
493+
base_bbox_idx = i * num_anchors * 4
494+
487495
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
488496
# Reorder output
489497
nkeep = if_then_else(
490498
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
491499
)
492500
j = bx * max_threads + tx
493501
with ib.if_scope(j < nkeep):
502+
src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length
494503
# Fill in out with sorted boxes
495-
with ib.for_range(0, box_data_length) as k:
496-
out[(base_idx + j * box_data_length + k)] = data[
497-
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
498-
]
504+
with ib.for_range(0, 4) as k:
505+
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k]
506+
507+
out_scores[i * num_anchors + j] = data[src_idx + score_index]
508+
509+
if id_index >= 0:
510+
out_class_ids[i * num_anchors + j] = data[src_idx + id_index]
511+
499512
with ib.else_scope():
500513
# Indices > nkeep are discarded
501514
# Only needed for return_indices = False case
502515
if return_indices is False:
503516
with ib.if_scope(j < num_anchors):
504-
with ib.for_range(0, box_data_length) as k:
505-
out[(base_idx + j * box_data_length + k)] = -1.0
517+
with ib.for_range(0, 4) as k:
518+
out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0
519+
520+
out_scores[i, j] = -1.0
521+
522+
if id_index >= 0:
523+
out_class_ids[i, j] = -1.0
506524

507525
if return_indices:
508526
with ib.if_scope(j < num_anchors):
509527
box_indices[i * num_anchors + j] = -1
510528

511529
with ib.else_scope():
512530
with ib.if_scope(j < valid_count[i]):
513-
with ib.for_range(0, box_data_length) as k:
514-
offset = base_idx + j * box_data_length + k
515-
out[offset] = data[offset]
531+
src_offset = base_src_idx + j * box_data_length
532+
533+
with ib.for_range(0, 4) as k:
534+
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k]
535+
536+
out_scores[i * num_anchors + j] = data[src_offset + score_index]
537+
538+
if id_index >= 0:
539+
out_class_ids[i * num_anchors + j] = data[src_offset + id_index]
540+
516541
box_indices[i * num_anchors + j] = j
517542

518543
with ib.new_scope():
@@ -526,7 +551,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
526551

527552
i = by
528553

529-
base_idx = i * num_anchors * box_data_length
554+
base_bbox_idx = i * num_anchors * 4
530555
num_valid_boxes_local = ib.allocate(
531556
"int32", (1,), name="num_valid_boxes_local", scope="local"
532557
)
@@ -549,37 +574,35 @@ def nms_inner_loop(ib, j):
549574

550575
num_valid_boxes_local[0] += 1
551576

552-
offset_j = j * box_data_length
577+
offset_j = j * 4
553578
num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx)
554579

555580
with ib.for_range(0, num_iter_per_thread) as _k:
556581
k = j + 1 + _k * nthread_tx + tx
557-
offset_k = k * box_data_length
582+
offset_k = k * 4
558583

559584
with ib.if_scope(
560585
tvm.tir.all(
561586
k < num_anchors,
562-
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
587+
out_scores[i, k] > 0, # is the box k still valid?
563588
tvm.tir.any(
564589
force_suppress > 0,
565590
id_index < 0,
566-
out[base_idx + offset_k + id_index]
567-
== out[base_idx + offset_j + id_index],
591+
out_class_ids[i, k] == out_class_ids[i, j],
568592
),
569593
)
570594
):
571595
iou = calculate_overlap(
572-
out,
573-
base_idx + offset_j + coord_start,
574-
base_idx + offset_k + coord_start,
596+
out_bboxes,
597+
base_bbox_idx + offset_j,
598+
base_bbox_idx + offset_k,
575599
)
576600
with ib.if_scope(iou >= iou_threshold):
577601
# invalidate the box k
578-
out[base_idx + offset_k + score_index] = -1.0
579-
with ib.if_scope(id_index >= 0):
580-
out[base_idx + offset_k + id_index] = -1.0
602+
out_scores[i, k] = -1.0
603+
if return_indices is False and id_index >= 0:
604+
out_class_ids[i, k] = -1.0
581605

582-
# Make sure to do the next loop in a lock step
583606
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
584607

585608
if isinstance(max_output_size, int):
@@ -589,7 +612,7 @@ def nms_inner_loop(ib, j):
589612
# Apply nms
590613
with ib.for_range(0, valid_count[i]) as j:
591614
# Proceed to the inner loop if the box j is still valid
592-
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
615+
with ib.if_scope(out_scores[i, j] > -1.0):
593616
with ib.if_scope(max_output_size > 0):
594617
# No need to do more iteration if we already reach max_output_size boxes
595618
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
@@ -638,6 +661,33 @@ def _fetch_score_ir(data, score, axis):
638661
return ib.get()
639662

640663

664+
def _get_sorted_indices(data, data_buf, score_index, score_shape):
665+
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
666+
score_tensor = te.extern(
667+
[score_shape],
668+
[data],
669+
lambda ins, outs: _fetch_score_ir(
670+
ins[0],
671+
outs[0],
672+
score_index,
673+
),
674+
dtype=[data.dtype],
675+
in_buffers=[data_buf],
676+
out_buffers=[score_buf],
677+
name="fetch_score",
678+
tag="fetch_score",
679+
)
680+
681+
if is_thrust_available():
682+
sort_tensor = argsort_thrust(
683+
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype="int32"
684+
)
685+
else:
686+
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
687+
688+
return sort_tensor
689+
690+
641691
def non_max_suppression(
642692
data,
643693
valid_count,
@@ -736,54 +786,35 @@ def non_max_suppression(
736786
valid_count_buf = tvm.tir.decl_buffer(
737787
valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4
738788
)
739-
score_axis = score_index
789+
740790
score_shape = (batch_size, num_anchors)
741791
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
742-
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
743-
score_tensor = te.extern(
744-
[score_shape],
745-
[data],
746-
lambda ins, outs: _fetch_score_ir(
747-
ins[0],
748-
outs[0],
749-
score_axis,
750-
),
751-
dtype=[data.dtype],
752-
in_buffers=[data_buf],
753-
out_buffers=[score_buf],
754-
name="fetch_score",
755-
tag="fetch_score",
756-
)
757-
target = tvm.target.Target.current()
758-
if (
759-
target
760-
and target.kind.name == "cuda"
761-
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
762-
):
763-
sort_tensor = argsort_thrust(
764-
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
765-
)
766-
else:
767-
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)
768792

793+
sort_tensor = _get_sorted_indices(data, data_buf, score_index, score_shape)
769794
sort_tensor_buf = tvm.tir.decl_buffer(
770795
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
771796
)
772797

773-
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
774798
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)
775799

776-
out, box_indices, num_valid_boxes = te.extern(
777-
[data.shape, score_shape, [batch_size, 1]],
800+
bbox_shape = (batch_size, num_anchors, 4)
801+
class_id_shape = score_shape
802+
box_indices_shape = score_shape
803+
num_valid_boxes_shape = (batch_size, 1)
804+
805+
out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = te.extern(
806+
[bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape],
778807
[data, sort_tensor, valid_count, indices],
779808
lambda ins, outs: nms_ir(
780809
ins[0],
781810
ins[1],
782811
ins[2],
783812
ins[3],
784-
outs[0],
785-
outs[1],
786-
outs[2],
813+
outs[0], # sorted bbox
814+
outs[1], # sorted scores
815+
outs[2], # sorted class ids
816+
outs[3], # box_indices
817+
outs[4], # num_valid_boxes
787818
max_output_size,
788819
iou_threshold,
789820
force_suppress,
@@ -793,7 +824,7 @@ def non_max_suppression(
793824
score_index,
794825
return_indices,
795826
),
796-
dtype=[data.dtype, "int32", "int32"],
827+
dtype=[data.dtype, "float32", "float32", "int32", "int32"],
797828
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
798829
name="nms",
799830
tag="nms",
@@ -802,4 +833,9 @@ def non_max_suppression(
802833
if return_indices:
803834
return [box_indices, num_valid_boxes]
804835

805-
return out
836+
# TODO: do concat
837+
return out_bboxes
838+
# if id_index >= 0:
839+
# return concatenate([out_bboxes
840+
841+
# return out

0 commit comments

Comments
 (0)