2121from tvm import te
2222
2323from tvm .tir import if_then_else
24- from .sort import argsort , argsort_thrust
24+ from .sort import argsort , argsort_thrust , is_thrust_available
2525
2626
2727def cuda_atomic_add_rule (op ):
@@ -338,7 +338,9 @@ def nms_ir(
338338 sorted_index ,
339339 valid_count ,
340340 indices ,
341- out ,
341+ out_bboxes ,
342+ out_scores ,
343+ out_class_ids ,
342344 box_indices ,
343345 num_valid_boxes ,
344346 max_output_size ,
@@ -458,9 +460,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
458460 sorted_index = ib .buffer_ptr (sorted_index )
459461 valid_count = ib .buffer_ptr (valid_count )
460462 indices = ib .buffer_ptr (indices )
461- num_valid_boxes = ib .buffer_ptr (num_valid_boxes )
462- out = ib .buffer_ptr (out )
463+
464+ # outputs
465+ out_bboxes = ib .buffer_ptr (out_bboxes )
466+ out_scores = ib .buffer_ptr (out_scores )
467+ out_class_ids = ib .buffer_ptr (out_class_ids )
463468 box_indices = ib .buffer_ptr (box_indices )
469+ num_valid_boxes = ib .buffer_ptr (num_valid_boxes )
464470
465471 if isinstance (iou_threshold , float ):
466472 iou_threshold = tvm .tir .FloatImm ("float32" , iou_threshold )
@@ -483,36 +489,55 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
483489 ib .scope_attr (tx , "thread_extent" , nthread_tx )
484490 ib .scope_attr (bx , "thread_extent" , nthread_bx )
485491 i = by
486- base_idx = i * num_anchors * box_data_length
492+ base_src_idx = i * num_anchors * box_data_length
493+ base_bbox_idx = i * num_anchors * 4
494+
487495 with ib .if_scope (tvm .tir .all (iou_threshold > 0 , valid_count [i ] > 0 )):
488496 # Reorder output
489497 nkeep = if_then_else (
490498 tvm .tir .all (top_k > 0 , top_k < valid_count [i ]), top_k , valid_count [i ]
491499 )
492500 j = bx * max_threads + tx
493501 with ib .if_scope (j < nkeep ):
502+ src_idx = base_src_idx + sorted_index [i * num_anchors + j ] * box_data_length
494503 # Fill in out with sorted boxes
495- with ib .for_range (0 , box_data_length ) as k :
496- out [(base_idx + j * box_data_length + k )] = data [
497- (base_idx + sorted_index [i * num_anchors + j ] * box_data_length + k )
498- ]
504+ with ib .for_range (0 , 4 ) as k :
505+ out_bboxes [(base_bbox_idx + j * 4 + k )] = data [src_idx + coord_start + k ]
506+
507+ out_scores [i * num_anchors + j ] = data [src_idx + score_index ]
508+
509+ if id_index >= 0 :
510+ out_class_ids [i * num_anchors + j ] = data [src_idx + id_index ]
511+
499512 with ib .else_scope ():
500513 # Indices > nkeep are discarded
501514 # Only needed for return_indices = False case
502515 if return_indices is False :
503516 with ib .if_scope (j < num_anchors ):
504- with ib .for_range (0 , box_data_length ) as k :
505- out [(base_idx + j * box_data_length + k )] = - 1.0
517+ with ib .for_range (0 , 4 ) as k :
518+ out_bboxes [(base_bbox_idx + j * 4 + k )] = - 1.0
519+
520+ out_scores [i , j ] = - 1.0
521+
522+ if id_index >= 0 :
523+ out_class_ids [i , j ] = - 1.0
506524
507525 if return_indices :
508526 with ib .if_scope (j < num_anchors ):
509527 box_indices [i * num_anchors + j ] = - 1
510528
511529 with ib .else_scope ():
512530 with ib .if_scope (j < valid_count [i ]):
513- with ib .for_range (0 , box_data_length ) as k :
514- offset = base_idx + j * box_data_length + k
515- out [offset ] = data [offset ]
531+ src_offset = base_src_idx + j * box_data_length
532+
533+ with ib .for_range (0 , 4 ) as k :
534+ out_bboxes [base_bbox_idx + j * 4 + k ] = data [src_offset + coord_start + k ]
535+
536+ out_scores [i * num_anchors + j ] = data [src_offset + score_index ]
537+
538+ if id_index >= 0 :
539+ out_class_ids [i * num_anchors + j ] = data [src_offset + id_index ]
540+
516541 box_indices [i * num_anchors + j ] = j
517542
518543 with ib .new_scope ():
@@ -526,7 +551,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
526551
527552 i = by
528553
529- base_idx = i * num_anchors * box_data_length
554+ base_bbox_idx = i * num_anchors * 4
530555 num_valid_boxes_local = ib .allocate (
531556 "int32" , (1 ,), name = "num_valid_boxes_local" , scope = "local"
532557 )
@@ -549,37 +574,35 @@ def nms_inner_loop(ib, j):
549574
550575 num_valid_boxes_local [0 ] += 1
551576
552- offset_j = j * box_data_length
577+ offset_j = j * 4
553578 num_iter_per_thread = ceil_div (valid_count [i ] - (j + 1 ), nthread_tx )
554579
555580 with ib .for_range (0 , num_iter_per_thread ) as _k :
556581 k = j + 1 + _k * nthread_tx + tx
557- offset_k = k * box_data_length
582+ offset_k = k * 4
558583
559584 with ib .if_scope (
560585 tvm .tir .all (
561586 k < num_anchors ,
562- out [ base_idx + offset_k + score_index ] > 0 , # is the box k still valid?
587+ out_scores [ i , k ] > 0 , # is the box k still valid?
563588 tvm .tir .any (
564589 force_suppress > 0 ,
565590 id_index < 0 ,
566- out [base_idx + offset_k + id_index ]
567- == out [base_idx + offset_j + id_index ],
591+ out_class_ids [i , k ] == out_class_ids [i , j ],
568592 ),
569593 )
570594 ):
571595 iou = calculate_overlap (
572- out ,
573- base_idx + offset_j + coord_start ,
574- base_idx + offset_k + coord_start ,
596+ out_bboxes ,
597+ base_bbox_idx + offset_j ,
598+ base_bbox_idx + offset_k ,
575599 )
576600 with ib .if_scope (iou >= iou_threshold ):
577601 # invalidate the box k
578- out [ base_idx + offset_k + score_index ] = - 1.0
579- with ib . if_scope ( id_index >= 0 ) :
580- out [ base_idx + offset_k + id_index ] = - 1.0
602+ out_scores [ i , k ] = - 1.0
603+ if return_indices is False and id_index >= 0 :
604+ out_class_ids [ i , k ] = - 1.0
581605
582- # Make sure to do the next loop in a lock step
583606 ib .emit (tvm .tir .Call (None , "tir.tvm_storage_sync" , tvm .runtime .convert (["shared" ])))
584607
585608 if isinstance (max_output_size , int ):
@@ -589,7 +612,7 @@ def nms_inner_loop(ib, j):
589612 # Apply nms
590613 with ib .for_range (0 , valid_count [i ]) as j :
591614 # Proceed to the inner loop if the box j is still valid
592- with ib .if_scope (out [ base_idx + ( j * box_data_length ) + score_index ] > - 1.0 ):
615+ with ib .if_scope (out_scores [ i , j ] > - 1.0 ):
593616 with ib .if_scope (max_output_size > 0 ):
594617 # No need to do more iteration if we already reach max_output_size boxes
595618 with ib .if_scope (num_valid_boxes_local [0 ] < max_output_size ):
@@ -638,6 +661,33 @@ def _fetch_score_ir(data, score, axis):
638661 return ib .get ()
639662
640663
664+ def _get_sorted_indices (data , data_buf , score_index , score_shape ):
665+ score_buf = tvm .tir .decl_buffer (score_shape , data .dtype , "score_buf" , data_alignment = 8 )
666+ score_tensor = te .extern (
667+ [score_shape ],
668+ [data ],
669+ lambda ins , outs : _fetch_score_ir (
670+ ins [0 ],
671+ outs [0 ],
672+ score_index ,
673+ ),
674+ dtype = [data .dtype ],
675+ in_buffers = [data_buf ],
676+ out_buffers = [score_buf ],
677+ name = "fetch_score" ,
678+ tag = "fetch_score" ,
679+ )
680+
681+ if is_thrust_available ():
682+ sort_tensor = argsort_thrust (
683+ score_tensor , valid_count = None , axis = 1 , is_ascend = False , dtype = "int32"
684+ )
685+ else :
686+ sort_tensor = argsort (score_tensor , axis = 1 , is_ascend = False , dtype = "int32" )
687+
688+ return sort_tensor
689+
690+
641691def non_max_suppression (
642692 data ,
643693 valid_count ,
@@ -736,54 +786,35 @@ def non_max_suppression(
736786 valid_count_buf = tvm .tir .decl_buffer (
737787 valid_count .shape , valid_count_dtype , "valid_count_buf" , data_alignment = 4
738788 )
739- score_axis = score_index
789+
740790 score_shape = (batch_size , num_anchors )
741791 data_buf = tvm .tir .decl_buffer (data .shape , data .dtype , "data_buf" , data_alignment = 8 )
742- score_buf = tvm .tir .decl_buffer (score_shape , data .dtype , "score_buf" , data_alignment = 8 )
743- score_tensor = te .extern (
744- [score_shape ],
745- [data ],
746- lambda ins , outs : _fetch_score_ir (
747- ins [0 ],
748- outs [0 ],
749- score_axis ,
750- ),
751- dtype = [data .dtype ],
752- in_buffers = [data_buf ],
753- out_buffers = [score_buf ],
754- name = "fetch_score" ,
755- tag = "fetch_score" ,
756- )
757- target = tvm .target .Target .current ()
758- if (
759- target
760- and target .kind .name == "cuda"
761- and tvm .get_global_func ("tvm.contrib.thrust.sort_nms" , allow_missing = True )
762- ):
763- sort_tensor = argsort_thrust (
764- score_tensor , valid_count = None , axis = 1 , is_ascend = False , dtype = valid_count_dtype
765- )
766- else :
767- sort_tensor = argsort (score_tensor , axis = 1 , is_ascend = False , dtype = valid_count_dtype )
768792
793+ sort_tensor = _get_sorted_indices (data , data_buf , score_index , score_shape )
769794 sort_tensor_buf = tvm .tir .decl_buffer (
770795 sort_tensor .shape , sort_tensor .dtype , "sort_tensor_buf" , data_alignment = 8
771796 )
772797
773- data_buf = tvm .tir .decl_buffer (data .shape , data .dtype , "data_buf" , data_alignment = 8 )
774798 indices_buf = tvm .tir .decl_buffer (indices .shape , indices .dtype , "indices_buf" , data_alignment = 8 )
775799
776- out , box_indices , num_valid_boxes = te .extern (
777- [data .shape , score_shape , [batch_size , 1 ]],
800+ bbox_shape = (batch_size , num_anchors , 4 )
801+ class_id_shape = score_shape
802+ box_indices_shape = score_shape
803+ num_valid_boxes_shape = (batch_size , 1 )
804+
805+ out_bboxes , out_scores , out_sorted_ids , box_indices , num_valid_boxes = te .extern (
806+ [bbox_shape , score_shape , class_id_shape , box_indices_shape , num_valid_boxes_shape ],
778807 [data , sort_tensor , valid_count , indices ],
779808 lambda ins , outs : nms_ir (
780809 ins [0 ],
781810 ins [1 ],
782811 ins [2 ],
783812 ins [3 ],
784- outs [0 ],
785- outs [1 ],
786- outs [2 ],
813+ outs [0 ], # sorted bbox
814+ outs [1 ], # sorted scores
815+ outs [2 ], # sorted class ids
816+ outs [3 ], # box_indices
817+ outs [4 ], # num_valid_boxes
787818 max_output_size ,
788819 iou_threshold ,
789820 force_suppress ,
@@ -793,7 +824,7 @@ def non_max_suppression(
793824 score_index ,
794825 return_indices ,
795826 ),
796- dtype = [data .dtype , "int32" , "int32" ],
827+ dtype = [data .dtype , "float32" , "float32" , " int32" , "int32" ],
797828 in_buffers = [data_buf , sort_tensor_buf , valid_count_buf , indices_buf ],
798829 name = "nms" ,
799830 tag = "nms" ,
@@ -802,4 +833,9 @@ def non_max_suppression(
802833 if return_indices :
803834 return [box_indices , num_valid_boxes ]
804835
805- return out
836+ # TODO: do concat
837+ return out_bboxes
838+ # if id_index >= 0:
839+ # return concatenate([out_bboxes
840+
841+ # return out
0 commit comments