Improve TensorRT-LLM Functionality#487
Conversation
…rrtllm backend repo, wip
| output = self.tokenizer.decode( | ||
| tokens[:seq_len], | ||
| skip_special_tokens=self.skip_special_tokens) | ||
| # Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423 |
There was a problem hiding this comment.
Differs from NVIDIA here
| item_flat_ids += ids | ||
| item_offsets.append(len(ids)) | ||
|
|
||
| # Add a case where ids[0] decodes to empty string, then add another set of ids here |
There was a problem hiding this comment.
Differs from NVIDIA here
There was a problem hiding this comment.
This is a partial patch to get some of the stop sequence behavior more functional; when I was trying stop sequences out, I noticed that there were cases where the stop sequence was being ignored. Root cause was because trt tokenizes the stop sequences, and looks for that exact sequence of tokens; but there was a case where what's returned from the model isn't quite that stop sequence, i.e. there's an extra empty token when trt tokenizes the stop sequence. So I just patched it so that we also look for the original sequence minus the empty token.
ex. if you pass in a stop sequence of text that tokenizes to [1,2,3], the model can output the sequence [..., 2,3], where [2,3] also decodes to text.
There was a problem hiding this comment.
even this doesn't seem right? stop sequence could be a part of [2,3], e.g. tok2 = abc, tok3 = def, stop sequence = cdef. i think best is to compare in postprocessing with string?
There was a problem hiding this comment.
yup, this isn't a complete fix for stop sequences for sure unfortunately; think we'd need to spend more time to see if it's possible with the current framework and how to do it
| fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( | ||
| status=TaskStatus.SUCCESS, | ||
| result={ | ||
| "result": '{"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":" Machine learning is a branch"}' |
There was a problem hiding this comment.
may need to figure out why the log probs are not returned properly
| for beam_idx, tokens in enumerate(beam_tokens): | ||
| seq_len = sequence_lengths[batch_idx][beam_idx] | ||
| output = self.tokenizer.decode( | ||
| tokens[:seq_len], skip_special_tokens=self.skip_special_tokens |
There was a problem hiding this comment.
why do we restrict to [:seq_len], what are in tokens that outside of seq_len?
There was a problem hiding this comment.
| # Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423 | ||
| # This is somewhat of a hack: add a space before the output if the first token starts with a space | ||
| # This may add a space in front of the first token though when we don't want it. | ||
| token_id_string = self.tokenizer.convert_ids_to_tokens( |
There was a problem hiding this comment.
this is expensive though? we're effectively decoding twice. if we're only checking string for first token, can we just do self.tokenizer.convert_ids_to_tokens( tokens[0], skip_special_tokens=self.skip_special_tokens )?
There was a problem hiding this comment.
changed this to just look at the first token
| item_flat_ids += ids | ||
| item_offsets.append(len(ids)) | ||
|
|
||
| # Add a case where ids[0] decodes to empty string, then add another set of ids here |
| item_flat_ids += ids | ||
| item_offsets.append(len(ids)) | ||
|
|
||
| # Add a case where ids[0] decodes to empty string, then add another set of ids here |
There was a problem hiding this comment.
even this doesn't seem right? stop sequence could be a part of [2,3], e.g. tok2 = abc, tok3 = def, stop sequence = cdef. i think best is to compare in postprocessing with string?
| for batch_idx, beam_tokens in enumerate(tokens_batch): | ||
| for beam_idx, tokens in enumerate(beam_tokens): | ||
| seq_len = sequence_lengths[batch_idx][beam_idx] | ||
| output = self.tokenizer.decode( |
There was a problem hiding this comment.
should we check for stop token here?
There was a problem hiding this comment.
I don't remember seeing stop tokens in the output when testing at least.
Pull Request Summary
Changes to get tensorrtllm to work with Mixtral
Note: the logprobs returned aren't correct still, haven't investigated.
Test Plan and Usage Guide
Deployed a weights-only quantized Mixtral model, which works