@@ -1183,6 +1183,12 @@ def sample_async(
11831183 model_outputs : dict [str , torch .Tensor ],
11841184 num_context_logits_prefix_sum : list [int ],
11851185 resource_manager : Optional [ResourceManager ] = None ) -> SampleState :
1186+ # NB: The sampler is either called directly by PyExecutor, for the target model,
1187+ # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1188+ # case there are 1 + get_draft_token_length(request) tokens per request. In the
1189+ # latter case, there is always only 1 token per request because draft
1190+ # tokens are sampled one-by-one.
1191+
11861192 requests = scheduled_requests .all_requests ()
11871193 new_tokens = self .store .new_tokens
11881194 log_probs_host = self .log_probs_host (scheduled_requests )
@@ -1332,8 +1338,6 @@ def _sample_batched_by_strategy(
13321338 requests , pin_memory = True )
13331339 generator_cuda = self .get_generator (cuda_device )
13341340
1335- # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1336- #
13371341 # NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
13381342 # breaking _process_draft_tokens_rejection_sampling.
13391343 needs_d2t = "d2t" in model_outputs
@@ -1459,15 +1463,16 @@ def _sample_batched_by_strategy(
14591463 (batch_req_indices , batch_next_tokens_cuda_int ,
14601464 batch_softmax_cuda ), = batched_results
14611465
1462- # FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1463- # parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1464- # that d2t can be applied in a batched manner similar to the code below.
1466+ # NB: 'd2t' contains offsets for transforming draft vocab token IDs into
1467+ # the target vocab. This is used by Eagle3ForCausalLM, whose input domain
1468+ # is the target vocab, whereas the output logits correspond to the draft
1469+ # vocab. Since the inputs/outputs are linked by TorchSampler.update_requests,
1470+ # they currently need to be handled within TorchSampler. Changing the model
1471+ # outputs to use the target vocab would require inflating the logit tensors,
1472+ # which is inefficient. Changing the inputs to use the draft vocab, might
1473+ # be cleaner, but would require applying 'd2t' in multiple locations:
1474+ # Prefill, Eagle3ForCausalLM embeddings, ModelDrafter
14651475 if needs_d2t :
1466- # NB: The sampler is either called directly by PyExecutor, for the target model,
1467- # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1468- # case there are 1 + get_draft_token_length(request) tokens per request. In the
1469- # latter case, only there is always only 1 token per request because draft
1470- # tokens are sampled one-by-one.
14711476 self ._apply_d2t (batch_next_tokens_cuda_int , model_outputs )
14721477
14731478 return _BatchedSamplingResult (
@@ -1909,7 +1914,6 @@ def sample_async(
19091914 num_context_logits_prefix_sum : list [int ],
19101915 resource_manager : Optional [ResourceManager ] = None
19111916 ) -> SampleStateTRTLLM :
1912-
19131917 batch_size = scheduled_requests .batch_size
19141918 beam_width = self .beam_width (scheduled_requests .all_requests ())
19151919 if (batch_size > 1 and beam_width > 1
0 commit comments