Skip to content

Improve TensorRT-LLM Functionality#487

Merged
seanshi-scale merged 21 commits intomainfrom
seanshi/20240409-tensorrtllm-improvements
May 15, 2024
Merged

Improve TensorRT-LLM Functionality#487
seanshi-scale merged 21 commits intomainfrom
seanshi/20240409-tensorrtllm-improvements

Conversation

@seanshi-scale
Copy link
Copy Markdown
Contributor

@seanshi-scale seanshi-scale commented Apr 10, 2024

Pull Request Summary

Changes to get tensorrtllm to work with Mixtral

  1. Update tensorrt llm included code/build processes to a newer version
  2. Add some bits to mitigate some tokenization issues

Note: the logprobs returned aren't correct still, haven't investigated.

Test Plan and Usage Guide

Deployed a weights-only quantized Mixtral model, which works

@seanshi-scale seanshi-scale self-assigned this Apr 30, 2024
output = self.tokenizer.decode(
tokens[:seq_len],
skip_special_tokens=self.skip_special_tokens)
# Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differs from NVIDIA here

item_flat_ids += ids
item_offsets.append(len(ids))

# Add a case where ids[0] decodes to empty string, then add another set of ids here
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differs from NVIDIA here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a partial patch to get some of the stop sequence behavior more functional; when I was trying stop sequences out, I noticed that there were cases where the stop sequence was being ignored. Root cause was because trt tokenizes the stop sequences, and looks for that exact sequence of tokens; but there was a case where what's returned from the model isn't quite that stop sequence, i.e. there's an extra empty token when trt tokenizes the stop sequence. So I just patched it so that we also look for the original sequence minus the empty token.

ex. if you pass in a stop sequence of text that tokenizes to [1,2,3], the model can output the sequence [..., 2,3], where [2,3] also decodes to text.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even this doesn't seem right? stop sequence could be a part of [2,3], e.g. tok2 = abc, tok3 = def, stop sequence = cdef. i think best is to compare in postprocessing with string?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, this isn't a complete fix for stop sequences for sure unfortunately; think we'd need to spend more time to see if it's possible with the current framework and how to do it

fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response(
status=TaskStatus.SUCCESS,
result={
"result": '{"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":" Machine learning is a branch"}'
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may need to figure out why the log probs are not returned properly

@seanshi-scale seanshi-scale marked this pull request as ready for review May 9, 2024 23:00
@seanshi-scale seanshi-scale changed the title Seanshi/20240409 tensorrtllm improvements Improve TensorRT-LLM Functionality May 9, 2024
@seanshi-scale seanshi-scale requested a review from squeakymouse May 9, 2024 23:02
for beam_idx, tokens in enumerate(beam_tokens):
seq_len = sequence_lengths[batch_idx][beam_idx]
output = self.tokenizer.decode(
tokens[:seq_len], skip_special_tokens=self.skip_special_tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we restrict to [:seq_len], what are in tokens that outside of seq_len?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423
# This is somewhat of a hack: add a space before the output if the first token starts with a space
# This may add a space in front of the first token though when we don't want it.
token_id_string = self.tokenizer.convert_ids_to_tokens(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is expensive though? we're effectively decoding twice. if we're only checking string for first token, can we just do self.tokenizer.convert_ids_to_tokens( tokens[0], skip_special_tokens=self.skip_special_tokens )?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed this to just look at the first token

item_flat_ids += ids
item_offsets.append(len(ids))

# Add a case where ids[0] decodes to empty string, then add another set of ids here
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this?

item_flat_ids += ids
item_offsets.append(len(ids))

# Add a case where ids[0] decodes to empty string, then add another set of ids here
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even this doesn't seem right? stop sequence could be a part of [2,3], e.g. tok2 = abc, tok3 = def, stop sequence = cdef. i think best is to compare in postprocessing with string?

for batch_idx, beam_tokens in enumerate(tokens_batch):
for beam_idx, tokens in enumerate(beam_tokens):
seq_len = sequence_lengths[batch_idx][beam_idx]
output = self.tokenizer.decode(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check for stop token here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember seeing stop tokens in the output when testing at least.

@seanshi-scale seanshi-scale merged commit 1470aac into main May 15, 2024
@seanshi-scale seanshi-scale deleted the seanshi/20240409-tensorrtllm-improvements branch May 15, 2024 21:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants