Preserve integer/bool tensor dtype in to_device() (fixes #36)#80
Open
lonexreb wants to merge 2 commits intoNVlabs:mainfrom
Open
Preserve integer/bool tensor dtype in to_device() (fixes #36)#80lonexreb wants to merge 2 commits intoNVlabs:mainfrom
lonexreb wants to merge 2 commits intoNVlabs:mainfrom
Conversation
When `dtype` is provided, `to_device` previously cast every tensor in the payload — including `input_ids` and `attention_mask` — to that dtype. With mixed-precision inference (e.g. `dtype=torch.bfloat16`) this turns token IDs into floats and breaks subsequent embedding lookups. Only apply `dtype` to floating-point tensors; integer and boolean tensors keep their original dtype and are still moved to `device`. Behavior is unchanged when `dtype` is `None` (the path used by `test_inference.py`). Fixes NVlabs#36 Signed-off-by: lonexreb <[email protected]>
Five scenarios covering the issue NVlabs#36 contract: - dtype=None preserves all original dtypes - dtype=<float> keeps integer tensors as-is and casts floats - bool tensors are preserved when a float dtype is requested - recursion descends into nested mappings and sequences - non-tensor leaves (str, int) pass through unchanged Signed-off-by: lonexreb <[email protected]>
2726542 to
ad3a6a3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #36.
helper.to_device()currently casts every tensor it encounters to the requesteddtype. Under mixed-precision inference (e.g.dtype=torch.bfloat16), this also rewritesinput_idsandattention_maskfromint64tobfloat16, which then crashes the embedding lookup in the VLM.This PR restricts the dtype cast to floating-point tensors. Integer and boolean tensors keep their original dtype and are still moved to
device.Behavior
dtypeargtest_inference.py:52(to_device(model_inputs, "cuda"))Noneto_device(batch, "cuda", torch.bfloat16))bfloat16; ints/bools preservedThe reporter (@aniekannn) suggested the same shape of fix in the issue.
Test plan
device-only call path intest_inference.pyis unaffected (dtype isNone, so the originaldata.to(device=device)path is taken).int64tensor throughto_device(..., dtype=torch.bfloat16)and confirm the float is cast and the int staysint64.