Skip to content

Preserve integer/bool tensor dtype in to_device() (fixes #36)#80

Open
lonexreb wants to merge 2 commits intoNVlabs:mainfrom
lonexreb:fix/to-device-preserve-integer-tensors
Open

Preserve integer/bool tensor dtype in to_device() (fixes #36)#80
lonexreb wants to merge 2 commits intoNVlabs:mainfrom
lonexreb:fix/to-device-preserve-integer-tensors

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 4, 2026

Summary

Fixes #36.

helper.to_device() currently casts every tensor it encounters to the requested dtype. Under mixed-precision inference (e.g. dtype=torch.bfloat16), this also rewrites input_ids and attention_mask from int64 to bfloat16, which then crashes the embedding lookup in the VLM.

This PR restricts the dtype cast to floating-point tensors. Integer and boolean tensors keep their original dtype and are still moved to device.

Behavior

Caller dtype arg Before After
test_inference.py:52 (to_device(model_inputs, "cuda")) None device-only move unchanged
Mixed-precision call (to_device(batch, "cuda", torch.bfloat16)) non-None every tensor cast — breaks embedding lookups floats cast to bfloat16; ints/bools preserved

The reporter (@aniekannn) suggested the same shape of fix in the issue.

Test plan

  • Existing device-only call path in test_inference.py is unaffected (dtype is None, so the original data.to(device=device) path is taken).
  • Manual smoke: pass a dict containing a float tensor and an int64 tensor through to_device(..., dtype=torch.bfloat16) and confirm the float is cast and the int stays int64.

lonexreb added 2 commits May 4, 2026 03:48
When `dtype` is provided, `to_device` previously cast every tensor in the
payload — including `input_ids` and `attention_mask` — to that dtype. With
mixed-precision inference (e.g. `dtype=torch.bfloat16`) this turns token IDs
into floats and breaks subsequent embedding lookups.

Only apply `dtype` to floating-point tensors; integer and boolean tensors
keep their original dtype and are still moved to `device`. Behavior is
unchanged when `dtype` is `None` (the path used by `test_inference.py`).

Fixes NVlabs#36

Signed-off-by: lonexreb <[email protected]>
Five scenarios covering the issue NVlabs#36 contract:
- dtype=None preserves all original dtypes
- dtype=<float> keeps integer tensors as-is and casts floats
- bool tensors are preserved when a float dtype is requested
- recursion descends into nested mappings and sequences
- non-tensor leaves (str, int) pass through unchanged

Signed-off-by: lonexreb <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

to_device() incorrectly casts integer tensors when dtype is already provided

1 participant