@@ -862,7 +862,7 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
862862 self .real_shape_to_captured_size [bs ] = end
863863 self .real_shape_to_captured_size [self .max_capture_size ] = self .max_capture_size
864864
865- def _set_cudagraph_sizes (self , max_num_seqs : int = 0 ):
865+ def _set_cudagraph_sizes (self , max_capture_size : int = 0 ):
866866 """
867867 Calculate a series of candidate capture sizes,
868868 and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -874,7 +874,7 @@ def _set_cudagraph_sizes(self, max_num_seqs: int = 0):
874874 # Shape [256, 288, ... 992, 1024]
875875 draft_capture_sizes += [32 * i for i in range (9 , 33 )]
876876
877- draft_capture_sizes .append (max_num_seqs )
877+ draft_capture_sizes .append (max_capture_size )
878878 self .cudagraph_capture_sizes = sorted (draft_capture_sizes )
879879
880880 def to_json_string (self ):
@@ -1391,19 +1391,22 @@ def __init__(
13911391 self .cache_config : CacheConfig = cache_config # type: ignore
13921392 self .plas_attention_config : Optional [PlasAttentionConfig ] = plas_attention_config
13931393 self .structured_outputs_config : StructuredOutputsConfig = structured_outputs_config
1394- # Initialize cuda graph capture list
1395- if self .graph_opt_config .cudagraph_capture_sizes is None :
1396- self .graph_opt_config ._set_cudagraph_sizes (max_num_seqs = self .scheduler_config .max_num_seqs )
13971394
1395+ # Initialize cuda graph capture list
1396+ max_capture_shape = self .scheduler_config .max_num_seqs
1397+ if self .speculative_config is not None and self .speculative_config .method == "mtp" :
1398+ max_capture_shape = self .scheduler_config .max_num_seqs * (
1399+ self .speculative_config .num_speculative_tokens + 1
1400+ )
1401+ assert max_capture_shape % 2 == 0 , "CUDAGraph only supports capturing even token nums in MTP scenarios."
13981402 if self .graph_opt_config .cudagraph_only_prefill :
1399- self .graph_opt_config .init_with_cudagrpah_size (max_capture_size = 512 )
1400- elif self .speculative_config is not None and self .speculative_config .method == "mtp" :
1401- max_shape = self .scheduler_config .max_num_seqs * (self .speculative_config .num_speculative_tokens + 1 )
1402- if max_shape % 2 == 1 :
1403- max_shape = max_shape + 1
1404- self .graph_opt_config .init_with_cudagrpah_size (max_capture_size = min (512 , max_shape ))
1403+ max_capture_shape = 512
14051404 else :
1406- self .graph_opt_config .init_with_cudagrpah_size (max_capture_size = self .scheduler_config .max_num_seqs )
1405+ max_capture_shape = min (512 , max_capture_shape )
1406+
1407+ if self .graph_opt_config .cudagraph_capture_sizes is None :
1408+ self .graph_opt_config ._set_cudagraph_sizes (max_capture_size = max_capture_shape )
1409+ self .graph_opt_config .init_with_cudagrpah_size (max_capture_size = max_capture_shape )
14071410
14081411 self .tokenizer = tokenizer
14091412 self .ips = ips
0 commit comments