Skip to content

to_device() incorrectly casts integer tensors when dtype is already provided #36

@aniekannn

Description

@aniekannn

Hi NV team! Thanks for sharing this work!

I noticed that the to_device() function in helper.py currently changes all tensors to whatever dtype given. Some token tensors like input_ids and attention_mask are changing when they typically stay as integers for Hugging Face models. It’s causing errors if you try to use mixed precision, and the fields get turned into floats and break embedding lookups. Is it possible to only cast tensors to the dtype if they're already floating-points by checking with is_floating_point() first? This way it preserves critical integer/bool fields while still allowing correct mixed precision inference for other tensors.

def to_device(data, device=None, dtype=None):
    if isinstance(data, torch.Tensor):
        if dtype is not None and data.is_floating_point():
            return data.to(device=device, dtype=dtype)
        return data.to(device=device)
    elif isinstance(data, collections.abc.Mapping):
        return {k: to_device(v, device=device, dtype=dtype) for k, v in data.items()}
    elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
        return [to_device(x, device=device, dtype=dtype) for x in data]
    else:
        return data

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions